Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,32 @@ class CodeGenerator(
}
}

def genSetAggregatesContext: String = {

val sig: String =
j"""
| public final void setAggregateContext(
| org.apache.flink.table.functions.AggregateContext aggregateContext)""".stripMargin

val setAggs: String = {
for (i <- aggs.indices) yield

j"""
| if ((org.apache.flink.table.functions.AggregateFunction)${aggs(i)} instanceof
| org.apache.flink.table.functions.RichAggregateFunction) {
| Object obj = ${aggs(i)};
| org.apache.flink.table.functions.RichAggregateFunction richAggFunction =
| (org.apache.flink.table.functions.RichAggregateFunction) obj;
| richAggFunction.setAggregateContext(aggregateContext);
| }""".stripMargin
}.mkString("\n")

j"""
|$sig {
|$setAggs
| }""".stripMargin
}

def genSetAggregationResults: String = {

val sig: String =
Expand Down Expand Up @@ -637,6 +663,7 @@ class CodeGenerator(
|
""".stripMargin

funcCode += genSetAggregatesContext + "\n"
funcCode += genSetAggregationResults + "\n"
funcCode += genAccumulate + "\n"
funcCode += genRetract + "\n"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.functions

import org.apache.flink.annotation.PublicEvolving
import org.apache.flink.api.common.functions.RuntimeContext
import org.apache.flink.api.common.state._

/**
* A AggregateContext allows to obtain global runtime information about the context in which the
* aggregate function is executed. The information include the methods for accessing state.
*
* @param context the runtime context in which the Flink Function is executed
*/
class AggregateContext(context: RuntimeContext) {

// ------------------------------------------------------------------------
// Methods for accessing state
// ------------------------------------------------------------------------

/**
* Gets a handle to the [[ValueState]].
*
* @param stateProperties The descriptor defining the properties of the stats.
* @tparam T The type of value stored in the state.
* @return The partitioned state object.
* @throws UnsupportedOperationException Thrown, if no partitioned state is available for the
* function (function is not part of a KeyedStream).
*/
@PublicEvolving
def getState[T](stateProperties: ValueStateDescriptor[T]): ValueState[T] =
context.getState(stateProperties)

/**
* Gets a handle to the [[ListState]].
*
* @param stateProperties The descriptor defining the properties of the stats.
* @tparam T The type of value stored in the state.
* @return The partitioned state object.
* @throws UnsupportedOperationException Thrown, if no partitioned state is available for the
* function (function is not part os a KeyedStream).
*/
@PublicEvolving
def getListState[T](stateProperties: ListStateDescriptor[T]): ListState[T] =
context.getListState(stateProperties)

/**
* Gets a handle to the [[MapState]].
*
* @param stateProperties The descriptor defining the properties of the stats.
* @tparam UK The type of the user keys stored in the state.
* @tparam UV The type of the user values stored in the state.
* @return The partitioned state object.
* @throws UnsupportedOperationException Thrown, if no partitioned state is available for the
* function (function is not part of a KeyedStream).
*/
@PublicEvolving
def getMapState[UK, UV](stateProperties: MapStateDescriptor[UK, UV]): MapState[UK, UV] =
context.getMapState(stateProperties)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.functions

import org.apache.flink.api.common.state._
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.common.typeutils.TypeSerializer

import scala.collection.mutable

/**
* Rich variant of the [[AggregateFunction]]. It encapsulates access to the state.
*
*/
abstract class RichAggregateFunction[T, ACC] extends AggregateFunction[T, ACC] {
private var aggContext: AggregateContext = _
private val descriptorMapping = mutable.Map[String, StateDescriptor[_, _]]()

private[flink] def setAggregateContext(context : AggregateContext) = {
this.aggContext = context
}

def registerValue[K](name: String, typeClass: Class[K]): Unit = {
descriptorMapping.put(name, new ValueStateDescriptor[K](name, typeClass))
}

def registerValue[K](name: String, typeInfo: TypeInformation[K]): Unit = {
descriptorMapping.put(name, new ValueStateDescriptor[K](name, typeInfo))
}

def registerValue[K](name: String, typeSerializer: TypeSerializer[K]): Unit = {
descriptorMapping.put(name, new ValueStateDescriptor[K](name, typeSerializer))
}

def registerList[K](name: String, elementTypeClass: Class[K]): Unit = {
descriptorMapping.put(name, new ListStateDescriptor[K](name, elementTypeClass))
}

def registerList[K](name: String, elementTypeInfo: TypeInformation[K]): Unit = {
descriptorMapping.put(name, new ListStateDescriptor[K](name, elementTypeInfo))
}

def registerList[K](name: String, typeSerializer: TypeSerializer[K]): Unit = {
descriptorMapping.put(name, new ListStateDescriptor[K](name, typeSerializer))
}

def registerMap[UK, UV](name: String, keySerializer: TypeSerializer[UK],
valueSerializer: TypeSerializer[UV]): Unit = {
descriptorMapping.put(name, new MapStateDescriptor[UK, UV](name, keySerializer, valueSerializer))
}

def registerMap[UK, UV](name: String, keyTypeInfo: TypeInformation[UK],
valueTypeInfo: TypeInformation[UV]): Unit = {
descriptorMapping.put(name, new MapStateDescriptor[UK, UV](name, keyTypeInfo, valueTypeInfo))
}

def registerMap[UK, UV](name: String, keyClass: Class[UK], valueClass: Class[UV]): Unit = {
descriptorMapping.put(name, new MapStateDescriptor[UK, UV](name, keyClass, valueClass))
}

def getValueByStateName[K](name: String): ValueState[K] = {
aggContext.getState(descriptorMapping(name).asInstanceOf[ValueStateDescriptor[K]])
}

def getListByStateName[K](name: String): ListState[K] = {
aggContext.getListState(descriptorMapping(name).asInstanceOf[ListStateDescriptor[K]])
}

def getMapByStateName[UK, UV](name: String): MapState[UK, UV] = {
aggContext.getMapState(descriptorMapping(name).asInstanceOf[MapStateDescriptor[UK, UV]])
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,6 @@ class AggregateAggFunction(genAggregations: GeneratedAggregationsFunction)
genAggregations.code)
LOG.debug("Instantiating AggregateHelper.")
function = clazz.newInstance()

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import org.apache.flink.table.functions.aggfunctions._
import org.apache.flink.table.functions.utils.AggSqlFunction
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
import org.apache.flink.table.functions.{AggregateFunction => TableAggregateFunction}
import org.apache.flink.table.functions.{RichAggregateFunction => TableRichAggregateFunction}
import org.apache.flink.table.plan.logical._
import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo}
import org.apache.flink.table.typeutils.TypeCheckUtils._
Expand Down Expand Up @@ -332,6 +333,8 @@ object AggregateUtil {
inputType,
needRetract)

validateRichAggregate(aggregates)

val mapReturnType: RowTypeInfo =
createRowTypeForKeysAndAggregates(
groupings,
Expand Down Expand Up @@ -437,6 +440,8 @@ object AggregateUtil {
physicalInputRowType,
needRetract)

validateRichAggregate(aggregates)

val returnType: RowTypeInfo = createRowTypeForKeysAndAggregates(
groupings,
aggregates,
Expand Down Expand Up @@ -550,6 +555,8 @@ object AggregateUtil {
physicalInputRowType,
needRetract)

validateRichAggregate(aggregates)

val aggMapping = aggregates.indices.toArray.map(_ + groupings.length)

val genPreAggFunction = generator.generateAggregations(
Expand Down Expand Up @@ -697,6 +704,8 @@ object AggregateUtil {
physicalInputRowType,
needRetract)

validateRichAggregate(aggregates)

val aggMapping = aggregates.indices.map(_ + groupings.length).toArray

val keysAndAggregatesArity = groupings.length + namedAggregates.length
Expand Down Expand Up @@ -770,6 +779,8 @@ object AggregateUtil {
physicalInputRowType,
needRetract)

validateRichAggregate(aggregates)

val aggMapping = aggregates.indices.map(_ + groupings.length).toArray

val keysAndAggregatesArity = groupings.length + namedAggregates.length
Expand Down Expand Up @@ -836,6 +847,8 @@ object AggregateUtil {
inputType,
needRetract)

validateRichAggregate(aggregates)

val (gkeyOutMapping, aggOutMapping) = getOutputMappings(
namedAggregates,
groupings,
Expand Down Expand Up @@ -1009,6 +1022,8 @@ object AggregateUtil {
inputType,
needRetract)

validateRichAggregate(aggregates)

val aggMapping = aggregates.indices.toArray
val outputArity = aggregates.length

Expand Down Expand Up @@ -1038,6 +1053,14 @@ object AggregateUtil {
(aggFunction, accumulatorRowType, aggResultRowType)
}

private def validateRichAggregate(aggregates: Array[TableAggregateFunction[_, _]]): Unit = {
aggregates.foreach {
case agg: TableRichAggregateFunction[_, _] =>
throw new TableException("RichAggregate is currently not supported")
case _ => // ok
}
}

/**
* Return true if all aggregates can be partially merged. False otherwise.
*/
Expand Down Expand Up @@ -1417,9 +1440,9 @@ object AggregateUtil {
if (accType != null) {
accType
} else {
val accumulator = agg.createAccumulator()
try {
TypeInformation.of(accumulator.getClass)
val method = agg.getClass.getMethod("createAccumulator")
TypeInformation.of(method.getReturnType)
} catch {
case ite: InvalidTypesException =>
throw new TableException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,19 @@
package org.apache.flink.table.runtime.aggregate

import org.apache.flink.api.common.functions.Function
import org.apache.flink.table.functions.AggregateContext
import org.apache.flink.types.Row

/**
* Base class for code-generated aggregations.
*/
abstract class GeneratedAggregations extends Function {

/**
* Set the context to aggregates
*/
def setAggregateContext(aggregateContext: AggregateContext)

/**
* Sets the results of the aggregations (partial or final) to the output row.
* Final results are computed with the aggregation function.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.flink.api.common.state.ValueState
import org.apache.flink.table.api.{StreamQueryConfig, Types}
import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction}
import org.slf4j.{Logger, LoggerFactory}
import org.apache.flink.table.functions.AggregateContext
import org.apache.flink.table.runtime.types.CRow

/**
Expand Down Expand Up @@ -65,6 +66,7 @@ class GroupAggProcessFunction(
genAggregations.code)
LOG.debug("Instantiating AggregateHelper.")
function = clazz.newInstance()
function.setAggregateContext(new AggregateContext(getRuntimeContext))

newRow = new CRow(function.createOutputRow(), true)
prevRow = new CRow(function.createOutputRow(), false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import java.util.{ArrayList, List => JList}
import org.apache.flink.api.common.typeinfo.BasicTypeInfo
import org.apache.flink.table.api.StreamQueryConfig
import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction}
import org.apache.flink.table.functions.AggregateContext
import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo}
import org.slf4j.LoggerFactory

Expand Down Expand Up @@ -70,6 +71,7 @@ class ProcTimeBoundedRangeOver(
genAggregations.code)
LOG.debug("Instantiating AggregateHelper.")
function = clazz.newInstance()
function.setAggregateContext(new AggregateContext(getRuntimeContext))
output = new CRow(function.createOutputRow(), true)

// We keep the elements received in a MapState indexed based on their ingestion time
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import java.util.{List => JList}
import org.apache.flink.api.common.typeinfo.BasicTypeInfo
import org.apache.flink.table.api.StreamQueryConfig
import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction}
import org.apache.flink.table.functions.AggregateContext
import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo}
import org.slf4j.LoggerFactory

Expand Down Expand Up @@ -75,6 +76,7 @@ class ProcTimeBoundedRowsOver(
genAggregations.code)
LOG.debug("Instantiating AggregateHelper.")
function = clazz.newInstance()
function.setAggregateContext(new AggregateContext(getRuntimeContext))

output = new CRow(function.createOutputRow(), true)
// We keep the elements received in a Map state keyed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.flink.util.Collector
import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction}
import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo}
import org.apache.flink.types.Row
import org.apache.flink.table.functions.AggregateContext
import org.slf4j.LoggerFactory

/**
Expand Down Expand Up @@ -60,6 +61,7 @@ class ProcTimeUnboundedNonPartitionedOver(
genAggregations.code)
LOG.debug("Instantiating AggregateHelper.")
function = clazz.newInstance()
function.setAggregateContext(new AggregateContext(getRuntimeContext))

output = new CRow(function.createOutputRow(), true)
if (null == accumulators) {
Expand Down
Loading