diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/MatchCodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/MatchCodeGenerator.scala index 0d812133274a44..409ea2319bc350 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/MatchCodeGenerator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/MatchCodeGenerator.scala @@ -21,6 +21,7 @@ package org.apache.flink.table.codegen import java.lang.{Long => JLong} import java.util +import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rex._ import org.apache.calcite.sql.SqlAggFunction import org.apache.calcite.sql.fun.SqlStdOperatorTable._ @@ -32,7 +33,7 @@ import org.apache.flink.configuration.Configuration import org.apache.flink.table.api.dataview.DataViewSpec import org.apache.flink.table.api.{TableConfig, TableException, ValidationException} import org.apache.flink.table.calcite.FlinkTypeFactory -import org.apache.flink.table.codegen.CodeGenUtils.{boxedTypeTermForTypeInfo, newName, primitiveDefaultValue} +import org.apache.flink.table.codegen.CodeGenUtils.{boxedTypeTermForTypeInfo, newName, primitiveDefaultValue, primitiveTypeTermForTypeInfo} import org.apache.flink.table.codegen.GeneratedExpression.{NEVER_NULL, NO_CODE} import org.apache.flink.table.codegen.Indenter.toISC import org.apache.flink.table.functions.{AggregateFunction => TableAggregateFunction} @@ -49,6 +50,90 @@ import scala.collection.mutable /** * A code generator for generating CEP related functions. * + * Aggregates are generated as follows: + * 1. all aggregate [[RexCall]]s are grouped by corresponding pattern variable + * 2. even if the same aggregation is used multiple times in an expression + * (e.g. SUM(A.price) > SUM(A.price) + 1) it will be calculated once. To do so [[AggBuilder]] + * keeps set of already seen different aggregation calls, and reuses the code to access + * appropriate field of aggregation result + * 3. after translating every expression (either in [[generateCondition]] or in + * [[generateOneRowPerMatchExpression]]) there will be generated code for + * * [[GeneratedFunction]], which will be an inner class + * * said [[GeneratedFunction]] will be instantiated in the ctor and opened/closed + * in corresponding methods of top level generated classes + * * function that transforms input rows (row by row) into aggregate input rows + * * function that calculates aggregates for variable, that uses the previous method + * The generated code will look similar to this: + * + * + * {{{ + * + * public class MatchRecognizePatternSelectFunction$175 extends RichPatternSelectFunction { + * + * // Class used to calculate aggregates for a single pattern variable + * public final class AggFunction_variable$115$151 extends GeneratedAggregations { + * ... + * } + * + * private final AggFunction_variable$115$151 aggregator_variable$115; + * + * public MatchRecognizePatternSelectFunction$175() { + * aggregator_variable$115 = new AggFunction_variable$115$151(); + * } + * + * public void open() { + * aggregator_variable$115.open(); + * ... + * } + * + * // Function to transform incoming row into aggregate specific row. It can e.g calculate + * // inner expression of said aggregate + * private Row transformRowForAgg_variable$115(Row inAgg) { + * ... + * } + * + * // Function to calculate all aggregates for a single pattern variable + * private Row calculateAgg_variable$115(List input) { + * Acc accumulator = aggregator_variable$115.createAccumulator(); + * for (Row row : input) { + * aggregator_variable$115.accumulate(accumulator, transformRowForAgg_variable$115(row)); + * } + * + * return aggregator_variable$115.getResult(accumulator); + * } + * + * @Override + * public Object select(Map> in1) throws Exception { + * + * // Extract list of rows assigned to a single pattern variable + * java.util.List patternEvents$130 = (java.util.List) in1.get("A"); + * ... + * + * // Calculate aggregates + * Row aggRow_variable$110$111 = calculateAgg_variable$110(patternEvents$114); + * + * // Every aggregation (e.g SUM(A.price) and AVG(A.price)) will be extracted to a variable + * double result$135 = aggRow_variable$126$127.getField(0); + * long result$137 = aggRow_variable$126$127.getField(1); + * + * // Result of aggregation will be used in expression evaluation + * out.setField(0, result$135) + * + * long result$140 = result$137 * 2; + * out.setField(1, result$140); + * + * double result$144 = $result135 + result$137; + * out.setField(2, result$144); + * } + * + * public void close() { + * aggregator_variable$115.close(); + * ... + * } + * + * } + * }}} + * * @param config configuration that determines runtime behavior * @param patternNames sorted sequence of pattern variables * @param input type information about the first input of the Function @@ -64,6 +149,8 @@ class MatchCodeGenerator( private case class GeneratedPatternList(resultTerm: String, code: String) + private val ALL_PATTERN_VARIABLE = "*" + /** * Used to assign unique names for list of events per pattern variable name. Those lists * are treated as inputs and are needed by input access code. @@ -82,18 +169,18 @@ class MatchCodeGenerator( * Flags that tells if we generate expressions inside an aggregate. It tells how to access input * row. */ - private var innerAggExpr: Boolean = false + private var isWithinAggExprState: Boolean = false /** * Name of term in function used to transform input row into aggregate input row. */ - private val inputAggRowTerm = newName("inAgg") + private val inputAggRowTerm = "inAgg" /** Term for row for key extraction */ - private val keyRowTerm = newName("keyRow") + private val keyRowTerm = "keyRow" /** Term for list of all pattern names */ - private val patternNamesTerm = newName("patternNames") + private val patternNamesTerm = "patternNames" /** * Used to collect all aggregates per pattern variable. @@ -221,7 +308,6 @@ class MatchCodeGenerator( private def generateKeyRow() : GeneratedExpression = { val exp = reusableInputUnboxingExprs .get((keyRowTerm, 0)) match { - // input access and unboxing has already been generated case Some(expr) => expr @@ -331,7 +417,7 @@ class MatchCodeGenerator( agg } - matchAgg.getOrAddAggregation(call) + matchAgg.generateAggAccess(call) case _ => super.visitCall(call) } @@ -349,10 +435,10 @@ class MatchCodeGenerator( } override def visitPatternFieldRef(fieldRef: RexPatternFieldRef): GeneratedExpression = { - if (innerAggExpr) { + if (isWithinAggExprState) { generateFieldAccess(input, inputAggRowTerm, fieldRef.getIndex) } else { - if (fieldRef.getAlpha.equals("*") && currentPattern.isDefined && offset == 0 && !first) { + if (fieldRef.getAlpha.equals(ALL_PATTERN_VARIABLE) && currentPattern.isDefined && offset == 0 && !first) { generateInputAccess(input, input1Term, fieldRef.getIndex) } else { generatePatternFieldRef(fieldRef) @@ -368,14 +454,14 @@ class MatchCodeGenerator( val eventTypeTerm = boxedTypeTermForTypeInfo(input) val eventNameTerm = newName("event") - val addCurrent = if (currentPattern == patternName || patternName == "*") { + val addCurrent = if (currentPattern == patternName || patternName == ALL_PATTERN_VARIABLE) { j""" |$listName.add($input1Term); """.stripMargin } else { "" } - val listCode = if (patternName == "*") { + val listCode = if (patternName == ALL_PATTERN_VARIABLE) { addReusablePatternNames() val patternTerm = newName("pattern") j""" @@ -410,7 +496,7 @@ class MatchCodeGenerator( private def generateMeasurePatternVariableExp(patternName: String): GeneratedPatternList = { val listName = newName("patternEvents") - val code = if (patternName == "*") { + val code = if (patternName == ALL_PATTERN_VARIABLE) { addReusablePatternNames() val patternTerm = newName("pattern") @@ -468,7 +554,6 @@ class MatchCodeGenerator( patternFieldAlpha: String) : GeneratedPatternList = { reusablePatternLists.get(patternFieldAlpha) match { - // input access and unboxing has already been generated case Some(expr) => expr @@ -486,7 +571,6 @@ class MatchCodeGenerator( val escapedAlpha = EncodingUtils.escapeJava(fieldRef.getAlpha) val patternVariableRef = reusableInputUnboxingExprs .get((s"$escapedAlpha#$first", offset)) match { - // input access and unboxing has already been generated case Some(expr) => expr @@ -509,24 +593,25 @@ class MatchCodeGenerator( private val rowTypeTerm = "org.apache.flink.types.Row" - def getOrAddAggregation(call: RexCall): GeneratedExpression = { - reusableInputUnboxingExprs.get((call.toString, 0)) match { + def generateAggAccess(aggCall: RexCall): GeneratedExpression = { + reusableInputUnboxingExprs.get((aggCall.toString, 0)) match { case Some(expr) => expr case None => - val exp: GeneratedExpression = generateAggAccess(call) - aggregates += call - reusableInputUnboxingExprs((call.toString, 0)) = exp + val exp: GeneratedExpression = doGenerateAggAccess(aggCall) + aggregates += aggCall + reusableInputUnboxingExprs((aggCall.toString, 0)) = exp exp.copy(code = NO_CODE) } } - private def generateAggAccess(call: RexCall) = { + private def doGenerateAggAccess(call: RexCall) = { val singleResultTerm = newName("result") val singleResultNullTerm = newName("nullTerm") val singleResultType = FlinkTypeFactory.toTypeInfo(call.`type`) - val singleResultTypeTerm = boxedTypeTermForTypeInfo(singleResultType) + val primitiveSingleResultTypeTerm = primitiveTypeTermForTypeInfo(singleResultType) + val boxedSingleResultTypeTerm = boxedTypeTermForTypeInfo(singleResultType) val patternName = findEventsByPatternName(variable) @@ -538,45 +623,49 @@ class MatchCodeGenerator( reusablePerRecordStatements += codeForAgg val defaultValue = primitiveDefaultValue(singleResultType) - val codeForSingleAgg = + val codeForSingleAgg = if (nullCheck) { j""" |boolean $singleResultNullTerm; - |$singleResultTypeTerm $singleResultTerm = ($singleResultTypeTerm) $resultRowTerm - | .getField(${aggregates.size}); - |if ($singleResultTerm != null) { + |$primitiveSingleResultTypeTerm $singleResultTerm; + |if ($resultRowTerm.getField(${aggregates.size}) != null) { + | $singleResultTerm = ($boxedSingleResultTypeTerm) $resultRowTerm + | .getField(${aggregates.size}); | $singleResultNullTerm = false; |} else { | $singleResultNullTerm = true; | $singleResultTerm = $defaultValue; |} |""".stripMargin + } else { + j""" + |boolean $singleResultNullTerm = false; + |$primitiveSingleResultTypeTerm $singleResultTerm = + | ($boxedSingleResultTypeTerm) $resultRowTerm.getField(${aggregates.size}); + |""".stripMargin + } reusablePerRecordStatements += codeForSingleAgg - val exp = GeneratedExpression(singleResultTerm, - singleResultNullTerm, - NO_CODE, - singleResultType) - exp + GeneratedExpression(singleResultTerm, singleResultNullTerm, NO_CODE, singleResultType) } def generateAggFunction() : Unit = { - val (aggs, exprs) = extractAggregatesAndExpressions + val matchAgg = extractAggregatesAndExpressions val aggGenerator = new AggregationCodeGenerator(config, false, input, None) val aggFunc = aggGenerator.generateAggregations( s"AggFunction_$variableUID", - exprs.map(r => FlinkTypeFactory.toTypeInfo(r.getType)), - aggs.map(_.aggFunction).toArray, - aggs.map(_.inputIndices).toArray, - aggs.indices.toArray, - Array.fill(aggs.size)(false), + matchAgg.inputExprs.map(r => FlinkTypeFactory.toTypeInfo(r.getType)), + matchAgg.aggregations.map(_.aggFunction).toArray, + matchAgg.aggregations.map(_.inputIndices).toArray, + matchAgg.aggregations.indices.toArray, + Array.fill(matchAgg.aggregations.size)(false), isStateBackedDataViews = false, partialResults = false, Array.emptyIntArray, None, - aggs.size, + matchAgg.aggregations.size, needRetract = false, needMerge = false, needReset = false, @@ -587,22 +676,33 @@ class MatchCodeGenerator( val transformFuncName = s"transformRowForAgg_$variableUID" val inputTransform: String = generateAggInputExprEvaluation( - exprs, + matchAgg.inputExprs, transformFuncName) generateAggCalculation(aggFunc, transformFuncName, inputTransform) } + private case class LogicalMatchAggCall( + function: SqlAggFunction, + inputTypes: Seq[RelDataType], + exprIndices: Seq[Int] + ) + private case class MatchAggCall( aggFunction: TableAggregateFunction[_, _], inputIndices: Array[Int], dataViews: Seq[DataViewSpec[_]] ) - private def extractAggregatesAndExpressions: (Seq[MatchAggCall], Seq[RexNode]) = { + private case class MatchAgg( + aggregations: Seq[MatchAggCall], + inputExprs: Seq[RexNode] + ) + + private def extractAggregatesAndExpressions: MatchAgg = { val inputRows = new mutable.LinkedHashMap[String, (RexNode, Int)] - val aggs = aggregates.map(rexAggCall => { + val logicalAggregates = aggregates.map(rexAggCall => { val callsWithIndices = rexAggCall.operands.asScala.map(innerCall => { inputRows.get(innerCall.toString) match { case Some(x) => @@ -616,22 +716,26 @@ class MatchCodeGenerator( }).toList val agg = rexAggCall.getOperator.asInstanceOf[SqlAggFunction] - (agg, callsWithIndices.map(_._1), callsWithIndices.map(_._2).toArray) - }).zipWithIndex.map { + LogicalMatchAggCall(agg, + callsWithIndices.map(_._1.getType), + callsWithIndices.map(_._2).toArray) + }) + + val aggs = logicalAggregates.zipWithIndex.map { case (agg, index) => val result = AggregateUtil.extractAggregateCallMetadata( - agg._1, + agg.function, isDistinct = false, - agg._2.map(_.getType), + agg.inputTypes, needRetraction = false, config, isStateBackedDataViews = false, index) - MatchAggCall(result.aggregateFunction, agg._3, result.accumulatorSpecs) + MatchAggCall(result.aggregateFunction, agg.exprIndices.toArray, result.accumulatorSpecs) } - (aggs, inputRows.values.map(_._1).toSeq) + MatchAgg(aggs, inputRows.values.map(_._1).toSeq) } private def generateAggCalculation( @@ -668,7 +772,7 @@ class MatchCodeGenerator( inputExprs: Seq[RexNode], funcName: String) : String = { - innerAggExpr = true + isWithinAggExprState = true val resultTerm = newName("result") val exprs = inputExprs.zipWithIndex.map(row => { val expr = generateExpression(row._1) @@ -681,7 +785,7 @@ class MatchCodeGenerator( |} """.stripMargin }).mkString("\n") - innerAggExpr = false + isWithinAggExprState = false j""" |private $rowTypeTerm $funcName($rowTypeTerm $inputAggRowTerm) { @@ -696,12 +800,14 @@ class MatchCodeGenerator( class PatternVariableFinder extends RexDefaultVisitor[Option[String]] { + val ALL_PATTERN_VARIABLE = "*" + override def visitPatternFieldRef(patternFieldRef: RexPatternFieldRef): Option[String] = Some( patternFieldRef.getAlpha) override def visitCall(call: RexCall): Option[String] = { if (call.operands.size() == 0) { - Some("*") + Some(ALL_PATTERN_VARIABLE) } else { call.operands.asScala.map(n => n.accept(this)).reduce((op1, op2) => (op1, op2) match { case (None, None) => None diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/match/MatchOperatorValidationTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/match/MatchRecognizeValidationTest.scala similarity index 97% rename from flink-libraries/flink-table/src/test/scala/org/apache/flink/table/match/MatchOperatorValidationTest.scala rename to flink-libraries/flink-table/src/test/scala/org/apache/flink/table/match/MatchRecognizeValidationTest.scala index d5b74c9e45b127..b77a60e1af5195 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/match/MatchOperatorValidationTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/match/MatchRecognizeValidationTest.scala @@ -27,7 +27,7 @@ import org.apache.flink.table.utils.TableTestBase import org.apache.flink.types.Row import org.junit.Test -class MatchOperatorValidationTest extends TableTestBase { +class MatchRecognizeValidationTest extends TableTestBase { private val streamUtils = streamTestUtil() streamUtils.addTable[(String, Long, Int, Int)]("Ticker", @@ -131,6 +131,7 @@ class MatchOperatorValidationTest extends TableTestBase { @Test def testAggregatesOnMultiplePatternVariablesNotSupported(): Unit = { thrown.expect(classOf[ValidationException]) + thrown.expectMessage("SQL validation failed.") val sqlQuery = s""" @@ -152,6 +153,7 @@ class MatchOperatorValidationTest extends TableTestBase { @Test def testAggregatesOnMultiplePatternVariablesNotSupportedInUDAGs(): Unit = { thrown.expect(classOf[ValidationException]) + thrown.expectMessage("Aggregation must be applied to a single pattern variable") streamUtils.tableEnv.registerFunction("weightedAvg", new WeightedAvg) diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/MatchRecognizeITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/MatchRecognizeITCase.scala index 60fecf0d5ba948..67401a3bca885e 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/MatchRecognizeITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/MatchRecognizeITCase.scala @@ -23,12 +23,11 @@ import java.util.TimeZone import org.apache.flink.api.common.time.Time import org.apache.flink.api.common.typeinfo.BasicTypeInfo -import org.apache.flink.api.java.typeutils.RowTypeInfo import org.apache.flink.api.scala._ import org.apache.flink.streaming.api.TimeCharacteristic import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment import org.apache.flink.table.api.scala._ -import org.apache.flink.table.api.{TableConfig, TableEnvironment} +import org.apache.flink.table.api.{TableConfig, TableEnvironment, Types} import org.apache.flink.table.functions.{AggregateFunction, FunctionContext, ScalarFunction} import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.WeightedAvg import org.apache.flink.table.runtime.utils.TimeTestUtil.EventTimeSourceFunction @@ -42,7 +41,7 @@ import scala.collection.mutable class MatchRecognizeITCase extends StreamingWithStateTestBase { @Test - def testSimpleCEP(): Unit = { + def testSimplePattern(): Unit = { val env = StreamExecutionEnvironment.getExecutionEnvironment env.setParallelism(1) val tEnv = TableEnvironment.getTableEnvironment(env) @@ -89,7 +88,7 @@ class MatchRecognizeITCase extends StreamingWithStateTestBase { } @Test - def testSimpleCEPWithNulls(): Unit = { + def testSimplePatternWithNulls(): Unit = { val env = StreamExecutionEnvironment.getExecutionEnvironment env.setParallelism(1) val tEnv = TableEnvironment.getTableEnvironment(env) @@ -476,7 +475,7 @@ class MatchRecognizeITCase extends StreamingWithStateTestBase { * 4. aggregates with expressions work */ @Test - def testCepAggregates(): Unit = { + def testAggregates(): Unit = { val env = StreamExecutionEnvironment.getExecutionEnvironment env.setParallelism(1) val tEnv = TableEnvironment.getTableEnvironment(env) @@ -492,7 +491,10 @@ class MatchRecognizeITCase extends StreamingWithStateTestBase { data.+=((6, "a", 2, 1.5, 2)) data.+=((7, "b", 2, 0.8, 3)) data.+=((8, "c", 1, 0.8, 2)) - data.+=((9, "h", 2, 0.8, 3)) + data.+=((9, "h", 4, 0.8, 3)) + data.+=((10, "h", 4, 0.8, 3)) + data.+=((11, "h", 2, 0.8, 3)) + data.+=((12, "h", 2, 0.8, 3)) val t = env.fromCollection(data) .toTable(tEnv, 'id, 'name, 'price, 'rate, 'weight, 'proctime.proctime) @@ -530,12 +532,12 @@ class MatchRecognizeITCase extends StreamingWithStateTestBase { result.addSink(new StreamITCase.StringSink[Row]) env.execute() - val expected = mutable.MutableList("1,5,0,null,2,3,3.4,8") + val expected = mutable.MutableList("1,5,0,null,2,3,3.4,8", "9,4,0,null,3,4,3.2,12") assertEquals(expected.sorted, StreamITCase.testResults.sorted) } @Test - def testCepAggregatesWithNullInputs(): Unit = { + def testAggregatesWithNullInputs(): Unit = { val env = StreamExecutionEnvironment.getExecutionEnvironment env.setParallelism(1) val tEnv = TableEnvironment.getTableEnvironment(env) @@ -543,17 +545,17 @@ class MatchRecognizeITCase extends StreamingWithStateTestBase { StreamITCase.clear val data = new mutable.MutableList[Row] - data.+=(Row.of(1:java.lang.Integer, "a", 10:java.lang.Integer)) - data.+=(Row.of(2:java.lang.Integer, "z", 10:java.lang.Integer)) - data.+=(Row.of(3:java.lang.Integer, "b", null)) - data.+=(Row.of(4:java.lang.Integer, "c", null)) - data.+=(Row.of(5:java.lang.Integer, "d", 3:java.lang.Integer)) - data.+=(Row.of(6:java.lang.Integer, "c", 3:java.lang.Integer)) - data.+=(Row.of(7:java.lang.Integer, "c", 3:java.lang.Integer)) - data.+=(Row.of(8:java.lang.Integer, "c", 3:java.lang.Integer)) - data.+=(Row.of(9:java.lang.Integer, "c", 2:java.lang.Integer)) - - val t = env.fromCollection(data)(new RowTypeInfo( + data.+=(Row.of(Int.box(1), "a", Int.box(10))) + data.+=(Row.of(Int.box(2), "z", Int.box(10))) + data.+=(Row.of(Int.box(3), "b", null)) + data.+=(Row.of(Int.box(4), "c", null)) + data.+=(Row.of(Int.box(5), "d", Int.box(3))) + data.+=(Row.of(Int.box(6), "c", Int.box(3))) + data.+=(Row.of(Int.box(7), "c", Int.box(3))) + data.+=(Row.of(Int.box(8), "c", Int.box(3))) + data.+=(Row.of(Int.box(9), "c", Int.box(2))) + + val t = env.fromCollection(data)(Types.ROW( BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO))