From 1b8e599b2c68344b941eeb141720b71e567da677 Mon Sep 17 00:00:00 2001 From: beyond1920 Date: Thu, 4 Apr 2019 18:47:02 +0800 Subject: [PATCH 1/5] [FLINK-12017][table-planner-blink] Support translation from Rank/Deduplicate to StreamTransformation --- flink-table/flink-table-planner-blink/pom.xml | 7 + .../table/plan/util/KeySelectorUtil.java | 71 + .../apache/flink/table/api/TableConfig.scala | 14 + .../flink/table/codegen/CodeGenUtils.scala | 4 +- .../codegen/EqualiserCodeGenerator.scala | 148 ++ .../table/codegen/SinkCodeGenerator.scala | 12 +- .../table/codegen/ValuesCodeGenerator.scala | 4 +- .../codegen/sort/SortCodeGenerator.scala | 1 + .../plan/nodes/calcite/LogicalRank.scala | 2 +- .../flink/table/plan/nodes/calcite/Rank.scala | 78 +- .../plan/nodes/logical/FlinkLogicalRank.scala | 4 +- .../nodes/physical/batch/BatchExecRank.scala | 6 +- .../stream/StreamExecDataStreamScan.scala | 8 +- .../stream/StreamExecDeduplicate.scala | 111 +- .../physical/stream/StreamExecExchange.scala | 54 +- .../physical/stream/StreamExecRank.scala | 143 +- .../rules/logical/FlinkLogicalRankRule.scala | 7 +- .../physical/batch/BatchExecRankRule.scala | 6 +- .../stream/StreamExecDeduplicateRule.scala | 5 +- .../table/plan/util/FlinkRelMdUtil.scala | 4 +- .../flink/table/plan/util/RankUtil.scala | 16 +- .../table/typeutils/TypeCheckUtils.scala | 65 +- .../utils/FailingCollectionSource.java | 269 ++++ .../stream/sql/DeduplicateITCase.scala | 172 +++ .../table/runtime/stream/sql/RankITCase.scala | 1296 +++++++++++++++++ .../table/runtime/utils/StreamTestSink.scala | 277 +++- .../runtime/utils/StreamingTestBase.scala | 6 + .../StreamingWithMiniBatchTestBase.scala | 73 + .../utils/StreamingWithStateTestBase.scala | 271 ++++ .../flink/table/runtime/utils/TableUtil.scala | 13 +- .../table/runtime/utils/TimeTestUtil.scala | 67 + .../flink/table/api/TableConfigOptions.java | 9 + .../flink/table/dataformat/BinaryRow.java | 16 + .../flink/table/dataformat/BinaryWriter.java | 3 +- .../dataformat/DataFormatConverters.java | 11 +- .../table/dataformat/TypeGetterSetters.java | 3 +- .../table/dataformat/util/BinaryRowUtil.java | 10 + .../bundle/AbstractMapBundleOperator.java | 4 +- .../deduplicate/DeduplicateFunction.java | 97 ++ .../deduplicate/DeduplicateFunctionBase.java | 69 + .../MiniBatchDeduplicateFunction.java | 101 ++ .../KeyedProcessFunctionWithCleanupState.java | 94 ++ .../keyselector/BaseRowKeySelector.java | 33 + .../keyselector/BinaryRowKeySelector.java | 56 + .../keyselector/NullBinaryRowKeySelector.java | 41 + .../runtime/rank/AbstractRankFunction.java | 318 ++++ .../rank/AbstractUpdateRankFunction.java | 293 ++++ .../runtime/rank/AppendRankFunction.java | 226 +++ .../table/runtime/rank/ConstantRankRange.java | 53 + .../rank/ConstantRankRangeWithoutEnd.java | 43 + .../flink/table/runtime/rank/RankRange.java | 32 + .../flink/table/runtime/rank/RankType.java | 49 + .../runtime/rank/RetractRankFunction.java | 263 ++++ .../flink/table/runtime/rank/SortedMap.java | 214 +++ .../runtime/rank/UpdateRankFunction.java | 258 ++++ .../table/runtime/rank/VariableRankRange.java | 49 + .../runtime/values/ValuesInputFormat.java | 4 +- .../flink/table/type/TypeConverters.java | 4 + .../typeutils/AbstractMapSerializer.java | 194 +++ .../table/typeutils/AbstractMapTypeInfo.java | 140 ++ .../table/typeutils/SortedMapSerializer.java | 120 ++ .../table/typeutils/SortedMapTypeInfo.java | 145 ++ 62 files changed, 6039 insertions(+), 127 deletions(-) create mode 100644 flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/util/KeySelectorUtil.java create mode 100644 flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/EqualiserCodeGenerator.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/runtime/utils/FailingCollectionSource.java create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/DeduplicateITCase.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/RankITCase.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingWithMiniBatchTestBase.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingWithStateTestBase.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TimeTestUtil.scala create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunction.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionBase.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateFunction.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/KeyedProcessFunctionWithCleanupState.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/BaseRowKeySelector.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/BinaryRowKeySelector.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/NullBinaryRowKeySelector.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractRankFunction.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractUpdateRankFunction.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AppendRankFunction.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/ConstantRankRange.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/ConstantRankRangeWithoutEnd.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RankRange.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RankType.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RetractRankFunction.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/SortedMap.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/UpdateRankFunction.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/VariableRankRange.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/AbstractMapSerializer.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/AbstractMapTypeInfo.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapSerializer.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapTypeInfo.java diff --git a/flink-table/flink-table-planner-blink/pom.xml b/flink-table/flink-table-planner-blink/pom.xml index f8ab62e3a264e..4fad3028883d8 100644 --- a/flink-table/flink-table-planner-blink/pom.xml +++ b/flink-table/flink-table-planner-blink/pom.xml @@ -200,6 +200,13 @@ under the License. test-jar test + + + org.apache.flink + flink-statebackend-rocksdb_${scala.binary.version} + ${project.version} + test + diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/util/KeySelectorUtil.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/util/KeySelectorUtil.java new file mode 100644 index 0000000000000..651ef531274ee --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/util/KeySelectorUtil.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.util; + +import org.apache.flink.table.api.TableConfig; +import org.apache.flink.table.codegen.CodeGeneratorContext; +import org.apache.flink.table.codegen.ProjectionCodeGenerator; +import org.apache.flink.table.generated.GeneratedProjection; +import org.apache.flink.table.runtime.keyselector.BaseRowKeySelector; +import org.apache.flink.table.runtime.keyselector.BinaryRowKeySelector; +import org.apache.flink.table.runtime.keyselector.NullBinaryRowKeySelector; +import org.apache.flink.table.type.InternalType; +import org.apache.flink.table.type.RowType; +import org.apache.flink.table.typeutils.BaseRowSerializer; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; +import org.apache.flink.table.typeutils.TypeCheckUtils; + +/** + * Utility for KeySelector. + */ +public class KeySelectorUtil { + + /** + * Create a BaseRowKeySelector to extract keys from DataStream which type is BaseRowTypeInfo. + * + * @param keyFields key fields + * @param rowType type of DataStream to extract keys + * + * @return the BaseRowKeySelector to extract keys from DataStream which type is BaseRowTypeInfo. + */ + public static BaseRowKeySelector getBaseRowSelector(int[] keyFields, BaseRowTypeInfo rowType) { + if (keyFields.length > 0) { + InternalType[] inputFieldTypes = rowType.getInternalTypes(); + String[] inputFieldNames = rowType.getFieldNames(); + InternalType[] keyFieldTypes = new InternalType[keyFields.length]; + String[] keyFieldNames = new String[keyFields.length]; + for (int i = 0; i < keyFields.length; ++i) { + keyFieldTypes[i] = inputFieldTypes[keyFields[i]]; + keyFieldNames[i] = inputFieldNames[keyFields[i]]; + } + RowType returnType = new RowType(keyFieldTypes, keyFieldNames); + RowType inputType = new RowType(inputFieldTypes, rowType.getFieldNames()); + GeneratedProjection generatedProjection = ProjectionCodeGenerator.generateProjection( + CodeGeneratorContext.apply(new TableConfig()), + BaseRowSerializer.class.getSimpleName(), inputType, returnType, keyFields); + BaseRowTypeInfo keyRowType = returnType.toTypeInfo(); + // check if type implements proper equals/hashCode + TypeCheckUtils.validateEqualsHashCode("grouping", keyRowType); + return new BinaryRowKeySelector(keyRowType, generatedProjection); + } else { + return new NullBinaryRowKeySelector(); + } + } + +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableConfig.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableConfig.scala index e6c0bc92419a8..a1844b0e4fd28 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableConfig.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableConfig.scala @@ -187,6 +187,20 @@ class TableConfig { this.conf.setLong(TableConfigOptions.SQL_EXEC_STATE_TTL_MAX_MS, maxTime.toMilliseconds) this } + + def getMinIdleStateRetentionTime: Long = { + this.conf.getLong(TableConfigOptions.SQL_EXEC_STATE_TTL_MS) + } + + def getMaxIdleStateRetentionTime: Long = { + // only min idle ttl provided. + if (this.conf.contains(TableConfigOptions.SQL_EXEC_STATE_TTL_MS) + && !this.conf.contains(TableConfigOptions.SQL_EXEC_STATE_TTL_MAX_MS)) { + this.conf.setLong(TableConfigOptions.SQL_EXEC_STATE_TTL_MAX_MS, + getMinIdleStateRetentionTime * 2) + } + this.conf.getLong(TableConfigOptions.SQL_EXEC_STATE_TTL_MAX_MS) + } } object TableConfig { diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CodeGenUtils.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CodeGenUtils.scala index 0d79320f5956b..87114d0ca7a06 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CodeGenUtils.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CodeGenUtils.scala @@ -115,7 +115,7 @@ object CodeGenUtils { case InternalTypes.DATE => "int" case InternalTypes.TIME => "int" - case InternalTypes.TIMESTAMP => "long" + case _: TimestampType => "long" case InternalTypes.INTERVAL_MONTHS => "int" case InternalTypes.INTERVAL_MILLIS => "long" @@ -135,7 +135,7 @@ object CodeGenUtils { case InternalTypes.DATE => boxedTypeTermForType(InternalTypes.INT) case InternalTypes.TIME => boxedTypeTermForType(InternalTypes.INT) - case InternalTypes.TIMESTAMP => boxedTypeTermForType(InternalTypes.LONG) + case _: TimestampType => boxedTypeTermForType(InternalTypes.LONG) case InternalTypes.STRING => BINARY_STRING case InternalTypes.BINARY => "byte[]" diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/EqualiserCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/EqualiserCodeGenerator.scala new file mode 100644 index 0000000000000..35c8d92d49f15 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/EqualiserCodeGenerator.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.codegen + +import org.apache.flink.table.api.TableConfig +import org.apache.flink.table.codegen.CodeGenUtils._ +import org.apache.flink.table.codegen.Indenter.toISC +import org.apache.flink.table.dataformat.{BaseRow, BinaryRow} +import org.apache.flink.table.generated.{GeneratedRecordEqualiser, RecordEqualiser} +import org.apache.flink.table.`type`.{DateType, InternalType, PrimitiveType, RowType, TimeType, TimestampType} + +class EqualiserCodeGenerator(fieldTypes: Seq[InternalType]) { + + private val BASE_ROW = className[BaseRow] + private val BINARY_ROW = className[BinaryRow] + private val RECORD_EQUALISER = className[RecordEqualiser] + private val LEFT_INPUT = "left" + private val RIGHT_INPUT = "right" + + def generateRecordEqualiser(name: String): GeneratedRecordEqualiser = { + // ignore time zone + val ctx = CodeGeneratorContext(new TableConfig) + val className = newName(name) + val header = + s""" + |if ($LEFT_INPUT.getHeader() != $RIGHT_INPUT.getHeader()) { + | return false; + |} + """.stripMargin + + val codes = for (i <- fieldTypes.indices) yield { + val fieldType = fieldTypes(i) + val fieldTypeTerm = primitiveTypeTermForType(fieldType) + val result = s"cmp$i" + val leftNullTerm = "leftIsNull$" + i + val rightNullTerm = "rightIsNull$" + i + val leftFieldTerm = "leftField$" + i + val rightFieldTerm = "rightField$" + i + val equalsCode = if (isInternalPrimitive(fieldType)) { + s"$leftFieldTerm == $rightFieldTerm" + } else if (isBaseRow(fieldType)) { + val equaliserGenerator = + new EqualiserCodeGenerator(fieldType.asInstanceOf[RowType].getFieldTypes) + val generatedEqualiser = equaliserGenerator + .generateRecordEqualiser("field$" + i + "GeneratedEqualiser") + val generatedEqualiserTerm = ctx.addReusableObject( + generatedEqualiser, "field$" + i + "GeneratedEqualiser") + val equaliserTypeTerm = classOf[RecordEqualiser].getCanonicalName + val equaliserTerm = newName("equaliser") + ctx.addReusableMember(s"private $equaliserTypeTerm $equaliserTerm = null;") + ctx.addReusableInitStatement( + s""" + |$equaliserTerm = ($equaliserTypeTerm) + | $generatedEqualiserTerm.newInstance(Thread.currentThread().getContextClassLoader()); + |""".stripMargin) + s"$equaliserTerm.equalsWithoutHeader($leftFieldTerm, $rightFieldTerm)" + } else { + s"$leftFieldTerm.equals($rightFieldTerm)" + } + val leftReadCode = baseRowFieldReadAccess(ctx, i, LEFT_INPUT, fieldType) + val rightReadCode = baseRowFieldReadAccess(ctx, i, RIGHT_INPUT, fieldType) + s""" + |boolean $leftNullTerm = $LEFT_INPUT.isNullAt($i); + |boolean $rightNullTerm = $RIGHT_INPUT.isNullAt($i); + |boolean $result; + |if ($leftNullTerm && $rightNullTerm) { + | $result = true; + |} else if ($leftNullTerm || $rightNullTerm) { + | $result = false; + |} else { + | $fieldTypeTerm $leftFieldTerm = $leftReadCode; + | $fieldTypeTerm $rightFieldTerm = $rightReadCode; + | $result = $equalsCode; + |} + |if (!$result) { + | return false; + |} + """.stripMargin + } + + val functionCode = + j""" + public final class $className implements $RECORD_EQUALISER { + + ${ctx.reuseMemberCode()} + + public $className(Object[] references) throws Exception { + ${ctx.reuseInitCode()} + } + + @Override + public boolean equals($BASE_ROW $LEFT_INPUT, $BASE_ROW $RIGHT_INPUT) { + if ($LEFT_INPUT instanceof $BINARY_ROW && $RIGHT_INPUT instanceof $BINARY_ROW) { + return $LEFT_INPUT.equals($RIGHT_INPUT); + } else { + $header + ${ctx.reuseLocalVariableCode()} + ${codes.mkString("\n")} + return true; + } + } + + @Override + public boolean equalsWithoutHeader($BASE_ROW $LEFT_INPUT, $BASE_ROW $RIGHT_INPUT) { + if ($LEFT_INPUT instanceof $BINARY_ROW && $RIGHT_INPUT instanceof $BINARY_ROW) { + return (($BINARY_ROW)$LEFT_INPUT).equalsWithoutHeader((($BINARY_ROW)$RIGHT_INPUT)); + } else { + ${ctx.reuseLocalVariableCode()} + ${codes.mkString("\n")} + return true; + } + } + } + """.stripMargin + + new GeneratedRecordEqualiser(className, functionCode, ctx.references.toArray) + } + + private def isInternalPrimitive(t: InternalType): Boolean = t match { + case _: PrimitiveType => true + + case _: DateType => true + case TimeType.INSTANCE => true + case _: TimestampType => true + + case _ => false + } + + private def isBaseRow(t: InternalType): Boolean = t match { + case _: RowType => true + case _ => false + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/SinkCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/SinkCodeGenerator.scala index 29431c2fb6458..ce388c8bc9f7f 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/SinkCodeGenerator.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/SinkCodeGenerator.scala @@ -31,6 +31,7 @@ import org.apache.flink.table.`type`.TypeConverters.createInternalTypeFromTypeIn import org.apache.flink.table.api.{Table, TableConfig, TableException, Types} import org.apache.flink.table.codegen.CodeGenUtils.{BASE_ROW, genToExternal} import org.apache.flink.table.codegen.OperatorCodeGenerator.generateCollect +import org.apache.flink.table.dataformat.util.BaseRowUtil import org.apache.flink.table.dataformat.{BaseRow, GenericRow} import org.apache.flink.table.runtime.OneInputOperatorWrapper import org.apache.flink.table.sinks.{DataStreamTableSink, TableSink} @@ -93,6 +94,10 @@ object SinkCodeGenerator { new RowTypeInfo( inputTypeInfo.getFieldTypes, inputTypeInfo.getFieldNames) + case gt: GenericTypeInfo[BaseRow] if gt.getTypeClass == classOf[BaseRow] => + new BaseRowTypeInfo( + inputTypeInfo.getInternalTypes, + inputTypeInfo.getFieldNames) case _ => requestedTypeInfo } @@ -154,13 +159,14 @@ object SinkCodeGenerator { val retractProcessCode = if (!withChangeFlag) { generateCollect(genToExternal(ctx, outputTypeInfo, afterIndexModify)) } else { - val flagResultTerm = s"$afterIndexModify.getHeader() == $BASE_ROW.ACCUMULATE_MSG" + val flagResultTerm = + s"${classOf[BaseRowUtil].getCanonicalName}.isAccumulateMsg($afterIndexModify)" val resultTerm = CodeGenUtils.newName("result") val genericRowField = classOf[GenericRow].getCanonicalName s""" |$genericRowField $resultTerm = new $genericRowField(2); - |$resultTerm.update(0, $flagResultTerm); - |$resultTerm.update(1, $afterIndexModify); + |$resultTerm.setField(0, $flagResultTerm); + |$resultTerm.setField(1, $afterIndexModify); |${generateCollect(genToExternal(ctx, outputTypeInfo, resultTerm))} """.stripMargin } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/ValuesCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/ValuesCodeGenerator.scala index c207bd685a52f..86b79a9e590a0 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/ValuesCodeGenerator.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/ValuesCodeGenerator.scala @@ -22,7 +22,6 @@ import org.apache.flink.table.api.TableEnvironment import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.dataformat.{BaseRow, GenericRow} import org.apache.flink.table.runtime.values.ValuesInputFormat -import org.apache.flink.table.typeutils.BaseRowTypeInfo import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rex.RexLiteral @@ -56,8 +55,7 @@ object ValuesCodeGenerator { generatedRecords.map(_.code), outputType) - val baseRowTypeInfo = new BaseRowTypeInfo(outputType.getFieldTypes, outputType.getFieldNames) - new ValuesInputFormat(generatedFunction, baseRowTypeInfo) + new ValuesInputFormat(generatedFunction, outputType.toTypeInfo) } } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/sort/SortCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/sort/SortCodeGenerator.scala index a3fefe15c691e..b7913a12285e6 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/sort/SortCodeGenerator.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/sort/SortCodeGenerator.scala @@ -449,6 +449,7 @@ class SortCodeGenerator( case InternalTypes.FLOAT => 4 case InternalTypes.DOUBLE => 8 case InternalTypes.LONG => 8 + case _: TimestampType => 8 case dt: DecimalType if Decimal.isCompact(dt.precision()) => 8 case InternalTypes.STRING | InternalTypes.BINARY => Int.MaxValue } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/calcite/LogicalRank.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/calcite/LogicalRank.scala index de4238ea29c27..e9e4c0231367b 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/calcite/LogicalRank.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/calcite/LogicalRank.scala @@ -18,7 +18,7 @@ package org.apache.flink.table.plan.nodes.calcite -import org.apache.flink.table.plan.nodes.calcite.RankType.RankType +import org.apache.flink.table.runtime.rank.{RankRange, RankType} import org.apache.calcite.plan.{Convention, RelOptCluster, RelTraitSet} import org.apache.calcite.rel.`type`.RelDataTypeField diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/calcite/Rank.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/calcite/Rank.scala index 9e19e028e3dbc..0ab70a3406a65 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/calcite/Rank.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/calcite/Rank.scala @@ -19,8 +19,8 @@ package org.apache.flink.table.plan.nodes.calcite import org.apache.flink.table.api.TableException -import org.apache.flink.table.plan.nodes.calcite.RankType.RankType import org.apache.flink.table.plan.util._ +import org.apache.flink.table.runtime.rank.{ConstantRankRange, RankRange, RankType, VariableRankRange} import org.apache.calcite.plan.{RelOptCluster, RelOptCost, RelOptPlanner, RelTraitSet} import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeField} @@ -71,18 +71,19 @@ abstract class Rank( rankRange match { case r: ConstantRankRange => - if (r.rankEnd <= 0) { - throw new TableException(s"Rank end can't smaller than zero. The rank end is ${r.rankEnd}") + if (r.getRankEnd <= 0) { + throw new TableException( + s"Rank end can't smaller than zero. The rank end is ${r.getRankEnd}") } - if (r.rankStart > r.rankEnd) { + if (r.getRankStart > r.getRankEnd) { throw new TableException( - s"Rank start '${r.rankStart}' can't greater than rank end '${r.rankEnd}'.") + s"Rank start '${r.getRankStart}' can't greater than rank end '${r.getRankEnd}'.") } case v: VariableRankRange => - if (v.rankEndIndex < 0) { + if (v.getRankEndIndex < 0) { throw new TableException(s"Rank end index can't smaller than zero.") } - if (v.rankEndIndex >= input.getRowType.getFieldCount) { + if (v.getRankEndIndex >= input.getRowType.getFieldCount) { throw new TableException(s"Rank end index can't greater than input field count.") } } @@ -105,7 +106,7 @@ abstract class Rank( }.mkString(", ") super.explainTerms(pw) .item("rankType", rankType) - .item("rankRange", rankRange.toString()) + .item("rankRange", rankRange) .item("partitionBy", partitionKey.map(i => s"$$$i").mkString(",")) .item("orderBy", RelExplainUtil.collationToString(orderKey)) .item("select", select) @@ -134,64 +135,3 @@ abstract class Rank( } } - -/** - * An enumeration of rank type, usable to tell the [[Rank]] node how exactly generate rank number. - */ -object RankType extends Enumeration { - type RankType = Value - - /** - * Returns a unique sequential number for each row within the partition based on the order, - * starting at 1 for the first row in each partition and without repeating or skipping - * numbers in the ranking result of each partition. If there are duplicate values within the - * row set, the ranking numbers will be assigned arbitrarily. - */ - val ROW_NUMBER: RankType.Value = Value - - /** - * Returns a unique rank number for each distinct row within the partition based on the order, - * starting at 1 for the first row in each partition, with the same rank for duplicate values - * and leaving gaps between the ranks; this gap appears in the sequence after the duplicate - * values. - */ - val RANK: RankType.Value = Value - - /** - * is similar to the RANK by generating a unique rank number for each distinct row - * within the partition based on the order, starting at 1 for the first row in each partition, - * ranking the rows with equal values with the same rank number, except that it does not skip - * any rank, leaving no gaps between the ranks. - */ - val DENSE_RANK: RankType.Value = Value -} - -sealed trait RankRange extends Serializable { - def toString(inputFieldNames: Seq[String]): String -} - -/** [[ConstantRankRangeWithoutEnd]] is a RankRange which not specify RankEnd. */ -case class ConstantRankRangeWithoutEnd(rankStart: Long) extends RankRange { - override def toString(inputFieldNames: Seq[String]): String = this.toString - - override def toString: String = s"rankStart=$rankStart" -} - -/** rankStart and rankEnd are inclusive, rankStart always start from one. */ -case class ConstantRankRange(rankStart: Long, rankEnd: Long) extends RankRange { - - override def toString(inputFieldNames: Seq[String]): String = this.toString - - override def toString: String = s"rankStart=$rankStart, rankEnd=$rankEnd" -} - -/** changing rank limit depends on input */ -case class VariableRankRange(rankEndIndex: Int) extends RankRange { - override def toString(inputFieldNames: Seq[String]): String = { - s"rankEnd=${inputFieldNames(rankEndIndex)}" - } - - override def toString: String = { - s"rankEnd=$$$rankEndIndex" - } -} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalRank.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalRank.scala index 806dadfa0edd0..55d9575637fc5 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalRank.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalRank.scala @@ -18,9 +18,9 @@ package org.apache.flink.table.plan.nodes.logical import org.apache.flink.table.plan.nodes.FlinkConventions -import org.apache.flink.table.plan.nodes.calcite.RankType.RankType -import org.apache.flink.table.plan.nodes.calcite.{LogicalRank, Rank, RankRange} +import org.apache.flink.table.plan.nodes.calcite.{LogicalRank, Rank} import org.apache.flink.table.plan.util.RelExplainUtil +import org.apache.flink.table.runtime.rank.{RankRange, RankType} import org.apache.calcite.plan._ import org.apache.calcite.rel.`type`.RelDataTypeField diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecRank.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecRank.scala index 84f05aae636dc..05cd3252d438b 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecRank.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecRank.scala @@ -25,10 +25,10 @@ import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.codegen.sort.ComparatorCodeGenerator import org.apache.flink.table.dataformat.BaseRow import org.apache.flink.table.plan.cost.{FlinkCost, FlinkCostFactory} -import org.apache.flink.table.plan.nodes.calcite.RankType.RankType -import org.apache.flink.table.plan.nodes.calcite.{ConstantRankRange, Rank, RankRange, RankType} +import org.apache.flink.table.plan.nodes.calcite.Rank import org.apache.flink.table.plan.nodes.exec.{BatchExecNode, ExecNode} import org.apache.flink.table.plan.util.RelExplainUtil +import org.apache.flink.table.runtime.rank.{ConstantRankRange, RankRange, RankType} import org.apache.flink.table.runtime.sort.RankOperator import org.apache.calcite.plan._ @@ -72,7 +72,7 @@ class BatchExecRank( require(rankType == RankType.RANK, "Only RANK is supported now") val (rankStart, rankEnd) = rankRange match { - case r: ConstantRankRange => (r.rankStart, r.rankEnd) + case r: ConstantRankRange => (r.getRankStart, r.getRankEnd) case o => throw new TableException(s"$o is not supported now") } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDataStreamScan.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDataStreamScan.scala index 81fe874c56938..a8a42b189c11b 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDataStreamScan.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDataStreamScan.scala @@ -31,6 +31,7 @@ import org.apache.flink.table.functions.sql.StreamRecordTimestampSqlFunction import org.apache.flink.table.plan.nodes.exec.{ExecNode, StreamExecNode} import org.apache.flink.table.plan.schema.DataStreamTable import org.apache.flink.table.plan.util.ScanUtil +import org.apache.flink.table.runtime.AbstractProcessStreamOperator import org.apache.flink.table.typeutils.TimeIndicatorTypeInfo.{ROWTIME_INDICATOR, ROWTIME_STREAM_MARKER} import org.apache.calcite.plan._ @@ -112,16 +113,17 @@ class StreamExecDataStreamScan( } else { ("", "") } - + val ctx = CodeGeneratorContext(config).setOperatorBaseClass( + classOf[AbstractProcessStreamOperator[BaseRow]]) ScanUtil.convertToInternalRow( - CodeGeneratorContext(config), + ctx, transform, dataStreamTable.fieldIndexes, dataStreamTable.typeInfo, getRowType, getTable.getQualifiedName, config, - None, + rowtimeExpr, beforeConvert = extractElement, afterConvert = resetElement) } else { diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDeduplicate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDeduplicate.scala index e76af21841def..6e75118022463 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDeduplicate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDeduplicate.scala @@ -18,12 +18,30 @@ package org.apache.flink.table.plan.nodes.physical.stream +import org.apache.flink.streaming.api.operators.KeyedProcessOperator +import org.apache.flink.streaming.api.transformations.{OneInputTransformation, StreamTransformation} +import org.apache.flink.table.api.{StreamTableEnvironment, TableConfigOptions, TableException} +import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.codegen.EqualiserCodeGenerator +import org.apache.flink.table.dataformat.BaseRow +import org.apache.flink.table.generated.GeneratedRecordEqualiser +import org.apache.flink.table.plan.nodes.exec.{ExecNode, StreamExecNode} +import org.apache.flink.table.plan.util.KeySelectorUtil +import org.apache.flink.table.runtime.bundle.KeyedMapBundleOperator +import org.apache.flink.table.runtime.bundle.trigger.CountBundleTrigger +import org.apache.flink.table.runtime.deduplicate.{DeduplicateFunction, MiniBatchDeduplicateFunction} +import org.apache.flink.table.`type`.TypeConverters +import org.apache.flink.table.typeutils.BaseRowTypeInfo +import org.apache.flink.table.typeutils.TypeCheckUtils.isRowTime + import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} -import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} +import org.apache.calcite.rel.`type`.RelDataType import java.util +import scala.collection.JavaConversions._ + /** * Stream physical RelNode which deduplicate on keys and keeps only first row or last row. * This node is an optimization of [[StreamExecRank]] for some special cases. @@ -38,7 +56,8 @@ class StreamExecDeduplicate( isRowtime: Boolean, keepLastRow: Boolean) extends SingleRel(cluster, traitSet, inputRel) - with StreamPhysicalRel { + with StreamPhysicalRel + with StreamExecNode[BaseRow] { def getUniqueKeys: Array[Int] = uniqueKeys @@ -73,4 +92,92 @@ class StreamExecDeduplicate( .item("order", orderString) } + //~ ExecNode methods ----------------------------------------------------------- + + override protected def translateToPlanInternal( + tableEnv: StreamTableEnvironment): StreamTransformation[BaseRow] = { + + // TODO checkInput is not acc retract after FLINK- is done + val inputIsAccRetract = false + + if (inputIsAccRetract) { + throw new TableException( + "Deduplicate: Retraction on Deduplicate is not supported yet.\n" + + "please re-check sql grammar. \n" + + "Note: Deduplicate should not follow a non-windowed GroupBy aggregation.") + } + + val inputTransform = getInputNodes.get(0).translateToPlan(tableEnv) + .asInstanceOf[StreamTransformation[BaseRow]] + + val rowTypeInfo = inputTransform.getOutputType.asInstanceOf[BaseRowTypeInfo] + + val generateRetraction = true + + val inputRowType = FlinkTypeFactory.toInternalRowType(getInput.getRowType) + val rowTimeFieldIndex = inputRowType.getFieldTypes.zipWithIndex + .filter(e => isRowTime(e._1)) + .map(_._2) + if (rowTimeFieldIndex.size > 1) { + throw new RuntimeException("More than one row time field. Currently this is not supported!") + } + if (rowTimeFieldIndex.nonEmpty) { + throw new TableException("Currently not support Deduplicate on rowtime.") + } + val tableConfig = tableEnv.getConfig + val exeConfig = tableEnv.execEnv.getConfig + val isMiniBatchEnabled = tableConfig.getConf.getLong( + TableConfigOptions.SQL_EXEC_MINIBATCH_ALLOW_LATENCY) > 0 + val generatedRecordEqualiser = generateRecordEqualiser(rowTypeInfo) + val operator = if (isMiniBatchEnabled) { + val processFunction = new MiniBatchDeduplicateFunction( + rowTypeInfo, + generateRetraction, + exeConfig, + keepLastRow, + generatedRecordEqualiser) + val trigger = new CountBundleTrigger[BaseRow]( + tableConfig.getConf.getLong(TableConfigOptions.SQL_EXEC_MINIBATCH_SIZE)) + new KeyedMapBundleOperator( + processFunction, + trigger) + } else { + val minRetentionTime = tableConfig.getMinIdleStateRetentionTime + val maxRetentionTime = tableConfig.getMaxIdleStateRetentionTime + val processFunction = new DeduplicateFunction( + minRetentionTime, + maxRetentionTime, + rowTypeInfo, + generateRetraction, + keepLastRow, + generatedRecordEqualiser) + new KeyedProcessOperator[BaseRow, BaseRow, BaseRow](processFunction) + } + val ret = new OneInputTransformation( + inputTransform, getOperatorName, operator, rowTypeInfo, inputTransform.getParallelism) + val selector = KeySelectorUtil.getBaseRowSelector(uniqueKeys, rowTypeInfo) + ret.setStateKeySelector(selector) + ret.setStateKeyType(selector.getProducedType) + ret + } + + override def getInputNodes: util.List[ExecNode[StreamTableEnvironment, _]] = { + List(getInput.asInstanceOf[ExecNode[StreamTableEnvironment, _]]) + } + + private def getOperatorName: String = { + val fieldNames = getRowType.getFieldNames + val keyNames = uniqueKeys.map(fieldNames.get).mkString(", ") + val orderString = if (isRowtime) "ROWTIME" else "PROCTIME" + s"${if (keepLastRow) "keepLastRow" else "KeepFirstRow"}" + + s": (key: ($keyNames), select: (${fieldNames.mkString(", ")}), order: ($orderString))" + } + + private def generateRecordEqualiser(rowTypeInfo: BaseRowTypeInfo): GeneratedRecordEqualiser = { + val generator = new EqualiserCodeGenerator( + rowTypeInfo.getFieldTypes.map(TypeConverters.createInternalTypeFromTypeInfo)) + val equaliserName = s"${if (keepLastRow) "LastRow" else "FirstRow"}ValueEqualiser" + generator.generateRecordEqualiser(equaliserName) + } + } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecExchange.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecExchange.scala index b503fb2d9bded..e1c431b0440f3 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecExchange.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecExchange.scala @@ -19,10 +19,22 @@ package org.apache.flink.table.plan.nodes.physical.stream import org.apache.flink.table.plan.nodes.common.CommonPhysicalExchange +import org.apache.flink.table.dataformat.BaseRow +import org.apache.flink.table.plan.nodes.exec.{ExecNode, StreamExecNode} +import org.apache.flink.streaming.api.transformations.{PartitionTransformation, StreamTransformation} +import org.apache.flink.streaming.runtime.partitioner.{GlobalPartitioner, KeyGroupStreamPartitioner, StreamPartitioner} +import org.apache.flink.table.api.StreamTableEnvironment +import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.plan.util.KeySelectorUtil +import org.apache.flink.table.typeutils.BaseRowTypeInfo import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} import org.apache.calcite.rel.{RelDistribution, RelNode} +import java.util + +import scala.collection.JavaConversions._ + /** * Stream physical RelNode for [[org.apache.calcite.rel.core.Exchange]]. */ @@ -32,7 +44,10 @@ class StreamExecExchange( relNode: RelNode, relDistribution: RelDistribution) extends CommonPhysicalExchange(cluster, traitSet, relNode, relDistribution) - with StreamPhysicalRel { + with StreamPhysicalRel + with StreamExecNode[BaseRow] { + + private val DEFAULT_MAX_PARALLELISM = 1 << 7 override def producesUpdates: Boolean = false @@ -50,4 +65,41 @@ class StreamExecExchange( newDistribution: RelDistribution): StreamExecExchange = { new StreamExecExchange(cluster, traitSet, newInput, newDistribution) } + + //~ ExecNode methods ----------------------------------------------------------- + + override def getInputNodes: util.List[ExecNode[StreamTableEnvironment, _]] = { + List(getInput.asInstanceOf[ExecNode[StreamTableEnvironment, _]]) + } + + override protected def translateToPlanInternal( + tableEnv: StreamTableEnvironment): StreamTransformation[BaseRow] = { + val inputTransform = getInputNodes.get(0).translateToPlan(tableEnv) + .asInstanceOf[StreamTransformation[BaseRow]] + val inputTypeInfo = inputTransform.getOutputType.asInstanceOf[BaseRowTypeInfo] + val outputTypeInfo = FlinkTypeFactory.toInternalRowType(getRowType).toTypeInfo + relDistribution.getType match { + case RelDistribution.Type.SINGLETON => + val partitioner = new GlobalPartitioner[BaseRow] + val transformation = new PartitionTransformation( + inputTransform, + partitioner.asInstanceOf[StreamPartitioner[BaseRow]]) + transformation.setOutputType(outputTypeInfo) + transformation + case RelDistribution.Type.HASH_DISTRIBUTED => + // TODO Eliminate duplicate keys + val selector = KeySelectorUtil.getBaseRowSelector( + relDistribution.getKeys.map(_.toInt).toArray, inputTypeInfo) + val partitioner = new KeyGroupStreamPartitioner(selector, DEFAULT_MAX_PARALLELISM) + val transformation = new PartitionTransformation( + inputTransform, + partitioner.asInstanceOf[StreamPartitioner[BaseRow]]) + transformation.setOutputType(outputTypeInfo) + transformation + case _ => + throw new UnsupportedOperationException( + s"not support RelDistribution: ${relDistribution.getType} now!") + } + } + } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecRank.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecRank.scala index ac833b548ea65..97504a2e0a12f 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecRank.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecRank.scala @@ -17,9 +17,17 @@ */ package org.apache.flink.table.plan.nodes.physical.stream -import org.apache.flink.table.plan.nodes.calcite.RankType.RankType -import org.apache.flink.table.plan.nodes.calcite.{Rank, RankRange} -import org.apache.flink.table.plan.util.{RankProcessStrategy, RelExplainUtil, RetractStrategy} +import org.apache.flink.streaming.api.operators.KeyedProcessOperator +import org.apache.flink.streaming.api.transformations.{OneInputTransformation, StreamTransformation} +import org.apache.flink.table.api.{StreamTableEnvironment, TableConfigOptions, TableException} +import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.codegen.{EqualiserCodeGenerator, SortCodeGenerator} +import org.apache.flink.table.dataformat.BaseRow +import org.apache.flink.table.plan.nodes.calcite.Rank +import org.apache.flink.table.plan.nodes.exec.{ExecNode, StreamExecNode} +import org.apache.flink.table.plan.util._ +import org.apache.flink.table.runtime.rank._ +import org.apache.flink.table.typeutils.BaseRowTypeInfo import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} import org.apache.calcite.rel._ @@ -53,7 +61,8 @@ class StreamExecRank( rankRange, rankNumberType, outputRankNumber) - with StreamPhysicalRel { + with StreamPhysicalRel + with StreamExecNode[BaseRow] { /** please uses [[getStrategy]] instead of this field */ private var strategy: RankProcessStrategy = _ @@ -101,4 +110,130 @@ class StreamExecRank( .item("orderBy", RelExplainUtil.collationToString(orderKey, inputRowType)) .item("select", getRowType.getFieldNames.mkString(", ")) } + + + //~ ExecNode methods ----------------------------------------------------------- + + override def getInputNodes: util.List[ExecNode[StreamTableEnvironment, _]] = { + List(getInput.asInstanceOf[ExecNode[StreamTableEnvironment, _]]) + } + + override protected def translateToPlanInternal( + tableEnv: StreamTableEnvironment): StreamTransformation[BaseRow] = { + val tableConfig = tableEnv.getConfig + rankType match { + case RankType.ROW_NUMBER => // ignore + case RankType.RANK => + throw new TableException("RANK() on streaming table is not supported currently") + case RankType.DENSE_RANK => + throw new TableException("DENSE_RANK() on streaming table is not supported currently") + case k => + throw new TableException(s"Streaming tables do not support $k rank function.") + } + + val inputRowTypeInfo = FlinkTypeFactory.toInternalRowType(getInput.getRowType).toTypeInfo + val outputRowTypeInfo = FlinkTypeFactory.toInternalRowType(getRowType).toTypeInfo + + val fieldCollations = orderKey.getFieldCollations + val (sortFields, sortDirections, nullsIsLast) = SortUtil.getKeysAndOrders(fieldCollations) + val sortKeySelector = KeySelectorUtil.getBaseRowSelector(sortFields, inputRowTypeInfo) + val sortKeyType = sortKeySelector.getProducedType + val sortCodeGen = new SortCodeGenerator( + tableConfig, sortFields.indices.toArray, sortKeyType.getInternalTypes, + sortDirections, nullsIsLast) + val comparator = sortCodeGen.generateRecordComparator("StreamExecSortComparator") + // TODO infer generate retraction after retraction rules are merged + val generateRetraction = true + val cacheSize = tableConfig.getConf.getLong(TableConfigOptions.SQL_EXEC_TOPN_CACHE_SIZE) + val minIdleStateRetentionTime = tableConfig.getMinIdleStateRetentionTime + val maxIdleStateRetentionTime = tableConfig.getMaxIdleStateRetentionTime + val equaliserCodeGenerator = new EqualiserCodeGenerator(inputRowTypeInfo.getInternalTypes) + val equaliser = equaliserCodeGenerator.generateRecordEqualiser("RankValueEqualiser") + val processFunction = getStrategy(true) match { + case AppendFastStrategy => + new AppendRankFunction( + minIdleStateRetentionTime, + maxIdleStateRetentionTime, + inputRowTypeInfo, + outputRowTypeInfo, + sortKeyType, + comparator, + sortKeySelector, + rankType, + rankRange, + equaliser, + generateRetraction, + cacheSize) + + case UpdateFastStrategy(primaryKeys) => + val internalTypes = inputRowTypeInfo.getInternalTypes + val fieldTypes = primaryKeys.map(internalTypes) + val rowKeyType = new BaseRowTypeInfo(fieldTypes: _*) + val rowKeySelector = KeySelectorUtil.getBaseRowSelector(primaryKeys, inputRowTypeInfo) + new UpdateRankFunction( + minIdleStateRetentionTime, + maxIdleStateRetentionTime, + inputRowTypeInfo, + outputRowTypeInfo, + rowKeyType, + rowKeySelector, + comparator, + sortKeySelector, + rankType, + rankRange, + equaliser, + generateRetraction, + cacheSize) + + // TODO UnaryUpdateRank after SortedMapState is merged + case RetractStrategy | UnaryUpdateStrategy(_) => + new RetractRankFunction( + minIdleStateRetentionTime, + maxIdleStateRetentionTime, + inputRowTypeInfo, + outputRowTypeInfo, + sortKeyType, + comparator, + sortKeySelector, + rankType, + rankRange, + equaliser, + generateRetraction) + } + val rankOpName = getOperatorName + val operator = new KeyedProcessOperator(processFunction) + processFunction.setKeyContext(operator); + val inputTransform = getInputNodes.get(0).translateToPlan(tableEnv) + .asInstanceOf[StreamTransformation[BaseRow]] + val ret = new OneInputTransformation( + inputTransform, + rankOpName, + operator, + outputRowTypeInfo, + inputTransform.getParallelism) + + if (partitionKey.isEmpty) { + ret.setParallelism(1) + ret.setMaxParallelism(1) + } + + // set KeyType and Selector for state + val selector = KeySelectorUtil.getBaseRowSelector(partitionKey.toArray, inputRowTypeInfo) + ret.setStateKeySelector(selector) + ret.setStateKeyType(selector.getProducedType) + ret + } + + private def getOperatorName: String = { + val inputRowType = inputRel.getRowType + var result = getStrategy().toString + result += s"(orderBy: (${RelExplainUtil.collationToString(orderKey, inputRowType)})" + if (partitionKey.nonEmpty) { + val partitionKeys = partitionKey.toArray + result += s", partitionBy: (${RelExplainUtil.fieldToString(partitionKeys, inputRowType)})" + } + result += s", ${getRowType.getFieldNames.mkString(", ")}" + result += s", ${rankRange.toString(inputRowType.getFieldNames)})" + result + } } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkLogicalRankRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkLogicalRankRule.scala index 3a53e2c1223d5..3b21083e2a608 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkLogicalRankRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkLogicalRankRule.scala @@ -19,9 +19,9 @@ package org.apache.flink.table.plan.rules.logical import org.apache.flink.table.api.TableException import org.apache.flink.table.calcite.FlinkContext -import org.apache.flink.table.plan.nodes.calcite.{ConstantRankRange, ConstantRankRangeWithoutEnd, RankType} import org.apache.flink.table.plan.nodes.logical.{FlinkLogicalCalc, FlinkLogicalOverWindow, FlinkLogicalRank} import org.apache.flink.table.plan.util.RankUtil +import org.apache.flink.table.runtime.rank.{ConstantRankRange, ConstantRankRangeWithoutEnd, RankType} import org.apache.calcite.plan.RelOptRule.{any, operand} import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelOptUtil} @@ -85,8 +85,9 @@ abstract class FlinkLogicalRankRuleBase require(rankNumberType.isDefined) rankRange match { - case Some(ConstantRankRange(_, rankEnd)) if rankEnd <= 0 => - throw new TableException(s"Rank end should not less than zero, but now is $rankEnd") + case Some(crr: ConstantRankRange) if crr.getRankEnd <= 0 => + throw new TableException( + s"Rank end should not less than zero, but now is ${crr.getRankEnd}") case _ => // do nothing } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecRankRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecRankRule.scala index 1e8950b6091ea..918211b9cda67 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecRankRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecRankRule.scala @@ -21,10 +21,10 @@ package org.apache.flink.table.plan.rules.physical.batch import org.apache.flink.table.api.TableException import org.apache.flink.table.plan.`trait`.FlinkRelDistribution import org.apache.flink.table.plan.nodes.FlinkConventions -import org.apache.flink.table.plan.nodes.calcite.{ConstantRankRange, RankType} import org.apache.flink.table.plan.nodes.logical.FlinkLogicalRank import org.apache.flink.table.plan.nodes.physical.batch.BatchExecRank import org.apache.flink.table.plan.util.RelFieldCollationUtil +import org.apache.flink.table.runtime.rank.{ConstantRankRange, RankType} import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall} import org.apache.calcite.rel.convert.ConverterRule @@ -58,7 +58,7 @@ class BatchExecRankRule def convert(rel: RelNode): RelNode = { val rank = rel.asInstanceOf[FlinkLogicalRank] val (_, rankEnd) = rank.rankRange match { - case r: ConstantRankRange => (r.rankStart, r.rankEnd) + case r: ConstantRankRange => (r.getRankStart, r.getRankEnd) case o => throw new TableException(s"$o is not supported now") } @@ -71,7 +71,7 @@ class BatchExecRankRule val newLocalInput = RelOptRule.convert(rank.getInput, localRequiredTraitSet) // create local BatchExecRank - val localRankRange = ConstantRankRange(1, rankEnd) // local rank always start from 1 + val localRankRange = new ConstantRankRange(1, rankEnd) // local rank always start from 1 val localRank = new BatchExecRank( cluster, emptyTraits, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecDeduplicateRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecDeduplicateRule.scala index b07ba7a040989..609bae8b7d34f 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecDeduplicateRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecDeduplicateRule.scala @@ -21,9 +21,9 @@ package org.apache.flink.table.plan.rules.physical.stream import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.plan.`trait`.FlinkRelDistribution import org.apache.flink.table.plan.nodes.FlinkConventions -import org.apache.flink.table.plan.nodes.calcite.{ConstantRankRange, RankType} import org.apache.flink.table.plan.nodes.logical.FlinkLogicalRank import org.apache.flink.table.plan.nodes.physical.stream.{StreamExecDeduplicate, StreamExecRank} +import org.apache.flink.table.runtime.rank.{ConstantRankRange, RankType} import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall} import org.apache.calcite.rel.`type`.RelDataType @@ -111,7 +111,8 @@ object StreamExecDeduplicateRule { val isRowNumberType = rank.rankType == RankType.ROW_NUMBER val isLimit1 = rankRange match { - case ConstantRankRange(rankStart, rankEnd) => rankStart == 1 && rankEnd == 1 + case rankRange: ConstantRankRange => + rankRange.getRankStart() == 1 && rankRange.getRankEnd() == 1 case _ => false } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/FlinkRelMdUtil.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/FlinkRelMdUtil.scala index fbb16f50260cb..b88cfffbdc2f0 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/FlinkRelMdUtil.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/FlinkRelMdUtil.scala @@ -20,7 +20,7 @@ package org.apache.flink.table.plan.util import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.dataformat.BinaryRow -import org.apache.flink.table.plan.nodes.calcite.{ConstantRankRange, RankRange} +import org.apache.flink.table.runtime.rank.{ConstantRankRange, RankRange} import org.apache.flink.table.runtime.sort.BinaryIndexedSortable import org.apache.flink.table.typeutils.BinaryRowSerializer @@ -37,7 +37,7 @@ import scala.collection.JavaConversions._ object FlinkRelMdUtil { def getRankRangeNdv(rankRange: RankRange): Double = rankRange match { - case r: ConstantRankRange => (r.rankEnd - r.rankStart + 1).toDouble + case r: ConstantRankRange => (r.getRankEnd - r.getRankStart + 1).toDouble case _ => 100D // default value now } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/RankUtil.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/RankUtil.scala index 338564c42c164..f7f7378aa5dff 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/RankUtil.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/RankUtil.scala @@ -19,7 +19,9 @@ package org.apache.flink.table.plan.util import org.apache.flink.table.api.{PlannerConfigOptions, TableConfig} -import org.apache.flink.table.plan.nodes.calcite.{ConstantRankRange, ConstantRankRangeWithoutEnd, Rank, RankRange, VariableRankRange} +import org.apache.flink.table.plan.nodes.calcite.Rank +import org.apache.flink.table.runtime.rank.{ConstantRankRange, ConstantRankRangeWithoutEnd, RankRange, VariableRankRange} +import org.apache.flink.table.typeutils.BaseRowTypeInfo import org.apache.calcite.plan.RelOptUtil import org.apache.calcite.rex.{RexBuilder, RexCall, RexInputRef, RexLiteral, RexNode, RexUtil} @@ -98,20 +100,20 @@ object RankUtil { val sortBounds = limitPreds.map(computeWindowBoundFromPredicate(_, rexBuilder, config)) val rankRange = sortBounds match { case Seq(Some(LowerBoundary(x)), Some(UpperBoundary(y))) => - ConstantRankRange(x, y) + new ConstantRankRange(x, y) case Seq(Some(UpperBoundary(x)), Some(LowerBoundary(y))) => - ConstantRankRange(y, x) + new ConstantRankRange(y, x) case Seq(Some(LowerBoundary(x))) => // only offset - ConstantRankRangeWithoutEnd(x) + new ConstantRankRangeWithoutEnd(x) case Seq(Some(UpperBoundary(x))) => // rankStart starts from one - ConstantRankRange(1, x) + new ConstantRankRange(1, x) case Seq(Some(BothBoundary(x, y))) => // nth rank - ConstantRankRange(x, y) + new ConstantRankRange(x, y) case Seq(Some(InputRefBoundary(x))) => - VariableRankRange(x) + new VariableRankRange(x) case _ => // TopN requires at least one rank comparison predicate return (None, Some(predicate)) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/typeutils/TypeCheckUtils.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/typeutils/TypeCheckUtils.scala index 32c9f4f61c3bb..07b9d93abcf92 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/typeutils/TypeCheckUtils.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/typeutils/TypeCheckUtils.scala @@ -18,8 +18,12 @@ package org.apache.flink.table.typeutils -import org.apache.flink.table.`type`._ +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.common.typeutils.CompositeType +import org.apache.flink.api.java.typeutils.PojoTypeInfo +import org.apache.flink.table.api.ValidationException import org.apache.flink.table.codegen.GeneratedExpression +import org.apache.flink.table.`type`._ object TypeCheckUtils { @@ -96,4 +100,63 @@ object TypeCheckUtils { def isReference(genExpr: GeneratedExpression): Boolean = isReference(genExpr.resultType) + /** + * Checks whether a type implements own hashCode() and equals() methods for storing an instance + * in Flink's state or performing a keyBy operation. + * + * @param name name of the operation. + * @param t type information to be validated + */ + def validateEqualsHashCode(name: String, t: TypeInformation[_]): Unit = t match { + + // make sure that a POJO class is a valid state type + case pt: PojoTypeInfo[_] => + // we don't check the types recursively to give a chance of wrapping + // proper hashCode/equals methods around an immutable type + validateEqualsHashCode(name, pt.getClass) + // BinaryRow direct hash in bytes, no need to check field types. + case bt: BaseRowTypeInfo => + // recursively check composite types + case ct: CompositeType[_] => + validateEqualsHashCode(name, t.getTypeClass) + // we check recursively for entering Flink types such as tuples and rows + for (i <- 0 until ct.getArity) { + val subtype = ct.getTypeAt(i) + validateEqualsHashCode(name, subtype) + } + // check other type information only based on the type class + case _: TypeInformation[_] => + validateEqualsHashCode(name, t.getTypeClass) + } + + /** + * Checks whether a class implements own hashCode() and equals() methods for storing an instance + * in Flink's state or performing a keyBy operation. + * + * @param name name of the operation + * @param c class to be validated + */ + def validateEqualsHashCode(name: String, c: Class[_]): Unit = { + + // skip primitives + if (!c.isPrimitive) { + // check the component type of arrays + if (c.isArray) { + validateEqualsHashCode(name, c.getComponentType) + } + // check type for methods + else { + if (c.getMethod("hashCode").getDeclaringClass eq classOf[Object]) { + throw new ValidationException( + s"Type '${c.getCanonicalName}' cannot be used in a $name operation because it " + + s"does not implement a proper hashCode() method.") + } + if (c.getMethod("equals", classOf[Object]).getDeclaringClass eq classOf[Object]) { + throw new ValidationException( + s"Type '${c.getCanonicalName}' cannot be used in a $name operation because it " + + s"does not implement a proper equals() method.") + } + } + } + } } diff --git a/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/runtime/utils/FailingCollectionSource.java b/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/runtime/utils/FailingCollectionSource.java new file mode 100644 index 0000000000000..726f3d9ae09da --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/runtime/utils/FailingCollectionSource.java @@ -0,0 +1,269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.utils; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.runtime.state.CheckpointListener; +import org.apache.flink.runtime.state.FunctionInitializationContext; +import org.apache.flink.runtime.state.FunctionSnapshotContext; +import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; +import org.apache.flink.streaming.api.functions.source.SourceFunction; +import org.apache.flink.util.Preconditions; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.util.Preconditions.checkArgument; + +/** + * The FailingCollectionSource will fail after emitted a specified number of elements. This is used + * to perform checkpoint and restore in IT cases. + */ +public class FailingCollectionSource + implements SourceFunction, CheckpointedFunction, CheckpointListener { + + public static volatile boolean failedBefore = false; + + private static final long serialVersionUID = 1L; + + /** The (de)serializer to be used for the data elements. */ + private final TypeSerializer serializer; + + /** The actual data elements, in serialized form. */ + private final byte[] elementsSerialized; + + /** The number of serialized elements. */ + private final int numElements; + + /** The number of elements emitted already. */ + private volatile int numElementsEmitted; + + /** The number of elements to skip initially. */ + private volatile int numElementsToSkip; + + /** Flag to make the source cancelable. */ + private volatile boolean isRunning = true; + + private transient ListState checkpointedState; + + /** A failure will occur when the given number of elements have been processed. */ + private final int failureAfterNumElements; + + /** The number of completed checkpoints. */ + private volatile int numSuccessfulCheckpoints; + + /** The checkpointed number of emitted elements. */ + private final Map checkpointedEmittedNums; + + /** The last successful checkpointed number of emitted elements. */ + private volatile int lastCheckpointedEmittedNum = 0; + + /** Whether to perform a checkpoint before job finished. */ + private final boolean performCheckpointBeforeJobFinished; + + public FailingCollectionSource( + TypeSerializer serializer, + Iterable elements, + int failureAfterNumElements, + boolean performCheckpointBeforeJobFinished) throws IOException { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutputViewStreamWrapper wrapper = new DataOutputViewStreamWrapper(baos); + + int count = 0; + try { + for (T element : elements) { + serializer.serialize(element, wrapper); + count++; + } + } + catch (Exception e) { + throw new IOException("Serializing the source elements failed: " + e.getMessage(), e); + } + + this.serializer = serializer; + this.elementsSerialized = baos.toByteArray(); + this.numElements = count; + checkArgument(failureAfterNumElements > 0); + this.failureAfterNumElements = failureAfterNumElements; + this.performCheckpointBeforeJobFinished = performCheckpointBeforeJobFinished; + this.checkpointedEmittedNums = new HashMap<>(); + } + + @Override + public void initializeState(FunctionInitializationContext context) throws Exception { + Preconditions.checkState( + this.checkpointedState == null, + "The " + getClass().getSimpleName() + " has already been initialized."); + + this.checkpointedState = context.getOperatorStateStore().getListState( + new ListStateDescriptor<>( + "from-elements-state", + IntSerializer.INSTANCE + ) + ); + + if (failedBefore && context.isRestored()) { + List retrievedStates = new ArrayList<>(); + for (Integer entry : this.checkpointedState.get()) { + retrievedStates.add(entry); + } + + // given that the parallelism of the function is 1, we can only have 1 state + Preconditions.checkArgument( + retrievedStates.size() == 1, + getClass().getSimpleName() + " retrieved invalid state."); + + this.numElementsToSkip = retrievedStates.get(0); + } + } + + @Override + public void run(SourceContext ctx) throws Exception { + ByteArrayInputStream bais = new ByteArrayInputStream(elementsSerialized); + final DataInputView input = new DataInputViewStreamWrapper(bais); + + // if we are restored from a checkpoint and need to skip elements, skip them now. + int toSkip = numElementsToSkip; + if (toSkip > 0) { + try { + while (toSkip > 0) { + serializer.deserialize(input); + toSkip--; + } + } + catch (Exception e) { + throw new IOException( + "Failed to deserialize an element from the source. " + + "If you are using user-defined serialization (Value and Writable types), check the " + + "serialization functions.\nSerializer is " + serializer); + } + + this.numElementsEmitted = this.numElementsToSkip; + } + + while (isRunning && numElementsEmitted < numElements) { + if (!failedBefore) { + // delay a bit, if we have not failed before + Thread.sleep(1); + if (numSuccessfulCheckpoints >= 1 && lastCheckpointedEmittedNum >= failureAfterNumElements) { + // cause a failure if we have not failed before and have reached + // enough completed checkpoints and elements + failedBefore = true; + throw new Exception("Artificial Failure"); + } + } + + if (failedBefore || numElementsEmitted < failureAfterNumElements) { + // the function failed before, or we are in the elements before the failure + T next; + try { + next = serializer.deserialize(input); + } + catch (Exception e) { + throw new IOException( + "Failed to deserialize an element from the source. " + + "If you are using user-defined serialization (Value and Writable types), check the " + + "serialization functions.\nSerializer is " + serializer); + } + + synchronized (ctx.getCheckpointLock()) { + ctx.collect(next); + numElementsEmitted++; + } + } else { + // if our work is done, delay a bit to prevent busy waiting + Thread.sleep(1); + } + } + + if (performCheckpointBeforeJobFinished) { + while (isRunning) { + // wait until the latest checkpoint records everything + if (lastCheckpointedEmittedNum < numElements) { + // delay a bit to prevent busy waiting + Thread.sleep(1); + } else { + // cause a failure to retain the last checkpoint + throw new Exception("Job finished normally"); + } + } + } + } + + @Override + public void cancel() { + isRunning = false; + } + + + /** + * Gets the number of elements produced in total by this function. + * + * @return The number of elements produced in total. + */ + public int getNumElements() { + return numElements; + } + + /** + * Gets the number of elements emitted so far. + * + * @return The number of elements emitted so far. + */ + public int getNumElementsEmitted() { + return numElementsEmitted; + } + + // ------------------------------------------------------------------------ + // Checkpointing + // ------------------------------------------------------------------------ + + @Override + public void snapshotState(FunctionSnapshotContext context) throws Exception { + Preconditions.checkState( + this.checkpointedState != null, + "The " + getClass().getSimpleName() + " has not been properly initialized."); + + this.checkpointedState.clear(); + this.checkpointedState.add(this.numElementsEmitted); + long checkpointId = context.getCheckpointId(); + checkpointedEmittedNums.put(checkpointId, numElementsEmitted); + } + + @Override + public void notifyCheckpointComplete(long checkpointId) throws Exception { + numSuccessfulCheckpoints++; + lastCheckpointedEmittedNum = checkpointedEmittedNums.get(checkpointId); + } + + public static void reset() { + failedBefore = false; + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/DeduplicateITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/DeduplicateITCase.scala new file mode 100644 index 0000000000000..78744516d88b0 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/DeduplicateITCase.scala @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.stream.sql + +import org.apache.flink.api.scala._ +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.runtime.utils.StreamingWithMiniBatchTestBase.MiniBatchMode +import org.apache.flink.table.runtime.utils._ +import org.apache.flink.table.runtime.utils.StreamingWithStateTestBase.StateBackendMode +import org.apache.flink.table.runtime.utils.TimeTestUtil.TimestampAndWatermarkWithOffset +import org.apache.flink.types.Row + +import org.junit.Assert._ +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized + +@RunWith(classOf[Parameterized]) +class DeduplicateITCase(miniBatch: MiniBatchMode, mode: StateBackendMode) + extends StreamingWithMiniBatchTestBase(miniBatch, mode) { + + @Test + def testFirstRowOnProctime(): Unit = { + val t = failingDataSource(StreamTestData.get3TupleData) + .toTable(tEnv, 'a, 'b, 'c, 'proctime) + tEnv.registerTable("T", t) + + val sql = + """ + |SELECT a, b, c + |FROM ( + | SELECT *, + | ROW_NUMBER() OVER (PARTITION BY b ORDER BY proctime) as rowNum + | FROM T + |) + |WHERE rowNum = 1 + """.stripMargin + + val sink = new TestingAppendSink + tEnv.sqlQuery(sql).toAppendStream[Row].addSink(sink) + env.execute() + + val expected = List("1,1,Hi", "2,2,Hello", "4,3,Hello world, how are you?", + "7,4,Comment#1", "11,5,Comment#5", "16,6,Comment#10") + assertEquals(expected.sorted, sink.getAppendResults.sorted) + } + + @Test + def testLastRowOnProctime(): Unit = { + val t = failingDataSource(StreamTestData.get3TupleData) + .toTable(tEnv, 'a, 'b, 'c, 'proctime) + tEnv.registerTable("T", t) + + val sql = + """ + |SELECT a, b, c + |FROM ( + | SELECT *, + | ROW_NUMBER() OVER (PARTITION BY b ORDER BY proctime DESC) as rowNum + | FROM T + |) + |WHERE rowNum = 1 + """.stripMargin + + val sink = new TestingRetractSink + tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink) + env.execute() + + val expected = List("1,1,Hi", "3,2,Hello world", "6,3,Luke Skywalker", + "10,4,Comment#4", "15,5,Comment#9", "21,6,Comment#15") + assertEquals(expected.sorted, sink.getRetractResults.sorted) + } + + @Test + def testFirstRowOnRowtime(): Unit = { + val data = List( + (3L, 2L, "Hello world", 3), + (2L, 2L, "Hello", 2), + (6L, 3L, "Luke Skywalker", 6), + (5L, 3L, "I am fine.", 5), + (7L, 4L, "Comment#1", 7), + (9L, 4L, "Comment#3", 9), + (10L, 4L, "Comment#4", 10), + (8L, 4L, "Comment#2", 8), + (1L, 1L, "Hi", 1), + (4L, 3L, "Helloworld, how are you?", 4)) + + val t = failingDataSource(data) + .assignTimestampsAndWatermarks( + new TimestampAndWatermarkWithOffset[(Long, Long, String, Int)](10L)) + .toTable(tEnv, 'rowtime, 'key, 'str, 'int) + tEnv.registerTable("T", t) + + val sql = + """ + |SELECT key, str, `int` + |FROM ( + | SELECT *, + | ROW_NUMBER() OVER (PARTITION BY key ORDER BY rowtime) as rowNum + | FROM T + |) + |WHERE rowNum = 1 + """.stripMargin + + val sink = new TestingUpsertTableSink(Array(1)) + val table = tEnv.sqlQuery(sql) + writeToSink(table, sink) + + // TODO Deduplicate does not support sort on rowtime now, so it is translated to Rank currently + env.execute() + val expected = List("1,Hi,1", "2,Hello,2", "3,Helloworld, how are you?,4", "4,Comment#1,7") + assertEquals(expected.sorted, sink.getUpsertResults.sorted) + } + + @Test + def testLastRowOnRowtime(): Unit = { + val data = List( + (3L, 2L, "Hello world", 3), + (2L, 2L, "Hello", 2), + (6L, 3L, "Luke Skywalker", 6), + (5L, 3L, "I am fine.", 5), + (7L, 4L, "Comment#1", 7), + (9L, 4L, "Comment#3", 9), + (10L, 4L, "Comment#4", 10), + (8L, 4L, "Comment#2", 8), + (1L, 1L, "Hi", 1), + (4L, 3L, "Helloworld, how are you?", 4)) + + val t = failingDataSource(data) + .assignTimestampsAndWatermarks( + new TimestampAndWatermarkWithOffset[(Long, Long, String, Int)](10L)) + .toTable(tEnv, 'rowtime, 'key, 'str, 'int) + tEnv.registerTable("T", t) + + val sql = + """ + |SELECT key, str, `int` + |FROM ( + | SELECT *, + | ROW_NUMBER() OVER (PARTITION BY key ORDER BY rowtime DESC) as rowNum + | FROM T + |) + |WHERE rowNum = 1 + """.stripMargin + + val sink = new TestingUpsertTableSink(Array(1)) + val table = tEnv.sqlQuery(sql) + writeToSink(table, sink) + + // TODO Deduplicate does not support sort on rowtime now, so it is translated to Rank currently + env.execute() + val expected = List("1,Hi,1", "2,Hello world,3", "3,Luke Skywalker,6", "4,Comment#4,10") + assertEquals(expected.sorted, sink.getUpsertResults.sorted) + } + +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/RankITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/RankITCase.scala new file mode 100644 index 0000000000000..30577b156b029 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/RankITCase.scala @@ -0,0 +1,1296 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.stream.sql + +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.api.scala._ +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.api.TableConfigOptions +import org.apache.flink.table.runtime.utils.StreamingWithStateTestBase.StateBackendMode +import org.apache.flink.table.runtime.utils.{TestingRetractTableSink, TestingUpsertTableSink, _} +import org.apache.flink.types.Row + +import org.junit.Assert._ +import org.junit.{Ignore, _} +import org.junit.runner.RunWith +import org.junit.runners.Parameterized + +@RunWith(classOf[Parameterized]) +class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode) { + + @Test + def testTopN(): Unit = { + val data = List( + ("book", 1, 12), + ("book", 2, 19), + ("book", 4, 11), + ("fruit", 4, 33), + ("fruit", 3, 44), + ("fruit", 5, 22)) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT * + |FROM ( + | SELECT category, shopId, num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num DESC) as rank_num + | FROM T) + |WHERE rank_num <= 2 + """.stripMargin + + val sink = new TestingRetractSink + tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink).setParallelism(1) + env.execute() + + val expected = List( + "book,2,19,1", + "book,1,12,2", + "fruit,3,44,1", + "fruit,4,33,2") + assertEquals(expected.sorted, sink.getRetractResults.sorted) + } + + @Test + def testTopNth(): Unit = { + val data = List( + ("book", 1, 12), + ("book", 2, 19), + ("book", 4, 11), + ("fruit", 4, 33), + ("fruit", 3, 44), + ("fruit", 5, 22)) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT * + |FROM ( + | SELECT category, shopId, num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num DESC) as rank_num + | FROM T) + |WHERE rank_num = 2 + """.stripMargin + + val sink = new TestingRetractSink + tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink).setParallelism(1) + env.execute() + + val expected = List( + "book,1,12,2", + "fruit,4,33,2") + assertEquals(expected.sorted, sink.getRetractResults.sorted) + } + + @Ignore("Enable after retraction infer is introduced") + @Test + def testTopNWithUpsertSink(): Unit = { + val data = List( + ("book", 1, 12), + ("book", 2, 19), + ("book", 4, 11), + ("fruit", 4, 33), + ("fruit", 3, 44), + ("fruit", 5, 22)) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT * + |FROM ( + | SELECT category, shopId, num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num ASC) as rank_num + | FROM T) + |WHERE rank_num <= 2 + """.stripMargin + + val sink = new TestingUpsertTableSink(Array(0, 3)) + val table = tEnv.sqlQuery(sql) + writeToSink(table, sink) + env.execute() + + val expected = List( + "book,4,11,1", + "book,1,12,2", + "fruit,5,22,1", + "fruit,4,33,2") + assertEquals(expected.sorted, sink.getUpsertResults.sorted) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + @Test + def testTopNWithUnary(): Unit = { + val data = List( + ("book", 11, 100), + ("book", 11, 200), + ("book", 12, 400), + ("book", 12, 500), + ("book", 10, 600), + ("book", 10, 700), + ("book", 9, 800), + ("book", 9, 900), + ("book", 10, 500), + ("book", 8, 110), + ("book", 8, 120), + ("book", 7, 1800), + ("book", 9, 300), + ("book", 6, 1900), + ("book", 7, 50), + ("book", 11, 1800), + ("book", 7, 50), + ("book", 8, 2000), + ("book", 6, 700), + ("book", 5, 800), + ("book", 4, 910), + ("book", 3, 1000), + ("book", 2, 1100), + ("book", 1, 1200) + ) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT * + |FROM ( + | SELECT category, shopId, num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num ASC) as rank_num + | FROM ( + | SELECT category, shopId, sum(num) as num + | FROM T + | GROUP BY category, shopId + | )) + |WHERE rank_num <= 3 + """.stripMargin + + val sink = new TestingUpsertTableSink(Array(0, 3)) + val table = tEnv.sqlQuery(sql) + writeToSink(table, sink) + env.execute() + + val updatedExpected = List( + "book,5,800,1", + "book,12,900,2", + "book,4,910,3") + + assertEquals(updatedExpected.sorted, sink.getUpsertResults.sorted) + } + + @Ignore("Enable when state support SortedMapState") + @Test + def testUnarySortTopNOnString(): Unit = { + val data = List( + ("book", 11, "100"), + ("book", 11, "200"), + ("book", 12, "400"), + ("book", 12, "600"), + ("book", 10, "600"), + ("book", 10, "700"), + ("book", 9, "800"), + ("book", 9, "900"), + ("book", 10, "500"), + ("book", 8, "110"), + ("book", 8, "120"), + ("book", 7, "812"), + ("book", 9, "300"), + ("book", 6, "900"), + ("book", 7, "50"), + ("book", 11, "800"), + ("book", 7, "50"), + ("book", 8, "200"), + ("book", 6, "700"), + ("book", 5, "800"), + ("book", 4, "910"), + ("book", 3, "110"), + ("book", 2, "900"), + ("book", 1, "700") + ) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'price) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT * + |FROM ( + | SELECT category, shopId, max_price, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY max_price ASC) as rank_num + | FROM ( + | SELECT category, shopId, max(price) as max_price + | FROM T + | GROUP BY category, shopId + | )) + |WHERE rank_num <= 3 + """.stripMargin + + val table = tEnv.sqlQuery(sql) + val sink = new TestingUpsertTableSink(Array(0, 3)) + tEnv.writeToSink(table, sink) + env.execute() + + val updatedExpected = List( + "book,3,110,1", + "book,8,200,2", + "book,12,600,3") + + assertEquals(updatedExpected.sorted, sink.getUpsertResults.sorted) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + @Test + def testTopNWithGroupBy(): Unit = { + val data = List( + ("book", 1, 11), + ("book", 2, 19), + ("book", 4, 13), + ("book", 1, 11), + ("fruit", 4, 33), + ("fruit", 5, 12), + ("fruit", 3, 44), + ("fruit", 5, 22)) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT * + |FROM ( + | SELECT category, shopId, num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num DESC) as rank_num + | FROM ( + | SELECT category, shopId, sum(num) as num + | FROM T + | GROUP BY category, shopId + | )) + |WHERE rank_num <= 2 + """.stripMargin + + val table = tEnv.sqlQuery(sql) + val sink = new TestingUpsertTableSink(Array(0, 3)) + writeToSink(table, sink) + env.execute() + + val updatedExpected = List( + "book,1,22,1", + "book,2,19,2", + "fruit,3,44,1", + "fruit,5,34,2") + assertEquals(updatedExpected.sorted, sink.getUpsertResults.sorted) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + @Test + def testTopNWithSumAndCondition(): Unit = { + val data = List( + Row.of("book", Int.box(11), Double.box(100)), + Row.of("book", Int.box(11), Double.box(200)), + Row.of("book", Int.box(12), Double.box(400)), + Row.of("book", Int.box(12), Double.box(500)), + Row.of("book", Int.box(10), Double.box(600)), + Row.of("book", Int.box(10), Double.box(700))) + + implicit val tpe: TypeInformation[Row] = new RowTypeInfo( + BasicTypeInfo.STRING_TYPE_INFO, + BasicTypeInfo.INT_TYPE_INFO, + BasicTypeInfo.DOUBLE_TYPE_INFO) // tpe is automatically + + val ds = env.fromCollection(data) + val t = ds.toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", t) + + val subquery = + """ + |SELECT category, shopId, sum(num) as sum_num + |FROM T + |WHERE num >= cast(1.1 as double) + |GROUP BY category, shopId + """.stripMargin + + val sql = + s""" + |SELECT * + |FROM ( + | SELECT category, shopId, sum_num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY sum_num DESC) as rank_num + | FROM ($subquery)) + |WHERE rank_num <= 2 + """.stripMargin + + val table = tEnv.sqlQuery(sql) + val sink = new TestingUpsertTableSink(Array(0, 3)) + writeToSink(table, sink) + env.execute() + + val updatedExpected = List( + "book,10,1300.0,1", + "book,12,900.0,2") + + assertEquals(updatedExpected.sorted, sink.getUpsertResults.sorted) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + @Test + def testTopNthWithGroupBy(): Unit = { + val data = List( + ("book", 1, 11), + ("book", 2, 19), + ("book", 4, 13), + ("book", 1, 11), + ("fruit", 4, 33), + ("fruit", 5, 12), + ("fruit", 3, 44), + ("fruit", 5, 22)) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT * + |FROM ( + | SELECT category, shopId, num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num DESC) as rank_num + | FROM ( + | SELECT category, shopId, sum(num) as num + | FROM T + | GROUP BY category, shopId + | )) + |WHERE rank_num = 2 + """.stripMargin + + val table = tEnv.sqlQuery(sql) + val tableSink = new TestingUpsertTableSink(Array(0, 3)) + writeToSink(table, tableSink) + env.execute() + + val updatedExpected = List( + "book,2,19,2", + "fruit,5,34,2") + + assertEquals(updatedExpected.sorted, tableSink.getUpsertResults.sorted) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + @Test + def testTopNWithGroupByAndRetract(): Unit = { + val data = List( + ("book", 1, 11), + ("book", 2, 19), + ("book", 4, 13), + ("book", 1, 11), + ("fruit", 4, 33), + ("fruit", 5, 12), + ("fruit", 3, 44), + ("fruit", 5, 22)) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT * + |FROM ( + | SELECT category, shopId, num, cnt, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num DESC, cnt ASC) as rank_num + | FROM ( + | SELECT category, shopId, sum(num) as num, count(num) as cnt + | FROM T + | GROUP BY category, shopId + | )) + |WHERE rank_num <= 2 + """.stripMargin + + val sink = new TestingRetractSink + tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink).setParallelism(1) + env.execute() + + val expected = List( + "book,1,22,2,1", + "book,2,19,1,2", + "fruit,3,44,1,1", + "fruit,5,34,2,2") + assertEquals(expected.sorted, sink.getRetractResults.sorted) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + def testTopNthWithGroupByAndRetract(): Unit = { + val data = List( + ("book", 1, 11), + ("book", 2, 19), + ("book", 4, 13), + ("book", 1, 11), + ("fruit", 4, 33), + ("fruit", 5, 12), + ("fruit", 3, 44), + ("fruit", 5, 22)) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT * + |FROM ( + | SELECT category, shopId, num, cnt, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num DESC, cnt ASC) as rank_num + | FROM ( + | SELECT category, shopId, sum(num) as num, count(num) as cnt + | FROM T + | GROUP BY category, shopId + | )) + |WHERE rank_num = 2 + """.stripMargin + + val sink = new TestingRetractSink + tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink).setParallelism(1) + env.execute() + + val expected = List( + "book,2,19,1,2", + "fruit,5,34,2,2") + assertEquals(expected.sorted, sink.getRetractResults.sorted) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + @Test + def testTopNWithGroupByCount(): Unit = { + val data = List( + ("book", 1, 1001), + ("book", 2, 1002), + ("book", 4, 1003), + ("book", 1, 1004), + ("book", 1, 1005), + ("book", 3, 1006), + ("book", 2, 1007), + ("book", 4, 1008), + ("book", 1, 1009), + ("book", 4, 1010), + ("book", 4, 1012), + ("book", 4, 1012), + ("fruit", 4, 1013), + ("fruit", 5, 1014), + ("fruit", 3, 1015), + ("fruit", 4, 1017), + ("fruit", 5, 1018), + ("fruit", 5, 1016)) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'sellId) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT category, rank_num, sells, shopId + |FROM ( + | SELECT category, shopId, sells, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY sells DESC) as rank_num + | FROM ( + | SELECT category, shopId, count(sellId) as sells + | FROM T + | GROUP BY category, shopId + | )) + |WHERE rank_num <= 4 + """.stripMargin + + val table = tEnv.sqlQuery(sql) + val tableSink = new TestingUpsertTableSink(Array(0, 1)) + writeToSink(table, tableSink) + env.execute() + + val expected = List( + "book,1,5,4", + "book,2,4,1", + "book,3,2,2", + "book,4,1,3", + "fruit,1,3,5", + "fruit,2,2,4", + "fruit,3,1,3") + assertEquals(expected.sorted, tableSink.getUpsertResults.sorted) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + @Test + def testTopNthWithGroupByCount(): Unit = { + val data = List( + ("book", 1, 1001), + ("book", 2, 1002), + ("book", 4, 1003), + ("book", 1, 1004), + ("book", 1, 1005), + ("book", 3, 1006), + ("book", 2, 1007), + ("book", 4, 1008), + ("book", 1, 1009), + ("book", 4, 1010), + ("book", 4, 1012), + ("book", 4, 1012), + ("fruit", 4, 1013), + ("fruit", 5, 1014), + ("fruit", 3, 1015), + ("fruit", 4, 1017), + ("fruit", 5, 1018), + ("fruit", 5, 1016)) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'sellId) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT category, rank_num, sells, shopId + |FROM ( + | SELECT category, shopId, sells, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY sells DESC) as rank_num + | FROM ( + | SELECT category, shopId, count(sellId) as sells + | FROM T + | GROUP BY category, shopId + | )) + |WHERE rank_num = 3 + """.stripMargin + + val table = tEnv.sqlQuery(sql) + val tableSink = new TestingUpsertTableSink(Array(0, 1)) + writeToSink(table, tableSink) + env.execute() + + val expected = List( + "book,3,2,2", + "fruit,3,1,3") + assertEquals(expected.sorted, tableSink.getUpsertResults.sorted) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + @Test + def testNestedTopN(): Unit = { + val data = List( + ("book", "a", 1), + ("book", "b", 1), + ("book", "c", 1), + ("fruit", "a", 2), + ("book", "a", 1), + ("book", "d", 0), + ("book", "b", 3), + ("fruit", "b", 6), + ("book", "c", 1), + ("book", "e", 5), + ("book", "d", 4)) + + env.setParallelism(1) + val ds = failingDataSource(data).toTable(tEnv, 'cate, 'shopId, 'sells) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT rank_num, cate, shopId, sells, cnt + |FROM ( + | SELECT *, + | ROW_NUMBER() OVER (PARTITION BY cate ORDER BY sells DESC) as rank_num + | FROM ( + | SELECT cate, shopId, count(*) as cnt, max(sells) as sells + | FROM T + | GROUP BY cate, shopId + | )) + |WHERE rank_num <= 4 + """.stripMargin + + + val sql2 = + s""" + |SELECT rank_num, cate, shopId, sells, cnt + |FROM ( + | SELECT cate, shopId, sells, cnt, + | ROW_NUMBER() OVER (ORDER BY sells DESC) as rank_num + | FROM ($sql) + |) + |WHERE rank_num <= 4 + """.stripMargin + + val tableSink = new TestingUpsertTableSink(Array(0)) + val table = tEnv.sqlQuery(sql2) + writeToSink(table, tableSink) + env.execute() + + val expected = List( + "(true,1,book,a,1,1)", "(true,2,book,b,1,1)", "(true,3,book,c,1,1)", + "(true,1,fruit,a,2,1)", "(true,2,book,a,1,1)", "(true,3,book,b,1,1)", "(true,4,book,c,1,1)", + "(true,2,book,a,1,2)", + "(true,1,book,b,3,2)", "(true,2,fruit,a,2,1)", "(true,3,book,a,1,2)", + "(true,3,book,a,1,2)", + "(true,1,fruit,b,6,1)", "(true,2,book,b,3,2)", "(true,3,fruit,a,2,1)", "(true,4,book,a,1,2)", + "(true,3,fruit,a,2,1)", + "(true,2,book,e,5,1)", + "(true,3,book,b,3,2)", "(true,4,fruit,a,2,1)", + "(true,3,book,b,3,2)", + "(true,3,book,d,4,2)", + "(true,4,book,b,3,2)", + "(true,4,book,b,3,2)") + assertEquals(expected.mkString("\n"), tableSink.getRawResults.mkString("\n")) + + val expected2 = List("1,fruit,b,6,1", "2,book,e,5,1", "3,book,d,4,2", "4,book,b,3,2") + assertEquals(expected2, tableSink.getUpsertResults.sorted) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + @Test + def testTopNWithoutDeduplicate(): Unit = { + val data = List( + ("book", "a", 1), + ("book", "b", 1), + ("book", "c", 1), + ("fruit", "a", 2), + ("book", "a", 1), + ("book", "d", 0), + ("book", "b", 3), + ("fruit", "b", 6), + ("book", "c", 1), + ("book", "e", 5), + ("book", "d", 4)) + + env.setParallelism(1) + val ds = failingDataSource(data).toTable(tEnv, 'cate, 'shopId, 'sells) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT rank_num, cate, shopId, sells, cnt + |FROM ( + | SELECT *, + | ROW_NUMBER() OVER (PARTITION BY cate ORDER BY sells DESC) as rank_num + | FROM ( + | SELECT cate, shopId, count(*) as cnt, max(sells) as sells + | FROM T + | GROUP BY cate, shopId + | )) + |WHERE rank_num <= 4 + """.stripMargin + + tEnv.getConfig.getConf.setLong(TableConfigOptions.SQL_EXEC_TOPN_CACHE_SIZE, 1) + val tableSink = new TestingUpsertTableSink(Array(0)) + val table = tEnv.sqlQuery(sql) + writeToSink(table, tableSink) + env.execute() + + val expected = List( + "(true,1,book,a,1,1)", + "(true,2,book,b,1,1)", + "(true,3,book,c,1,1)", + "(true,1,fruit,a,2,1)", + "(true,1,book,a,1,2)", + "(true,4,book,d,0,1)", + "(true,1,book,b,3,2)", + "(true,2,book,a,1,2)", + "(true,1,fruit,b,6,1)", + "(true,2,fruit,a,2,1)", + "(true,3,book,c,1,2)", + "(true,1,book,e,5,1)", + "(true,2,book,b,3,2)", + "(true,3,book,a,1,2)", + "(true,4,book,c,1,2)", + "(true,2,book,d,4,2)", + "(true,3,book,b,3,2)", + "(true,4,book,a,1,2)") + + assertEquals(expected, tableSink.getRawResults) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + @Test + def testTopNWithVariableTopSize(): Unit = { + val data = List( + ("book", 1, 1001, 4), + ("book", 2, 1002, 4), + ("book", 4, 1003, 4), + ("book", 1, 1004, 4), + ("book", 1, 1005, 4), + ("book", 3, 1006, 4), + ("book", 2, 1007, 4), + ("book", 4, 1008, 4), + ("book", 1, 1009, 4), + ("book", 4, 1010, 4), + ("book", 4, 1012, 4), + ("book", 4, 1012, 4), + ("fruit", 4, 1013, 2), + ("fruit", 5, 1014, 2), + ("fruit", 3, 1015, 2), + ("fruit", 4, 1017, 2), + ("fruit", 5, 1018, 2), + ("fruit", 5, 1016, 2)) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'sellId, 'topSize) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT category, rank_num, sells, shopId + |FROM ( + | SELECT category, shopId, sells, topSize, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY sells DESC) as rank_num + | FROM ( + | SELECT category, shopId, count(sellId) as sells, max(topSize) as topSize + | FROM T + | GROUP BY category, shopId + | )) + |WHERE rank_num <= topSize + """.stripMargin + + val tableSink = new TestingUpsertTableSink(Array(0, 1)) + val table = tEnv.sqlQuery(sql) + writeToSink(table, tableSink) + env.execute() + + val expected = List( + "book,1,5,4", + "book,2,4,1", + "book,3,2,2", + "book,4,1,3", + "fruit,1,3,5", + "fruit,2,2,4") + assertEquals(expected.sorted, tableSink.getUpsertResults.sorted) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + @Test + def testTopNUnaryComplexScenario(): Unit = { + val data = List( + ("book", 1, 11), + ("book", 2, 19), + ("book", 4, 13), + ("book", 1, 11), // backward update in heap + ("book", 3, 23), // elems exceed topn size after insert + ("book", 5, 19), // sort map shirk out some elem after insert + ("book", 7, 10), // sort map keeps a little more than topn size elems after insert + ("book", 8, 13), // sort map now can shrink out-of-range elems after another insert + ("book", 10, 13), // Once again, sort map keeps a little more elems after insert + ("book", 8, 6), // backward update from heap to state + ("book", 10, 6), // backward update from heap to state, and sort map load more data + ("book", 5, 3), // backward update from heap to state + ("book", 10, 1), // backward update in heap, and then sort map shrink some data + ("book", 5, 1), // backward update in state + ("book", 5, -3), // forward update in state + ("book", 2, -10), // forward update in heap, and then sort map shrink some data + ("book", 10, -7), // forward update from state to heap + ("book", 11, 13), // insert into heap + ("book", 12, 10), // insert into heap, and sort map shrink some data + ("book", 15, 14) // insert into state + ) + + env.setParallelism(1) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT * + |FROM ( + | SELECT category, shopId, num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num ASC) as rank_num + | FROM ( + | SELECT category, shopId, sum(num) as num + | FROM T + | GROUP BY category, shopId + | )) + |WHERE rank_num <= 3 + """.stripMargin + + val tableSink = new TestingUpsertTableSink(Array(0, 3)) + val table = tEnv.sqlQuery(sql) + writeToSink(table, tableSink) + env.execute() + + val expected = List( + "(true,book,1,11,1)", + "(true,book,2,19,2)", + "(true,book,4,13,2)", + "(true,book,2,19,3)", + + "(true,book,4,13,1)", + "(true,book,2,19,2)", + "(true,book,1,22,3)", + + "(true,book,5,19,3)", + + "(true,book,7,10,1)", + "(true,book,4,13,2)", + "(true,book,2,19,3)", + + "(true,book,8,13,3)", + + "(true,book,10,13,3)", + + "(true,book,2,19,3)", + + "(true,book,2,9,1)", + "(true,book,7,10,2)", + "(true,book,4,13,3)", + + "(true,book,12,10,3)") + + assertEquals(expected.mkString("\n"), tableSink.getRawResults.mkString("\n")) + + val updatedExpected = List( + "book,2,9,1", + "book,7,10,2", + "book,12,10,3") + + assertEquals(updatedExpected.sorted, tableSink.getUpsertResults.sorted) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + @Test + def testTopNWithGroupByAvgWithoutRowNumber(): Unit = { + val data = List( + ("book", 1, 100), + ("book", 3, 110), + ("book", 4, 120), + ("book", 1, 200), + ("book", 1, 200), + ("book", 2, 300), + ("book", 2, 400), + ("book", 4, 500), + ("book", 1, 400), + ("fruit", 5, 100)) + + env.setParallelism(1) + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'sellId) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT category, shopId, avgSellId + |FROM ( + | SELECT category, shopId, avgSellId, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY avgSellId DESC) as rank_num + | FROM ( + | SELECT category, shopId, AVG(sellId) as avgSellId + | FROM T + | GROUP BY category, shopId + | )) + |WHERE rank_num <= 3 + """.stripMargin + + val tableSink = new TestingUpsertTableSink(Array(0, 1)) + val table = tEnv.sqlQuery(sql) + writeToSink(table, tableSink) + env.execute() + + val expected = List( + "(true,book,1,100.0)", + "(true,book,3,110.0)", + "(true,book,4,120.0)", + "(true,book,1,150.0)", + "(true,book,1,166.66666666666666)", + "(true,book,2,300.0)", + "(false,book,3,110.0)", + "(true,book,2,350.0)", + "(true,book,4,310.0)", + "(true,book,1,225.0)", + "(true,fruit,5,100.0)") + + assertEquals(expected, tableSink.getRawResults) + + val updatedExpected = List( + "book,1,225.0", + "book,2,350.0", + "book,4,310.0", + "fruit,5,100.0") + + assertEquals(updatedExpected.sorted, tableSink.getUpsertResults.sorted) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + @Test + def testTopNWithGroupByCountWithoutRowNumber(): Unit = { + val data = List( + ("book", 1, 1001), + ("book", 3, 1006), + ("book", 4, 1003), + ("book", 1, 1004), + ("book", 1, 1005), + ("book", 2, 1002), + ("book", 2, 1007), + ("book", 4, 1008), + ("book", 1, 1009), + ("book", 4, 1010), + ("book", 4, 1012), + ("book", 4, 1012), + ("fruit", 4, 1013), + ("fruit", 5, 1014), + ("fruit", 3, 1015), + ("fruit", 4, 1017), + ("fruit", 5, 1018), + ("fruit", 5, 1016)) + + env.setParallelism(1) + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'sellId) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT category, shopId, sells + |FROM ( + | SELECT category, shopId, sells, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY sells DESC) as rank_num + | FROM ( + | SELECT category, shopId, count(sellId) as sells + | FROM T + | GROUP BY category, shopId + | )) + |WHERE rank_num <= 3 + """.stripMargin + + val tableSink = new TestingUpsertTableSink(Array(0, 1)) + val table = tEnv.sqlQuery(sql) + writeToSink(table, tableSink) + env.execute() + + val expected = List( + "(true,book,1,1)", + "(true,book,3,1)", + "(true,book,4,1)", + "(true,book,1,2)", + "(true,book,1,3)", + "(true,book,2,2)", + "(false,book,4,1)", + "(true,book,4,2)", + "(false,book,3,1)", + "(true,book,1,4)", + "(true,book,4,3)", + "(true,book,4,4)", + "(true,book,4,5)", + "(true,fruit,4,1)", + "(true,fruit,5,1)", + "(true,fruit,3,1)", + "(true,fruit,4,2)", + "(true,fruit,5,2)", + "(true,fruit,5,3)") + assertEquals(expected, tableSink.getRawResults) + + val updatedExpected = List( + "book,4,5", + "book,1,4", + "book,2,2", + "fruit,5,3", + "fruit,4,2", + "fruit,3,1") + assertEquals(updatedExpected.sorted, tableSink.getUpsertResults.sorted) + } + + @Test + def testTopNWithoutRowNumber(): Unit = { + val data = List( + ("book", 1, 12), + ("book", 2, 19), + ("book", 4, 11), + ("book", 5, 20), + ("fruit", 4, 33), + ("fruit", 3, 44), + ("fruit", 5, 22), + ("fruit", 1, 40)) + + env.setParallelism(1) + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT category, num, shopId + |FROM ( + | SELECT category, shopId, num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num DESC) as rank_num + | FROM T) + |WHERE rank_num <= 2 + """.stripMargin + + val tableSink = new TestingUpsertTableSink(Array(0, 2)) + val table = tEnv.sqlQuery(sql) + writeToSink(table, tableSink) + env.execute() + + val expected = List( + "(true,book,12,1)", + "(true,book,19,2)", + "(false,book,12,1)", + "(true,book,20,5)", + "(true,fruit,33,4)", + "(true,fruit,44,3)", + "(false,fruit,33,4)", + "(true,fruit,40,1)") + assertEquals(expected, tableSink.getRawResults) + + val updatedExpected = List( + "book,19,2", + "book,20,5", + "fruit,40,1", + "fruit,44,3") + assertEquals(updatedExpected.sorted, tableSink.getUpsertResults.sorted) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + @Test + def testMultipleRetractTopNAfterAgg(): Unit = { + val data = List( + ("book", 1, 12), + ("book", 1, 13), + ("book", 2, 19), + ("book", 4, 11), + ("fruit", 4, 33), + ("fruit", 3, 44), + ("fruit", 5, 22)) + + env.setParallelism(1) + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val subquery = + s""" + |SELECT category, shopId, SUM(num) as sum_num, MAX(num) as max_num, + | AVG(num) as avg_num, COUNT(num) as cnt + |FROM T + |GROUP BY category, shopId + |""".stripMargin + + val t1 = tEnv.sqlQuery(subquery) + tEnv.registerTable("MyView", t1) + + val sink1 = new TestingRetractSink + tEnv.sqlQuery( + s""" + |SELECT * + |FROM ( + | SELECT category, shopId, sum_num, avg_num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY sum_num DESC, avg_num ASC + | ) as rank_num + | FROM MyView) + |WHERE rank_num <= 2 + |""".stripMargin).toRetractStream[Row].addSink(sink1).setParallelism(1) + env.execute() + + val expected1 = List( + "book,1,25,12.5,1", + "book,2,19,19.0,2", + "fruit,3,44,44.0,1", + "fruit,4,33,33.0,2") + assertEquals(expected1.sorted, sink1.getRetractResults.sorted) + + val sink2 = new TestingRetractSink + val table2 = tEnv.sqlQuery( + s""" + |SELECT * + |FROM ( + | SELECT category, shopId, max_num, cnt, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY max_num DESC, cnt ASC) as rank_num + | FROM MyView) + |WHERE rank_num <= 2 + |""".stripMargin).toRetractStream[Row].addSink(sink2).setParallelism(1) + + env.execute() + + val expected2 = List( + "book,2,19,1,1", + "book,1,13,2,2", + "fruit,3,44,1,1", + "fruit,4,33,1,2") + assertEquals(expected2.sorted, sink2.getRetractResults.sorted) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + @Test + def testMultipleUnaryTopNAfterAgg(): Unit = { + val data = List( + ("book", 1, 12), + ("book", 1, 13), + ("book", 2, 19), + ("book", 4, 11), + ("fruit", 4, 33), + ("fruit", 3, 44), + ("fruit", 5, 22)) + + env.setParallelism(1) + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val subquery = + s""" + |SELECT category, shopId, SUM(num) as sum_num, MAX(num) as max_num + |FROM T + |GROUP BY category, shopId + |""".stripMargin + + val t1 = tEnv.sqlQuery(subquery) + tEnv.registerTable("MyView", t1) + + val sink1 = new TestingUpsertTableSink(Array(0, 3)) + val table1 = tEnv.sqlQuery( + s""" + |SELECT * + |FROM ( + | SELECT category, shopId, sum_num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY sum_num DESC) as rank_num + | FROM MyView) + |WHERE rank_num <= 2 + |""".stripMargin) + writeToSink(table1, sink1) + + val sink2 = new TestingUpsertTableSink(Array(0, 3)) + val table2 = tEnv.sqlQuery( + s""" + |SELECT * + |FROM ( + | SELECT category, shopId, max_num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY max_num DESC) as rank_num + | FROM MyView) + |WHERE rank_num <= 2 + |""".stripMargin) + + writeToSink(table2, sink2) + + val expected1 = List( + "book,1,25,1", + "book,2,19,2", + "fruit,3,44,1", + "fruit,4,33,2") + assertEquals(expected1.sorted, sink1.getUpsertResults.sorted) + + val expected2 = List( + "book,2,19,1", + "book,1,13,2", + "fruit,3,44,1", + "fruit,4,33,2") + assertEquals(expected2.sorted, sink2.getUpsertResults.sorted) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + @Test + def testMultipleUpdateTopNAfterAgg(): Unit = { + val data = List( + ("book", 1, 12), + ("book", 1, 13), + ("book", 2, 19), + ("book", 4, 11), + ("fruit", 4, 33), + ("fruit", 3, 44), + ("fruit", 5, 22)) + + env.setParallelism(1) + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val subquery = + s""" + |SELECT category, shopId, COUNT(num) as cnt_num, MAX(num) as max_num + |FROM T + |GROUP BY category, shopId + |""".stripMargin + + val t1 = tEnv.sqlQuery(subquery) + tEnv.registerTable("MyView", t1) + + val sink1 = new TestingRetractTableSink + val table1 = tEnv.sqlQuery( + s""" + |SELECT * + |FROM ( + | SELECT category, shopId, cnt_num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY cnt_num DESC) as rank_num + | FROM MyView) + |WHERE rank_num <= 2 + |""".stripMargin) + writeToSink(table1, sink1) + + val sink2 = new TestingRetractTableSink + val table2 = tEnv.sqlQuery( + s""" + |SELECT * + |FROM ( + | SELECT category, shopId, max_num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY max_num DESC) as rank_num + | FROM MyView) + |WHERE rank_num <= 2 + |""".stripMargin) + + writeToSink(table2, sink2) + + val expected1 = List( + "book,1,2,1", + "book,2,1,2", + "fruit,4,1,1", + "fruit,3,1,2") + assertEquals(expected1.sorted, sink1.getRetractResults.sorted) + + val expected2 = List( + "book,2,19,1", + "book,1,13,2", + "fruit,3,44,1", + "fruit,4,33,2") + assertEquals(expected2.sorted, sink2.getRetractResults.sorted) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + @Test + def testUpdateRank(): Unit = { + val data = List( + (1, 1), (1, 2), (1, 3), + (2, 2), (2, 3), (2, 4), + (3, 3), (3, 4), (3, 5)) + + val ds = failingDataSource(data).toTable(tEnv, 'a, 'b) + tEnv.registerTable("T", ds) + + // We use max here to ensure the usage of update rank + val sql = "SELECT a, max(b) FROM T GROUP BY a ORDER BY a LIMIT 2" + + val sink = new TestingRetractSink + tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink).setParallelism(1) + env.execute() + + val expected = List("1,3", "2,4") + assertEquals(expected.sorted, sink.getRetractResults.sorted) + } + + @Ignore("Enable after streamAgg implements StreamExecNode") + @Test + def testUpdateRankWithOffset(): Unit = { + val data = List( + (1, 1), (1, 2), (1, 3), + (2, 2), (2, 3), (2, 4), + (3, 3), (3, 4), (3, 5)) + + val ds = failingDataSource(data).toTable(tEnv, 'a, 'b) + tEnv.registerTable("T", ds) + + // We use max here to ensure the usage of update rank + val sql = "SELECT a, max(b) FROM T GROUP BY a ORDER BY a LIMIT 2 OFFSET 1" + + val sink = new TestingRetractSink + tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink).setParallelism(1) + env.execute() + + val expected = List("2,4", "3,5") + assertEquals(expected.sorted, sink.getRetractResults.sorted) + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamTestSink.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamTestSink.scala index 95138c28edf6d..7c32d5c2a5bee 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamTestSink.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamTestSink.scala @@ -19,23 +19,26 @@ package org.apache.flink.table.runtime.utils import org.apache.flink.api.common.ExecutionConfig +import org.apache.flink.api.common.functions.MapFunction import org.apache.flink.api.common.io.OutputFormat import org.apache.flink.api.common.state.{ListState, ListStateDescriptor} import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2} -import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.api.java.typeutils.{RowTypeInfo, TupleTypeInfo} import org.apache.flink.configuration.Configuration import org.apache.flink.runtime.state.{FunctionInitializationContext, FunctionSnapshotContext} import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction import org.apache.flink.streaming.api.datastream.{DataStream, DataStreamSink} import org.apache.flink.streaming.api.functions.sink.RichSinkFunction import org.apache.flink.table.api.{TableConfig, Types} -import org.apache.flink.table.dataformat.BaseRow -import org.apache.flink.table.sinks.{AppendStreamTableSink, BatchTableSink} +import org.apache.flink.table.dataformat.{BaseRow, DataFormatConverters, GenericRow} +import org.apache.flink.table.sinks._ +import org.apache.flink.table.`type`.TypeConverters.createInternalTypeFromTypeInfo import org.apache.flink.table.typeutils.BaseRowTypeInfo import org.apache.flink.table.util.BaseRowTestUtil import org.apache.flink.types.Row +import _root_.java.lang.{Boolean => JBoolean} import _root_.java.util.TimeZone import _root_.java.util.concurrent.atomic.AtomicInteger @@ -150,6 +153,159 @@ final class TestingAppendSink(tz: TimeZone) extends AbstractExactlyOnceSink[Row] def getAppendResults: List[String] = getResults } +final class TestingUpsertSink(keys: Array[Int], tz: TimeZone) + extends AbstractExactlyOnceSink[(Boolean, BaseRow)] { + + private var upsertResultsState: ListState[String] = _ + private var localUpsertResults: mutable.Map[String, String] = _ + private var fieldTypes: Array[TypeInformation[_]] = _ + + def this(keys: Array[Int]) { + this(keys, TimeZone.getTimeZone("UTC")) + } + + def configureTypes(fieldTypes: Array[TypeInformation[_]]): Unit = { + this.fieldTypes = fieldTypes + } + + override def initializeState(context: FunctionInitializationContext): Unit = { + super.initializeState(context) + upsertResultsState = context.getOperatorStateStore.getListState( + new ListStateDescriptor[String]("sink-upsert-results", Types.STRING)) + + localUpsertResults = mutable.HashMap.empty[String, String] + + if (context.isRestored) { + var key: String = null + var value: String = null + for (entry <- upsertResultsState.get().asScala) { + if (key == null) { + key = entry + } else { + value = entry + localUpsertResults += (key -> value) + key = null + value = null + } + } + if (key != null) { + throw new RuntimeException("The resultState is corrupt.") + } + } + + val taskId = getRuntimeContext.getIndexOfThisSubtask + StreamTestSink.synchronized{ + StreamTestSink.globalUpsertResults(idx) += (taskId -> localUpsertResults) + } + } + + override def snapshotState(context: FunctionSnapshotContext): Unit = { + super.snapshotState(context) + upsertResultsState.clear() + for ((key, value) <- localUpsertResults) { + upsertResultsState.add(key) + upsertResultsState.add(value) + } + } + + def invoke(d: (Boolean, BaseRow)): Unit = { + this.synchronized { + val wrapRow = new GenericRow(2) + wrapRow.setField(0, d._1) + wrapRow.setField(1, d._2) + val converter = + DataFormatConverters.getConverterForTypeInfo( + new TupleTypeInfo(Types.BOOLEAN, new RowTypeInfo(fieldTypes: _*))) + .asInstanceOf[DataFormatConverters.DataFormatConverter[BaseRow, JTuple2[JBoolean, Row]]] + val v = converter.toExternal(wrapRow) + val rowString = TestSinkUtil.rowToString(v.f1, tz) + val tupleString = "(" + v.f0.toString + "," + rowString + ")" + localResults += tupleString + val keyString = TestSinkUtil.rowToString(Row.project(v.f1, keys), tz) + if (v.f0) { + localUpsertResults += (keyString -> rowString) + } else { + val oldValue = localUpsertResults.remove(keyString) + if (oldValue.isEmpty) { + throw new RuntimeException("Tried to delete a value that wasn't inserted first. " + + "This is probably an incorrectly implemented test. " + + "Try to set the parallelism of the sink to 1.") + } + } + } + } + + def getRawResults: List[String] = getResults + + def getUpsertResults: List[String] = { + clearAndStashGlobalResults() + val result = ArrayBuffer.empty[String] + this.globalUpsertResults.foreach { + case (_, map) => map.foreach(result += _._2) + } + result.toList + } +} + +final class TestingUpsertTableSink(keys: Array[Int], tz: TimeZone) + extends UpsertStreamTableSink[BaseRow] { + var fNames: Array[String] = _ + var fTypes: Array[TypeInformation[_]] = _ + var sink = new TestingUpsertSink(keys, tz) + + def this(keys: Array[Int]) { + this(keys, TimeZone.getTimeZone("UTC")) + } + + override def setKeyFields(keys: Array[String]): Unit = { + // ignore + } + + override def setIsAppendOnly(isAppendOnly: JBoolean): Unit = { + // ignore + } + + override def getRecordType: TypeInformation[BaseRow] = + new BaseRowTypeInfo(fTypes.map(createInternalTypeFromTypeInfo(_)), fNames) + + override def getFieldNames: Array[String] = fNames + + override def getFieldTypes: Array[TypeInformation[_]] = fTypes + + override def emitDataStream(dataStream: DataStream[JTuple2[JBoolean, BaseRow]]) = { + dataStream.map(new MapFunction[JTuple2[JBoolean, BaseRow], (Boolean, BaseRow)] { + override def map(value: JTuple2[JBoolean, BaseRow]): (Boolean, BaseRow) = { + (value.f0, value.f1) + } + }) + .addSink(sink) + .name(s"TestingUpsertTableSink(keys=${ + if (keys != null) { + "(" + keys.mkString(",") + ")" + } else { + "null" + } + })") + .setParallelism(1) + } + + override def configure( + fieldNames: Array[String], + fieldTypes: Array[TypeInformation[_]]) + : TableSink[JTuple2[JBoolean, BaseRow]] = { + val copy = new TestingUpsertTableSink(keys, tz) + copy.fNames = fieldNames + copy.fTypes = fieldTypes + sink.configureTypes(fieldTypes) + copy.sink = sink + copy + } + + def getRawResults: List[String] = sink.getRawResults + + def getUpsertResults: List[String] = sink.getUpsertResults +} + final class TestingAppendTableSink(tz: TimeZone) extends AppendStreamTableSink[Row] with BatchTableSink[Row]{ var fNames: Array[String] = _ @@ -242,3 +398,118 @@ class TestingOutputFormat[T](tz: TimeZone) result.toList } } + +class TestingRetractSink(tz: TimeZone) + extends AbstractExactlyOnceSink[(Boolean, Row)] { + protected var retractResultsState: ListState[String] = _ + protected var localRetractResults: ArrayBuffer[String] = _ + + def this() { + this(TimeZone.getTimeZone("UTC")) + } + + override def initializeState(context: FunctionInitializationContext): Unit = { + super.initializeState(context) + retractResultsState = context.getOperatorStateStore.getListState( + new ListStateDescriptor[String]("sink-retract-results", Types.STRING)) + + localRetractResults = mutable.ArrayBuffer.empty[String] + + if (context.isRestored) { + for (value <- retractResultsState.get().asScala) { + localRetractResults += value + } + } + + val taskId = getRuntimeContext.getIndexOfThisSubtask + StreamTestSink.synchronized{ + StreamTestSink.globalRetractResults(idx) += (taskId -> localRetractResults) + } + } + + override def snapshotState(context: FunctionSnapshotContext): Unit = { + super.snapshotState(context) + retractResultsState.clear() + for (value <- localRetractResults) { + retractResultsState.add(value) + } + } + + def invoke(v: (Boolean, Row)): Unit = { + this.synchronized { + val tupleString = "(" + v._1.toString + "," + TestSinkUtil.rowToString(v._2, tz) + ")" + localResults += tupleString + val rowString = TestSinkUtil.rowToString(v._2, tz) + if (v._1) { + localRetractResults += rowString + } else { + val index = localRetractResults.indexOf(rowString) + if (index >= 0) { + localRetractResults.remove(index) + } else { + throw new RuntimeException("Tried to retract a value that wasn't added first. " + + "This is probably an incorrectly implemented test. " + + "Try to set the parallelism of the sink to 1.") + } + } + } + } + + def getRawResults: List[String] = getResults + + def getRetractResults: List[String] = { + clearAndStashGlobalResults() + val result = ArrayBuffer.empty[String] + this.globalRetractResults.foreach { + case (_, list) => result ++= list + } + result.toList + } +} + +final class TestingRetractTableSink(tz: TimeZone) extends RetractStreamTableSink[Row] { + + var fNames: Array[String] = _ + var fTypes: Array[TypeInformation[_]] = _ + var sink = new TestingRetractSink(tz) + + def this() { + this(TimeZone.getTimeZone("UTC")) + } + + override def emitDataStream(dataStream: DataStream[JTuple2[JBoolean, Row]]) = { + dataStream.map(new MapFunction[JTuple2[JBoolean, Row], (Boolean, Row)] { + override def map(value: JTuple2[JBoolean, Row]): (Boolean, Row) = { + (value.f0, value.f1) + } + }).setParallelism(dataStream.getParallelism) + .addSink(sink) + .name("TestingRetractTableSink") + .setParallelism(dataStream.getParallelism) + } + + override def getRecordType: TypeInformation[Row] = + new RowTypeInfo(fTypes, fNames) + + override def getFieldNames: Array[String] = fNames + + override def getFieldTypes: Array[TypeInformation[_]] = fTypes + + override def configure( + fieldNames: Array[String], + fieldTypes: Array[TypeInformation[_]]): TableSink[JTuple2[JBoolean, Row]] = { + val copy = new TestingRetractTableSink(tz) + copy.fNames = fieldNames + copy.fTypes = fieldTypes + copy.sink = sink + copy + } + + def getRawResults: List[String] = { + sink.getRawResults + } + + def getRetractResults: List[String] = { + sink.getRetractResults + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingTestBase.scala index 1335e897bda1b..94b814fe833a8 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingTestBase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingTestBase.scala @@ -20,7 +20,9 @@ package org.apache.flink.table.runtime.utils import org.apache.flink.streaming.api.TimeCharacteristic import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment +import org.apache.flink.table.api.{Table, TableImpl} import org.apache.flink.table.api.scala.StreamTableEnvironment +import org.apache.flink.table.sinks.TableSink import org.apache.flink.test.util.AbstractTestBase import org.junit.rules.{ExpectedException, TemporaryFolder} @@ -53,4 +55,8 @@ class StreamingTestBase extends AbstractTestBase { this.tEnv = StreamTableEnvironment.create(env) } + def writeToSink(table: Table, sink: TableSink[_]): Unit = { + TableUtil.writeToSink(table.asInstanceOf[TableImpl], sink) + } + } diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingWithMiniBatchTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingWithMiniBatchTestBase.scala new file mode 100644 index 0000000000000..d502b902f2d86 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingWithMiniBatchTestBase.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.utils + +import org.apache.flink.table.api.TableConfigOptions +import org.apache.flink.table.runtime.utils.StreamingWithStateTestBase.{HEAP_BACKEND, ROCKSDB_BACKEND, StateBackendMode} +import org.apache.flink.table.runtime.utils.StreamingWithMiniBatchTestBase.{MiniBatchMode, MiniBatchOff, MiniBatchOn} + +import java.util + +import scala.collection.JavaConversions._ + +import org.junit.runners.Parameterized + +abstract class StreamingWithMiniBatchTestBase( + miniBatch: MiniBatchMode, + state: StateBackendMode) + extends StreamingWithStateTestBase(state) { + + override def before(): Unit = { + super.before() + // set mini batch + val tableConfig = tEnv.getConfig + miniBatch match { + case MiniBatchOn => + tableConfig.getConf.setLong(TableConfigOptions.SQL_EXEC_MINIBATCH_ALLOW_LATENCY, 1000L) + tableConfig.getConf.setLong(TableConfigOptions.SQL_EXEC_MINIBATCH_SIZE, 3L) + case MiniBatchOff => + tableConfig.getConf.removeConfig(TableConfigOptions.SQL_EXEC_MINIBATCH_ALLOW_LATENCY) + } + } +} + +object StreamingWithMiniBatchTestBase { + + case class MiniBatchMode(on: Boolean) { + override def toString: String = { + if (on){ + "MiniBatch=ON" + } else { + "MiniBatch=OFF" + } + } + } + + val MiniBatchOff = MiniBatchMode(false) + val MiniBatchOn = MiniBatchMode(true) + + @Parameterized.Parameters(name = "{0}, StateBackend={1}") + def parameters(): util.Collection[Array[java.lang.Object]] = { + Seq[Array[AnyRef]]( + Array(MiniBatchOff, HEAP_BACKEND), + Array(MiniBatchOff, ROCKSDB_BACKEND), + Array(MiniBatchOn, HEAP_BACKEND), + Array(MiniBatchOn, ROCKSDB_BACKEND)) + } + +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingWithStateTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingWithStateTestBase.scala new file mode 100644 index 0000000000000..602edd84e4666 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingWithStateTestBase.scala @@ -0,0 +1,271 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.utils + +import org.apache.flink.api.common.restartstrategy.RestartStrategies +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.common.typeutils.CompositeType +import org.apache.flink.configuration.{CheckpointingOptions, Configuration} +import org.apache.flink.runtime.state.memory.MemoryStateBackend +import org.apache.flink.streaming.api.CheckpointingMode +import org.apache.flink.streaming.api.functions.source.FromElementsFunction +import org.apache.flink.streaming.api.scala.DataStream +import org.apache.flink.table.api.{TableEnvironment, Types} +import org.apache.flink.table.dataformat.{BaseRow, BinaryRow, BinaryRowWriter, BinaryString} +import org.apache.flink.table.runtime.utils.StreamingWithStateTestBase.{HEAP_BACKEND, ROCKSDB_BACKEND, StateBackendMode} +import org.apache.flink.table.`type`.TypeConverters.createInternalTypeFromTypeInfo +import org.apache.flink.table.`type`.RowType + +import scala.collection.JavaConversions._ +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import java.io.File +import java.util + +import org.junit.runners.Parameterized +import org.junit.{After, Assert, Before} + +import org.apache.flink.contrib.streaming.state.RocksDBStateBackend + +class StreamingWithStateTestBase(state: StateBackendMode) extends StreamingTestBase { + + enableObjectReuse = state match { + case HEAP_BACKEND => false // TODO gemini not support obj reuse now. + case ROCKSDB_BACKEND => true + } + + private val classLoader = Thread.currentThread.getContextClassLoader + + var baseCheckpointPath: File = _ + + @Before + override def before(): Unit = { + super.before() + // set state backend + baseCheckpointPath = tempFolder.newFolder().getAbsoluteFile + state match { + case HEAP_BACKEND => + val conf = new Configuration() + conf.setBoolean(CheckpointingOptions.ASYNC_SNAPSHOTS, true) + env.setStateBackend(new MemoryStateBackend( + "file://" + baseCheckpointPath, null).configure(conf, classLoader)) + case ROCKSDB_BACKEND => + val conf = new Configuration() + conf.setBoolean(CheckpointingOptions.INCREMENTAL_CHECKPOINTS, true) + env.setStateBackend(new RocksDBStateBackend( + "file://" + baseCheckpointPath).configure(conf, classLoader)) + } + this.tEnv = TableEnvironment.getTableEnvironment(env) + FailingCollectionSource.failedBefore = true + } + + @After + def after(): Unit = { + Assert.assertTrue(FailingCollectionSource.failedBefore) + } + + /** + * Creates a BinaryRow DataStream from the given non-empty [[Seq]]. + */ + def failingBinaryRowSource[T: TypeInformation](data: Seq[T]): DataStream[BaseRow] = { + val typeInfo = implicitly[TypeInformation[_]].asInstanceOf[CompositeType[_]] + val result = new mutable.MutableList[BaseRow] + val reuse = new BinaryRow(typeInfo.getArity) + val writer = new BinaryRowWriter(reuse) + data.foreach { + case p: Product => + for (i <- 0 until typeInfo.getArity) { + val fieldType = typeInfo.getTypeAt(i).asInstanceOf[TypeInformation[_]] + fieldType match { + case Types.INT => writer.writeInt(i, p.productElement(i).asInstanceOf[Int]) + case Types.LONG => writer.writeLong(i, p.productElement(i).asInstanceOf[Long]) + case Types.STRING => writer.writeString(i, + p.productElement(i).asInstanceOf[BinaryString]) + case Types.BOOLEAN => writer.writeBoolean(i, p.productElement(i).asInstanceOf[Boolean]) + } + } + writer.complete() + result += reuse.copy() + case _ => throw new UnsupportedOperationException + } + val newTypeInfo = createInternalTypeFromTypeInfo(typeInfo).asInstanceOf[RowType].toTypeInfo + failingDataSource(result)(newTypeInfo.asInstanceOf[TypeInformation[BaseRow]]) + } + + /** + * Creates a DataStream from the given non-empty [[Seq]]. + */ + def retainStateDataSource[T: TypeInformation](data: Seq[T]): DataStream[T] = { + env.enableCheckpointing(100, CheckpointingMode.EXACTLY_ONCE) + env.setRestartStrategy(RestartStrategies.noRestart()) + env.setParallelism(1) + // reset failedBefore flag to false + FailingCollectionSource.reset() + + require(data != null, "Data must not be null.") + val typeInfo = implicitly[TypeInformation[T]] + + val collection = scala.collection.JavaConversions.asJavaCollection(data) + // must not have null elements and mixed elements + FromElementsFunction.checkCollection(data, typeInfo.getTypeClass) + + val function = new FailingCollectionSource[T]( + typeInfo.createSerializer(env.getConfig), + collection, + data.length, // fail after half elements + true) + + env.addSource(function)(typeInfo).setMaxParallelism(1) + } + + /** + * Creates a DataStream from the given non-empty [[Seq]]. + */ + def failingDataSource[T: TypeInformation](data: Seq[T]): DataStream[T] = { + env.enableCheckpointing(100, CheckpointingMode.EXACTLY_ONCE) + env.setRestartStrategy(RestartStrategies.fixedDelayRestart(1, 0)) + env.setParallelism(1) + // reset failedBefore flag to false + FailingCollectionSource.reset() + + require(data != null, "Data must not be null.") + val typeInfo = implicitly[TypeInformation[T]] + + val collection = scala.collection.JavaConversions.asJavaCollection(data) + // must not have null elements and mixed elements + FromElementsFunction.checkCollection(data, typeInfo.getTypeClass) + + val function = new FailingCollectionSource[T]( + typeInfo.createSerializer(env.getConfig), + collection, + data.length / 2, // fail after half elements + false) + + env.addSource(function)(typeInfo).setMaxParallelism(1) + } + + private def mapStrEquals(str1: String, str2: String): Boolean = { + val array1 = str1.toCharArray + val array2 = str2.toCharArray + if (array1.length != array2.length) { + return false + } + val l = array1.length + val leftBrace = "{".charAt(0) + val rightBrace = "}".charAt(0) + val equalsChar = "=".charAt(0) + val lParenthesis = "(".charAt(0) + val rParenthesis = ")".charAt(0) + val dot = ",".charAt(0) + val whiteSpace = " ".charAt(0) + val map1 = Map[String, String]() + val map2 = Map[String, String]() + var idx = 0 + def findEquals(ss: CharSequence): Array[Int] = { + val ret = new ArrayBuffer[Int]() + (0 until ss.length) foreach (idx => if (ss.charAt(idx) == equalsChar) ret += idx) + ret.toArray + } + + def splitKV(ss: CharSequence, equalsIdx: Int): (String, String) = { + // find right, if starts with '(' find until the ')', else until the ',' + var endFlag = false + var curIdx = equalsIdx + 1 + var endChar = if (ss.charAt(curIdx) == lParenthesis) rParenthesis else dot + var valueStr: CharSequence = null + var keyStr: CharSequence = null + while (curIdx < ss.length && !endFlag) { + val curChar = ss.charAt(curIdx) + if (curChar != endChar && curChar != rightBrace) { + curIdx += 1 + if (curIdx == ss.length) { + valueStr = ss.subSequence(equalsIdx + 1, curIdx) + } + } else { + valueStr = ss.subSequence(equalsIdx + 1, curIdx) + endFlag = true + } + } + + // find left, if starts with ')' find until the '(', else until the ' ,' + endFlag = false + curIdx = equalsIdx - 1 + endChar = if (ss.charAt(curIdx) == rParenthesis) lParenthesis else whiteSpace + while (curIdx >= 0 && !endFlag) { + val curChar = ss.charAt(curIdx) + if (curChar != endChar && curChar != leftBrace) { + curIdx -= 1 + if (curIdx == -1) { + keyStr = ss.subSequence(0, equalsIdx) + } + } else { + keyStr = ss.subSequence(curIdx, equalsIdx) + endFlag = true + } + } + require(keyStr != null) + require(valueStr != null) + (keyStr.toString, valueStr.toString) + } + + def appendStrToMap(ss: CharSequence, m: Map[String, String]): Unit = { + val equalsIdxs = findEquals(ss) + equalsIdxs.foreach (idx => m + splitKV(ss, idx)) + } + + while (idx < l) { + val char1 = array1(idx) + val char2 = array2(idx) + if (char1 != char2) { + return false + } + + if (char1 == leftBrace) { + val rightBraceIdx = array1.subSequence(idx + 1, l).toString.indexOf(rightBrace) + appendStrToMap(array1.subSequence(idx + 1, rightBraceIdx + idx + 2), map1) + idx += rightBraceIdx + } else { + idx += 1 + } + } + map1.equals(map2) + } + + def assertMapStrEquals(str1: String, str2: String): Unit = { + if (!mapStrEquals(str1, str2)) { + throw new AssertionError(s"Expected: $str1 \n Actual: $str2") + } + } +} + +object StreamingWithStateTestBase { + + case class StateBackendMode(backend: String) { + override def toString: String = backend.toString + } + + val HEAP_BACKEND = StateBackendMode("HEAP") + val ROCKSDB_BACKEND = StateBackendMode("ROCKSDB") + + @Parameterized.Parameters(name = "StateBackend={0}") + def parameters(): util.Collection[Array[java.lang.Object]] = { + Seq[Array[AnyRef]](Array(HEAP_BACKEND), Array(ROCKSDB_BACKEND)) + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TableUtil.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TableUtil.scala index 5ec370015c4b1..5bcc372091395 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TableUtil.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TableUtil.scala @@ -22,7 +22,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.table.`type`.TypeConverters.createExternalTypeInfoFromInternalType import org.apache.flink.table.api.{BatchTableEnvironment, TableImpl} import org.apache.flink.table.calcite.FlinkTypeFactory -import org.apache.flink.table.sinks.{CollectRowTableSink, CollectTableSink} +import org.apache.flink.table.sinks.{CollectRowTableSink, CollectTableSink, TableSink} import org.apache.flink.types.Row import _root_.scala.collection.JavaConversions._ @@ -60,4 +60,15 @@ object TableUtil { BatchTableEnvUtil.collect(table.tableEnv.asInstanceOf[BatchTableEnvironment], table, configuredSink.asInstanceOf[CollectTableSink[T]], jobName) } + + def writeToSink(table: TableImpl, sink: TableSink[_]): Unit = { + // get schema information of table + val rowType = table.getRelNode.getRowType + val fieldNames = rowType.getFieldNames.asScala.toArray + val fieldTypes = rowType.getFieldList + .map(field => FlinkTypeFactory.toInternalType(field.getType)).toArray + val configuredSink = sink.configure( + fieldNames, fieldTypes.map(createExternalTypeInfoFromInternalType)) + table.tableEnv.writeToSink(table, configuredSink) + } } diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TimeTestUtil.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TimeTestUtil.scala new file mode 100644 index 0000000000000..df1dd3f435c52 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TimeTestUtil.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.utils + +import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks +import org.apache.flink.streaming.api.functions.source.SourceFunction +import org.apache.flink.streaming.api.functions.source.SourceFunction.SourceContext +import org.apache.flink.streaming.api.operators.{AbstractStreamOperator, OneInputStreamOperator} +import org.apache.flink.streaming.api.watermark.Watermark +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord + +object TimeTestUtil { + + class EventTimeSourceFunction[T]( + dataWithTimestampList: Seq[Either[(Long, T), Long]]) extends SourceFunction[T] { + + override def run(ctx: SourceContext[T]): Unit = { + dataWithTimestampList.foreach { + case Left(t) => ctx.collectWithTimestamp(t._2, t._1) + case Right(w) => ctx.emitWatermark(new Watermark(w)) + } + } + + override def cancel(): Unit = ??? + } + + class TimestampAndWatermarkWithOffset[T <: Product]( + offset: Long) extends AssignerWithPunctuatedWatermarks[T] { + + override def checkAndGetNextWatermark(lastElement: T, extractedTimestamp: Long): Watermark = { + new Watermark(extractedTimestamp - offset) + } + + override def extractTimestamp(element: T, previousElementTimestamp: Long): Long = { + element.productElement(0).asInstanceOf[Long] + } + } + + class EventTimeProcessOperator[T] + extends AbstractStreamOperator[T] + with OneInputStreamOperator[Either[(Long, T), Long], T] { + + override def processElement(element: StreamRecord[Either[(Long, T), Long]]): Unit = { + element.getValue match { + case Left(t) => output.collect(new StreamRecord[T](t._2, t._1)) + case Right(w) => output.emitWatermark(new Watermark(w)) + } + } + + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/api/TableConfigOptions.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/api/TableConfigOptions.java index c472c1f934a18..9a47a3b738d13 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/api/TableConfigOptions.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/api/TableConfigOptions.java @@ -133,6 +133,15 @@ public class TableConfigOptions { .defaultValue(512) .withDescription("Sets the max buffer memory size for sort. It defines the upper memory for the sort."); + // ------------------------------------------------------------------------ + // topN Options + // ------------------------------------------------------------------------ + + public static final ConfigOption SQL_EXEC_TOPN_CACHE_SIZE = + key("sql.exec.topn.cache.size") + .defaultValue(10000L) + .withDescription("Cache size of every topn task."); + // ------------------------------------------------------------------------ // MiniBatch Options // ------------------------------------------------------------------------ diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryRow.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryRow.java index 9be8db632248d..faa95fbb19a59 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryRow.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryRow.java @@ -428,4 +428,20 @@ public static String toOriginString(BaseRow row, InternalType[] types) { build.append(']'); return build.toString(); } + + public boolean equalsWithoutHeader(BaseRow o) { + return equalsFrom(o, 1); + } + + private boolean equalsFrom(Object o, int startIndex) { + if (o != null && o instanceof BinaryRow) { + BinaryRow other = (BinaryRow) o; + return sizeInBytes == other.sizeInBytes && + SegmentsUtil.equals( + segments, offset + startIndex, + other.segments, other.offset + startIndex, sizeInBytes - startIndex); + } else { + return false; + } + } } diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryWriter.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryWriter.java index b1730ec3c3892..0adbfc2aa23e4 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryWriter.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryWriter.java @@ -24,6 +24,7 @@ import org.apache.flink.table.type.InternalTypes; import org.apache.flink.table.type.MapType; import org.apache.flink.table.type.RowType; +import org.apache.flink.table.type.TimestampType; import org.apache.flink.table.typeutils.BaseRowSerializer; /** @@ -102,7 +103,7 @@ static void write(BinaryWriter writer, int pos, Object o, InternalType type) { writer.writeInt(pos, (int) o); } else if (type.equals(InternalTypes.TIME)) { writer.writeInt(pos, (int) o); - } else if (type.equals(InternalTypes.TIMESTAMP)) { + } else if (type instanceof TimestampType) { writer.writeLong(pos, (long) o); } else if (type instanceof DecimalType) { DecimalType decimalType = (DecimalType) type; diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/DataFormatConverters.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/DataFormatConverters.java index 101ff35d16253..ebc59bc3f3bda 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/DataFormatConverters.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/DataFormatConverters.java @@ -146,9 +146,9 @@ public static DataFormatConverter getConverterForTypeInfo(TypeInformation typeIn } else if (typeInfo instanceof BinaryMapTypeInfo) { return BinaryMapConverter.INSTANCE; } else if (typeInfo instanceof BaseRowTypeInfo) { - return BaseRowConverter.INSTANCE; + return new BaseRowConverter(typeInfo.getArity()); } else if (typeInfo.equals(BasicTypeInfo.BIG_DEC_TYPE_INFO)) { - return BaseRowConverter.INSTANCE; + return new BaseRowConverter(typeInfo.getArity()); } else if (typeInfo instanceof DecimalTypeInfo) { DecimalTypeInfo decimalType = (DecimalTypeInfo) typeInfo; return new DecimalConverter(decimalType.precision(), decimalType.scale()); @@ -993,14 +993,13 @@ E toExternalImpl(BaseRow row, int column) { public static final class BaseRowConverter extends IdentityConverter { private static final long serialVersionUID = -4470307402371540680L; + private int arity; - public static final BaseRowConverter INSTANCE = new BaseRowConverter(); - - private BaseRowConverter() {} + private BaseRowConverter(int arity) {} @Override BaseRow toExternalImpl(BaseRow row, int column) { - throw new RuntimeException("Not support yet!"); + return row.getRow(column, arity); } } diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/TypeGetterSetters.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/TypeGetterSetters.java index f1e080ce89d57..71c323f5914db 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/TypeGetterSetters.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/TypeGetterSetters.java @@ -24,6 +24,7 @@ import org.apache.flink.table.type.InternalTypes; import org.apache.flink.table.type.MapType; import org.apache.flink.table.type.RowType; +import org.apache.flink.table.type.TimestampType; /** * Provide type specialized getters and setters to reduce if/else and eliminate box and unbox. @@ -197,7 +198,7 @@ static Object get(TypeGetterSetters row, int ordinal, InternalType type) { return row.getInt(ordinal); } else if (type.equals(InternalTypes.TIME)) { return row.getInt(ordinal); - } else if (type.equals(InternalTypes.TIMESTAMP)) { + } else if (type instanceof TimestampType) { return row.getLong(ordinal); } else if (type instanceof DecimalType) { DecimalType decimalType = (DecimalType) type; diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/util/BinaryRowUtil.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/util/BinaryRowUtil.java index 71dc6c2eacd7b..e50e5463c0327 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/util/BinaryRowUtil.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/util/BinaryRowUtil.java @@ -18,7 +18,9 @@ package org.apache.flink.table.dataformat.util; +import org.apache.flink.core.memory.MemorySegmentFactory; import org.apache.flink.core.memory.MemoryUtils; +import org.apache.flink.table.dataformat.BinaryRow; /** * Util for binary row. Many of the methods in this class are used in code generation. @@ -29,6 +31,14 @@ public class BinaryRowUtil { public static final sun.misc.Unsafe UNSAFE = MemoryUtils.UNSAFE; public static final int BYTE_ARRAY_BASE_OFFSET = UNSAFE.arrayBaseOffset(byte[].class); + public static final BinaryRow EMPTY_ROW = new BinaryRow(0); + + static { + int size = EMPTY_ROW.getFixedLengthPartSize(); + byte[] bytes = new byte[size]; + EMPTY_ROW.pointTo(MemorySegmentFactory.wrap(bytes), 0, size); + } + public static boolean byteArrayEquals(byte[] left, byte[] right, int length) { return byteArrayEquals( left, BYTE_ARRAY_BASE_OFFSET, right, BYTE_ARRAY_BASE_OFFSET, length); diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/bundle/AbstractMapBundleOperator.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/bundle/AbstractMapBundleOperator.java index d37aa31560c95..8f2c76363c008 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/bundle/AbstractMapBundleOperator.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/bundle/AbstractMapBundleOperator.java @@ -57,7 +57,7 @@ public abstract class AbstractMapBundleOperator private static final long serialVersionUID = 5081841938324118594L; /** The map in heap to store elements. */ - private final transient Map bundle; + private transient Map bundle; /** The trigger that determines how many elements should be put into a bundle. */ private final BundleTrigger bundleTrigger; @@ -74,7 +74,6 @@ public abstract class AbstractMapBundleOperator MapBundleFunction function, BundleTrigger bundleTrigger) { chainingStrategy = ChainingStrategy.ALWAYS; - this.bundle = new HashMap<>(); this.function = checkNotNull(function, "function is null"); this.bundleTrigger = checkNotNull(bundleTrigger, "bundleTrigger is null"); } @@ -86,6 +85,7 @@ public void open() throws Exception { this.numOfElements = 0; this.collector = new StreamRecordCollector<>(output); + this.bundle = new HashMap<>(); bundleTrigger.registerCallback(this); // reset trigger diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunction.java new file mode 100644 index 0000000000000..9cd3c1d4b299b --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunction.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.deduplicate; + +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.generated.GeneratedRecordEqualiser; +import org.apache.flink.table.generated.RecordEqualiser; +import org.apache.flink.table.runtime.functions.KeyedProcessFunctionWithCleanupState; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; +import org.apache.flink.util.Collector; + +/** + * This function is used to deduplicate on keys and keeps only first row or last row. + */ +public class DeduplicateFunction + extends KeyedProcessFunctionWithCleanupState + implements DeduplicateFunctionBase { + + private final BaseRowTypeInfo rowTypeInfo; + private final boolean generateRetraction; + private final boolean keepLastRow; + protected ValueState pkRow; + private GeneratedRecordEqualiser generatedEqualiser; + private transient RecordEqualiser equaliser; + + public DeduplicateFunction( + long minRetentionTime, + long maxRetentionTime, + BaseRowTypeInfo rowTypeInfo, + boolean generateRetraction, + boolean keepLastRow, + GeneratedRecordEqualiser generatedEqualiser) { + super(minRetentionTime, maxRetentionTime); + this.rowTypeInfo = rowTypeInfo; + this.generateRetraction = generateRetraction; + this.keepLastRow = keepLastRow; + this.generatedEqualiser = generatedEqualiser; + } + + @Override + public void open(Configuration configure) throws Exception { + super.open(configure); + String stateName = keepLastRow ? "DeduplicateFunctionCleanupTime" : "DeduplicateFunctionCleanupTime"; + initCleanupTimeState(stateName); + ValueStateDescriptor rowStateDesc = new ValueStateDescriptor("rowState", rowTypeInfo); + pkRow = getRuntimeContext().getState(rowStateDesc); + equaliser = generatedEqualiser.newInstance(getRuntimeContext().getUserCodeClassLoader()); + } + + @Override + public void processElement(BaseRow input, Context ctx, Collector out) throws Exception { + long currentTime = ctx.timerService().currentProcessingTime(); + // register state-cleanup timer + registerProcessingCleanupTimer(ctx, currentTime); + + BaseRow preRow = pkRow.value(); + if (keepLastRow) { + processLastRow(preRow, input, generateRetraction, stateCleaningEnabled, pkRow, equaliser, out); + } else { + processFirstRow(preRow, input, generateRetraction, stateCleaningEnabled, pkRow, equaliser, out); + } + } + + @Override + public void close() throws Exception { + super.close(); + } + + @Override + public void onTimer( + long timestamp, + OnTimerContext ctx, + Collector out) throws Exception { + if (needToCleanupState(timestamp)) { + cleanupState(pkRow); + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionBase.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionBase.java new file mode 100644 index 0000000000000..22293ae75beb0 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionBase.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.deduplicate; + +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.dataformat.util.BaseRowUtil; +import org.apache.flink.table.generated.RecordEqualiser; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +/** + * Base class to deduplicate on keys and keeps only first row or last row. + */ +public interface DeduplicateFunctionBase { + + default void processLastRow(BaseRow preRow, BaseRow currentRow, boolean generateRetraction, + boolean stateCleaningEnabled, ValueState pkRow, RecordEqualiser equaliser, + Collector out) throws Exception { + // should be accumulate msg. + Preconditions.checkArgument(BaseRowUtil.isAccumulateMsg(currentRow)); + // ignore same record + if (!stateCleaningEnabled && preRow != null && + equaliser.equalsWithoutHeader(preRow, currentRow)) { + return; + } + pkRow.update(currentRow); + if (preRow != null && generateRetraction) { + preRow.setHeader(BaseRowUtil.RETRACT_MSG); + out.collect(preRow); + } + out.collect(currentRow); + } + + default void processFirstRow(BaseRow preRow, BaseRow currentRow, boolean generateRetraction, + boolean stateCleaningEnabled, ValueState pkRow, RecordEqualiser equaliser, + Collector out) throws Exception { + // should be accumulate msg. + Preconditions.checkArgument(BaseRowUtil.isAccumulateMsg(currentRow)); + // ignore record with timestamp bigger than preRow + if (!isFirstRow(preRow)) { + return; + } + + pkRow.update(currentRow); + out.collect(currentRow); + } + + default boolean isFirstRow(BaseRow preRow) { + return preRow == null; + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateFunction.java new file mode 100644 index 0000000000000..cff33d77948e7 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateFunction.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.deduplicate; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.generated.GeneratedRecordEqualiser; +import org.apache.flink.table.generated.RecordEqualiser; +import org.apache.flink.table.runtime.bundle.MapBundleFunction; +import org.apache.flink.table.runtime.context.ExecutionContext; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; +import org.apache.flink.util.Collector; + +import javax.annotation.Nullable; + +import java.util.Map; + +/** + * This function is used to get the first row or last row for every key partition in miniBatch + * mode. + */ +public class MiniBatchDeduplicateFunction + extends MapBundleFunction + implements DeduplicateFunctionBase { + + private BaseRowTypeInfo rowTypeInfo; + private boolean generateRetraction; + private boolean keepLastRow; + protected ValueState pkRow; + private TypeSerializer ser; + private GeneratedRecordEqualiser generatedEqualiser; + private transient RecordEqualiser equaliser; + + public MiniBatchDeduplicateFunction( + BaseRowTypeInfo rowTypeInfo, + boolean generateRetraction, + ExecutionConfig executionConfig, + boolean keepLastRow, + GeneratedRecordEqualiser generatedEqualiser) { + this.rowTypeInfo = rowTypeInfo; + this.generateRetraction = generateRetraction; + this.keepLastRow = keepLastRow; + ser = rowTypeInfo.createSerializer(executionConfig); + this.generatedEqualiser = generatedEqualiser; + } + + @Override + public void open(ExecutionContext ctx) throws Exception { + super.open(ctx); + ValueStateDescriptor rowStateDesc = new ValueStateDescriptor("rowState", rowTypeInfo); + pkRow = ctx.getRuntimeContext().getState(rowStateDesc); + equaliser = generatedEqualiser.newInstance(ctx.getRuntimeContext().getUserCodeClassLoader()); + } + + @Override + public BaseRow addInput(@Nullable BaseRow value, BaseRow input) { + if (value == null || keepLastRow || (!keepLastRow && isFirstRow(value))) { + // put the input into buffer + return ser.copy(input); + } else { + // the input is not last row, ignore it + return value; + } + } + + @Override + public void finishBundle( + Map buffer, Collector out) throws Exception { + for (Map.Entry entry : buffer.entrySet()) { + BaseRow currentKey = entry.getKey(); + BaseRow currentRow = entry.getValue(); + ctx.setCurrentKey(currentKey); + BaseRow preRow = pkRow.value(); + + if (keepLastRow) { + processLastRow(preRow, currentRow, generateRetraction, false, pkRow, equaliser, out); + } else { + processFirstRow(preRow, currentRow, generateRetraction, false, pkRow, equaliser, out); + } + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/KeyedProcessFunctionWithCleanupState.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/KeyedProcessFunctionWithCleanupState.java new file mode 100644 index 0000000000000..66123be6b9ab0 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/KeyedProcessFunctionWithCleanupState.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.functions; + +import org.apache.flink.api.common.state.State; +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.streaming.api.TimeDomain; +import org.apache.flink.streaming.api.functions.KeyedProcessFunction; + +/** + * A function that processes elements of a stream, and could cleanup state. + * @param Type of the key. + * @param Type of the input elements. + * @param Type of the output elements. + */ +public abstract class KeyedProcessFunctionWithCleanupState extends KeyedProcessFunction { + + protected final long minRetentionTime; + protected final long maxRetentionTime; + protected final boolean stateCleaningEnabled; + + // holds the latest registered cleanup timer + private ValueState cleanupTimeState; + + public KeyedProcessFunctionWithCleanupState(long minRetentionTime, long maxRetentionTime) { + this.minRetentionTime = minRetentionTime; + this.maxRetentionTime = maxRetentionTime; + this.stateCleaningEnabled = minRetentionTime > 1; + } + + protected void initCleanupTimeState(String stateName) { + if (stateCleaningEnabled) { + ValueStateDescriptor inputCntDescriptor = new ValueStateDescriptor(stateName, Types.LONG); + cleanupTimeState = getRuntimeContext().getState(inputCntDescriptor); + } + } + + protected void registerProcessingCleanupTimer(Context ctx, long currentTime) throws Exception { + if (stateCleaningEnabled) { + // last registered timer + Long curCleanupTime = cleanupTimeState.value(); + + // check if a cleanup timer is registered and + // that the current cleanup timer won't delete state we need to keep + if (curCleanupTime == null || (currentTime + minRetentionTime) > curCleanupTime) { + // we need to register a new (later) timer + Long cleanupTime = currentTime + maxRetentionTime; + // register timer and remember clean-up time + ctx.timerService().registerProcessingTimeTimer(cleanupTime); + cleanupTimeState.update(cleanupTime); + } + } + } + + protected boolean isProcessingTimeTimer(OnTimerContext ctx) { + return ctx.timeDomain() == TimeDomain.PROCESSING_TIME; + } + + protected boolean needToCleanupState(long timestamp) throws Exception { + if (stateCleaningEnabled) { + Long cleanupTime = cleanupTimeState.value(); + // check that the triggered timer is the last registered processing time timer. + return null != cleanupTime && timestamp == cleanupTime; + } else { + return false; + } + } + + protected void cleanupState(State... states) { + for (State state : states) { + state.clear(); + } + this.cleanupTimeState.clear(); + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/BaseRowKeySelector.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/BaseRowKeySelector.java new file mode 100644 index 0000000000000..2ad9e5dd2c393 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/BaseRowKeySelector.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.keyselector; + +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.typeutils.ResultTypeQueryable; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; + +/** + * BaseRowKeySelector takes an BaseRow and extracts the deterministic key for the BaseRow. + */ +public interface BaseRowKeySelector extends KeySelector, ResultTypeQueryable { + + BaseRowTypeInfo getProducedType(); + +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/BinaryRowKeySelector.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/BinaryRowKeySelector.java new file mode 100644 index 0000000000000..6e3f1729cb8e8 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/BinaryRowKeySelector.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.keyselector; + +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.dataformat.BinaryRow; +import org.apache.flink.table.generated.GeneratedProjection; +import org.apache.flink.table.generated.Projection; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; + +/** + * A KeySelector which will extract key from BaseRow. + */ +public class BinaryRowKeySelector implements BaseRowKeySelector { + + private static final long serialVersionUID = 5375355285015381919L; + + private final BaseRowTypeInfo keyRowType; + private final GeneratedProjection generatedProjection; + private transient Projection projection; + + public BinaryRowKeySelector(BaseRowTypeInfo keyRowType, GeneratedProjection generatedProjection) { + this.keyRowType = keyRowType; + this.generatedProjection = generatedProjection; + } + + @Override + public BaseRow getKey(BaseRow value) throws Exception { + if (projection == null) { + projection = generatedProjection.newInstance(Thread.currentThread().getContextClassLoader()); + } + return projection.apply(value).copy(); + } + + @Override + public BaseRowTypeInfo getProducedType() { + return keyRowType; + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/NullBinaryRowKeySelector.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/NullBinaryRowKeySelector.java new file mode 100644 index 0000000000000..3ca3ccaeaace7 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/NullBinaryRowKeySelector.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.keyselector; + +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.dataformat.util.BinaryRowUtil; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; + +/** + * A KeySelector which key is always empty. + */ +public class NullBinaryRowKeySelector implements BaseRowKeySelector { + + private final BaseRowTypeInfo returnType = new BaseRowTypeInfo(); + + @Override + public BaseRow getKey(BaseRow value) throws Exception { + return BinaryRowUtil.EMPTY_ROW; + } + + @Override + public BaseRowTypeInfo getProducedType() { + return returnType; + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractRankFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractRankFunction.java new file mode 100644 index 0000000000000..8477ef6cc987e --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractRankFunction.java @@ -0,0 +1,318 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.metrics.Counter; +import org.apache.flink.metrics.Gauge; +import org.apache.flink.streaming.api.operators.KeyContext; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.dataformat.GenericRow; +import org.apache.flink.table.dataformat.JoinedRow; +import org.apache.flink.table.dataformat.util.BaseRowUtil; +import org.apache.flink.table.generated.GeneratedRecordComparator; +import org.apache.flink.table.generated.GeneratedRecordEqualiser; +import org.apache.flink.table.generated.RecordEqualiser; +import org.apache.flink.table.runtime.functions.KeyedProcessFunctionWithCleanupState; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; +import org.apache.flink.util.Collector; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collection; +import java.util.Comparator; +import java.util.Iterator; +import java.util.Map; + +/** + * Base class for Rank Function. + */ +public abstract class AbstractRankFunction extends KeyedProcessFunctionWithCleanupState { + + private static final Logger LOG = LoggerFactory.getLogger(AbstractRankFunction.class); + + // we set default topn size to 100 + private static final long DEFAULT_TOPN_SIZE = 100; + + private final RankRange rankRange; + private final GeneratedRecordEqualiser generatedEqualiser; + private final boolean generateRetraction; + protected final boolean isRowNumberAppend; + protected final RankType rankType; + protected final BaseRowTypeInfo inputRowType; + protected final BaseRowTypeInfo outputRowType; + protected final GeneratedRecordComparator generatedRecordComparator; + protected final KeySelector sortKeySelector; + + protected KeyContext keyContext; + protected boolean isConstantRankEnd; + protected long rankEnd = -1; + protected long rankStart = -1; + protected RecordEqualiser equaliser; + private int rankEndIndex; + private ValueState rankEndState; + private Counter invalidCounter; + private JoinedRow outputRow; + + // metrics + protected long hitCount = 0L; + protected long requestCount = 0L; + + public AbstractRankFunction( + long minRetentionTime, long maxRetentionTime, BaseRowTypeInfo inputRowType, BaseRowTypeInfo outputRowType, + GeneratedRecordComparator generatedRecordComparator, KeySelector sortKeySelector, + RankType rankType, RankRange rankRange, GeneratedRecordEqualiser generatedEqualiser, boolean generateRetraction) { + super(minRetentionTime, maxRetentionTime); + this.rankRange = rankRange; + this.generatedEqualiser = generatedEqualiser; + this.generateRetraction = generateRetraction; + this.rankType = rankType; + // TODO support RANK and DENSE_RANK + switch (rankType) { + case ROW_NUMBER: + break; + case RANK: + LOG.error("RANK() on streaming table is not supported currently"); + throw new UnsupportedOperationException("RANK() on streaming table is not supported currently"); + case DENSE_RANK: + LOG.error("DENSE_RANK() on streaming table is not supported currently"); + throw new UnsupportedOperationException("DENSE_RANK() on streaming table is not supported currently"); + default: + LOG.error("Streaming tables do not support {}", rankType.name()); + throw new UnsupportedOperationException("Streaming tables do not support " + rankType.toString()); + } + this.inputRowType = inputRowType; + this.outputRowType = outputRowType; + this.isRowNumberAppend = inputRowType.getArity() + 1 == outputRowType.getArity(); + this.generatedRecordComparator = generatedRecordComparator; + this.sortKeySelector = sortKeySelector; + } + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + initCleanupTimeState("RankFunctionCleanupTime"); + outputRow = new JoinedRow(); + + // variable rank limit + if (rankRange instanceof ConstantRankRange) { + ConstantRankRange constantRankRange = (ConstantRankRange) rankRange; + isConstantRankEnd = true; + rankEnd = constantRankRange.getRankEnd(); + rankStart = constantRankRange.getRankStart(); + } else if (rankRange instanceof VariableRankRange) { + VariableRankRange variableRankRange = (VariableRankRange) rankRange; + isConstantRankEnd = false; + rankEndIndex = variableRankRange.getRankEndIndex(); + ValueStateDescriptor rankStateDesc = new ValueStateDescriptor("rankEnd", Types.LONG); + rankEndState = getRuntimeContext().getState(rankStateDesc); + } + equaliser = generatedEqualiser.newInstance(getRuntimeContext().getUserCodeClassLoader()); + invalidCounter = getRuntimeContext().getMetricGroup().counter("topn.invalidTopSize"); + } + + protected long getDefaultTopSize() { + return isConstantRankEnd ? rankEnd : DEFAULT_TOPN_SIZE; + } + + protected long initRankEnd(BaseRow row) throws Exception { + if (isConstantRankEnd) { + return rankEnd; + } else { + Long rankEndValue = rankEndState.value(); + long curRankEnd = row.getLong(rankEndIndex); + if (rankEndValue == null) { + rankEnd = curRankEnd; + rankEndState.update(rankEnd); + return rankEnd; + } else { + rankEnd = rankEndValue; + if (rankEnd != curRankEnd) { + // increment the invalid counter when the current rank end + // not equal to previous rank end + invalidCounter.inc(); + } + return rankEnd; + } + } + } + + protected Tuple2 rowNumber(K sortKey, BaseRow rowKey, SortedMap sortedMap) { + Iterator>> iterator = sortedMap.entrySet().iterator(); + int curRank = 1; + while (iterator.hasNext()) { + Map.Entry> entry = iterator.next(); + K curKey = entry.getKey(); + Collection rowKeys = entry.getValue(); + if (curKey.equals(sortKey)) { + Iterator rowKeysIter = rowKeys.iterator(); + int innerRank = 1; + while (rowKeysIter.hasNext()) { + if (rowKey.equals(rowKeysIter.next())) { + return Tuple2.of(curRank, innerRank); + } else { + innerRank += 1; + curRank += 1; + } + } + } else { + curRank += rowKeys.size(); + } + } + LOG.error("Failed to find the sortKey: {}, rowkey: {} in SortedMap. " + + "This should never happen", sortKey, rowKey); + throw new RuntimeException( + "Failed to find the sortKey, rowkey in SortedMap. This should never happen"); + } + + /** + * return true if record should be put into sort map. + */ + protected boolean checkSortKeyInBufferRange(K sortKey, SortedMap sortedMap, Comparator sortKeyComparator) { + Map.Entry> worstEntry = sortedMap.lastEntry(); + if (worstEntry == null) { + // sort map is empty + return true; + } else { + K worstKey = worstEntry.getKey(); + int compare = sortKeyComparator.compare(sortKey, worstKey); + if (compare < 0) { + return true; + } else if (sortedMap.getCurrentTopNum() < getMaxSortMapSize()) { + return true; + } else { + return false; + } + } + } + + protected void registerMetric(long heapSize) { + getRuntimeContext().getMetricGroup().>gauge( + "topn.cache.hitRate", + new Gauge() { + + @Override + public Double getValue() { + return requestCount == 0 ? 1.0 : + Long.valueOf(hitCount).doubleValue() / requestCount; + } + }); + + getRuntimeContext().getMetricGroup().>gauge( + "topn.cache.size", + new Gauge() { + + @Override + public Long getValue() { + return heapSize; + } + }); + } + + protected void collect(Collector out, BaseRow inputRow) { + BaseRowUtil.setAccumulate(inputRow); + out.collect(inputRow); + } + + /** + * This is similar to [[retract()]] but always send retraction message regardless of + * generateRetraction is true or not. + */ + protected void delete(Collector out, BaseRow inputRow) { + BaseRowUtil.setRetract(inputRow); + out.collect(inputRow); + } + + /** + * This is with-row-number version of above delete() method. + */ + protected void delete(Collector out, BaseRow inputRow, long rank) { + if (isInRankRange(rank)) { + out.collect(createOutputRow(inputRow, rank, BaseRowUtil.RETRACT_MSG)); + } + } + + protected void collect(Collector out, BaseRow inputRow, long rank) { + if (isInRankRange(rank)) { + out.collect(createOutputRow(inputRow, rank, BaseRowUtil.ACCUMULATE_MSG)); + } + } + + protected void retract(Collector out, BaseRow inputRow, long rank) { + if (generateRetraction && isInRankRange(rank)) { + out.collect(createOutputRow(inputRow, rank, BaseRowUtil.RETRACT_MSG)); + } + } + + protected boolean isInRankEnd(long rank) { + return rank <= rankEnd; + } + + protected boolean isInRankRange(long rank) { + return rank <= rankEnd && rank >= rankStart; + } + + protected boolean hasOffset() { + // rank start is 1-based + return rankStart > 1; + } + + protected BaseRow createOutputRow(BaseRow inputRow, long rank, byte header) { + if (isRowNumberAppend) { + GenericRow rankRow = new GenericRow(1); + rankRow.setField(0, rank); + + outputRow.replace(inputRow, rankRow); + outputRow.setHeader(header); + return outputRow; + } else { + inputRow.setHeader(header); + return inputRow; + } + } + + /** + * get sorted map size limit + * Implementations may vary depending on each rank who has in-memory sort map. + * @return + */ + protected abstract long getMaxSortMapSize(); + + @Override + public void processElement( + BaseRow input, Context ctx, Collector out) throws Exception { + + } + + /** + * Set keyContext to RankFunction. + * + * @param keyContext keyContext of current function. + */ + public void setKeyContext(KeyContext keyContext) { + this.keyContext = keyContext; + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractUpdateRankFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractUpdateRankFunction.java new file mode 100644 index 0000000000000..d54ce3af0afc9 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractUpdateRankFunction.java @@ -0,0 +1,293 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +import org.apache.flink.api.common.state.MapState; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.state.FunctionInitializationContext; +import org.apache.flink.runtime.state.FunctionSnapshotContext; +import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.generated.GeneratedRecordComparator; +import org.apache.flink.table.generated.GeneratedRecordEqualiser; +import org.apache.flink.table.runtime.util.LRUMap; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; +import org.apache.flink.util.Collector; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collection; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.TreeMap; +import java.util.function.Supplier; + +/** + * Base class for Update Rank Function. + */ +abstract class AbstractUpdateRankFunction extends AbstractRankFunction + implements CheckpointedFunction { + + private static final Logger LOG = LoggerFactory.getLogger(AbstractUpdateRankFunction.class); + + private final BaseRowTypeInfo rowKeyType; + private final long cacheSize; + + // a map state stores mapping from row key to record which is in topN + // in tuple2, f0 is the record row, f1 is the index in the list of the same sort_key + // the f1 is used to preserve the record order in the same sort_key + protected transient MapState> dataState; + + // a sorted map stores mapping from sort key to rowkey list + protected transient SortedMap sortedMap; + + protected transient Map> kvSortedMap; + + // a HashMap stores mapping from rowkey to record, a heap mirror to dataState + protected transient Map rowKeyMap; + + protected transient LRUMap> kvRowKeyMap; + + protected Comparator sortKeyComparator; + + public AbstractUpdateRankFunction( + long minRetentionTime, + long maxRetentionTime, + BaseRowTypeInfo inputRowType, + BaseRowTypeInfo outputRowType, + BaseRowTypeInfo rowKeyType, + GeneratedRecordComparator generatedRecordComparator, + KeySelector sortKeySelector, + RankType rankType, + RankRange rankRange, + GeneratedRecordEqualiser generatedEqualiser, + boolean generateRetraction, + long cacheSize) { + super(minRetentionTime, + maxRetentionTime, + inputRowType, + outputRowType, + generatedRecordComparator, + sortKeySelector, + rankType, + rankRange, + generatedEqualiser, + generateRetraction); + this.rowKeyType = rowKeyType; + this.cacheSize = cacheSize; + } + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + int lruCacheSize = Math.max(1, (int) (cacheSize / getMaxSortMapSize())); + // make sure the cached map is in a fixed size, avoid OOM + kvSortedMap = new HashMap<>(lruCacheSize); + kvRowKeyMap = new LRUMap<>(lruCacheSize, new CacheRemovalListener()); + + LOG.info("Top{} operator is using LRU caches key-size: {}", getMaxSortMapSize(), lruCacheSize); + + TupleTypeInfo> valueTypeInfo = new TupleTypeInfo<>(inputRowType, Types.INT); + MapStateDescriptor> mapStateDescriptor = new MapStateDescriptor( + "data-state-with-update", rowKeyType, valueTypeInfo); + dataState = getRuntimeContext().getMapState(mapStateDescriptor); + + // metrics + registerMetric(kvSortedMap.size() * getMaxSortMapSize()); + + sortKeyComparator = generatedRecordComparator.newInstance(getRuntimeContext().getUserCodeClassLoader()); + } + + @Override + public void onTimer( + long timestamp, + OnTimerContext ctx, + Collector out) throws Exception { + if (needToCleanupState(timestamp)) { + BaseRow partitionKey = (BaseRow) keyContext.getCurrentKey(); + // cleanup cache + kvRowKeyMap.remove(partitionKey); + kvSortedMap.remove(partitionKey); + cleanupState(dataState); + } + } + + @Override + public void snapshotState(FunctionSnapshotContext context) throws Exception { + Iterator>> iter = kvRowKeyMap.entrySet().iterator(); + while (iter.hasNext()) { + Map.Entry> entry = iter.next(); + BaseRow partitionKey = entry.getKey(); + Map currentRowKeyMap = entry.getValue(); + keyContext.setCurrentKey(partitionKey); + synchronizeState(currentRowKeyMap); + } + } + + @Override + public void initializeState(FunctionInitializationContext context) throws Exception { + // nothing to do + } + + protected void initHeapStates() throws Exception { + requestCount += 1; + BaseRow partitionKey = (BaseRow) keyContext.getCurrentKey(); + sortedMap = kvSortedMap.get(partitionKey); + rowKeyMap = kvRowKeyMap.get(partitionKey); + if (sortedMap == null) { + sortedMap = new SortedMap( + sortKeyComparator, + new Supplier>() { + + @Override + public Collection get() { + return new LinkedHashSet<>(); + } + }); + rowKeyMap = new HashMap<>(); + kvSortedMap.put(partitionKey, sortedMap); + kvRowKeyMap.put(partitionKey, rowKeyMap); + + // restore sorted map + Iterator>> iter = dataState.iterator(); + if (iter != null) { + // a temp map associate sort key to tuple2 + Map> tempSortedMap = new HashMap<>(); + while (iter.hasNext()) { + Map.Entry> entry = iter.next(); + BaseRow rowkey = entry.getKey(); + Tuple2 recordAndInnerRank = entry.getValue(); + BaseRow record = recordAndInnerRank.f0; + Integer innerRank = recordAndInnerRank.f1; + rowKeyMap.put(rowkey, new RankRow(record, innerRank, false)); + + // insert into temp sort map to preserve the record order in the same sort key + BaseRow sortKey = sortKeySelector.getKey(record); + TreeMap treeMap = tempSortedMap.get(sortKey); + if (treeMap == null) { + treeMap = new TreeMap<>(); + tempSortedMap.put(sortKey, treeMap); + } + treeMap.put(innerRank, rowkey); + } + + // build sorted map from the temp map + Iterator>> tempIter = + tempSortedMap.entrySet().iterator(); + while (tempIter.hasNext()) { + Map.Entry> entry = tempIter.next(); + BaseRow sortKey = entry.getKey(); + TreeMap treeMap = entry.getValue(); + Iterator> treeMapIter = treeMap.entrySet().iterator(); + while (treeMapIter.hasNext()) { + Map.Entry treeMapEntry = treeMapIter.next(); + Integer innerRank = treeMapEntry.getKey(); + BaseRow recordRowKey = treeMapEntry.getValue(); + int size = sortedMap.put(sortKey, recordRowKey); + if (innerRank != size) { + LOG.warn("Failed to build sorted map from state, this may result in wrong result. " + + "The sort key is {}, partition key is {}, " + + "treeMap is {}. The expected inner rank is {}, but current size is {}.", + sortKey, partitionKey, treeMap, innerRank, size); + } + } + } + } + } else { + hitCount += 1; + } + } + + private void synchronizeState(Map curRowKeyMap) throws Exception { + Iterator> iter = curRowKeyMap.entrySet().iterator(); + while (iter.hasNext()) { + Map.Entry entry = iter.next(); + BaseRow key = entry.getKey(); + RankRow rankRow = entry.getValue(); + if (rankRow.dirty) { + // should update state + dataState.put(key, Tuple2.of(rankRow.row, rankRow.innerRank)); + rankRow.dirty = false; + } + } + } + + class CacheRemovalListener implements LRUMap.RemovalListener> { + + @Override + public void onRemoval(Map.Entry> eldest) { + BaseRow previousKey = (BaseRow) keyContext.getCurrentKey(); + BaseRow partitionKey = eldest.getKey(); + Map currentRowKeyMap = eldest.getValue(); + keyContext.setCurrentKey(partitionKey); + kvSortedMap.remove(partitionKey); + try { + synchronizeState(currentRowKeyMap); + } catch (Throwable e) { + LOG.error("Fail to synchronize state!"); + throw new RuntimeException(e); + } + keyContext.setCurrentKey(previousKey); + } + } + + protected class RankRow { + private final BaseRow row; + private int innerRank; + private boolean dirty; + + protected RankRow(BaseRow row, int innerRank, boolean dirty) { + this.row = row; + this.innerRank = innerRank; + this.dirty = dirty; + } + + protected BaseRow getRow() { + return row; + } + + protected int getInnerRank() { + return innerRank; + } + + protected boolean isDirty() { + return dirty; + } + + public void setInnerRank(int innerRank) { + this.innerRank = innerRank; + } + + public void setDirty(boolean dirty) { + this.dirty = dirty; + } + } +} + + + diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AppendRankFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AppendRankFunction.java new file mode 100644 index 0000000000000..65ca3e5e20216 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AppendRankFunction.java @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.state.MapState; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.typeutils.ListTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.generated.GeneratedRecordComparator; +import org.apache.flink.table.generated.GeneratedRecordEqualiser; +import org.apache.flink.table.runtime.util.LRUMap; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; +import org.apache.flink.util.Collector; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; + +/** + * RankFunction in Append Stream mode. + */ +public class AppendRankFunction extends AbstractRankFunction { + + private static final Logger LOG = LoggerFactory.getLogger(AppendRankFunction.class); + + protected final BaseRowTypeInfo sortKeyType; + private final TypeSerializer inputRowSer; + private final long cacheSize; + + // a map state stores mapping from sort key to records list which is in topN + private transient MapState> dataState; + + // a sorted map stores mapping from sort key to records list, a heap mirror to dataState + private transient SortedMap sortedMap; + private transient Map> kvSortedMap; + private Comparator sortKeyComparator; + + public AppendRankFunction( + long minRetentionTime, long maxRetentionTime, BaseRowTypeInfo inputRowType, BaseRowTypeInfo outputRowType, + BaseRowTypeInfo sortKeyType, GeneratedRecordComparator generatedRecordComparator, + KeySelector sortKeySelector, RankType rankType, RankRange rankRange, + GeneratedRecordEqualiser generatedEqualiser, boolean generateRetraction, long cacheSize) { + super(minRetentionTime, maxRetentionTime, inputRowType, outputRowType, + generatedRecordComparator, sortKeySelector, rankType, rankRange, generatedEqualiser, generateRetraction); + this.sortKeyType = sortKeyType; + this.inputRowSer = inputRowType.createSerializer(new ExecutionConfig()); + this.cacheSize = cacheSize; + } + + public void open(Configuration parameters) throws Exception { + super.open(parameters); + int lruCacheSize = Math.max(1, (int) (cacheSize / getDefaultTopSize())); + kvSortedMap = new LRUMap<>(lruCacheSize); + LOG.info("Top{} operator is using LRU caches key-size: {}", getDefaultTopSize(), lruCacheSize); + + ListTypeInfo valueTypeInfo = new ListTypeInfo<>(inputRowType); + MapStateDescriptor> mapStateDescriptor = new MapStateDescriptor( + "data-state-with-append", sortKeyType, valueTypeInfo); + dataState = getRuntimeContext().getMapState(mapStateDescriptor); + + sortKeyComparator = generatedRecordComparator.newInstance(getRuntimeContext().getUserCodeClassLoader()); + + // metrics + registerMetric(kvSortedMap.size() * getDefaultTopSize()); + } + + @Override + public void processElement( + BaseRow input, Context context, Collector out) throws Exception { + long currentTime = context.timerService().currentProcessingTime(); + // register state-cleanup timer + registerProcessingCleanupTimer(context, currentTime); + + initHeapStates(); + initRankEnd(input); + + BaseRow sortKey = sortKeySelector.getKey(input); + // check whether the sortKey is in the topN range + if (checkSortKeyInBufferRange(sortKey, sortedMap, sortKeyComparator)) { + // insert sort key into sortedMap + sortedMap.put(sortKey, inputRowSer.copy(input)); + Collection inputs = sortedMap.get(sortKey); + // update data state + dataState.put(sortKey, (List) inputs); + if (isRowNumberAppend || hasOffset()) { + // the without-number-algorithm can't handle topn with offset, + // so use the with-number-algorithm to handle offset + emitRecordsWithRowNumber(sortKey, input, out); + } else { + processElementWithoutRowNumber(input, out); + } + } + } + + private void initHeapStates() throws Exception { + requestCount += 1; + BaseRow currentKey = (BaseRow) keyContext.getCurrentKey(); + sortedMap = kvSortedMap.get(currentKey); + if (sortedMap == null) { + sortedMap = new SortedMap(sortKeyComparator, new Supplier>() { + + @Override + public Collection get() { + return new ArrayList<>(); + } + }); + kvSortedMap.put(currentKey, sortedMap); + // restore sorted map + Iterator>> iter = dataState.iterator(); + if (iter != null) { + while (iter.hasNext()) { + Map.Entry> entry = iter.next(); + BaseRow sortKey = entry.getKey(); + List values = entry.getValue(); + // the order is preserved + sortedMap.putAll(sortKey, values); + } + } + } else { + hitCount += 1; + } + } + + private void emitRecordsWithRowNumber(BaseRow sortKey, BaseRow input, Collector out) throws Exception { + Iterator>> iterator = sortedMap.entrySet().iterator(); + long curRank = 0L; + boolean findSortKey = false; + while (iterator.hasNext() && isInRankEnd(curRank)) { + Map.Entry> entry = iterator.next(); + Collection records = entry.getValue(); + // meet its own sort key + if (!findSortKey && entry.getKey().equals(sortKey)) { + curRank += records.size(); + collect(out, input, curRank); + findSortKey = true; + } else if (findSortKey) { + Iterator recordsIter = records.iterator(); + while (recordsIter.hasNext() && isInRankEnd(curRank)) { + curRank += 1; + BaseRow prevRow = recordsIter.next(); + retract(out, prevRow, curRank - 1); + collect(out, prevRow, curRank); + } + } else { + curRank += records.size(); + } + } + + List toDeleteKeys = new ArrayList<>(); + // remove the records associated to the sort key which is out of topN + while (iterator.hasNext()) { + Map.Entry> entry = iterator.next(); + BaseRow key = entry.getKey(); + dataState.remove(key); + toDeleteKeys.add(key); + } + for (BaseRow toDeleteKey : toDeleteKeys) { + sortedMap.removeAll(toDeleteKey); + } + } + + private void processElementWithoutRowNumber(BaseRow input, Collector out) throws Exception { + // remove retired element + if (sortedMap.getCurrentTopNum() > rankEnd) { + Map.Entry> lastEntry = sortedMap.lastEntry(); + BaseRow lastKey = lastEntry.getKey(); + List lastList = (List) lastEntry.getValue(); + // remove last one + BaseRow lastElement = lastList.remove(lastList.size() - 1); + if (lastList.isEmpty()) { + sortedMap.removeAll(lastKey); + dataState.remove(lastKey); + } else { + dataState.put(lastKey, lastList); + } + // lastElement shouldn't be null + delete(out, lastElement); + } + collect(out, input); + } + + @Override + public void onTimer( + long timestamp, + OnTimerContext ctx, + Collector out) throws Exception { + if (needToCleanupState(timestamp)) { + // cleanup cache + kvSortedMap.remove(keyContext.getCurrentKey()); + cleanupState(dataState); + } + } + + @Override + protected long getMaxSortMapSize() { + return getDefaultTopSize(); + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/ConstantRankRange.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/ConstantRankRange.java new file mode 100644 index 0000000000000..b4bb201308293 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/ConstantRankRange.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +import java.util.List; + +/** rankStart and rankEnd are inclusive, rankStart always start from one. */ +public class ConstantRankRange implements RankRange { + + private static final long serialVersionUID = 9062345289888078376L; + private long rankStart; + private long rankEnd; + + public ConstantRankRange(long rankStart, long rankEnd) { + this.rankStart = rankStart; + this.rankEnd = rankEnd; + } + + public long getRankStart() { + return rankStart; + } + + public long getRankEnd() { + return rankEnd; + } + + @Override + public String toString(List inputFieldNames) { + return toString(); + } + + @Override + public String toString() { + return "rankStart=" + rankStart + ", rankEnd=" + rankEnd; + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/ConstantRankRangeWithoutEnd.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/ConstantRankRangeWithoutEnd.java new file mode 100644 index 0000000000000..0d0ddd3dcc2fe --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/ConstantRankRangeWithoutEnd.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +import java.util.List; + +/** ConstantRankRangeWithoutEnd is a RankRange which not specify RankEnd. */ +public class ConstantRankRangeWithoutEnd implements RankRange { + + private static final long serialVersionUID = -1944057111062598696L; + + private final long rankStart; + + public ConstantRankRangeWithoutEnd(long rankStart) { + this.rankStart = rankStart; + } + + @Override + public String toString(List inputFieldNames) { + return toString(); + } + + @Override + public String toString() { + return "rankStart=" + rankStart; + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RankRange.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RankRange.java new file mode 100644 index 0000000000000..e14b4bfff9dd9 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RankRange.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +import java.io.Serializable; +import java.util.List; + +/** + * RankRange for Rank, including following 3 types : + * ConstantRankRange, ConstantRankRangeWithoutEnd, VariableRankRange. + */ +public interface RankRange extends Serializable { + + String toString(List inputFieldNames); + +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RankType.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RankType.java new file mode 100644 index 0000000000000..d6573dd96b24b --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RankType.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +/** + * An enumeration of rank type, usable to show how exactly generate rank number. + */ +public enum RankType { + + /** + * Returns a unique sequential number for each row within the partition based on the order, + * starting at 1 for the first row in each partition and without repeating or skipping + * numbers in the ranking result of each partition. If there are duplicate values within the + * row set, the ranking numbers will be assigned arbitrarily. + */ + ROW_NUMBER, + + /** + * Returns a unique rank number for each distinct row within the partition based on the order, + * starting at 1 for the first row in each partition, with the same rank for duplicate values + * and leaving gaps between the ranks; this gap appears in the sequence after the duplicate + * values. + */ + RANK, + + /** + * is similar to the RANK by generating a unique rank number for each distinct row + * within the partition based on the order, starting at 1 for the first row in each partition, + * ranking the rows with equal values with the same rank number, except that it does not skip + * any rank, leaving no gaps between the ranks. + */ + DENSE_RANK +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RetractRankFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RetractRankFunction.java new file mode 100644 index 0000000000000..f3ac79b3269ce --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RetractRankFunction.java @@ -0,0 +1,263 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +import org.apache.flink.api.common.state.MapState; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.typeutils.ListTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.dataformat.util.BaseRowUtil; +import org.apache.flink.table.generated.GeneratedRecordComparator; +import org.apache.flink.table.generated.GeneratedRecordEqualiser; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; +import org.apache.flink.table.typeutils.SortedMapTypeInfo; +import org.apache.flink.util.Collector; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.SortedMap; +import java.util.TreeMap; + +/** + * RankFunction in Update Stream mode. + */ +public class RetractRankFunction extends AbstractRankFunction { + + private static final Logger LOG = LoggerFactory.getLogger(RetractRankFunction.class); + + /** + * Message to indicate the state is cleared because of ttl restriction. The message could be + * used to output to log. + */ + private static final String STATE_CLEARED_WARN_MSG = "The state is cleared because of state ttl. " + + "This will result in incorrect result. " + + "You can increase the state ttl to avoid this."; + + protected final BaseRowTypeInfo sortKeyType; + + // flag to skip records with non-exist error instead to fail, true by default. + private final boolean lenient = true; + + // a map state stores mapping from sort key to records list + private transient MapState> dataState; + + // a sorted map stores mapping from sort key to records count + private transient ValueState> treeMap; + + private Comparator sortKeyComparator; + + public RetractRankFunction( + long minRetentionTime, long maxRetentionTime, BaseRowTypeInfo inputRowType, BaseRowTypeInfo outputRowType, + BaseRowTypeInfo sortKeyType, GeneratedRecordComparator generatedRecordComparator, + KeySelector sortKeySelector, RankType rankType, RankRange rankRange, + GeneratedRecordEqualiser generatedEqualiser, boolean generateRetraction) { + super(minRetentionTime, maxRetentionTime, inputRowType, outputRowType, + generatedRecordComparator, sortKeySelector, rankType, rankRange, generatedEqualiser, generateRetraction); + this.sortKeyType = sortKeyType; + } + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + ListTypeInfo valueTypeInfo = new ListTypeInfo<>(inputRowType); + MapStateDescriptor> mapStateDescriptor = new MapStateDescriptor( + "data-state", sortKeyType, valueTypeInfo); + dataState = getRuntimeContext().getMapState(mapStateDescriptor); + + sortKeyComparator = generatedRecordComparator.newInstance(getRuntimeContext().getUserCodeClassLoader()); + ValueStateDescriptor> valueStateDescriptor = new ValueStateDescriptor( + "sorted-map", + new SortedMapTypeInfo(sortKeyType, BasicTypeInfo.LONG_TYPE_INFO, sortKeyComparator) + ); + treeMap = getRuntimeContext().getState(valueStateDescriptor); + } + + @Override + public void processElement( + BaseRow input, Context ctx, Collector out) throws Exception { + + initRankEnd(input); + + SortedMap sortedMap = treeMap.value(); + if (sortedMap == null) { + sortedMap = new TreeMap<>(sortKeyComparator); + } + BaseRow sortKey = sortKeySelector.getKey(input); + + if (BaseRowUtil.isAccumulateMsg(input)) { + // update sortedMap + if (sortedMap.containsKey(sortKey)) { + sortedMap.put(sortKey, sortedMap.get(sortKey) + 1); + } else { + sortedMap.put(sortKey, 1L); + } + + // emit + emitRecordsWithRowNumber(sortedMap, sortKey, input, out); + + // update data state + List inputs = dataState.get(sortKey); + if (inputs == null) { + // the sort key is never seen + inputs = new ArrayList<>(); + } + inputs.add(input); + dataState.put(sortKey, inputs); + } else { + // retract input + + // emit updates first + retractRecordWithRowNumber(sortedMap, sortKey, input, out); + + // and then update sortedMap + if (sortedMap.containsKey(sortKey)) { + long count = sortedMap.get(sortKey) - 1; + if (count == 0) { + sortedMap.remove(sortKey); + } else { + sortedMap.put(sortKey, count); + } + } else { + if (sortedMap.isEmpty()) { + if (lenient) { + LOG.warn(STATE_CLEARED_WARN_MSG); + } else { + throw new RuntimeException(STATE_CLEARED_WARN_MSG); + } + } else { + throw new RuntimeException( + "Can not retract a non-existent record: ${inputBaseRow.toString}. " + + "This should never happen."); + } + } + + } + treeMap.update(sortedMap); + } + + // ------------- ROW_NUMBER------------------------------- + + private void retractRecordWithRowNumber( + SortedMap sortedMap, BaseRow sortKey, BaseRow inputRow, Collector out) throws Exception { + Iterator> iterator = sortedMap.entrySet().iterator(); + long curRank = 0L; + boolean findSortKey = false; + while (iterator.hasNext() && isInRankEnd(curRank)) { + Map.Entry entry = iterator.next(); + BaseRow key = entry.getKey(); + if (!findSortKey && key.equals(sortKey)) { + List inputs = dataState.get(key); + if (inputs == null) { + // Skip the data if it's state is cleared because of state ttl. + if (lenient) { + LOG.warn(STATE_CLEARED_WARN_MSG); + } else { + throw new RuntimeException(STATE_CLEARED_WARN_MSG); + } + } else { + Iterator inputIter = inputs.iterator(); + while (inputIter.hasNext() && isInRankEnd(curRank)) { + curRank += 1; + BaseRow prevRow = inputIter.next(); + if (!findSortKey && equaliser.equalsWithoutHeader(prevRow, inputRow)) { + delete(out, prevRow, curRank); + curRank -= 1; + findSortKey = true; + inputIter.remove(); + } else if (findSortKey) { + retract(out, prevRow, curRank + 1); + collect(out, prevRow, curRank); + } + } + if (inputs.isEmpty()) { + dataState.remove(key); + } else { + dataState.put(key, inputs); + } + } + } else if (findSortKey) { + List inputs = dataState.get(key); + int i = 0; + while (i < inputs.size() && isInRankEnd(curRank)) { + curRank += 1; + BaseRow prevRow = inputs.get(i); + retract(out, prevRow, curRank + 1); + collect(out, prevRow, curRank); + i++; + } + } else { + curRank += entry.getValue(); + } + } + } + + private void emitRecordsWithRowNumber( + SortedMap sortedMap, BaseRow sortKey, BaseRow inputRow, Collector out) throws Exception { + Iterator> iterator = sortedMap.entrySet().iterator(); + long curRank = 0L; + boolean findSortKey = false; + while (iterator.hasNext() && isInRankRange(curRank)) { + Map.Entry entry = iterator.next(); + BaseRow key = entry.getKey(); + if (!findSortKey && key.equals(sortKey)) { + curRank += entry.getValue(); + collect(out, inputRow, curRank); + findSortKey = true; + } else if (findSortKey) { + List inputs = dataState.get(key); + if (inputs == null) { + // Skip the data if it's state is cleared because of state ttl. + if (lenient) { + LOG.warn(STATE_CLEARED_WARN_MSG); + } else { + throw new RuntimeException(STATE_CLEARED_WARN_MSG); + } + } else { + int i = 0; + while (i < inputs.size() && isInRankRange(curRank)) { + curRank += 1; + BaseRow prevRow = inputs.get(i); + retract(out, prevRow, curRank - 1); + collect(out, prevRow, curRank); + i++; + } + } + } else { + curRank += entry.getValue(); + } + } + } + + @Override + protected long getMaxSortMapSize() { + // just let it go, retract rank has no interest in this + return 0L; + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/SortedMap.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/SortedMap.java new file mode 100644 index 0000000000000..4f65a121453b2 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/SortedMap.java @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +import org.apache.flink.table.dataformat.BaseRow; + +import java.util.Collection; +import java.util.Comparator; +import java.util.Iterator; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; +import java.util.function.Supplier; + +/** + * SortedMap stores mapping from sort key to records list, each record is BaseRow type. + * SortedMap could also track rank number of each records. + * + * @param Type of the sort key + */ +public class SortedMap { + + private final Supplier> valueSupplier; + private int currentTopNum = 0; + private TreeMap> treeMap; + + public SortedMap(Comparator sortKeyComparator, Supplier> valueSupplier) { + this.valueSupplier = valueSupplier; + this.treeMap = new TreeMap(sortKeyComparator); + } + + /** + * Appends a record into the SortedMap under the sortKey. + * + * @param sortKey sort key with which the specified value is to be associated + * @param value record which is to be appended + * + * @return the size of the collection under the sortKey. + */ + public int put(T sortKey, BaseRow value) { + currentTopNum += 1; + // update treeMap + Collection collection = treeMap.get(sortKey); + if (collection == null) { + collection = valueSupplier.get(); + treeMap.put(sortKey, collection); + } + collection.add(value); + return collection.size(); + } + + /** + * Puts a record list into the SortedMap under the sortKey. + * Note: if SortedMap already contains sortKey, putAll will overwrite the previous value + * + * @param sortKey sort key with which the specified values are to be associated + * @param values record lists to be associated with the specified key + */ + public void putAll(T sortKey, Collection values) { + treeMap.put(sortKey, values); + currentTopNum += values.size(); + } + + /** + * Get the record list from SortedMap under the sortKey. + * + * @param sortKey key to get + * + * @return the record list from SortedMap under the sortKey + */ + public Collection get(T sortKey) { + return treeMap.get(sortKey); + } + + public void remove(T sortKey, BaseRow value) { + Collection list = treeMap.get(sortKey); + if (list != null) { + if (list.remove(value)) { + currentTopNum -= 1; + } + if (list.size() == 0) { + treeMap.remove(sortKey); + } + } + } + + /** + * Remove all record list from SortedMap under the sortKey. + * + * @param sortKey key to remove + */ + public void removeAll(T sortKey) { + Collection list = treeMap.get(sortKey); + if (list != null) { + currentTopNum -= list.size(); + treeMap.remove(sortKey); + } + } + + /** + * Remove the last record of the last Entry in the TreeMap (according to the TreeMap's + * key-sort function). + * + * @return removed record + */ + public BaseRow removeLast() { + Map.Entry> last = treeMap.lastEntry(); + BaseRow lastElement = null; + if (last != null) { + Collection list = last.getValue(); + lastElement = getLastElement(list); + if (lastElement != null) { + if (list.remove(lastElement)) { + currentTopNum -= 1; + } + if (list.size() == 0) { + treeMap.remove(last.getKey()); + } + } + } + return lastElement; + } + + /** + * Get record which rank is given value. + * + * @param rank rank value to search + * + * @return the record which rank is given value + */ + public BaseRow getElement(int rank) { + int curRank = 0; + Iterator>> iter = treeMap.entrySet().iterator(); + while (iter.hasNext()) { + Map.Entry> entry = iter.next(); + Collection list = entry.getValue(); + + Iterator listIter = list.iterator(); + while (listIter.hasNext()) { + BaseRow elem = listIter.next(); + curRank += 1; + if (curRank == rank) { + return elem; + } + } + } + return null; + } + + private BaseRow getLastElement(Collection list) { + BaseRow element = null; + if (list != null && !list.isEmpty()) { + Iterator iter = list.iterator(); + while (iter.hasNext()) { + element = iter.next(); + } + } + return element; + } + + /** + * Returns a {@link Set} view of the mappings contained in this map. + */ + public Set>> entrySet() { + return treeMap.entrySet(); + } + + /** + * Returns the last Entry in the TreeMap (according to the TreeMap's + * key-sort function). Returns null if the TreeMap is empty. + */ + public Map.Entry> lastEntry() { + return treeMap.lastEntry(); + } + + /** + * Returns {@code true} if this map contains a mapping for the specified + * key. + * + * @param key key whose presence in this map is to be tested + * + * @return {@code true} if this map contains a mapping for the + * specified key + */ + public boolean containsKey(T key) { + return treeMap.containsKey(key); + } + + /** + * Get number of total records. + * + * @return the number of total records. + */ + public int getCurrentTopNum() { + return currentTopNum; + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/UpdateRankFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/UpdateRankFunction.java new file mode 100644 index 0000000000000..79f817fc5cd35 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/UpdateRankFunction.java @@ -0,0 +1,258 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.generated.GeneratedRecordComparator; +import org.apache.flink.table.generated.GeneratedRecordEqualiser; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; +import org.apache.flink.util.Collector; + +import java.util.Collection; +import java.util.Iterator; +import java.util.Map; + +/** + * A fast version of rank process function which only hold top n data in state, + * and keep sorted map in heap. This only works in some special scenarios, such + * as, rank a count(*) stream + */ +public class UpdateRankFunction extends AbstractUpdateRankFunction { + + private final TypeSerializer inputRowSer; + private final KeySelector rowKeySelector; + + public UpdateRankFunction( + long minRetentionTime, + long maxRetentionTime, + BaseRowTypeInfo inputRowType, + BaseRowTypeInfo outputRowType, + BaseRowTypeInfo rowKeyType, + KeySelector rowKeySelector, + GeneratedRecordComparator generatedRecordComparator, + KeySelector sortKeySelector, + RankType rankType, + RankRange rankRange, + GeneratedRecordEqualiser generatedEqualiser, + boolean generateRetraction, + long cacheSize) { + super(minRetentionTime, + maxRetentionTime, + inputRowType, + outputRowType, + rowKeyType, + generatedRecordComparator, + sortKeySelector, + rankType, + rankRange, + generatedEqualiser, + generateRetraction, + cacheSize); + this.inputRowSer = inputRowType.createSerializer(new ExecutionConfig()); + this.rowKeySelector = rowKeySelector; + } + + @Override + public void processElement( + BaseRow input, Context context, Collector out) throws Exception { + long currentTime = context.timerService().currentProcessingTime(); + // register state-cleanup timer + registerProcessingCleanupTimer(context, currentTime); + + initHeapStates(); + initRankEnd(input); + if (isRowNumberAppend || hasOffset()) { + // the without-number-algorithm can't handle topn with offset, + // so use the with-number-algorithm to handle offset + processElementWithRowNumber(input, out); + } else { + processElementWithoutRowNumber(input, out); + } + } + + @Override + protected long getMaxSortMapSize() { + return getDefaultTopSize(); + } + + private void processElementWithRowNumber(BaseRow inputRow, Collector out) throws Exception { + BaseRow sortKey = sortKeySelector.getKey(inputRow); + BaseRow rowKey = rowKeySelector.getKey(inputRow); + if (rowKeyMap.containsKey(rowKey)) { + // it is an updated record which is in the topN, in this scenario, + // the new sort key must be higher than old sort key, this is guaranteed by rules + RankRow oldRow = rowKeyMap.get(rowKey); + BaseRow oldSortKey = sortKeySelector.getKey(oldRow.getRow()); + if (oldSortKey.equals(sortKey)) { + // sort key is not changed, so the rank is the same, only output the row + Tuple2 rankAndInnerRank = rowNumber(sortKey, rowKey, sortedMap); + int rank = rankAndInnerRank.f0; + int innerRank = rankAndInnerRank.f1; + rowKeyMap.put(rowKey, new RankRow(inputRowSer.copy(inputRow), innerRank, true)); + retract(out, oldRow.getRow(), rank); // retract old record + collect(out, inputRow, rank); + return; + } + + Tuple2 oldRankAndInnerRank = rowNumber(oldSortKey, rowKey, sortedMap); + int oldRank = oldRankAndInnerRank.f0; + // remove old sort key + sortedMap.remove(oldSortKey, rowKey); + // add new sort key + int size = sortedMap.put(sortKey, rowKey); + rowKeyMap.put(rowKey, new RankRow(inputRowSer.copy(inputRow), size, true)); + // update inner rank of records under the old sort key + updateInnerRank(oldSortKey); + + // emit records + emitRecordsWithRowNumber(sortKey, inputRow, out, oldSortKey, oldRow, oldRank); + } else { + // out of topN + } + } + + private void updateInnerRank(BaseRow oldSortKey) { + Collection list = sortedMap.get(oldSortKey); + if (list != null) { + Iterator iter = list.iterator(); + int innerRank = 1; + while (iter.hasNext()) { + BaseRow rowkey = iter.next(); + RankRow row = rowKeyMap.get(rowkey); + if (row.getInnerRank() != innerRank) { + row.setInnerRank(innerRank); + row.setDirty(true); + } + innerRank += 1; + } + } + } + + private void emitRecordsWithRowNumber( + BaseRow sortKey, BaseRow inputRow, Collector out, BaseRow oldSortKey, RankRow oldRow, int oldRank) { + + int oldInnerRank = oldRow == null ? -1 : oldRow.getInnerRank(); + Iterator>> iterator = sortedMap.entrySet().iterator(); + int curRank = 0; + // whether we have found the sort key in the sorted tree + boolean findSortKey = false; + while (iterator.hasNext() && isInRankEnd(curRank + 1)) { + Map.Entry> entry = iterator.next(); + BaseRow curKey = entry.getKey(); + Collection rowKeys = entry.getValue(); + // meet its own sort key + if (!findSortKey && curKey.equals(sortKey)) { + curRank += rowKeys.size(); + if (oldRow != null) { + retract(out, oldRow.getRow(), oldRank); + } + collect(out, inputRow, curRank); + findSortKey = true; + } else if (findSortKey) { + if (oldSortKey == null) { + // this is a new row, emit updates for all rows in the topn + Iterator rowKeyIter = rowKeys.iterator(); + while (rowKeyIter.hasNext() && isInRankEnd(curRank + 1)) { + curRank += 1; + BaseRow rowKey = rowKeyIter.next(); + RankRow prevRow = rowKeyMap.get(rowKey); + retract(out, prevRow.getRow(), curRank - 1); + collect(out, prevRow.getRow(), curRank); + } + } else { + // current sort key is higher than old sort key, + // the rank of current record is changed, need to update the following rank + int compare = sortKeyComparator.compare(curKey, oldSortKey); + if (compare <= 0) { + Iterator rowKeyIter = rowKeys.iterator(); + int curInnerRank = 0; + while (rowKeyIter.hasNext() && isInRankEnd(curRank + 1)) { + curRank += 1; + curInnerRank += 1; + if (compare == 0 && curInnerRank >= oldInnerRank) { + // match to the previous position + return; + } + + BaseRow rowKey = rowKeyIter.next(); + RankRow prevRow = rowKeyMap.get(rowKey); + retract(out, prevRow.getRow(), curRank - 1); + collect(out, prevRow.getRow(), curRank); + } + } else { + // current sort key is smaller than old sort key, + // the rank is not changed, so skip + return; + } + } + } else { + curRank += rowKeys.size(); + } + } + } + + private void processElementWithoutRowNumber(BaseRow inputRow, Collector out) throws Exception { + BaseRow sortKey = sortKeySelector.getKey(inputRow); + BaseRow rowKey = rowKeySelector.getKey(inputRow); + if (rowKeyMap.containsKey(rowKey)) { + // it is an updated record which is in the topN, in this scenario, + // the new sort key must be higher than old sort key, this is guaranteed by rules + RankRow oldRow = rowKeyMap.get(rowKey); + BaseRow oldSortKey = sortKeySelector.getKey(oldRow.getRow()); + if (!oldSortKey.equals(sortKey)) { + // remove old sort key + sortedMap.remove(oldSortKey, rowKey); + // add new sort key + int size = sortedMap.put(sortKey, rowKey); + rowKeyMap.put(rowKey, new RankRow(inputRowSer.copy(inputRow), size, true)); + // update inner rank of records under the old sort key + updateInnerRank(oldSortKey); + } else { + // row content may change, so we need to update row in map + rowKeyMap.put(rowKey, new RankRow(inputRowSer.copy(inputRow), oldRow.getInnerRank(), true)); + } + // row content may change, so a retract is needed + retract(out, oldRow.getRow(), oldRow.getInnerRank()); + collect(out, inputRow); + } else if (checkSortKeyInBufferRange(sortKey, sortedMap, sortKeyComparator)) { + // it is an unique record but is in the topN + // insert sort key into sortedMap + int size = sortedMap.put(sortKey, rowKey); + rowKeyMap.put(rowKey, new RankRow(inputRowSer.copy(inputRow), size, true)); + collect(out, inputRow); + // remove retired element + if (sortedMap.getCurrentTopNum() > rankEnd) { + BaseRow lastRowKey = sortedMap.removeLast(); + if (lastRowKey != null) { + RankRow lastRow = rowKeyMap.remove(lastRowKey); + dataState.remove(lastRowKey); + // always send a retraction message + delete(out, lastRow.getRow()); + } + } + } else { + // out of topN + } + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/VariableRankRange.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/VariableRankRange.java new file mode 100644 index 0000000000000..a065fe3ea1d2c --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/VariableRankRange.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +import java.util.List; + +/** + * changing rank limit depends on input. + */ +public class VariableRankRange implements RankRange { + + private static final long serialVersionUID = 5579785886506433955L; + private int rankEndIndex; + + public VariableRankRange(int rankEndIndex) { + this.rankEndIndex = rankEndIndex; + } + + public int getRankEndIndex() { + return rankEndIndex; + } + + @Override + public String toString(List inputFieldNames) { + return "rankEnd=" + inputFieldNames.get(rankEndIndex); + } + + @Override + public String toString() { + return "rankEnd=$" + rankEndIndex; + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/values/ValuesInputFormat.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/values/ValuesInputFormat.java index 45ddd4ce33162..5ad8779059e78 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/values/ValuesInputFormat.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/values/ValuesInputFormat.java @@ -40,8 +40,8 @@ public class ValuesInputFormat implements NonParallelInput, ResultTypeQueryable { private static final Logger LOG = LoggerFactory.getLogger(ValuesInputFormat.class); - private GeneratedInput> generatedInput; - private BaseRowTypeInfo returnType; + private final GeneratedInput> generatedInput; + private final BaseRowTypeInfo returnType; private GenericInputFormat format; public ValuesInputFormat(GeneratedInput> generatedInput, BaseRowTypeInfo returnType) { diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/TypeConverters.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/TypeConverters.java index a7b9199c962ca..12aaf6192b4bc 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/TypeConverters.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/TypeConverters.java @@ -93,6 +93,8 @@ public class TypeConverters { internalTypeToInfo.put(InternalTypes.CHAR, BasicTypeInfo.CHAR_TYPE_INFO); internalTypeToInfo.put(InternalTypes.DATE, BasicTypeInfo.INT_TYPE_INFO); internalTypeToInfo.put(InternalTypes.TIMESTAMP, BasicTypeInfo.LONG_TYPE_INFO); + internalTypeToInfo.put(InternalTypes.PROCTIME_INDICATOR, BasicTypeInfo.LONG_TYPE_INFO); + internalTypeToInfo.put(InternalTypes.ROWTIME_INDICATOR, BasicTypeInfo.LONG_TYPE_INFO); internalTypeToInfo.put(InternalTypes.TIME, BasicTypeInfo.INT_TYPE_INFO); internalTypeToInfo.put(InternalTypes.BINARY, PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO); internalTypeToInfo.put(InternalTypes.INTERVAL_MONTHS, BasicTypeInfo.INT_TYPE_INFO); @@ -111,6 +113,8 @@ public class TypeConverters { itToEti.put(InternalTypes.CHAR, BasicTypeInfo.CHAR_TYPE_INFO); itToEti.put(InternalTypes.DATE, SqlTimeTypeInfo.DATE); itToEti.put(InternalTypes.TIMESTAMP, SqlTimeTypeInfo.TIMESTAMP); + itToEti.put(InternalTypes.PROCTIME_INDICATOR, SqlTimeTypeInfo.TIMESTAMP); + itToEti.put(InternalTypes.ROWTIME_INDICATOR, SqlTimeTypeInfo.TIMESTAMP); itToEti.put(InternalTypes.TIME, SqlTimeTypeInfo.TIME); itToEti.put(InternalTypes.BINARY, PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO); INTERNAL_TYPE_TO_EXTERNAL_TYPE_INFO = Collections.unmodifiableMap(itToEti); diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/AbstractMapSerializer.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/AbstractMapSerializer.java new file mode 100644 index 0000000000000..344fb978c5fdc --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/AbstractMapSerializer.java @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.typeutils; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.Map; + +/** + * Base serializer for {@link Map}. The serializer relies on a key serializer + * and a value serializer for the serialization of the map's key-value pairs. + * + *

The serialization format for the map is as follows: four bytes for the + * length of the map, followed by the serialized representation of each + * key-value pair. To allow null values, each value is prefixed by a null flag. + * + * @param The type of the keys in the map. + * @param The type of the values in the map. + * @param The type of the map. + */ +abstract class AbstractMapSerializer> extends TypeSerializer { + + private static final long serialVersionUID = 1L; + + /** The serializer for the keys in the map. */ + final TypeSerializer keySerializer; + + /** The serializer for the values in the map. */ + final TypeSerializer valueSerializer; + + /** + * Creates a map serializer that uses the given serializers to serialize the key-value pairs in the map. + * + * @param keySerializer The serializer for the keys in the map + * @param valueSerializer The serializer for the values in the map + */ + AbstractMapSerializer(TypeSerializer keySerializer, TypeSerializer valueSerializer) { + Preconditions.checkNotNull(keySerializer, "The key serializer must not be null"); + Preconditions.checkNotNull(valueSerializer, "The value serializer must not be null."); + this.keySerializer = keySerializer; + this.valueSerializer = valueSerializer; + } + + // ------------------------------------------------------------------------ + + /** + * Returns the serializer for the keys in the map. + * + * @return The serializer for the keys in the map. + */ + public TypeSerializer getKeySerializer() { + return keySerializer; + } + + /** + * Returns the serializer for the values in the map. + * + * @return The serializer for the values in the map. + */ + public TypeSerializer getValueSerializer() { + return valueSerializer; + } + + // ------------------------------------------------------------------------ + // Type Serializer implementation + // ------------------------------------------------------------------------ + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public M copy(M from) { + M newMap = createInstance(); + + for (Map.Entry entry : from.entrySet()) { + K newKey = entry.getKey() == null ? null : keySerializer.copy(entry.getKey()); + V newValue = entry.getValue() == null ? null : valueSerializer.copy(entry.getValue()); + + newMap.put(newKey, newValue); + } + + return newMap; + } + + @Override + public M copy(M from, M reuse) { + return copy(from); + } + + @Override + public int getLength() { + return -1; // var length + } + + @Override + public void serialize(M map, DataOutputView target) throws IOException { + final int size = map.size(); + target.writeInt(size); + + for (Map.Entry entry : map.entrySet()) { + keySerializer.serialize(entry.getKey(), target); + + if (entry.getValue() == null) { + target.writeBoolean(true); + } else { + target.writeBoolean(false); + valueSerializer.serialize(entry.getValue(), target); + } + } + } + + @Override + public M deserialize(DataInputView source) throws IOException { + final int size = source.readInt(); + + final M map = createInstance(); + for (int i = 0; i < size; ++i) { + K key = keySerializer.deserialize(source); + + boolean isNull = source.readBoolean(); + V value = isNull ? null : valueSerializer.deserialize(source); + + map.put(key, value); + } + + return map; + } + + @Override + public M deserialize(M reuse, DataInputView source) throws IOException { + return deserialize(source); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + final int size = source.readInt(); + target.writeInt(size); + + for (int i = 0; i < size; ++i) { + keySerializer.copy(source, target); + + boolean isNull = source.readBoolean(); + target.writeBoolean(isNull); + + if (!isNull) { + valueSerializer.copy(source, target); + } + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + AbstractMapSerializer that = (AbstractMapSerializer) o; + + return keySerializer.equals(that.keySerializer) && valueSerializer.equals(that.valueSerializer); + } + + @Override + public int hashCode() { + int result = keySerializer.hashCode(); + result = 31 * result + valueSerializer.hashCode(); + return result; + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/AbstractMapTypeInfo.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/AbstractMapTypeInfo.java new file mode 100644 index 0000000000000..2321d28494cbe --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/AbstractMapTypeInfo.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.typeutils; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.util.Preconditions; + +import java.util.Map; + +/** + * Base type information for maps. + * + * @param The type of the keys in the map. + * @param The type of the values in the map. + */ +abstract class AbstractMapTypeInfo> extends TypeInformation { + + private static final long serialVersionUID = 1L; + + /* The type information for the keys in the map*/ + final TypeInformation keyTypeInfo; + + /* The type information for the values in the map */ + final TypeInformation valueTypeInfo; + + /** + * Constructor with given type information for the keys and the values in + * the map. + * + * @param keyTypeInfo The type information for the keys in the map. + * @param valueTypeInfo The type information for the values in th map. + */ + AbstractMapTypeInfo(TypeInformation keyTypeInfo, TypeInformation valueTypeInfo) { + Preconditions.checkNotNull(keyTypeInfo, "The type information for the keys cannot be null."); + Preconditions.checkNotNull(valueTypeInfo, "The type information for the values cannot be null."); + this.keyTypeInfo = keyTypeInfo; + this.valueTypeInfo = valueTypeInfo; + } + + /** + * Constructor with the classes of the keys and the values in the map. + * + * @param keyClass The class of the keys in the map. + * @param valueClass The class of the values in the map. + */ + AbstractMapTypeInfo(Class keyClass, Class valueClass) { + Preconditions.checkNotNull(keyClass, "The key class cannot be null."); + Preconditions.checkNotNull(valueClass, "The value class cannot be null."); + + this.keyTypeInfo = TypeInformation.of(keyClass); + this.valueTypeInfo = TypeInformation.of(valueClass); + } + + // ------------------------------------------------------------------------ + + /** + * Returns the type information for the keys in the map. + * + * @return The type information for the keys in the map. + */ + public TypeInformation getKeyTypeInfo() { + return keyTypeInfo; + } + + /** + * Returns the type information for the values in the map. + * + * @return The type information for the values in the map. + */ + public TypeInformation getValueTypeInfo() { + return valueTypeInfo; + } + + // ------------------------------------------------------------------------ + + @Override + public boolean isBasicType() { + return false; + } + + @Override + public boolean isTupleType() { + return false; + } + + @Override + public int getArity() { + return 0; + } + + @Override + public int getTotalFields() { + return 2; + } + + @Override + public boolean isKeyType() { + return false; + } + + // ------------------------------------------------------------------------ + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + AbstractMapTypeInfo that = (AbstractMapTypeInfo) o; + + return keyTypeInfo.equals(that.keyTypeInfo) && valueTypeInfo.equals(that.valueTypeInfo); + } + + @Override + public int hashCode() { + int result = keyTypeInfo.hashCode(); + result = 31 * result + valueTypeInfo.hashCode(); + return result; + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapSerializer.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapSerializer.java new file mode 100644 index 0000000000000..c1c9421afe928 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapSerializer.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.typeutils; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializerConfigSnapshot; +import org.apache.flink.api.common.typeutils.base.MapSerializerConfigSnapshot; +import org.apache.flink.util.Preconditions; + +import java.util.Comparator; +import java.util.SortedMap; +import java.util.TreeMap; + +/** + * A serializer for {@link SortedMap}. The serializer relies on a key serializer + * and a value serializer for the serialization of the map's key-value pairs. + * It also deploys a comparator to ensure the order of the keys. + * + *

The serialization format for the map is as follows: four bytes for the + * length of the map, followed by the serialized representation of each + * key-value pair. To allow null values, each value is prefixed by a null flag. + * + * @param The type of the keys in the map. + * @param The type of the values in the map. + */ +public final class SortedMapSerializer extends AbstractMapSerializer> { + + private static final long serialVersionUID = 1L; + + /** The comparator for the keys in the map. */ + private final Comparator comparator; + + /** + * Constructor with given comparator, and the serializers for the keys and + * values in the map. + * + * @param comparator The comparator for the keys in the map. + * @param keySerializer The serializer for the keys in the map. + * @param valueSerializer The serializer for the values in the map. + */ + public SortedMapSerializer( + Comparator comparator, + TypeSerializer keySerializer, + TypeSerializer valueSerializer) { + super(keySerializer, valueSerializer); + + Preconditions.checkNotNull(comparator, "The comparator cannot be null."); + this.comparator = comparator; + } + + /** + * Returns the comparator for the keys in the map. + * + * @return The comparator for the keys in the map. + */ + public Comparator getComparator() { + return comparator; + } + + @Override + public TypeSerializer> duplicate() { + TypeSerializer keySerializer = getKeySerializer().duplicate(); + TypeSerializer valueSerializer = getValueSerializer().duplicate(); + + return new SortedMapSerializer<>(comparator, keySerializer, valueSerializer); + } + + @Override + public SortedMap createInstance() { + return new TreeMap<>(comparator); + } + + @Override + public boolean equals(Object o) { + if (!super.equals(o)) { + return false; + } + + SortedMapSerializer that = (SortedMapSerializer) o; + return comparator.equals(that.comparator); + } + + @Override + public int hashCode() { + int result = super.hashCode(); + result = result * 31 + comparator.hashCode(); + return result; + } + + @Override + public String toString() { + return "SortedMapSerializer{" + + "comparator = " + comparator + + ", keySerializer = " + keySerializer + + ", valueSerializer = " + valueSerializer + + "}"; + } + + @Override + public TypeSerializerConfigSnapshot snapshotConfiguration() { + return new MapSerializerConfigSnapshot<>(keySerializer, valueSerializer); + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapTypeInfo.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapTypeInfo.java new file mode 100644 index 0000000000000..6850c23a1e0df --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapTypeInfo.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.typeutils; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.util.Preconditions; + +import java.io.Serializable; +import java.util.Comparator; +import java.util.SortedMap; + +/** + * The type information for sorted maps. + * + * @param The type of the keys in the map. + * @param The type of the values in the map. + */ +@PublicEvolving +public class SortedMapTypeInfo extends AbstractMapTypeInfo> { + + private static final long serialVersionUID = 1L; + + /** The comparator for the keys in the map. */ + private final Comparator comparator; + + public SortedMapTypeInfo( + TypeInformation keyTypeInfo, + TypeInformation valueTypeInfo, + Comparator comparator) { + super(keyTypeInfo, valueTypeInfo); + + Preconditions.checkNotNull(comparator, "The comparator cannot be null."); + this.comparator = comparator; + } + + public SortedMapTypeInfo( + Class keyClass, + Class valueClass, + Comparator comparator) { + super(keyClass, valueClass); + + Preconditions.checkNotNull(comparator, "The comparator cannot be null."); + this.comparator = comparator; + } + + public SortedMapTypeInfo(Class keyClass, Class valueClass) { + super(keyClass, valueClass); + + Preconditions.checkArgument(Comparable.class.isAssignableFrom(keyClass), + "The key class must be comparable when no comparator is given."); + this.comparator = new ComparableComparator<>(); + } + + // ------------------------------------------------------------------------ + + @SuppressWarnings("unchecked") + @Override + public Class> getTypeClass() { + return (Class>) (Class) SortedMap.class; + } + + @Override + public TypeSerializer> createSerializer(ExecutionConfig config) { + TypeSerializer keyTypeSerializer = keyTypeInfo.createSerializer(config); + TypeSerializer valueTypeSerializer = valueTypeInfo.createSerializer(config); + + return new SortedMapSerializer<>(comparator, keyTypeSerializer, valueTypeSerializer); + } + + @Override + public boolean canEqual(Object obj) { + return null != obj && getClass() == obj.getClass(); + } + + @Override + public boolean equals(Object o) { + if (!super.equals(o)) { + return false; + } + + SortedMapTypeInfo that = (SortedMapTypeInfo) o; + + return comparator.equals(that.comparator); + } + + @Override + public int hashCode() { + int result = super.hashCode(); + result = 31 * result + comparator.hashCode(); + return result; + } + + @Override + public String toString() { + return "SortedMapTypeInfo{" + + "comparator=" + comparator + + ", keyTypeInfo=" + getKeyTypeInfo() + + ", valueTypeInfo=" + getValueTypeInfo() + + "}"; + } + + //-------------------------------------------------------------------------- + + /** + * The default comparator for comparable types. + */ + private static class ComparableComparator implements Comparator, Serializable { + private static final long serialVersionUID = 1L; + + @SuppressWarnings("unchecked") + public int compare(K obj1, K obj2) { + return ((Comparable) obj1).compareTo(obj2); + } + + @Override + public boolean equals(Object o) { + return (o == this) || (o != null && o.getClass() == getClass()); + } + + @Override + public int hashCode() { + return "ComparableComparator".hashCode(); + } + + } +} From ac1b502d45c19816deeafdc69df40184325115c8 Mon Sep 17 00:00:00 2001 From: beyond1920 Date: Tue, 9 Apr 2019 19:48:33 +0800 Subject: [PATCH 2/5] Update code style. --- .../table/plan/util/KeySelectorUtil.java | 3 +- .../codegen/EqualiserCodeGenerator.scala | 10 +- .../codegen/sort/SortCodeGenerator.scala | 2 + .../stream/StreamExecDeduplicate.scala | 16 +- .../physical/stream/StreamExecExchange.scala | 2 +- .../physical/stream/StreamExecRank.scala | 33 +- .../utils/FailingCollectionSource.java | 17 +- .../stream/sql/DeduplicateITCase.scala | 30 +- .../table/runtime/stream/sql/RankITCase.scala | 102 +++-- .../table/runtime/utils/StreamTestSink.scala | 68 +-- .../StreamingWithMiniBatchTestBase.scala | 2 +- .../utils/StreamingWithStateTestBase.scala | 4 +- .../flink/table/runtime/utils/TableUtil.scala | 10 +- .../table/runtime/utils/TimeTestUtil.scala | 3 +- .../flink/table/dataformat/BinaryRow.java | 8 +- .../flink/table/dataformat/BinaryWriter.java | 3 +- .../table/dataformat/TypeGetterSetters.java | 3 +- .../deduplicate/DeduplicateFunction.java | 16 +- ...se.java => DeduplicateFunctionHelper.java} | 18 +- .../MiniBatchDeduplicateFunction.java | 14 +- .../KeyedProcessFunctionWithCleanupState.java | 6 +- .../keyselector/BinaryRowKeySelector.java | 2 +- .../keyselector/NullBinaryRowKeySelector.java | 2 +- .../runtime/rank/AbstractRankFunction.java | 145 +++---- .../rank/AbstractUpdateRankFunction.java | 293 ------------- .../runtime/rank/AppendRankFunction.java | 121 +++--- .../runtime/rank/RetractRankFunction.java | 67 ++- .../rank/{SortedMap.java => TopNBuffer.java} | 93 ++-- .../runtime/rank/UpdateRankFunction.java | 403 ++++++++++++++---- .../runtime/values/ValuesInputFormat.java | 4 +- .../table/typeutils/SortedMapSerializer.java | 8 +- .../table/typeutils/SortedMapTypeInfo.java | 10 +- 32 files changed, 736 insertions(+), 782 deletions(-) rename flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/{DeduplicateFunctionBase.java => DeduplicateFunctionHelper.java} (78%) delete mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractUpdateRankFunction.java rename flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/{SortedMap.java => TopNBuffer.java} (60%) diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/util/KeySelectorUtil.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/util/KeySelectorUtil.java index 651ef531274ee..50082c2beacfa 100644 --- a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/util/KeySelectorUtil.java +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/util/KeySelectorUtil.java @@ -40,8 +40,7 @@ public class KeySelectorUtil { * Create a BaseRowKeySelector to extract keys from DataStream which type is BaseRowTypeInfo. * * @param keyFields key fields - * @param rowType type of DataStream to extract keys - * + * @param rowType type of DataStream to extract keys * @return the BaseRowKeySelector to extract keys from DataStream which type is BaseRowTypeInfo. */ public static BaseRowKeySelector getBaseRowSelector(int[] keyFields, BaseRowTypeInfo rowType) { diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/EqualiserCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/EqualiserCodeGenerator.scala index 35c8d92d49f15..09658663e6ca0 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/EqualiserCodeGenerator.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/EqualiserCodeGenerator.scala @@ -20,14 +20,12 @@ package org.apache.flink.table.codegen import org.apache.flink.table.api.TableConfig import org.apache.flink.table.codegen.CodeGenUtils._ import org.apache.flink.table.codegen.Indenter.toISC -import org.apache.flink.table.dataformat.{BaseRow, BinaryRow} import org.apache.flink.table.generated.{GeneratedRecordEqualiser, RecordEqualiser} -import org.apache.flink.table.`type`.{DateType, InternalType, PrimitiveType, RowType, TimeType, TimestampType} +import org.apache.flink.table.`type`.{DateType, InternalType, PrimitiveType, RowType, TimeType, +TimestampType} class EqualiserCodeGenerator(fieldTypes: Seq[InternalType]) { - private val BASE_ROW = className[BaseRow] - private val BINARY_ROW = className[BinaryRow] private val RECORD_EQUALISER = className[RecordEqualiser] private val LEFT_INPUT = "left" private val RIGHT_INPUT = "right" @@ -57,7 +55,7 @@ class EqualiserCodeGenerator(fieldTypes: Seq[InternalType]) { val equaliserGenerator = new EqualiserCodeGenerator(fieldType.asInstanceOf[RowType].getFieldTypes) val generatedEqualiser = equaliserGenerator - .generateRecordEqualiser("field$" + i + "GeneratedEqualiser") + .generateRecordEqualiser("field$" + i + "GeneratedEqualiser") val generatedEqualiserTerm = ctx.addReusableObject( generatedEqualiser, "field$" + i + "GeneratedEqualiser") val equaliserTypeTerm = classOf[RecordEqualiser].getCanonicalName @@ -80,7 +78,7 @@ class EqualiserCodeGenerator(fieldTypes: Seq[InternalType]) { |boolean $result; |if ($leftNullTerm && $rightNullTerm) { | $result = true; - |} else if ($leftNullTerm || $rightNullTerm) { + |} else if ($leftNullTerm|| $rightNullTerm) { | $result = false; |} else { | $fieldTypeTerm $leftFieldTerm = $leftReadCode; diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/sort/SortCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/sort/SortCodeGenerator.scala index b7913a12285e6..57a64212a0b7e 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/sort/SortCodeGenerator.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/sort/SortCodeGenerator.scala @@ -450,6 +450,8 @@ class SortCodeGenerator( case InternalTypes.DOUBLE => 8 case InternalTypes.LONG => 8 case _: TimestampType => 8 + case _: DateType => 4 + case InternalTypes.TIME => 4 case dt: DecimalType if Decimal.isCompact(dt.precision()) => 8 case InternalTypes.STRING | InternalTypes.BINARY => Int.MaxValue } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDeduplicate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDeduplicate.scala index 6e75118022463..0c4171c78b2ab 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDeduplicate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDeduplicate.scala @@ -29,7 +29,8 @@ import org.apache.flink.table.plan.nodes.exec.{ExecNode, StreamExecNode} import org.apache.flink.table.plan.util.KeySelectorUtil import org.apache.flink.table.runtime.bundle.KeyedMapBundleOperator import org.apache.flink.table.runtime.bundle.trigger.CountBundleTrigger -import org.apache.flink.table.runtime.deduplicate.{DeduplicateFunction, MiniBatchDeduplicateFunction} +import org.apache.flink.table.runtime.deduplicate.{DeduplicateFunction, +MiniBatchDeduplicateFunction} import org.apache.flink.table.`type`.TypeConverters import org.apache.flink.table.typeutils.BaseRowTypeInfo import org.apache.flink.table.typeutils.TypeCheckUtils.isRowTime @@ -56,8 +57,8 @@ class StreamExecDeduplicate( isRowtime: Boolean, keepLastRow: Boolean) extends SingleRel(cluster, traitSet, inputRel) - with StreamPhysicalRel - with StreamExecNode[BaseRow] { + with StreamPhysicalRel + with StreamExecNode[BaseRow] { def getUniqueKeys: Array[Int] = uniqueKeys @@ -97,7 +98,7 @@ class StreamExecDeduplicate( override protected def translateToPlanInternal( tableEnv: StreamTableEnvironment): StreamTransformation[BaseRow] = { - // TODO checkInput is not acc retract after FLINK- is done + // FIXME checkInput is not acc retract after FLINK-12098 is done val inputIsAccRetract = false if (inputIsAccRetract) { @@ -108,16 +109,17 @@ class StreamExecDeduplicate( } val inputTransform = getInputNodes.get(0).translateToPlan(tableEnv) - .asInstanceOf[StreamTransformation[BaseRow]] + .asInstanceOf[StreamTransformation[BaseRow]] val rowTypeInfo = inputTransform.getOutputType.asInstanceOf[BaseRowTypeInfo] + // FIXME infer generate retraction after FLINK-12098 is done val generateRetraction = true val inputRowType = FlinkTypeFactory.toInternalRowType(getInput.getRowType) val rowTimeFieldIndex = inputRowType.getFieldTypes.zipWithIndex - .filter(e => isRowTime(e._1)) - .map(_._2) + .filter(e => isRowTime(e._1)) + .map(_._2) if (rowTimeFieldIndex.size > 1) { throw new RuntimeException("More than one row time field. Currently this is not supported!") } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecExchange.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecExchange.scala index e1c431b0440f3..2c58863dd6a40 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecExchange.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecExchange.scala @@ -75,7 +75,7 @@ class StreamExecExchange( override protected def translateToPlanInternal( tableEnv: StreamTableEnvironment): StreamTransformation[BaseRow] = { val inputTransform = getInputNodes.get(0).translateToPlan(tableEnv) - .asInstanceOf[StreamTransformation[BaseRow]] + .asInstanceOf[StreamTransformation[BaseRow]] val inputTypeInfo = inputTransform.getOutputType.asInstanceOf[BaseRowTypeInfo] val outputTypeInfo = FlinkTypeFactory.toInternalRowType(getRowType).toTypeInfo relDistribution.getType match { diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecRank.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecRank.scala index 97504a2e0a12f..d0ec274740d31 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecRank.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecRank.scala @@ -132,8 +132,6 @@ class StreamExecRank( } val inputRowTypeInfo = FlinkTypeFactory.toInternalRowType(getInput.getRowType).toTypeInfo - val outputRowTypeInfo = FlinkTypeFactory.toInternalRowType(getRowType).toTypeInfo - val fieldCollations = orderKey.getFieldCollations val (sortFields, sortDirections, nullsIsLast) = SortUtil.getKeysAndOrders(fieldCollations) val sortKeySelector = KeySelectorUtil.getBaseRowSelector(sortFields, inputRowTypeInfo) @@ -141,28 +139,28 @@ class StreamExecRank( val sortCodeGen = new SortCodeGenerator( tableConfig, sortFields.indices.toArray, sortKeyType.getInternalTypes, sortDirections, nullsIsLast) - val comparator = sortCodeGen.generateRecordComparator("StreamExecSortComparator") - // TODO infer generate retraction after retraction rules are merged + val sortKeyComparator = sortCodeGen.generateRecordComparator("StreamExecSortComparator") + // FIXME infer generate retraction after FLINK-12098 is done val generateRetraction = true val cacheSize = tableConfig.getConf.getLong(TableConfigOptions.SQL_EXEC_TOPN_CACHE_SIZE) val minIdleStateRetentionTime = tableConfig.getMinIdleStateRetentionTime val maxIdleStateRetentionTime = tableConfig.getMaxIdleStateRetentionTime val equaliserCodeGenerator = new EqualiserCodeGenerator(inputRowTypeInfo.getInternalTypes) - val equaliser = equaliserCodeGenerator.generateRecordEqualiser("RankValueEqualiser") + val generatedEqualiser = equaliserCodeGenerator.generateRecordEqualiser("RankValueEqualiser") val processFunction = getStrategy(true) match { case AppendFastStrategy => new AppendRankFunction( minIdleStateRetentionTime, maxIdleStateRetentionTime, inputRowTypeInfo, - outputRowTypeInfo, sortKeyType, - comparator, + sortKeyComparator, sortKeySelector, rankType, rankRange, - equaliser, + generatedEqualiser, generateRetraction, + outputRankNumber, cacheSize) case UpdateFastStrategy(primaryKeys) => @@ -174,15 +172,15 @@ class StreamExecRank( minIdleStateRetentionTime, maxIdleStateRetentionTime, inputRowTypeInfo, - outputRowTypeInfo, rowKeyType, rowKeySelector, - comparator, + sortKeyComparator, sortKeySelector, rankType, rankRange, - equaliser, + generatedEqualiser, generateRetraction, + outputRankNumber, cacheSize) // TODO UnaryUpdateRank after SortedMapState is merged @@ -191,20 +189,21 @@ class StreamExecRank( minIdleStateRetentionTime, maxIdleStateRetentionTime, inputRowTypeInfo, - outputRowTypeInfo, sortKeyType, - comparator, + sortKeyComparator, sortKeySelector, rankType, rankRange, - equaliser, - generateRetraction) + generatedEqualiser, + generateRetraction, + outputRankNumber) } val rankOpName = getOperatorName val operator = new KeyedProcessOperator(processFunction) - processFunction.setKeyContext(operator); + processFunction.setKeyContext(operator) val inputTransform = getInputNodes.get(0).translateToPlan(tableEnv) - .asInstanceOf[StreamTransformation[BaseRow]] + .asInstanceOf[StreamTransformation[BaseRow]] + val outputRowTypeInfo = FlinkTypeFactory.toInternalRowType(getRowType).toTypeInfo val ret = new OneInputTransformation( inputTransform, rankOpName, diff --git a/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/runtime/utils/FailingCollectionSource.java b/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/runtime/utils/FailingCollectionSource.java index 726f3d9ae09da..3d95e3c3a6407 100644 --- a/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/runtime/utils/FailingCollectionSource.java +++ b/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/runtime/utils/FailingCollectionSource.java @@ -102,8 +102,7 @@ public FailingCollectionSource( serializer.serialize(element, wrapper); count++; } - } - catch (Exception e) { + } catch (Exception e) { throw new IOException("Serializing the source elements failed: " + e.getMessage(), e); } @@ -157,12 +156,11 @@ public void run(SourceContext ctx) throws Exception { serializer.deserialize(input); toSkip--; } - } - catch (Exception e) { + } catch (Exception e) { throw new IOException( "Failed to deserialize an element from the source. " + - "If you are using user-defined serialization (Value and Writable types), check the " + - "serialization functions.\nSerializer is " + serializer); + "If you are using user-defined serialization (Value and Writable types), check the " + + "serialization functions.\nSerializer is " + serializer); } this.numElementsEmitted = this.numElementsToSkip; @@ -185,12 +183,11 @@ public void run(SourceContext ctx) throws Exception { T next; try { next = serializer.deserialize(input); - } - catch (Exception e) { + } catch (Exception e) { throw new IOException( "Failed to deserialize an element from the source. " + - "If you are using user-defined serialization (Value and Writable types), check the " + - "serialization functions.\nSerializer is " + serializer); + "If you are using user-defined serialization (Value and Writable types), check the " + + "serialization functions.\nSerializer is " + serializer); } synchronized (ctx.getCheckpointLock()) { diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/DeduplicateITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/DeduplicateITCase.scala index 78744516d88b0..b218e0e5a6fb9 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/DeduplicateITCase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/DeduplicateITCase.scala @@ -38,7 +38,7 @@ class DeduplicateITCase(miniBatch: MiniBatchMode, mode: StateBackendMode) @Test def testFirstRowOnProctime(): Unit = { val t = failingDataSource(StreamTestData.get3TupleData) - .toTable(tEnv, 'a, 'b, 'c, 'proctime) + .toTable(tEnv, 'a, 'b, 'c, 'proctime) tEnv.registerTable("T", t) val sql = @@ -52,19 +52,19 @@ class DeduplicateITCase(miniBatch: MiniBatchMode, mode: StateBackendMode) |WHERE rowNum = 1 """.stripMargin - val sink = new TestingAppendSink - tEnv.sqlQuery(sql).toAppendStream[Row].addSink(sink) + val sink = new TestingRetractSink + tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink) env.execute() val expected = List("1,1,Hi", "2,2,Hello", "4,3,Hello world, how are you?", - "7,4,Comment#1", "11,5,Comment#5", "16,6,Comment#10") - assertEquals(expected.sorted, sink.getAppendResults.sorted) + "7,4,Comment#1", "11,5,Comment#5", "16,6,Comment#10") + assertEquals(expected.sorted, sink.getRetractResults.sorted) } @Test def testLastRowOnProctime(): Unit = { val t = failingDataSource(StreamTestData.get3TupleData) - .toTable(tEnv, 'a, 'b, 'c, 'proctime) + .toTable(tEnv, 'a, 'b, 'c, 'proctime) tEnv.registerTable("T", t) val sql = @@ -83,10 +83,11 @@ class DeduplicateITCase(miniBatch: MiniBatchMode, mode: StateBackendMode) env.execute() val expected = List("1,1,Hi", "3,2,Hello world", "6,3,Luke Skywalker", - "10,4,Comment#4", "15,5,Comment#9", "21,6,Comment#15") + "10,4,Comment#4", "15,5,Comment#9", "21,6,Comment#15") assertEquals(expected.sorted, sink.getRetractResults.sorted) } + // TODO Deduplicate does not support sort on rowtime now, so it is translated to Rank currently @Test def testFirstRowOnRowtime(): Unit = { val data = List( @@ -102,9 +103,9 @@ class DeduplicateITCase(miniBatch: MiniBatchMode, mode: StateBackendMode) (4L, 3L, "Helloworld, how are you?", 4)) val t = failingDataSource(data) - .assignTimestampsAndWatermarks( - new TimestampAndWatermarkWithOffset[(Long, Long, String, Int)](10L)) - .toTable(tEnv, 'rowtime, 'key, 'str, 'int) + .assignTimestampsAndWatermarks( + new TimestampAndWatermarkWithOffset[(Long, Long, String, Int)](10L)) + .toTable(tEnv, 'rowtime, 'key, 'str, 'int) tEnv.registerTable("T", t) val sql = @@ -122,12 +123,12 @@ class DeduplicateITCase(miniBatch: MiniBatchMode, mode: StateBackendMode) val table = tEnv.sqlQuery(sql) writeToSink(table, sink) - // TODO Deduplicate does not support sort on rowtime now, so it is translated to Rank currently env.execute() val expected = List("1,Hi,1", "2,Hello,2", "3,Helloworld, how are you?,4", "4,Comment#1,7") assertEquals(expected.sorted, sink.getUpsertResults.sorted) } + // TODO Deduplicate does not support sort on rowtime now, so it is translated to Rank currently @Test def testLastRowOnRowtime(): Unit = { val data = List( @@ -143,9 +144,9 @@ class DeduplicateITCase(miniBatch: MiniBatchMode, mode: StateBackendMode) (4L, 3L, "Helloworld, how are you?", 4)) val t = failingDataSource(data) - .assignTimestampsAndWatermarks( - new TimestampAndWatermarkWithOffset[(Long, Long, String, Int)](10L)) - .toTable(tEnv, 'rowtime, 'key, 'str, 'int) + .assignTimestampsAndWatermarks( + new TimestampAndWatermarkWithOffset[(Long, Long, String, Int)](10L)) + .toTable(tEnv, 'rowtime, 'key, 'str, 'int) tEnv.registerTable("T", t) val sql = @@ -163,7 +164,6 @@ class DeduplicateITCase(miniBatch: MiniBatchMode, mode: StateBackendMode) val table = tEnv.sqlQuery(sql) writeToSink(table, sink) - // TODO Deduplicate does not support sort on rowtime now, so it is translated to Rank currently env.execute() val expected = List("1,Hi,1", "2,Hello world,3", "3,Luke Skywalker,6", "4,Comment#4,10") assertEquals(expected.sorted, sink.getUpsertResults.sorted) diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/RankITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/RankITCase.scala index 30577b156b029..06d280f5e636a 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/RankITCase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/RankITCase.scala @@ -22,7 +22,7 @@ import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} import org.apache.flink.api.java.typeutils.RowTypeInfo import org.apache.flink.api.scala._ import org.apache.flink.table.api.scala._ -import org.apache.flink.table.api.TableConfigOptions +import org.apache.flink.table.api.{TableConfigOptions, TableException} import org.apache.flink.table.runtime.utils.StreamingWithStateTestBase.StateBackendMode import org.apache.flink.table.runtime.utils.{TestingRetractTableSink, TestingUpsertTableSink, _} import org.apache.flink.types.Row @@ -103,7 +103,8 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode assertEquals(expected.sorted, sink.getRetractResults.sorted) } - @Ignore("Enable after retraction infer is introduced") + // FIXME + @Ignore("Enable after retraction infer (FLINK-12098) is introduced") @Test def testTopNWithUpsertSink(): Unit = { val data = List( @@ -140,8 +141,9 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode assertEquals(expected.sorted, sink.getUpsertResults.sorted) } - @Ignore("Enable after streamAgg implements StreamExecNode") - @Test + // FIXME + @Ignore("Enable after agg rules added and SortedMapState is supported") + @Test(expected = classOf[TableException]) def testTopNWithUnary(): Unit = { val data = List( ("book", 11, 100), @@ -200,7 +202,8 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode assertEquals(updatedExpected.sorted, sink.getUpsertResults.sorted) } - @Ignore("Enable when state support SortedMapState") + // FIXME + @Ignore("Enable after agg rules added and SortedMapState is supported") @Test def testUnarySortTopNOnString(): Unit = { val data = List( @@ -260,7 +263,8 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode assertEquals(updatedExpected.sorted, sink.getUpsertResults.sorted) } - @Ignore("Enable after streamAgg implements StreamExecNode") + // FIXME + @Ignore("Enable after agg rules added") @Test def testTopNWithGroupBy(): Unit = { val data = List( @@ -303,7 +307,8 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode assertEquals(updatedExpected.sorted, sink.getUpsertResults.sorted) } - @Ignore("Enable after streamAgg implements StreamExecNode") + // FIXME + @Ignore("Enable after agg rules added") @Test def testTopNWithSumAndCondition(): Unit = { val data = List( @@ -353,7 +358,8 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode assertEquals(updatedExpected.sorted, sink.getUpsertResults.sorted) } - @Ignore("Enable after streamAgg implements StreamExecNode") + // FIXME + @Ignore("Enable after agg rules added") @Test def testTopNthWithGroupBy(): Unit = { val data = List( @@ -395,7 +401,8 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode assertEquals(updatedExpected.sorted, tableSink.getUpsertResults.sorted) } - @Ignore("Enable after streamAgg implements StreamExecNode") + // FIXME + @Ignore("Enable after agg rules added") @Test def testTopNWithGroupByAndRetract(): Unit = { val data = List( @@ -437,7 +444,9 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode assertEquals(expected.sorted, sink.getRetractResults.sorted) } - @Ignore("Enable after streamAgg implements StreamExecNode") + // FIXME + @Ignore("Enable after agg rules added") + @Test def testTopNthWithGroupByAndRetract(): Unit = { val data = List( ("book", 1, 11), @@ -476,7 +485,8 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode assertEquals(expected.sorted, sink.getRetractResults.sorted) } - @Ignore("Enable after streamAgg implements StreamExecNode") + // FIXME + @Ignore("Enable after agg rules added") @Test def testTopNWithGroupByCount(): Unit = { val data = List( @@ -532,7 +542,8 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode assertEquals(expected.sorted, tableSink.getUpsertResults.sorted) } - @Ignore("Enable after streamAgg implements StreamExecNode") + // FIXME + @Ignore("Enable after agg rules added") @Test def testTopNthWithGroupByCount(): Unit = { val data = List( @@ -583,7 +594,8 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode assertEquals(expected.sorted, tableSink.getUpsertResults.sorted) } - @Ignore("Enable after streamAgg implements StreamExecNode") + // FIXME + @Ignore("Enable after agg rules added") @Test def testNestedTopN(): Unit = { val data = List( @@ -654,7 +666,8 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode assertEquals(expected2, tableSink.getUpsertResults.sorted) } - @Ignore("Enable after streamAgg implements StreamExecNode") + // FIXME + @Ignore("Enable after agg rules added") @Test def testTopNWithoutDeduplicate(): Unit = { val data = List( @@ -717,7 +730,8 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode assertEquals(expected, tableSink.getRawResults) } - @Ignore("Enable after streamAgg implements StreamExecNode") + // FIXME + @Ignore("Enable after agg rules added") @Test def testTopNWithVariableTopSize(): Unit = { val data = List( @@ -772,30 +786,31 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode assertEquals(expected.sorted, tableSink.getUpsertResults.sorted) } - @Ignore("Enable after streamAgg implements StreamExecNode") + // FIXME + @Ignore("Enable after agg rules added") @Test def testTopNUnaryComplexScenario(): Unit = { val data = List( ("book", 1, 11), ("book", 2, 19), ("book", 4, 13), - ("book", 1, 11), // backward update in heap - ("book", 3, 23), // elems exceed topn size after insert - ("book", 5, 19), // sort map shirk out some elem after insert - ("book", 7, 10), // sort map keeps a little more than topn size elems after insert - ("book", 8, 13), // sort map now can shrink out-of-range elems after another insert - ("book", 10, 13), // Once again, sort map keeps a little more elems after insert - ("book", 8, 6), // backward update from heap to state - ("book", 10, 6), // backward update from heap to state, and sort map load more data - ("book", 5, 3), // backward update from heap to state - ("book", 10, 1), // backward update in heap, and then sort map shrink some data - ("book", 5, 1), // backward update in state - ("book", 5, -3), // forward update in state - ("book", 2, -10), // forward update in heap, and then sort map shrink some data - ("book", 10, -7), // forward update from state to heap - ("book", 11, 13), // insert into heap - ("book", 12, 10), // insert into heap, and sort map shrink some data - ("book", 15, 14) // insert into state + ("book", 1, 11), // backward update in heap + ("book", 3, 23), // elems exceed topn size after insert + ("book", 5, 19), // sort map shirk out some elem after insert + ("book", 7, 10), // sort map keeps a little more than topn size elems after insert + ("book", 8, 13), // sort map now can shrink out-of-range elems after another insert + ("book", 10, 13), // Once again, sort map keeps a little more elems after insert + ("book", 8, 6), // backward update from heap to state + ("book", 10, 6), // backward update from heap to state, and sort map load more data + ("book", 5, 3), // backward update from heap to state + ("book", 10, 1), // backward update in heap, and then sort map shrink some data + ("book", 5, 1), // backward update in state + ("book", 5, -3), // forward update in state + ("book", 2, -10), // forward update in heap, and then sort map shrink some data + ("book", 10, -7), // forward update from state to heap + ("book", 11, 13), // insert into heap + ("book", 12, 10), // insert into heap, and sort map shrink some data + ("book", 15, 14) // insert into state ) env.setParallelism(1) @@ -860,7 +875,8 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode assertEquals(updatedExpected.sorted, tableSink.getUpsertResults.sorted) } - @Ignore("Enable after streamAgg implements StreamExecNode") + // FIXME + @Ignore("Enable after agg rules added") @Test def testTopNWithGroupByAvgWithoutRowNumber(): Unit = { val data = List( @@ -922,7 +938,8 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode assertEquals(updatedExpected.sorted, tableSink.getUpsertResults.sorted) } - @Ignore("Enable after streamAgg implements StreamExecNode") + // FIXME + @Ignore("Enable after agg rules added") @Test def testTopNWithGroupByCountWithoutRowNumber(): Unit = { val data = List( @@ -1050,7 +1067,8 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode assertEquals(updatedExpected.sorted, tableSink.getUpsertResults.sorted) } - @Ignore("Enable after streamAgg implements StreamExecNode") + // FIXME + @Ignore("Enable after agg rules added") @Test def testMultipleRetractTopNAfterAgg(): Unit = { val data = List( @@ -1118,7 +1136,8 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode assertEquals(expected2.sorted, sink2.getRetractResults.sorted) } - @Ignore("Enable after streamAgg implements StreamExecNode") + // FIXME + @Ignore("Enable after agg rules added") @Test def testMultipleUnaryTopNAfterAgg(): Unit = { val data = List( @@ -1184,7 +1203,8 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode assertEquals(expected2.sorted, sink2.getUpsertResults.sorted) } - @Ignore("Enable after streamAgg implements StreamExecNode") + // FIXME + @Ignore("Enable after agg rules added") @Test def testMultipleUpdateTopNAfterAgg(): Unit = { val data = List( @@ -1250,7 +1270,8 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode assertEquals(expected2.sorted, sink2.getRetractResults.sorted) } - @Ignore("Enable after streamAgg implements StreamExecNode") + // FIXME + @Ignore("Enable after agg rules added") @Test def testUpdateRank(): Unit = { val data = List( @@ -1272,7 +1293,8 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode assertEquals(expected.sorted, sink.getRetractResults.sorted) } - @Ignore("Enable after streamAgg implements StreamExecNode") + // FIXME + @Ignore("Enable after agg rules added") @Test def testUpdateRankWithOffset(): Unit = { val data = List( diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamTestSink.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamTestSink.scala index 7c32d5c2a5bee..e3fc360a90ddd 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamTestSink.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamTestSink.scala @@ -61,7 +61,7 @@ object StreamTestSink { private[utils] def getNewSinkId: Int = { val idx = idCounter.getAndIncrement() - this.synchronized{ + this.synchronized { globalResults.put(idx, mutable.HashMap.empty[Int, ArrayBuffer[String]]) globalRetractResults.put(idx, mutable.HashMap.empty[Int, ArrayBuffer[String]]) globalUpsertResults.put(idx, mutable.HashMap.empty[Int, mutable.Map[String, String]]) @@ -81,7 +81,7 @@ abstract class AbstractExactlyOnceSink[T] extends RichSinkFunction[T] with Check protected var localResults: ArrayBuffer[String] = _ protected val idx: Int = StreamTestSink.getNewSinkId - protected var globalResults: mutable.Map[Int, ArrayBuffer[String]]= _ + protected var globalResults: mutable.Map[Int, ArrayBuffer[String]] = _ protected var globalRetractResults: mutable.Map[Int, ArrayBuffer[String]] = _ protected var globalUpsertResults: mutable.Map[Int, mutable.Map[String, String]] = _ @@ -112,7 +112,7 @@ abstract class AbstractExactlyOnceSink[T] extends RichSinkFunction[T] with Check protected def clearAndStashGlobalResults(): Unit = { if (globalResults == null) { - StreamTestSink.synchronized{ + StreamTestSink.synchronized { globalResults = StreamTestSink.globalResults.remove(idx).get globalRetractResults = StreamTestSink.globalRetractResults.remove(idx).get globalUpsertResults = StreamTestSink.globalUpsertResults.remove(idx).get @@ -149,7 +149,9 @@ final class TestingAppendSink(tz: TimeZone) extends AbstractExactlyOnceSink[Row] def this() { this(TimeZone.getTimeZone("UTC")) } + def invoke(value: Row): Unit = localResults += TestSinkUtil.rowToString(value, tz) + def getAppendResults: List[String] = getResults } @@ -194,7 +196,7 @@ final class TestingUpsertSink(keys: Array[Int], tz: TimeZone) } val taskId = getRuntimeContext.getIndexOfThisSubtask - StreamTestSink.synchronized{ + StreamTestSink.synchronized { StreamTestSink.globalUpsertResults(idx) += (taskId -> localUpsertResults) } } @@ -216,7 +218,7 @@ final class TestingUpsertSink(keys: Array[Int], tz: TimeZone) val converter = DataFormatConverters.getConverterForTypeInfo( new TupleTypeInfo(Types.BOOLEAN, new RowTypeInfo(fieldTypes: _*))) - .asInstanceOf[DataFormatConverters.DataFormatConverter[BaseRow, JTuple2[JBoolean, Row]]] + .asInstanceOf[DataFormatConverters.DataFormatConverter[BaseRow, JTuple2[JBoolean, Row]]] val v = converter.toExternal(wrapRow) val rowString = TestSinkUtil.rowToString(v.f1, tz) val tupleString = "(" + v.f0.toString + "," + rowString + ")" @@ -228,8 +230,8 @@ final class TestingUpsertSink(keys: Array[Int], tz: TimeZone) val oldValue = localUpsertResults.remove(keyString) if (oldValue.isEmpty) { throw new RuntimeException("Tried to delete a value that wasn't inserted first. " + - "This is probably an incorrectly implemented test. " + - "Try to set the parallelism of the sink to 1.") + "This is probably an incorrectly implemented test. " + + "Try to set the parallelism of the sink to 1.") } } } @@ -278,15 +280,15 @@ final class TestingUpsertTableSink(keys: Array[Int], tz: TimeZone) (value.f0, value.f1) } }) - .addSink(sink) - .name(s"TestingUpsertTableSink(keys=${ - if (keys != null) { - "(" + keys.mkString(",") + ")" - } else { - "null" - } - })") - .setParallelism(1) + .addSink(sink) + .name(s"TestingUpsertTableSink(keys=${ + if (keys != null) { + "(" + keys.mkString(",") + ")" + } else { + "null" + } + })") + .setParallelism(1) } override def configure( @@ -307,7 +309,7 @@ final class TestingUpsertTableSink(keys: Array[Int], tz: TimeZone) } final class TestingAppendTableSink(tz: TimeZone) extends AppendStreamTableSink[Row] - with BatchTableSink[Row]{ + with BatchTableSink[Row] { var fNames: Array[String] = _ var fTypes: Array[TypeInformation[_]] = _ var sink = new TestingAppendSink(tz) @@ -319,7 +321,7 @@ final class TestingAppendTableSink(tz: TimeZone) extends AppendStreamTableSink[R override def emitDataStream(dataStream: DataStream[Row]): DataStreamSink[Row] = { dataStream.addSink(sink).name("TestingAppendTableSink") - .setParallelism(dataStream.getParallelism) + .setParallelism(dataStream.getParallelism) } override def emitBoundedStream( @@ -367,23 +369,25 @@ class TestingOutputFormat[T](tz: TimeZone) def open(taskNumber: Int, numTasks: Int): Unit = { localRetractResults = mutable.ArrayBuffer.empty[String] - StreamTestSink.synchronized{ + StreamTestSink.synchronized { StreamTestSink.globalResults(index) += (taskNumber -> localRetractResults) } } - def writeRecord(value: T): Unit = localRetractResults += { value match { - case r: Row => TestSinkUtil.rowToString(r, tz) - case tp: JTuple2[java.lang.Boolean, Row] => - "(" + tp.f0.toString + "," + TestSinkUtil.rowToString(tp.f1, tz) + ")" - case _ => "" - }} + def writeRecord(value: T): Unit = localRetractResults += { + value match { + case r: Row => TestSinkUtil.rowToString(r, tz) + case tp: JTuple2[java.lang.Boolean, Row] => + "(" + tp.f0.toString + "," + TestSinkUtil.rowToString(tp.f1, tz) + ")" + case _ => "" + } + } def close(): Unit = {} protected def clearAndStashGlobalResults(): Unit = { if (globalResults == null) { - StreamTestSink.synchronized{ + StreamTestSink.synchronized { globalResults = StreamTestSink.globalResults.remove(index).get } } @@ -422,7 +426,7 @@ class TestingRetractSink(tz: TimeZone) } val taskId = getRuntimeContext.getIndexOfThisSubtask - StreamTestSink.synchronized{ + StreamTestSink.synchronized { StreamTestSink.globalRetractResults(idx) += (taskId -> localRetractResults) } } @@ -448,8 +452,8 @@ class TestingRetractSink(tz: TimeZone) localRetractResults.remove(index) } else { throw new RuntimeException("Tried to retract a value that wasn't added first. " + - "This is probably an incorrectly implemented test. " + - "Try to set the parallelism of the sink to 1.") + "This is probably an incorrectly implemented test. " + + "Try to set the parallelism of the sink to 1.") } } } @@ -483,9 +487,9 @@ final class TestingRetractTableSink(tz: TimeZone) extends RetractStreamTableSink (value.f0, value.f1) } }).setParallelism(dataStream.getParallelism) - .addSink(sink) - .name("TestingRetractTableSink") - .setParallelism(dataStream.getParallelism) + .addSink(sink) + .name("TestingRetractTableSink") + .setParallelism(dataStream.getParallelism) } override def getRecordType: TypeInformation[Row] = diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingWithMiniBatchTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingWithMiniBatchTestBase.scala index d502b902f2d86..56c2593cd9b48 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingWithMiniBatchTestBase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingWithMiniBatchTestBase.scala @@ -50,7 +50,7 @@ object StreamingWithMiniBatchTestBase { case class MiniBatchMode(on: Boolean) { override def toString: String = { - if (on){ + if (on) { "MiniBatch=ON" } else { "MiniBatch=OFF" diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingWithStateTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingWithStateTestBase.scala index 602edd84e4666..ae80d9c0525bf 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingWithStateTestBase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingWithStateTestBase.scala @@ -97,7 +97,7 @@ class StreamingWithStateTestBase(state: StateBackendMode) extends StreamingTestB case Types.INT => writer.writeInt(i, p.productElement(i).asInstanceOf[Int]) case Types.LONG => writer.writeLong(i, p.productElement(i).asInstanceOf[Long]) case Types.STRING => writer.writeString(i, - p.productElement(i).asInstanceOf[BinaryString]) + p.productElement(i).asInstanceOf[BinaryString]) case Types.BOOLEAN => writer.writeBoolean(i, p.productElement(i).asInstanceOf[Boolean]) } } @@ -227,7 +227,7 @@ class StreamingWithStateTestBase(state: StateBackendMode) extends StreamingTestB def appendStrToMap(ss: CharSequence, m: Map[String, String]): Unit = { val equalsIdxs = findEquals(ss) - equalsIdxs.foreach (idx => m + splitKV(ss, idx)) + equalsIdxs.foreach(idx => m + splitKV(ss, idx)) } while (idx < l) { diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TableUtil.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TableUtil.scala index 5bcc372091395..1522bb63d38ec 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TableUtil.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TableUtil.scala @@ -36,25 +36,25 @@ object TableUtil { * Note: The difference between print() and collect() is * - print() prints data on workers and collect() collects data to the client. * - You have to call TableEnvironment.execute() to run the job for print(), while collect() - * calls execute automatically. + * calls execute automatically. */ def collect(table: TableImpl): Seq[Row] = collectSink(table, new CollectRowTableSink, None) def collect(table: TableImpl, jobName: String): Seq[Row] = collectSink(table, new CollectRowTableSink, Option.apply(jobName)) - def collectAsT[T](table: TableImpl, t: TypeInformation[_], jobName : String = null): Seq[T] = + def collectAsT[T](table: TableImpl, t: TypeInformation[_], jobName: String = null): Seq[T] = collectSink( table, new CollectTableSink(_ => t.asInstanceOf[TypeInformation[T]]), Option(jobName)) def collectSink[T]( - table: TableImpl, sink: CollectTableSink[T], jobName : Option[String] = None): Seq[T] = { + table: TableImpl, sink: CollectTableSink[T], jobName: Option[String] = None): Seq[T] = { // get schema information of table val rowType = table.getRelNode.getRowType val fieldNames = rowType.getFieldNames.asScala.toArray val fieldTypes = rowType.getFieldList - .map(field => FlinkTypeFactory.toInternalType(field.getType)).toArray + .map(field => FlinkTypeFactory.toInternalType(field.getType)).toArray val configuredSink = sink.configure( fieldNames, fieldTypes.map(createExternalTypeInfoFromInternalType)) BatchTableEnvUtil.collect(table.tableEnv.asInstanceOf[BatchTableEnvironment], @@ -66,7 +66,7 @@ object TableUtil { val rowType = table.getRelNode.getRowType val fieldNames = rowType.getFieldNames.asScala.toArray val fieldTypes = rowType.getFieldList - .map(field => FlinkTypeFactory.toInternalType(field.getType)).toArray + .map(field => FlinkTypeFactory.toInternalType(field.getType)).toArray val configuredSink = sink.configure( fieldNames, fieldTypes.map(createExternalTypeInfoFromInternalType)) table.tableEnv.writeToSink(table, configuredSink) diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TimeTestUtil.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TimeTestUtil.scala index df1dd3f435c52..f93c2c2425b69 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TimeTestUtil.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TimeTestUtil.scala @@ -53,7 +53,7 @@ object TimeTestUtil { } class EventTimeProcessOperator[T] - extends AbstractStreamOperator[T] + extends AbstractStreamOperator[T] with OneInputStreamOperator[Either[(Long, T), Long], T] { override def processElement(element: StreamRecord[Either[(Long, T), Long]]): Unit = { @@ -64,4 +64,5 @@ object TimeTestUtil { } } + } diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryRow.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryRow.java index faa95fbb19a59..e52938960bd3c 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryRow.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryRow.java @@ -220,7 +220,7 @@ public void setDecimal(int pos, Decimal value, int precision) { } else { byte[] bytes = value.toUnscaledBytes(); - assert(bytes.length <= 16); + assert bytes.length <= 16; // Write the bytes to the variable length portion. SegmentsUtil.copyFromBytes(segments, offset + cursor, bytes, 0, bytes.length); @@ -437,9 +437,9 @@ private boolean equalsFrom(Object o, int startIndex) { if (o != null && o instanceof BinaryRow) { BinaryRow other = (BinaryRow) o; return sizeInBytes == other.sizeInBytes && - SegmentsUtil.equals( - segments, offset + startIndex, - other.segments, other.offset + startIndex, sizeInBytes - startIndex); + SegmentsUtil.equals( + segments, offset + startIndex, + other.segments, other.offset + startIndex, sizeInBytes - startIndex); } else { return false; } diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryWriter.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryWriter.java index 0adbfc2aa23e4..aafcb3feeab9c 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryWriter.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryWriter.java @@ -18,6 +18,7 @@ package org.apache.flink.table.dataformat; import org.apache.flink.table.type.ArrayType; +import org.apache.flink.table.type.DateType; import org.apache.flink.table.type.DecimalType; import org.apache.flink.table.type.GenericType; import org.apache.flink.table.type.InternalType; @@ -99,7 +100,7 @@ static void write(BinaryWriter writer, int pos, Object o, InternalType type) { writer.writeString(pos, (BinaryString) o); } else if (type.equals(InternalTypes.CHAR)) { writer.writeChar(pos, (char) o); - } else if (type.equals(InternalTypes.DATE)) { + } else if (type instanceof DateType) { writer.writeInt(pos, (int) o); } else if (type.equals(InternalTypes.TIME)) { writer.writeInt(pos, (int) o); diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/TypeGetterSetters.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/TypeGetterSetters.java index 71c323f5914db..ab263c40c3adc 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/TypeGetterSetters.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/TypeGetterSetters.java @@ -18,6 +18,7 @@ package org.apache.flink.table.dataformat; import org.apache.flink.table.type.ArrayType; +import org.apache.flink.table.type.DateType; import org.apache.flink.table.type.DecimalType; import org.apache.flink.table.type.GenericType; import org.apache.flink.table.type.InternalType; @@ -194,7 +195,7 @@ static Object get(TypeGetterSetters row, int ordinal, InternalType type) { return row.getString(ordinal); } else if (type.equals(InternalTypes.CHAR)) { return row.getChar(ordinal); - } else if (type.equals(InternalTypes.DATE)) { + } else if (type instanceof DateType) { return row.getInt(ordinal); } else if (type.equals(InternalTypes.TIME)) { return row.getInt(ordinal); diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunction.java index 9cd3c1d4b299b..65dd9bf2ab028 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunction.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunction.java @@ -32,13 +32,14 @@ * This function is used to deduplicate on keys and keeps only first row or last row. */ public class DeduplicateFunction - extends KeyedProcessFunctionWithCleanupState - implements DeduplicateFunctionBase { + extends KeyedProcessFunctionWithCleanupState { + + private static final long serialVersionUID = 4950071982706870944L; private final BaseRowTypeInfo rowTypeInfo; private final boolean generateRetraction; private final boolean keepLastRow; - protected ValueState pkRow; + private ValueState pkRow; private GeneratedRecordEqualiser generatedEqualiser; private transient RecordEqualiser equaliser; @@ -63,7 +64,10 @@ public void open(Configuration configure) throws Exception { initCleanupTimeState(stateName); ValueStateDescriptor rowStateDesc = new ValueStateDescriptor("rowState", rowTypeInfo); pkRow = getRuntimeContext().getState(rowStateDesc); + + // compile equaliser equaliser = generatedEqualiser.newInstance(getRuntimeContext().getUserCodeClassLoader()); + generatedEqualiser = null; } @Override @@ -74,9 +78,11 @@ public void processElement(BaseRow input, Context ctx, Collector out) t BaseRow preRow = pkRow.value(); if (keepLastRow) { - processLastRow(preRow, input, generateRetraction, stateCleaningEnabled, pkRow, equaliser, out); + DeduplicateFunctionHelper + .processLastRow(preRow, input, generateRetraction, stateCleaningEnabled, pkRow, equaliser, out); } else { - processFirstRow(preRow, input, generateRetraction, stateCleaningEnabled, pkRow, equaliser, out); + DeduplicateFunctionHelper + .processFirstRow(preRow, input, pkRow, out); } } diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionBase.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionHelper.java similarity index 78% rename from flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionBase.java rename to flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionHelper.java index 22293ae75beb0..7a77e16748fc7 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionBase.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionHelper.java @@ -26,18 +26,18 @@ import org.apache.flink.util.Preconditions; /** - * Base class to deduplicate on keys and keeps only first row or last row. + * Utility for deduplicate function. */ -public interface DeduplicateFunctionBase { +class DeduplicateFunctionHelper { - default void processLastRow(BaseRow preRow, BaseRow currentRow, boolean generateRetraction, + static void processLastRow(BaseRow preRow, BaseRow currentRow, boolean generateRetraction, boolean stateCleaningEnabled, ValueState pkRow, RecordEqualiser equaliser, Collector out) throws Exception { // should be accumulate msg. Preconditions.checkArgument(BaseRowUtil.isAccumulateMsg(currentRow)); // ignore same record if (!stateCleaningEnabled && preRow != null && - equaliser.equalsWithoutHeader(preRow, currentRow)) { + equaliser.equalsWithoutHeader(preRow, currentRow)) { return; } pkRow.update(currentRow); @@ -48,13 +48,12 @@ default void processLastRow(BaseRow preRow, BaseRow currentRow, boolean generate out.collect(currentRow); } - default void processFirstRow(BaseRow preRow, BaseRow currentRow, boolean generateRetraction, - boolean stateCleaningEnabled, ValueState pkRow, RecordEqualiser equaliser, + static void processFirstRow(BaseRow preRow, BaseRow currentRow, ValueState pkRow, Collector out) throws Exception { // should be accumulate msg. Preconditions.checkArgument(BaseRowUtil.isAccumulateMsg(currentRow)); // ignore record with timestamp bigger than preRow - if (!isFirstRow(preRow)) { + if (preRow != null) { return; } @@ -62,8 +61,7 @@ default void processFirstRow(BaseRow preRow, BaseRow currentRow, boolean generat out.collect(currentRow); } - default boolean isFirstRow(BaseRow preRow) { - return preRow == null; - } + private DeduplicateFunctionHelper() { + } } diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateFunction.java index cff33d77948e7..ebb61b874fd90 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateFunction.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateFunction.java @@ -39,8 +39,7 @@ * mode. */ public class MiniBatchDeduplicateFunction - extends MapBundleFunction - implements DeduplicateFunctionBase { + extends MapBundleFunction { private BaseRowTypeInfo rowTypeInfo; private boolean generateRetraction; @@ -68,12 +67,15 @@ public void open(ExecutionContext ctx) throws Exception { super.open(ctx); ValueStateDescriptor rowStateDesc = new ValueStateDescriptor("rowState", rowTypeInfo); pkRow = ctx.getRuntimeContext().getState(rowStateDesc); + + // compile equaliser equaliser = generatedEqualiser.newInstance(ctx.getRuntimeContext().getUserCodeClassLoader()); + generatedEqualiser = null; } @Override public BaseRow addInput(@Nullable BaseRow value, BaseRow input) { - if (value == null || keepLastRow || (!keepLastRow && isFirstRow(value))) { + if (value == null || keepLastRow || (!keepLastRow && value == null)) { // put the input into buffer return ser.copy(input); } else { @@ -92,9 +94,11 @@ public void finishBundle( BaseRow preRow = pkRow.value(); if (keepLastRow) { - processLastRow(preRow, currentRow, generateRetraction, false, pkRow, equaliser, out); + DeduplicateFunctionHelper + .processLastRow(preRow, currentRow, generateRetraction, false, pkRow, equaliser, out); } else { - processFirstRow(preRow, currentRow, generateRetraction, false, pkRow, equaliser, out); + DeduplicateFunctionHelper + .processFirstRow(preRow, currentRow, pkRow, out); } } } diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/KeyedProcessFunctionWithCleanupState.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/KeyedProcessFunctionWithCleanupState.java index 66123be6b9ab0..2e3835202949d 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/KeyedProcessFunctionWithCleanupState.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/KeyedProcessFunctionWithCleanupState.java @@ -27,8 +27,9 @@ /** * A function that processes elements of a stream, and could cleanup state. + * * @param Type of the key. - * @param Type of the input elements. + * @param Type of the input elements. * @param Type of the output elements. */ public abstract class KeyedProcessFunctionWithCleanupState extends KeyedProcessFunction { @@ -58,8 +59,7 @@ protected void registerProcessingCleanupTimer(Context ctx, long currentTime) thr // last registered timer Long curCleanupTime = cleanupTimeState.value(); - // check if a cleanup timer is registered and - // that the current cleanup timer won't delete state we need to keep + // check if a cleanup timer is registered and that the current cleanup timer won't delete state we need to keep if (curCleanupTime == null || (currentTime + minRetentionTime) > curCleanupTime) { // we need to register a new (later) timer Long cleanupTime = currentTime + maxRetentionTime; diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/BinaryRowKeySelector.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/BinaryRowKeySelector.java index 6e3f1729cb8e8..2b51e7a2db55a 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/BinaryRowKeySelector.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/BinaryRowKeySelector.java @@ -25,7 +25,7 @@ import org.apache.flink.table.typeutils.BaseRowTypeInfo; /** - * A KeySelector which will extract key from BaseRow. + * A utility class which will extract key from BaseRow. */ public class BinaryRowKeySelector implements BaseRowKeySelector { diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/NullBinaryRowKeySelector.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/NullBinaryRowKeySelector.java index 3ca3ccaeaace7..ed83f9e2adfa5 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/NullBinaryRowKeySelector.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/NullBinaryRowKeySelector.java @@ -23,7 +23,7 @@ import org.apache.flink.table.typeutils.BaseRowTypeInfo; /** - * A KeySelector which key is always empty. + * A utility class which key is always empty. */ public class NullBinaryRowKeySelector implements BaseRowKeySelector { diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractRankFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractRankFunction.java index 8477ef6cc987e..edce4e12f70e2 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractRankFunction.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractRankFunction.java @@ -22,7 +22,6 @@ import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.functions.KeySelector; -import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.configuration.Configuration; import org.apache.flink.metrics.Counter; import org.apache.flink.metrics.Gauge; @@ -43,7 +42,6 @@ import java.util.Collection; import java.util.Comparator; -import java.util.Iterator; import java.util.Map; /** @@ -53,24 +51,33 @@ public abstract class AbstractRankFunction extends KeyedProcessFunctionWithClean private static final Logger LOG = LoggerFactory.getLogger(AbstractRankFunction.class); - // we set default topn size to 100 + // we set default topN size to 100 private static final long DEFAULT_TOPN_SIZE = 100; private final RankRange rankRange; - private final GeneratedRecordEqualiser generatedEqualiser; + + /** + * The util to compare two BaseRow equals to each other. + * As different BaseRow can't be equals directly, we use a code generated util to handle this. + */ + private GeneratedRecordEqualiser generatedEqualiser; + protected RecordEqualiser equaliser; + + /** + * The util to compare two sortKey equals to each other. + */ + private GeneratedRecordComparator generatedSortKeyComparator; + protected Comparator sortKeyComparator; + private final boolean generateRetraction; - protected final boolean isRowNumberAppend; - protected final RankType rankType; + protected final boolean outputRankNumber; protected final BaseRowTypeInfo inputRowType; - protected final BaseRowTypeInfo outputRowType; - protected final GeneratedRecordComparator generatedRecordComparator; protected final KeySelector sortKeySelector; protected KeyContext keyContext; - protected boolean isConstantRankEnd; + private boolean isConstantRankEnd; + private long rankStart = -1; protected long rankEnd = -1; - protected long rankStart = -1; - protected RecordEqualiser equaliser; private int rankEndIndex; private ValueState rankEndState; private Counter invalidCounter; @@ -80,15 +87,11 @@ public abstract class AbstractRankFunction extends KeyedProcessFunctionWithClean protected long hitCount = 0L; protected long requestCount = 0L; - public AbstractRankFunction( - long minRetentionTime, long maxRetentionTime, BaseRowTypeInfo inputRowType, BaseRowTypeInfo outputRowType, - GeneratedRecordComparator generatedRecordComparator, KeySelector sortKeySelector, - RankType rankType, RankRange rankRange, GeneratedRecordEqualiser generatedEqualiser, boolean generateRetraction) { + AbstractRankFunction(long minRetentionTime, long maxRetentionTime, BaseRowTypeInfo inputRowType, + GeneratedRecordComparator generatedSortKeyComparator, KeySelector sortKeySelector, + RankType rankType, RankRange rankRange, GeneratedRecordEqualiser generatedEqualiser, + boolean generateRetraction, boolean outputRankNumber) { super(minRetentionTime, maxRetentionTime); - this.rankRange = rankRange; - this.generatedEqualiser = generatedEqualiser; - this.generateRetraction = generateRetraction; - this.rankType = rankType; // TODO support RANK and DENSE_RANK switch (rankType) { case ROW_NUMBER: @@ -103,10 +106,12 @@ public AbstractRankFunction( LOG.error("Streaming tables do not support {}", rankType.name()); throw new UnsupportedOperationException("Streaming tables do not support " + rankType.toString()); } + this.rankRange = rankRange; + this.generatedEqualiser = generatedEqualiser; + this.generatedSortKeyComparator = generatedSortKeyComparator; + this.generateRetraction = generateRetraction; this.inputRowType = inputRowType; - this.outputRowType = outputRowType; - this.isRowNumberAppend = inputRowType.getArity() + 1 == outputRowType.getArity(); - this.generatedRecordComparator = generatedRecordComparator; + this.outputRankNumber = outputRankNumber; this.sortKeySelector = sortKeySelector; } @@ -129,14 +134,32 @@ public void open(Configuration parameters) throws Exception { ValueStateDescriptor rankStateDesc = new ValueStateDescriptor("rankEnd", Types.LONG); rankEndState = getRuntimeContext().getState(rankStateDesc); } + + // compile equaliser equaliser = generatedEqualiser.newInstance(getRuntimeContext().getUserCodeClassLoader()); + generatedEqualiser = null; + // compile comparator + sortKeyComparator = generatedSortKeyComparator.newInstance(getRuntimeContext().getUserCodeClassLoader()); + generatedSortKeyComparator = null; invalidCounter = getRuntimeContext().getMetricGroup().counter("topn.invalidTopSize"); } - protected long getDefaultTopSize() { + /** + * Gets default topN size. + * + * @return default topN size + */ + protected long getDefaultTopNSize() { return isConstantRankEnd ? rankEnd : DEFAULT_TOPN_SIZE; } + /** + * Initialize rank end. + * + * @param row input record + * @return rank end + * @throws Exception + */ protected long initRankEnd(BaseRow row) throws Exception { if (isConstantRankEnd) { return rankEnd; @@ -150,8 +173,7 @@ protected long initRankEnd(BaseRow row) throws Exception { } else { rankEnd = rankEndValue; if (rankEnd != curRankEnd) { - // increment the invalid counter when the current rank end - // not equal to previous rank end + // increment the invalid counter when the current rank end not equal to previous rank end invalidCounter.inc(); } return rankEnd; @@ -159,51 +181,26 @@ protected long initRankEnd(BaseRow row) throws Exception { } } - protected Tuple2 rowNumber(K sortKey, BaseRow rowKey, SortedMap sortedMap) { - Iterator>> iterator = sortedMap.entrySet().iterator(); - int curRank = 1; - while (iterator.hasNext()) { - Map.Entry> entry = iterator.next(); - K curKey = entry.getKey(); - Collection rowKeys = entry.getValue(); - if (curKey.equals(sortKey)) { - Iterator rowKeysIter = rowKeys.iterator(); - int innerRank = 1; - while (rowKeysIter.hasNext()) { - if (rowKey.equals(rowKeysIter.next())) { - return Tuple2.of(curRank, innerRank); - } else { - innerRank += 1; - curRank += 1; - } - } - } else { - curRank += rowKeys.size(); - } - } - LOG.error("Failed to find the sortKey: {}, rowkey: {} in SortedMap. " + - "This should never happen", sortKey, rowKey); - throw new RuntimeException( - "Failed to find the sortKey, rowkey in SortedMap. This should never happen"); - } - /** - * return true if record should be put into sort map. + * Checks whether the record should be put into the buffer. + * + * @param sortKey sortKey to test + * @param buffer buffer to add + * @return true if the record should be put into the buffer. */ - protected boolean checkSortKeyInBufferRange(K sortKey, SortedMap sortedMap, Comparator sortKeyComparator) { - Map.Entry> worstEntry = sortedMap.lastEntry(); + protected boolean checkSortKeyInBufferRange(BaseRow sortKey, TopNBuffer buffer) { + Comparator comparator = buffer.getSortKeyComparator(); + Map.Entry> worstEntry = buffer.lastEntry(); if (worstEntry == null) { - // sort map is empty + // return true if the buffer is empty. return true; } else { - K worstKey = worstEntry.getKey(); - int compare = sortKeyComparator.compare(sortKey, worstKey); + BaseRow worstKey = worstEntry.getKey(); + int compare = comparator.compare(sortKey, worstKey); if (compare < 0) { return true; - } else if (sortedMap.getCurrentTopNum() < getMaxSortMapSize()) { - return true; } else { - return false; + return buffer.getCurrentTopNum() < getMaxSizeOfBuffer(); } } } @@ -237,8 +234,8 @@ protected void collect(Collector out, BaseRow inputRow) { } /** - * This is similar to [[retract()]] but always send retraction message regardless of - * generateRetraction is true or not. + * This is similar to [[retract()]] but always send retraction message regardless of generateRetraction is true or + * not. */ protected void delete(Collector out, BaseRow inputRow) { BaseRowUtil.setRetract(inputRow); @@ -279,8 +276,8 @@ protected boolean hasOffset() { return rankStart > 1; } - protected BaseRow createOutputRow(BaseRow inputRow, long rank, byte header) { - if (isRowNumberAppend) { + private BaseRow createOutputRow(BaseRow inputRow, long rank, byte header) { + if (outputRankNumber) { GenericRow rankRow = new GenericRow(1); rankRow.setField(0, rank); @@ -294,20 +291,14 @@ protected BaseRow createOutputRow(BaseRow inputRow, long rank, byte header) { } /** - * get sorted map size limit - * Implementations may vary depending on each rank who has in-memory sort map. - * @return + * Gets buffer size limit. Implementations may vary depending on each rank who has in-memory buffer. + * + * @return buffer size limit */ - protected abstract long getMaxSortMapSize(); - - @Override - public void processElement( - BaseRow input, Context ctx, Collector out) throws Exception { - - } + protected abstract long getMaxSizeOfBuffer(); /** - * Set keyContext to RankFunction. + * Sets keyContext to RankFunction. * * @param keyContext keyContext of current function. */ diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractUpdateRankFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractUpdateRankFunction.java deleted file mode 100644 index d54ce3af0afc9..0000000000000 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractUpdateRankFunction.java +++ /dev/null @@ -1,293 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.runtime.rank; - -import org.apache.flink.api.common.state.MapState; -import org.apache.flink.api.common.state.MapStateDescriptor; -import org.apache.flink.api.common.typeinfo.Types; -import org.apache.flink.api.java.functions.KeySelector; -import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.api.java.typeutils.TupleTypeInfo; -import org.apache.flink.configuration.Configuration; -import org.apache.flink.runtime.state.FunctionInitializationContext; -import org.apache.flink.runtime.state.FunctionSnapshotContext; -import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; -import org.apache.flink.table.dataformat.BaseRow; -import org.apache.flink.table.generated.GeneratedRecordComparator; -import org.apache.flink.table.generated.GeneratedRecordEqualiser; -import org.apache.flink.table.runtime.util.LRUMap; -import org.apache.flink.table.typeutils.BaseRowTypeInfo; -import org.apache.flink.util.Collector; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.Collection; -import java.util.Comparator; -import java.util.HashMap; -import java.util.Iterator; -import java.util.LinkedHashSet; -import java.util.Map; -import java.util.TreeMap; -import java.util.function.Supplier; - -/** - * Base class for Update Rank Function. - */ -abstract class AbstractUpdateRankFunction extends AbstractRankFunction - implements CheckpointedFunction { - - private static final Logger LOG = LoggerFactory.getLogger(AbstractUpdateRankFunction.class); - - private final BaseRowTypeInfo rowKeyType; - private final long cacheSize; - - // a map state stores mapping from row key to record which is in topN - // in tuple2, f0 is the record row, f1 is the index in the list of the same sort_key - // the f1 is used to preserve the record order in the same sort_key - protected transient MapState> dataState; - - // a sorted map stores mapping from sort key to rowkey list - protected transient SortedMap sortedMap; - - protected transient Map> kvSortedMap; - - // a HashMap stores mapping from rowkey to record, a heap mirror to dataState - protected transient Map rowKeyMap; - - protected transient LRUMap> kvRowKeyMap; - - protected Comparator sortKeyComparator; - - public AbstractUpdateRankFunction( - long minRetentionTime, - long maxRetentionTime, - BaseRowTypeInfo inputRowType, - BaseRowTypeInfo outputRowType, - BaseRowTypeInfo rowKeyType, - GeneratedRecordComparator generatedRecordComparator, - KeySelector sortKeySelector, - RankType rankType, - RankRange rankRange, - GeneratedRecordEqualiser generatedEqualiser, - boolean generateRetraction, - long cacheSize) { - super(minRetentionTime, - maxRetentionTime, - inputRowType, - outputRowType, - generatedRecordComparator, - sortKeySelector, - rankType, - rankRange, - generatedEqualiser, - generateRetraction); - this.rowKeyType = rowKeyType; - this.cacheSize = cacheSize; - } - - @Override - public void open(Configuration parameters) throws Exception { - super.open(parameters); - int lruCacheSize = Math.max(1, (int) (cacheSize / getMaxSortMapSize())); - // make sure the cached map is in a fixed size, avoid OOM - kvSortedMap = new HashMap<>(lruCacheSize); - kvRowKeyMap = new LRUMap<>(lruCacheSize, new CacheRemovalListener()); - - LOG.info("Top{} operator is using LRU caches key-size: {}", getMaxSortMapSize(), lruCacheSize); - - TupleTypeInfo> valueTypeInfo = new TupleTypeInfo<>(inputRowType, Types.INT); - MapStateDescriptor> mapStateDescriptor = new MapStateDescriptor( - "data-state-with-update", rowKeyType, valueTypeInfo); - dataState = getRuntimeContext().getMapState(mapStateDescriptor); - - // metrics - registerMetric(kvSortedMap.size() * getMaxSortMapSize()); - - sortKeyComparator = generatedRecordComparator.newInstance(getRuntimeContext().getUserCodeClassLoader()); - } - - @Override - public void onTimer( - long timestamp, - OnTimerContext ctx, - Collector out) throws Exception { - if (needToCleanupState(timestamp)) { - BaseRow partitionKey = (BaseRow) keyContext.getCurrentKey(); - // cleanup cache - kvRowKeyMap.remove(partitionKey); - kvSortedMap.remove(partitionKey); - cleanupState(dataState); - } - } - - @Override - public void snapshotState(FunctionSnapshotContext context) throws Exception { - Iterator>> iter = kvRowKeyMap.entrySet().iterator(); - while (iter.hasNext()) { - Map.Entry> entry = iter.next(); - BaseRow partitionKey = entry.getKey(); - Map currentRowKeyMap = entry.getValue(); - keyContext.setCurrentKey(partitionKey); - synchronizeState(currentRowKeyMap); - } - } - - @Override - public void initializeState(FunctionInitializationContext context) throws Exception { - // nothing to do - } - - protected void initHeapStates() throws Exception { - requestCount += 1; - BaseRow partitionKey = (BaseRow) keyContext.getCurrentKey(); - sortedMap = kvSortedMap.get(partitionKey); - rowKeyMap = kvRowKeyMap.get(partitionKey); - if (sortedMap == null) { - sortedMap = new SortedMap( - sortKeyComparator, - new Supplier>() { - - @Override - public Collection get() { - return new LinkedHashSet<>(); - } - }); - rowKeyMap = new HashMap<>(); - kvSortedMap.put(partitionKey, sortedMap); - kvRowKeyMap.put(partitionKey, rowKeyMap); - - // restore sorted map - Iterator>> iter = dataState.iterator(); - if (iter != null) { - // a temp map associate sort key to tuple2 - Map> tempSortedMap = new HashMap<>(); - while (iter.hasNext()) { - Map.Entry> entry = iter.next(); - BaseRow rowkey = entry.getKey(); - Tuple2 recordAndInnerRank = entry.getValue(); - BaseRow record = recordAndInnerRank.f0; - Integer innerRank = recordAndInnerRank.f1; - rowKeyMap.put(rowkey, new RankRow(record, innerRank, false)); - - // insert into temp sort map to preserve the record order in the same sort key - BaseRow sortKey = sortKeySelector.getKey(record); - TreeMap treeMap = tempSortedMap.get(sortKey); - if (treeMap == null) { - treeMap = new TreeMap<>(); - tempSortedMap.put(sortKey, treeMap); - } - treeMap.put(innerRank, rowkey); - } - - // build sorted map from the temp map - Iterator>> tempIter = - tempSortedMap.entrySet().iterator(); - while (tempIter.hasNext()) { - Map.Entry> entry = tempIter.next(); - BaseRow sortKey = entry.getKey(); - TreeMap treeMap = entry.getValue(); - Iterator> treeMapIter = treeMap.entrySet().iterator(); - while (treeMapIter.hasNext()) { - Map.Entry treeMapEntry = treeMapIter.next(); - Integer innerRank = treeMapEntry.getKey(); - BaseRow recordRowKey = treeMapEntry.getValue(); - int size = sortedMap.put(sortKey, recordRowKey); - if (innerRank != size) { - LOG.warn("Failed to build sorted map from state, this may result in wrong result. " + - "The sort key is {}, partition key is {}, " + - "treeMap is {}. The expected inner rank is {}, but current size is {}.", - sortKey, partitionKey, treeMap, innerRank, size); - } - } - } - } - } else { - hitCount += 1; - } - } - - private void synchronizeState(Map curRowKeyMap) throws Exception { - Iterator> iter = curRowKeyMap.entrySet().iterator(); - while (iter.hasNext()) { - Map.Entry entry = iter.next(); - BaseRow key = entry.getKey(); - RankRow rankRow = entry.getValue(); - if (rankRow.dirty) { - // should update state - dataState.put(key, Tuple2.of(rankRow.row, rankRow.innerRank)); - rankRow.dirty = false; - } - } - } - - class CacheRemovalListener implements LRUMap.RemovalListener> { - - @Override - public void onRemoval(Map.Entry> eldest) { - BaseRow previousKey = (BaseRow) keyContext.getCurrentKey(); - BaseRow partitionKey = eldest.getKey(); - Map currentRowKeyMap = eldest.getValue(); - keyContext.setCurrentKey(partitionKey); - kvSortedMap.remove(partitionKey); - try { - synchronizeState(currentRowKeyMap); - } catch (Throwable e) { - LOG.error("Fail to synchronize state!"); - throw new RuntimeException(e); - } - keyContext.setCurrentKey(previousKey); - } - } - - protected class RankRow { - private final BaseRow row; - private int innerRank; - private boolean dirty; - - protected RankRow(BaseRow row, int innerRank, boolean dirty) { - this.row = row; - this.innerRank = innerRank; - this.dirty = dirty; - } - - protected BaseRow getRow() { - return row; - } - - protected int getInnerRank() { - return innerRank; - } - - protected boolean isDirty() { - return dirty; - } - - public void setInnerRank(int innerRank) { - this.innerRank = innerRank; - } - - public void setDirty(boolean dirty) { - this.dirty = dirty; - } - } -} - - - diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AppendRankFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AppendRankFunction.java index 65ca3e5e20216..e21b6aaee5ee3 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AppendRankFunction.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AppendRankFunction.java @@ -37,7 +37,6 @@ import java.util.ArrayList; import java.util.Collection; -import java.util.Comparator; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -48,27 +47,28 @@ */ public class AppendRankFunction extends AbstractRankFunction { + private static final long serialVersionUID = -4708453213104128010L; + private static final Logger LOG = LoggerFactory.getLogger(AppendRankFunction.class); - protected final BaseRowTypeInfo sortKeyType; + private final BaseRowTypeInfo sortKeyType; private final TypeSerializer inputRowSer; private final long cacheSize; // a map state stores mapping from sort key to records list which is in topN private transient MapState> dataState; - // a sorted map stores mapping from sort key to records list, a heap mirror to dataState - private transient SortedMap sortedMap; - private transient Map> kvSortedMap; - private Comparator sortKeyComparator; + // the buffer stores mapping from sort key to records list, a heap mirror to dataState + private transient TopNBuffer buffer; + private transient Map kvSortedMap; - public AppendRankFunction( - long minRetentionTime, long maxRetentionTime, BaseRowTypeInfo inputRowType, BaseRowTypeInfo outputRowType, - BaseRowTypeInfo sortKeyType, GeneratedRecordComparator generatedRecordComparator, + public AppendRankFunction(long minRetentionTime, long maxRetentionTime, BaseRowTypeInfo inputRowType, + BaseRowTypeInfo sortKeyType, GeneratedRecordComparator sortKeyGeneratedRecordComparator, KeySelector sortKeySelector, RankType rankType, RankRange rankRange, - GeneratedRecordEqualiser generatedEqualiser, boolean generateRetraction, long cacheSize) { - super(minRetentionTime, maxRetentionTime, inputRowType, outputRowType, - generatedRecordComparator, sortKeySelector, rankType, rankRange, generatedEqualiser, generateRetraction); + GeneratedRecordEqualiser generatedEqualiser, boolean generateRetraction, boolean outputRankNumber, + long cacheSize) { + super(minRetentionTime, maxRetentionTime, inputRowType, sortKeyGeneratedRecordComparator, sortKeySelector, + rankType, rankRange, generatedEqualiser, generateRetraction, outputRankNumber); this.sortKeyType = sortKeyType; this.inputRowSer = inputRowType.createSerializer(new ExecutionConfig()); this.cacheSize = cacheSize; @@ -76,24 +76,21 @@ public AppendRankFunction( public void open(Configuration parameters) throws Exception { super.open(parameters); - int lruCacheSize = Math.max(1, (int) (cacheSize / getDefaultTopSize())); + int lruCacheSize = Math.max(1, (int) (cacheSize / getDefaultTopNSize())); kvSortedMap = new LRUMap<>(lruCacheSize); - LOG.info("Top{} operator is using LRU caches key-size: {}", getDefaultTopSize(), lruCacheSize); + LOG.info("Top{} operator is using LRU caches key-size: {}", getDefaultTopNSize(), lruCacheSize); ListTypeInfo valueTypeInfo = new ListTypeInfo<>(inputRowType); MapStateDescriptor> mapStateDescriptor = new MapStateDescriptor( "data-state-with-append", sortKeyType, valueTypeInfo); dataState = getRuntimeContext().getMapState(mapStateDescriptor); - sortKeyComparator = generatedRecordComparator.newInstance(getRuntimeContext().getUserCodeClassLoader()); - // metrics - registerMetric(kvSortedMap.size() * getDefaultTopSize()); + registerMetric(kvSortedMap.size() * getDefaultTopNSize()); } @Override - public void processElement( - BaseRow input, Context context, Collector out) throws Exception { + public void processElement(BaseRow input, Context context, Collector out) throws Exception { long currentTime = context.timerService().currentProcessingTime(); // register state-cleanup timer registerProcessingCleanupTimer(context, currentTime); @@ -103,36 +100,53 @@ public void processElement( BaseRow sortKey = sortKeySelector.getKey(input); // check whether the sortKey is in the topN range - if (checkSortKeyInBufferRange(sortKey, sortedMap, sortKeyComparator)) { - // insert sort key into sortedMap - sortedMap.put(sortKey, inputRowSer.copy(input)); - Collection inputs = sortedMap.get(sortKey); + if (checkSortKeyInBufferRange(sortKey, buffer)) { + // insert sort key into buffer + buffer.put(sortKey, inputRowSer.copy(input)); + Collection inputs = buffer.get(sortKey); // update data state dataState.put(sortKey, (List) inputs); - if (isRowNumberAppend || hasOffset()) { - // the without-number-algorithm can't handle topn with offset, + if (outputRankNumber || hasOffset()) { + // the without-number-algorithm can't handle topN with offset, // so use the with-number-algorithm to handle offset - emitRecordsWithRowNumber(sortKey, input, out); + processElementWithRowNumber(sortKey, input, out); } else { processElementWithoutRowNumber(input, out); } } } + @Override + public void onTimer( + long timestamp, + OnTimerContext ctx, + Collector out) throws Exception { + if (needToCleanupState(timestamp)) { + // cleanup cache + kvSortedMap.remove(keyContext.getCurrentKey()); + cleanupState(dataState); + } + } + + @Override + protected long getMaxSizeOfBuffer() { + return getDefaultTopNSize(); + } + private void initHeapStates() throws Exception { requestCount += 1; BaseRow currentKey = (BaseRow) keyContext.getCurrentKey(); - sortedMap = kvSortedMap.get(currentKey); - if (sortedMap == null) { - sortedMap = new SortedMap(sortKeyComparator, new Supplier>() { + buffer = kvSortedMap.get(currentKey); + if (buffer == null) { + buffer = new TopNBuffer(sortKeyComparator, new Supplier>() { @Override public Collection get() { return new ArrayList<>(); } }); - kvSortedMap.put(currentKey, sortedMap); - // restore sorted map + kvSortedMap.put(currentKey, buffer); + // restore buffer Iterator>> iter = dataState.iterator(); if (iter != null) { while (iter.hasNext()) { @@ -140,7 +154,7 @@ public Collection get() { BaseRow sortKey = entry.getKey(); List values = entry.getValue(); // the order is preserved - sortedMap.putAll(sortKey, values); + buffer.putAll(sortKey, values); } } } else { @@ -148,19 +162,19 @@ public Collection get() { } } - private void emitRecordsWithRowNumber(BaseRow sortKey, BaseRow input, Collector out) throws Exception { - Iterator>> iterator = sortedMap.entrySet().iterator(); + private void processElementWithRowNumber(BaseRow sortKey, BaseRow input, Collector out) throws Exception { + Iterator>> iterator = buffer.entrySet().iterator(); long curRank = 0L; - boolean findSortKey = false; + boolean findsSortKey = false; while (iterator.hasNext() && isInRankEnd(curRank)) { Map.Entry> entry = iterator.next(); Collection records = entry.getValue(); // meet its own sort key - if (!findSortKey && entry.getKey().equals(sortKey)) { + if (!findsSortKey && entry.getKey().equals(sortKey)) { curRank += records.size(); collect(out, input, curRank); - findSortKey = true; - } else if (findSortKey) { + findsSortKey = true; + } else if (findsSortKey) { Iterator recordsIter = records.iterator(); while (recordsIter.hasNext() && isInRankEnd(curRank)) { curRank += 1; @@ -173,29 +187,29 @@ private void emitRecordsWithRowNumber(BaseRow sortKey, BaseRow input, Collector< } } - List toDeleteKeys = new ArrayList<>(); // remove the records associated to the sort key which is out of topN + List toDeleteSortKeys = new ArrayList<>(); while (iterator.hasNext()) { Map.Entry> entry = iterator.next(); BaseRow key = entry.getKey(); dataState.remove(key); - toDeleteKeys.add(key); + toDeleteSortKeys.add(key); } - for (BaseRow toDeleteKey : toDeleteKeys) { - sortedMap.removeAll(toDeleteKey); + for (BaseRow toDeleteKey : toDeleteSortKeys) { + buffer.removeAll(toDeleteKey); } } private void processElementWithoutRowNumber(BaseRow input, Collector out) throws Exception { // remove retired element - if (sortedMap.getCurrentTopNum() > rankEnd) { - Map.Entry> lastEntry = sortedMap.lastEntry(); + if (buffer.getCurrentTopNum() > rankEnd) { + Map.Entry> lastEntry = buffer.lastEntry(); BaseRow lastKey = lastEntry.getKey(); List lastList = (List) lastEntry.getValue(); // remove last one BaseRow lastElement = lastList.remove(lastList.size() - 1); if (lastList.isEmpty()) { - sortedMap.removeAll(lastKey); + buffer.removeAll(lastKey); dataState.remove(lastKey); } else { dataState.put(lastKey, lastList); @@ -206,21 +220,4 @@ private void processElementWithoutRowNumber(BaseRow input, Collector ou collect(out, input); } - @Override - public void onTimer( - long timestamp, - OnTimerContext ctx, - Collector out) throws Exception { - if (needToCleanupState(timestamp)) { - // cleanup cache - kvSortedMap.remove(keyContext.getCurrentKey()); - cleanupState(dataState); - } - } - - @Override - protected long getMaxSortMapSize() { - return getDefaultTopSize(); - } - } diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RetractRankFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RetractRankFunction.java index f3ac79b3269ce..4f4e7ab980722 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RetractRankFunction.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RetractRankFunction.java @@ -38,7 +38,6 @@ import org.slf4j.LoggerFactory; import java.util.ArrayList; -import java.util.Comparator; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -50,17 +49,15 @@ */ public class RetractRankFunction extends AbstractRankFunction { + private static final long serialVersionUID = 1365312180599454479L; + private static final Logger LOG = LoggerFactory.getLogger(RetractRankFunction.class); - /** - * Message to indicate the state is cleared because of ttl restriction. The message could be - * used to output to log. - */ + // Message to indicate the state is cleared because of ttl restriction. The message could be used to output to log. private static final String STATE_CLEARED_WARN_MSG = "The state is cleared because of state ttl. " + - "This will result in incorrect result. " + - "You can increase the state ttl to avoid this."; + "This will result in incorrect result. You can increase the state ttl to avoid this."; - protected final BaseRowTypeInfo sortKeyType; + private final BaseRowTypeInfo sortKeyType; // flag to skip records with non-exist error instead to fail, true by default. private final boolean lenient = true; @@ -71,15 +68,12 @@ public class RetractRankFunction extends AbstractRankFunction { // a sorted map stores mapping from sort key to records count private transient ValueState> treeMap; - private Comparator sortKeyComparator; - - public RetractRankFunction( - long minRetentionTime, long maxRetentionTime, BaseRowTypeInfo inputRowType, BaseRowTypeInfo outputRowType, + public RetractRankFunction(long minRetentionTime, long maxRetentionTime, BaseRowTypeInfo inputRowType, BaseRowTypeInfo sortKeyType, GeneratedRecordComparator generatedRecordComparator, KeySelector sortKeySelector, RankType rankType, RankRange rankRange, - GeneratedRecordEqualiser generatedEqualiser, boolean generateRetraction) { - super(minRetentionTime, maxRetentionTime, inputRowType, outputRowType, - generatedRecordComparator, sortKeySelector, rankType, rankRange, generatedEqualiser, generateRetraction); + GeneratedRecordEqualiser generatedEqualiser, boolean generateRetraction, boolean outputRankNumber) { + super(minRetentionTime, maxRetentionTime, inputRowType, generatedRecordComparator, sortKeySelector, rankType, + rankRange, generatedEqualiser, generateRetraction, outputRankNumber); this.sortKeyType = sortKeyType; } @@ -91,26 +85,20 @@ public void open(Configuration parameters) throws Exception { "data-state", sortKeyType, valueTypeInfo); dataState = getRuntimeContext().getMapState(mapStateDescriptor); - sortKeyComparator = generatedRecordComparator.newInstance(getRuntimeContext().getUserCodeClassLoader()); ValueStateDescriptor> valueStateDescriptor = new ValueStateDescriptor( "sorted-map", - new SortedMapTypeInfo(sortKeyType, BasicTypeInfo.LONG_TYPE_INFO, sortKeyComparator) - ); + new SortedMapTypeInfo(sortKeyType, BasicTypeInfo.LONG_TYPE_INFO, sortKeyComparator)); treeMap = getRuntimeContext().getState(valueStateDescriptor); } @Override - public void processElement( - BaseRow input, Context ctx, Collector out) throws Exception { - + public void processElement(BaseRow input, Context ctx, Collector out) throws Exception { initRankEnd(input); - SortedMap sortedMap = treeMap.value(); if (sortedMap == null) { sortedMap = new TreeMap<>(sortKeyComparator); } BaseRow sortKey = sortKeySelector.getKey(input); - if (BaseRowUtil.isAccumulateMsg(input)) { // update sortedMap if (sortedMap.containsKey(sortKey)) { @@ -131,8 +119,6 @@ public void processElement( inputs.add(input); dataState.put(sortKey, inputs); } else { - // retract input - // emit updates first retractRecordWithRowNumber(sortedMap, sortKey, input, out); @@ -152,8 +138,7 @@ public void processElement( throw new RuntimeException(STATE_CLEARED_WARN_MSG); } } else { - throw new RuntimeException( - "Can not retract a non-existent record: ${inputBaseRow.toString}. " + + throw new RuntimeException("Can not retract a non-existent record: ${inputBaseRow.toString}. " + "This should never happen."); } } @@ -165,14 +150,15 @@ public void processElement( // ------------- ROW_NUMBER------------------------------- private void retractRecordWithRowNumber( - SortedMap sortedMap, BaseRow sortKey, BaseRow inputRow, Collector out) throws Exception { + SortedMap sortedMap, BaseRow sortKey, BaseRow inputRow, Collector out) + throws Exception { Iterator> iterator = sortedMap.entrySet().iterator(); long curRank = 0L; - boolean findSortKey = false; + boolean findsSortKey = false; while (iterator.hasNext() && isInRankEnd(curRank)) { Map.Entry entry = iterator.next(); BaseRow key = entry.getKey(); - if (!findSortKey && key.equals(sortKey)) { + if (!findsSortKey && key.equals(sortKey)) { List inputs = dataState.get(key); if (inputs == null) { // Skip the data if it's state is cleared because of state ttl. @@ -186,12 +172,12 @@ private void retractRecordWithRowNumber( while (inputIter.hasNext() && isInRankEnd(curRank)) { curRank += 1; BaseRow prevRow = inputIter.next(); - if (!findSortKey && equaliser.equalsWithoutHeader(prevRow, inputRow)) { + if (!findsSortKey && equaliser.equalsWithoutHeader(prevRow, inputRow)) { delete(out, prevRow, curRank); curRank -= 1; - findSortKey = true; + findsSortKey = true; inputIter.remove(); - } else if (findSortKey) { + } else if (findsSortKey) { retract(out, prevRow, curRank + 1); collect(out, prevRow, curRank); } @@ -202,7 +188,7 @@ private void retractRecordWithRowNumber( dataState.put(key, inputs); } } - } else if (findSortKey) { + } else if (findsSortKey) { List inputs = dataState.get(key); int i = 0; while (i < inputs.size() && isInRankEnd(curRank)) { @@ -219,18 +205,19 @@ private void retractRecordWithRowNumber( } private void emitRecordsWithRowNumber( - SortedMap sortedMap, BaseRow sortKey, BaseRow inputRow, Collector out) throws Exception { + SortedMap sortedMap, BaseRow sortKey, BaseRow inputRow, Collector out) + throws Exception { Iterator> iterator = sortedMap.entrySet().iterator(); long curRank = 0L; - boolean findSortKey = false; + boolean findsSortKey = false; while (iterator.hasNext() && isInRankRange(curRank)) { Map.Entry entry = iterator.next(); BaseRow key = entry.getKey(); - if (!findSortKey && key.equals(sortKey)) { + if (!findsSortKey && key.equals(sortKey)) { curRank += entry.getValue(); collect(out, inputRow, curRank); - findSortKey = true; - } else if (findSortKey) { + findsSortKey = true; + } else if (findsSortKey) { List inputs = dataState.get(key); if (inputs == null) { // Skip the data if it's state is cleared because of state ttl. @@ -256,7 +243,7 @@ private void emitRecordsWithRowNumber( } @Override - protected long getMaxSortMapSize() { + protected long getMaxSizeOfBuffer() { // just let it go, retract rank has no interest in this return 0L; } diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/SortedMap.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/TopNBuffer.java similarity index 60% rename from flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/SortedMap.java rename to flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/TopNBuffer.java index 4f65a121453b2..8600d919c7e93 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/SortedMap.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/TopNBuffer.java @@ -20,6 +20,7 @@ import org.apache.flink.table.dataformat.BaseRow; +import java.io.Serializable; import java.util.Collection; import java.util.Comparator; import java.util.Iterator; @@ -29,31 +30,32 @@ import java.util.function.Supplier; /** - * SortedMap stores mapping from sort key to records list, each record is BaseRow type. - * SortedMap could also track rank number of each records. - * - * @param Type of the sort key + * TopNBuffer stores mapping from sort key to records list, sortKey is BaseRow type, each record is BaseRow type. + * TopNBuffer could also track rank number of each record. */ -public class SortedMap { +class TopNBuffer implements Serializable { + + private static final long serialVersionUID = 6824488508991990228L; private final Supplier> valueSupplier; + private final Comparator sortKeyComparator; private int currentTopNum = 0; - private TreeMap> treeMap; + private TreeMap> treeMap; - public SortedMap(Comparator sortKeyComparator, Supplier> valueSupplier) { + TopNBuffer(Comparator sortKeyComparator, Supplier> valueSupplier) { this.valueSupplier = valueSupplier; + this.sortKeyComparator = sortKeyComparator; this.treeMap = new TreeMap(sortKeyComparator); } /** - * Appends a record into the SortedMap under the sortKey. + * Appends a record into the buffer. * * @param sortKey sort key with which the specified value is to be associated - * @param value record which is to be appended - * + * @param value record which is to be appended * @return the size of the collection under the sortKey. */ - public int put(T sortKey, BaseRow value) { + public int put(BaseRow sortKey, BaseRow value) { currentTopNum += 1; // update treeMap Collection collection = treeMap.get(sortKey); @@ -66,29 +68,28 @@ public int put(T sortKey, BaseRow value) { } /** - * Puts a record list into the SortedMap under the sortKey. - * Note: if SortedMap already contains sortKey, putAll will overwrite the previous value + * Puts a record list into the buffer under the sortKey. + * Note: if buffer already contains sortKey, putAll will overwrite the previous value * * @param sortKey sort key with which the specified values are to be associated - * @param values record lists to be associated with the specified key + * @param values record lists to be associated with the specified key */ - public void putAll(T sortKey, Collection values) { + void putAll(BaseRow sortKey, Collection values) { treeMap.put(sortKey, values); currentTopNum += values.size(); } /** - * Get the record list from SortedMap under the sortKey. + * Gets the record list from the buffer under the sortKey. * * @param sortKey key to get - * - * @return the record list from SortedMap under the sortKey + * @return the record list from the buffer under the sortKey */ - public Collection get(T sortKey) { + public Collection get(BaseRow sortKey) { return treeMap.get(sortKey); } - public void remove(T sortKey, BaseRow value) { + public void remove(BaseRow sortKey, BaseRow value) { Collection list = treeMap.get(sortKey); if (list != null) { if (list.remove(value)) { @@ -101,11 +102,11 @@ public void remove(T sortKey, BaseRow value) { } /** - * Remove all record list from SortedMap under the sortKey. + * Removes all record list from the buffer under the sortKey. * * @param sortKey key to remove */ - public void removeAll(T sortKey) { + void removeAll(BaseRow sortKey) { Collection list = treeMap.get(sortKey); if (list != null) { currentTopNum -= list.size(); @@ -114,13 +115,12 @@ public void removeAll(T sortKey) { } /** - * Remove the last record of the last Entry in the TreeMap (according to the TreeMap's - * key-sort function). + * Removes the last record of the last Entry in the buffer. * * @return removed record */ - public BaseRow removeLast() { - Map.Entry> last = treeMap.lastEntry(); + BaseRow removeLast() { + Map.Entry> last = treeMap.lastEntry(); BaseRow lastElement = null; if (last != null) { Collection list = last.getValue(); @@ -138,17 +138,16 @@ public BaseRow removeLast() { } /** - * Get record which rank is given value. + * Gets record which rank is given value. * * @param rank rank value to search - * * @return the record which rank is given value */ - public BaseRow getElement(int rank) { + BaseRow getElement(int rank) { int curRank = 0; - Iterator>> iter = treeMap.entrySet().iterator(); + Iterator>> iter = treeMap.entrySet().iterator(); while (iter.hasNext()) { - Map.Entry> entry = iter.next(); + Map.Entry> entry = iter.next(); Collection list = entry.getValue(); Iterator listIter = list.iterator(); @@ -175,40 +174,44 @@ private BaseRow getLastElement(Collection list) { } /** - * Returns a {@link Set} view of the mappings contained in this map. + * Returns a {@link Set} view of the mappings contained in the buffer. */ - public Set>> entrySet() { + Set>> entrySet() { return treeMap.entrySet(); } /** - * Returns the last Entry in the TreeMap (according to the TreeMap's - * key-sort function). Returns null if the TreeMap is empty. + * Returns the last Entry in the buffer. Returns null if the TreeMap is empty. */ - public Map.Entry> lastEntry() { + Map.Entry> lastEntry() { return treeMap.lastEntry(); } /** - * Returns {@code true} if this map contains a mapping for the specified - * key. + * Returns {@code true} if the buffer contains a mapping for the specified key. * - * @param key key whose presence in this map is to be tested - * - * @return {@code true} if this map contains a mapping for the - * specified key + * @param key key whose presence in the buffer is to be tested + * @return {@code true} if the buffer contains a mapping for the specified key */ - public boolean containsKey(T key) { + boolean containsKey(BaseRow key) { return treeMap.containsKey(key); } /** - * Get number of total records. + * Gets number of total records. * * @return the number of total records. */ - public int getCurrentTopNum() { + int getCurrentTopNum() { return currentTopNum; } + /** + * Gets sort key comparator used by buffer. + * + * @return sort key comparator used by buffer + */ + Comparator getSortKeyComparator() { + return sortKeyComparator; + } } diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/UpdateRankFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/UpdateRankFunction.java index 79f817fc5cd35..1e2849bc4f5ef 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/UpdateRankFunction.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/UpdateRankFunction.java @@ -19,59 +19,119 @@ package org.apache.flink.table.runtime.rank; import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.state.MapState; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.state.FunctionInitializationContext; +import org.apache.flink.runtime.state.FunctionSnapshotContext; +import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; import org.apache.flink.table.dataformat.BaseRow; import org.apache.flink.table.generated.GeneratedRecordComparator; import org.apache.flink.table.generated.GeneratedRecordEqualiser; +import org.apache.flink.table.runtime.util.LRUMap; import org.apache.flink.table.typeutils.BaseRowTypeInfo; import org.apache.flink.util.Collector; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; import java.util.Collection; +import java.util.HashMap; import java.util.Iterator; +import java.util.LinkedHashSet; +import java.util.List; import java.util.Map; +import java.util.TreeMap; +import java.util.function.Supplier; /** * A fast version of rank process function which only hold top n data in state, - * and keep sorted map in heap. This only works in some special scenarios, such - * as, rank a count(*) stream + * and keep sorted map in heap. This only works in some special scenarios, such as, rank a count(*) stream */ -public class UpdateRankFunction extends AbstractUpdateRankFunction { +public class UpdateRankFunction extends AbstractRankFunction implements CheckpointedFunction { + + private static final long serialVersionUID = 6786508184355952780L; + + private static final Logger LOG = LoggerFactory.getLogger(UpdateRankFunction.class); + + private final BaseRowTypeInfo rowKeyType; + private final long cacheSize; + + // a map state stores mapping from row key to record which is in topN + // in tuple2, f0 is the record row, f1 is the index in the list of the same sort_key + // the f1 is used to preserve the record order in the same sort_key + private transient MapState> dataState; + + // a buffer stores mapping from sort key to record list + private transient TopNBuffer buffer; + + private transient Map kvSortedMap; + + // a HashMap stores mapping from rowkey to record, a heap mirror to dataState + private transient Map rowKeyMap; + + private transient LRUMap> kvRowKeyMap; private final TypeSerializer inputRowSer; private final KeySelector rowKeySelector; - public UpdateRankFunction( - long minRetentionTime, - long maxRetentionTime, - BaseRowTypeInfo inputRowType, - BaseRowTypeInfo outputRowType, - BaseRowTypeInfo rowKeyType, - KeySelector rowKeySelector, - GeneratedRecordComparator generatedRecordComparator, - KeySelector sortKeySelector, - RankType rankType, - RankRange rankRange, - GeneratedRecordEqualiser generatedEqualiser, - boolean generateRetraction, - long cacheSize) { - super(minRetentionTime, - maxRetentionTime, - inputRowType, - outputRowType, - rowKeyType, - generatedRecordComparator, - sortKeySelector, - rankType, - rankRange, - generatedEqualiser, - generateRetraction, - cacheSize); + public UpdateRankFunction(long minRetentionTime, long maxRetentionTime, BaseRowTypeInfo inputRowType, + BaseRowTypeInfo rowKeyType, KeySelector rowKeySelector, + GeneratedRecordComparator generatedRecordComparator, KeySelector sortKeySelector, + RankType rankType, RankRange rankRange, GeneratedRecordEqualiser generatedEqualiser, + boolean generateRetraction, boolean outputRankNumber, long cacheSize) { + super(minRetentionTime, maxRetentionTime, inputRowType, generatedRecordComparator, sortKeySelector, rankType, + rankRange, generatedEqualiser, generateRetraction, outputRankNumber); + this.rowKeyType = rowKeyType; + this.cacheSize = cacheSize; this.inputRowSer = inputRowType.createSerializer(new ExecutionConfig()); this.rowKeySelector = rowKeySelector; } + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + int lruCacheSize = Math.max(1, (int) (cacheSize / getMaxSizeOfBuffer())); + // make sure the cached map is in a fixed size, avoid OOM + kvSortedMap = new HashMap<>(lruCacheSize); + kvRowKeyMap = new LRUMap<>(lruCacheSize, new CacheRemovalListener()); + + LOG.info("Top{} operator is using LRU caches key-size: {}", getMaxSizeOfBuffer(), lruCacheSize); + + TupleTypeInfo> valueTypeInfo = new TupleTypeInfo<>(inputRowType, Types.INT); + MapStateDescriptor> mapStateDescriptor = new MapStateDescriptor( + "data-state-with-update", rowKeyType, valueTypeInfo); + dataState = getRuntimeContext().getMapState(mapStateDescriptor); + + // metrics + registerMetric(kvSortedMap.size() * getMaxSizeOfBuffer()); + } + + @Override + public void onTimer( + long timestamp, + OnTimerContext ctx, + Collector out) throws Exception { + if (needToCleanupState(timestamp)) { + BaseRow partitionKey = (BaseRow) keyContext.getCurrentKey(); + // cleanup cache + kvRowKeyMap.remove(partitionKey); + kvSortedMap.remove(partitionKey); + cleanupState(dataState); + } + } + + @Override + public void initializeState(FunctionInitializationContext context) throws Exception { + // nothing to do + } + @Override public void processElement( BaseRow input, Context context, Collector out) throws Exception { @@ -81,8 +141,8 @@ public void processElement( initHeapStates(); initRankEnd(input); - if (isRowNumberAppend || hasOffset()) { - // the without-number-algorithm can't handle topn with offset, + if (outputRankNumber || hasOffset()) { + // the without-number-algorithm can't handle topN with offset, // so use the with-number-algorithm to handle offset processElementWithRowNumber(input, out); } else { @@ -91,8 +151,86 @@ public void processElement( } @Override - protected long getMaxSortMapSize() { - return getDefaultTopSize(); + protected long getMaxSizeOfBuffer() { + return getDefaultTopNSize(); + } + + @Override + public void snapshotState(FunctionSnapshotContext context) throws Exception { + Iterator>> iter = kvRowKeyMap.entrySet().iterator(); + while (iter.hasNext()) { + Map.Entry> entry = iter.next(); + BaseRow partitionKey = entry.getKey(); + Map currentRowKeyMap = entry.getValue(); + keyContext.setCurrentKey(partitionKey); + synchronizeState(currentRowKeyMap); + } + } + + private void initHeapStates() throws Exception { + requestCount += 1; + BaseRow partitionKey = (BaseRow) keyContext.getCurrentKey(); + buffer = kvSortedMap.get(partitionKey); + rowKeyMap = kvRowKeyMap.get(partitionKey); + if (buffer == null) { + buffer = new TopNBuffer(sortKeyComparator, new Supplier>() { + + @Override + public Collection get() { + return new LinkedHashSet<>(); + } + }); + rowKeyMap = new HashMap<>(); + kvSortedMap.put(partitionKey, buffer); + kvRowKeyMap.put(partitionKey, rowKeyMap); + + // restore sorted map + Iterator>> iter = dataState.iterator(); + if (iter != null) { + // a temp map associate sort key to tuple2 + Map> tempSortedMap = new HashMap<>(); + while (iter.hasNext()) { + Map.Entry> entry = iter.next(); + BaseRow rowkey = entry.getKey(); + Tuple2 recordAndInnerRank = entry.getValue(); + BaseRow record = recordAndInnerRank.f0; + Integer innerRank = recordAndInnerRank.f1; + rowKeyMap.put(rowkey, new RankRow(record, innerRank, false)); + + // insert into temp sort map to preserve the record order in the same sort key + BaseRow sortKey = sortKeySelector.getKey(record); + TreeMap treeMap = tempSortedMap.get(sortKey); + if (treeMap == null) { + treeMap = new TreeMap<>(); + tempSortedMap.put(sortKey, treeMap); + } + treeMap.put(innerRank, rowkey); + } + + // build sorted map from the temp map + Iterator>> tempIter = tempSortedMap.entrySet().iterator(); + while (tempIter.hasNext()) { + Map.Entry> entry = tempIter.next(); + BaseRow sortKey = entry.getKey(); + TreeMap treeMap = entry.getValue(); + Iterator> treeMapIter = treeMap.entrySet().iterator(); + while (treeMapIter.hasNext()) { + Map.Entry treeMapEntry = treeMapIter.next(); + Integer innerRank = treeMapEntry.getKey(); + BaseRow recordRowKey = treeMapEntry.getValue(); + int size = buffer.put(sortKey, recordRowKey); + if (innerRank != size) { + LOG.warn("Failed to build sorted map from state, this may result in wrong result. " + + "The sort key is {}, partition key is {}, " + + "treeMap is {}. The expected inner rank is {}, but current size is {}.", + sortKey, partitionKey, treeMap, innerRank, size); + } + } + } + } + } else { + hitCount += 1; + } } private void processElementWithRowNumber(BaseRow inputRow, Collector out) throws Exception { @@ -102,73 +240,92 @@ private void processElementWithRowNumber(BaseRow inputRow, Collector ou // it is an updated record which is in the topN, in this scenario, // the new sort key must be higher than old sort key, this is guaranteed by rules RankRow oldRow = rowKeyMap.get(rowKey); - BaseRow oldSortKey = sortKeySelector.getKey(oldRow.getRow()); + BaseRow oldSortKey = sortKeySelector.getKey(oldRow.row); if (oldSortKey.equals(sortKey)) { // sort key is not changed, so the rank is the same, only output the row - Tuple2 rankAndInnerRank = rowNumber(sortKey, rowKey, sortedMap); + Tuple2 rankAndInnerRank = rowNumber(sortKey, rowKey, buffer); int rank = rankAndInnerRank.f0; int innerRank = rankAndInnerRank.f1; rowKeyMap.put(rowKey, new RankRow(inputRowSer.copy(inputRow), innerRank, true)); - retract(out, oldRow.getRow(), rank); // retract old record + retract(out, oldRow.row, rank); // retract old record collect(out, inputRow, rank); return; } - Tuple2 oldRankAndInnerRank = rowNumber(oldSortKey, rowKey, sortedMap); + Tuple2 oldRankAndInnerRank = rowNumber(oldSortKey, rowKey, buffer); int oldRank = oldRankAndInnerRank.f0; // remove old sort key - sortedMap.remove(oldSortKey, rowKey); + buffer.remove(oldSortKey, rowKey); // add new sort key - int size = sortedMap.put(sortKey, rowKey); + int size = buffer.put(sortKey, rowKey); rowKeyMap.put(rowKey, new RankRow(inputRowSer.copy(inputRow), size, true)); // update inner rank of records under the old sort key updateInnerRank(oldSortKey); // emit records emitRecordsWithRowNumber(sortKey, inputRow, out, oldSortKey, oldRow, oldRank); - } else { - // out of topN + } else if (checkSortKeyInBufferRange(sortKey, buffer)) { + // it is an unique record but is in the topN, insert sort key into buffer + int size = buffer.put(sortKey, rowKey); + rowKeyMap.put(rowKey, new RankRow(inputRowSer.copy(inputRow), size, true)); + + // emit records + emitRecordsWithRowNumber(sortKey, inputRow, out); } } - private void updateInnerRank(BaseRow oldSortKey) { - Collection list = sortedMap.get(oldSortKey); - if (list != null) { - Iterator iter = list.iterator(); - int innerRank = 1; - while (iter.hasNext()) { - BaseRow rowkey = iter.next(); - RankRow row = rowKeyMap.get(rowkey); - if (row.getInnerRank() != innerRank) { - row.setInnerRank(innerRank); - row.setDirty(true); + private Tuple2 rowNumber(BaseRow sortKey, BaseRow rowKey, TopNBuffer buffer) { + Iterator>> iterator = buffer.entrySet().iterator(); + int curRank = 1; + while (iterator.hasNext()) { + Map.Entry> entry = iterator.next(); + BaseRow curKey = entry.getKey(); + Collection rowKeys = entry.getValue(); + if (curKey.equals(sortKey)) { + Iterator rowKeysIter = rowKeys.iterator(); + int innerRank = 1; + while (rowKeysIter.hasNext()) { + if (rowKey.equals(rowKeysIter.next())) { + return Tuple2.of(curRank, innerRank); + } else { + innerRank += 1; + curRank += 1; + } } - innerRank += 1; + } else { + curRank += rowKeys.size(); } } + LOG.error("Failed to find the sortKey: {}, rowkey: {} in the buffer. This should never happen", sortKey, + rowKey); + throw new RuntimeException("Failed to find the sortKey, rowkey in the buffer. This should never happen"); } - private void emitRecordsWithRowNumber( - BaseRow sortKey, BaseRow inputRow, Collector out, BaseRow oldSortKey, RankRow oldRow, int oldRank) { + private void emitRecordsWithRowNumber(BaseRow sortKey, BaseRow inputRow, Collector out) throws Exception { + emitRecordsWithRowNumber(sortKey, inputRow, out, null, null, -1); + } + + private void emitRecordsWithRowNumber(BaseRow sortKey, BaseRow inputRow, Collector out, BaseRow oldSortKey, + RankRow oldRow, int oldRank) throws Exception { - int oldInnerRank = oldRow == null ? -1 : oldRow.getInnerRank(); - Iterator>> iterator = sortedMap.entrySet().iterator(); + int oldInnerRank = oldRow == null ? -1 : oldRow.innerRank; + Iterator>> iterator = buffer.entrySet().iterator(); int curRank = 0; - // whether we have found the sort key in the sorted tree - boolean findSortKey = false; + // whether we have found the sort key in the buffer + boolean findsSortKey = false; while (iterator.hasNext() && isInRankEnd(curRank + 1)) { Map.Entry> entry = iterator.next(); - BaseRow curKey = entry.getKey(); + BaseRow curSortKey = entry.getKey(); Collection rowKeys = entry.getValue(); // meet its own sort key - if (!findSortKey && curKey.equals(sortKey)) { + if (!findsSortKey && curSortKey.equals(sortKey)) { curRank += rowKeys.size(); if (oldRow != null) { - retract(out, oldRow.getRow(), oldRank); + retract(out, oldRow.row, oldRank); } collect(out, inputRow, curRank); - findSortKey = true; - } else if (findSortKey) { + findsSortKey = true; + } else if (findsSortKey) { if (oldSortKey == null) { // this is a new row, emit updates for all rows in the topn Iterator rowKeyIter = rowKeys.iterator(); @@ -176,13 +333,13 @@ private void emitRecordsWithRowNumber( curRank += 1; BaseRow rowKey = rowKeyIter.next(); RankRow prevRow = rowKeyMap.get(rowKey); - retract(out, prevRow.getRow(), curRank - 1); - collect(out, prevRow.getRow(), curRank); + retract(out, prevRow.row, curRank - 1); + collect(out, prevRow.row, curRank); } } else { // current sort key is higher than old sort key, // the rank of current record is changed, need to update the following rank - int compare = sortKeyComparator.compare(curKey, oldSortKey); + int compare = sortKeyComparator.compare(curSortKey, oldSortKey); if (compare <= 0) { Iterator rowKeyIter = rowKeys.iterator(); int curInnerRank = 0; @@ -196,12 +353,11 @@ private void emitRecordsWithRowNumber( BaseRow rowKey = rowKeyIter.next(); RankRow prevRow = rowKeyMap.get(rowKey); - retract(out, prevRow.getRow(), curRank - 1); - collect(out, prevRow.getRow(), curRank); + retract(out, prevRow.row, curRank - 1); + collect(out, prevRow.row, curRank); } } else { - // current sort key is smaller than old sort key, - // the rank is not changed, so skip + // current sort key is smaller than old sort key, the rank is not changed, so skip return; } } @@ -209,6 +365,23 @@ private void emitRecordsWithRowNumber( curRank += rowKeys.size(); } } + + // remove the records associated to the sort key which is out of topN + List toDeleteSortKeys = new ArrayList<>(); + while (iterator.hasNext()) { + Map.Entry> entry = iterator.next(); + Collection rowKeys = entry.getValue(); + Iterator rowKeyIter = rowKeys.iterator(); + while (rowKeyIter.hasNext()) { + BaseRow rowKey = rowKeyIter.next(); + rowKeyMap.remove(rowKey); + dataState.remove(rowKey); + } + toDeleteSortKeys.add(entry.getKey()); + } + for (BaseRow toDeleteKey : toDeleteSortKeys) { + buffer.removeAll(toDeleteKey); + } } private void processElementWithoutRowNumber(BaseRow inputRow, Collector out) throws Exception { @@ -218,40 +391,100 @@ private void processElementWithoutRowNumber(BaseRow inputRow, Collector // it is an updated record which is in the topN, in this scenario, // the new sort key must be higher than old sort key, this is guaranteed by rules RankRow oldRow = rowKeyMap.get(rowKey); - BaseRow oldSortKey = sortKeySelector.getKey(oldRow.getRow()); + BaseRow oldSortKey = sortKeySelector.getKey(oldRow.row); if (!oldSortKey.equals(sortKey)) { // remove old sort key - sortedMap.remove(oldSortKey, rowKey); + buffer.remove(oldSortKey, rowKey); // add new sort key - int size = sortedMap.put(sortKey, rowKey); + int size = buffer.put(sortKey, rowKey); rowKeyMap.put(rowKey, new RankRow(inputRowSer.copy(inputRow), size, true)); // update inner rank of records under the old sort key updateInnerRank(oldSortKey); } else { // row content may change, so we need to update row in map - rowKeyMap.put(rowKey, new RankRow(inputRowSer.copy(inputRow), oldRow.getInnerRank(), true)); + rowKeyMap.put(rowKey, new RankRow(inputRowSer.copy(inputRow), oldRow.innerRank, true)); } // row content may change, so a retract is needed - retract(out, oldRow.getRow(), oldRow.getInnerRank()); + retract(out, oldRow.row, oldRow.innerRank); collect(out, inputRow); - } else if (checkSortKeyInBufferRange(sortKey, sortedMap, sortKeyComparator)) { - // it is an unique record but is in the topN - // insert sort key into sortedMap - int size = sortedMap.put(sortKey, rowKey); + } else if (checkSortKeyInBufferRange(sortKey, buffer)) { + // it is an unique record but is in the topN, insert sort key into buffer + int size = buffer.put(sortKey, rowKey); rowKeyMap.put(rowKey, new RankRow(inputRowSer.copy(inputRow), size, true)); collect(out, inputRow); // remove retired element - if (sortedMap.getCurrentTopNum() > rankEnd) { - BaseRow lastRowKey = sortedMap.removeLast(); + if (buffer.getCurrentTopNum() > rankEnd) { + BaseRow lastRowKey = buffer.removeLast(); if (lastRowKey != null) { RankRow lastRow = rowKeyMap.remove(lastRowKey); dataState.remove(lastRowKey); // always send a retraction message - delete(out, lastRow.getRow()); + delete(out, lastRow.row); } } - } else { - // out of topN + } + } + + private void synchronizeState(Map curRowKeyMap) throws Exception { + Iterator> iter = curRowKeyMap.entrySet().iterator(); + while (iter.hasNext()) { + Map.Entry entry = iter.next(); + BaseRow key = entry.getKey(); + RankRow rankRow = entry.getValue(); + if (rankRow.dirty) { + // should update state + dataState.put(key, Tuple2.of(rankRow.row, rankRow.innerRank)); + rankRow.dirty = false; + } + } + } + + private void updateInnerRank(BaseRow oldSortKey) { + Collection list = buffer.get(oldSortKey); + if (list != null) { + Iterator iter = list.iterator(); + int innerRank = 1; + while (iter.hasNext()) { + BaseRow rowkey = iter.next(); + RankRow row = rowKeyMap.get(rowkey); + if (row.innerRank != innerRank) { + row.innerRank = innerRank; + row.dirty = true; + } + innerRank += 1; + } + } + } + + private class CacheRemovalListener implements LRUMap.RemovalListener> { + + @Override + public void onRemoval(Map.Entry> eldest) { + BaseRow previousKey = (BaseRow) keyContext.getCurrentKey(); + BaseRow partitionKey = eldest.getKey(); + Map currentRowKeyMap = eldest.getValue(); + keyContext.setCurrentKey(partitionKey); + kvSortedMap.remove(partitionKey); + try { + synchronizeState(currentRowKeyMap); + } catch (Throwable e) { + LOG.error("Fail to synchronize state!", e); + throw new RuntimeException(e); + } finally { + keyContext.setCurrentKey(previousKey); + } + } + } + + private class RankRow { + private final BaseRow row; + private int innerRank; + private boolean dirty; + + private RankRow(BaseRow row, int innerRank, boolean dirty) { + this.row = row; + this.innerRank = innerRank; + this.dirty = dirty; } } diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/values/ValuesInputFormat.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/values/ValuesInputFormat.java index 5ad8779059e78..49cc0734dcf51 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/values/ValuesInputFormat.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/values/ValuesInputFormat.java @@ -40,7 +40,7 @@ public class ValuesInputFormat implements NonParallelInput, ResultTypeQueryable { private static final Logger LOG = LoggerFactory.getLogger(ValuesInputFormat.class); - private final GeneratedInput> generatedInput; + private GeneratedInput> generatedInput; private final BaseRowTypeInfo returnType; private GenericInputFormat format; @@ -54,7 +54,9 @@ public void open(GenericInputSplit split) { LOG.debug("Compiling GenericInputFormat: $name \n\n Code:\n$code", generatedInput.getClassName(), generatedInput.getCode()); LOG.debug("Instantiating GenericInputFormat."); + format = generatedInput.newInstance(getRuntimeContext().getUserCodeClassLoader()); + generatedInput = null; } @Override diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapSerializer.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapSerializer.java index c1c9421afe928..443ceed66b3d2 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapSerializer.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapSerializer.java @@ -106,10 +106,10 @@ public int hashCode() { @Override public String toString() { return "SortedMapSerializer{" + - "comparator = " + comparator + - ", keySerializer = " + keySerializer + - ", valueSerializer = " + valueSerializer + - "}"; + "comparator = " + comparator + + ", keySerializer = " + keySerializer + + ", valueSerializer = " + valueSerializer + + "}"; } @Override diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapTypeInfo.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapTypeInfo.java index 6850c23a1e0df..0e3cef7297044 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapTypeInfo.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapTypeInfo.java @@ -66,7 +66,7 @@ public SortedMapTypeInfo(Class keyClass, Class valueClass) { super(keyClass, valueClass); Preconditions.checkArgument(Comparable.class.isAssignableFrom(keyClass), - "The key class must be comparable when no comparator is given."); + "The key class must be comparable when no comparator is given."); this.comparator = new ComparableComparator<>(); } @@ -112,10 +112,10 @@ public int hashCode() { @Override public String toString() { return "SortedMapTypeInfo{" + - "comparator=" + comparator + - ", keyTypeInfo=" + getKeyTypeInfo() + - ", valueTypeInfo=" + getValueTypeInfo() + - "}"; + "comparator=" + comparator + + ", keyTypeInfo=" + getKeyTypeInfo() + + ", valueTypeInfo=" + getValueTypeInfo() + + "}"; } //-------------------------------------------------------------------------- From 2c794f9b410cbf93ba02a69570fe770b1f757d13 Mon Sep 17 00:00:00 2001 From: beyond1920 Date: Fri, 12 Apr 2019 16:17:21 +0800 Subject: [PATCH 3/5] 1. add ut for Deduplicate functions and Rank functions. 2. Introduce SortedMapSerializerSnapshot to do snapshot for SortedMapTypeInfo. --- .../stream/StreamExecDeduplicate.scala | 11 +- .../physical/stream/StreamExecRank.scala | 11 +- .../table/plan/optimize/StreamOptimizer.scala | 7 +- .../table/runtime/stream/sql/RankITCase.scala | 86 +---- .../deduplicate/DeduplicateFunction.java | 20 +- .../DeduplicateFunctionHelper.java | 3 +- .../MiniBatchDeduplicateFunction.java | 16 +- .../table/runtime/functions/CleanupState.java | 55 +++ .../KeyedProcessFunctionWithCleanupState.java | 38 +- .../runtime/rank/AbstractRankFunction.java | 69 ++-- .../runtime/rank/AppendRankFunction.java | 24 +- .../runtime/rank/RetractRankFunction.java | 16 +- .../runtime/rank/UpdateRankFunction.java | 41 ++- .../table/typeutils/SortedMapSerializer.java | 17 +- .../SortedMapSerializerSnapshot.java | 110 ++++++ .../BaseDeduplicateFunctionTest.java | 59 +++ .../deduplicate/DeduplicateFunctionTest.java | 125 +++++++ .../MiniBatchDeduplicateFunctionTest.java | 155 ++++++++ .../runtime/rank/AppendRankFunctionTest.java | 71 ++++ .../runtime/rank/BaseRankFunctionTest.java | 337 ++++++++++++++++++ .../runtime/rank/RetractRankFunctionTest.java | 223 ++++++++++++ .../runtime/rank/UpdateRankFunctionTest.java | 228 ++++++++++++ .../runtime/sort/IntRecordComparator.java | 4 + .../runtime/util/BaseRowRecordEqualiser.java | 54 +++ .../runtime/util/BinaryRowKeySelector.java | 8 +- .../util/GenericRowRecordSortComparator.java | 60 ++++ .../table/runtime/util/StreamRecordUtils.java | 55 +++ .../window/WindowOperatorContractTest.java | 4 +- .../runtime/window/WindowOperatorTest.java | 40 +-- .../table/runtime/window/WindowTestUtils.java | 22 -- 30 files changed, 1715 insertions(+), 254 deletions(-) create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/CleanupState.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapSerializerSnapshot.java create mode 100644 flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/BaseDeduplicateFunctionTest.java create mode 100644 flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionTest.java create mode 100644 flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateFunctionTest.java create mode 100644 flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/AppendRankFunctionTest.java create mode 100644 flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/BaseRankFunctionTest.java create mode 100644 flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/RetractRankFunctionTest.java create mode 100644 flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/UpdateRankFunctionTest.java create mode 100644 flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/util/BaseRowRecordEqualiser.java create mode 100644 flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/util/GenericRowRecordSortComparator.java create mode 100644 flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/util/StreamRecordUtils.java diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDeduplicate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDeduplicate.scala index 0c4171c78b2ab..024ee45700fd8 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDeduplicate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDeduplicate.scala @@ -32,6 +32,7 @@ import org.apache.flink.table.runtime.bundle.trigger.CountBundleTrigger import org.apache.flink.table.runtime.deduplicate.{DeduplicateFunction, MiniBatchDeduplicateFunction} import org.apache.flink.table.`type`.TypeConverters +import org.apache.flink.table.plan.rules.physical.stream.StreamExecRetractionRules import org.apache.flink.table.typeutils.BaseRowTypeInfo import org.apache.flink.table.typeutils.TypeCheckUtils.isRowTime @@ -98,8 +99,7 @@ class StreamExecDeduplicate( override protected def translateToPlanInternal( tableEnv: StreamTableEnvironment): StreamTransformation[BaseRow] = { - // FIXME checkInput is not acc retract after FLINK-12098 is done - val inputIsAccRetract = false + val inputIsAccRetract = StreamExecRetractionRules.isAccRetract(getInput) if (inputIsAccRetract) { throw new TableException( @@ -113,8 +113,7 @@ class StreamExecDeduplicate( val rowTypeInfo = inputTransform.getOutputType.asInstanceOf[BaseRowTypeInfo] - // FIXME infer generate retraction after FLINK-12098 is done - val generateRetraction = true + val generateRetraction = StreamExecRetractionRules.isAccRetract(this) val inputRowType = FlinkTypeFactory.toInternalRowType(getInput.getRowType) val rowTimeFieldIndex = inputRowType.getFieldTypes.zipWithIndex @@ -127,15 +126,15 @@ class StreamExecDeduplicate( throw new TableException("Currently not support Deduplicate on rowtime.") } val tableConfig = tableEnv.getConfig - val exeConfig = tableEnv.execEnv.getConfig val isMiniBatchEnabled = tableConfig.getConf.getLong( TableConfigOptions.SQL_EXEC_MINIBATCH_ALLOW_LATENCY) > 0 val generatedRecordEqualiser = generateRecordEqualiser(rowTypeInfo) val operator = if (isMiniBatchEnabled) { + val exeConfig = tableEnv.execEnv.getConfig val processFunction = new MiniBatchDeduplicateFunction( rowTypeInfo, generateRetraction, - exeConfig, + rowTypeInfo.createSerializer(exeConfig), keepLastRow, generatedRecordEqualiser) val trigger = new CountBundleTrigger[BaseRow]( diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecRank.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecRank.scala index d0ec274740d31..ea2b5e48bd8c8 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecRank.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecRank.scala @@ -25,9 +25,9 @@ import org.apache.flink.table.codegen.{EqualiserCodeGenerator, SortCodeGenerator import org.apache.flink.table.dataformat.BaseRow import org.apache.flink.table.plan.nodes.calcite.Rank import org.apache.flink.table.plan.nodes.exec.{ExecNode, StreamExecNode} +import org.apache.flink.table.plan.rules.physical.stream.StreamExecRetractionRules import org.apache.flink.table.plan.util._ import org.apache.flink.table.runtime.rank._ -import org.apache.flink.table.typeutils.BaseRowTypeInfo import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} import org.apache.calcite.rel._ @@ -140,8 +140,7 @@ class StreamExecRank( tableConfig, sortFields.indices.toArray, sortKeyType.getInternalTypes, sortDirections, nullsIsLast) val sortKeyComparator = sortCodeGen.generateRecordComparator("StreamExecSortComparator") - // FIXME infer generate retraction after FLINK-12098 is done - val generateRetraction = true + val generateRetraction = StreamExecRetractionRules.isAccRetract(this) val cacheSize = tableConfig.getConf.getLong(TableConfigOptions.SQL_EXEC_TOPN_CACHE_SIZE) val minIdleStateRetentionTime = tableConfig.getMinIdleStateRetentionTime val maxIdleStateRetentionTime = tableConfig.getMaxIdleStateRetentionTime @@ -153,7 +152,6 @@ class StreamExecRank( minIdleStateRetentionTime, maxIdleStateRetentionTime, inputRowTypeInfo, - sortKeyType, sortKeyComparator, sortKeySelector, rankType, @@ -164,15 +162,11 @@ class StreamExecRank( cacheSize) case UpdateFastStrategy(primaryKeys) => - val internalTypes = inputRowTypeInfo.getInternalTypes - val fieldTypes = primaryKeys.map(internalTypes) - val rowKeyType = new BaseRowTypeInfo(fieldTypes: _*) val rowKeySelector = KeySelectorUtil.getBaseRowSelector(primaryKeys, inputRowTypeInfo) new UpdateRankFunction( minIdleStateRetentionTime, maxIdleStateRetentionTime, inputRowTypeInfo, - rowKeyType, rowKeySelector, sortKeyComparator, sortKeySelector, @@ -189,7 +183,6 @@ class StreamExecRank( minIdleStateRetentionTime, maxIdleStateRetentionTime, inputRowTypeInfo, - sortKeyType, sortKeyComparator, sortKeySelector, rankType, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/optimize/StreamOptimizer.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/optimize/StreamOptimizer.scala index 1bff861eb7287..19fbd3df9021a 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/optimize/StreamOptimizer.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/optimize/StreamOptimizer.scala @@ -22,7 +22,7 @@ import org.apache.flink.table.api.{StreamTableEnvironment, TableConfig} import org.apache.flink.table.plan.`trait`.UpdateAsRetractionTraitDef import org.apache.flink.table.plan.nodes.calcite.Sink import org.apache.flink.table.plan.optimize.program.{FlinkStreamProgram, StreamOptimizeContext} -import org.apache.flink.table.sinks.RetractStreamTableSink +import org.apache.flink.table.sinks.{DataStreamTableSink, RetractStreamTableSink} import org.apache.flink.util.Preconditions import org.apache.calcite.plan.volcano.VolcanoPlanner @@ -38,7 +38,10 @@ class StreamOptimizer(tEnv: StreamTableEnvironment) extends Optimizer { roots.map { root => val retractionFromRoot = root match { case n: Sink => - n.sink.isInstanceOf[RetractStreamTableSink[_]] + n.sink match { + case _: RetractStreamTableSink[_] => true + case s: DataStreamTableSink[_] => s.updatesAsRetraction + } case o => o.getTraitSet.getTrait(UpdateAsRetractionTraitDef.INSTANCE).sendsUpdatesAsRetractions } diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/RankITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/RankITCase.scala index 06d280f5e636a..fd2485dc282cb 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/RankITCase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/RankITCase.scala @@ -61,10 +61,9 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode val sink = new TestingRetractSink tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink).setParallelism(1) env.execute() - val expected = List( - "book,2,19,1", "book,1,12,2", + "book,2,19,1", "fruit,3,44,1", "fruit,4,33,2") assertEquals(expected.sorted, sink.getRetractResults.sorted) @@ -103,8 +102,6 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode assertEquals(expected.sorted, sink.getRetractResults.sorted) } - // FIXME - @Ignore("Enable after retraction infer (FLINK-12098) is introduced") @Test def testTopNWithUpsertSink(): Unit = { val data = List( @@ -142,7 +139,7 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode } // FIXME - @Ignore("Enable after agg rules added and SortedMapState is supported") + @Ignore("Enable after agg implements ExecNode and SortedMapState is supported") @Test(expected = classOf[TableException]) def testTopNWithUnary(): Unit = { val data = List( @@ -203,7 +200,7 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode } // FIXME - @Ignore("Enable after agg rules added and SortedMapState is supported") + @Ignore("Enable after agg implements ExecNode and SortedMapState is supported") @Test def testUnarySortTopNOnString(): Unit = { val data = List( @@ -264,7 +261,7 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode } // FIXME - @Ignore("Enable after agg rules added") + @Ignore("Enable after agg implements ExecNode") @Test def testTopNWithGroupBy(): Unit = { val data = List( @@ -308,7 +305,7 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode } // FIXME - @Ignore("Enable after agg rules added") + @Ignore("Enable after agg implements ExecNode") @Test def testTopNWithSumAndCondition(): Unit = { val data = List( @@ -359,7 +356,7 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode } // FIXME - @Ignore("Enable after agg rules added") + @Ignore("Enable after agg implements ExecNode") @Test def testTopNthWithGroupBy(): Unit = { val data = List( @@ -402,7 +399,7 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode } // FIXME - @Ignore("Enable after agg rules added") + @Ignore("Enable after agg implements ExecNode") @Test def testTopNWithGroupByAndRetract(): Unit = { val data = List( @@ -445,7 +442,7 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode } // FIXME - @Ignore("Enable after agg rules added") + @Ignore("Enable after agg implements ExecNode") @Test def testTopNthWithGroupByAndRetract(): Unit = { val data = List( @@ -486,7 +483,7 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode } // FIXME - @Ignore("Enable after agg rules added") + @Ignore("Enable after agg implements ExecNode") @Test def testTopNWithGroupByCount(): Unit = { val data = List( @@ -543,7 +540,7 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode } // FIXME - @Ignore("Enable after agg rules added") + @Ignore("Enable after agg implements ExecNode") @Test def testTopNthWithGroupByCount(): Unit = { val data = List( @@ -595,7 +592,7 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode } // FIXME - @Ignore("Enable after agg rules added") + @Ignore("Enable after agg implements ExecNode") @Test def testNestedTopN(): Unit = { val data = List( @@ -667,7 +664,7 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode } // FIXME - @Ignore("Enable after agg rules added") + @Ignore("Enable after agg implements ExecNode") @Test def testTopNWithoutDeduplicate(): Unit = { val data = List( @@ -731,7 +728,7 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode } // FIXME - @Ignore("Enable after agg rules added") + @Ignore("Enable after agg implements ExecNode") @Test def testTopNWithVariableTopSize(): Unit = { val data = List( @@ -787,7 +784,7 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode } // FIXME - @Ignore("Enable after agg rules added") + @Ignore("Enable after agg implements ExecNode") @Test def testTopNUnaryComplexScenario(): Unit = { val data = List( @@ -876,7 +873,7 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode } // FIXME - @Ignore("Enable after agg rules added") + @Ignore("Enable after agg implements ExecNode") @Test def testTopNWithGroupByAvgWithoutRowNumber(): Unit = { val data = List( @@ -939,7 +936,7 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode } // FIXME - @Ignore("Enable after agg rules added") + @Ignore("Enable after agg implements ExecNode") @Test def testTopNWithGroupByCountWithoutRowNumber(): Unit = { val data = List( @@ -1068,7 +1065,7 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode } // FIXME - @Ignore("Enable after agg rules added") + @Ignore("Enable after agg implements ExecNode") @Test def testMultipleRetractTopNAfterAgg(): Unit = { val data = List( @@ -1137,7 +1134,7 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode } // FIXME - @Ignore("Enable after agg rules added") + @Ignore("Enable after agg implements ExecNode") @Test def testMultipleUnaryTopNAfterAgg(): Unit = { val data = List( @@ -1204,7 +1201,7 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode } // FIXME - @Ignore("Enable after agg rules added") + @Ignore("Enable after agg implements ExecNode") @Test def testMultipleUpdateTopNAfterAgg(): Unit = { val data = List( @@ -1270,49 +1267,4 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode assertEquals(expected2.sorted, sink2.getRetractResults.sorted) } - // FIXME - @Ignore("Enable after agg rules added") - @Test - def testUpdateRank(): Unit = { - val data = List( - (1, 1), (1, 2), (1, 3), - (2, 2), (2, 3), (2, 4), - (3, 3), (3, 4), (3, 5)) - - val ds = failingDataSource(data).toTable(tEnv, 'a, 'b) - tEnv.registerTable("T", ds) - - // We use max here to ensure the usage of update rank - val sql = "SELECT a, max(b) FROM T GROUP BY a ORDER BY a LIMIT 2" - - val sink = new TestingRetractSink - tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink).setParallelism(1) - env.execute() - - val expected = List("1,3", "2,4") - assertEquals(expected.sorted, sink.getRetractResults.sorted) - } - - // FIXME - @Ignore("Enable after agg rules added") - @Test - def testUpdateRankWithOffset(): Unit = { - val data = List( - (1, 1), (1, 2), (1, 3), - (2, 2), (2, 3), (2, 4), - (3, 3), (3, 4), (3, 5)) - - val ds = failingDataSource(data).toTable(tEnv, 'a, 'b) - tEnv.registerTable("T", ds) - - // We use max here to ensure the usage of update rank - val sql = "SELECT a, max(b) FROM T GROUP BY a ORDER BY a LIMIT 2 OFFSET 1" - - val sink = new TestingRetractSink - tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink).setParallelism(1) - env.execute() - - val expected = List("2,4", "3,5") - assertEquals(expected.sorted, sink.getRetractResults.sorted) - } } diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunction.java index 65dd9bf2ab028..e1a9f3485dbb2 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunction.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunction.java @@ -28,6 +28,9 @@ import org.apache.flink.table.typeutils.BaseRowTypeInfo; import org.apache.flink.util.Collector; +import static org.apache.flink.table.runtime.deduplicate.DeduplicateFunctionHelper.processFirstRow; +import static org.apache.flink.table.runtime.deduplicate.DeduplicateFunctionHelper.processLastRow; + /** * This function is used to deduplicate on keys and keeps only first row or last row. */ @@ -43,13 +46,8 @@ public class DeduplicateFunction private GeneratedRecordEqualiser generatedEqualiser; private transient RecordEqualiser equaliser; - public DeduplicateFunction( - long minRetentionTime, - long maxRetentionTime, - BaseRowTypeInfo rowTypeInfo, - boolean generateRetraction, - boolean keepLastRow, - GeneratedRecordEqualiser generatedEqualiser) { + public DeduplicateFunction(long minRetentionTime, long maxRetentionTime, BaseRowTypeInfo rowTypeInfo, + boolean generateRetraction, boolean keepLastRow, GeneratedRecordEqualiser generatedEqualiser) { super(minRetentionTime, maxRetentionTime); this.rowTypeInfo = rowTypeInfo; this.generateRetraction = generateRetraction; @@ -78,11 +76,9 @@ public void processElement(BaseRow input, Context ctx, Collector out) t BaseRow preRow = pkRow.value(); if (keepLastRow) { - DeduplicateFunctionHelper - .processLastRow(preRow, input, generateRetraction, stateCleaningEnabled, pkRow, equaliser, out); + processLastRow(preRow, input, generateRetraction, stateCleaningEnabled, pkRow, equaliser, out); } else { - DeduplicateFunctionHelper - .processFirstRow(preRow, input, pkRow, out); + processFirstRow(preRow, input, pkRow, out); } } @@ -96,7 +92,7 @@ public void onTimer( long timestamp, OnTimerContext ctx, Collector out) throws Exception { - if (needToCleanupState(timestamp)) { + if (stateCleaningEnabled) { cleanupState(pkRow); } } diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionHelper.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionHelper.java index 7a77e16748fc7..d3b4f2c5102bb 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionHelper.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionHelper.java @@ -35,9 +35,10 @@ static void processLastRow(BaseRow preRow, BaseRow currentRow, boolean generateR Collector out) throws Exception { // should be accumulate msg. Preconditions.checkArgument(BaseRowUtil.isAccumulateMsg(currentRow)); - // ignore same record if (!stateCleaningEnabled && preRow != null && equaliser.equalsWithoutHeader(preRow, currentRow)) { + // If state cleaning is not enabled, don't emit retraction and acc message. But if state cleaning is + // enabled, we have to emit message to prevent too early state eviction of downstream operators. return; } pkRow.update(currentRow); diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateFunction.java index ebb61b874fd90..aa3c4cdec6820 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateFunction.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateFunction.java @@ -18,7 +18,6 @@ package org.apache.flink.table.runtime.deduplicate; -import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeutils.TypeSerializer; @@ -34,6 +33,9 @@ import java.util.Map; +import static org.apache.flink.table.runtime.deduplicate.DeduplicateFunctionHelper.processFirstRow; +import static org.apache.flink.table.runtime.deduplicate.DeduplicateFunctionHelper.processLastRow; + /** * This function is used to get the first row or last row for every key partition in miniBatch * mode. @@ -44,7 +46,7 @@ public class MiniBatchDeduplicateFunction private BaseRowTypeInfo rowTypeInfo; private boolean generateRetraction; private boolean keepLastRow; - protected ValueState pkRow; + private ValueState pkRow; private TypeSerializer ser; private GeneratedRecordEqualiser generatedEqualiser; private transient RecordEqualiser equaliser; @@ -52,13 +54,13 @@ public class MiniBatchDeduplicateFunction public MiniBatchDeduplicateFunction( BaseRowTypeInfo rowTypeInfo, boolean generateRetraction, - ExecutionConfig executionConfig, + TypeSerializer typeSerializer, boolean keepLastRow, GeneratedRecordEqualiser generatedEqualiser) { this.rowTypeInfo = rowTypeInfo; this.generateRetraction = generateRetraction; this.keepLastRow = keepLastRow; - ser = rowTypeInfo.createSerializer(executionConfig); + ser = typeSerializer; this.generatedEqualiser = generatedEqualiser; } @@ -94,11 +96,9 @@ public void finishBundle( BaseRow preRow = pkRow.value(); if (keepLastRow) { - DeduplicateFunctionHelper - .processLastRow(preRow, currentRow, generateRetraction, false, pkRow, equaliser, out); + processLastRow(preRow, currentRow, generateRetraction, false, pkRow, equaliser, out); } else { - DeduplicateFunctionHelper - .processFirstRow(preRow, currentRow, pkRow, out); + processFirstRow(preRow, currentRow, pkRow, out); } } } diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/CleanupState.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/CleanupState.java new file mode 100644 index 0000000000000..ddb823c60b003 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/CleanupState.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.functions; + +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.streaming.api.TimerService; +import org.apache.flink.streaming.api.functions.ProcessFunction; +import org.apache.flink.streaming.api.functions.co.CoProcessFunction; + +/** + * Base interface for clean up state, both for {@link ProcessFunction} and {@link CoProcessFunction}. + */ +public interface CleanupState { + + default void registerProcessingCleanupTimer( + ValueState cleanupTimeState, + long currentTime, + long minRetentionTime, + long maxRetentionTime, + TimerService timerService) throws Exception { + + // last registered timer + Long curCleanupTime = cleanupTimeState.value(); + + // check if a cleanup timer is registered and + // that the current cleanup timer won't delete state we need to keep + if (curCleanupTime == null || (currentTime + minRetentionTime) > curCleanupTime) { + // we need to register a new (later) timer + long cleanupTime = currentTime + maxRetentionTime; + // register timer and remember clean-up time + timerService.registerProcessingTimeTimer(cleanupTime); + // delete expired timer + if (curCleanupTime != null) { + timerService.deleteProcessingTimeTimer(curCleanupTime); + } + cleanupTimeState.update(cleanupTime); + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/KeyedProcessFunctionWithCleanupState.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/KeyedProcessFunctionWithCleanupState.java index 2e3835202949d..c5797632f71c9 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/KeyedProcessFunctionWithCleanupState.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/KeyedProcessFunctionWithCleanupState.java @@ -27,12 +27,14 @@ /** * A function that processes elements of a stream, and could cleanup state. - * * @param Type of the key. - * @param Type of the input elements. + * @param Type of the input elements. * @param Type of the output elements. */ -public abstract class KeyedProcessFunctionWithCleanupState extends KeyedProcessFunction { +public abstract class KeyedProcessFunctionWithCleanupState + extends KeyedProcessFunction implements CleanupState { + + private static final long serialVersionUID = 2084560869233898457L; protected final long minRetentionTime; protected final long maxRetentionTime; @@ -49,24 +51,20 @@ public KeyedProcessFunctionWithCleanupState(long minRetentionTime, long maxReten protected void initCleanupTimeState(String stateName) { if (stateCleaningEnabled) { - ValueStateDescriptor inputCntDescriptor = new ValueStateDescriptor(stateName, Types.LONG); + ValueStateDescriptor inputCntDescriptor = new ValueStateDescriptor<>(stateName, Types.LONG); cleanupTimeState = getRuntimeContext().getState(inputCntDescriptor); } } protected void registerProcessingCleanupTimer(Context ctx, long currentTime) throws Exception { if (stateCleaningEnabled) { - // last registered timer - Long curCleanupTime = cleanupTimeState.value(); - - // check if a cleanup timer is registered and that the current cleanup timer won't delete state we need to keep - if (curCleanupTime == null || (currentTime + minRetentionTime) > curCleanupTime) { - // we need to register a new (later) timer - Long cleanupTime = currentTime + maxRetentionTime; - // register timer and remember clean-up time - ctx.timerService().registerProcessingTimeTimer(cleanupTime); - cleanupTimeState.update(cleanupTime); - } + registerProcessingCleanupTimer( + cleanupTimeState, + currentTime, + minRetentionTime, + maxRetentionTime, + ctx.timerService() + ); } } @@ -74,16 +72,6 @@ protected boolean isProcessingTimeTimer(OnTimerContext ctx) { return ctx.timeDomain() == TimeDomain.PROCESSING_TIME; } - protected boolean needToCleanupState(long timestamp) throws Exception { - if (stateCleaningEnabled) { - Long cleanupTime = cleanupTimeState.value(); - // check that the triggered timer is the last registered processing time timer. - return null != cleanupTime && timestamp == cleanupTime; - } else { - return false; - } - } - protected void cleanupState(State... states) { for (State state : states) { state.clear(); diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractRankFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractRankFunction.java index edce4e12f70e2..431d25fb8f977 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractRankFunction.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractRankFunction.java @@ -34,6 +34,9 @@ import org.apache.flink.table.generated.GeneratedRecordEqualiser; import org.apache.flink.table.generated.RecordEqualiser; import org.apache.flink.table.runtime.functions.KeyedProcessFunctionWithCleanupState; +import org.apache.flink.table.runtime.keyselector.BaseRowKeySelector; +import org.apache.flink.table.type.InternalType; +import org.apache.flink.table.type.InternalTypes; import org.apache.flink.table.typeutils.BaseRowTypeInfo; import org.apache.flink.util.Collector; @@ -51,11 +54,15 @@ public abstract class AbstractRankFunction extends KeyedProcessFunctionWithClean private static final Logger LOG = LoggerFactory.getLogger(AbstractRankFunction.class); + private static final String RANK_UNSUPPORTED_MSG = "RANK() on streaming table is not supported currently"; + + private static final String DENSE_RANK_UNSUPPORTED_MSG = "DENSE_RANK() on streaming table is not supported currently"; + + private static final String WITHOUT_RANK_END_UNSUPPORTED_MSG = "Rank end is not specified. Currently rank only support TopN, which means the rank end must be specified."; + // we set default topN size to 100 private static final long DEFAULT_TOPN_SIZE = 100; - private final RankRange rankRange; - /** * The util to compare two BaseRow equals to each other. * As different BaseRow can't be equals directly, we use a code generated util to handle this. @@ -75,10 +82,10 @@ public abstract class AbstractRankFunction extends KeyedProcessFunctionWithClean protected final KeySelector sortKeySelector; protected KeyContext keyContext; - private boolean isConstantRankEnd; - private long rankStart = -1; - protected long rankEnd = -1; - private int rankEndIndex; + private final boolean isConstantRankEnd; + private final long rankStart; + protected long rankEnd; + private final int rankEndIndex; private ValueState rankEndState; private Counter invalidCounter; private JoinedRow outputRow; @@ -88,7 +95,7 @@ public abstract class AbstractRankFunction extends KeyedProcessFunctionWithClean protected long requestCount = 0L; AbstractRankFunction(long minRetentionTime, long maxRetentionTime, BaseRowTypeInfo inputRowType, - GeneratedRecordComparator generatedSortKeyComparator, KeySelector sortKeySelector, + GeneratedRecordComparator generatedSortKeyComparator, BaseRowKeySelector sortKeySelector, RankType rankType, RankRange rankRange, GeneratedRecordEqualiser generatedEqualiser, boolean generateRetraction, boolean outputRankNumber) { super(minRetentionTime, maxRetentionTime); @@ -97,16 +104,42 @@ public abstract class AbstractRankFunction extends KeyedProcessFunctionWithClean case ROW_NUMBER: break; case RANK: - LOG.error("RANK() on streaming table is not supported currently"); - throw new UnsupportedOperationException("RANK() on streaming table is not supported currently"); + LOG.error(RANK_UNSUPPORTED_MSG); + throw new UnsupportedOperationException(RANK_UNSUPPORTED_MSG); case DENSE_RANK: - LOG.error("DENSE_RANK() on streaming table is not supported currently"); - throw new UnsupportedOperationException("DENSE_RANK() on streaming table is not supported currently"); + LOG.error(DENSE_RANK_UNSUPPORTED_MSG); + throw new UnsupportedOperationException(DENSE_RANK_UNSUPPORTED_MSG); default: LOG.error("Streaming tables do not support {}", rankType.name()); throw new UnsupportedOperationException("Streaming tables do not support " + rankType.toString()); } - this.rankRange = rankRange; + + if (rankRange instanceof ConstantRankRange) { + ConstantRankRange constantRankRange = (ConstantRankRange) rankRange; + isConstantRankEnd = true; + rankStart = constantRankRange.getRankStart(); + rankEnd = constantRankRange.getRankEnd(); + rankEndIndex = -1; + } else if (rankRange instanceof VariableRankRange) { + VariableRankRange variableRankRange = (VariableRankRange) rankRange; + int rankEndIdx = variableRankRange.getRankEndIndex(); + InternalType rankEndIdxType = inputRowType.getInternalTypes()[rankEndIdx]; + if (!rankEndIdxType.equals(InternalTypes.LONG)) { + LOG.error("variable rank index column must be long type, while input type is {}", + rankEndIdxType.getClass().getName()); + throw new UnsupportedOperationException( + "variable rank index column must be long type, while input type is " + + rankEndIdxType.getClass().getName()); + } + rankEndIndex = rankEndIdx; + isConstantRankEnd = false; + rankStart = -1; + rankEnd = -1; + + } else { + LOG.error(WITHOUT_RANK_END_UNSUPPORTED_MSG); + throw new UnsupportedOperationException(WITHOUT_RANK_END_UNSUPPORTED_MSG); + } this.generatedEqualiser = generatedEqualiser; this.generatedSortKeyComparator = generatedSortKeyComparator; this.generateRetraction = generateRetraction; @@ -121,20 +154,10 @@ public void open(Configuration parameters) throws Exception { initCleanupTimeState("RankFunctionCleanupTime"); outputRow = new JoinedRow(); - // variable rank limit - if (rankRange instanceof ConstantRankRange) { - ConstantRankRange constantRankRange = (ConstantRankRange) rankRange; - isConstantRankEnd = true; - rankEnd = constantRankRange.getRankEnd(); - rankStart = constantRankRange.getRankStart(); - } else if (rankRange instanceof VariableRankRange) { - VariableRankRange variableRankRange = (VariableRankRange) rankRange; - isConstantRankEnd = false; - rankEndIndex = variableRankRange.getRankEndIndex(); + if (!isConstantRankEnd) { ValueStateDescriptor rankStateDesc = new ValueStateDescriptor("rankEnd", Types.LONG); rankEndState = getRuntimeContext().getState(rankStateDesc); } - // compile equaliser equaliser = generatedEqualiser.newInstance(getRuntimeContext().getUserCodeClassLoader()); generatedEqualiser = null; diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AppendRankFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AppendRankFunction.java index e21b6aaee5ee3..41ea9d703ca55 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AppendRankFunction.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AppendRankFunction.java @@ -22,12 +22,13 @@ import org.apache.flink.api.common.state.MapState; import org.apache.flink.api.common.state.MapStateDescriptor; import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.typeutils.ListTypeInfo; import org.apache.flink.configuration.Configuration; +import org.apache.flink.table.api.TableConfigOptions; import org.apache.flink.table.dataformat.BaseRow; import org.apache.flink.table.generated.GeneratedRecordComparator; import org.apache.flink.table.generated.GeneratedRecordEqualiser; +import org.apache.flink.table.runtime.keyselector.BaseRowKeySelector; import org.apache.flink.table.runtime.util.LRUMap; import org.apache.flink.table.typeutils.BaseRowTypeInfo; import org.apache.flink.util.Collector; @@ -43,7 +44,7 @@ import java.util.function.Supplier; /** - * RankFunction in Append Stream mode. + * AppendRankFunction's input stream only contains append record. */ public class AppendRankFunction extends AbstractRankFunction { @@ -63,13 +64,12 @@ public class AppendRankFunction extends AbstractRankFunction { private transient Map kvSortedMap; public AppendRankFunction(long minRetentionTime, long maxRetentionTime, BaseRowTypeInfo inputRowType, - BaseRowTypeInfo sortKeyType, GeneratedRecordComparator sortKeyGeneratedRecordComparator, - KeySelector sortKeySelector, RankType rankType, RankRange rankRange, - GeneratedRecordEqualiser generatedEqualiser, boolean generateRetraction, boolean outputRankNumber, - long cacheSize) { + GeneratedRecordComparator sortKeyGeneratedRecordComparator, BaseRowKeySelector sortKeySelector, + RankType rankType, RankRange rankRange, GeneratedRecordEqualiser generatedEqualiser, + boolean generateRetraction, boolean outputRankNumber, long cacheSize) { super(minRetentionTime, maxRetentionTime, inputRowType, sortKeyGeneratedRecordComparator, sortKeySelector, rankType, rankRange, generatedEqualiser, generateRetraction, outputRankNumber); - this.sortKeyType = sortKeyType; + this.sortKeyType = sortKeySelector.getProducedType(); this.inputRowSer = inputRowType.createSerializer(new ExecutionConfig()); this.cacheSize = cacheSize; } @@ -121,7 +121,7 @@ public void onTimer( long timestamp, OnTimerContext ctx, Collector out) throws Exception { - if (needToCleanupState(timestamp)) { + if (stateCleaningEnabled) { // cleanup cache kvSortedMap.remove(keyContext.getCurrentKey()); cleanupState(dataState); @@ -214,8 +214,12 @@ private void processElementWithoutRowNumber(BaseRow input, Collector ou } else { dataState.put(lastKey, lastList); } - // lastElement shouldn't be null - delete(out, lastElement); + if (input.equals(lastElement)) { + return; + } else { + // lastElement shouldn't be null + delete(out, lastElement); + } } collect(out, input); } diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RetractRankFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RetractRankFunction.java index 4f4e7ab980722..21c3e934ab004 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RetractRankFunction.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RetractRankFunction.java @@ -23,13 +23,13 @@ import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; -import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.typeutils.ListTypeInfo; import org.apache.flink.configuration.Configuration; import org.apache.flink.table.dataformat.BaseRow; import org.apache.flink.table.dataformat.util.BaseRowUtil; import org.apache.flink.table.generated.GeneratedRecordComparator; import org.apache.flink.table.generated.GeneratedRecordEqualiser; +import org.apache.flink.table.runtime.keyselector.BaseRowKeySelector; import org.apache.flink.table.typeutils.BaseRowTypeInfo; import org.apache.flink.table.typeutils.SortedMapTypeInfo; import org.apache.flink.util.Collector; @@ -45,7 +45,7 @@ import java.util.TreeMap; /** - * RankFunction in Update Stream mode. + * RetractRankFunction's input stream could contain append record, update record, delete record. */ public class RetractRankFunction extends AbstractRankFunction { @@ -69,12 +69,12 @@ public class RetractRankFunction extends AbstractRankFunction { private transient ValueState> treeMap; public RetractRankFunction(long minRetentionTime, long maxRetentionTime, BaseRowTypeInfo inputRowType, - BaseRowTypeInfo sortKeyType, GeneratedRecordComparator generatedRecordComparator, - KeySelector sortKeySelector, RankType rankType, RankRange rankRange, - GeneratedRecordEqualiser generatedEqualiser, boolean generateRetraction, boolean outputRankNumber) { + GeneratedRecordComparator generatedRecordComparator, BaseRowKeySelector sortKeySelector, + RankType rankType, RankRange rankRange, GeneratedRecordEqualiser generatedEqualiser, + boolean generateRetraction, boolean outputRankNumber) { super(minRetentionTime, maxRetentionTime, inputRowType, generatedRecordComparator, sortKeySelector, rankType, rankRange, generatedEqualiser, generateRetraction, outputRankNumber); - this.sortKeyType = sortKeyType; + this.sortKeyType = sortKeySelector.getProducedType(); } @Override @@ -210,7 +210,7 @@ private void emitRecordsWithRowNumber( Iterator> iterator = sortedMap.entrySet().iterator(); long curRank = 0L; boolean findsSortKey = false; - while (iterator.hasNext() && isInRankRange(curRank)) { + while (iterator.hasNext() && isInRankEnd(curRank)) { Map.Entry entry = iterator.next(); BaseRow key = entry.getKey(); if (!findsSortKey && key.equals(sortKey)) { @@ -228,7 +228,7 @@ private void emitRecordsWithRowNumber( } } else { int i = 0; - while (i < inputs.size() && isInRankRange(curRank)) { + while (i < inputs.size() && isInRankEnd(curRank)) { curRank += 1; BaseRow prevRow = inputs.get(i); retract(out, prevRow, curRank - 1); diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/UpdateRankFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/UpdateRankFunction.java index 1e2849bc4f5ef..d5e90834538fb 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/UpdateRankFunction.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/UpdateRankFunction.java @@ -33,6 +33,7 @@ import org.apache.flink.table.dataformat.BaseRow; import org.apache.flink.table.generated.GeneratedRecordComparator; import org.apache.flink.table.generated.GeneratedRecordEqualiser; +import org.apache.flink.table.runtime.keyselector.BaseRowKeySelector; import org.apache.flink.table.runtime.util.LRUMap; import org.apache.flink.table.typeutils.BaseRowTypeInfo; import org.apache.flink.util.Collector; @@ -51,8 +52,12 @@ import java.util.function.Supplier; /** - * A fast version of rank process function which only hold top n data in state, - * and keep sorted map in heap. This only works in some special scenarios, such as, rank a count(*) stream + * A fast version of rank process function which only hold top n data in state, and keep sorted map in heap. + * This only works in some special scenarios: + * 1. sort field collation is ascending and its mono is decreasing, or sort field collation is descending and its mono + * is increasing + * 2. input data has unique keys + * 3. input stream could not contain delete record or retract record */ public class UpdateRankFunction extends AbstractRankFunction implements CheckpointedFunction { @@ -68,12 +73,12 @@ public class UpdateRankFunction extends AbstractRankFunction implements Checkpoi // the f1 is used to preserve the record order in the same sort_key private transient MapState> dataState; - // a buffer stores mapping from sort key to record list + // a buffer stores mapping from sort key to rowKey list private transient TopNBuffer buffer; private transient Map kvSortedMap; - // a HashMap stores mapping from rowkey to record, a heap mirror to dataState + // a HashMap stores mapping from rowKey to record, a heap mirror to dataState private transient Map rowKeyMap; private transient LRUMap> kvRowKeyMap; @@ -82,13 +87,13 @@ public class UpdateRankFunction extends AbstractRankFunction implements Checkpoi private final KeySelector rowKeySelector; public UpdateRankFunction(long minRetentionTime, long maxRetentionTime, BaseRowTypeInfo inputRowType, - BaseRowTypeInfo rowKeyType, KeySelector rowKeySelector, - GeneratedRecordComparator generatedRecordComparator, KeySelector sortKeySelector, - RankType rankType, RankRange rankRange, GeneratedRecordEqualiser generatedEqualiser, - boolean generateRetraction, boolean outputRankNumber, long cacheSize) { + BaseRowKeySelector rowKeySelector, GeneratedRecordComparator generatedRecordComparator, + BaseRowKeySelector sortKeySelector, RankType rankType, + RankRange rankRange, GeneratedRecordEqualiser generatedEqualiser, boolean generateRetraction, + boolean outputRankNumber, long cacheSize) { super(minRetentionTime, maxRetentionTime, inputRowType, generatedRecordComparator, sortKeySelector, rankType, rankRange, generatedEqualiser, generateRetraction, outputRankNumber); - this.rowKeyType = rowKeyType; + this.rowKeyType = rowKeySelector.getProducedType(); this.cacheSize = cacheSize; this.inputRowSer = inputRowType.createSerializer(new ExecutionConfig()); this.rowKeySelector = rowKeySelector; @@ -118,7 +123,7 @@ public void onTimer( long timestamp, OnTimerContext ctx, Collector out) throws Exception { - if (needToCleanupState(timestamp)) { + if (stateCleaningEnabled) { BaseRow partitionKey = (BaseRow) keyContext.getCurrentKey(); // cleanup cache kvRowKeyMap.remove(partitionKey); @@ -191,11 +196,11 @@ public Collection get() { Map> tempSortedMap = new HashMap<>(); while (iter.hasNext()) { Map.Entry> entry = iter.next(); - BaseRow rowkey = entry.getKey(); + BaseRow rowKey = entry.getKey(); Tuple2 recordAndInnerRank = entry.getValue(); BaseRow record = recordAndInnerRank.f0; Integer innerRank = recordAndInnerRank.f1; - rowKeyMap.put(rowkey, new RankRow(record, innerRank, false)); + rowKeyMap.put(rowKey, new RankRow(record, innerRank, false)); // insert into temp sort map to preserve the record order in the same sort key BaseRow sortKey = sortKeySelector.getKey(record); @@ -204,7 +209,7 @@ public Collection get() { treeMap = new TreeMap<>(); tempSortedMap.put(sortKey, treeMap); } - treeMap.put(innerRank, rowkey); + treeMap.put(innerRank, rowKey); } // build sorted map from the temp map @@ -313,7 +318,7 @@ private void emitRecordsWithRowNumber(BaseRow sortKey, BaseRow inputRow, Collect int curRank = 0; // whether we have found the sort key in the buffer boolean findsSortKey = false; - while (iterator.hasNext() && isInRankEnd(curRank + 1)) { + while (iterator.hasNext() && isInRankEnd(curRank)) { Map.Entry> entry = iterator.next(); BaseRow curSortKey = entry.getKey(); Collection rowKeys = entry.getValue(); @@ -329,7 +334,7 @@ private void emitRecordsWithRowNumber(BaseRow sortKey, BaseRow inputRow, Collect if (oldSortKey == null) { // this is a new row, emit updates for all rows in the topn Iterator rowKeyIter = rowKeys.iterator(); - while (rowKeyIter.hasNext() && isInRankEnd(curRank + 1)) { + while (rowKeyIter.hasNext() && isInRankEnd(curRank)) { curRank += 1; BaseRow rowKey = rowKeyIter.next(); RankRow prevRow = rowKeyMap.get(rowKey); @@ -343,7 +348,7 @@ private void emitRecordsWithRowNumber(BaseRow sortKey, BaseRow inputRow, Collect if (compare <= 0) { Iterator rowKeyIter = rowKeys.iterator(); int curInnerRank = 0; - while (rowKeyIter.hasNext() && isInRankEnd(curRank + 1)) { + while (rowKeyIter.hasNext() && isInRankEnd(curRank)) { curRank += 1; curInnerRank += 1; if (compare == 0 && curInnerRank >= oldInnerRank) { @@ -445,8 +450,8 @@ private void updateInnerRank(BaseRow oldSortKey) { Iterator iter = list.iterator(); int innerRank = 1; while (iter.hasNext()) { - BaseRow rowkey = iter.next(); - RankRow row = rowKeyMap.get(rowkey); + BaseRow rowKey = iter.next(); + RankRow row = rowKeyMap.get(rowKey); if (row.innerRank != innerRank) { row.innerRank = innerRank; row.dirty = true; diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapSerializer.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapSerializer.java index 443ceed66b3d2..ea5b8200f17fb 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapSerializer.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapSerializer.java @@ -19,9 +19,7 @@ package org.apache.flink.table.typeutils; import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.api.common.typeutils.TypeSerializerConfigSnapshot; -import org.apache.flink.api.common.typeutils.base.MapSerializerConfigSnapshot; -import org.apache.flink.util.Preconditions; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; import java.util.Comparator; import java.util.SortedMap; @@ -59,8 +57,6 @@ public SortedMapSerializer( TypeSerializer keySerializer, TypeSerializer valueSerializer) { super(keySerializer, valueSerializer); - - Preconditions.checkNotNull(comparator, "The comparator cannot be null."); this.comparator = comparator; } @@ -112,9 +108,12 @@ public String toString() { "}"; } + // -------------------------------------------------------------------------------------------- + // Serializer configuration snapshot + // -------------------------------------------------------------------------------------------- + @Override - public TypeSerializerConfigSnapshot snapshotConfiguration() { - return new MapSerializerConfigSnapshot<>(keySerializer, valueSerializer); + public TypeSerializerSnapshot> snapshotConfiguration() { + return new SortedMapSerializerSnapshot<>(this); } - -} +} \ No newline at end of file diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapSerializerSnapshot.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapSerializerSnapshot.java new file mode 100644 index 0000000000000..d7a1dcf68ce1d --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapSerializerSnapshot.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.typeutils; + +import org.apache.flink.api.common.typeutils.NestedSerializersSnapshotDelegate; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializerSchemaCompatibility; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.java.typeutils.runtime.DataInputViewStream; +import org.apache.flink.api.java.typeutils.runtime.DataOutputViewStream; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.util.InstantiationUtil; + +import java.io.IOException; +import java.util.Comparator; +import java.util.SortedMap; + +import static org.apache.flink.util.Preconditions.checkState; + +/** + * Snapshot class for the {@link SortedMapSerializer}. + */ +public class SortedMapSerializerSnapshot implements TypeSerializerSnapshot> { + + private Comparator comparator; + + private NestedSerializersSnapshotDelegate nestedSerializersSnapshotDelegate; + + private static final int CURRENT_VERSION = 3; + + @SuppressWarnings("unused") + public SortedMapSerializerSnapshot() { + // this constructor is used when restoring from a checkpoint/savepoint. + } + + SortedMapSerializerSnapshot(SortedMapSerializer sortedMapSerializer) { + this.comparator = sortedMapSerializer.getComparator(); + TypeSerializer[] typeSerializers = + new TypeSerializer[] { sortedMapSerializer.getKeySerializer(), sortedMapSerializer.getValueSerializer() }; + this.nestedSerializersSnapshotDelegate = new NestedSerializersSnapshotDelegate(typeSerializers); + } + + @Override + public int getCurrentVersion() { + return CURRENT_VERSION; + } + + @Override + public void writeSnapshot(DataOutputView out) throws IOException { + checkState(comparator != null, "Comparator cannot be null."); + InstantiationUtil.serializeObject(new DataOutputViewStream(out), comparator); + nestedSerializersSnapshotDelegate.writeNestedSerializerSnapshots(out); + } + + @Override + public void readSnapshot(int readVersion, DataInputView in, ClassLoader userCodeClassLoader) throws IOException { + try { + comparator = InstantiationUtil.deserializeObject( + new DataInputViewStream(in), userCodeClassLoader); + } catch (ClassNotFoundException e) { + throw new IOException(e); + } + this.nestedSerializersSnapshotDelegate = NestedSerializersSnapshotDelegate.readNestedSerializerSnapshots( + in, + userCodeClassLoader); + } + + @Override + public SortedMapSerializer restoreSerializer() { + TypeSerializer[] nestedSerializers = nestedSerializersSnapshotDelegate.getRestoredNestedSerializers(); + @SuppressWarnings("unchecked") + TypeSerializer keySerializer = (TypeSerializer) nestedSerializers[0]; + + @SuppressWarnings("unchecked") + TypeSerializer valueSerializer = (TypeSerializer) nestedSerializers[1]; + + return new SortedMapSerializer(comparator, keySerializer, valueSerializer); + } + + @Override + public TypeSerializerSchemaCompatibility> resolveSchemaCompatibility( + TypeSerializer> newSerializer) { + if (!(newSerializer instanceof SortedMapSerializer)) { + return TypeSerializerSchemaCompatibility.incompatible(); + } + SortedMapSerializer newSortedMapSerializer = (SortedMapSerializer) newSerializer; + if (!comparator.equals(newSortedMapSerializer.getComparator())) { + return TypeSerializerSchemaCompatibility.incompatible(); + } else { + return TypeSerializerSchemaCompatibility.compatibleAsIs(); + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/BaseDeduplicateFunctionTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/BaseDeduplicateFunctionTest.java new file mode 100644 index 0000000000000..d8a475c45241f --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/BaseDeduplicateFunctionTest.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.deduplicate; + +import org.apache.flink.table.generated.GeneratedRecordEqualiser; +import org.apache.flink.table.generated.RecordEqualiser; +import org.apache.flink.table.runtime.util.BaseRowHarnessAssertor; +import org.apache.flink.table.runtime.util.BaseRowRecordEqualiser; +import org.apache.flink.table.runtime.util.BinaryRowKeySelector; +import org.apache.flink.table.runtime.util.GenericRowRecordSortComparator; +import org.apache.flink.table.type.InternalTypes; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; + +/** + * Base Tests for {@link DeduplicateFunction} and {@link MiniBatchDeduplicateFunction}. + */ +public abstract class BaseDeduplicateFunctionTest { + + protected BaseRowTypeInfo inputRowType = new BaseRowTypeInfo( + InternalTypes.STRING, + InternalTypes.LONG, + InternalTypes.INT); + + protected GeneratedRecordEqualiser generatedEqualiser = new GeneratedRecordEqualiser("", "", new Object[0]) { + + private static final long serialVersionUID = -5080236034372380295L; + + @Override + public RecordEqualiser newInstance(ClassLoader classLoader) { + return new BaseRowRecordEqualiser(); + } + }; + + private int rowKeyIdx = 1; + protected BinaryRowKeySelector rowKeySelector = new BinaryRowKeySelector(new int[] { rowKeyIdx }, + inputRowType.getInternalTypes()); + + + protected BaseRowHarnessAssertor assertor = new BaseRowHarnessAssertor( + inputRowType.getFieldTypes(), + new GenericRowRecordSortComparator(rowKeyIdx, inputRowType.getInternalTypes()[rowKeyIdx])); + +} diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionTest.java new file mode 100644 index 0000000000000..d99e6783c9c75 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionTest.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.deduplicate; + +import org.apache.flink.api.common.time.Time; +import org.apache.flink.streaming.api.operators.KeyedProcessOperator; +import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; +import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; +import org.apache.flink.table.dataformat.BaseRow; + +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.apache.flink.table.runtime.util.StreamRecordUtils.record; +import static org.apache.flink.table.runtime.util.StreamRecordUtils.retractRecord; + +/** + * Tests for {@link DeduplicateFunction}. + */ +public class DeduplicateFunctionTest extends BaseDeduplicateFunctionTest { + + private Time minTime = Time.milliseconds(10); + private Time maxTime = Time.milliseconds(20); + + private DeduplicateFunction createFunction(boolean generateRetraction, boolean keepLastRow) { + DeduplicateFunction func = new DeduplicateFunction(minTime.toMilliseconds(), maxTime.toMilliseconds(), + inputRowType, generateRetraction, keepLastRow, + generatedEqualiser); + return func; + } + + private OneInputStreamOperatorTestHarness createTestHarness( + DeduplicateFunction func) + throws Exception { + KeyedProcessOperator operator = new KeyedProcessOperator(func); + return new KeyedOneInputStreamOperatorTestHarness(operator, rowKeySelector, rowKeySelector.getProducedType()); + } + + @Test + public void testKeepFirstRowWithoutGenerateRetraction() throws Exception { + DeduplicateFunction func = createFunction(false, false); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 1L, 12)); + testHarness.processElement(record("book", 2L, 11)); + testHarness.processElement(record("book", 1L, 13)); + testHarness.close(); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 1L, 12)); + expectedOutputOutput.add(record("book", 2L, 11)); + assertor.assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + + @Test + public void testKeepFirstRowWithGenerateRetraction() throws Exception { + DeduplicateFunction func = createFunction(true, false); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 1L, 12)); + testHarness.processElement(record("book", 2L, 11)); + testHarness.processElement(record("book", 1L, 13)); + testHarness.close(); + + // Keep FirstRow in deduplicate will not send retraction + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 1L, 12)); + expectedOutputOutput.add(record("book", 2L, 11)); + assertor.assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + + @Test + public void testKeepLastWithoutGenerateRetraction() throws Exception { + DeduplicateFunction func = createFunction(false, true); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 1L, 12)); + testHarness.processElement(record("book", 2L, 11)); + testHarness.processElement(record("book", 1L, 13)); + testHarness.close(); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 1L, 12)); + expectedOutputOutput.add(record("book", 2L, 11)); + expectedOutputOutput.add(record("book", 1L, 13)); + assertor.assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + + @Test + public void testKeepLastRowWithGenerateRetraction() throws Exception { + DeduplicateFunction func = createFunction(true, true); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 1L, 12)); + testHarness.processElement(record("book", 2L, 11)); + testHarness.processElement(record("book", 1L, 13)); + testHarness.close(); + + // Keep LastRow in deduplicate may send retraction + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 1L, 12)); + expectedOutputOutput.add(retractRecord("book", 1L, 12)); + expectedOutputOutput.add(record("book", 2L, 11)); + expectedOutputOutput.add(record("book", 1L, 13)); + assertor.assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } +} diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateFunctionTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateFunctionTest.java new file mode 100644 index 0000000000000..aa4602ff7631d --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateFunctionTest.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.deduplicate; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; +import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.runtime.bundle.KeyedMapBundleOperator; +import org.apache.flink.table.runtime.bundle.trigger.CountBundleTrigger; + +import org.junit.Assert; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.apache.flink.table.runtime.util.StreamRecordUtils.record; +import static org.apache.flink.table.runtime.util.StreamRecordUtils.retractRecord; + +/** + * Tests for {@link MiniBatchDeduplicateFunction}. + */ +public class MiniBatchDeduplicateFunctionTest extends BaseDeduplicateFunctionTest { + private TypeSerializer typeSerializer = inputRowType.createSerializer(new ExecutionConfig()); + + private MiniBatchDeduplicateFunction createFunction(boolean generateRetraction, boolean keepLastRow) { + MiniBatchDeduplicateFunction func = new MiniBatchDeduplicateFunction(inputRowType, generateRetraction, + typeSerializer, keepLastRow, generatedEqualiser); + return func; + } + + private OneInputStreamOperatorTestHarness createTestHarness( + MiniBatchDeduplicateFunction func) + throws Exception { + CountBundleTrigger> trigger = new CountBundleTrigger<>(3); + KeyedMapBundleOperator op = new KeyedMapBundleOperator(func, trigger); + return new KeyedOneInputStreamOperatorTestHarness(op, rowKeySelector, rowKeySelector.getProducedType()); + } + + @Test + public void testKeepFirstRowWithoutGenerateRetraction() throws Exception { + MiniBatchDeduplicateFunction func = createFunction(false, false); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 1L, 12)); + testHarness.processElement(record("book", 2L, 11)); + + // output is empty because bundle not trigger yet. + Assert.assertTrue(testHarness.getOutput().isEmpty()); + + testHarness.processElement(record("book", 1L, 13)); + // output is not empty because bundle is trigger. + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 1L, 12)); + expectedOutputOutput.add(record("book", 2L, 11)); + assertor.assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + testHarness.close(); + } + + @Test + public void testKeepFirstRowWithGenerateRetraction() throws Exception { + MiniBatchDeduplicateFunction func = createFunction(true, false); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 1L, 12)); + testHarness.processElement(record("book", 2L, 11)); + + // output is empty because bundle not trigger yet. + Assert.assertTrue(testHarness.getOutput().isEmpty()); + + testHarness.processElement(record("book", 1L, 13)); + + // Keep FirstRow in deduplicate will not send retraction + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 1L, 12)); + expectedOutputOutput.add(record("book", 2L, 11)); + assertor.assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + testHarness.close(); + } + + @Test + public void testKeepLastWithoutGenerateRetraction() throws Exception { + MiniBatchDeduplicateFunction func = createFunction(false, true); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 1L, 10)); + testHarness.processElement(record("book", 2L, 11)); + // output is empty because bundle not trigger yet. + Assert.assertTrue(testHarness.getOutput().isEmpty()); + + testHarness.processElement(record("book", 1L, 13)); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 2L, 11)); + expectedOutputOutput.add(record("book", 1L, 13)); + assertor.assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + + testHarness.processElement(record("book", 1L, 12)); + testHarness.processElement(record("book", 2L, 11)); + testHarness.processElement(record("book", 3L, 11)); + + expectedOutputOutput.add(record("book", 1L, 12)); + expectedOutputOutput.add(record("book", 3L, 11)); + testHarness.close(); + assertor.assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + + @Test + public void testKeepLastRowWithGenerateRetraction() throws Exception { + MiniBatchDeduplicateFunction func = createFunction(true, true); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 1L, 10)); + testHarness.processElement(record("book", 2L, 11)); + // output is empty because bundle not trigger yet. + Assert.assertTrue(testHarness.getOutput().isEmpty()); + + testHarness.processElement(record("book", 1L, 13)); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 2L, 11)); + expectedOutputOutput.add(record("book", 1L, 13)); + assertor.assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + + testHarness.processElement(record("book", 1L, 12)); + testHarness.processElement(record("book", 2L, 11)); + testHarness.processElement(record("book", 3L, 11)); + + // this will send retract message to downstream + expectedOutputOutput.add(retractRecord("book", 1L, 13)); + expectedOutputOutput.add(record("book", 1L, 12)); + expectedOutputOutput.add(record("book", 3L, 11)); + testHarness.close(); + assertor.assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } +} diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/AppendRankFunctionTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/AppendRankFunctionTest.java new file mode 100644 index 0000000000000..288be22469343 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/AppendRankFunctionTest.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; +import org.apache.flink.table.dataformat.BaseRow; + +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.apache.flink.table.runtime.util.StreamRecordUtils.record; +import static org.apache.flink.table.runtime.util.StreamRecordUtils.retractRecord; + +/** + * Tests for {@link AppendRankFunction}. + */ +public class AppendRankFunctionTest extends BaseRankFunctionTest { + + @Override + protected AbstractRankFunction createRankFunction(RankType rankType, RankRange rankRange, + boolean generateRetraction, boolean outputRankNumber) { + AbstractRankFunction rankFunction = new AppendRankFunction(minTime.toMilliseconds(), maxTime.toMilliseconds(), + inputRowType, sortKeyComparator, sortKeySelector, rankType, rankRange, generatedEqualiser, + generateRetraction, outputRankNumber, cacheSize); + return rankFunction; + } + + @Test + public void testVariableRankRange() throws Exception { + AbstractRankFunction func = createRankFunction(RankType.ROW_NUMBER, new VariableRankRange(1), true, false); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 2L, 12)); + testHarness.processElement(record("book", 2L, 19)); + testHarness.processElement(record("book", 2L, 11)); + testHarness.processElement(record("fruit", 1L, 33)); + testHarness.processElement(record("fruit", 1L, 44)); + testHarness.processElement(record("fruit", 1L, 22)); + testHarness.close(); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 2L, 12)); + expectedOutputOutput.add(record("book", 2L, 19)); + expectedOutputOutput.add(retractRecord("book", 2L, 19)); + expectedOutputOutput.add(record("book", 2L, 11)); + expectedOutputOutput.add(record("fruit", 1L, 33)); + expectedOutputOutput.add(retractRecord("fruit", 1L, 33)); + expectedOutputOutput.add(record("fruit", 1L, 22)); + assertorWithoutRowNumber + .assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/BaseRankFunctionTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/BaseRankFunctionTest.java new file mode 100644 index 0000000000000..9abc9d43e0845 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/BaseRankFunctionTest.java @@ -0,0 +1,337 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +import org.apache.flink.api.common.time.Time; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; +import org.apache.flink.streaming.api.operators.KeyedProcessOperator; +import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; +import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.generated.GeneratedRecordComparator; +import org.apache.flink.table.generated.GeneratedRecordEqualiser; +import org.apache.flink.table.generated.RecordComparator; +import org.apache.flink.table.generated.RecordEqualiser; +import org.apache.flink.table.runtime.sort.IntRecordComparator; +import org.apache.flink.table.runtime.util.BaseRowHarnessAssertor; +import org.apache.flink.table.runtime.util.BaseRowRecordEqualiser; +import org.apache.flink.table.runtime.util.GenericRowRecordSortComparator; +import org.apache.flink.table.runtime.util.BinaryRowKeySelector; +import org.apache.flink.table.type.InternalTypes; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; + +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.apache.flink.table.runtime.util.StreamRecordUtils.deleteRecord; +import static org.apache.flink.table.runtime.util.StreamRecordUtils.record; +import static org.apache.flink.table.runtime.util.StreamRecordUtils.retractRecord; + +/** + * Base Tests for all subclass of {@link AbstractRankFunction}. + */ +abstract class BaseRankFunctionTest { + + protected Time minTime = Time.milliseconds(10); + protected Time maxTime = Time.milliseconds(20); + protected long cacheSize = 10000L; + + BaseRowTypeInfo inputRowType = new BaseRowTypeInfo( + InternalTypes.STRING, + InternalTypes.LONG, + InternalTypes.INT); + + GeneratedRecordComparator sortKeyComparator = new GeneratedRecordComparator("", "", new Object[0]) { + + private static final long serialVersionUID = 1434685115916728955L; + + @Override + public RecordComparator newInstance(ClassLoader classLoader) { + + return IntRecordComparator.INSTANCE; + } + }; + + private int sortKeyIdx = 2; + + BinaryRowKeySelector sortKeySelector = new BinaryRowKeySelector(new int[] { sortKeyIdx }, + inputRowType.getInternalTypes()); + + GeneratedRecordEqualiser generatedEqualiser = new GeneratedRecordEqualiser("", "", new Object[0]) { + + private static final long serialVersionUID = 8932460173848746733L; + + @Override + public RecordEqualiser newInstance(ClassLoader classLoader) { + return new BaseRowRecordEqualiser(); + } + }; + + private int partitionKeyIdx = 0; + + private BinaryRowKeySelector keySelector = new BinaryRowKeySelector(new int[] { partitionKeyIdx }, + inputRowType.getInternalTypes()); + + private BaseRowTypeInfo outputTypeWithoutRowNumber = inputRowType; + + private BaseRowTypeInfo outputTypeWithRowNumber = new BaseRowTypeInfo( + InternalTypes.STRING, + InternalTypes.LONG, + InternalTypes.INT, + InternalTypes.LONG); + + BaseRowHarnessAssertor assertorWithoutRowNumber = new BaseRowHarnessAssertor( + outputTypeWithoutRowNumber.getFieldTypes(), + new GenericRowRecordSortComparator(sortKeyIdx, outputTypeWithoutRowNumber.getInternalTypes()[sortKeyIdx])); + + BaseRowHarnessAssertor assertorWithRowNumber = new BaseRowHarnessAssertor( + outputTypeWithRowNumber.getFieldTypes(), + new GenericRowRecordSortComparator(sortKeyIdx, outputTypeWithRowNumber.getInternalTypes()[sortKeyIdx])); + + // rowKey only used in UpdateRankFunction + private int rowKeyIdx = 1; + BinaryRowKeySelector rowKeySelector = new BinaryRowKeySelector(new int[] { rowKeyIdx }, + inputRowType.getInternalTypes()); + + @Test(expected = UnsupportedOperationException.class) + public void testInvalidVariableRankRangeWithIntType() throws Exception { + createRankFunction(RankType.ROW_NUMBER, new VariableRankRange(2), true, false); + } + + @Test(expected = UnsupportedOperationException.class) + public void testNotSupportRank() throws Exception { + createRankFunction(RankType.RANK, new ConstantRankRange(1, 10), true, true); + } + + @Test(expected = UnsupportedOperationException.class) + public void testNotSupportDenseRank() throws Exception { + createRankFunction(RankType.DENSE_RANK, new ConstantRankRange(1, 10), true, true); + } + + @Test(expected = UnsupportedOperationException.class) + public void testNotSupportWithoutRankEnd() throws Exception { + createRankFunction(RankType.ROW_NUMBER, new ConstantRankRangeWithoutEnd(1), true, true); + } + + @Test + public void testDisableGenerateRetraction() throws Exception { + AbstractRankFunction func = createRankFunction(RankType.ROW_NUMBER, new ConstantRankRange(1, 2), false, + false); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 1L, 12)); + testHarness.processElement(record("book", 2L, 19)); + testHarness.processElement(record("book", 3L, 19)); + testHarness.processElement(record("book", 4L, 11)); + testHarness.processElement(record("book", 5L, 11)); + testHarness.processElement(record("fruit", 4L, 33)); + testHarness.processElement(record("fruit", 3L, 44)); + testHarness.processElement(record("fruit", 5L, 22)); + testHarness.close(); + + // Notes: Delete message will be sent even disable generate retraction when not output rankNumber. + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 1L, 12)); + expectedOutputOutput.add(record("book", 2L, 19)); + expectedOutputOutput.add(deleteRecord("book", 2L, 19)); + expectedOutputOutput.add(record("book", 4L, 11)); + expectedOutputOutput.add(deleteRecord("book", 1L, 12)); + expectedOutputOutput.add(record("book", 5L, 11)); + expectedOutputOutput.add(record("fruit", 4L, 33)); + expectedOutputOutput.add(record("fruit", 3L, 44)); + expectedOutputOutput.add(deleteRecord("fruit", 3L, 44)); + expectedOutputOutput.add(record("fruit", 5L, 22)); + assertorWithoutRowNumber + .assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + + @Test + public void testDisableGenerateRetractionAndOutputRankNumber() throws Exception { + AbstractRankFunction func = createRankFunction(RankType.ROW_NUMBER, new ConstantRankRange(1, 2), false, + true); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 1L, 12)); + testHarness.processElement(record("book", 2L, 19)); + testHarness.processElement(record("book", 4L, 11)); + testHarness.processElement(record("book", 5L, 11)); + testHarness.processElement(record("fruit", 4L, 33)); + testHarness.processElement(record("fruit", 3L, 44)); + testHarness.processElement(record("fruit", 5L, 22)); + testHarness.close(); + + // Notes: Retract message will not be sent if disable generate retraction and output rankNumber. + // Because partition key + rankNumber decomposes a uniqueKey. + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 1L, 12, 1L)); + expectedOutputOutput.add(record("book", 2L, 19, 2L)); + expectedOutputOutput.add(record("book", 4L, 11, 1L)); + expectedOutputOutput.add(record("book", 1L, 12, 2L)); + expectedOutputOutput.add(record("book", 5L, 11, 2L)); + expectedOutputOutput.add(record("fruit", 4L, 33, 1L)); + expectedOutputOutput.add(record("fruit", 3L, 44, 2L)); + expectedOutputOutput.add(record("fruit", 5L, 22, 1L)); + expectedOutputOutput.add(record("fruit", 4L, 33, 2L)); + assertorWithRowNumber.assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + + @Test + public void testOutputRankNumberWithConstantRankRange() throws Exception { + AbstractRankFunction func = createRankFunction(RankType.ROW_NUMBER, new ConstantRankRange(1, 2), true, + true); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 1L, 12)); + testHarness.processElement(record("book", 2L, 19)); + testHarness.processElement(record("book", 4L, 11)); + testHarness.processElement(record("book", 5L, 11)); + testHarness.processElement(record("fruit", 4L, 33)); + testHarness.processElement(record("fruit", 3L, 44)); + testHarness.processElement(record("fruit", 5L, 22)); + testHarness.close(); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 1L, 12, 1L)); + expectedOutputOutput.add(record("book", 2L, 19, 2L)); + expectedOutputOutput.add(retractRecord("book", 1L, 12, 1L)); + expectedOutputOutput.add(retractRecord("book", 2L, 19, 2L)); + expectedOutputOutput.add(record("book", 4L, 11, 1L)); + expectedOutputOutput.add(record("book", 1L, 12, 2L)); + expectedOutputOutput.add(retractRecord("book", 1L, 12, 2L)); + expectedOutputOutput.add(record("book", 5L, 11, 2L)); + expectedOutputOutput.add(record("fruit", 4L, 33, 1L)); + expectedOutputOutput.add(record("fruit", 3L, 44, 2L)); + expectedOutputOutput.add(retractRecord("fruit", 4L, 33, 1L)); + expectedOutputOutput.add(retractRecord("fruit", 3L, 44, 2L)); + expectedOutputOutput.add(record("fruit", 4L, 33, 2L)); + expectedOutputOutput.add(record("fruit", 5L, 22, 1L)); + assertorWithRowNumber + .assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + + @Test + public void testConstantRankRangeWithOffset() throws Exception { + AbstractRankFunction func = createRankFunction(RankType.ROW_NUMBER, new ConstantRankRange(2, 2), true, + false); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 1L, 12)); + testHarness.processElement(record("book", 2L, 19)); + testHarness.processElement(record("book", 4L, 11)); + testHarness.processElement(record("fruit", 4L, 33)); + testHarness.processElement(record("fruit", 3L, 44)); + testHarness.processElement(record("fruit", 5L, 22)); + testHarness.close(); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 2L, 19)); + expectedOutputOutput.add(retractRecord("book", 2L, 19)); + expectedOutputOutput.add(record("book", 1L, 12)); + expectedOutputOutput.add(record("fruit", 3L, 44)); + expectedOutputOutput.add(retractRecord("fruit", 3L, 44)); + expectedOutputOutput.add(record("fruit", 4L, 33)); + assertorWithoutRowNumber + .assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + + @Test + public void testOutputRankNumberWithVariableRankRange() throws Exception { + AbstractRankFunction func = createRankFunction(RankType.ROW_NUMBER, new VariableRankRange(1), true, true); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 2L, 12)); + testHarness.processElement(record("book", 2L, 19)); + testHarness.processElement(record("book", 2L, 11)); + testHarness.processElement(record("fruit", 1L, 33)); + testHarness.processElement(record("fruit", 1L, 44)); + testHarness.processElement(record("fruit", 1L, 22)); + testHarness.close(); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 2L, 12, 1L)); + expectedOutputOutput.add(record("book", 2L, 19, 2L)); + expectedOutputOutput.add(retractRecord("book", 2L, 12, 1L)); + expectedOutputOutput.add(retractRecord("book", 2L, 19, 2L)); + expectedOutputOutput.add(record("book", 2L, 11, 1L)); + expectedOutputOutput.add(record("book", 2L, 12, 2L)); + expectedOutputOutput.add(record("fruit", 1L, 33, 1L)); + expectedOutputOutput.add(retractRecord("fruit", 1L, 33, 1L)); + expectedOutputOutput.add(record("fruit", 1L, 22, 1L)); + assertorWithRowNumber + .assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + + @Test + public void testConstantRankRangeWithoutOffset() throws Exception { + AbstractRankFunction func = createRankFunction(RankType.ROW_NUMBER, new ConstantRankRange(1, 2), true, + false); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 1L, 12)); + testHarness.processElement(record("book", 2L, 19)); + testHarness.processElement(record("book", 4L, 11)); + testHarness.processElement(record("fruit", 4L, 33)); + testHarness.processElement(record("fruit", 3L, 44)); + testHarness.processElement(record("fruit", 5L, 22)); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 1L, 12)); + expectedOutputOutput.add(record("book", 2L, 19)); + expectedOutputOutput.add(retractRecord("book", 2L, 19)); + expectedOutputOutput.add(record("book", 4L, 11)); + expectedOutputOutput.add(record("fruit", 4L, 33)); + expectedOutputOutput.add(record("fruit", 3L, 44)); + expectedOutputOutput.add(retractRecord("fruit", 3L, 44)); + expectedOutputOutput.add(record("fruit", 5L, 22)); + assertorWithoutRowNumber + .assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + + // do a snapshot, data could be recovered from state + OperatorSubtaskState snapshot = testHarness.snapshot(0L, 0); + testHarness.close(); + expectedOutputOutput.clear(); + + func = createRankFunction(RankType.ROW_NUMBER, new ConstantRankRange(1, 2), true, false); + testHarness = createTestHarness(func); + testHarness.setup(); + testHarness.initializeState(snapshot); + testHarness.open(); + testHarness.processElement(record("book", 1L, 10)); + testHarness.close(); + + expectedOutputOutput.add(retractRecord("book", 1L, 12)); + expectedOutputOutput.add(record("book", 1L, 10)); + assertorWithoutRowNumber + .assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + + protected OneInputStreamOperatorTestHarness createTestHarness( + AbstractRankFunction rankFunction) + throws Exception { + KeyedProcessOperator operator = new KeyedProcessOperator(rankFunction); + rankFunction.setKeyContext(operator); + return new KeyedOneInputStreamOperatorTestHarness(operator, keySelector, keySelector.getProducedType()); + } + + protected abstract AbstractRankFunction createRankFunction(RankType rankType, RankRange rankRange, + boolean generateRetraction, boolean outputRankNumber); + +} diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/RetractRankFunctionTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/RetractRankFunctionTest.java new file mode 100644 index 0000000000000..c068412470a02 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/RetractRankFunctionTest.java @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; +import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; +import org.apache.flink.table.dataformat.BaseRow; + +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.apache.flink.table.runtime.util.StreamRecordUtils.deleteRecord; +import static org.apache.flink.table.runtime.util.StreamRecordUtils.record; +import static org.apache.flink.table.runtime.util.StreamRecordUtils.retractRecord; + +/** + * Tests for {@link RetractRankFunction}. + */ +public class RetractRankFunctionTest extends BaseRankFunctionTest { + + @Override + protected AbstractRankFunction createRankFunction(RankType rankType, RankRange rankRange, + boolean generateRetraction, boolean outputRankNumber) { + AbstractRankFunction rankFunction = new RetractRankFunction(minTime.toMilliseconds(), maxTime.toMilliseconds(), + inputRowType, sortKeyComparator, sortKeySelector, rankType, rankRange, generatedEqualiser, + generateRetraction, outputRankNumber); + return rankFunction; + } + + @Test + public void testProcessRetractMessageWithNotGenerateRetraction() throws Exception { + AbstractRankFunction func = createRankFunction(RankType.ROW_NUMBER, new ConstantRankRange(1, 2), false, + true); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 1L, 12)); + testHarness.processElement(record("book", 2L, 19)); + testHarness.processElement(record("book", 4L, 11)); + testHarness.processElement(retractRecord("book", 1L, 12)); + testHarness.processElement(record("book", 5L, 11)); + testHarness.processElement(record("fruit", 4L, 33)); + testHarness.processElement(record("fruit", 3L, 44)); + testHarness.processElement(record("fruit", 5L, 22)); + testHarness.close(); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 1L, 12, 1L)); + expectedOutputOutput.add(record("book", 2L, 19, 2L)); + expectedOutputOutput.add(record("book", 4L, 11, 1L)); + expectedOutputOutput.add(record("book", 1L, 12, 2L)); + expectedOutputOutput.add(deleteRecord("book", 1L, 12, 2L)); + expectedOutputOutput.add(record("book", 2L, 19, 2L)); + expectedOutputOutput.add(record("book", 5L, 11, 2L)); + expectedOutputOutput.add(record("fruit", 4L, 33, 1L)); + expectedOutputOutput.add(record("fruit", 3L, 44, 2L)); + expectedOutputOutput.add(record("fruit", 5L, 22, 1L)); + expectedOutputOutput.add(record("fruit", 4L, 33, 2L)); + assertorWithRowNumber.assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + + @Test + public void testProcessRetractMessageWithGenerateRetraction() throws Exception { + AbstractRankFunction func = createRankFunction(RankType.ROW_NUMBER, new ConstantRankRange(1, 2), true, + true); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 1L, 12)); + testHarness.processElement(record("book", 2L, 19)); + testHarness.processElement(record("book", 4L, 11)); + testHarness.processElement(retractRecord("book", 1L, 12)); + testHarness.processElement(record("book", 5L, 11)); + testHarness.processElement(record("fruit", 4L, 33)); + testHarness.processElement(record("fruit", 3L, 44)); + testHarness.processElement(record("fruit", 5L, 22)); + testHarness.close(); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 1L, 12, 1L)); + expectedOutputOutput.add(record("book", 2L, 19, 2L)); + expectedOutputOutput.add(retractRecord("book", 2L, 19, 2L)); + expectedOutputOutput.add(retractRecord("book", 1L, 12, 1L)); + expectedOutputOutput.add(record("book", 4L, 11, 1L)); + expectedOutputOutput.add(record("book", 1L, 12, 2L)); + expectedOutputOutput.add(retractRecord("book", 1L, 12, 2L)); + expectedOutputOutput.add(record("book", 2L, 19, 2L)); + expectedOutputOutput.add(retractRecord("book", 2L, 19, 2L)); + expectedOutputOutput.add(record("book", 5L, 11, 2L)); + expectedOutputOutput.add(record("fruit", 4L, 33, 1L)); + expectedOutputOutput.add(record("fruit", 3L, 44, 2L)); + expectedOutputOutput.add(retractRecord("fruit", 4L, 33, 1L)); + expectedOutputOutput.add(retractRecord("fruit", 3L, 44, 2L)); + expectedOutputOutput.add(record("fruit", 5L, 22, 1L)); + expectedOutputOutput.add(record("fruit", 4L, 33, 2L)); + assertorWithRowNumber.assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + + // TODO RetractRankFunction could be sent less retraction message when does not need to retract row_number + @Override + @Test + public void testConstantRankRangeWithoutOffset() throws Exception { + AbstractRankFunction func = createRankFunction(RankType.ROW_NUMBER, new ConstantRankRange(1, 2), true, + false); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 1L, 12)); + testHarness.processElement(record("book", 2L, 19)); + testHarness.processElement(record("book", 4L, 11)); + testHarness.processElement(record("fruit", 4L, 33)); + testHarness.processElement(record("fruit", 3L, 44)); + testHarness.processElement(record("fruit", 5L, 22)); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 1L, 12)); + expectedOutputOutput.add(record("book", 2L, 19)); + expectedOutputOutput.add(retractRecord("book", 1L, 12)); + expectedOutputOutput.add(record("book", 1L, 12)); + expectedOutputOutput.add(retractRecord("book", 2L, 19)); + expectedOutputOutput.add(record("book", 4L, 11)); + expectedOutputOutput.add(record("fruit", 4L, 33)); + expectedOutputOutput.add(record("fruit", 3L, 44)); + expectedOutputOutput.add(retractRecord("fruit", 4L, 33)); + expectedOutputOutput.add(retractRecord("fruit", 3L, 44)); + expectedOutputOutput.add(record("fruit", 4L, 33)); + expectedOutputOutput.add(record("fruit", 5L, 22)); + assertorWithoutRowNumber + .assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + + // do a snapshot, data could be recovered from state + OperatorSubtaskState snapshot = testHarness.snapshot(0L, 0); + testHarness.close(); + expectedOutputOutput.clear(); + + func = createRankFunction(RankType.ROW_NUMBER, new ConstantRankRange(1, 2), true, false); + testHarness = createTestHarness(func); + testHarness.setup(); + testHarness.initializeState(snapshot); + testHarness.open(); + testHarness.processElement(record("book", 1L, 10)); + + expectedOutputOutput.add(retractRecord("book", 1L, 12)); + expectedOutputOutput.add(retractRecord("book", 4L, 11)); + expectedOutputOutput.add(record("book", 4L, 11)); + expectedOutputOutput.add(record("book", 1L, 10)); + assertorWithoutRowNumber + .assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + testHarness.close(); + } + + // TODO RetractRankFunction could be sent less retraction message when does not need to retract row_number + @Test + public void testVariableRankRange() throws Exception { + AbstractRankFunction func = createRankFunction(RankType.ROW_NUMBER, new VariableRankRange(1), true, false); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 2L, 12)); + testHarness.processElement(record("book", 2L, 19)); + testHarness.processElement(record("book", 2L, 11)); + testHarness.processElement(record("fruit", 1L, 33)); + testHarness.processElement(record("fruit", 1L, 44)); + testHarness.processElement(record("fruit", 1L, 22)); + testHarness.close(); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 2L, 12)); + expectedOutputOutput.add(record("book", 2L, 19)); + expectedOutputOutput.add(retractRecord("book", 2L, 19)); + expectedOutputOutput.add(retractRecord("book", 2L, 12)); + expectedOutputOutput.add(record("book", 2L, 12)); + expectedOutputOutput.add(record("book", 2L, 11)); + expectedOutputOutput.add(record("fruit", 1L, 33)); + expectedOutputOutput.add(retractRecord("fruit", 1L, 33)); + expectedOutputOutput.add(record("fruit", 1L, 22)); + assertorWithoutRowNumber + .assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + + // TODO + @Test + public void testDisableGenerateRetraction() throws Exception { + AbstractRankFunction func = createRankFunction(RankType.ROW_NUMBER, new ConstantRankRange(1, 2), false, + false); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 1L, 12)); + testHarness.processElement(record("book", 2L, 19)); + testHarness.processElement(record("book", 4L, 11)); + testHarness.processElement(record("fruit", 4L, 33)); + testHarness.processElement(record("fruit", 3L, 44)); + testHarness.processElement(record("fruit", 5L, 22)); + testHarness.close(); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 1L, 12)); + expectedOutputOutput.add(record("book", 2L, 19)); + expectedOutputOutput.add(record("book", 1L, 12)); + expectedOutputOutput.add(record("book", 4L, 11)); + expectedOutputOutput.add(record("fruit", 4L, 33)); + expectedOutputOutput.add(record("fruit", 3L, 44)); + expectedOutputOutput.add(record("fruit", 4L, 33)); + expectedOutputOutput.add(record("fruit", 5L, 22)); + assertorWithoutRowNumber + .assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/UpdateRankFunctionTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/UpdateRankFunctionTest.java new file mode 100644 index 0000000000000..b657dab26d85c --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/UpdateRankFunctionTest.java @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; +import org.apache.flink.table.dataformat.BaseRow; + +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.apache.flink.table.runtime.util.StreamRecordUtils.deleteRecord; +import static org.apache.flink.table.runtime.util.StreamRecordUtils.record; +import static org.apache.flink.table.runtime.util.StreamRecordUtils.retractRecord; + +/** + * Tests for {@link UpdateRankFunction}. + */ +public class UpdateRankFunctionTest extends BaseRankFunctionTest { + + @Override + protected AbstractRankFunction createRankFunction(RankType rankType, RankRange rankRange, + boolean generateRetraction, boolean outputRankNumber) { + + AbstractRankFunction rankFunction = new UpdateRankFunction(minTime.toMilliseconds(), maxTime.toMilliseconds(), + inputRowType, rowKeySelector, sortKeyComparator, sortKeySelector, rankType, rankRange, + generatedEqualiser, generateRetraction, outputRankNumber, cacheSize); + return rankFunction; + } + + @Test + public void testVariableRankRange() throws Exception { + AbstractRankFunction func = createRankFunction(RankType.ROW_NUMBER, + new VariableRankRange(1), true, false); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 2L, 19)); + testHarness.processElement(record("book", 2L, 18)); + testHarness.processElement(record("fruit", 1L, 44)); + testHarness.processElement(record("fruit", 1L, 33)); + testHarness.processElement(record("fruit", 1L, 22)); + testHarness.close(); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 2L, 19)); + expectedOutputOutput.add(retractRecord("book", 2L, 19)); + expectedOutputOutput.add(record("book", 2L, 18)); + expectedOutputOutput.add(record("fruit", 1L, 44)); + expectedOutputOutput.add(retractRecord("fruit", 1L, 44)); + expectedOutputOutput.add(record("fruit", 1L, 33)); + expectedOutputOutput.add(retractRecord("fruit", 1L, 33)); + expectedOutputOutput.add(record("fruit", 1L, 22)); + assertorWithoutRowNumber + .assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + + @Override + @Test + public void testOutputRankNumberWithVariableRankRange() throws Exception { + AbstractRankFunction func = createRankFunction(RankType.ROW_NUMBER, + new VariableRankRange(1), true, true); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 2L, 19)); + testHarness.processElement(record("book", 2L, 12)); + testHarness.processElement(record("book", 2L, 11)); + testHarness.processElement(record("fruit", 1L, 44)); + testHarness.processElement(record("fruit", 1L, 33)); + testHarness.processElement(record("fruit", 1L, 22)); + testHarness.close(); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 2L, 19, 1L)); + expectedOutputOutput.add(retractRecord("book", 2L, 19, 1L)); + expectedOutputOutput.add(record("book", 2L, 12, 1L)); + expectedOutputOutput.add(retractRecord("book", 2L, 12, 1L)); + expectedOutputOutput.add(record("book", 2L, 11, 1L)); + expectedOutputOutput.add(record("fruit", 1L, 44, 1L)); + expectedOutputOutput.add(retractRecord("fruit", 1L, 44, 1L)); + expectedOutputOutput.add(record("fruit", 1L, 33, 1L)); + expectedOutputOutput.add(retractRecord("fruit", 1L, 33, 1L)); + expectedOutputOutput.add(record("fruit", 1L, 22, 1L)); + assertorWithRowNumber + .assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + + @Test + public void testSortKeyChangesWhenOutputRankNumber() throws Exception { + AbstractRankFunction func = createRankFunction(RankType.ROW_NUMBER, + new ConstantRankRange(1, 2), true, true); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 2L, 19)); + testHarness.processElement(record("book", 3L, 16)); + testHarness.processElement(record("book", 2L, 11)); + testHarness.processElement(record("book", 3L, 15)); + testHarness.processElement(record("book", 4L, 2)); + testHarness.processElement(record("book", 2L, 1)); + testHarness.close(); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 2L, 19, 1L)); + expectedOutputOutput.add(retractRecord("book", 2L, 19, 1L)); + expectedOutputOutput.add(record("book", 2L, 19, 2L)); + expectedOutputOutput.add(record("book", 3L, 16, 1L)); + expectedOutputOutput.add(retractRecord("book", 3L, 16, 1L)); + expectedOutputOutput.add(record("book", 3L, 16, 2L)); + expectedOutputOutput.add(retractRecord("book", 2L, 19, 2L)); + expectedOutputOutput.add(record("book", 2L, 11, 1L)); + expectedOutputOutput.add(retractRecord("book", 3L, 16, 2L)); + expectedOutputOutput.add(record("book", 3L, 15, 2L)); + expectedOutputOutput.add(retractRecord("book", 3L, 15, 2L)); + expectedOutputOutput.add(retractRecord("book", 2L, 11, 1L)); + expectedOutputOutput.add(record("book", 2L, 11, 2L)); + expectedOutputOutput.add(record("book", 4L, 2, 1L)); + expectedOutputOutput.add(retractRecord("book", 2L, 11, 2L)); + expectedOutputOutput.add(retractRecord("book", 4L, 2, 1L)); + expectedOutputOutput.add(record("book", 2L, 1, 1L)); + expectedOutputOutput.add(record("book", 4L, 2, 2L)); + + assertorWithRowNumber + .assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + + @Test + public void testSortKeyChangesWhenOutputRankNumberAndNotGenerateRetraction() throws Exception { + AbstractRankFunction func = createRankFunction(RankType.ROW_NUMBER, + new ConstantRankRange(1, 2), false, true); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 2L, 19)); + testHarness.processElement(record("book", 3L, 16)); + testHarness.processElement(record("book", 2L, 11)); + testHarness.processElement(record("book", 3L, 15)); + testHarness.processElement(record("book", 4L, 2)); + testHarness.processElement(record("book", 2L, 1)); + testHarness.close(); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 2L, 19, 1L)); + expectedOutputOutput.add(record("book", 2L, 19, 2L)); + expectedOutputOutput.add(record("book", 3L, 16, 1L)); + expectedOutputOutput.add(record("book", 3L, 16, 2L)); + expectedOutputOutput.add(record("book", 2L, 11, 1L)); + expectedOutputOutput.add(record("book", 3L, 15, 2L)); + expectedOutputOutput.add(record("book", 2L, 11, 2L)); + expectedOutputOutput.add(record("book", 4L, 2, 1L)); + expectedOutputOutput.add(record("book", 2L, 1, 1L)); + expectedOutputOutput.add(record("book", 4L, 2, 2L)); + + assertorWithRowNumber + .assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + + @Test + public void testSortKeyChangesWhenNotOutputRankNumber() throws Exception { + AbstractRankFunction func = createRankFunction(RankType.ROW_NUMBER, + new ConstantRankRange(1, 2), true, false); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 2L, 19)); + testHarness.processElement(record("book", 3L, 16)); + testHarness.processElement(record("book", 2L, 11)); + testHarness.processElement(record("book", 3L, 15)); + testHarness.processElement(record("book", 4L, 2)); + testHarness.processElement(record("book", 2L, 1)); + testHarness.close(); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 2L, 19)); + expectedOutputOutput.add(record("book", 3L, 16)); + expectedOutputOutput.add(retractRecord("book", 2L, 19)); + expectedOutputOutput.add(record("book", 2L, 11)); + expectedOutputOutput.add(retractRecord("book", 3L, 16)); + expectedOutputOutput.add(record("book", 3L, 15)); + expectedOutputOutput.add(record("book", 4L, 2)); + expectedOutputOutput.add(retractRecord("book", 3L, 15)); + expectedOutputOutput.add(record("book", 2L, 1)); + expectedOutputOutput.add(retractRecord("book", 2L, 11)); + + assertorWithRowNumber + .assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + + @Test + public void testSortKeyChangesWhenNotOutputRankNumberAndNotGenerateRetraction() throws Exception { + AbstractRankFunction func = createRankFunction(RankType.ROW_NUMBER, + new ConstantRankRange(1, 2), false, false); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 2L, 19)); + testHarness.processElement(record("book", 3L, 16)); + testHarness.processElement(record("book", 2L, 11)); + testHarness.processElement(record("book", 3L, 15)); + testHarness.processElement(record("book", 4L, 2)); + testHarness.processElement(record("book", 2L, 1)); + testHarness.close(); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 2L, 19)); + expectedOutputOutput.add(record("book", 3L, 16)); + expectedOutputOutput.add(record("book", 2L, 11)); + expectedOutputOutput.add(record("book", 3L, 15)); + expectedOutputOutput.add(record("book", 4L, 2)); + expectedOutputOutput.add(deleteRecord("book", 3L, 15)); + expectedOutputOutput.add(record("book", 2L, 1)); + + assertorWithRowNumber + .assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } +} diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/sort/IntRecordComparator.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/sort/IntRecordComparator.java index 12ac7df525c97..6b5abcd2e4857 100644 --- a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/sort/IntRecordComparator.java +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/sort/IntRecordComparator.java @@ -41,4 +41,8 @@ public int compare(BaseRow o1, BaseRow o2) { return 0; } + @Override + public boolean equals(Object obj) { + return obj instanceof IntRecordComparator; + } } diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/util/BaseRowRecordEqualiser.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/util/BaseRowRecordEqualiser.java new file mode 100644 index 0000000000000..fd867f07bb274 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/util/BaseRowRecordEqualiser.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.util; + +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.dataformat.BinaryRow; +import org.apache.flink.table.dataformat.GenericRow; +import org.apache.flink.table.generated.RecordEqualiser; + +/** + * A utility class to check whether two BaseRow are equal. + * Note: Only support to compare two BinaryRows or two GenericRows. + */ +public class BaseRowRecordEqualiser implements RecordEqualiser { + + @Override + public boolean equals(BaseRow row1, BaseRow row2) { + if (row1 instanceof BinaryRow && row2 instanceof BinaryRow) { + return row1.equals(row2); + } else if (row1 instanceof GenericRow && row2 instanceof GenericRow) { + return row1.equals(row2); + } else { + throw new UnsupportedOperationException(); + } + + } + + @Override + public boolean equalsWithoutHeader(BaseRow row1, BaseRow row2) { + if (row1 instanceof BinaryRow && row2 instanceof BinaryRow) { + return ((BinaryRow) row1).equalsWithoutHeader(row2); + } else if (row1 instanceof GenericRow && row2 instanceof GenericRow) { + return ((GenericRow) row1).equalsWithoutHeader(row2); + } else { + throw new UnsupportedOperationException(); + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/util/BinaryRowKeySelector.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/util/BinaryRowKeySelector.java index 3b5b9e8822983..e697ed80c227a 100644 --- a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/util/BinaryRowKeySelector.java +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/util/BinaryRowKeySelector.java @@ -18,21 +18,19 @@ package org.apache.flink.table.runtime.util; -import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.api.java.functions.KeySelector; -import org.apache.flink.api.java.typeutils.ResultTypeQueryable; import org.apache.flink.table.dataformat.BaseRow; import org.apache.flink.table.dataformat.BinaryRow; import org.apache.flink.table.dataformat.BinaryRowWriter; import org.apache.flink.table.dataformat.BinaryWriter; import org.apache.flink.table.dataformat.TypeGetterSetters; +import org.apache.flink.table.runtime.keyselector.BaseRowKeySelector; import org.apache.flink.table.type.InternalType; import org.apache.flink.table.typeutils.BaseRowTypeInfo; /** * A utility class which will extract key from BaseRow. */ -public class BinaryRowKeySelector implements KeySelector, ResultTypeQueryable { +public class BinaryRowKeySelector implements BaseRowKeySelector { private final int[] keyFields; private final InternalType[] inputFieldTypes; @@ -67,7 +65,7 @@ public BaseRow getKey(BaseRow value) throws Exception { } @Override - public TypeInformation getProducedType() { + public BaseRowTypeInfo getProducedType() { return new BaseRowTypeInfo(keyFieldTypes); } } diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/util/GenericRowRecordSortComparator.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/util/GenericRowRecordSortComparator.java new file mode 100644 index 0000000000000..c572f8b9e2e38 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/util/GenericRowRecordSortComparator.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.util; + +import org.apache.flink.table.dataformat.GenericRow; +import org.apache.flink.table.dataformat.TypeGetterSetters; +import org.apache.flink.table.type.InternalType; + +import java.io.Serializable; +import java.util.Comparator; + +/** + * A utility class to compare two GenericRow based on sortKey value. + * Note: Only support sortKey is Comparable value. + */ +public class GenericRowRecordSortComparator implements Comparator, Serializable { + + private static final long serialVersionUID = -4988371592272863772L; + + private final int sortKeyIdx; + private final InternalType sortKeyType; + + public GenericRowRecordSortComparator(int sortKeyIdx, InternalType sortKeyType) { + this.sortKeyIdx = sortKeyIdx; + this.sortKeyType = sortKeyType; + } + + @Override + public int compare(GenericRow row1, GenericRow row2) { + byte header1 = row1.getHeader(); + byte header2 = row2.getHeader(); + if (header1 != header2) { + return header1 - header2; + } else { + Object key1 = TypeGetterSetters.get(row1, sortKeyIdx, sortKeyType); + Object key2 = TypeGetterSetters.get(row2, sortKeyIdx, sortKeyType); + if (key1 instanceof Comparable) { + return ((Comparable) key1).compareTo(key2); + } else { + throw new UnsupportedOperationException(); + } + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/util/StreamRecordUtils.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/util/StreamRecordUtils.java new file mode 100644 index 0000000000000..7cf7d5b5e90f5 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/util/StreamRecordUtils.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.util; + +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.dataformat.GenericRow; +import org.apache.flink.table.dataformat.util.BaseRowUtil; + +import static org.apache.flink.table.dataformat.BinaryString.fromString; + +/** + * Utilities to generate a StreamRecord which encapsulates value of BaseRow type. + */ +public class StreamRecordUtils { + + public static StreamRecord record(String key, Object... fields) { + return new StreamRecord<>(baserow(key, fields)); + } + + public static StreamRecord retractRecord(String key, Object... fields) { + BaseRow row = baserow(key, fields); + BaseRowUtil.setRetract(row); + return new StreamRecord<>(row); + } + + public static StreamRecord deleteRecord(String key, Object... fields) { + BaseRow row = baserow(key, fields); + BaseRowUtil.setRetract(row); + return new StreamRecord<>(row); + } + + public static BaseRow baserow(String key, Object... fields) { + Object[] objects = new Object[fields.length + 1]; + objects[0] = fromString(key); + System.arraycopy(fields, 0, objects, 1, fields.length); + return GenericRow.of(objects); + } +} diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/window/WindowOperatorContractTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/window/WindowOperatorContractTest.java index 54863a6d235d9..4961892c2eb8d 100644 --- a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/window/WindowOperatorContractTest.java +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/window/WindowOperatorContractTest.java @@ -44,8 +44,8 @@ import java.util.Arrays; import java.util.Collections; -import static org.apache.flink.table.runtime.window.WindowTestUtils.baserow; -import static org.apache.flink.table.runtime.window.WindowTestUtils.record; +import static org.apache.flink.table.runtime.util.StreamRecordUtils.baserow; +import static org.apache.flink.table.runtime.util.StreamRecordUtils.record; import static org.junit.Assert.assertEquals; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyLong; diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/window/WindowOperatorTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/window/WindowOperatorTest.java index 0004fc0ba418e..c0338f1a4453a 100644 --- a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/window/WindowOperatorTest.java +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/window/WindowOperatorTest.java @@ -35,6 +35,7 @@ import org.apache.flink.table.runtime.context.ExecutionContext; import org.apache.flink.table.runtime.util.BaseRowHarnessAssertor; import org.apache.flink.table.runtime.util.BinaryRowKeySelector; +import org.apache.flink.table.runtime.util.GenericRowRecordSortComparator; import org.apache.flink.table.runtime.window.assigners.SessionWindowAssigner; import org.apache.flink.table.runtime.window.assigners.TumblingWindowAssigner; import org.apache.flink.table.runtime.window.assigners.WindowAssigner; @@ -47,17 +48,15 @@ import org.junit.Test; -import java.io.Serializable; import java.time.Duration; import java.util.Collection; import java.util.Collections; -import java.util.Comparator; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicInteger; import static org.apache.flink.table.dataformat.BinaryString.fromString; -import static org.apache.flink.table.runtime.window.WindowTestUtils.record; -import static org.apache.flink.table.runtime.window.WindowTestUtils.retractRecord; +import static org.apache.flink.table.runtime.util.StreamRecordUtils.record; +import static org.apache.flink.table.runtime.util.StreamRecordUtils.retractRecord; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -70,10 +69,10 @@ public class WindowOperatorTest { // For counting if close() is called the correct number of times on the SumReducer private static AtomicInteger closeCalled = new AtomicInteger(0); - private InternalType[] inputFieldTypes = new InternalType[]{ + private InternalType[] inputFieldTypes = new InternalType[] { InternalTypes.STRING, InternalTypes.INT, - InternalTypes.LONG}; + InternalTypes.LONG }; private BaseRowTypeInfo outputType = new BaseRowTypeInfo( InternalTypes.STRING, @@ -83,15 +82,15 @@ public class WindowOperatorTest { InternalTypes.LONG, InternalTypes.LONG); - private InternalType[] aggResultTypes = new InternalType[]{InternalTypes.LONG, InternalTypes.LONG}; - private InternalType[] accTypes = new InternalType[]{InternalTypes.LONG, InternalTypes.LONG}; - private InternalType[] windowTypes = new InternalType[]{InternalTypes.LONG, InternalTypes.LONG, InternalTypes.LONG}; + private InternalType[] aggResultTypes = new InternalType[] { InternalTypes.LONG, InternalTypes.LONG }; + private InternalType[] accTypes = new InternalType[] { InternalTypes.LONG, InternalTypes.LONG }; + private InternalType[] windowTypes = new InternalType[] { InternalTypes.LONG, InternalTypes.LONG, InternalTypes.LONG }; private GenericRowEqualiser equaliser = new GenericRowEqualiser(accTypes, windowTypes); - private BinaryRowKeySelector keySelector = new BinaryRowKeySelector(new int[]{0}, inputFieldTypes); + private BinaryRowKeySelector keySelector = new BinaryRowKeySelector(new int[] { 0 }, inputFieldTypes); private TypeInformation keyType = keySelector.getProducedType(); private BaseRowHarnessAssertor assertor = new BaseRowHarnessAssertor( outputType.getFieldTypes(), - new GenericRowResultSortComparator()); + new GenericRowRecordSortComparator(0, InternalTypes.STRING)); @Test public void testEventTimeSlidingWindows() throws Exception { @@ -709,7 +708,7 @@ public void testProcessingTimeSessionWindows() throws Throwable { OneInputStreamOperatorTestHarness testHarness = createTestHarness(operator); BaseRowHarnessAssertor assertor = new BaseRowHarnessAssertor( - outputType.getFieldTypes(), new GenericRowResultSortComparator()); + outputType.getFieldTypes(), new GenericRowRecordSortComparator(0, InternalTypes.STRING)); ConcurrentLinkedQueue expectedOutputOutput = new ConcurrentLinkedQueue<>(); @@ -953,7 +952,7 @@ public void testCleanupTimerWithEmptyReduceStateForTumblingWindows() throws Exce public void testTumblingCountWindow() throws Exception { closeCalled.set(0); final int windowSize = 3; - InternalType[] windowTypes = new InternalType[]{InternalTypes.LONG}; + InternalType[] windowTypes = new InternalType[] { InternalTypes.LONG }; WindowOperator operator = WindowOperatorBuilder.builder() .withInputFields(inputFieldTypes) @@ -1023,7 +1022,7 @@ public void testSlidingCountWindow() throws Exception { closeCalled.set(0); final int windowSize = 5; final int windowSlide = 3; - InternalType[] windowTypes = new InternalType[]{InternalTypes.LONG}; + InternalType[] windowTypes = new InternalType[] { InternalTypes.LONG }; WindowOperator operator = WindowOperatorBuilder.builder() .withInputFields(inputFieldTypes) @@ -1127,19 +1126,6 @@ public SessionWindowAssigner withProcessingTime() { } } - // schema: String, Long, Long, Long, Long - private static class GenericRowResultSortComparator implements Comparator, Serializable { - - private static final long serialVersionUID = -4988371592272863772L; - - @Override - public int compare(GenericRow row1, GenericRow row2) { - String key1 = row1.getString(0).toString(); - String key2 = row2.getString(0).toString(); - return key1.compareTo(key2); - } - } - // sum, count, window_start, window_end private static class SumAndCountAggTimeWindow extends SumAndCountAgg { diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/window/WindowTestUtils.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/window/WindowTestUtils.java index ac8bfe3ede364..4abdc5d4b13da 100644 --- a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/window/WindowTestUtils.java +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/window/WindowTestUtils.java @@ -18,17 +18,11 @@ package org.apache.flink.table.runtime.window; -import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.table.api.window.TimeWindow; -import org.apache.flink.table.dataformat.BaseRow; -import org.apache.flink.table.dataformat.GenericRow; -import org.apache.flink.table.dataformat.util.BaseRowUtil; import org.hamcrest.Matcher; import org.hamcrest.Matchers; -import static org.apache.flink.table.dataformat.BinaryString.fromString; - /** * Utilities that are useful for working with Window tests. */ @@ -38,20 +32,4 @@ static Matcher timeWindow(long start, long end) { return Matchers.equalTo(new TimeWindow(start, end)); } - static StreamRecord record(String key, Object... fields) { - return new StreamRecord<>(baserow(key, fields)); - } - - static StreamRecord retractRecord(String key, Object... fields) { - BaseRow row = baserow(key, fields); - BaseRowUtil.setRetract(row); - return new StreamRecord<>(row); - } - - static BaseRow baserow(String key, Object... fields) { - Object[] objects = new Object[fields.length + 1]; - objects[0] = fromString(key); - System.arraycopy(fields, 0, objects, 1, fields.length); - return GenericRow.of(objects); - } } From 25bdd55f20fa8db7ccfbfcb206cbf8bb23e7858b Mon Sep 17 00:00:00 2001 From: beyond1920 Date: Sat, 13 Apr 2019 01:13:22 +0800 Subject: [PATCH 4/5] 1.Update Deduplicate Function state. 2.Other minor update. --- .../stream/StreamExecDeduplicate.scala | 46 ++--------- .../physical/stream/StreamExecExchange.scala | 11 +-- .../stream/StreamExecDeduplicateRule.scala | 6 +- .../stream/sql/DeduplicateITCase.scala | 82 ------------------- .../table/runtime/stream/sql/RankITCase.scala | 2 +- .../deduplicate/DeduplicateFunction.java | 47 +++++------ .../DeduplicateFunctionHelper.java | 54 +++++++----- .../MiniBatchDeduplicateFunction.java | 36 ++++---- .../runtime/rank/AppendRankFunction.java | 3 +- .../runtime/rank/UpdateRankFunction.java | 2 + .../table/typeutils/SortedMapSerializer.java | 2 +- .../BaseDeduplicateFunctionTest.java | 59 ------------- .../deduplicate/DeduplicateFunctionTest.java | 23 +++++- .../MiniBatchDeduplicateFunctionTest.java | 25 +++++- .../runtime/rank/BaseRankFunctionTest.java | 2 +- 15 files changed, 137 insertions(+), 263 deletions(-) delete mode 100644 flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/BaseDeduplicateFunctionTest.java diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDeduplicate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDeduplicate.scala index 024ee45700fd8..bacd3629acde5 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDeduplicate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDeduplicate.scala @@ -21,20 +21,15 @@ package org.apache.flink.table.plan.nodes.physical.stream import org.apache.flink.streaming.api.operators.KeyedProcessOperator import org.apache.flink.streaming.api.transformations.{OneInputTransformation, StreamTransformation} import org.apache.flink.table.api.{StreamTableEnvironment, TableConfigOptions, TableException} -import org.apache.flink.table.calcite.FlinkTypeFactory -import org.apache.flink.table.codegen.EqualiserCodeGenerator import org.apache.flink.table.dataformat.BaseRow -import org.apache.flink.table.generated.GeneratedRecordEqualiser import org.apache.flink.table.plan.nodes.exec.{ExecNode, StreamExecNode} import org.apache.flink.table.plan.util.KeySelectorUtil import org.apache.flink.table.runtime.bundle.KeyedMapBundleOperator import org.apache.flink.table.runtime.bundle.trigger.CountBundleTrigger import org.apache.flink.table.runtime.deduplicate.{DeduplicateFunction, MiniBatchDeduplicateFunction} -import org.apache.flink.table.`type`.TypeConverters import org.apache.flink.table.plan.rules.physical.stream.StreamExecRetractionRules import org.apache.flink.table.typeutils.BaseRowTypeInfo -import org.apache.flink.table.typeutils.TypeCheckUtils.isRowTime import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} @@ -48,18 +43,18 @@ import scala.collection.JavaConversions._ * Stream physical RelNode which deduplicate on keys and keeps only first row or last row. * This node is an optimization of [[StreamExecRank]] for some special cases. * Compared to [[StreamExecRank]], this node could use mini-batch and access less state. - *

NOTES: only supports sort on proctime now. + *

NOTES: only supports sort on proctime now, sort on rowtime will not translated into + * StreamExecDeduplicate node. */ class StreamExecDeduplicate( cluster: RelOptCluster, traitSet: RelTraitSet, inputRel: RelNode, uniqueKeys: Array[Int], - isRowtime: Boolean, keepLastRow: Boolean) extends SingleRel(cluster, traitSet, inputRel) - with StreamPhysicalRel - with StreamExecNode[BaseRow] { + with StreamPhysicalRel + with StreamExecNode[BaseRow] { def getUniqueKeys: Array[Int] = uniqueKeys @@ -81,17 +76,15 @@ class StreamExecDeduplicate( traitSet, inputs.get(0), uniqueKeys, - isRowtime, keepLastRow) } override def explainTerms(pw: RelWriter): RelWriter = { val fieldNames = getRowType.getFieldNames - val orderString = if (isRowtime) "ROWTIME" else "PROCTIME" super.explainTerms(pw) .item("keepLastRow", keepLastRow) .item("key", uniqueKeys.map(fieldNames.get).mkString(", ")) - .item("order", orderString) + .item("order", "PROCTIME") } //~ ExecNode methods ----------------------------------------------------------- @@ -112,31 +105,17 @@ class StreamExecDeduplicate( .asInstanceOf[StreamTransformation[BaseRow]] val rowTypeInfo = inputTransform.getOutputType.asInstanceOf[BaseRowTypeInfo] - val generateRetraction = StreamExecRetractionRules.isAccRetract(this) - - val inputRowType = FlinkTypeFactory.toInternalRowType(getInput.getRowType) - val rowTimeFieldIndex = inputRowType.getFieldTypes.zipWithIndex - .filter(e => isRowTime(e._1)) - .map(_._2) - if (rowTimeFieldIndex.size > 1) { - throw new RuntimeException("More than one row time field. Currently this is not supported!") - } - if (rowTimeFieldIndex.nonEmpty) { - throw new TableException("Currently not support Deduplicate on rowtime.") - } val tableConfig = tableEnv.getConfig val isMiniBatchEnabled = tableConfig.getConf.getLong( TableConfigOptions.SQL_EXEC_MINIBATCH_ALLOW_LATENCY) > 0 - val generatedRecordEqualiser = generateRecordEqualiser(rowTypeInfo) val operator = if (isMiniBatchEnabled) { val exeConfig = tableEnv.execEnv.getConfig val processFunction = new MiniBatchDeduplicateFunction( rowTypeInfo, generateRetraction, rowTypeInfo.createSerializer(exeConfig), - keepLastRow, - generatedRecordEqualiser) + keepLastRow) val trigger = new CountBundleTrigger[BaseRow]( tableConfig.getConf.getLong(TableConfigOptions.SQL_EXEC_MINIBATCH_SIZE)) new KeyedMapBundleOperator( @@ -150,8 +129,7 @@ class StreamExecDeduplicate( maxRetentionTime, rowTypeInfo, generateRetraction, - keepLastRow, - generatedRecordEqualiser) + keepLastRow) new KeyedProcessOperator[BaseRow, BaseRow, BaseRow](processFunction) } val ret = new OneInputTransformation( @@ -169,16 +147,8 @@ class StreamExecDeduplicate( private def getOperatorName: String = { val fieldNames = getRowType.getFieldNames val keyNames = uniqueKeys.map(fieldNames.get).mkString(", ") - val orderString = if (isRowtime) "ROWTIME" else "PROCTIME" s"${if (keepLastRow) "keepLastRow" else "KeepFirstRow"}" + - s": (key: ($keyNames), select: (${fieldNames.mkString(", ")}), order: ($orderString))" - } - - private def generateRecordEqualiser(rowTypeInfo: BaseRowTypeInfo): GeneratedRecordEqualiser = { - val generator = new EqualiserCodeGenerator( - rowTypeInfo.getFieldTypes.map(TypeConverters.createInternalTypeFromTypeInfo)) - val equaliserName = s"${if (keepLastRow) "LastRow" else "FirstRow"}ValueEqualiser" - generator.generateRecordEqualiser(equaliserName) + s": (key: ($keyNames), select: (${fieldNames.mkString(", ")}), order: (PROCTIME)" } } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecExchange.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecExchange.scala index 2c58863dd6a40..2f2fc307a9744 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecExchange.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecExchange.scala @@ -18,6 +18,7 @@ package org.apache.flink.table.plan.nodes.physical.stream +import org.apache.flink.runtime.state.KeyGroupRangeAssignment.DEFAULT_LOWER_BOUND_MAX_PARALLELISM import org.apache.flink.table.plan.nodes.common.CommonPhysicalExchange import org.apache.flink.table.dataformat.BaseRow import org.apache.flink.table.plan.nodes.exec.{ExecNode, StreamExecNode} @@ -44,10 +45,8 @@ class StreamExecExchange( relNode: RelNode, relDistribution: RelDistribution) extends CommonPhysicalExchange(cluster, traitSet, relNode, relDistribution) - with StreamPhysicalRel - with StreamExecNode[BaseRow] { - - private val DEFAULT_MAX_PARALLELISM = 1 << 7 + with StreamPhysicalRel + with StreamExecNode[BaseRow] { override def producesUpdates: Boolean = false @@ -88,9 +87,11 @@ class StreamExecExchange( transformation case RelDistribution.Type.HASH_DISTRIBUTED => // TODO Eliminate duplicate keys + val selector = KeySelectorUtil.getBaseRowSelector( relDistribution.getKeys.map(_.toInt).toArray, inputTypeInfo) - val partitioner = new KeyGroupStreamPartitioner(selector, DEFAULT_MAX_PARALLELISM) + val partitioner = new KeyGroupStreamPartitioner(selector, + DEFAULT_LOWER_BOUND_MAX_PARALLELISM) val transformation = new PartitionTransformation( inputTransform, partitioner.asInstanceOf[StreamPartitioner[BaseRow]]) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecDeduplicateRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecDeduplicateRule.scala index 609bae8b7d34f..6ec7a27eb3aa3 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecDeduplicateRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecDeduplicateRule.scala @@ -68,10 +68,6 @@ class StreamExecDeduplicateRule override def convert(rel: RelNode): RelNode = { val rank = rel.asInstanceOf[FlinkLogicalRank] - val fieldCollation = rank.orderKey.getFieldCollations.get(0) - val fieldType = rank.getInput.getRowType.getFieldList.get(fieldCollation.getFieldIndex).getType - val isRowtime = FlinkTypeFactory.isRowtimeIndicatorType(fieldType) - val requiredDistribution = FlinkRelDistribution.hash(rank.partitionKey.toList) val requiredTraitSet = rel.getCluster.getPlanner.emptyTraitSet() .replace(FlinkConventions.STREAM_PHYSICAL) @@ -79,6 +75,7 @@ class StreamExecDeduplicateRule val convInput: RelNode = RelOptRule.convert(rank.getInput, requiredTraitSet) // order by timeIndicator desc ==> lastRow, otherwise is firstRow + val fieldCollation = rank.orderKey.getFieldCollations.get(0) val isLastRow = fieldCollation.direction.isDescending val providedTraitSet = rel.getTraitSet.replace(FlinkConventions.STREAM_PHYSICAL) new StreamExecDeduplicate( @@ -86,7 +83,6 @@ class StreamExecDeduplicateRule providedTraitSet, convInput, rank.partitionKey.toArray, - isRowtime, isLastRow) } } diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/DeduplicateITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/DeduplicateITCase.scala index b218e0e5a6fb9..179394086772a 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/DeduplicateITCase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/DeduplicateITCase.scala @@ -87,86 +87,4 @@ class DeduplicateITCase(miniBatch: MiniBatchMode, mode: StateBackendMode) assertEquals(expected.sorted, sink.getRetractResults.sorted) } - // TODO Deduplicate does not support sort on rowtime now, so it is translated to Rank currently - @Test - def testFirstRowOnRowtime(): Unit = { - val data = List( - (3L, 2L, "Hello world", 3), - (2L, 2L, "Hello", 2), - (6L, 3L, "Luke Skywalker", 6), - (5L, 3L, "I am fine.", 5), - (7L, 4L, "Comment#1", 7), - (9L, 4L, "Comment#3", 9), - (10L, 4L, "Comment#4", 10), - (8L, 4L, "Comment#2", 8), - (1L, 1L, "Hi", 1), - (4L, 3L, "Helloworld, how are you?", 4)) - - val t = failingDataSource(data) - .assignTimestampsAndWatermarks( - new TimestampAndWatermarkWithOffset[(Long, Long, String, Int)](10L)) - .toTable(tEnv, 'rowtime, 'key, 'str, 'int) - tEnv.registerTable("T", t) - - val sql = - """ - |SELECT key, str, `int` - |FROM ( - | SELECT *, - | ROW_NUMBER() OVER (PARTITION BY key ORDER BY rowtime) as rowNum - | FROM T - |) - |WHERE rowNum = 1 - """.stripMargin - - val sink = new TestingUpsertTableSink(Array(1)) - val table = tEnv.sqlQuery(sql) - writeToSink(table, sink) - - env.execute() - val expected = List("1,Hi,1", "2,Hello,2", "3,Helloworld, how are you?,4", "4,Comment#1,7") - assertEquals(expected.sorted, sink.getUpsertResults.sorted) - } - - // TODO Deduplicate does not support sort on rowtime now, so it is translated to Rank currently - @Test - def testLastRowOnRowtime(): Unit = { - val data = List( - (3L, 2L, "Hello world", 3), - (2L, 2L, "Hello", 2), - (6L, 3L, "Luke Skywalker", 6), - (5L, 3L, "I am fine.", 5), - (7L, 4L, "Comment#1", 7), - (9L, 4L, "Comment#3", 9), - (10L, 4L, "Comment#4", 10), - (8L, 4L, "Comment#2", 8), - (1L, 1L, "Hi", 1), - (4L, 3L, "Helloworld, how are you?", 4)) - - val t = failingDataSource(data) - .assignTimestampsAndWatermarks( - new TimestampAndWatermarkWithOffset[(Long, Long, String, Int)](10L)) - .toTable(tEnv, 'rowtime, 'key, 'str, 'int) - tEnv.registerTable("T", t) - - val sql = - """ - |SELECT key, str, `int` - |FROM ( - | SELECT *, - | ROW_NUMBER() OVER (PARTITION BY key ORDER BY rowtime DESC) as rowNum - | FROM T - |) - |WHERE rowNum = 1 - """.stripMargin - - val sink = new TestingUpsertTableSink(Array(1)) - val table = tEnv.sqlQuery(sql) - writeToSink(table, sink) - - env.execute() - val expected = List("1,Hi,1", "2,Hello world,3", "3,Luke Skywalker,6", "4,Comment#4,10") - assertEquals(expected.sorted, sink.getUpsertResults.sorted) - } - } diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/RankITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/RankITCase.scala index fd2485dc282cb..2e8225651a864 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/RankITCase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/RankITCase.scala @@ -62,8 +62,8 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink).setParallelism(1) env.execute() val expected = List( - "book,1,12,2", "book,2,19,1", + "book,1,12,2", "fruit,3,44,1", "fruit,4,33,2") assertEquals(expected.sorted, sink.getRetractResults.sorted) diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunction.java index e1a9f3485dbb2..1467b1c9f0e99 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunction.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunction.java @@ -20,10 +20,9 @@ import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.configuration.Configuration; import org.apache.flink.table.dataformat.BaseRow; -import org.apache.flink.table.generated.GeneratedRecordEqualiser; -import org.apache.flink.table.generated.RecordEqualiser; import org.apache.flink.table.runtime.functions.KeyedProcessFunctionWithCleanupState; import org.apache.flink.table.typeutils.BaseRowTypeInfo; import org.apache.flink.util.Collector; @@ -42,30 +41,33 @@ public class DeduplicateFunction private final BaseRowTypeInfo rowTypeInfo; private final boolean generateRetraction; private final boolean keepLastRow; - private ValueState pkRow; - private GeneratedRecordEqualiser generatedEqualiser; - private transient RecordEqualiser equaliser; + + // state stores complete row if keep last row and generate retraction is true, + // else stores a flag to indicate whether key appears before. + private ValueState state; public DeduplicateFunction(long minRetentionTime, long maxRetentionTime, BaseRowTypeInfo rowTypeInfo, - boolean generateRetraction, boolean keepLastRow, GeneratedRecordEqualiser generatedEqualiser) { + boolean generateRetraction, boolean keepLastRow) { super(minRetentionTime, maxRetentionTime); this.rowTypeInfo = rowTypeInfo; this.generateRetraction = generateRetraction; this.keepLastRow = keepLastRow; - this.generatedEqualiser = generatedEqualiser; } @Override public void open(Configuration configure) throws Exception { super.open(configure); - String stateName = keepLastRow ? "DeduplicateFunctionCleanupTime" : "DeduplicateFunctionCleanupTime"; + String stateName = keepLastRow ? "DeduplicateFunctionKeepLastRow" : "DeduplicateFunctionKeepFirstRow"; initCleanupTimeState(stateName); - ValueStateDescriptor rowStateDesc = new ValueStateDescriptor("rowState", rowTypeInfo); - pkRow = getRuntimeContext().getState(rowStateDesc); - - // compile equaliser - equaliser = generatedEqualiser.newInstance(getRuntimeContext().getUserCodeClassLoader()); - generatedEqualiser = null; + ValueStateDescriptor stateDesc = null; + if (keepLastRow && generateRetraction) { + // if need generate retraction and keep last row, stores complete row into state + stateDesc = new ValueStateDescriptor("deduplicateFunction", rowTypeInfo); + } else { + // else stores a flag to indicator whether pk appears before. + stateDesc = new ValueStateDescriptor("fistValueState", Types.BOOLEAN); + } + state = getRuntimeContext().getState(stateDesc); } @Override @@ -74,26 +76,17 @@ public void processElement(BaseRow input, Context ctx, Collector out) t // register state-cleanup timer registerProcessingCleanupTimer(ctx, currentTime); - BaseRow preRow = pkRow.value(); if (keepLastRow) { - processLastRow(preRow, input, generateRetraction, stateCleaningEnabled, pkRow, equaliser, out); + processLastRow(input, generateRetraction, state, out); } else { - processFirstRow(preRow, input, pkRow, out); + processFirstRow(input, state, out); } } @Override - public void close() throws Exception { - super.close(); - } - - @Override - public void onTimer( - long timestamp, - OnTimerContext ctx, - Collector out) throws Exception { + public void onTimer(long timestamp, OnTimerContext ctx, Collector out) throws Exception { if (stateCleaningEnabled) { - cleanupState(pkRow); + cleanupState(state); } } } diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionHelper.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionHelper.java index d3b4f2c5102bb..831a2cfbd6389 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionHelper.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionHelper.java @@ -21,7 +21,6 @@ import org.apache.flink.api.common.state.ValueState; import org.apache.flink.table.dataformat.BaseRow; import org.apache.flink.table.dataformat.util.BaseRowUtil; -import org.apache.flink.table.generated.RecordEqualiser; import org.apache.flink.util.Collector; import org.apache.flink.util.Preconditions; @@ -30,35 +29,52 @@ */ class DeduplicateFunctionHelper { - static void processLastRow(BaseRow preRow, BaseRow currentRow, boolean generateRetraction, - boolean stateCleaningEnabled, ValueState pkRow, RecordEqualiser equaliser, + /** + * Processes element to deduplicate on keys, sends current element as last row, retracts previous element if + * needed. + * + * @param currentRow latest row received by deduplicate function + * @param generateRetraction whether need to send retract message to downstream + * @param state state of function + * @param out underlying collector + * @throws Exception + */ + static void processLastRow(BaseRow currentRow, boolean generateRetraction, ValueState state, Collector out) throws Exception { - // should be accumulate msg. + // should be accumulate msg Preconditions.checkArgument(BaseRowUtil.isAccumulateMsg(currentRow)); - if (!stateCleaningEnabled && preRow != null && - equaliser.equalsWithoutHeader(preRow, currentRow)) { - // If state cleaning is not enabled, don't emit retraction and acc message. But if state cleaning is - // enabled, we have to emit message to prevent too early state eviction of downstream operators. - return; - } - pkRow.update(currentRow); - if (preRow != null && generateRetraction) { - preRow.setHeader(BaseRowUtil.RETRACT_MSG); - out.collect(preRow); + if (generateRetraction) { + // state stores complete row if generateRetraction is true + BaseRow preRow = (BaseRow) state.value(); + state.update(currentRow); + if (preRow != null) { + preRow.setHeader(BaseRowUtil.RETRACT_MSG); + out.collect(preRow); + } + } else { + // state stores a flag to indicator whether pk appears before + state.update(true); } out.collect(currentRow); } - static void processFirstRow(BaseRow preRow, BaseRow currentRow, ValueState pkRow, - Collector out) throws Exception { + /** + * Processes element to deduplicate on keys, sends current element if it is first row. + * + * @param currentRow latest row received by deduplicate function + * @param state state of function + * @param out underlying collector + * @throws Exception + */ + static void processFirstRow(BaseRow currentRow, ValueState state, Collector out) + throws Exception { // should be accumulate msg. Preconditions.checkArgument(BaseRowUtil.isAccumulateMsg(currentRow)); // ignore record with timestamp bigger than preRow - if (preRow != null) { + if (state.value() != null) { return; } - - pkRow.update(currentRow); + state.update(true); out.collect(currentRow); } diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateFunction.java index aa3c4cdec6820..8b8f9832e1748 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateFunction.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateFunction.java @@ -20,10 +20,9 @@ import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.table.dataformat.BaseRow; -import org.apache.flink.table.generated.GeneratedRecordEqualiser; -import org.apache.flink.table.generated.RecordEqualiser; import org.apache.flink.table.runtime.bundle.MapBundleFunction; import org.apache.flink.table.runtime.context.ExecutionContext; import org.apache.flink.table.typeutils.BaseRowTypeInfo; @@ -46,33 +45,35 @@ public class MiniBatchDeduplicateFunction private BaseRowTypeInfo rowTypeInfo; private boolean generateRetraction; private boolean keepLastRow; - private ValueState pkRow; + + // state stores complete row if keep last row and generate retraction is true, + // else stores a flag to indicate whether key appears before. + private ValueState state; private TypeSerializer ser; - private GeneratedRecordEqualiser generatedEqualiser; - private transient RecordEqualiser equaliser; public MiniBatchDeduplicateFunction( BaseRowTypeInfo rowTypeInfo, boolean generateRetraction, TypeSerializer typeSerializer, - boolean keepLastRow, - GeneratedRecordEqualiser generatedEqualiser) { + boolean keepLastRow) { this.rowTypeInfo = rowTypeInfo; - this.generateRetraction = generateRetraction; this.keepLastRow = keepLastRow; + this.generateRetraction = generateRetraction; ser = typeSerializer; - this.generatedEqualiser = generatedEqualiser; } @Override public void open(ExecutionContext ctx) throws Exception { super.open(ctx); - ValueStateDescriptor rowStateDesc = new ValueStateDescriptor("rowState", rowTypeInfo); - pkRow = ctx.getRuntimeContext().getState(rowStateDesc); - - // compile equaliser - equaliser = generatedEqualiser.newInstance(ctx.getRuntimeContext().getUserCodeClassLoader()); - generatedEqualiser = null; + ValueStateDescriptor stateDesc = null; + if (keepLastRow && generateRetraction) { + // if need generate retraction and keep last row, stores complete row into state + stateDesc = new ValueStateDescriptor("deduplicateFunction", rowTypeInfo); + } else { + // else stores a flag to indicator whether pk appears before. + stateDesc = new ValueStateDescriptor("fistValueState", Types.BOOLEAN); + } + state = ctx.getRuntimeContext().getState(stateDesc); } @Override @@ -93,12 +94,11 @@ public void finishBundle( BaseRow currentKey = entry.getKey(); BaseRow currentRow = entry.getValue(); ctx.setCurrentKey(currentKey); - BaseRow preRow = pkRow.value(); if (keepLastRow) { - processLastRow(preRow, currentRow, generateRetraction, false, pkRow, equaliser, out); + processLastRow(currentRow, generateRetraction, state, out); } else { - processFirstRow(preRow, currentRow, pkRow, out); + processFirstRow(currentRow, state, out); } } } diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AppendRankFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AppendRankFunction.java index 41ea9d703ca55..6464637e2a871 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AppendRankFunction.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AppendRankFunction.java @@ -24,7 +24,6 @@ import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.typeutils.ListTypeInfo; import org.apache.flink.configuration.Configuration; -import org.apache.flink.table.api.TableConfigOptions; import org.apache.flink.table.dataformat.BaseRow; import org.apache.flink.table.generated.GeneratedRecordComparator; import org.apache.flink.table.generated.GeneratedRecordEqualiser; @@ -61,6 +60,8 @@ public class AppendRankFunction extends AbstractRankFunction { // the buffer stores mapping from sort key to records list, a heap mirror to dataState private transient TopNBuffer buffer; + + // the kvSortedMap stores mapping from partition key to it's buffer private transient Map kvSortedMap; public AppendRankFunction(long minRetentionTime, long maxRetentionTime, BaseRowTypeInfo inputRowType, diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/UpdateRankFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/UpdateRankFunction.java index d5e90834538fb..1965e52e2ea2c 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/UpdateRankFunction.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/UpdateRankFunction.java @@ -76,11 +76,13 @@ public class UpdateRankFunction extends AbstractRankFunction implements Checkpoi // a buffer stores mapping from sort key to rowKey list private transient TopNBuffer buffer; + // the kvSortedMap stores mapping from partition key to it's buffer private transient Map kvSortedMap; // a HashMap stores mapping from rowKey to record, a heap mirror to dataState private transient Map rowKeyMap; + // the kvRowKeyMap store mapping from partitionKey to its rowKeyMap. private transient LRUMap> kvRowKeyMap; private final TypeSerializer inputRowSer; diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapSerializer.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapSerializer.java index ea5b8200f17fb..63f4f8f71eeac 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapSerializer.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapSerializer.java @@ -116,4 +116,4 @@ public String toString() { public TypeSerializerSnapshot> snapshotConfiguration() { return new SortedMapSerializerSnapshot<>(this); } -} \ No newline at end of file +} diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/BaseDeduplicateFunctionTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/BaseDeduplicateFunctionTest.java deleted file mode 100644 index d8a475c45241f..0000000000000 --- a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/BaseDeduplicateFunctionTest.java +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.runtime.deduplicate; - -import org.apache.flink.table.generated.GeneratedRecordEqualiser; -import org.apache.flink.table.generated.RecordEqualiser; -import org.apache.flink.table.runtime.util.BaseRowHarnessAssertor; -import org.apache.flink.table.runtime.util.BaseRowRecordEqualiser; -import org.apache.flink.table.runtime.util.BinaryRowKeySelector; -import org.apache.flink.table.runtime.util.GenericRowRecordSortComparator; -import org.apache.flink.table.type.InternalTypes; -import org.apache.flink.table.typeutils.BaseRowTypeInfo; - -/** - * Base Tests for {@link DeduplicateFunction} and {@link MiniBatchDeduplicateFunction}. - */ -public abstract class BaseDeduplicateFunctionTest { - - protected BaseRowTypeInfo inputRowType = new BaseRowTypeInfo( - InternalTypes.STRING, - InternalTypes.LONG, - InternalTypes.INT); - - protected GeneratedRecordEqualiser generatedEqualiser = new GeneratedRecordEqualiser("", "", new Object[0]) { - - private static final long serialVersionUID = -5080236034372380295L; - - @Override - public RecordEqualiser newInstance(ClassLoader classLoader) { - return new BaseRowRecordEqualiser(); - } - }; - - private int rowKeyIdx = 1; - protected BinaryRowKeySelector rowKeySelector = new BinaryRowKeySelector(new int[] { rowKeyIdx }, - inputRowType.getInternalTypes()); - - - protected BaseRowHarnessAssertor assertor = new BaseRowHarnessAssertor( - inputRowType.getFieldTypes(), - new GenericRowRecordSortComparator(rowKeyIdx, inputRowType.getInternalTypes()[rowKeyIdx])); - -} diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionTest.java index d99e6783c9c75..6210f1945f83c 100644 --- a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionTest.java +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionTest.java @@ -23,6 +23,11 @@ import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.runtime.util.BaseRowHarnessAssertor; +import org.apache.flink.table.runtime.util.BinaryRowKeySelector; +import org.apache.flink.table.runtime.util.GenericRowRecordSortComparator; +import org.apache.flink.table.type.InternalTypes; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; import org.junit.Test; @@ -35,15 +40,25 @@ /** * Tests for {@link DeduplicateFunction}. */ -public class DeduplicateFunctionTest extends BaseDeduplicateFunctionTest { +public class DeduplicateFunctionTest { private Time minTime = Time.milliseconds(10); private Time maxTime = Time.milliseconds(20); + private BaseRowTypeInfo inputRowType = new BaseRowTypeInfo(InternalTypes.STRING, InternalTypes.LONG, + InternalTypes.INT); + + private int rowKeyIdx = 1; + private BinaryRowKeySelector rowKeySelector = new BinaryRowKeySelector(new int[] { rowKeyIdx }, + inputRowType.getInternalTypes()); + + private BaseRowHarnessAssertor assertor = new BaseRowHarnessAssertor( + inputRowType.getFieldTypes(), + new GenericRowRecordSortComparator(rowKeyIdx, inputRowType.getInternalTypes()[rowKeyIdx])); + private DeduplicateFunction createFunction(boolean generateRetraction, boolean keepLastRow) { DeduplicateFunction func = new DeduplicateFunction(minTime.toMilliseconds(), maxTime.toMilliseconds(), - inputRowType, generateRetraction, keepLastRow, - generatedEqualiser); + inputRowType, generateRetraction, keepLastRow); return func; } @@ -118,8 +133,8 @@ public void testKeepLastRowWithGenerateRetraction() throws Exception { List expectedOutputOutput = new ArrayList<>(); expectedOutputOutput.add(record("book", 1L, 12)); expectedOutputOutput.add(retractRecord("book", 1L, 12)); - expectedOutputOutput.add(record("book", 2L, 11)); expectedOutputOutput.add(record("book", 1L, 13)); + expectedOutputOutput.add(record("book", 2L, 11)); assertor.assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); } } diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateFunctionTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateFunctionTest.java index aa4602ff7631d..e47525db45e84 100644 --- a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateFunctionTest.java +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateFunctionTest.java @@ -26,6 +26,11 @@ import org.apache.flink.table.dataformat.BaseRow; import org.apache.flink.table.runtime.bundle.KeyedMapBundleOperator; import org.apache.flink.table.runtime.bundle.trigger.CountBundleTrigger; +import org.apache.flink.table.runtime.util.BaseRowHarnessAssertor; +import org.apache.flink.table.runtime.util.BinaryRowKeySelector; +import org.apache.flink.table.runtime.util.GenericRowRecordSortComparator; +import org.apache.flink.table.type.InternalTypes; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; import org.junit.Assert; import org.junit.Test; @@ -39,12 +44,25 @@ /** * Tests for {@link MiniBatchDeduplicateFunction}. */ -public class MiniBatchDeduplicateFunctionTest extends BaseDeduplicateFunctionTest { +public class MiniBatchDeduplicateFunctionTest { + + private BaseRowTypeInfo inputRowType = new BaseRowTypeInfo(InternalTypes.STRING, InternalTypes.LONG, + InternalTypes.INT); + + private int rowKeyIdx = 1; + private BinaryRowKeySelector rowKeySelector = new BinaryRowKeySelector(new int[] { rowKeyIdx }, + inputRowType.getInternalTypes()); + + + private BaseRowHarnessAssertor assertor = new BaseRowHarnessAssertor( + inputRowType.getFieldTypes(), + new GenericRowRecordSortComparator(rowKeyIdx, inputRowType.getInternalTypes()[rowKeyIdx])); + private TypeSerializer typeSerializer = inputRowType.createSerializer(new ExecutionConfig()); private MiniBatchDeduplicateFunction createFunction(boolean generateRetraction, boolean keepLastRow) { MiniBatchDeduplicateFunction func = new MiniBatchDeduplicateFunction(inputRowType, generateRetraction, - typeSerializer, keepLastRow, generatedEqualiser); + typeSerializer, keepLastRow); return func; } @@ -119,6 +137,7 @@ public void testKeepLastWithoutGenerateRetraction() throws Exception { testHarness.processElement(record("book", 3L, 11)); expectedOutputOutput.add(record("book", 1L, 12)); + expectedOutputOutput.add(record("book", 2L, 11)); expectedOutputOutput.add(record("book", 3L, 11)); testHarness.close(); assertor.assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); @@ -148,6 +167,8 @@ public void testKeepLastRowWithGenerateRetraction() throws Exception { // this will send retract message to downstream expectedOutputOutput.add(retractRecord("book", 1L, 13)); expectedOutputOutput.add(record("book", 1L, 12)); + expectedOutputOutput.add(retractRecord("book", 2L, 11)); + expectedOutputOutput.add(record("book", 2L, 11)); expectedOutputOutput.add(record("book", 3L, 11)); testHarness.close(); assertor.assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/BaseRankFunctionTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/BaseRankFunctionTest.java index 9abc9d43e0845..0d766a546c686 100644 --- a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/BaseRankFunctionTest.java +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/BaseRankFunctionTest.java @@ -31,8 +31,8 @@ import org.apache.flink.table.runtime.sort.IntRecordComparator; import org.apache.flink.table.runtime.util.BaseRowHarnessAssertor; import org.apache.flink.table.runtime.util.BaseRowRecordEqualiser; -import org.apache.flink.table.runtime.util.GenericRowRecordSortComparator; import org.apache.flink.table.runtime.util.BinaryRowKeySelector; +import org.apache.flink.table.runtime.util.GenericRowRecordSortComparator; import org.apache.flink.table.type.InternalTypes; import org.apache.flink.table.typeutils.BaseRowTypeInfo; From 7de15170d14681167502cbb493fba522165907da Mon Sep 17 00:00:00 2001 From: beyond1920 Date: Mon, 15 Apr 2019 10:36:48 +0800 Subject: [PATCH 5/5] 1. split DeduplicateFunction into DeduplicateKeepFirstRowFunction and DeduplicateKeepLastRowFunction 2. other minor update. --- .../apache/flink/table/api/TableConfig.scala | 6 + .../calcite/FlinkLogicalRelFactories.scala | 3 +- .../flink/table/calcite/FlinkRelBuilder.scala | 3 +- .../table/calcite/FlinkRelFactories.scala | 4 +- .../stream/StreamExecDeduplicate.scala | 31 ++-- .../physical/stream/StreamExecRank.scala | 4 +- .../table/plan/optimize/StreamOptimizer.scala | 1 + .../table/runtime/stream/sql/RankITCase.scala | 152 +++++++++++------- .../table/runtime/utils/StreamTestSink.scala | 5 +- .../runtime/utils/StreamingTestBase.scala | 6 - .../utils/StreamingWithStateTestBase.scala | 26 --- .../flink/table/runtime/utils/TableUtil.scala | 12 +- .../DeduplicateFunctionHelper.java | 13 +- ...a => DeduplicateKeepFirstRowFunction.java} | 42 ++--- .../DeduplicateKeepLastRowFunction.java | 78 +++++++++ ...iBatchDeduplicateKeepFirstRowFunction.java | 81 ++++++++++ ...iBatchDeduplicateKeepLastRowFunction.java} | 56 ++----- .../keyselector/BinaryRowKeySelector.java | 2 +- .../keyselector/NullBinaryRowKeySelector.java | 4 +- .../runtime/rank/AbstractRankFunction.java | 29 +--- .../runtime/rank/AppendRankFunction.java | 16 +- .../runtime/rank/RetractRankFunction.java | 11 +- .../runtime/rank/UpdateRankFunction.java | 28 +--- .../DeduplicateKeepFirstRowFunctionTest.java | 83 ++++++++++ ...> DeduplicateKeepLastRowFunctionTest.java} | 58 ++----- ...chDeduplicateKeepFirstRowFunctionTest.java | 91 +++++++++++ ...chDeduplicateKeepLastRowFunctionTest.java} | 63 ++------ .../runtime/rank/AppendRankFunctionTest.java | 3 +- .../runtime/rank/BaseRankFunctionTest.java | 12 +- .../runtime/rank/RetractRankFunctionTest.java | 3 +- .../runtime/rank/UpdateRankFunctionTest.java | 4 +- .../runtime/util/BinaryRowKeySelector.java | 4 +- .../runtime/window/WindowOperatorTest.java | 4 +- 33 files changed, 546 insertions(+), 392 deletions(-) rename flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/{DeduplicateFunction.java => DeduplicateKeepFirstRowFunction.java} (58%) create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateKeepLastRowFunction.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateKeepFirstRowFunction.java rename flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/{MiniBatchDeduplicateFunction.java => MiniBatchDeduplicateKeepLastRowFunction.java} (57%) create mode 100644 flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/DeduplicateKeepFirstRowFunctionTest.java rename flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/{DeduplicateFunctionTest.java => DeduplicateKeepLastRowFunctionTest.java} (62%) create mode 100644 flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateKeepFirstRowFunctionTest.java rename flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/{MiniBatchDeduplicateFunctionTest.java => MiniBatchDeduplicateKeepLastRowFunctionTest.java} (67%) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableConfig.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableConfig.scala index a1844b0e4fd28..4d5b2c3d1226c 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableConfig.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableConfig.scala @@ -188,10 +188,16 @@ class TableConfig { this } + /** + * Returns the minimum time until state which was not updated will be retained. + */ def getMinIdleStateRetentionTime: Long = { this.conf.getLong(TableConfigOptions.SQL_EXEC_STATE_TTL_MS) } + /** + * Returns the maximum time until state which was not updated will be retained. + */ def getMaxIdleStateRetentionTime: Long = { // only min idle ttl provided. if (this.conf.contains(TableConfigOptions.SQL_EXEC_STATE_TTL_MS) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkLogicalRelFactories.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkLogicalRelFactories.scala index 4e26b65b378ff..c60bb12f60d95 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkLogicalRelFactories.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkLogicalRelFactories.scala @@ -19,10 +19,9 @@ package org.apache.flink.table.calcite import org.apache.flink.table.calcite.FlinkRelFactories.{ExpandFactory, RankFactory, SinkFactory} -import org.apache.flink.table.plan.nodes.calcite.RankRange -import org.apache.flink.table.plan.nodes.calcite.RankType.RankType import org.apache.flink.table.plan.nodes.logical._ import org.apache.flink.table.plan.schema.FlinkRelOptTable +import org.apache.flink.table.runtime.rank.{RankRange, RankType} import org.apache.flink.table.sinks.TableSink import com.google.common.collect.ImmutableList diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala index 3f462fe3d2156..b8e2c1a755913 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala @@ -20,8 +20,7 @@ package org.apache.flink.table.calcite import org.apache.flink.table.calcite.FlinkRelFactories.{ExpandFactory, RankFactory, SinkFactory} import org.apache.flink.table.expressions.WindowProperty -import org.apache.flink.table.plan.nodes.calcite.RankRange -import org.apache.flink.table.plan.nodes.calcite.RankType.RankType +import org.apache.flink.table.runtime.rank.{RankRange, RankType} import org.apache.flink.table.sinks.TableSink import org.apache.calcite.config.{CalciteConnectionConfigImpl, CalciteConnectionProperty} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkRelFactories.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkRelFactories.scala index d18e5ec82cd66..05dc497f6b1bf 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkRelFactories.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkRelFactories.scala @@ -18,8 +18,8 @@ package org.apache.flink.table.calcite -import org.apache.flink.table.plan.nodes.calcite.RankType.RankType -import org.apache.flink.table.plan.nodes.calcite.{LogicalExpand, LogicalRank, LogicalSink, RankRange} +import org.apache.flink.table.plan.nodes.calcite.{LogicalExpand, LogicalRank, LogicalSink} +import org.apache.flink.table.runtime.rank.{RankRange, RankType} import org.apache.flink.table.sinks.TableSink import org.apache.calcite.plan.Contexts diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDeduplicate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDeduplicate.scala index bacd3629acde5..58d56b84ba472 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDeduplicate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDeduplicate.scala @@ -26,8 +26,7 @@ import org.apache.flink.table.plan.nodes.exec.{ExecNode, StreamExecNode} import org.apache.flink.table.plan.util.KeySelectorUtil import org.apache.flink.table.runtime.bundle.KeyedMapBundleOperator import org.apache.flink.table.runtime.bundle.trigger.CountBundleTrigger -import org.apache.flink.table.runtime.deduplicate.{DeduplicateFunction, -MiniBatchDeduplicateFunction} +import org.apache.flink.table.runtime.deduplicate.{DeduplicateKeepFirstRowFunction, DeduplicateKeepLastRowFunction, MiniBatchDeduplicateKeepFirstRowFunction, MiniBatchDeduplicateKeepLastRowFunction} import org.apache.flink.table.plan.rules.physical.stream.StreamExecRetractionRules import org.apache.flink.table.typeutils.BaseRowTypeInfo @@ -95,10 +94,7 @@ class StreamExecDeduplicate( val inputIsAccRetract = StreamExecRetractionRules.isAccRetract(getInput) if (inputIsAccRetract) { - throw new TableException( - "Deduplicate: Retraction on Deduplicate is not supported yet.\n" + - "please re-check sql grammar. \n" + - "Note: Deduplicate should not follow a non-windowed GroupBy aggregation.") + throw new TableException("Deduplicate doesn't support retraction input stream currently.") } val inputTransform = getInputNodes.get(0).translateToPlan(tableEnv) @@ -111,11 +107,12 @@ class StreamExecDeduplicate( TableConfigOptions.SQL_EXEC_MINIBATCH_ALLOW_LATENCY) > 0 val operator = if (isMiniBatchEnabled) { val exeConfig = tableEnv.execEnv.getConfig - val processFunction = new MiniBatchDeduplicateFunction( - rowTypeInfo, - generateRetraction, - rowTypeInfo.createSerializer(exeConfig), - keepLastRow) + val rowSerializer = rowTypeInfo.createSerializer(exeConfig) + val processFunction = if (keepLastRow) { + new MiniBatchDeduplicateKeepLastRowFunction(rowTypeInfo, generateRetraction, rowSerializer) + } else { + new MiniBatchDeduplicateKeepFirstRowFunction(rowSerializer) + } val trigger = new CountBundleTrigger[BaseRow]( tableConfig.getConf.getLong(TableConfigOptions.SQL_EXEC_MINIBATCH_SIZE)) new KeyedMapBundleOperator( @@ -124,12 +121,12 @@ class StreamExecDeduplicate( } else { val minRetentionTime = tableConfig.getMinIdleStateRetentionTime val maxRetentionTime = tableConfig.getMaxIdleStateRetentionTime - val processFunction = new DeduplicateFunction( - minRetentionTime, - maxRetentionTime, - rowTypeInfo, - generateRetraction, - keepLastRow) + val processFunction = if (keepLastRow) { + new DeduplicateKeepLastRowFunction(minRetentionTime, maxRetentionTime, rowTypeInfo, + generateRetraction) + } else { + new DeduplicateKeepFirstRowFunction(minRetentionTime, maxRetentionTime) + } new KeyedProcessOperator[BaseRow, BaseRow, BaseRow](processFunction) } val ret = new OneInputTransformation( diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecRank.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecRank.scala index ea2b5e48bd8c8..e27c85b4b7370 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecRank.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecRank.scala @@ -21,7 +21,8 @@ import org.apache.flink.streaming.api.operators.KeyedProcessOperator import org.apache.flink.streaming.api.transformations.{OneInputTransformation, StreamTransformation} import org.apache.flink.table.api.{StreamTableEnvironment, TableConfigOptions, TableException} import org.apache.flink.table.calcite.FlinkTypeFactory -import org.apache.flink.table.codegen.{EqualiserCodeGenerator, SortCodeGenerator} +import org.apache.flink.table.codegen.EqualiserCodeGenerator +import org.apache.flink.table.codegen.sort.SortCodeGenerator import org.apache.flink.table.dataformat.BaseRow import org.apache.flink.table.plan.nodes.calcite.Rank import org.apache.flink.table.plan.nodes.exec.{ExecNode, StreamExecNode} @@ -111,7 +112,6 @@ class StreamExecRank( .item("select", getRowType.getFieldNames.mkString(", ")) } - //~ ExecNode methods ----------------------------------------------------------- override def getInputNodes: util.List[ExecNode[StreamTableEnvironment, _]] = { diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/optimize/StreamOptimizer.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/optimize/StreamOptimizer.scala index 19fbd3df9021a..822359c5ca520 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/optimize/StreamOptimizer.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/optimize/StreamOptimizer.scala @@ -41,6 +41,7 @@ class StreamOptimizer(tEnv: StreamTableEnvironment) extends Optimizer { n.sink match { case _: RetractStreamTableSink[_] => true case s: DataStreamTableSink[_] => s.updatesAsRetraction + case _ => false } case o => o.getTraitSet.getTrait(UpdateAsRetractionTraitDef.INSTANCE).sendsUpdatesAsRetractions diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/RankITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/RankITCase.scala index 2e8225651a864..ff046bc90a390 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/RankITCase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/RankITCase.scala @@ -125,9 +125,11 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode |WHERE rank_num <= 2 """.stripMargin - val sink = new TestingUpsertTableSink(Array(0, 3)) val table = tEnv.sqlQuery(sql) - writeToSink(table, sink) + val schema = table.getSchema + val sink = new TestingUpsertTableSink(Array(0, 3)). + configure(schema.getFieldNames, schema.getFieldTypes) + tEnv.writeToSink(table, sink) env.execute() val expected = List( @@ -186,9 +188,11 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode |WHERE rank_num <= 3 """.stripMargin - val sink = new TestingUpsertTableSink(Array(0, 3)) val table = tEnv.sqlQuery(sql) - writeToSink(table, sink) + val schema = table.getSchema + val sink = new TestingUpsertTableSink(Array(0, 3)). + configure(schema.getFieldNames, schema.getFieldTypes) + tEnv.writeToSink(table, sink) env.execute() val updatedExpected = List( @@ -248,7 +252,9 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode """.stripMargin val table = tEnv.sqlQuery(sql) - val sink = new TestingUpsertTableSink(Array(0, 3)) + val schema = table.getSchema + val sink = new TestingUpsertTableSink(Array(0, 3)). + configure(schema.getFieldNames, schema.getFieldTypes) tEnv.writeToSink(table, sink) env.execute() @@ -292,8 +298,10 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode """.stripMargin val table = tEnv.sqlQuery(sql) - val sink = new TestingUpsertTableSink(Array(0, 3)) - writeToSink(table, sink) + val schema = table.getSchema + val sink = new TestingUpsertTableSink(Array(0, 3)). + configure(schema.getFieldNames, schema.getFieldTypes) + tEnv.writeToSink(table, sink) env.execute() val updatedExpected = List( @@ -344,8 +352,10 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode """.stripMargin val table = tEnv.sqlQuery(sql) - val sink = new TestingUpsertTableSink(Array(0, 3)) - writeToSink(table, sink) + val schema = table.getSchema + val sink = new TestingUpsertTableSink(Array(0, 3)). + configure(schema.getFieldNames, schema.getFieldTypes) + tEnv.writeToSink(table, sink) env.execute() val updatedExpected = List( @@ -387,15 +397,17 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode """.stripMargin val table = tEnv.sqlQuery(sql) - val tableSink = new TestingUpsertTableSink(Array(0, 3)) - writeToSink(table, tableSink) + val schema = table.getSchema + val sink = new TestingUpsertTableSink(Array(0, 3)). + configure(schema.getFieldNames, schema.getFieldTypes) + tEnv.writeToSink(table, sink) env.execute() val updatedExpected = List( "book,2,19,2", "fruit,5,34,2") - assertEquals(updatedExpected.sorted, tableSink.getUpsertResults.sorted) + assertEquals(updatedExpected.sorted, sink.getUpsertResults.sorted) } // FIXME @@ -524,8 +536,10 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode """.stripMargin val table = tEnv.sqlQuery(sql) - val tableSink = new TestingUpsertTableSink(Array(0, 1)) - writeToSink(table, tableSink) + val schema = table.getSchema + val sink = new TestingUpsertTableSink(Array(0, 1)). + configure(schema.getFieldNames, schema.getFieldTypes) + tEnv.writeToSink(table, sink) env.execute() val expected = List( @@ -536,7 +550,7 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode "fruit,1,3,5", "fruit,2,2,4", "fruit,3,1,3") - assertEquals(expected.sorted, tableSink.getUpsertResults.sorted) + assertEquals(expected.sorted, sink.getUpsertResults.sorted) } // FIXME @@ -581,14 +595,16 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode """.stripMargin val table = tEnv.sqlQuery(sql) - val tableSink = new TestingUpsertTableSink(Array(0, 1)) - writeToSink(table, tableSink) + val schema = table.getSchema + val sink = new TestingUpsertTableSink(Array(0, 1)). + configure(schema.getFieldNames, schema.getFieldTypes) + tEnv.writeToSink(table, sink) env.execute() val expected = List( "book,3,2,2", "fruit,3,1,3") - assertEquals(expected.sorted, tableSink.getUpsertResults.sorted) + assertEquals(expected.sorted, sink.getUpsertResults.sorted) } // FIXME @@ -638,9 +654,11 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode |WHERE rank_num <= 4 """.stripMargin - val tableSink = new TestingUpsertTableSink(Array(0)) val table = tEnv.sqlQuery(sql2) - writeToSink(table, tableSink) + val schema = table.getSchema + val sink = new TestingUpsertTableSink(Array(0)). + configure(schema.getFieldNames, schema.getFieldTypes) + tEnv.writeToSink(table, sink) env.execute() val expected = List( @@ -657,10 +675,10 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode "(true,3,book,d,4,2)", "(true,4,book,b,3,2)", "(true,4,book,b,3,2)") - assertEquals(expected.mkString("\n"), tableSink.getRawResults.mkString("\n")) + assertEquals(expected.mkString("\n"), sink.getRawResults.mkString("\n")) val expected2 = List("1,fruit,b,6,1", "2,book,e,5,1", "3,book,d,4,2", "4,book,b,3,2") - assertEquals(expected2, tableSink.getUpsertResults.sorted) + assertEquals(expected2, sink.getUpsertResults.sorted) } // FIXME @@ -699,9 +717,12 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode """.stripMargin tEnv.getConfig.getConf.setLong(TableConfigOptions.SQL_EXEC_TOPN_CACHE_SIZE, 1) - val tableSink = new TestingUpsertTableSink(Array(0)) + val table = tEnv.sqlQuery(sql) - writeToSink(table, tableSink) + val schema = table.getSchema + val sink = new TestingUpsertTableSink(Array(0)). + configure(schema.getFieldNames, schema.getFieldTypes) + tEnv.writeToSink(table, sink) env.execute() val expected = List( @@ -724,7 +745,7 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode "(true,3,book,b,3,2)", "(true,4,book,a,1,2)") - assertEquals(expected, tableSink.getRawResults) + assertEquals(expected, sink.getRawResults) } // FIXME @@ -768,9 +789,11 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode |WHERE rank_num <= topSize """.stripMargin - val tableSink = new TestingUpsertTableSink(Array(0, 1)) val table = tEnv.sqlQuery(sql) - writeToSink(table, tableSink) + val schema = table.getSchema + val sink = new TestingUpsertTableSink(Array(0, 1)). + configure(schema.getFieldNames, schema.getFieldTypes) + tEnv.writeToSink(table, sink) env.execute() val expected = List( @@ -780,7 +803,7 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode "book,4,1,3", "fruit,1,3,5", "fruit,2,2,4") - assertEquals(expected.sorted, tableSink.getUpsertResults.sorted) + assertEquals(expected.sorted, sink.getUpsertResults.sorted) } // FIXME @@ -829,9 +852,11 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode |WHERE rank_num <= 3 """.stripMargin - val tableSink = new TestingUpsertTableSink(Array(0, 3)) val table = tEnv.sqlQuery(sql) - writeToSink(table, tableSink) + val schema = table.getSchema + val sink = new TestingUpsertTableSink(Array(0, 3)). + configure(schema.getFieldNames, schema.getFieldTypes) + tEnv.writeToSink(table, sink) env.execute() val expected = List( @@ -862,14 +887,14 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode "(true,book,12,10,3)") - assertEquals(expected.mkString("\n"), tableSink.getRawResults.mkString("\n")) + assertEquals(expected.mkString("\n"), sink.getRawResults.mkString("\n")) val updatedExpected = List( "book,2,9,1", "book,7,10,2", "book,12,10,3") - assertEquals(updatedExpected.sorted, tableSink.getUpsertResults.sorted) + assertEquals(updatedExpected.sorted, sink.getUpsertResults.sorted) } // FIXME @@ -906,9 +931,11 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode |WHERE rank_num <= 3 """.stripMargin - val tableSink = new TestingUpsertTableSink(Array(0, 1)) val table = tEnv.sqlQuery(sql) - writeToSink(table, tableSink) + val schema = table.getSchema + val sink = new TestingUpsertTableSink(Array(0, 1)). + configure(schema.getFieldNames, schema.getFieldTypes) + tEnv.writeToSink(table, sink) env.execute() val expected = List( @@ -924,7 +951,7 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode "(true,book,1,225.0)", "(true,fruit,5,100.0)") - assertEquals(expected, tableSink.getRawResults) + assertEquals(expected, sink.getRawResults) val updatedExpected = List( "book,1,225.0", @@ -932,7 +959,7 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode "book,4,310.0", "fruit,5,100.0") - assertEquals(updatedExpected.sorted, tableSink.getUpsertResults.sorted) + assertEquals(updatedExpected.sorted, sink.getUpsertResults.sorted) } // FIXME @@ -977,9 +1004,11 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode |WHERE rank_num <= 3 """.stripMargin - val tableSink = new TestingUpsertTableSink(Array(0, 1)) val table = tEnv.sqlQuery(sql) - writeToSink(table, tableSink) + val schema = table.getSchema + val sink = new TestingUpsertTableSink(Array(0, 1)). + configure(schema.getFieldNames, schema.getFieldTypes) + tEnv.writeToSink(table, sink) env.execute() val expected = List( @@ -1002,7 +1031,7 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode "(true,fruit,4,2)", "(true,fruit,5,2)", "(true,fruit,5,3)") - assertEquals(expected, tableSink.getRawResults) + assertEquals(expected, sink.getRawResults) val updatedExpected = List( "book,4,5", @@ -1011,7 +1040,7 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode "fruit,5,3", "fruit,4,2", "fruit,3,1") - assertEquals(updatedExpected.sorted, tableSink.getUpsertResults.sorted) + assertEquals(updatedExpected.sorted, sink.getUpsertResults.sorted) } @Test @@ -1040,9 +1069,11 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode |WHERE rank_num <= 2 """.stripMargin - val tableSink = new TestingUpsertTableSink(Array(0, 2)) val table = tEnv.sqlQuery(sql) - writeToSink(table, tableSink) + val schema = table.getSchema + val sink = new TestingUpsertTableSink(Array(0, 2)). + configure(schema.getFieldNames, schema.getFieldTypes) + tEnv.writeToSink(table, sink) env.execute() val expected = List( @@ -1054,14 +1085,14 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode "(true,fruit,44,3)", "(false,fruit,33,4)", "(true,fruit,40,1)") - assertEquals(expected, tableSink.getRawResults) + assertEquals(expected, sink.getRawResults) val updatedExpected = List( "book,19,2", "book,20,5", "fruit,40,1", "fruit,44,3") - assertEquals(updatedExpected.sorted, tableSink.getUpsertResults.sorted) + assertEquals(updatedExpected.sorted, sink.getUpsertResults.sorted) } // FIXME @@ -1113,7 +1144,7 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode assertEquals(expected1.sorted, sink1.getRetractResults.sorted) val sink2 = new TestingRetractSink - val table2 = tEnv.sqlQuery( + tEnv.sqlQuery( s""" |SELECT * |FROM ( @@ -1122,7 +1153,6 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode | FROM MyView) |WHERE rank_num <= 2 |""".stripMargin).toRetractStream[Row].addSink(sink2).setParallelism(1) - env.execute() val expected2 = List( @@ -1160,7 +1190,6 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode val t1 = tEnv.sqlQuery(subquery) tEnv.registerTable("MyView", t1) - val sink1 = new TestingUpsertTableSink(Array(0, 3)) val table1 = tEnv.sqlQuery( s""" |SELECT * @@ -1170,9 +1199,12 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode | FROM MyView) |WHERE rank_num <= 2 |""".stripMargin) - writeToSink(table1, sink1) + val schema1 = table1.getSchema + val sink1 = new TestingUpsertTableSink(Array(0, 3)). + configure(schema1.getFieldNames, schema1 + .getFieldTypes) + tEnv.writeToSink(table1, sink1) - val sink2 = new TestingUpsertTableSink(Array(0, 3)) val table2 = tEnv.sqlQuery( s""" |SELECT * @@ -1182,9 +1214,13 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode | FROM MyView) |WHERE rank_num <= 2 |""".stripMargin) + val schema2 = table2.getSchema + val sink2 = new TestingUpsertTableSink(Array(0, 3)). + configure(schema2.getFieldNames, schema2 + .getFieldTypes) + tEnv.writeToSink(table2, sink2) - writeToSink(table2, sink2) - + env.execute() val expected1 = List( "book,1,25,1", "book,2,19,2", @@ -1227,7 +1263,6 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode val t1 = tEnv.sqlQuery(subquery) tEnv.registerTable("MyView", t1) - val sink1 = new TestingRetractTableSink val table1 = tEnv.sqlQuery( s""" |SELECT * @@ -1237,9 +1272,11 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode | FROM MyView) |WHERE rank_num <= 2 |""".stripMargin) - writeToSink(table1, sink1) + val schema1 = table1.getSchema + val sink1 = new TestingRetractTableSink(). + configure(schema1.getFieldNames, schema1.getFieldTypes) + tEnv.writeToSink(table1, sink1) - val sink2 = new TestingRetractTableSink val table2 = tEnv.sqlQuery( s""" |SELECT * @@ -1249,8 +1286,11 @@ class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode | FROM MyView) |WHERE rank_num <= 2 |""".stripMargin) - - writeToSink(table2, sink2) + val schema2 = table2.getSchema + val sink2 = new TestingRetractTableSink(). + configure(schema2.getFieldNames, schema2.getFieldTypes) + tEnv.writeToSink(table2, sink2) + env.execute() val expected1 = List( "book,1,2,1", diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamTestSink.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamTestSink.scala index e3fc360a90ddd..e91a1523b3ac9 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamTestSink.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamTestSink.scala @@ -293,8 +293,7 @@ final class TestingUpsertTableSink(keys: Array[Int], tz: TimeZone) override def configure( fieldNames: Array[String], - fieldTypes: Array[TypeInformation[_]]) - : TableSink[JTuple2[JBoolean, BaseRow]] = { + fieldTypes: Array[TypeInformation[_]]): TestingUpsertTableSink = { val copy = new TestingUpsertTableSink(keys, tz) copy.fNames = fieldNames copy.fTypes = fieldTypes @@ -501,7 +500,7 @@ final class TestingRetractTableSink(tz: TimeZone) extends RetractStreamTableSink override def configure( fieldNames: Array[String], - fieldTypes: Array[TypeInformation[_]]): TableSink[JTuple2[JBoolean, Row]] = { + fieldTypes: Array[TypeInformation[_]]): TestingRetractTableSink = { val copy = new TestingRetractTableSink(tz) copy.fNames = fieldNames copy.fTypes = fieldTypes diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingTestBase.scala index 94b814fe833a8..1335e897bda1b 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingTestBase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingTestBase.scala @@ -20,9 +20,7 @@ package org.apache.flink.table.runtime.utils import org.apache.flink.streaming.api.TimeCharacteristic import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment -import org.apache.flink.table.api.{Table, TableImpl} import org.apache.flink.table.api.scala.StreamTableEnvironment -import org.apache.flink.table.sinks.TableSink import org.apache.flink.test.util.AbstractTestBase import org.junit.rules.{ExpectedException, TemporaryFolder} @@ -55,8 +53,4 @@ class StreamingTestBase extends AbstractTestBase { this.tEnv = StreamTableEnvironment.create(env) } - def writeToSink(table: Table, sink: TableSink[_]): Unit = { - TableUtil.writeToSink(table.asInstanceOf[TableImpl], sink) - } - } diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingWithStateTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingWithStateTestBase.scala index ae80d9c0525bf..9a3836ba76703 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingWithStateTestBase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingWithStateTestBase.scala @@ -109,32 +109,6 @@ class StreamingWithStateTestBase(state: StateBackendMode) extends StreamingTestB failingDataSource(result)(newTypeInfo.asInstanceOf[TypeInformation[BaseRow]]) } - /** - * Creates a DataStream from the given non-empty [[Seq]]. - */ - def retainStateDataSource[T: TypeInformation](data: Seq[T]): DataStream[T] = { - env.enableCheckpointing(100, CheckpointingMode.EXACTLY_ONCE) - env.setRestartStrategy(RestartStrategies.noRestart()) - env.setParallelism(1) - // reset failedBefore flag to false - FailingCollectionSource.reset() - - require(data != null, "Data must not be null.") - val typeInfo = implicitly[TypeInformation[T]] - - val collection = scala.collection.JavaConversions.asJavaCollection(data) - // must not have null elements and mixed elements - FromElementsFunction.checkCollection(data, typeInfo.getTypeClass) - - val function = new FailingCollectionSource[T]( - typeInfo.createSerializer(env.getConfig), - collection, - data.length, // fail after half elements - true) - - env.addSource(function)(typeInfo).setMaxParallelism(1) - } - /** * Creates a DataStream from the given non-empty [[Seq]]. */ diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TableUtil.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TableUtil.scala index 1522bb63d38ec..4595ac81841bd 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TableUtil.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TableUtil.scala @@ -22,7 +22,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.table.`type`.TypeConverters.createExternalTypeInfoFromInternalType import org.apache.flink.table.api.{BatchTableEnvironment, TableImpl} import org.apache.flink.table.calcite.FlinkTypeFactory -import org.apache.flink.table.sinks.{CollectRowTableSink, CollectTableSink, TableSink} +import org.apache.flink.table.sinks.{CollectRowTableSink, CollectTableSink} import org.apache.flink.types.Row import _root_.scala.collection.JavaConversions._ @@ -61,14 +61,4 @@ object TableUtil { table, configuredSink.asInstanceOf[CollectTableSink[T]], jobName) } - def writeToSink(table: TableImpl, sink: TableSink[_]): Unit = { - // get schema information of table - val rowType = table.getRelNode.getRowType - val fieldNames = rowType.getFieldNames.asScala.toArray - val fieldTypes = rowType.getFieldList - .map(field => FlinkTypeFactory.toInternalType(field.getType)).toArray - val configuredSink = sink.configure( - fieldNames, fieldTypes.map(createExternalTypeInfoFromInternalType)) - table.tableEnv.writeToSink(table, configuredSink) - } } diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionHelper.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionHelper.java index 831a2cfbd6389..1c554bd5c1497 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionHelper.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionHelper.java @@ -39,21 +39,18 @@ class DeduplicateFunctionHelper { * @param out underlying collector * @throws Exception */ - static void processLastRow(BaseRow currentRow, boolean generateRetraction, ValueState state, + static void processLastRow(BaseRow currentRow, boolean generateRetraction, ValueState state, Collector out) throws Exception { - // should be accumulate msg + // Check message should be accumulate Preconditions.checkArgument(BaseRowUtil.isAccumulateMsg(currentRow)); if (generateRetraction) { // state stores complete row if generateRetraction is true - BaseRow preRow = (BaseRow) state.value(); + BaseRow preRow = state.value(); state.update(currentRow); if (preRow != null) { preRow.setHeader(BaseRowUtil.RETRACT_MSG); out.collect(preRow); } - } else { - // state stores a flag to indicator whether pk appears before - state.update(true); } out.collect(currentRow); } @@ -66,9 +63,9 @@ static void processLastRow(BaseRow currentRow, boolean generateRetraction, Value * @param out underlying collector * @throws Exception */ - static void processFirstRow(BaseRow currentRow, ValueState state, Collector out) + static void processFirstRow(BaseRow currentRow, ValueState state, Collector out) throws Exception { - // should be accumulate msg. + // Check message should be accumulate Preconditions.checkArgument(BaseRowUtil.isAccumulateMsg(currentRow)); // ignore record with timestamp bigger than preRow if (state.value() != null) { diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateKeepFirstRowFunction.java similarity index 58% rename from flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunction.java rename to flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateKeepFirstRowFunction.java index 1467b1c9f0e99..14feaf981fd32 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunction.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateKeepFirstRowFunction.java @@ -24,49 +24,30 @@ import org.apache.flink.configuration.Configuration; import org.apache.flink.table.dataformat.BaseRow; import org.apache.flink.table.runtime.functions.KeyedProcessFunctionWithCleanupState; -import org.apache.flink.table.typeutils.BaseRowTypeInfo; import org.apache.flink.util.Collector; import static org.apache.flink.table.runtime.deduplicate.DeduplicateFunctionHelper.processFirstRow; -import static org.apache.flink.table.runtime.deduplicate.DeduplicateFunctionHelper.processLastRow; /** - * This function is used to deduplicate on keys and keeps only first row or last row. + * This function is used to deduplicate on keys and keeps only first row. */ -public class DeduplicateFunction +public class DeduplicateKeepFirstRowFunction extends KeyedProcessFunctionWithCleanupState { - private static final long serialVersionUID = 4950071982706870944L; + private static final long serialVersionUID = 5865777137707602549L; - private final BaseRowTypeInfo rowTypeInfo; - private final boolean generateRetraction; - private final boolean keepLastRow; + // state stores a boolean flag to indicate whether key appears before. + private ValueState state; - // state stores complete row if keep last row and generate retraction is true, - // else stores a flag to indicate whether key appears before. - private ValueState state; - - public DeduplicateFunction(long minRetentionTime, long maxRetentionTime, BaseRowTypeInfo rowTypeInfo, - boolean generateRetraction, boolean keepLastRow) { + public DeduplicateKeepFirstRowFunction(long minRetentionTime, long maxRetentionTime) { super(minRetentionTime, maxRetentionTime); - this.rowTypeInfo = rowTypeInfo; - this.generateRetraction = generateRetraction; - this.keepLastRow = keepLastRow; } @Override public void open(Configuration configure) throws Exception { super.open(configure); - String stateName = keepLastRow ? "DeduplicateFunctionKeepLastRow" : "DeduplicateFunctionKeepFirstRow"; - initCleanupTimeState(stateName); - ValueStateDescriptor stateDesc = null; - if (keepLastRow && generateRetraction) { - // if need generate retraction and keep last row, stores complete row into state - stateDesc = new ValueStateDescriptor("deduplicateFunction", rowTypeInfo); - } else { - // else stores a flag to indicator whether pk appears before. - stateDesc = new ValueStateDescriptor("fistValueState", Types.BOOLEAN); - } + initCleanupTimeState("DeduplicateFunctionKeepFirstRow"); + ValueStateDescriptor stateDesc = new ValueStateDescriptor<>("existsState", Types.BOOLEAN); state = getRuntimeContext().getState(stateDesc); } @@ -75,12 +56,7 @@ public void processElement(BaseRow input, Context ctx, Collector out) t long currentTime = ctx.timerService().currentProcessingTime(); // register state-cleanup timer registerProcessingCleanupTimer(ctx, currentTime); - - if (keepLastRow) { - processLastRow(input, generateRetraction, state, out); - } else { - processFirstRow(input, state, out); - } + processFirstRow(input, state, out); } @Override diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateKeepLastRowFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateKeepLastRowFunction.java new file mode 100644 index 0000000000000..3dcde66eb5ea6 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateKeepLastRowFunction.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.deduplicate; + +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.runtime.functions.KeyedProcessFunctionWithCleanupState; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; +import org.apache.flink.util.Collector; + +import static org.apache.flink.table.runtime.deduplicate.DeduplicateFunctionHelper.processLastRow; + +/** + * This function is used to deduplicate on keys and keeps only last row. + */ +public class DeduplicateKeepLastRowFunction + extends KeyedProcessFunctionWithCleanupState { + + private static final long serialVersionUID = -291348892087180350L; + private final BaseRowTypeInfo rowTypeInfo; + private final boolean generateRetraction; + + // state stores complete row. + private ValueState state; + + public DeduplicateKeepLastRowFunction(long minRetentionTime, long maxRetentionTime, BaseRowTypeInfo rowTypeInfo, + boolean generateRetraction) { + super(minRetentionTime, maxRetentionTime); + this.rowTypeInfo = rowTypeInfo; + this.generateRetraction = generateRetraction; + } + + @Override + public void open(Configuration configure) throws Exception { + super.open(configure); + if (generateRetraction) { + // state stores complete row if need generate retraction, otherwise do not need a state + initCleanupTimeState("DeduplicateFunctionKeepLastRow"); + ValueStateDescriptor stateDesc = new ValueStateDescriptor<>("preRowState", rowTypeInfo); + state = getRuntimeContext().getState(stateDesc); + } + } + + @Override + public void processElement(BaseRow input, Context ctx, Collector out) throws Exception { + if (generateRetraction) { + long currentTime = ctx.timerService().currentProcessingTime(); + // register state-cleanup timer + registerProcessingCleanupTimer(ctx, currentTime); + } + processLastRow(input, generateRetraction, state, out); + } + + @Override + public void onTimer(long timestamp, OnTimerContext ctx, Collector out) throws Exception { + if (stateCleaningEnabled) { + cleanupState(state); + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateKeepFirstRowFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateKeepFirstRowFunction.java new file mode 100644 index 0000000000000..b97d1a670ac3b --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateKeepFirstRowFunction.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.deduplicate; + +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.runtime.bundle.MapBundleFunction; +import org.apache.flink.table.runtime.context.ExecutionContext; +import org.apache.flink.util.Collector; + +import javax.annotation.Nullable; + +import java.util.Map; + +import static org.apache.flink.table.runtime.deduplicate.DeduplicateFunctionHelper.processFirstRow; + +/** + * This function is used to get the first row for every key partition in miniBatch mode. + */ +public class MiniBatchDeduplicateKeepFirstRowFunction + extends MapBundleFunction { + + private static final long serialVersionUID = -7994602893547654994L; + + private final TypeSerializer typeSerializer; + + // state stores a boolean flag to indicate whether key appears before. + private ValueState state; + + public MiniBatchDeduplicateKeepFirstRowFunction(TypeSerializer typeSerializer) { + this.typeSerializer = typeSerializer; + } + + @Override + public void open(ExecutionContext ctx) throws Exception { + super.open(ctx); + ValueStateDescriptor stateDesc = new ValueStateDescriptor<>("existsState", Types.BOOLEAN); + state = ctx.getRuntimeContext().getState(stateDesc); + } + + @Override + public BaseRow addInput(@Nullable BaseRow value, BaseRow input) { + if (value == null) { + // put the input into buffer + return typeSerializer.copy(input); + } else { + // the input is not first row, ignore it + return value; + } + } + + @Override + public void finishBundle( + Map buffer, Collector out) throws Exception { + for (Map.Entry entry : buffer.entrySet()) { + BaseRow currentKey = entry.getKey(); + BaseRow currentRow = entry.getValue(); + ctx.setCurrentKey(currentKey); + processFirstRow(currentRow, state, out); + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateKeepLastRowFunction.java similarity index 57% rename from flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateFunction.java rename to flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateKeepLastRowFunction.java index 8b8f9832e1748..c1f2ec40cf5b8 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateFunction.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateKeepLastRowFunction.java @@ -20,7 +20,6 @@ import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.state.ValueStateDescriptor; -import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.table.dataformat.BaseRow; import org.apache.flink.table.runtime.bundle.MapBundleFunction; @@ -32,59 +31,41 @@ import java.util.Map; -import static org.apache.flink.table.runtime.deduplicate.DeduplicateFunctionHelper.processFirstRow; import static org.apache.flink.table.runtime.deduplicate.DeduplicateFunctionHelper.processLastRow; /** - * This function is used to get the first row or last row for every key partition in miniBatch - * mode. + * This function is used to get the last row for every key partition in miniBatch mode. */ -public class MiniBatchDeduplicateFunction +public class MiniBatchDeduplicateKeepLastRowFunction extends MapBundleFunction { - private BaseRowTypeInfo rowTypeInfo; - private boolean generateRetraction; - private boolean keepLastRow; + private static final long serialVersionUID = -8981813609115029119L; - // state stores complete row if keep last row and generate retraction is true, - // else stores a flag to indicate whether key appears before. - private ValueState state; - private TypeSerializer ser; + private final BaseRowTypeInfo rowTypeInfo; + private final boolean generateRetraction; + private final TypeSerializer typeSerializer; - public MiniBatchDeduplicateFunction( - BaseRowTypeInfo rowTypeInfo, - boolean generateRetraction, - TypeSerializer typeSerializer, - boolean keepLastRow) { + // state stores complete row. + private ValueState state; + + public MiniBatchDeduplicateKeepLastRowFunction(BaseRowTypeInfo rowTypeInfo, boolean generateRetraction, + TypeSerializer typeSerializer) { this.rowTypeInfo = rowTypeInfo; - this.keepLastRow = keepLastRow; this.generateRetraction = generateRetraction; - ser = typeSerializer; + this.typeSerializer = typeSerializer; } @Override public void open(ExecutionContext ctx) throws Exception { super.open(ctx); - ValueStateDescriptor stateDesc = null; - if (keepLastRow && generateRetraction) { - // if need generate retraction and keep last row, stores complete row into state - stateDesc = new ValueStateDescriptor("deduplicateFunction", rowTypeInfo); - } else { - // else stores a flag to indicator whether pk appears before. - stateDesc = new ValueStateDescriptor("fistValueState", Types.BOOLEAN); - } + ValueStateDescriptor stateDesc = new ValueStateDescriptor<>("preRowState", rowTypeInfo); state = ctx.getRuntimeContext().getState(stateDesc); } @Override public BaseRow addInput(@Nullable BaseRow value, BaseRow input) { - if (value == null || keepLastRow || (!keepLastRow && value == null)) { - // put the input into buffer - return ser.copy(input); - } else { - // the input is not last row, ignore it - return value; - } + // always put the input into buffer + return typeSerializer.copy(input); } @Override @@ -94,12 +75,7 @@ public void finishBundle( BaseRow currentKey = entry.getKey(); BaseRow currentRow = entry.getValue(); ctx.setCurrentKey(currentKey); - - if (keepLastRow) { - processLastRow(currentRow, generateRetraction, state, out); - } else { - processFirstRow(currentRow, state, out); - } + processLastRow(currentRow, generateRetraction, state, out); } } } diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/BinaryRowKeySelector.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/BinaryRowKeySelector.java index 2b51e7a2db55a..3d7887a57afcd 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/BinaryRowKeySelector.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/BinaryRowKeySelector.java @@ -25,7 +25,7 @@ import org.apache.flink.table.typeutils.BaseRowTypeInfo; /** - * A utility class which will extract key from BaseRow. + * A utility class which extracts key from BaseRow. The key type is BinaryRow. */ public class BinaryRowKeySelector implements BaseRowKeySelector { diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/NullBinaryRowKeySelector.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/NullBinaryRowKeySelector.java index ed83f9e2adfa5..795bdbf5a3796 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/NullBinaryRowKeySelector.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/NullBinaryRowKeySelector.java @@ -23,10 +23,12 @@ import org.apache.flink.table.typeutils.BaseRowTypeInfo; /** - * A utility class which key is always empty. + * A utility class which key is always empty no matter what the input row is. */ public class NullBinaryRowKeySelector implements BaseRowKeySelector { + private static final long serialVersionUID = -2079386198687082032L; + private final BaseRowTypeInfo returnType = new BaseRowTypeInfo(); @Override diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractRankFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractRankFunction.java index 431d25fb8f977..5c44133ded5a6 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractRankFunction.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractRankFunction.java @@ -155,7 +155,7 @@ public void open(Configuration parameters) throws Exception { outputRow = new JoinedRow(); if (!isConstantRankEnd) { - ValueStateDescriptor rankStateDesc = new ValueStateDescriptor("rankEnd", Types.LONG); + ValueStateDescriptor rankStateDesc = new ValueStateDescriptor<>("rankEnd", Types.LONG); rankEndState = getRuntimeContext().getState(rankStateDesc); } // compile equaliser @@ -223,7 +223,7 @@ protected boolean checkSortKeyInBufferRange(BaseRow sortKey, TopNBuffer buffer) if (compare < 0) { return true; } else { - return buffer.getCurrentTopNum() < getMaxSizeOfBuffer(); + return buffer.getCurrentTopNum() < getDefaultTopNSize(); } } } @@ -231,24 +231,10 @@ protected boolean checkSortKeyInBufferRange(BaseRow sortKey, TopNBuffer buffer) protected void registerMetric(long heapSize) { getRuntimeContext().getMetricGroup().>gauge( "topn.cache.hitRate", - new Gauge() { - - @Override - public Double getValue() { - return requestCount == 0 ? 1.0 : - Long.valueOf(hitCount).doubleValue() / requestCount; - } - }); + () -> requestCount == 0 ? 1.0 : Long.valueOf(hitCount).doubleValue() / requestCount); getRuntimeContext().getMetricGroup().>gauge( - "topn.cache.size", - new Gauge() { - - @Override - public Long getValue() { - return heapSize; - } - }); + "topn.cache.size", () -> heapSize); } protected void collect(Collector out, BaseRow inputRow) { @@ -313,13 +299,6 @@ private BaseRow createOutputRow(BaseRow inputRow, long rank, byte header) { } } - /** - * Gets buffer size limit. Implementations may vary depending on each rank who has in-memory buffer. - * - * @return buffer size limit - */ - protected abstract long getMaxSizeOfBuffer(); - /** * Sets keyContext to RankFunction. * diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AppendRankFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AppendRankFunction.java index 6464637e2a871..73e7edb631a3c 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AppendRankFunction.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AppendRankFunction.java @@ -40,7 +40,6 @@ import java.util.Iterator; import java.util.List; import java.util.Map; -import java.util.function.Supplier; /** * AppendRankFunction's input stream only contains append record. @@ -82,7 +81,7 @@ public void open(Configuration parameters) throws Exception { LOG.info("Top{} operator is using LRU caches key-size: {}", getDefaultTopNSize(), lruCacheSize); ListTypeInfo valueTypeInfo = new ListTypeInfo<>(inputRowType); - MapStateDescriptor> mapStateDescriptor = new MapStateDescriptor( + MapStateDescriptor> mapStateDescriptor = new MapStateDescriptor<>( "data-state-with-append", sortKeyType, valueTypeInfo); dataState = getRuntimeContext().getMapState(mapStateDescriptor); @@ -129,23 +128,12 @@ public void onTimer( } } - @Override - protected long getMaxSizeOfBuffer() { - return getDefaultTopNSize(); - } - private void initHeapStates() throws Exception { requestCount += 1; BaseRow currentKey = (BaseRow) keyContext.getCurrentKey(); buffer = kvSortedMap.get(currentKey); if (buffer == null) { - buffer = new TopNBuffer(sortKeyComparator, new Supplier>() { - - @Override - public Collection get() { - return new ArrayList<>(); - } - }); + buffer = new TopNBuffer(sortKeyComparator, ArrayList::new); kvSortedMap.put(currentKey, buffer); // restore buffer Iterator>> iter = dataState.iterator(); diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RetractRankFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RetractRankFunction.java index 21c3e934ab004..eb0250d83de25 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RetractRankFunction.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RetractRankFunction.java @@ -81,13 +81,13 @@ public RetractRankFunction(long minRetentionTime, long maxRetentionTime, BaseRow public void open(Configuration parameters) throws Exception { super.open(parameters); ListTypeInfo valueTypeInfo = new ListTypeInfo<>(inputRowType); - MapStateDescriptor> mapStateDescriptor = new MapStateDescriptor( + MapStateDescriptor> mapStateDescriptor = new MapStateDescriptor<>( "data-state", sortKeyType, valueTypeInfo); dataState = getRuntimeContext().getMapState(mapStateDescriptor); - ValueStateDescriptor> valueStateDescriptor = new ValueStateDescriptor( + ValueStateDescriptor> valueStateDescriptor = new ValueStateDescriptor<>( "sorted-map", - new SortedMapTypeInfo(sortKeyType, BasicTypeInfo.LONG_TYPE_INFO, sortKeyComparator)); + new SortedMapTypeInfo<>(sortKeyType, BasicTypeInfo.LONG_TYPE_INFO, sortKeyComparator)); treeMap = getRuntimeContext().getState(valueStateDescriptor); } @@ -242,9 +242,4 @@ private void emitRecordsWithRowNumber( } } - @Override - protected long getMaxSizeOfBuffer() { - // just let it go, retract rank has no interest in this - return 0L; - } } diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/UpdateRankFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/UpdateRankFunction.java index 1965e52e2ea2c..a78b00d5e11ea 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/UpdateRankFunction.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/UpdateRankFunction.java @@ -49,7 +49,6 @@ import java.util.List; import java.util.Map; import java.util.TreeMap; -import java.util.function.Supplier; /** * A fast version of rank process function which only hold top n data in state, and keep sorted map in heap. @@ -104,20 +103,20 @@ public UpdateRankFunction(long minRetentionTime, long maxRetentionTime, BaseRowT @Override public void open(Configuration parameters) throws Exception { super.open(parameters); - int lruCacheSize = Math.max(1, (int) (cacheSize / getMaxSizeOfBuffer())); + int lruCacheSize = Math.max(1, (int) (cacheSize / getDefaultTopNSize())); // make sure the cached map is in a fixed size, avoid OOM kvSortedMap = new HashMap<>(lruCacheSize); kvRowKeyMap = new LRUMap<>(lruCacheSize, new CacheRemovalListener()); - LOG.info("Top{} operator is using LRU caches key-size: {}", getMaxSizeOfBuffer(), lruCacheSize); + LOG.info("Top{} operator is using LRU caches key-size: {}", getDefaultTopNSize(), lruCacheSize); TupleTypeInfo> valueTypeInfo = new TupleTypeInfo<>(inputRowType, Types.INT); - MapStateDescriptor> mapStateDescriptor = new MapStateDescriptor( + MapStateDescriptor> mapStateDescriptor = new MapStateDescriptor<>( "data-state-with-update", rowKeyType, valueTypeInfo); dataState = getRuntimeContext().getMapState(mapStateDescriptor); // metrics - registerMetric(kvSortedMap.size() * getMaxSizeOfBuffer()); + registerMetric(kvSortedMap.size() * getDefaultTopNSize()); } @Override @@ -157,11 +156,6 @@ public void processElement( } } - @Override - protected long getMaxSizeOfBuffer() { - return getDefaultTopNSize(); - } - @Override public void snapshotState(FunctionSnapshotContext context) throws Exception { Iterator>> iter = kvRowKeyMap.entrySet().iterator(); @@ -170,7 +164,7 @@ public void snapshotState(FunctionSnapshotContext context) throws Exception { BaseRow partitionKey = entry.getKey(); Map currentRowKeyMap = entry.getValue(); keyContext.setCurrentKey(partitionKey); - synchronizeState(currentRowKeyMap); + flushBufferToState(currentRowKeyMap); } } @@ -180,13 +174,7 @@ private void initHeapStates() throws Exception { buffer = kvSortedMap.get(partitionKey); rowKeyMap = kvRowKeyMap.get(partitionKey); if (buffer == null) { - buffer = new TopNBuffer(sortKeyComparator, new Supplier>() { - - @Override - public Collection get() { - return new LinkedHashSet<>(); - } - }); + buffer = new TopNBuffer(sortKeyComparator, LinkedHashSet::new); rowKeyMap = new HashMap<>(); kvSortedMap.put(partitionKey, buffer); kvRowKeyMap.put(partitionKey, rowKeyMap); @@ -432,7 +420,7 @@ private void processElementWithoutRowNumber(BaseRow inputRow, Collector } } - private void synchronizeState(Map curRowKeyMap) throws Exception { + private void flushBufferToState(Map curRowKeyMap) throws Exception { Iterator> iter = curRowKeyMap.entrySet().iterator(); while (iter.hasNext()) { Map.Entry entry = iter.next(); @@ -473,7 +461,7 @@ public void onRemoval(Map.Entry> eldest) { keyContext.setCurrentKey(partitionKey); kvSortedMap.remove(partitionKey); try { - synchronizeState(currentRowKeyMap); + flushBufferToState(currentRowKeyMap); } catch (Throwable e) { LOG.error("Fail to synchronize state!", e); throw new RuntimeException(e); diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/DeduplicateKeepFirstRowFunctionTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/DeduplicateKeepFirstRowFunctionTest.java new file mode 100644 index 0000000000000..f23acd727a00d --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/DeduplicateKeepFirstRowFunctionTest.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.deduplicate; + +import org.apache.flink.api.common.time.Time; +import org.apache.flink.streaming.api.operators.KeyedProcessOperator; +import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; +import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.runtime.util.BaseRowHarnessAssertor; +import org.apache.flink.table.runtime.util.BinaryRowKeySelector; +import org.apache.flink.table.runtime.util.GenericRowRecordSortComparator; +import org.apache.flink.table.type.InternalTypes; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; + +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.apache.flink.table.runtime.util.StreamRecordUtils.record; + +/** + * Tests for {@link DeduplicateKeepFirstRowFunction}. + */ +public class DeduplicateKeepFirstRowFunctionTest { + + private Time minTime = Time.milliseconds(10); + private Time maxTime = Time.milliseconds(20); + + private BaseRowTypeInfo inputRowType = new BaseRowTypeInfo(InternalTypes.STRING, InternalTypes.LONG, + InternalTypes.INT); + + private int rowKeyIdx = 1; + private BinaryRowKeySelector rowKeySelector = new BinaryRowKeySelector(new int[] { rowKeyIdx }, + inputRowType.getInternalTypes()); + + private BaseRowHarnessAssertor assertor = new BaseRowHarnessAssertor( + inputRowType.getFieldTypes(), + new GenericRowRecordSortComparator(rowKeyIdx, inputRowType.getInternalTypes()[rowKeyIdx])); + + private OneInputStreamOperatorTestHarness createTestHarness( + DeduplicateKeepFirstRowFunction func) + throws Exception { + KeyedProcessOperator operator = new KeyedProcessOperator<>(func); + return new KeyedOneInputStreamOperatorTestHarness<>(operator, rowKeySelector, rowKeySelector.getProducedType()); + } + + @Test + public void test() throws Exception { + DeduplicateKeepFirstRowFunction func = new DeduplicateKeepFirstRowFunction(minTime.toMilliseconds(), + maxTime.toMilliseconds()); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 1L, 12)); + testHarness.processElement(record("book", 2L, 11)); + testHarness.processElement(record("book", 1L, 13)); + testHarness.close(); + + // Keep FirstRow in deduplicate will not send retraction + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 1L, 12)); + expectedOutputOutput.add(record("book", 2L, 11)); + assertor.assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/DeduplicateKeepLastRowFunctionTest.java similarity index 62% rename from flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionTest.java rename to flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/DeduplicateKeepLastRowFunctionTest.java index 6210f1945f83c..f6c3419185f43 100644 --- a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionTest.java +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/DeduplicateKeepLastRowFunctionTest.java @@ -38,9 +38,9 @@ import static org.apache.flink.table.runtime.util.StreamRecordUtils.retractRecord; /** - * Tests for {@link DeduplicateFunction}. + * Tests for {@link DeduplicateKeepLastRowFunction}. */ -public class DeduplicateFunctionTest { +public class DeduplicateKeepLastRowFunctionTest { private Time minTime = Time.milliseconds(10); private Time maxTime = Time.milliseconds(20); @@ -56,55 +56,21 @@ public class DeduplicateFunctionTest { inputRowType.getFieldTypes(), new GenericRowRecordSortComparator(rowKeyIdx, inputRowType.getInternalTypes()[rowKeyIdx])); - private DeduplicateFunction createFunction(boolean generateRetraction, boolean keepLastRow) { - DeduplicateFunction func = new DeduplicateFunction(minTime.toMilliseconds(), maxTime.toMilliseconds(), - inputRowType, generateRetraction, keepLastRow); - return func; + private DeduplicateKeepLastRowFunction createFunction(boolean generateRetraction) { + return new DeduplicateKeepLastRowFunction(minTime.toMilliseconds(), maxTime.toMilliseconds(), inputRowType, + generateRetraction); } private OneInputStreamOperatorTestHarness createTestHarness( - DeduplicateFunction func) + DeduplicateKeepLastRowFunction func) throws Exception { - KeyedProcessOperator operator = new KeyedProcessOperator(func); - return new KeyedOneInputStreamOperatorTestHarness(operator, rowKeySelector, rowKeySelector.getProducedType()); + KeyedProcessOperator operator = new KeyedProcessOperator<>(func); + return new KeyedOneInputStreamOperatorTestHarness<>(operator, rowKeySelector, rowKeySelector.getProducedType()); } @Test - public void testKeepFirstRowWithoutGenerateRetraction() throws Exception { - DeduplicateFunction func = createFunction(false, false); - OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); - testHarness.open(); - testHarness.processElement(record("book", 1L, 12)); - testHarness.processElement(record("book", 2L, 11)); - testHarness.processElement(record("book", 1L, 13)); - testHarness.close(); - - List expectedOutputOutput = new ArrayList<>(); - expectedOutputOutput.add(record("book", 1L, 12)); - expectedOutputOutput.add(record("book", 2L, 11)); - assertor.assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); - } - - @Test - public void testKeepFirstRowWithGenerateRetraction() throws Exception { - DeduplicateFunction func = createFunction(true, false); - OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); - testHarness.open(); - testHarness.processElement(record("book", 1L, 12)); - testHarness.processElement(record("book", 2L, 11)); - testHarness.processElement(record("book", 1L, 13)); - testHarness.close(); - - // Keep FirstRow in deduplicate will not send retraction - List expectedOutputOutput = new ArrayList<>(); - expectedOutputOutput.add(record("book", 1L, 12)); - expectedOutputOutput.add(record("book", 2L, 11)); - assertor.assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); - } - - @Test - public void testKeepLastWithoutGenerateRetraction() throws Exception { - DeduplicateFunction func = createFunction(false, true); + public void testWithoutGenerateRetraction() throws Exception { + DeduplicateKeepLastRowFunction func = createFunction(false); OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); testHarness.open(); testHarness.processElement(record("book", 1L, 12)); @@ -120,8 +86,8 @@ public void testKeepLastWithoutGenerateRetraction() throws Exception { } @Test - public void testKeepLastRowWithGenerateRetraction() throws Exception { - DeduplicateFunction func = createFunction(true, true); + public void testWithGenerateRetraction() throws Exception { + DeduplicateKeepLastRowFunction func = createFunction(true); OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); testHarness.open(); testHarness.processElement(record("book", 1L, 12)); diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateKeepFirstRowFunctionTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateKeepFirstRowFunctionTest.java new file mode 100644 index 0000000000000..af3dbde80c9ce --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateKeepFirstRowFunctionTest.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.deduplicate; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; +import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.runtime.bundle.KeyedMapBundleOperator; +import org.apache.flink.table.runtime.bundle.trigger.CountBundleTrigger; +import org.apache.flink.table.runtime.util.BaseRowHarnessAssertor; +import org.apache.flink.table.runtime.util.BinaryRowKeySelector; +import org.apache.flink.table.runtime.util.GenericRowRecordSortComparator; +import org.apache.flink.table.type.InternalTypes; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; + +import org.junit.Assert; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.apache.flink.table.runtime.util.StreamRecordUtils.record; + +/** + * Tests for {@link MiniBatchDeduplicateKeepFirstRowFunction}. + */ +public class MiniBatchDeduplicateKeepFirstRowFunctionTest { + + private BaseRowTypeInfo inputRowType = new BaseRowTypeInfo(InternalTypes.STRING, InternalTypes.LONG, + InternalTypes.INT); + + private int rowKeyIdx = 1; + private BinaryRowKeySelector rowKeySelector = new BinaryRowKeySelector(new int[] { rowKeyIdx }, + inputRowType.getInternalTypes()); + + + private BaseRowHarnessAssertor assertor = new BaseRowHarnessAssertor( + inputRowType.getFieldTypes(), + new GenericRowRecordSortComparator(rowKeyIdx, inputRowType.getInternalTypes()[rowKeyIdx])); + + private TypeSerializer typeSerializer = inputRowType.createSerializer(new ExecutionConfig()); + + private OneInputStreamOperatorTestHarness createTestHarness( + MiniBatchDeduplicateKeepFirstRowFunction func) + throws Exception { + CountBundleTrigger> trigger = new CountBundleTrigger<>(3); + KeyedMapBundleOperator op = new KeyedMapBundleOperator(func, trigger); + return new KeyedOneInputStreamOperatorTestHarness<>(op, rowKeySelector, rowKeySelector.getProducedType()); + } + + @Test + public void testKeepFirstRowWithGenerateRetraction() throws Exception { + MiniBatchDeduplicateKeepFirstRowFunction func = new MiniBatchDeduplicateKeepFirstRowFunction(typeSerializer); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 1L, 12)); + testHarness.processElement(record("book", 2L, 11)); + + // output is empty because bundle not trigger yet. + Assert.assertTrue(testHarness.getOutput().isEmpty()); + + testHarness.processElement(record("book", 1L, 13)); + + // Keep FirstRow in deduplicate will not send retraction + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 1L, 12)); + expectedOutputOutput.add(record("book", 2L, 11)); + assertor.assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + testHarness.close(); + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateFunctionTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateKeepLastRowFunctionTest.java similarity index 67% rename from flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateFunctionTest.java rename to flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateKeepLastRowFunctionTest.java index e47525db45e84..efbf33552da8f 100644 --- a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateFunctionTest.java +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateKeepLastRowFunctionTest.java @@ -42,9 +42,9 @@ import static org.apache.flink.table.runtime.util.StreamRecordUtils.retractRecord; /** - * Tests for {@link MiniBatchDeduplicateFunction}. + * Tests for {@link MiniBatchDeduplicateKeepLastRowFunction}. */ -public class MiniBatchDeduplicateFunctionTest { +public class MiniBatchDeduplicateKeepLastRowFunctionTest { private BaseRowTypeInfo inputRowType = new BaseRowTypeInfo(InternalTypes.STRING, InternalTypes.LONG, InternalTypes.INT); @@ -60,64 +60,21 @@ public class MiniBatchDeduplicateFunctionTest { private TypeSerializer typeSerializer = inputRowType.createSerializer(new ExecutionConfig()); - private MiniBatchDeduplicateFunction createFunction(boolean generateRetraction, boolean keepLastRow) { - MiniBatchDeduplicateFunction func = new MiniBatchDeduplicateFunction(inputRowType, generateRetraction, - typeSerializer, keepLastRow); - return func; + private MiniBatchDeduplicateKeepLastRowFunction createFunction(boolean generateRetraction) { + return new MiniBatchDeduplicateKeepLastRowFunction(inputRowType, generateRetraction, typeSerializer); } private OneInputStreamOperatorTestHarness createTestHarness( - MiniBatchDeduplicateFunction func) + MiniBatchDeduplicateKeepLastRowFunction func) throws Exception { CountBundleTrigger> trigger = new CountBundleTrigger<>(3); KeyedMapBundleOperator op = new KeyedMapBundleOperator(func, trigger); - return new KeyedOneInputStreamOperatorTestHarness(op, rowKeySelector, rowKeySelector.getProducedType()); + return new KeyedOneInputStreamOperatorTestHarness<>(op, rowKeySelector, rowKeySelector.getProducedType()); } @Test - public void testKeepFirstRowWithoutGenerateRetraction() throws Exception { - MiniBatchDeduplicateFunction func = createFunction(false, false); - OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); - testHarness.open(); - testHarness.processElement(record("book", 1L, 12)); - testHarness.processElement(record("book", 2L, 11)); - - // output is empty because bundle not trigger yet. - Assert.assertTrue(testHarness.getOutput().isEmpty()); - - testHarness.processElement(record("book", 1L, 13)); - // output is not empty because bundle is trigger. - List expectedOutputOutput = new ArrayList<>(); - expectedOutputOutput.add(record("book", 1L, 12)); - expectedOutputOutput.add(record("book", 2L, 11)); - assertor.assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); - testHarness.close(); - } - - @Test - public void testKeepFirstRowWithGenerateRetraction() throws Exception { - MiniBatchDeduplicateFunction func = createFunction(true, false); - OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); - testHarness.open(); - testHarness.processElement(record("book", 1L, 12)); - testHarness.processElement(record("book", 2L, 11)); - - // output is empty because bundle not trigger yet. - Assert.assertTrue(testHarness.getOutput().isEmpty()); - - testHarness.processElement(record("book", 1L, 13)); - - // Keep FirstRow in deduplicate will not send retraction - List expectedOutputOutput = new ArrayList<>(); - expectedOutputOutput.add(record("book", 1L, 12)); - expectedOutputOutput.add(record("book", 2L, 11)); - assertor.assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); - testHarness.close(); - } - - @Test - public void testKeepLastWithoutGenerateRetraction() throws Exception { - MiniBatchDeduplicateFunction func = createFunction(false, true); + public void testWithoutGenerateRetraction() throws Exception { + MiniBatchDeduplicateKeepLastRowFunction func = createFunction(false); OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); testHarness.open(); testHarness.processElement(record("book", 1L, 10)); @@ -144,8 +101,8 @@ public void testKeepLastWithoutGenerateRetraction() throws Exception { } @Test - public void testKeepLastRowWithGenerateRetraction() throws Exception { - MiniBatchDeduplicateFunction func = createFunction(true, true); + public void testWithGenerateRetraction() throws Exception { + MiniBatchDeduplicateKeepLastRowFunction func = createFunction(true); OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); testHarness.open(); testHarness.processElement(record("book", 1L, 10)); diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/AppendRankFunctionTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/AppendRankFunctionTest.java index 288be22469343..49b9366c71f51 100644 --- a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/AppendRankFunctionTest.java +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/AppendRankFunctionTest.java @@ -37,10 +37,9 @@ public class AppendRankFunctionTest extends BaseRankFunctionTest { @Override protected AbstractRankFunction createRankFunction(RankType rankType, RankRange rankRange, boolean generateRetraction, boolean outputRankNumber) { - AbstractRankFunction rankFunction = new AppendRankFunction(minTime.toMilliseconds(), maxTime.toMilliseconds(), + return new AppendRankFunction(minTime.toMilliseconds(), maxTime.toMilliseconds(), inputRowType, sortKeyComparator, sortKeySelector, rankType, rankRange, generatedEqualiser, generateRetraction, outputRankNumber, cacheSize); - return rankFunction; } @Test diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/BaseRankFunctionTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/BaseRankFunctionTest.java index 0d766a546c686..deff2ac2ef98d 100644 --- a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/BaseRankFunctionTest.java +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/BaseRankFunctionTest.java @@ -50,9 +50,9 @@ */ abstract class BaseRankFunctionTest { - protected Time minTime = Time.milliseconds(10); - protected Time maxTime = Time.milliseconds(20); - protected long cacheSize = 10000L; + Time minTime = Time.milliseconds(10); + Time maxTime = Time.milliseconds(20); + long cacheSize = 10000L; BaseRowTypeInfo inputRowType = new BaseRowTypeInfo( InternalTypes.STRING, @@ -323,12 +323,12 @@ public void testConstantRankRangeWithoutOffset() throws Exception { .assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); } - protected OneInputStreamOperatorTestHarness createTestHarness( + OneInputStreamOperatorTestHarness createTestHarness( AbstractRankFunction rankFunction) throws Exception { - KeyedProcessOperator operator = new KeyedProcessOperator(rankFunction); + KeyedProcessOperator operator = new KeyedProcessOperator<>(rankFunction); rankFunction.setKeyContext(operator); - return new KeyedOneInputStreamOperatorTestHarness(operator, keySelector, keySelector.getProducedType()); + return new KeyedOneInputStreamOperatorTestHarness<>(operator, keySelector, keySelector.getProducedType()); } protected abstract AbstractRankFunction createRankFunction(RankType rankType, RankRange rankRange, diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/RetractRankFunctionTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/RetractRankFunctionTest.java index c068412470a02..ed38f631267ea 100644 --- a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/RetractRankFunctionTest.java +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/RetractRankFunctionTest.java @@ -39,10 +39,9 @@ public class RetractRankFunctionTest extends BaseRankFunctionTest { @Override protected AbstractRankFunction createRankFunction(RankType rankType, RankRange rankRange, boolean generateRetraction, boolean outputRankNumber) { - AbstractRankFunction rankFunction = new RetractRankFunction(minTime.toMilliseconds(), maxTime.toMilliseconds(), + return new RetractRankFunction(minTime.toMilliseconds(), maxTime.toMilliseconds(), inputRowType, sortKeyComparator, sortKeySelector, rankType, rankRange, generatedEqualiser, generateRetraction, outputRankNumber); - return rankFunction; } @Test diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/UpdateRankFunctionTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/UpdateRankFunctionTest.java index b657dab26d85c..92b76cb8eb955 100644 --- a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/UpdateRankFunctionTest.java +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/UpdateRankFunctionTest.java @@ -38,11 +38,9 @@ public class UpdateRankFunctionTest extends BaseRankFunctionTest { @Override protected AbstractRankFunction createRankFunction(RankType rankType, RankRange rankRange, boolean generateRetraction, boolean outputRankNumber) { - - AbstractRankFunction rankFunction = new UpdateRankFunction(minTime.toMilliseconds(), maxTime.toMilliseconds(), + return new UpdateRankFunction(minTime.toMilliseconds(), maxTime.toMilliseconds(), inputRowType, rowKeySelector, sortKeyComparator, sortKeySelector, rankType, rankRange, generatedEqualiser, generateRetraction, outputRankNumber, cacheSize); - return rankFunction; } @Test diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/util/BinaryRowKeySelector.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/util/BinaryRowKeySelector.java index e697ed80c227a..46475662a0b51 100644 --- a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/util/BinaryRowKeySelector.java +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/util/BinaryRowKeySelector.java @@ -28,10 +28,12 @@ import org.apache.flink.table.typeutils.BaseRowTypeInfo; /** - * A utility class which will extract key from BaseRow. + * A utility class which extracts key from BaseRow. */ public class BinaryRowKeySelector implements BaseRowKeySelector { + private static final long serialVersionUID = -2327761762415377059L; + private final int[] keyFields; private final InternalType[] inputFieldTypes; private final InternalType[] keyFieldTypes; diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/window/WindowOperatorTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/window/WindowOperatorTest.java index c0338f1a4453a..5f4deb0818d2e 100644 --- a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/window/WindowOperatorTest.java +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/window/WindowOperatorTest.java @@ -69,10 +69,10 @@ public class WindowOperatorTest { // For counting if close() is called the correct number of times on the SumReducer private static AtomicInteger closeCalled = new AtomicInteger(0); - private InternalType[] inputFieldTypes = new InternalType[] { + private InternalType[] inputFieldTypes = new InternalType[]{ InternalTypes.STRING, InternalTypes.INT, - InternalTypes.LONG }; + InternalTypes.LONG}; private BaseRowTypeInfo outputType = new BaseRowTypeInfo( InternalTypes.STRING,