From 89dae2470d515e2efb7e2c0cb7afedfb9416b9f2 Mon Sep 17 00:00:00 2001 From: Stefano Bortoli Date: Tue, 25 Apr 2017 18:38:53 +0200 Subject: [PATCH 1/2] Distinct implementation in aggregation code generator --- .../flink/table/codegen/CodeGenerator.scala | 87 ++++++- .../codegen/calls/FunctionGenerator.scala | 17 ++ ...supportedOperatorsIndicatorFunctions.scala | 39 +++ .../datastream/DataStreamOverAggregate.scala | 42 ++- .../runtime/aggregate/AggregateUtil.scala | 48 +++- .../aggregate/GeneratedAggregations.scala | 12 +- .../aggregate/ProcTimeBoundedRowsOver.scala | 11 +- .../table/validate/FunctionCatalog.scala | 6 +- .../api/scala/stream/sql/SqlITCase.scala | 241 ++++++++++++++++++ 9 files changed, 487 insertions(+), 16 deletions(-) create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UnsupportedOperatorsIndicatorFunctions.scala diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala index 298fb70d742fd..1b3166a045d41 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala @@ -272,6 +272,7 @@ class CodeGenerator( fwdMapping: Array[Int], mergeMapping: Option[Array[Int]], constantFlags: Option[Array[(Int, Boolean)]], + distinctAggsFlags: Array[Boolean], outputArity: Int, needRetract: Boolean, needMerge: Boolean) @@ -296,6 +297,41 @@ class CodeGenerator( fields.mkString(", ") } +def genInitialize(existDistinct : Boolean): String = { + + val sig: String = + j""" + | org.apache.flink.api.common.state.MapState[] distStateList = + | new org.apache.flink.api.common.state.MapState[ ${distinctAggsFlags.size} ]; + | + | public void initialize( + | org.apache.flink.api.common.functions.RuntimeContext ctx + | )""".stripMargin + if(existDistinct){ + val initDist: String = { + for(i <- distinctAggsFlags.indices) yield + if( distinctAggsFlags(i)) { + j""" + | + | org.apache.flink.api.common.state.MapStateDescriptor distDesc$i = + | new org.apache.flink.api.common.state.MapStateDescriptor( + | "distinctValuesBufferMapState" + $i, + | Object.class, Long.class); + | distStateList[$i] = ctx.getMapState( distDesc$i ); + """.stripMargin + } else { + "" + } + }.mkString("\n") + + j"""$sig { + | $initDist + | }""".stripMargin + }else { + j"""$sig { }""".stripMargin + } + } + def genSetAggregationResults: String = { val sig: String = @@ -335,14 +371,28 @@ class CodeGenerator( j""" | public final void accumulate( | org.apache.flink.types.Row accs, - | org.apache.flink.types.Row input)""".stripMargin + | org.apache.flink.types.Row input) throws Exception""".stripMargin val accumulate: String = { for (i <- aggs.indices) yield - j""" + if(distinctAggsFlags(i)){ + j""" + | Long distValCount$i = (Long) distStateList[$i].get(${parameters(i)}); + | if( distValCount$i == null){ + | ${aggs(i)}.accumulate( + | ((${accTypes(i)}) accs.getField($i)), + | ${parameters(i)}); + | distValCount$i = 0L; + | } + | distValCount$i += 1; + | distStateList[$i].put(${parameters(i)}, distValCount$i); + """.stripMargin + }else { + j""" | ${aggs(i)}.accumulate( | ((${accTypes(i)}) accs.getField($i)), | ${parameters(i)});""".stripMargin + } }.mkString("\n") j"""$sig { @@ -356,14 +406,29 @@ class CodeGenerator( j""" | public final void retract( | org.apache.flink.types.Row accs, - | org.apache.flink.types.Row input)""".stripMargin + | org.apache.flink.types.Row input) throws Exception""".stripMargin val retract: String = { for (i <- aggs.indices) yield - j""" - | ${aggs(i)}.retract( - | ((${accTypes(i)}) accs.getField($i)), - | ${parameters(i)});""".stripMargin + if(distinctAggsFlags(i)){ + j""" + | Long distValCount$i = (Long) distStateList[$i].get(${parameters(i)}); + | if(distValCount$i == 1L){ + | ${aggs(i)}.retract( + | ((${accTypes(i)}) accs.getField($i)), + | ${parameters(i)}); + | distStateList[$i].remove(${parameters(i)}); + | } else { + | distValCount$i -= 1L; + | distStateList[$i].put(${parameters(i)},distValCount$i); + | } + """.stripMargin + } else { + j""" + | ${aggs(i)}.retract( + | ((${accTypes(i)}) accs.getField($i)), + | ${parameters(i)});""".stripMargin + } }.mkString("\n") if (needRetract) { @@ -533,7 +598,12 @@ class CodeGenerator( |$reset | }""".stripMargin } - + + var existDistinct = false + for(i <- distinctAggsFlags.indices){ + if(distinctAggsFlags(i)){ existDistinct = true } + } + var funcCode = j""" |public final class $funcName @@ -548,6 +618,7 @@ class CodeGenerator( | """.stripMargin + funcCode += genInitialize(existDistinct) + "\n" funcCode += genSetAggregationResults + "\n" funcCode += genAccumulate + "\n" funcCode += genRetract + "\n" diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/FunctionGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/FunctionGenerator.scala index 27e6dc6a449cb..c82e7d8a450db 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/FunctionGenerator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/FunctionGenerator.scala @@ -33,6 +33,7 @@ import org.apache.flink.table.functions.{EventTimeExtractor, ProcTimeExtractor} import org.apache.flink.table.functions.utils.{ScalarSqlFunction, TableSqlFunction} import scala.collection.mutable +import org.apache.flink.table.functions.DistinctAggregatorExtractor /** * Global hub for user-defined and built-in advanced SQL functions. @@ -338,6 +339,22 @@ object FunctionGenerator { } }) + /** + * Temporary workaround waiting for the resolution of + * https://issues.apache.org/jira/browse/CALCITE-1740 + * To support distinct on aggregation + */ + case DistinctAggregatorExtractor => + Some(new CallGenerator { + override def generate(codeGenerator: CodeGenerator, operands: Seq[GeneratedExpression]) = { + // the "empty" unary operator generates an equivalence between + // the parameter and the result of the function, making distinct + // a dummy function from the computation perspective + ScalarOperators. + generateUnaryArithmeticOperator("", true, resultType, operands.head) + } + }) + // built-in scalar function case _ => sqlFunctions.get((sqlOperator, operandTypes)) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UnsupportedOperatorsIndicatorFunctions.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UnsupportedOperatorsIndicatorFunctions.scala new file mode 100644 index 0000000000000..8aa6ec8a0d33f --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UnsupportedOperatorsIndicatorFunctions.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.functions + +import org.apache.calcite.sql.SqlFunction +import org.apache.calcite.sql.SqlFunctionCategory +import org.apache.calcite.sql.SqlKind +import org.apache.calcite.sql.SqlSyntax +import org.apache.calcite.sql.`type`.InferTypes +import org.apache.calcite.sql.`type`.OperandTypes +import org.apache.calcite.sql.`type`.ReturnTypes + +/** + * An SQL Function DISTINCT() used to mark the DISTINCT operator + * on aggregation input. This is temporary workaround waiting for + * https://issues.apache.org/jira/browse/CALCITE-1740 being solved + */ +object DistinctAggregatorExtractor extends SqlFunction("DIST", SqlKind.OTHER_FUNCTION, + ReturnTypes.ARG0, InferTypes.RETURN_TYPE, + OperandTypes.ANY, SqlFunctionCategory.SYSTEM) { + + override def getSyntax: SqlSyntax = SqlSyntax.FUNCTION + +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala index 2224752c0018b..f97a6effaf2a7 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala @@ -37,6 +37,8 @@ import org.apache.flink.api.java.functions.NullByteKeySelector import org.apache.flink.table.codegen.CodeGenerator import org.apache.flink.table.functions.{ProcTimeType, RowTimeType} import org.apache.flink.table.runtime.aggregate.AggregateUtil.CalcitePair +import java.util.HashMap +import java.util.Map class DataStreamOverAggregate( logicWindow: Window, @@ -91,6 +93,22 @@ class DataStreamOverAggregate( val overWindow: org.apache.calcite.rel.core.Window.Group = logicWindow.groups.get(0) + val distinctVarMap: Map[String,Boolean] = new HashMap[String, Boolean] + if (input.isInstanceOf[DataStreamCalc]) { + val dsCalc = input.asInstanceOf[DataStreamCalc] + val iter = dsCalc + .selectionToString(dsCalc.getProgram, dsCalc.getExpressionString) + .split(",") + .iterator + while (iter.hasNext) { + val exp = iter.next + if(exp.contains("DIST")){ + val varName = exp.substring(exp.indexOf("$")) + distinctVarMap.put(varName,true) + } + } + } + val orderKeys = overWindow.orderKeys.getFieldCollations if (orderKeys.size() != 1) { @@ -124,6 +142,7 @@ class DataStreamOverAggregate( createUnboundedAndCurrentRowOverWindow( generator, inputDS, + distinctVarMap, isRowTimeType = false, isRowsClause = overWindow.isRows) } else if ( @@ -133,6 +152,7 @@ class DataStreamOverAggregate( createBoundedAndCurrentRowOverWindow( generator, inputDS, + distinctVarMap, isRowTimeType = false, isRowsClause = overWindow.isRows ) @@ -148,6 +168,7 @@ class DataStreamOverAggregate( createUnboundedAndCurrentRowOverWindow( generator, inputDS, + distinctVarMap, isRowTimeType = true, isRowsClause = overWindow.isRows ) @@ -156,6 +177,7 @@ class DataStreamOverAggregate( createBoundedAndCurrentRowOverWindow( generator, inputDS, + distinctVarMap, isRowTimeType = true, isRowsClause = overWindow.isRows ) @@ -174,6 +196,7 @@ class DataStreamOverAggregate( def createUnboundedAndCurrentRowOverWindow( generator: CodeGenerator, inputDS: DataStream[Row], + distinctVarMap : Map[String,Boolean], isRowTimeType: Boolean, isRowsClause: Boolean): DataStream[Row] = { @@ -183,10 +206,18 @@ class DataStreamOverAggregate( // get the output types val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo] - + + val aggregateCalls = overWindow.getAggregateCalls(logicWindow) + val distinctAggFlags: Array[Boolean] = new Array[Boolean](aggregateCalls.size) + for (i <- 0 until aggregateCalls.size()){ + val aggParamName = "$" + namedAggregates(i).getKey.getArgList.get(0) + distinctAggFlags(i) = distinctVarMap.get(aggParamName) + } + val processFunction = AggregateUtil.createUnboundedOverProcessFunction( generator, namedAggregates, + distinctAggFlags, inputType, isRowTimeType, partitionKeys.nonEmpty, @@ -224,6 +255,7 @@ class DataStreamOverAggregate( def createBoundedAndCurrentRowOverWindow( generator: CodeGenerator, inputDS: DataStream[Row], + distinctVarMap : Map[String,Boolean], isRowTimeType: Boolean, isRowsClause: Boolean): DataStream[Row] = { @@ -231,6 +263,13 @@ class DataStreamOverAggregate( val partitionKeys: Array[Int] = overWindow.keys.toArray val namedAggregates: Seq[CalcitePair[AggregateCall, String]] = generateNamedAggregates + val aggregateCalls = overWindow.getAggregateCalls(logicWindow) + val distinctAggFlags: Array[Boolean] = new Array[Boolean](aggregateCalls.size) + for (i <- 0 until aggregateCalls.size()){ + val aggParamName = "$" + namedAggregates(i).getKey.getArgList.get(0) + distinctAggFlags(i) = distinctVarMap.get(aggParamName) + } + val precedingOffset = getLowerBoundary(logicWindow, overWindow, getInput()) + (if (isRowsClause) 1 else 0) @@ -240,6 +279,7 @@ class DataStreamOverAggregate( val processFunction = AggregateUtil.createBoundedOverProcessFunction( generator, namedAggregates, + distinctAggFlags, inputType, precedingOffset, isRowsClause, diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala index a82f38312ce16..dfc3f868ed808 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala @@ -69,6 +69,7 @@ object AggregateUtil { private[flink] def createUnboundedOverProcessFunction( generator: CodeGenerator, namedAggregates: Seq[CalcitePair[AggregateCall, String]], + distinctAggregatesFlags: Array[Boolean], inputType: RelDataType, isRowTimeType: Boolean, isPartitioned: Boolean, @@ -99,6 +100,7 @@ object AggregateUtil { forwardMapping, None, None, + distinctAggregatesFlags, outputArity, needRetract, needMerge = false @@ -146,6 +148,7 @@ object AggregateUtil { private[flink] def createBoundedOverProcessFunction( generator: CodeGenerator, namedAggregates: Seq[CalcitePair[AggregateCall, String]], + distinctAggregatesFlags: Array[Boolean], inputType: RelDataType, precedingOffset: Long, isRowsClause: Boolean, @@ -176,6 +179,7 @@ object AggregateUtil { forwardMapping, None, None, + distinctAggregatesFlags, outputArity, needRetract, needMerge = false @@ -201,6 +205,7 @@ object AggregateUtil { if (isRowsClause) { new ProcTimeBoundedRowsOver( genFunction, + distinctAggregatesFlags, precedingOffset, aggregationStateType, inputRowType) @@ -251,7 +256,7 @@ object AggregateUtil { namedAggregates.map(_.getKey), inputType, needRetract) - + val mapReturnType: RowTypeInfo = createDataSetAggregateBufferDataType( groupings, @@ -290,6 +295,10 @@ object AggregateUtil { val aggMapping = aggregates.indices.toArray.map(_ + groupings.length) val outputArity = aggregates.length + groupings.length + 1 + // remove when distinct is supported + val distinctAggregatesFlags = new Array[Boolean](aggregates.size) + for(i <- aggregates.indices) distinctAggregatesFlags(i) = false + val genFunction = generator.generateAggregations( "DataSetAggregatePrepareMapHelper", generator, @@ -301,6 +310,7 @@ object AggregateUtil { groupings, None, None, + distinctAggregatesFlags, outputArity, needRetract, needMerge = false @@ -363,6 +373,10 @@ object AggregateUtil { val keysAndAggregatesArity = groupings.length + namedAggregates.length + // remove when distinct is supported + val distinctAggregatesFlags = new Array[Boolean](aggregates.size) + for(i <- aggregates.indices) distinctAggregatesFlags(i)=false + window match { case EventTimeSlidingGroupWindow(_, _, size, slide) if isTimeInterval(size.resultType) => // sliding time-window for partial aggregations @@ -377,6 +391,7 @@ object AggregateUtil { groupings, Some(aggregates.indices.map(_ + groupings.length).toArray), None, + distinctAggregatesFlags, keysAndAggregatesArity + 1, needRetract, needMerge = true @@ -468,6 +483,10 @@ object AggregateUtil { val aggMapping = aggregates.indices.toArray.map(_ + groupings.length) + // remove when distinct is supported + val distinctAggregatesFlags = new Array[Boolean](aggregates.size) + for(i <- aggregates.indices) distinctAggregatesFlags(i)=false + val genPreAggFunction = generator.generateAggregations( "GroupingWindowAggregateHelper", generator, @@ -479,6 +498,7 @@ object AggregateUtil { groupings, Some(aggregates.indices.map(_ + groupings.length).toArray), None, + distinctAggregatesFlags, outputType.getFieldCount, needRetract, needMerge = true @@ -495,6 +515,7 @@ object AggregateUtil { groupings.indices.toArray, Some(aggregates.indices.map(_ + groupings.length).toArray), None, + distinctAggregatesFlags, outputType.getFieldCount, needRetract, needMerge = true @@ -614,6 +635,10 @@ object AggregateUtil { val keysAndAggregatesArity = groupings.length + namedAggregates.length + // remove when distinct is supported + val distinctAggregatesFlags = new Array[Boolean](aggregates.size) + for(i <- aggregates.indices) distinctAggregatesFlags(i)=false + window match { case EventTimeSessionGroupWindow(_, _, gap) => val combineReturnType: RowTypeInfo = @@ -634,6 +659,7 @@ object AggregateUtil { groupings.indices.toArray, Some(aggregates.indices.map(_ + groupings.length).toArray), None, + distinctAggregatesFlags, groupings.length + aggregates.length + 2, needRetract, needMerge = true @@ -681,6 +707,10 @@ object AggregateUtil { inputType, needRetract) + // remove when distinct is supported + val distinctAggregatesFlags = new Array[Boolean](aggregates.size) + for(i <- aggregates.indices) distinctAggregatesFlags(i)=false + val aggMapping = aggregates.indices.map(_ + groupings.length).toArray val keysAndAggregatesArity = groupings.length + namedAggregates.length @@ -706,6 +736,7 @@ object AggregateUtil { groupings.indices.toArray, Some(aggregates.indices.map(_ + groupings.length).toArray), None, + distinctAggregatesFlags, groupings.length + aggregates.length + 2, needRetract, needMerge = true @@ -724,6 +755,7 @@ object AggregateUtil { } } + /** * Create functions to compute a [[org.apache.flink.table.plan.nodes.dataset.DataSetAggregate]]. * If all aggregation functions support pre-aggregation, a pre-aggregation function and the @@ -763,7 +795,11 @@ object AggregateUtil { } else { None } - + + // remove when distinct is supported + val distinctAggregatesFlags = new Array[Boolean](aggregates.size) + for(i <- aggregates.indices) distinctAggregatesFlags(i)=false + val aggOutFields = aggOutMapping.map(_._1) if (doAllSupportPartialMerge(aggregates)) { @@ -785,6 +821,7 @@ object AggregateUtil { groupings, None, None, + distinctAggregatesFlags, groupings.length + aggregates.length, needRetract, needMerge = false @@ -811,6 +848,7 @@ object AggregateUtil { gkeyMapping, Some(aggregates.indices.map(_ + groupings.length).toArray), constantFlags, + distinctAggregatesFlags, outputType.getFieldCount, needRetract, needMerge = true @@ -834,6 +872,7 @@ object AggregateUtil { groupings, None, constantFlags, + distinctAggregatesFlags, outputType.getFieldCount, needRetract, needMerge = false @@ -915,6 +954,10 @@ object AggregateUtil { val aggMapping = aggregates.indices.toArray val outputArity = aggregates.length + // remove when distinct is supported + val distinctAggregatesFlags = new Array[Boolean](aggregates.size) + for(i <- aggregates.indices) distinctAggregatesFlags(i)=false + val genFunction = generator.generateAggregations( "GroupingWindowAggregateHelper", generator, @@ -926,6 +969,7 @@ object AggregateUtil { Array(), // no fields are forwarded None, None, + distinctAggregatesFlags, outputArity, needRetract, needMerge = true diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala index 5f48e091996e5..8cee58be745be 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala @@ -20,16 +20,22 @@ package org.apache.flink.table.runtime.aggregate import org.apache.flink.api.common.functions.Function import org.apache.flink.types.Row +import org.apache.flink.api.common.functions.RuntimeContext /** * Base class for code-generated aggregations. */ abstract class GeneratedAggregations extends Function { + + /** + * Initialize the state for the distinct aggregation check + * + * @param ctx the runtime context to retrieve and initialize the distinct states + */ + def initialize(ctx: RuntimeContext) /** - * Sets the results of the aggregations (partial or final) to the output row. - * Final results are computed with the aggregation function. - * Partial results are the accumulators themselves. + * Calculate the results from accumulators, and set the results to the output * * @param accumulators the accumulators (saved in a row) which contains the current * aggregated results diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRowsOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRowsOver.scala index 31cfd730eab89..32a1d7e327e1d 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRowsOver.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRowsOver.scala @@ -46,6 +46,7 @@ import org.slf4j.LoggerFactory */ class ProcTimeBoundedRowsOver( genAggregations: GeneratedAggregationsFunction, + distinctAggFlags: Array[Boolean], precedingOffset: Long, aggregatesTypeInfo: RowTypeInfo, inputType: TypeInformation[Row]) @@ -72,7 +73,15 @@ class ProcTimeBoundedRowsOver( genAggregations.code) LOG.debug("Instantiating AggregateHelper.") function = clazz.newInstance() - + + var initialized = false + for(i <- distinctAggFlags.indices){ + if(distinctAggFlags(i) && !initialized){ + function.initialize(getRuntimeContext()) + initialized = true + } + } + output = function.createOutputRow() // We keep the elements received in a Map state keyed // by the ingestion time in the operator. diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala index 74b371aedd9fd..4d29a85439529 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala @@ -359,7 +359,11 @@ class BasicOperatorTable extends ReflectiveSqlOperatorTable { SqlStdOperatorTable.HOP_END, SqlStdOperatorTable.SESSION, SqlStdOperatorTable.SESSION_START, - SqlStdOperatorTable.SESSION_END + SqlStdOperatorTable.SESSION_END, + + // TO BE REMOVED WHEN https://issues.apache.org/jira/browse/CALCITE-1740 + // is merged and calcite is updated + DistinctAggregatorExtractor ) builtInSqlOperators.foreach(register) diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala index 67d13b0f455a5..774cc923918f2 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala @@ -1140,6 +1140,247 @@ class SqlITCase extends StreamingWithStateTestBase { assertEquals(expected.sorted, StreamITCase.testResults.sorted) } + @Test + def testNonPartitionedProcTimeOverDistinctWindow(): Unit = { + + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStateBackend(getStateBackend) + val tEnv = TableEnvironment.getTableEnvironment(env) + env.setParallelism(1) + StreamITCase.testResults = mutable.MutableList() + + val t = StreamTestData.get5TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) + tEnv.registerTable("MyTable", t) + + val sqlQuery = "SELECT a, " + + " SUM(DIST(e)) OVER (" + + " ORDER BY procTime() ROWS BETWEEN 10 PRECEDING AND CURRENT ROW) AS sumE " + + " FROM MyTable" + val result = tEnv.sql(sqlQuery).toDataStream[Row] + result.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList( + "1,1", + "2,3", + "2,3", + "3,3", + "3,3", + "3,6", + "4,6", + "4,6", + "4,6", + "4,6", + "5,6", + "5,6", + "5,6", + "5,6", + "5,6") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testPartitionedProcTimeOverDistinctWindow(): Unit = { + + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStateBackend(getStateBackend) + val tEnv = TableEnvironment.getTableEnvironment(env) + env.setParallelism(1) + StreamITCase.testResults = mutable.MutableList() + + val t = StreamTestData.get5TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) + tEnv.registerTable("MyTable", t) + + val sqlQuery = "SELECT a, " + + " SUM(DIST(e)) OVER (" + + " PARTITION BY a ORDER BY procTime() ROWS BETWEEN 10 PRECEDING AND CURRENT ROW) AS sumE " + + " FROM MyTable" + val result = tEnv.sql(sqlQuery).toDataStream[Row] + result.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList( + "1,1", + "2,2", + "2,3", + "3,2", + "3,2", + "3,5", + "4,2", + "4,3", + "4,3", + "4,3", + "5,1", + "5,4", + "5,4", + "5,6", + "5,6") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testPartitionedProcTimeOverDistinctWindow2(): Unit = { + + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStateBackend(getStateBackend) + val tEnv = TableEnvironment.getTableEnvironment(env) + env.setParallelism(1) + StreamITCase.testResults = mutable.MutableList() + + val t = StreamTestData.get5TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) + tEnv.registerTable("MyTable", t) + + val sqlQuery = "SELECT a, " + + " SUM(DIST(e)) OVER (" + + " PARTITION BY a ORDER BY procTime() ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS sumE " + + " FROM MyTable" + val result = tEnv.sql(sqlQuery).toDataStream[Row] + result.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList( + "1,1", + "2,2", + "2,3", + "3,2", + "3,2", + "3,5", + "4,2", + "4,3", + "4,3", + "4,3", + "5,1", + "5,4", + "5,4", + "5,5", + "5,5") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testNonPartitionedProcTimeOverDistinctWindow2(): Unit = { + + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStateBackend(getStateBackend) + val tEnv = TableEnvironment.getTableEnvironment(env) + env.setParallelism(1) + StreamITCase.testResults = mutable.MutableList() + + val t = StreamTestData.get5TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) + tEnv.registerTable("MyTable", t) + + val sqlQuery = "SELECT a, " + + " MAX(e) OVER (ORDER BY procTime() ROWS BETWEEN 10 PRECEDING AND CURRENT ROW) AS maxE," + + " SUM(DIST(e)) " + + " OVER (ORDER BY procTime() ROWS BETWEEN 10 PRECEDING AND CURRENT ROW) AS sumE" + + " FROM MyTable" + val result = tEnv.sql(sqlQuery).toDataStream[Row] + result.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList( + "1,1,1", + "2,2,3", + "2,2,3", + "3,2,3", + "3,2,3", + "3,3,6", + "4,3,6", + "4,3,6", + "4,3,6", + "4,3,6", + "5,3,6", + "5,3,6", + "5,3,6", + "5,3,6", + "5,3,6") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testNonPartitionedProcTimeOverDistinctWindow3(): Unit = { + + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStateBackend(getStateBackend) + val tEnv = TableEnvironment.getTableEnvironment(env) + env.setParallelism(1) + StreamITCase.testResults = mutable.MutableList() + + val t = StreamTestData.get5TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) + tEnv.registerTable("MyTable", t) + + val sqlQuery = "SELECT a, " + + " COUNT(DIST(a)) " + + " OVER (ORDER BY procTime() ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS countA ," + + " SUM(DIST(e)) " + + " OVER (ORDER BY procTime() ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS sumE " + + " FROM MyTable" + val result = tEnv.sql(sqlQuery).toDataStream[Row] + result.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList( + "1,1,1", + "2,2,3", + "2,2,3", + "3,2,3", + "3,2,3", + "3,1,5", + "4,2,5", + "4,2,6", + "4,1,3", + "4,1,3", + "5,2,3", + "5,2,6", + "5,1,4", + "5,1,5", + "5,1,5") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testPartitionedProcTimeOverDistinctWindow3(): Unit = { + + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStateBackend(getStateBackend) + val tEnv = TableEnvironment.getTableEnvironment(env) + env.setParallelism(1) + StreamITCase.testResults = mutable.MutableList() + + val t = StreamTestData.get5TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) + tEnv.registerTable("MyTable", t) + + val sqlQuery = "SELECT a, " + + " MIN(DIST(b)) " + + " OVER (PARTITION BY a ORDER BY procTime() " + + " ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS minB ," + + " SUM(DIST(e)) " + + " OVER (PARTITION BY a ORDER BY procTime() " + + " ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS sumE " + + " FROM MyTable" + + val result = tEnv.sql(sqlQuery).toDataStream[Row] + result.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList( + "1,1,1", + "2,2,2", + "2,2,3", + "3,4,2", + "3,4,2", + "3,4,5", + "4,7,2", + "4,7,3", + "4,7,3", + "4,8,3", + "5,11,1", + "5,11,4", + "5,11,4", + "5,12,5", + "5,13,5") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } } object SqlITCase { From deab745cc0018c484538b606dac203196d740849 Mon Sep 17 00:00:00 2001 From: Stefano Bortoli Date: Tue, 25 Apr 2017 18:53:08 +0200 Subject: [PATCH 2/2] fixing code generation test --- .../BoundedProcessingOverRangeProcessFunctionTest.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/BoundedProcessingOverRangeProcessFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/BoundedProcessingOverRangeProcessFunctionTest.scala index 5e3e995221ec5..0604ae473b2dc 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/BoundedProcessingOverRangeProcessFunctionTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/BoundedProcessingOverRangeProcessFunctionTest.scala @@ -88,6 +88,12 @@ class BoundedProcessingOverRangeProcessFunctionTest { | "mluZyRMb25nJOda0iCPo2ukAgAAeHA"); | } | + | public void initialize( + | org.apache.flink.api.common.functions.RuntimeContext ctx) { + | + | } + | + | | public void setAggregationResults( | org.apache.flink.types.Row accs, | org.apache.flink.types.Row output) {