diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/util/KeySelectorUtil.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/util/KeySelectorUtil.java
new file mode 100644
index 0000000000000..50082c2beacfa
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/util/KeySelectorUtil.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.util;
+
+import org.apache.flink.table.api.TableConfig;
+import org.apache.flink.table.codegen.CodeGeneratorContext;
+import org.apache.flink.table.codegen.ProjectionCodeGenerator;
+import org.apache.flink.table.generated.GeneratedProjection;
+import org.apache.flink.table.runtime.keyselector.BaseRowKeySelector;
+import org.apache.flink.table.runtime.keyselector.BinaryRowKeySelector;
+import org.apache.flink.table.runtime.keyselector.NullBinaryRowKeySelector;
+import org.apache.flink.table.type.InternalType;
+import org.apache.flink.table.type.RowType;
+import org.apache.flink.table.typeutils.BaseRowSerializer;
+import org.apache.flink.table.typeutils.BaseRowTypeInfo;
+import org.apache.flink.table.typeutils.TypeCheckUtils;
+
+/**
+ * Utility for KeySelector.
+ */
+public class KeySelectorUtil {
+
+ /**
+ * Create a BaseRowKeySelector to extract keys from DataStream which type is BaseRowTypeInfo.
+ *
+ * @param keyFields key fields
+ * @param rowType type of DataStream to extract keys
+ * @return the BaseRowKeySelector to extract keys from DataStream which type is BaseRowTypeInfo.
+ */
+ public static BaseRowKeySelector getBaseRowSelector(int[] keyFields, BaseRowTypeInfo rowType) {
+ if (keyFields.length > 0) {
+ InternalType[] inputFieldTypes = rowType.getInternalTypes();
+ String[] inputFieldNames = rowType.getFieldNames();
+ InternalType[] keyFieldTypes = new InternalType[keyFields.length];
+ String[] keyFieldNames = new String[keyFields.length];
+ for (int i = 0; i < keyFields.length; ++i) {
+ keyFieldTypes[i] = inputFieldTypes[keyFields[i]];
+ keyFieldNames[i] = inputFieldNames[keyFields[i]];
+ }
+ RowType returnType = new RowType(keyFieldTypes, keyFieldNames);
+ RowType inputType = new RowType(inputFieldTypes, rowType.getFieldNames());
+ GeneratedProjection generatedProjection = ProjectionCodeGenerator.generateProjection(
+ CodeGeneratorContext.apply(new TableConfig()),
+ BaseRowSerializer.class.getSimpleName(), inputType, returnType, keyFields);
+ BaseRowTypeInfo keyRowType = returnType.toTypeInfo();
+ // check if type implements proper equals/hashCode
+ TypeCheckUtils.validateEqualsHashCode("grouping", keyRowType);
+ return new BinaryRowKeySelector(keyRowType, generatedProjection);
+ } else {
+ return new NullBinaryRowKeySelector();
+ }
+ }
+
+}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableConfig.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableConfig.scala
index e6c0bc92419a8..4d5b2c3d1226c 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableConfig.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableConfig.scala
@@ -187,6 +187,26 @@ class TableConfig {
this.conf.setLong(TableConfigOptions.SQL_EXEC_STATE_TTL_MAX_MS, maxTime.toMilliseconds)
this
}
+
+ /**
+ * Returns the minimum time until state which was not updated will be retained.
+ */
+ def getMinIdleStateRetentionTime: Long = {
+ this.conf.getLong(TableConfigOptions.SQL_EXEC_STATE_TTL_MS)
+ }
+
+ /**
+ * Returns the maximum time until state which was not updated will be retained.
+ */
+ def getMaxIdleStateRetentionTime: Long = {
+ // only min idle ttl provided.
+ if (this.conf.contains(TableConfigOptions.SQL_EXEC_STATE_TTL_MS)
+ && !this.conf.contains(TableConfigOptions.SQL_EXEC_STATE_TTL_MAX_MS)) {
+ this.conf.setLong(TableConfigOptions.SQL_EXEC_STATE_TTL_MAX_MS,
+ getMinIdleStateRetentionTime * 2)
+ }
+ this.conf.getLong(TableConfigOptions.SQL_EXEC_STATE_TTL_MAX_MS)
+ }
}
object TableConfig {
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkLogicalRelFactories.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkLogicalRelFactories.scala
index 4e26b65b378ff..c60bb12f60d95 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkLogicalRelFactories.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkLogicalRelFactories.scala
@@ -19,10 +19,9 @@
package org.apache.flink.table.calcite
import org.apache.flink.table.calcite.FlinkRelFactories.{ExpandFactory, RankFactory, SinkFactory}
-import org.apache.flink.table.plan.nodes.calcite.RankRange
-import org.apache.flink.table.plan.nodes.calcite.RankType.RankType
import org.apache.flink.table.plan.nodes.logical._
import org.apache.flink.table.plan.schema.FlinkRelOptTable
+import org.apache.flink.table.runtime.rank.{RankRange, RankType}
import org.apache.flink.table.sinks.TableSink
import com.google.common.collect.ImmutableList
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala
index 3f462fe3d2156..b8e2c1a755913 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala
@@ -20,8 +20,7 @@ package org.apache.flink.table.calcite
import org.apache.flink.table.calcite.FlinkRelFactories.{ExpandFactory, RankFactory, SinkFactory}
import org.apache.flink.table.expressions.WindowProperty
-import org.apache.flink.table.plan.nodes.calcite.RankRange
-import org.apache.flink.table.plan.nodes.calcite.RankType.RankType
+import org.apache.flink.table.runtime.rank.{RankRange, RankType}
import org.apache.flink.table.sinks.TableSink
import org.apache.calcite.config.{CalciteConnectionConfigImpl, CalciteConnectionProperty}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkRelFactories.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkRelFactories.scala
index d18e5ec82cd66..05dc497f6b1bf 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkRelFactories.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkRelFactories.scala
@@ -18,8 +18,8 @@
package org.apache.flink.table.calcite
-import org.apache.flink.table.plan.nodes.calcite.RankType.RankType
-import org.apache.flink.table.plan.nodes.calcite.{LogicalExpand, LogicalRank, LogicalSink, RankRange}
+import org.apache.flink.table.plan.nodes.calcite.{LogicalExpand, LogicalRank, LogicalSink}
+import org.apache.flink.table.runtime.rank.{RankRange, RankType}
import org.apache.flink.table.sinks.TableSink
import org.apache.calcite.plan.Contexts
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CodeGenUtils.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CodeGenUtils.scala
index 0d79320f5956b..87114d0ca7a06 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CodeGenUtils.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CodeGenUtils.scala
@@ -115,7 +115,7 @@ object CodeGenUtils {
case InternalTypes.DATE => "int"
case InternalTypes.TIME => "int"
- case InternalTypes.TIMESTAMP => "long"
+ case _: TimestampType => "long"
case InternalTypes.INTERVAL_MONTHS => "int"
case InternalTypes.INTERVAL_MILLIS => "long"
@@ -135,7 +135,7 @@ object CodeGenUtils {
case InternalTypes.DATE => boxedTypeTermForType(InternalTypes.INT)
case InternalTypes.TIME => boxedTypeTermForType(InternalTypes.INT)
- case InternalTypes.TIMESTAMP => boxedTypeTermForType(InternalTypes.LONG)
+ case _: TimestampType => boxedTypeTermForType(InternalTypes.LONG)
case InternalTypes.STRING => BINARY_STRING
case InternalTypes.BINARY => "byte[]"
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/EqualiserCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/EqualiserCodeGenerator.scala
new file mode 100644
index 0000000000000..09658663e6ca0
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/EqualiserCodeGenerator.scala
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.codegen
+
+import org.apache.flink.table.api.TableConfig
+import org.apache.flink.table.codegen.CodeGenUtils._
+import org.apache.flink.table.codegen.Indenter.toISC
+import org.apache.flink.table.generated.{GeneratedRecordEqualiser, RecordEqualiser}
+import org.apache.flink.table.`type`.{DateType, InternalType, PrimitiveType, RowType, TimeType,
+TimestampType}
+
+class EqualiserCodeGenerator(fieldTypes: Seq[InternalType]) {
+
+ private val RECORD_EQUALISER = className[RecordEqualiser]
+ private val LEFT_INPUT = "left"
+ private val RIGHT_INPUT = "right"
+
+ def generateRecordEqualiser(name: String): GeneratedRecordEqualiser = {
+ // ignore time zone
+ val ctx = CodeGeneratorContext(new TableConfig)
+ val className = newName(name)
+ val header =
+ s"""
+ |if ($LEFT_INPUT.getHeader() != $RIGHT_INPUT.getHeader()) {
+ | return false;
+ |}
+ """.stripMargin
+
+ val codes = for (i <- fieldTypes.indices) yield {
+ val fieldType = fieldTypes(i)
+ val fieldTypeTerm = primitiveTypeTermForType(fieldType)
+ val result = s"cmp$i"
+ val leftNullTerm = "leftIsNull$" + i
+ val rightNullTerm = "rightIsNull$" + i
+ val leftFieldTerm = "leftField$" + i
+ val rightFieldTerm = "rightField$" + i
+ val equalsCode = if (isInternalPrimitive(fieldType)) {
+ s"$leftFieldTerm == $rightFieldTerm"
+ } else if (isBaseRow(fieldType)) {
+ val equaliserGenerator =
+ new EqualiserCodeGenerator(fieldType.asInstanceOf[RowType].getFieldTypes)
+ val generatedEqualiser = equaliserGenerator
+ .generateRecordEqualiser("field$" + i + "GeneratedEqualiser")
+ val generatedEqualiserTerm = ctx.addReusableObject(
+ generatedEqualiser, "field$" + i + "GeneratedEqualiser")
+ val equaliserTypeTerm = classOf[RecordEqualiser].getCanonicalName
+ val equaliserTerm = newName("equaliser")
+ ctx.addReusableMember(s"private $equaliserTypeTerm $equaliserTerm = null;")
+ ctx.addReusableInitStatement(
+ s"""
+ |$equaliserTerm = ($equaliserTypeTerm)
+ | $generatedEqualiserTerm.newInstance(Thread.currentThread().getContextClassLoader());
+ |""".stripMargin)
+ s"$equaliserTerm.equalsWithoutHeader($leftFieldTerm, $rightFieldTerm)"
+ } else {
+ s"$leftFieldTerm.equals($rightFieldTerm)"
+ }
+ val leftReadCode = baseRowFieldReadAccess(ctx, i, LEFT_INPUT, fieldType)
+ val rightReadCode = baseRowFieldReadAccess(ctx, i, RIGHT_INPUT, fieldType)
+ s"""
+ |boolean $leftNullTerm = $LEFT_INPUT.isNullAt($i);
+ |boolean $rightNullTerm = $RIGHT_INPUT.isNullAt($i);
+ |boolean $result;
+ |if ($leftNullTerm && $rightNullTerm) {
+ | $result = true;
+ |} else if ($leftNullTerm|| $rightNullTerm) {
+ | $result = false;
+ |} else {
+ | $fieldTypeTerm $leftFieldTerm = $leftReadCode;
+ | $fieldTypeTerm $rightFieldTerm = $rightReadCode;
+ | $result = $equalsCode;
+ |}
+ |if (!$result) {
+ | return false;
+ |}
+ """.stripMargin
+ }
+
+ val functionCode =
+ j"""
+ public final class $className implements $RECORD_EQUALISER {
+
+ ${ctx.reuseMemberCode()}
+
+ public $className(Object[] references) throws Exception {
+ ${ctx.reuseInitCode()}
+ }
+
+ @Override
+ public boolean equals($BASE_ROW $LEFT_INPUT, $BASE_ROW $RIGHT_INPUT) {
+ if ($LEFT_INPUT instanceof $BINARY_ROW && $RIGHT_INPUT instanceof $BINARY_ROW) {
+ return $LEFT_INPUT.equals($RIGHT_INPUT);
+ } else {
+ $header
+ ${ctx.reuseLocalVariableCode()}
+ ${codes.mkString("\n")}
+ return true;
+ }
+ }
+
+ @Override
+ public boolean equalsWithoutHeader($BASE_ROW $LEFT_INPUT, $BASE_ROW $RIGHT_INPUT) {
+ if ($LEFT_INPUT instanceof $BINARY_ROW && $RIGHT_INPUT instanceof $BINARY_ROW) {
+ return (($BINARY_ROW)$LEFT_INPUT).equalsWithoutHeader((($BINARY_ROW)$RIGHT_INPUT));
+ } else {
+ ${ctx.reuseLocalVariableCode()}
+ ${codes.mkString("\n")}
+ return true;
+ }
+ }
+ }
+ """.stripMargin
+
+ new GeneratedRecordEqualiser(className, functionCode, ctx.references.toArray)
+ }
+
+ private def isInternalPrimitive(t: InternalType): Boolean = t match {
+ case _: PrimitiveType => true
+
+ case _: DateType => true
+ case TimeType.INSTANCE => true
+ case _: TimestampType => true
+
+ case _ => false
+ }
+
+ private def isBaseRow(t: InternalType): Boolean = t match {
+ case _: RowType => true
+ case _ => false
+ }
+}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/SinkCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/SinkCodeGenerator.scala
index 29431c2fb6458..ce388c8bc9f7f 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/SinkCodeGenerator.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/SinkCodeGenerator.scala
@@ -31,6 +31,7 @@ import org.apache.flink.table.`type`.TypeConverters.createInternalTypeFromTypeIn
import org.apache.flink.table.api.{Table, TableConfig, TableException, Types}
import org.apache.flink.table.codegen.CodeGenUtils.{BASE_ROW, genToExternal}
import org.apache.flink.table.codegen.OperatorCodeGenerator.generateCollect
+import org.apache.flink.table.dataformat.util.BaseRowUtil
import org.apache.flink.table.dataformat.{BaseRow, GenericRow}
import org.apache.flink.table.runtime.OneInputOperatorWrapper
import org.apache.flink.table.sinks.{DataStreamTableSink, TableSink}
@@ -93,6 +94,10 @@ object SinkCodeGenerator {
new RowTypeInfo(
inputTypeInfo.getFieldTypes,
inputTypeInfo.getFieldNames)
+ case gt: GenericTypeInfo[BaseRow] if gt.getTypeClass == classOf[BaseRow] =>
+ new BaseRowTypeInfo(
+ inputTypeInfo.getInternalTypes,
+ inputTypeInfo.getFieldNames)
case _ => requestedTypeInfo
}
@@ -154,13 +159,14 @@ object SinkCodeGenerator {
val retractProcessCode = if (!withChangeFlag) {
generateCollect(genToExternal(ctx, outputTypeInfo, afterIndexModify))
} else {
- val flagResultTerm = s"$afterIndexModify.getHeader() == $BASE_ROW.ACCUMULATE_MSG"
+ val flagResultTerm =
+ s"${classOf[BaseRowUtil].getCanonicalName}.isAccumulateMsg($afterIndexModify)"
val resultTerm = CodeGenUtils.newName("result")
val genericRowField = classOf[GenericRow].getCanonicalName
s"""
|$genericRowField $resultTerm = new $genericRowField(2);
- |$resultTerm.update(0, $flagResultTerm);
- |$resultTerm.update(1, $afterIndexModify);
+ |$resultTerm.setField(0, $flagResultTerm);
+ |$resultTerm.setField(1, $afterIndexModify);
|${generateCollect(genToExternal(ctx, outputTypeInfo, resultTerm))}
""".stripMargin
}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/ValuesCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/ValuesCodeGenerator.scala
index c207bd685a52f..86b79a9e590a0 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/ValuesCodeGenerator.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/ValuesCodeGenerator.scala
@@ -22,7 +22,6 @@ import org.apache.flink.table.api.TableEnvironment
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.dataformat.{BaseRow, GenericRow}
import org.apache.flink.table.runtime.values.ValuesInputFormat
-import org.apache.flink.table.typeutils.BaseRowTypeInfo
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rex.RexLiteral
@@ -56,8 +55,7 @@ object ValuesCodeGenerator {
generatedRecords.map(_.code),
outputType)
- val baseRowTypeInfo = new BaseRowTypeInfo(outputType.getFieldTypes, outputType.getFieldNames)
- new ValuesInputFormat(generatedFunction, baseRowTypeInfo)
+ new ValuesInputFormat(generatedFunction, outputType.toTypeInfo)
}
}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/sort/SortCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/sort/SortCodeGenerator.scala
index a3fefe15c691e..57a64212a0b7e 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/sort/SortCodeGenerator.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/sort/SortCodeGenerator.scala
@@ -449,6 +449,9 @@ class SortCodeGenerator(
case InternalTypes.FLOAT => 4
case InternalTypes.DOUBLE => 8
case InternalTypes.LONG => 8
+ case _: TimestampType => 8
+ case _: DateType => 4
+ case InternalTypes.TIME => 4
case dt: DecimalType if Decimal.isCompact(dt.precision()) => 8
case InternalTypes.STRING | InternalTypes.BINARY => Int.MaxValue
}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/calcite/LogicalRank.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/calcite/LogicalRank.scala
index de4238ea29c27..e9e4c0231367b 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/calcite/LogicalRank.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/calcite/LogicalRank.scala
@@ -18,7 +18,7 @@
package org.apache.flink.table.plan.nodes.calcite
-import org.apache.flink.table.plan.nodes.calcite.RankType.RankType
+import org.apache.flink.table.runtime.rank.{RankRange, RankType}
import org.apache.calcite.plan.{Convention, RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.`type`.RelDataTypeField
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/calcite/Rank.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/calcite/Rank.scala
index 9e19e028e3dbc..0ab70a3406a65 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/calcite/Rank.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/calcite/Rank.scala
@@ -19,8 +19,8 @@
package org.apache.flink.table.plan.nodes.calcite
import org.apache.flink.table.api.TableException
-import org.apache.flink.table.plan.nodes.calcite.RankType.RankType
import org.apache.flink.table.plan.util._
+import org.apache.flink.table.runtime.rank.{ConstantRankRange, RankRange, RankType, VariableRankRange}
import org.apache.calcite.plan.{RelOptCluster, RelOptCost, RelOptPlanner, RelTraitSet}
import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeField}
@@ -71,18 +71,19 @@ abstract class Rank(
rankRange match {
case r: ConstantRankRange =>
- if (r.rankEnd <= 0) {
- throw new TableException(s"Rank end can't smaller than zero. The rank end is ${r.rankEnd}")
+ if (r.getRankEnd <= 0) {
+ throw new TableException(
+ s"Rank end can't smaller than zero. The rank end is ${r.getRankEnd}")
}
- if (r.rankStart > r.rankEnd) {
+ if (r.getRankStart > r.getRankEnd) {
throw new TableException(
- s"Rank start '${r.rankStart}' can't greater than rank end '${r.rankEnd}'.")
+ s"Rank start '${r.getRankStart}' can't greater than rank end '${r.getRankEnd}'.")
}
case v: VariableRankRange =>
- if (v.rankEndIndex < 0) {
+ if (v.getRankEndIndex < 0) {
throw new TableException(s"Rank end index can't smaller than zero.")
}
- if (v.rankEndIndex >= input.getRowType.getFieldCount) {
+ if (v.getRankEndIndex >= input.getRowType.getFieldCount) {
throw new TableException(s"Rank end index can't greater than input field count.")
}
}
@@ -105,7 +106,7 @@ abstract class Rank(
}.mkString(", ")
super.explainTerms(pw)
.item("rankType", rankType)
- .item("rankRange", rankRange.toString())
+ .item("rankRange", rankRange)
.item("partitionBy", partitionKey.map(i => s"$$$i").mkString(","))
.item("orderBy", RelExplainUtil.collationToString(orderKey))
.item("select", select)
@@ -134,64 +135,3 @@ abstract class Rank(
}
}
-
-/**
- * An enumeration of rank type, usable to tell the [[Rank]] node how exactly generate rank number.
- */
-object RankType extends Enumeration {
- type RankType = Value
-
- /**
- * Returns a unique sequential number for each row within the partition based on the order,
- * starting at 1 for the first row in each partition and without repeating or skipping
- * numbers in the ranking result of each partition. If there are duplicate values within the
- * row set, the ranking numbers will be assigned arbitrarily.
- */
- val ROW_NUMBER: RankType.Value = Value
-
- /**
- * Returns a unique rank number for each distinct row within the partition based on the order,
- * starting at 1 for the first row in each partition, with the same rank for duplicate values
- * and leaving gaps between the ranks; this gap appears in the sequence after the duplicate
- * values.
- */
- val RANK: RankType.Value = Value
-
- /**
- * is similar to the RANK by generating a unique rank number for each distinct row
- * within the partition based on the order, starting at 1 for the first row in each partition,
- * ranking the rows with equal values with the same rank number, except that it does not skip
- * any rank, leaving no gaps between the ranks.
- */
- val DENSE_RANK: RankType.Value = Value
-}
-
-sealed trait RankRange extends Serializable {
- def toString(inputFieldNames: Seq[String]): String
-}
-
-/** [[ConstantRankRangeWithoutEnd]] is a RankRange which not specify RankEnd. */
-case class ConstantRankRangeWithoutEnd(rankStart: Long) extends RankRange {
- override def toString(inputFieldNames: Seq[String]): String = this.toString
-
- override def toString: String = s"rankStart=$rankStart"
-}
-
-/** rankStart and rankEnd are inclusive, rankStart always start from one. */
-case class ConstantRankRange(rankStart: Long, rankEnd: Long) extends RankRange {
-
- override def toString(inputFieldNames: Seq[String]): String = this.toString
-
- override def toString: String = s"rankStart=$rankStart, rankEnd=$rankEnd"
-}
-
-/** changing rank limit depends on input */
-case class VariableRankRange(rankEndIndex: Int) extends RankRange {
- override def toString(inputFieldNames: Seq[String]): String = {
- s"rankEnd=${inputFieldNames(rankEndIndex)}"
- }
-
- override def toString: String = {
- s"rankEnd=$$$rankEndIndex"
- }
-}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalRank.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalRank.scala
index 806dadfa0edd0..55d9575637fc5 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalRank.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalRank.scala
@@ -18,9 +18,9 @@
package org.apache.flink.table.plan.nodes.logical
import org.apache.flink.table.plan.nodes.FlinkConventions
-import org.apache.flink.table.plan.nodes.calcite.RankType.RankType
-import org.apache.flink.table.plan.nodes.calcite.{LogicalRank, Rank, RankRange}
+import org.apache.flink.table.plan.nodes.calcite.{LogicalRank, Rank}
import org.apache.flink.table.plan.util.RelExplainUtil
+import org.apache.flink.table.runtime.rank.{RankRange, RankType}
import org.apache.calcite.plan._
import org.apache.calcite.rel.`type`.RelDataTypeField
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecRank.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecRank.scala
index 84f05aae636dc..05cd3252d438b 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecRank.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecRank.scala
@@ -25,10 +25,10 @@ import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.codegen.sort.ComparatorCodeGenerator
import org.apache.flink.table.dataformat.BaseRow
import org.apache.flink.table.plan.cost.{FlinkCost, FlinkCostFactory}
-import org.apache.flink.table.plan.nodes.calcite.RankType.RankType
-import org.apache.flink.table.plan.nodes.calcite.{ConstantRankRange, Rank, RankRange, RankType}
+import org.apache.flink.table.plan.nodes.calcite.Rank
import org.apache.flink.table.plan.nodes.exec.{BatchExecNode, ExecNode}
import org.apache.flink.table.plan.util.RelExplainUtil
+import org.apache.flink.table.runtime.rank.{ConstantRankRange, RankRange, RankType}
import org.apache.flink.table.runtime.sort.RankOperator
import org.apache.calcite.plan._
@@ -72,7 +72,7 @@ class BatchExecRank(
require(rankType == RankType.RANK, "Only RANK is supported now")
val (rankStart, rankEnd) = rankRange match {
- case r: ConstantRankRange => (r.rankStart, r.rankEnd)
+ case r: ConstantRankRange => (r.getRankStart, r.getRankEnd)
case o => throw new TableException(s"$o is not supported now")
}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDataStreamScan.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDataStreamScan.scala
index 81fe874c56938..a8a42b189c11b 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDataStreamScan.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDataStreamScan.scala
@@ -31,6 +31,7 @@ import org.apache.flink.table.functions.sql.StreamRecordTimestampSqlFunction
import org.apache.flink.table.plan.nodes.exec.{ExecNode, StreamExecNode}
import org.apache.flink.table.plan.schema.DataStreamTable
import org.apache.flink.table.plan.util.ScanUtil
+import org.apache.flink.table.runtime.AbstractProcessStreamOperator
import org.apache.flink.table.typeutils.TimeIndicatorTypeInfo.{ROWTIME_INDICATOR, ROWTIME_STREAM_MARKER}
import org.apache.calcite.plan._
@@ -112,16 +113,17 @@ class StreamExecDataStreamScan(
} else {
("", "")
}
-
+ val ctx = CodeGeneratorContext(config).setOperatorBaseClass(
+ classOf[AbstractProcessStreamOperator[BaseRow]])
ScanUtil.convertToInternalRow(
- CodeGeneratorContext(config),
+ ctx,
transform,
dataStreamTable.fieldIndexes,
dataStreamTable.typeInfo,
getRowType,
getTable.getQualifiedName,
config,
- None,
+ rowtimeExpr,
beforeConvert = extractElement,
afterConvert = resetElement)
} else {
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDeduplicate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDeduplicate.scala
index e76af21841def..58d56b84ba472 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDeduplicate.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDeduplicate.scala
@@ -18,27 +18,42 @@
package org.apache.flink.table.plan.nodes.physical.stream
+import org.apache.flink.streaming.api.operators.KeyedProcessOperator
+import org.apache.flink.streaming.api.transformations.{OneInputTransformation, StreamTransformation}
+import org.apache.flink.table.api.{StreamTableEnvironment, TableConfigOptions, TableException}
+import org.apache.flink.table.dataformat.BaseRow
+import org.apache.flink.table.plan.nodes.exec.{ExecNode, StreamExecNode}
+import org.apache.flink.table.plan.util.KeySelectorUtil
+import org.apache.flink.table.runtime.bundle.KeyedMapBundleOperator
+import org.apache.flink.table.runtime.bundle.trigger.CountBundleTrigger
+import org.apache.flink.table.runtime.deduplicate.{DeduplicateKeepFirstRowFunction, DeduplicateKeepLastRowFunction, MiniBatchDeduplicateKeepFirstRowFunction, MiniBatchDeduplicateKeepLastRowFunction}
+import org.apache.flink.table.plan.rules.physical.stream.StreamExecRetractionRules
+import org.apache.flink.table.typeutils.BaseRowTypeInfo
+
import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
-import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
+import org.apache.calcite.rel.`type`.RelDataType
import java.util
+import scala.collection.JavaConversions._
+
/**
* Stream physical RelNode which deduplicate on keys and keeps only first row or last row.
* This node is an optimization of [[StreamExecRank]] for some special cases.
* Compared to [[StreamExecRank]], this node could use mini-batch and access less state.
- * NOTES: only supports sort on proctime now.
+ *
NOTES: only supports sort on proctime now, sort on rowtime will not translated into
+ * StreamExecDeduplicate node.
*/
class StreamExecDeduplicate(
cluster: RelOptCluster,
traitSet: RelTraitSet,
inputRel: RelNode,
uniqueKeys: Array[Int],
- isRowtime: Boolean,
keepLastRow: Boolean)
extends SingleRel(cluster, traitSet, inputRel)
- with StreamPhysicalRel {
+ with StreamPhysicalRel
+ with StreamExecNode[BaseRow] {
def getUniqueKeys: Array[Int] = uniqueKeys
@@ -60,17 +75,77 @@ class StreamExecDeduplicate(
traitSet,
inputs.get(0),
uniqueKeys,
- isRowtime,
keepLastRow)
}
override def explainTerms(pw: RelWriter): RelWriter = {
val fieldNames = getRowType.getFieldNames
- val orderString = if (isRowtime) "ROWTIME" else "PROCTIME"
super.explainTerms(pw)
.item("keepLastRow", keepLastRow)
.item("key", uniqueKeys.map(fieldNames.get).mkString(", "))
- .item("order", orderString)
+ .item("order", "PROCTIME")
+ }
+
+ //~ ExecNode methods -----------------------------------------------------------
+
+ override protected def translateToPlanInternal(
+ tableEnv: StreamTableEnvironment): StreamTransformation[BaseRow] = {
+
+ val inputIsAccRetract = StreamExecRetractionRules.isAccRetract(getInput)
+
+ if (inputIsAccRetract) {
+ throw new TableException("Deduplicate doesn't support retraction input stream currently.")
+ }
+
+ val inputTransform = getInputNodes.get(0).translateToPlan(tableEnv)
+ .asInstanceOf[StreamTransformation[BaseRow]]
+
+ val rowTypeInfo = inputTransform.getOutputType.asInstanceOf[BaseRowTypeInfo]
+ val generateRetraction = StreamExecRetractionRules.isAccRetract(this)
+ val tableConfig = tableEnv.getConfig
+ val isMiniBatchEnabled = tableConfig.getConf.getLong(
+ TableConfigOptions.SQL_EXEC_MINIBATCH_ALLOW_LATENCY) > 0
+ val operator = if (isMiniBatchEnabled) {
+ val exeConfig = tableEnv.execEnv.getConfig
+ val rowSerializer = rowTypeInfo.createSerializer(exeConfig)
+ val processFunction = if (keepLastRow) {
+ new MiniBatchDeduplicateKeepLastRowFunction(rowTypeInfo, generateRetraction, rowSerializer)
+ } else {
+ new MiniBatchDeduplicateKeepFirstRowFunction(rowSerializer)
+ }
+ val trigger = new CountBundleTrigger[BaseRow](
+ tableConfig.getConf.getLong(TableConfigOptions.SQL_EXEC_MINIBATCH_SIZE))
+ new KeyedMapBundleOperator(
+ processFunction,
+ trigger)
+ } else {
+ val minRetentionTime = tableConfig.getMinIdleStateRetentionTime
+ val maxRetentionTime = tableConfig.getMaxIdleStateRetentionTime
+ val processFunction = if (keepLastRow) {
+ new DeduplicateKeepLastRowFunction(minRetentionTime, maxRetentionTime, rowTypeInfo,
+ generateRetraction)
+ } else {
+ new DeduplicateKeepFirstRowFunction(minRetentionTime, maxRetentionTime)
+ }
+ new KeyedProcessOperator[BaseRow, BaseRow, BaseRow](processFunction)
+ }
+ val ret = new OneInputTransformation(
+ inputTransform, getOperatorName, operator, rowTypeInfo, inputTransform.getParallelism)
+ val selector = KeySelectorUtil.getBaseRowSelector(uniqueKeys, rowTypeInfo)
+ ret.setStateKeySelector(selector)
+ ret.setStateKeyType(selector.getProducedType)
+ ret
+ }
+
+ override def getInputNodes: util.List[ExecNode[StreamTableEnvironment, _]] = {
+ List(getInput.asInstanceOf[ExecNode[StreamTableEnvironment, _]])
+ }
+
+ private def getOperatorName: String = {
+ val fieldNames = getRowType.getFieldNames
+ val keyNames = uniqueKeys.map(fieldNames.get).mkString(", ")
+ s"${if (keepLastRow) "keepLastRow" else "KeepFirstRow"}" +
+ s": (key: ($keyNames), select: (${fieldNames.mkString(", ")}), order: (PROCTIME)"
}
}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecExchange.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecExchange.scala
index b503fb2d9bded..2f2fc307a9744 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecExchange.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecExchange.scala
@@ -18,11 +18,24 @@
package org.apache.flink.table.plan.nodes.physical.stream
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment.DEFAULT_LOWER_BOUND_MAX_PARALLELISM
import org.apache.flink.table.plan.nodes.common.CommonPhysicalExchange
+import org.apache.flink.table.dataformat.BaseRow
+import org.apache.flink.table.plan.nodes.exec.{ExecNode, StreamExecNode}
+import org.apache.flink.streaming.api.transformations.{PartitionTransformation, StreamTransformation}
+import org.apache.flink.streaming.runtime.partitioner.{GlobalPartitioner, KeyGroupStreamPartitioner, StreamPartitioner}
+import org.apache.flink.table.api.StreamTableEnvironment
+import org.apache.flink.table.calcite.FlinkTypeFactory
+import org.apache.flink.table.plan.util.KeySelectorUtil
+import org.apache.flink.table.typeutils.BaseRowTypeInfo
import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.{RelDistribution, RelNode}
+import java.util
+
+import scala.collection.JavaConversions._
+
/**
* Stream physical RelNode for [[org.apache.calcite.rel.core.Exchange]].
*/
@@ -32,7 +45,8 @@ class StreamExecExchange(
relNode: RelNode,
relDistribution: RelDistribution)
extends CommonPhysicalExchange(cluster, traitSet, relNode, relDistribution)
- with StreamPhysicalRel {
+ with StreamPhysicalRel
+ with StreamExecNode[BaseRow] {
override def producesUpdates: Boolean = false
@@ -50,4 +64,43 @@ class StreamExecExchange(
newDistribution: RelDistribution): StreamExecExchange = {
new StreamExecExchange(cluster, traitSet, newInput, newDistribution)
}
+
+ //~ ExecNode methods -----------------------------------------------------------
+
+ override def getInputNodes: util.List[ExecNode[StreamTableEnvironment, _]] = {
+ List(getInput.asInstanceOf[ExecNode[StreamTableEnvironment, _]])
+ }
+
+ override protected def translateToPlanInternal(
+ tableEnv: StreamTableEnvironment): StreamTransformation[BaseRow] = {
+ val inputTransform = getInputNodes.get(0).translateToPlan(tableEnv)
+ .asInstanceOf[StreamTransformation[BaseRow]]
+ val inputTypeInfo = inputTransform.getOutputType.asInstanceOf[BaseRowTypeInfo]
+ val outputTypeInfo = FlinkTypeFactory.toInternalRowType(getRowType).toTypeInfo
+ relDistribution.getType match {
+ case RelDistribution.Type.SINGLETON =>
+ val partitioner = new GlobalPartitioner[BaseRow]
+ val transformation = new PartitionTransformation(
+ inputTransform,
+ partitioner.asInstanceOf[StreamPartitioner[BaseRow]])
+ transformation.setOutputType(outputTypeInfo)
+ transformation
+ case RelDistribution.Type.HASH_DISTRIBUTED =>
+ // TODO Eliminate duplicate keys
+
+ val selector = KeySelectorUtil.getBaseRowSelector(
+ relDistribution.getKeys.map(_.toInt).toArray, inputTypeInfo)
+ val partitioner = new KeyGroupStreamPartitioner(selector,
+ DEFAULT_LOWER_BOUND_MAX_PARALLELISM)
+ val transformation = new PartitionTransformation(
+ inputTransform,
+ partitioner.asInstanceOf[StreamPartitioner[BaseRow]])
+ transformation.setOutputType(outputTypeInfo)
+ transformation
+ case _ =>
+ throw new UnsupportedOperationException(
+ s"not support RelDistribution: ${relDistribution.getType} now!")
+ }
+ }
+
}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecRank.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecRank.scala
index ac833b548ea65..e27c85b4b7370 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecRank.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecRank.scala
@@ -17,9 +17,18 @@
*/
package org.apache.flink.table.plan.nodes.physical.stream
-import org.apache.flink.table.plan.nodes.calcite.RankType.RankType
-import org.apache.flink.table.plan.nodes.calcite.{Rank, RankRange}
-import org.apache.flink.table.plan.util.{RankProcessStrategy, RelExplainUtil, RetractStrategy}
+import org.apache.flink.streaming.api.operators.KeyedProcessOperator
+import org.apache.flink.streaming.api.transformations.{OneInputTransformation, StreamTransformation}
+import org.apache.flink.table.api.{StreamTableEnvironment, TableConfigOptions, TableException}
+import org.apache.flink.table.calcite.FlinkTypeFactory
+import org.apache.flink.table.codegen.EqualiserCodeGenerator
+import org.apache.flink.table.codegen.sort.SortCodeGenerator
+import org.apache.flink.table.dataformat.BaseRow
+import org.apache.flink.table.plan.nodes.calcite.Rank
+import org.apache.flink.table.plan.nodes.exec.{ExecNode, StreamExecNode}
+import org.apache.flink.table.plan.rules.physical.stream.StreamExecRetractionRules
+import org.apache.flink.table.plan.util._
+import org.apache.flink.table.runtime.rank._
import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel._
@@ -53,7 +62,8 @@ class StreamExecRank(
rankRange,
rankNumberType,
outputRankNumber)
- with StreamPhysicalRel {
+ with StreamPhysicalRel
+ with StreamExecNode[BaseRow] {
/** please uses [[getStrategy]] instead of this field */
private var strategy: RankProcessStrategy = _
@@ -101,4 +111,121 @@ class StreamExecRank(
.item("orderBy", RelExplainUtil.collationToString(orderKey, inputRowType))
.item("select", getRowType.getFieldNames.mkString(", "))
}
+
+ //~ ExecNode methods -----------------------------------------------------------
+
+ override def getInputNodes: util.List[ExecNode[StreamTableEnvironment, _]] = {
+ List(getInput.asInstanceOf[ExecNode[StreamTableEnvironment, _]])
+ }
+
+ override protected def translateToPlanInternal(
+ tableEnv: StreamTableEnvironment): StreamTransformation[BaseRow] = {
+ val tableConfig = tableEnv.getConfig
+ rankType match {
+ case RankType.ROW_NUMBER => // ignore
+ case RankType.RANK =>
+ throw new TableException("RANK() on streaming table is not supported currently")
+ case RankType.DENSE_RANK =>
+ throw new TableException("DENSE_RANK() on streaming table is not supported currently")
+ case k =>
+ throw new TableException(s"Streaming tables do not support $k rank function.")
+ }
+
+ val inputRowTypeInfo = FlinkTypeFactory.toInternalRowType(getInput.getRowType).toTypeInfo
+ val fieldCollations = orderKey.getFieldCollations
+ val (sortFields, sortDirections, nullsIsLast) = SortUtil.getKeysAndOrders(fieldCollations)
+ val sortKeySelector = KeySelectorUtil.getBaseRowSelector(sortFields, inputRowTypeInfo)
+ val sortKeyType = sortKeySelector.getProducedType
+ val sortCodeGen = new SortCodeGenerator(
+ tableConfig, sortFields.indices.toArray, sortKeyType.getInternalTypes,
+ sortDirections, nullsIsLast)
+ val sortKeyComparator = sortCodeGen.generateRecordComparator("StreamExecSortComparator")
+ val generateRetraction = StreamExecRetractionRules.isAccRetract(this)
+ val cacheSize = tableConfig.getConf.getLong(TableConfigOptions.SQL_EXEC_TOPN_CACHE_SIZE)
+ val minIdleStateRetentionTime = tableConfig.getMinIdleStateRetentionTime
+ val maxIdleStateRetentionTime = tableConfig.getMaxIdleStateRetentionTime
+ val equaliserCodeGenerator = new EqualiserCodeGenerator(inputRowTypeInfo.getInternalTypes)
+ val generatedEqualiser = equaliserCodeGenerator.generateRecordEqualiser("RankValueEqualiser")
+ val processFunction = getStrategy(true) match {
+ case AppendFastStrategy =>
+ new AppendRankFunction(
+ minIdleStateRetentionTime,
+ maxIdleStateRetentionTime,
+ inputRowTypeInfo,
+ sortKeyComparator,
+ sortKeySelector,
+ rankType,
+ rankRange,
+ generatedEqualiser,
+ generateRetraction,
+ outputRankNumber,
+ cacheSize)
+
+ case UpdateFastStrategy(primaryKeys) =>
+ val rowKeySelector = KeySelectorUtil.getBaseRowSelector(primaryKeys, inputRowTypeInfo)
+ new UpdateRankFunction(
+ minIdleStateRetentionTime,
+ maxIdleStateRetentionTime,
+ inputRowTypeInfo,
+ rowKeySelector,
+ sortKeyComparator,
+ sortKeySelector,
+ rankType,
+ rankRange,
+ generatedEqualiser,
+ generateRetraction,
+ outputRankNumber,
+ cacheSize)
+
+ // TODO UnaryUpdateRank after SortedMapState is merged
+ case RetractStrategy | UnaryUpdateStrategy(_) =>
+ new RetractRankFunction(
+ minIdleStateRetentionTime,
+ maxIdleStateRetentionTime,
+ inputRowTypeInfo,
+ sortKeyComparator,
+ sortKeySelector,
+ rankType,
+ rankRange,
+ generatedEqualiser,
+ generateRetraction,
+ outputRankNumber)
+ }
+ val rankOpName = getOperatorName
+ val operator = new KeyedProcessOperator(processFunction)
+ processFunction.setKeyContext(operator)
+ val inputTransform = getInputNodes.get(0).translateToPlan(tableEnv)
+ .asInstanceOf[StreamTransformation[BaseRow]]
+ val outputRowTypeInfo = FlinkTypeFactory.toInternalRowType(getRowType).toTypeInfo
+ val ret = new OneInputTransformation(
+ inputTransform,
+ rankOpName,
+ operator,
+ outputRowTypeInfo,
+ inputTransform.getParallelism)
+
+ if (partitionKey.isEmpty) {
+ ret.setParallelism(1)
+ ret.setMaxParallelism(1)
+ }
+
+ // set KeyType and Selector for state
+ val selector = KeySelectorUtil.getBaseRowSelector(partitionKey.toArray, inputRowTypeInfo)
+ ret.setStateKeySelector(selector)
+ ret.setStateKeyType(selector.getProducedType)
+ ret
+ }
+
+ private def getOperatorName: String = {
+ val inputRowType = inputRel.getRowType
+ var result = getStrategy().toString
+ result += s"(orderBy: (${RelExplainUtil.collationToString(orderKey, inputRowType)})"
+ if (partitionKey.nonEmpty) {
+ val partitionKeys = partitionKey.toArray
+ result += s", partitionBy: (${RelExplainUtil.fieldToString(partitionKeys, inputRowType)})"
+ }
+ result += s", ${getRowType.getFieldNames.mkString(", ")}"
+ result += s", ${rankRange.toString(inputRowType.getFieldNames)})"
+ result
+ }
}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/optimize/StreamOptimizer.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/optimize/StreamOptimizer.scala
index 1bff861eb7287..822359c5ca520 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/optimize/StreamOptimizer.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/optimize/StreamOptimizer.scala
@@ -22,7 +22,7 @@ import org.apache.flink.table.api.{StreamTableEnvironment, TableConfig}
import org.apache.flink.table.plan.`trait`.UpdateAsRetractionTraitDef
import org.apache.flink.table.plan.nodes.calcite.Sink
import org.apache.flink.table.plan.optimize.program.{FlinkStreamProgram, StreamOptimizeContext}
-import org.apache.flink.table.sinks.RetractStreamTableSink
+import org.apache.flink.table.sinks.{DataStreamTableSink, RetractStreamTableSink}
import org.apache.flink.util.Preconditions
import org.apache.calcite.plan.volcano.VolcanoPlanner
@@ -38,7 +38,11 @@ class StreamOptimizer(tEnv: StreamTableEnvironment) extends Optimizer {
roots.map { root =>
val retractionFromRoot = root match {
case n: Sink =>
- n.sink.isInstanceOf[RetractStreamTableSink[_]]
+ n.sink match {
+ case _: RetractStreamTableSink[_] => true
+ case s: DataStreamTableSink[_] => s.updatesAsRetraction
+ case _ => false
+ }
case o =>
o.getTraitSet.getTrait(UpdateAsRetractionTraitDef.INSTANCE).sendsUpdatesAsRetractions
}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkLogicalRankRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkLogicalRankRule.scala
index 3a53e2c1223d5..3b21083e2a608 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkLogicalRankRule.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkLogicalRankRule.scala
@@ -19,9 +19,9 @@ package org.apache.flink.table.plan.rules.logical
import org.apache.flink.table.api.TableException
import org.apache.flink.table.calcite.FlinkContext
-import org.apache.flink.table.plan.nodes.calcite.{ConstantRankRange, ConstantRankRangeWithoutEnd, RankType}
import org.apache.flink.table.plan.nodes.logical.{FlinkLogicalCalc, FlinkLogicalOverWindow, FlinkLogicalRank}
import org.apache.flink.table.plan.util.RankUtil
+import org.apache.flink.table.runtime.rank.{ConstantRankRange, ConstantRankRangeWithoutEnd, RankType}
import org.apache.calcite.plan.RelOptRule.{any, operand}
import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelOptUtil}
@@ -85,8 +85,9 @@ abstract class FlinkLogicalRankRuleBase
require(rankNumberType.isDefined)
rankRange match {
- case Some(ConstantRankRange(_, rankEnd)) if rankEnd <= 0 =>
- throw new TableException(s"Rank end should not less than zero, but now is $rankEnd")
+ case Some(crr: ConstantRankRange) if crr.getRankEnd <= 0 =>
+ throw new TableException(
+ s"Rank end should not less than zero, but now is ${crr.getRankEnd}")
case _ => // do nothing
}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecRankRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecRankRule.scala
index 1e8950b6091ea..918211b9cda67 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecRankRule.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecRankRule.scala
@@ -21,10 +21,10 @@ package org.apache.flink.table.plan.rules.physical.batch
import org.apache.flink.table.api.TableException
import org.apache.flink.table.plan.`trait`.FlinkRelDistribution
import org.apache.flink.table.plan.nodes.FlinkConventions
-import org.apache.flink.table.plan.nodes.calcite.{ConstantRankRange, RankType}
import org.apache.flink.table.plan.nodes.logical.FlinkLogicalRank
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecRank
import org.apache.flink.table.plan.util.RelFieldCollationUtil
+import org.apache.flink.table.runtime.rank.{ConstantRankRange, RankType}
import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
import org.apache.calcite.rel.convert.ConverterRule
@@ -58,7 +58,7 @@ class BatchExecRankRule
def convert(rel: RelNode): RelNode = {
val rank = rel.asInstanceOf[FlinkLogicalRank]
val (_, rankEnd) = rank.rankRange match {
- case r: ConstantRankRange => (r.rankStart, r.rankEnd)
+ case r: ConstantRankRange => (r.getRankStart, r.getRankEnd)
case o => throw new TableException(s"$o is not supported now")
}
@@ -71,7 +71,7 @@ class BatchExecRankRule
val newLocalInput = RelOptRule.convert(rank.getInput, localRequiredTraitSet)
// create local BatchExecRank
- val localRankRange = ConstantRankRange(1, rankEnd) // local rank always start from 1
+ val localRankRange = new ConstantRankRange(1, rankEnd) // local rank always start from 1
val localRank = new BatchExecRank(
cluster,
emptyTraits,
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecDeduplicateRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecDeduplicateRule.scala
index b07ba7a040989..6ec7a27eb3aa3 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecDeduplicateRule.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecDeduplicateRule.scala
@@ -21,9 +21,9 @@ package org.apache.flink.table.plan.rules.physical.stream
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.plan.`trait`.FlinkRelDistribution
import org.apache.flink.table.plan.nodes.FlinkConventions
-import org.apache.flink.table.plan.nodes.calcite.{ConstantRankRange, RankType}
import org.apache.flink.table.plan.nodes.logical.FlinkLogicalRank
import org.apache.flink.table.plan.nodes.physical.stream.{StreamExecDeduplicate, StreamExecRank}
+import org.apache.flink.table.runtime.rank.{ConstantRankRange, RankType}
import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
import org.apache.calcite.rel.`type`.RelDataType
@@ -68,10 +68,6 @@ class StreamExecDeduplicateRule
override def convert(rel: RelNode): RelNode = {
val rank = rel.asInstanceOf[FlinkLogicalRank]
- val fieldCollation = rank.orderKey.getFieldCollations.get(0)
- val fieldType = rank.getInput.getRowType.getFieldList.get(fieldCollation.getFieldIndex).getType
- val isRowtime = FlinkTypeFactory.isRowtimeIndicatorType(fieldType)
-
val requiredDistribution = FlinkRelDistribution.hash(rank.partitionKey.toList)
val requiredTraitSet = rel.getCluster.getPlanner.emptyTraitSet()
.replace(FlinkConventions.STREAM_PHYSICAL)
@@ -79,6 +75,7 @@ class StreamExecDeduplicateRule
val convInput: RelNode = RelOptRule.convert(rank.getInput, requiredTraitSet)
// order by timeIndicator desc ==> lastRow, otherwise is firstRow
+ val fieldCollation = rank.orderKey.getFieldCollations.get(0)
val isLastRow = fieldCollation.direction.isDescending
val providedTraitSet = rel.getTraitSet.replace(FlinkConventions.STREAM_PHYSICAL)
new StreamExecDeduplicate(
@@ -86,7 +83,6 @@ class StreamExecDeduplicateRule
providedTraitSet,
convInput,
rank.partitionKey.toArray,
- isRowtime,
isLastRow)
}
}
@@ -111,7 +107,8 @@ object StreamExecDeduplicateRule {
val isRowNumberType = rank.rankType == RankType.ROW_NUMBER
val isLimit1 = rankRange match {
- case ConstantRankRange(rankStart, rankEnd) => rankStart == 1 && rankEnd == 1
+ case rankRange: ConstantRankRange =>
+ rankRange.getRankStart() == 1 && rankRange.getRankEnd() == 1
case _ => false
}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/FlinkRelMdUtil.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/FlinkRelMdUtil.scala
index fbb16f50260cb..b88cfffbdc2f0 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/FlinkRelMdUtil.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/FlinkRelMdUtil.scala
@@ -20,7 +20,7 @@ package org.apache.flink.table.plan.util
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.dataformat.BinaryRow
-import org.apache.flink.table.plan.nodes.calcite.{ConstantRankRange, RankRange}
+import org.apache.flink.table.runtime.rank.{ConstantRankRange, RankRange}
import org.apache.flink.table.runtime.sort.BinaryIndexedSortable
import org.apache.flink.table.typeutils.BinaryRowSerializer
@@ -37,7 +37,7 @@ import scala.collection.JavaConversions._
object FlinkRelMdUtil {
def getRankRangeNdv(rankRange: RankRange): Double = rankRange match {
- case r: ConstantRankRange => (r.rankEnd - r.rankStart + 1).toDouble
+ case r: ConstantRankRange => (r.getRankEnd - r.getRankStart + 1).toDouble
case _ => 100D // default value now
}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/RankUtil.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/RankUtil.scala
index 338564c42c164..f7f7378aa5dff 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/RankUtil.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/RankUtil.scala
@@ -19,7 +19,9 @@
package org.apache.flink.table.plan.util
import org.apache.flink.table.api.{PlannerConfigOptions, TableConfig}
-import org.apache.flink.table.plan.nodes.calcite.{ConstantRankRange, ConstantRankRangeWithoutEnd, Rank, RankRange, VariableRankRange}
+import org.apache.flink.table.plan.nodes.calcite.Rank
+import org.apache.flink.table.runtime.rank.{ConstantRankRange, ConstantRankRangeWithoutEnd, RankRange, VariableRankRange}
+import org.apache.flink.table.typeutils.BaseRowTypeInfo
import org.apache.calcite.plan.RelOptUtil
import org.apache.calcite.rex.{RexBuilder, RexCall, RexInputRef, RexLiteral, RexNode, RexUtil}
@@ -98,20 +100,20 @@ object RankUtil {
val sortBounds = limitPreds.map(computeWindowBoundFromPredicate(_, rexBuilder, config))
val rankRange = sortBounds match {
case Seq(Some(LowerBoundary(x)), Some(UpperBoundary(y))) =>
- ConstantRankRange(x, y)
+ new ConstantRankRange(x, y)
case Seq(Some(UpperBoundary(x)), Some(LowerBoundary(y))) =>
- ConstantRankRange(y, x)
+ new ConstantRankRange(y, x)
case Seq(Some(LowerBoundary(x))) =>
// only offset
- ConstantRankRangeWithoutEnd(x)
+ new ConstantRankRangeWithoutEnd(x)
case Seq(Some(UpperBoundary(x))) =>
// rankStart starts from one
- ConstantRankRange(1, x)
+ new ConstantRankRange(1, x)
case Seq(Some(BothBoundary(x, y))) =>
// nth rank
- ConstantRankRange(x, y)
+ new ConstantRankRange(x, y)
case Seq(Some(InputRefBoundary(x))) =>
- VariableRankRange(x)
+ new VariableRankRange(x)
case _ =>
// TopN requires at least one rank comparison predicate
return (None, Some(predicate))
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/typeutils/TypeCheckUtils.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/typeutils/TypeCheckUtils.scala
index 32c9f4f61c3bb..07b9d93abcf92 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/typeutils/TypeCheckUtils.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/typeutils/TypeCheckUtils.scala
@@ -18,8 +18,12 @@
package org.apache.flink.table.typeutils
-import org.apache.flink.table.`type`._
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.common.typeutils.CompositeType
+import org.apache.flink.api.java.typeutils.PojoTypeInfo
+import org.apache.flink.table.api.ValidationException
import org.apache.flink.table.codegen.GeneratedExpression
+import org.apache.flink.table.`type`._
object TypeCheckUtils {
@@ -96,4 +100,63 @@ object TypeCheckUtils {
def isReference(genExpr: GeneratedExpression): Boolean = isReference(genExpr.resultType)
+ /**
+ * Checks whether a type implements own hashCode() and equals() methods for storing an instance
+ * in Flink's state or performing a keyBy operation.
+ *
+ * @param name name of the operation.
+ * @param t type information to be validated
+ */
+ def validateEqualsHashCode(name: String, t: TypeInformation[_]): Unit = t match {
+
+ // make sure that a POJO class is a valid state type
+ case pt: PojoTypeInfo[_] =>
+ // we don't check the types recursively to give a chance of wrapping
+ // proper hashCode/equals methods around an immutable type
+ validateEqualsHashCode(name, pt.getClass)
+ // BinaryRow direct hash in bytes, no need to check field types.
+ case bt: BaseRowTypeInfo =>
+ // recursively check composite types
+ case ct: CompositeType[_] =>
+ validateEqualsHashCode(name, t.getTypeClass)
+ // we check recursively for entering Flink types such as tuples and rows
+ for (i <- 0 until ct.getArity) {
+ val subtype = ct.getTypeAt(i)
+ validateEqualsHashCode(name, subtype)
+ }
+ // check other type information only based on the type class
+ case _: TypeInformation[_] =>
+ validateEqualsHashCode(name, t.getTypeClass)
+ }
+
+ /**
+ * Checks whether a class implements own hashCode() and equals() methods for storing an instance
+ * in Flink's state or performing a keyBy operation.
+ *
+ * @param name name of the operation
+ * @param c class to be validated
+ */
+ def validateEqualsHashCode(name: String, c: Class[_]): Unit = {
+
+ // skip primitives
+ if (!c.isPrimitive) {
+ // check the component type of arrays
+ if (c.isArray) {
+ validateEqualsHashCode(name, c.getComponentType)
+ }
+ // check type for methods
+ else {
+ if (c.getMethod("hashCode").getDeclaringClass eq classOf[Object]) {
+ throw new ValidationException(
+ s"Type '${c.getCanonicalName}' cannot be used in a $name operation because it " +
+ s"does not implement a proper hashCode() method.")
+ }
+ if (c.getMethod("equals", classOf[Object]).getDeclaringClass eq classOf[Object]) {
+ throw new ValidationException(
+ s"Type '${c.getCanonicalName}' cannot be used in a $name operation because it " +
+ s"does not implement a proper equals() method.")
+ }
+ }
+ }
+ }
}
diff --git a/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/runtime/utils/FailingCollectionSource.java b/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/runtime/utils/FailingCollectionSource.java
new file mode 100644
index 0000000000000..3d95e3c3a6407
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/runtime/utils/FailingCollectionSource.java
@@ -0,0 +1,266 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.utils;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.runtime.state.CheckpointListener;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.util.Preconditions;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.apache.flink.util.Preconditions.checkArgument;
+
+/**
+ * The FailingCollectionSource will fail after emitted a specified number of elements. This is used
+ * to perform checkpoint and restore in IT cases.
+ */
+public class FailingCollectionSource
+ implements SourceFunction, CheckpointedFunction, CheckpointListener {
+
+ public static volatile boolean failedBefore = false;
+
+ private static final long serialVersionUID = 1L;
+
+ /** The (de)serializer to be used for the data elements. */
+ private final TypeSerializer serializer;
+
+ /** The actual data elements, in serialized form. */
+ private final byte[] elementsSerialized;
+
+ /** The number of serialized elements. */
+ private final int numElements;
+
+ /** The number of elements emitted already. */
+ private volatile int numElementsEmitted;
+
+ /** The number of elements to skip initially. */
+ private volatile int numElementsToSkip;
+
+ /** Flag to make the source cancelable. */
+ private volatile boolean isRunning = true;
+
+ private transient ListState checkpointedState;
+
+ /** A failure will occur when the given number of elements have been processed. */
+ private final int failureAfterNumElements;
+
+ /** The number of completed checkpoints. */
+ private volatile int numSuccessfulCheckpoints;
+
+ /** The checkpointed number of emitted elements. */
+ private final Map checkpointedEmittedNums;
+
+ /** The last successful checkpointed number of emitted elements. */
+ private volatile int lastCheckpointedEmittedNum = 0;
+
+ /** Whether to perform a checkpoint before job finished. */
+ private final boolean performCheckpointBeforeJobFinished;
+
+ public FailingCollectionSource(
+ TypeSerializer serializer,
+ Iterable elements,
+ int failureAfterNumElements,
+ boolean performCheckpointBeforeJobFinished) throws IOException {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ DataOutputViewStreamWrapper wrapper = new DataOutputViewStreamWrapper(baos);
+
+ int count = 0;
+ try {
+ for (T element : elements) {
+ serializer.serialize(element, wrapper);
+ count++;
+ }
+ } catch (Exception e) {
+ throw new IOException("Serializing the source elements failed: " + e.getMessage(), e);
+ }
+
+ this.serializer = serializer;
+ this.elementsSerialized = baos.toByteArray();
+ this.numElements = count;
+ checkArgument(failureAfterNumElements > 0);
+ this.failureAfterNumElements = failureAfterNumElements;
+ this.performCheckpointBeforeJobFinished = performCheckpointBeforeJobFinished;
+ this.checkpointedEmittedNums = new HashMap<>();
+ }
+
+ @Override
+ public void initializeState(FunctionInitializationContext context) throws Exception {
+ Preconditions.checkState(
+ this.checkpointedState == null,
+ "The " + getClass().getSimpleName() + " has already been initialized.");
+
+ this.checkpointedState = context.getOperatorStateStore().getListState(
+ new ListStateDescriptor<>(
+ "from-elements-state",
+ IntSerializer.INSTANCE
+ )
+ );
+
+ if (failedBefore && context.isRestored()) {
+ List retrievedStates = new ArrayList<>();
+ for (Integer entry : this.checkpointedState.get()) {
+ retrievedStates.add(entry);
+ }
+
+ // given that the parallelism of the function is 1, we can only have 1 state
+ Preconditions.checkArgument(
+ retrievedStates.size() == 1,
+ getClass().getSimpleName() + " retrieved invalid state.");
+
+ this.numElementsToSkip = retrievedStates.get(0);
+ }
+ }
+
+ @Override
+ public void run(SourceContext ctx) throws Exception {
+ ByteArrayInputStream bais = new ByteArrayInputStream(elementsSerialized);
+ final DataInputView input = new DataInputViewStreamWrapper(bais);
+
+ // if we are restored from a checkpoint and need to skip elements, skip them now.
+ int toSkip = numElementsToSkip;
+ if (toSkip > 0) {
+ try {
+ while (toSkip > 0) {
+ serializer.deserialize(input);
+ toSkip--;
+ }
+ } catch (Exception e) {
+ throw new IOException(
+ "Failed to deserialize an element from the source. " +
+ "If you are using user-defined serialization (Value and Writable types), check the " +
+ "serialization functions.\nSerializer is " + serializer);
+ }
+
+ this.numElementsEmitted = this.numElementsToSkip;
+ }
+
+ while (isRunning && numElementsEmitted < numElements) {
+ if (!failedBefore) {
+ // delay a bit, if we have not failed before
+ Thread.sleep(1);
+ if (numSuccessfulCheckpoints >= 1 && lastCheckpointedEmittedNum >= failureAfterNumElements) {
+ // cause a failure if we have not failed before and have reached
+ // enough completed checkpoints and elements
+ failedBefore = true;
+ throw new Exception("Artificial Failure");
+ }
+ }
+
+ if (failedBefore || numElementsEmitted < failureAfterNumElements) {
+ // the function failed before, or we are in the elements before the failure
+ T next;
+ try {
+ next = serializer.deserialize(input);
+ } catch (Exception e) {
+ throw new IOException(
+ "Failed to deserialize an element from the source. " +
+ "If you are using user-defined serialization (Value and Writable types), check the " +
+ "serialization functions.\nSerializer is " + serializer);
+ }
+
+ synchronized (ctx.getCheckpointLock()) {
+ ctx.collect(next);
+ numElementsEmitted++;
+ }
+ } else {
+ // if our work is done, delay a bit to prevent busy waiting
+ Thread.sleep(1);
+ }
+ }
+
+ if (performCheckpointBeforeJobFinished) {
+ while (isRunning) {
+ // wait until the latest checkpoint records everything
+ if (lastCheckpointedEmittedNum < numElements) {
+ // delay a bit to prevent busy waiting
+ Thread.sleep(1);
+ } else {
+ // cause a failure to retain the last checkpoint
+ throw new Exception("Job finished normally");
+ }
+ }
+ }
+ }
+
+ @Override
+ public void cancel() {
+ isRunning = false;
+ }
+
+
+ /**
+ * Gets the number of elements produced in total by this function.
+ *
+ * @return The number of elements produced in total.
+ */
+ public int getNumElements() {
+ return numElements;
+ }
+
+ /**
+ * Gets the number of elements emitted so far.
+ *
+ * @return The number of elements emitted so far.
+ */
+ public int getNumElementsEmitted() {
+ return numElementsEmitted;
+ }
+
+ // ------------------------------------------------------------------------
+ // Checkpointing
+ // ------------------------------------------------------------------------
+
+ @Override
+ public void snapshotState(FunctionSnapshotContext context) throws Exception {
+ Preconditions.checkState(
+ this.checkpointedState != null,
+ "The " + getClass().getSimpleName() + " has not been properly initialized.");
+
+ this.checkpointedState.clear();
+ this.checkpointedState.add(this.numElementsEmitted);
+ long checkpointId = context.getCheckpointId();
+ checkpointedEmittedNums.put(checkpointId, numElementsEmitted);
+ }
+
+ @Override
+ public void notifyCheckpointComplete(long checkpointId) throws Exception {
+ numSuccessfulCheckpoints++;
+ lastCheckpointedEmittedNum = checkpointedEmittedNums.get(checkpointId);
+ }
+
+ public static void reset() {
+ failedBefore = false;
+ }
+}
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/DeduplicateITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/DeduplicateITCase.scala
new file mode 100644
index 0000000000000..179394086772a
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/DeduplicateITCase.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.stream.sql
+
+import org.apache.flink.api.scala._
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.runtime.utils.StreamingWithMiniBatchTestBase.MiniBatchMode
+import org.apache.flink.table.runtime.utils._
+import org.apache.flink.table.runtime.utils.StreamingWithStateTestBase.StateBackendMode
+import org.apache.flink.table.runtime.utils.TimeTestUtil.TimestampAndWatermarkWithOffset
+import org.apache.flink.types.Row
+
+import org.junit.Assert._
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.junit.runners.Parameterized
+
+@RunWith(classOf[Parameterized])
+class DeduplicateITCase(miniBatch: MiniBatchMode, mode: StateBackendMode)
+ extends StreamingWithMiniBatchTestBase(miniBatch, mode) {
+
+ @Test
+ def testFirstRowOnProctime(): Unit = {
+ val t = failingDataSource(StreamTestData.get3TupleData)
+ .toTable(tEnv, 'a, 'b, 'c, 'proctime)
+ tEnv.registerTable("T", t)
+
+ val sql =
+ """
+ |SELECT a, b, c
+ |FROM (
+ | SELECT *,
+ | ROW_NUMBER() OVER (PARTITION BY b ORDER BY proctime) as rowNum
+ | FROM T
+ |)
+ |WHERE rowNum = 1
+ """.stripMargin
+
+ val sink = new TestingRetractSink
+ tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink)
+ env.execute()
+
+ val expected = List("1,1,Hi", "2,2,Hello", "4,3,Hello world, how are you?",
+ "7,4,Comment#1", "11,5,Comment#5", "16,6,Comment#10")
+ assertEquals(expected.sorted, sink.getRetractResults.sorted)
+ }
+
+ @Test
+ def testLastRowOnProctime(): Unit = {
+ val t = failingDataSource(StreamTestData.get3TupleData)
+ .toTable(tEnv, 'a, 'b, 'c, 'proctime)
+ tEnv.registerTable("T", t)
+
+ val sql =
+ """
+ |SELECT a, b, c
+ |FROM (
+ | SELECT *,
+ | ROW_NUMBER() OVER (PARTITION BY b ORDER BY proctime DESC) as rowNum
+ | FROM T
+ |)
+ |WHERE rowNum = 1
+ """.stripMargin
+
+ val sink = new TestingRetractSink
+ tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink)
+ env.execute()
+
+ val expected = List("1,1,Hi", "3,2,Hello world", "6,3,Luke Skywalker",
+ "10,4,Comment#4", "15,5,Comment#9", "21,6,Comment#15")
+ assertEquals(expected.sorted, sink.getRetractResults.sorted)
+ }
+
+}
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/RankITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/RankITCase.scala
new file mode 100644
index 0000000000000..ff046bc90a390
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/RankITCase.scala
@@ -0,0 +1,1310 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.stream.sql
+
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
+import org.apache.flink.api.java.typeutils.RowTypeInfo
+import org.apache.flink.api.scala._
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.api.{TableConfigOptions, TableException}
+import org.apache.flink.table.runtime.utils.StreamingWithStateTestBase.StateBackendMode
+import org.apache.flink.table.runtime.utils.{TestingRetractTableSink, TestingUpsertTableSink, _}
+import org.apache.flink.types.Row
+
+import org.junit.Assert._
+import org.junit.{Ignore, _}
+import org.junit.runner.RunWith
+import org.junit.runners.Parameterized
+
+@RunWith(classOf[Parameterized])
+class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode) {
+
+ @Test
+ def testTopN(): Unit = {
+ val data = List(
+ ("book", 1, 12),
+ ("book", 2, 19),
+ ("book", 4, 11),
+ ("fruit", 4, 33),
+ ("fruit", 3, 44),
+ ("fruit", 5, 22))
+
+ val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num)
+ tEnv.registerTable("T", ds)
+
+ val sql =
+ """
+ |SELECT *
+ |FROM (
+ | SELECT category, shopId, num,
+ | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num DESC) as rank_num
+ | FROM T)
+ |WHERE rank_num <= 2
+ """.stripMargin
+
+ val sink = new TestingRetractSink
+ tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink).setParallelism(1)
+ env.execute()
+ val expected = List(
+ "book,2,19,1",
+ "book,1,12,2",
+ "fruit,3,44,1",
+ "fruit,4,33,2")
+ assertEquals(expected.sorted, sink.getRetractResults.sorted)
+ }
+
+ @Test
+ def testTopNth(): Unit = {
+ val data = List(
+ ("book", 1, 12),
+ ("book", 2, 19),
+ ("book", 4, 11),
+ ("fruit", 4, 33),
+ ("fruit", 3, 44),
+ ("fruit", 5, 22))
+
+ val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num)
+ tEnv.registerTable("T", ds)
+
+ val sql =
+ """
+ |SELECT *
+ |FROM (
+ | SELECT category, shopId, num,
+ | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num DESC) as rank_num
+ | FROM T)
+ |WHERE rank_num = 2
+ """.stripMargin
+
+ val sink = new TestingRetractSink
+ tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink).setParallelism(1)
+ env.execute()
+
+ val expected = List(
+ "book,1,12,2",
+ "fruit,4,33,2")
+ assertEquals(expected.sorted, sink.getRetractResults.sorted)
+ }
+
+ @Test
+ def testTopNWithUpsertSink(): Unit = {
+ val data = List(
+ ("book", 1, 12),
+ ("book", 2, 19),
+ ("book", 4, 11),
+ ("fruit", 4, 33),
+ ("fruit", 3, 44),
+ ("fruit", 5, 22))
+
+ val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num)
+ tEnv.registerTable("T", ds)
+
+ val sql =
+ """
+ |SELECT *
+ |FROM (
+ | SELECT category, shopId, num,
+ | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num ASC) as rank_num
+ | FROM T)
+ |WHERE rank_num <= 2
+ """.stripMargin
+
+ val table = tEnv.sqlQuery(sql)
+ val schema = table.getSchema
+ val sink = new TestingUpsertTableSink(Array(0, 3)).
+ configure(schema.getFieldNames, schema.getFieldTypes)
+ tEnv.writeToSink(table, sink)
+ env.execute()
+
+ val expected = List(
+ "book,4,11,1",
+ "book,1,12,2",
+ "fruit,5,22,1",
+ "fruit,4,33,2")
+ assertEquals(expected.sorted, sink.getUpsertResults.sorted)
+ }
+
+ // FIXME
+ @Ignore("Enable after agg implements ExecNode and SortedMapState is supported")
+ @Test(expected = classOf[TableException])
+ def testTopNWithUnary(): Unit = {
+ val data = List(
+ ("book", 11, 100),
+ ("book", 11, 200),
+ ("book", 12, 400),
+ ("book", 12, 500),
+ ("book", 10, 600),
+ ("book", 10, 700),
+ ("book", 9, 800),
+ ("book", 9, 900),
+ ("book", 10, 500),
+ ("book", 8, 110),
+ ("book", 8, 120),
+ ("book", 7, 1800),
+ ("book", 9, 300),
+ ("book", 6, 1900),
+ ("book", 7, 50),
+ ("book", 11, 1800),
+ ("book", 7, 50),
+ ("book", 8, 2000),
+ ("book", 6, 700),
+ ("book", 5, 800),
+ ("book", 4, 910),
+ ("book", 3, 1000),
+ ("book", 2, 1100),
+ ("book", 1, 1200)
+ )
+
+ val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num)
+ tEnv.registerTable("T", ds)
+
+ val sql =
+ """
+ |SELECT *
+ |FROM (
+ | SELECT category, shopId, num,
+ | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num ASC) as rank_num
+ | FROM (
+ | SELECT category, shopId, sum(num) as num
+ | FROM T
+ | GROUP BY category, shopId
+ | ))
+ |WHERE rank_num <= 3
+ """.stripMargin
+
+ val table = tEnv.sqlQuery(sql)
+ val schema = table.getSchema
+ val sink = new TestingUpsertTableSink(Array(0, 3)).
+ configure(schema.getFieldNames, schema.getFieldTypes)
+ tEnv.writeToSink(table, sink)
+ env.execute()
+
+ val updatedExpected = List(
+ "book,5,800,1",
+ "book,12,900,2",
+ "book,4,910,3")
+
+ assertEquals(updatedExpected.sorted, sink.getUpsertResults.sorted)
+ }
+
+ // FIXME
+ @Ignore("Enable after agg implements ExecNode and SortedMapState is supported")
+ @Test
+ def testUnarySortTopNOnString(): Unit = {
+ val data = List(
+ ("book", 11, "100"),
+ ("book", 11, "200"),
+ ("book", 12, "400"),
+ ("book", 12, "600"),
+ ("book", 10, "600"),
+ ("book", 10, "700"),
+ ("book", 9, "800"),
+ ("book", 9, "900"),
+ ("book", 10, "500"),
+ ("book", 8, "110"),
+ ("book", 8, "120"),
+ ("book", 7, "812"),
+ ("book", 9, "300"),
+ ("book", 6, "900"),
+ ("book", 7, "50"),
+ ("book", 11, "800"),
+ ("book", 7, "50"),
+ ("book", 8, "200"),
+ ("book", 6, "700"),
+ ("book", 5, "800"),
+ ("book", 4, "910"),
+ ("book", 3, "110"),
+ ("book", 2, "900"),
+ ("book", 1, "700")
+ )
+
+ val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'price)
+ tEnv.registerTable("T", ds)
+
+ val sql =
+ """
+ |SELECT *
+ |FROM (
+ | SELECT category, shopId, max_price,
+ | ROW_NUMBER() OVER (PARTITION BY category ORDER BY max_price ASC) as rank_num
+ | FROM (
+ | SELECT category, shopId, max(price) as max_price
+ | FROM T
+ | GROUP BY category, shopId
+ | ))
+ |WHERE rank_num <= 3
+ """.stripMargin
+
+ val table = tEnv.sqlQuery(sql)
+ val schema = table.getSchema
+ val sink = new TestingUpsertTableSink(Array(0, 3)).
+ configure(schema.getFieldNames, schema.getFieldTypes)
+ tEnv.writeToSink(table, sink)
+ env.execute()
+
+ val updatedExpected = List(
+ "book,3,110,1",
+ "book,8,200,2",
+ "book,12,600,3")
+
+ assertEquals(updatedExpected.sorted, sink.getUpsertResults.sorted)
+ }
+
+ // FIXME
+ @Ignore("Enable after agg implements ExecNode")
+ @Test
+ def testTopNWithGroupBy(): Unit = {
+ val data = List(
+ ("book", 1, 11),
+ ("book", 2, 19),
+ ("book", 4, 13),
+ ("book", 1, 11),
+ ("fruit", 4, 33),
+ ("fruit", 5, 12),
+ ("fruit", 3, 44),
+ ("fruit", 5, 22))
+
+ val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num)
+ tEnv.registerTable("T", ds)
+
+ val sql =
+ """
+ |SELECT *
+ |FROM (
+ | SELECT category, shopId, num,
+ | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num DESC) as rank_num
+ | FROM (
+ | SELECT category, shopId, sum(num) as num
+ | FROM T
+ | GROUP BY category, shopId
+ | ))
+ |WHERE rank_num <= 2
+ """.stripMargin
+
+ val table = tEnv.sqlQuery(sql)
+ val schema = table.getSchema
+ val sink = new TestingUpsertTableSink(Array(0, 3)).
+ configure(schema.getFieldNames, schema.getFieldTypes)
+ tEnv.writeToSink(table, sink)
+ env.execute()
+
+ val updatedExpected = List(
+ "book,1,22,1",
+ "book,2,19,2",
+ "fruit,3,44,1",
+ "fruit,5,34,2")
+ assertEquals(updatedExpected.sorted, sink.getUpsertResults.sorted)
+ }
+
+ // FIXME
+ @Ignore("Enable after agg implements ExecNode")
+ @Test
+ def testTopNWithSumAndCondition(): Unit = {
+ val data = List(
+ Row.of("book", Int.box(11), Double.box(100)),
+ Row.of("book", Int.box(11), Double.box(200)),
+ Row.of("book", Int.box(12), Double.box(400)),
+ Row.of("book", Int.box(12), Double.box(500)),
+ Row.of("book", Int.box(10), Double.box(600)),
+ Row.of("book", Int.box(10), Double.box(700)))
+
+ implicit val tpe: TypeInformation[Row] = new RowTypeInfo(
+ BasicTypeInfo.STRING_TYPE_INFO,
+ BasicTypeInfo.INT_TYPE_INFO,
+ BasicTypeInfo.DOUBLE_TYPE_INFO) // tpe is automatically
+
+ val ds = env.fromCollection(data)
+ val t = ds.toTable(tEnv, 'category, 'shopId, 'num)
+ tEnv.registerTable("T", t)
+
+ val subquery =
+ """
+ |SELECT category, shopId, sum(num) as sum_num
+ |FROM T
+ |WHERE num >= cast(1.1 as double)
+ |GROUP BY category, shopId
+ """.stripMargin
+
+ val sql =
+ s"""
+ |SELECT *
+ |FROM (
+ | SELECT category, shopId, sum_num,
+ | ROW_NUMBER() OVER (PARTITION BY category ORDER BY sum_num DESC) as rank_num
+ | FROM ($subquery))
+ |WHERE rank_num <= 2
+ """.stripMargin
+
+ val table = tEnv.sqlQuery(sql)
+ val schema = table.getSchema
+ val sink = new TestingUpsertTableSink(Array(0, 3)).
+ configure(schema.getFieldNames, schema.getFieldTypes)
+ tEnv.writeToSink(table, sink)
+ env.execute()
+
+ val updatedExpected = List(
+ "book,10,1300.0,1",
+ "book,12,900.0,2")
+
+ assertEquals(updatedExpected.sorted, sink.getUpsertResults.sorted)
+ }
+
+ // FIXME
+ @Ignore("Enable after agg implements ExecNode")
+ @Test
+ def testTopNthWithGroupBy(): Unit = {
+ val data = List(
+ ("book", 1, 11),
+ ("book", 2, 19),
+ ("book", 4, 13),
+ ("book", 1, 11),
+ ("fruit", 4, 33),
+ ("fruit", 5, 12),
+ ("fruit", 3, 44),
+ ("fruit", 5, 22))
+
+ val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num)
+ tEnv.registerTable("T", ds)
+
+ val sql =
+ """
+ |SELECT *
+ |FROM (
+ | SELECT category, shopId, num,
+ | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num DESC) as rank_num
+ | FROM (
+ | SELECT category, shopId, sum(num) as num
+ | FROM T
+ | GROUP BY category, shopId
+ | ))
+ |WHERE rank_num = 2
+ """.stripMargin
+
+ val table = tEnv.sqlQuery(sql)
+ val schema = table.getSchema
+ val sink = new TestingUpsertTableSink(Array(0, 3)).
+ configure(schema.getFieldNames, schema.getFieldTypes)
+ tEnv.writeToSink(table, sink)
+ env.execute()
+
+ val updatedExpected = List(
+ "book,2,19,2",
+ "fruit,5,34,2")
+
+ assertEquals(updatedExpected.sorted, sink.getUpsertResults.sorted)
+ }
+
+ // FIXME
+ @Ignore("Enable after agg implements ExecNode")
+ @Test
+ def testTopNWithGroupByAndRetract(): Unit = {
+ val data = List(
+ ("book", 1, 11),
+ ("book", 2, 19),
+ ("book", 4, 13),
+ ("book", 1, 11),
+ ("fruit", 4, 33),
+ ("fruit", 5, 12),
+ ("fruit", 3, 44),
+ ("fruit", 5, 22))
+
+ val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num)
+ tEnv.registerTable("T", ds)
+
+ val sql =
+ """
+ |SELECT *
+ |FROM (
+ | SELECT category, shopId, num, cnt,
+ | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num DESC, cnt ASC) as rank_num
+ | FROM (
+ | SELECT category, shopId, sum(num) as num, count(num) as cnt
+ | FROM T
+ | GROUP BY category, shopId
+ | ))
+ |WHERE rank_num <= 2
+ """.stripMargin
+
+ val sink = new TestingRetractSink
+ tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink).setParallelism(1)
+ env.execute()
+
+ val expected = List(
+ "book,1,22,2,1",
+ "book,2,19,1,2",
+ "fruit,3,44,1,1",
+ "fruit,5,34,2,2")
+ assertEquals(expected.sorted, sink.getRetractResults.sorted)
+ }
+
+ // FIXME
+ @Ignore("Enable after agg implements ExecNode")
+ @Test
+ def testTopNthWithGroupByAndRetract(): Unit = {
+ val data = List(
+ ("book", 1, 11),
+ ("book", 2, 19),
+ ("book", 4, 13),
+ ("book", 1, 11),
+ ("fruit", 4, 33),
+ ("fruit", 5, 12),
+ ("fruit", 3, 44),
+ ("fruit", 5, 22))
+
+ val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num)
+ tEnv.registerTable("T", ds)
+
+ val sql =
+ """
+ |SELECT *
+ |FROM (
+ | SELECT category, shopId, num, cnt,
+ | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num DESC, cnt ASC) as rank_num
+ | FROM (
+ | SELECT category, shopId, sum(num) as num, count(num) as cnt
+ | FROM T
+ | GROUP BY category, shopId
+ | ))
+ |WHERE rank_num = 2
+ """.stripMargin
+
+ val sink = new TestingRetractSink
+ tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink).setParallelism(1)
+ env.execute()
+
+ val expected = List(
+ "book,2,19,1,2",
+ "fruit,5,34,2,2")
+ assertEquals(expected.sorted, sink.getRetractResults.sorted)
+ }
+
+ // FIXME
+ @Ignore("Enable after agg implements ExecNode")
+ @Test
+ def testTopNWithGroupByCount(): Unit = {
+ val data = List(
+ ("book", 1, 1001),
+ ("book", 2, 1002),
+ ("book", 4, 1003),
+ ("book", 1, 1004),
+ ("book", 1, 1005),
+ ("book", 3, 1006),
+ ("book", 2, 1007),
+ ("book", 4, 1008),
+ ("book", 1, 1009),
+ ("book", 4, 1010),
+ ("book", 4, 1012),
+ ("book", 4, 1012),
+ ("fruit", 4, 1013),
+ ("fruit", 5, 1014),
+ ("fruit", 3, 1015),
+ ("fruit", 4, 1017),
+ ("fruit", 5, 1018),
+ ("fruit", 5, 1016))
+
+ val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'sellId)
+ tEnv.registerTable("T", ds)
+
+ val sql =
+ """
+ |SELECT category, rank_num, sells, shopId
+ |FROM (
+ | SELECT category, shopId, sells,
+ | ROW_NUMBER() OVER (PARTITION BY category ORDER BY sells DESC) as rank_num
+ | FROM (
+ | SELECT category, shopId, count(sellId) as sells
+ | FROM T
+ | GROUP BY category, shopId
+ | ))
+ |WHERE rank_num <= 4
+ """.stripMargin
+
+ val table = tEnv.sqlQuery(sql)
+ val schema = table.getSchema
+ val sink = new TestingUpsertTableSink(Array(0, 1)).
+ configure(schema.getFieldNames, schema.getFieldTypes)
+ tEnv.writeToSink(table, sink)
+ env.execute()
+
+ val expected = List(
+ "book,1,5,4",
+ "book,2,4,1",
+ "book,3,2,2",
+ "book,4,1,3",
+ "fruit,1,3,5",
+ "fruit,2,2,4",
+ "fruit,3,1,3")
+ assertEquals(expected.sorted, sink.getUpsertResults.sorted)
+ }
+
+ // FIXME
+ @Ignore("Enable after agg implements ExecNode")
+ @Test
+ def testTopNthWithGroupByCount(): Unit = {
+ val data = List(
+ ("book", 1, 1001),
+ ("book", 2, 1002),
+ ("book", 4, 1003),
+ ("book", 1, 1004),
+ ("book", 1, 1005),
+ ("book", 3, 1006),
+ ("book", 2, 1007),
+ ("book", 4, 1008),
+ ("book", 1, 1009),
+ ("book", 4, 1010),
+ ("book", 4, 1012),
+ ("book", 4, 1012),
+ ("fruit", 4, 1013),
+ ("fruit", 5, 1014),
+ ("fruit", 3, 1015),
+ ("fruit", 4, 1017),
+ ("fruit", 5, 1018),
+ ("fruit", 5, 1016))
+
+ val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'sellId)
+ tEnv.registerTable("T", ds)
+
+ val sql =
+ """
+ |SELECT category, rank_num, sells, shopId
+ |FROM (
+ | SELECT category, shopId, sells,
+ | ROW_NUMBER() OVER (PARTITION BY category ORDER BY sells DESC) as rank_num
+ | FROM (
+ | SELECT category, shopId, count(sellId) as sells
+ | FROM T
+ | GROUP BY category, shopId
+ | ))
+ |WHERE rank_num = 3
+ """.stripMargin
+
+ val table = tEnv.sqlQuery(sql)
+ val schema = table.getSchema
+ val sink = new TestingUpsertTableSink(Array(0, 1)).
+ configure(schema.getFieldNames, schema.getFieldTypes)
+ tEnv.writeToSink(table, sink)
+ env.execute()
+
+ val expected = List(
+ "book,3,2,2",
+ "fruit,3,1,3")
+ assertEquals(expected.sorted, sink.getUpsertResults.sorted)
+ }
+
+ // FIXME
+ @Ignore("Enable after agg implements ExecNode")
+ @Test
+ def testNestedTopN(): Unit = {
+ val data = List(
+ ("book", "a", 1),
+ ("book", "b", 1),
+ ("book", "c", 1),
+ ("fruit", "a", 2),
+ ("book", "a", 1),
+ ("book", "d", 0),
+ ("book", "b", 3),
+ ("fruit", "b", 6),
+ ("book", "c", 1),
+ ("book", "e", 5),
+ ("book", "d", 4))
+
+ env.setParallelism(1)
+ val ds = failingDataSource(data).toTable(tEnv, 'cate, 'shopId, 'sells)
+ tEnv.registerTable("T", ds)
+
+ val sql =
+ """
+ |SELECT rank_num, cate, shopId, sells, cnt
+ |FROM (
+ | SELECT *,
+ | ROW_NUMBER() OVER (PARTITION BY cate ORDER BY sells DESC) as rank_num
+ | FROM (
+ | SELECT cate, shopId, count(*) as cnt, max(sells) as sells
+ | FROM T
+ | GROUP BY cate, shopId
+ | ))
+ |WHERE rank_num <= 4
+ """.stripMargin
+
+
+ val sql2 =
+ s"""
+ |SELECT rank_num, cate, shopId, sells, cnt
+ |FROM (
+ | SELECT cate, shopId, sells, cnt,
+ | ROW_NUMBER() OVER (ORDER BY sells DESC) as rank_num
+ | FROM ($sql)
+ |)
+ |WHERE rank_num <= 4
+ """.stripMargin
+
+ val table = tEnv.sqlQuery(sql2)
+ val schema = table.getSchema
+ val sink = new TestingUpsertTableSink(Array(0)).
+ configure(schema.getFieldNames, schema.getFieldTypes)
+ tEnv.writeToSink(table, sink)
+ env.execute()
+
+ val expected = List(
+ "(true,1,book,a,1,1)", "(true,2,book,b,1,1)", "(true,3,book,c,1,1)",
+ "(true,1,fruit,a,2,1)", "(true,2,book,a,1,1)", "(true,3,book,b,1,1)", "(true,4,book,c,1,1)",
+ "(true,2,book,a,1,2)",
+ "(true,1,book,b,3,2)", "(true,2,fruit,a,2,1)", "(true,3,book,a,1,2)",
+ "(true,3,book,a,1,2)",
+ "(true,1,fruit,b,6,1)", "(true,2,book,b,3,2)", "(true,3,fruit,a,2,1)", "(true,4,book,a,1,2)",
+ "(true,3,fruit,a,2,1)",
+ "(true,2,book,e,5,1)",
+ "(true,3,book,b,3,2)", "(true,4,fruit,a,2,1)",
+ "(true,3,book,b,3,2)",
+ "(true,3,book,d,4,2)",
+ "(true,4,book,b,3,2)",
+ "(true,4,book,b,3,2)")
+ assertEquals(expected.mkString("\n"), sink.getRawResults.mkString("\n"))
+
+ val expected2 = List("1,fruit,b,6,1", "2,book,e,5,1", "3,book,d,4,2", "4,book,b,3,2")
+ assertEquals(expected2, sink.getUpsertResults.sorted)
+ }
+
+ // FIXME
+ @Ignore("Enable after agg implements ExecNode")
+ @Test
+ def testTopNWithoutDeduplicate(): Unit = {
+ val data = List(
+ ("book", "a", 1),
+ ("book", "b", 1),
+ ("book", "c", 1),
+ ("fruit", "a", 2),
+ ("book", "a", 1),
+ ("book", "d", 0),
+ ("book", "b", 3),
+ ("fruit", "b", 6),
+ ("book", "c", 1),
+ ("book", "e", 5),
+ ("book", "d", 4))
+
+ env.setParallelism(1)
+ val ds = failingDataSource(data).toTable(tEnv, 'cate, 'shopId, 'sells)
+ tEnv.registerTable("T", ds)
+
+ val sql =
+ """
+ |SELECT rank_num, cate, shopId, sells, cnt
+ |FROM (
+ | SELECT *,
+ | ROW_NUMBER() OVER (PARTITION BY cate ORDER BY sells DESC) as rank_num
+ | FROM (
+ | SELECT cate, shopId, count(*) as cnt, max(sells) as sells
+ | FROM T
+ | GROUP BY cate, shopId
+ | ))
+ |WHERE rank_num <= 4
+ """.stripMargin
+
+ tEnv.getConfig.getConf.setLong(TableConfigOptions.SQL_EXEC_TOPN_CACHE_SIZE, 1)
+
+ val table = tEnv.sqlQuery(sql)
+ val schema = table.getSchema
+ val sink = new TestingUpsertTableSink(Array(0)).
+ configure(schema.getFieldNames, schema.getFieldTypes)
+ tEnv.writeToSink(table, sink)
+ env.execute()
+
+ val expected = List(
+ "(true,1,book,a,1,1)",
+ "(true,2,book,b,1,1)",
+ "(true,3,book,c,1,1)",
+ "(true,1,fruit,a,2,1)",
+ "(true,1,book,a,1,2)",
+ "(true,4,book,d,0,1)",
+ "(true,1,book,b,3,2)",
+ "(true,2,book,a,1,2)",
+ "(true,1,fruit,b,6,1)",
+ "(true,2,fruit,a,2,1)",
+ "(true,3,book,c,1,2)",
+ "(true,1,book,e,5,1)",
+ "(true,2,book,b,3,2)",
+ "(true,3,book,a,1,2)",
+ "(true,4,book,c,1,2)",
+ "(true,2,book,d,4,2)",
+ "(true,3,book,b,3,2)",
+ "(true,4,book,a,1,2)")
+
+ assertEquals(expected, sink.getRawResults)
+ }
+
+ // FIXME
+ @Ignore("Enable after agg implements ExecNode")
+ @Test
+ def testTopNWithVariableTopSize(): Unit = {
+ val data = List(
+ ("book", 1, 1001, 4),
+ ("book", 2, 1002, 4),
+ ("book", 4, 1003, 4),
+ ("book", 1, 1004, 4),
+ ("book", 1, 1005, 4),
+ ("book", 3, 1006, 4),
+ ("book", 2, 1007, 4),
+ ("book", 4, 1008, 4),
+ ("book", 1, 1009, 4),
+ ("book", 4, 1010, 4),
+ ("book", 4, 1012, 4),
+ ("book", 4, 1012, 4),
+ ("fruit", 4, 1013, 2),
+ ("fruit", 5, 1014, 2),
+ ("fruit", 3, 1015, 2),
+ ("fruit", 4, 1017, 2),
+ ("fruit", 5, 1018, 2),
+ ("fruit", 5, 1016, 2))
+
+ val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'sellId, 'topSize)
+ tEnv.registerTable("T", ds)
+
+ val sql =
+ """
+ |SELECT category, rank_num, sells, shopId
+ |FROM (
+ | SELECT category, shopId, sells, topSize,
+ | ROW_NUMBER() OVER (PARTITION BY category ORDER BY sells DESC) as rank_num
+ | FROM (
+ | SELECT category, shopId, count(sellId) as sells, max(topSize) as topSize
+ | FROM T
+ | GROUP BY category, shopId
+ | ))
+ |WHERE rank_num <= topSize
+ """.stripMargin
+
+ val table = tEnv.sqlQuery(sql)
+ val schema = table.getSchema
+ val sink = new TestingUpsertTableSink(Array(0, 1)).
+ configure(schema.getFieldNames, schema.getFieldTypes)
+ tEnv.writeToSink(table, sink)
+ env.execute()
+
+ val expected = List(
+ "book,1,5,4",
+ "book,2,4,1",
+ "book,3,2,2",
+ "book,4,1,3",
+ "fruit,1,3,5",
+ "fruit,2,2,4")
+ assertEquals(expected.sorted, sink.getUpsertResults.sorted)
+ }
+
+ // FIXME
+ @Ignore("Enable after agg implements ExecNode")
+ @Test
+ def testTopNUnaryComplexScenario(): Unit = {
+ val data = List(
+ ("book", 1, 11),
+ ("book", 2, 19),
+ ("book", 4, 13),
+ ("book", 1, 11), // backward update in heap
+ ("book", 3, 23), // elems exceed topn size after insert
+ ("book", 5, 19), // sort map shirk out some elem after insert
+ ("book", 7, 10), // sort map keeps a little more than topn size elems after insert
+ ("book", 8, 13), // sort map now can shrink out-of-range elems after another insert
+ ("book", 10, 13), // Once again, sort map keeps a little more elems after insert
+ ("book", 8, 6), // backward update from heap to state
+ ("book", 10, 6), // backward update from heap to state, and sort map load more data
+ ("book", 5, 3), // backward update from heap to state
+ ("book", 10, 1), // backward update in heap, and then sort map shrink some data
+ ("book", 5, 1), // backward update in state
+ ("book", 5, -3), // forward update in state
+ ("book", 2, -10), // forward update in heap, and then sort map shrink some data
+ ("book", 10, -7), // forward update from state to heap
+ ("book", 11, 13), // insert into heap
+ ("book", 12, 10), // insert into heap, and sort map shrink some data
+ ("book", 15, 14) // insert into state
+ )
+
+ env.setParallelism(1)
+
+ val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num)
+ tEnv.registerTable("T", ds)
+
+ val sql =
+ """
+ |SELECT *
+ |FROM (
+ | SELECT category, shopId, num,
+ | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num ASC) as rank_num
+ | FROM (
+ | SELECT category, shopId, sum(num) as num
+ | FROM T
+ | GROUP BY category, shopId
+ | ))
+ |WHERE rank_num <= 3
+ """.stripMargin
+
+ val table = tEnv.sqlQuery(sql)
+ val schema = table.getSchema
+ val sink = new TestingUpsertTableSink(Array(0, 3)).
+ configure(schema.getFieldNames, schema.getFieldTypes)
+ tEnv.writeToSink(table, sink)
+ env.execute()
+
+ val expected = List(
+ "(true,book,1,11,1)",
+ "(true,book,2,19,2)",
+ "(true,book,4,13,2)",
+ "(true,book,2,19,3)",
+
+ "(true,book,4,13,1)",
+ "(true,book,2,19,2)",
+ "(true,book,1,22,3)",
+
+ "(true,book,5,19,3)",
+
+ "(true,book,7,10,1)",
+ "(true,book,4,13,2)",
+ "(true,book,2,19,3)",
+
+ "(true,book,8,13,3)",
+
+ "(true,book,10,13,3)",
+
+ "(true,book,2,19,3)",
+
+ "(true,book,2,9,1)",
+ "(true,book,7,10,2)",
+ "(true,book,4,13,3)",
+
+ "(true,book,12,10,3)")
+
+ assertEquals(expected.mkString("\n"), sink.getRawResults.mkString("\n"))
+
+ val updatedExpected = List(
+ "book,2,9,1",
+ "book,7,10,2",
+ "book,12,10,3")
+
+ assertEquals(updatedExpected.sorted, sink.getUpsertResults.sorted)
+ }
+
+ // FIXME
+ @Ignore("Enable after agg implements ExecNode")
+ @Test
+ def testTopNWithGroupByAvgWithoutRowNumber(): Unit = {
+ val data = List(
+ ("book", 1, 100),
+ ("book", 3, 110),
+ ("book", 4, 120),
+ ("book", 1, 200),
+ ("book", 1, 200),
+ ("book", 2, 300),
+ ("book", 2, 400),
+ ("book", 4, 500),
+ ("book", 1, 400),
+ ("fruit", 5, 100))
+
+ env.setParallelism(1)
+ val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'sellId)
+ tEnv.registerTable("T", ds)
+
+ val sql =
+ """
+ |SELECT category, shopId, avgSellId
+ |FROM (
+ | SELECT category, shopId, avgSellId,
+ | ROW_NUMBER() OVER (PARTITION BY category ORDER BY avgSellId DESC) as rank_num
+ | FROM (
+ | SELECT category, shopId, AVG(sellId) as avgSellId
+ | FROM T
+ | GROUP BY category, shopId
+ | ))
+ |WHERE rank_num <= 3
+ """.stripMargin
+
+ val table = tEnv.sqlQuery(sql)
+ val schema = table.getSchema
+ val sink = new TestingUpsertTableSink(Array(0, 1)).
+ configure(schema.getFieldNames, schema.getFieldTypes)
+ tEnv.writeToSink(table, sink)
+ env.execute()
+
+ val expected = List(
+ "(true,book,1,100.0)",
+ "(true,book,3,110.0)",
+ "(true,book,4,120.0)",
+ "(true,book,1,150.0)",
+ "(true,book,1,166.66666666666666)",
+ "(true,book,2,300.0)",
+ "(false,book,3,110.0)",
+ "(true,book,2,350.0)",
+ "(true,book,4,310.0)",
+ "(true,book,1,225.0)",
+ "(true,fruit,5,100.0)")
+
+ assertEquals(expected, sink.getRawResults)
+
+ val updatedExpected = List(
+ "book,1,225.0",
+ "book,2,350.0",
+ "book,4,310.0",
+ "fruit,5,100.0")
+
+ assertEquals(updatedExpected.sorted, sink.getUpsertResults.sorted)
+ }
+
+ // FIXME
+ @Ignore("Enable after agg implements ExecNode")
+ @Test
+ def testTopNWithGroupByCountWithoutRowNumber(): Unit = {
+ val data = List(
+ ("book", 1, 1001),
+ ("book", 3, 1006),
+ ("book", 4, 1003),
+ ("book", 1, 1004),
+ ("book", 1, 1005),
+ ("book", 2, 1002),
+ ("book", 2, 1007),
+ ("book", 4, 1008),
+ ("book", 1, 1009),
+ ("book", 4, 1010),
+ ("book", 4, 1012),
+ ("book", 4, 1012),
+ ("fruit", 4, 1013),
+ ("fruit", 5, 1014),
+ ("fruit", 3, 1015),
+ ("fruit", 4, 1017),
+ ("fruit", 5, 1018),
+ ("fruit", 5, 1016))
+
+ env.setParallelism(1)
+ val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'sellId)
+ tEnv.registerTable("T", ds)
+
+ val sql =
+ """
+ |SELECT category, shopId, sells
+ |FROM (
+ | SELECT category, shopId, sells,
+ | ROW_NUMBER() OVER (PARTITION BY category ORDER BY sells DESC) as rank_num
+ | FROM (
+ | SELECT category, shopId, count(sellId) as sells
+ | FROM T
+ | GROUP BY category, shopId
+ | ))
+ |WHERE rank_num <= 3
+ """.stripMargin
+
+ val table = tEnv.sqlQuery(sql)
+ val schema = table.getSchema
+ val sink = new TestingUpsertTableSink(Array(0, 1)).
+ configure(schema.getFieldNames, schema.getFieldTypes)
+ tEnv.writeToSink(table, sink)
+ env.execute()
+
+ val expected = List(
+ "(true,book,1,1)",
+ "(true,book,3,1)",
+ "(true,book,4,1)",
+ "(true,book,1,2)",
+ "(true,book,1,3)",
+ "(true,book,2,2)",
+ "(false,book,4,1)",
+ "(true,book,4,2)",
+ "(false,book,3,1)",
+ "(true,book,1,4)",
+ "(true,book,4,3)",
+ "(true,book,4,4)",
+ "(true,book,4,5)",
+ "(true,fruit,4,1)",
+ "(true,fruit,5,1)",
+ "(true,fruit,3,1)",
+ "(true,fruit,4,2)",
+ "(true,fruit,5,2)",
+ "(true,fruit,5,3)")
+ assertEquals(expected, sink.getRawResults)
+
+ val updatedExpected = List(
+ "book,4,5",
+ "book,1,4",
+ "book,2,2",
+ "fruit,5,3",
+ "fruit,4,2",
+ "fruit,3,1")
+ assertEquals(updatedExpected.sorted, sink.getUpsertResults.sorted)
+ }
+
+ @Test
+ def testTopNWithoutRowNumber(): Unit = {
+ val data = List(
+ ("book", 1, 12),
+ ("book", 2, 19),
+ ("book", 4, 11),
+ ("book", 5, 20),
+ ("fruit", 4, 33),
+ ("fruit", 3, 44),
+ ("fruit", 5, 22),
+ ("fruit", 1, 40))
+
+ env.setParallelism(1)
+ val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num)
+ tEnv.registerTable("T", ds)
+
+ val sql =
+ """
+ |SELECT category, num, shopId
+ |FROM (
+ | SELECT category, shopId, num,
+ | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num DESC) as rank_num
+ | FROM T)
+ |WHERE rank_num <= 2
+ """.stripMargin
+
+ val table = tEnv.sqlQuery(sql)
+ val schema = table.getSchema
+ val sink = new TestingUpsertTableSink(Array(0, 2)).
+ configure(schema.getFieldNames, schema.getFieldTypes)
+ tEnv.writeToSink(table, sink)
+ env.execute()
+
+ val expected = List(
+ "(true,book,12,1)",
+ "(true,book,19,2)",
+ "(false,book,12,1)",
+ "(true,book,20,5)",
+ "(true,fruit,33,4)",
+ "(true,fruit,44,3)",
+ "(false,fruit,33,4)",
+ "(true,fruit,40,1)")
+ assertEquals(expected, sink.getRawResults)
+
+ val updatedExpected = List(
+ "book,19,2",
+ "book,20,5",
+ "fruit,40,1",
+ "fruit,44,3")
+ assertEquals(updatedExpected.sorted, sink.getUpsertResults.sorted)
+ }
+
+ // FIXME
+ @Ignore("Enable after agg implements ExecNode")
+ @Test
+ def testMultipleRetractTopNAfterAgg(): Unit = {
+ val data = List(
+ ("book", 1, 12),
+ ("book", 1, 13),
+ ("book", 2, 19),
+ ("book", 4, 11),
+ ("fruit", 4, 33),
+ ("fruit", 3, 44),
+ ("fruit", 5, 22))
+
+ env.setParallelism(1)
+ val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num)
+ tEnv.registerTable("T", ds)
+
+ val subquery =
+ s"""
+ |SELECT category, shopId, SUM(num) as sum_num, MAX(num) as max_num,
+ | AVG(num) as avg_num, COUNT(num) as cnt
+ |FROM T
+ |GROUP BY category, shopId
+ |""".stripMargin
+
+ val t1 = tEnv.sqlQuery(subquery)
+ tEnv.registerTable("MyView", t1)
+
+ val sink1 = new TestingRetractSink
+ tEnv.sqlQuery(
+ s"""
+ |SELECT *
+ |FROM (
+ | SELECT category, shopId, sum_num, avg_num,
+ | ROW_NUMBER() OVER (PARTITION BY category ORDER BY sum_num DESC, avg_num ASC
+ | ) as rank_num
+ | FROM MyView)
+ |WHERE rank_num <= 2
+ |""".stripMargin).toRetractStream[Row].addSink(sink1).setParallelism(1)
+ env.execute()
+
+ val expected1 = List(
+ "book,1,25,12.5,1",
+ "book,2,19,19.0,2",
+ "fruit,3,44,44.0,1",
+ "fruit,4,33,33.0,2")
+ assertEquals(expected1.sorted, sink1.getRetractResults.sorted)
+
+ val sink2 = new TestingRetractSink
+ tEnv.sqlQuery(
+ s"""
+ |SELECT *
+ |FROM (
+ | SELECT category, shopId, max_num, cnt,
+ | ROW_NUMBER() OVER (PARTITION BY category ORDER BY max_num DESC, cnt ASC) as rank_num
+ | FROM MyView)
+ |WHERE rank_num <= 2
+ |""".stripMargin).toRetractStream[Row].addSink(sink2).setParallelism(1)
+ env.execute()
+
+ val expected2 = List(
+ "book,2,19,1,1",
+ "book,1,13,2,2",
+ "fruit,3,44,1,1",
+ "fruit,4,33,1,2")
+ assertEquals(expected2.sorted, sink2.getRetractResults.sorted)
+ }
+
+ // FIXME
+ @Ignore("Enable after agg implements ExecNode")
+ @Test
+ def testMultipleUnaryTopNAfterAgg(): Unit = {
+ val data = List(
+ ("book", 1, 12),
+ ("book", 1, 13),
+ ("book", 2, 19),
+ ("book", 4, 11),
+ ("fruit", 4, 33),
+ ("fruit", 3, 44),
+ ("fruit", 5, 22))
+
+ env.setParallelism(1)
+ val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num)
+ tEnv.registerTable("T", ds)
+
+ val subquery =
+ s"""
+ |SELECT category, shopId, SUM(num) as sum_num, MAX(num) as max_num
+ |FROM T
+ |GROUP BY category, shopId
+ |""".stripMargin
+
+ val t1 = tEnv.sqlQuery(subquery)
+ tEnv.registerTable("MyView", t1)
+
+ val table1 = tEnv.sqlQuery(
+ s"""
+ |SELECT *
+ |FROM (
+ | SELECT category, shopId, sum_num,
+ | ROW_NUMBER() OVER (PARTITION BY category ORDER BY sum_num DESC) as rank_num
+ | FROM MyView)
+ |WHERE rank_num <= 2
+ |""".stripMargin)
+ val schema1 = table1.getSchema
+ val sink1 = new TestingUpsertTableSink(Array(0, 3)).
+ configure(schema1.getFieldNames, schema1
+ .getFieldTypes)
+ tEnv.writeToSink(table1, sink1)
+
+ val table2 = tEnv.sqlQuery(
+ s"""
+ |SELECT *
+ |FROM (
+ | SELECT category, shopId, max_num,
+ | ROW_NUMBER() OVER (PARTITION BY category ORDER BY max_num DESC) as rank_num
+ | FROM MyView)
+ |WHERE rank_num <= 2
+ |""".stripMargin)
+ val schema2 = table2.getSchema
+ val sink2 = new TestingUpsertTableSink(Array(0, 3)).
+ configure(schema2.getFieldNames, schema2
+ .getFieldTypes)
+ tEnv.writeToSink(table2, sink2)
+
+ env.execute()
+ val expected1 = List(
+ "book,1,25,1",
+ "book,2,19,2",
+ "fruit,3,44,1",
+ "fruit,4,33,2")
+ assertEquals(expected1.sorted, sink1.getUpsertResults.sorted)
+
+ val expected2 = List(
+ "book,2,19,1",
+ "book,1,13,2",
+ "fruit,3,44,1",
+ "fruit,4,33,2")
+ assertEquals(expected2.sorted, sink2.getUpsertResults.sorted)
+ }
+
+ // FIXME
+ @Ignore("Enable after agg implements ExecNode")
+ @Test
+ def testMultipleUpdateTopNAfterAgg(): Unit = {
+ val data = List(
+ ("book", 1, 12),
+ ("book", 1, 13),
+ ("book", 2, 19),
+ ("book", 4, 11),
+ ("fruit", 4, 33),
+ ("fruit", 3, 44),
+ ("fruit", 5, 22))
+
+ env.setParallelism(1)
+ val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num)
+ tEnv.registerTable("T", ds)
+
+ val subquery =
+ s"""
+ |SELECT category, shopId, COUNT(num) as cnt_num, MAX(num) as max_num
+ |FROM T
+ |GROUP BY category, shopId
+ |""".stripMargin
+
+ val t1 = tEnv.sqlQuery(subquery)
+ tEnv.registerTable("MyView", t1)
+
+ val table1 = tEnv.sqlQuery(
+ s"""
+ |SELECT *
+ |FROM (
+ | SELECT category, shopId, cnt_num,
+ | ROW_NUMBER() OVER (PARTITION BY category ORDER BY cnt_num DESC) as rank_num
+ | FROM MyView)
+ |WHERE rank_num <= 2
+ |""".stripMargin)
+ val schema1 = table1.getSchema
+ val sink1 = new TestingRetractTableSink().
+ configure(schema1.getFieldNames, schema1.getFieldTypes)
+ tEnv.writeToSink(table1, sink1)
+
+ val table2 = tEnv.sqlQuery(
+ s"""
+ |SELECT *
+ |FROM (
+ | SELECT category, shopId, max_num,
+ | ROW_NUMBER() OVER (PARTITION BY category ORDER BY max_num DESC) as rank_num
+ | FROM MyView)
+ |WHERE rank_num <= 2
+ |""".stripMargin)
+ val schema2 = table2.getSchema
+ val sink2 = new TestingRetractTableSink().
+ configure(schema2.getFieldNames, schema2.getFieldTypes)
+ tEnv.writeToSink(table2, sink2)
+ env.execute()
+
+ val expected1 = List(
+ "book,1,2,1",
+ "book,2,1,2",
+ "fruit,4,1,1",
+ "fruit,3,1,2")
+ assertEquals(expected1.sorted, sink1.getRetractResults.sorted)
+
+ val expected2 = List(
+ "book,2,19,1",
+ "book,1,13,2",
+ "fruit,3,44,1",
+ "fruit,4,33,2")
+ assertEquals(expected2.sorted, sink2.getRetractResults.sorted)
+ }
+
+}
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamTestSink.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamTestSink.scala
index 95138c28edf6d..e91a1523b3ac9 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamTestSink.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamTestSink.scala
@@ -19,23 +19,26 @@
package org.apache.flink.table.runtime.utils
import org.apache.flink.api.common.ExecutionConfig
+import org.apache.flink.api.common.functions.MapFunction
import org.apache.flink.api.common.io.OutputFormat
import org.apache.flink.api.common.state.{ListState, ListStateDescriptor}
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
-import org.apache.flink.api.java.typeutils.RowTypeInfo
+import org.apache.flink.api.java.typeutils.{RowTypeInfo, TupleTypeInfo}
import org.apache.flink.configuration.Configuration
import org.apache.flink.runtime.state.{FunctionInitializationContext, FunctionSnapshotContext}
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction
import org.apache.flink.streaming.api.datastream.{DataStream, DataStreamSink}
import org.apache.flink.streaming.api.functions.sink.RichSinkFunction
import org.apache.flink.table.api.{TableConfig, Types}
-import org.apache.flink.table.dataformat.BaseRow
-import org.apache.flink.table.sinks.{AppendStreamTableSink, BatchTableSink}
+import org.apache.flink.table.dataformat.{BaseRow, DataFormatConverters, GenericRow}
+import org.apache.flink.table.sinks._
+import org.apache.flink.table.`type`.TypeConverters.createInternalTypeFromTypeInfo
import org.apache.flink.table.typeutils.BaseRowTypeInfo
import org.apache.flink.table.util.BaseRowTestUtil
import org.apache.flink.types.Row
+import _root_.java.lang.{Boolean => JBoolean}
import _root_.java.util.TimeZone
import _root_.java.util.concurrent.atomic.AtomicInteger
@@ -58,7 +61,7 @@ object StreamTestSink {
private[utils] def getNewSinkId: Int = {
val idx = idCounter.getAndIncrement()
- this.synchronized{
+ this.synchronized {
globalResults.put(idx, mutable.HashMap.empty[Int, ArrayBuffer[String]])
globalRetractResults.put(idx, mutable.HashMap.empty[Int, ArrayBuffer[String]])
globalUpsertResults.put(idx, mutable.HashMap.empty[Int, mutable.Map[String, String]])
@@ -78,7 +81,7 @@ abstract class AbstractExactlyOnceSink[T] extends RichSinkFunction[T] with Check
protected var localResults: ArrayBuffer[String] = _
protected val idx: Int = StreamTestSink.getNewSinkId
- protected var globalResults: mutable.Map[Int, ArrayBuffer[String]]= _
+ protected var globalResults: mutable.Map[Int, ArrayBuffer[String]] = _
protected var globalRetractResults: mutable.Map[Int, ArrayBuffer[String]] = _
protected var globalUpsertResults: mutable.Map[Int, mutable.Map[String, String]] = _
@@ -109,7 +112,7 @@ abstract class AbstractExactlyOnceSink[T] extends RichSinkFunction[T] with Check
protected def clearAndStashGlobalResults(): Unit = {
if (globalResults == null) {
- StreamTestSink.synchronized{
+ StreamTestSink.synchronized {
globalResults = StreamTestSink.globalResults.remove(idx).get
globalRetractResults = StreamTestSink.globalRetractResults.remove(idx).get
globalUpsertResults = StreamTestSink.globalUpsertResults.remove(idx).get
@@ -146,12 +149,166 @@ final class TestingAppendSink(tz: TimeZone) extends AbstractExactlyOnceSink[Row]
def this() {
this(TimeZone.getTimeZone("UTC"))
}
+
def invoke(value: Row): Unit = localResults += TestSinkUtil.rowToString(value, tz)
+
def getAppendResults: List[String] = getResults
}
+final class TestingUpsertSink(keys: Array[Int], tz: TimeZone)
+ extends AbstractExactlyOnceSink[(Boolean, BaseRow)] {
+
+ private var upsertResultsState: ListState[String] = _
+ private var localUpsertResults: mutable.Map[String, String] = _
+ private var fieldTypes: Array[TypeInformation[_]] = _
+
+ def this(keys: Array[Int]) {
+ this(keys, TimeZone.getTimeZone("UTC"))
+ }
+
+ def configureTypes(fieldTypes: Array[TypeInformation[_]]): Unit = {
+ this.fieldTypes = fieldTypes
+ }
+
+ override def initializeState(context: FunctionInitializationContext): Unit = {
+ super.initializeState(context)
+ upsertResultsState = context.getOperatorStateStore.getListState(
+ new ListStateDescriptor[String]("sink-upsert-results", Types.STRING))
+
+ localUpsertResults = mutable.HashMap.empty[String, String]
+
+ if (context.isRestored) {
+ var key: String = null
+ var value: String = null
+ for (entry <- upsertResultsState.get().asScala) {
+ if (key == null) {
+ key = entry
+ } else {
+ value = entry
+ localUpsertResults += (key -> value)
+ key = null
+ value = null
+ }
+ }
+ if (key != null) {
+ throw new RuntimeException("The resultState is corrupt.")
+ }
+ }
+
+ val taskId = getRuntimeContext.getIndexOfThisSubtask
+ StreamTestSink.synchronized {
+ StreamTestSink.globalUpsertResults(idx) += (taskId -> localUpsertResults)
+ }
+ }
+
+ override def snapshotState(context: FunctionSnapshotContext): Unit = {
+ super.snapshotState(context)
+ upsertResultsState.clear()
+ for ((key, value) <- localUpsertResults) {
+ upsertResultsState.add(key)
+ upsertResultsState.add(value)
+ }
+ }
+
+ def invoke(d: (Boolean, BaseRow)): Unit = {
+ this.synchronized {
+ val wrapRow = new GenericRow(2)
+ wrapRow.setField(0, d._1)
+ wrapRow.setField(1, d._2)
+ val converter =
+ DataFormatConverters.getConverterForTypeInfo(
+ new TupleTypeInfo(Types.BOOLEAN, new RowTypeInfo(fieldTypes: _*)))
+ .asInstanceOf[DataFormatConverters.DataFormatConverter[BaseRow, JTuple2[JBoolean, Row]]]
+ val v = converter.toExternal(wrapRow)
+ val rowString = TestSinkUtil.rowToString(v.f1, tz)
+ val tupleString = "(" + v.f0.toString + "," + rowString + ")"
+ localResults += tupleString
+ val keyString = TestSinkUtil.rowToString(Row.project(v.f1, keys), tz)
+ if (v.f0) {
+ localUpsertResults += (keyString -> rowString)
+ } else {
+ val oldValue = localUpsertResults.remove(keyString)
+ if (oldValue.isEmpty) {
+ throw new RuntimeException("Tried to delete a value that wasn't inserted first. " +
+ "This is probably an incorrectly implemented test. " +
+ "Try to set the parallelism of the sink to 1.")
+ }
+ }
+ }
+ }
+
+ def getRawResults: List[String] = getResults
+
+ def getUpsertResults: List[String] = {
+ clearAndStashGlobalResults()
+ val result = ArrayBuffer.empty[String]
+ this.globalUpsertResults.foreach {
+ case (_, map) => map.foreach(result += _._2)
+ }
+ result.toList
+ }
+}
+
+final class TestingUpsertTableSink(keys: Array[Int], tz: TimeZone)
+ extends UpsertStreamTableSink[BaseRow] {
+ var fNames: Array[String] = _
+ var fTypes: Array[TypeInformation[_]] = _
+ var sink = new TestingUpsertSink(keys, tz)
+
+ def this(keys: Array[Int]) {
+ this(keys, TimeZone.getTimeZone("UTC"))
+ }
+
+ override def setKeyFields(keys: Array[String]): Unit = {
+ // ignore
+ }
+
+ override def setIsAppendOnly(isAppendOnly: JBoolean): Unit = {
+ // ignore
+ }
+
+ override def getRecordType: TypeInformation[BaseRow] =
+ new BaseRowTypeInfo(fTypes.map(createInternalTypeFromTypeInfo(_)), fNames)
+
+ override def getFieldNames: Array[String] = fNames
+
+ override def getFieldTypes: Array[TypeInformation[_]] = fTypes
+
+ override def emitDataStream(dataStream: DataStream[JTuple2[JBoolean, BaseRow]]) = {
+ dataStream.map(new MapFunction[JTuple2[JBoolean, BaseRow], (Boolean, BaseRow)] {
+ override def map(value: JTuple2[JBoolean, BaseRow]): (Boolean, BaseRow) = {
+ (value.f0, value.f1)
+ }
+ })
+ .addSink(sink)
+ .name(s"TestingUpsertTableSink(keys=${
+ if (keys != null) {
+ "(" + keys.mkString(",") + ")"
+ } else {
+ "null"
+ }
+ })")
+ .setParallelism(1)
+ }
+
+ override def configure(
+ fieldNames: Array[String],
+ fieldTypes: Array[TypeInformation[_]]): TestingUpsertTableSink = {
+ val copy = new TestingUpsertTableSink(keys, tz)
+ copy.fNames = fieldNames
+ copy.fTypes = fieldTypes
+ sink.configureTypes(fieldTypes)
+ copy.sink = sink
+ copy
+ }
+
+ def getRawResults: List[String] = sink.getRawResults
+
+ def getUpsertResults: List[String] = sink.getUpsertResults
+}
+
final class TestingAppendTableSink(tz: TimeZone) extends AppendStreamTableSink[Row]
- with BatchTableSink[Row]{
+ with BatchTableSink[Row] {
var fNames: Array[String] = _
var fTypes: Array[TypeInformation[_]] = _
var sink = new TestingAppendSink(tz)
@@ -163,7 +320,7 @@ final class TestingAppendTableSink(tz: TimeZone) extends AppendStreamTableSink[R
override def emitDataStream(dataStream: DataStream[Row]): DataStreamSink[Row] = {
dataStream.addSink(sink).name("TestingAppendTableSink")
- .setParallelism(dataStream.getParallelism)
+ .setParallelism(dataStream.getParallelism)
}
override def emitBoundedStream(
@@ -211,23 +368,25 @@ class TestingOutputFormat[T](tz: TimeZone)
def open(taskNumber: Int, numTasks: Int): Unit = {
localRetractResults = mutable.ArrayBuffer.empty[String]
- StreamTestSink.synchronized{
+ StreamTestSink.synchronized {
StreamTestSink.globalResults(index) += (taskNumber -> localRetractResults)
}
}
- def writeRecord(value: T): Unit = localRetractResults += { value match {
- case r: Row => TestSinkUtil.rowToString(r, tz)
- case tp: JTuple2[java.lang.Boolean, Row] =>
- "(" + tp.f0.toString + "," + TestSinkUtil.rowToString(tp.f1, tz) + ")"
- case _ => ""
- }}
+ def writeRecord(value: T): Unit = localRetractResults += {
+ value match {
+ case r: Row => TestSinkUtil.rowToString(r, tz)
+ case tp: JTuple2[java.lang.Boolean, Row] =>
+ "(" + tp.f0.toString + "," + TestSinkUtil.rowToString(tp.f1, tz) + ")"
+ case _ => ""
+ }
+ }
def close(): Unit = {}
protected def clearAndStashGlobalResults(): Unit = {
if (globalResults == null) {
- StreamTestSink.synchronized{
+ StreamTestSink.synchronized {
globalResults = StreamTestSink.globalResults.remove(index).get
}
}
@@ -242,3 +401,118 @@ class TestingOutputFormat[T](tz: TimeZone)
result.toList
}
}
+
+class TestingRetractSink(tz: TimeZone)
+ extends AbstractExactlyOnceSink[(Boolean, Row)] {
+ protected var retractResultsState: ListState[String] = _
+ protected var localRetractResults: ArrayBuffer[String] = _
+
+ def this() {
+ this(TimeZone.getTimeZone("UTC"))
+ }
+
+ override def initializeState(context: FunctionInitializationContext): Unit = {
+ super.initializeState(context)
+ retractResultsState = context.getOperatorStateStore.getListState(
+ new ListStateDescriptor[String]("sink-retract-results", Types.STRING))
+
+ localRetractResults = mutable.ArrayBuffer.empty[String]
+
+ if (context.isRestored) {
+ for (value <- retractResultsState.get().asScala) {
+ localRetractResults += value
+ }
+ }
+
+ val taskId = getRuntimeContext.getIndexOfThisSubtask
+ StreamTestSink.synchronized {
+ StreamTestSink.globalRetractResults(idx) += (taskId -> localRetractResults)
+ }
+ }
+
+ override def snapshotState(context: FunctionSnapshotContext): Unit = {
+ super.snapshotState(context)
+ retractResultsState.clear()
+ for (value <- localRetractResults) {
+ retractResultsState.add(value)
+ }
+ }
+
+ def invoke(v: (Boolean, Row)): Unit = {
+ this.synchronized {
+ val tupleString = "(" + v._1.toString + "," + TestSinkUtil.rowToString(v._2, tz) + ")"
+ localResults += tupleString
+ val rowString = TestSinkUtil.rowToString(v._2, tz)
+ if (v._1) {
+ localRetractResults += rowString
+ } else {
+ val index = localRetractResults.indexOf(rowString)
+ if (index >= 0) {
+ localRetractResults.remove(index)
+ } else {
+ throw new RuntimeException("Tried to retract a value that wasn't added first. " +
+ "This is probably an incorrectly implemented test. " +
+ "Try to set the parallelism of the sink to 1.")
+ }
+ }
+ }
+ }
+
+ def getRawResults: List[String] = getResults
+
+ def getRetractResults: List[String] = {
+ clearAndStashGlobalResults()
+ val result = ArrayBuffer.empty[String]
+ this.globalRetractResults.foreach {
+ case (_, list) => result ++= list
+ }
+ result.toList
+ }
+}
+
+final class TestingRetractTableSink(tz: TimeZone) extends RetractStreamTableSink[Row] {
+
+ var fNames: Array[String] = _
+ var fTypes: Array[TypeInformation[_]] = _
+ var sink = new TestingRetractSink(tz)
+
+ def this() {
+ this(TimeZone.getTimeZone("UTC"))
+ }
+
+ override def emitDataStream(dataStream: DataStream[JTuple2[JBoolean, Row]]) = {
+ dataStream.map(new MapFunction[JTuple2[JBoolean, Row], (Boolean, Row)] {
+ override def map(value: JTuple2[JBoolean, Row]): (Boolean, Row) = {
+ (value.f0, value.f1)
+ }
+ }).setParallelism(dataStream.getParallelism)
+ .addSink(sink)
+ .name("TestingRetractTableSink")
+ .setParallelism(dataStream.getParallelism)
+ }
+
+ override def getRecordType: TypeInformation[Row] =
+ new RowTypeInfo(fTypes, fNames)
+
+ override def getFieldNames: Array[String] = fNames
+
+ override def getFieldTypes: Array[TypeInformation[_]] = fTypes
+
+ override def configure(
+ fieldNames: Array[String],
+ fieldTypes: Array[TypeInformation[_]]): TestingRetractTableSink = {
+ val copy = new TestingRetractTableSink(tz)
+ copy.fNames = fieldNames
+ copy.fTypes = fieldTypes
+ copy.sink = sink
+ copy
+ }
+
+ def getRawResults: List[String] = {
+ sink.getRawResults
+ }
+
+ def getRetractResults: List[String] = {
+ sink.getRetractResults
+ }
+}
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingWithMiniBatchTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingWithMiniBatchTestBase.scala
new file mode 100644
index 0000000000000..56c2593cd9b48
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingWithMiniBatchTestBase.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.runtime.utils
+
+import org.apache.flink.table.api.TableConfigOptions
+import org.apache.flink.table.runtime.utils.StreamingWithStateTestBase.{HEAP_BACKEND, ROCKSDB_BACKEND, StateBackendMode}
+import org.apache.flink.table.runtime.utils.StreamingWithMiniBatchTestBase.{MiniBatchMode, MiniBatchOff, MiniBatchOn}
+
+import java.util
+
+import scala.collection.JavaConversions._
+
+import org.junit.runners.Parameterized
+
+abstract class StreamingWithMiniBatchTestBase(
+ miniBatch: MiniBatchMode,
+ state: StateBackendMode)
+ extends StreamingWithStateTestBase(state) {
+
+ override def before(): Unit = {
+ super.before()
+ // set mini batch
+ val tableConfig = tEnv.getConfig
+ miniBatch match {
+ case MiniBatchOn =>
+ tableConfig.getConf.setLong(TableConfigOptions.SQL_EXEC_MINIBATCH_ALLOW_LATENCY, 1000L)
+ tableConfig.getConf.setLong(TableConfigOptions.SQL_EXEC_MINIBATCH_SIZE, 3L)
+ case MiniBatchOff =>
+ tableConfig.getConf.removeConfig(TableConfigOptions.SQL_EXEC_MINIBATCH_ALLOW_LATENCY)
+ }
+ }
+}
+
+object StreamingWithMiniBatchTestBase {
+
+ case class MiniBatchMode(on: Boolean) {
+ override def toString: String = {
+ if (on) {
+ "MiniBatch=ON"
+ } else {
+ "MiniBatch=OFF"
+ }
+ }
+ }
+
+ val MiniBatchOff = MiniBatchMode(false)
+ val MiniBatchOn = MiniBatchMode(true)
+
+ @Parameterized.Parameters(name = "{0}, StateBackend={1}")
+ def parameters(): util.Collection[Array[java.lang.Object]] = {
+ Seq[Array[AnyRef]](
+ Array(MiniBatchOff, HEAP_BACKEND),
+ Array(MiniBatchOff, ROCKSDB_BACKEND),
+ Array(MiniBatchOn, HEAP_BACKEND),
+ Array(MiniBatchOn, ROCKSDB_BACKEND))
+ }
+
+}
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingWithStateTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingWithStateTestBase.scala
new file mode 100644
index 0000000000000..9a3836ba76703
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingWithStateTestBase.scala
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.utils
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.common.typeutils.CompositeType
+import org.apache.flink.configuration.{CheckpointingOptions, Configuration}
+import org.apache.flink.runtime.state.memory.MemoryStateBackend
+import org.apache.flink.streaming.api.CheckpointingMode
+import org.apache.flink.streaming.api.functions.source.FromElementsFunction
+import org.apache.flink.streaming.api.scala.DataStream
+import org.apache.flink.table.api.{TableEnvironment, Types}
+import org.apache.flink.table.dataformat.{BaseRow, BinaryRow, BinaryRowWriter, BinaryString}
+import org.apache.flink.table.runtime.utils.StreamingWithStateTestBase.{HEAP_BACKEND, ROCKSDB_BACKEND, StateBackendMode}
+import org.apache.flink.table.`type`.TypeConverters.createInternalTypeFromTypeInfo
+import org.apache.flink.table.`type`.RowType
+
+import scala.collection.JavaConversions._
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
+import java.io.File
+import java.util
+
+import org.junit.runners.Parameterized
+import org.junit.{After, Assert, Before}
+
+import org.apache.flink.contrib.streaming.state.RocksDBStateBackend
+
+class StreamingWithStateTestBase(state: StateBackendMode) extends StreamingTestBase {
+
+ enableObjectReuse = state match {
+ case HEAP_BACKEND => false // TODO gemini not support obj reuse now.
+ case ROCKSDB_BACKEND => true
+ }
+
+ private val classLoader = Thread.currentThread.getContextClassLoader
+
+ var baseCheckpointPath: File = _
+
+ @Before
+ override def before(): Unit = {
+ super.before()
+ // set state backend
+ baseCheckpointPath = tempFolder.newFolder().getAbsoluteFile
+ state match {
+ case HEAP_BACKEND =>
+ val conf = new Configuration()
+ conf.setBoolean(CheckpointingOptions.ASYNC_SNAPSHOTS, true)
+ env.setStateBackend(new MemoryStateBackend(
+ "file://" + baseCheckpointPath, null).configure(conf, classLoader))
+ case ROCKSDB_BACKEND =>
+ val conf = new Configuration()
+ conf.setBoolean(CheckpointingOptions.INCREMENTAL_CHECKPOINTS, true)
+ env.setStateBackend(new RocksDBStateBackend(
+ "file://" + baseCheckpointPath).configure(conf, classLoader))
+ }
+ this.tEnv = TableEnvironment.getTableEnvironment(env)
+ FailingCollectionSource.failedBefore = true
+ }
+
+ @After
+ def after(): Unit = {
+ Assert.assertTrue(FailingCollectionSource.failedBefore)
+ }
+
+ /**
+ * Creates a BinaryRow DataStream from the given non-empty [[Seq]].
+ */
+ def failingBinaryRowSource[T: TypeInformation](data: Seq[T]): DataStream[BaseRow] = {
+ val typeInfo = implicitly[TypeInformation[_]].asInstanceOf[CompositeType[_]]
+ val result = new mutable.MutableList[BaseRow]
+ val reuse = new BinaryRow(typeInfo.getArity)
+ val writer = new BinaryRowWriter(reuse)
+ data.foreach {
+ case p: Product =>
+ for (i <- 0 until typeInfo.getArity) {
+ val fieldType = typeInfo.getTypeAt(i).asInstanceOf[TypeInformation[_]]
+ fieldType match {
+ case Types.INT => writer.writeInt(i, p.productElement(i).asInstanceOf[Int])
+ case Types.LONG => writer.writeLong(i, p.productElement(i).asInstanceOf[Long])
+ case Types.STRING => writer.writeString(i,
+ p.productElement(i).asInstanceOf[BinaryString])
+ case Types.BOOLEAN => writer.writeBoolean(i, p.productElement(i).asInstanceOf[Boolean])
+ }
+ }
+ writer.complete()
+ result += reuse.copy()
+ case _ => throw new UnsupportedOperationException
+ }
+ val newTypeInfo = createInternalTypeFromTypeInfo(typeInfo).asInstanceOf[RowType].toTypeInfo
+ failingDataSource(result)(newTypeInfo.asInstanceOf[TypeInformation[BaseRow]])
+ }
+
+ /**
+ * Creates a DataStream from the given non-empty [[Seq]].
+ */
+ def failingDataSource[T: TypeInformation](data: Seq[T]): DataStream[T] = {
+ env.enableCheckpointing(100, CheckpointingMode.EXACTLY_ONCE)
+ env.setRestartStrategy(RestartStrategies.fixedDelayRestart(1, 0))
+ env.setParallelism(1)
+ // reset failedBefore flag to false
+ FailingCollectionSource.reset()
+
+ require(data != null, "Data must not be null.")
+ val typeInfo = implicitly[TypeInformation[T]]
+
+ val collection = scala.collection.JavaConversions.asJavaCollection(data)
+ // must not have null elements and mixed elements
+ FromElementsFunction.checkCollection(data, typeInfo.getTypeClass)
+
+ val function = new FailingCollectionSource[T](
+ typeInfo.createSerializer(env.getConfig),
+ collection,
+ data.length / 2, // fail after half elements
+ false)
+
+ env.addSource(function)(typeInfo).setMaxParallelism(1)
+ }
+
+ private def mapStrEquals(str1: String, str2: String): Boolean = {
+ val array1 = str1.toCharArray
+ val array2 = str2.toCharArray
+ if (array1.length != array2.length) {
+ return false
+ }
+ val l = array1.length
+ val leftBrace = "{".charAt(0)
+ val rightBrace = "}".charAt(0)
+ val equalsChar = "=".charAt(0)
+ val lParenthesis = "(".charAt(0)
+ val rParenthesis = ")".charAt(0)
+ val dot = ",".charAt(0)
+ val whiteSpace = " ".charAt(0)
+ val map1 = Map[String, String]()
+ val map2 = Map[String, String]()
+ var idx = 0
+ def findEquals(ss: CharSequence): Array[Int] = {
+ val ret = new ArrayBuffer[Int]()
+ (0 until ss.length) foreach (idx => if (ss.charAt(idx) == equalsChar) ret += idx)
+ ret.toArray
+ }
+
+ def splitKV(ss: CharSequence, equalsIdx: Int): (String, String) = {
+ // find right, if starts with '(' find until the ')', else until the ','
+ var endFlag = false
+ var curIdx = equalsIdx + 1
+ var endChar = if (ss.charAt(curIdx) == lParenthesis) rParenthesis else dot
+ var valueStr: CharSequence = null
+ var keyStr: CharSequence = null
+ while (curIdx < ss.length && !endFlag) {
+ val curChar = ss.charAt(curIdx)
+ if (curChar != endChar && curChar != rightBrace) {
+ curIdx += 1
+ if (curIdx == ss.length) {
+ valueStr = ss.subSequence(equalsIdx + 1, curIdx)
+ }
+ } else {
+ valueStr = ss.subSequence(equalsIdx + 1, curIdx)
+ endFlag = true
+ }
+ }
+
+ // find left, if starts with ')' find until the '(', else until the ' ,'
+ endFlag = false
+ curIdx = equalsIdx - 1
+ endChar = if (ss.charAt(curIdx) == rParenthesis) lParenthesis else whiteSpace
+ while (curIdx >= 0 && !endFlag) {
+ val curChar = ss.charAt(curIdx)
+ if (curChar != endChar && curChar != leftBrace) {
+ curIdx -= 1
+ if (curIdx == -1) {
+ keyStr = ss.subSequence(0, equalsIdx)
+ }
+ } else {
+ keyStr = ss.subSequence(curIdx, equalsIdx)
+ endFlag = true
+ }
+ }
+ require(keyStr != null)
+ require(valueStr != null)
+ (keyStr.toString, valueStr.toString)
+ }
+
+ def appendStrToMap(ss: CharSequence, m: Map[String, String]): Unit = {
+ val equalsIdxs = findEquals(ss)
+ equalsIdxs.foreach(idx => m + splitKV(ss, idx))
+ }
+
+ while (idx < l) {
+ val char1 = array1(idx)
+ val char2 = array2(idx)
+ if (char1 != char2) {
+ return false
+ }
+
+ if (char1 == leftBrace) {
+ val rightBraceIdx = array1.subSequence(idx + 1, l).toString.indexOf(rightBrace)
+ appendStrToMap(array1.subSequence(idx + 1, rightBraceIdx + idx + 2), map1)
+ idx += rightBraceIdx
+ } else {
+ idx += 1
+ }
+ }
+ map1.equals(map2)
+ }
+
+ def assertMapStrEquals(str1: String, str2: String): Unit = {
+ if (!mapStrEquals(str1, str2)) {
+ throw new AssertionError(s"Expected: $str1 \n Actual: $str2")
+ }
+ }
+}
+
+object StreamingWithStateTestBase {
+
+ case class StateBackendMode(backend: String) {
+ override def toString: String = backend.toString
+ }
+
+ val HEAP_BACKEND = StateBackendMode("HEAP")
+ val ROCKSDB_BACKEND = StateBackendMode("ROCKSDB")
+
+ @Parameterized.Parameters(name = "StateBackend={0}")
+ def parameters(): util.Collection[Array[java.lang.Object]] = {
+ Seq[Array[AnyRef]](Array(HEAP_BACKEND), Array(ROCKSDB_BACKEND))
+ }
+}
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TableUtil.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TableUtil.scala
index 5ec370015c4b1..4595ac81841bd 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TableUtil.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TableUtil.scala
@@ -36,28 +36,29 @@ object TableUtil {
* Note: The difference between print() and collect() is
* - print() prints data on workers and collect() collects data to the client.
* - You have to call TableEnvironment.execute() to run the job for print(), while collect()
- * calls execute automatically.
+ * calls execute automatically.
*/
def collect(table: TableImpl): Seq[Row] = collectSink(table, new CollectRowTableSink, None)
def collect(table: TableImpl, jobName: String): Seq[Row] =
collectSink(table, new CollectRowTableSink, Option.apply(jobName))
- def collectAsT[T](table: TableImpl, t: TypeInformation[_], jobName : String = null): Seq[T] =
+ def collectAsT[T](table: TableImpl, t: TypeInformation[_], jobName: String = null): Seq[T] =
collectSink(
table,
new CollectTableSink(_ => t.asInstanceOf[TypeInformation[T]]), Option(jobName))
def collectSink[T](
- table: TableImpl, sink: CollectTableSink[T], jobName : Option[String] = None): Seq[T] = {
+ table: TableImpl, sink: CollectTableSink[T], jobName: Option[String] = None): Seq[T] = {
// get schema information of table
val rowType = table.getRelNode.getRowType
val fieldNames = rowType.getFieldNames.asScala.toArray
val fieldTypes = rowType.getFieldList
- .map(field => FlinkTypeFactory.toInternalType(field.getType)).toArray
+ .map(field => FlinkTypeFactory.toInternalType(field.getType)).toArray
val configuredSink = sink.configure(
fieldNames, fieldTypes.map(createExternalTypeInfoFromInternalType))
BatchTableEnvUtil.collect(table.tableEnv.asInstanceOf[BatchTableEnvironment],
table, configuredSink.asInstanceOf[CollectTableSink[T]], jobName)
}
+
}
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TimeTestUtil.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TimeTestUtil.scala
new file mode 100644
index 0000000000000..f93c2c2425b69
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TimeTestUtil.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.utils
+
+import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks
+import org.apache.flink.streaming.api.functions.source.SourceFunction
+import org.apache.flink.streaming.api.functions.source.SourceFunction.SourceContext
+import org.apache.flink.streaming.api.operators.{AbstractStreamOperator, OneInputStreamOperator}
+import org.apache.flink.streaming.api.watermark.Watermark
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord
+
+object TimeTestUtil {
+
+ class EventTimeSourceFunction[T](
+ dataWithTimestampList: Seq[Either[(Long, T), Long]]) extends SourceFunction[T] {
+
+ override def run(ctx: SourceContext[T]): Unit = {
+ dataWithTimestampList.foreach {
+ case Left(t) => ctx.collectWithTimestamp(t._2, t._1)
+ case Right(w) => ctx.emitWatermark(new Watermark(w))
+ }
+ }
+
+ override def cancel(): Unit = ???
+ }
+
+ class TimestampAndWatermarkWithOffset[T <: Product](
+ offset: Long) extends AssignerWithPunctuatedWatermarks[T] {
+
+ override def checkAndGetNextWatermark(lastElement: T, extractedTimestamp: Long): Watermark = {
+ new Watermark(extractedTimestamp - offset)
+ }
+
+ override def extractTimestamp(element: T, previousElementTimestamp: Long): Long = {
+ element.productElement(0).asInstanceOf[Long]
+ }
+ }
+
+ class EventTimeProcessOperator[T]
+ extends AbstractStreamOperator[T]
+ with OneInputStreamOperator[Either[(Long, T), Long], T] {
+
+ override def processElement(element: StreamRecord[Either[(Long, T), Long]]): Unit = {
+ element.getValue match {
+ case Left(t) => output.collect(new StreamRecord[T](t._2, t._1))
+ case Right(w) => output.emitWatermark(new Watermark(w))
+ }
+ }
+
+ }
+
+}
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/api/TableConfigOptions.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/api/TableConfigOptions.java
index c472c1f934a18..9a47a3b738d13 100644
--- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/api/TableConfigOptions.java
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/api/TableConfigOptions.java
@@ -133,6 +133,15 @@ public class TableConfigOptions {
.defaultValue(512)
.withDescription("Sets the max buffer memory size for sort. It defines the upper memory for the sort.");
+ // ------------------------------------------------------------------------
+ // topN Options
+ // ------------------------------------------------------------------------
+
+ public static final ConfigOption SQL_EXEC_TOPN_CACHE_SIZE =
+ key("sql.exec.topn.cache.size")
+ .defaultValue(10000L)
+ .withDescription("Cache size of every topn task.");
+
// ------------------------------------------------------------------------
// MiniBatch Options
// ------------------------------------------------------------------------
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryRow.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryRow.java
index 9be8db632248d..e52938960bd3c 100644
--- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryRow.java
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryRow.java
@@ -220,7 +220,7 @@ public void setDecimal(int pos, Decimal value, int precision) {
} else {
byte[] bytes = value.toUnscaledBytes();
- assert(bytes.length <= 16);
+ assert bytes.length <= 16;
// Write the bytes to the variable length portion.
SegmentsUtil.copyFromBytes(segments, offset + cursor, bytes, 0, bytes.length);
@@ -428,4 +428,20 @@ public static String toOriginString(BaseRow row, InternalType[] types) {
build.append(']');
return build.toString();
}
+
+ public boolean equalsWithoutHeader(BaseRow o) {
+ return equalsFrom(o, 1);
+ }
+
+ private boolean equalsFrom(Object o, int startIndex) {
+ if (o != null && o instanceof BinaryRow) {
+ BinaryRow other = (BinaryRow) o;
+ return sizeInBytes == other.sizeInBytes &&
+ SegmentsUtil.equals(
+ segments, offset + startIndex,
+ other.segments, other.offset + startIndex, sizeInBytes - startIndex);
+ } else {
+ return false;
+ }
+ }
}
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryWriter.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryWriter.java
index b1730ec3c3892..aafcb3feeab9c 100644
--- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryWriter.java
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryWriter.java
@@ -18,12 +18,14 @@
package org.apache.flink.table.dataformat;
import org.apache.flink.table.type.ArrayType;
+import org.apache.flink.table.type.DateType;
import org.apache.flink.table.type.DecimalType;
import org.apache.flink.table.type.GenericType;
import org.apache.flink.table.type.InternalType;
import org.apache.flink.table.type.InternalTypes;
import org.apache.flink.table.type.MapType;
import org.apache.flink.table.type.RowType;
+import org.apache.flink.table.type.TimestampType;
import org.apache.flink.table.typeutils.BaseRowSerializer;
/**
@@ -98,11 +100,11 @@ static void write(BinaryWriter writer, int pos, Object o, InternalType type) {
writer.writeString(pos, (BinaryString) o);
} else if (type.equals(InternalTypes.CHAR)) {
writer.writeChar(pos, (char) o);
- } else if (type.equals(InternalTypes.DATE)) {
+ } else if (type instanceof DateType) {
writer.writeInt(pos, (int) o);
} else if (type.equals(InternalTypes.TIME)) {
writer.writeInt(pos, (int) o);
- } else if (type.equals(InternalTypes.TIMESTAMP)) {
+ } else if (type instanceof TimestampType) {
writer.writeLong(pos, (long) o);
} else if (type instanceof DecimalType) {
DecimalType decimalType = (DecimalType) type;
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/DataFormatConverters.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/DataFormatConverters.java
index 101ff35d16253..ebc59bc3f3bda 100644
--- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/DataFormatConverters.java
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/DataFormatConverters.java
@@ -146,9 +146,9 @@ public static DataFormatConverter getConverterForTypeInfo(TypeInformation typeIn
} else if (typeInfo instanceof BinaryMapTypeInfo) {
return BinaryMapConverter.INSTANCE;
} else if (typeInfo instanceof BaseRowTypeInfo) {
- return BaseRowConverter.INSTANCE;
+ return new BaseRowConverter(typeInfo.getArity());
} else if (typeInfo.equals(BasicTypeInfo.BIG_DEC_TYPE_INFO)) {
- return BaseRowConverter.INSTANCE;
+ return new BaseRowConverter(typeInfo.getArity());
} else if (typeInfo instanceof DecimalTypeInfo) {
DecimalTypeInfo decimalType = (DecimalTypeInfo) typeInfo;
return new DecimalConverter(decimalType.precision(), decimalType.scale());
@@ -993,14 +993,13 @@ E toExternalImpl(BaseRow row, int column) {
public static final class BaseRowConverter extends IdentityConverter {
private static final long serialVersionUID = -4470307402371540680L;
+ private int arity;
- public static final BaseRowConverter INSTANCE = new BaseRowConverter();
-
- private BaseRowConverter() {}
+ private BaseRowConverter(int arity) {}
@Override
BaseRow toExternalImpl(BaseRow row, int column) {
- throw new RuntimeException("Not support yet!");
+ return row.getRow(column, arity);
}
}
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/TypeGetterSetters.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/TypeGetterSetters.java
index f1e080ce89d57..ab263c40c3adc 100644
--- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/TypeGetterSetters.java
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/TypeGetterSetters.java
@@ -18,12 +18,14 @@
package org.apache.flink.table.dataformat;
import org.apache.flink.table.type.ArrayType;
+import org.apache.flink.table.type.DateType;
import org.apache.flink.table.type.DecimalType;
import org.apache.flink.table.type.GenericType;
import org.apache.flink.table.type.InternalType;
import org.apache.flink.table.type.InternalTypes;
import org.apache.flink.table.type.MapType;
import org.apache.flink.table.type.RowType;
+import org.apache.flink.table.type.TimestampType;
/**
* Provide type specialized getters and setters to reduce if/else and eliminate box and unbox.
@@ -193,11 +195,11 @@ static Object get(TypeGetterSetters row, int ordinal, InternalType type) {
return row.getString(ordinal);
} else if (type.equals(InternalTypes.CHAR)) {
return row.getChar(ordinal);
- } else if (type.equals(InternalTypes.DATE)) {
+ } else if (type instanceof DateType) {
return row.getInt(ordinal);
} else if (type.equals(InternalTypes.TIME)) {
return row.getInt(ordinal);
- } else if (type.equals(InternalTypes.TIMESTAMP)) {
+ } else if (type instanceof TimestampType) {
return row.getLong(ordinal);
} else if (type instanceof DecimalType) {
DecimalType decimalType = (DecimalType) type;
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/util/BinaryRowUtil.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/util/BinaryRowUtil.java
index 71dc6c2eacd7b..e50e5463c0327 100644
--- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/util/BinaryRowUtil.java
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/util/BinaryRowUtil.java
@@ -18,7 +18,9 @@
package org.apache.flink.table.dataformat.util;
+import org.apache.flink.core.memory.MemorySegmentFactory;
import org.apache.flink.core.memory.MemoryUtils;
+import org.apache.flink.table.dataformat.BinaryRow;
/**
* Util for binary row. Many of the methods in this class are used in code generation.
@@ -29,6 +31,14 @@ public class BinaryRowUtil {
public static final sun.misc.Unsafe UNSAFE = MemoryUtils.UNSAFE;
public static final int BYTE_ARRAY_BASE_OFFSET = UNSAFE.arrayBaseOffset(byte[].class);
+ public static final BinaryRow EMPTY_ROW = new BinaryRow(0);
+
+ static {
+ int size = EMPTY_ROW.getFixedLengthPartSize();
+ byte[] bytes = new byte[size];
+ EMPTY_ROW.pointTo(MemorySegmentFactory.wrap(bytes), 0, size);
+ }
+
public static boolean byteArrayEquals(byte[] left, byte[] right, int length) {
return byteArrayEquals(
left, BYTE_ARRAY_BASE_OFFSET, right, BYTE_ARRAY_BASE_OFFSET, length);
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/bundle/AbstractMapBundleOperator.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/bundle/AbstractMapBundleOperator.java
index d37aa31560c95..8f2c76363c008 100644
--- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/bundle/AbstractMapBundleOperator.java
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/bundle/AbstractMapBundleOperator.java
@@ -57,7 +57,7 @@ public abstract class AbstractMapBundleOperator
private static final long serialVersionUID = 5081841938324118594L;
/** The map in heap to store elements. */
- private final transient Map bundle;
+ private transient Map bundle;
/** The trigger that determines how many elements should be put into a bundle. */
private final BundleTrigger bundleTrigger;
@@ -74,7 +74,6 @@ public abstract class AbstractMapBundleOperator
MapBundleFunction function,
BundleTrigger bundleTrigger) {
chainingStrategy = ChainingStrategy.ALWAYS;
- this.bundle = new HashMap<>();
this.function = checkNotNull(function, "function is null");
this.bundleTrigger = checkNotNull(bundleTrigger, "bundleTrigger is null");
}
@@ -86,6 +85,7 @@ public void open() throws Exception {
this.numOfElements = 0;
this.collector = new StreamRecordCollector<>(output);
+ this.bundle = new HashMap<>();
bundleTrigger.registerCallback(this);
// reset trigger
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionHelper.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionHelper.java
new file mode 100644
index 0000000000000..1c554bd5c1497
--- /dev/null
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionHelper.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.deduplicate;
+
+import org.apache.flink.api.common.state.ValueState;
+import org.apache.flink.table.dataformat.BaseRow;
+import org.apache.flink.table.dataformat.util.BaseRowUtil;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+/**
+ * Utility for deduplicate function.
+ */
+class DeduplicateFunctionHelper {
+
+ /**
+ * Processes element to deduplicate on keys, sends current element as last row, retracts previous element if
+ * needed.
+ *
+ * @param currentRow latest row received by deduplicate function
+ * @param generateRetraction whether need to send retract message to downstream
+ * @param state state of function
+ * @param out underlying collector
+ * @throws Exception
+ */
+ static void processLastRow(BaseRow currentRow, boolean generateRetraction, ValueState state,
+ Collector out) throws Exception {
+ // Check message should be accumulate
+ Preconditions.checkArgument(BaseRowUtil.isAccumulateMsg(currentRow));
+ if (generateRetraction) {
+ // state stores complete row if generateRetraction is true
+ BaseRow preRow = state.value();
+ state.update(currentRow);
+ if (preRow != null) {
+ preRow.setHeader(BaseRowUtil.RETRACT_MSG);
+ out.collect(preRow);
+ }
+ }
+ out.collect(currentRow);
+ }
+
+ /**
+ * Processes element to deduplicate on keys, sends current element if it is first row.
+ *
+ * @param currentRow latest row received by deduplicate function
+ * @param state state of function
+ * @param out underlying collector
+ * @throws Exception
+ */
+ static void processFirstRow(BaseRow currentRow, ValueState state, Collector out)
+ throws Exception {
+ // Check message should be accumulate
+ Preconditions.checkArgument(BaseRowUtil.isAccumulateMsg(currentRow));
+ // ignore record with timestamp bigger than preRow
+ if (state.value() != null) {
+ return;
+ }
+ state.update(true);
+ out.collect(currentRow);
+ }
+
+ private DeduplicateFunctionHelper() {
+
+ }
+}
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateKeepFirstRowFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateKeepFirstRowFunction.java
new file mode 100644
index 0000000000000..14feaf981fd32
--- /dev/null
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateKeepFirstRowFunction.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.deduplicate;
+
+import org.apache.flink.api.common.state.ValueState;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.table.dataformat.BaseRow;
+import org.apache.flink.table.runtime.functions.KeyedProcessFunctionWithCleanupState;
+import org.apache.flink.util.Collector;
+
+import static org.apache.flink.table.runtime.deduplicate.DeduplicateFunctionHelper.processFirstRow;
+
+/**
+ * This function is used to deduplicate on keys and keeps only first row.
+ */
+public class DeduplicateKeepFirstRowFunction
+ extends KeyedProcessFunctionWithCleanupState {
+
+ private static final long serialVersionUID = 5865777137707602549L;
+
+ // state stores a boolean flag to indicate whether key appears before.
+ private ValueState state;
+
+ public DeduplicateKeepFirstRowFunction(long minRetentionTime, long maxRetentionTime) {
+ super(minRetentionTime, maxRetentionTime);
+ }
+
+ @Override
+ public void open(Configuration configure) throws Exception {
+ super.open(configure);
+ initCleanupTimeState("DeduplicateFunctionKeepFirstRow");
+ ValueStateDescriptor stateDesc = new ValueStateDescriptor<>("existsState", Types.BOOLEAN);
+ state = getRuntimeContext().getState(stateDesc);
+ }
+
+ @Override
+ public void processElement(BaseRow input, Context ctx, Collector out) throws Exception {
+ long currentTime = ctx.timerService().currentProcessingTime();
+ // register state-cleanup timer
+ registerProcessingCleanupTimer(ctx, currentTime);
+ processFirstRow(input, state, out);
+ }
+
+ @Override
+ public void onTimer(long timestamp, OnTimerContext ctx, Collector out) throws Exception {
+ if (stateCleaningEnabled) {
+ cleanupState(state);
+ }
+ }
+}
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateKeepLastRowFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateKeepLastRowFunction.java
new file mode 100644
index 0000000000000..3dcde66eb5ea6
--- /dev/null
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateKeepLastRowFunction.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.deduplicate;
+
+import org.apache.flink.api.common.state.ValueState;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.table.dataformat.BaseRow;
+import org.apache.flink.table.runtime.functions.KeyedProcessFunctionWithCleanupState;
+import org.apache.flink.table.typeutils.BaseRowTypeInfo;
+import org.apache.flink.util.Collector;
+
+import static org.apache.flink.table.runtime.deduplicate.DeduplicateFunctionHelper.processLastRow;
+
+/**
+ * This function is used to deduplicate on keys and keeps only last row.
+ */
+public class DeduplicateKeepLastRowFunction
+ extends KeyedProcessFunctionWithCleanupState {
+
+ private static final long serialVersionUID = -291348892087180350L;
+ private final BaseRowTypeInfo rowTypeInfo;
+ private final boolean generateRetraction;
+
+ // state stores complete row.
+ private ValueState state;
+
+ public DeduplicateKeepLastRowFunction(long minRetentionTime, long maxRetentionTime, BaseRowTypeInfo rowTypeInfo,
+ boolean generateRetraction) {
+ super(minRetentionTime, maxRetentionTime);
+ this.rowTypeInfo = rowTypeInfo;
+ this.generateRetraction = generateRetraction;
+ }
+
+ @Override
+ public void open(Configuration configure) throws Exception {
+ super.open(configure);
+ if (generateRetraction) {
+ // state stores complete row if need generate retraction, otherwise do not need a state
+ initCleanupTimeState("DeduplicateFunctionKeepLastRow");
+ ValueStateDescriptor stateDesc = new ValueStateDescriptor<>("preRowState", rowTypeInfo);
+ state = getRuntimeContext().getState(stateDesc);
+ }
+ }
+
+ @Override
+ public void processElement(BaseRow input, Context ctx, Collector out) throws Exception {
+ if (generateRetraction) {
+ long currentTime = ctx.timerService().currentProcessingTime();
+ // register state-cleanup timer
+ registerProcessingCleanupTimer(ctx, currentTime);
+ }
+ processLastRow(input, generateRetraction, state, out);
+ }
+
+ @Override
+ public void onTimer(long timestamp, OnTimerContext ctx, Collector out) throws Exception {
+ if (stateCleaningEnabled) {
+ cleanupState(state);
+ }
+ }
+}
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateKeepFirstRowFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateKeepFirstRowFunction.java
new file mode 100644
index 0000000000000..b97d1a670ac3b
--- /dev/null
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateKeepFirstRowFunction.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.deduplicate;
+
+import org.apache.flink.api.common.state.ValueState;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.table.dataformat.BaseRow;
+import org.apache.flink.table.runtime.bundle.MapBundleFunction;
+import org.apache.flink.table.runtime.context.ExecutionContext;
+import org.apache.flink.util.Collector;
+
+import javax.annotation.Nullable;
+
+import java.util.Map;
+
+import static org.apache.flink.table.runtime.deduplicate.DeduplicateFunctionHelper.processFirstRow;
+
+/**
+ * This function is used to get the first row for every key partition in miniBatch mode.
+ */
+public class MiniBatchDeduplicateKeepFirstRowFunction
+ extends MapBundleFunction {
+
+ private static final long serialVersionUID = -7994602893547654994L;
+
+ private final TypeSerializer typeSerializer;
+
+ // state stores a boolean flag to indicate whether key appears before.
+ private ValueState state;
+
+ public MiniBatchDeduplicateKeepFirstRowFunction(TypeSerializer typeSerializer) {
+ this.typeSerializer = typeSerializer;
+ }
+
+ @Override
+ public void open(ExecutionContext ctx) throws Exception {
+ super.open(ctx);
+ ValueStateDescriptor stateDesc = new ValueStateDescriptor<>("existsState", Types.BOOLEAN);
+ state = ctx.getRuntimeContext().getState(stateDesc);
+ }
+
+ @Override
+ public BaseRow addInput(@Nullable BaseRow value, BaseRow input) {
+ if (value == null) {
+ // put the input into buffer
+ return typeSerializer.copy(input);
+ } else {
+ // the input is not first row, ignore it
+ return value;
+ }
+ }
+
+ @Override
+ public void finishBundle(
+ Map buffer, Collector out) throws Exception {
+ for (Map.Entry entry : buffer.entrySet()) {
+ BaseRow currentKey = entry.getKey();
+ BaseRow currentRow = entry.getValue();
+ ctx.setCurrentKey(currentKey);
+ processFirstRow(currentRow, state, out);
+ }
+ }
+}
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateKeepLastRowFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateKeepLastRowFunction.java
new file mode 100644
index 0000000000000..c1f2ec40cf5b8
--- /dev/null
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateKeepLastRowFunction.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.deduplicate;
+
+import org.apache.flink.api.common.state.ValueState;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.table.dataformat.BaseRow;
+import org.apache.flink.table.runtime.bundle.MapBundleFunction;
+import org.apache.flink.table.runtime.context.ExecutionContext;
+import org.apache.flink.table.typeutils.BaseRowTypeInfo;
+import org.apache.flink.util.Collector;
+
+import javax.annotation.Nullable;
+
+import java.util.Map;
+
+import static org.apache.flink.table.runtime.deduplicate.DeduplicateFunctionHelper.processLastRow;
+
+/**
+ * This function is used to get the last row for every key partition in miniBatch mode.
+ */
+public class MiniBatchDeduplicateKeepLastRowFunction
+ extends MapBundleFunction {
+
+ private static final long serialVersionUID = -8981813609115029119L;
+
+ private final BaseRowTypeInfo rowTypeInfo;
+ private final boolean generateRetraction;
+ private final TypeSerializer typeSerializer;
+
+ // state stores complete row.
+ private ValueState state;
+
+ public MiniBatchDeduplicateKeepLastRowFunction(BaseRowTypeInfo rowTypeInfo, boolean generateRetraction,
+ TypeSerializer typeSerializer) {
+ this.rowTypeInfo = rowTypeInfo;
+ this.generateRetraction = generateRetraction;
+ this.typeSerializer = typeSerializer;
+ }
+
+ @Override
+ public void open(ExecutionContext ctx) throws Exception {
+ super.open(ctx);
+ ValueStateDescriptor stateDesc = new ValueStateDescriptor<>("preRowState", rowTypeInfo);
+ state = ctx.getRuntimeContext().getState(stateDesc);
+ }
+
+ @Override
+ public BaseRow addInput(@Nullable BaseRow value, BaseRow input) {
+ // always put the input into buffer
+ return typeSerializer.copy(input);
+ }
+
+ @Override
+ public void finishBundle(
+ Map buffer, Collector out) throws Exception {
+ for (Map.Entry entry : buffer.entrySet()) {
+ BaseRow currentKey = entry.getKey();
+ BaseRow currentRow = entry.getValue();
+ ctx.setCurrentKey(currentKey);
+ processLastRow(currentRow, generateRetraction, state, out);
+ }
+ }
+}
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/CleanupState.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/CleanupState.java
new file mode 100644
index 0000000000000..ddb823c60b003
--- /dev/null
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/CleanupState.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.functions;
+
+import org.apache.flink.api.common.state.ValueState;
+import org.apache.flink.streaming.api.TimerService;
+import org.apache.flink.streaming.api.functions.ProcessFunction;
+import org.apache.flink.streaming.api.functions.co.CoProcessFunction;
+
+/**
+ * Base interface for clean up state, both for {@link ProcessFunction} and {@link CoProcessFunction}.
+ */
+public interface CleanupState {
+
+ default void registerProcessingCleanupTimer(
+ ValueState cleanupTimeState,
+ long currentTime,
+ long minRetentionTime,
+ long maxRetentionTime,
+ TimerService timerService) throws Exception {
+
+ // last registered timer
+ Long curCleanupTime = cleanupTimeState.value();
+
+ // check if a cleanup timer is registered and
+ // that the current cleanup timer won't delete state we need to keep
+ if (curCleanupTime == null || (currentTime + minRetentionTime) > curCleanupTime) {
+ // we need to register a new (later) timer
+ long cleanupTime = currentTime + maxRetentionTime;
+ // register timer and remember clean-up time
+ timerService.registerProcessingTimeTimer(cleanupTime);
+ // delete expired timer
+ if (curCleanupTime != null) {
+ timerService.deleteProcessingTimeTimer(curCleanupTime);
+ }
+ cleanupTimeState.update(cleanupTime);
+ }
+ }
+}
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/KeyedProcessFunctionWithCleanupState.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/KeyedProcessFunctionWithCleanupState.java
new file mode 100644
index 0000000000000..c5797632f71c9
--- /dev/null
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/KeyedProcessFunctionWithCleanupState.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.functions;
+
+import org.apache.flink.api.common.state.State;
+import org.apache.flink.api.common.state.ValueState;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.streaming.api.TimeDomain;
+import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
+
+/**
+ * A function that processes elements of a stream, and could cleanup state.
+ * @param Type of the key.
+ * @param Type of the input elements.
+ * @param Type of the output elements.
+ */
+public abstract class KeyedProcessFunctionWithCleanupState
+ extends KeyedProcessFunction implements CleanupState {
+
+ private static final long serialVersionUID = 2084560869233898457L;
+
+ protected final long minRetentionTime;
+ protected final long maxRetentionTime;
+ protected final boolean stateCleaningEnabled;
+
+ // holds the latest registered cleanup timer
+ private ValueState cleanupTimeState;
+
+ public KeyedProcessFunctionWithCleanupState(long minRetentionTime, long maxRetentionTime) {
+ this.minRetentionTime = minRetentionTime;
+ this.maxRetentionTime = maxRetentionTime;
+ this.stateCleaningEnabled = minRetentionTime > 1;
+ }
+
+ protected void initCleanupTimeState(String stateName) {
+ if (stateCleaningEnabled) {
+ ValueStateDescriptor inputCntDescriptor = new ValueStateDescriptor<>(stateName, Types.LONG);
+ cleanupTimeState = getRuntimeContext().getState(inputCntDescriptor);
+ }
+ }
+
+ protected void registerProcessingCleanupTimer(Context ctx, long currentTime) throws Exception {
+ if (stateCleaningEnabled) {
+ registerProcessingCleanupTimer(
+ cleanupTimeState,
+ currentTime,
+ minRetentionTime,
+ maxRetentionTime,
+ ctx.timerService()
+ );
+ }
+ }
+
+ protected boolean isProcessingTimeTimer(OnTimerContext ctx) {
+ return ctx.timeDomain() == TimeDomain.PROCESSING_TIME;
+ }
+
+ protected void cleanupState(State... states) {
+ for (State state : states) {
+ state.clear();
+ }
+ this.cleanupTimeState.clear();
+ }
+
+}
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/BaseRowKeySelector.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/BaseRowKeySelector.java
new file mode 100644
index 0000000000000..2ad9e5dd2c393
--- /dev/null
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/BaseRowKeySelector.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.keyselector;
+
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
+import org.apache.flink.table.dataformat.BaseRow;
+import org.apache.flink.table.typeutils.BaseRowTypeInfo;
+
+/**
+ * BaseRowKeySelector takes an BaseRow and extracts the deterministic key for the BaseRow.
+ */
+public interface BaseRowKeySelector extends KeySelector, ResultTypeQueryable {
+
+ BaseRowTypeInfo getProducedType();
+
+}
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/BinaryRowKeySelector.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/BinaryRowKeySelector.java
new file mode 100644
index 0000000000000..3d7887a57afcd
--- /dev/null
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/BinaryRowKeySelector.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.keyselector;
+
+import org.apache.flink.table.dataformat.BaseRow;
+import org.apache.flink.table.dataformat.BinaryRow;
+import org.apache.flink.table.generated.GeneratedProjection;
+import org.apache.flink.table.generated.Projection;
+import org.apache.flink.table.typeutils.BaseRowTypeInfo;
+
+/**
+ * A utility class which extracts key from BaseRow. The key type is BinaryRow.
+ */
+public class BinaryRowKeySelector implements BaseRowKeySelector {
+
+ private static final long serialVersionUID = 5375355285015381919L;
+
+ private final BaseRowTypeInfo keyRowType;
+ private final GeneratedProjection generatedProjection;
+ private transient Projection projection;
+
+ public BinaryRowKeySelector(BaseRowTypeInfo keyRowType, GeneratedProjection generatedProjection) {
+ this.keyRowType = keyRowType;
+ this.generatedProjection = generatedProjection;
+ }
+
+ @Override
+ public BaseRow getKey(BaseRow value) throws Exception {
+ if (projection == null) {
+ projection = generatedProjection.newInstance(Thread.currentThread().getContextClassLoader());
+ }
+ return projection.apply(value).copy();
+ }
+
+ @Override
+ public BaseRowTypeInfo getProducedType() {
+ return keyRowType;
+ }
+
+}
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/NullBinaryRowKeySelector.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/NullBinaryRowKeySelector.java
new file mode 100644
index 0000000000000..795bdbf5a3796
--- /dev/null
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/NullBinaryRowKeySelector.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.keyselector;
+
+import org.apache.flink.table.dataformat.BaseRow;
+import org.apache.flink.table.dataformat.util.BinaryRowUtil;
+import org.apache.flink.table.typeutils.BaseRowTypeInfo;
+
+/**
+ * A utility class which key is always empty no matter what the input row is.
+ */
+public class NullBinaryRowKeySelector implements BaseRowKeySelector {
+
+ private static final long serialVersionUID = -2079386198687082032L;
+
+ private final BaseRowTypeInfo returnType = new BaseRowTypeInfo();
+
+ @Override
+ public BaseRow getKey(BaseRow value) throws Exception {
+ return BinaryRowUtil.EMPTY_ROW;
+ }
+
+ @Override
+ public BaseRowTypeInfo getProducedType() {
+ return returnType;
+ }
+}
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractRankFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractRankFunction.java
new file mode 100644
index 0000000000000..5c44133ded5a6
--- /dev/null
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractRankFunction.java
@@ -0,0 +1,311 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.rank;
+
+import org.apache.flink.api.common.state.ValueState;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.metrics.Counter;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.streaming.api.operators.KeyContext;
+import org.apache.flink.table.dataformat.BaseRow;
+import org.apache.flink.table.dataformat.GenericRow;
+import org.apache.flink.table.dataformat.JoinedRow;
+import org.apache.flink.table.dataformat.util.BaseRowUtil;
+import org.apache.flink.table.generated.GeneratedRecordComparator;
+import org.apache.flink.table.generated.GeneratedRecordEqualiser;
+import org.apache.flink.table.generated.RecordEqualiser;
+import org.apache.flink.table.runtime.functions.KeyedProcessFunctionWithCleanupState;
+import org.apache.flink.table.runtime.keyselector.BaseRowKeySelector;
+import org.apache.flink.table.type.InternalType;
+import org.apache.flink.table.type.InternalTypes;
+import org.apache.flink.table.typeutils.BaseRowTypeInfo;
+import org.apache.flink.util.Collector;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Collection;
+import java.util.Comparator;
+import java.util.Map;
+
+/**
+ * Base class for Rank Function.
+ */
+public abstract class AbstractRankFunction extends KeyedProcessFunctionWithCleanupState {
+
+ private static final Logger LOG = LoggerFactory.getLogger(AbstractRankFunction.class);
+
+ private static final String RANK_UNSUPPORTED_MSG = "RANK() on streaming table is not supported currently";
+
+ private static final String DENSE_RANK_UNSUPPORTED_MSG = "DENSE_RANK() on streaming table is not supported currently";
+
+ private static final String WITHOUT_RANK_END_UNSUPPORTED_MSG = "Rank end is not specified. Currently rank only support TopN, which means the rank end must be specified.";
+
+ // we set default topN size to 100
+ private static final long DEFAULT_TOPN_SIZE = 100;
+
+ /**
+ * The util to compare two BaseRow equals to each other.
+ * As different BaseRow can't be equals directly, we use a code generated util to handle this.
+ */
+ private GeneratedRecordEqualiser generatedEqualiser;
+ protected RecordEqualiser equaliser;
+
+ /**
+ * The util to compare two sortKey equals to each other.
+ */
+ private GeneratedRecordComparator generatedSortKeyComparator;
+ protected Comparator sortKeyComparator;
+
+ private final boolean generateRetraction;
+ protected final boolean outputRankNumber;
+ protected final BaseRowTypeInfo inputRowType;
+ protected final KeySelector sortKeySelector;
+
+ protected KeyContext keyContext;
+ private final boolean isConstantRankEnd;
+ private final long rankStart;
+ protected long rankEnd;
+ private final int rankEndIndex;
+ private ValueState rankEndState;
+ private Counter invalidCounter;
+ private JoinedRow outputRow;
+
+ // metrics
+ protected long hitCount = 0L;
+ protected long requestCount = 0L;
+
+ AbstractRankFunction(long minRetentionTime, long maxRetentionTime, BaseRowTypeInfo inputRowType,
+ GeneratedRecordComparator generatedSortKeyComparator, BaseRowKeySelector sortKeySelector,
+ RankType rankType, RankRange rankRange, GeneratedRecordEqualiser generatedEqualiser,
+ boolean generateRetraction, boolean outputRankNumber) {
+ super(minRetentionTime, maxRetentionTime);
+ // TODO support RANK and DENSE_RANK
+ switch (rankType) {
+ case ROW_NUMBER:
+ break;
+ case RANK:
+ LOG.error(RANK_UNSUPPORTED_MSG);
+ throw new UnsupportedOperationException(RANK_UNSUPPORTED_MSG);
+ case DENSE_RANK:
+ LOG.error(DENSE_RANK_UNSUPPORTED_MSG);
+ throw new UnsupportedOperationException(DENSE_RANK_UNSUPPORTED_MSG);
+ default:
+ LOG.error("Streaming tables do not support {}", rankType.name());
+ throw new UnsupportedOperationException("Streaming tables do not support " + rankType.toString());
+ }
+
+ if (rankRange instanceof ConstantRankRange) {
+ ConstantRankRange constantRankRange = (ConstantRankRange) rankRange;
+ isConstantRankEnd = true;
+ rankStart = constantRankRange.getRankStart();
+ rankEnd = constantRankRange.getRankEnd();
+ rankEndIndex = -1;
+ } else if (rankRange instanceof VariableRankRange) {
+ VariableRankRange variableRankRange = (VariableRankRange) rankRange;
+ int rankEndIdx = variableRankRange.getRankEndIndex();
+ InternalType rankEndIdxType = inputRowType.getInternalTypes()[rankEndIdx];
+ if (!rankEndIdxType.equals(InternalTypes.LONG)) {
+ LOG.error("variable rank index column must be long type, while input type is {}",
+ rankEndIdxType.getClass().getName());
+ throw new UnsupportedOperationException(
+ "variable rank index column must be long type, while input type is " +
+ rankEndIdxType.getClass().getName());
+ }
+ rankEndIndex = rankEndIdx;
+ isConstantRankEnd = false;
+ rankStart = -1;
+ rankEnd = -1;
+
+ } else {
+ LOG.error(WITHOUT_RANK_END_UNSUPPORTED_MSG);
+ throw new UnsupportedOperationException(WITHOUT_RANK_END_UNSUPPORTED_MSG);
+ }
+ this.generatedEqualiser = generatedEqualiser;
+ this.generatedSortKeyComparator = generatedSortKeyComparator;
+ this.generateRetraction = generateRetraction;
+ this.inputRowType = inputRowType;
+ this.outputRankNumber = outputRankNumber;
+ this.sortKeySelector = sortKeySelector;
+ }
+
+ @Override
+ public void open(Configuration parameters) throws Exception {
+ super.open(parameters);
+ initCleanupTimeState("RankFunctionCleanupTime");
+ outputRow = new JoinedRow();
+
+ if (!isConstantRankEnd) {
+ ValueStateDescriptor rankStateDesc = new ValueStateDescriptor<>("rankEnd", Types.LONG);
+ rankEndState = getRuntimeContext().getState(rankStateDesc);
+ }
+ // compile equaliser
+ equaliser = generatedEqualiser.newInstance(getRuntimeContext().getUserCodeClassLoader());
+ generatedEqualiser = null;
+ // compile comparator
+ sortKeyComparator = generatedSortKeyComparator.newInstance(getRuntimeContext().getUserCodeClassLoader());
+ generatedSortKeyComparator = null;
+ invalidCounter = getRuntimeContext().getMetricGroup().counter("topn.invalidTopSize");
+ }
+
+ /**
+ * Gets default topN size.
+ *
+ * @return default topN size
+ */
+ protected long getDefaultTopNSize() {
+ return isConstantRankEnd ? rankEnd : DEFAULT_TOPN_SIZE;
+ }
+
+ /**
+ * Initialize rank end.
+ *
+ * @param row input record
+ * @return rank end
+ * @throws Exception
+ */
+ protected long initRankEnd(BaseRow row) throws Exception {
+ if (isConstantRankEnd) {
+ return rankEnd;
+ } else {
+ Long rankEndValue = rankEndState.value();
+ long curRankEnd = row.getLong(rankEndIndex);
+ if (rankEndValue == null) {
+ rankEnd = curRankEnd;
+ rankEndState.update(rankEnd);
+ return rankEnd;
+ } else {
+ rankEnd = rankEndValue;
+ if (rankEnd != curRankEnd) {
+ // increment the invalid counter when the current rank end not equal to previous rank end
+ invalidCounter.inc();
+ }
+ return rankEnd;
+ }
+ }
+ }
+
+ /**
+ * Checks whether the record should be put into the buffer.
+ *
+ * @param sortKey sortKey to test
+ * @param buffer buffer to add
+ * @return true if the record should be put into the buffer.
+ */
+ protected boolean checkSortKeyInBufferRange(BaseRow sortKey, TopNBuffer buffer) {
+ Comparator comparator = buffer.getSortKeyComparator();
+ Map.Entry> worstEntry = buffer.lastEntry();
+ if (worstEntry == null) {
+ // return true if the buffer is empty.
+ return true;
+ } else {
+ BaseRow worstKey = worstEntry.getKey();
+ int compare = comparator.compare(sortKey, worstKey);
+ if (compare < 0) {
+ return true;
+ } else {
+ return buffer.getCurrentTopNum() < getDefaultTopNSize();
+ }
+ }
+ }
+
+ protected void registerMetric(long heapSize) {
+ getRuntimeContext().getMetricGroup().>gauge(
+ "topn.cache.hitRate",
+ () -> requestCount == 0 ? 1.0 : Long.valueOf(hitCount).doubleValue() / requestCount);
+
+ getRuntimeContext().getMetricGroup().>gauge(
+ "topn.cache.size", () -> heapSize);
+ }
+
+ protected void collect(Collector out, BaseRow inputRow) {
+ BaseRowUtil.setAccumulate(inputRow);
+ out.collect(inputRow);
+ }
+
+ /**
+ * This is similar to [[retract()]] but always send retraction message regardless of generateRetraction is true or
+ * not.
+ */
+ protected void delete(Collector out, BaseRow inputRow) {
+ BaseRowUtil.setRetract(inputRow);
+ out.collect(inputRow);
+ }
+
+ /**
+ * This is with-row-number version of above delete() method.
+ */
+ protected void delete(Collector out, BaseRow inputRow, long rank) {
+ if (isInRankRange(rank)) {
+ out.collect(createOutputRow(inputRow, rank, BaseRowUtil.RETRACT_MSG));
+ }
+ }
+
+ protected void collect(Collector out, BaseRow inputRow, long rank) {
+ if (isInRankRange(rank)) {
+ out.collect(createOutputRow(inputRow, rank, BaseRowUtil.ACCUMULATE_MSG));
+ }
+ }
+
+ protected void retract(Collector out, BaseRow inputRow, long rank) {
+ if (generateRetraction && isInRankRange(rank)) {
+ out.collect(createOutputRow(inputRow, rank, BaseRowUtil.RETRACT_MSG));
+ }
+ }
+
+ protected boolean isInRankEnd(long rank) {
+ return rank <= rankEnd;
+ }
+
+ protected boolean isInRankRange(long rank) {
+ return rank <= rankEnd && rank >= rankStart;
+ }
+
+ protected boolean hasOffset() {
+ // rank start is 1-based
+ return rankStart > 1;
+ }
+
+ private BaseRow createOutputRow(BaseRow inputRow, long rank, byte header) {
+ if (outputRankNumber) {
+ GenericRow rankRow = new GenericRow(1);
+ rankRow.setField(0, rank);
+
+ outputRow.replace(inputRow, rankRow);
+ outputRow.setHeader(header);
+ return outputRow;
+ } else {
+ inputRow.setHeader(header);
+ return inputRow;
+ }
+ }
+
+ /**
+ * Sets keyContext to RankFunction.
+ *
+ * @param keyContext keyContext of current function.
+ */
+ public void setKeyContext(KeyContext keyContext) {
+ this.keyContext = keyContext;
+ }
+
+}
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AppendRankFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AppendRankFunction.java
new file mode 100644
index 0000000000000..73e7edb631a3c
--- /dev/null
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AppendRankFunction.java
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.rank;
+
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.state.MapState;
+import org.apache.flink.api.common.state.MapStateDescriptor;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.typeutils.ListTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.table.dataformat.BaseRow;
+import org.apache.flink.table.generated.GeneratedRecordComparator;
+import org.apache.flink.table.generated.GeneratedRecordEqualiser;
+import org.apache.flink.table.runtime.keyselector.BaseRowKeySelector;
+import org.apache.flink.table.runtime.util.LRUMap;
+import org.apache.flink.table.typeutils.BaseRowTypeInfo;
+import org.apache.flink.util.Collector;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * AppendRankFunction's input stream only contains append record.
+ */
+public class AppendRankFunction extends AbstractRankFunction {
+
+ private static final long serialVersionUID = -4708453213104128010L;
+
+ private static final Logger LOG = LoggerFactory.getLogger(AppendRankFunction.class);
+
+ private final BaseRowTypeInfo sortKeyType;
+ private final TypeSerializer inputRowSer;
+ private final long cacheSize;
+
+ // a map state stores mapping from sort key to records list which is in topN
+ private transient MapState> dataState;
+
+ // the buffer stores mapping from sort key to records list, a heap mirror to dataState
+ private transient TopNBuffer buffer;
+
+ // the kvSortedMap stores mapping from partition key to it's buffer
+ private transient Map kvSortedMap;
+
+ public AppendRankFunction(long minRetentionTime, long maxRetentionTime, BaseRowTypeInfo inputRowType,
+ GeneratedRecordComparator sortKeyGeneratedRecordComparator, BaseRowKeySelector sortKeySelector,
+ RankType rankType, RankRange rankRange, GeneratedRecordEqualiser generatedEqualiser,
+ boolean generateRetraction, boolean outputRankNumber, long cacheSize) {
+ super(minRetentionTime, maxRetentionTime, inputRowType, sortKeyGeneratedRecordComparator, sortKeySelector,
+ rankType, rankRange, generatedEqualiser, generateRetraction, outputRankNumber);
+ this.sortKeyType = sortKeySelector.getProducedType();
+ this.inputRowSer = inputRowType.createSerializer(new ExecutionConfig());
+ this.cacheSize = cacheSize;
+ }
+
+ public void open(Configuration parameters) throws Exception {
+ super.open(parameters);
+ int lruCacheSize = Math.max(1, (int) (cacheSize / getDefaultTopNSize()));
+ kvSortedMap = new LRUMap<>(lruCacheSize);
+ LOG.info("Top{} operator is using LRU caches key-size: {}", getDefaultTopNSize(), lruCacheSize);
+
+ ListTypeInfo valueTypeInfo = new ListTypeInfo<>(inputRowType);
+ MapStateDescriptor> mapStateDescriptor = new MapStateDescriptor<>(
+ "data-state-with-append", sortKeyType, valueTypeInfo);
+ dataState = getRuntimeContext().getMapState(mapStateDescriptor);
+
+ // metrics
+ registerMetric(kvSortedMap.size() * getDefaultTopNSize());
+ }
+
+ @Override
+ public void processElement(BaseRow input, Context context, Collector out) throws Exception {
+ long currentTime = context.timerService().currentProcessingTime();
+ // register state-cleanup timer
+ registerProcessingCleanupTimer(context, currentTime);
+
+ initHeapStates();
+ initRankEnd(input);
+
+ BaseRow sortKey = sortKeySelector.getKey(input);
+ // check whether the sortKey is in the topN range
+ if (checkSortKeyInBufferRange(sortKey, buffer)) {
+ // insert sort key into buffer
+ buffer.put(sortKey, inputRowSer.copy(input));
+ Collection inputs = buffer.get(sortKey);
+ // update data state
+ dataState.put(sortKey, (List) inputs);
+ if (outputRankNumber || hasOffset()) {
+ // the without-number-algorithm can't handle topN with offset,
+ // so use the with-number-algorithm to handle offset
+ processElementWithRowNumber(sortKey, input, out);
+ } else {
+ processElementWithoutRowNumber(input, out);
+ }
+ }
+ }
+
+ @Override
+ public void onTimer(
+ long timestamp,
+ OnTimerContext ctx,
+ Collector out) throws Exception {
+ if (stateCleaningEnabled) {
+ // cleanup cache
+ kvSortedMap.remove(keyContext.getCurrentKey());
+ cleanupState(dataState);
+ }
+ }
+
+ private void initHeapStates() throws Exception {
+ requestCount += 1;
+ BaseRow currentKey = (BaseRow) keyContext.getCurrentKey();
+ buffer = kvSortedMap.get(currentKey);
+ if (buffer == null) {
+ buffer = new TopNBuffer(sortKeyComparator, ArrayList::new);
+ kvSortedMap.put(currentKey, buffer);
+ // restore buffer
+ Iterator>> iter = dataState.iterator();
+ if (iter != null) {
+ while (iter.hasNext()) {
+ Map.Entry> entry = iter.next();
+ BaseRow sortKey = entry.getKey();
+ List values = entry.getValue();
+ // the order is preserved
+ buffer.putAll(sortKey, values);
+ }
+ }
+ } else {
+ hitCount += 1;
+ }
+ }
+
+ private void processElementWithRowNumber(BaseRow sortKey, BaseRow input, Collector out) throws Exception {
+ Iterator>> iterator = buffer.entrySet().iterator();
+ long curRank = 0L;
+ boolean findsSortKey = false;
+ while (iterator.hasNext() && isInRankEnd(curRank)) {
+ Map.Entry> entry = iterator.next();
+ Collection records = entry.getValue();
+ // meet its own sort key
+ if (!findsSortKey && entry.getKey().equals(sortKey)) {
+ curRank += records.size();
+ collect(out, input, curRank);
+ findsSortKey = true;
+ } else if (findsSortKey) {
+ Iterator recordsIter = records.iterator();
+ while (recordsIter.hasNext() && isInRankEnd(curRank)) {
+ curRank += 1;
+ BaseRow prevRow = recordsIter.next();
+ retract(out, prevRow, curRank - 1);
+ collect(out, prevRow, curRank);
+ }
+ } else {
+ curRank += records.size();
+ }
+ }
+
+ // remove the records associated to the sort key which is out of topN
+ List toDeleteSortKeys = new ArrayList<>();
+ while (iterator.hasNext()) {
+ Map.Entry> entry = iterator.next();
+ BaseRow key = entry.getKey();
+ dataState.remove(key);
+ toDeleteSortKeys.add(key);
+ }
+ for (BaseRow toDeleteKey : toDeleteSortKeys) {
+ buffer.removeAll(toDeleteKey);
+ }
+ }
+
+ private void processElementWithoutRowNumber(BaseRow input, Collector out) throws Exception {
+ // remove retired element
+ if (buffer.getCurrentTopNum() > rankEnd) {
+ Map.Entry> lastEntry = buffer.lastEntry();
+ BaseRow lastKey = lastEntry.getKey();
+ List lastList = (List) lastEntry.getValue();
+ // remove last one
+ BaseRow lastElement = lastList.remove(lastList.size() - 1);
+ if (lastList.isEmpty()) {
+ buffer.removeAll(lastKey);
+ dataState.remove(lastKey);
+ } else {
+ dataState.put(lastKey, lastList);
+ }
+ if (input.equals(lastElement)) {
+ return;
+ } else {
+ // lastElement shouldn't be null
+ delete(out, lastElement);
+ }
+ }
+ collect(out, input);
+ }
+
+}
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/ConstantRankRange.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/ConstantRankRange.java
new file mode 100644
index 0000000000000..b4bb201308293
--- /dev/null
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/ConstantRankRange.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.rank;
+
+import java.util.List;
+
+/** rankStart and rankEnd are inclusive, rankStart always start from one. */
+public class ConstantRankRange implements RankRange {
+
+ private static final long serialVersionUID = 9062345289888078376L;
+ private long rankStart;
+ private long rankEnd;
+
+ public ConstantRankRange(long rankStart, long rankEnd) {
+ this.rankStart = rankStart;
+ this.rankEnd = rankEnd;
+ }
+
+ public long getRankStart() {
+ return rankStart;
+ }
+
+ public long getRankEnd() {
+ return rankEnd;
+ }
+
+ @Override
+ public String toString(List inputFieldNames) {
+ return toString();
+ }
+
+ @Override
+ public String toString() {
+ return "rankStart=" + rankStart + ", rankEnd=" + rankEnd;
+ }
+
+}
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/ConstantRankRangeWithoutEnd.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/ConstantRankRangeWithoutEnd.java
new file mode 100644
index 0000000000000..0d0ddd3dcc2fe
--- /dev/null
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/ConstantRankRangeWithoutEnd.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.rank;
+
+import java.util.List;
+
+/** ConstantRankRangeWithoutEnd is a RankRange which not specify RankEnd. */
+public class ConstantRankRangeWithoutEnd implements RankRange {
+
+ private static final long serialVersionUID = -1944057111062598696L;
+
+ private final long rankStart;
+
+ public ConstantRankRangeWithoutEnd(long rankStart) {
+ this.rankStart = rankStart;
+ }
+
+ @Override
+ public String toString(List inputFieldNames) {
+ return toString();
+ }
+
+ @Override
+ public String toString() {
+ return "rankStart=" + rankStart;
+ }
+}
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RankRange.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RankRange.java
new file mode 100644
index 0000000000000..e14b4bfff9dd9
--- /dev/null
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RankRange.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.rank;
+
+import java.io.Serializable;
+import java.util.List;
+
+/**
+ * RankRange for Rank, including following 3 types :
+ * ConstantRankRange, ConstantRankRangeWithoutEnd, VariableRankRange.
+ */
+public interface RankRange extends Serializable {
+
+ String toString(List inputFieldNames);
+
+}
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RankType.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RankType.java
new file mode 100644
index 0000000000000..d6573dd96b24b
--- /dev/null
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RankType.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.rank;
+
+/**
+ * An enumeration of rank type, usable to show how exactly generate rank number.
+ */
+public enum RankType {
+
+ /**
+ * Returns a unique sequential number for each row within the partition based on the order,
+ * starting at 1 for the first row in each partition and without repeating or skipping
+ * numbers in the ranking result of each partition. If there are duplicate values within the
+ * row set, the ranking numbers will be assigned arbitrarily.
+ */
+ ROW_NUMBER,
+
+ /**
+ * Returns a unique rank number for each distinct row within the partition based on the order,
+ * starting at 1 for the first row in each partition, with the same rank for duplicate values
+ * and leaving gaps between the ranks; this gap appears in the sequence after the duplicate
+ * values.
+ */
+ RANK,
+
+ /**
+ * is similar to the RANK by generating a unique rank number for each distinct row
+ * within the partition based on the order, starting at 1 for the first row in each partition,
+ * ranking the rows with equal values with the same rank number, except that it does not skip
+ * any rank, leaving no gaps between the ranks.
+ */
+ DENSE_RANK
+}
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RetractRankFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RetractRankFunction.java
new file mode 100644
index 0000000000000..eb0250d83de25
--- /dev/null
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RetractRankFunction.java
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.rank;
+
+import org.apache.flink.api.common.state.MapState;
+import org.apache.flink.api.common.state.MapStateDescriptor;
+import org.apache.flink.api.common.state.ValueState;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.typeutils.ListTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.table.dataformat.BaseRow;
+import org.apache.flink.table.dataformat.util.BaseRowUtil;
+import org.apache.flink.table.generated.GeneratedRecordComparator;
+import org.apache.flink.table.generated.GeneratedRecordEqualiser;
+import org.apache.flink.table.runtime.keyselector.BaseRowKeySelector;
+import org.apache.flink.table.typeutils.BaseRowTypeInfo;
+import org.apache.flink.table.typeutils.SortedMapTypeInfo;
+import org.apache.flink.util.Collector;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.SortedMap;
+import java.util.TreeMap;
+
+/**
+ * RetractRankFunction's input stream could contain append record, update record, delete record.
+ */
+public class RetractRankFunction extends AbstractRankFunction {
+
+ private static final long serialVersionUID = 1365312180599454479L;
+
+ private static final Logger LOG = LoggerFactory.getLogger(RetractRankFunction.class);
+
+ // Message to indicate the state is cleared because of ttl restriction. The message could be used to output to log.
+ private static final String STATE_CLEARED_WARN_MSG = "The state is cleared because of state ttl. " +
+ "This will result in incorrect result. You can increase the state ttl to avoid this.";
+
+ private final BaseRowTypeInfo sortKeyType;
+
+ // flag to skip records with non-exist error instead to fail, true by default.
+ private final boolean lenient = true;
+
+ // a map state stores mapping from sort key to records list
+ private transient MapState> dataState;
+
+ // a sorted map stores mapping from sort key to records count
+ private transient ValueState> treeMap;
+
+ public RetractRankFunction(long minRetentionTime, long maxRetentionTime, BaseRowTypeInfo inputRowType,
+ GeneratedRecordComparator generatedRecordComparator, BaseRowKeySelector sortKeySelector,
+ RankType rankType, RankRange rankRange, GeneratedRecordEqualiser generatedEqualiser,
+ boolean generateRetraction, boolean outputRankNumber) {
+ super(minRetentionTime, maxRetentionTime, inputRowType, generatedRecordComparator, sortKeySelector, rankType,
+ rankRange, generatedEqualiser, generateRetraction, outputRankNumber);
+ this.sortKeyType = sortKeySelector.getProducedType();
+ }
+
+ @Override
+ public void open(Configuration parameters) throws Exception {
+ super.open(parameters);
+ ListTypeInfo valueTypeInfo = new ListTypeInfo<>(inputRowType);
+ MapStateDescriptor> mapStateDescriptor = new MapStateDescriptor<>(
+ "data-state", sortKeyType, valueTypeInfo);
+ dataState = getRuntimeContext().getMapState(mapStateDescriptor);
+
+ ValueStateDescriptor> valueStateDescriptor = new ValueStateDescriptor<>(
+ "sorted-map",
+ new SortedMapTypeInfo<>(sortKeyType, BasicTypeInfo.LONG_TYPE_INFO, sortKeyComparator));
+ treeMap = getRuntimeContext().getState(valueStateDescriptor);
+ }
+
+ @Override
+ public void processElement(BaseRow input, Context ctx, Collector out) throws Exception {
+ initRankEnd(input);
+ SortedMap sortedMap = treeMap.value();
+ if (sortedMap == null) {
+ sortedMap = new TreeMap<>(sortKeyComparator);
+ }
+ BaseRow sortKey = sortKeySelector.getKey(input);
+ if (BaseRowUtil.isAccumulateMsg(input)) {
+ // update sortedMap
+ if (sortedMap.containsKey(sortKey)) {
+ sortedMap.put(sortKey, sortedMap.get(sortKey) + 1);
+ } else {
+ sortedMap.put(sortKey, 1L);
+ }
+
+ // emit
+ emitRecordsWithRowNumber(sortedMap, sortKey, input, out);
+
+ // update data state
+ List inputs = dataState.get(sortKey);
+ if (inputs == null) {
+ // the sort key is never seen
+ inputs = new ArrayList<>();
+ }
+ inputs.add(input);
+ dataState.put(sortKey, inputs);
+ } else {
+ // emit updates first
+ retractRecordWithRowNumber(sortedMap, sortKey, input, out);
+
+ // and then update sortedMap
+ if (sortedMap.containsKey(sortKey)) {
+ long count = sortedMap.get(sortKey) - 1;
+ if (count == 0) {
+ sortedMap.remove(sortKey);
+ } else {
+ sortedMap.put(sortKey, count);
+ }
+ } else {
+ if (sortedMap.isEmpty()) {
+ if (lenient) {
+ LOG.warn(STATE_CLEARED_WARN_MSG);
+ } else {
+ throw new RuntimeException(STATE_CLEARED_WARN_MSG);
+ }
+ } else {
+ throw new RuntimeException("Can not retract a non-existent record: ${inputBaseRow.toString}. " +
+ "This should never happen.");
+ }
+ }
+
+ }
+ treeMap.update(sortedMap);
+ }
+
+ // ------------- ROW_NUMBER-------------------------------
+
+ private void retractRecordWithRowNumber(
+ SortedMap sortedMap, BaseRow sortKey, BaseRow inputRow, Collector out)
+ throws Exception {
+ Iterator> iterator = sortedMap.entrySet().iterator();
+ long curRank = 0L;
+ boolean findsSortKey = false;
+ while (iterator.hasNext() && isInRankEnd(curRank)) {
+ Map.Entry entry = iterator.next();
+ BaseRow key = entry.getKey();
+ if (!findsSortKey && key.equals(sortKey)) {
+ List inputs = dataState.get(key);
+ if (inputs == null) {
+ // Skip the data if it's state is cleared because of state ttl.
+ if (lenient) {
+ LOG.warn(STATE_CLEARED_WARN_MSG);
+ } else {
+ throw new RuntimeException(STATE_CLEARED_WARN_MSG);
+ }
+ } else {
+ Iterator inputIter = inputs.iterator();
+ while (inputIter.hasNext() && isInRankEnd(curRank)) {
+ curRank += 1;
+ BaseRow prevRow = inputIter.next();
+ if (!findsSortKey && equaliser.equalsWithoutHeader(prevRow, inputRow)) {
+ delete(out, prevRow, curRank);
+ curRank -= 1;
+ findsSortKey = true;
+ inputIter.remove();
+ } else if (findsSortKey) {
+ retract(out, prevRow, curRank + 1);
+ collect(out, prevRow, curRank);
+ }
+ }
+ if (inputs.isEmpty()) {
+ dataState.remove(key);
+ } else {
+ dataState.put(key, inputs);
+ }
+ }
+ } else if (findsSortKey) {
+ List inputs = dataState.get(key);
+ int i = 0;
+ while (i < inputs.size() && isInRankEnd(curRank)) {
+ curRank += 1;
+ BaseRow prevRow = inputs.get(i);
+ retract(out, prevRow, curRank + 1);
+ collect(out, prevRow, curRank);
+ i++;
+ }
+ } else {
+ curRank += entry.getValue();
+ }
+ }
+ }
+
+ private void emitRecordsWithRowNumber(
+ SortedMap sortedMap, BaseRow sortKey, BaseRow inputRow, Collector out)
+ throws Exception {
+ Iterator> iterator = sortedMap.entrySet().iterator();
+ long curRank = 0L;
+ boolean findsSortKey = false;
+ while (iterator.hasNext() && isInRankEnd(curRank)) {
+ Map.Entry entry = iterator.next();
+ BaseRow key = entry.getKey();
+ if (!findsSortKey && key.equals(sortKey)) {
+ curRank += entry.getValue();
+ collect(out, inputRow, curRank);
+ findsSortKey = true;
+ } else if (findsSortKey) {
+ List inputs = dataState.get(key);
+ if (inputs == null) {
+ // Skip the data if it's state is cleared because of state ttl.
+ if (lenient) {
+ LOG.warn(STATE_CLEARED_WARN_MSG);
+ } else {
+ throw new RuntimeException(STATE_CLEARED_WARN_MSG);
+ }
+ } else {
+ int i = 0;
+ while (i < inputs.size() && isInRankEnd(curRank)) {
+ curRank += 1;
+ BaseRow prevRow = inputs.get(i);
+ retract(out, prevRow, curRank - 1);
+ collect(out, prevRow, curRank);
+ i++;
+ }
+ }
+ } else {
+ curRank += entry.getValue();
+ }
+ }
+ }
+
+}
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/TopNBuffer.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/TopNBuffer.java
new file mode 100644
index 0000000000000..8600d919c7e93
--- /dev/null
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/TopNBuffer.java
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.rank;
+
+import org.apache.flink.table.dataformat.BaseRow;
+
+import java.io.Serializable;
+import java.util.Collection;
+import java.util.Comparator;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeMap;
+import java.util.function.Supplier;
+
+/**
+ * TopNBuffer stores mapping from sort key to records list, sortKey is BaseRow type, each record is BaseRow type.
+ * TopNBuffer could also track rank number of each record.
+ */
+class TopNBuffer implements Serializable {
+
+ private static final long serialVersionUID = 6824488508991990228L;
+
+ private final Supplier> valueSupplier;
+ private final Comparator sortKeyComparator;
+ private int currentTopNum = 0;
+ private TreeMap> treeMap;
+
+ TopNBuffer(Comparator sortKeyComparator, Supplier> valueSupplier) {
+ this.valueSupplier = valueSupplier;
+ this.sortKeyComparator = sortKeyComparator;
+ this.treeMap = new TreeMap(sortKeyComparator);
+ }
+
+ /**
+ * Appends a record into the buffer.
+ *
+ * @param sortKey sort key with which the specified value is to be associated
+ * @param value record which is to be appended
+ * @return the size of the collection under the sortKey.
+ */
+ public int put(BaseRow sortKey, BaseRow value) {
+ currentTopNum += 1;
+ // update treeMap
+ Collection collection = treeMap.get(sortKey);
+ if (collection == null) {
+ collection = valueSupplier.get();
+ treeMap.put(sortKey, collection);
+ }
+ collection.add(value);
+ return collection.size();
+ }
+
+ /**
+ * Puts a record list into the buffer under the sortKey.
+ * Note: if buffer already contains sortKey, putAll will overwrite the previous value
+ *
+ * @param sortKey sort key with which the specified values are to be associated
+ * @param values record lists to be associated with the specified key
+ */
+ void putAll(BaseRow sortKey, Collection values) {
+ treeMap.put(sortKey, values);
+ currentTopNum += values.size();
+ }
+
+ /**
+ * Gets the record list from the buffer under the sortKey.
+ *
+ * @param sortKey key to get
+ * @return the record list from the buffer under the sortKey
+ */
+ public Collection get(BaseRow sortKey) {
+ return treeMap.get(sortKey);
+ }
+
+ public void remove(BaseRow sortKey, BaseRow value) {
+ Collection list = treeMap.get(sortKey);
+ if (list != null) {
+ if (list.remove(value)) {
+ currentTopNum -= 1;
+ }
+ if (list.size() == 0) {
+ treeMap.remove(sortKey);
+ }
+ }
+ }
+
+ /**
+ * Removes all record list from the buffer under the sortKey.
+ *
+ * @param sortKey key to remove
+ */
+ void removeAll(BaseRow sortKey) {
+ Collection list = treeMap.get(sortKey);
+ if (list != null) {
+ currentTopNum -= list.size();
+ treeMap.remove(sortKey);
+ }
+ }
+
+ /**
+ * Removes the last record of the last Entry in the buffer.
+ *
+ * @return removed record
+ */
+ BaseRow removeLast() {
+ Map.Entry> last = treeMap.lastEntry();
+ BaseRow lastElement = null;
+ if (last != null) {
+ Collection list = last.getValue();
+ lastElement = getLastElement(list);
+ if (lastElement != null) {
+ if (list.remove(lastElement)) {
+ currentTopNum -= 1;
+ }
+ if (list.size() == 0) {
+ treeMap.remove(last.getKey());
+ }
+ }
+ }
+ return lastElement;
+ }
+
+ /**
+ * Gets record which rank is given value.
+ *
+ * @param rank rank value to search
+ * @return the record which rank is given value
+ */
+ BaseRow getElement(int rank) {
+ int curRank = 0;
+ Iterator>> iter = treeMap.entrySet().iterator();
+ while (iter.hasNext()) {
+ Map.Entry> entry = iter.next();
+ Collection list = entry.getValue();
+
+ Iterator listIter = list.iterator();
+ while (listIter.hasNext()) {
+ BaseRow elem = listIter.next();
+ curRank += 1;
+ if (curRank == rank) {
+ return elem;
+ }
+ }
+ }
+ return null;
+ }
+
+ private BaseRow getLastElement(Collection list) {
+ BaseRow element = null;
+ if (list != null && !list.isEmpty()) {
+ Iterator iter = list.iterator();
+ while (iter.hasNext()) {
+ element = iter.next();
+ }
+ }
+ return element;
+ }
+
+ /**
+ * Returns a {@link Set} view of the mappings contained in the buffer.
+ */
+ Set>> entrySet() {
+ return treeMap.entrySet();
+ }
+
+ /**
+ * Returns the last Entry in the buffer. Returns null if the TreeMap is empty.
+ */
+ Map.Entry> lastEntry() {
+ return treeMap.lastEntry();
+ }
+
+ /**
+ * Returns {@code true} if the buffer contains a mapping for the specified key.
+ *
+ * @param key key whose presence in the buffer is to be tested
+ * @return {@code true} if the buffer contains a mapping for the specified key
+ */
+ boolean containsKey(BaseRow key) {
+ return treeMap.containsKey(key);
+ }
+
+ /**
+ * Gets number of total records.
+ *
+ * @return the number of total records.
+ */
+ int getCurrentTopNum() {
+ return currentTopNum;
+ }
+
+ /**
+ * Gets sort key comparator used by buffer.
+ *
+ * @return sort key comparator used by buffer
+ */
+ Comparator getSortKeyComparator() {
+ return sortKeyComparator;
+ }
+}
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/UpdateRankFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/UpdateRankFunction.java
new file mode 100644
index 0000000000000..a78b00d5e11ea
--- /dev/null
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/UpdateRankFunction.java
@@ -0,0 +1,486 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.rank;
+
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.state.MapState;
+import org.apache.flink.api.common.state.MapStateDescriptor;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.table.dataformat.BaseRow;
+import org.apache.flink.table.generated.GeneratedRecordComparator;
+import org.apache.flink.table.generated.GeneratedRecordEqualiser;
+import org.apache.flink.table.runtime.keyselector.BaseRowKeySelector;
+import org.apache.flink.table.runtime.util.LRUMap;
+import org.apache.flink.table.typeutils.BaseRowTypeInfo;
+import org.apache.flink.util.Collector;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.TreeMap;
+
+/**
+ * A fast version of rank process function which only hold top n data in state, and keep sorted map in heap.
+ * This only works in some special scenarios:
+ * 1. sort field collation is ascending and its mono is decreasing, or sort field collation is descending and its mono
+ * is increasing
+ * 2. input data has unique keys
+ * 3. input stream could not contain delete record or retract record
+ */
+public class UpdateRankFunction extends AbstractRankFunction implements CheckpointedFunction {
+
+ private static final long serialVersionUID = 6786508184355952780L;
+
+ private static final Logger LOG = LoggerFactory.getLogger(UpdateRankFunction.class);
+
+ private final BaseRowTypeInfo rowKeyType;
+ private final long cacheSize;
+
+ // a map state stores mapping from row key to record which is in topN
+ // in tuple2, f0 is the record row, f1 is the index in the list of the same sort_key
+ // the f1 is used to preserve the record order in the same sort_key
+ private transient MapState> dataState;
+
+ // a buffer stores mapping from sort key to rowKey list
+ private transient TopNBuffer buffer;
+
+ // the kvSortedMap stores mapping from partition key to it's buffer
+ private transient Map kvSortedMap;
+
+ // a HashMap stores mapping from rowKey to record, a heap mirror to dataState
+ private transient Map rowKeyMap;
+
+ // the kvRowKeyMap store mapping from partitionKey to its rowKeyMap.
+ private transient LRUMap> kvRowKeyMap;
+
+ private final TypeSerializer inputRowSer;
+ private final KeySelector rowKeySelector;
+
+ public UpdateRankFunction(long minRetentionTime, long maxRetentionTime, BaseRowTypeInfo inputRowType,
+ BaseRowKeySelector rowKeySelector, GeneratedRecordComparator generatedRecordComparator,
+ BaseRowKeySelector sortKeySelector, RankType rankType,
+ RankRange rankRange, GeneratedRecordEqualiser generatedEqualiser, boolean generateRetraction,
+ boolean outputRankNumber, long cacheSize) {
+ super(minRetentionTime, maxRetentionTime, inputRowType, generatedRecordComparator, sortKeySelector, rankType,
+ rankRange, generatedEqualiser, generateRetraction, outputRankNumber);
+ this.rowKeyType = rowKeySelector.getProducedType();
+ this.cacheSize = cacheSize;
+ this.inputRowSer = inputRowType.createSerializer(new ExecutionConfig());
+ this.rowKeySelector = rowKeySelector;
+ }
+
+ @Override
+ public void open(Configuration parameters) throws Exception {
+ super.open(parameters);
+ int lruCacheSize = Math.max(1, (int) (cacheSize / getDefaultTopNSize()));
+ // make sure the cached map is in a fixed size, avoid OOM
+ kvSortedMap = new HashMap<>(lruCacheSize);
+ kvRowKeyMap = new LRUMap<>(lruCacheSize, new CacheRemovalListener());
+
+ LOG.info("Top{} operator is using LRU caches key-size: {}", getDefaultTopNSize(), lruCacheSize);
+
+ TupleTypeInfo> valueTypeInfo = new TupleTypeInfo<>(inputRowType, Types.INT);
+ MapStateDescriptor> mapStateDescriptor = new MapStateDescriptor<>(
+ "data-state-with-update", rowKeyType, valueTypeInfo);
+ dataState = getRuntimeContext().getMapState(mapStateDescriptor);
+
+ // metrics
+ registerMetric(kvSortedMap.size() * getDefaultTopNSize());
+ }
+
+ @Override
+ public void onTimer(
+ long timestamp,
+ OnTimerContext ctx,
+ Collector out) throws Exception {
+ if (stateCleaningEnabled) {
+ BaseRow partitionKey = (BaseRow) keyContext.getCurrentKey();
+ // cleanup cache
+ kvRowKeyMap.remove(partitionKey);
+ kvSortedMap.remove(partitionKey);
+ cleanupState(dataState);
+ }
+ }
+
+ @Override
+ public void initializeState(FunctionInitializationContext context) throws Exception {
+ // nothing to do
+ }
+
+ @Override
+ public void processElement(
+ BaseRow input, Context context, Collector out) throws Exception {
+ long currentTime = context.timerService().currentProcessingTime();
+ // register state-cleanup timer
+ registerProcessingCleanupTimer(context, currentTime);
+
+ initHeapStates();
+ initRankEnd(input);
+ if (outputRankNumber || hasOffset()) {
+ // the without-number-algorithm can't handle topN with offset,
+ // so use the with-number-algorithm to handle offset
+ processElementWithRowNumber(input, out);
+ } else {
+ processElementWithoutRowNumber(input, out);
+ }
+ }
+
+ @Override
+ public void snapshotState(FunctionSnapshotContext context) throws Exception {
+ Iterator>> iter = kvRowKeyMap.entrySet().iterator();
+ while (iter.hasNext()) {
+ Map.Entry> entry = iter.next();
+ BaseRow partitionKey = entry.getKey();
+ Map currentRowKeyMap = entry.getValue();
+ keyContext.setCurrentKey(partitionKey);
+ flushBufferToState(currentRowKeyMap);
+ }
+ }
+
+ private void initHeapStates() throws Exception {
+ requestCount += 1;
+ BaseRow partitionKey = (BaseRow) keyContext.getCurrentKey();
+ buffer = kvSortedMap.get(partitionKey);
+ rowKeyMap = kvRowKeyMap.get(partitionKey);
+ if (buffer == null) {
+ buffer = new TopNBuffer(sortKeyComparator, LinkedHashSet::new);
+ rowKeyMap = new HashMap<>();
+ kvSortedMap.put(partitionKey, buffer);
+ kvRowKeyMap.put(partitionKey, rowKeyMap);
+
+ // restore sorted map
+ Iterator>> iter = dataState.iterator();
+ if (iter != null) {
+ // a temp map associate sort key to tuple2
+ Map> tempSortedMap = new HashMap<>();
+ while (iter.hasNext()) {
+ Map.Entry> entry = iter.next();
+ BaseRow rowKey = entry.getKey();
+ Tuple2 recordAndInnerRank = entry.getValue();
+ BaseRow record = recordAndInnerRank.f0;
+ Integer innerRank = recordAndInnerRank.f1;
+ rowKeyMap.put(rowKey, new RankRow(record, innerRank, false));
+
+ // insert into temp sort map to preserve the record order in the same sort key
+ BaseRow sortKey = sortKeySelector.getKey(record);
+ TreeMap treeMap = tempSortedMap.get(sortKey);
+ if (treeMap == null) {
+ treeMap = new TreeMap<>();
+ tempSortedMap.put(sortKey, treeMap);
+ }
+ treeMap.put(innerRank, rowKey);
+ }
+
+ // build sorted map from the temp map
+ Iterator>> tempIter = tempSortedMap.entrySet().iterator();
+ while (tempIter.hasNext()) {
+ Map.Entry> entry = tempIter.next();
+ BaseRow sortKey = entry.getKey();
+ TreeMap treeMap = entry.getValue();
+ Iterator> treeMapIter = treeMap.entrySet().iterator();
+ while (treeMapIter.hasNext()) {
+ Map.Entry treeMapEntry = treeMapIter.next();
+ Integer innerRank = treeMapEntry.getKey();
+ BaseRow recordRowKey = treeMapEntry.getValue();
+ int size = buffer.put(sortKey, recordRowKey);
+ if (innerRank != size) {
+ LOG.warn("Failed to build sorted map from state, this may result in wrong result. " +
+ "The sort key is {}, partition key is {}, " +
+ "treeMap is {}. The expected inner rank is {}, but current size is {}.",
+ sortKey, partitionKey, treeMap, innerRank, size);
+ }
+ }
+ }
+ }
+ } else {
+ hitCount += 1;
+ }
+ }
+
+ private void processElementWithRowNumber(BaseRow inputRow, Collector out) throws Exception {
+ BaseRow sortKey = sortKeySelector.getKey(inputRow);
+ BaseRow rowKey = rowKeySelector.getKey(inputRow);
+ if (rowKeyMap.containsKey(rowKey)) {
+ // it is an updated record which is in the topN, in this scenario,
+ // the new sort key must be higher than old sort key, this is guaranteed by rules
+ RankRow oldRow = rowKeyMap.get(rowKey);
+ BaseRow oldSortKey = sortKeySelector.getKey(oldRow.row);
+ if (oldSortKey.equals(sortKey)) {
+ // sort key is not changed, so the rank is the same, only output the row
+ Tuple2 rankAndInnerRank = rowNumber(sortKey, rowKey, buffer);
+ int rank = rankAndInnerRank.f0;
+ int innerRank = rankAndInnerRank.f1;
+ rowKeyMap.put(rowKey, new RankRow(inputRowSer.copy(inputRow), innerRank, true));
+ retract(out, oldRow.row, rank); // retract old record
+ collect(out, inputRow, rank);
+ return;
+ }
+
+ Tuple2 oldRankAndInnerRank = rowNumber(oldSortKey, rowKey, buffer);
+ int oldRank = oldRankAndInnerRank.f0;
+ // remove old sort key
+ buffer.remove(oldSortKey, rowKey);
+ // add new sort key
+ int size = buffer.put(sortKey, rowKey);
+ rowKeyMap.put(rowKey, new RankRow(inputRowSer.copy(inputRow), size, true));
+ // update inner rank of records under the old sort key
+ updateInnerRank(oldSortKey);
+
+ // emit records
+ emitRecordsWithRowNumber(sortKey, inputRow, out, oldSortKey, oldRow, oldRank);
+ } else if (checkSortKeyInBufferRange(sortKey, buffer)) {
+ // it is an unique record but is in the topN, insert sort key into buffer
+ int size = buffer.put(sortKey, rowKey);
+ rowKeyMap.put(rowKey, new RankRow(inputRowSer.copy(inputRow), size, true));
+
+ // emit records
+ emitRecordsWithRowNumber(sortKey, inputRow, out);
+ }
+ }
+
+ private Tuple2 rowNumber(BaseRow sortKey, BaseRow rowKey, TopNBuffer buffer) {
+ Iterator>> iterator = buffer.entrySet().iterator();
+ int curRank = 1;
+ while (iterator.hasNext()) {
+ Map.Entry> entry = iterator.next();
+ BaseRow curKey = entry.getKey();
+ Collection rowKeys = entry.getValue();
+ if (curKey.equals(sortKey)) {
+ Iterator rowKeysIter = rowKeys.iterator();
+ int innerRank = 1;
+ while (rowKeysIter.hasNext()) {
+ if (rowKey.equals(rowKeysIter.next())) {
+ return Tuple2.of(curRank, innerRank);
+ } else {
+ innerRank += 1;
+ curRank += 1;
+ }
+ }
+ } else {
+ curRank += rowKeys.size();
+ }
+ }
+ LOG.error("Failed to find the sortKey: {}, rowkey: {} in the buffer. This should never happen", sortKey,
+ rowKey);
+ throw new RuntimeException("Failed to find the sortKey, rowkey in the buffer. This should never happen");
+ }
+
+ private void emitRecordsWithRowNumber(BaseRow sortKey, BaseRow inputRow, Collector out) throws Exception {
+ emitRecordsWithRowNumber(sortKey, inputRow, out, null, null, -1);
+ }
+
+ private void emitRecordsWithRowNumber(BaseRow sortKey, BaseRow inputRow, Collector out, BaseRow oldSortKey,
+ RankRow oldRow, int oldRank) throws Exception {
+
+ int oldInnerRank = oldRow == null ? -1 : oldRow.innerRank;
+ Iterator>> iterator = buffer.entrySet().iterator();
+ int curRank = 0;
+ // whether we have found the sort key in the buffer
+ boolean findsSortKey = false;
+ while (iterator.hasNext() && isInRankEnd(curRank)) {
+ Map.Entry> entry = iterator.next();
+ BaseRow curSortKey = entry.getKey();
+ Collection rowKeys = entry.getValue();
+ // meet its own sort key
+ if (!findsSortKey && curSortKey.equals(sortKey)) {
+ curRank += rowKeys.size();
+ if (oldRow != null) {
+ retract(out, oldRow.row, oldRank);
+ }
+ collect(out, inputRow, curRank);
+ findsSortKey = true;
+ } else if (findsSortKey) {
+ if (oldSortKey == null) {
+ // this is a new row, emit updates for all rows in the topn
+ Iterator rowKeyIter = rowKeys.iterator();
+ while (rowKeyIter.hasNext() && isInRankEnd(curRank)) {
+ curRank += 1;
+ BaseRow rowKey = rowKeyIter.next();
+ RankRow prevRow = rowKeyMap.get(rowKey);
+ retract(out, prevRow.row, curRank - 1);
+ collect(out, prevRow.row, curRank);
+ }
+ } else {
+ // current sort key is higher than old sort key,
+ // the rank of current record is changed, need to update the following rank
+ int compare = sortKeyComparator.compare(curSortKey, oldSortKey);
+ if (compare <= 0) {
+ Iterator rowKeyIter = rowKeys.iterator();
+ int curInnerRank = 0;
+ while (rowKeyIter.hasNext() && isInRankEnd(curRank)) {
+ curRank += 1;
+ curInnerRank += 1;
+ if (compare == 0 && curInnerRank >= oldInnerRank) {
+ // match to the previous position
+ return;
+ }
+
+ BaseRow rowKey = rowKeyIter.next();
+ RankRow prevRow = rowKeyMap.get(rowKey);
+ retract(out, prevRow.row, curRank - 1);
+ collect(out, prevRow.row, curRank);
+ }
+ } else {
+ // current sort key is smaller than old sort key, the rank is not changed, so skip
+ return;
+ }
+ }
+ } else {
+ curRank += rowKeys.size();
+ }
+ }
+
+ // remove the records associated to the sort key which is out of topN
+ List toDeleteSortKeys = new ArrayList<>();
+ while (iterator.hasNext()) {
+ Map.Entry> entry = iterator.next();
+ Collection rowKeys = entry.getValue();
+ Iterator rowKeyIter = rowKeys.iterator();
+ while (rowKeyIter.hasNext()) {
+ BaseRow rowKey = rowKeyIter.next();
+ rowKeyMap.remove(rowKey);
+ dataState.remove(rowKey);
+ }
+ toDeleteSortKeys.add(entry.getKey());
+ }
+ for (BaseRow toDeleteKey : toDeleteSortKeys) {
+ buffer.removeAll(toDeleteKey);
+ }
+ }
+
+ private void processElementWithoutRowNumber(BaseRow inputRow, Collector out) throws Exception {
+ BaseRow sortKey = sortKeySelector.getKey(inputRow);
+ BaseRow rowKey = rowKeySelector.getKey(inputRow);
+ if (rowKeyMap.containsKey(rowKey)) {
+ // it is an updated record which is in the topN, in this scenario,
+ // the new sort key must be higher than old sort key, this is guaranteed by rules
+ RankRow oldRow = rowKeyMap.get(rowKey);
+ BaseRow oldSortKey = sortKeySelector.getKey(oldRow.row);
+ if (!oldSortKey.equals(sortKey)) {
+ // remove old sort key
+ buffer.remove(oldSortKey, rowKey);
+ // add new sort key
+ int size = buffer.put(sortKey, rowKey);
+ rowKeyMap.put(rowKey, new RankRow(inputRowSer.copy(inputRow), size, true));
+ // update inner rank of records under the old sort key
+ updateInnerRank(oldSortKey);
+ } else {
+ // row content may change, so we need to update row in map
+ rowKeyMap.put(rowKey, new RankRow(inputRowSer.copy(inputRow), oldRow.innerRank, true));
+ }
+ // row content may change, so a retract is needed
+ retract(out, oldRow.row, oldRow.innerRank);
+ collect(out, inputRow);
+ } else if (checkSortKeyInBufferRange(sortKey, buffer)) {
+ // it is an unique record but is in the topN, insert sort key into buffer
+ int size = buffer.put(sortKey, rowKey);
+ rowKeyMap.put(rowKey, new RankRow(inputRowSer.copy(inputRow), size, true));
+ collect(out, inputRow);
+ // remove retired element
+ if (buffer.getCurrentTopNum() > rankEnd) {
+ BaseRow lastRowKey = buffer.removeLast();
+ if (lastRowKey != null) {
+ RankRow lastRow = rowKeyMap.remove(lastRowKey);
+ dataState.remove(lastRowKey);
+ // always send a retraction message
+ delete(out, lastRow.row);
+ }
+ }
+ }
+ }
+
+ private void flushBufferToState(Map curRowKeyMap) throws Exception {
+ Iterator> iter = curRowKeyMap.entrySet().iterator();
+ while (iter.hasNext()) {
+ Map.Entry entry = iter.next();
+ BaseRow key = entry.getKey();
+ RankRow rankRow = entry.getValue();
+ if (rankRow.dirty) {
+ // should update state
+ dataState.put(key, Tuple2.of(rankRow.row, rankRow.innerRank));
+ rankRow.dirty = false;
+ }
+ }
+ }
+
+ private void updateInnerRank(BaseRow oldSortKey) {
+ Collection list = buffer.get(oldSortKey);
+ if (list != null) {
+ Iterator iter = list.iterator();
+ int innerRank = 1;
+ while (iter.hasNext()) {
+ BaseRow rowKey = iter.next();
+ RankRow row = rowKeyMap.get(rowKey);
+ if (row.innerRank != innerRank) {
+ row.innerRank = innerRank;
+ row.dirty = true;
+ }
+ innerRank += 1;
+ }
+ }
+ }
+
+ private class CacheRemovalListener implements LRUMap.RemovalListener> {
+
+ @Override
+ public void onRemoval(Map.Entry> eldest) {
+ BaseRow previousKey = (BaseRow) keyContext.getCurrentKey();
+ BaseRow partitionKey = eldest.getKey();
+ Map currentRowKeyMap = eldest.getValue();
+ keyContext.setCurrentKey(partitionKey);
+ kvSortedMap.remove(partitionKey);
+ try {
+ flushBufferToState(currentRowKeyMap);
+ } catch (Throwable e) {
+ LOG.error("Fail to synchronize state!", e);
+ throw new RuntimeException(e);
+ } finally {
+ keyContext.setCurrentKey(previousKey);
+ }
+ }
+ }
+
+ private class RankRow {
+ private final BaseRow row;
+ private int innerRank;
+ private boolean dirty;
+
+ private RankRow(BaseRow row, int innerRank, boolean dirty) {
+ this.row = row;
+ this.innerRank = innerRank;
+ this.dirty = dirty;
+ }
+ }
+
+}
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/VariableRankRange.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/VariableRankRange.java
new file mode 100644
index 0000000000000..a065fe3ea1d2c
--- /dev/null
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/VariableRankRange.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.rank;
+
+import java.util.List;
+
+/**
+ * changing rank limit depends on input.
+ */
+public class VariableRankRange implements RankRange {
+
+ private static final long serialVersionUID = 5579785886506433955L;
+ private int rankEndIndex;
+
+ public VariableRankRange(int rankEndIndex) {
+ this.rankEndIndex = rankEndIndex;
+ }
+
+ public int getRankEndIndex() {
+ return rankEndIndex;
+ }
+
+ @Override
+ public String toString(List inputFieldNames) {
+ return "rankEnd=" + inputFieldNames.get(rankEndIndex);
+ }
+
+ @Override
+ public String toString() {
+ return "rankEnd=$" + rankEndIndex;
+ }
+
+}
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/values/ValuesInputFormat.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/values/ValuesInputFormat.java
index 45ddd4ce33162..49cc0734dcf51 100644
--- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/values/ValuesInputFormat.java
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/values/ValuesInputFormat.java
@@ -41,7 +41,7 @@ public class ValuesInputFormat
private static final Logger LOG = LoggerFactory.getLogger(ValuesInputFormat.class);
private GeneratedInput> generatedInput;
- private BaseRowTypeInfo returnType;
+ private final BaseRowTypeInfo returnType;
private GenericInputFormat format;
public ValuesInputFormat(GeneratedInput> generatedInput, BaseRowTypeInfo returnType) {
@@ -54,7 +54,9 @@ public void open(GenericInputSplit split) {
LOG.debug("Compiling GenericInputFormat: $name \n\n Code:\n$code",
generatedInput.getClassName(), generatedInput.getCode());
LOG.debug("Instantiating GenericInputFormat.");
+
format = generatedInput.newInstance(getRuntimeContext().getUserCodeClassLoader());
+ generatedInput = null;
}
@Override
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/TypeConverters.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/TypeConverters.java
index a7b9199c962ca..12aaf6192b4bc 100644
--- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/TypeConverters.java
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/TypeConverters.java
@@ -93,6 +93,8 @@ public class TypeConverters {
internalTypeToInfo.put(InternalTypes.CHAR, BasicTypeInfo.CHAR_TYPE_INFO);
internalTypeToInfo.put(InternalTypes.DATE, BasicTypeInfo.INT_TYPE_INFO);
internalTypeToInfo.put(InternalTypes.TIMESTAMP, BasicTypeInfo.LONG_TYPE_INFO);
+ internalTypeToInfo.put(InternalTypes.PROCTIME_INDICATOR, BasicTypeInfo.LONG_TYPE_INFO);
+ internalTypeToInfo.put(InternalTypes.ROWTIME_INDICATOR, BasicTypeInfo.LONG_TYPE_INFO);
internalTypeToInfo.put(InternalTypes.TIME, BasicTypeInfo.INT_TYPE_INFO);
internalTypeToInfo.put(InternalTypes.BINARY, PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO);
internalTypeToInfo.put(InternalTypes.INTERVAL_MONTHS, BasicTypeInfo.INT_TYPE_INFO);
@@ -111,6 +113,8 @@ public class TypeConverters {
itToEti.put(InternalTypes.CHAR, BasicTypeInfo.CHAR_TYPE_INFO);
itToEti.put(InternalTypes.DATE, SqlTimeTypeInfo.DATE);
itToEti.put(InternalTypes.TIMESTAMP, SqlTimeTypeInfo.TIMESTAMP);
+ itToEti.put(InternalTypes.PROCTIME_INDICATOR, SqlTimeTypeInfo.TIMESTAMP);
+ itToEti.put(InternalTypes.ROWTIME_INDICATOR, SqlTimeTypeInfo.TIMESTAMP);
itToEti.put(InternalTypes.TIME, SqlTimeTypeInfo.TIME);
itToEti.put(InternalTypes.BINARY, PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO);
INTERNAL_TYPE_TO_EXTERNAL_TYPE_INFO = Collections.unmodifiableMap(itToEti);
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/AbstractMapSerializer.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/AbstractMapSerializer.java
new file mode 100644
index 0000000000000..344fb978c5fdc
--- /dev/null
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/AbstractMapSerializer.java
@@ -0,0 +1,194 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.typeutils;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.Map;
+
+/**
+ * Base serializer for {@link Map}. The serializer relies on a key serializer
+ * and a value serializer for the serialization of the map's key-value pairs.
+ *
+ * The serialization format for the map is as follows: four bytes for the
+ * length of the map, followed by the serialized representation of each
+ * key-value pair. To allow null values, each value is prefixed by a null flag.
+ *
+ * @param The type of the keys in the map.
+ * @param The type of the values in the map.
+ * @param The type of the map.
+ */
+abstract class AbstractMapSerializer> extends TypeSerializer {
+
+ private static final long serialVersionUID = 1L;
+
+ /** The serializer for the keys in the map. */
+ final TypeSerializer keySerializer;
+
+ /** The serializer for the values in the map. */
+ final TypeSerializer valueSerializer;
+
+ /**
+ * Creates a map serializer that uses the given serializers to serialize the key-value pairs in the map.
+ *
+ * @param keySerializer The serializer for the keys in the map
+ * @param valueSerializer The serializer for the values in the map
+ */
+ AbstractMapSerializer(TypeSerializer keySerializer, TypeSerializer valueSerializer) {
+ Preconditions.checkNotNull(keySerializer, "The key serializer must not be null");
+ Preconditions.checkNotNull(valueSerializer, "The value serializer must not be null.");
+ this.keySerializer = keySerializer;
+ this.valueSerializer = valueSerializer;
+ }
+
+ // ------------------------------------------------------------------------
+
+ /**
+ * Returns the serializer for the keys in the map.
+ *
+ * @return The serializer for the keys in the map.
+ */
+ public TypeSerializer getKeySerializer() {
+ return keySerializer;
+ }
+
+ /**
+ * Returns the serializer for the values in the map.
+ *
+ * @return The serializer for the values in the map.
+ */
+ public TypeSerializer getValueSerializer() {
+ return valueSerializer;
+ }
+
+ // ------------------------------------------------------------------------
+ // Type Serializer implementation
+ // ------------------------------------------------------------------------
+
+ @Override
+ public boolean isImmutableType() {
+ return false;
+ }
+
+ @Override
+ public M copy(M from) {
+ M newMap = createInstance();
+
+ for (Map.Entry entry : from.entrySet()) {
+ K newKey = entry.getKey() == null ? null : keySerializer.copy(entry.getKey());
+ V newValue = entry.getValue() == null ? null : valueSerializer.copy(entry.getValue());
+
+ newMap.put(newKey, newValue);
+ }
+
+ return newMap;
+ }
+
+ @Override
+ public M copy(M from, M reuse) {
+ return copy(from);
+ }
+
+ @Override
+ public int getLength() {
+ return -1; // var length
+ }
+
+ @Override
+ public void serialize(M map, DataOutputView target) throws IOException {
+ final int size = map.size();
+ target.writeInt(size);
+
+ for (Map.Entry entry : map.entrySet()) {
+ keySerializer.serialize(entry.getKey(), target);
+
+ if (entry.getValue() == null) {
+ target.writeBoolean(true);
+ } else {
+ target.writeBoolean(false);
+ valueSerializer.serialize(entry.getValue(), target);
+ }
+ }
+ }
+
+ @Override
+ public M deserialize(DataInputView source) throws IOException {
+ final int size = source.readInt();
+
+ final M map = createInstance();
+ for (int i = 0; i < size; ++i) {
+ K key = keySerializer.deserialize(source);
+
+ boolean isNull = source.readBoolean();
+ V value = isNull ? null : valueSerializer.deserialize(source);
+
+ map.put(key, value);
+ }
+
+ return map;
+ }
+
+ @Override
+ public M deserialize(M reuse, DataInputView source) throws IOException {
+ return deserialize(source);
+ }
+
+ @Override
+ public void copy(DataInputView source, DataOutputView target) throws IOException {
+ final int size = source.readInt();
+ target.writeInt(size);
+
+ for (int i = 0; i < size; ++i) {
+ keySerializer.copy(source, target);
+
+ boolean isNull = source.readBoolean();
+ target.writeBoolean(isNull);
+
+ if (!isNull) {
+ valueSerializer.copy(source, target);
+ }
+ }
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+
+ AbstractMapSerializer, ?, ?> that = (AbstractMapSerializer, ?, ?>) o;
+
+ return keySerializer.equals(that.keySerializer) && valueSerializer.equals(that.valueSerializer);
+ }
+
+ @Override
+ public int hashCode() {
+ int result = keySerializer.hashCode();
+ result = 31 * result + valueSerializer.hashCode();
+ return result;
+ }
+}
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/AbstractMapTypeInfo.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/AbstractMapTypeInfo.java
new file mode 100644
index 0000000000000..2321d28494cbe
--- /dev/null
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/AbstractMapTypeInfo.java
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.typeutils;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.Map;
+
+/**
+ * Base type information for maps.
+ *
+ * @param The type of the keys in the map.
+ * @param The type of the values in the map.
+ */
+abstract class AbstractMapTypeInfo> extends TypeInformation {
+
+ private static final long serialVersionUID = 1L;
+
+ /* The type information for the keys in the map*/
+ final TypeInformation keyTypeInfo;
+
+ /* The type information for the values in the map */
+ final TypeInformation valueTypeInfo;
+
+ /**
+ * Constructor with given type information for the keys and the values in
+ * the map.
+ *
+ * @param keyTypeInfo The type information for the keys in the map.
+ * @param valueTypeInfo The type information for the values in th map.
+ */
+ AbstractMapTypeInfo(TypeInformation keyTypeInfo, TypeInformation valueTypeInfo) {
+ Preconditions.checkNotNull(keyTypeInfo, "The type information for the keys cannot be null.");
+ Preconditions.checkNotNull(valueTypeInfo, "The type information for the values cannot be null.");
+ this.keyTypeInfo = keyTypeInfo;
+ this.valueTypeInfo = valueTypeInfo;
+ }
+
+ /**
+ * Constructor with the classes of the keys and the values in the map.
+ *
+ * @param keyClass The class of the keys in the map.
+ * @param valueClass The class of the values in the map.
+ */
+ AbstractMapTypeInfo(Class keyClass, Class valueClass) {
+ Preconditions.checkNotNull(keyClass, "The key class cannot be null.");
+ Preconditions.checkNotNull(valueClass, "The value class cannot be null.");
+
+ this.keyTypeInfo = TypeInformation.of(keyClass);
+ this.valueTypeInfo = TypeInformation.of(valueClass);
+ }
+
+ // ------------------------------------------------------------------------
+
+ /**
+ * Returns the type information for the keys in the map.
+ *
+ * @return The type information for the keys in the map.
+ */
+ public TypeInformation getKeyTypeInfo() {
+ return keyTypeInfo;
+ }
+
+ /**
+ * Returns the type information for the values in the map.
+ *
+ * @return The type information for the values in the map.
+ */
+ public TypeInformation getValueTypeInfo() {
+ return valueTypeInfo;
+ }
+
+ // ------------------------------------------------------------------------
+
+ @Override
+ public boolean isBasicType() {
+ return false;
+ }
+
+ @Override
+ public boolean isTupleType() {
+ return false;
+ }
+
+ @Override
+ public int getArity() {
+ return 0;
+ }
+
+ @Override
+ public int getTotalFields() {
+ return 2;
+ }
+
+ @Override
+ public boolean isKeyType() {
+ return false;
+ }
+
+ // ------------------------------------------------------------------------
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+
+ AbstractMapTypeInfo, ?, ?> that = (AbstractMapTypeInfo, ?, ?>) o;
+
+ return keyTypeInfo.equals(that.keyTypeInfo) && valueTypeInfo.equals(that.valueTypeInfo);
+ }
+
+ @Override
+ public int hashCode() {
+ int result = keyTypeInfo.hashCode();
+ result = 31 * result + valueTypeInfo.hashCode();
+ return result;
+ }
+}
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapSerializer.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapSerializer.java
new file mode 100644
index 0000000000000..63f4f8f71eeac
--- /dev/null
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapSerializer.java
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.typeutils;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
+
+import java.util.Comparator;
+import java.util.SortedMap;
+import java.util.TreeMap;
+
+/**
+ * A serializer for {@link SortedMap}. The serializer relies on a key serializer
+ * and a value serializer for the serialization of the map's key-value pairs.
+ * It also deploys a comparator to ensure the order of the keys.
+ *
+ * The serialization format for the map is as follows: four bytes for the
+ * length of the map, followed by the serialized representation of each
+ * key-value pair. To allow null values, each value is prefixed by a null flag.
+ *
+ * @param The type of the keys in the map.
+ * @param The type of the values in the map.
+ */
+public final class SortedMapSerializer extends AbstractMapSerializer> {
+
+ private static final long serialVersionUID = 1L;
+
+ /** The comparator for the keys in the map. */
+ private final Comparator comparator;
+
+ /**
+ * Constructor with given comparator, and the serializers for the keys and
+ * values in the map.
+ *
+ * @param comparator The comparator for the keys in the map.
+ * @param keySerializer The serializer for the keys in the map.
+ * @param valueSerializer The serializer for the values in the map.
+ */
+ public SortedMapSerializer(
+ Comparator comparator,
+ TypeSerializer keySerializer,
+ TypeSerializer valueSerializer) {
+ super(keySerializer, valueSerializer);
+ this.comparator = comparator;
+ }
+
+ /**
+ * Returns the comparator for the keys in the map.
+ *
+ * @return The comparator for the keys in the map.
+ */
+ public Comparator getComparator() {
+ return comparator;
+ }
+
+ @Override
+ public TypeSerializer> duplicate() {
+ TypeSerializer keySerializer = getKeySerializer().duplicate();
+ TypeSerializer valueSerializer = getValueSerializer().duplicate();
+
+ return new SortedMapSerializer<>(comparator, keySerializer, valueSerializer);
+ }
+
+ @Override
+ public SortedMap createInstance() {
+ return new TreeMap<>(comparator);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (!super.equals(o)) {
+ return false;
+ }
+
+ SortedMapSerializer, ?> that = (SortedMapSerializer, ?>) o;
+ return comparator.equals(that.comparator);
+ }
+
+ @Override
+ public int hashCode() {
+ int result = super.hashCode();
+ result = result * 31 + comparator.hashCode();
+ return result;
+ }
+
+ @Override
+ public String toString() {
+ return "SortedMapSerializer{" +
+ "comparator = " + comparator +
+ ", keySerializer = " + keySerializer +
+ ", valueSerializer = " + valueSerializer +
+ "}";
+ }
+
+ // --------------------------------------------------------------------------------------------
+ // Serializer configuration snapshot
+ // --------------------------------------------------------------------------------------------
+
+ @Override
+ public TypeSerializerSnapshot> snapshotConfiguration() {
+ return new SortedMapSerializerSnapshot<>(this);
+ }
+}
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapSerializerSnapshot.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapSerializerSnapshot.java
new file mode 100644
index 0000000000000..d7a1dcf68ce1d
--- /dev/null
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapSerializerSnapshot.java
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.typeutils;
+
+import org.apache.flink.api.common.typeutils.NestedSerializersSnapshotDelegate;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.TypeSerializerSchemaCompatibility;
+import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
+import org.apache.flink.api.java.typeutils.runtime.DataInputViewStream;
+import org.apache.flink.api.java.typeutils.runtime.DataOutputViewStream;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.util.InstantiationUtil;
+
+import java.io.IOException;
+import java.util.Comparator;
+import java.util.SortedMap;
+
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * Snapshot class for the {@link SortedMapSerializer}.
+ */
+public class SortedMapSerializerSnapshot implements TypeSerializerSnapshot> {
+
+ private Comparator comparator;
+
+ private NestedSerializersSnapshotDelegate nestedSerializersSnapshotDelegate;
+
+ private static final int CURRENT_VERSION = 3;
+
+ @SuppressWarnings("unused")
+ public SortedMapSerializerSnapshot() {
+ // this constructor is used when restoring from a checkpoint/savepoint.
+ }
+
+ SortedMapSerializerSnapshot(SortedMapSerializer sortedMapSerializer) {
+ this.comparator = sortedMapSerializer.getComparator();
+ TypeSerializer[] typeSerializers =
+ new TypeSerializer>[] { sortedMapSerializer.getKeySerializer(), sortedMapSerializer.getValueSerializer() };
+ this.nestedSerializersSnapshotDelegate = new NestedSerializersSnapshotDelegate(typeSerializers);
+ }
+
+ @Override
+ public int getCurrentVersion() {
+ return CURRENT_VERSION;
+ }
+
+ @Override
+ public void writeSnapshot(DataOutputView out) throws IOException {
+ checkState(comparator != null, "Comparator cannot be null.");
+ InstantiationUtil.serializeObject(new DataOutputViewStream(out), comparator);
+ nestedSerializersSnapshotDelegate.writeNestedSerializerSnapshots(out);
+ }
+
+ @Override
+ public void readSnapshot(int readVersion, DataInputView in, ClassLoader userCodeClassLoader) throws IOException {
+ try {
+ comparator = InstantiationUtil.deserializeObject(
+ new DataInputViewStream(in), userCodeClassLoader);
+ } catch (ClassNotFoundException e) {
+ throw new IOException(e);
+ }
+ this.nestedSerializersSnapshotDelegate = NestedSerializersSnapshotDelegate.readNestedSerializerSnapshots(
+ in,
+ userCodeClassLoader);
+ }
+
+ @Override
+ public SortedMapSerializer restoreSerializer() {
+ TypeSerializer>[] nestedSerializers = nestedSerializersSnapshotDelegate.getRestoredNestedSerializers();
+ @SuppressWarnings("unchecked")
+ TypeSerializer keySerializer = (TypeSerializer) nestedSerializers[0];
+
+ @SuppressWarnings("unchecked")
+ TypeSerializer valueSerializer = (TypeSerializer) nestedSerializers[1];
+
+ return new SortedMapSerializer(comparator, keySerializer, valueSerializer);
+ }
+
+ @Override
+ public TypeSerializerSchemaCompatibility> resolveSchemaCompatibility(
+ TypeSerializer> newSerializer) {
+ if (!(newSerializer instanceof SortedMapSerializer)) {
+ return TypeSerializerSchemaCompatibility.incompatible();
+ }
+ SortedMapSerializer newSortedMapSerializer = (SortedMapSerializer) newSerializer;
+ if (!comparator.equals(newSortedMapSerializer.getComparator())) {
+ return TypeSerializerSchemaCompatibility.incompatible();
+ } else {
+ return TypeSerializerSchemaCompatibility.compatibleAsIs();
+ }
+ }
+}
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapTypeInfo.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapTypeInfo.java
new file mode 100644
index 0000000000000..0e3cef7297044
--- /dev/null
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapTypeInfo.java
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.typeutils;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.util.Preconditions;
+
+import java.io.Serializable;
+import java.util.Comparator;
+import java.util.SortedMap;
+
+/**
+ * The type information for sorted maps.
+ *
+ * @param The type of the keys in the map.
+ * @param The type of the values in the map.
+ */
+@PublicEvolving
+public class SortedMapTypeInfo extends AbstractMapTypeInfo> {
+
+ private static final long serialVersionUID = 1L;
+
+ /** The comparator for the keys in the map. */
+ private final Comparator comparator;
+
+ public SortedMapTypeInfo(
+ TypeInformation