diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala index 036889ff78128..ab02d33f39913 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala @@ -380,6 +380,32 @@ class CodeGenerator( } } + def genSetAggregatesContext: String = { + + val sig: String = + j""" + | public final void setAggregateContext( + | org.apache.flink.table.functions.AggregateContext aggregateContext)""".stripMargin + + val setAggs: String = { + for (i <- aggs.indices) yield + + j""" + | if ((org.apache.flink.table.functions.AggregateFunction)${aggs(i)} instanceof + | org.apache.flink.table.functions.RichAggregateFunction) { + | Object obj = ${aggs(i)}; + | org.apache.flink.table.functions.RichAggregateFunction richAggFunction = + | (org.apache.flink.table.functions.RichAggregateFunction) obj; + | richAggFunction.setAggregateContext(aggregateContext); + | }""".stripMargin + }.mkString("\n") + + j""" + |$sig { + |$setAggs + | }""".stripMargin + } + def genSetAggregationResults: String = { val sig: String = @@ -637,6 +663,7 @@ class CodeGenerator( | """.stripMargin + funcCode += genSetAggregatesContext + "\n" funcCode += genSetAggregationResults + "\n" funcCode += genAccumulate + "\n" funcCode += genRetract + "\n" diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateContext.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateContext.scala new file mode 100644 index 0000000000000..4299feca19f29 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateContext.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions + +import org.apache.flink.annotation.PublicEvolving +import org.apache.flink.api.common.functions.RuntimeContext +import org.apache.flink.api.common.state._ + +/** + * A AggregateContext allows to obtain global runtime information about the context in which the + * aggregate function is executed. The information include the methods for accessing state. + * + * @param context the runtime context in which the Flink Function is executed + */ +class AggregateContext(context: RuntimeContext) { + + // ------------------------------------------------------------------------ + // Methods for accessing state + // ------------------------------------------------------------------------ + + /** + * Gets a handle to the [[ValueState]]. + * + * @param stateProperties The descriptor defining the properties of the stats. + * @tparam T The type of value stored in the state. + * @return The partitioned state object. + * @throws UnsupportedOperationException Thrown, if no partitioned state is available for the + * function (function is not part of a KeyedStream). + */ + @PublicEvolving + def getState[T](stateProperties: ValueStateDescriptor[T]): ValueState[T] = + context.getState(stateProperties) + + /** + * Gets a handle to the [[ListState]]. + * + * @param stateProperties The descriptor defining the properties of the stats. + * @tparam T The type of value stored in the state. + * @return The partitioned state object. + * @throws UnsupportedOperationException Thrown, if no partitioned state is available for the + * function (function is not part os a KeyedStream). + */ + @PublicEvolving + def getListState[T](stateProperties: ListStateDescriptor[T]): ListState[T] = + context.getListState(stateProperties) + + /** + * Gets a handle to the [[MapState]]. + * + * @param stateProperties The descriptor defining the properties of the stats. + * @tparam UK The type of the user keys stored in the state. + * @tparam UV The type of the user values stored in the state. + * @return The partitioned state object. + * @throws UnsupportedOperationException Thrown, if no partitioned state is available for the + * function (function is not part of a KeyedStream). + */ + @PublicEvolving + def getMapState[UK, UV](stateProperties: MapStateDescriptor[UK, UV]): MapState[UK, UV] = + context.getMapState(stateProperties) +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/RichAggregateFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/RichAggregateFunction.scala new file mode 100644 index 0000000000000..0d674d2e681d6 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/RichAggregateFunction.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions + +import org.apache.flink.api.common.state._ +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.common.typeutils.TypeSerializer + +import scala.collection.mutable + +/** + * Rich variant of the [[AggregateFunction]]. It encapsulates access to the state. + * + */ +abstract class RichAggregateFunction[T, ACC] extends AggregateFunction[T, ACC] { + private var aggContext: AggregateContext = _ + private val descriptorMapping = mutable.Map[String, StateDescriptor[_, _]]() + + private[flink] def setAggregateContext(context : AggregateContext) = { + this.aggContext = context + } + + def registerValue[K](name: String, typeClass: Class[K]): Unit = { + descriptorMapping.put(name, new ValueStateDescriptor[K](name, typeClass)) + } + + def registerValue[K](name: String, typeInfo: TypeInformation[K]): Unit = { + descriptorMapping.put(name, new ValueStateDescriptor[K](name, typeInfo)) + } + + def registerValue[K](name: String, typeSerializer: TypeSerializer[K]): Unit = { + descriptorMapping.put(name, new ValueStateDescriptor[K](name, typeSerializer)) + } + + def registerList[K](name: String, elementTypeClass: Class[K]): Unit = { + descriptorMapping.put(name, new ListStateDescriptor[K](name, elementTypeClass)) + } + + def registerList[K](name: String, elementTypeInfo: TypeInformation[K]): Unit = { + descriptorMapping.put(name, new ListStateDescriptor[K](name, elementTypeInfo)) + } + + def registerList[K](name: String, typeSerializer: TypeSerializer[K]): Unit = { + descriptorMapping.put(name, new ListStateDescriptor[K](name, typeSerializer)) + } + + def registerMap[UK, UV](name: String, keySerializer: TypeSerializer[UK], + valueSerializer: TypeSerializer[UV]): Unit = { + descriptorMapping.put(name, new MapStateDescriptor[UK, UV](name, keySerializer, valueSerializer)) + } + + def registerMap[UK, UV](name: String, keyTypeInfo: TypeInformation[UK], + valueTypeInfo: TypeInformation[UV]): Unit = { + descriptorMapping.put(name, new MapStateDescriptor[UK, UV](name, keyTypeInfo, valueTypeInfo)) + } + + def registerMap[UK, UV](name: String, keyClass: Class[UK], valueClass: Class[UV]): Unit = { + descriptorMapping.put(name, new MapStateDescriptor[UK, UV](name, keyClass, valueClass)) + } + + def getValueByStateName[K](name: String): ValueState[K] = { + aggContext.getState(descriptorMapping(name).asInstanceOf[ValueStateDescriptor[K]]) + } + + def getListByStateName[K](name: String): ListState[K] = { + aggContext.getListState(descriptorMapping(name).asInstanceOf[ListStateDescriptor[K]]) + } + + def getMapByStateName[UK, UV](name: String): MapState[UK, UV] = { + aggContext.getMapState(descriptorMapping(name).asInstanceOf[MapStateDescriptor[UK, UV]]) + } +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAggFunction.scala index dd9c015c2d98e..639775d14e577 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAggFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAggFunction.scala @@ -75,5 +75,6 @@ class AggregateAggFunction(genAggregations: GeneratedAggregationsFunction) genAggregations.code) LOG.debug("Instantiating AggregateHelper.") function = clazz.newInstance() + } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala index 8073959ae0bd5..c0827b8a78790 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala @@ -43,6 +43,7 @@ import org.apache.flink.table.functions.aggfunctions._ import org.apache.flink.table.functions.utils.AggSqlFunction import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._ import org.apache.flink.table.functions.{AggregateFunction => TableAggregateFunction} +import org.apache.flink.table.functions.{RichAggregateFunction => TableRichAggregateFunction} import org.apache.flink.table.plan.logical._ import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo} import org.apache.flink.table.typeutils.TypeCheckUtils._ @@ -332,6 +333,8 @@ object AggregateUtil { inputType, needRetract) + validateRichAggregate(aggregates) + val mapReturnType: RowTypeInfo = createRowTypeForKeysAndAggregates( groupings, @@ -437,6 +440,8 @@ object AggregateUtil { physicalInputRowType, needRetract) + validateRichAggregate(aggregates) + val returnType: RowTypeInfo = createRowTypeForKeysAndAggregates( groupings, aggregates, @@ -550,6 +555,8 @@ object AggregateUtil { physicalInputRowType, needRetract) + validateRichAggregate(aggregates) + val aggMapping = aggregates.indices.toArray.map(_ + groupings.length) val genPreAggFunction = generator.generateAggregations( @@ -697,6 +704,8 @@ object AggregateUtil { physicalInputRowType, needRetract) + validateRichAggregate(aggregates) + val aggMapping = aggregates.indices.map(_ + groupings.length).toArray val keysAndAggregatesArity = groupings.length + namedAggregates.length @@ -770,6 +779,8 @@ object AggregateUtil { physicalInputRowType, needRetract) + validateRichAggregate(aggregates) + val aggMapping = aggregates.indices.map(_ + groupings.length).toArray val keysAndAggregatesArity = groupings.length + namedAggregates.length @@ -836,6 +847,8 @@ object AggregateUtil { inputType, needRetract) + validateRichAggregate(aggregates) + val (gkeyOutMapping, aggOutMapping) = getOutputMappings( namedAggregates, groupings, @@ -1009,6 +1022,8 @@ object AggregateUtil { inputType, needRetract) + validateRichAggregate(aggregates) + val aggMapping = aggregates.indices.toArray val outputArity = aggregates.length @@ -1038,6 +1053,14 @@ object AggregateUtil { (aggFunction, accumulatorRowType, aggResultRowType) } + private def validateRichAggregate(aggregates: Array[TableAggregateFunction[_, _]]): Unit = { + aggregates.foreach { + case agg: TableRichAggregateFunction[_, _] => + throw new TableException("RichAggregate is currently not supported") + case _ => // ok + } + } + /** * Return true if all aggregates can be partially merged. False otherwise. */ @@ -1417,9 +1440,9 @@ object AggregateUtil { if (accType != null) { accType } else { - val accumulator = agg.createAccumulator() try { - TypeInformation.of(accumulator.getClass) + val method = agg.getClass.getMethod("createAccumulator") + TypeInformation.of(method.getReturnType) } catch { case ite: InvalidTypesException => throw new TableException( diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala index 5f48e091996e5..56315eb740ad0 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala @@ -19,6 +19,7 @@ package org.apache.flink.table.runtime.aggregate import org.apache.flink.api.common.functions.Function +import org.apache.flink.table.functions.AggregateContext import org.apache.flink.types.Row /** @@ -26,6 +27,11 @@ import org.apache.flink.types.Row */ abstract class GeneratedAggregations extends Function { + /** + * Set the context to aggregates + */ + def setAggregateContext(aggregateContext: AggregateContext) + /** * Sets the results of the aggregations (partial or final) to the output row. * Final results are computed with the aggregation function. diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala index 57ea86e70cbef..a2ceda49b5c37 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala @@ -29,6 +29,7 @@ import org.apache.flink.api.common.state.ValueState import org.apache.flink.table.api.{StreamQueryConfig, Types} import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} import org.slf4j.{Logger, LoggerFactory} +import org.apache.flink.table.functions.AggregateContext import org.apache.flink.table.runtime.types.CRow /** @@ -65,6 +66,7 @@ class GroupAggProcessFunction( genAggregations.code) LOG.debug("Instantiating AggregateHelper.") function = clazz.newInstance() + function.setAggregateContext(new AggregateContext(getRuntimeContext)) newRow = new CRow(function.createOutputRow(), true) prevRow = new CRow(function.createOutputRow(), false) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRangeOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRangeOver.scala index d50912ccf9021..16081dccbbc2a 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRangeOver.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRangeOver.scala @@ -33,6 +33,7 @@ import java.util.{ArrayList, List => JList} import org.apache.flink.api.common.typeinfo.BasicTypeInfo import org.apache.flink.table.api.StreamQueryConfig import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} +import org.apache.flink.table.functions.AggregateContext import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo} import org.slf4j.LoggerFactory @@ -70,6 +71,7 @@ class ProcTimeBoundedRangeOver( genAggregations.code) LOG.debug("Instantiating AggregateHelper.") function = clazz.newInstance() + function.setAggregateContext(new AggregateContext(getRuntimeContext)) output = new CRow(function.createOutputRow(), true) // We keep the elements received in a MapState indexed based on their ingestion time diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRowsOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRowsOver.scala index e388c93bf6526..ca9f7cc74a495 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRowsOver.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRowsOver.scala @@ -35,6 +35,7 @@ import java.util.{List => JList} import org.apache.flink.api.common.typeinfo.BasicTypeInfo import org.apache.flink.table.api.StreamQueryConfig import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} +import org.apache.flink.table.functions.AggregateContext import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo} import org.slf4j.LoggerFactory @@ -75,6 +76,7 @@ class ProcTimeBoundedRowsOver( genAggregations.code) LOG.debug("Instantiating AggregateHelper.") function = clazz.newInstance() + function.setAggregateContext(new AggregateContext(getRuntimeContext)) output = new CRow(function.createOutputRow(), true) // We keep the elements received in a Map state keyed diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeUnboundedNonPartitionedOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeUnboundedNonPartitionedOver.scala index 2a6c9c85f95ba..35d41029805c7 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeUnboundedNonPartitionedOver.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeUnboundedNonPartitionedOver.scala @@ -28,6 +28,7 @@ import org.apache.flink.util.Collector import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo} import org.apache.flink.types.Row +import org.apache.flink.table.functions.AggregateContext import org.slf4j.LoggerFactory /** @@ -60,6 +61,7 @@ class ProcTimeUnboundedNonPartitionedOver( genAggregations.code) LOG.debug("Instantiating AggregateHelper.") function = clazz.newInstance() + function.setAggregateContext(new AggregateContext(getRuntimeContext)) output = new CRow(function.createOutputRow(), true) if (null == accumulators) { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeUnboundedPartitionedOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeUnboundedPartitionedOver.scala index 97f0ad78fdc79..f4a226cb88351 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeUnboundedPartitionedOver.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeUnboundedPartitionedOver.scala @@ -27,6 +27,7 @@ import org.apache.flink.api.common.state.ValueState import org.apache.flink.table.api.StreamQueryConfig import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} import org.apache.flink.table.runtime.types.CRow +import org.apache.flink.table.functions.AggregateContext import org.slf4j.LoggerFactory /** @@ -56,6 +57,7 @@ class ProcTimeUnboundedPartitionedOver( genAggregations.code) LOG.debug("Instantiating AggregateHelper.") function = clazz.newInstance() + function.setAggregateContext(new AggregateContext(getRuntimeContext)) output = new CRow(function.createOutputRow(), true) val stateDescriptor: ValueStateDescriptor[Row] = diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRangeOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRangeOver.scala index 65edf6d662fd8..12362439d9d30 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRangeOver.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRangeOver.scala @@ -27,6 +27,7 @@ import org.apache.flink.streaming.api.functions.ProcessFunction import org.apache.flink.table.api.StreamQueryConfig import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo} +import org.apache.flink.table.functions.AggregateContext import org.apache.flink.types.Row import org.apache.flink.util.{Collector, Preconditions} import org.slf4j.LoggerFactory @@ -76,6 +77,7 @@ class RowTimeBoundedRangeOver( genAggregations.code) LOG.debug("Instantiating AggregateHelper.") function = clazz.newInstance() + function.setAggregateContext(new AggregateContext(getRuntimeContext)) output = new CRow(function.createOutputRow(), true) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRowsOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRowsOver.scala index 395ae3986076a..480118fd7d92f 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRowsOver.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRowsOver.scala @@ -30,6 +30,7 @@ import org.apache.flink.types.Row import org.apache.flink.util.{Collector, Preconditions} import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo} +import org.apache.flink.table.functions.AggregateContext import org.slf4j.LoggerFactory /** @@ -81,6 +82,7 @@ class RowTimeBoundedRowsOver( genAggregations.code) LOG.debug("Instantiating AggregateHelper.") function = clazz.newInstance() + function.setAggregateContext(new AggregateContext(getRuntimeContext)) output = new CRow(function.createOutputRow(), true) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeUnboundedOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeUnboundedOver.scala index 741d2b48fea53..0e710cba6737b 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeUnboundedOver.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeUnboundedOver.scala @@ -31,6 +31,7 @@ import org.apache.flink.streaming.api.operators.TimestampedCollector import org.apache.flink.table.api.StreamQueryConfig import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo} +import org.apache.flink.table.functions.AggregateContext import org.slf4j.LoggerFactory @@ -69,6 +70,7 @@ abstract class RowTimeUnboundedOver( genAggregations.code) LOG.debug("Instantiating AggregateHelper.") function = clazz.newInstance() + function.setAggregateContext(new AggregateContext(getRuntimeContext)) output = new CRow(function.createOutputRow(), true) sortedTimestamps = new util.LinkedList[Long]() diff --git a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedAggFunctions.java b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedAggFunctions.java index cfddc57272343..f5e28cb50cc2c 100644 --- a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedAggFunctions.java +++ b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedAggFunctions.java @@ -17,9 +17,13 @@ */ package org.apache.flink.table.api.java.utils; +import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; import org.apache.flink.table.functions.AggregateFunction; +import org.apache.flink.table.functions.RichAggregateFunction; +import java.io.IOException; import java.util.Iterator; public class UserDefinedAggFunctions { @@ -92,4 +96,92 @@ public void retract(WeightedAvgAccum accumulator, int iValue, int iWeight) { accumulator.count -= iWeight; } } + + // A WeightedAvg class with retract and reset method + public static class WeightedAvgWithRetractAndReset extends WeightedAvgWithRetract { + public void resetAccumulator(WeightedAvgAccum acc) { + acc.count = 0; + acc.sum = 0L; + } + } + + // Accumulator for WeightedAvg with state + public static class WeightedStateAvgAccum { + final String valueName = "valuestate"; + } + + // Base class for WeightedAvg with state + public static class WeightedStateAvg + extends RichAggregateFunction { + + @Override + public WeightedStateAvgAccum createAccumulator() { + WeightedStateAvgAccum accum = new WeightedStateAvgAccum(); + registerValue(accum.valueName, new TupleTypeInfo<>(Types.LONG, Types.INT)); + + try { + getValueByStateName(accum.valueName).update(new Tuple2<>(0L, 0)); + } catch (IOException e) { + throw new RuntimeException("init accumulator value failed!", e); + } + return accum; + } + + @Override + public Long getValue(WeightedStateAvgAccum accumulator) { + try { + Tuple2 avgPair = + (Tuple2) getValueByStateName(accumulator.valueName).value(); + if (avgPair.f1 == 0) + return null; + else + return avgPair.f0 / avgPair.f1; + } catch (IOException e) { + throw new RuntimeException("getValue failed!", e); + } + } + + public void accumulate(WeightedStateAvgAccum accumulator, long iValue, int iWeight) { + try { + Tuple2 avgPair = + (Tuple2) getValueByStateName(accumulator.valueName).value(); + avgPair.f0 += iValue * iWeight; + avgPair.f1 += iWeight; + getValueByStateName(accumulator.valueName).update(avgPair); + } catch (IOException e) { + throw new RuntimeException("accumulate failed!", e); + } + } + } + + // A WeightedStateAvg class with retract method + public static class WeightedStateAvgWithRetract extends WeightedStateAvg { + //Overloaded retract method + public void retract(WeightedStateAvgAccum accumulator, long iValue, int iWeight) { + try { + Tuple2 avgPair = + (Tuple2) getValueByStateName(accumulator.valueName).value(); + avgPair.f0 -= iValue * iWeight; + avgPair.f1 -= iWeight; + getValueByStateName(accumulator.valueName).update(avgPair); + } catch (IOException e) { + throw new RuntimeException("retract failed!", e); + } + } + } + + // A WeightedStateAvg class with retract and reset method + public static class WeightedStateAvgWithRetractAndReset extends WeightedStateAvgWithRetract { + public void resetAccumulator(WeightedStateAvgAccum acc) { + try { + Tuple2 avgPair = + (Tuple2) getValueByStateName(acc.valueName).value(); + avgPair.f0 = 0L; + avgPair.f1 = 0; + getValueByStateName(acc.valueName).update(avgPair); + } catch (IOException e) { + throw new RuntimeException("retract failed!", e); + } + } + } } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupAggregationsITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupAggregationsITCase.scala index 9da2c445a9d06..bb5015757f079 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupAggregationsITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupAggregationsITCase.scala @@ -24,6 +24,7 @@ import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment import org.apache.flink.table.api.scala._ import org.apache.flink.table.api.scala.stream.utils.{StreamITCase, StreamTestData, StreamingWithStateTestBase} import org.apache.flink.table.api.{StreamQueryConfig, TableEnvironment} +import org.apache.flink.table.api.java.utils.UserDefinedAggFunctions.{WeightedAvg, WeightedStateAvg} import org.apache.flink.table.api.scala.stream.utils.StreamITCase.RetractingSink import org.apache.flink.types.Row import org.junit.Assert.assertEquals @@ -45,14 +46,17 @@ class GroupAggregationsITCase extends StreamingWithStateTestBase { val tEnv = TableEnvironment.getTableEnvironment(env) StreamITCase.clear + val weightAvg = new WeightedAvg + val weightStateAvg = new WeightedStateAvg + val t = StreamTestData.get3TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c) - .select('a.sum, 'b.sum) + .select('a.sum, 'b.sum, weightAvg('b, 'a), weightStateAvg('b, 'a)) val results = t.toRetractStream[Row](queryConfig) results.addSink(new StreamITCase.RetractingSink).setParallelism(1) env.execute() - val expected = List("231,91") + val expected = List("231,91,5,5") assertEquals(expected.sorted, StreamITCase.retractedResults.sorted) } @@ -63,15 +67,18 @@ class GroupAggregationsITCase extends StreamingWithStateTestBase { val tEnv = TableEnvironment.getTableEnvironment(env) StreamITCase.clear + val weightAvg = new WeightedAvg + val weightStateAvg = new WeightedStateAvg + val t = StreamTestData.get3TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c) .groupBy('b) - .select('b, 'a.sum) + .select('b, 'a.sum, weightAvg('b, 'a), weightStateAvg('b, 'a)) val results = t.toRetractStream[Row](queryConfig) results.addSink(new StreamITCase.RetractingSink) env.execute() - val expected = List("1,1", "2,5", "3,15", "4,34", "5,65", "6,111") + val expected = List("1,1,1,1", "2,5,2,2", "3,15,3,3", "4,34,4,4", "5,65,5,5", "6,111,6,6") assertEquals(expected.sorted, StreamITCase.retractedResults.sorted) } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/OverWindowITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/OverWindowITCase.scala index b097767202165..631587e6151d7 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/OverWindowITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/OverWindowITCase.scala @@ -25,7 +25,7 @@ import org.apache.flink.streaming.api.functions.source.SourceFunction.SourceCont import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment import org.apache.flink.streaming.api.watermark.Watermark import org.apache.flink.table.api.TableEnvironment -import org.apache.flink.table.api.java.utils.UserDefinedAggFunctions.WeightedAvg +import org.apache.flink.table.api.java.utils.UserDefinedAggFunctions._ import org.apache.flink.table.api.scala._ import org.apache.flink.table.api.scala.stream.table.OverWindowITCase.RowTimeSourceFunction import org.apache.flink.table.api.scala.stream.utils.{StreamITCase, StreamingWithStateTestBase} @@ -61,20 +61,22 @@ class OverWindowITCase extends StreamingWithStateTestBase { val table = stream.toTable(tEnv, 'a, 'b, 'c, 'proctime.proctime) val countFun = new CountAggFunction val weightAvgFun = new WeightedAvg + val weightStateAvg = new WeightedStateAvg val windowedTable = table .window( Over partitionBy 'c orderBy 'proctime preceding UNBOUNDED_ROW as 'w) - .select('c, countFun('b) over 'w as 'mycount, weightAvgFun('a, 'b) over 'w as 'wAvg) - .select('c, 'mycount, 'wAvg) + .select('c, countFun('b) over 'w as 'mycount, weightAvgFun('a, 'b) over 'w as 'wAvg, + weightStateAvg('a, 'b) over 'w as 'wStateAvg) + .select('c, 'mycount, 'wAvg, 'wStateAvg) val results = windowedTable.toDataStream[Row] results.addSink(new StreamITCase.StringSink) env.execute() val expected = Seq( - "Hello World,1,7", "Hello World,2,7", "Hello World,3,14", - "Hello,1,1", "Hello,2,1", "Hello,3,2", "Hello,4,3", "Hello,5,3", "Hello,6,4") + "Hello World,1,7,7", "Hello World,2,7,7", "Hello World,3,14,14", + "Hello,1,1,1", "Hello,2,1,1", "Hello,3,2,2", "Hello,4,3,3", "Hello,5,3,3", "Hello,6,4,4") assertEquals(expected.sorted, StreamITCase.testResults.sorted) } @@ -110,6 +112,7 @@ class OverWindowITCase extends StreamingWithStateTestBase { .toTable(tEnv, 'a, 'b, 'c, 'rowtime.rowtime) val countFun = new CountAggFunction val weightAvgFun = new WeightedAvg + val weightStateAvg = new WeightedStateAvg val windowedTable = table .window(Over partitionBy 'a orderBy 'rowtime preceding UNBOUNDED_RANGE following @@ -121,26 +124,27 @@ class OverWindowITCase extends StreamingWithStateTestBase { 'b.avg over 'w, 'b.max over 'w, 'b.min over 'w, - weightAvgFun('b, 'a) over 'w) + weightAvgFun('b, 'a) over 'w, + weightStateAvg('b, 'a) over 'w) val result = windowedTable.toDataStream[Row] result.addSink(new StreamITCase.StringSink) env.execute() val expected = mutable.MutableList( - "1,1,Hello,6,3,2,3,1,2", - "1,2,Hello,6,3,2,3,1,2", - "1,3,Hello world,6,3,2,3,1,2", - "1,1,Hi,7,4,1,3,1,1", - "2,1,Hello,1,1,1,1,1,1", - "2,2,Hello world,6,3,2,3,1,2", - "2,3,Hello world,6,3,2,3,1,2", - "1,4,Hello world,11,5,2,4,1,2", - "1,5,Hello world,29,8,3,7,1,3", - "1,6,Hello world,29,8,3,7,1,3", - "1,7,Hello world,29,8,3,7,1,3", - "2,4,Hello world,15,5,3,5,1,3", - "2,5,Hello world,15,5,3,5,1,3" + "1,1,Hello,6,3,2,3,1,2,2", + "1,2,Hello,6,3,2,3,1,2,2", + "1,3,Hello world,6,3,2,3,1,2,2", + "1,1,Hi,7,4,1,3,1,1,1", + "2,1,Hello,1,1,1,1,1,1,1", + "2,2,Hello world,6,3,2,3,1,2,2", + "2,3,Hello world,6,3,2,3,1,2,2", + "1,4,Hello world,11,5,2,4,1,2,2", + "1,5,Hello world,29,8,3,7,1,3,3", + "1,6,Hello world,29,8,3,7,1,3,3", + "1,7,Hello world,29,8,3,7,1,3,3", + "2,4,Hello world,15,5,3,5,1,3,3", + "2,5,Hello world,15,5,3,5,1,3,3" ) assertEquals(expected.sorted, StreamITCase.testResults.sorted) @@ -174,30 +178,33 @@ class OverWindowITCase extends StreamingWithStateTestBase { val stream = env.fromCollection(data) val table = stream.toTable(tEnv, 'a, 'b, 'c, 'd, 'e, 'proctime.proctime) + val weightAvgFun = new WeightedAvgWithRetractAndReset + val weightStateAvg = new WeightedStateAvgWithRetractAndReset val windowedTable = table .window(Over partitionBy 'a orderBy 'proctime preceding 4.rows following CURRENT_ROW as 'w) - .select('a, 'c.sum over 'w, 'c.min over 'w) + .select('a, 'c.sum over 'w, 'c.min over 'w, weightAvgFun('b, 'a) over 'w, + weightStateAvg('b, 'a) over 'w) val result = windowedTable.toDataStream[Row] result.addSink(new StreamITCase.StringSink) env.execute() val expected = mutable.MutableList( - "1,0,0", - "2,1,1", - "2,3,1", - "3,3,3", - "3,7,3", - "3,12,3", - "4,6,6", - "4,13,6", - "4,21,6", - "4,30,6", - "5,10,10", - "5,21,10", - "5,33,10", - "5,46,10", - "5,60,10") + "1,0,0,1,1", + "2,1,1,2,2", + "2,3,1,2,2", + "3,3,3,4,4", + "3,7,3,4,4", + "3,12,3,5,5", + "4,6,6,7,7", + "4,13,6,7,7", + "4,21,6,8,8", + "4,30,6,8,8", + "5,10,10,11,11", + "5,21,10,11,11", + "5,33,10,12,12", + "5,46,10,12,12", + "5,60,10,13,13") assertEquals(expected.sorted, StreamITCase.testResults.sorted) } @@ -236,22 +243,25 @@ class OverWindowITCase extends StreamingWithStateTestBase { val table = env.addSource[(Long, Int, String)]( new RowTimeSourceFunction[(Long, Int, String)](data)) .toTable(tEnv, 'a, 'b, 'c, 'rowtime.rowtime) + val weightAvgFun = new WeightedAvgWithRetractAndReset + val weightStateAvg = new WeightedStateAvgWithRetractAndReset val windowedTable = table .window(Over partitionBy 'c orderBy 'rowtime preceding 2.rows following CURRENT_ROW as 'w) - .select('c, 'a, 'a.count over 'w, 'a.sum over 'w) + .select('c, 'a, 'a.count over 'w, 'a.sum over 'w, weightAvgFun('a, 'b) over 'w, + weightStateAvg('a, 'b) over 'w) val result = windowedTable.toDataStream[Row] result.addSink(new StreamITCase.StringSink) env.execute() val expected = mutable.MutableList( - "Hello,1,1,1", "Hello,1,2,2", "Hello,1,3,3", - "Hello,2,3,4", "Hello,2,3,5", "Hello,2,3,6", - "Hello,3,3,7", "Hello,4,3,9", "Hello,5,3,12", - "Hello,6,3,15", - "Hello World,7,1,7", "Hello World,7,2,14", "Hello World,7,3,21", - "Hello World,7,3,21", "Hello World,8,3,22", "Hello World,20,3,35") + "Hello,1,1,1,1,1", "Hello,1,2,2,1,1", "Hello,1,3,3,1,1", + "Hello,2,3,4,1,1", "Hello,2,3,5,1,1", "Hello,2,3,6,2,2", + "Hello,3,3,7,2,2", "Hello,4,3,9,3,3", "Hello,5,3,12,4,4", + "Hello,6,3,15,5,5", + "Hello World,7,1,7,7,7", "Hello World,7,2,14,7,7", "Hello World,7,3,21,7,7", + "Hello World,7,3,21,7,7", "Hello World,8,3,22,7,7", "Hello World,20,3,35,14,14") assertEquals(expected.sorted, StreamITCase.testResults.sorted) } @@ -298,27 +308,31 @@ class OverWindowITCase extends StreamingWithStateTestBase { val table = env.addSource[(Long, Int, String)]( new RowTimeSourceFunction[(Long, Int, String)](data)) .toTable(tEnv, 'a, 'b, 'c, 'rowtime.rowtime) + val weightAvgFun = new WeightedAvgWithRetractAndReset + val weightStateAvg = new WeightedStateAvgWithRetractAndReset val windowedTable = table .window( Over partitionBy 'c orderBy 'rowtime preceding 1.seconds following CURRENT_RANGE as 'w) - .select('c, 'b, 'a.count over 'w, 'a.sum over 'w) + .select('c, 'b, 'a.count over 'w, 'a.sum over 'w, weightAvgFun('a, 'b) over 'w, + weightStateAvg('a, 'b) over 'w) val result = windowedTable.toDataStream[Row] result.addSink(new StreamITCase.StringSink) env.execute() val expected = mutable.MutableList( - "Hello,1,1,1", "Hello,15,2,2", "Hello,16,3,3", - "Hello,2,6,9", "Hello,3,6,9", "Hello,2,6,9", - "Hello,3,4,9", - "Hello,4,2,7", - "Hello,5,2,9", - "Hello,6,2,11", "Hello,65,2,12", - "Hello,9,2,12", "Hello,9,2,12", "Hello,18,3,18", - "Hello World,7,1,7", "Hello World,17,3,21", "Hello World,77,3,21", "Hello World,18,1,7", - "Hello World,8,2,15", - "Hello World,20,1,20") + "Hello,1,1,1,1,1", "Hello,15,2,2,1,1", "Hello,16,3,3,1,1", + "Hello,2,6,9,1,1", "Hello,3,6,9,1,1", "Hello,2,6,9,1,1", + "Hello,3,4,9,2,2", + "Hello,4,2,7,3,3", + "Hello,5,2,9,4,4", + "Hello,6,2,11,5,5", "Hello,65,2,12,6,6", + "Hello,9,2,12,6,6", "Hello,9,2,12,6,6", "Hello,18,3,18,6,6", + "Hello World,7,1,7,7,7", "Hello World,17,3,21,7,7", "Hello World,77,3,21,7,7", + "Hello World,18,1,7,7,7", + "Hello World,8,2,15,7,7", + "Hello World,20,1,20,20,20") assertEquals(expected.sorted, StreamITCase.testResults.sorted) } } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala index 77798f9ddcebb..7a51f64e61b70 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala @@ -176,6 +176,10 @@ class HarnessTestBase { | return new org.apache.flink.types.Row(7); | } | + | public final void setAggregateContext( + | org.apache.flink.table.functions.AggregateContext aggregateContext) { + | } + | |/******* This test does not use the following methods *******/ | public org.apache.flink.types.Row mergeAccumulatorsPair( | org.apache.flink.types.Row a, @@ -297,6 +301,10 @@ class HarnessTestBase { | public final void resetAccumulator( | org.apache.flink.types.Row accs) { | } + | + | public final void setAggregateContext( + | org.apache.flink.table.functions.AggregateContext aggregateContext) { + | } |} |""".stripMargin