From 3c720b926b6225d0ae6e62f08cffad2d2dd62b88 Mon Sep 17 00:00:00 2001 From: beyond1920 Date: Wed, 3 Apr 2019 18:37:32 +0800 Subject: [PATCH] [FLINK-12017][table-planner-blink] Support translation from Rank/Deduplicate to StreamTransformation. --- flink-table/flink-table-planner-blink/pom.xml | 7 + .../table/plan/util/KeySelectorUtil.java | 71 + .../apache/flink/table/api/TableConfig.scala | 14 + .../flink/table/codegen/CodeGenUtils.scala | 4 +- .../codegen/EqualiserCodeGenerator.scala | 148 ++ .../table/codegen/SinkCodeGenerator.scala | 12 +- .../table/codegen/SortCodeGenerator.scala | 1 + .../table/codegen/ValuesCodeGenerator.scala | 4 +- .../plan/nodes/calcite/LogicalRank.scala | 2 +- .../flink/table/plan/nodes/calcite/Rank.scala | 78 +- .../plan/nodes/logical/FlinkLogicalRank.scala | 6 +- .../nodes/physical/batch/BatchExecRank.scala | 6 +- .../stream/StreamExecDataStreamScan.scala | 8 +- .../stream/StreamExecDeduplicate.scala | 100 +- .../physical/stream/StreamExecExchange.scala | 54 +- .../physical/stream/StreamExecRank.scala | 142 +- .../rules/logical/FlinkLogicalRankRule.scala | 7 +- .../physical/batch/BatchExecRankRule.scala | 6 +- .../stream/StreamExecDeduplicateRule.scala | 5 +- .../table/plan/util/FlinkRelMdUtil.scala | 4 +- .../flink/table/plan/util/RankUtil.scala | 16 +- .../flink/table/plan/util/SortUtil.scala | 77 + .../table/typeutils/TypeCheckUtils.scala | 65 +- .../utils/FailingCollectionSource.java | 269 ++++ .../stream/sql/DeduplicateITCase.scala | 170 +++ .../table/runtime/stream/sql/RankITCase.scala | 1296 +++++++++++++++++ .../table/runtime/utils/StreamTestSink.scala | 277 +++- .../runtime/utils/StreamingTestBase.scala | 6 + .../utils/StreamingWithStateTestBase.scala | 271 ++++ .../flink/table/runtime/utils/TableUtil.scala | 13 +- .../table/runtime/utils/TimeTestUtil.scala | 67 + .../flink/table/api/TableConfigOptions.java | 36 + .../flink/table/dataformat/BinaryRow.java | 16 + .../flink/table/dataformat/BinaryWriter.java | 3 +- .../dataformat/DataFormatConverters.java | 11 +- .../table/dataformat/TypeGetterSetters.java | 3 +- .../table/dataformat/util/BinaryRowUtil.java | 10 + .../deduplicate/DeduplicateFunction.java | 97 ++ .../deduplicate/DeduplicateFunctionBase.java | 69 + .../runtime/functions/ProcessFunction.java | 119 ++ .../ProcessFunctionWithCleanupState.java | 94 ++ .../keySelector/BaseRowKeySelector.java | 33 + .../keySelector/BinaryRowKeySelector.java | 56 + .../keySelector/NullBinaryRowKeySelector.java | 41 + .../runtime/keyed/KeyedProcessOperator.java | 205 +++ .../runtime/rank/AbstractRankFunction.java | 306 ++++ .../rank/AbstractUpdateRankFunction.java | 294 ++++ .../runtime/rank/AppendRankFunction.java | 226 +++ .../table/runtime/rank/ConstantRankRange.java | 53 + .../rank/ConstantRankRangeWithoutEnd.java | 43 + .../flink/table/runtime/rank/RankRange.java | 32 + .../flink/table/runtime/rank/RankType.java | 49 + .../runtime/rank/RetractRankFunction.java | 263 ++++ .../flink/table/runtime/rank/SortedMap.java | 214 +++ .../runtime/rank/UpdateRankFunction.java | 259 ++++ .../table/runtime/rank/VariableRankRange.java | 47 + .../runtime/values/ValuesInputFormat.java | 4 +- .../flink/table/type/TypeConverters.java | 4 + .../typeutils/AbstractMapSerializer.java | 200 +++ .../table/typeutils/AbstractMapTypeInfo.java | 149 ++ .../table/typeutils/SortedMapSerializer.java | 120 ++ .../table/typeutils/SortedMapTypeInfo.java | 145 ++ 62 files changed, 6281 insertions(+), 126 deletions(-) create mode 100644 flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/util/KeySelectorUtil.java create mode 100644 flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/EqualiserCodeGenerator.scala create mode 100644 flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/SortUtil.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/runtime/utils/FailingCollectionSource.java create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/DeduplicateITCase.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/RankITCase.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingWithStateTestBase.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TimeTestUtil.scala create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunction.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionBase.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/ProcessFunction.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/ProcessFunctionWithCleanupState.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keySelector/BaseRowKeySelector.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keySelector/BinaryRowKeySelector.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keySelector/NullBinaryRowKeySelector.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyed/KeyedProcessOperator.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractRankFunction.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractUpdateRankFunction.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AppendRankFunction.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/ConstantRankRange.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/ConstantRankRangeWithoutEnd.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RankRange.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RankType.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RetractRankFunction.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/SortedMap.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/UpdateRankFunction.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/VariableRankRange.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/AbstractMapSerializer.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/AbstractMapTypeInfo.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapSerializer.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapTypeInfo.java diff --git a/flink-table/flink-table-planner-blink/pom.xml b/flink-table/flink-table-planner-blink/pom.xml index f8ab62e3a264e5..4fad3028883d85 100644 --- a/flink-table/flink-table-planner-blink/pom.xml +++ b/flink-table/flink-table-planner-blink/pom.xml @@ -200,6 +200,13 @@ under the License. test-jar test + + + org.apache.flink + flink-statebackend-rocksdb_${scala.binary.version} + ${project.version} + test + diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/util/KeySelectorUtil.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/util/KeySelectorUtil.java new file mode 100644 index 00000000000000..c9ec1ee2977a8a --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/util/KeySelectorUtil.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.util; + +import org.apache.flink.table.api.TableConfig; +import org.apache.flink.table.codegen.CodeGeneratorContext; +import org.apache.flink.table.codegen.ProjectionCodeGenerator; +import org.apache.flink.table.generated.GeneratedProjection; +import org.apache.flink.table.runtime.keySelector.BaseRowKeySelector; +import org.apache.flink.table.runtime.keySelector.BinaryRowKeySelector; +import org.apache.flink.table.runtime.keySelector.NullBinaryRowKeySelector; +import org.apache.flink.table.type.InternalType; +import org.apache.flink.table.type.RowType; +import org.apache.flink.table.typeutils.BaseRowSerializer; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; +import org.apache.flink.table.typeutils.TypeCheckUtils; + +/** + * Utility for KeySelector. + */ +public class KeySelectorUtil { + + /** + * Create a BaseRowKeySelector to extract keys from DataStream which type is BaseRowTypeInfo. + * + * @param keyFields key fields + * @param rowType type of DataStream to extract keys + * + * @return the BaseRowKeySelector to extract keys from DataStream which type is BaseRowTypeInfo. + */ + public static BaseRowKeySelector getBaseRowSelector(int[] keyFields, BaseRowTypeInfo rowType) { + if (keyFields.length > 0) { + InternalType[] inputFieldTypes = rowType.getInternalTypes(); + String[] inputFieldNames = rowType.getFieldNames(); + InternalType[] keyFieldTypes = new InternalType[keyFields.length]; + String[] keyFieldNames = new String[keyFields.length]; + for (int i = 0; i < keyFields.length; ++i) { + keyFieldTypes[i] = inputFieldTypes[keyFields[i]]; + keyFieldNames[i] = inputFieldNames[keyFields[i]]; + } + RowType returnType = new RowType(keyFieldTypes, keyFieldNames); + RowType inputType = new RowType(inputFieldTypes, rowType.getFieldNames()); + GeneratedProjection generatedProjection = ProjectionCodeGenerator.generateProjection( + CodeGeneratorContext.apply(new TableConfig()), + BaseRowSerializer.class.getSimpleName(), inputType, returnType, keyFields); + BaseRowTypeInfo keyRowType = returnType.toTypeInfo(); + // check if type implements proper equals/hashCode + TypeCheckUtils.validateEqualsHashCode("grouping", keyRowType); + return new BinaryRowKeySelector(keyRowType, generatedProjection); + } else { + return new NullBinaryRowKeySelector(); + } + } + +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableConfig.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableConfig.scala index fdad9d57ae03a3..d684a3bd7efd97 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableConfig.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableConfig.scala @@ -161,6 +161,20 @@ class TableConfig { !disableOperators.contains(operator.toString) } } + + def getMinIdleStateRetentionTime: Long = { + this.conf.getLong(TableConfigOptions.SQL_EXEC_STATE_TTL_MS) + } + + def getMaxIdleStateRetentionTime: Long = { + // only min idle ttl provided. + if (this.conf.contains(TableConfigOptions.SQL_EXEC_STATE_TTL_MS) + && !this.conf.contains(TableConfigOptions.SQL_EXEC_STATE_TTL_MAX_MS)) { + this.conf.setLong(TableConfigOptions.SQL_EXEC_STATE_TTL_MAX_MS, + getMinIdleStateRetentionTime * 2) + } + this.conf.getLong(TableConfigOptions.SQL_EXEC_STATE_TTL_MAX_MS) + } } object TableConfig { diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CodeGenUtils.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CodeGenUtils.scala index 6a3104131e277c..a048733143a607 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CodeGenUtils.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CodeGenUtils.scala @@ -111,7 +111,7 @@ object CodeGenUtils { case InternalTypes.DATE => "int" case InternalTypes.TIME => "int" - case InternalTypes.TIMESTAMP => "long" + case _: TimestampType => "long" case InternalTypes.INTERVAL_MONTHS => "int" case InternalTypes.INTERVAL_MILLIS => "long" @@ -131,7 +131,7 @@ object CodeGenUtils { case InternalTypes.DATE => boxedTypeTermForType(InternalTypes.INT) case InternalTypes.TIME => boxedTypeTermForType(InternalTypes.INT) - case InternalTypes.TIMESTAMP => boxedTypeTermForType(InternalTypes.LONG) + case _: TimestampType => boxedTypeTermForType(InternalTypes.LONG) case InternalTypes.STRING => BINARY_STRING case InternalTypes.BINARY => "byte[]" diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/EqualiserCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/EqualiserCodeGenerator.scala new file mode 100644 index 00000000000000..35c8d92d49f15d --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/EqualiserCodeGenerator.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.codegen + +import org.apache.flink.table.api.TableConfig +import org.apache.flink.table.codegen.CodeGenUtils._ +import org.apache.flink.table.codegen.Indenter.toISC +import org.apache.flink.table.dataformat.{BaseRow, BinaryRow} +import org.apache.flink.table.generated.{GeneratedRecordEqualiser, RecordEqualiser} +import org.apache.flink.table.`type`.{DateType, InternalType, PrimitiveType, RowType, TimeType, TimestampType} + +class EqualiserCodeGenerator(fieldTypes: Seq[InternalType]) { + + private val BASE_ROW = className[BaseRow] + private val BINARY_ROW = className[BinaryRow] + private val RECORD_EQUALISER = className[RecordEqualiser] + private val LEFT_INPUT = "left" + private val RIGHT_INPUT = "right" + + def generateRecordEqualiser(name: String): GeneratedRecordEqualiser = { + // ignore time zone + val ctx = CodeGeneratorContext(new TableConfig) + val className = newName(name) + val header = + s""" + |if ($LEFT_INPUT.getHeader() != $RIGHT_INPUT.getHeader()) { + | return false; + |} + """.stripMargin + + val codes = for (i <- fieldTypes.indices) yield { + val fieldType = fieldTypes(i) + val fieldTypeTerm = primitiveTypeTermForType(fieldType) + val result = s"cmp$i" + val leftNullTerm = "leftIsNull$" + i + val rightNullTerm = "rightIsNull$" + i + val leftFieldTerm = "leftField$" + i + val rightFieldTerm = "rightField$" + i + val equalsCode = if (isInternalPrimitive(fieldType)) { + s"$leftFieldTerm == $rightFieldTerm" + } else if (isBaseRow(fieldType)) { + val equaliserGenerator = + new EqualiserCodeGenerator(fieldType.asInstanceOf[RowType].getFieldTypes) + val generatedEqualiser = equaliserGenerator + .generateRecordEqualiser("field$" + i + "GeneratedEqualiser") + val generatedEqualiserTerm = ctx.addReusableObject( + generatedEqualiser, "field$" + i + "GeneratedEqualiser") + val equaliserTypeTerm = classOf[RecordEqualiser].getCanonicalName + val equaliserTerm = newName("equaliser") + ctx.addReusableMember(s"private $equaliserTypeTerm $equaliserTerm = null;") + ctx.addReusableInitStatement( + s""" + |$equaliserTerm = ($equaliserTypeTerm) + | $generatedEqualiserTerm.newInstance(Thread.currentThread().getContextClassLoader()); + |""".stripMargin) + s"$equaliserTerm.equalsWithoutHeader($leftFieldTerm, $rightFieldTerm)" + } else { + s"$leftFieldTerm.equals($rightFieldTerm)" + } + val leftReadCode = baseRowFieldReadAccess(ctx, i, LEFT_INPUT, fieldType) + val rightReadCode = baseRowFieldReadAccess(ctx, i, RIGHT_INPUT, fieldType) + s""" + |boolean $leftNullTerm = $LEFT_INPUT.isNullAt($i); + |boolean $rightNullTerm = $RIGHT_INPUT.isNullAt($i); + |boolean $result; + |if ($leftNullTerm && $rightNullTerm) { + | $result = true; + |} else if ($leftNullTerm || $rightNullTerm) { + | $result = false; + |} else { + | $fieldTypeTerm $leftFieldTerm = $leftReadCode; + | $fieldTypeTerm $rightFieldTerm = $rightReadCode; + | $result = $equalsCode; + |} + |if (!$result) { + | return false; + |} + """.stripMargin + } + + val functionCode = + j""" + public final class $className implements $RECORD_EQUALISER { + + ${ctx.reuseMemberCode()} + + public $className(Object[] references) throws Exception { + ${ctx.reuseInitCode()} + } + + @Override + public boolean equals($BASE_ROW $LEFT_INPUT, $BASE_ROW $RIGHT_INPUT) { + if ($LEFT_INPUT instanceof $BINARY_ROW && $RIGHT_INPUT instanceof $BINARY_ROW) { + return $LEFT_INPUT.equals($RIGHT_INPUT); + } else { + $header + ${ctx.reuseLocalVariableCode()} + ${codes.mkString("\n")} + return true; + } + } + + @Override + public boolean equalsWithoutHeader($BASE_ROW $LEFT_INPUT, $BASE_ROW $RIGHT_INPUT) { + if ($LEFT_INPUT instanceof $BINARY_ROW && $RIGHT_INPUT instanceof $BINARY_ROW) { + return (($BINARY_ROW)$LEFT_INPUT).equalsWithoutHeader((($BINARY_ROW)$RIGHT_INPUT)); + } else { + ${ctx.reuseLocalVariableCode()} + ${codes.mkString("\n")} + return true; + } + } + } + """.stripMargin + + new GeneratedRecordEqualiser(className, functionCode, ctx.references.toArray) + } + + private def isInternalPrimitive(t: InternalType): Boolean = t match { + case _: PrimitiveType => true + + case _: DateType => true + case TimeType.INSTANCE => true + case _: TimestampType => true + + case _ => false + } + + private def isBaseRow(t: InternalType): Boolean = t match { + case _: RowType => true + case _ => false + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/SinkCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/SinkCodeGenerator.scala index 29431c2fb64582..ce388c8bc9f7fb 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/SinkCodeGenerator.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/SinkCodeGenerator.scala @@ -31,6 +31,7 @@ import org.apache.flink.table.`type`.TypeConverters.createInternalTypeFromTypeIn import org.apache.flink.table.api.{Table, TableConfig, TableException, Types} import org.apache.flink.table.codegen.CodeGenUtils.{BASE_ROW, genToExternal} import org.apache.flink.table.codegen.OperatorCodeGenerator.generateCollect +import org.apache.flink.table.dataformat.util.BaseRowUtil import org.apache.flink.table.dataformat.{BaseRow, GenericRow} import org.apache.flink.table.runtime.OneInputOperatorWrapper import org.apache.flink.table.sinks.{DataStreamTableSink, TableSink} @@ -93,6 +94,10 @@ object SinkCodeGenerator { new RowTypeInfo( inputTypeInfo.getFieldTypes, inputTypeInfo.getFieldNames) + case gt: GenericTypeInfo[BaseRow] if gt.getTypeClass == classOf[BaseRow] => + new BaseRowTypeInfo( + inputTypeInfo.getInternalTypes, + inputTypeInfo.getFieldNames) case _ => requestedTypeInfo } @@ -154,13 +159,14 @@ object SinkCodeGenerator { val retractProcessCode = if (!withChangeFlag) { generateCollect(genToExternal(ctx, outputTypeInfo, afterIndexModify)) } else { - val flagResultTerm = s"$afterIndexModify.getHeader() == $BASE_ROW.ACCUMULATE_MSG" + val flagResultTerm = + s"${classOf[BaseRowUtil].getCanonicalName}.isAccumulateMsg($afterIndexModify)" val resultTerm = CodeGenUtils.newName("result") val genericRowField = classOf[GenericRow].getCanonicalName s""" |$genericRowField $resultTerm = new $genericRowField(2); - |$resultTerm.update(0, $flagResultTerm); - |$resultTerm.update(1, $afterIndexModify); + |$resultTerm.setField(0, $flagResultTerm); + |$resultTerm.setField(1, $afterIndexModify); |${generateCollect(genToExternal(ctx, outputTypeInfo, resultTerm))} """.stripMargin } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/SortCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/SortCodeGenerator.scala index 0f6b2d230c09d8..dbebdab4bbaed6 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/SortCodeGenerator.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/SortCodeGenerator.scala @@ -478,6 +478,7 @@ class SortCodeGenerator( case InternalTypes.FLOAT => 4 case InternalTypes.DOUBLE => 8 case InternalTypes.LONG => 8 + case _: TimestampType => 8 case dt: DecimalType if Decimal.isCompact(dt.precision()) => 8 case InternalTypes.STRING | InternalTypes.BINARY => Int.MaxValue } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/ValuesCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/ValuesCodeGenerator.scala index c207bd685a52f2..86b79a9e590a0c 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/ValuesCodeGenerator.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/ValuesCodeGenerator.scala @@ -22,7 +22,6 @@ import org.apache.flink.table.api.TableEnvironment import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.dataformat.{BaseRow, GenericRow} import org.apache.flink.table.runtime.values.ValuesInputFormat -import org.apache.flink.table.typeutils.BaseRowTypeInfo import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rex.RexLiteral @@ -56,8 +55,7 @@ object ValuesCodeGenerator { generatedRecords.map(_.code), outputType) - val baseRowTypeInfo = new BaseRowTypeInfo(outputType.getFieldTypes, outputType.getFieldNames) - new ValuesInputFormat(generatedFunction, baseRowTypeInfo) + new ValuesInputFormat(generatedFunction, outputType.toTypeInfo) } } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/calcite/LogicalRank.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/calcite/LogicalRank.scala index de4238ea29c27e..e9e4c0231367b1 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/calcite/LogicalRank.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/calcite/LogicalRank.scala @@ -18,7 +18,7 @@ package org.apache.flink.table.plan.nodes.calcite -import org.apache.flink.table.plan.nodes.calcite.RankType.RankType +import org.apache.flink.table.runtime.rank.{RankRange, RankType} import org.apache.calcite.plan.{Convention, RelOptCluster, RelTraitSet} import org.apache.calcite.rel.`type`.RelDataTypeField diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/calcite/Rank.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/calcite/Rank.scala index 9e19e028e3dbc4..0ab70a3406a656 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/calcite/Rank.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/calcite/Rank.scala @@ -19,8 +19,8 @@ package org.apache.flink.table.plan.nodes.calcite import org.apache.flink.table.api.TableException -import org.apache.flink.table.plan.nodes.calcite.RankType.RankType import org.apache.flink.table.plan.util._ +import org.apache.flink.table.runtime.rank.{ConstantRankRange, RankRange, RankType, VariableRankRange} import org.apache.calcite.plan.{RelOptCluster, RelOptCost, RelOptPlanner, RelTraitSet} import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeField} @@ -71,18 +71,19 @@ abstract class Rank( rankRange match { case r: ConstantRankRange => - if (r.rankEnd <= 0) { - throw new TableException(s"Rank end can't smaller than zero. The rank end is ${r.rankEnd}") + if (r.getRankEnd <= 0) { + throw new TableException( + s"Rank end can't smaller than zero. The rank end is ${r.getRankEnd}") } - if (r.rankStart > r.rankEnd) { + if (r.getRankStart > r.getRankEnd) { throw new TableException( - s"Rank start '${r.rankStart}' can't greater than rank end '${r.rankEnd}'.") + s"Rank start '${r.getRankStart}' can't greater than rank end '${r.getRankEnd}'.") } case v: VariableRankRange => - if (v.rankEndIndex < 0) { + if (v.getRankEndIndex < 0) { throw new TableException(s"Rank end index can't smaller than zero.") } - if (v.rankEndIndex >= input.getRowType.getFieldCount) { + if (v.getRankEndIndex >= input.getRowType.getFieldCount) { throw new TableException(s"Rank end index can't greater than input field count.") } } @@ -105,7 +106,7 @@ abstract class Rank( }.mkString(", ") super.explainTerms(pw) .item("rankType", rankType) - .item("rankRange", rankRange.toString()) + .item("rankRange", rankRange) .item("partitionBy", partitionKey.map(i => s"$$$i").mkString(",")) .item("orderBy", RelExplainUtil.collationToString(orderKey)) .item("select", select) @@ -134,64 +135,3 @@ abstract class Rank( } } - -/** - * An enumeration of rank type, usable to tell the [[Rank]] node how exactly generate rank number. - */ -object RankType extends Enumeration { - type RankType = Value - - /** - * Returns a unique sequential number for each row within the partition based on the order, - * starting at 1 for the first row in each partition and without repeating or skipping - * numbers in the ranking result of each partition. If there are duplicate values within the - * row set, the ranking numbers will be assigned arbitrarily. - */ - val ROW_NUMBER: RankType.Value = Value - - /** - * Returns a unique rank number for each distinct row within the partition based on the order, - * starting at 1 for the first row in each partition, with the same rank for duplicate values - * and leaving gaps between the ranks; this gap appears in the sequence after the duplicate - * values. - */ - val RANK: RankType.Value = Value - - /** - * is similar to the RANK by generating a unique rank number for each distinct row - * within the partition based on the order, starting at 1 for the first row in each partition, - * ranking the rows with equal values with the same rank number, except that it does not skip - * any rank, leaving no gaps between the ranks. - */ - val DENSE_RANK: RankType.Value = Value -} - -sealed trait RankRange extends Serializable { - def toString(inputFieldNames: Seq[String]): String -} - -/** [[ConstantRankRangeWithoutEnd]] is a RankRange which not specify RankEnd. */ -case class ConstantRankRangeWithoutEnd(rankStart: Long) extends RankRange { - override def toString(inputFieldNames: Seq[String]): String = this.toString - - override def toString: String = s"rankStart=$rankStart" -} - -/** rankStart and rankEnd are inclusive, rankStart always start from one. */ -case class ConstantRankRange(rankStart: Long, rankEnd: Long) extends RankRange { - - override def toString(inputFieldNames: Seq[String]): String = this.toString - - override def toString: String = s"rankStart=$rankStart, rankEnd=$rankEnd" -} - -/** changing rank limit depends on input */ -case class VariableRankRange(rankEndIndex: Int) extends RankRange { - override def toString(inputFieldNames: Seq[String]): String = { - s"rankEnd=${inputFieldNames(rankEndIndex)}" - } - - override def toString: String = { - s"rankEnd=$$$rankEndIndex" - } -} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalRank.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalRank.scala index 806dadfa0edd0c..a4d58a1f658458 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalRank.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalRank.scala @@ -18,9 +18,9 @@ package org.apache.flink.table.plan.nodes.logical import org.apache.flink.table.plan.nodes.FlinkConventions -import org.apache.flink.table.plan.nodes.calcite.RankType.RankType -import org.apache.flink.table.plan.nodes.calcite.{LogicalRank, Rank, RankRange} +import org.apache.flink.table.plan.nodes.calcite.{LogicalRank, Rank} import org.apache.flink.table.plan.util.RelExplainUtil +import org.apache.flink.table.runtime.rank.{RankRange, RankType} import org.apache.calcite.plan._ import org.apache.calcite.rel.`type`.RelDataTypeField @@ -59,7 +59,7 @@ class FlinkLogicalRank( with FlinkLogicalRel { override def explainTerms(pw: RelWriter): RelWriter = { - val inputFieldNames = input.getRowType.getFieldNames + val inputFieldNames = getInput.getRowType.getFieldNames pw.item("input", getInput) .item("rankType", rankType) .item("rankRange", rankRange.toString(inputFieldNames)) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecRank.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecRank.scala index d21e9750c66938..b29291475726e7 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecRank.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecRank.scala @@ -20,9 +20,9 @@ package org.apache.flink.table.plan.nodes.physical.batch import org.apache.flink.table.api.TableException import org.apache.flink.table.plan.cost.{FlinkCost, FlinkCostFactory} -import org.apache.flink.table.plan.nodes.calcite.RankType.RankType -import org.apache.flink.table.plan.nodes.calcite.{ConstantRankRange, Rank, RankRange, RankType} +import org.apache.flink.table.plan.nodes.calcite.Rank import org.apache.flink.table.plan.util.RelExplainUtil +import org.apache.flink.table.runtime.rank.{ConstantRankRange, RankRange, RankType} import org.apache.calcite.plan._ import org.apache.calcite.rel._ @@ -64,7 +64,7 @@ class BatchExecRank( require(rankType == RankType.RANK, "Only RANK is supported now") val (rankStart, rankEnd) = rankRange match { - case r: ConstantRankRange => (r.rankStart, r.rankEnd) + case r: ConstantRankRange => (r.getRankStart, r.getRankEnd) case o => throw new TableException(s"$o is not supported now") } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDataStreamScan.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDataStreamScan.scala index 0ed3f5550b1d10..7f94a766adc4fc 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDataStreamScan.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDataStreamScan.scala @@ -31,6 +31,7 @@ import org.apache.flink.table.functions.sql.StreamRecordTimestampSqlFunction import org.apache.flink.table.plan.nodes.exec.{ExecNode, StreamExecNode} import org.apache.flink.table.plan.schema.DataStreamTable import org.apache.flink.table.plan.util.ScanUtil +import org.apache.flink.table.runtime.AbstractProcessStreamOperator import org.apache.flink.table.typeutils.TimeIndicatorTypeInfo.{ROWTIME_INDICATOR, ROWTIME_STREAM_MARKER} import org.apache.calcite.plan._ @@ -112,16 +113,17 @@ class StreamExecDataStreamScan( } else { ("", "") } - + val ctx = CodeGeneratorContext(config).setOperatorBaseClass( + classOf[AbstractProcessStreamOperator[BaseRow]]) ScanUtil.convertToInternalRow( - CodeGeneratorContext(config), + ctx, transform, dataStreamTable.fieldIndexes, dataStreamTable.typeInfo, getRowType, getTable.getQualifiedName, config, - None, + rowtimeExpr, beforeConvert = extractElement, afterConvert = resetElement) } else { diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDeduplicate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDeduplicate.scala index e76af21841defa..3aebe27d07a133 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDeduplicate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDeduplicate.scala @@ -18,12 +18,28 @@ package org.apache.flink.table.plan.nodes.physical.stream +import org.apache.flink.streaming.api.transformations.{OneInputTransformation, StreamTransformation} +import org.apache.flink.table.api.{StreamTableEnvironment, TableConfigOptions, TableException} +import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.codegen.EqualiserCodeGenerator +import org.apache.flink.table.dataformat.BaseRow +import org.apache.flink.table.generated.GeneratedRecordEqualiser +import org.apache.flink.table.plan.nodes.exec.{ExecNode, StreamExecNode} +import org.apache.flink.table.plan.util.KeySelectorUtil +import org.apache.flink.table.runtime.deduplicate.DeduplicateFunction +import org.apache.flink.table.runtime.keyed.KeyedProcessOperator +import org.apache.flink.table.`type`.TypeConverters +import org.apache.flink.table.typeutils.BaseRowTypeInfo +import org.apache.flink.table.typeutils.TypeCheckUtils.isRowTime + import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} -import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} +import org.apache.calcite.rel.`type`.RelDataType import java.util +import scala.collection.JavaConversions._ + /** * Stream physical RelNode which deduplicate on keys and keeps only first row or last row. * This node is an optimization of [[StreamExecRank]] for some special cases. @@ -38,7 +54,8 @@ class StreamExecDeduplicate( isRowtime: Boolean, keepLastRow: Boolean) extends SingleRel(cluster, traitSet, inputRel) - with StreamPhysicalRel { + with StreamPhysicalRel + with StreamExecNode[BaseRow] { def getUniqueKeys: Array[Int] = uniqueKeys @@ -73,4 +90,83 @@ class StreamExecDeduplicate( .item("order", orderString) } + //~ ExecNode methods ----------------------------------------------------------- + + override protected def translateToPlanInternal( + tableEnv: StreamTableEnvironment): StreamTransformation[BaseRow] = { + + // TODO checkInput is not acc retract after FLINK- is done + val inputIsAccRetract = false + + if (inputIsAccRetract) { + throw new TableException( + "Deduplicate: Retraction on Deduplicate is not supported yet.\n" + + "please re-check sql grammar. \n" + + "Note: Deduplicate should not follow a non-windowed GroupBy aggregation.") + } + + val inputTransform = getInputNodes.get(0).translateToPlan(tableEnv) + .asInstanceOf[StreamTransformation[BaseRow]] + + val rowTypeInfo = inputTransform.getOutputType.asInstanceOf[BaseRowTypeInfo] + + val generateRetraction = true + + val inputRowType = FlinkTypeFactory.toInternalRowType(getInput.getRowType) + val rowTimeFieldIndex = inputRowType.getFieldTypes.zipWithIndex + .filter(e => isRowTime(e._1)) + .map(_._2) + if (rowTimeFieldIndex.size > 1) { + throw new RuntimeException("More than one row time field. Currently this is not supported!") + } + if (rowTimeFieldIndex.nonEmpty) { + throw new TableException("Currently not support Deduplicate on rowtime.") + } + val tableConfig = tableEnv.getConfig + val exeConfig = tableEnv.execEnv.getConfig + val isMiniBatchEnabled = tableConfig.getConf.contains( + TableConfigOptions.SQL_EXEC_MINIBATCH_ALLOW_LATENCY) + val generatedRecordEqualiser = generateRecordEqualiser(rowTypeInfo) + // TODO use MiniBatchDeduplicateFunction if miniBatch is enabled + val minRetentionTime = tableConfig.getMinIdleStateRetentionTime + val maxRetentionTime = tableConfig.getMaxIdleStateRetentionTime + val processFunction = new DeduplicateFunction( + minRetentionTime, + maxRetentionTime, + rowTypeInfo, + generateRetraction, + keepLastRow, + generatedRecordEqualiser) + val operator = new KeyedProcessOperator[BaseRow, BaseRow, BaseRow](processFunction) + val ret = new OneInputTransformation( + inputTransform, + getOperatorName, + operator, + rowTypeInfo, + inputTransform.getParallelism) + val selector = KeySelectorUtil.getBaseRowSelector(uniqueKeys, rowTypeInfo) + ret.setStateKeySelector(selector) + ret.setStateKeyType(selector.getProducedType) + ret + } + + override def getInputNodes: util.List[ExecNode[StreamTableEnvironment, _]] = { + List(getInput.asInstanceOf[ExecNode[StreamTableEnvironment, _]]) + } + + private def getOperatorName: String = { + val fieldNames = getRowType.getFieldNames + val keyNames = uniqueKeys.map(fieldNames.get).mkString(", ") + val orderString = if (isRowtime) "ROWTIME" else "PROCTIME" + s"${if (keepLastRow) "keepLastRow" else "KeepFirstRow"}" + + s": (key: ($keyNames), select: (${fieldNames.mkString(", ")}), order: ($orderString))" + } + + private def generateRecordEqualiser(rowTypeInfo: BaseRowTypeInfo): GeneratedRecordEqualiser = { + val generator = new EqualiserCodeGenerator( + rowTypeInfo.getFieldTypes.map(TypeConverters.createInternalTypeFromTypeInfo)) + val equaliserName = s"${if (keepLastRow) "LastRow" else "FirstRow"}ValueEqualiser" + generator.generateRecordEqualiser(equaliserName) + } + } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecExchange.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecExchange.scala index b503fb2d9bdedd..e1c431b0440f37 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecExchange.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecExchange.scala @@ -19,10 +19,22 @@ package org.apache.flink.table.plan.nodes.physical.stream import org.apache.flink.table.plan.nodes.common.CommonPhysicalExchange +import org.apache.flink.table.dataformat.BaseRow +import org.apache.flink.table.plan.nodes.exec.{ExecNode, StreamExecNode} +import org.apache.flink.streaming.api.transformations.{PartitionTransformation, StreamTransformation} +import org.apache.flink.streaming.runtime.partitioner.{GlobalPartitioner, KeyGroupStreamPartitioner, StreamPartitioner} +import org.apache.flink.table.api.StreamTableEnvironment +import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.plan.util.KeySelectorUtil +import org.apache.flink.table.typeutils.BaseRowTypeInfo import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} import org.apache.calcite.rel.{RelDistribution, RelNode} +import java.util + +import scala.collection.JavaConversions._ + /** * Stream physical RelNode for [[org.apache.calcite.rel.core.Exchange]]. */ @@ -32,7 +44,10 @@ class StreamExecExchange( relNode: RelNode, relDistribution: RelDistribution) extends CommonPhysicalExchange(cluster, traitSet, relNode, relDistribution) - with StreamPhysicalRel { + with StreamPhysicalRel + with StreamExecNode[BaseRow] { + + private val DEFAULT_MAX_PARALLELISM = 1 << 7 override def producesUpdates: Boolean = false @@ -50,4 +65,41 @@ class StreamExecExchange( newDistribution: RelDistribution): StreamExecExchange = { new StreamExecExchange(cluster, traitSet, newInput, newDistribution) } + + //~ ExecNode methods ----------------------------------------------------------- + + override def getInputNodes: util.List[ExecNode[StreamTableEnvironment, _]] = { + List(getInput.asInstanceOf[ExecNode[StreamTableEnvironment, _]]) + } + + override protected def translateToPlanInternal( + tableEnv: StreamTableEnvironment): StreamTransformation[BaseRow] = { + val inputTransform = getInputNodes.get(0).translateToPlan(tableEnv) + .asInstanceOf[StreamTransformation[BaseRow]] + val inputTypeInfo = inputTransform.getOutputType.asInstanceOf[BaseRowTypeInfo] + val outputTypeInfo = FlinkTypeFactory.toInternalRowType(getRowType).toTypeInfo + relDistribution.getType match { + case RelDistribution.Type.SINGLETON => + val partitioner = new GlobalPartitioner[BaseRow] + val transformation = new PartitionTransformation( + inputTransform, + partitioner.asInstanceOf[StreamPartitioner[BaseRow]]) + transformation.setOutputType(outputTypeInfo) + transformation + case RelDistribution.Type.HASH_DISTRIBUTED => + // TODO Eliminate duplicate keys + val selector = KeySelectorUtil.getBaseRowSelector( + relDistribution.getKeys.map(_.toInt).toArray, inputTypeInfo) + val partitioner = new KeyGroupStreamPartitioner(selector, DEFAULT_MAX_PARALLELISM) + val transformation = new PartitionTransformation( + inputTransform, + partitioner.asInstanceOf[StreamPartitioner[BaseRow]]) + transformation.setOutputType(outputTypeInfo) + transformation + case _ => + throw new UnsupportedOperationException( + s"not support RelDistribution: ${relDistribution.getType} now!") + } + } + } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecRank.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecRank.scala index ac833b548ea656..fdd0a471dc0f62 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecRank.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecRank.scala @@ -17,9 +17,17 @@ */ package org.apache.flink.table.plan.nodes.physical.stream -import org.apache.flink.table.plan.nodes.calcite.RankType.RankType -import org.apache.flink.table.plan.nodes.calcite.{Rank, RankRange} -import org.apache.flink.table.plan.util.{RankProcessStrategy, RelExplainUtil, RetractStrategy} +import org.apache.flink.streaming.api.transformations.{OneInputTransformation, StreamTransformation} +import org.apache.flink.table.api.{StreamTableEnvironment, TableConfigOptions, TableException} +import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.codegen.{EqualiserCodeGenerator, SortCodeGenerator} +import org.apache.flink.table.dataformat.BaseRow +import org.apache.flink.table.plan.nodes.calcite.Rank +import org.apache.flink.table.plan.nodes.exec.{ExecNode, StreamExecNode} +import org.apache.flink.table.plan.util._ +import org.apache.flink.table.runtime.keyed.KeyedProcessOperator +import org.apache.flink.table.runtime.rank._ +import org.apache.flink.table.typeutils.BaseRowTypeInfo import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} import org.apache.calcite.rel._ @@ -53,7 +61,8 @@ class StreamExecRank( rankRange, rankNumberType, outputRankNumber) - with StreamPhysicalRel { + with StreamPhysicalRel + with StreamExecNode[BaseRow] { /** please uses [[getStrategy]] instead of this field */ private var strategy: RankProcessStrategy = _ @@ -101,4 +110,129 @@ class StreamExecRank( .item("orderBy", RelExplainUtil.collationToString(orderKey, inputRowType)) .item("select", getRowType.getFieldNames.mkString(", ")) } + + + //~ ExecNode methods ----------------------------------------------------------- + + override def getInputNodes: util.List[ExecNode[StreamTableEnvironment, _]] = { + List(getInput.asInstanceOf[ExecNode[StreamTableEnvironment, _]]) + } + + override protected def translateToPlanInternal( + tableEnv: StreamTableEnvironment): StreamTransformation[BaseRow] = { + val tableConfig = tableEnv.getConfig + rankType match { + case RankType.ROW_NUMBER => // ignore + case RankType.RANK => + throw new TableException("RANK() on streaming table is not supported currently") + case RankType.DENSE_RANK => + throw new TableException("DENSE_RANK() on streaming table is not supported currently") + case k => + throw new TableException(s"Streaming tables do not support $k rank function.") + } + + val inputRowTypeInfo = FlinkTypeFactory.toInternalRowType(getInput.getRowType).toTypeInfo + val outputRowTypeInfo = FlinkTypeFactory.toInternalRowType(getRowType).toTypeInfo + + val fieldCollations = orderKey.getFieldCollations + val (sortFields, sortDirections, nullsIsLast) = SortUtil.getKeysAndOrders(fieldCollations) + val sortKeySelector = KeySelectorUtil.getBaseRowSelector(sortFields, inputRowTypeInfo) + val sortKeyType = sortKeySelector.getProducedType + val sortCodeGen = new SortCodeGenerator( + tableConfig, sortFields.indices.toArray, sortKeyType.getInternalTypes, + sortDirections, nullsIsLast) + val comparator = sortCodeGen.generateRecordComparator("StreamExecSortComparator") + // TODO infer generate retraction after FLINK- is done + val generateRetraction = true + val cacheSize = tableConfig.getConf.getLong(TableConfigOptions.SQL_EXEC_TOPN_CACHE_SIZE) + val minIdleStateRetentionTime = tableConfig.getMinIdleStateRetentionTime + val maxIdleStateRetentionTime = tableConfig.getMaxIdleStateRetentionTime + val equaliserCodeGenerator = new EqualiserCodeGenerator(inputRowTypeInfo.getInternalTypes) + val equaliser = equaliserCodeGenerator.generateRecordEqualiser("RankValueEqualiser") + val processFunction = getStrategy(true) match { + case AppendFastStrategy => + new AppendRankFunction( + minIdleStateRetentionTime, + maxIdleStateRetentionTime, + inputRowTypeInfo, + outputRowTypeInfo, + sortKeyType, + comparator, + sortKeySelector, + rankType, + rankRange, + equaliser, + generateRetraction, + cacheSize) + + case UpdateFastStrategy(primaryKeys) => + val internalTypes = inputRowTypeInfo.getInternalTypes + val fieldTypes = primaryKeys.map(internalTypes) + val rowKeyType = new BaseRowTypeInfo(fieldTypes: _*) + val rowKeySelector = KeySelectorUtil.getBaseRowSelector(primaryKeys, inputRowTypeInfo) + new UpdateRankFunction( + minIdleStateRetentionTime, + maxIdleStateRetentionTime, + inputRowTypeInfo, + outputRowTypeInfo, + rowKeyType, + rowKeySelector, + comparator, + sortKeySelector, + rankType, + rankRange, + equaliser, + generateRetraction, + cacheSize) + + // TODO UnaryUpdateRank after SortedMapState is merged + case RetractStrategy | UnaryUpdateStrategy(_) => + new RetractRankFunction( + minIdleStateRetentionTime, + maxIdleStateRetentionTime, + inputRowTypeInfo, + outputRowTypeInfo, + sortKeyType, + comparator, + sortKeySelector, + rankType, + rankRange, + equaliser, + generateRetraction) + } + val rankOpName = getOperatorName + val operator = new KeyedProcessOperator(processFunction) + val inputTransform = getInputNodes.get(0).translateToPlan(tableEnv) + .asInstanceOf[StreamTransformation[BaseRow]] + val ret = new OneInputTransformation( + inputTransform, + rankOpName, + operator, + outputRowTypeInfo, + inputTransform.getParallelism) + + if (partitionKey.isEmpty) { + ret.setParallelism(1) + ret.setMaxParallelism(1) + } + + // set KeyType and Selector for state + val selector = KeySelectorUtil.getBaseRowSelector(partitionKey.toArray, inputRowTypeInfo) + ret.setStateKeySelector(selector) + ret.setStateKeyType(selector.getProducedType) + ret + } + + private def getOperatorName: String = { + val inputRowType = inputRel.getRowType + var result = getStrategy().toString + result += s"(orderBy: (${RelExplainUtil.collationToString(orderKey, inputRowType)})" + if (partitionKey.nonEmpty) { + val partitionKeys = partitionKey.toArray + result += s", partitionBy: (${RelExplainUtil.fieldToString(partitionKeys, inputRowType)})" + } + result += s", ${getRowType.getFieldNames.mkString(", ")}" + result += s", ${rankRange.toString(inputRowType.getFieldNames)})" + result + } } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkLogicalRankRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkLogicalRankRule.scala index 3a53e2c1223d5c..3b21083e2a608a 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkLogicalRankRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkLogicalRankRule.scala @@ -19,9 +19,9 @@ package org.apache.flink.table.plan.rules.logical import org.apache.flink.table.api.TableException import org.apache.flink.table.calcite.FlinkContext -import org.apache.flink.table.plan.nodes.calcite.{ConstantRankRange, ConstantRankRangeWithoutEnd, RankType} import org.apache.flink.table.plan.nodes.logical.{FlinkLogicalCalc, FlinkLogicalOverWindow, FlinkLogicalRank} import org.apache.flink.table.plan.util.RankUtil +import org.apache.flink.table.runtime.rank.{ConstantRankRange, ConstantRankRangeWithoutEnd, RankType} import org.apache.calcite.plan.RelOptRule.{any, operand} import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelOptUtil} @@ -85,8 +85,9 @@ abstract class FlinkLogicalRankRuleBase require(rankNumberType.isDefined) rankRange match { - case Some(ConstantRankRange(_, rankEnd)) if rankEnd <= 0 => - throw new TableException(s"Rank end should not less than zero, but now is $rankEnd") + case Some(crr: ConstantRankRange) if crr.getRankEnd <= 0 => + throw new TableException( + s"Rank end should not less than zero, but now is ${crr.getRankEnd}") case _ => // do nothing } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecRankRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecRankRule.scala index 1e8950b6091ea4..918211b9cda67c 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecRankRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecRankRule.scala @@ -21,10 +21,10 @@ package org.apache.flink.table.plan.rules.physical.batch import org.apache.flink.table.api.TableException import org.apache.flink.table.plan.`trait`.FlinkRelDistribution import org.apache.flink.table.plan.nodes.FlinkConventions -import org.apache.flink.table.plan.nodes.calcite.{ConstantRankRange, RankType} import org.apache.flink.table.plan.nodes.logical.FlinkLogicalRank import org.apache.flink.table.plan.nodes.physical.batch.BatchExecRank import org.apache.flink.table.plan.util.RelFieldCollationUtil +import org.apache.flink.table.runtime.rank.{ConstantRankRange, RankType} import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall} import org.apache.calcite.rel.convert.ConverterRule @@ -58,7 +58,7 @@ class BatchExecRankRule def convert(rel: RelNode): RelNode = { val rank = rel.asInstanceOf[FlinkLogicalRank] val (_, rankEnd) = rank.rankRange match { - case r: ConstantRankRange => (r.rankStart, r.rankEnd) + case r: ConstantRankRange => (r.getRankStart, r.getRankEnd) case o => throw new TableException(s"$o is not supported now") } @@ -71,7 +71,7 @@ class BatchExecRankRule val newLocalInput = RelOptRule.convert(rank.getInput, localRequiredTraitSet) // create local BatchExecRank - val localRankRange = ConstantRankRange(1, rankEnd) // local rank always start from 1 + val localRankRange = new ConstantRankRange(1, rankEnd) // local rank always start from 1 val localRank = new BatchExecRank( cluster, emptyTraits, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecDeduplicateRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecDeduplicateRule.scala index b07ba7a0409893..609bae8b7d34f8 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecDeduplicateRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecDeduplicateRule.scala @@ -21,9 +21,9 @@ package org.apache.flink.table.plan.rules.physical.stream import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.plan.`trait`.FlinkRelDistribution import org.apache.flink.table.plan.nodes.FlinkConventions -import org.apache.flink.table.plan.nodes.calcite.{ConstantRankRange, RankType} import org.apache.flink.table.plan.nodes.logical.FlinkLogicalRank import org.apache.flink.table.plan.nodes.physical.stream.{StreamExecDeduplicate, StreamExecRank} +import org.apache.flink.table.runtime.rank.{ConstantRankRange, RankType} import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall} import org.apache.calcite.rel.`type`.RelDataType @@ -111,7 +111,8 @@ object StreamExecDeduplicateRule { val isRowNumberType = rank.rankType == RankType.ROW_NUMBER val isLimit1 = rankRange match { - case ConstantRankRange(rankStart, rankEnd) => rankStart == 1 && rankEnd == 1 + case rankRange: ConstantRankRange => + rankRange.getRankStart() == 1 && rankRange.getRankEnd() == 1 case _ => false } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/FlinkRelMdUtil.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/FlinkRelMdUtil.scala index fbb16f50260cba..b88cfffbdc2f0c 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/FlinkRelMdUtil.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/FlinkRelMdUtil.scala @@ -20,7 +20,7 @@ package org.apache.flink.table.plan.util import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.dataformat.BinaryRow -import org.apache.flink.table.plan.nodes.calcite.{ConstantRankRange, RankRange} +import org.apache.flink.table.runtime.rank.{ConstantRankRange, RankRange} import org.apache.flink.table.runtime.sort.BinaryIndexedSortable import org.apache.flink.table.typeutils.BinaryRowSerializer @@ -37,7 +37,7 @@ import scala.collection.JavaConversions._ object FlinkRelMdUtil { def getRankRangeNdv(rankRange: RankRange): Double = rankRange match { - case r: ConstantRankRange => (r.rankEnd - r.rankStart + 1).toDouble + case r: ConstantRankRange => (r.getRankEnd - r.getRankStart + 1).toDouble case _ => 100D // default value now } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/RankUtil.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/RankUtil.scala index 338564c42c164f..f7f7378aa5dffa 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/RankUtil.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/RankUtil.scala @@ -19,7 +19,9 @@ package org.apache.flink.table.plan.util import org.apache.flink.table.api.{PlannerConfigOptions, TableConfig} -import org.apache.flink.table.plan.nodes.calcite.{ConstantRankRange, ConstantRankRangeWithoutEnd, Rank, RankRange, VariableRankRange} +import org.apache.flink.table.plan.nodes.calcite.Rank +import org.apache.flink.table.runtime.rank.{ConstantRankRange, ConstantRankRangeWithoutEnd, RankRange, VariableRankRange} +import org.apache.flink.table.typeutils.BaseRowTypeInfo import org.apache.calcite.plan.RelOptUtil import org.apache.calcite.rex.{RexBuilder, RexCall, RexInputRef, RexLiteral, RexNode, RexUtil} @@ -98,20 +100,20 @@ object RankUtil { val sortBounds = limitPreds.map(computeWindowBoundFromPredicate(_, rexBuilder, config)) val rankRange = sortBounds match { case Seq(Some(LowerBoundary(x)), Some(UpperBoundary(y))) => - ConstantRankRange(x, y) + new ConstantRankRange(x, y) case Seq(Some(UpperBoundary(x)), Some(LowerBoundary(y))) => - ConstantRankRange(y, x) + new ConstantRankRange(y, x) case Seq(Some(LowerBoundary(x))) => // only offset - ConstantRankRangeWithoutEnd(x) + new ConstantRankRangeWithoutEnd(x) case Seq(Some(UpperBoundary(x))) => // rankStart starts from one - ConstantRankRange(1, x) + new ConstantRankRange(1, x) case Seq(Some(BothBoundary(x, y))) => // nth rank - ConstantRankRange(x, y) + new ConstantRankRange(x, y) case Seq(Some(InputRefBoundary(x))) => - VariableRankRange(x) + new VariableRankRange(x) case _ => // TopN requires at least one rank comparison predicate return (None, Some(predicate)) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/SortUtil.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/SortUtil.scala new file mode 100644 index 00000000000000..854161aa7a71f7 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/SortUtil.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.util + +import org.apache.flink.api.common.operators.Order +import org.apache.flink.table.api.TableException + +import org.apache.calcite.rel.RelFieldCollation +import org.apache.calcite.rel.RelFieldCollation.Direction + +import scala.collection.mutable + +/** + * Common methods for Flink sort operators. + */ +object SortUtil { + + def getKeysAndOrders( + fieldCollations: Seq[RelFieldCollation]): (Array[Int], Array[Boolean], Array[Boolean]) = { + val fieldMappingDirections = fieldCollations map { + c => (c.getFieldIndex, directionToOrder(c.getDirection)) + } + val keys = fieldMappingDirections.map(_._1) + val orders = fieldMappingDirections.map(_._2 == Order.ASCENDING) + val nullsIsLast = fieldCollations.map(_.nullDirection).map { + case RelFieldCollation.NullDirection.LAST => true + case RelFieldCollation.NullDirection.FIRST => false + case RelFieldCollation.NullDirection.UNSPECIFIED => + throw new TableException(s"Do not support UNSPECIFIED for null order.") + }.toArray + + deduplicateSortKeys(keys.toArray, orders.toArray, nullsIsLast) + } + + def deduplicateSortKeys( + keys: Array[Int], + orders: Array[Boolean], + nullsIsLast: Array[Boolean]): (Array[Int], Array[Boolean], Array[Boolean]) = { + val keySet = new mutable.HashSet[Int] + val keyBuffer = new mutable.ArrayBuffer[Int] + val orderBuffer = new mutable.ArrayBuffer[Boolean] + val nullsIsLastBuffer = new mutable.ArrayBuffer[Boolean] + keys.indices foreach { i => + if (keySet.add(keys(i))) { + keyBuffer += keys(i) + orderBuffer += orders(i) + nullsIsLastBuffer += nullsIsLast(i) + } + } + (keyBuffer.toArray, orderBuffer.toArray, nullsIsLastBuffer.toArray) + } + + def directionToOrder(direction: Direction): Order = { + direction match { + case Direction.ASCENDING | Direction.STRICTLY_ASCENDING => Order.ASCENDING + case Direction.DESCENDING | Direction.STRICTLY_DESCENDING => Order.DESCENDING + case _ => throw new IllegalArgumentException("Unsupported direction.") + } + } + +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/typeutils/TypeCheckUtils.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/typeutils/TypeCheckUtils.scala index 32c9f4f61c3bb0..07b9d93abcf92d 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/typeutils/TypeCheckUtils.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/typeutils/TypeCheckUtils.scala @@ -18,8 +18,12 @@ package org.apache.flink.table.typeutils -import org.apache.flink.table.`type`._ +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.common.typeutils.CompositeType +import org.apache.flink.api.java.typeutils.PojoTypeInfo +import org.apache.flink.table.api.ValidationException import org.apache.flink.table.codegen.GeneratedExpression +import org.apache.flink.table.`type`._ object TypeCheckUtils { @@ -96,4 +100,63 @@ object TypeCheckUtils { def isReference(genExpr: GeneratedExpression): Boolean = isReference(genExpr.resultType) + /** + * Checks whether a type implements own hashCode() and equals() methods for storing an instance + * in Flink's state or performing a keyBy operation. + * + * @param name name of the operation. + * @param t type information to be validated + */ + def validateEqualsHashCode(name: String, t: TypeInformation[_]): Unit = t match { + + // make sure that a POJO class is a valid state type + case pt: PojoTypeInfo[_] => + // we don't check the types recursively to give a chance of wrapping + // proper hashCode/equals methods around an immutable type + validateEqualsHashCode(name, pt.getClass) + // BinaryRow direct hash in bytes, no need to check field types. + case bt: BaseRowTypeInfo => + // recursively check composite types + case ct: CompositeType[_] => + validateEqualsHashCode(name, t.getTypeClass) + // we check recursively for entering Flink types such as tuples and rows + for (i <- 0 until ct.getArity) { + val subtype = ct.getTypeAt(i) + validateEqualsHashCode(name, subtype) + } + // check other type information only based on the type class + case _: TypeInformation[_] => + validateEqualsHashCode(name, t.getTypeClass) + } + + /** + * Checks whether a class implements own hashCode() and equals() methods for storing an instance + * in Flink's state or performing a keyBy operation. + * + * @param name name of the operation + * @param c class to be validated + */ + def validateEqualsHashCode(name: String, c: Class[_]): Unit = { + + // skip primitives + if (!c.isPrimitive) { + // check the component type of arrays + if (c.isArray) { + validateEqualsHashCode(name, c.getComponentType) + } + // check type for methods + else { + if (c.getMethod("hashCode").getDeclaringClass eq classOf[Object]) { + throw new ValidationException( + s"Type '${c.getCanonicalName}' cannot be used in a $name operation because it " + + s"does not implement a proper hashCode() method.") + } + if (c.getMethod("equals", classOf[Object]).getDeclaringClass eq classOf[Object]) { + throw new ValidationException( + s"Type '${c.getCanonicalName}' cannot be used in a $name operation because it " + + s"does not implement a proper equals() method.") + } + } + } + } } diff --git a/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/runtime/utils/FailingCollectionSource.java b/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/runtime/utils/FailingCollectionSource.java new file mode 100644 index 00000000000000..726f3d9ae09dab --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/runtime/utils/FailingCollectionSource.java @@ -0,0 +1,269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.utils; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.runtime.state.CheckpointListener; +import org.apache.flink.runtime.state.FunctionInitializationContext; +import org.apache.flink.runtime.state.FunctionSnapshotContext; +import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; +import org.apache.flink.streaming.api.functions.source.SourceFunction; +import org.apache.flink.util.Preconditions; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.util.Preconditions.checkArgument; + +/** + * The FailingCollectionSource will fail after emitted a specified number of elements. This is used + * to perform checkpoint and restore in IT cases. + */ +public class FailingCollectionSource + implements SourceFunction, CheckpointedFunction, CheckpointListener { + + public static volatile boolean failedBefore = false; + + private static final long serialVersionUID = 1L; + + /** The (de)serializer to be used for the data elements. */ + private final TypeSerializer serializer; + + /** The actual data elements, in serialized form. */ + private final byte[] elementsSerialized; + + /** The number of serialized elements. */ + private final int numElements; + + /** The number of elements emitted already. */ + private volatile int numElementsEmitted; + + /** The number of elements to skip initially. */ + private volatile int numElementsToSkip; + + /** Flag to make the source cancelable. */ + private volatile boolean isRunning = true; + + private transient ListState checkpointedState; + + /** A failure will occur when the given number of elements have been processed. */ + private final int failureAfterNumElements; + + /** The number of completed checkpoints. */ + private volatile int numSuccessfulCheckpoints; + + /** The checkpointed number of emitted elements. */ + private final Map checkpointedEmittedNums; + + /** The last successful checkpointed number of emitted elements. */ + private volatile int lastCheckpointedEmittedNum = 0; + + /** Whether to perform a checkpoint before job finished. */ + private final boolean performCheckpointBeforeJobFinished; + + public FailingCollectionSource( + TypeSerializer serializer, + Iterable elements, + int failureAfterNumElements, + boolean performCheckpointBeforeJobFinished) throws IOException { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutputViewStreamWrapper wrapper = new DataOutputViewStreamWrapper(baos); + + int count = 0; + try { + for (T element : elements) { + serializer.serialize(element, wrapper); + count++; + } + } + catch (Exception e) { + throw new IOException("Serializing the source elements failed: " + e.getMessage(), e); + } + + this.serializer = serializer; + this.elementsSerialized = baos.toByteArray(); + this.numElements = count; + checkArgument(failureAfterNumElements > 0); + this.failureAfterNumElements = failureAfterNumElements; + this.performCheckpointBeforeJobFinished = performCheckpointBeforeJobFinished; + this.checkpointedEmittedNums = new HashMap<>(); + } + + @Override + public void initializeState(FunctionInitializationContext context) throws Exception { + Preconditions.checkState( + this.checkpointedState == null, + "The " + getClass().getSimpleName() + " has already been initialized."); + + this.checkpointedState = context.getOperatorStateStore().getListState( + new ListStateDescriptor<>( + "from-elements-state", + IntSerializer.INSTANCE + ) + ); + + if (failedBefore && context.isRestored()) { + List retrievedStates = new ArrayList<>(); + for (Integer entry : this.checkpointedState.get()) { + retrievedStates.add(entry); + } + + // given that the parallelism of the function is 1, we can only have 1 state + Preconditions.checkArgument( + retrievedStates.size() == 1, + getClass().getSimpleName() + " retrieved invalid state."); + + this.numElementsToSkip = retrievedStates.get(0); + } + } + + @Override + public void run(SourceContext ctx) throws Exception { + ByteArrayInputStream bais = new ByteArrayInputStream(elementsSerialized); + final DataInputView input = new DataInputViewStreamWrapper(bais); + + // if we are restored from a checkpoint and need to skip elements, skip them now. + int toSkip = numElementsToSkip; + if (toSkip > 0) { + try { + while (toSkip > 0) { + serializer.deserialize(input); + toSkip--; + } + } + catch (Exception e) { + throw new IOException( + "Failed to deserialize an element from the source. " + + "If you are using user-defined serialization (Value and Writable types), check the " + + "serialization functions.\nSerializer is " + serializer); + } + + this.numElementsEmitted = this.numElementsToSkip; + } + + while (isRunning && numElementsEmitted < numElements) { + if (!failedBefore) { + // delay a bit, if we have not failed before + Thread.sleep(1); + if (numSuccessfulCheckpoints >= 1 && lastCheckpointedEmittedNum >= failureAfterNumElements) { + // cause a failure if we have not failed before and have reached + // enough completed checkpoints and elements + failedBefore = true; + throw new Exception("Artificial Failure"); + } + } + + if (failedBefore || numElementsEmitted < failureAfterNumElements) { + // the function failed before, or we are in the elements before the failure + T next; + try { + next = serializer.deserialize(input); + } + catch (Exception e) { + throw new IOException( + "Failed to deserialize an element from the source. " + + "If you are using user-defined serialization (Value and Writable types), check the " + + "serialization functions.\nSerializer is " + serializer); + } + + synchronized (ctx.getCheckpointLock()) { + ctx.collect(next); + numElementsEmitted++; + } + } else { + // if our work is done, delay a bit to prevent busy waiting + Thread.sleep(1); + } + } + + if (performCheckpointBeforeJobFinished) { + while (isRunning) { + // wait until the latest checkpoint records everything + if (lastCheckpointedEmittedNum < numElements) { + // delay a bit to prevent busy waiting + Thread.sleep(1); + } else { + // cause a failure to retain the last checkpoint + throw new Exception("Job finished normally"); + } + } + } + } + + @Override + public void cancel() { + isRunning = false; + } + + + /** + * Gets the number of elements produced in total by this function. + * + * @return The number of elements produced in total. + */ + public int getNumElements() { + return numElements; + } + + /** + * Gets the number of elements emitted so far. + * + * @return The number of elements emitted so far. + */ + public int getNumElementsEmitted() { + return numElementsEmitted; + } + + // ------------------------------------------------------------------------ + // Checkpointing + // ------------------------------------------------------------------------ + + @Override + public void snapshotState(FunctionSnapshotContext context) throws Exception { + Preconditions.checkState( + this.checkpointedState != null, + "The " + getClass().getSimpleName() + " has not been properly initialized."); + + this.checkpointedState.clear(); + this.checkpointedState.add(this.numElementsEmitted); + long checkpointId = context.getCheckpointId(); + checkpointedEmittedNums.put(checkpointId, numElementsEmitted); + } + + @Override + public void notifyCheckpointComplete(long checkpointId) throws Exception { + numSuccessfulCheckpoints++; + lastCheckpointedEmittedNum = checkpointedEmittedNums.get(checkpointId); + } + + public static void reset() { + failedBefore = false; + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/DeduplicateITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/DeduplicateITCase.scala new file mode 100644 index 00000000000000..debeb5da3ffc04 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/DeduplicateITCase.scala @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.stream.sql + +import org.apache.flink.api.scala._ +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.runtime.utils._ +import org.apache.flink.table.runtime.utils.StreamingWithStateTestBase.StateBackendMode +import org.apache.flink.table.runtime.utils.TimeTestUtil.TimestampAndWatermarkWithOffset +import org.apache.flink.types.Row + +import org.junit.Assert._ +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized + +@RunWith(classOf[Parameterized]) +class DeduplicateITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode) { + + @Test + def testFirstRowOnProctime(): Unit = { + val t = failingDataSource(StreamTestData.get3TupleData) + .toTable(tEnv, 'a, 'b, 'c, 'proctime) + tEnv.registerTable("T", t) + + val sql = + """ + |SELECT a, b, c + |FROM ( + | SELECT *, + | ROW_NUMBER() OVER (PARTITION BY b ORDER BY proctime) as rowNum + | FROM T + |) + |WHERE rowNum = 1 + """.stripMargin + + val sink = new TestingAppendSink + tEnv.sqlQuery(sql).toAppendStream[Row].addSink(sink) + env.execute() + + val expected = List("1,1,Hi", "2,2,Hello", "4,3,Hello world, how are you?", + "7,4,Comment#1", "11,5,Comment#5", "16,6,Comment#10") + assertEquals(expected.sorted, sink.getAppendResults.sorted) + } + + @Test + def testLastRowOnProctime(): Unit = { + val t = failingDataSource(StreamTestData.get3TupleData) + .toTable(tEnv, 'a, 'b, 'c, 'proctime) + tEnv.registerTable("T", t) + + val sql = + """ + |SELECT a, b, c + |FROM ( + | SELECT *, + | ROW_NUMBER() OVER (PARTITION BY b ORDER BY proctime DESC) as rowNum + | FROM T + |) + |WHERE rowNum = 1 + """.stripMargin + + val sink = new TestingRetractSink + tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink) + env.execute() + + val expected = List("1,1,Hi", "3,2,Hello world", "6,3,Luke Skywalker", + "10,4,Comment#4", "15,5,Comment#9", "21,6,Comment#15") + assertEquals(expected.sorted, sink.getRetractResults.sorted) + } + + @Test + def testFirstRowOnRowtime(): Unit = { + val data = List( + (3L, 2L, "Hello world", 3), + (2L, 2L, "Hello", 2), + (6L, 3L, "Luke Skywalker", 6), + (5L, 3L, "I am fine.", 5), + (7L, 4L, "Comment#1", 7), + (9L, 4L, "Comment#3", 9), + (10L, 4L, "Comment#4", 10), + (8L, 4L, "Comment#2", 8), + (1L, 1L, "Hi", 1), + (4L, 3L, "Helloworld, how are you?", 4)) + + val t = failingDataSource(data) + .assignTimestampsAndWatermarks( + new TimestampAndWatermarkWithOffset[(Long, Long, String, Int)](10L)) + .toTable(tEnv, 'rowtime, 'key, 'str, 'int) + tEnv.registerTable("T", t) + + val sql = + """ + |SELECT key, str, `int` + |FROM ( + | SELECT *, + | ROW_NUMBER() OVER (PARTITION BY key ORDER BY rowtime) as rowNum + | FROM T + |) + |WHERE rowNum = 1 + """.stripMargin + + val sink = new TestingUpsertTableSink(Array(1)) + val table = tEnv.sqlQuery(sql) + writeToSink(table, sink) + + // TODO Deduplicate does not support sort on rowtime now, so it is translated to Rank currently + env.execute() + val expected = List("1,Hi,1", "2,Hello,2", "3,Helloworld, how are you?,4", "4,Comment#1,7") + assertEquals(expected.sorted, sink.getUpsertResults.sorted) + } + + @Test + def testLastRowOnRowtime(): Unit = { + val data = List( + (3L, 2L, "Hello world", 3), + (2L, 2L, "Hello", 2), + (6L, 3L, "Luke Skywalker", 6), + (5L, 3L, "I am fine.", 5), + (7L, 4L, "Comment#1", 7), + (9L, 4L, "Comment#3", 9), + (10L, 4L, "Comment#4", 10), + (8L, 4L, "Comment#2", 8), + (1L, 1L, "Hi", 1), + (4L, 3L, "Helloworld, how are you?", 4)) + + val t = failingDataSource(data) + .assignTimestampsAndWatermarks( + new TimestampAndWatermarkWithOffset[(Long, Long, String, Int)](10L)) + .toTable(tEnv, 'rowtime, 'key, 'str, 'int) + tEnv.registerTable("T", t) + + val sql = + """ + |SELECT key, str, `int` + |FROM ( + | SELECT *, + | ROW_NUMBER() OVER (PARTITION BY key ORDER BY rowtime DESC) as rowNum + | FROM T + |) + |WHERE rowNum = 1 + """.stripMargin + + val sink = new TestingUpsertTableSink(Array(1)) + val table = tEnv.sqlQuery(sql) + writeToSink(table, sink) + + // TODO Deduplicate does not support sort on rowtime now, so it is translated to Rank currently + env.execute() + val expected = List("1,Hi,1", "2,Hello world,3", "3,Luke Skywalker,6", "4,Comment#4,10") + assertEquals(expected.sorted, sink.getUpsertResults.sorted) + } + +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/RankITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/RankITCase.scala new file mode 100644 index 00000000000000..30577b156b0291 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/RankITCase.scala @@ -0,0 +1,1296 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.stream.sql + +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.api.scala._ +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.api.TableConfigOptions +import org.apache.flink.table.runtime.utils.StreamingWithStateTestBase.StateBackendMode +import org.apache.flink.table.runtime.utils.{TestingRetractTableSink, TestingUpsertTableSink, _} +import org.apache.flink.types.Row + +import org.junit.Assert._ +import org.junit.{Ignore, _} +import org.junit.runner.RunWith +import org.junit.runners.Parameterized + +@RunWith(classOf[Parameterized]) +class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode) { + + @Test + def testTopN(): Unit = { + val data = List( + ("book", 1, 12), + ("book", 2, 19), + ("book", 4, 11), + ("fruit", 4, 33), + ("fruit", 3, 44), + ("fruit", 5, 22)) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT * + |FROM ( + | SELECT category, shopId, num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num DESC) as rank_num + | FROM T) + |WHERE rank_num <= 2 + """.stripMargin + + val sink = new TestingRetractSink + tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink).setParallelism(1) + env.execute() + + val expected = List( + "book,2,19,1", + "book,1,12,2", + "fruit,3,44,1", + "fruit,4,33,2") + assertEquals(expected.sorted, sink.getRetractResults.sorted) + } + + @Test + def testTopNth(): Unit = { + val data = List( + ("book", 1, 12), + ("book", 2, 19), + ("book", 4, 11), + ("fruit", 4, 33), + ("fruit", 3, 44), + ("fruit", 5, 22)) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT * + |FROM ( + | SELECT category, shopId, num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num DESC) as rank_num + | FROM T) + |WHERE rank_num = 2 + """.stripMargin + + val sink = new TestingRetractSink + tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink).setParallelism(1) + env.execute() + + val expected = List( + "book,1,12,2", + "fruit,4,33,2") + assertEquals(expected.sorted, sink.getRetractResults.sorted) + } + + @Ignore("Enable after retraction infer is introduced") + @Test + def testTopNWithUpsertSink(): Unit = { + val data = List( + ("book", 1, 12), + ("book", 2, 19), + ("book", 4, 11), + ("fruit", 4, 33), + ("fruit", 3, 44), + ("fruit", 5, 22)) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT * + |FROM ( + | SELECT category, shopId, num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num ASC) as rank_num + | FROM T) + |WHERE rank_num <= 2 + """.stripMargin + + val sink = new TestingUpsertTableSink(Array(0, 3)) + val table = tEnv.sqlQuery(sql) + writeToSink(table, sink) + env.execute() + + val expected = List( + "book,4,11,1", + "book,1,12,2", + "fruit,5,22,1", + "fruit,4,33,2") + assertEquals(expected.sorted, sink.getUpsertResults.sorted) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + @Test + def testTopNWithUnary(): Unit = { + val data = List( + ("book", 11, 100), + ("book", 11, 200), + ("book", 12, 400), + ("book", 12, 500), + ("book", 10, 600), + ("book", 10, 700), + ("book", 9, 800), + ("book", 9, 900), + ("book", 10, 500), + ("book", 8, 110), + ("book", 8, 120), + ("book", 7, 1800), + ("book", 9, 300), + ("book", 6, 1900), + ("book", 7, 50), + ("book", 11, 1800), + ("book", 7, 50), + ("book", 8, 2000), + ("book", 6, 700), + ("book", 5, 800), + ("book", 4, 910), + ("book", 3, 1000), + ("book", 2, 1100), + ("book", 1, 1200) + ) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT * + |FROM ( + | SELECT category, shopId, num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num ASC) as rank_num + | FROM ( + | SELECT category, shopId, sum(num) as num + | FROM T + | GROUP BY category, shopId + | )) + |WHERE rank_num <= 3 + """.stripMargin + + val sink = new TestingUpsertTableSink(Array(0, 3)) + val table = tEnv.sqlQuery(sql) + writeToSink(table, sink) + env.execute() + + val updatedExpected = List( + "book,5,800,1", + "book,12,900,2", + "book,4,910,3") + + assertEquals(updatedExpected.sorted, sink.getUpsertResults.sorted) + } + + @Ignore("Enable when state support SortedMapState") + @Test + def testUnarySortTopNOnString(): Unit = { + val data = List( + ("book", 11, "100"), + ("book", 11, "200"), + ("book", 12, "400"), + ("book", 12, "600"), + ("book", 10, "600"), + ("book", 10, "700"), + ("book", 9, "800"), + ("book", 9, "900"), + ("book", 10, "500"), + ("book", 8, "110"), + ("book", 8, "120"), + ("book", 7, "812"), + ("book", 9, "300"), + ("book", 6, "900"), + ("book", 7, "50"), + ("book", 11, "800"), + ("book", 7, "50"), + ("book", 8, "200"), + ("book", 6, "700"), + ("book", 5, "800"), + ("book", 4, "910"), + ("book", 3, "110"), + ("book", 2, "900"), + ("book", 1, "700") + ) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'price) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT * + |FROM ( + | SELECT category, shopId, max_price, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY max_price ASC) as rank_num + | FROM ( + | SELECT category, shopId, max(price) as max_price + | FROM T + | GROUP BY category, shopId + | )) + |WHERE rank_num <= 3 + """.stripMargin + + val table = tEnv.sqlQuery(sql) + val sink = new TestingUpsertTableSink(Array(0, 3)) + tEnv.writeToSink(table, sink) + env.execute() + + val updatedExpected = List( + "book,3,110,1", + "book,8,200,2", + "book,12,600,3") + + assertEquals(updatedExpected.sorted, sink.getUpsertResults.sorted) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + @Test + def testTopNWithGroupBy(): Unit = { + val data = List( + ("book", 1, 11), + ("book", 2, 19), + ("book", 4, 13), + ("book", 1, 11), + ("fruit", 4, 33), + ("fruit", 5, 12), + ("fruit", 3, 44), + ("fruit", 5, 22)) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT * + |FROM ( + | SELECT category, shopId, num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num DESC) as rank_num + | FROM ( + | SELECT category, shopId, sum(num) as num + | FROM T + | GROUP BY category, shopId + | )) + |WHERE rank_num <= 2 + """.stripMargin + + val table = tEnv.sqlQuery(sql) + val sink = new TestingUpsertTableSink(Array(0, 3)) + writeToSink(table, sink) + env.execute() + + val updatedExpected = List( + "book,1,22,1", + "book,2,19,2", + "fruit,3,44,1", + "fruit,5,34,2") + assertEquals(updatedExpected.sorted, sink.getUpsertResults.sorted) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + @Test + def testTopNWithSumAndCondition(): Unit = { + val data = List( + Row.of("book", Int.box(11), Double.box(100)), + Row.of("book", Int.box(11), Double.box(200)), + Row.of("book", Int.box(12), Double.box(400)), + Row.of("book", Int.box(12), Double.box(500)), + Row.of("book", Int.box(10), Double.box(600)), + Row.of("book", Int.box(10), Double.box(700))) + + implicit val tpe: TypeInformation[Row] = new RowTypeInfo( + BasicTypeInfo.STRING_TYPE_INFO, + BasicTypeInfo.INT_TYPE_INFO, + BasicTypeInfo.DOUBLE_TYPE_INFO) // tpe is automatically + + val ds = env.fromCollection(data) + val t = ds.toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", t) + + val subquery = + """ + |SELECT category, shopId, sum(num) as sum_num + |FROM T + |WHERE num >= cast(1.1 as double) + |GROUP BY category, shopId + """.stripMargin + + val sql = + s""" + |SELECT * + |FROM ( + | SELECT category, shopId, sum_num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY sum_num DESC) as rank_num + | FROM ($subquery)) + |WHERE rank_num <= 2 + """.stripMargin + + val table = tEnv.sqlQuery(sql) + val sink = new TestingUpsertTableSink(Array(0, 3)) + writeToSink(table, sink) + env.execute() + + val updatedExpected = List( + "book,10,1300.0,1", + "book,12,900.0,2") + + assertEquals(updatedExpected.sorted, sink.getUpsertResults.sorted) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + @Test + def testTopNthWithGroupBy(): Unit = { + val data = List( + ("book", 1, 11), + ("book", 2, 19), + ("book", 4, 13), + ("book", 1, 11), + ("fruit", 4, 33), + ("fruit", 5, 12), + ("fruit", 3, 44), + ("fruit", 5, 22)) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT * + |FROM ( + | SELECT category, shopId, num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num DESC) as rank_num + | FROM ( + | SELECT category, shopId, sum(num) as num + | FROM T + | GROUP BY category, shopId + | )) + |WHERE rank_num = 2 + """.stripMargin + + val table = tEnv.sqlQuery(sql) + val tableSink = new TestingUpsertTableSink(Array(0, 3)) + writeToSink(table, tableSink) + env.execute() + + val updatedExpected = List( + "book,2,19,2", + "fruit,5,34,2") + + assertEquals(updatedExpected.sorted, tableSink.getUpsertResults.sorted) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + @Test + def testTopNWithGroupByAndRetract(): Unit = { + val data = List( + ("book", 1, 11), + ("book", 2, 19), + ("book", 4, 13), + ("book", 1, 11), + ("fruit", 4, 33), + ("fruit", 5, 12), + ("fruit", 3, 44), + ("fruit", 5, 22)) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT * + |FROM ( + | SELECT category, shopId, num, cnt, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num DESC, cnt ASC) as rank_num + | FROM ( + | SELECT category, shopId, sum(num) as num, count(num) as cnt + | FROM T + | GROUP BY category, shopId + | )) + |WHERE rank_num <= 2 + """.stripMargin + + val sink = new TestingRetractSink + tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink).setParallelism(1) + env.execute() + + val expected = List( + "book,1,22,2,1", + "book,2,19,1,2", + "fruit,3,44,1,1", + "fruit,5,34,2,2") + assertEquals(expected.sorted, sink.getRetractResults.sorted) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + def testTopNthWithGroupByAndRetract(): Unit = { + val data = List( + ("book", 1, 11), + ("book", 2, 19), + ("book", 4, 13), + ("book", 1, 11), + ("fruit", 4, 33), + ("fruit", 5, 12), + ("fruit", 3, 44), + ("fruit", 5, 22)) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT * + |FROM ( + | SELECT category, shopId, num, cnt, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num DESC, cnt ASC) as rank_num + | FROM ( + | SELECT category, shopId, sum(num) as num, count(num) as cnt + | FROM T + | GROUP BY category, shopId + | )) + |WHERE rank_num = 2 + """.stripMargin + + val sink = new TestingRetractSink + tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink).setParallelism(1) + env.execute() + + val expected = List( + "book,2,19,1,2", + "fruit,5,34,2,2") + assertEquals(expected.sorted, sink.getRetractResults.sorted) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + @Test + def testTopNWithGroupByCount(): Unit = { + val data = List( + ("book", 1, 1001), + ("book", 2, 1002), + ("book", 4, 1003), + ("book", 1, 1004), + ("book", 1, 1005), + ("book", 3, 1006), + ("book", 2, 1007), + ("book", 4, 1008), + ("book", 1, 1009), + ("book", 4, 1010), + ("book", 4, 1012), + ("book", 4, 1012), + ("fruit", 4, 1013), + ("fruit", 5, 1014), + ("fruit", 3, 1015), + ("fruit", 4, 1017), + ("fruit", 5, 1018), + ("fruit", 5, 1016)) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'sellId) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT category, rank_num, sells, shopId + |FROM ( + | SELECT category, shopId, sells, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY sells DESC) as rank_num + | FROM ( + | SELECT category, shopId, count(sellId) as sells + | FROM T + | GROUP BY category, shopId + | )) + |WHERE rank_num <= 4 + """.stripMargin + + val table = tEnv.sqlQuery(sql) + val tableSink = new TestingUpsertTableSink(Array(0, 1)) + writeToSink(table, tableSink) + env.execute() + + val expected = List( + "book,1,5,4", + "book,2,4,1", + "book,3,2,2", + "book,4,1,3", + "fruit,1,3,5", + "fruit,2,2,4", + "fruit,3,1,3") + assertEquals(expected.sorted, tableSink.getUpsertResults.sorted) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + @Test + def testTopNthWithGroupByCount(): Unit = { + val data = List( + ("book", 1, 1001), + ("book", 2, 1002), + ("book", 4, 1003), + ("book", 1, 1004), + ("book", 1, 1005), + ("book", 3, 1006), + ("book", 2, 1007), + ("book", 4, 1008), + ("book", 1, 1009), + ("book", 4, 1010), + ("book", 4, 1012), + ("book", 4, 1012), + ("fruit", 4, 1013), + ("fruit", 5, 1014), + ("fruit", 3, 1015), + ("fruit", 4, 1017), + ("fruit", 5, 1018), + ("fruit", 5, 1016)) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'sellId) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT category, rank_num, sells, shopId + |FROM ( + | SELECT category, shopId, sells, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY sells DESC) as rank_num + | FROM ( + | SELECT category, shopId, count(sellId) as sells + | FROM T + | GROUP BY category, shopId + | )) + |WHERE rank_num = 3 + """.stripMargin + + val table = tEnv.sqlQuery(sql) + val tableSink = new TestingUpsertTableSink(Array(0, 1)) + writeToSink(table, tableSink) + env.execute() + + val expected = List( + "book,3,2,2", + "fruit,3,1,3") + assertEquals(expected.sorted, tableSink.getUpsertResults.sorted) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + @Test + def testNestedTopN(): Unit = { + val data = List( + ("book", "a", 1), + ("book", "b", 1), + ("book", "c", 1), + ("fruit", "a", 2), + ("book", "a", 1), + ("book", "d", 0), + ("book", "b", 3), + ("fruit", "b", 6), + ("book", "c", 1), + ("book", "e", 5), + ("book", "d", 4)) + + env.setParallelism(1) + val ds = failingDataSource(data).toTable(tEnv, 'cate, 'shopId, 'sells) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT rank_num, cate, shopId, sells, cnt + |FROM ( + | SELECT *, + | ROW_NUMBER() OVER (PARTITION BY cate ORDER BY sells DESC) as rank_num + | FROM ( + | SELECT cate, shopId, count(*) as cnt, max(sells) as sells + | FROM T + | GROUP BY cate, shopId + | )) + |WHERE rank_num <= 4 + """.stripMargin + + + val sql2 = + s""" + |SELECT rank_num, cate, shopId, sells, cnt + |FROM ( + | SELECT cate, shopId, sells, cnt, + | ROW_NUMBER() OVER (ORDER BY sells DESC) as rank_num + | FROM ($sql) + |) + |WHERE rank_num <= 4 + """.stripMargin + + val tableSink = new TestingUpsertTableSink(Array(0)) + val table = tEnv.sqlQuery(sql2) + writeToSink(table, tableSink) + env.execute() + + val expected = List( + "(true,1,book,a,1,1)", "(true,2,book,b,1,1)", "(true,3,book,c,1,1)", + "(true,1,fruit,a,2,1)", "(true,2,book,a,1,1)", "(true,3,book,b,1,1)", "(true,4,book,c,1,1)", + "(true,2,book,a,1,2)", + "(true,1,book,b,3,2)", "(true,2,fruit,a,2,1)", "(true,3,book,a,1,2)", + "(true,3,book,a,1,2)", + "(true,1,fruit,b,6,1)", "(true,2,book,b,3,2)", "(true,3,fruit,a,2,1)", "(true,4,book,a,1,2)", + "(true,3,fruit,a,2,1)", + "(true,2,book,e,5,1)", + "(true,3,book,b,3,2)", "(true,4,fruit,a,2,1)", + "(true,3,book,b,3,2)", + "(true,3,book,d,4,2)", + "(true,4,book,b,3,2)", + "(true,4,book,b,3,2)") + assertEquals(expected.mkString("\n"), tableSink.getRawResults.mkString("\n")) + + val expected2 = List("1,fruit,b,6,1", "2,book,e,5,1", "3,book,d,4,2", "4,book,b,3,2") + assertEquals(expected2, tableSink.getUpsertResults.sorted) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + @Test + def testTopNWithoutDeduplicate(): Unit = { + val data = List( + ("book", "a", 1), + ("book", "b", 1), + ("book", "c", 1), + ("fruit", "a", 2), + ("book", "a", 1), + ("book", "d", 0), + ("book", "b", 3), + ("fruit", "b", 6), + ("book", "c", 1), + ("book", "e", 5), + ("book", "d", 4)) + + env.setParallelism(1) + val ds = failingDataSource(data).toTable(tEnv, 'cate, 'shopId, 'sells) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT rank_num, cate, shopId, sells, cnt + |FROM ( + | SELECT *, + | ROW_NUMBER() OVER (PARTITION BY cate ORDER BY sells DESC) as rank_num + | FROM ( + | SELECT cate, shopId, count(*) as cnt, max(sells) as sells + | FROM T + | GROUP BY cate, shopId + | )) + |WHERE rank_num <= 4 + """.stripMargin + + tEnv.getConfig.getConf.setLong(TableConfigOptions.SQL_EXEC_TOPN_CACHE_SIZE, 1) + val tableSink = new TestingUpsertTableSink(Array(0)) + val table = tEnv.sqlQuery(sql) + writeToSink(table, tableSink) + env.execute() + + val expected = List( + "(true,1,book,a,1,1)", + "(true,2,book,b,1,1)", + "(true,3,book,c,1,1)", + "(true,1,fruit,a,2,1)", + "(true,1,book,a,1,2)", + "(true,4,book,d,0,1)", + "(true,1,book,b,3,2)", + "(true,2,book,a,1,2)", + "(true,1,fruit,b,6,1)", + "(true,2,fruit,a,2,1)", + "(true,3,book,c,1,2)", + "(true,1,book,e,5,1)", + "(true,2,book,b,3,2)", + "(true,3,book,a,1,2)", + "(true,4,book,c,1,2)", + "(true,2,book,d,4,2)", + "(true,3,book,b,3,2)", + "(true,4,book,a,1,2)") + + assertEquals(expected, tableSink.getRawResults) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + @Test + def testTopNWithVariableTopSize(): Unit = { + val data = List( + ("book", 1, 1001, 4), + ("book", 2, 1002, 4), + ("book", 4, 1003, 4), + ("book", 1, 1004, 4), + ("book", 1, 1005, 4), + ("book", 3, 1006, 4), + ("book", 2, 1007, 4), + ("book", 4, 1008, 4), + ("book", 1, 1009, 4), + ("book", 4, 1010, 4), + ("book", 4, 1012, 4), + ("book", 4, 1012, 4), + ("fruit", 4, 1013, 2), + ("fruit", 5, 1014, 2), + ("fruit", 3, 1015, 2), + ("fruit", 4, 1017, 2), + ("fruit", 5, 1018, 2), + ("fruit", 5, 1016, 2)) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'sellId, 'topSize) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT category, rank_num, sells, shopId + |FROM ( + | SELECT category, shopId, sells, topSize, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY sells DESC) as rank_num + | FROM ( + | SELECT category, shopId, count(sellId) as sells, max(topSize) as topSize + | FROM T + | GROUP BY category, shopId + | )) + |WHERE rank_num <= topSize + """.stripMargin + + val tableSink = new TestingUpsertTableSink(Array(0, 1)) + val table = tEnv.sqlQuery(sql) + writeToSink(table, tableSink) + env.execute() + + val expected = List( + "book,1,5,4", + "book,2,4,1", + "book,3,2,2", + "book,4,1,3", + "fruit,1,3,5", + "fruit,2,2,4") + assertEquals(expected.sorted, tableSink.getUpsertResults.sorted) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + @Test + def testTopNUnaryComplexScenario(): Unit = { + val data = List( + ("book", 1, 11), + ("book", 2, 19), + ("book", 4, 13), + ("book", 1, 11), // backward update in heap + ("book", 3, 23), // elems exceed topn size after insert + ("book", 5, 19), // sort map shirk out some elem after insert + ("book", 7, 10), // sort map keeps a little more than topn size elems after insert + ("book", 8, 13), // sort map now can shrink out-of-range elems after another insert + ("book", 10, 13), // Once again, sort map keeps a little more elems after insert + ("book", 8, 6), // backward update from heap to state + ("book", 10, 6), // backward update from heap to state, and sort map load more data + ("book", 5, 3), // backward update from heap to state + ("book", 10, 1), // backward update in heap, and then sort map shrink some data + ("book", 5, 1), // backward update in state + ("book", 5, -3), // forward update in state + ("book", 2, -10), // forward update in heap, and then sort map shrink some data + ("book", 10, -7), // forward update from state to heap + ("book", 11, 13), // insert into heap + ("book", 12, 10), // insert into heap, and sort map shrink some data + ("book", 15, 14) // insert into state + ) + + env.setParallelism(1) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT * + |FROM ( + | SELECT category, shopId, num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num ASC) as rank_num + | FROM ( + | SELECT category, shopId, sum(num) as num + | FROM T + | GROUP BY category, shopId + | )) + |WHERE rank_num <= 3 + """.stripMargin + + val tableSink = new TestingUpsertTableSink(Array(0, 3)) + val table = tEnv.sqlQuery(sql) + writeToSink(table, tableSink) + env.execute() + + val expected = List( + "(true,book,1,11,1)", + "(true,book,2,19,2)", + "(true,book,4,13,2)", + "(true,book,2,19,3)", + + "(true,book,4,13,1)", + "(true,book,2,19,2)", + "(true,book,1,22,3)", + + "(true,book,5,19,3)", + + "(true,book,7,10,1)", + "(true,book,4,13,2)", + "(true,book,2,19,3)", + + "(true,book,8,13,3)", + + "(true,book,10,13,3)", + + "(true,book,2,19,3)", + + "(true,book,2,9,1)", + "(true,book,7,10,2)", + "(true,book,4,13,3)", + + "(true,book,12,10,3)") + + assertEquals(expected.mkString("\n"), tableSink.getRawResults.mkString("\n")) + + val updatedExpected = List( + "book,2,9,1", + "book,7,10,2", + "book,12,10,3") + + assertEquals(updatedExpected.sorted, tableSink.getUpsertResults.sorted) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + @Test + def testTopNWithGroupByAvgWithoutRowNumber(): Unit = { + val data = List( + ("book", 1, 100), + ("book", 3, 110), + ("book", 4, 120), + ("book", 1, 200), + ("book", 1, 200), + ("book", 2, 300), + ("book", 2, 400), + ("book", 4, 500), + ("book", 1, 400), + ("fruit", 5, 100)) + + env.setParallelism(1) + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'sellId) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT category, shopId, avgSellId + |FROM ( + | SELECT category, shopId, avgSellId, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY avgSellId DESC) as rank_num + | FROM ( + | SELECT category, shopId, AVG(sellId) as avgSellId + | FROM T + | GROUP BY category, shopId + | )) + |WHERE rank_num <= 3 + """.stripMargin + + val tableSink = new TestingUpsertTableSink(Array(0, 1)) + val table = tEnv.sqlQuery(sql) + writeToSink(table, tableSink) + env.execute() + + val expected = List( + "(true,book,1,100.0)", + "(true,book,3,110.0)", + "(true,book,4,120.0)", + "(true,book,1,150.0)", + "(true,book,1,166.66666666666666)", + "(true,book,2,300.0)", + "(false,book,3,110.0)", + "(true,book,2,350.0)", + "(true,book,4,310.0)", + "(true,book,1,225.0)", + "(true,fruit,5,100.0)") + + assertEquals(expected, tableSink.getRawResults) + + val updatedExpected = List( + "book,1,225.0", + "book,2,350.0", + "book,4,310.0", + "fruit,5,100.0") + + assertEquals(updatedExpected.sorted, tableSink.getUpsertResults.sorted) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + @Test + def testTopNWithGroupByCountWithoutRowNumber(): Unit = { + val data = List( + ("book", 1, 1001), + ("book", 3, 1006), + ("book", 4, 1003), + ("book", 1, 1004), + ("book", 1, 1005), + ("book", 2, 1002), + ("book", 2, 1007), + ("book", 4, 1008), + ("book", 1, 1009), + ("book", 4, 1010), + ("book", 4, 1012), + ("book", 4, 1012), + ("fruit", 4, 1013), + ("fruit", 5, 1014), + ("fruit", 3, 1015), + ("fruit", 4, 1017), + ("fruit", 5, 1018), + ("fruit", 5, 1016)) + + env.setParallelism(1) + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'sellId) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT category, shopId, sells + |FROM ( + | SELECT category, shopId, sells, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY sells DESC) as rank_num + | FROM ( + | SELECT category, shopId, count(sellId) as sells + | FROM T + | GROUP BY category, shopId + | )) + |WHERE rank_num <= 3 + """.stripMargin + + val tableSink = new TestingUpsertTableSink(Array(0, 1)) + val table = tEnv.sqlQuery(sql) + writeToSink(table, tableSink) + env.execute() + + val expected = List( + "(true,book,1,1)", + "(true,book,3,1)", + "(true,book,4,1)", + "(true,book,1,2)", + "(true,book,1,3)", + "(true,book,2,2)", + "(false,book,4,1)", + "(true,book,4,2)", + "(false,book,3,1)", + "(true,book,1,4)", + "(true,book,4,3)", + "(true,book,4,4)", + "(true,book,4,5)", + "(true,fruit,4,1)", + "(true,fruit,5,1)", + "(true,fruit,3,1)", + "(true,fruit,4,2)", + "(true,fruit,5,2)", + "(true,fruit,5,3)") + assertEquals(expected, tableSink.getRawResults) + + val updatedExpected = List( + "book,4,5", + "book,1,4", + "book,2,2", + "fruit,5,3", + "fruit,4,2", + "fruit,3,1") + assertEquals(updatedExpected.sorted, tableSink.getUpsertResults.sorted) + } + + @Test + def testTopNWithoutRowNumber(): Unit = { + val data = List( + ("book", 1, 12), + ("book", 2, 19), + ("book", 4, 11), + ("book", 5, 20), + ("fruit", 4, 33), + ("fruit", 3, 44), + ("fruit", 5, 22), + ("fruit", 1, 40)) + + env.setParallelism(1) + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT category, num, shopId + |FROM ( + | SELECT category, shopId, num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num DESC) as rank_num + | FROM T) + |WHERE rank_num <= 2 + """.stripMargin + + val tableSink = new TestingUpsertTableSink(Array(0, 2)) + val table = tEnv.sqlQuery(sql) + writeToSink(table, tableSink) + env.execute() + + val expected = List( + "(true,book,12,1)", + "(true,book,19,2)", + "(false,book,12,1)", + "(true,book,20,5)", + "(true,fruit,33,4)", + "(true,fruit,44,3)", + "(false,fruit,33,4)", + "(true,fruit,40,1)") + assertEquals(expected, tableSink.getRawResults) + + val updatedExpected = List( + "book,19,2", + "book,20,5", + "fruit,40,1", + "fruit,44,3") + assertEquals(updatedExpected.sorted, tableSink.getUpsertResults.sorted) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + @Test + def testMultipleRetractTopNAfterAgg(): Unit = { + val data = List( + ("book", 1, 12), + ("book", 1, 13), + ("book", 2, 19), + ("book", 4, 11), + ("fruit", 4, 33), + ("fruit", 3, 44), + ("fruit", 5, 22)) + + env.setParallelism(1) + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val subquery = + s""" + |SELECT category, shopId, SUM(num) as sum_num, MAX(num) as max_num, + | AVG(num) as avg_num, COUNT(num) as cnt + |FROM T + |GROUP BY category, shopId + |""".stripMargin + + val t1 = tEnv.sqlQuery(subquery) + tEnv.registerTable("MyView", t1) + + val sink1 = new TestingRetractSink + tEnv.sqlQuery( + s""" + |SELECT * + |FROM ( + | SELECT category, shopId, sum_num, avg_num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY sum_num DESC, avg_num ASC + | ) as rank_num + | FROM MyView) + |WHERE rank_num <= 2 + |""".stripMargin).toRetractStream[Row].addSink(sink1).setParallelism(1) + env.execute() + + val expected1 = List( + "book,1,25,12.5,1", + "book,2,19,19.0,2", + "fruit,3,44,44.0,1", + "fruit,4,33,33.0,2") + assertEquals(expected1.sorted, sink1.getRetractResults.sorted) + + val sink2 = new TestingRetractSink + val table2 = tEnv.sqlQuery( + s""" + |SELECT * + |FROM ( + | SELECT category, shopId, max_num, cnt, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY max_num DESC, cnt ASC) as rank_num + | FROM MyView) + |WHERE rank_num <= 2 + |""".stripMargin).toRetractStream[Row].addSink(sink2).setParallelism(1) + + env.execute() + + val expected2 = List( + "book,2,19,1,1", + "book,1,13,2,2", + "fruit,3,44,1,1", + "fruit,4,33,1,2") + assertEquals(expected2.sorted, sink2.getRetractResults.sorted) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + @Test + def testMultipleUnaryTopNAfterAgg(): Unit = { + val data = List( + ("book", 1, 12), + ("book", 1, 13), + ("book", 2, 19), + ("book", 4, 11), + ("fruit", 4, 33), + ("fruit", 3, 44), + ("fruit", 5, 22)) + + env.setParallelism(1) + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val subquery = + s""" + |SELECT category, shopId, SUM(num) as sum_num, MAX(num) as max_num + |FROM T + |GROUP BY category, shopId + |""".stripMargin + + val t1 = tEnv.sqlQuery(subquery) + tEnv.registerTable("MyView", t1) + + val sink1 = new TestingUpsertTableSink(Array(0, 3)) + val table1 = tEnv.sqlQuery( + s""" + |SELECT * + |FROM ( + | SELECT category, shopId, sum_num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY sum_num DESC) as rank_num + | FROM MyView) + |WHERE rank_num <= 2 + |""".stripMargin) + writeToSink(table1, sink1) + + val sink2 = new TestingUpsertTableSink(Array(0, 3)) + val table2 = tEnv.sqlQuery( + s""" + |SELECT * + |FROM ( + | SELECT category, shopId, max_num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY max_num DESC) as rank_num + | FROM MyView) + |WHERE rank_num <= 2 + |""".stripMargin) + + writeToSink(table2, sink2) + + val expected1 = List( + "book,1,25,1", + "book,2,19,2", + "fruit,3,44,1", + "fruit,4,33,2") + assertEquals(expected1.sorted, sink1.getUpsertResults.sorted) + + val expected2 = List( + "book,2,19,1", + "book,1,13,2", + "fruit,3,44,1", + "fruit,4,33,2") + assertEquals(expected2.sorted, sink2.getUpsertResults.sorted) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + @Test + def testMultipleUpdateTopNAfterAgg(): Unit = { + val data = List( + ("book", 1, 12), + ("book", 1, 13), + ("book", 2, 19), + ("book", 4, 11), + ("fruit", 4, 33), + ("fruit", 3, 44), + ("fruit", 5, 22)) + + env.setParallelism(1) + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val subquery = + s""" + |SELECT category, shopId, COUNT(num) as cnt_num, MAX(num) as max_num + |FROM T + |GROUP BY category, shopId + |""".stripMargin + + val t1 = tEnv.sqlQuery(subquery) + tEnv.registerTable("MyView", t1) + + val sink1 = new TestingRetractTableSink + val table1 = tEnv.sqlQuery( + s""" + |SELECT * + |FROM ( + | SELECT category, shopId, cnt_num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY cnt_num DESC) as rank_num + | FROM MyView) + |WHERE rank_num <= 2 + |""".stripMargin) + writeToSink(table1, sink1) + + val sink2 = new TestingRetractTableSink + val table2 = tEnv.sqlQuery( + s""" + |SELECT * + |FROM ( + | SELECT category, shopId, max_num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY max_num DESC) as rank_num + | FROM MyView) + |WHERE rank_num <= 2 + |""".stripMargin) + + writeToSink(table2, sink2) + + val expected1 = List( + "book,1,2,1", + "book,2,1,2", + "fruit,4,1,1", + "fruit,3,1,2") + assertEquals(expected1.sorted, sink1.getRetractResults.sorted) + + val expected2 = List( + "book,2,19,1", + "book,1,13,2", + "fruit,3,44,1", + "fruit,4,33,2") + assertEquals(expected2.sorted, sink2.getRetractResults.sorted) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + @Test + def testUpdateRank(): Unit = { + val data = List( + (1, 1), (1, 2), (1, 3), + (2, 2), (2, 3), (2, 4), + (3, 3), (3, 4), (3, 5)) + + val ds = failingDataSource(data).toTable(tEnv, 'a, 'b) + tEnv.registerTable("T", ds) + + // We use max here to ensure the usage of update rank + val sql = "SELECT a, max(b) FROM T GROUP BY a ORDER BY a LIMIT 2" + + val sink = new TestingRetractSink + tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink).setParallelism(1) + env.execute() + + val expected = List("1,3", "2,4") + assertEquals(expected.sorted, sink.getRetractResults.sorted) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + @Test + def testUpdateRankWithOffset(): Unit = { + val data = List( + (1, 1), (1, 2), (1, 3), + (2, 2), (2, 3), (2, 4), + (3, 3), (3, 4), (3, 5)) + + val ds = failingDataSource(data).toTable(tEnv, 'a, 'b) + tEnv.registerTable("T", ds) + + // We use max here to ensure the usage of update rank + val sql = "SELECT a, max(b) FROM T GROUP BY a ORDER BY a LIMIT 2 OFFSET 1" + + val sink = new TestingRetractSink + tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink).setParallelism(1) + env.execute() + + val expected = List("2,4", "3,5") + assertEquals(expected.sorted, sink.getRetractResults.sorted) + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamTestSink.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamTestSink.scala index 95138c28edf6de..50b242c2bde151 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamTestSink.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamTestSink.scala @@ -19,23 +19,26 @@ package org.apache.flink.table.runtime.utils import org.apache.flink.api.common.ExecutionConfig +import org.apache.flink.api.common.functions.MapFunction import org.apache.flink.api.common.io.OutputFormat import org.apache.flink.api.common.state.{ListState, ListStateDescriptor} import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2} -import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.api.java.typeutils.{RowTypeInfo, TupleTypeInfo} import org.apache.flink.configuration.Configuration import org.apache.flink.runtime.state.{FunctionInitializationContext, FunctionSnapshotContext} import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction import org.apache.flink.streaming.api.datastream.{DataStream, DataStreamSink} import org.apache.flink.streaming.api.functions.sink.RichSinkFunction import org.apache.flink.table.api.{TableConfig, Types} -import org.apache.flink.table.dataformat.BaseRow -import org.apache.flink.table.sinks.{AppendStreamTableSink, BatchTableSink} +import org.apache.flink.table.dataformat.{BaseRow, DataFormatConverters, GenericRow} +import org.apache.flink.table.sinks._ +import org.apache.flink.table.`type`.TypeConverters.createInternalTypeFromTypeInfo import org.apache.flink.table.typeutils.BaseRowTypeInfo import org.apache.flink.table.util.BaseRowTestUtil import org.apache.flink.types.Row +import _root_.java.lang.{Boolean => JBoolean} import _root_.java.util.TimeZone import _root_.java.util.concurrent.atomic.AtomicInteger @@ -150,6 +153,159 @@ final class TestingAppendSink(tz: TimeZone) extends AbstractExactlyOnceSink[Row] def getAppendResults: List[String] = getResults } +final class TestingUpsertSink(keys: Array[Int], tz: TimeZone) + extends AbstractExactlyOnceSink[(Boolean, BaseRow)] { + + private var upsertResultsState: ListState[String] = _ + private var localUpsertResults: mutable.Map[String, String] = _ + private var fieldTypes: Array[TypeInformation[_]] = _ + + def this(keys: Array[Int]) { + this(keys, TimeZone.getTimeZone("UTC")) + } + + def configureTypes(fieldTypes: Array[TypeInformation[_]]): Unit = { + this.fieldTypes = fieldTypes + } + + override def initializeState(context: FunctionInitializationContext): Unit = { + super.initializeState(context) + upsertResultsState = context.getOperatorStateStore.getListState( + new ListStateDescriptor[String]("sink-upsert-results", Types.STRING)) + + localUpsertResults = mutable.HashMap.empty[String, String] + + if (context.isRestored) { + var key: String = null + var value: String = null + for (entry <- upsertResultsState.get().asScala) { + if (key == null) { + key = entry + } else { + value = entry + localUpsertResults += (key -> value) + key = null + value = null + } + } + if (key != null) { + throw new RuntimeException("The resultState is corrupt.") + } + } + + val taskId = getRuntimeContext.getIndexOfThisSubtask + StreamTestSink.synchronized{ + StreamTestSink.globalUpsertResults(idx) += (taskId -> localUpsertResults) + } + } + + override def snapshotState(context: FunctionSnapshotContext): Unit = { + super.snapshotState(context) + upsertResultsState.clear() + for ((key, value) <- localUpsertResults) { + upsertResultsState.add(key) + upsertResultsState.add(value) + } + } + + def invoke(d: (Boolean, BaseRow)): Unit = { + this.synchronized { + val wrapRow = new GenericRow(2) + wrapRow.setField(0, d._1) + wrapRow.setField(1, d._2) + val converter = + DataFormatConverters.getConverterForTypeInfo( + new TupleTypeInfo(Types.BOOLEAN, new RowTypeInfo(fieldTypes: _*))) + .asInstanceOf[DataFormatConverters.DataFormatConverter[BaseRow, JTuple2[JBoolean, Row]]] + val v = converter.toExternal(wrapRow) + val rowString = TestSinkUtil.rowToString(v.f1, tz) + val tupleString = "(" + v.f0.toString + "," + rowString + ")" + localResults += tupleString + val keyString = TestSinkUtil.rowToString(Row.project(v.f1, keys), tz) + if (v.f0) { + localUpsertResults += (keyString -> rowString) + } else { + val oldValue = localUpsertResults.remove(keyString) + if (oldValue.isEmpty) { + throw new RuntimeException("Tried to delete a value that wasn't inserted first. " + + "This is probably an incorrectly implemented test. " + + "Try to set the parallelism of the sink to 1.") + } + } + } + } + + def getRawResults: List[String] = getResults + + def getUpsertResults: List[String] = { + clearAndStashGlobalResults() + val result = ArrayBuffer.empty[String] + this.globalUpsertResults.foreach { + case (_, map) => map.foreach(result += _._2) + } + result.toList + } +} + +final class TestingUpsertTableSink(keys: Array[Int], tz: TimeZone) + extends UpsertStreamTableSink[BaseRow] { + var fNames: Array[String] = _ + var fTypes: Array[TypeInformation[_]] = _ + var sink = new TestingUpsertSink(keys, tz) + + def this(keys: Array[Int]) { + this(keys, TimeZone.getTimeZone("UTC")) + } + + override def setKeyFields(keys: Array[String]): Unit = { + // ignore + } + + override def setIsAppendOnly(isAppendOnly: JBoolean): Unit = { + // ignore + } + + override def getRecordType: TypeInformation[BaseRow] = + new BaseRowTypeInfo(fTypes.map(createInternalTypeFromTypeInfo(_)), fNames) + + override def getFieldNames: Array[String] = fNames + + override def getFieldTypes: Array[TypeInformation[_]] = fTypes + + override def emitDataStream(dataStream: DataStream[JTuple2[JBoolean, BaseRow]]) = { + dataStream.map(new MapFunction[JTuple2[JBoolean, BaseRow], (Boolean, BaseRow)] { + override def map(value: JTuple2[JBoolean, BaseRow]): (Boolean, BaseRow) = { + (value.f0, value.f1) + } + }) + .addSink(sink) + .name(s"TestingUpsertTableSink(keys=${ + if (keys != null) { + "(" + keys.mkString(",") + ")" + } else { + "null" + } + })") + .setParallelism(1) + } + + override def configure( + fieldNames: Array[String], + fieldTypes: Array[TypeInformation[_]]) + : TableSink[JTuple2[JBoolean, BaseRow]] = { + val copy = new TestingUpsertTableSink(keys, tz) + copy.fNames = fieldNames + copy.fTypes = fieldTypes + sink.configureTypes(fieldTypes) + copy.sink = sink + copy + } + + def getRawResults: List[String] = sink.getRawResults + + def getUpsertResults: List[String] = sink.getUpsertResults +} + final class TestingAppendTableSink(tz: TimeZone) extends AppendStreamTableSink[Row] with BatchTableSink[Row]{ var fNames: Array[String] = _ @@ -242,3 +398,118 @@ class TestingOutputFormat[T](tz: TimeZone) result.toList } } + +class TestingRetractSink(tz: TimeZone) + extends AbstractExactlyOnceSink[(Boolean, Row)] { + protected var retractResultsState: ListState[String] = _ + protected var localRetractResults: ArrayBuffer[String] = _ + + def this() { + this(TimeZone.getTimeZone("UTC")) + } + + override def initializeState(context: FunctionInitializationContext): Unit = { + super.initializeState(context) + retractResultsState = context.getOperatorStateStore + .getListState(new ListStateDescriptor[String]("sink-retract-results", Types.STRING)) + + localRetractResults = mutable.ArrayBuffer.empty[String] + + if (context.isRestored) { + for (value <- retractResultsState.get().asScala) { + localRetractResults += value + } + } + + val taskId = getRuntimeContext.getIndexOfThisSubtask + StreamTestSink.synchronized{ + StreamTestSink.globalRetractResults(idx) += (taskId -> localRetractResults) + } + } + + override def snapshotState(context: FunctionSnapshotContext): Unit = { + super.snapshotState(context) + retractResultsState.clear() + for (value <- localRetractResults) { + retractResultsState.add(value) + } + } + + def invoke(v: (Boolean, Row)): Unit = { + this.synchronized { + val tupleString = "(" + v._1.toString + "," + TestSinkUtil.rowToString(v._2, tz) + ")" + localResults += tupleString + val rowString = TestSinkUtil.rowToString(v._2, tz) + if (v._1) { + localRetractResults += rowString + } else { + val index = localRetractResults.indexOf(rowString) + if (index >= 0) { + localRetractResults.remove(index) + } else { + throw new RuntimeException("Tried to retract a value that wasn't added first. " + + "This is probably an incorrectly implemented test. " + + "Try to set the parallelism of the sink to 1.") + } + } + } + } + + def getRawResults: List[String] = getResults + + def getRetractResults: List[String] = { + clearAndStashGlobalResults() + val result = ArrayBuffer.empty[String] + this.globalRetractResults.foreach { + case (_, list) => result ++= list + } + result.toList + } +} + +final class TestingRetractTableSink(tz: TimeZone) extends RetractStreamTableSink[Row] { + + var fNames: Array[String] = _ + var fTypes: Array[TypeInformation[_]] = _ + var sink = new TestingRetractSink(tz) + + def this() { + this(TimeZone.getTimeZone("UTC")) + } + + override def emitDataStream(dataStream: DataStream[JTuple2[JBoolean, Row]]) = { + dataStream.map(new MapFunction[JTuple2[JBoolean, Row], (Boolean, Row)] { + override def map(value: JTuple2[JBoolean, Row]): (Boolean, Row) = { + (value.f0, value.f1) + } + }).setParallelism(dataStream.getParallelism) + .addSink(sink) + .name("TestingRetractTableSink") + .setParallelism(dataStream.getParallelism) + } + + override def getRecordType: TypeInformation[Row] = + new RowTypeInfo(fTypes, fNames) + + override def getFieldNames: Array[String] = fNames + + override def getFieldTypes: Array[TypeInformation[_]] = fTypes + + override def configure( + fieldNames: Array[String], + fieldTypes: Array[TypeInformation[_]]): TableSink[JTuple2[JBoolean, Row]] = { + val copy = new TestingRetractTableSink(tz) + copy.fNames = fieldNames + copy.fTypes = fieldTypes + copy.sink = sink + copy + } + + def getRawResults: List[String] = { + sink.getRawResults + } + + def getRetractResults: List[String] = { + sink.getRetractResults + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingTestBase.scala index 1335e897bda1b1..94b814fe833a80 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingTestBase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingTestBase.scala @@ -20,7 +20,9 @@ package org.apache.flink.table.runtime.utils import org.apache.flink.streaming.api.TimeCharacteristic import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment +import org.apache.flink.table.api.{Table, TableImpl} import org.apache.flink.table.api.scala.StreamTableEnvironment +import org.apache.flink.table.sinks.TableSink import org.apache.flink.test.util.AbstractTestBase import org.junit.rules.{ExpectedException, TemporaryFolder} @@ -53,4 +55,8 @@ class StreamingTestBase extends AbstractTestBase { this.tEnv = StreamTableEnvironment.create(env) } + def writeToSink(table: Table, sink: TableSink[_]): Unit = { + TableUtil.writeToSink(table.asInstanceOf[TableImpl], sink) + } + } diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingWithStateTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingWithStateTestBase.scala new file mode 100644 index 00000000000000..602edd84e4666e --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingWithStateTestBase.scala @@ -0,0 +1,271 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.utils + +import org.apache.flink.api.common.restartstrategy.RestartStrategies +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.common.typeutils.CompositeType +import org.apache.flink.configuration.{CheckpointingOptions, Configuration} +import org.apache.flink.runtime.state.memory.MemoryStateBackend +import org.apache.flink.streaming.api.CheckpointingMode +import org.apache.flink.streaming.api.functions.source.FromElementsFunction +import org.apache.flink.streaming.api.scala.DataStream +import org.apache.flink.table.api.{TableEnvironment, Types} +import org.apache.flink.table.dataformat.{BaseRow, BinaryRow, BinaryRowWriter, BinaryString} +import org.apache.flink.table.runtime.utils.StreamingWithStateTestBase.{HEAP_BACKEND, ROCKSDB_BACKEND, StateBackendMode} +import org.apache.flink.table.`type`.TypeConverters.createInternalTypeFromTypeInfo +import org.apache.flink.table.`type`.RowType + +import scala.collection.JavaConversions._ +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import java.io.File +import java.util + +import org.junit.runners.Parameterized +import org.junit.{After, Assert, Before} + +import org.apache.flink.contrib.streaming.state.RocksDBStateBackend + +class StreamingWithStateTestBase(state: StateBackendMode) extends StreamingTestBase { + + enableObjectReuse = state match { + case HEAP_BACKEND => false // TODO gemini not support obj reuse now. + case ROCKSDB_BACKEND => true + } + + private val classLoader = Thread.currentThread.getContextClassLoader + + var baseCheckpointPath: File = _ + + @Before + override def before(): Unit = { + super.before() + // set state backend + baseCheckpointPath = tempFolder.newFolder().getAbsoluteFile + state match { + case HEAP_BACKEND => + val conf = new Configuration() + conf.setBoolean(CheckpointingOptions.ASYNC_SNAPSHOTS, true) + env.setStateBackend(new MemoryStateBackend( + "file://" + baseCheckpointPath, null).configure(conf, classLoader)) + case ROCKSDB_BACKEND => + val conf = new Configuration() + conf.setBoolean(CheckpointingOptions.INCREMENTAL_CHECKPOINTS, true) + env.setStateBackend(new RocksDBStateBackend( + "file://" + baseCheckpointPath).configure(conf, classLoader)) + } + this.tEnv = TableEnvironment.getTableEnvironment(env) + FailingCollectionSource.failedBefore = true + } + + @After + def after(): Unit = { + Assert.assertTrue(FailingCollectionSource.failedBefore) + } + + /** + * Creates a BinaryRow DataStream from the given non-empty [[Seq]]. + */ + def failingBinaryRowSource[T: TypeInformation](data: Seq[T]): DataStream[BaseRow] = { + val typeInfo = implicitly[TypeInformation[_]].asInstanceOf[CompositeType[_]] + val result = new mutable.MutableList[BaseRow] + val reuse = new BinaryRow(typeInfo.getArity) + val writer = new BinaryRowWriter(reuse) + data.foreach { + case p: Product => + for (i <- 0 until typeInfo.getArity) { + val fieldType = typeInfo.getTypeAt(i).asInstanceOf[TypeInformation[_]] + fieldType match { + case Types.INT => writer.writeInt(i, p.productElement(i).asInstanceOf[Int]) + case Types.LONG => writer.writeLong(i, p.productElement(i).asInstanceOf[Long]) + case Types.STRING => writer.writeString(i, + p.productElement(i).asInstanceOf[BinaryString]) + case Types.BOOLEAN => writer.writeBoolean(i, p.productElement(i).asInstanceOf[Boolean]) + } + } + writer.complete() + result += reuse.copy() + case _ => throw new UnsupportedOperationException + } + val newTypeInfo = createInternalTypeFromTypeInfo(typeInfo).asInstanceOf[RowType].toTypeInfo + failingDataSource(result)(newTypeInfo.asInstanceOf[TypeInformation[BaseRow]]) + } + + /** + * Creates a DataStream from the given non-empty [[Seq]]. + */ + def retainStateDataSource[T: TypeInformation](data: Seq[T]): DataStream[T] = { + env.enableCheckpointing(100, CheckpointingMode.EXACTLY_ONCE) + env.setRestartStrategy(RestartStrategies.noRestart()) + env.setParallelism(1) + // reset failedBefore flag to false + FailingCollectionSource.reset() + + require(data != null, "Data must not be null.") + val typeInfo = implicitly[TypeInformation[T]] + + val collection = scala.collection.JavaConversions.asJavaCollection(data) + // must not have null elements and mixed elements + FromElementsFunction.checkCollection(data, typeInfo.getTypeClass) + + val function = new FailingCollectionSource[T]( + typeInfo.createSerializer(env.getConfig), + collection, + data.length, // fail after half elements + true) + + env.addSource(function)(typeInfo).setMaxParallelism(1) + } + + /** + * Creates a DataStream from the given non-empty [[Seq]]. + */ + def failingDataSource[T: TypeInformation](data: Seq[T]): DataStream[T] = { + env.enableCheckpointing(100, CheckpointingMode.EXACTLY_ONCE) + env.setRestartStrategy(RestartStrategies.fixedDelayRestart(1, 0)) + env.setParallelism(1) + // reset failedBefore flag to false + FailingCollectionSource.reset() + + require(data != null, "Data must not be null.") + val typeInfo = implicitly[TypeInformation[T]] + + val collection = scala.collection.JavaConversions.asJavaCollection(data) + // must not have null elements and mixed elements + FromElementsFunction.checkCollection(data, typeInfo.getTypeClass) + + val function = new FailingCollectionSource[T]( + typeInfo.createSerializer(env.getConfig), + collection, + data.length / 2, // fail after half elements + false) + + env.addSource(function)(typeInfo).setMaxParallelism(1) + } + + private def mapStrEquals(str1: String, str2: String): Boolean = { + val array1 = str1.toCharArray + val array2 = str2.toCharArray + if (array1.length != array2.length) { + return false + } + val l = array1.length + val leftBrace = "{".charAt(0) + val rightBrace = "}".charAt(0) + val equalsChar = "=".charAt(0) + val lParenthesis = "(".charAt(0) + val rParenthesis = ")".charAt(0) + val dot = ",".charAt(0) + val whiteSpace = " ".charAt(0) + val map1 = Map[String, String]() + val map2 = Map[String, String]() + var idx = 0 + def findEquals(ss: CharSequence): Array[Int] = { + val ret = new ArrayBuffer[Int]() + (0 until ss.length) foreach (idx => if (ss.charAt(idx) == equalsChar) ret += idx) + ret.toArray + } + + def splitKV(ss: CharSequence, equalsIdx: Int): (String, String) = { + // find right, if starts with '(' find until the ')', else until the ',' + var endFlag = false + var curIdx = equalsIdx + 1 + var endChar = if (ss.charAt(curIdx) == lParenthesis) rParenthesis else dot + var valueStr: CharSequence = null + var keyStr: CharSequence = null + while (curIdx < ss.length && !endFlag) { + val curChar = ss.charAt(curIdx) + if (curChar != endChar && curChar != rightBrace) { + curIdx += 1 + if (curIdx == ss.length) { + valueStr = ss.subSequence(equalsIdx + 1, curIdx) + } + } else { + valueStr = ss.subSequence(equalsIdx + 1, curIdx) + endFlag = true + } + } + + // find left, if starts with ')' find until the '(', else until the ' ,' + endFlag = false + curIdx = equalsIdx - 1 + endChar = if (ss.charAt(curIdx) == rParenthesis) lParenthesis else whiteSpace + while (curIdx >= 0 && !endFlag) { + val curChar = ss.charAt(curIdx) + if (curChar != endChar && curChar != leftBrace) { + curIdx -= 1 + if (curIdx == -1) { + keyStr = ss.subSequence(0, equalsIdx) + } + } else { + keyStr = ss.subSequence(curIdx, equalsIdx) + endFlag = true + } + } + require(keyStr != null) + require(valueStr != null) + (keyStr.toString, valueStr.toString) + } + + def appendStrToMap(ss: CharSequence, m: Map[String, String]): Unit = { + val equalsIdxs = findEquals(ss) + equalsIdxs.foreach (idx => m + splitKV(ss, idx)) + } + + while (idx < l) { + val char1 = array1(idx) + val char2 = array2(idx) + if (char1 != char2) { + return false + } + + if (char1 == leftBrace) { + val rightBraceIdx = array1.subSequence(idx + 1, l).toString.indexOf(rightBrace) + appendStrToMap(array1.subSequence(idx + 1, rightBraceIdx + idx + 2), map1) + idx += rightBraceIdx + } else { + idx += 1 + } + } + map1.equals(map2) + } + + def assertMapStrEquals(str1: String, str2: String): Unit = { + if (!mapStrEquals(str1, str2)) { + throw new AssertionError(s"Expected: $str1 \n Actual: $str2") + } + } +} + +object StreamingWithStateTestBase { + + case class StateBackendMode(backend: String) { + override def toString: String = backend.toString + } + + val HEAP_BACKEND = StateBackendMode("HEAP") + val ROCKSDB_BACKEND = StateBackendMode("ROCKSDB") + + @Parameterized.Parameters(name = "StateBackend={0}") + def parameters(): util.Collection[Array[java.lang.Object]] = { + Seq[Array[AnyRef]](Array(HEAP_BACKEND), Array(ROCKSDB_BACKEND)) + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TableUtil.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TableUtil.scala index 5ec370015c4b14..5bcc372091395c 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TableUtil.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TableUtil.scala @@ -22,7 +22,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.table.`type`.TypeConverters.createExternalTypeInfoFromInternalType import org.apache.flink.table.api.{BatchTableEnvironment, TableImpl} import org.apache.flink.table.calcite.FlinkTypeFactory -import org.apache.flink.table.sinks.{CollectRowTableSink, CollectTableSink} +import org.apache.flink.table.sinks.{CollectRowTableSink, CollectTableSink, TableSink} import org.apache.flink.types.Row import _root_.scala.collection.JavaConversions._ @@ -60,4 +60,15 @@ object TableUtil { BatchTableEnvUtil.collect(table.tableEnv.asInstanceOf[BatchTableEnvironment], table, configuredSink.asInstanceOf[CollectTableSink[T]], jobName) } + + def writeToSink(table: TableImpl, sink: TableSink[_]): Unit = { + // get schema information of table + val rowType = table.getRelNode.getRowType + val fieldNames = rowType.getFieldNames.asScala.toArray + val fieldTypes = rowType.getFieldList + .map(field => FlinkTypeFactory.toInternalType(field.getType)).toArray + val configuredSink = sink.configure( + fieldNames, fieldTypes.map(createExternalTypeInfoFromInternalType)) + table.tableEnv.writeToSink(table, configuredSink) + } } diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TimeTestUtil.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TimeTestUtil.scala new file mode 100644 index 00000000000000..df1dd3f435c52f --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TimeTestUtil.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.utils + +import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks +import org.apache.flink.streaming.api.functions.source.SourceFunction +import org.apache.flink.streaming.api.functions.source.SourceFunction.SourceContext +import org.apache.flink.streaming.api.operators.{AbstractStreamOperator, OneInputStreamOperator} +import org.apache.flink.streaming.api.watermark.Watermark +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord + +object TimeTestUtil { + + class EventTimeSourceFunction[T]( + dataWithTimestampList: Seq[Either[(Long, T), Long]]) extends SourceFunction[T] { + + override def run(ctx: SourceContext[T]): Unit = { + dataWithTimestampList.foreach { + case Left(t) => ctx.collectWithTimestamp(t._2, t._1) + case Right(w) => ctx.emitWatermark(new Watermark(w)) + } + } + + override def cancel(): Unit = ??? + } + + class TimestampAndWatermarkWithOffset[T <: Product]( + offset: Long) extends AssignerWithPunctuatedWatermarks[T] { + + override def checkAndGetNextWatermark(lastElement: T, extractedTimestamp: Long): Watermark = { + new Watermark(extractedTimestamp - offset) + } + + override def extractTimestamp(element: T, previousElementTimestamp: Long): Long = { + element.productElement(0).asInstanceOf[Long] + } + } + + class EventTimeProcessOperator[T] + extends AbstractStreamOperator[T] + with OneInputStreamOperator[Either[(Long, T), Long], T] { + + override def processElement(element: StreamRecord[Either[(Long, T), Long]]): Unit = { + element.getValue match { + case Left(t) => output.collect(new StreamRecord[T](t._2, t._1)) + case Right(w) => output.emitWatermark(new Watermark(w)) + } + } + + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/api/TableConfigOptions.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/api/TableConfigOptions.java index b7a28b7425e4b9..41badd23451206 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/api/TableConfigOptions.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/api/TableConfigOptions.java @@ -97,6 +97,15 @@ public class TableConfigOptions { .withDescription("Default parallelism of the job. If any node do not have special parallelism, use it." + "Its default value is the num of cpu cores in the client host."); + // ------------------------------------------------------------------------ + // topN Options + // ------------------------------------------------------------------------ + + public static final ConfigOption SQL_EXEC_TOPN_CACHE_SIZE = + key("sql.exec.topn.cache.size") + .defaultValue(10000L) + .withDescription("Cache size of every topn task."); + // ------------------------------------------------------------------------ // MiniBatch Options // ------------------------------------------------------------------------ @@ -106,6 +115,33 @@ public class TableConfigOptions { .defaultValue(Long.MIN_VALUE) .withDescription("MiniBatch allow latency(ms). Value > 0 means MiniBatch enabled."); + public static final ConfigOption SQL_EXEC_MINIBATCH_SIZE = + key("sql.exec.mini-batch.size") + .defaultValue(Long.MIN_VALUE) + .withDescription("The maximum number of inputs that MiniBatch buffer can accommodate."); + + public static final ConfigOption SQL_EXEC_MINI_BATCH_FLUSH_BEFORE_SNAPSHOT = + key("sql.exec.mini-batch.flush-before-snapshot") + .defaultValue(true) + .withDescription("Whether to enable flushing buffered data before snapshot."); + + // ------------------------------------------------------------------------ + // State Options + // ------------------------------------------------------------------------ + + public static final ConfigOption SQL_EXEC_STATE_TTL_MS = + key("sql.exec.state.ttl.ms") + .defaultValue(Long.MIN_VALUE) + .withDescription("The minimum time until state that was not updated will be retained. State" + + " might be cleared and removed if it was not updated for the defined period of time."); + + public static final ConfigOption SQL_EXEC_STATE_TTL_MAX_MS = + key("sql.exec.state.ttl.max.ms") + .defaultValue(Long.MIN_VALUE) + .withDescription("The maximum time until state which was not updated will be retained." + + "State will be cleared and removed if it was not updated for the defined " + + "period of time."); + // ------------------------------------------------------------------------ // STATE BACKEND Options // ------------------------------------------------------------------------ diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryRow.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryRow.java index 9be8db632248d2..1514f2b9043267 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryRow.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryRow.java @@ -428,4 +428,20 @@ public static String toOriginString(BaseRow row, InternalType[] types) { build.append(']'); return build.toString(); } + + public boolean equalsWithoutHeader(BaseRow o) { + return equalsFrom(o, 1); + } + + private boolean equalsFrom(Object o, int startIndex) { + if (o != null && o instanceof BinaryRow) { + BinaryRow other = (BinaryRow) o; + return sizeInBytes == other.sizeInBytes && + SegmentsUtil.equals( + segments, offset + startIndex, + other.segments, other.offset + startIndex, sizeInBytes - startIndex); + } else { + return false; + } + } } diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryWriter.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryWriter.java index b1730ec3c3892c..0adbfc2aa23e4f 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryWriter.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryWriter.java @@ -24,6 +24,7 @@ import org.apache.flink.table.type.InternalTypes; import org.apache.flink.table.type.MapType; import org.apache.flink.table.type.RowType; +import org.apache.flink.table.type.TimestampType; import org.apache.flink.table.typeutils.BaseRowSerializer; /** @@ -102,7 +103,7 @@ static void write(BinaryWriter writer, int pos, Object o, InternalType type) { writer.writeInt(pos, (int) o); } else if (type.equals(InternalTypes.TIME)) { writer.writeInt(pos, (int) o); - } else if (type.equals(InternalTypes.TIMESTAMP)) { + } else if (type instanceof TimestampType) { writer.writeLong(pos, (long) o); } else if (type instanceof DecimalType) { DecimalType decimalType = (DecimalType) type; diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/DataFormatConverters.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/DataFormatConverters.java index 101ff35d162535..ebc59bc3f3bda2 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/DataFormatConverters.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/DataFormatConverters.java @@ -146,9 +146,9 @@ public static DataFormatConverter getConverterForTypeInfo(TypeInformation typeIn } else if (typeInfo instanceof BinaryMapTypeInfo) { return BinaryMapConverter.INSTANCE; } else if (typeInfo instanceof BaseRowTypeInfo) { - return BaseRowConverter.INSTANCE; + return new BaseRowConverter(typeInfo.getArity()); } else if (typeInfo.equals(BasicTypeInfo.BIG_DEC_TYPE_INFO)) { - return BaseRowConverter.INSTANCE; + return new BaseRowConverter(typeInfo.getArity()); } else if (typeInfo instanceof DecimalTypeInfo) { DecimalTypeInfo decimalType = (DecimalTypeInfo) typeInfo; return new DecimalConverter(decimalType.precision(), decimalType.scale()); @@ -993,14 +993,13 @@ E toExternalImpl(BaseRow row, int column) { public static final class BaseRowConverter extends IdentityConverter { private static final long serialVersionUID = -4470307402371540680L; + private int arity; - public static final BaseRowConverter INSTANCE = new BaseRowConverter(); - - private BaseRowConverter() {} + private BaseRowConverter(int arity) {} @Override BaseRow toExternalImpl(BaseRow row, int column) { - throw new RuntimeException("Not support yet!"); + return row.getRow(column, arity); } } diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/TypeGetterSetters.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/TypeGetterSetters.java index f1e080ce89d575..71c323f5914db3 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/TypeGetterSetters.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/TypeGetterSetters.java @@ -24,6 +24,7 @@ import org.apache.flink.table.type.InternalTypes; import org.apache.flink.table.type.MapType; import org.apache.flink.table.type.RowType; +import org.apache.flink.table.type.TimestampType; /** * Provide type specialized getters and setters to reduce if/else and eliminate box and unbox. @@ -197,7 +198,7 @@ static Object get(TypeGetterSetters row, int ordinal, InternalType type) { return row.getInt(ordinal); } else if (type.equals(InternalTypes.TIME)) { return row.getInt(ordinal); - } else if (type.equals(InternalTypes.TIMESTAMP)) { + } else if (type instanceof TimestampType) { return row.getLong(ordinal); } else if (type instanceof DecimalType) { DecimalType decimalType = (DecimalType) type; diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/util/BinaryRowUtil.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/util/BinaryRowUtil.java index 71dc6c2eacd7b6..e50e5463c03274 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/util/BinaryRowUtil.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/util/BinaryRowUtil.java @@ -18,7 +18,9 @@ package org.apache.flink.table.dataformat.util; +import org.apache.flink.core.memory.MemorySegmentFactory; import org.apache.flink.core.memory.MemoryUtils; +import org.apache.flink.table.dataformat.BinaryRow; /** * Util for binary row. Many of the methods in this class are used in code generation. @@ -29,6 +31,14 @@ public class BinaryRowUtil { public static final sun.misc.Unsafe UNSAFE = MemoryUtils.UNSAFE; public static final int BYTE_ARRAY_BASE_OFFSET = UNSAFE.arrayBaseOffset(byte[].class); + public static final BinaryRow EMPTY_ROW = new BinaryRow(0); + + static { + int size = EMPTY_ROW.getFixedLengthPartSize(); + byte[] bytes = new byte[size]; + EMPTY_ROW.pointTo(MemorySegmentFactory.wrap(bytes), 0, size); + } + public static boolean byteArrayEquals(byte[] left, byte[] right, int length) { return byteArrayEquals( left, BYTE_ARRAY_BASE_OFFSET, right, BYTE_ARRAY_BASE_OFFSET, length); diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunction.java new file mode 100644 index 00000000000000..f35d24953b8598 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunction.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.deduplicate; + +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.generated.GeneratedRecordEqualiser; +import org.apache.flink.table.generated.RecordEqualiser; +import org.apache.flink.table.runtime.context.ExecutionContext; +import org.apache.flink.table.runtime.functions.ProcessFunctionWithCleanupState; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; +import org.apache.flink.util.Collector; + +/** + * This function is used to deduplicate on keys and keeps only first row or last row. + */ +public class DeduplicateFunction + extends ProcessFunctionWithCleanupState + implements DeduplicateFunctionBase { + + private final BaseRowTypeInfo rowTypeInfo; + private final boolean generateRetraction; + private final boolean keepLastRow; + protected ValueState pkRow; + private GeneratedRecordEqualiser generatedEqualiser; + private transient RecordEqualiser equaliser; + + public DeduplicateFunction( + long minRetentionTime, + long maxRetentionTime, + BaseRowTypeInfo rowTypeInfo, + boolean generateRetraction, + boolean keepLastRow, + GeneratedRecordEqualiser generatedEqualiser) { + super(minRetentionTime, maxRetentionTime); + this.rowTypeInfo = rowTypeInfo; + this.generateRetraction = generateRetraction; + this.keepLastRow = keepLastRow; + this.generatedEqualiser = generatedEqualiser; + } + + @Override + public void open(ExecutionContext ctx) throws Exception { + super.open(ctx); + String stateName = keepLastRow ? "DeduplicateFunctionCleanupTime" : "DeduplicateFunctionCleanupTime"; + initCleanupTimeState(stateName); + ValueStateDescriptor rowStateDesc = new ValueStateDescriptor("rowState", rowTypeInfo); + pkRow = ctx.getRuntimeContext().getState(rowStateDesc); + equaliser = generatedEqualiser.newInstance(ctx.getRuntimeContext().getUserCodeClassLoader()); + } + + @Override + public void processElement(BaseRow input, Context ctx, Collector out) throws Exception { + long currentTime = ctx.timerService().currentProcessingTime(); + // register state-cleanup timer + registerProcessingCleanupTimer(ctx, currentTime); + + BaseRow preRow = pkRow.value(); + if (keepLastRow) { + processLastRow(preRow, input, generateRetraction, stateCleaningEnabled, pkRow, equaliser, out); + } else { + processFirstRow(preRow, input, generateRetraction, stateCleaningEnabled, pkRow, equaliser, out); + } + } + + @Override + public void close() throws Exception { + super.close(); + } + + @Override + public void onTimer( + long timestamp, + OnTimerContext ctx, + Collector out) throws Exception { + if (needToCleanupState(timestamp)) { + cleanupState(pkRow); + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionBase.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionBase.java new file mode 100644 index 00000000000000..22293ae75beb0b --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionBase.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.deduplicate; + +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.dataformat.util.BaseRowUtil; +import org.apache.flink.table.generated.RecordEqualiser; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +/** + * Base class to deduplicate on keys and keeps only first row or last row. + */ +public interface DeduplicateFunctionBase { + + default void processLastRow(BaseRow preRow, BaseRow currentRow, boolean generateRetraction, + boolean stateCleaningEnabled, ValueState pkRow, RecordEqualiser equaliser, + Collector out) throws Exception { + // should be accumulate msg. + Preconditions.checkArgument(BaseRowUtil.isAccumulateMsg(currentRow)); + // ignore same record + if (!stateCleaningEnabled && preRow != null && + equaliser.equalsWithoutHeader(preRow, currentRow)) { + return; + } + pkRow.update(currentRow); + if (preRow != null && generateRetraction) { + preRow.setHeader(BaseRowUtil.RETRACT_MSG); + out.collect(preRow); + } + out.collect(currentRow); + } + + default void processFirstRow(BaseRow preRow, BaseRow currentRow, boolean generateRetraction, + boolean stateCleaningEnabled, ValueState pkRow, RecordEqualiser equaliser, + Collector out) throws Exception { + // should be accumulate msg. + Preconditions.checkArgument(BaseRowUtil.isAccumulateMsg(currentRow)); + // ignore record with timestamp bigger than preRow + if (!isFirstRow(preRow)) { + return; + } + + pkRow.update(currentRow); + out.collect(currentRow); + } + + default boolean isFirstRow(BaseRow preRow) { + return preRow == null; + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/ProcessFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/ProcessFunction.java new file mode 100644 index 00000000000000..b8dd970f6c40b2 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/ProcessFunction.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.functions; + +import org.apache.flink.api.common.functions.Function; +import org.apache.flink.api.common.functions.RuntimeContext; +import org.apache.flink.streaming.api.TimeDomain; +import org.apache.flink.streaming.api.TimerService; +import org.apache.flink.table.runtime.context.ExecutionContext; +import org.apache.flink.util.Collector; + +/** + * A function that processes elements of a stream. + * + *

For every element in the input stream {@link #processElement(Object, Context, Collector)} + * is invoked. This can produce zero or more elements as output. Implementations can also + * query the time and set timers through the provided {@link Context}. For firing timers + * {@link #onTimer(long, OnTimerContext, Collector)} will be invoked. This can again produce + * zero or more elements as output and register further timers. + * + *

NOTE: Access to keyed state and timers (which are also scoped to a key) is only + * available if the {@code ProcessFunction} is applied on a {@code KeyedStream}. + * + * @param Type of the input elements. + * @param Type of the output elements. + */ +public abstract class ProcessFunction implements Function { + + private static final long serialVersionUID = 1L; + + protected transient ExecutionContext executionContext; + + public void open(ExecutionContext ctx) throws Exception { + this.executionContext = ctx; + } + + public void close() throws Exception {} + + public void endInput(Collector out) throws Exception {} + + protected RuntimeContext getRuntimeContext() { + if (this.executionContext != null) { + return this.executionContext.getRuntimeContext(); + } else { + throw new IllegalStateException("The stream exec runtime context has not been initialized."); + } + } + + /** + * Process one element from the input stream. + * + *

This function can output zero or more elements using the {@link Collector} parameter + * and also update internal state or set timers using the {@link Context} parameter. + * + * @param input The input value. + * @param ctx A {@link Context} that allows querying the timestamp of the element and getting + * a {@link TimerService} for registering timers and querying the time. The + * context is only valid during the invocation of this method, do not store it. + * @param out The collector for returning result values. + * + * @throws Exception This method may throw exceptions. Throwing an exception will cause the operation + * to fail and may trigger recovery. + */ + public abstract void processElement(I input, Context ctx, Collector out) throws Exception; + + /** + * Called when a timer set using {@link TimerService} fires. + * + * @param timestamp The timestamp of the firing timer. + * @param ctx An {@link OnTimerContext} that allows querying the timestamp of the firing timer, + * querying the {@link TimeDomain} of the firing timer and getting a + * {@link TimerService} for registering timers and querying the time. + * The context is only valid during the invocation of this method, do not store it. + * @param out The collector for returning result values. + * + * @throws Exception This method may throw exceptions. Throwing an exception will cause the operation + * to fail and may trigger recovery. + */ + public void onTimer(long timestamp, OnTimerContext ctx, Collector out) throws Exception {} + + /** + * Information available in an invocation of {@link #processElement(Object, Context, Collector)} + * or {@link #onTimer(long, OnTimerContext, Collector)}. + */ + public abstract static class Context { + + /** + * A {@link TimerService} for querying time and registering timers. + */ + public abstract TimerService timerService(); + } + + /** + * Information available in an invocation of {@link #onTimer(long, OnTimerContext, Collector)}. + */ + public abstract static class OnTimerContext extends Context { + /** + * The {@link TimeDomain} of the firing timer. + */ + public abstract TimeDomain timeDomain(); + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/ProcessFunctionWithCleanupState.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/ProcessFunctionWithCleanupState.java new file mode 100644 index 00000000000000..a841fe7d7b5774 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/ProcessFunctionWithCleanupState.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.functions; + +import org.apache.flink.api.common.state.State; +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.streaming.api.TimeDomain; + +/** + * A function that processes elements of a stream, and could cleanup state. + * @param Type of the input elements. + * @param Type of the output elements. + */ +public abstract class ProcessFunctionWithCleanupState extends ProcessFunction { + + protected final long minRetentionTime; + protected final long maxRetentionTime; + protected final boolean stateCleaningEnabled; + + // holds the latest registered cleanup timer + private ValueState cleanupTimeState; + + public ProcessFunctionWithCleanupState(long minRetentionTime, long maxRetentionTime) { + this.minRetentionTime = minRetentionTime; + this.maxRetentionTime = maxRetentionTime; + this.stateCleaningEnabled = minRetentionTime > 1; + } + + protected void initCleanupTimeState(String stateName) { + if (stateCleaningEnabled) { + ValueStateDescriptor inputCntDescriptor = new ValueStateDescriptor( + stateName, + Types.LONG); + cleanupTimeState = executionContext.getRuntimeContext().getState(inputCntDescriptor); + } + } + + protected void registerProcessingCleanupTimer(Context ctx, long currentTime) throws Exception { + if (stateCleaningEnabled) { + // last registered timer + Long curCleanupTime = cleanupTimeState.value(); + + // check if a cleanup timer is registered and + // that the current cleanup timer won't delete state we need to keep + if (curCleanupTime == null || (currentTime + minRetentionTime) > curCleanupTime) { + // we need to register a new (later) timer + Long cleanupTime = currentTime + maxRetentionTime; + // register timer and remember clean-up time + ctx.timerService().registerProcessingTimeTimer(cleanupTime); + cleanupTimeState.update(cleanupTime); + } + } + } + + protected boolean isProcessingTimeTimer(OnTimerContext ctx) { + return ctx.timeDomain() == TimeDomain.PROCESSING_TIME; + } + + protected boolean needToCleanupState(long timestamp) throws Exception { + if (stateCleaningEnabled) { + Long cleanupTime = cleanupTimeState.value(); + // check that the triggered timer is the last registered processing time timer. + return null != cleanupTime && timestamp == cleanupTime; + } else { + return false; + } + } + + protected void cleanupState(State... states) { + for (State state : states) { + state.clear(); + } + this.cleanupTimeState.clear(); + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keySelector/BaseRowKeySelector.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keySelector/BaseRowKeySelector.java new file mode 100644 index 00000000000000..01a5f581ac9e54 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keySelector/BaseRowKeySelector.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.keySelector; + +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.typeutils.ResultTypeQueryable; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; + +/** + * BaseRowKeySelector takes an BaseRow and returns the deterministic key for that BaseRow. + */ +public interface BaseRowKeySelector extends KeySelector, ResultTypeQueryable { + + BaseRowTypeInfo getProducedType(); + +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keySelector/BinaryRowKeySelector.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keySelector/BinaryRowKeySelector.java new file mode 100644 index 00000000000000..acfb0aaab3cb7b --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keySelector/BinaryRowKeySelector.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.keySelector; + +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.dataformat.BinaryRow; +import org.apache.flink.table.generated.GeneratedProjection; +import org.apache.flink.table.generated.Projection; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; + +/** + * A KeySelector which will extract key from BaseRow. + */ +public class BinaryRowKeySelector implements BaseRowKeySelector { + + private static final long serialVersionUID = 5375355285015381919L; + + private final BaseRowTypeInfo keyRowType; + private final GeneratedProjection generatedProjection; + private transient Projection projection; + + public BinaryRowKeySelector(BaseRowTypeInfo keyRowType, GeneratedProjection generatedProjection) { + this.keyRowType = keyRowType; + this.generatedProjection = generatedProjection; + } + + @Override + public BaseRow getKey(BaseRow value) throws Exception { + if (projection == null) { + projection = generatedProjection.newInstance(Thread.currentThread().getContextClassLoader()); + } + return projection.apply(value).copy(); + } + + @Override + public BaseRowTypeInfo getProducedType() { + return keyRowType; + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keySelector/NullBinaryRowKeySelector.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keySelector/NullBinaryRowKeySelector.java new file mode 100644 index 00000000000000..8b673247d0b624 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keySelector/NullBinaryRowKeySelector.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.keySelector; + +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.dataformat.util.BinaryRowUtil; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; + +/** + * A KeySelector which key is always empty. + */ +public class NullBinaryRowKeySelector implements BaseRowKeySelector { + + private final BaseRowTypeInfo returnType = new BaseRowTypeInfo(); + + @Override + public BaseRow getKey(BaseRow value) throws Exception { + return BinaryRowUtil.EMPTY_ROW; + } + + @Override + public BaseRowTypeInfo getProducedType() { + return returnType; + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyed/KeyedProcessOperator.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyed/KeyedProcessOperator.java new file mode 100644 index 00000000000000..36c3e270f4606b --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyed/KeyedProcessOperator.java @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.keyed; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.flink.runtime.state.CheckpointListener; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.runtime.state.VoidNamespace; +import org.apache.flink.runtime.state.VoidNamespaceSerializer; +import org.apache.flink.streaming.api.SimpleTimerService; +import org.apache.flink.streaming.api.TimeDomain; +import org.apache.flink.streaming.api.TimerService; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.ChainingStrategy; +import org.apache.flink.streaming.api.operators.InternalTimer; +import org.apache.flink.streaming.api.operators.InternalTimerService; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.Triggerable; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.util.functions.StreamingFunctionUtils; +import org.apache.flink.table.runtime.context.ExecutionContextImpl; +import org.apache.flink.table.runtime.functions.ProcessFunction; +import org.apache.flink.table.runtime.util.StreamRecordCollector; + +import static org.apache.flink.util.Preconditions.checkNotNull; +import static org.apache.flink.util.Preconditions.checkState; + +/** + * A {@link org.apache.flink.streaming.api.operators.StreamOperator} for executing keyed + * {@link ProcessFunction ProcessFunctions}. + */ +public class KeyedProcessOperator + extends AbstractStreamOperator + implements OneInputStreamOperator, Triggerable { + + private static final long serialVersionUID = 1L; + + private ProcessFunction function; + + private transient StreamRecordCollector collector; + + private transient ContextImpl context; + + private transient OnTimerContextImpl onTimerContext; + + /** Flag to prevent duplicate function.close() calls in close() and dispose(). */ + private transient boolean functionsClosed = false; + + public KeyedProcessOperator(ProcessFunction function) { + this.function = function; + + chainingStrategy = ChainingStrategy.ALWAYS; + } + + @Override + public void open() throws Exception { + super.open(); + + function.open(new ExecutionContextImpl(this, getRuntimeContext())); + + collector = new StreamRecordCollector<>(output); + + InternalTimerService internalTimerService = + getInternalTimerService("user-timers", VoidNamespaceSerializer.INSTANCE, this); + + TimerService timerService = new SimpleTimerService(internalTimerService); + + context = new ContextImpl(timerService); + onTimerContext = new OnTimerContextImpl(timerService); + } + + + // ------------------------------------------------------------------------ + // checkpointing and recovery + // ------------------------------------------------------------------------ + + @Override + public void notifyCheckpointComplete(long checkpointId) throws Exception { + super.notifyCheckpointComplete(checkpointId); + + if (function instanceof CheckpointListener) { + ((CheckpointListener) function).notifyCheckpointComplete(checkpointId); + } + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + StreamingFunctionUtils.snapshotFunctionState(context, getOperatorStateBackend(), function); + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + StreamingFunctionUtils.restoreFunctionState(context, function); + } + + @Override + public void close() throws Exception { + super.close(); + functionsClosed = true; + function.close(); + } + + @Override + public void dispose() throws Exception { + super.dispose(); + if (!functionsClosed) { + functionsClosed = true; + function.close(); + } + } + + @Override + public void onEventTime(InternalTimer timer) throws Exception { + setCurrentKey(timer.getKey()); + + onTimerContext.timeDomain = TimeDomain.EVENT_TIME; + function.onTimer(timer.getTimestamp(), onTimerContext, collector); + onTimerContext.timeDomain = null; + } + + @Override + public void onProcessingTime(InternalTimer timer) throws Exception { + setCurrentKey(timer.getKey()); + + onTimerContext.timeDomain = TimeDomain.PROCESSING_TIME; + function.onTimer(timer.getTimestamp(), onTimerContext, collector); + onTimerContext.timeDomain = null; + } + + @Override + public void processElement(StreamRecord element) throws Exception { + function.processElement(element.getValue(), context, collector); + } + + private class ContextImpl extends ProcessFunction.Context { + + private final TimerService timerService; + + ContextImpl(TimerService timerService) { + this.timerService = checkNotNull(timerService); + } + + @Override + public TimerService timerService() { + return timerService; + } + } + + private class OnTimerContextImpl extends ProcessFunction.OnTimerContext{ + + private final TimerService timerService; + + private TimeDomain timeDomain; + + OnTimerContextImpl(TimerService timerService) { + this.timerService = checkNotNull(timerService); + } + + @Override + public TimerService timerService() { + return timerService; + } + + @Override + public TimeDomain timeDomain() { + checkState(timeDomain != null); + return timeDomain; + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractRankFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractRankFunction.java new file mode 100644 index 00000000000000..dba75de9c9fe1c --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractRankFunction.java @@ -0,0 +1,306 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.metrics.Counter; +import org.apache.flink.metrics.Gauge; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.dataformat.GenericRow; +import org.apache.flink.table.dataformat.JoinedRow; +import org.apache.flink.table.dataformat.util.BaseRowUtil; +import org.apache.flink.table.generated.GeneratedRecordComparator; +import org.apache.flink.table.generated.GeneratedRecordEqualiser; +import org.apache.flink.table.generated.RecordEqualiser; +import org.apache.flink.table.runtime.context.ExecutionContext; +import org.apache.flink.table.runtime.functions.ProcessFunctionWithCleanupState; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; +import org.apache.flink.util.Collector; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collection; +import java.util.Comparator; +import java.util.Iterator; +import java.util.Map; + +/** + * Base class for Rank Function. + */ +public abstract class AbstractRankFunction extends ProcessFunctionWithCleanupState { + + private static final Logger LOG = LoggerFactory.getLogger(AbstractRankFunction.class); + + // we set default topn size to 100 + private static final long DEFAULT_TOPN_SIZE = 100; + + private final RankRange rankRange; + private final GeneratedRecordEqualiser generatedEqualiser; + private final boolean generateRetraction; + protected final boolean isRowNumberAppend; + protected final RankType rankType; + protected final BaseRowTypeInfo inputRowType; + protected final BaseRowTypeInfo outputRowType; + protected final GeneratedRecordComparator generatedRecordComparator; + protected final KeySelector sortKeySelector; + + protected boolean isConstantRankEnd; + protected long rankEnd = -1; + protected long rankStart = -1; + protected RecordEqualiser equaliser; + private int rankEndIndex; + private ValueState rankEndState; + private Counter invalidCounter; + private JoinedRow outputRow; + + // metrics + protected long hitCount = 0L; + protected long requestCount = 0L; + + public AbstractRankFunction( + long minRetentionTime, long maxRetentionTime, BaseRowTypeInfo inputRowType, BaseRowTypeInfo outputRowType, + GeneratedRecordComparator generatedRecordComparator, KeySelector sortKeySelector, + RankType rankType, RankRange rankRange, GeneratedRecordEqualiser generatedEqualiser, boolean generateRetraction) { + super(minRetentionTime, maxRetentionTime); + this.rankRange = rankRange; + this.generatedEqualiser = generatedEqualiser; + this.generateRetraction = generateRetraction; + this.rankType = rankType; + // TODO support RANK and DENSE_RANK + switch (rankType) { + case ROW_NUMBER: + break; + case RANK: + LOG.error("RANK() on streaming table is not supported currently"); + throw new UnsupportedOperationException("RANK() on streaming table is not supported currently"); + case DENSE_RANK: + LOG.error("DENSE_RANK() on streaming table is not supported currently"); + throw new UnsupportedOperationException("DENSE_RANK() on streaming table is not supported currently"); + default: + LOG.error("Streaming tables do not support {}", rankType.name()); + throw new UnsupportedOperationException("Streaming tables do not support " + rankType.toString()); + } + this.inputRowType = inputRowType; + this.outputRowType = outputRowType; + this.isRowNumberAppend = inputRowType.getArity() + 1 == outputRowType.getArity(); + this.generatedRecordComparator = generatedRecordComparator; + this.sortKeySelector = sortKeySelector; + } + + @Override + public void open(ExecutionContext ctx) throws Exception { + super.open(ctx); + initCleanupTimeState("RankFunctionCleanupTime"); + outputRow = new JoinedRow(); + + // variable rank limit + if (rankRange instanceof ConstantRankRange) { + ConstantRankRange constantRankRange = (ConstantRankRange) rankRange; + isConstantRankEnd = true; + rankEnd = constantRankRange.getRankEnd(); + rankStart = constantRankRange.getRankStart(); + } else if (rankRange instanceof VariableRankRange) { + VariableRankRange variableRankRange = (VariableRankRange) rankRange; + isConstantRankEnd = false; + rankEndIndex = variableRankRange.getRankEndIndex(); + ValueStateDescriptor rankStateDesc = new ValueStateDescriptor("rankEnd", Types.LONG); + rankEndState = ctx.getRuntimeContext().getState(rankStateDesc); + } + equaliser = generatedEqualiser.newInstance(ctx.getRuntimeContext().getUserCodeClassLoader()); + invalidCounter = ctx.getRuntimeContext().getMetricGroup().counter("topn.invalidTopSize"); + } + + protected long getDefaultTopSize() { + return isConstantRankEnd ? rankEnd : DEFAULT_TOPN_SIZE; + } + + protected long initRankEnd(BaseRow row) throws Exception { + if (isConstantRankEnd) { + return rankEnd; + } else { + Long rankEndValue = rankEndState.value(); + long curRankEnd = row.getLong(rankEndIndex); + if (rankEndValue == null) { + rankEnd = curRankEnd; + rankEndState.update(rankEnd); + return rankEnd; + } else { + rankEnd = rankEndValue; + if (rankEnd != curRankEnd) { + // increment the invalid counter when the current rank end + // not equal to previous rank end + invalidCounter.inc(); + } + return rankEnd; + } + } + } + + protected Tuple2 rowNumber(K sortKey, BaseRow rowKey, SortedMap sortedMap) { + Iterator>> iterator = sortedMap.entrySet().iterator(); + int curRank = 1; + while (iterator.hasNext()) { + Map.Entry> entry = iterator.next(); + K curKey = entry.getKey(); + Collection rowKeys = entry.getValue(); + if (curKey.equals(sortKey)) { + Iterator rowKeysIter = rowKeys.iterator(); + int innerRank = 1; + while (rowKeysIter.hasNext()) { + if (rowKey.equals(rowKeysIter.next())) { + return Tuple2.of(curRank, innerRank); + } else { + innerRank += 1; + curRank += 1; + } + } + } else { + curRank += rowKeys.size(); + } + } + LOG.error("Failed to find the sortKey: {}, rowkey: {} in SortedMap. " + + "This should never happen", sortKey, rowKey); + throw new RuntimeException( + "Failed to find the sortKey, rowkey in SortedMap. This should never happen"); + } + + /** + * return true if record should be put into sort map. + */ + protected boolean checkSortKeyInBufferRange(K sortKey, SortedMap sortedMap, Comparator sortKeyComparator) { + Map.Entry> worstEntry = sortedMap.lastEntry(); + if (worstEntry == null) { + // sort map is empty + return true; + } else { + K worstKey = worstEntry.getKey(); + int compare = sortKeyComparator.compare(sortKey, worstKey); + if (compare < 0) { + return true; + } else if (sortedMap.getCurrentTopNum() < getMaxSortMapSize()) { + return true; + } else { + return false; + } + } + } + + protected void registerMetric(long heapSize) { + executionContext.getRuntimeContext().getMetricGroup().>gauge( + "topn.cache.hitRate", + new Gauge() { + + @Override + public Double getValue() { + return requestCount == 0 ? 1.0 : + Long.valueOf(hitCount).doubleValue() / requestCount; + } + }); + + executionContext.getRuntimeContext().getMetricGroup().>gauge( + "topn.cache.size", + new Gauge() { + + @Override + public Long getValue() { + return heapSize; + } + }); + } + + protected void collect(Collector out, BaseRow inputRow) { + BaseRowUtil.setAccumulate(inputRow); + out.collect(inputRow); + } + + /** + * This is similar to [[retract()]] but always send retraction message regardless of + * generateRetraction is true or not + */ + protected void delete(Collector out, BaseRow inputRow) { + BaseRowUtil.setRetract(inputRow); + out.collect(inputRow); + } + + /** + * This is with-row-number version of above delete() method + */ + protected void delete(Collector out, BaseRow inputRow, long rank) { + if (isInRankRange(rank)) { + out.collect(createOutputRow(inputRow, rank, BaseRowUtil.RETRACT_MSG)); + } + } + + protected void collect(Collector out, BaseRow inputRow, long rank) { + if (isInRankRange(rank)) { + out.collect(createOutputRow(inputRow, rank, BaseRowUtil.ACCUMULATE_MSG)); + } + } + + protected void retract(Collector out, BaseRow inputRow, long rank) { + if (generateRetraction && isInRankRange(rank)) { + out.collect(createOutputRow(inputRow, rank, BaseRowUtil.RETRACT_MSG)); + } + } + + protected boolean isInRankEnd(long rank) { + return rank <= rankEnd; + } + + protected boolean isInRankRange(long rank) { + return rank <= rankEnd && rank >= rankStart; + } + + protected boolean hasOffset() { + // rank start is 1-based + return rankStart > 1; + } + + protected BaseRow createOutputRow(BaseRow inputRow, long rank, byte header) { + if (isRowNumberAppend) { + GenericRow rankRow = new GenericRow(1); + rankRow.setField(0, rank); + + outputRow.replace(inputRow, rankRow); + outputRow.setHeader(header); + return outputRow; + } else { + inputRow.setHeader(header); + return inputRow; + } + } + + /** + * get sorted map size limit + * Implementations may vary depending on each rank who has in-memory sort map. + * @return + */ + protected abstract long getMaxSortMapSize(); + + @Override + public void processElement( + BaseRow input, Context ctx, Collector out) throws Exception { + + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractUpdateRankFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractUpdateRankFunction.java new file mode 100644 index 00000000000000..bc56aa2365ea83 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractUpdateRankFunction.java @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +import org.apache.flink.api.common.state.MapState; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.runtime.state.FunctionInitializationContext; +import org.apache.flink.runtime.state.FunctionSnapshotContext; +import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.generated.GeneratedRecordComparator; +import org.apache.flink.table.generated.GeneratedRecordEqualiser; +import org.apache.flink.table.runtime.context.ExecutionContext; +import org.apache.flink.table.runtime.util.LRUMap; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; +import org.apache.flink.util.Collector; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collection; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.TreeMap; +import java.util.function.Supplier; + +/** + * Base class for Update Rank Function. + */ +abstract class AbstractUpdateRankFunction extends AbstractRankFunction + implements CheckpointedFunction { + + private static final Logger LOG = LoggerFactory.getLogger(AbstractUpdateRankFunction.class); + + private final BaseRowTypeInfo rowKeyType; + private final long cacheSize; + + // a map state stores mapping from row key to record which is in topN + // in tuple2, f0 is the record row, f1 is the index in the list of the same sort_key + // the f1 is used to preserve the record order in the same sort_key + protected transient MapState> dataState; + + // a sorted map stores mapping from sort key to rowkey list + protected transient SortedMap sortedMap; + + protected transient Map> kvSortedMap; + + // a HashMap stores mapping from rowkey to record, a heap mirror to dataState + protected transient Map rowKeyMap; + + protected transient LRUMap> kvRowKeyMap; + + protected Comparator sortKeyComparator; + + public AbstractUpdateRankFunction( + long minRetentionTime, + long maxRetentionTime, + BaseRowTypeInfo inputRowType, + BaseRowTypeInfo outputRowType, + BaseRowTypeInfo rowKeyType, + GeneratedRecordComparator generatedRecordComparator, + KeySelector sortKeySelector, + RankType rankType, + RankRange rankRange, + GeneratedRecordEqualiser generatedEqualiser, + boolean generateRetraction, + long cacheSize) { + super(minRetentionTime, + maxRetentionTime, + inputRowType, + outputRowType, + generatedRecordComparator, + sortKeySelector, + rankType, + rankRange, + generatedEqualiser, + generateRetraction); + this.rowKeyType = rowKeyType; + this.cacheSize = cacheSize; + } + + @Override + public void open(ExecutionContext ctx) throws Exception { + super.open(ctx); + int lruCacheSize = Math.max(1, (int) (cacheSize / getMaxSortMapSize())); + // make sure the cached map is in a fixed size, avoid OOM + kvSortedMap = new HashMap<>(lruCacheSize); + kvRowKeyMap = new LRUMap<>(lruCacheSize, new CacheRemovalListener()); + + LOG.info("Top{} operator is using LRU caches key-size: {}", getMaxSortMapSize(), lruCacheSize); + + TupleTypeInfo> valueTypeInfo = new TupleTypeInfo<>(inputRowType, Types.INT); + MapStateDescriptor> mapStateDescriptor = new MapStateDescriptor( + "data-state-with-update", rowKeyType, valueTypeInfo); + dataState = ctx.getRuntimeContext().getMapState(mapStateDescriptor); + + // metrics + registerMetric(kvSortedMap.size() * getMaxSortMapSize()); + + sortKeyComparator = generatedRecordComparator.newInstance(ctx.getRuntimeContext().getUserCodeClassLoader()); + } + + @Override + public void onTimer( + long timestamp, + OnTimerContext ctx, + Collector out) throws Exception { + if (needToCleanupState(timestamp)) { + BaseRow partitionKey = executionContext.currentKey(); + // cleanup cache + kvRowKeyMap.remove(partitionKey); + kvSortedMap.remove(partitionKey); + cleanupState(dataState); + } + } + + @Override + public void snapshotState(FunctionSnapshotContext context) throws Exception { + Iterator>> iter = kvRowKeyMap.entrySet().iterator(); + while (iter.hasNext()) { + Map.Entry> entry = iter.next(); + BaseRow partitionKey = entry.getKey(); + Map currentRowKeyMap = entry.getValue(); + executionContext.setCurrentKey(partitionKey); + synchronizeState(currentRowKeyMap); + } + } + + @Override + public void initializeState(FunctionInitializationContext context) throws Exception { + // nothing to do + } + + protected void initHeapStates() throws Exception { + requestCount += 1; + BaseRow partitionKey = executionContext.currentKey(); + sortedMap = kvSortedMap.get(partitionKey); + rowKeyMap = kvRowKeyMap.get(partitionKey); + if (sortedMap == null) { + sortedMap = new SortedMap( + sortKeyComparator, + new Supplier>() { + + @Override + public Collection get() { + return new LinkedHashSet<>(); + } + }); + rowKeyMap = new HashMap<>(); + kvSortedMap.put(partitionKey, sortedMap); + kvRowKeyMap.put(partitionKey, rowKeyMap); + + // restore sorted map + Iterator>> iter = dataState.iterator(); + if (iter != null) { + // a temp map associate sort key to tuple2 + Map> tempSortedMap = new HashMap<>(); + while (iter.hasNext()) { + Map.Entry> entry = iter.next(); + BaseRow rowkey = entry.getKey(); + Tuple2 recordAndInnerRank = entry.getValue(); + BaseRow record = recordAndInnerRank.f0; + Integer innerRank = recordAndInnerRank.f1; + rowKeyMap.put(rowkey, new RankRow(record, innerRank, false)); + + // insert into temp sort map to preserve the record order in the same sort key + BaseRow sortKey = sortKeySelector.getKey(record); + TreeMap treeMap = tempSortedMap.get(sortKey); + if (treeMap == null) { + treeMap = new TreeMap<>(); + tempSortedMap.put(sortKey, treeMap); + } + treeMap.put(innerRank, rowkey); + } + + // build sorted map from the temp map + Iterator>> tempIter = + tempSortedMap.entrySet().iterator(); + while (tempIter.hasNext()) { + Map.Entry> entry = tempIter.next(); + BaseRow sortKey = entry.getKey(); + TreeMap treeMap = entry.getValue(); + Iterator> treeMapIter = treeMap.entrySet().iterator(); + while (treeMapIter.hasNext()) { + Map.Entry treeMapEntry = treeMapIter.next(); + Integer innerRank = treeMapEntry.getKey(); + BaseRow recordRowKey = treeMapEntry.getValue(); + int size = sortedMap.put(sortKey, recordRowKey); + if (innerRank != size) { + LOG.warn("Failed to build sorted map from state, this may result in wrong result. " + + "The sort key is {}, partition key is {}, " + + "treeMap is {}. The expected inner rank is {}, " + + "but current size is {}", + sortKey, partitionKey, treeMap, innerRank, size); + } + } + } + } + } else { + hitCount += 1; + } + } + + private void synchronizeState(Map curRowKeyMap) throws Exception { + Iterator> iter = curRowKeyMap.entrySet().iterator(); + while (iter.hasNext()) { + Map.Entry entry = iter.next(); + BaseRow key = entry.getKey(); + RankRow rankRow = entry.getValue(); + if (rankRow.dirty) { + // should update state + dataState.put(key, Tuple2.of(rankRow.row, rankRow.innerRank)); + rankRow.dirty = false; + } + } + } + + class CacheRemovalListener implements LRUMap.RemovalListener> { + + @Override + public void onRemoval(Map.Entry> eldest) { + BaseRow previousKey = executionContext.currentKey(); + BaseRow partitionKey = eldest.getKey(); + Map currentRowKeyMap = eldest.getValue(); + executionContext.setCurrentKey(partitionKey); + kvSortedMap.remove(partitionKey); + try { + synchronizeState(currentRowKeyMap); + } catch (Throwable e) { + LOG.error("Fail to synchronize state!"); + throw new RuntimeException(e); + } + executionContext.setCurrentKey(previousKey); + } + } + + protected class RankRow { + private final BaseRow row; + private int innerRank; + private boolean dirty; + + protected RankRow(BaseRow row, int innerRank, boolean dirty) { + this.row = row; + this.innerRank = innerRank; + this.dirty = dirty; + } + + protected BaseRow getRow() { + return row; + } + + protected int getInnerRank() { + return innerRank; + } + + protected boolean isDirty() { + return dirty; + } + + public void setInnerRank(int innerRank) { + this.innerRank = innerRank; + } + + public void setDirty(boolean dirty) { + this.dirty = dirty; + } + } +} + + + diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AppendRankFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AppendRankFunction.java new file mode 100644 index 00000000000000..0c316a412c32c0 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AppendRankFunction.java @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.state.MapState; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.typeutils.ListTypeInfo; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.generated.GeneratedRecordComparator; +import org.apache.flink.table.generated.GeneratedRecordEqualiser; +import org.apache.flink.table.runtime.context.ExecutionContext; +import org.apache.flink.table.runtime.util.LRUMap; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; +import org.apache.flink.util.Collector; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; + +/** + * RankFunction in Append Stream mode. + */ +public class AppendRankFunction extends AbstractRankFunction { + + private static final Logger LOG = LoggerFactory.getLogger(AppendRankFunction.class); + + protected final BaseRowTypeInfo sortKeyType; + private final TypeSerializer inputRowSer; + private final long cacheSize; + + // a map state stores mapping from sort key to records list which is in topN + private transient MapState> dataState; + + // a sorted map stores mapping from sort key to records list, a heap mirror to dataState + private transient SortedMap sortedMap; + private transient Map> kvSortedMap; + private Comparator sortKeyComparator; + + public AppendRankFunction( + long minRetentionTime, long maxRetentionTime, BaseRowTypeInfo inputRowType, BaseRowTypeInfo outputRowType, + BaseRowTypeInfo sortKeyType, GeneratedRecordComparator generatedRecordComparator, + KeySelector sortKeySelector, RankType rankType, RankRange rankRange, + GeneratedRecordEqualiser generatedEqualiser, boolean generateRetraction, long cacheSize) { + super(minRetentionTime, maxRetentionTime, inputRowType, outputRowType, + generatedRecordComparator, sortKeySelector, rankType, rankRange, generatedEqualiser, generateRetraction); + this.sortKeyType = sortKeyType; + this.inputRowSer = inputRowType.createSerializer(new ExecutionConfig()); + this.cacheSize = cacheSize; + } + + public void open(ExecutionContext ctx) throws Exception { + super.open(ctx); + int lruCacheSize = Math.max(1, (int) (cacheSize / getDefaultTopSize())); + kvSortedMap = new LRUMap<>(lruCacheSize); + LOG.info("Top{} operator is using LRU caches key-size: {}", getDefaultTopSize(), lruCacheSize); + + ListTypeInfo valueTypeInfo = new ListTypeInfo<>(inputRowType); + MapStateDescriptor> mapStateDescriptor = new MapStateDescriptor( + "data-state-with-append", sortKeyType, valueTypeInfo); + dataState = ctx.getRuntimeContext().getMapState(mapStateDescriptor); + + sortKeyComparator = generatedRecordComparator.newInstance(ctx.getRuntimeContext().getUserCodeClassLoader()); + + // metrics + registerMetric(kvSortedMap.size() * getDefaultTopSize()); + } + + @Override + public void processElement( + BaseRow input, Context context, Collector out) throws Exception { + long currentTime = context.timerService().currentProcessingTime(); + // register state-cleanup timer + registerProcessingCleanupTimer(context, currentTime); + + initHeapStates(); + initRankEnd(input); + + BaseRow sortKey = sortKeySelector.getKey(input); + // check whether the sortKey is in the topN range + if (checkSortKeyInBufferRange(sortKey, sortedMap, sortKeyComparator)) { + // insert sort key into sortedMap + sortedMap.put(sortKey, inputRowSer.copy(input)); + Collection inputs = sortedMap.get(sortKey); + // update data state + dataState.put(sortKey, (List) inputs); + if (isRowNumberAppend || hasOffset()) { + // the without-number-algorithm can't handle topn with offset, + // so use the with-number-algorithm to handle offset + emitRecordsWithRowNumber(sortKey, input, out); + } else { + processElementWithoutRowNumber(input, out); + } + } + } + + private void initHeapStates() throws Exception { + requestCount += 1; + BaseRow currentKey = executionContext.currentKey(); + sortedMap = kvSortedMap.get(currentKey); + if (sortedMap == null) { + sortedMap = new SortedMap(sortKeyComparator, new Supplier>() { + + @Override + public Collection get() { + return new ArrayList<>(); + } + }); + kvSortedMap.put(currentKey, sortedMap); + // restore sorted map + Iterator>> iter = dataState.iterator(); + if (iter != null) { + while (iter.hasNext()) { + Map.Entry> entry = iter.next(); + BaseRow sortKey = entry.getKey(); + List values = entry.getValue(); + // the order is preserved + sortedMap.putAll(sortKey, values); + } + } + } else { + hitCount += 1; + } + } + + private void emitRecordsWithRowNumber(BaseRow sortKey, BaseRow input, Collector out) throws Exception { + Iterator>> iterator = sortedMap.entrySet().iterator(); + long curRank = 0L; + boolean findSortKey = false; + while (iterator.hasNext() && isInRankEnd(curRank)) { + Map.Entry> entry = iterator.next(); + Collection records = entry.getValue(); + // meet its own sort key + if (!findSortKey && entry.getKey().equals(sortKey)) { + curRank += records.size(); + collect(out, input, curRank); + findSortKey = true; + } else if (findSortKey) { + Iterator recordsIter = records.iterator(); + while (recordsIter.hasNext() && isInRankEnd(curRank)) { + curRank += 1; + BaseRow prevRow = recordsIter.next(); + retract(out, prevRow, curRank - 1); + collect(out, prevRow, curRank); + } + } else { + curRank += records.size(); + } + } + + List toDeleteKeys = new ArrayList<>(); + // remove the records associated to the sort key which is out of topN + while (iterator.hasNext()) { + Map.Entry> entry = iterator.next(); + BaseRow key = entry.getKey(); + dataState.remove(key); + toDeleteKeys.add(key); + } + for (BaseRow toDeleteKey : toDeleteKeys) { + sortedMap.removeAll(toDeleteKey); + } + } + + private void processElementWithoutRowNumber(BaseRow input, Collector out) throws Exception { + // remove retired element + if (sortedMap.getCurrentTopNum() > rankEnd) { + Map.Entry> lastEntry = sortedMap.lastEntry(); + BaseRow lastKey = lastEntry.getKey(); + List lastList = (List) lastEntry.getValue(); + // remove last one + BaseRow lastElement = lastList.remove(lastList.size() - 1); + if (lastList.isEmpty()) { + sortedMap.removeAll(lastKey); + dataState.remove(lastKey); + } else { + dataState.put(lastKey, lastList); + } + // lastElement shouldn't be null + delete(out, lastElement); + } + collect(out, input); + } + + @Override + public void onTimer( + long timestamp, + OnTimerContext ctx, + Collector out) throws Exception { + if (needToCleanupState(timestamp)) { + // cleanup cache + kvSortedMap.remove(executionContext.currentKey()); + cleanupState(dataState); + } + } + + @Override + protected long getMaxSortMapSize() { + return getDefaultTopSize(); + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/ConstantRankRange.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/ConstantRankRange.java new file mode 100644 index 00000000000000..b4bb201308293b --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/ConstantRankRange.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +import java.util.List; + +/** rankStart and rankEnd are inclusive, rankStart always start from one. */ +public class ConstantRankRange implements RankRange { + + private static final long serialVersionUID = 9062345289888078376L; + private long rankStart; + private long rankEnd; + + public ConstantRankRange(long rankStart, long rankEnd) { + this.rankStart = rankStart; + this.rankEnd = rankEnd; + } + + public long getRankStart() { + return rankStart; + } + + public long getRankEnd() { + return rankEnd; + } + + @Override + public String toString(List inputFieldNames) { + return toString(); + } + + @Override + public String toString() { + return "rankStart=" + rankStart + ", rankEnd=" + rankEnd; + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/ConstantRankRangeWithoutEnd.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/ConstantRankRangeWithoutEnd.java new file mode 100644 index 00000000000000..0d0ddd3dcc2feb --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/ConstantRankRangeWithoutEnd.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +import java.util.List; + +/** ConstantRankRangeWithoutEnd is a RankRange which not specify RankEnd. */ +public class ConstantRankRangeWithoutEnd implements RankRange { + + private static final long serialVersionUID = -1944057111062598696L; + + private final long rankStart; + + public ConstantRankRangeWithoutEnd(long rankStart) { + this.rankStart = rankStart; + } + + @Override + public String toString(List inputFieldNames) { + return toString(); + } + + @Override + public String toString() { + return "rankStart=" + rankStart; + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RankRange.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RankRange.java new file mode 100644 index 00000000000000..0b16c236024d1a --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RankRange.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +import java.io.Serializable; +import java.util.List; + +/** + * RankRange for Rank, including following 3 types : + * ConstantRankRange, ConstantRankRangeWithoutEnd, VariableRankRange + */ +public interface RankRange extends Serializable { + + String toString(List inputFieldNames); + +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RankType.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RankType.java new file mode 100644 index 00000000000000..d6573dd96b24b4 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RankType.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +/** + * An enumeration of rank type, usable to show how exactly generate rank number. + */ +public enum RankType { + + /** + * Returns a unique sequential number for each row within the partition based on the order, + * starting at 1 for the first row in each partition and without repeating or skipping + * numbers in the ranking result of each partition. If there are duplicate values within the + * row set, the ranking numbers will be assigned arbitrarily. + */ + ROW_NUMBER, + + /** + * Returns a unique rank number for each distinct row within the partition based on the order, + * starting at 1 for the first row in each partition, with the same rank for duplicate values + * and leaving gaps between the ranks; this gap appears in the sequence after the duplicate + * values. + */ + RANK, + + /** + * is similar to the RANK by generating a unique rank number for each distinct row + * within the partition based on the order, starting at 1 for the first row in each partition, + * ranking the rows with equal values with the same rank number, except that it does not skip + * any rank, leaving no gaps between the ranks. + */ + DENSE_RANK +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RetractRankFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RetractRankFunction.java new file mode 100644 index 00000000000000..5477d860359c2f --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RetractRankFunction.java @@ -0,0 +1,263 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +import org.apache.flink.api.common.state.MapState; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.typeutils.ListTypeInfo; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.dataformat.util.BaseRowUtil; +import org.apache.flink.table.generated.GeneratedRecordComparator; +import org.apache.flink.table.generated.GeneratedRecordEqualiser; +import org.apache.flink.table.runtime.context.ExecutionContext; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; +import org.apache.flink.table.typeutils.SortedMapTypeInfo; +import org.apache.flink.util.Collector; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.SortedMap; +import java.util.TreeMap; + +/** + * RankFunction in Update Stream mode. + */ +public class RetractRankFunction extends AbstractRankFunction { + + private static final Logger LOG = LoggerFactory.getLogger(RetractRankFunction.class); + + /** + * Message to indicate the state is cleared because of ttl restriction. The message could be + * used to output to log. + */ + private static final String STATE_CLEARED_WARN_MSG = "The state is cleared because of state ttl. " + + "This will result in incorrect result. " + + "You can increase the state ttl to avoid this."; + + protected final BaseRowTypeInfo sortKeyType; + + // flag to skip records with non-exist error instead to fail, true by default. + private final boolean lenient = true; + + // a map state stores mapping from sort key to records list + private transient MapState> dataState; + + // a sorted map stores mapping from sort key to records count + private transient ValueState> treeMap; + + private Comparator sortKeyComparator; + + public RetractRankFunction( + long minRetentionTime, long maxRetentionTime, BaseRowTypeInfo inputRowType, BaseRowTypeInfo outputRowType, + BaseRowTypeInfo sortKeyType, GeneratedRecordComparator generatedRecordComparator, + KeySelector sortKeySelector, RankType rankType, RankRange rankRange, + GeneratedRecordEqualiser generatedEqualiser, boolean generateRetraction) { + super(minRetentionTime, maxRetentionTime, inputRowType, outputRowType, + generatedRecordComparator, sortKeySelector, rankType, rankRange, generatedEqualiser, generateRetraction); + this.sortKeyType = sortKeyType; + } + + @Override + public void open(ExecutionContext ctx) throws Exception { + super.open(ctx); + ListTypeInfo valueTypeInfo = new ListTypeInfo<>(inputRowType); + MapStateDescriptor> mapStateDescriptor = new MapStateDescriptor( + "data-state", sortKeyType, valueTypeInfo); + dataState = ctx.getRuntimeContext().getMapState(mapStateDescriptor); + + sortKeyComparator = generatedRecordComparator.newInstance(ctx.getRuntimeContext().getUserCodeClassLoader()); + ValueStateDescriptor> valueStateDescriptor = new ValueStateDescriptor( + "sorted-map", + new SortedMapTypeInfo(sortKeyType, BasicTypeInfo.LONG_TYPE_INFO, sortKeyComparator) + ); + treeMap = ctx.getRuntimeContext().getState(valueStateDescriptor); + } + + @Override + public void processElement( + BaseRow input, Context ctx, Collector out) throws Exception { + + initRankEnd(input); + + SortedMap sortedMap = treeMap.value(); + if (sortedMap == null) { + sortedMap = new TreeMap<>(sortKeyComparator); + } + BaseRow sortKey = sortKeySelector.getKey(input); + + if (BaseRowUtil.isAccumulateMsg(input)) { + // update sortedMap + if (sortedMap.containsKey(sortKey)) { + sortedMap.put(sortKey, sortedMap.get(sortKey) + 1); + } else { + sortedMap.put(sortKey, 1L); + } + + // emit + emitRecordsWithRowNumber(sortedMap, sortKey, input, out); + + // update data state + List inputs = dataState.get(sortKey); + if (inputs == null) { + // the sort key is never seen + inputs = new ArrayList<>(); + } + inputs.add(input); + dataState.put(sortKey, inputs); + } else { + // retract input + + // emit updates first + retractRecordWithRowNumber(sortedMap, sortKey, input, out); + + // and then update sortedMap + if (sortedMap.containsKey(sortKey)) { + long count = sortedMap.get(sortKey) - 1; + if (count == 0) { + sortedMap.remove(sortKey); + } else { + sortedMap.put(sortKey, count); + } + } else { + if (sortedMap.isEmpty()) { + if (lenient) { + LOG.warn(STATE_CLEARED_WARN_MSG); + } else { + throw new RuntimeException(STATE_CLEARED_WARN_MSG); + } + } else { + throw new RuntimeException( + "Can not retract a non-existent record: ${inputBaseRow.toString}. " + + "This should never happen."); + } + } + + } + treeMap.update(sortedMap); + } + + // ------------- ROW_NUMBER------------------------------- + + private void retractRecordWithRowNumber( + SortedMap sortedMap, BaseRow sortKey, BaseRow inputRow, Collector out) throws Exception { + Iterator> iterator = sortedMap.entrySet().iterator(); + long curRank = 0L; + boolean findSortKey = false; + while(iterator.hasNext() && isInRankEnd(curRank)) { + Map.Entry entry = iterator.next(); + BaseRow key = entry.getKey(); + if (!findSortKey && key.equals(sortKey)) { + List inputs = dataState.get(key); + if (inputs == null) { + // Skip the data if it's state is cleared because of state ttl. + if (lenient) { + LOG.warn(STATE_CLEARED_WARN_MSG); + } else { + throw new RuntimeException(STATE_CLEARED_WARN_MSG); + } + } else { + Iterator inputIter = inputs.iterator(); + while (inputIter.hasNext() && isInRankEnd(curRank)) { + curRank += 1; + BaseRow prevRow = inputIter.next(); + if (!findSortKey && equaliser.equalsWithoutHeader(prevRow, inputRow)) { + delete(out, prevRow, curRank); + curRank -= 1; + findSortKey = true; + inputIter.remove(); + } else if (findSortKey) { + retract(out, prevRow, curRank + 1); + collect(out, prevRow, curRank); + } + } + if (inputs.isEmpty()) { + dataState.remove(key); + } else { + dataState.put(key, inputs); + } + } + } else if (findSortKey) { + List inputs = dataState.get(key); + int i = 0; + while (i < inputs.size() && isInRankEnd(curRank)) { + curRank += 1; + BaseRow prevRow = inputs.get(i); + retract(out, prevRow, curRank + 1); + collect(out, prevRow, curRank); + i++; + } + } else { + curRank += entry.getValue(); + } + } + } + + private void emitRecordsWithRowNumber( + SortedMap sortedMap, BaseRow sortKey, BaseRow inputRow, Collector out) throws Exception { + Iterator> iterator = sortedMap.entrySet().iterator(); + long curRank = 0L; + boolean findSortKey = false; + while (iterator.hasNext() && isInRankRange(curRank)) { + Map.Entry entry = iterator.next(); + BaseRow key = entry.getKey(); + if (!findSortKey && key.equals(sortKey)) { + curRank += entry.getValue(); + collect(out, inputRow, curRank); + findSortKey = true; + } else if (findSortKey) { + List inputs = dataState.get(key); + if (inputs == null) { + // Skip the data if it's state is cleared because of state ttl. + if (lenient) { + LOG.warn(STATE_CLEARED_WARN_MSG); + } else { + throw new RuntimeException(STATE_CLEARED_WARN_MSG); + } + } else { + int i = 0; + while (i < inputs.size() && isInRankRange(curRank)) { + curRank += 1; + BaseRow prevRow = inputs.get(i); + retract(out, prevRow, curRank - 1); + collect(out, prevRow, curRank); + i++; + } + } + } else { + curRank += entry.getValue(); + } + } + } + + @Override + protected long getMaxSortMapSize() { + // just let it go, retract rank has no interest in this + return 0L; + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/SortedMap.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/SortedMap.java new file mode 100644 index 00000000000000..e1e1e5b433c741 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/SortedMap.java @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +import org.apache.flink.table.dataformat.BaseRow; + +import java.util.Collection; +import java.util.Comparator; +import java.util.function.Supplier; +import java.util.Iterator; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; + +/** + * SortedMap stores mapping from sort key to records list, each record is BaseRow type. + * SortedMap could also track rank number of each records. + * + * @param Type of the sort key + */ +public class SortedMap { + + private final Supplier> valueSupplier; + private int currentTopNum = 0; + private TreeMap> treeMap; + + public SortedMap(Comparator sortKeyComparator, Supplier> valueSupplier) { + this.valueSupplier = valueSupplier; + this.treeMap = new TreeMap(sortKeyComparator); + } + + /** + * Appends a record into the SortedMap under the sortKey + * + * @param sortKey sort key with which the specified value is to be associated + * @param value record which is to be appended + * + * @return the size of the collection under the sortKey. + */ + public int put(T sortKey, BaseRow value) { + currentTopNum += 1; + // update treeMap + Collection collection = treeMap.get(sortKey); + if (collection == null) { + collection = valueSupplier.get(); + treeMap.put(sortKey, collection); + } + collection.add(value); + return collection.size(); + } + + /** + * Puts a record list into the SortedMap under the sortKey + * Note: if SortedMap already contains sortKey, putAll will overwrite the previous value + * + * @param sortKey sort key with which the specified values are to be associated + * @param values record lists to be associated with the specified key + */ + public void putAll(T sortKey, Collection values) { + treeMap.put(sortKey, values); + currentTopNum += values.size(); + } + + /** + * Get the record list from SortedMap under the sortKey + * + * @param sortKey key to get + * + * @return the record list from SortedMap under the sortKey + */ + public Collection get(T sortKey) { + return treeMap.get(sortKey); + } + + public void remove(T sortKey, BaseRow value) { + Collection list = treeMap.get(sortKey); + if (list != null) { + if (list.remove(value)) { + currentTopNum -= 1; + } + if (list.size() == 0) { + treeMap.remove(sortKey); + } + } + } + + /** + * Remove all record list from SortedMap under the sortKey + * + * @param sortKey key to remove + */ + public void removeAll(T sortKey) { + Collection list = treeMap.get(sortKey); + if (list != null) { + currentTopNum -= list.size(); + treeMap.remove(sortKey); + } + } + + /** + * Remove the last record of the last Entry in the TreeMap (according to the TreeMap's + * key-sort function). + * + * @return removed record + */ + public BaseRow removeLast() { + Map.Entry> last = treeMap.lastEntry(); + BaseRow lastElement = null; + if (last != null) { + Collection list = last.getValue(); + lastElement = getLastElement(list); + if (lastElement != null) { + if (list.remove(lastElement)) { + currentTopNum -= 1; + } + if (list.size() == 0) { + treeMap.remove(last.getKey()); + } + } + } + return lastElement; + } + + /** + * Get record which rank is given value. + * + * @param rank rank value to search + * + * @return the record which rank is given value + */ + public BaseRow getElement(int rank) { + int curRank = 0; + Iterator>> iter = treeMap.entrySet().iterator(); + while (iter.hasNext()) { + Map.Entry> entry = iter.next(); + Collection list = entry.getValue(); + + Iterator listIter = list.iterator(); + while (listIter.hasNext()) { + BaseRow elem = listIter.next(); + curRank += 1; + if (curRank == rank) { + return elem; + } + } + } + return null; + } + + private BaseRow getLastElement(Collection list) { + BaseRow element = null; + if (list != null && !list.isEmpty()) { + Iterator iter = list.iterator(); + while (iter.hasNext()) { + element = iter.next(); + } + } + return element; + } + + /** + * Returns a {@link Set} view of the mappings contained in this map. + */ + public Set>> entrySet() { + return treeMap.entrySet(); + } + + /** + * Returns the last Entry in the TreeMap (according to the TreeMap's + * key-sort function). Returns null if the TreeMap is empty. + */ + public Map.Entry> lastEntry() { + return treeMap.lastEntry(); + } + + /** + * Returns {@code true} if this map contains a mapping for the specified + * key. + * + * @param key key whose presence in this map is to be tested + * + * @return {@code true} if this map contains a mapping for the + * specified key + */ + public boolean containsKey(T key) { + return treeMap.containsKey(key); + } + + /** + * Get number of total records. + * + * @return the number of total records. + */ + public int getCurrentTopNum() { + return currentTopNum; + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/UpdateRankFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/UpdateRankFunction.java new file mode 100644 index 00000000000000..4dfb72a474b534 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/UpdateRankFunction.java @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.generated.GeneratedRecordComparator; +import org.apache.flink.table.generated.GeneratedRecordEqualiser; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; +import org.apache.flink.util.Collector; + +import java.util.Collection; +import java.util.Iterator; +import java.util.Map; + +/** + * A fast version of rank process function which only hold top n data in state, + * and keep sorted map in heap. This only works in some special scenarios, such + * as, rank a count(*) stream + */ +public class UpdateRankFunction extends AbstractUpdateRankFunction { + + private final TypeSerializer inputRowSer; + private final KeySelector rowKeySelector; + + public UpdateRankFunction( + long minRetentionTime, + long maxRetentionTime, + BaseRowTypeInfo inputRowType, + BaseRowTypeInfo outputRowType, + BaseRowTypeInfo rowKeyType, + KeySelector rowKeySelector, + GeneratedRecordComparator generatedRecordComparator, + KeySelector sortKeySelector, + RankType rankType, + RankRange rankRange, + GeneratedRecordEqualiser generatedEqualiser, + boolean generateRetraction, + long cacheSize) { + super(minRetentionTime, + maxRetentionTime, + inputRowType, + outputRowType, + rowKeyType, + generatedRecordComparator, + sortKeySelector, + rankType, + rankRange, + generatedEqualiser, + generateRetraction, + cacheSize); + this.inputRowSer = inputRowType.createSerializer(new ExecutionConfig()); + this.rowKeySelector = rowKeySelector; + } + + @Override + public void processElement( + BaseRow input, Context context, Collector out) throws Exception { + long currentTime = context.timerService().currentProcessingTime(); + // register state-cleanup timer + registerProcessingCleanupTimer(context, currentTime); + + initHeapStates(); + initRankEnd(input); + if (isRowNumberAppend || hasOffset()) { + // the without-number-algorithm can't handle topn with offset, + // so use the with-number-algorithm to handle offset + processElementWithRowNumber(input, out); + } else { + processElementWithoutRowNumber(input, out); + } + } + + @Override + protected long getMaxSortMapSize() { + return getDefaultTopSize(); + } + + private void processElementWithRowNumber(BaseRow inputRow, Collector out) throws Exception { + BaseRow sortKey = sortKeySelector.getKey(inputRow); + BaseRow rowKey = rowKeySelector.getKey(inputRow); + if (rowKeyMap.containsKey(rowKey)) { + // it is an updated record which is in the topN, in this scenario, + // the new sort key must be higher than old sort key, this is guaranteed by rules + RankRow oldRow = rowKeyMap.get(rowKey); + BaseRow oldSortKey = sortKeySelector.getKey(oldRow.getRow()); + if (oldSortKey.equals(sortKey)) { + // sort key is not changed, so the rank is the same, only output the row + Tuple2 rankAndInnerRank = rowNumber(sortKey, rowKey, sortedMap); + int rank = rankAndInnerRank.f0; + int innerRank = rankAndInnerRank.f1; + rowKeyMap.put(rowKey, new RankRow(inputRowSer.copy(inputRow), innerRank, true)); + retract(out, oldRow.getRow(), rank); // retract old record + collect(out, inputRow, rank); + return; + } + + Tuple2 oldRankAndInnerRank = rowNumber(oldSortKey, rowKey, sortedMap); + int oldRank = oldRankAndInnerRank.f0; + // remove old sort key + sortedMap.remove(oldSortKey, rowKey); + // add new sort key + int size = sortedMap.put(sortKey, rowKey); + rowKeyMap.put(rowKey, new RankRow(inputRowSer.copy(inputRow), size, true)); + // update inner rank of records under the old sort key + updateInnerRank(oldSortKey); + + // emit records + emitRecordsWithRowNumber(sortKey, inputRow, out, oldSortKey, oldRow, oldRank); + } else { + // out of topN + } + } + + private void updateInnerRank(BaseRow oldSortKey) { + Collection list = sortedMap.get(oldSortKey); + if (list != null) { + Iterator iter = list.iterator(); + int innerRank = 1; + while (iter.hasNext()) { + BaseRow rowkey = iter.next(); + RankRow row = rowKeyMap.get(rowkey); + if (row.getInnerRank() != innerRank) { + row.setInnerRank(innerRank); + row.setDirty(true); + } + innerRank += 1; + } + } + } + + private void emitRecordsWithRowNumber( + BaseRow sortKey, BaseRow inputRow, Collector out, BaseRow oldSortKey, RankRow oldRow, int oldRank) { + + int oldInnerRank = oldRow == null ? -1 : oldRow.getInnerRank(); + Iterator>> iterator = sortedMap.entrySet() + .iterator(); + int curRank = 0; + // whether we have found the sort key in the sorted tree + boolean findSortKey = false; + while (iterator.hasNext() && isInRankEnd(curRank + 1)) { + Map.Entry> entry = iterator.next(); + BaseRow curKey = entry.getKey(); + Collection rowKeys = entry.getValue(); + // meet its own sort key + if (!findSortKey && curKey.equals(sortKey)) { + curRank += rowKeys.size(); + if (oldRow != null) { + retract(out, oldRow.getRow(), oldRank); + } + collect(out, inputRow, curRank); + findSortKey = true; + } else if (findSortKey) { + if (oldSortKey == null) { + // this is a new row, emit updates for all rows in the topn + Iterator rowKeyIter = rowKeys.iterator(); + while (rowKeyIter.hasNext() && isInRankEnd(curRank + 1)) { + curRank += 1; + BaseRow rowKey = rowKeyIter.next(); + RankRow prevRow = rowKeyMap.get(rowKey); + retract(out, prevRow.getRow(), curRank - 1); + collect(out, prevRow.getRow(), curRank); + } + } else { + // current sort key is higher than old sort key, + // the rank of current record is changed, need to update the following rank + int compare = sortKeyComparator.compare(curKey, oldSortKey); + if (compare <= 0) { + Iterator rowKeyIter = rowKeys.iterator(); + int curInnerRank = 0; + while (rowKeyIter.hasNext() && isInRankEnd(curRank + 1)) { + curRank += 1; + curInnerRank += 1; + if (compare == 0 && curInnerRank >= oldInnerRank) { + // match to the previous position + return; + } + + BaseRow rowKey = rowKeyIter.next(); + RankRow prevRow = rowKeyMap.get(rowKey); + retract(out, prevRow.getRow(), curRank - 1); + collect(out, prevRow.getRow(), curRank); + } + } else { + // current sort key is smaller than old sort key, + // the rank is not changed, so skip + return; + } + } + } else { + curRank += rowKeys.size(); + } + } + } + + private void processElementWithoutRowNumber(BaseRow inputRow, Collector out) throws Exception { + BaseRow sortKey = sortKeySelector.getKey(inputRow); + BaseRow rowKey = rowKeySelector.getKey(inputRow); + if (rowKeyMap.containsKey(rowKey)) { + // it is an updated record which is in the topN, in this scenario, + // the new sort key must be higher than old sort key, this is guaranteed by rules + RankRow oldRow = rowKeyMap.get(rowKey); + BaseRow oldSortKey = sortKeySelector.getKey(oldRow.getRow()); + if (!oldSortKey.equals(sortKey)) { + // remove old sort key + sortedMap.remove(oldSortKey, rowKey); + // add new sort key + int size = sortedMap.put(sortKey, rowKey); + rowKeyMap.put(rowKey, new RankRow(inputRowSer.copy(inputRow), size, true)); + // update inner rank of records under the old sort key + updateInnerRank(oldSortKey); + } else { + // row content may change, so we need to update row in map + rowKeyMap.put(rowKey, new RankRow(inputRowSer.copy(inputRow), oldRow.getInnerRank(), true)); + } + // row content may change, so a retract is needed + retract(out, oldRow.getRow(), oldRow.getInnerRank()); + collect(out, inputRow); + } else if (checkSortKeyInBufferRange(sortKey, sortedMap, sortKeyComparator)) { + // it is an unique record but is in the topN + // insert sort key into sortedMap + int size = sortedMap.put(sortKey, rowKey); + rowKeyMap.put(rowKey, new RankRow(inputRowSer.copy(inputRow), size, true)); + collect(out, inputRow); + // remove retired element + if (sortedMap.getCurrentTopNum() > rankEnd) { + BaseRow lastRowKey = sortedMap.removeLast(); + if (lastRowKey != null) { + RankRow lastRow = rowKeyMap.remove(lastRowKey); + dataState.remove(lastRowKey); + // always send a retraction message + delete(out, lastRow.getRow()); + } + } + } else { + // out of topN + } + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/VariableRankRange.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/VariableRankRange.java new file mode 100644 index 00000000000000..fdd1e660b55621 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/VariableRankRange.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +import java.util.List; + +/** changing rank limit depends on input */ +public class VariableRankRange implements RankRange { + + private static final long serialVersionUID = 5579785886506433955L; + private int rankEndIndex; + + public VariableRankRange(int rankEndIndex) { + this.rankEndIndex = rankEndIndex; + } + + public int getRankEndIndex() { + return rankEndIndex; + } + + @Override + public String toString(List inputFieldNames) { + return "rankEnd" + inputFieldNames.get(rankEndIndex); + } + + @Override + public String toString() { + return "rankEnd=$$" + rankEndIndex; + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/values/ValuesInputFormat.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/values/ValuesInputFormat.java index 45ddd4ce331624..5ad8779059e787 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/values/ValuesInputFormat.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/values/ValuesInputFormat.java @@ -40,8 +40,8 @@ public class ValuesInputFormat implements NonParallelInput, ResultTypeQueryable { private static final Logger LOG = LoggerFactory.getLogger(ValuesInputFormat.class); - private GeneratedInput> generatedInput; - private BaseRowTypeInfo returnType; + private final GeneratedInput> generatedInput; + private final BaseRowTypeInfo returnType; private GenericInputFormat format; public ValuesInputFormat(GeneratedInput> generatedInput, BaseRowTypeInfo returnType) { diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/TypeConverters.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/TypeConverters.java index a7b9199c962ca9..12aaf6192b4bc3 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/TypeConverters.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/TypeConverters.java @@ -93,6 +93,8 @@ public class TypeConverters { internalTypeToInfo.put(InternalTypes.CHAR, BasicTypeInfo.CHAR_TYPE_INFO); internalTypeToInfo.put(InternalTypes.DATE, BasicTypeInfo.INT_TYPE_INFO); internalTypeToInfo.put(InternalTypes.TIMESTAMP, BasicTypeInfo.LONG_TYPE_INFO); + internalTypeToInfo.put(InternalTypes.PROCTIME_INDICATOR, BasicTypeInfo.LONG_TYPE_INFO); + internalTypeToInfo.put(InternalTypes.ROWTIME_INDICATOR, BasicTypeInfo.LONG_TYPE_INFO); internalTypeToInfo.put(InternalTypes.TIME, BasicTypeInfo.INT_TYPE_INFO); internalTypeToInfo.put(InternalTypes.BINARY, PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO); internalTypeToInfo.put(InternalTypes.INTERVAL_MONTHS, BasicTypeInfo.INT_TYPE_INFO); @@ -111,6 +113,8 @@ public class TypeConverters { itToEti.put(InternalTypes.CHAR, BasicTypeInfo.CHAR_TYPE_INFO); itToEti.put(InternalTypes.DATE, SqlTimeTypeInfo.DATE); itToEti.put(InternalTypes.TIMESTAMP, SqlTimeTypeInfo.TIMESTAMP); + itToEti.put(InternalTypes.PROCTIME_INDICATOR, SqlTimeTypeInfo.TIMESTAMP); + itToEti.put(InternalTypes.ROWTIME_INDICATOR, SqlTimeTypeInfo.TIMESTAMP); itToEti.put(InternalTypes.TIME, SqlTimeTypeInfo.TIME); itToEti.put(InternalTypes.BINARY, PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO); INTERNAL_TYPE_TO_EXTERNAL_TYPE_INFO = Collections.unmodifiableMap(itToEti); diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/AbstractMapSerializer.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/AbstractMapSerializer.java new file mode 100644 index 00000000000000..efdaf9acd7972c --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/AbstractMapSerializer.java @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.typeutils; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.Map; + +/** + * Base serializer for {@link Map}. The serializer relies on a key serializer + * and a value serializer for the serialization of the map's key-value pairs. + * + *

The serialization format for the map is as follows: four bytes for the + * length of the map, followed by the serialized representation of each + * key-value pair. To allow null values, each value is prefixed by a null flag. + * + * @param The type of the keys in the map. + * @param The type of the values in the map. + * @param The type of the map. + */ +abstract class AbstractMapSerializer> extends TypeSerializer { + + private static final long serialVersionUID = 1L; + + /** The serializer for the keys in the map */ + final TypeSerializer keySerializer; + + /** The serializer for the values in the map */ + final TypeSerializer valueSerializer; + + /** + * Creates a map serializer that uses the given serializers to serialize the key-value pairs in the map. + * + * @param keySerializer The serializer for the keys in the map + * @param valueSerializer The serializer for the values in the map + */ + AbstractMapSerializer( + TypeSerializer keySerializer, + TypeSerializer valueSerializer + ) { + Preconditions.checkNotNull(keySerializer, + "The key serializer must not be null"); + Preconditions.checkNotNull(valueSerializer, + "The value serializer must not be null."); + this.keySerializer = keySerializer; + this.valueSerializer = valueSerializer; + } + + // ------------------------------------------------------------------------ + + /** + * Returns the serializer for the keys in the map. + * + * @return The serializer for the keys in the map. + */ + public TypeSerializer getKeySerializer() { + return keySerializer; + } + + /** + * Returns the serializer for the values in the map. + * + * @return The serializer for the values in the map. + */ + public TypeSerializer getValueSerializer() { + return valueSerializer; + } + + // ------------------------------------------------------------------------ + // Type Serializer implementation + // ------------------------------------------------------------------------ + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public M copy(M from) { + M newMap = createInstance(); + + for (Map.Entry entry : from.entrySet()) { + K newKey = entry.getKey() == null ? null : keySerializer.copy(entry.getKey()); + V newValue = entry.getValue() == null ? null : valueSerializer.copy(entry.getValue()); + + newMap.put(newKey, newValue); + } + + return newMap; + } + + @Override + public M copy(M from, M reuse) { + return copy(from); + } + + @Override + public int getLength() { + return -1; // var length + } + + @Override + public void serialize(M map, DataOutputView target) throws IOException { + final int size = map.size(); + target.writeInt(size); + + for (Map.Entry entry : map.entrySet()) { + keySerializer.serialize(entry.getKey(), target); + + if (entry.getValue() == null) { + target.writeBoolean(true); + } else { + target.writeBoolean(false); + valueSerializer.serialize(entry.getValue(), target); + } + } + } + + @Override + public M deserialize(DataInputView source) throws IOException { + final int size = source.readInt(); + + final M map = createInstance(); + for (int i = 0; i < size; ++i) { + K key = keySerializer.deserialize(source); + + boolean isNull = source.readBoolean(); + V value = isNull ? null : valueSerializer.deserialize(source); + + map.put(key, value); + } + + return map; + } + + @Override + public M deserialize(M reuse, DataInputView source) throws IOException { + return deserialize(source); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + final int size = source.readInt(); + target.writeInt(size); + + for (int i = 0; i < size; ++i) { + keySerializer.copy(source, target); + + boolean isNull = source.readBoolean(); + target.writeBoolean(isNull); + + if (!isNull) { + valueSerializer.copy(source, target); + } + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + AbstractMapSerializer that = (AbstractMapSerializer) o; + + return keySerializer.equals(that.keySerializer) && + valueSerializer.equals(that.valueSerializer); + } + + @Override + public int hashCode() { + int result = keySerializer.hashCode(); + result = 31 * result + valueSerializer.hashCode(); + return result; + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/AbstractMapTypeInfo.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/AbstractMapTypeInfo.java new file mode 100644 index 00000000000000..eea9649b937559 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/AbstractMapTypeInfo.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.typeutils; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.util.Preconditions; + +import java.util.Map; + +/** + * Base type information for maps. + * + * @param The type of the keys in the map. + * @param The type of the values in the map. + */ +abstract class AbstractMapTypeInfo> extends TypeInformation { + + private static final long serialVersionUID = 1L; + + /* The type information for the keys in the map*/ + final TypeInformation keyTypeInfo; + + /* The type information for the values in the map */ + final TypeInformation valueTypeInfo; + + /** + * Constructor with given type information for the keys and the values in + * the map. + * + * @param keyTypeInfo The type information for the keys in the map. + * @param valueTypeInfo The type information for the values in th map. + */ + AbstractMapTypeInfo( + TypeInformation keyTypeInfo, + TypeInformation valueTypeInfo + ) { + Preconditions.checkNotNull(keyTypeInfo, + "The type information for the keys cannot be null."); + Preconditions.checkNotNull(valueTypeInfo, + "The type information for the values cannot be null."); + this.keyTypeInfo = keyTypeInfo; + this.valueTypeInfo = valueTypeInfo; + } + + /** + * Constructor with the classes of the keys and the values in the map. + * + * @param keyClass The class of the keys in the map. + * @param valueClass The class of the values in the map. + */ + AbstractMapTypeInfo(Class keyClass, Class valueClass) { + Preconditions.checkNotNull(keyClass, + "The key class cannot be null."); + Preconditions.checkNotNull(valueClass, + "The value class cannot be null."); + + this.keyTypeInfo = TypeInformation.of(keyClass); + this.valueTypeInfo = TypeInformation.of(valueClass); + } + + // ------------------------------------------------------------------------ + + /** + * Returns the type information for the keys in the map. + * + * @return The type information for the keys in the map. + */ + public TypeInformation getKeyTypeInfo() { + return keyTypeInfo; + } + + /** + * Returns the type information for the values in the map. + * + * @return The type information for the values in the map. + */ + public TypeInformation getValueTypeInfo() { + return valueTypeInfo; + } + + // ------------------------------------------------------------------------ + + @Override + public boolean isBasicType() { + return false; + } + + @Override + public boolean isTupleType() { + return false; + } + + @Override + public int getArity() { + return 0; + } + + @Override + public int getTotalFields() { + return 2; + } + + @Override + public boolean isKeyType() { + return false; + } + + // ------------------------------------------------------------------------ + + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + AbstractMapTypeInfo that = (AbstractMapTypeInfo) o; + + return keyTypeInfo.equals(that.keyTypeInfo) && + valueTypeInfo.equals(that.valueTypeInfo); + } + + @Override + public int hashCode() { + int result = keyTypeInfo.hashCode(); + result = 31 * result + valueTypeInfo.hashCode(); + return result; + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapSerializer.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapSerializer.java new file mode 100644 index 00000000000000..1c745a40f63dfa --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapSerializer.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.typeutils; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializerConfigSnapshot; +import org.apache.flink.api.common.typeutils.base.MapSerializerConfigSnapshot; +import org.apache.flink.util.Preconditions; + +import java.util.Comparator; +import java.util.SortedMap; +import java.util.TreeMap; + +/** + * A serializer for {@link SortedMap}. The serializer relies on a key serializer + * and a value serializer for the serialization of the map's key-value pairs. + * It also deploys a comparator to ensure the order of the keys. + * + *

The serialization format for the map is as follows: four bytes for the + * length of the map, followed by the serialized representation of each + * key-value pair. To allow null values, each value is prefixed by a null flag. + * + * @param The type of the keys in the map. + * @param The type of the values in the map. + */ +public final class SortedMapSerializer extends AbstractMapSerializer> { + + private static final long serialVersionUID = 1L; + + /** The comparator for the keys in the map. */ + private final Comparator comparator; + + /** + * Constructor with given comparator, and the serializers for the keys and + * values in the map. + * + * @param comparator The comparator for the keys in the map. + * @param keySerializer The serializer for the keys in the map. + * @param valueSerializer The serializer for the values in the map. + */ + public SortedMapSerializer( + Comparator comparator, + TypeSerializer keySerializer, + TypeSerializer valueSerializer) { + super(keySerializer, valueSerializer); + + Preconditions.checkNotNull(comparator, "The comparator cannot be null."); + this.comparator = comparator; + } + + /** + * Returns the comparator for the keys in the map. + * + * @return The comparator for the keys in the map. + */ + public Comparator getComparator() { + return comparator; + } + + @Override + public TypeSerializer> duplicate() { + TypeSerializer keySerializer = getKeySerializer().duplicate(); + TypeSerializer valueSerializer = getValueSerializer().duplicate(); + + return new SortedMapSerializer<>(comparator, keySerializer, valueSerializer); + } + + @Override + public SortedMap createInstance() { + return new TreeMap<>(comparator); + } + + @Override + public boolean equals(Object o) { + if (!super.equals(o)) { + return false; + } + + SortedMapSerializer that = (SortedMapSerializer) o; + return comparator.equals(that.comparator); + } + + @Override + public int hashCode() { + int result = super.hashCode(); + result = result * 31 + comparator.hashCode(); + return result; + } + + @Override + public String toString() { + return "SortedMapSerializer{" + + "comparator = " + comparator + + ", keySerializer = " + keySerializer + + ", valueSerializer = " + valueSerializer + + "}"; + } + + @Override + public TypeSerializerConfigSnapshot snapshotConfiguration() { + return new MapSerializerConfigSnapshot<>(keySerializer, valueSerializer); + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapTypeInfo.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapTypeInfo.java new file mode 100644 index 00000000000000..010e2ae665db6d --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapTypeInfo.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.typeutils; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.util.Preconditions; + +import java.io.Serializable; +import java.util.Comparator; +import java.util.SortedMap; + +/** + * The type information for sorted maps. + * + * @param The type of the keys in the map. + * @param The type of the values in the map. + */ +@PublicEvolving +public class SortedMapTypeInfo extends AbstractMapTypeInfo> { + + private static final long serialVersionUID = 1L; + + /** The comparator for the keys in the map. */ + private final Comparator comparator; + + public SortedMapTypeInfo( + TypeInformation keyTypeInfo, + TypeInformation valueTypeInfo, + Comparator comparator) { + super(keyTypeInfo, valueTypeInfo); + + Preconditions.checkNotNull(comparator, "The comparator cannot be null."); + this.comparator = comparator; + } + + public SortedMapTypeInfo( + Class keyClass, + Class valueClass, + Comparator comparator) { + super(keyClass, valueClass); + + Preconditions.checkNotNull(comparator, "The comparator cannot be null."); + this.comparator = comparator; + } + + public SortedMapTypeInfo(Class keyClass, Class valueClass) { + super(keyClass, valueClass); + + Preconditions.checkArgument(Comparable.class.isAssignableFrom(keyClass), + "The key class must be comparable when no comparator is given."); + this.comparator = new ComparableComparator<>(); + } + + // ------------------------------------------------------------------------ + + @SuppressWarnings("unchecked") + @Override + public Class> getTypeClass() { + return (Class>)(Class)SortedMap.class; + } + + @Override + public TypeSerializer> createSerializer(ExecutionConfig config) { + TypeSerializer keyTypeSerializer = keyTypeInfo.createSerializer(config); + TypeSerializer valueTypeSerializer = valueTypeInfo.createSerializer(config); + + return new SortedMapSerializer<>(comparator, keyTypeSerializer, valueTypeSerializer); + } + + @Override + public boolean canEqual(Object obj) { + return null != obj && getClass() == obj.getClass(); + } + + @Override + public boolean equals(Object o) { + if (!super.equals(o)) { + return false; + } + + SortedMapTypeInfo that = (SortedMapTypeInfo) o; + + return comparator.equals(that.comparator); + } + + @Override + public int hashCode() { + int result = super.hashCode(); + result = 31 * result + comparator.hashCode(); + return result; + } + + @Override + public String toString() { + return "SortedMapTypeInfo{" + + "comparator=" + comparator + + ", keyTypeInfo=" + getKeyTypeInfo() + + ", valueTypeInfo=" + getValueTypeInfo() + + "}"; + } + + //-------------------------------------------------------------------------- + + /** + * The default comparator for comparable types + */ + private static class ComparableComparator implements Comparator, Serializable { + private static final long serialVersionUID = 1L; + + @SuppressWarnings("unchecked") + public int compare(K obj1, K obj2) { + return ((Comparable) obj1).compareTo(obj2); + } + + @Override + public boolean equals(Object o) { + return (o == this) || (o != null && o.getClass() == getClass()); + } + + @Override + public int hashCode() { + return "ComparableComparator".hashCode(); + } + + } +}