From 05ce262cb454d7c09ea6f9160f45f2e0cabdbb57 Mon Sep 17 00:00:00 2001 From: Dian Fu Date: Tue, 8 Aug 2017 19:10:24 +0800 Subject: [PATCH 1/2] [FLINK-7062] [table, cep] Support the basic functionality of MATCH_RECOGNIZE --- flink-libraries/flink-table/pom.xml | 7 + .../calcite/RelTimeIndicatorConverter.scala | 37 +- .../flink/table/codegen/CodeGenerator.scala | 20 +- .../table/codegen/MatchCodeGenerator.scala | 589 ++++++++++++++++++ .../flink/table/codegen/generated.scala | 24 + .../nodes/datastream/DataStreamMatch.scala | 324 ++++++++++ .../nodes/logical/FlinkLogicalMatch.scala | 132 ++++ .../table/plan/rules/FlinkRuleSets.scala | 6 +- .../datastream/DataStreamMatchRule.scala | 64 ++ .../table/runtime/match/ConvertToRow.scala | 32 + .../match/IterativeConditionRunner.scala | 58 ++ .../flink/table/runtime/match/MatchUtil.scala | 114 ++++ .../PatternFlatSelectFunctionRunner.scala | 65 ++ .../match/PatternSelectFunctionRunner.scala | 63 ++ .../table/validate/FunctionCatalog.scala | 9 + .../table/api/stream/sql/CepITCase.scala | 410 ++++++++++++ 16 files changed, 1949 insertions(+), 5 deletions(-) create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/MatchCodeGenerator.scala create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamMatch.scala create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalMatch.scala create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamMatchRule.scala create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/match/ConvertToRow.scala create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/match/IterativeConditionRunner.scala create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/match/MatchUtil.scala create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/match/PatternFlatSelectFunctionRunner.scala create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/match/PatternSelectFunctionRunner.scala create mode 100644 flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/CepITCase.scala diff --git a/flink-libraries/flink-table/pom.xml b/flink-libraries/flink-table/pom.xml index d45e997529a19..ee8c831380522 100644 --- a/flink-libraries/flink-table/pom.xml +++ b/flink-libraries/flink-table/pom.xml @@ -85,6 +85,13 @@ under the License. + + org.apache.flink + flink-cep_${scala.binary.version} + ${project.version} + provided + + org.codehaus.janino janino diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala index 4f3fbaa8edeb3..56f700dc25558 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala @@ -33,6 +33,7 @@ import org.apache.flink.table.plan.schema.TimeIndicatorRelDataType import org.apache.flink.table.validate.BasicOperatorTable import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable /** @@ -100,7 +101,7 @@ class RelTimeIndicatorConverter(rexBuilder: RexBuilder) extends RelShuttle { } override def visit(`match`: LogicalMatch): RelNode = - throw new TableException("Logical match in a stream environment is not supported yet.") + convertMatch(`match`) override def visit(other: RelNode): RelNode = other match { @@ -207,6 +208,40 @@ class RelTimeIndicatorConverter(rexBuilder: RexBuilder) extends RelShuttle { correlate.getJoinType) } + private def convertMatch(`match`: Match): LogicalMatch = { + val rowType = `match`.getInput.getRowType + + val measures = `match`.getMeasures.foldLeft(mutable.Map[String, RexNode]()) { + case (m, (k, v)) => + m += k -> RelTimeIndicatorConverter.convertExpression(v, rowType, rexBuilder) + } + + val outputTypeBuilder = rexBuilder + .getTypeFactory + .asInstanceOf[FlinkTypeFactory] + .builder() + `match`.getRowType.getFieldList.asScala + .foreach(x => measures.get(x.getName) match { + case Some(measure) => outputTypeBuilder.add(x.getName, measure.getType) + case None => outputTypeBuilder.add(x) + }) + + LogicalMatch.create( + `match`.getInput, + outputTypeBuilder.build(), + `match`.getPattern, + `match`.isStrictStart, + `match`.isStrictEnd, + `match`.getPatternDefinitions, + measures, + `match`.getAfter, + `match`.getSubsets.asInstanceOf[java.util.Map[String, java.util.TreeSet[String]]], + `match`.isAllRows, + `match`.getPartitionKeys, + `match`.getOrderKeys, + `match`.getInterval) + } + private def convertAggregate(aggregate: Aggregate): LogicalAggregate = { // visit children and update inputs val input = aggregate.getInput.accept(this) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala index 6cabe21221383..43c3f8949a179 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala @@ -757,7 +757,15 @@ abstract class CodeGenerator( o.accept(this) } - call.getOperator match { + generateCallExpression(call.getOperator, operands, resultType) + } + + def generateCallExpression( + operator: SqlOperator, + operands: Seq[GeneratedExpression], + resultType: TypeInformation[_]) + : GeneratedExpression = { + operator match { // arithmetic case PLUS if isNumeric(resultType) => val left = operands.head @@ -1193,7 +1201,7 @@ abstract class CodeGenerator( GeneratedExpression(resultTerm, nullTerm, inputCheckCode, fieldType) } - private def generateFieldAccess( + def generateFieldAccess( inputType: TypeInformation[_], inputTerm: String, index: Int) @@ -1826,6 +1834,14 @@ abstract class CodeGenerator( fieldTerm } + def addReusableInitStatement(initStatement: String): Unit = { + reusableInitStatements.add(initStatement) + } + + def addReusableMemberStatement(memberStatement: String): Unit = { + reusableMemberStatements.add(memberStatement) + } + /** * Adds a reusable [[java.util.HashSet]] to the member area of the generated [[Function]]. * diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/MatchCodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/MatchCodeGenerator.scala new file mode 100644 index 0000000000000..dd434b3bf4770 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/MatchCodeGenerator.scala @@ -0,0 +1,589 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.codegen + +import java.math.{BigDecimal => JBigDecimal} +import java.util + +import org.apache.calcite.rel.RelCollation +import org.apache.calcite.rex._ +import org.apache.calcite.sql.fun.SqlStdOperatorTable.{CLASSIFIER, FINAL, FIRST, LAST, MATCH_NUMBER, NEXT, PREV, RUNNING} +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.cep.{PatternFlatSelectFunction, PatternSelectFunction} +import org.apache.flink.cep.pattern.conditions.IterativeCondition +import org.apache.flink.table.api.TableConfig +import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.codegen.Indenter.toISC +import org.apache.flink.table.codegen.CodeGenUtils.{boxedTypeTermForTypeInfo, newName, primitiveDefaultValue} +import org.apache.flink.table.plan.schema.RowSchema +import org.apache.flink.types.Row + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +/** + * A code generator for generating CEP related functions. + * + * @param config configuration that determines runtime behavior + * @param nullableInput input(s) can be null. + * @param input type information about the first input of the Function + * @param patternNames the names of patterns + * @param generateCondition whether the code generator is generating [[IterativeCondition]] + * @param patternName the name of current pattern + */ +class MatchCodeGenerator( + config: TableConfig, + nullableInput: Boolean, + input: TypeInformation[_ <: Any], + patternNames: Seq[String], + generateCondition: Boolean, + patternName: Option[String] = None) + extends CodeGenerator(config, nullableInput, input){ + + /** + * @return term of pattern names + */ + private val patternNameListTerm = newName("patternNameList") + + /** + * @return term of current pattern which is processing + */ + private val currPatternTerm = newName("currPattern") + + /** + * @return term of current event which is processing + */ + private val currEventTerm = newName("currEvent") + + private val buildPatternNameList: String = { + for (patternName <- patternNames) yield + s""" + |$patternNameListTerm.add("$patternName"); + |""".stripMargin + }.mkString("\n") + + def addReusableStatements(): Unit = { + val eventTypeTerm = boxedTypeTermForTypeInfo(input) + val memberStatement = + s""" + |$eventTypeTerm $currEventTerm = null; + |String $currPatternTerm = null; + |java.util.List $patternNameListTerm = new java.util.ArrayList(); + |""".stripMargin + addReusableMemberStatement(memberStatement) + + addReusableInitStatement(buildPatternNameList) + } + + /** + * Generates a [[IterativeCondition]] that can be passed to Java compiler. + * + * @param name Class name of the function. Must not be unique but has to be a + * valid Java class identifier. + * @param bodyCode body code for the function + * @return a GeneratedIterativeCondition + */ + def generateIterativeCondition( + name: String, + bodyCode: String) + : GeneratedIterativeCondition = { + + val funcName = newName(name) + val inputTypeTerm = boxedTypeTermForTypeInfo(input) + + val funcCode = j""" + public class $funcName + extends ${classOf[IterativeCondition[_]].getCanonicalName} { + + ${reuseMemberCode()} + + public $funcName() throws Exception { + ${reuseInitCode()} + } + + @Override + public boolean filter( + Object _in1, ${classOf[IterativeCondition.Context[_]].getCanonicalName} $contextTerm) + throws Exception { + + $inputTypeTerm $input1Term = ($inputTypeTerm) _in1; + ${reusePerRecordCode()} + ${reuseInputUnboxingCode()} + $bodyCode + } + } + """.stripMargin + + GeneratedIterativeCondition(funcName, funcCode) + } + + /** + * Generates a [[PatternSelectFunction]] that can be passed to Java compiler. + * + * @param name Class name of the function. Must not be unique but has to be a + * valid Java class identifier. + * @param bodyCode body code for the function + * @return a GeneratedPatternSelectFunction + */ + def generatePatternSelectFunction( + name: String, + bodyCode: String) + : GeneratedPatternSelectFunction = { + + val funcName = newName(name) + val inputTypeTerm = + classOf[java.util.Map[java.lang.String, java.util.List[Row]]].getCanonicalName + + val funcCode = j""" + public class $funcName + implements ${classOf[PatternSelectFunction[_, _]].getCanonicalName} { + + ${reuseMemberCode()} + + public $funcName() throws Exception { + ${reuseInitCode()} + } + + @Override + public Object select(java.util.Map> _in1) + throws Exception { + + $inputTypeTerm $input1Term = ($inputTypeTerm) _in1; + ${reusePerRecordCode()} + ${reuseInputUnboxingCode()} + $bodyCode + } + } + """.stripMargin + + GeneratedPatternSelectFunction(funcName, funcCode) + } + + /** + * Generates a [[PatternFlatSelectFunction]] that can be passed to Java compiler. + * + * @param name Class name of the function. Must not be unique but has to be a + * valid Java class identifier. + * @param bodyCode body code for the function + * @return a GeneratedPatternFlatSelectFunction + */ + def generatePatternFlatSelectFunction( + name: String, + bodyCode: String) + : GeneratedPatternFlatSelectFunction = { + + val funcName = newName(name) + val inputTypeTerm = + classOf[java.util.Map[java.lang.String, java.util.List[Row]]].getCanonicalName + + val funcCode = j""" + public class $funcName + implements ${classOf[PatternFlatSelectFunction[_, _]].getCanonicalName} { + + ${reuseMemberCode()} + + public $funcName() throws Exception { + ${reuseInitCode()} + } + + @Override + public void flatSelect(java.util.Map> _in1, + org.apache.flink.util.Collector $collectorTerm) + throws Exception { + + $inputTypeTerm $input1Term = ($inputTypeTerm) _in1; + ${reusePerRecordCode()} + ${reuseInputUnboxingCode()} + $bodyCode + } + } + """.stripMargin + + GeneratedPatternFlatSelectFunction(funcName, funcCode) + } + + def generateSelectOutputExpression( + partitionKeys: util.List[RexNode], + measures: util.Map[String, RexNode], + returnType: RowSchema + ): GeneratedExpression = { + + val eventNameTerm = newName("event") + val eventTypeTerm = boxedTypeTermForTypeInfo(input) + + // For "ONE ROW PER MATCH", the output columns include: + // 1) the partition columns; + // 2) the columns defined in the measures clause. + val resultExprs = + partitionKeys.asScala.map { case inputRef: RexInputRef => + generateFieldAccess(input, eventNameTerm, inputRef.getIndex) + } ++ returnType.fieldNames.filter(measures.containsKey(_)).map { fieldName => + generateExpression(measures.get(fieldName)) + } + + val resultExpression = generateResultExpression( + resultExprs, + returnType.typeInfo, + returnType.fieldNames) + + val resultCode = + s""" + |$eventTypeTerm $eventNameTerm = null; + |if (${partitionKeys.size()} > 0) { + | for (java.util.Map.Entry entry : $input1Term.entrySet()) { + | java.util.List value = (java.util.List) entry.getValue(); + | if (value != null && value.size() > 0) { + | $eventNameTerm = ($eventTypeTerm) value.get(0); + | break; + | } + | } + |} + | + |${resultExpression.code} + |""".stripMargin + + resultExpression.copy(code = resultCode) + } + + def generateFlatSelectOutputExpression( + partitionKeys: util.List[RexNode], + orderKeys: RelCollation, + measures: util.Map[String, RexNode], + returnType: RowSchema) + : GeneratedExpression = { + + val patternNameTerm = newName("patternName") + val eventNameTerm = newName("event") + val eventNameListTerm = newName("eventList") + val eventTypeTerm = boxedTypeTermForTypeInfo(input) + val listTypeTerm = classOf[java.util.List[_]].getCanonicalName + + // For "ALL ROWS PER MATCH", the output columns include: + // 1) the partition columns; + // 2) the ordering columns; + // 3) the columns defined in the measures clause; + // 4) any remaining columns defined of the input. + val fieldsAccessed = mutable.Set[Int]() + val resultExprs = + partitionKeys.asScala.map { case inputRef: RexInputRef => + fieldsAccessed += inputRef.getIndex + generateFieldAccess(input, eventNameTerm, inputRef.getIndex) + } ++ orderKeys.getFieldCollations.asScala.map { fieldCollation => + fieldsAccessed += fieldCollation.getFieldIndex + generateFieldAccess(input, eventNameTerm, fieldCollation.getFieldIndex) + } ++ (0 until input.getArity).filterNot(fieldsAccessed.contains).map { idx => + generateFieldAccess(input, eventNameTerm, idx) + } ++ returnType.fieldNames.filter(measures.containsKey(_)).map { fieldName => + generateExpression(measures.get(fieldName)) + } + + val resultExpression = generateResultExpression( + resultExprs, + returnType.typeInfo, + returnType.fieldNames) + + val resultCode = + s""" + |for (String $patternNameTerm : $patternNameListTerm) { + | $currPatternTerm = $patternNameTerm; + | $listTypeTerm $eventNameListTerm = ($listTypeTerm) $input1Term.get($patternNameTerm); + | if ($eventNameListTerm != null) { + | for ($eventTypeTerm $eventNameTerm : $eventNameListTerm) { + | $currEventTerm = $eventNameTerm; + | ${resultExpression.code} + | $collectorTerm.collect(${resultExpression.resultTerm}); + | } + | } + |} + |$currPatternTerm = null; + |$currEventTerm = null; + |""".stripMargin + + GeneratedExpression("", "false", resultCode, null) + } + + override def visitCall(call: RexCall): GeneratedExpression = { + val resultType = FlinkTypeFactory.toTypeInfo(call.getType) + call.getOperator match { + case PREV => + val countLiteral = call.operands.get(1).asInstanceOf[RexLiteral] + val count = countLiteral.getValue3.asInstanceOf[JBigDecimal].intValue() + generatePrev( + call.operands.get(0), + count, + resultType) + + case NEXT | CLASSIFIER | MATCH_NUMBER => + throw new CodeGenException(s"Unsupported call: $call") + + case FIRST | LAST => + val countLiteral = call.operands.get(1).asInstanceOf[RexLiteral] + val count = countLiteral.getValue3.asInstanceOf[JBigDecimal].intValue() + generateFirstLast( + call.operands.get(0), + count, + resultType, + running = true, + call.getOperator == FIRST) + + case RUNNING | FINAL => + generateRunningFinal( + call.operands.get(0), + resultType, + call.getOperator == RUNNING) + + case _ => super.visitCall(call) + } + } + + private def generatePrev( + rexNode: RexNode, + count: Int, + resultType: TypeInformation[_]) + : GeneratedExpression = { + rexNode match { + case patternFieldRef: RexPatternFieldRef => + if (count == 0 && patternFieldRef.getAlpha == patternName.get) { + // return current one + return visitInputRef(patternFieldRef) + } + + val listName = newName("patternEvents") + val resultTerm = newName("result") + val nullTerm = newName("isNull") + val indexTerm = newName("eventIndex") + val visitedEventNumberTerm = newName("visitedEventNumber") + val eventTerm = newName("event") + val resultTypeTerm = boxedTypeTermForTypeInfo(resultType) + val defaultValue = primitiveDefaultValue(resultType) + + val eventTypeTerm = boxedTypeTermForTypeInfo(input) + + val patternNamesToVisit = patternNames + .take(patternNames.indexOf(patternFieldRef.getAlpha) + 1) + .reverse + def findEventByPhysicalPosition: String = { + val init: String = + s""" + |java.util.List $listName = new java.util.ArrayList(); + |""".stripMargin + + val getResult: String = { + for (tmpPatternName <- patternNamesToVisit) yield + s""" + |for ($eventTypeTerm $eventTerm : $contextTerm + | .getEventsForPattern("$tmpPatternName")) { + | $listName.add($eventTerm); + |} + | + |$indexTerm = $listName.size() - ($count - $visitedEventNumberTerm); + |if ($indexTerm >= 0) { + | $resultTerm = ($resultTypeTerm) (($eventTypeTerm) $listName.get($indexTerm)) + | .getField(${patternFieldRef.getIndex}); + | $nullTerm = false; + | break; + |} + | + |$visitedEventNumberTerm += $listName.size(); + |$listName.clear(); + |""".stripMargin + }.mkString("\n") + + s""" + |$init + |$getResult + |""".stripMargin + } + + val resultCode = + s""" + |int $visitedEventNumberTerm = 0; + |int $indexTerm; + |$resultTypeTerm $resultTerm = $defaultValue; + |boolean $nullTerm = true; + |do { + | $findEventByPhysicalPosition + |} while (false); + |""".stripMargin + + GeneratedExpression(resultTerm, nullTerm, resultCode, resultType) + + case rexCall: RexCall => + val operands = rexCall.operands.asScala.map { + operand => generatePrev( + operand, + count, + FlinkTypeFactory.toTypeInfo(operand.getType)) + } + + generateCallExpression(rexCall.getOperator, operands, resultType) + + case _ => + generateExpression(rexNode) + } + } + + private def generateFirstLast( + rexNode: RexNode, + count: Int, + resultType: TypeInformation[_], + running: Boolean, + first: Boolean) + : GeneratedExpression = { + rexNode match { + case patternFieldRef: RexPatternFieldRef => + + val eventNameTerm = newName("event") + val resultTerm = newName("result") + val listName = newName("patternEvents") + val nullTerm = newName("isNull") + val patternNameTerm = newName("patternName") + val eventNameListTerm = newName("eventNameList") + val resultTypeTerm = boxedTypeTermForTypeInfo(resultType) + val defaultValue = primitiveDefaultValue(resultType) + + val eventTypeTerm = boxedTypeTermForTypeInfo(input) + val listTypeTerm = classOf[java.util.List[_]].getCanonicalName + + def findEventByLogicalPosition: String = { + val init = + s""" + |java.util.List $listName = new java.util.ArrayList(); + |""".stripMargin + + val findEventsByPatterName = if (generateCondition) { + s""" + |for ($eventTypeTerm $eventNameTerm : $contextTerm + | .getEventsForPattern("${patternFieldRef.getAlpha}")) { + | $listName.add($eventNameTerm); + |} + |""".stripMargin + } else { + s""" + |for (String $patternNameTerm : $patternNameListTerm) { + | if ($patternNameTerm.equals("${patternFieldRef.getAlpha}") || + | ${patternFieldRef.getAlpha.equals("*")}) { + | boolean skipLoop = false; + | $listTypeTerm $eventNameListTerm = + | ($listTypeTerm) $input1Term.get($patternNameTerm); + | if ($eventNameListTerm != null) { + | for ($eventTypeTerm $eventNameTerm : $eventNameListTerm) { + | $listName.add($eventNameTerm); + | if ($running && $eventNameTerm == $currEventTerm) { + | skipLoop = true; + | break; + | } + | } + | } + | + | if (skipLoop) { + | break; + | } + | } + | + | if ($running && $patternNameTerm.equals($currPatternTerm)) { + | break; + | } + |} + |""".stripMargin + } + + val getResult = + s""" + |if ($listName.size() > $count) { + | if ($first) { + | $resultTerm = ($resultTypeTerm) (($eventTypeTerm) + | $listName.get($count)) + | .getField(${patternFieldRef.getIndex}); + | } else { + | $resultTerm = ($resultTypeTerm) (($eventTypeTerm) + | $listName.get($listName.size() - $count - 1)) + | .getField(${patternFieldRef.getIndex}); + | } + | $nullTerm = false; + |} + |""".stripMargin + + s""" + |$init + |$findEventsByPatterName + |$getResult + |""".stripMargin + } + + val resultCode = + s""" + |$resultTypeTerm $resultTerm = $defaultValue; + |boolean $nullTerm = true; + |$findEventByLogicalPosition + |""".stripMargin + + GeneratedExpression(resultTerm, nullTerm, resultCode, resultType) + + case rexCall: RexCall => + val operands = rexCall.operands.asScala.map { + operand => generateFirstLast( + operand, + count, + FlinkTypeFactory.toTypeInfo(operand.getType), + running, + first) + } + + generateCallExpression(rexCall.getOperator, operands, resultType) + + case _ => + generateExpression(rexNode) + } + } + + private def generateRunningFinal( + rexNode: RexNode, + resultType: TypeInformation[_], + running: Boolean) + : GeneratedExpression = { + rexNode match { + case _: RexPatternFieldRef => + generateFirstLast(rexNode, 0, resultType, running, first = false) + + case rexCall: RexCall if rexCall.getOperator == FIRST || rexCall.getOperator == LAST => + val countLiteral = rexCall.operands.get(1).asInstanceOf[RexLiteral] + val count = countLiteral.getValue3.asInstanceOf[JBigDecimal].intValue() + generateFirstLast( + rexCall.operands.get(0), + count, + resultType, + running, + rexCall.getOperator == FIRST) + + case rexCall: RexCall => + val operands = rexCall.operands.asScala.map { + operand => generateRunningFinal( + operand, + FlinkTypeFactory.toTypeInfo(operand.getType), + running) + } + + generateCallExpression(rexCall.getOperator, operands, resultType) + + case _ => + generateExpression(rexNode) + } + } +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/generated.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/generated.scala index c6d722a59a8d2..9b43b141c282f 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/generated.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/generated.scala @@ -92,3 +92,27 @@ case class GeneratedInput[F <: InputFormat[_, _], T <: Any]( * @param code code of the generated Collector. */ case class GeneratedCollector(name: String, code: String) + +/** + * Describes a generated [[org.apache.flink.cep.pattern.conditions.IterativeCondition]]. + * + * @param name class name of the generated IterativeCondition. + * @param code code of the generated IterativeCondition. + */ +case class GeneratedIterativeCondition(name: String, code: String) + +/** + * Describes a generated [[org.apache.flink.cep.PatternSelectFunction]]. + * + * @param name class name of the generated PatternSelectFunction. + * @param code code of the generated PatternSelectFunction. + */ +case class GeneratedPatternSelectFunction(name: String, code: String) + +/** + * Describes a generated [[org.apache.flink.cep.PatternFlatSelectFunction]]. + * + * @param name class name of the generated PatternFlatSelectFunction. + * @param code code of the generated PatternFlatSelectFunction. + */ +case class GeneratedPatternFlatSelectFunction(name: String, code: String) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamMatch.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamMatch.scala new file mode 100644 index 0000000000000..f2267d1bc426c --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamMatch.scala @@ -0,0 +1,324 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.nodes.datastream + +import java.util +import java.math.{BigDecimal => JBigDecimal} + +import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rel._ +import org.apache.calcite.rex._ +import org.apache.calcite.sql.`type`.SqlTypeName._ +import org.apache.calcite.sql.fun.SqlStdOperatorTable._ +import org.apache.flink.cep.{CEP, PatternStream} +import org.apache.flink.cep.pattern.Pattern +import org.apache.flink.streaming.api.datastream.DataStream +import org.apache.flink.streaming.api.windowing.time.Time +import org.apache.flink.table.api.{StreamQueryConfig, StreamTableEnvironment, TableException} +import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.plan.schema.RowSchema +import org.apache.flink.table.runtime.RowtimeProcessFunction +import org.apache.flink.table.runtime.`match`._ +import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo} +import org.apache.flink.types.Row + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ListBuffer + +/** + * Flink RelNode which matches along with LogicalMatch. + */ +class DataStreamMatch( + cluster: RelOptCluster, + traitSet: RelTraitSet, + input: RelNode, + pattern: RexNode, + strictStart: Boolean, + strictEnd: Boolean, + patternDefinitions: util.Map[String, RexNode], + measures: util.Map[String, RexNode], + after: RexNode, + subsets: util.Map[String, util.SortedSet[String]], + allRows: Boolean, + partitionKeys: util.List[RexNode], + orderKeys: RelCollation, + interval: RexNode, + schema: RowSchema, + inputSchema: RowSchema) + extends SingleRel(cluster, traitSet, input) + with DataStreamRel { + + override def deriveRowType(): RelDataType = schema.relDataType + + override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = { + new DataStreamMatch( + cluster, + traitSet, + inputs.get(0), + pattern, + strictStart, + strictEnd, + patternDefinitions, + measures, + after, + subsets, + allRows, + partitionKeys, + orderKeys, + interval, + schema, + inputSchema) + } + + override def toString: String = { + s"Match(${ + if (!partitionKeys.isEmpty) { + s"PARTITION BY: ${partitionKeys.toArray.map(_.toString).mkString(", ")}, " + } else { + "" + } + }${ + if (!orderKeys.getFieldCollations.isEmpty) { + s"ORDER BY: ${orderKeys.getFieldCollations.asScala.map { + x => inputSchema.relDataType.getFieldList.get(x.getFieldIndex).getName + }.mkString(", ")}, " + } else { + "" + } + }${ + if (!measures.isEmpty) { + s"MEASURES: ${measures.asScala.map { + case (k, v) => s"${v.toString} AS $k" + }.mkString(", ")}, " + } else { + "" + } + }${ + if (allRows) { + s"ALL ROWS PER MATCH, " + } else { + s"ONE ROW PER MATCH, " + } + }${ + s"${after.toString}, " + }${ + s"PATTERN: (${pattern.toString})" + }${ + if (interval != null) { + s"WITHIN INTERVAL: $interval, " + } else { + s", " + } + }${ + if (!subsets.isEmpty) { + s"SUBSET: ${subsets.asScala.map { + case (k, v) => s"$k = (${v.toArray.mkString(", ")})" + }.mkString(", ")}, " + } else { + "" + } + }${ + s"DEFINE: ${patternDefinitions.asScala.map { + case (k, v) => s"$k AS ${v.toString}" + }.mkString(", ")}" + })" + } + + override def explainTerms(pw: RelWriter): RelWriter = { + pw.input("input", getInput()) + .itemIf("partitionBy", + partitionKeys.toArray.map(_.toString).mkString(", "), + !partitionKeys.isEmpty) + .itemIf("orderBy", + orderKeys.getFieldCollations.asScala.map { + x => inputSchema.relDataType.getFieldList.get(x.getFieldIndex).getName + }.mkString(", "), + !orderKeys.getFieldCollations.isEmpty) + .itemIf("measures", + measures.asScala.map { case (k, v) => s"${v.toString} AS $k"}.mkString(", "), + !measures.isEmpty) + .item("allrows", allRows) + .item("after", after.toString) + .item("pattern", pattern.toString) + .itemIf("within interval", + if (interval != null) { + interval.toString + } else { + null + }, + interval != null) + .itemIf("subset", + subsets.asScala.map { case (k, v) => s"$k = (${v.toArray.mkString(", ")})"}.mkString(", "), + !subsets.isEmpty) + .item("define", + patternDefinitions.asScala.map { case (k, v) => s"$k AS ${v.toString}"}.mkString(", ")) + } + + override def translateToPlan( + tableEnv: StreamTableEnvironment, + queryConfig: StreamQueryConfig): DataStream[CRow] = { + + val config = tableEnv.config + val inputTypeInfo = inputSchema.typeInfo + + val crowInput: DataStream[CRow] = getInput + .asInstanceOf[DataStreamRel] + .translateToPlan(tableEnv, queryConfig) + + val rowtimeFields = inputSchema.relDataType + .getFieldList.asScala + .filter(f => FlinkTypeFactory.isRowtimeIndicatorType(f.getType)) + + val timestampedInput = if (rowtimeFields.nonEmpty) { + // copy the rowtime field into the StreamRecord timestamp field + val timeIdx = rowtimeFields.head.getIndex + + crowInput + .process(new RowtimeProcessFunction(timeIdx, CRowTypeInfo(inputTypeInfo))) + .setParallelism(crowInput.getParallelism) + .name(s"rowtime field: (${rowtimeFields.head})") + } else { + crowInput + } + + val inputDS: DataStream[Row] = timestampedInput + .map(new ConvertToRow) + .setParallelism(timestampedInput.getParallelism) + .name("ConvertToRow") + .returns(inputTypeInfo) + + def translatePattern( + rexNode: RexNode, + currentPattern: Pattern[Row, Row], + patternNames: ListBuffer[String]): Pattern[Row, Row] = rexNode match { + case literal: RexLiteral => + val patternName = literal.getValue3.toString + patternNames += patternName + val newPattern = next(currentPattern, patternName) + + val patternDefinition = patternDefinitions.get(patternName) + if (patternDefinition != null) { + val condition = MatchUtil.generateIterativeCondition( + config, + inputSchema, + patternName, + patternNames, + patternDefinition, + inputTypeInfo) + + newPattern.where(condition) + } else { + newPattern + } + + case call: RexCall => + + call.getOperator match { + case PATTERN_CONCAT => + val left = call.operands.get(0) + val right = call.operands.get(1) + translatePattern(right, + translatePattern(left, currentPattern, patternNames), + patternNames) + + case PATTERN_QUANTIFIER => + val name = call.operands.get(0).asInstanceOf[RexLiteral] + val newPattern = translatePattern(name, currentPattern, patternNames) + + val startNum = call.operands.get(1).asInstanceOf[RexLiteral] + .getValue3.asInstanceOf[JBigDecimal].intValue() + val endNum = call.operands.get(2).asInstanceOf[RexLiteral] + .getValue3.asInstanceOf[JBigDecimal].intValue() + + if (startNum == 0 && endNum == -1) { // zero or more + newPattern.oneOrMore().optional().consecutive() + } else if (startNum == 1 && endNum == -1) { // one or more + newPattern.oneOrMore().consecutive() + } else if (startNum == 0 && endNum == 1) { // optional + newPattern.optional() + } else if (endNum != -1) { // times + newPattern.times(startNum, endNum).consecutive() + } else { // times or more + newPattern.timesOrMore(startNum).consecutive() + } + + case PATTERN_ALTER => + throw TableException("Currently, CEP doesn't support branching patterns.") + + case PATTERN_PERMUTE => + throw TableException("Currently, CEP doesn't support PERMUTE patterns.") + + case PATTERN_EXCLUDE => + throw TableException("Currently, CEP doesn't support '{-' '-}' patterns.") + } + + case _ => + throw TableException("") + } + + val patternNames: ListBuffer[String] = ListBuffer() + val cepPattern = translatePattern(pattern, null, patternNames) + if (interval != null) { + val intervalLiteral = interval.asInstanceOf[RexLiteral] + val intervalValue = interval.asInstanceOf[RexLiteral].getValueAs(classOf[java.lang.Long]) + val intervalMs: Long = intervalLiteral.getTypeName match { + case INTERVAL_YEAR | INTERVAL_YEAR_MONTH | INTERVAL_MONTH => + // convert from months to milliseconds, suppose 1 month = 30 days + intervalValue * 30L * 24 * 3600 * 1000 + case _ => intervalValue + } + + cepPattern.within(Time.milliseconds(intervalMs)) + } + val patternStream: PatternStream[Row] = CEP.pattern[Row](inputDS, cepPattern) + + val outTypeInfo = CRowTypeInfo(schema.typeInfo) + if (allRows) { + val patternFlatSelectFunction = + MatchUtil.generatePatternFlatSelectFunction( + config, + schema, + patternNames, + partitionKeys, + orderKeys, + measures, + inputTypeInfo) + patternStream.flatSelect[CRow](patternFlatSelectFunction, outTypeInfo) + } else { + val patternSelectFunction = + MatchUtil.generatePatternSelectFunction( + config, + schema, + patternNames, + partitionKeys, + measures, + inputTypeInfo) + patternStream.select[CRow](patternSelectFunction, outTypeInfo) + } + } + + private def next(currentPattern: Pattern[Row, Row], patternName: String): Pattern[Row, Row] = { + if (currentPattern == null) { + Pattern.begin(patternName) + } else { + currentPattern.next(patternName) + } + } +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalMatch.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalMatch.scala new file mode 100644 index 0000000000000..9e29b105dc83e --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalMatch.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.nodes.logical + +import java.util + +import org.apache.calcite.plan._ +import org.apache.calcite.rel.{RelCollation, RelNode} +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rel.convert.ConverterRule +import org.apache.calcite.rel.core.Match +import org.apache.calcite.rel.logical.LogicalMatch +import org.apache.calcite.rex.RexNode +import org.apache.flink.table.plan.nodes.FlinkConventions + +class FlinkLogicalMatch( + cluster: RelOptCluster, + traitSet: RelTraitSet, + input: RelNode, + rowType: RelDataType, + pattern: RexNode, + strictStart: Boolean, + strictEnd: Boolean, + patternDefinitions: util.Map[String, RexNode], + measures: util.Map[String, RexNode], + after: RexNode, + subsets: util.Map[String, _ <: util.SortedSet[String]], + allRows: Boolean, + partitionKeys: util.List[RexNode], + orderKeys: RelCollation, + interval: RexNode) + extends Match( + cluster, + traitSet, + input, + rowType, + pattern, + strictStart, + strictEnd, + patternDefinitions, + measures, + after, + subsets, + allRows, + partitionKeys, + orderKeys, + interval) + with FlinkLogicalRel { + + override def copy( + input: RelNode, + rowType: RelDataType, + pattern: RexNode, + strictStart: Boolean, + strictEnd: Boolean, + patternDefinitions: util.Map[String, RexNode], + measures: util.Map[String, RexNode], + after: RexNode, + subsets: util.Map[String, _ <: util.SortedSet[String]], + allRows: Boolean, + partitionKeys: util.List[RexNode], + orderKeys: RelCollation, + interval: RexNode): Match = { + new FlinkLogicalMatch( + cluster, + traitSet, + input, + rowType, + pattern, + strictStart, + strictEnd, + patternDefinitions, + measures, + after, + subsets, + allRows, + partitionKeys, + orderKeys, + interval) + } +} + +private class FlinkLogicalMatchConverter + extends ConverterRule( + classOf[LogicalMatch], + Convention.NONE, + FlinkConventions.LOGICAL, + "FlinkLogicalMatchConverter") { + + override def convert(rel: RelNode): RelNode = { + val logicalMatch = rel.asInstanceOf[LogicalMatch] + val traitSet = rel.getTraitSet.replace(FlinkConventions.LOGICAL) + val newInput = RelOptRule.convert(logicalMatch.getInput, FlinkConventions.LOGICAL) + + new FlinkLogicalMatch( + rel.getCluster, + traitSet, + newInput, + logicalMatch.getRowType, + logicalMatch.getPattern, + logicalMatch.isStrictStart, + logicalMatch.isStrictEnd, + logicalMatch.getPatternDefinitions, + logicalMatch.getMeasures, + logicalMatch.getAfter, + logicalMatch.getSubsets, + logicalMatch.isAllRows, + logicalMatch.getPartitionKeys, + logicalMatch.getOrderKeys, + logicalMatch.getInterval) + } +} + +object FlinkLogicalMatch { + val CONVERTER: ConverterRule = new FlinkLogicalMatchConverter() +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala index 52dab8b33792c..da22c50a0d37b 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala @@ -133,7 +133,8 @@ object FlinkRuleSets { FlinkLogicalValues.CONVERTER, FlinkLogicalTableSourceScan.CONVERTER, FlinkLogicalTableFunctionScan.CONVERTER, - FlinkLogicalNativeTableScan.CONVERTER + FlinkLogicalNativeTableScan.CONVERTER, + FlinkLogicalMatch.CONVERTER ) /** @@ -211,7 +212,8 @@ object FlinkRuleSets { DataStreamCorrelateRule.INSTANCE, DataStreamWindowJoinRule.INSTANCE, DataStreamJoinRule.INSTANCE, - StreamTableSourceScanRule.INSTANCE + StreamTableSourceScanRule.INSTANCE, + DataStreamMatchRule.INSTANCE ) /** diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamMatchRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamMatchRule.scala new file mode 100644 index 0000000000000..f3c57e5f0d397 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamMatchRule.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.rules.datastream + +import org.apache.calcite.plan.{RelOptRule, RelTraitSet} +import org.apache.calcite.rel.RelNode +import org.apache.calcite.rel.convert.ConverterRule +import org.apache.flink.table.plan.nodes.FlinkConventions +import org.apache.flink.table.plan.nodes.datastream.DataStreamMatch +import org.apache.flink.table.plan.nodes.logical.FlinkLogicalMatch +import org.apache.flink.table.plan.schema.RowSchema + +class DataStreamMatchRule + extends ConverterRule( + classOf[FlinkLogicalMatch], + FlinkConventions.LOGICAL, + FlinkConventions.DATASTREAM, + "DataStreamMatchRule") { + + override def convert(rel: RelNode): RelNode = { + val logicalMatch: FlinkLogicalMatch = rel.asInstanceOf[FlinkLogicalMatch] + val traitSet: RelTraitSet = rel.getTraitSet.replace(FlinkConventions.DATASTREAM) + val convertInput: RelNode = + RelOptRule.convert(logicalMatch.getInput, FlinkConventions.DATASTREAM) + + new DataStreamMatch( + rel.getCluster, + traitSet, + convertInput, + logicalMatch.getPattern, + logicalMatch.isStrictStart, + logicalMatch.isStrictEnd, + logicalMatch.getPatternDefinitions, + logicalMatch.getMeasures, + logicalMatch.getAfter, + logicalMatch.getSubsets, + logicalMatch.isAllRows, + logicalMatch.getPartitionKeys, + logicalMatch.getOrderKeys, + logicalMatch.getInterval, + new RowSchema(logicalMatch.getRowType), + new RowSchema(logicalMatch.getInput.getRowType)) + } +} + +object DataStreamMatchRule { + val INSTANCE: RelOptRule = new DataStreamMatchRule +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/match/ConvertToRow.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/match/ConvertToRow.scala new file mode 100644 index 0000000000000..de73ddead1447 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/match/ConvertToRow.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.`match` + +import org.apache.flink.api.common.functions.MapFunction +import org.apache.flink.table.runtime.types.CRow +import org.apache.flink.types.Row + +/** + * MapFunction convert CRow to Row. + */ +class ConvertToRow extends MapFunction[CRow, Row] { + override def map(value: CRow): Row = { + value.row + } +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/match/IterativeConditionRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/match/IterativeConditionRunner.scala new file mode 100644 index 0000000000000..84077faa57d50 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/match/IterativeConditionRunner.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.`match` + +import org.apache.flink.cep.pattern.conditions.IterativeCondition +import org.apache.flink.table.codegen.Compiler +import org.apache.flink.types.Row +import org.slf4j.LoggerFactory + +/** + * IterativeConditionRunner with [[Row]] value. + */ +class IterativeConditionRunner( + name: String, + code: String) + extends IterativeCondition[Row] + with Compiler[IterativeCondition[Row]]{ + + val LOG = LoggerFactory.getLogger(this.getClass) + + // IterativeCondition will be serialized as part of state, + // so make function as transient to avoid ClassNotFoundException when restore state, + // see FLINK-6939 for details + @transient private var function: IterativeCondition[Row] = _ + + def init(): Unit = { + LOG.debug(s"Compiling IterativeCondition: $name \n\n Code:\n$code") + // We cannot get user's classloader currently, see FLINK-6938 for details + val clazz = compile(Thread.currentThread().getContextClassLoader, name, code) + LOG.debug("Instantiating IterativeCondition.") + function = clazz.newInstance() + } + + override def filter(value: Row, ctx: IterativeCondition.Context[Row]): Boolean = { + + if (function == null) { + init() + } + + function.filter(value, ctx) + } +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/match/MatchUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/match/MatchUtil.scala new file mode 100644 index 0000000000000..3e1d471b738ba --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/match/MatchUtil.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.`match` + +import java.util + +import org.apache.calcite.rel.RelCollation +import org.apache.calcite.rex.RexNode +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.cep.{PatternFlatSelectFunction, PatternSelectFunction} +import org.apache.flink.cep.pattern.conditions.IterativeCondition +import org.apache.flink.table.api.TableConfig +import org.apache.flink.table.codegen.MatchCodeGenerator +import org.apache.flink.table.plan.schema.RowSchema +import org.apache.flink.table.runtime.types.CRow +import org.apache.flink.types.Row + +/** + * An util class to generate match functions. + */ +object MatchUtil { + + private[flink] def generateIterativeCondition( + config: TableConfig, + inputType: RowSchema, + patternName: String, + patternNames: Seq[String], + patternDefinition: RexNode, + inputTypeInfo: TypeInformation[_]): IterativeCondition[Row] = { + + val generator = new MatchCodeGenerator( + config, false, inputTypeInfo, patternNames, true, Some(patternName)) + val condition = generator.generateExpression(patternDefinition) + val body = + s""" + |${condition.code} + |return ${condition.resultTerm}; + |""".stripMargin + + val genCondition = generator.generateIterativeCondition("MatchRecognizeCondition", body) + new IterativeConditionRunner(genCondition.name, genCondition.code) + } + + private[flink] def generatePatternSelectFunction( + config: TableConfig, + returnType: RowSchema, + patternNames: Seq[String], + partitionKeys: util.List[RexNode], + measures: util.Map[String, RexNode], + inputTypeInfo: TypeInformation[_]): PatternSelectFunction[Row, CRow] = { + + val generator = new MatchCodeGenerator(config, false, inputTypeInfo, patternNames, false) + + val resultExpression = generator.generateSelectOutputExpression( + partitionKeys, + measures, + returnType) + val body = + s""" + |${resultExpression.code} + |return ${resultExpression.resultTerm}; + |""".stripMargin + + generator.addReusableStatements() + val genFunction = generator.generatePatternSelectFunction( + "MatchRecognizePatternSelectFunction", + body) + new PatternSelectFunctionRunner(genFunction.name, genFunction.code) + } + + private[flink] def generatePatternFlatSelectFunction( + config: TableConfig, + returnType: RowSchema, + patternNames: Seq[String], + partitionKeys: util.List[RexNode], + orderKeys: RelCollation, + measures: util.Map[String, RexNode], + inputTypeInfo: TypeInformation[_]): PatternFlatSelectFunction[Row, CRow] = { + + val generator = new MatchCodeGenerator(config, false, inputTypeInfo, patternNames, false) + + val resultExpression = generator.generateFlatSelectOutputExpression( + partitionKeys, + orderKeys, + measures, + returnType) + val body = + s""" + |${resultExpression.code} + |""".stripMargin + + generator.addReusableStatements() + val genFunction = generator.generatePatternFlatSelectFunction( + "MatchRecognizePatternFlatSelectFunction", + body) + new PatternFlatSelectFunctionRunner(genFunction.name, genFunction.code) + } +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/match/PatternFlatSelectFunctionRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/match/PatternFlatSelectFunctionRunner.scala new file mode 100644 index 0000000000000..aeb25df117d1f --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/match/PatternFlatSelectFunctionRunner.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.`match` + +import java.util + +import org.apache.flink.cep.PatternFlatSelectFunction +import org.apache.flink.table.codegen.Compiler +import org.apache.flink.table.runtime.CRowWrappingCollector +import org.apache.flink.table.runtime.types.CRow +import org.apache.flink.types.Row +import org.apache.flink.util.Collector +import org.slf4j.LoggerFactory + +/** + * PatternFlatSelectFunctionRunner with [[Row]] input and [[CRow]] output. + */ +class PatternFlatSelectFunctionRunner( + name: String, + code: String) + extends PatternFlatSelectFunction[Row, CRow] + with Compiler[PatternFlatSelectFunction[Row, Row]] { + + val LOG = LoggerFactory.getLogger(this.getClass) + + private var cRowWrapper: CRowWrappingCollector = _ + + private var function: PatternFlatSelectFunction[Row, Row] = _ + + def init(): Unit = { + LOG.debug(s"Compiling PatternFlatSelectFunction: $name \n\n Code:\n$code") + val clazz = compile(Thread.currentThread().getContextClassLoader, name, code) + LOG.debug("Instantiating PatternFlatSelectFunction.") + function = clazz.newInstance() + + this.cRowWrapper = new CRowWrappingCollector() + } + + override def flatSelect( + pattern: util.Map[String, util.List[Row]], + out: Collector[CRow]): Unit = { + if (function == null) { + init() + } + + cRowWrapper.out = out + function.flatSelect(pattern, cRowWrapper) + } +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/match/PatternSelectFunctionRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/match/PatternSelectFunctionRunner.scala new file mode 100644 index 0000000000000..f15db680b5e82 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/match/PatternSelectFunctionRunner.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.`match` + +import java.util + +import org.apache.flink.cep.PatternSelectFunction +import org.apache.flink.table.codegen.Compiler +import org.apache.flink.table.runtime.types.CRow +import org.apache.flink.types.Row +import org.slf4j.LoggerFactory + +/** + * PatternSelectFunctionRunner with [[Row]] input and [[CRow]] output. + */ +class PatternSelectFunctionRunner( + name: String, + code: String) + extends PatternSelectFunction[Row, CRow] + with Compiler[PatternSelectFunction[Row, Row]] { + + val LOG = LoggerFactory.getLogger(this.getClass) + + private var outCRow: CRow = _ + + private var function: PatternSelectFunction[Row, Row] = _ + + def init(): Unit = { + LOG.debug(s"Compiling PatternSelectFunction: $name \n\n Code:\n$code") + val clazz = compile(Thread.currentThread().getContextClassLoader, name, code) + LOG.debug("Instantiating PatternSelectFunction.") + function = clazz.newInstance() + } + + override def select(pattern: util.Map[String, util.List[Row]]): CRow = { + if (outCRow == null) { + outCRow = new CRow(null, true) + } + + if (function == null) { + init() + } + + outCRow.row = function.select(pattern) + outCRow + } +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala index 3184e0001ea9a..4ceda1a6bf035 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala @@ -443,6 +443,15 @@ class BasicOperatorTable extends ReflectiveSqlOperatorTable { ScalarSqlFunctions.SHA384, ScalarSqlFunctions.SHA512, ScalarSqlFunctions.SHA2, + // MATCH_RECOGNIZE + SqlStdOperatorTable.FIRST, + SqlStdOperatorTable.LAST, + SqlStdOperatorTable.PREV, + SqlStdOperatorTable.NEXT, + SqlStdOperatorTable.CLASSIFIER, + SqlStdOperatorTable.MATCH_NUMBER, + SqlStdOperatorTable.FINAL, + SqlStdOperatorTable.RUNNING, // EXTENSIONS BasicOperatorTable.TUMBLE, BasicOperatorTable.HOP, diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/CepITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/CepITCase.scala new file mode 100644 index 0000000000000..9fb9e6bc4dd23 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/CepITCase.scala @@ -0,0 +1,410 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.api.stream.sql + +import org.apache.flink.api.scala._ +import org.apache.flink.streaming.api.TimeCharacteristic +import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment +import org.apache.flink.table.api.TableEnvironment +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.runtime.utils.TimeTestUtil.EventTimeSourceFunction +import org.apache.flink.table.runtime.utils.{StreamITCase, StreamingWithStateTestBase} +import org.apache.flink.types.Row +import org.junit.Assert.assertEquals +import org.junit.Test + +import scala.collection.mutable + +class CepITCase extends StreamingWithStateTestBase { + + @Test + def testSimpleCEP() = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val data = new mutable.MutableList[(Int, String)] + data.+=((1, "a")) + data.+=((2, "z")) + data.+=((3, "b")) + data.+=((4, "c")) + data.+=((5, "d")) + data.+=((6, "a")) + data.+=((7, "b")) + data.+=((8, "c")) + data.+=((9, "h")) + + val t = env.fromCollection(data).toTable(tEnv).as('id, 'name) + tEnv.registerTable("MyTable", t) + + val sqlQuery = + s""" + |SELECT T.aid, T.bid, T.cid + |FROM MyTable + |MATCH_RECOGNIZE ( + | MEASURES + | A.id AS aid, + | B.id AS bid, + | C.id AS cid + | PATTERN (A B C) + | DEFINE + | A AS A.name = 'a', + | B AS B.name = 'b', + | C AS C.name = 'c' + |) AS T + |""".stripMargin + + val result = tEnv.sql(sqlQuery).toAppendStream[Row] + result.addSink(new StreamITCase.StringSink[Row]) + env.execute() + + val expected = mutable.MutableList("6,7,8") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testAllRowsPerMatch() = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val data = new mutable.MutableList[(Int, String)] + data.+=((1, "a")) + data.+=((2, "z")) + data.+=((3, "b")) + data.+=((4, "c")) + data.+=((5, "d")) + data.+=((6, "a")) + data.+=((7, "b")) + data.+=((8, "c")) + data.+=((9, "h")) + + val t = env.fromCollection(data).toTable(tEnv).as('id, 'name) + tEnv.registerTable("MyTable", t) + + val sqlQuery = + s""" + |SELECT * + |FROM MyTable + |MATCH_RECOGNIZE ( + | MEASURES + | A.id AS aid, + | B.id AS bid, + | C.id AS cid + | ALL ROWS PER MATCH + | PATTERN (A B C) + | DEFINE + | A AS A.name = 'a', + | B AS B.name = 'b', + | C AS C.name = 'c' + |) AS T + |""".stripMargin + + val result = tEnv.sql(sqlQuery).toAppendStream[Row] + result.addSink(new StreamITCase.StringSink[Row]) + env.execute() + + val expected = mutable.MutableList("6,a,6,null,null", "7,b,6,7,null", "8,c,6,7,8") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testFinalFirst() = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val data = new mutable.MutableList[(String, Long, Int, Int)] + data.+=(("ACME", 1L, 12, 1)) + data.+=(("ACME", 2L, 17, 2)) + data.+=(("ACME", 3L, 13, 3)) + data.+=(("ACME", 4L, 15, 4)) + data.+=(("ACME", 5L, 20, 5)) + data.+=(("ACME", 6L, 24, 6)) + data.+=(("ACME", 7L, 25, 7)) + data.+=(("ACME", 8L, 19, 8)) + + val t = env.fromCollection(data).toTable(tEnv).as('symbol, 'tstamp, 'price, 'tax) + tEnv.registerTable("Ticker", t) + + val sqlQuery = + s""" + |SELECT * + |FROM Ticker + |MATCH_RECOGNIZE ( + | MEASURES + | STRT.tstamp AS start_tstamp, + | FIRST(DOWN.tstamp) AS bottom_tstamp, + | FIRST(UP.tstamp) AS end_tstamp, + | FIRST(DOWN.price + DOWN.tax + 1) AS bottom_total, + | FIRST(UP.price + UP.tax) AS end_total + | ONE ROW PER MATCH + | PATTERN (STRT DOWN+ UP+) + | DEFINE + | DOWN AS DOWN.price < PREV(DOWN.price), + | UP AS UP.price > PREV(UP.price) + |) AS T + |""".stripMargin + + val result = tEnv.sql(sqlQuery).toAppendStream[Row] + result.addSink(new StreamITCase.StringSink[Row]) + env.execute() + + val expected = List("2,3,4,17,19", "2,3,4,17,19", "2,3,4,17,19", "2,3,4,17,19") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testFinalLast() = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val data = new mutable.MutableList[(String, Long, Int, Int)] + data.+=(("ACME", 1L, 12, 1)) + data.+=(("ACME", 2L, 17, 2)) + data.+=(("ACME", 3L, 13, 3)) + data.+=(("ACME", 4L, 15, 4)) + data.+=(("ACME", 5L, 20, 5)) + data.+=(("ACME", 6L, 24, 6)) + data.+=(("ACME", 7L, 25, 7)) + data.+=(("ACME", 8L, 19, 8)) + + val t = env.fromCollection(data).toTable(tEnv).as('symbol, 'tstamp, 'price, 'tax) + tEnv.registerTable("Ticker", t) + + val sqlQuery = + s""" + |SELECT * + |FROM Ticker + |MATCH_RECOGNIZE ( + | MEASURES + | STRT.tstamp AS start_tstamp, + | LAST(DOWN.tstamp) AS bottom_tstamp, + | LAST(UP.tstamp) AS end_tstamp, + | LAST(DOWN.price + DOWN.tax) AS bottom_total, + | LAST(UP.price + UP.tax + 1) AS end_total + | ONE ROW PER MATCH + | PATTERN (STRT DOWN+ UP+) + | DEFINE + | DOWN AS DOWN.price < PREV(DOWN.price), + | UP AS UP.price > PREV(UP.price) + |) AS T + |""".stripMargin + + val result = tEnv.sql(sqlQuery).toAppendStream[Row] + result.addSink(new StreamITCase.StringSink[Row]) + env.execute() + + val expected = List("2,3,4,16,20", "2,3,5,16,26", "2,3,6,16,31", "2,3,7,16,33") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testPrev() = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val data = new mutable.MutableList[(String, Long, Int)] + data.+=(("ACME", 1L, 12)) + data.+=(("ACME", 2L, 17)) + data.+=(("ACME", 3L, 13)) + data.+=(("ACME", 4L, 11)) + data.+=(("ACME", 5L, 14)) + data.+=(("ACME", 6L, 12)) + data.+=(("ACME", 7L, 13)) + data.+=(("ACME", 8L, 19)) + + val t = env.fromCollection(data).toTable(tEnv).as('symbol, 'tstamp, 'price) + tEnv.registerTable("Ticker", t) + + val sqlQuery = + s""" + |SELECT * + |FROM Ticker + |MATCH_RECOGNIZE ( + | MEASURES + | STRT.tstamp AS start_tstamp, + | LAST(DOWN.tstamp) AS up_days, + | LAST(UP.tstamp) AS total_days + | PATTERN (STRT DOWN+ UP+) + | DEFINE + | DOWN AS DOWN.price < PREV(DOWN.price), + | UP AS UP.price > PREV(UP.price, 2) + |) AS T + |""".stripMargin + + val result = tEnv.sql(sqlQuery).toAppendStream[Row] + result.addSink(new StreamITCase.StringSink[Row]) + env.execute() + + val expected = List("2,4,5", "2,4,6", "3,4,5", "3,4,6") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testRunningFirst() = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val data = new mutable.MutableList[(String, Long, Int, Int)] + data.+=(("ACME", 1L, 12, 1)) + data.+=(("ACME", 2L, 17, 2)) + data.+=(("ACME", 3L, 13, 4)) + data.+=(("ACME", 4L, 11, 3)) + data.+=(("ACME", 5L, 20, 5)) + data.+=(("ACME", 6L, 24, 4)) + data.+=(("ACME", 7L, 25, 3)) + data.+=(("ACME", 8L, 19, 8)) + + val t = env.fromCollection(data).toTable(tEnv).as('symbol, 'tstamp, 'price, 'tax) + tEnv.registerTable("Ticker", t) + + val sqlQuery = + s""" + |SELECT * + |FROM Ticker + |MATCH_RECOGNIZE ( + | MEASURES + | STRT.tstamp AS start_tstamp, + | LAST(DOWN.tstamp) AS bottom_tstamp, + | LAST(UP.tstamp) AS end_tstamp + | ONE ROW PER MATCH + | PATTERN (STRT DOWN+ UP+) + | DEFINE + | DOWN AS DOWN.price < PREV(DOWN.price), + | UP AS UP.price > PREV(UP.price) AND UP.tax > FIRST(DOWN.tax) + |) AS T + |""".stripMargin + + val result = tEnv.sql(sqlQuery).toAppendStream[Row] + result.addSink(new StreamITCase.StringSink[Row]) + env.execute() + + val expected = List("2,4,5", "3,4,5", "3,4,6") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testRunningLast() = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val data = new mutable.MutableList[(String, Long, Int, Int)] + data.+=(("ACME", 1L, 12, 1)) + data.+=(("ACME", 2L, 17, 2)) + data.+=(("ACME", 3L, 13, 4)) + data.+=(("ACME", 4L, 11, 3)) + data.+=(("ACME", 5L, 20, 4)) + data.+=(("ACME", 6L, 24, 4)) + data.+=(("ACME", 7L, 25, 3)) + data.+=(("ACME", 8L, 19, 8)) + + val t = env.fromCollection(data).toTable(tEnv).as('symbol, 'tstamp, 'price, 'tax) + tEnv.registerTable("Ticker", t) + + val sqlQuery = + s""" + |SELECT * + |FROM Ticker + |MATCH_RECOGNIZE ( + | MEASURES + | STRT.tstamp AS start_tstamp, + | LAST(DOWN.tstamp) AS bottom_tstamp, + | LAST(UP.tstamp) AS end_tstamp + | ONE ROW PER MATCH + | PATTERN (STRT DOWN+ UP+) + | DEFINE + | DOWN AS DOWN.price < PREV(DOWN.price), + | UP AS UP.price > PREV(UP.price) AND UP.tax > LAST(DOWN.tax) + |) AS T + |""".stripMargin + + val result = tEnv.sql(sqlQuery).toAppendStream[Row] + result.addSink(new StreamITCase.StringSink[Row]) + env.execute() + + val expected = List("2,4,5", "2,4,6", "3,4,5", "3,4,6") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testWithinEventTime() = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setParallelism(1) + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val data = new mutable.MutableList[Either[(Long, (String, Int, Int)), Long]] + data.+=(Left((3000L, ("ACME", 17, 2)))) + data.+=(Left((1000L, ("ACME", 12, 1)))) + data.+=(Right(4000L)) + data.+=(Left((5000L, ("ACME", 13, 3)))) + data.+=(Left((7000L, ("ACME", 15, 4)))) + data.+=(Right(8000L)) + data.+=(Left((9000L, ("ACME", 20, 5)))) + data.+=(Right(13000L)) + data.+=(Left((15000L, ("ACME", 19, 8)))) + data.+=(Right(16000L)) + + val t = env.addSource(new EventTimeSourceFunction[(String, Int, Int)](data)) + .toTable(tEnv, 'symbol, 'price, 'tax, 'tstamp.rowtime) + tEnv.registerTable("Ticker", t) + + val sqlQuery = + s""" + |SELECT * + |FROM Ticker + |MATCH_RECOGNIZE ( + | PARTITION BY symbol + | ORDER BY tstamp + | MEASURES + | STRT.tstamp AS start_tstamp, + | FIRST(DOWN.tstamp) AS bottom_tstamp, + | FIRST(UP.tstamp) AS end_tstamp, + | FIRST(DOWN.price + DOWN.tax + 1) AS bottom_total, + | FIRST(UP.price + UP.tax) AS end_total + | ONE ROW PER MATCH + | PATTERN (STRT DOWN+ UP+) within interval '5' second + | DEFINE + | DOWN AS DOWN.price < PREV(DOWN.price), + | UP AS UP.price > PREV(UP.price) + |) AS T + |""".stripMargin + + val result = tEnv.sql(sqlQuery).toAppendStream[Row] + result.addSink(new StreamITCase.StringSink[Row]) + env.execute() + + val expected = List( + "ACME,1970-01-01 00:00:03.0,1970-01-01 00:00:05.0,1970-01-01 00:00:07.0,17,19") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } +} From 08a25252d1546cf4266ba1af933e5dc1debbefce Mon Sep 17 00:00:00 2001 From: Dian Fu Date: Wed, 25 Jul 2018 17:47:18 +0800 Subject: [PATCH 2/2] minor update --- .../table/calcite/RelTimeIndicatorConverter.scala | 14 +++++++++++--- .../apache/flink/table/codegen/CodeGenerator.scala | 6 +++--- .../flink/table/codegen/MatchCodeGenerator.scala | 14 +++++++------- .../{api => runtime}/stream/sql/CepITCase.scala | 2 +- 4 files changed, 22 insertions(+), 14 deletions(-) rename flink-libraries/flink-table/src/test/scala/org/apache/flink/table/{api => runtime}/stream/sql/CepITCase.scala (99%) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala index 56f700dc25558..5a62d6471d92f 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala @@ -211,9 +211,17 @@ class RelTimeIndicatorConverter(rexBuilder: RexBuilder) extends RelShuttle { private def convertMatch(`match`: Match): LogicalMatch = { val rowType = `match`.getInput.getRowType + val materializer = new RexTimeIndicatorMaterializer( + rexBuilder, + rowType.getFieldList.map(_.getType)) + + val patternDefinitions = + `match`.getPatternDefinitions.foldLeft(mutable.Map[String, RexNode]()) { + case (m, (k, v)) => m += k -> v.accept(materializer) + } + val measures = `match`.getMeasures.foldLeft(mutable.Map[String, RexNode]()) { - case (m, (k, v)) => - m += k -> RelTimeIndicatorConverter.convertExpression(v, rowType, rexBuilder) + case (m, (k, v)) => m += k -> v.accept(materializer) } val outputTypeBuilder = rexBuilder @@ -232,7 +240,7 @@ class RelTimeIndicatorConverter(rexBuilder: RexBuilder) extends RelShuttle { `match`.getPattern, `match`.isStrictStart, `match`.isStrictEnd, - `match`.getPatternDefinitions, + patternDefinitions, measures, `match`.getAfter, `match`.getSubsets.asInstanceOf[java.util.Map[String, java.util.TreeSet[String]]], diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala index 43c3f8949a179..d74547b52ae8b 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala @@ -757,15 +757,15 @@ abstract class CodeGenerator( o.accept(this) } - generateCallExpression(call.getOperator, operands, resultType) + generateCallExpression(call, operands, resultType) } def generateCallExpression( - operator: SqlOperator, + call: RexCall, operands: Seq[GeneratedExpression], resultType: TypeInformation[_]) : GeneratedExpression = { - operator match { + call.getOperator match { // arithmetic case PLUS if isNumeric(resultType) => val left = operands.head diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/MatchCodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/MatchCodeGenerator.scala index dd434b3bf4770..51566c71ffdea 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/MatchCodeGenerator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/MatchCodeGenerator.scala @@ -219,10 +219,10 @@ class MatchCodeGenerator( } def generateSelectOutputExpression( - partitionKeys: util.List[RexNode], - measures: util.Map[String, RexNode], - returnType: RowSchema - ): GeneratedExpression = { + partitionKeys: util.List[RexNode], + measures: util.Map[String, RexNode], + returnType: RowSchema) + : GeneratedExpression = { val eventNameTerm = newName("event") val eventTypeTerm = boxedTypeTermForTypeInfo(input) @@ -432,7 +432,7 @@ class MatchCodeGenerator( FlinkTypeFactory.toTypeInfo(operand.getType)) } - generateCallExpression(rexCall.getOperator, operands, resultType) + generateCallExpression(rexCall, operands, resultType) case _ => generateExpression(rexNode) @@ -546,7 +546,7 @@ class MatchCodeGenerator( first) } - generateCallExpression(rexCall.getOperator, operands, resultType) + generateCallExpression(rexCall, operands, resultType) case _ => generateExpression(rexNode) @@ -580,7 +580,7 @@ class MatchCodeGenerator( running) } - generateCallExpression(rexCall.getOperator, operands, resultType) + generateCallExpression(rexCall, operands, resultType) case _ => generateExpression(rexNode) diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/CepITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/CepITCase.scala similarity index 99% rename from flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/CepITCase.scala rename to flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/CepITCase.scala index 9fb9e6bc4dd23..66d724b057d0c 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/CepITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/CepITCase.scala @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.table.api.stream.sql +package org.apache.flink.table.runtime.stream.sql import org.apache.flink.api.scala._ import org.apache.flink.streaming.api.TimeCharacteristic