Skip to content

Commit

Permalink
[FLINK-7599] Support for aggregates in MATCH_RECOGNIZE
Browse files Browse the repository at this point in the history
  • Loading branch information
dawidwys committed Nov 27, 2018
1 parent da6854d commit 134abbc
Show file tree
Hide file tree
Showing 5 changed files with 443 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1196,7 +1196,7 @@ abstract class CodeGenerator(
GeneratedExpression(resultTerm, nullTerm, inputCheckCode, fieldType)
}

private def generateFieldAccess(
protected def generateFieldAccess(
inputType: TypeInformation[_],
inputTerm: String,
index: Int)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,22 @@ import java.lang.{Long => JLong}
import java.util

import org.apache.calcite.rex._
import org.apache.calcite.sql.SqlAggFunction
import org.apache.calcite.sql.fun.SqlStdOperatorTable._
import org.apache.flink.api.common.functions._
import org.apache.flink.api.common.typeinfo.{SqlTimeTypeInfo, TypeInformation}
import org.apache.flink.cep.pattern.conditions.IterativeCondition
import org.apache.flink.cep.{PatternFlatSelectFunction, PatternSelectFunction}
import org.apache.flink.table.api.{TableConfig, TableException}
import org.apache.flink.table.codegen.CodeGenUtils.{boxedTypeTermForTypeInfo, newName}
import org.apache.flink.table.api.dataview.DataViewSpec
import org.apache.flink.table.api.{TableConfig, TableException, ValidationException}
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.codegen.CodeGenUtils.{boxedTypeTermForTypeInfo, newName, primitiveDefaultValue}
import org.apache.flink.table.codegen.GeneratedExpression.{NEVER_NULL, NO_CODE}
import org.apache.flink.table.codegen.Indenter.toISC
import org.apache.flink.table.functions.{AggregateFunction => TableAggregateFunction}
import org.apache.flink.table.plan.schema.RowSchema
import org.apache.flink.table.plan.util.RexDefaultVisitor
import org.apache.flink.table.runtime.aggregate.AggregateUtil
import org.apache.flink.table.utils.EncodingUtils
import org.apache.flink.util.Collector
import org.apache.flink.util.MathUtils.checkedDownCast
Expand Down Expand Up @@ -71,12 +77,28 @@ class MatchCodeGenerator(
private var offset: Int = 0
private var first : Boolean = false

/**
* Flags that tells if we generate expressions inside an aggregate. It tells how to access input
* row.
*/
private var innerAggExpr: Boolean = false

/**
* Name of term in function used to transform input row into aggregate input row.
*/
private val inputAggRowTerm = newName("inAgg")

/** Term for row for key extraction */
private val keyRowTerm = newName("keyRow")

/** Term for list of all pattern names */
private val patternNamesTerm = newName("patternNames")

/**
* Used to collect all aggregates per pattern variable.
*/
private val aggregatesPerVariable = new mutable.HashMap[String, AggBuilder]

/**
* Sets the new reference variable indexing context. This should be used when resolving logical
* offsets = LAST/FIRST
Expand Down Expand Up @@ -254,10 +276,18 @@ class MatchCodeGenerator(
generateExpression(measures.get(fieldName))
}

generateResultExpression(
val exp = generateResultExpression(
resultExprs,
returnType.typeInfo,
returnType.fieldNames)
aggregatesPerVariable.values.foreach(_.generateAggFunction())
exp
}

def generateCondition(call: RexNode): GeneratedExpression = {
val exp = call.accept(this)
aggregatesPerVariable.values.foreach(_.generateAggFunction())
exp
}

override def visitCall(call: RexCall): GeneratedExpression = {
Expand Down Expand Up @@ -285,6 +315,21 @@ class MatchCodeGenerator(
case FINAL =>
call.getOperands.get(0).accept(this)

case _ : SqlAggFunction =>

val variable = call.accept(new PatternVariableFinder)
.getOrElse(throw new TableException("No pattern variable specified in aggregate"))

val matchAgg = aggregatesPerVariable.get(variable) match {
case Some(agg) => agg
case None =>
val agg = new AggBuilder(variable)
aggregatesPerVariable(variable) = agg
agg
}

matchAgg.getOrAddAggregation(call)

case _ => super.visitCall(call)
}
}
Expand All @@ -301,10 +346,14 @@ class MatchCodeGenerator(
}

override def visitPatternFieldRef(fieldRef: RexPatternFieldRef): GeneratedExpression = {
if (fieldRef.getAlpha.equals("*") && currentPattern.isDefined && offset == 0 && !first) {
generateInputAccess(input, input1Term, fieldRef.getIndex)
if (innerAggExpr) {
generateFieldAccess(input, inputAggRowTerm, fieldRef.getIndex)
} else {
generatePatternFieldRef(fieldRef)
if (fieldRef.getAlpha.equals("*") && currentPattern.isDefined && offset == 0 && !first) {
generateInputAccess(input, input1Term, fieldRef.getIndex)
} else {
generatePatternFieldRef(fieldRef)
}
}
}

Expand Down Expand Up @@ -392,21 +441,7 @@ class MatchCodeGenerator(
val eventTypeTerm = boxedTypeTermForTypeInfo(input)
val isRowNull = newName("isRowNull")

val findEventsByPatternName = reusablePatternLists.get(patternFieldAlpha) match {
// input access and unboxing has already been generated
case Some(expr) =>
expr

case None =>
val exp = currentPattern match {
case Some(p) => generateDefinePatternVariableExp(patternFieldAlpha, p)
case None => generateMeasurePatternVariableExp(patternFieldAlpha)
}
reusablePatternLists(patternFieldAlpha) = exp
exp
}

val listName = findEventsByPatternName.resultTerm
val listName = findEventsByPatternName(patternFieldAlpha).resultTerm
val resultIndex = if (first) {
j"""$offset"""
} else {
Expand All @@ -426,6 +461,24 @@ class MatchCodeGenerator(
GeneratedExpression(rowNameTerm, isRowNull, funcCode, input)
}

private def findEventsByPatternName(
patternFieldAlpha: String)
: GeneratedPatternList = {
reusablePatternLists.get(patternFieldAlpha) match {
// input access and unboxing has already been generated
case Some(expr) =>
expr

case None =>
val exp = currentPattern match {
case Some(p) => generateDefinePatternVariableExp(patternFieldAlpha, p)
case None => generateMeasurePatternVariableExp(patternFieldAlpha)
}
reusablePatternLists(patternFieldAlpha) = exp
exp
}
}

private def generatePatternFieldRef(fieldRef: RexPatternFieldRef): GeneratedExpression = {
val escapedAlpha = EncodingUtils.escapeJava(fieldRef.getAlpha)
val patternVariableRef = reusableInputUnboxingExprs
Expand All @@ -442,4 +495,221 @@ class MatchCodeGenerator(

generateFieldAccess(patternVariableRef.copy(code = NO_CODE), fieldRef.getIndex)
}

class AggBuilder(variable: String) {

private val aggregates = new mutable.ListBuffer[RexCall]()

private val variableUID = newName("variable")

private val resultRowTerm = newName(s"aggRow_$variableUID")

private val rowTypeTerm = "org.apache.flink.types.Row"

def getOrAddAggregation(call: RexCall): GeneratedExpression = {
reusableInputUnboxingExprs.get((call.toString, 0)) match {
case Some(expr) =>
expr

case None =>
val exp: GeneratedExpression = generateAggAccess(call)
aggregates += call
reusableInputUnboxingExprs((call.toString, 0)) = exp
exp.copy(code = NO_CODE)
}
}

private def generateAggAccess(call: RexCall) = {
val singleResultTerm = newName("result")
val singleResultNullTerm = newName("nullTerm")
val singleResultType = FlinkTypeFactory.toTypeInfo(call.`type`)
val singleResultTypeTerm = boxedTypeTermForTypeInfo(singleResultType)

val patternName = findEventsByPatternName(variable)

val codeForAgg =
j"""
|$rowTypeTerm $resultRowTerm = calculateAgg_$variableUID(${patternName.resultTerm});
|""".stripMargin

reusablePerRecordStatements += codeForAgg

val defaultValue = primitiveDefaultValue(singleResultType)
val codeForSingleAgg =
j"""
|boolean $singleResultNullTerm;
|$singleResultTypeTerm $singleResultTerm = ($singleResultTypeTerm) $resultRowTerm
| .getField(${aggregates.size});
|if ($singleResultTerm != null) {
| $singleResultNullTerm = false;
|} else {
| $singleResultNullTerm = true;
| $singleResultTerm = $defaultValue;
|}
|""".stripMargin

reusablePerRecordStatements += codeForSingleAgg

val exp = GeneratedExpression(singleResultTerm,
singleResultNullTerm,
NO_CODE,
singleResultType)
exp
}

def generateAggFunction() : Unit = {
val (aggs, exprs) = extractAggregatesAndExpressions

val aggGenerator = new AggregationCodeGenerator(config, false, input, None)

val aggFunc = aggGenerator.generateAggregations(
s"AggFunction_$variableUID",
exprs.map(r => FlinkTypeFactory.toTypeInfo(r.getType)),
aggs.map(_.aggFunction).toArray,
aggs.map(_.inputIndices).toArray,
aggs.indices.toArray,
Array.fill(aggs.size)(false),
isStateBackedDataViews = false,
partialResults = false,
Array.emptyIntArray,
None,
aggs.size,
needRetract = false,
needMerge = false,
needReset = false,
None
)

reusableMemberStatements.add(aggFunc.code)

val transformFuncName = s"transformRowForAgg_$variableUID"
val inputTransform: String = generateAggInputExprEvaluation(
exprs,
transformFuncName)

generateAggCalculation(aggFunc, transformFuncName, inputTransform)
}

private case class MatchAggCall(
aggFunction: TableAggregateFunction[_, _],
inputIndices: Array[Int],
dataViews: Seq[DataViewSpec[_]]
)

private def extractAggregatesAndExpressions: (Seq[MatchAggCall], Seq[RexNode]) = {
val inputRows = new mutable.LinkedHashMap[String, (RexNode, Int)]

val aggs = aggregates.map(rexAggCall => {
val callsWithIndices = rexAggCall.operands.asScala.map(innerCall => {
inputRows.get(innerCall.toString) match {
case Some(x) =>
x

case None =>
val callWithIndex = (innerCall, inputRows.size)
inputRows(innerCall.toString) = callWithIndex
callWithIndex
}
}).toList

val agg = rexAggCall.getOperator.asInstanceOf[SqlAggFunction]
(agg, callsWithIndices.map(_._1), callsWithIndices.map(_._2).toArray)
}).zipWithIndex.map {
case (agg, index) =>
val result = AggregateUtil
.transformToAggregateFunction(agg._1,
isDistinct = false,
agg._2.map(_.getType),
needRetraction = false,
config,
isStateBackedDataViews = false,
index)

MatchAggCall(result._1, agg._3, result._3)
}

(aggs, inputRows.values.map(_._1).toSeq)
}

private def generateAggCalculation(
aggFunc: GeneratedAggregationsFunction,
transformFuncName: String,
inputTransform: String)
: Unit = {
val aggregatorTerm = s"aggregator_$variableUID"
val code =
j"""
|private final ${aggFunc.name} $aggregatorTerm;
|
|$inputTransform
|
|private $rowTypeTerm calculateAgg_$variableUID(java.util.List input)
| throws Exception {
| $rowTypeTerm accumulator = $aggregatorTerm.createAccumulators();
| for ($rowTypeTerm row : input) {
| $aggregatorTerm.accumulate(accumulator, $transformFuncName(row));
| }
| $rowTypeTerm result = $aggregatorTerm.createOutputRow();
| $aggregatorTerm.setAggregationResults(accumulator, result);
| return result;
|}
""".stripMargin

reusableInitStatements.add(s"$aggregatorTerm = new ${aggFunc.name}();")
reusableMemberStatements.add(code)
}

private def generateAggInputExprEvaluation(
inputExprs: Seq[RexNode],
funcName: String)
: String = {
innerAggExpr = true
val resultTerm = newName("result")
val exprs = inputExprs.zipWithIndex.map(row => {
val expr = generateExpression(row._1)
s"""
|${expr.code}
|if (${expr.nullTerm}) {
| $resultTerm.setField(${row._2}, null);
|} else {
| $resultTerm.setField(${row._2}, ${expr.resultTerm});
|}
""".stripMargin
}).mkString("\n")
innerAggExpr = false

j"""
|private $rowTypeTerm $funcName($rowTypeTerm $inputAggRowTerm) {
| $rowTypeTerm $resultTerm = new $rowTypeTerm(${inputExprs.size});
| $exprs
| return $resultTerm;
|}
""".stripMargin
}
}
}

class PatternVariableFinder extends RexDefaultVisitor[Option[String]] {

override def visitPatternFieldRef(patternFieldRef: RexPatternFieldRef): Option[String] = Some(
patternFieldRef.getAlpha)

override def visitCall(call: RexCall): Option[String] = {
if (call.operands.size() == 0) {
Some("*")
} else {
call.operands.asScala.map(n => n.accept(this)).reduce((op1, op2) => (op1, op2) match {
case (None, None) => None
case (x, None) => x
case (None, x) => x
case (Some(var1), Some(var2)) if var1.equals(var2) =>
Some(var1)
case _ =>
throw new ValidationException(s"Aggregation must be applied to a single pattern " +
s"variable. Malformed expression: $call")
})
}
}

override def visitNode(rexNode: RexNode): Option[String] = None
}
Loading

0 comments on commit 134abbc

Please sign in to comment.