diff --git a/flink-table/flink-table-planner-blink/pom.xml b/flink-table/flink-table-planner-blink/pom.xml index f8ab62e3a264e..4fad3028883d8 100644 --- a/flink-table/flink-table-planner-blink/pom.xml +++ b/flink-table/flink-table-planner-blink/pom.xml @@ -200,6 +200,13 @@ under the License. test-jar test + + + org.apache.flink + flink-statebackend-rocksdb_${scala.binary.version} + ${project.version} + test + diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/util/KeySelectorUtil.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/util/KeySelectorUtil.java new file mode 100644 index 0000000000000..50082c2beacfa --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/util/KeySelectorUtil.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.util; + +import org.apache.flink.table.api.TableConfig; +import org.apache.flink.table.codegen.CodeGeneratorContext; +import org.apache.flink.table.codegen.ProjectionCodeGenerator; +import org.apache.flink.table.generated.GeneratedProjection; +import org.apache.flink.table.runtime.keyselector.BaseRowKeySelector; +import org.apache.flink.table.runtime.keyselector.BinaryRowKeySelector; +import org.apache.flink.table.runtime.keyselector.NullBinaryRowKeySelector; +import org.apache.flink.table.type.InternalType; +import org.apache.flink.table.type.RowType; +import org.apache.flink.table.typeutils.BaseRowSerializer; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; +import org.apache.flink.table.typeutils.TypeCheckUtils; + +/** + * Utility for KeySelector. + */ +public class KeySelectorUtil { + + /** + * Create a BaseRowKeySelector to extract keys from DataStream which type is BaseRowTypeInfo. + * + * @param keyFields key fields + * @param rowType type of DataStream to extract keys + * @return the BaseRowKeySelector to extract keys from DataStream which type is BaseRowTypeInfo. + */ + public static BaseRowKeySelector getBaseRowSelector(int[] keyFields, BaseRowTypeInfo rowType) { + if (keyFields.length > 0) { + InternalType[] inputFieldTypes = rowType.getInternalTypes(); + String[] inputFieldNames = rowType.getFieldNames(); + InternalType[] keyFieldTypes = new InternalType[keyFields.length]; + String[] keyFieldNames = new String[keyFields.length]; + for (int i = 0; i < keyFields.length; ++i) { + keyFieldTypes[i] = inputFieldTypes[keyFields[i]]; + keyFieldNames[i] = inputFieldNames[keyFields[i]]; + } + RowType returnType = new RowType(keyFieldTypes, keyFieldNames); + RowType inputType = new RowType(inputFieldTypes, rowType.getFieldNames()); + GeneratedProjection generatedProjection = ProjectionCodeGenerator.generateProjection( + CodeGeneratorContext.apply(new TableConfig()), + BaseRowSerializer.class.getSimpleName(), inputType, returnType, keyFields); + BaseRowTypeInfo keyRowType = returnType.toTypeInfo(); + // check if type implements proper equals/hashCode + TypeCheckUtils.validateEqualsHashCode("grouping", keyRowType); + return new BinaryRowKeySelector(keyRowType, generatedProjection); + } else { + return new NullBinaryRowKeySelector(); + } + } + +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableConfig.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableConfig.scala index e6c0bc92419a8..4d5b2c3d1226c 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableConfig.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableConfig.scala @@ -187,6 +187,26 @@ class TableConfig { this.conf.setLong(TableConfigOptions.SQL_EXEC_STATE_TTL_MAX_MS, maxTime.toMilliseconds) this } + + /** + * Returns the minimum time until state which was not updated will be retained. + */ + def getMinIdleStateRetentionTime: Long = { + this.conf.getLong(TableConfigOptions.SQL_EXEC_STATE_TTL_MS) + } + + /** + * Returns the maximum time until state which was not updated will be retained. + */ + def getMaxIdleStateRetentionTime: Long = { + // only min idle ttl provided. + if (this.conf.contains(TableConfigOptions.SQL_EXEC_STATE_TTL_MS) + && !this.conf.contains(TableConfigOptions.SQL_EXEC_STATE_TTL_MAX_MS)) { + this.conf.setLong(TableConfigOptions.SQL_EXEC_STATE_TTL_MAX_MS, + getMinIdleStateRetentionTime * 2) + } + this.conf.getLong(TableConfigOptions.SQL_EXEC_STATE_TTL_MAX_MS) + } } object TableConfig { diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkLogicalRelFactories.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkLogicalRelFactories.scala index 4e26b65b378ff..c60bb12f60d95 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkLogicalRelFactories.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkLogicalRelFactories.scala @@ -19,10 +19,9 @@ package org.apache.flink.table.calcite import org.apache.flink.table.calcite.FlinkRelFactories.{ExpandFactory, RankFactory, SinkFactory} -import org.apache.flink.table.plan.nodes.calcite.RankRange -import org.apache.flink.table.plan.nodes.calcite.RankType.RankType import org.apache.flink.table.plan.nodes.logical._ import org.apache.flink.table.plan.schema.FlinkRelOptTable +import org.apache.flink.table.runtime.rank.{RankRange, RankType} import org.apache.flink.table.sinks.TableSink import com.google.common.collect.ImmutableList diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala index 3f462fe3d2156..b8e2c1a755913 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala @@ -20,8 +20,7 @@ package org.apache.flink.table.calcite import org.apache.flink.table.calcite.FlinkRelFactories.{ExpandFactory, RankFactory, SinkFactory} import org.apache.flink.table.expressions.WindowProperty -import org.apache.flink.table.plan.nodes.calcite.RankRange -import org.apache.flink.table.plan.nodes.calcite.RankType.RankType +import org.apache.flink.table.runtime.rank.{RankRange, RankType} import org.apache.flink.table.sinks.TableSink import org.apache.calcite.config.{CalciteConnectionConfigImpl, CalciteConnectionProperty} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkRelFactories.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkRelFactories.scala index d18e5ec82cd66..05dc497f6b1bf 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkRelFactories.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkRelFactories.scala @@ -18,8 +18,8 @@ package org.apache.flink.table.calcite -import org.apache.flink.table.plan.nodes.calcite.RankType.RankType -import org.apache.flink.table.plan.nodes.calcite.{LogicalExpand, LogicalRank, LogicalSink, RankRange} +import org.apache.flink.table.plan.nodes.calcite.{LogicalExpand, LogicalRank, LogicalSink} +import org.apache.flink.table.runtime.rank.{RankRange, RankType} import org.apache.flink.table.sinks.TableSink import org.apache.calcite.plan.Contexts diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CodeGenUtils.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CodeGenUtils.scala index 0d79320f5956b..87114d0ca7a06 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CodeGenUtils.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CodeGenUtils.scala @@ -115,7 +115,7 @@ object CodeGenUtils { case InternalTypes.DATE => "int" case InternalTypes.TIME => "int" - case InternalTypes.TIMESTAMP => "long" + case _: TimestampType => "long" case InternalTypes.INTERVAL_MONTHS => "int" case InternalTypes.INTERVAL_MILLIS => "long" @@ -135,7 +135,7 @@ object CodeGenUtils { case InternalTypes.DATE => boxedTypeTermForType(InternalTypes.INT) case InternalTypes.TIME => boxedTypeTermForType(InternalTypes.INT) - case InternalTypes.TIMESTAMP => boxedTypeTermForType(InternalTypes.LONG) + case _: TimestampType => boxedTypeTermForType(InternalTypes.LONG) case InternalTypes.STRING => BINARY_STRING case InternalTypes.BINARY => "byte[]" diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/EqualiserCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/EqualiserCodeGenerator.scala new file mode 100644 index 0000000000000..09658663e6ca0 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/EqualiserCodeGenerator.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.codegen + +import org.apache.flink.table.api.TableConfig +import org.apache.flink.table.codegen.CodeGenUtils._ +import org.apache.flink.table.codegen.Indenter.toISC +import org.apache.flink.table.generated.{GeneratedRecordEqualiser, RecordEqualiser} +import org.apache.flink.table.`type`.{DateType, InternalType, PrimitiveType, RowType, TimeType, +TimestampType} + +class EqualiserCodeGenerator(fieldTypes: Seq[InternalType]) { + + private val RECORD_EQUALISER = className[RecordEqualiser] + private val LEFT_INPUT = "left" + private val RIGHT_INPUT = "right" + + def generateRecordEqualiser(name: String): GeneratedRecordEqualiser = { + // ignore time zone + val ctx = CodeGeneratorContext(new TableConfig) + val className = newName(name) + val header = + s""" + |if ($LEFT_INPUT.getHeader() != $RIGHT_INPUT.getHeader()) { + | return false; + |} + """.stripMargin + + val codes = for (i <- fieldTypes.indices) yield { + val fieldType = fieldTypes(i) + val fieldTypeTerm = primitiveTypeTermForType(fieldType) + val result = s"cmp$i" + val leftNullTerm = "leftIsNull$" + i + val rightNullTerm = "rightIsNull$" + i + val leftFieldTerm = "leftField$" + i + val rightFieldTerm = "rightField$" + i + val equalsCode = if (isInternalPrimitive(fieldType)) { + s"$leftFieldTerm == $rightFieldTerm" + } else if (isBaseRow(fieldType)) { + val equaliserGenerator = + new EqualiserCodeGenerator(fieldType.asInstanceOf[RowType].getFieldTypes) + val generatedEqualiser = equaliserGenerator + .generateRecordEqualiser("field$" + i + "GeneratedEqualiser") + val generatedEqualiserTerm = ctx.addReusableObject( + generatedEqualiser, "field$" + i + "GeneratedEqualiser") + val equaliserTypeTerm = classOf[RecordEqualiser].getCanonicalName + val equaliserTerm = newName("equaliser") + ctx.addReusableMember(s"private $equaliserTypeTerm $equaliserTerm = null;") + ctx.addReusableInitStatement( + s""" + |$equaliserTerm = ($equaliserTypeTerm) + | $generatedEqualiserTerm.newInstance(Thread.currentThread().getContextClassLoader()); + |""".stripMargin) + s"$equaliserTerm.equalsWithoutHeader($leftFieldTerm, $rightFieldTerm)" + } else { + s"$leftFieldTerm.equals($rightFieldTerm)" + } + val leftReadCode = baseRowFieldReadAccess(ctx, i, LEFT_INPUT, fieldType) + val rightReadCode = baseRowFieldReadAccess(ctx, i, RIGHT_INPUT, fieldType) + s""" + |boolean $leftNullTerm = $LEFT_INPUT.isNullAt($i); + |boolean $rightNullTerm = $RIGHT_INPUT.isNullAt($i); + |boolean $result; + |if ($leftNullTerm && $rightNullTerm) { + | $result = true; + |} else if ($leftNullTerm|| $rightNullTerm) { + | $result = false; + |} else { + | $fieldTypeTerm $leftFieldTerm = $leftReadCode; + | $fieldTypeTerm $rightFieldTerm = $rightReadCode; + | $result = $equalsCode; + |} + |if (!$result) { + | return false; + |} + """.stripMargin + } + + val functionCode = + j""" + public final class $className implements $RECORD_EQUALISER { + + ${ctx.reuseMemberCode()} + + public $className(Object[] references) throws Exception { + ${ctx.reuseInitCode()} + } + + @Override + public boolean equals($BASE_ROW $LEFT_INPUT, $BASE_ROW $RIGHT_INPUT) { + if ($LEFT_INPUT instanceof $BINARY_ROW && $RIGHT_INPUT instanceof $BINARY_ROW) { + return $LEFT_INPUT.equals($RIGHT_INPUT); + } else { + $header + ${ctx.reuseLocalVariableCode()} + ${codes.mkString("\n")} + return true; + } + } + + @Override + public boolean equalsWithoutHeader($BASE_ROW $LEFT_INPUT, $BASE_ROW $RIGHT_INPUT) { + if ($LEFT_INPUT instanceof $BINARY_ROW && $RIGHT_INPUT instanceof $BINARY_ROW) { + return (($BINARY_ROW)$LEFT_INPUT).equalsWithoutHeader((($BINARY_ROW)$RIGHT_INPUT)); + } else { + ${ctx.reuseLocalVariableCode()} + ${codes.mkString("\n")} + return true; + } + } + } + """.stripMargin + + new GeneratedRecordEqualiser(className, functionCode, ctx.references.toArray) + } + + private def isInternalPrimitive(t: InternalType): Boolean = t match { + case _: PrimitiveType => true + + case _: DateType => true + case TimeType.INSTANCE => true + case _: TimestampType => true + + case _ => false + } + + private def isBaseRow(t: InternalType): Boolean = t match { + case _: RowType => true + case _ => false + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/SinkCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/SinkCodeGenerator.scala index 29431c2fb6458..ce388c8bc9f7f 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/SinkCodeGenerator.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/SinkCodeGenerator.scala @@ -31,6 +31,7 @@ import org.apache.flink.table.`type`.TypeConverters.createInternalTypeFromTypeIn import org.apache.flink.table.api.{Table, TableConfig, TableException, Types} import org.apache.flink.table.codegen.CodeGenUtils.{BASE_ROW, genToExternal} import org.apache.flink.table.codegen.OperatorCodeGenerator.generateCollect +import org.apache.flink.table.dataformat.util.BaseRowUtil import org.apache.flink.table.dataformat.{BaseRow, GenericRow} import org.apache.flink.table.runtime.OneInputOperatorWrapper import org.apache.flink.table.sinks.{DataStreamTableSink, TableSink} @@ -93,6 +94,10 @@ object SinkCodeGenerator { new RowTypeInfo( inputTypeInfo.getFieldTypes, inputTypeInfo.getFieldNames) + case gt: GenericTypeInfo[BaseRow] if gt.getTypeClass == classOf[BaseRow] => + new BaseRowTypeInfo( + inputTypeInfo.getInternalTypes, + inputTypeInfo.getFieldNames) case _ => requestedTypeInfo } @@ -154,13 +159,14 @@ object SinkCodeGenerator { val retractProcessCode = if (!withChangeFlag) { generateCollect(genToExternal(ctx, outputTypeInfo, afterIndexModify)) } else { - val flagResultTerm = s"$afterIndexModify.getHeader() == $BASE_ROW.ACCUMULATE_MSG" + val flagResultTerm = + s"${classOf[BaseRowUtil].getCanonicalName}.isAccumulateMsg($afterIndexModify)" val resultTerm = CodeGenUtils.newName("result") val genericRowField = classOf[GenericRow].getCanonicalName s""" |$genericRowField $resultTerm = new $genericRowField(2); - |$resultTerm.update(0, $flagResultTerm); - |$resultTerm.update(1, $afterIndexModify); + |$resultTerm.setField(0, $flagResultTerm); + |$resultTerm.setField(1, $afterIndexModify); |${generateCollect(genToExternal(ctx, outputTypeInfo, resultTerm))} """.stripMargin } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/ValuesCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/ValuesCodeGenerator.scala index c207bd685a52f..86b79a9e590a0 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/ValuesCodeGenerator.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/ValuesCodeGenerator.scala @@ -22,7 +22,6 @@ import org.apache.flink.table.api.TableEnvironment import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.dataformat.{BaseRow, GenericRow} import org.apache.flink.table.runtime.values.ValuesInputFormat -import org.apache.flink.table.typeutils.BaseRowTypeInfo import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rex.RexLiteral @@ -56,8 +55,7 @@ object ValuesCodeGenerator { generatedRecords.map(_.code), outputType) - val baseRowTypeInfo = new BaseRowTypeInfo(outputType.getFieldTypes, outputType.getFieldNames) - new ValuesInputFormat(generatedFunction, baseRowTypeInfo) + new ValuesInputFormat(generatedFunction, outputType.toTypeInfo) } } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/sort/SortCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/sort/SortCodeGenerator.scala index a3fefe15c691e..57a64212a0b7e 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/sort/SortCodeGenerator.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/sort/SortCodeGenerator.scala @@ -449,6 +449,9 @@ class SortCodeGenerator( case InternalTypes.FLOAT => 4 case InternalTypes.DOUBLE => 8 case InternalTypes.LONG => 8 + case _: TimestampType => 8 + case _: DateType => 4 + case InternalTypes.TIME => 4 case dt: DecimalType if Decimal.isCompact(dt.precision()) => 8 case InternalTypes.STRING | InternalTypes.BINARY => Int.MaxValue } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/calcite/LogicalRank.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/calcite/LogicalRank.scala index de4238ea29c27..e9e4c0231367b 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/calcite/LogicalRank.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/calcite/LogicalRank.scala @@ -18,7 +18,7 @@ package org.apache.flink.table.plan.nodes.calcite -import org.apache.flink.table.plan.nodes.calcite.RankType.RankType +import org.apache.flink.table.runtime.rank.{RankRange, RankType} import org.apache.calcite.plan.{Convention, RelOptCluster, RelTraitSet} import org.apache.calcite.rel.`type`.RelDataTypeField diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/calcite/Rank.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/calcite/Rank.scala index 9e19e028e3dbc..0ab70a3406a65 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/calcite/Rank.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/calcite/Rank.scala @@ -19,8 +19,8 @@ package org.apache.flink.table.plan.nodes.calcite import org.apache.flink.table.api.TableException -import org.apache.flink.table.plan.nodes.calcite.RankType.RankType import org.apache.flink.table.plan.util._ +import org.apache.flink.table.runtime.rank.{ConstantRankRange, RankRange, RankType, VariableRankRange} import org.apache.calcite.plan.{RelOptCluster, RelOptCost, RelOptPlanner, RelTraitSet} import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeField} @@ -71,18 +71,19 @@ abstract class Rank( rankRange match { case r: ConstantRankRange => - if (r.rankEnd <= 0) { - throw new TableException(s"Rank end can't smaller than zero. The rank end is ${r.rankEnd}") + if (r.getRankEnd <= 0) { + throw new TableException( + s"Rank end can't smaller than zero. The rank end is ${r.getRankEnd}") } - if (r.rankStart > r.rankEnd) { + if (r.getRankStart > r.getRankEnd) { throw new TableException( - s"Rank start '${r.rankStart}' can't greater than rank end '${r.rankEnd}'.") + s"Rank start '${r.getRankStart}' can't greater than rank end '${r.getRankEnd}'.") } case v: VariableRankRange => - if (v.rankEndIndex < 0) { + if (v.getRankEndIndex < 0) { throw new TableException(s"Rank end index can't smaller than zero.") } - if (v.rankEndIndex >= input.getRowType.getFieldCount) { + if (v.getRankEndIndex >= input.getRowType.getFieldCount) { throw new TableException(s"Rank end index can't greater than input field count.") } } @@ -105,7 +106,7 @@ abstract class Rank( }.mkString(", ") super.explainTerms(pw) .item("rankType", rankType) - .item("rankRange", rankRange.toString()) + .item("rankRange", rankRange) .item("partitionBy", partitionKey.map(i => s"$$$i").mkString(",")) .item("orderBy", RelExplainUtil.collationToString(orderKey)) .item("select", select) @@ -134,64 +135,3 @@ abstract class Rank( } } - -/** - * An enumeration of rank type, usable to tell the [[Rank]] node how exactly generate rank number. - */ -object RankType extends Enumeration { - type RankType = Value - - /** - * Returns a unique sequential number for each row within the partition based on the order, - * starting at 1 for the first row in each partition and without repeating or skipping - * numbers in the ranking result of each partition. If there are duplicate values within the - * row set, the ranking numbers will be assigned arbitrarily. - */ - val ROW_NUMBER: RankType.Value = Value - - /** - * Returns a unique rank number for each distinct row within the partition based on the order, - * starting at 1 for the first row in each partition, with the same rank for duplicate values - * and leaving gaps between the ranks; this gap appears in the sequence after the duplicate - * values. - */ - val RANK: RankType.Value = Value - - /** - * is similar to the RANK by generating a unique rank number for each distinct row - * within the partition based on the order, starting at 1 for the first row in each partition, - * ranking the rows with equal values with the same rank number, except that it does not skip - * any rank, leaving no gaps between the ranks. - */ - val DENSE_RANK: RankType.Value = Value -} - -sealed trait RankRange extends Serializable { - def toString(inputFieldNames: Seq[String]): String -} - -/** [[ConstantRankRangeWithoutEnd]] is a RankRange which not specify RankEnd. */ -case class ConstantRankRangeWithoutEnd(rankStart: Long) extends RankRange { - override def toString(inputFieldNames: Seq[String]): String = this.toString - - override def toString: String = s"rankStart=$rankStart" -} - -/** rankStart and rankEnd are inclusive, rankStart always start from one. */ -case class ConstantRankRange(rankStart: Long, rankEnd: Long) extends RankRange { - - override def toString(inputFieldNames: Seq[String]): String = this.toString - - override def toString: String = s"rankStart=$rankStart, rankEnd=$rankEnd" -} - -/** changing rank limit depends on input */ -case class VariableRankRange(rankEndIndex: Int) extends RankRange { - override def toString(inputFieldNames: Seq[String]): String = { - s"rankEnd=${inputFieldNames(rankEndIndex)}" - } - - override def toString: String = { - s"rankEnd=$$$rankEndIndex" - } -} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalRank.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalRank.scala index 806dadfa0edd0..55d9575637fc5 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalRank.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalRank.scala @@ -18,9 +18,9 @@ package org.apache.flink.table.plan.nodes.logical import org.apache.flink.table.plan.nodes.FlinkConventions -import org.apache.flink.table.plan.nodes.calcite.RankType.RankType -import org.apache.flink.table.plan.nodes.calcite.{LogicalRank, Rank, RankRange} +import org.apache.flink.table.plan.nodes.calcite.{LogicalRank, Rank} import org.apache.flink.table.plan.util.RelExplainUtil +import org.apache.flink.table.runtime.rank.{RankRange, RankType} import org.apache.calcite.plan._ import org.apache.calcite.rel.`type`.RelDataTypeField diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecRank.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecRank.scala index 84f05aae636dc..05cd3252d438b 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecRank.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecRank.scala @@ -25,10 +25,10 @@ import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.codegen.sort.ComparatorCodeGenerator import org.apache.flink.table.dataformat.BaseRow import org.apache.flink.table.plan.cost.{FlinkCost, FlinkCostFactory} -import org.apache.flink.table.plan.nodes.calcite.RankType.RankType -import org.apache.flink.table.plan.nodes.calcite.{ConstantRankRange, Rank, RankRange, RankType} +import org.apache.flink.table.plan.nodes.calcite.Rank import org.apache.flink.table.plan.nodes.exec.{BatchExecNode, ExecNode} import org.apache.flink.table.plan.util.RelExplainUtil +import org.apache.flink.table.runtime.rank.{ConstantRankRange, RankRange, RankType} import org.apache.flink.table.runtime.sort.RankOperator import org.apache.calcite.plan._ @@ -72,7 +72,7 @@ class BatchExecRank( require(rankType == RankType.RANK, "Only RANK is supported now") val (rankStart, rankEnd) = rankRange match { - case r: ConstantRankRange => (r.rankStart, r.rankEnd) + case r: ConstantRankRange => (r.getRankStart, r.getRankEnd) case o => throw new TableException(s"$o is not supported now") } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDataStreamScan.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDataStreamScan.scala index 81fe874c56938..a8a42b189c11b 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDataStreamScan.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDataStreamScan.scala @@ -31,6 +31,7 @@ import org.apache.flink.table.functions.sql.StreamRecordTimestampSqlFunction import org.apache.flink.table.plan.nodes.exec.{ExecNode, StreamExecNode} import org.apache.flink.table.plan.schema.DataStreamTable import org.apache.flink.table.plan.util.ScanUtil +import org.apache.flink.table.runtime.AbstractProcessStreamOperator import org.apache.flink.table.typeutils.TimeIndicatorTypeInfo.{ROWTIME_INDICATOR, ROWTIME_STREAM_MARKER} import org.apache.calcite.plan._ @@ -112,16 +113,17 @@ class StreamExecDataStreamScan( } else { ("", "") } - + val ctx = CodeGeneratorContext(config).setOperatorBaseClass( + classOf[AbstractProcessStreamOperator[BaseRow]]) ScanUtil.convertToInternalRow( - CodeGeneratorContext(config), + ctx, transform, dataStreamTable.fieldIndexes, dataStreamTable.typeInfo, getRowType, getTable.getQualifiedName, config, - None, + rowtimeExpr, beforeConvert = extractElement, afterConvert = resetElement) } else { diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDeduplicate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDeduplicate.scala index e76af21841def..58d56b84ba472 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDeduplicate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecDeduplicate.scala @@ -18,27 +18,42 @@ package org.apache.flink.table.plan.nodes.physical.stream +import org.apache.flink.streaming.api.operators.KeyedProcessOperator +import org.apache.flink.streaming.api.transformations.{OneInputTransformation, StreamTransformation} +import org.apache.flink.table.api.{StreamTableEnvironment, TableConfigOptions, TableException} +import org.apache.flink.table.dataformat.BaseRow +import org.apache.flink.table.plan.nodes.exec.{ExecNode, StreamExecNode} +import org.apache.flink.table.plan.util.KeySelectorUtil +import org.apache.flink.table.runtime.bundle.KeyedMapBundleOperator +import org.apache.flink.table.runtime.bundle.trigger.CountBundleTrigger +import org.apache.flink.table.runtime.deduplicate.{DeduplicateKeepFirstRowFunction, DeduplicateKeepLastRowFunction, MiniBatchDeduplicateKeepFirstRowFunction, MiniBatchDeduplicateKeepLastRowFunction} +import org.apache.flink.table.plan.rules.physical.stream.StreamExecRetractionRules +import org.apache.flink.table.typeutils.BaseRowTypeInfo + import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} -import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} +import org.apache.calcite.rel.`type`.RelDataType import java.util +import scala.collection.JavaConversions._ + /** * Stream physical RelNode which deduplicate on keys and keeps only first row or last row. * This node is an optimization of [[StreamExecRank]] for some special cases. * Compared to [[StreamExecRank]], this node could use mini-batch and access less state. - *

NOTES: only supports sort on proctime now. + *

NOTES: only supports sort on proctime now, sort on rowtime will not translated into + * StreamExecDeduplicate node. */ class StreamExecDeduplicate( cluster: RelOptCluster, traitSet: RelTraitSet, inputRel: RelNode, uniqueKeys: Array[Int], - isRowtime: Boolean, keepLastRow: Boolean) extends SingleRel(cluster, traitSet, inputRel) - with StreamPhysicalRel { + with StreamPhysicalRel + with StreamExecNode[BaseRow] { def getUniqueKeys: Array[Int] = uniqueKeys @@ -60,17 +75,77 @@ class StreamExecDeduplicate( traitSet, inputs.get(0), uniqueKeys, - isRowtime, keepLastRow) } override def explainTerms(pw: RelWriter): RelWriter = { val fieldNames = getRowType.getFieldNames - val orderString = if (isRowtime) "ROWTIME" else "PROCTIME" super.explainTerms(pw) .item("keepLastRow", keepLastRow) .item("key", uniqueKeys.map(fieldNames.get).mkString(", ")) - .item("order", orderString) + .item("order", "PROCTIME") + } + + //~ ExecNode methods ----------------------------------------------------------- + + override protected def translateToPlanInternal( + tableEnv: StreamTableEnvironment): StreamTransformation[BaseRow] = { + + val inputIsAccRetract = StreamExecRetractionRules.isAccRetract(getInput) + + if (inputIsAccRetract) { + throw new TableException("Deduplicate doesn't support retraction input stream currently.") + } + + val inputTransform = getInputNodes.get(0).translateToPlan(tableEnv) + .asInstanceOf[StreamTransformation[BaseRow]] + + val rowTypeInfo = inputTransform.getOutputType.asInstanceOf[BaseRowTypeInfo] + val generateRetraction = StreamExecRetractionRules.isAccRetract(this) + val tableConfig = tableEnv.getConfig + val isMiniBatchEnabled = tableConfig.getConf.getLong( + TableConfigOptions.SQL_EXEC_MINIBATCH_ALLOW_LATENCY) > 0 + val operator = if (isMiniBatchEnabled) { + val exeConfig = tableEnv.execEnv.getConfig + val rowSerializer = rowTypeInfo.createSerializer(exeConfig) + val processFunction = if (keepLastRow) { + new MiniBatchDeduplicateKeepLastRowFunction(rowTypeInfo, generateRetraction, rowSerializer) + } else { + new MiniBatchDeduplicateKeepFirstRowFunction(rowSerializer) + } + val trigger = new CountBundleTrigger[BaseRow]( + tableConfig.getConf.getLong(TableConfigOptions.SQL_EXEC_MINIBATCH_SIZE)) + new KeyedMapBundleOperator( + processFunction, + trigger) + } else { + val minRetentionTime = tableConfig.getMinIdleStateRetentionTime + val maxRetentionTime = tableConfig.getMaxIdleStateRetentionTime + val processFunction = if (keepLastRow) { + new DeduplicateKeepLastRowFunction(minRetentionTime, maxRetentionTime, rowTypeInfo, + generateRetraction) + } else { + new DeduplicateKeepFirstRowFunction(minRetentionTime, maxRetentionTime) + } + new KeyedProcessOperator[BaseRow, BaseRow, BaseRow](processFunction) + } + val ret = new OneInputTransformation( + inputTransform, getOperatorName, operator, rowTypeInfo, inputTransform.getParallelism) + val selector = KeySelectorUtil.getBaseRowSelector(uniqueKeys, rowTypeInfo) + ret.setStateKeySelector(selector) + ret.setStateKeyType(selector.getProducedType) + ret + } + + override def getInputNodes: util.List[ExecNode[StreamTableEnvironment, _]] = { + List(getInput.asInstanceOf[ExecNode[StreamTableEnvironment, _]]) + } + + private def getOperatorName: String = { + val fieldNames = getRowType.getFieldNames + val keyNames = uniqueKeys.map(fieldNames.get).mkString(", ") + s"${if (keepLastRow) "keepLastRow" else "KeepFirstRow"}" + + s": (key: ($keyNames), select: (${fieldNames.mkString(", ")}), order: (PROCTIME)" } } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecExchange.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecExchange.scala index b503fb2d9bded..2f2fc307a9744 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecExchange.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecExchange.scala @@ -18,11 +18,24 @@ package org.apache.flink.table.plan.nodes.physical.stream +import org.apache.flink.runtime.state.KeyGroupRangeAssignment.DEFAULT_LOWER_BOUND_MAX_PARALLELISM import org.apache.flink.table.plan.nodes.common.CommonPhysicalExchange +import org.apache.flink.table.dataformat.BaseRow +import org.apache.flink.table.plan.nodes.exec.{ExecNode, StreamExecNode} +import org.apache.flink.streaming.api.transformations.{PartitionTransformation, StreamTransformation} +import org.apache.flink.streaming.runtime.partitioner.{GlobalPartitioner, KeyGroupStreamPartitioner, StreamPartitioner} +import org.apache.flink.table.api.StreamTableEnvironment +import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.plan.util.KeySelectorUtil +import org.apache.flink.table.typeutils.BaseRowTypeInfo import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} import org.apache.calcite.rel.{RelDistribution, RelNode} +import java.util + +import scala.collection.JavaConversions._ + /** * Stream physical RelNode for [[org.apache.calcite.rel.core.Exchange]]. */ @@ -32,7 +45,8 @@ class StreamExecExchange( relNode: RelNode, relDistribution: RelDistribution) extends CommonPhysicalExchange(cluster, traitSet, relNode, relDistribution) - with StreamPhysicalRel { + with StreamPhysicalRel + with StreamExecNode[BaseRow] { override def producesUpdates: Boolean = false @@ -50,4 +64,43 @@ class StreamExecExchange( newDistribution: RelDistribution): StreamExecExchange = { new StreamExecExchange(cluster, traitSet, newInput, newDistribution) } + + //~ ExecNode methods ----------------------------------------------------------- + + override def getInputNodes: util.List[ExecNode[StreamTableEnvironment, _]] = { + List(getInput.asInstanceOf[ExecNode[StreamTableEnvironment, _]]) + } + + override protected def translateToPlanInternal( + tableEnv: StreamTableEnvironment): StreamTransformation[BaseRow] = { + val inputTransform = getInputNodes.get(0).translateToPlan(tableEnv) + .asInstanceOf[StreamTransformation[BaseRow]] + val inputTypeInfo = inputTransform.getOutputType.asInstanceOf[BaseRowTypeInfo] + val outputTypeInfo = FlinkTypeFactory.toInternalRowType(getRowType).toTypeInfo + relDistribution.getType match { + case RelDistribution.Type.SINGLETON => + val partitioner = new GlobalPartitioner[BaseRow] + val transformation = new PartitionTransformation( + inputTransform, + partitioner.asInstanceOf[StreamPartitioner[BaseRow]]) + transformation.setOutputType(outputTypeInfo) + transformation + case RelDistribution.Type.HASH_DISTRIBUTED => + // TODO Eliminate duplicate keys + + val selector = KeySelectorUtil.getBaseRowSelector( + relDistribution.getKeys.map(_.toInt).toArray, inputTypeInfo) + val partitioner = new KeyGroupStreamPartitioner(selector, + DEFAULT_LOWER_BOUND_MAX_PARALLELISM) + val transformation = new PartitionTransformation( + inputTransform, + partitioner.asInstanceOf[StreamPartitioner[BaseRow]]) + transformation.setOutputType(outputTypeInfo) + transformation + case _ => + throw new UnsupportedOperationException( + s"not support RelDistribution: ${relDistribution.getType} now!") + } + } + } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecRank.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecRank.scala index ac833b548ea65..e27c85b4b7370 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecRank.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecRank.scala @@ -17,9 +17,18 @@ */ package org.apache.flink.table.plan.nodes.physical.stream -import org.apache.flink.table.plan.nodes.calcite.RankType.RankType -import org.apache.flink.table.plan.nodes.calcite.{Rank, RankRange} -import org.apache.flink.table.plan.util.{RankProcessStrategy, RelExplainUtil, RetractStrategy} +import org.apache.flink.streaming.api.operators.KeyedProcessOperator +import org.apache.flink.streaming.api.transformations.{OneInputTransformation, StreamTransformation} +import org.apache.flink.table.api.{StreamTableEnvironment, TableConfigOptions, TableException} +import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.codegen.EqualiserCodeGenerator +import org.apache.flink.table.codegen.sort.SortCodeGenerator +import org.apache.flink.table.dataformat.BaseRow +import org.apache.flink.table.plan.nodes.calcite.Rank +import org.apache.flink.table.plan.nodes.exec.{ExecNode, StreamExecNode} +import org.apache.flink.table.plan.rules.physical.stream.StreamExecRetractionRules +import org.apache.flink.table.plan.util._ +import org.apache.flink.table.runtime.rank._ import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} import org.apache.calcite.rel._ @@ -53,7 +62,8 @@ class StreamExecRank( rankRange, rankNumberType, outputRankNumber) - with StreamPhysicalRel { + with StreamPhysicalRel + with StreamExecNode[BaseRow] { /** please uses [[getStrategy]] instead of this field */ private var strategy: RankProcessStrategy = _ @@ -101,4 +111,121 @@ class StreamExecRank( .item("orderBy", RelExplainUtil.collationToString(orderKey, inputRowType)) .item("select", getRowType.getFieldNames.mkString(", ")) } + + //~ ExecNode methods ----------------------------------------------------------- + + override def getInputNodes: util.List[ExecNode[StreamTableEnvironment, _]] = { + List(getInput.asInstanceOf[ExecNode[StreamTableEnvironment, _]]) + } + + override protected def translateToPlanInternal( + tableEnv: StreamTableEnvironment): StreamTransformation[BaseRow] = { + val tableConfig = tableEnv.getConfig + rankType match { + case RankType.ROW_NUMBER => // ignore + case RankType.RANK => + throw new TableException("RANK() on streaming table is not supported currently") + case RankType.DENSE_RANK => + throw new TableException("DENSE_RANK() on streaming table is not supported currently") + case k => + throw new TableException(s"Streaming tables do not support $k rank function.") + } + + val inputRowTypeInfo = FlinkTypeFactory.toInternalRowType(getInput.getRowType).toTypeInfo + val fieldCollations = orderKey.getFieldCollations + val (sortFields, sortDirections, nullsIsLast) = SortUtil.getKeysAndOrders(fieldCollations) + val sortKeySelector = KeySelectorUtil.getBaseRowSelector(sortFields, inputRowTypeInfo) + val sortKeyType = sortKeySelector.getProducedType + val sortCodeGen = new SortCodeGenerator( + tableConfig, sortFields.indices.toArray, sortKeyType.getInternalTypes, + sortDirections, nullsIsLast) + val sortKeyComparator = sortCodeGen.generateRecordComparator("StreamExecSortComparator") + val generateRetraction = StreamExecRetractionRules.isAccRetract(this) + val cacheSize = tableConfig.getConf.getLong(TableConfigOptions.SQL_EXEC_TOPN_CACHE_SIZE) + val minIdleStateRetentionTime = tableConfig.getMinIdleStateRetentionTime + val maxIdleStateRetentionTime = tableConfig.getMaxIdleStateRetentionTime + val equaliserCodeGenerator = new EqualiserCodeGenerator(inputRowTypeInfo.getInternalTypes) + val generatedEqualiser = equaliserCodeGenerator.generateRecordEqualiser("RankValueEqualiser") + val processFunction = getStrategy(true) match { + case AppendFastStrategy => + new AppendRankFunction( + minIdleStateRetentionTime, + maxIdleStateRetentionTime, + inputRowTypeInfo, + sortKeyComparator, + sortKeySelector, + rankType, + rankRange, + generatedEqualiser, + generateRetraction, + outputRankNumber, + cacheSize) + + case UpdateFastStrategy(primaryKeys) => + val rowKeySelector = KeySelectorUtil.getBaseRowSelector(primaryKeys, inputRowTypeInfo) + new UpdateRankFunction( + minIdleStateRetentionTime, + maxIdleStateRetentionTime, + inputRowTypeInfo, + rowKeySelector, + sortKeyComparator, + sortKeySelector, + rankType, + rankRange, + generatedEqualiser, + generateRetraction, + outputRankNumber, + cacheSize) + + // TODO UnaryUpdateRank after SortedMapState is merged + case RetractStrategy | UnaryUpdateStrategy(_) => + new RetractRankFunction( + minIdleStateRetentionTime, + maxIdleStateRetentionTime, + inputRowTypeInfo, + sortKeyComparator, + sortKeySelector, + rankType, + rankRange, + generatedEqualiser, + generateRetraction, + outputRankNumber) + } + val rankOpName = getOperatorName + val operator = new KeyedProcessOperator(processFunction) + processFunction.setKeyContext(operator) + val inputTransform = getInputNodes.get(0).translateToPlan(tableEnv) + .asInstanceOf[StreamTransformation[BaseRow]] + val outputRowTypeInfo = FlinkTypeFactory.toInternalRowType(getRowType).toTypeInfo + val ret = new OneInputTransformation( + inputTransform, + rankOpName, + operator, + outputRowTypeInfo, + inputTransform.getParallelism) + + if (partitionKey.isEmpty) { + ret.setParallelism(1) + ret.setMaxParallelism(1) + } + + // set KeyType and Selector for state + val selector = KeySelectorUtil.getBaseRowSelector(partitionKey.toArray, inputRowTypeInfo) + ret.setStateKeySelector(selector) + ret.setStateKeyType(selector.getProducedType) + ret + } + + private def getOperatorName: String = { + val inputRowType = inputRel.getRowType + var result = getStrategy().toString + result += s"(orderBy: (${RelExplainUtil.collationToString(orderKey, inputRowType)})" + if (partitionKey.nonEmpty) { + val partitionKeys = partitionKey.toArray + result += s", partitionBy: (${RelExplainUtil.fieldToString(partitionKeys, inputRowType)})" + } + result += s", ${getRowType.getFieldNames.mkString(", ")}" + result += s", ${rankRange.toString(inputRowType.getFieldNames)})" + result + } } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/optimize/StreamOptimizer.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/optimize/StreamOptimizer.scala index 1bff861eb7287..822359c5ca520 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/optimize/StreamOptimizer.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/optimize/StreamOptimizer.scala @@ -22,7 +22,7 @@ import org.apache.flink.table.api.{StreamTableEnvironment, TableConfig} import org.apache.flink.table.plan.`trait`.UpdateAsRetractionTraitDef import org.apache.flink.table.plan.nodes.calcite.Sink import org.apache.flink.table.plan.optimize.program.{FlinkStreamProgram, StreamOptimizeContext} -import org.apache.flink.table.sinks.RetractStreamTableSink +import org.apache.flink.table.sinks.{DataStreamTableSink, RetractStreamTableSink} import org.apache.flink.util.Preconditions import org.apache.calcite.plan.volcano.VolcanoPlanner @@ -38,7 +38,11 @@ class StreamOptimizer(tEnv: StreamTableEnvironment) extends Optimizer { roots.map { root => val retractionFromRoot = root match { case n: Sink => - n.sink.isInstanceOf[RetractStreamTableSink[_]] + n.sink match { + case _: RetractStreamTableSink[_] => true + case s: DataStreamTableSink[_] => s.updatesAsRetraction + case _ => false + } case o => o.getTraitSet.getTrait(UpdateAsRetractionTraitDef.INSTANCE).sendsUpdatesAsRetractions } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkLogicalRankRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkLogicalRankRule.scala index 3a53e2c1223d5..3b21083e2a608 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkLogicalRankRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkLogicalRankRule.scala @@ -19,9 +19,9 @@ package org.apache.flink.table.plan.rules.logical import org.apache.flink.table.api.TableException import org.apache.flink.table.calcite.FlinkContext -import org.apache.flink.table.plan.nodes.calcite.{ConstantRankRange, ConstantRankRangeWithoutEnd, RankType} import org.apache.flink.table.plan.nodes.logical.{FlinkLogicalCalc, FlinkLogicalOverWindow, FlinkLogicalRank} import org.apache.flink.table.plan.util.RankUtil +import org.apache.flink.table.runtime.rank.{ConstantRankRange, ConstantRankRangeWithoutEnd, RankType} import org.apache.calcite.plan.RelOptRule.{any, operand} import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelOptUtil} @@ -85,8 +85,9 @@ abstract class FlinkLogicalRankRuleBase require(rankNumberType.isDefined) rankRange match { - case Some(ConstantRankRange(_, rankEnd)) if rankEnd <= 0 => - throw new TableException(s"Rank end should not less than zero, but now is $rankEnd") + case Some(crr: ConstantRankRange) if crr.getRankEnd <= 0 => + throw new TableException( + s"Rank end should not less than zero, but now is ${crr.getRankEnd}") case _ => // do nothing } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecRankRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecRankRule.scala index 1e8950b6091ea..918211b9cda67 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecRankRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecRankRule.scala @@ -21,10 +21,10 @@ package org.apache.flink.table.plan.rules.physical.batch import org.apache.flink.table.api.TableException import org.apache.flink.table.plan.`trait`.FlinkRelDistribution import org.apache.flink.table.plan.nodes.FlinkConventions -import org.apache.flink.table.plan.nodes.calcite.{ConstantRankRange, RankType} import org.apache.flink.table.plan.nodes.logical.FlinkLogicalRank import org.apache.flink.table.plan.nodes.physical.batch.BatchExecRank import org.apache.flink.table.plan.util.RelFieldCollationUtil +import org.apache.flink.table.runtime.rank.{ConstantRankRange, RankType} import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall} import org.apache.calcite.rel.convert.ConverterRule @@ -58,7 +58,7 @@ class BatchExecRankRule def convert(rel: RelNode): RelNode = { val rank = rel.asInstanceOf[FlinkLogicalRank] val (_, rankEnd) = rank.rankRange match { - case r: ConstantRankRange => (r.rankStart, r.rankEnd) + case r: ConstantRankRange => (r.getRankStart, r.getRankEnd) case o => throw new TableException(s"$o is not supported now") } @@ -71,7 +71,7 @@ class BatchExecRankRule val newLocalInput = RelOptRule.convert(rank.getInput, localRequiredTraitSet) // create local BatchExecRank - val localRankRange = ConstantRankRange(1, rankEnd) // local rank always start from 1 + val localRankRange = new ConstantRankRange(1, rankEnd) // local rank always start from 1 val localRank = new BatchExecRank( cluster, emptyTraits, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecDeduplicateRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecDeduplicateRule.scala index b07ba7a040989..6ec7a27eb3aa3 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecDeduplicateRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecDeduplicateRule.scala @@ -21,9 +21,9 @@ package org.apache.flink.table.plan.rules.physical.stream import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.plan.`trait`.FlinkRelDistribution import org.apache.flink.table.plan.nodes.FlinkConventions -import org.apache.flink.table.plan.nodes.calcite.{ConstantRankRange, RankType} import org.apache.flink.table.plan.nodes.logical.FlinkLogicalRank import org.apache.flink.table.plan.nodes.physical.stream.{StreamExecDeduplicate, StreamExecRank} +import org.apache.flink.table.runtime.rank.{ConstantRankRange, RankType} import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall} import org.apache.calcite.rel.`type`.RelDataType @@ -68,10 +68,6 @@ class StreamExecDeduplicateRule override def convert(rel: RelNode): RelNode = { val rank = rel.asInstanceOf[FlinkLogicalRank] - val fieldCollation = rank.orderKey.getFieldCollations.get(0) - val fieldType = rank.getInput.getRowType.getFieldList.get(fieldCollation.getFieldIndex).getType - val isRowtime = FlinkTypeFactory.isRowtimeIndicatorType(fieldType) - val requiredDistribution = FlinkRelDistribution.hash(rank.partitionKey.toList) val requiredTraitSet = rel.getCluster.getPlanner.emptyTraitSet() .replace(FlinkConventions.STREAM_PHYSICAL) @@ -79,6 +75,7 @@ class StreamExecDeduplicateRule val convInput: RelNode = RelOptRule.convert(rank.getInput, requiredTraitSet) // order by timeIndicator desc ==> lastRow, otherwise is firstRow + val fieldCollation = rank.orderKey.getFieldCollations.get(0) val isLastRow = fieldCollation.direction.isDescending val providedTraitSet = rel.getTraitSet.replace(FlinkConventions.STREAM_PHYSICAL) new StreamExecDeduplicate( @@ -86,7 +83,6 @@ class StreamExecDeduplicateRule providedTraitSet, convInput, rank.partitionKey.toArray, - isRowtime, isLastRow) } } @@ -111,7 +107,8 @@ object StreamExecDeduplicateRule { val isRowNumberType = rank.rankType == RankType.ROW_NUMBER val isLimit1 = rankRange match { - case ConstantRankRange(rankStart, rankEnd) => rankStart == 1 && rankEnd == 1 + case rankRange: ConstantRankRange => + rankRange.getRankStart() == 1 && rankRange.getRankEnd() == 1 case _ => false } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/FlinkRelMdUtil.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/FlinkRelMdUtil.scala index fbb16f50260cb..b88cfffbdc2f0 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/FlinkRelMdUtil.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/FlinkRelMdUtil.scala @@ -20,7 +20,7 @@ package org.apache.flink.table.plan.util import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.dataformat.BinaryRow -import org.apache.flink.table.plan.nodes.calcite.{ConstantRankRange, RankRange} +import org.apache.flink.table.runtime.rank.{ConstantRankRange, RankRange} import org.apache.flink.table.runtime.sort.BinaryIndexedSortable import org.apache.flink.table.typeutils.BinaryRowSerializer @@ -37,7 +37,7 @@ import scala.collection.JavaConversions._ object FlinkRelMdUtil { def getRankRangeNdv(rankRange: RankRange): Double = rankRange match { - case r: ConstantRankRange => (r.rankEnd - r.rankStart + 1).toDouble + case r: ConstantRankRange => (r.getRankEnd - r.getRankStart + 1).toDouble case _ => 100D // default value now } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/RankUtil.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/RankUtil.scala index 338564c42c164..f7f7378aa5dff 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/RankUtil.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/RankUtil.scala @@ -19,7 +19,9 @@ package org.apache.flink.table.plan.util import org.apache.flink.table.api.{PlannerConfigOptions, TableConfig} -import org.apache.flink.table.plan.nodes.calcite.{ConstantRankRange, ConstantRankRangeWithoutEnd, Rank, RankRange, VariableRankRange} +import org.apache.flink.table.plan.nodes.calcite.Rank +import org.apache.flink.table.runtime.rank.{ConstantRankRange, ConstantRankRangeWithoutEnd, RankRange, VariableRankRange} +import org.apache.flink.table.typeutils.BaseRowTypeInfo import org.apache.calcite.plan.RelOptUtil import org.apache.calcite.rex.{RexBuilder, RexCall, RexInputRef, RexLiteral, RexNode, RexUtil} @@ -98,20 +100,20 @@ object RankUtil { val sortBounds = limitPreds.map(computeWindowBoundFromPredicate(_, rexBuilder, config)) val rankRange = sortBounds match { case Seq(Some(LowerBoundary(x)), Some(UpperBoundary(y))) => - ConstantRankRange(x, y) + new ConstantRankRange(x, y) case Seq(Some(UpperBoundary(x)), Some(LowerBoundary(y))) => - ConstantRankRange(y, x) + new ConstantRankRange(y, x) case Seq(Some(LowerBoundary(x))) => // only offset - ConstantRankRangeWithoutEnd(x) + new ConstantRankRangeWithoutEnd(x) case Seq(Some(UpperBoundary(x))) => // rankStart starts from one - ConstantRankRange(1, x) + new ConstantRankRange(1, x) case Seq(Some(BothBoundary(x, y))) => // nth rank - ConstantRankRange(x, y) + new ConstantRankRange(x, y) case Seq(Some(InputRefBoundary(x))) => - VariableRankRange(x) + new VariableRankRange(x) case _ => // TopN requires at least one rank comparison predicate return (None, Some(predicate)) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/typeutils/TypeCheckUtils.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/typeutils/TypeCheckUtils.scala index 32c9f4f61c3bb..07b9d93abcf92 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/typeutils/TypeCheckUtils.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/typeutils/TypeCheckUtils.scala @@ -18,8 +18,12 @@ package org.apache.flink.table.typeutils -import org.apache.flink.table.`type`._ +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.common.typeutils.CompositeType +import org.apache.flink.api.java.typeutils.PojoTypeInfo +import org.apache.flink.table.api.ValidationException import org.apache.flink.table.codegen.GeneratedExpression +import org.apache.flink.table.`type`._ object TypeCheckUtils { @@ -96,4 +100,63 @@ object TypeCheckUtils { def isReference(genExpr: GeneratedExpression): Boolean = isReference(genExpr.resultType) + /** + * Checks whether a type implements own hashCode() and equals() methods for storing an instance + * in Flink's state or performing a keyBy operation. + * + * @param name name of the operation. + * @param t type information to be validated + */ + def validateEqualsHashCode(name: String, t: TypeInformation[_]): Unit = t match { + + // make sure that a POJO class is a valid state type + case pt: PojoTypeInfo[_] => + // we don't check the types recursively to give a chance of wrapping + // proper hashCode/equals methods around an immutable type + validateEqualsHashCode(name, pt.getClass) + // BinaryRow direct hash in bytes, no need to check field types. + case bt: BaseRowTypeInfo => + // recursively check composite types + case ct: CompositeType[_] => + validateEqualsHashCode(name, t.getTypeClass) + // we check recursively for entering Flink types such as tuples and rows + for (i <- 0 until ct.getArity) { + val subtype = ct.getTypeAt(i) + validateEqualsHashCode(name, subtype) + } + // check other type information only based on the type class + case _: TypeInformation[_] => + validateEqualsHashCode(name, t.getTypeClass) + } + + /** + * Checks whether a class implements own hashCode() and equals() methods for storing an instance + * in Flink's state or performing a keyBy operation. + * + * @param name name of the operation + * @param c class to be validated + */ + def validateEqualsHashCode(name: String, c: Class[_]): Unit = { + + // skip primitives + if (!c.isPrimitive) { + // check the component type of arrays + if (c.isArray) { + validateEqualsHashCode(name, c.getComponentType) + } + // check type for methods + else { + if (c.getMethod("hashCode").getDeclaringClass eq classOf[Object]) { + throw new ValidationException( + s"Type '${c.getCanonicalName}' cannot be used in a $name operation because it " + + s"does not implement a proper hashCode() method.") + } + if (c.getMethod("equals", classOf[Object]).getDeclaringClass eq classOf[Object]) { + throw new ValidationException( + s"Type '${c.getCanonicalName}' cannot be used in a $name operation because it " + + s"does not implement a proper equals() method.") + } + } + } + } } diff --git a/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/runtime/utils/FailingCollectionSource.java b/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/runtime/utils/FailingCollectionSource.java new file mode 100644 index 0000000000000..3d95e3c3a6407 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/runtime/utils/FailingCollectionSource.java @@ -0,0 +1,266 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.utils; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.runtime.state.CheckpointListener; +import org.apache.flink.runtime.state.FunctionInitializationContext; +import org.apache.flink.runtime.state.FunctionSnapshotContext; +import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; +import org.apache.flink.streaming.api.functions.source.SourceFunction; +import org.apache.flink.util.Preconditions; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.util.Preconditions.checkArgument; + +/** + * The FailingCollectionSource will fail after emitted a specified number of elements. This is used + * to perform checkpoint and restore in IT cases. + */ +public class FailingCollectionSource + implements SourceFunction, CheckpointedFunction, CheckpointListener { + + public static volatile boolean failedBefore = false; + + private static final long serialVersionUID = 1L; + + /** The (de)serializer to be used for the data elements. */ + private final TypeSerializer serializer; + + /** The actual data elements, in serialized form. */ + private final byte[] elementsSerialized; + + /** The number of serialized elements. */ + private final int numElements; + + /** The number of elements emitted already. */ + private volatile int numElementsEmitted; + + /** The number of elements to skip initially. */ + private volatile int numElementsToSkip; + + /** Flag to make the source cancelable. */ + private volatile boolean isRunning = true; + + private transient ListState checkpointedState; + + /** A failure will occur when the given number of elements have been processed. */ + private final int failureAfterNumElements; + + /** The number of completed checkpoints. */ + private volatile int numSuccessfulCheckpoints; + + /** The checkpointed number of emitted elements. */ + private final Map checkpointedEmittedNums; + + /** The last successful checkpointed number of emitted elements. */ + private volatile int lastCheckpointedEmittedNum = 0; + + /** Whether to perform a checkpoint before job finished. */ + private final boolean performCheckpointBeforeJobFinished; + + public FailingCollectionSource( + TypeSerializer serializer, + Iterable elements, + int failureAfterNumElements, + boolean performCheckpointBeforeJobFinished) throws IOException { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutputViewStreamWrapper wrapper = new DataOutputViewStreamWrapper(baos); + + int count = 0; + try { + for (T element : elements) { + serializer.serialize(element, wrapper); + count++; + } + } catch (Exception e) { + throw new IOException("Serializing the source elements failed: " + e.getMessage(), e); + } + + this.serializer = serializer; + this.elementsSerialized = baos.toByteArray(); + this.numElements = count; + checkArgument(failureAfterNumElements > 0); + this.failureAfterNumElements = failureAfterNumElements; + this.performCheckpointBeforeJobFinished = performCheckpointBeforeJobFinished; + this.checkpointedEmittedNums = new HashMap<>(); + } + + @Override + public void initializeState(FunctionInitializationContext context) throws Exception { + Preconditions.checkState( + this.checkpointedState == null, + "The " + getClass().getSimpleName() + " has already been initialized."); + + this.checkpointedState = context.getOperatorStateStore().getListState( + new ListStateDescriptor<>( + "from-elements-state", + IntSerializer.INSTANCE + ) + ); + + if (failedBefore && context.isRestored()) { + List retrievedStates = new ArrayList<>(); + for (Integer entry : this.checkpointedState.get()) { + retrievedStates.add(entry); + } + + // given that the parallelism of the function is 1, we can only have 1 state + Preconditions.checkArgument( + retrievedStates.size() == 1, + getClass().getSimpleName() + " retrieved invalid state."); + + this.numElementsToSkip = retrievedStates.get(0); + } + } + + @Override + public void run(SourceContext ctx) throws Exception { + ByteArrayInputStream bais = new ByteArrayInputStream(elementsSerialized); + final DataInputView input = new DataInputViewStreamWrapper(bais); + + // if we are restored from a checkpoint and need to skip elements, skip them now. + int toSkip = numElementsToSkip; + if (toSkip > 0) { + try { + while (toSkip > 0) { + serializer.deserialize(input); + toSkip--; + } + } catch (Exception e) { + throw new IOException( + "Failed to deserialize an element from the source. " + + "If you are using user-defined serialization (Value and Writable types), check the " + + "serialization functions.\nSerializer is " + serializer); + } + + this.numElementsEmitted = this.numElementsToSkip; + } + + while (isRunning && numElementsEmitted < numElements) { + if (!failedBefore) { + // delay a bit, if we have not failed before + Thread.sleep(1); + if (numSuccessfulCheckpoints >= 1 && lastCheckpointedEmittedNum >= failureAfterNumElements) { + // cause a failure if we have not failed before and have reached + // enough completed checkpoints and elements + failedBefore = true; + throw new Exception("Artificial Failure"); + } + } + + if (failedBefore || numElementsEmitted < failureAfterNumElements) { + // the function failed before, or we are in the elements before the failure + T next; + try { + next = serializer.deserialize(input); + } catch (Exception e) { + throw new IOException( + "Failed to deserialize an element from the source. " + + "If you are using user-defined serialization (Value and Writable types), check the " + + "serialization functions.\nSerializer is " + serializer); + } + + synchronized (ctx.getCheckpointLock()) { + ctx.collect(next); + numElementsEmitted++; + } + } else { + // if our work is done, delay a bit to prevent busy waiting + Thread.sleep(1); + } + } + + if (performCheckpointBeforeJobFinished) { + while (isRunning) { + // wait until the latest checkpoint records everything + if (lastCheckpointedEmittedNum < numElements) { + // delay a bit to prevent busy waiting + Thread.sleep(1); + } else { + // cause a failure to retain the last checkpoint + throw new Exception("Job finished normally"); + } + } + } + } + + @Override + public void cancel() { + isRunning = false; + } + + + /** + * Gets the number of elements produced in total by this function. + * + * @return The number of elements produced in total. + */ + public int getNumElements() { + return numElements; + } + + /** + * Gets the number of elements emitted so far. + * + * @return The number of elements emitted so far. + */ + public int getNumElementsEmitted() { + return numElementsEmitted; + } + + // ------------------------------------------------------------------------ + // Checkpointing + // ------------------------------------------------------------------------ + + @Override + public void snapshotState(FunctionSnapshotContext context) throws Exception { + Preconditions.checkState( + this.checkpointedState != null, + "The " + getClass().getSimpleName() + " has not been properly initialized."); + + this.checkpointedState.clear(); + this.checkpointedState.add(this.numElementsEmitted); + long checkpointId = context.getCheckpointId(); + checkpointedEmittedNums.put(checkpointId, numElementsEmitted); + } + + @Override + public void notifyCheckpointComplete(long checkpointId) throws Exception { + numSuccessfulCheckpoints++; + lastCheckpointedEmittedNum = checkpointedEmittedNums.get(checkpointId); + } + + public static void reset() { + failedBefore = false; + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/DeduplicateITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/DeduplicateITCase.scala new file mode 100644 index 0000000000000..179394086772a --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/DeduplicateITCase.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.stream.sql + +import org.apache.flink.api.scala._ +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.runtime.utils.StreamingWithMiniBatchTestBase.MiniBatchMode +import org.apache.flink.table.runtime.utils._ +import org.apache.flink.table.runtime.utils.StreamingWithStateTestBase.StateBackendMode +import org.apache.flink.table.runtime.utils.TimeTestUtil.TimestampAndWatermarkWithOffset +import org.apache.flink.types.Row + +import org.junit.Assert._ +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized + +@RunWith(classOf[Parameterized]) +class DeduplicateITCase(miniBatch: MiniBatchMode, mode: StateBackendMode) + extends StreamingWithMiniBatchTestBase(miniBatch, mode) { + + @Test + def testFirstRowOnProctime(): Unit = { + val t = failingDataSource(StreamTestData.get3TupleData) + .toTable(tEnv, 'a, 'b, 'c, 'proctime) + tEnv.registerTable("T", t) + + val sql = + """ + |SELECT a, b, c + |FROM ( + | SELECT *, + | ROW_NUMBER() OVER (PARTITION BY b ORDER BY proctime) as rowNum + | FROM T + |) + |WHERE rowNum = 1 + """.stripMargin + + val sink = new TestingRetractSink + tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink) + env.execute() + + val expected = List("1,1,Hi", "2,2,Hello", "4,3,Hello world, how are you?", + "7,4,Comment#1", "11,5,Comment#5", "16,6,Comment#10") + assertEquals(expected.sorted, sink.getRetractResults.sorted) + } + + @Test + def testLastRowOnProctime(): Unit = { + val t = failingDataSource(StreamTestData.get3TupleData) + .toTable(tEnv, 'a, 'b, 'c, 'proctime) + tEnv.registerTable("T", t) + + val sql = + """ + |SELECT a, b, c + |FROM ( + | SELECT *, + | ROW_NUMBER() OVER (PARTITION BY b ORDER BY proctime DESC) as rowNum + | FROM T + |) + |WHERE rowNum = 1 + """.stripMargin + + val sink = new TestingRetractSink + tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink) + env.execute() + + val expected = List("1,1,Hi", "3,2,Hello world", "6,3,Luke Skywalker", + "10,4,Comment#4", "15,5,Comment#9", "21,6,Comment#15") + assertEquals(expected.sorted, sink.getRetractResults.sorted) + } + +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/RankITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/RankITCase.scala new file mode 100644 index 0000000000000..ff046bc90a390 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/RankITCase.scala @@ -0,0 +1,1310 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.stream.sql + +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.api.scala._ +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.api.{TableConfigOptions, TableException} +import org.apache.flink.table.runtime.utils.StreamingWithStateTestBase.StateBackendMode +import org.apache.flink.table.runtime.utils.{TestingRetractTableSink, TestingUpsertTableSink, _} +import org.apache.flink.types.Row + +import org.junit.Assert._ +import org.junit.{Ignore, _} +import org.junit.runner.RunWith +import org.junit.runners.Parameterized + +@RunWith(classOf[Parameterized]) +class RankITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode) { + + @Test + def testTopN(): Unit = { + val data = List( + ("book", 1, 12), + ("book", 2, 19), + ("book", 4, 11), + ("fruit", 4, 33), + ("fruit", 3, 44), + ("fruit", 5, 22)) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT * + |FROM ( + | SELECT category, shopId, num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num DESC) as rank_num + | FROM T) + |WHERE rank_num <= 2 + """.stripMargin + + val sink = new TestingRetractSink + tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink).setParallelism(1) + env.execute() + val expected = List( + "book,2,19,1", + "book,1,12,2", + "fruit,3,44,1", + "fruit,4,33,2") + assertEquals(expected.sorted, sink.getRetractResults.sorted) + } + + @Test + def testTopNth(): Unit = { + val data = List( + ("book", 1, 12), + ("book", 2, 19), + ("book", 4, 11), + ("fruit", 4, 33), + ("fruit", 3, 44), + ("fruit", 5, 22)) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT * + |FROM ( + | SELECT category, shopId, num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num DESC) as rank_num + | FROM T) + |WHERE rank_num = 2 + """.stripMargin + + val sink = new TestingRetractSink + tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink).setParallelism(1) + env.execute() + + val expected = List( + "book,1,12,2", + "fruit,4,33,2") + assertEquals(expected.sorted, sink.getRetractResults.sorted) + } + + @Test + def testTopNWithUpsertSink(): Unit = { + val data = List( + ("book", 1, 12), + ("book", 2, 19), + ("book", 4, 11), + ("fruit", 4, 33), + ("fruit", 3, 44), + ("fruit", 5, 22)) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT * + |FROM ( + | SELECT category, shopId, num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num ASC) as rank_num + | FROM T) + |WHERE rank_num <= 2 + """.stripMargin + + val table = tEnv.sqlQuery(sql) + val schema = table.getSchema + val sink = new TestingUpsertTableSink(Array(0, 3)). + configure(schema.getFieldNames, schema.getFieldTypes) + tEnv.writeToSink(table, sink) + env.execute() + + val expected = List( + "book,4,11,1", + "book,1,12,2", + "fruit,5,22,1", + "fruit,4,33,2") + assertEquals(expected.sorted, sink.getUpsertResults.sorted) + } + + // FIXME + @Ignore("Enable after agg implements ExecNode and SortedMapState is supported") + @Test(expected = classOf[TableException]) + def testTopNWithUnary(): Unit = { + val data = List( + ("book", 11, 100), + ("book", 11, 200), + ("book", 12, 400), + ("book", 12, 500), + ("book", 10, 600), + ("book", 10, 700), + ("book", 9, 800), + ("book", 9, 900), + ("book", 10, 500), + ("book", 8, 110), + ("book", 8, 120), + ("book", 7, 1800), + ("book", 9, 300), + ("book", 6, 1900), + ("book", 7, 50), + ("book", 11, 1800), + ("book", 7, 50), + ("book", 8, 2000), + ("book", 6, 700), + ("book", 5, 800), + ("book", 4, 910), + ("book", 3, 1000), + ("book", 2, 1100), + ("book", 1, 1200) + ) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT * + |FROM ( + | SELECT category, shopId, num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num ASC) as rank_num + | FROM ( + | SELECT category, shopId, sum(num) as num + | FROM T + | GROUP BY category, shopId + | )) + |WHERE rank_num <= 3 + """.stripMargin + + val table = tEnv.sqlQuery(sql) + val schema = table.getSchema + val sink = new TestingUpsertTableSink(Array(0, 3)). + configure(schema.getFieldNames, schema.getFieldTypes) + tEnv.writeToSink(table, sink) + env.execute() + + val updatedExpected = List( + "book,5,800,1", + "book,12,900,2", + "book,4,910,3") + + assertEquals(updatedExpected.sorted, sink.getUpsertResults.sorted) + } + + // FIXME + @Ignore("Enable after agg implements ExecNode and SortedMapState is supported") + @Test + def testUnarySortTopNOnString(): Unit = { + val data = List( + ("book", 11, "100"), + ("book", 11, "200"), + ("book", 12, "400"), + ("book", 12, "600"), + ("book", 10, "600"), + ("book", 10, "700"), + ("book", 9, "800"), + ("book", 9, "900"), + ("book", 10, "500"), + ("book", 8, "110"), + ("book", 8, "120"), + ("book", 7, "812"), + ("book", 9, "300"), + ("book", 6, "900"), + ("book", 7, "50"), + ("book", 11, "800"), + ("book", 7, "50"), + ("book", 8, "200"), + ("book", 6, "700"), + ("book", 5, "800"), + ("book", 4, "910"), + ("book", 3, "110"), + ("book", 2, "900"), + ("book", 1, "700") + ) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'price) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT * + |FROM ( + | SELECT category, shopId, max_price, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY max_price ASC) as rank_num + | FROM ( + | SELECT category, shopId, max(price) as max_price + | FROM T + | GROUP BY category, shopId + | )) + |WHERE rank_num <= 3 + """.stripMargin + + val table = tEnv.sqlQuery(sql) + val schema = table.getSchema + val sink = new TestingUpsertTableSink(Array(0, 3)). + configure(schema.getFieldNames, schema.getFieldTypes) + tEnv.writeToSink(table, sink) + env.execute() + + val updatedExpected = List( + "book,3,110,1", + "book,8,200,2", + "book,12,600,3") + + assertEquals(updatedExpected.sorted, sink.getUpsertResults.sorted) + } + + // FIXME + @Ignore("Enable after agg implements ExecNode") + @Test + def testTopNWithGroupBy(): Unit = { + val data = List( + ("book", 1, 11), + ("book", 2, 19), + ("book", 4, 13), + ("book", 1, 11), + ("fruit", 4, 33), + ("fruit", 5, 12), + ("fruit", 3, 44), + ("fruit", 5, 22)) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT * + |FROM ( + | SELECT category, shopId, num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num DESC) as rank_num + | FROM ( + | SELECT category, shopId, sum(num) as num + | FROM T + | GROUP BY category, shopId + | )) + |WHERE rank_num <= 2 + """.stripMargin + + val table = tEnv.sqlQuery(sql) + val schema = table.getSchema + val sink = new TestingUpsertTableSink(Array(0, 3)). + configure(schema.getFieldNames, schema.getFieldTypes) + tEnv.writeToSink(table, sink) + env.execute() + + val updatedExpected = List( + "book,1,22,1", + "book,2,19,2", + "fruit,3,44,1", + "fruit,5,34,2") + assertEquals(updatedExpected.sorted, sink.getUpsertResults.sorted) + } + + // FIXME + @Ignore("Enable after agg implements ExecNode") + @Test + def testTopNWithSumAndCondition(): Unit = { + val data = List( + Row.of("book", Int.box(11), Double.box(100)), + Row.of("book", Int.box(11), Double.box(200)), + Row.of("book", Int.box(12), Double.box(400)), + Row.of("book", Int.box(12), Double.box(500)), + Row.of("book", Int.box(10), Double.box(600)), + Row.of("book", Int.box(10), Double.box(700))) + + implicit val tpe: TypeInformation[Row] = new RowTypeInfo( + BasicTypeInfo.STRING_TYPE_INFO, + BasicTypeInfo.INT_TYPE_INFO, + BasicTypeInfo.DOUBLE_TYPE_INFO) // tpe is automatically + + val ds = env.fromCollection(data) + val t = ds.toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", t) + + val subquery = + """ + |SELECT category, shopId, sum(num) as sum_num + |FROM T + |WHERE num >= cast(1.1 as double) + |GROUP BY category, shopId + """.stripMargin + + val sql = + s""" + |SELECT * + |FROM ( + | SELECT category, shopId, sum_num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY sum_num DESC) as rank_num + | FROM ($subquery)) + |WHERE rank_num <= 2 + """.stripMargin + + val table = tEnv.sqlQuery(sql) + val schema = table.getSchema + val sink = new TestingUpsertTableSink(Array(0, 3)). + configure(schema.getFieldNames, schema.getFieldTypes) + tEnv.writeToSink(table, sink) + env.execute() + + val updatedExpected = List( + "book,10,1300.0,1", + "book,12,900.0,2") + + assertEquals(updatedExpected.sorted, sink.getUpsertResults.sorted) + } + + // FIXME + @Ignore("Enable after agg implements ExecNode") + @Test + def testTopNthWithGroupBy(): Unit = { + val data = List( + ("book", 1, 11), + ("book", 2, 19), + ("book", 4, 13), + ("book", 1, 11), + ("fruit", 4, 33), + ("fruit", 5, 12), + ("fruit", 3, 44), + ("fruit", 5, 22)) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT * + |FROM ( + | SELECT category, shopId, num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num DESC) as rank_num + | FROM ( + | SELECT category, shopId, sum(num) as num + | FROM T + | GROUP BY category, shopId + | )) + |WHERE rank_num = 2 + """.stripMargin + + val table = tEnv.sqlQuery(sql) + val schema = table.getSchema + val sink = new TestingUpsertTableSink(Array(0, 3)). + configure(schema.getFieldNames, schema.getFieldTypes) + tEnv.writeToSink(table, sink) + env.execute() + + val updatedExpected = List( + "book,2,19,2", + "fruit,5,34,2") + + assertEquals(updatedExpected.sorted, sink.getUpsertResults.sorted) + } + + // FIXME + @Ignore("Enable after agg implements ExecNode") + @Test + def testTopNWithGroupByAndRetract(): Unit = { + val data = List( + ("book", 1, 11), + ("book", 2, 19), + ("book", 4, 13), + ("book", 1, 11), + ("fruit", 4, 33), + ("fruit", 5, 12), + ("fruit", 3, 44), + ("fruit", 5, 22)) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT * + |FROM ( + | SELECT category, shopId, num, cnt, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num DESC, cnt ASC) as rank_num + | FROM ( + | SELECT category, shopId, sum(num) as num, count(num) as cnt + | FROM T + | GROUP BY category, shopId + | )) + |WHERE rank_num <= 2 + """.stripMargin + + val sink = new TestingRetractSink + tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink).setParallelism(1) + env.execute() + + val expected = List( + "book,1,22,2,1", + "book,2,19,1,2", + "fruit,3,44,1,1", + "fruit,5,34,2,2") + assertEquals(expected.sorted, sink.getRetractResults.sorted) + } + + // FIXME + @Ignore("Enable after agg implements ExecNode") + @Test + def testTopNthWithGroupByAndRetract(): Unit = { + val data = List( + ("book", 1, 11), + ("book", 2, 19), + ("book", 4, 13), + ("book", 1, 11), + ("fruit", 4, 33), + ("fruit", 5, 12), + ("fruit", 3, 44), + ("fruit", 5, 22)) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT * + |FROM ( + | SELECT category, shopId, num, cnt, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num DESC, cnt ASC) as rank_num + | FROM ( + | SELECT category, shopId, sum(num) as num, count(num) as cnt + | FROM T + | GROUP BY category, shopId + | )) + |WHERE rank_num = 2 + """.stripMargin + + val sink = new TestingRetractSink + tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink).setParallelism(1) + env.execute() + + val expected = List( + "book,2,19,1,2", + "fruit,5,34,2,2") + assertEquals(expected.sorted, sink.getRetractResults.sorted) + } + + // FIXME + @Ignore("Enable after agg implements ExecNode") + @Test + def testTopNWithGroupByCount(): Unit = { + val data = List( + ("book", 1, 1001), + ("book", 2, 1002), + ("book", 4, 1003), + ("book", 1, 1004), + ("book", 1, 1005), + ("book", 3, 1006), + ("book", 2, 1007), + ("book", 4, 1008), + ("book", 1, 1009), + ("book", 4, 1010), + ("book", 4, 1012), + ("book", 4, 1012), + ("fruit", 4, 1013), + ("fruit", 5, 1014), + ("fruit", 3, 1015), + ("fruit", 4, 1017), + ("fruit", 5, 1018), + ("fruit", 5, 1016)) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'sellId) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT category, rank_num, sells, shopId + |FROM ( + | SELECT category, shopId, sells, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY sells DESC) as rank_num + | FROM ( + | SELECT category, shopId, count(sellId) as sells + | FROM T + | GROUP BY category, shopId + | )) + |WHERE rank_num <= 4 + """.stripMargin + + val table = tEnv.sqlQuery(sql) + val schema = table.getSchema + val sink = new TestingUpsertTableSink(Array(0, 1)). + configure(schema.getFieldNames, schema.getFieldTypes) + tEnv.writeToSink(table, sink) + env.execute() + + val expected = List( + "book,1,5,4", + "book,2,4,1", + "book,3,2,2", + "book,4,1,3", + "fruit,1,3,5", + "fruit,2,2,4", + "fruit,3,1,3") + assertEquals(expected.sorted, sink.getUpsertResults.sorted) + } + + // FIXME + @Ignore("Enable after agg implements ExecNode") + @Test + def testTopNthWithGroupByCount(): Unit = { + val data = List( + ("book", 1, 1001), + ("book", 2, 1002), + ("book", 4, 1003), + ("book", 1, 1004), + ("book", 1, 1005), + ("book", 3, 1006), + ("book", 2, 1007), + ("book", 4, 1008), + ("book", 1, 1009), + ("book", 4, 1010), + ("book", 4, 1012), + ("book", 4, 1012), + ("fruit", 4, 1013), + ("fruit", 5, 1014), + ("fruit", 3, 1015), + ("fruit", 4, 1017), + ("fruit", 5, 1018), + ("fruit", 5, 1016)) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'sellId) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT category, rank_num, sells, shopId + |FROM ( + | SELECT category, shopId, sells, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY sells DESC) as rank_num + | FROM ( + | SELECT category, shopId, count(sellId) as sells + | FROM T + | GROUP BY category, shopId + | )) + |WHERE rank_num = 3 + """.stripMargin + + val table = tEnv.sqlQuery(sql) + val schema = table.getSchema + val sink = new TestingUpsertTableSink(Array(0, 1)). + configure(schema.getFieldNames, schema.getFieldTypes) + tEnv.writeToSink(table, sink) + env.execute() + + val expected = List( + "book,3,2,2", + "fruit,3,1,3") + assertEquals(expected.sorted, sink.getUpsertResults.sorted) + } + + // FIXME + @Ignore("Enable after agg implements ExecNode") + @Test + def testNestedTopN(): Unit = { + val data = List( + ("book", "a", 1), + ("book", "b", 1), + ("book", "c", 1), + ("fruit", "a", 2), + ("book", "a", 1), + ("book", "d", 0), + ("book", "b", 3), + ("fruit", "b", 6), + ("book", "c", 1), + ("book", "e", 5), + ("book", "d", 4)) + + env.setParallelism(1) + val ds = failingDataSource(data).toTable(tEnv, 'cate, 'shopId, 'sells) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT rank_num, cate, shopId, sells, cnt + |FROM ( + | SELECT *, + | ROW_NUMBER() OVER (PARTITION BY cate ORDER BY sells DESC) as rank_num + | FROM ( + | SELECT cate, shopId, count(*) as cnt, max(sells) as sells + | FROM T + | GROUP BY cate, shopId + | )) + |WHERE rank_num <= 4 + """.stripMargin + + + val sql2 = + s""" + |SELECT rank_num, cate, shopId, sells, cnt + |FROM ( + | SELECT cate, shopId, sells, cnt, + | ROW_NUMBER() OVER (ORDER BY sells DESC) as rank_num + | FROM ($sql) + |) + |WHERE rank_num <= 4 + """.stripMargin + + val table = tEnv.sqlQuery(sql2) + val schema = table.getSchema + val sink = new TestingUpsertTableSink(Array(0)). + configure(schema.getFieldNames, schema.getFieldTypes) + tEnv.writeToSink(table, sink) + env.execute() + + val expected = List( + "(true,1,book,a,1,1)", "(true,2,book,b,1,1)", "(true,3,book,c,1,1)", + "(true,1,fruit,a,2,1)", "(true,2,book,a,1,1)", "(true,3,book,b,1,1)", "(true,4,book,c,1,1)", + "(true,2,book,a,1,2)", + "(true,1,book,b,3,2)", "(true,2,fruit,a,2,1)", "(true,3,book,a,1,2)", + "(true,3,book,a,1,2)", + "(true,1,fruit,b,6,1)", "(true,2,book,b,3,2)", "(true,3,fruit,a,2,1)", "(true,4,book,a,1,2)", + "(true,3,fruit,a,2,1)", + "(true,2,book,e,5,1)", + "(true,3,book,b,3,2)", "(true,4,fruit,a,2,1)", + "(true,3,book,b,3,2)", + "(true,3,book,d,4,2)", + "(true,4,book,b,3,2)", + "(true,4,book,b,3,2)") + assertEquals(expected.mkString("\n"), sink.getRawResults.mkString("\n")) + + val expected2 = List("1,fruit,b,6,1", "2,book,e,5,1", "3,book,d,4,2", "4,book,b,3,2") + assertEquals(expected2, sink.getUpsertResults.sorted) + } + + // FIXME + @Ignore("Enable after agg implements ExecNode") + @Test + def testTopNWithoutDeduplicate(): Unit = { + val data = List( + ("book", "a", 1), + ("book", "b", 1), + ("book", "c", 1), + ("fruit", "a", 2), + ("book", "a", 1), + ("book", "d", 0), + ("book", "b", 3), + ("fruit", "b", 6), + ("book", "c", 1), + ("book", "e", 5), + ("book", "d", 4)) + + env.setParallelism(1) + val ds = failingDataSource(data).toTable(tEnv, 'cate, 'shopId, 'sells) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT rank_num, cate, shopId, sells, cnt + |FROM ( + | SELECT *, + | ROW_NUMBER() OVER (PARTITION BY cate ORDER BY sells DESC) as rank_num + | FROM ( + | SELECT cate, shopId, count(*) as cnt, max(sells) as sells + | FROM T + | GROUP BY cate, shopId + | )) + |WHERE rank_num <= 4 + """.stripMargin + + tEnv.getConfig.getConf.setLong(TableConfigOptions.SQL_EXEC_TOPN_CACHE_SIZE, 1) + + val table = tEnv.sqlQuery(sql) + val schema = table.getSchema + val sink = new TestingUpsertTableSink(Array(0)). + configure(schema.getFieldNames, schema.getFieldTypes) + tEnv.writeToSink(table, sink) + env.execute() + + val expected = List( + "(true,1,book,a,1,1)", + "(true,2,book,b,1,1)", + "(true,3,book,c,1,1)", + "(true,1,fruit,a,2,1)", + "(true,1,book,a,1,2)", + "(true,4,book,d,0,1)", + "(true,1,book,b,3,2)", + "(true,2,book,a,1,2)", + "(true,1,fruit,b,6,1)", + "(true,2,fruit,a,2,1)", + "(true,3,book,c,1,2)", + "(true,1,book,e,5,1)", + "(true,2,book,b,3,2)", + "(true,3,book,a,1,2)", + "(true,4,book,c,1,2)", + "(true,2,book,d,4,2)", + "(true,3,book,b,3,2)", + "(true,4,book,a,1,2)") + + assertEquals(expected, sink.getRawResults) + } + + // FIXME + @Ignore("Enable after agg implements ExecNode") + @Test + def testTopNWithVariableTopSize(): Unit = { + val data = List( + ("book", 1, 1001, 4), + ("book", 2, 1002, 4), + ("book", 4, 1003, 4), + ("book", 1, 1004, 4), + ("book", 1, 1005, 4), + ("book", 3, 1006, 4), + ("book", 2, 1007, 4), + ("book", 4, 1008, 4), + ("book", 1, 1009, 4), + ("book", 4, 1010, 4), + ("book", 4, 1012, 4), + ("book", 4, 1012, 4), + ("fruit", 4, 1013, 2), + ("fruit", 5, 1014, 2), + ("fruit", 3, 1015, 2), + ("fruit", 4, 1017, 2), + ("fruit", 5, 1018, 2), + ("fruit", 5, 1016, 2)) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'sellId, 'topSize) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT category, rank_num, sells, shopId + |FROM ( + | SELECT category, shopId, sells, topSize, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY sells DESC) as rank_num + | FROM ( + | SELECT category, shopId, count(sellId) as sells, max(topSize) as topSize + | FROM T + | GROUP BY category, shopId + | )) + |WHERE rank_num <= topSize + """.stripMargin + + val table = tEnv.sqlQuery(sql) + val schema = table.getSchema + val sink = new TestingUpsertTableSink(Array(0, 1)). + configure(schema.getFieldNames, schema.getFieldTypes) + tEnv.writeToSink(table, sink) + env.execute() + + val expected = List( + "book,1,5,4", + "book,2,4,1", + "book,3,2,2", + "book,4,1,3", + "fruit,1,3,5", + "fruit,2,2,4") + assertEquals(expected.sorted, sink.getUpsertResults.sorted) + } + + // FIXME + @Ignore("Enable after agg implements ExecNode") + @Test + def testTopNUnaryComplexScenario(): Unit = { + val data = List( + ("book", 1, 11), + ("book", 2, 19), + ("book", 4, 13), + ("book", 1, 11), // backward update in heap + ("book", 3, 23), // elems exceed topn size after insert + ("book", 5, 19), // sort map shirk out some elem after insert + ("book", 7, 10), // sort map keeps a little more than topn size elems after insert + ("book", 8, 13), // sort map now can shrink out-of-range elems after another insert + ("book", 10, 13), // Once again, sort map keeps a little more elems after insert + ("book", 8, 6), // backward update from heap to state + ("book", 10, 6), // backward update from heap to state, and sort map load more data + ("book", 5, 3), // backward update from heap to state + ("book", 10, 1), // backward update in heap, and then sort map shrink some data + ("book", 5, 1), // backward update in state + ("book", 5, -3), // forward update in state + ("book", 2, -10), // forward update in heap, and then sort map shrink some data + ("book", 10, -7), // forward update from state to heap + ("book", 11, 13), // insert into heap + ("book", 12, 10), // insert into heap, and sort map shrink some data + ("book", 15, 14) // insert into state + ) + + env.setParallelism(1) + + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT * + |FROM ( + | SELECT category, shopId, num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num ASC) as rank_num + | FROM ( + | SELECT category, shopId, sum(num) as num + | FROM T + | GROUP BY category, shopId + | )) + |WHERE rank_num <= 3 + """.stripMargin + + val table = tEnv.sqlQuery(sql) + val schema = table.getSchema + val sink = new TestingUpsertTableSink(Array(0, 3)). + configure(schema.getFieldNames, schema.getFieldTypes) + tEnv.writeToSink(table, sink) + env.execute() + + val expected = List( + "(true,book,1,11,1)", + "(true,book,2,19,2)", + "(true,book,4,13,2)", + "(true,book,2,19,3)", + + "(true,book,4,13,1)", + "(true,book,2,19,2)", + "(true,book,1,22,3)", + + "(true,book,5,19,3)", + + "(true,book,7,10,1)", + "(true,book,4,13,2)", + "(true,book,2,19,3)", + + "(true,book,8,13,3)", + + "(true,book,10,13,3)", + + "(true,book,2,19,3)", + + "(true,book,2,9,1)", + "(true,book,7,10,2)", + "(true,book,4,13,3)", + + "(true,book,12,10,3)") + + assertEquals(expected.mkString("\n"), sink.getRawResults.mkString("\n")) + + val updatedExpected = List( + "book,2,9,1", + "book,7,10,2", + "book,12,10,3") + + assertEquals(updatedExpected.sorted, sink.getUpsertResults.sorted) + } + + // FIXME + @Ignore("Enable after agg implements ExecNode") + @Test + def testTopNWithGroupByAvgWithoutRowNumber(): Unit = { + val data = List( + ("book", 1, 100), + ("book", 3, 110), + ("book", 4, 120), + ("book", 1, 200), + ("book", 1, 200), + ("book", 2, 300), + ("book", 2, 400), + ("book", 4, 500), + ("book", 1, 400), + ("fruit", 5, 100)) + + env.setParallelism(1) + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'sellId) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT category, shopId, avgSellId + |FROM ( + | SELECT category, shopId, avgSellId, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY avgSellId DESC) as rank_num + | FROM ( + | SELECT category, shopId, AVG(sellId) as avgSellId + | FROM T + | GROUP BY category, shopId + | )) + |WHERE rank_num <= 3 + """.stripMargin + + val table = tEnv.sqlQuery(sql) + val schema = table.getSchema + val sink = new TestingUpsertTableSink(Array(0, 1)). + configure(schema.getFieldNames, schema.getFieldTypes) + tEnv.writeToSink(table, sink) + env.execute() + + val expected = List( + "(true,book,1,100.0)", + "(true,book,3,110.0)", + "(true,book,4,120.0)", + "(true,book,1,150.0)", + "(true,book,1,166.66666666666666)", + "(true,book,2,300.0)", + "(false,book,3,110.0)", + "(true,book,2,350.0)", + "(true,book,4,310.0)", + "(true,book,1,225.0)", + "(true,fruit,5,100.0)") + + assertEquals(expected, sink.getRawResults) + + val updatedExpected = List( + "book,1,225.0", + "book,2,350.0", + "book,4,310.0", + "fruit,5,100.0") + + assertEquals(updatedExpected.sorted, sink.getUpsertResults.sorted) + } + + // FIXME + @Ignore("Enable after agg implements ExecNode") + @Test + def testTopNWithGroupByCountWithoutRowNumber(): Unit = { + val data = List( + ("book", 1, 1001), + ("book", 3, 1006), + ("book", 4, 1003), + ("book", 1, 1004), + ("book", 1, 1005), + ("book", 2, 1002), + ("book", 2, 1007), + ("book", 4, 1008), + ("book", 1, 1009), + ("book", 4, 1010), + ("book", 4, 1012), + ("book", 4, 1012), + ("fruit", 4, 1013), + ("fruit", 5, 1014), + ("fruit", 3, 1015), + ("fruit", 4, 1017), + ("fruit", 5, 1018), + ("fruit", 5, 1016)) + + env.setParallelism(1) + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'sellId) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT category, shopId, sells + |FROM ( + | SELECT category, shopId, sells, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY sells DESC) as rank_num + | FROM ( + | SELECT category, shopId, count(sellId) as sells + | FROM T + | GROUP BY category, shopId + | )) + |WHERE rank_num <= 3 + """.stripMargin + + val table = tEnv.sqlQuery(sql) + val schema = table.getSchema + val sink = new TestingUpsertTableSink(Array(0, 1)). + configure(schema.getFieldNames, schema.getFieldTypes) + tEnv.writeToSink(table, sink) + env.execute() + + val expected = List( + "(true,book,1,1)", + "(true,book,3,1)", + "(true,book,4,1)", + "(true,book,1,2)", + "(true,book,1,3)", + "(true,book,2,2)", + "(false,book,4,1)", + "(true,book,4,2)", + "(false,book,3,1)", + "(true,book,1,4)", + "(true,book,4,3)", + "(true,book,4,4)", + "(true,book,4,5)", + "(true,fruit,4,1)", + "(true,fruit,5,1)", + "(true,fruit,3,1)", + "(true,fruit,4,2)", + "(true,fruit,5,2)", + "(true,fruit,5,3)") + assertEquals(expected, sink.getRawResults) + + val updatedExpected = List( + "book,4,5", + "book,1,4", + "book,2,2", + "fruit,5,3", + "fruit,4,2", + "fruit,3,1") + assertEquals(updatedExpected.sorted, sink.getUpsertResults.sorted) + } + + @Test + def testTopNWithoutRowNumber(): Unit = { + val data = List( + ("book", 1, 12), + ("book", 2, 19), + ("book", 4, 11), + ("book", 5, 20), + ("fruit", 4, 33), + ("fruit", 3, 44), + ("fruit", 5, 22), + ("fruit", 1, 40)) + + env.setParallelism(1) + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val sql = + """ + |SELECT category, num, shopId + |FROM ( + | SELECT category, shopId, num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY num DESC) as rank_num + | FROM T) + |WHERE rank_num <= 2 + """.stripMargin + + val table = tEnv.sqlQuery(sql) + val schema = table.getSchema + val sink = new TestingUpsertTableSink(Array(0, 2)). + configure(schema.getFieldNames, schema.getFieldTypes) + tEnv.writeToSink(table, sink) + env.execute() + + val expected = List( + "(true,book,12,1)", + "(true,book,19,2)", + "(false,book,12,1)", + "(true,book,20,5)", + "(true,fruit,33,4)", + "(true,fruit,44,3)", + "(false,fruit,33,4)", + "(true,fruit,40,1)") + assertEquals(expected, sink.getRawResults) + + val updatedExpected = List( + "book,19,2", + "book,20,5", + "fruit,40,1", + "fruit,44,3") + assertEquals(updatedExpected.sorted, sink.getUpsertResults.sorted) + } + + // FIXME + @Ignore("Enable after agg implements ExecNode") + @Test + def testMultipleRetractTopNAfterAgg(): Unit = { + val data = List( + ("book", 1, 12), + ("book", 1, 13), + ("book", 2, 19), + ("book", 4, 11), + ("fruit", 4, 33), + ("fruit", 3, 44), + ("fruit", 5, 22)) + + env.setParallelism(1) + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val subquery = + s""" + |SELECT category, shopId, SUM(num) as sum_num, MAX(num) as max_num, + | AVG(num) as avg_num, COUNT(num) as cnt + |FROM T + |GROUP BY category, shopId + |""".stripMargin + + val t1 = tEnv.sqlQuery(subquery) + tEnv.registerTable("MyView", t1) + + val sink1 = new TestingRetractSink + tEnv.sqlQuery( + s""" + |SELECT * + |FROM ( + | SELECT category, shopId, sum_num, avg_num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY sum_num DESC, avg_num ASC + | ) as rank_num + | FROM MyView) + |WHERE rank_num <= 2 + |""".stripMargin).toRetractStream[Row].addSink(sink1).setParallelism(1) + env.execute() + + val expected1 = List( + "book,1,25,12.5,1", + "book,2,19,19.0,2", + "fruit,3,44,44.0,1", + "fruit,4,33,33.0,2") + assertEquals(expected1.sorted, sink1.getRetractResults.sorted) + + val sink2 = new TestingRetractSink + tEnv.sqlQuery( + s""" + |SELECT * + |FROM ( + | SELECT category, shopId, max_num, cnt, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY max_num DESC, cnt ASC) as rank_num + | FROM MyView) + |WHERE rank_num <= 2 + |""".stripMargin).toRetractStream[Row].addSink(sink2).setParallelism(1) + env.execute() + + val expected2 = List( + "book,2,19,1,1", + "book,1,13,2,2", + "fruit,3,44,1,1", + "fruit,4,33,1,2") + assertEquals(expected2.sorted, sink2.getRetractResults.sorted) + } + + // FIXME + @Ignore("Enable after agg implements ExecNode") + @Test + def testMultipleUnaryTopNAfterAgg(): Unit = { + val data = List( + ("book", 1, 12), + ("book", 1, 13), + ("book", 2, 19), + ("book", 4, 11), + ("fruit", 4, 33), + ("fruit", 3, 44), + ("fruit", 5, 22)) + + env.setParallelism(1) + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val subquery = + s""" + |SELECT category, shopId, SUM(num) as sum_num, MAX(num) as max_num + |FROM T + |GROUP BY category, shopId + |""".stripMargin + + val t1 = tEnv.sqlQuery(subquery) + tEnv.registerTable("MyView", t1) + + val table1 = tEnv.sqlQuery( + s""" + |SELECT * + |FROM ( + | SELECT category, shopId, sum_num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY sum_num DESC) as rank_num + | FROM MyView) + |WHERE rank_num <= 2 + |""".stripMargin) + val schema1 = table1.getSchema + val sink1 = new TestingUpsertTableSink(Array(0, 3)). + configure(schema1.getFieldNames, schema1 + .getFieldTypes) + tEnv.writeToSink(table1, sink1) + + val table2 = tEnv.sqlQuery( + s""" + |SELECT * + |FROM ( + | SELECT category, shopId, max_num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY max_num DESC) as rank_num + | FROM MyView) + |WHERE rank_num <= 2 + |""".stripMargin) + val schema2 = table2.getSchema + val sink2 = new TestingUpsertTableSink(Array(0, 3)). + configure(schema2.getFieldNames, schema2 + .getFieldTypes) + tEnv.writeToSink(table2, sink2) + + env.execute() + val expected1 = List( + "book,1,25,1", + "book,2,19,2", + "fruit,3,44,1", + "fruit,4,33,2") + assertEquals(expected1.sorted, sink1.getUpsertResults.sorted) + + val expected2 = List( + "book,2,19,1", + "book,1,13,2", + "fruit,3,44,1", + "fruit,4,33,2") + assertEquals(expected2.sorted, sink2.getUpsertResults.sorted) + } + + // FIXME + @Ignore("Enable after agg implements ExecNode") + @Test + def testMultipleUpdateTopNAfterAgg(): Unit = { + val data = List( + ("book", 1, 12), + ("book", 1, 13), + ("book", 2, 19), + ("book", 4, 11), + ("fruit", 4, 33), + ("fruit", 3, 44), + ("fruit", 5, 22)) + + env.setParallelism(1) + val ds = failingDataSource(data).toTable(tEnv, 'category, 'shopId, 'num) + tEnv.registerTable("T", ds) + + val subquery = + s""" + |SELECT category, shopId, COUNT(num) as cnt_num, MAX(num) as max_num + |FROM T + |GROUP BY category, shopId + |""".stripMargin + + val t1 = tEnv.sqlQuery(subquery) + tEnv.registerTable("MyView", t1) + + val table1 = tEnv.sqlQuery( + s""" + |SELECT * + |FROM ( + | SELECT category, shopId, cnt_num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY cnt_num DESC) as rank_num + | FROM MyView) + |WHERE rank_num <= 2 + |""".stripMargin) + val schema1 = table1.getSchema + val sink1 = new TestingRetractTableSink(). + configure(schema1.getFieldNames, schema1.getFieldTypes) + tEnv.writeToSink(table1, sink1) + + val table2 = tEnv.sqlQuery( + s""" + |SELECT * + |FROM ( + | SELECT category, shopId, max_num, + | ROW_NUMBER() OVER (PARTITION BY category ORDER BY max_num DESC) as rank_num + | FROM MyView) + |WHERE rank_num <= 2 + |""".stripMargin) + val schema2 = table2.getSchema + val sink2 = new TestingRetractTableSink(). + configure(schema2.getFieldNames, schema2.getFieldTypes) + tEnv.writeToSink(table2, sink2) + env.execute() + + val expected1 = List( + "book,1,2,1", + "book,2,1,2", + "fruit,4,1,1", + "fruit,3,1,2") + assertEquals(expected1.sorted, sink1.getRetractResults.sorted) + + val expected2 = List( + "book,2,19,1", + "book,1,13,2", + "fruit,3,44,1", + "fruit,4,33,2") + assertEquals(expected2.sorted, sink2.getRetractResults.sorted) + } + +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamTestSink.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamTestSink.scala index 95138c28edf6d..e91a1523b3ac9 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamTestSink.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamTestSink.scala @@ -19,23 +19,26 @@ package org.apache.flink.table.runtime.utils import org.apache.flink.api.common.ExecutionConfig +import org.apache.flink.api.common.functions.MapFunction import org.apache.flink.api.common.io.OutputFormat import org.apache.flink.api.common.state.{ListState, ListStateDescriptor} import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2} -import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.api.java.typeutils.{RowTypeInfo, TupleTypeInfo} import org.apache.flink.configuration.Configuration import org.apache.flink.runtime.state.{FunctionInitializationContext, FunctionSnapshotContext} import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction import org.apache.flink.streaming.api.datastream.{DataStream, DataStreamSink} import org.apache.flink.streaming.api.functions.sink.RichSinkFunction import org.apache.flink.table.api.{TableConfig, Types} -import org.apache.flink.table.dataformat.BaseRow -import org.apache.flink.table.sinks.{AppendStreamTableSink, BatchTableSink} +import org.apache.flink.table.dataformat.{BaseRow, DataFormatConverters, GenericRow} +import org.apache.flink.table.sinks._ +import org.apache.flink.table.`type`.TypeConverters.createInternalTypeFromTypeInfo import org.apache.flink.table.typeutils.BaseRowTypeInfo import org.apache.flink.table.util.BaseRowTestUtil import org.apache.flink.types.Row +import _root_.java.lang.{Boolean => JBoolean} import _root_.java.util.TimeZone import _root_.java.util.concurrent.atomic.AtomicInteger @@ -58,7 +61,7 @@ object StreamTestSink { private[utils] def getNewSinkId: Int = { val idx = idCounter.getAndIncrement() - this.synchronized{ + this.synchronized { globalResults.put(idx, mutable.HashMap.empty[Int, ArrayBuffer[String]]) globalRetractResults.put(idx, mutable.HashMap.empty[Int, ArrayBuffer[String]]) globalUpsertResults.put(idx, mutable.HashMap.empty[Int, mutable.Map[String, String]]) @@ -78,7 +81,7 @@ abstract class AbstractExactlyOnceSink[T] extends RichSinkFunction[T] with Check protected var localResults: ArrayBuffer[String] = _ protected val idx: Int = StreamTestSink.getNewSinkId - protected var globalResults: mutable.Map[Int, ArrayBuffer[String]]= _ + protected var globalResults: mutable.Map[Int, ArrayBuffer[String]] = _ protected var globalRetractResults: mutable.Map[Int, ArrayBuffer[String]] = _ protected var globalUpsertResults: mutable.Map[Int, mutable.Map[String, String]] = _ @@ -109,7 +112,7 @@ abstract class AbstractExactlyOnceSink[T] extends RichSinkFunction[T] with Check protected def clearAndStashGlobalResults(): Unit = { if (globalResults == null) { - StreamTestSink.synchronized{ + StreamTestSink.synchronized { globalResults = StreamTestSink.globalResults.remove(idx).get globalRetractResults = StreamTestSink.globalRetractResults.remove(idx).get globalUpsertResults = StreamTestSink.globalUpsertResults.remove(idx).get @@ -146,12 +149,166 @@ final class TestingAppendSink(tz: TimeZone) extends AbstractExactlyOnceSink[Row] def this() { this(TimeZone.getTimeZone("UTC")) } + def invoke(value: Row): Unit = localResults += TestSinkUtil.rowToString(value, tz) + def getAppendResults: List[String] = getResults } +final class TestingUpsertSink(keys: Array[Int], tz: TimeZone) + extends AbstractExactlyOnceSink[(Boolean, BaseRow)] { + + private var upsertResultsState: ListState[String] = _ + private var localUpsertResults: mutable.Map[String, String] = _ + private var fieldTypes: Array[TypeInformation[_]] = _ + + def this(keys: Array[Int]) { + this(keys, TimeZone.getTimeZone("UTC")) + } + + def configureTypes(fieldTypes: Array[TypeInformation[_]]): Unit = { + this.fieldTypes = fieldTypes + } + + override def initializeState(context: FunctionInitializationContext): Unit = { + super.initializeState(context) + upsertResultsState = context.getOperatorStateStore.getListState( + new ListStateDescriptor[String]("sink-upsert-results", Types.STRING)) + + localUpsertResults = mutable.HashMap.empty[String, String] + + if (context.isRestored) { + var key: String = null + var value: String = null + for (entry <- upsertResultsState.get().asScala) { + if (key == null) { + key = entry + } else { + value = entry + localUpsertResults += (key -> value) + key = null + value = null + } + } + if (key != null) { + throw new RuntimeException("The resultState is corrupt.") + } + } + + val taskId = getRuntimeContext.getIndexOfThisSubtask + StreamTestSink.synchronized { + StreamTestSink.globalUpsertResults(idx) += (taskId -> localUpsertResults) + } + } + + override def snapshotState(context: FunctionSnapshotContext): Unit = { + super.snapshotState(context) + upsertResultsState.clear() + for ((key, value) <- localUpsertResults) { + upsertResultsState.add(key) + upsertResultsState.add(value) + } + } + + def invoke(d: (Boolean, BaseRow)): Unit = { + this.synchronized { + val wrapRow = new GenericRow(2) + wrapRow.setField(0, d._1) + wrapRow.setField(1, d._2) + val converter = + DataFormatConverters.getConverterForTypeInfo( + new TupleTypeInfo(Types.BOOLEAN, new RowTypeInfo(fieldTypes: _*))) + .asInstanceOf[DataFormatConverters.DataFormatConverter[BaseRow, JTuple2[JBoolean, Row]]] + val v = converter.toExternal(wrapRow) + val rowString = TestSinkUtil.rowToString(v.f1, tz) + val tupleString = "(" + v.f0.toString + "," + rowString + ")" + localResults += tupleString + val keyString = TestSinkUtil.rowToString(Row.project(v.f1, keys), tz) + if (v.f0) { + localUpsertResults += (keyString -> rowString) + } else { + val oldValue = localUpsertResults.remove(keyString) + if (oldValue.isEmpty) { + throw new RuntimeException("Tried to delete a value that wasn't inserted first. " + + "This is probably an incorrectly implemented test. " + + "Try to set the parallelism of the sink to 1.") + } + } + } + } + + def getRawResults: List[String] = getResults + + def getUpsertResults: List[String] = { + clearAndStashGlobalResults() + val result = ArrayBuffer.empty[String] + this.globalUpsertResults.foreach { + case (_, map) => map.foreach(result += _._2) + } + result.toList + } +} + +final class TestingUpsertTableSink(keys: Array[Int], tz: TimeZone) + extends UpsertStreamTableSink[BaseRow] { + var fNames: Array[String] = _ + var fTypes: Array[TypeInformation[_]] = _ + var sink = new TestingUpsertSink(keys, tz) + + def this(keys: Array[Int]) { + this(keys, TimeZone.getTimeZone("UTC")) + } + + override def setKeyFields(keys: Array[String]): Unit = { + // ignore + } + + override def setIsAppendOnly(isAppendOnly: JBoolean): Unit = { + // ignore + } + + override def getRecordType: TypeInformation[BaseRow] = + new BaseRowTypeInfo(fTypes.map(createInternalTypeFromTypeInfo(_)), fNames) + + override def getFieldNames: Array[String] = fNames + + override def getFieldTypes: Array[TypeInformation[_]] = fTypes + + override def emitDataStream(dataStream: DataStream[JTuple2[JBoolean, BaseRow]]) = { + dataStream.map(new MapFunction[JTuple2[JBoolean, BaseRow], (Boolean, BaseRow)] { + override def map(value: JTuple2[JBoolean, BaseRow]): (Boolean, BaseRow) = { + (value.f0, value.f1) + } + }) + .addSink(sink) + .name(s"TestingUpsertTableSink(keys=${ + if (keys != null) { + "(" + keys.mkString(",") + ")" + } else { + "null" + } + })") + .setParallelism(1) + } + + override def configure( + fieldNames: Array[String], + fieldTypes: Array[TypeInformation[_]]): TestingUpsertTableSink = { + val copy = new TestingUpsertTableSink(keys, tz) + copy.fNames = fieldNames + copy.fTypes = fieldTypes + sink.configureTypes(fieldTypes) + copy.sink = sink + copy + } + + def getRawResults: List[String] = sink.getRawResults + + def getUpsertResults: List[String] = sink.getUpsertResults +} + final class TestingAppendTableSink(tz: TimeZone) extends AppendStreamTableSink[Row] - with BatchTableSink[Row]{ + with BatchTableSink[Row] { var fNames: Array[String] = _ var fTypes: Array[TypeInformation[_]] = _ var sink = new TestingAppendSink(tz) @@ -163,7 +320,7 @@ final class TestingAppendTableSink(tz: TimeZone) extends AppendStreamTableSink[R override def emitDataStream(dataStream: DataStream[Row]): DataStreamSink[Row] = { dataStream.addSink(sink).name("TestingAppendTableSink") - .setParallelism(dataStream.getParallelism) + .setParallelism(dataStream.getParallelism) } override def emitBoundedStream( @@ -211,23 +368,25 @@ class TestingOutputFormat[T](tz: TimeZone) def open(taskNumber: Int, numTasks: Int): Unit = { localRetractResults = mutable.ArrayBuffer.empty[String] - StreamTestSink.synchronized{ + StreamTestSink.synchronized { StreamTestSink.globalResults(index) += (taskNumber -> localRetractResults) } } - def writeRecord(value: T): Unit = localRetractResults += { value match { - case r: Row => TestSinkUtil.rowToString(r, tz) - case tp: JTuple2[java.lang.Boolean, Row] => - "(" + tp.f0.toString + "," + TestSinkUtil.rowToString(tp.f1, tz) + ")" - case _ => "" - }} + def writeRecord(value: T): Unit = localRetractResults += { + value match { + case r: Row => TestSinkUtil.rowToString(r, tz) + case tp: JTuple2[java.lang.Boolean, Row] => + "(" + tp.f0.toString + "," + TestSinkUtil.rowToString(tp.f1, tz) + ")" + case _ => "" + } + } def close(): Unit = {} protected def clearAndStashGlobalResults(): Unit = { if (globalResults == null) { - StreamTestSink.synchronized{ + StreamTestSink.synchronized { globalResults = StreamTestSink.globalResults.remove(index).get } } @@ -242,3 +401,118 @@ class TestingOutputFormat[T](tz: TimeZone) result.toList } } + +class TestingRetractSink(tz: TimeZone) + extends AbstractExactlyOnceSink[(Boolean, Row)] { + protected var retractResultsState: ListState[String] = _ + protected var localRetractResults: ArrayBuffer[String] = _ + + def this() { + this(TimeZone.getTimeZone("UTC")) + } + + override def initializeState(context: FunctionInitializationContext): Unit = { + super.initializeState(context) + retractResultsState = context.getOperatorStateStore.getListState( + new ListStateDescriptor[String]("sink-retract-results", Types.STRING)) + + localRetractResults = mutable.ArrayBuffer.empty[String] + + if (context.isRestored) { + for (value <- retractResultsState.get().asScala) { + localRetractResults += value + } + } + + val taskId = getRuntimeContext.getIndexOfThisSubtask + StreamTestSink.synchronized { + StreamTestSink.globalRetractResults(idx) += (taskId -> localRetractResults) + } + } + + override def snapshotState(context: FunctionSnapshotContext): Unit = { + super.snapshotState(context) + retractResultsState.clear() + for (value <- localRetractResults) { + retractResultsState.add(value) + } + } + + def invoke(v: (Boolean, Row)): Unit = { + this.synchronized { + val tupleString = "(" + v._1.toString + "," + TestSinkUtil.rowToString(v._2, tz) + ")" + localResults += tupleString + val rowString = TestSinkUtil.rowToString(v._2, tz) + if (v._1) { + localRetractResults += rowString + } else { + val index = localRetractResults.indexOf(rowString) + if (index >= 0) { + localRetractResults.remove(index) + } else { + throw new RuntimeException("Tried to retract a value that wasn't added first. " + + "This is probably an incorrectly implemented test. " + + "Try to set the parallelism of the sink to 1.") + } + } + } + } + + def getRawResults: List[String] = getResults + + def getRetractResults: List[String] = { + clearAndStashGlobalResults() + val result = ArrayBuffer.empty[String] + this.globalRetractResults.foreach { + case (_, list) => result ++= list + } + result.toList + } +} + +final class TestingRetractTableSink(tz: TimeZone) extends RetractStreamTableSink[Row] { + + var fNames: Array[String] = _ + var fTypes: Array[TypeInformation[_]] = _ + var sink = new TestingRetractSink(tz) + + def this() { + this(TimeZone.getTimeZone("UTC")) + } + + override def emitDataStream(dataStream: DataStream[JTuple2[JBoolean, Row]]) = { + dataStream.map(new MapFunction[JTuple2[JBoolean, Row], (Boolean, Row)] { + override def map(value: JTuple2[JBoolean, Row]): (Boolean, Row) = { + (value.f0, value.f1) + } + }).setParallelism(dataStream.getParallelism) + .addSink(sink) + .name("TestingRetractTableSink") + .setParallelism(dataStream.getParallelism) + } + + override def getRecordType: TypeInformation[Row] = + new RowTypeInfo(fTypes, fNames) + + override def getFieldNames: Array[String] = fNames + + override def getFieldTypes: Array[TypeInformation[_]] = fTypes + + override def configure( + fieldNames: Array[String], + fieldTypes: Array[TypeInformation[_]]): TestingRetractTableSink = { + val copy = new TestingRetractTableSink(tz) + copy.fNames = fieldNames + copy.fTypes = fieldTypes + copy.sink = sink + copy + } + + def getRawResults: List[String] = { + sink.getRawResults + } + + def getRetractResults: List[String] = { + sink.getRetractResults + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingWithMiniBatchTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingWithMiniBatchTestBase.scala new file mode 100644 index 0000000000000..56c2593cd9b48 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingWithMiniBatchTestBase.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.utils + +import org.apache.flink.table.api.TableConfigOptions +import org.apache.flink.table.runtime.utils.StreamingWithStateTestBase.{HEAP_BACKEND, ROCKSDB_BACKEND, StateBackendMode} +import org.apache.flink.table.runtime.utils.StreamingWithMiniBatchTestBase.{MiniBatchMode, MiniBatchOff, MiniBatchOn} + +import java.util + +import scala.collection.JavaConversions._ + +import org.junit.runners.Parameterized + +abstract class StreamingWithMiniBatchTestBase( + miniBatch: MiniBatchMode, + state: StateBackendMode) + extends StreamingWithStateTestBase(state) { + + override def before(): Unit = { + super.before() + // set mini batch + val tableConfig = tEnv.getConfig + miniBatch match { + case MiniBatchOn => + tableConfig.getConf.setLong(TableConfigOptions.SQL_EXEC_MINIBATCH_ALLOW_LATENCY, 1000L) + tableConfig.getConf.setLong(TableConfigOptions.SQL_EXEC_MINIBATCH_SIZE, 3L) + case MiniBatchOff => + tableConfig.getConf.removeConfig(TableConfigOptions.SQL_EXEC_MINIBATCH_ALLOW_LATENCY) + } + } +} + +object StreamingWithMiniBatchTestBase { + + case class MiniBatchMode(on: Boolean) { + override def toString: String = { + if (on) { + "MiniBatch=ON" + } else { + "MiniBatch=OFF" + } + } + } + + val MiniBatchOff = MiniBatchMode(false) + val MiniBatchOn = MiniBatchMode(true) + + @Parameterized.Parameters(name = "{0}, StateBackend={1}") + def parameters(): util.Collection[Array[java.lang.Object]] = { + Seq[Array[AnyRef]]( + Array(MiniBatchOff, HEAP_BACKEND), + Array(MiniBatchOff, ROCKSDB_BACKEND), + Array(MiniBatchOn, HEAP_BACKEND), + Array(MiniBatchOn, ROCKSDB_BACKEND)) + } + +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingWithStateTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingWithStateTestBase.scala new file mode 100644 index 0000000000000..9a3836ba76703 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamingWithStateTestBase.scala @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.utils + +import org.apache.flink.api.common.restartstrategy.RestartStrategies +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.common.typeutils.CompositeType +import org.apache.flink.configuration.{CheckpointingOptions, Configuration} +import org.apache.flink.runtime.state.memory.MemoryStateBackend +import org.apache.flink.streaming.api.CheckpointingMode +import org.apache.flink.streaming.api.functions.source.FromElementsFunction +import org.apache.flink.streaming.api.scala.DataStream +import org.apache.flink.table.api.{TableEnvironment, Types} +import org.apache.flink.table.dataformat.{BaseRow, BinaryRow, BinaryRowWriter, BinaryString} +import org.apache.flink.table.runtime.utils.StreamingWithStateTestBase.{HEAP_BACKEND, ROCKSDB_BACKEND, StateBackendMode} +import org.apache.flink.table.`type`.TypeConverters.createInternalTypeFromTypeInfo +import org.apache.flink.table.`type`.RowType + +import scala.collection.JavaConversions._ +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import java.io.File +import java.util + +import org.junit.runners.Parameterized +import org.junit.{After, Assert, Before} + +import org.apache.flink.contrib.streaming.state.RocksDBStateBackend + +class StreamingWithStateTestBase(state: StateBackendMode) extends StreamingTestBase { + + enableObjectReuse = state match { + case HEAP_BACKEND => false // TODO gemini not support obj reuse now. + case ROCKSDB_BACKEND => true + } + + private val classLoader = Thread.currentThread.getContextClassLoader + + var baseCheckpointPath: File = _ + + @Before + override def before(): Unit = { + super.before() + // set state backend + baseCheckpointPath = tempFolder.newFolder().getAbsoluteFile + state match { + case HEAP_BACKEND => + val conf = new Configuration() + conf.setBoolean(CheckpointingOptions.ASYNC_SNAPSHOTS, true) + env.setStateBackend(new MemoryStateBackend( + "file://" + baseCheckpointPath, null).configure(conf, classLoader)) + case ROCKSDB_BACKEND => + val conf = new Configuration() + conf.setBoolean(CheckpointingOptions.INCREMENTAL_CHECKPOINTS, true) + env.setStateBackend(new RocksDBStateBackend( + "file://" + baseCheckpointPath).configure(conf, classLoader)) + } + this.tEnv = TableEnvironment.getTableEnvironment(env) + FailingCollectionSource.failedBefore = true + } + + @After + def after(): Unit = { + Assert.assertTrue(FailingCollectionSource.failedBefore) + } + + /** + * Creates a BinaryRow DataStream from the given non-empty [[Seq]]. + */ + def failingBinaryRowSource[T: TypeInformation](data: Seq[T]): DataStream[BaseRow] = { + val typeInfo = implicitly[TypeInformation[_]].asInstanceOf[CompositeType[_]] + val result = new mutable.MutableList[BaseRow] + val reuse = new BinaryRow(typeInfo.getArity) + val writer = new BinaryRowWriter(reuse) + data.foreach { + case p: Product => + for (i <- 0 until typeInfo.getArity) { + val fieldType = typeInfo.getTypeAt(i).asInstanceOf[TypeInformation[_]] + fieldType match { + case Types.INT => writer.writeInt(i, p.productElement(i).asInstanceOf[Int]) + case Types.LONG => writer.writeLong(i, p.productElement(i).asInstanceOf[Long]) + case Types.STRING => writer.writeString(i, + p.productElement(i).asInstanceOf[BinaryString]) + case Types.BOOLEAN => writer.writeBoolean(i, p.productElement(i).asInstanceOf[Boolean]) + } + } + writer.complete() + result += reuse.copy() + case _ => throw new UnsupportedOperationException + } + val newTypeInfo = createInternalTypeFromTypeInfo(typeInfo).asInstanceOf[RowType].toTypeInfo + failingDataSource(result)(newTypeInfo.asInstanceOf[TypeInformation[BaseRow]]) + } + + /** + * Creates a DataStream from the given non-empty [[Seq]]. + */ + def failingDataSource[T: TypeInformation](data: Seq[T]): DataStream[T] = { + env.enableCheckpointing(100, CheckpointingMode.EXACTLY_ONCE) + env.setRestartStrategy(RestartStrategies.fixedDelayRestart(1, 0)) + env.setParallelism(1) + // reset failedBefore flag to false + FailingCollectionSource.reset() + + require(data != null, "Data must not be null.") + val typeInfo = implicitly[TypeInformation[T]] + + val collection = scala.collection.JavaConversions.asJavaCollection(data) + // must not have null elements and mixed elements + FromElementsFunction.checkCollection(data, typeInfo.getTypeClass) + + val function = new FailingCollectionSource[T]( + typeInfo.createSerializer(env.getConfig), + collection, + data.length / 2, // fail after half elements + false) + + env.addSource(function)(typeInfo).setMaxParallelism(1) + } + + private def mapStrEquals(str1: String, str2: String): Boolean = { + val array1 = str1.toCharArray + val array2 = str2.toCharArray + if (array1.length != array2.length) { + return false + } + val l = array1.length + val leftBrace = "{".charAt(0) + val rightBrace = "}".charAt(0) + val equalsChar = "=".charAt(0) + val lParenthesis = "(".charAt(0) + val rParenthesis = ")".charAt(0) + val dot = ",".charAt(0) + val whiteSpace = " ".charAt(0) + val map1 = Map[String, String]() + val map2 = Map[String, String]() + var idx = 0 + def findEquals(ss: CharSequence): Array[Int] = { + val ret = new ArrayBuffer[Int]() + (0 until ss.length) foreach (idx => if (ss.charAt(idx) == equalsChar) ret += idx) + ret.toArray + } + + def splitKV(ss: CharSequence, equalsIdx: Int): (String, String) = { + // find right, if starts with '(' find until the ')', else until the ',' + var endFlag = false + var curIdx = equalsIdx + 1 + var endChar = if (ss.charAt(curIdx) == lParenthesis) rParenthesis else dot + var valueStr: CharSequence = null + var keyStr: CharSequence = null + while (curIdx < ss.length && !endFlag) { + val curChar = ss.charAt(curIdx) + if (curChar != endChar && curChar != rightBrace) { + curIdx += 1 + if (curIdx == ss.length) { + valueStr = ss.subSequence(equalsIdx + 1, curIdx) + } + } else { + valueStr = ss.subSequence(equalsIdx + 1, curIdx) + endFlag = true + } + } + + // find left, if starts with ')' find until the '(', else until the ' ,' + endFlag = false + curIdx = equalsIdx - 1 + endChar = if (ss.charAt(curIdx) == rParenthesis) lParenthesis else whiteSpace + while (curIdx >= 0 && !endFlag) { + val curChar = ss.charAt(curIdx) + if (curChar != endChar && curChar != leftBrace) { + curIdx -= 1 + if (curIdx == -1) { + keyStr = ss.subSequence(0, equalsIdx) + } + } else { + keyStr = ss.subSequence(curIdx, equalsIdx) + endFlag = true + } + } + require(keyStr != null) + require(valueStr != null) + (keyStr.toString, valueStr.toString) + } + + def appendStrToMap(ss: CharSequence, m: Map[String, String]): Unit = { + val equalsIdxs = findEquals(ss) + equalsIdxs.foreach(idx => m + splitKV(ss, idx)) + } + + while (idx < l) { + val char1 = array1(idx) + val char2 = array2(idx) + if (char1 != char2) { + return false + } + + if (char1 == leftBrace) { + val rightBraceIdx = array1.subSequence(idx + 1, l).toString.indexOf(rightBrace) + appendStrToMap(array1.subSequence(idx + 1, rightBraceIdx + idx + 2), map1) + idx += rightBraceIdx + } else { + idx += 1 + } + } + map1.equals(map2) + } + + def assertMapStrEquals(str1: String, str2: String): Unit = { + if (!mapStrEquals(str1, str2)) { + throw new AssertionError(s"Expected: $str1 \n Actual: $str2") + } + } +} + +object StreamingWithStateTestBase { + + case class StateBackendMode(backend: String) { + override def toString: String = backend.toString + } + + val HEAP_BACKEND = StateBackendMode("HEAP") + val ROCKSDB_BACKEND = StateBackendMode("ROCKSDB") + + @Parameterized.Parameters(name = "StateBackend={0}") + def parameters(): util.Collection[Array[java.lang.Object]] = { + Seq[Array[AnyRef]](Array(HEAP_BACKEND), Array(ROCKSDB_BACKEND)) + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TableUtil.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TableUtil.scala index 5ec370015c4b1..4595ac81841bd 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TableUtil.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TableUtil.scala @@ -36,28 +36,29 @@ object TableUtil { * Note: The difference between print() and collect() is * - print() prints data on workers and collect() collects data to the client. * - You have to call TableEnvironment.execute() to run the job for print(), while collect() - * calls execute automatically. + * calls execute automatically. */ def collect(table: TableImpl): Seq[Row] = collectSink(table, new CollectRowTableSink, None) def collect(table: TableImpl, jobName: String): Seq[Row] = collectSink(table, new CollectRowTableSink, Option.apply(jobName)) - def collectAsT[T](table: TableImpl, t: TypeInformation[_], jobName : String = null): Seq[T] = + def collectAsT[T](table: TableImpl, t: TypeInformation[_], jobName: String = null): Seq[T] = collectSink( table, new CollectTableSink(_ => t.asInstanceOf[TypeInformation[T]]), Option(jobName)) def collectSink[T]( - table: TableImpl, sink: CollectTableSink[T], jobName : Option[String] = None): Seq[T] = { + table: TableImpl, sink: CollectTableSink[T], jobName: Option[String] = None): Seq[T] = { // get schema information of table val rowType = table.getRelNode.getRowType val fieldNames = rowType.getFieldNames.asScala.toArray val fieldTypes = rowType.getFieldList - .map(field => FlinkTypeFactory.toInternalType(field.getType)).toArray + .map(field => FlinkTypeFactory.toInternalType(field.getType)).toArray val configuredSink = sink.configure( fieldNames, fieldTypes.map(createExternalTypeInfoFromInternalType)) BatchTableEnvUtil.collect(table.tableEnv.asInstanceOf[BatchTableEnvironment], table, configuredSink.asInstanceOf[CollectTableSink[T]], jobName) } + } diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TimeTestUtil.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TimeTestUtil.scala new file mode 100644 index 0000000000000..f93c2c2425b69 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TimeTestUtil.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.utils + +import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks +import org.apache.flink.streaming.api.functions.source.SourceFunction +import org.apache.flink.streaming.api.functions.source.SourceFunction.SourceContext +import org.apache.flink.streaming.api.operators.{AbstractStreamOperator, OneInputStreamOperator} +import org.apache.flink.streaming.api.watermark.Watermark +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord + +object TimeTestUtil { + + class EventTimeSourceFunction[T]( + dataWithTimestampList: Seq[Either[(Long, T), Long]]) extends SourceFunction[T] { + + override def run(ctx: SourceContext[T]): Unit = { + dataWithTimestampList.foreach { + case Left(t) => ctx.collectWithTimestamp(t._2, t._1) + case Right(w) => ctx.emitWatermark(new Watermark(w)) + } + } + + override def cancel(): Unit = ??? + } + + class TimestampAndWatermarkWithOffset[T <: Product]( + offset: Long) extends AssignerWithPunctuatedWatermarks[T] { + + override def checkAndGetNextWatermark(lastElement: T, extractedTimestamp: Long): Watermark = { + new Watermark(extractedTimestamp - offset) + } + + override def extractTimestamp(element: T, previousElementTimestamp: Long): Long = { + element.productElement(0).asInstanceOf[Long] + } + } + + class EventTimeProcessOperator[T] + extends AbstractStreamOperator[T] + with OneInputStreamOperator[Either[(Long, T), Long], T] { + + override def processElement(element: StreamRecord[Either[(Long, T), Long]]): Unit = { + element.getValue match { + case Left(t) => output.collect(new StreamRecord[T](t._2, t._1)) + case Right(w) => output.emitWatermark(new Watermark(w)) + } + } + + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/api/TableConfigOptions.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/api/TableConfigOptions.java index c472c1f934a18..9a47a3b738d13 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/api/TableConfigOptions.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/api/TableConfigOptions.java @@ -133,6 +133,15 @@ public class TableConfigOptions { .defaultValue(512) .withDescription("Sets the max buffer memory size for sort. It defines the upper memory for the sort."); + // ------------------------------------------------------------------------ + // topN Options + // ------------------------------------------------------------------------ + + public static final ConfigOption SQL_EXEC_TOPN_CACHE_SIZE = + key("sql.exec.topn.cache.size") + .defaultValue(10000L) + .withDescription("Cache size of every topn task."); + // ------------------------------------------------------------------------ // MiniBatch Options // ------------------------------------------------------------------------ diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryRow.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryRow.java index 9be8db632248d..e52938960bd3c 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryRow.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryRow.java @@ -220,7 +220,7 @@ public void setDecimal(int pos, Decimal value, int precision) { } else { byte[] bytes = value.toUnscaledBytes(); - assert(bytes.length <= 16); + assert bytes.length <= 16; // Write the bytes to the variable length portion. SegmentsUtil.copyFromBytes(segments, offset + cursor, bytes, 0, bytes.length); @@ -428,4 +428,20 @@ public static String toOriginString(BaseRow row, InternalType[] types) { build.append(']'); return build.toString(); } + + public boolean equalsWithoutHeader(BaseRow o) { + return equalsFrom(o, 1); + } + + private boolean equalsFrom(Object o, int startIndex) { + if (o != null && o instanceof BinaryRow) { + BinaryRow other = (BinaryRow) o; + return sizeInBytes == other.sizeInBytes && + SegmentsUtil.equals( + segments, offset + startIndex, + other.segments, other.offset + startIndex, sizeInBytes - startIndex); + } else { + return false; + } + } } diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryWriter.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryWriter.java index b1730ec3c3892..aafcb3feeab9c 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryWriter.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/BinaryWriter.java @@ -18,12 +18,14 @@ package org.apache.flink.table.dataformat; import org.apache.flink.table.type.ArrayType; +import org.apache.flink.table.type.DateType; import org.apache.flink.table.type.DecimalType; import org.apache.flink.table.type.GenericType; import org.apache.flink.table.type.InternalType; import org.apache.flink.table.type.InternalTypes; import org.apache.flink.table.type.MapType; import org.apache.flink.table.type.RowType; +import org.apache.flink.table.type.TimestampType; import org.apache.flink.table.typeutils.BaseRowSerializer; /** @@ -98,11 +100,11 @@ static void write(BinaryWriter writer, int pos, Object o, InternalType type) { writer.writeString(pos, (BinaryString) o); } else if (type.equals(InternalTypes.CHAR)) { writer.writeChar(pos, (char) o); - } else if (type.equals(InternalTypes.DATE)) { + } else if (type instanceof DateType) { writer.writeInt(pos, (int) o); } else if (type.equals(InternalTypes.TIME)) { writer.writeInt(pos, (int) o); - } else if (type.equals(InternalTypes.TIMESTAMP)) { + } else if (type instanceof TimestampType) { writer.writeLong(pos, (long) o); } else if (type instanceof DecimalType) { DecimalType decimalType = (DecimalType) type; diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/DataFormatConverters.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/DataFormatConverters.java index 101ff35d16253..ebc59bc3f3bda 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/DataFormatConverters.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/DataFormatConverters.java @@ -146,9 +146,9 @@ public static DataFormatConverter getConverterForTypeInfo(TypeInformation typeIn } else if (typeInfo instanceof BinaryMapTypeInfo) { return BinaryMapConverter.INSTANCE; } else if (typeInfo instanceof BaseRowTypeInfo) { - return BaseRowConverter.INSTANCE; + return new BaseRowConverter(typeInfo.getArity()); } else if (typeInfo.equals(BasicTypeInfo.BIG_DEC_TYPE_INFO)) { - return BaseRowConverter.INSTANCE; + return new BaseRowConverter(typeInfo.getArity()); } else if (typeInfo instanceof DecimalTypeInfo) { DecimalTypeInfo decimalType = (DecimalTypeInfo) typeInfo; return new DecimalConverter(decimalType.precision(), decimalType.scale()); @@ -993,14 +993,13 @@ E toExternalImpl(BaseRow row, int column) { public static final class BaseRowConverter extends IdentityConverter { private static final long serialVersionUID = -4470307402371540680L; + private int arity; - public static final BaseRowConverter INSTANCE = new BaseRowConverter(); - - private BaseRowConverter() {} + private BaseRowConverter(int arity) {} @Override BaseRow toExternalImpl(BaseRow row, int column) { - throw new RuntimeException("Not support yet!"); + return row.getRow(column, arity); } } diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/TypeGetterSetters.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/TypeGetterSetters.java index f1e080ce89d57..ab263c40c3adc 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/TypeGetterSetters.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/TypeGetterSetters.java @@ -18,12 +18,14 @@ package org.apache.flink.table.dataformat; import org.apache.flink.table.type.ArrayType; +import org.apache.flink.table.type.DateType; import org.apache.flink.table.type.DecimalType; import org.apache.flink.table.type.GenericType; import org.apache.flink.table.type.InternalType; import org.apache.flink.table.type.InternalTypes; import org.apache.flink.table.type.MapType; import org.apache.flink.table.type.RowType; +import org.apache.flink.table.type.TimestampType; /** * Provide type specialized getters and setters to reduce if/else and eliminate box and unbox. @@ -193,11 +195,11 @@ static Object get(TypeGetterSetters row, int ordinal, InternalType type) { return row.getString(ordinal); } else if (type.equals(InternalTypes.CHAR)) { return row.getChar(ordinal); - } else if (type.equals(InternalTypes.DATE)) { + } else if (type instanceof DateType) { return row.getInt(ordinal); } else if (type.equals(InternalTypes.TIME)) { return row.getInt(ordinal); - } else if (type.equals(InternalTypes.TIMESTAMP)) { + } else if (type instanceof TimestampType) { return row.getLong(ordinal); } else if (type instanceof DecimalType) { DecimalType decimalType = (DecimalType) type; diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/util/BinaryRowUtil.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/util/BinaryRowUtil.java index 71dc6c2eacd7b..e50e5463c0327 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/util/BinaryRowUtil.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/util/BinaryRowUtil.java @@ -18,7 +18,9 @@ package org.apache.flink.table.dataformat.util; +import org.apache.flink.core.memory.MemorySegmentFactory; import org.apache.flink.core.memory.MemoryUtils; +import org.apache.flink.table.dataformat.BinaryRow; /** * Util for binary row. Many of the methods in this class are used in code generation. @@ -29,6 +31,14 @@ public class BinaryRowUtil { public static final sun.misc.Unsafe UNSAFE = MemoryUtils.UNSAFE; public static final int BYTE_ARRAY_BASE_OFFSET = UNSAFE.arrayBaseOffset(byte[].class); + public static final BinaryRow EMPTY_ROW = new BinaryRow(0); + + static { + int size = EMPTY_ROW.getFixedLengthPartSize(); + byte[] bytes = new byte[size]; + EMPTY_ROW.pointTo(MemorySegmentFactory.wrap(bytes), 0, size); + } + public static boolean byteArrayEquals(byte[] left, byte[] right, int length) { return byteArrayEquals( left, BYTE_ARRAY_BASE_OFFSET, right, BYTE_ARRAY_BASE_OFFSET, length); diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/bundle/AbstractMapBundleOperator.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/bundle/AbstractMapBundleOperator.java index d37aa31560c95..8f2c76363c008 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/bundle/AbstractMapBundleOperator.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/bundle/AbstractMapBundleOperator.java @@ -57,7 +57,7 @@ public abstract class AbstractMapBundleOperator private static final long serialVersionUID = 5081841938324118594L; /** The map in heap to store elements. */ - private final transient Map bundle; + private transient Map bundle; /** The trigger that determines how many elements should be put into a bundle. */ private final BundleTrigger bundleTrigger; @@ -74,7 +74,6 @@ public abstract class AbstractMapBundleOperator MapBundleFunction function, BundleTrigger bundleTrigger) { chainingStrategy = ChainingStrategy.ALWAYS; - this.bundle = new HashMap<>(); this.function = checkNotNull(function, "function is null"); this.bundleTrigger = checkNotNull(bundleTrigger, "bundleTrigger is null"); } @@ -86,6 +85,7 @@ public void open() throws Exception { this.numOfElements = 0; this.collector = new StreamRecordCollector<>(output); + this.bundle = new HashMap<>(); bundleTrigger.registerCallback(this); // reset trigger diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionHelper.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionHelper.java new file mode 100644 index 0000000000000..1c554bd5c1497 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateFunctionHelper.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.deduplicate; + +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.dataformat.util.BaseRowUtil; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +/** + * Utility for deduplicate function. + */ +class DeduplicateFunctionHelper { + + /** + * Processes element to deduplicate on keys, sends current element as last row, retracts previous element if + * needed. + * + * @param currentRow latest row received by deduplicate function + * @param generateRetraction whether need to send retract message to downstream + * @param state state of function + * @param out underlying collector + * @throws Exception + */ + static void processLastRow(BaseRow currentRow, boolean generateRetraction, ValueState state, + Collector out) throws Exception { + // Check message should be accumulate + Preconditions.checkArgument(BaseRowUtil.isAccumulateMsg(currentRow)); + if (generateRetraction) { + // state stores complete row if generateRetraction is true + BaseRow preRow = state.value(); + state.update(currentRow); + if (preRow != null) { + preRow.setHeader(BaseRowUtil.RETRACT_MSG); + out.collect(preRow); + } + } + out.collect(currentRow); + } + + /** + * Processes element to deduplicate on keys, sends current element if it is first row. + * + * @param currentRow latest row received by deduplicate function + * @param state state of function + * @param out underlying collector + * @throws Exception + */ + static void processFirstRow(BaseRow currentRow, ValueState state, Collector out) + throws Exception { + // Check message should be accumulate + Preconditions.checkArgument(BaseRowUtil.isAccumulateMsg(currentRow)); + // ignore record with timestamp bigger than preRow + if (state.value() != null) { + return; + } + state.update(true); + out.collect(currentRow); + } + + private DeduplicateFunctionHelper() { + + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateKeepFirstRowFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateKeepFirstRowFunction.java new file mode 100644 index 0000000000000..14feaf981fd32 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateKeepFirstRowFunction.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.deduplicate; + +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.runtime.functions.KeyedProcessFunctionWithCleanupState; +import org.apache.flink.util.Collector; + +import static org.apache.flink.table.runtime.deduplicate.DeduplicateFunctionHelper.processFirstRow; + +/** + * This function is used to deduplicate on keys and keeps only first row. + */ +public class DeduplicateKeepFirstRowFunction + extends KeyedProcessFunctionWithCleanupState { + + private static final long serialVersionUID = 5865777137707602549L; + + // state stores a boolean flag to indicate whether key appears before. + private ValueState state; + + public DeduplicateKeepFirstRowFunction(long minRetentionTime, long maxRetentionTime) { + super(minRetentionTime, maxRetentionTime); + } + + @Override + public void open(Configuration configure) throws Exception { + super.open(configure); + initCleanupTimeState("DeduplicateFunctionKeepFirstRow"); + ValueStateDescriptor stateDesc = new ValueStateDescriptor<>("existsState", Types.BOOLEAN); + state = getRuntimeContext().getState(stateDesc); + } + + @Override + public void processElement(BaseRow input, Context ctx, Collector out) throws Exception { + long currentTime = ctx.timerService().currentProcessingTime(); + // register state-cleanup timer + registerProcessingCleanupTimer(ctx, currentTime); + processFirstRow(input, state, out); + } + + @Override + public void onTimer(long timestamp, OnTimerContext ctx, Collector out) throws Exception { + if (stateCleaningEnabled) { + cleanupState(state); + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateKeepLastRowFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateKeepLastRowFunction.java new file mode 100644 index 0000000000000..3dcde66eb5ea6 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/DeduplicateKeepLastRowFunction.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.deduplicate; + +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.runtime.functions.KeyedProcessFunctionWithCleanupState; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; +import org.apache.flink.util.Collector; + +import static org.apache.flink.table.runtime.deduplicate.DeduplicateFunctionHelper.processLastRow; + +/** + * This function is used to deduplicate on keys and keeps only last row. + */ +public class DeduplicateKeepLastRowFunction + extends KeyedProcessFunctionWithCleanupState { + + private static final long serialVersionUID = -291348892087180350L; + private final BaseRowTypeInfo rowTypeInfo; + private final boolean generateRetraction; + + // state stores complete row. + private ValueState state; + + public DeduplicateKeepLastRowFunction(long minRetentionTime, long maxRetentionTime, BaseRowTypeInfo rowTypeInfo, + boolean generateRetraction) { + super(minRetentionTime, maxRetentionTime); + this.rowTypeInfo = rowTypeInfo; + this.generateRetraction = generateRetraction; + } + + @Override + public void open(Configuration configure) throws Exception { + super.open(configure); + if (generateRetraction) { + // state stores complete row if need generate retraction, otherwise do not need a state + initCleanupTimeState("DeduplicateFunctionKeepLastRow"); + ValueStateDescriptor stateDesc = new ValueStateDescriptor<>("preRowState", rowTypeInfo); + state = getRuntimeContext().getState(stateDesc); + } + } + + @Override + public void processElement(BaseRow input, Context ctx, Collector out) throws Exception { + if (generateRetraction) { + long currentTime = ctx.timerService().currentProcessingTime(); + // register state-cleanup timer + registerProcessingCleanupTimer(ctx, currentTime); + } + processLastRow(input, generateRetraction, state, out); + } + + @Override + public void onTimer(long timestamp, OnTimerContext ctx, Collector out) throws Exception { + if (stateCleaningEnabled) { + cleanupState(state); + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateKeepFirstRowFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateKeepFirstRowFunction.java new file mode 100644 index 0000000000000..b97d1a670ac3b --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateKeepFirstRowFunction.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.deduplicate; + +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.runtime.bundle.MapBundleFunction; +import org.apache.flink.table.runtime.context.ExecutionContext; +import org.apache.flink.util.Collector; + +import javax.annotation.Nullable; + +import java.util.Map; + +import static org.apache.flink.table.runtime.deduplicate.DeduplicateFunctionHelper.processFirstRow; + +/** + * This function is used to get the first row for every key partition in miniBatch mode. + */ +public class MiniBatchDeduplicateKeepFirstRowFunction + extends MapBundleFunction { + + private static final long serialVersionUID = -7994602893547654994L; + + private final TypeSerializer typeSerializer; + + // state stores a boolean flag to indicate whether key appears before. + private ValueState state; + + public MiniBatchDeduplicateKeepFirstRowFunction(TypeSerializer typeSerializer) { + this.typeSerializer = typeSerializer; + } + + @Override + public void open(ExecutionContext ctx) throws Exception { + super.open(ctx); + ValueStateDescriptor stateDesc = new ValueStateDescriptor<>("existsState", Types.BOOLEAN); + state = ctx.getRuntimeContext().getState(stateDesc); + } + + @Override + public BaseRow addInput(@Nullable BaseRow value, BaseRow input) { + if (value == null) { + // put the input into buffer + return typeSerializer.copy(input); + } else { + // the input is not first row, ignore it + return value; + } + } + + @Override + public void finishBundle( + Map buffer, Collector out) throws Exception { + for (Map.Entry entry : buffer.entrySet()) { + BaseRow currentKey = entry.getKey(); + BaseRow currentRow = entry.getValue(); + ctx.setCurrentKey(currentKey); + processFirstRow(currentRow, state, out); + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateKeepLastRowFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateKeepLastRowFunction.java new file mode 100644 index 0000000000000..c1f2ec40cf5b8 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateKeepLastRowFunction.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.deduplicate; + +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.runtime.bundle.MapBundleFunction; +import org.apache.flink.table.runtime.context.ExecutionContext; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; +import org.apache.flink.util.Collector; + +import javax.annotation.Nullable; + +import java.util.Map; + +import static org.apache.flink.table.runtime.deduplicate.DeduplicateFunctionHelper.processLastRow; + +/** + * This function is used to get the last row for every key partition in miniBatch mode. + */ +public class MiniBatchDeduplicateKeepLastRowFunction + extends MapBundleFunction { + + private static final long serialVersionUID = -8981813609115029119L; + + private final BaseRowTypeInfo rowTypeInfo; + private final boolean generateRetraction; + private final TypeSerializer typeSerializer; + + // state stores complete row. + private ValueState state; + + public MiniBatchDeduplicateKeepLastRowFunction(BaseRowTypeInfo rowTypeInfo, boolean generateRetraction, + TypeSerializer typeSerializer) { + this.rowTypeInfo = rowTypeInfo; + this.generateRetraction = generateRetraction; + this.typeSerializer = typeSerializer; + } + + @Override + public void open(ExecutionContext ctx) throws Exception { + super.open(ctx); + ValueStateDescriptor stateDesc = new ValueStateDescriptor<>("preRowState", rowTypeInfo); + state = ctx.getRuntimeContext().getState(stateDesc); + } + + @Override + public BaseRow addInput(@Nullable BaseRow value, BaseRow input) { + // always put the input into buffer + return typeSerializer.copy(input); + } + + @Override + public void finishBundle( + Map buffer, Collector out) throws Exception { + for (Map.Entry entry : buffer.entrySet()) { + BaseRow currentKey = entry.getKey(); + BaseRow currentRow = entry.getValue(); + ctx.setCurrentKey(currentKey); + processLastRow(currentRow, generateRetraction, state, out); + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/CleanupState.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/CleanupState.java new file mode 100644 index 0000000000000..ddb823c60b003 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/CleanupState.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.functions; + +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.streaming.api.TimerService; +import org.apache.flink.streaming.api.functions.ProcessFunction; +import org.apache.flink.streaming.api.functions.co.CoProcessFunction; + +/** + * Base interface for clean up state, both for {@link ProcessFunction} and {@link CoProcessFunction}. + */ +public interface CleanupState { + + default void registerProcessingCleanupTimer( + ValueState cleanupTimeState, + long currentTime, + long minRetentionTime, + long maxRetentionTime, + TimerService timerService) throws Exception { + + // last registered timer + Long curCleanupTime = cleanupTimeState.value(); + + // check if a cleanup timer is registered and + // that the current cleanup timer won't delete state we need to keep + if (curCleanupTime == null || (currentTime + minRetentionTime) > curCleanupTime) { + // we need to register a new (later) timer + long cleanupTime = currentTime + maxRetentionTime; + // register timer and remember clean-up time + timerService.registerProcessingTimeTimer(cleanupTime); + // delete expired timer + if (curCleanupTime != null) { + timerService.deleteProcessingTimeTimer(curCleanupTime); + } + cleanupTimeState.update(cleanupTime); + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/KeyedProcessFunctionWithCleanupState.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/KeyedProcessFunctionWithCleanupState.java new file mode 100644 index 0000000000000..c5797632f71c9 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/KeyedProcessFunctionWithCleanupState.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.functions; + +import org.apache.flink.api.common.state.State; +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.streaming.api.TimeDomain; +import org.apache.flink.streaming.api.functions.KeyedProcessFunction; + +/** + * A function that processes elements of a stream, and could cleanup state. + * @param Type of the key. + * @param Type of the input elements. + * @param Type of the output elements. + */ +public abstract class KeyedProcessFunctionWithCleanupState + extends KeyedProcessFunction implements CleanupState { + + private static final long serialVersionUID = 2084560869233898457L; + + protected final long minRetentionTime; + protected final long maxRetentionTime; + protected final boolean stateCleaningEnabled; + + // holds the latest registered cleanup timer + private ValueState cleanupTimeState; + + public KeyedProcessFunctionWithCleanupState(long minRetentionTime, long maxRetentionTime) { + this.minRetentionTime = minRetentionTime; + this.maxRetentionTime = maxRetentionTime; + this.stateCleaningEnabled = minRetentionTime > 1; + } + + protected void initCleanupTimeState(String stateName) { + if (stateCleaningEnabled) { + ValueStateDescriptor inputCntDescriptor = new ValueStateDescriptor<>(stateName, Types.LONG); + cleanupTimeState = getRuntimeContext().getState(inputCntDescriptor); + } + } + + protected void registerProcessingCleanupTimer(Context ctx, long currentTime) throws Exception { + if (stateCleaningEnabled) { + registerProcessingCleanupTimer( + cleanupTimeState, + currentTime, + minRetentionTime, + maxRetentionTime, + ctx.timerService() + ); + } + } + + protected boolean isProcessingTimeTimer(OnTimerContext ctx) { + return ctx.timeDomain() == TimeDomain.PROCESSING_TIME; + } + + protected void cleanupState(State... states) { + for (State state : states) { + state.clear(); + } + this.cleanupTimeState.clear(); + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/BaseRowKeySelector.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/BaseRowKeySelector.java new file mode 100644 index 0000000000000..2ad9e5dd2c393 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/BaseRowKeySelector.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.keyselector; + +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.typeutils.ResultTypeQueryable; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; + +/** + * BaseRowKeySelector takes an BaseRow and extracts the deterministic key for the BaseRow. + */ +public interface BaseRowKeySelector extends KeySelector, ResultTypeQueryable { + + BaseRowTypeInfo getProducedType(); + +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/BinaryRowKeySelector.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/BinaryRowKeySelector.java new file mode 100644 index 0000000000000..3d7887a57afcd --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/BinaryRowKeySelector.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.keyselector; + +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.dataformat.BinaryRow; +import org.apache.flink.table.generated.GeneratedProjection; +import org.apache.flink.table.generated.Projection; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; + +/** + * A utility class which extracts key from BaseRow. The key type is BinaryRow. + */ +public class BinaryRowKeySelector implements BaseRowKeySelector { + + private static final long serialVersionUID = 5375355285015381919L; + + private final BaseRowTypeInfo keyRowType; + private final GeneratedProjection generatedProjection; + private transient Projection projection; + + public BinaryRowKeySelector(BaseRowTypeInfo keyRowType, GeneratedProjection generatedProjection) { + this.keyRowType = keyRowType; + this.generatedProjection = generatedProjection; + } + + @Override + public BaseRow getKey(BaseRow value) throws Exception { + if (projection == null) { + projection = generatedProjection.newInstance(Thread.currentThread().getContextClassLoader()); + } + return projection.apply(value).copy(); + } + + @Override + public BaseRowTypeInfo getProducedType() { + return keyRowType; + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/NullBinaryRowKeySelector.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/NullBinaryRowKeySelector.java new file mode 100644 index 0000000000000..795bdbf5a3796 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/keyselector/NullBinaryRowKeySelector.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.keyselector; + +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.dataformat.util.BinaryRowUtil; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; + +/** + * A utility class which key is always empty no matter what the input row is. + */ +public class NullBinaryRowKeySelector implements BaseRowKeySelector { + + private static final long serialVersionUID = -2079386198687082032L; + + private final BaseRowTypeInfo returnType = new BaseRowTypeInfo(); + + @Override + public BaseRow getKey(BaseRow value) throws Exception { + return BinaryRowUtil.EMPTY_ROW; + } + + @Override + public BaseRowTypeInfo getProducedType() { + return returnType; + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractRankFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractRankFunction.java new file mode 100644 index 0000000000000..5c44133ded5a6 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AbstractRankFunction.java @@ -0,0 +1,311 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.metrics.Counter; +import org.apache.flink.metrics.Gauge; +import org.apache.flink.streaming.api.operators.KeyContext; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.dataformat.GenericRow; +import org.apache.flink.table.dataformat.JoinedRow; +import org.apache.flink.table.dataformat.util.BaseRowUtil; +import org.apache.flink.table.generated.GeneratedRecordComparator; +import org.apache.flink.table.generated.GeneratedRecordEqualiser; +import org.apache.flink.table.generated.RecordEqualiser; +import org.apache.flink.table.runtime.functions.KeyedProcessFunctionWithCleanupState; +import org.apache.flink.table.runtime.keyselector.BaseRowKeySelector; +import org.apache.flink.table.type.InternalType; +import org.apache.flink.table.type.InternalTypes; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; +import org.apache.flink.util.Collector; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collection; +import java.util.Comparator; +import java.util.Map; + +/** + * Base class for Rank Function. + */ +public abstract class AbstractRankFunction extends KeyedProcessFunctionWithCleanupState { + + private static final Logger LOG = LoggerFactory.getLogger(AbstractRankFunction.class); + + private static final String RANK_UNSUPPORTED_MSG = "RANK() on streaming table is not supported currently"; + + private static final String DENSE_RANK_UNSUPPORTED_MSG = "DENSE_RANK() on streaming table is not supported currently"; + + private static final String WITHOUT_RANK_END_UNSUPPORTED_MSG = "Rank end is not specified. Currently rank only support TopN, which means the rank end must be specified."; + + // we set default topN size to 100 + private static final long DEFAULT_TOPN_SIZE = 100; + + /** + * The util to compare two BaseRow equals to each other. + * As different BaseRow can't be equals directly, we use a code generated util to handle this. + */ + private GeneratedRecordEqualiser generatedEqualiser; + protected RecordEqualiser equaliser; + + /** + * The util to compare two sortKey equals to each other. + */ + private GeneratedRecordComparator generatedSortKeyComparator; + protected Comparator sortKeyComparator; + + private final boolean generateRetraction; + protected final boolean outputRankNumber; + protected final BaseRowTypeInfo inputRowType; + protected final KeySelector sortKeySelector; + + protected KeyContext keyContext; + private final boolean isConstantRankEnd; + private final long rankStart; + protected long rankEnd; + private final int rankEndIndex; + private ValueState rankEndState; + private Counter invalidCounter; + private JoinedRow outputRow; + + // metrics + protected long hitCount = 0L; + protected long requestCount = 0L; + + AbstractRankFunction(long minRetentionTime, long maxRetentionTime, BaseRowTypeInfo inputRowType, + GeneratedRecordComparator generatedSortKeyComparator, BaseRowKeySelector sortKeySelector, + RankType rankType, RankRange rankRange, GeneratedRecordEqualiser generatedEqualiser, + boolean generateRetraction, boolean outputRankNumber) { + super(minRetentionTime, maxRetentionTime); + // TODO support RANK and DENSE_RANK + switch (rankType) { + case ROW_NUMBER: + break; + case RANK: + LOG.error(RANK_UNSUPPORTED_MSG); + throw new UnsupportedOperationException(RANK_UNSUPPORTED_MSG); + case DENSE_RANK: + LOG.error(DENSE_RANK_UNSUPPORTED_MSG); + throw new UnsupportedOperationException(DENSE_RANK_UNSUPPORTED_MSG); + default: + LOG.error("Streaming tables do not support {}", rankType.name()); + throw new UnsupportedOperationException("Streaming tables do not support " + rankType.toString()); + } + + if (rankRange instanceof ConstantRankRange) { + ConstantRankRange constantRankRange = (ConstantRankRange) rankRange; + isConstantRankEnd = true; + rankStart = constantRankRange.getRankStart(); + rankEnd = constantRankRange.getRankEnd(); + rankEndIndex = -1; + } else if (rankRange instanceof VariableRankRange) { + VariableRankRange variableRankRange = (VariableRankRange) rankRange; + int rankEndIdx = variableRankRange.getRankEndIndex(); + InternalType rankEndIdxType = inputRowType.getInternalTypes()[rankEndIdx]; + if (!rankEndIdxType.equals(InternalTypes.LONG)) { + LOG.error("variable rank index column must be long type, while input type is {}", + rankEndIdxType.getClass().getName()); + throw new UnsupportedOperationException( + "variable rank index column must be long type, while input type is " + + rankEndIdxType.getClass().getName()); + } + rankEndIndex = rankEndIdx; + isConstantRankEnd = false; + rankStart = -1; + rankEnd = -1; + + } else { + LOG.error(WITHOUT_RANK_END_UNSUPPORTED_MSG); + throw new UnsupportedOperationException(WITHOUT_RANK_END_UNSUPPORTED_MSG); + } + this.generatedEqualiser = generatedEqualiser; + this.generatedSortKeyComparator = generatedSortKeyComparator; + this.generateRetraction = generateRetraction; + this.inputRowType = inputRowType; + this.outputRankNumber = outputRankNumber; + this.sortKeySelector = sortKeySelector; + } + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + initCleanupTimeState("RankFunctionCleanupTime"); + outputRow = new JoinedRow(); + + if (!isConstantRankEnd) { + ValueStateDescriptor rankStateDesc = new ValueStateDescriptor<>("rankEnd", Types.LONG); + rankEndState = getRuntimeContext().getState(rankStateDesc); + } + // compile equaliser + equaliser = generatedEqualiser.newInstance(getRuntimeContext().getUserCodeClassLoader()); + generatedEqualiser = null; + // compile comparator + sortKeyComparator = generatedSortKeyComparator.newInstance(getRuntimeContext().getUserCodeClassLoader()); + generatedSortKeyComparator = null; + invalidCounter = getRuntimeContext().getMetricGroup().counter("topn.invalidTopSize"); + } + + /** + * Gets default topN size. + * + * @return default topN size + */ + protected long getDefaultTopNSize() { + return isConstantRankEnd ? rankEnd : DEFAULT_TOPN_SIZE; + } + + /** + * Initialize rank end. + * + * @param row input record + * @return rank end + * @throws Exception + */ + protected long initRankEnd(BaseRow row) throws Exception { + if (isConstantRankEnd) { + return rankEnd; + } else { + Long rankEndValue = rankEndState.value(); + long curRankEnd = row.getLong(rankEndIndex); + if (rankEndValue == null) { + rankEnd = curRankEnd; + rankEndState.update(rankEnd); + return rankEnd; + } else { + rankEnd = rankEndValue; + if (rankEnd != curRankEnd) { + // increment the invalid counter when the current rank end not equal to previous rank end + invalidCounter.inc(); + } + return rankEnd; + } + } + } + + /** + * Checks whether the record should be put into the buffer. + * + * @param sortKey sortKey to test + * @param buffer buffer to add + * @return true if the record should be put into the buffer. + */ + protected boolean checkSortKeyInBufferRange(BaseRow sortKey, TopNBuffer buffer) { + Comparator comparator = buffer.getSortKeyComparator(); + Map.Entry> worstEntry = buffer.lastEntry(); + if (worstEntry == null) { + // return true if the buffer is empty. + return true; + } else { + BaseRow worstKey = worstEntry.getKey(); + int compare = comparator.compare(sortKey, worstKey); + if (compare < 0) { + return true; + } else { + return buffer.getCurrentTopNum() < getDefaultTopNSize(); + } + } + } + + protected void registerMetric(long heapSize) { + getRuntimeContext().getMetricGroup().>gauge( + "topn.cache.hitRate", + () -> requestCount == 0 ? 1.0 : Long.valueOf(hitCount).doubleValue() / requestCount); + + getRuntimeContext().getMetricGroup().>gauge( + "topn.cache.size", () -> heapSize); + } + + protected void collect(Collector out, BaseRow inputRow) { + BaseRowUtil.setAccumulate(inputRow); + out.collect(inputRow); + } + + /** + * This is similar to [[retract()]] but always send retraction message regardless of generateRetraction is true or + * not. + */ + protected void delete(Collector out, BaseRow inputRow) { + BaseRowUtil.setRetract(inputRow); + out.collect(inputRow); + } + + /** + * This is with-row-number version of above delete() method. + */ + protected void delete(Collector out, BaseRow inputRow, long rank) { + if (isInRankRange(rank)) { + out.collect(createOutputRow(inputRow, rank, BaseRowUtil.RETRACT_MSG)); + } + } + + protected void collect(Collector out, BaseRow inputRow, long rank) { + if (isInRankRange(rank)) { + out.collect(createOutputRow(inputRow, rank, BaseRowUtil.ACCUMULATE_MSG)); + } + } + + protected void retract(Collector out, BaseRow inputRow, long rank) { + if (generateRetraction && isInRankRange(rank)) { + out.collect(createOutputRow(inputRow, rank, BaseRowUtil.RETRACT_MSG)); + } + } + + protected boolean isInRankEnd(long rank) { + return rank <= rankEnd; + } + + protected boolean isInRankRange(long rank) { + return rank <= rankEnd && rank >= rankStart; + } + + protected boolean hasOffset() { + // rank start is 1-based + return rankStart > 1; + } + + private BaseRow createOutputRow(BaseRow inputRow, long rank, byte header) { + if (outputRankNumber) { + GenericRow rankRow = new GenericRow(1); + rankRow.setField(0, rank); + + outputRow.replace(inputRow, rankRow); + outputRow.setHeader(header); + return outputRow; + } else { + inputRow.setHeader(header); + return inputRow; + } + } + + /** + * Sets keyContext to RankFunction. + * + * @param keyContext keyContext of current function. + */ + public void setKeyContext(KeyContext keyContext) { + this.keyContext = keyContext; + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AppendRankFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AppendRankFunction.java new file mode 100644 index 0000000000000..73e7edb631a3c --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/AppendRankFunction.java @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.state.MapState; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.typeutils.ListTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.generated.GeneratedRecordComparator; +import org.apache.flink.table.generated.GeneratedRecordEqualiser; +import org.apache.flink.table.runtime.keyselector.BaseRowKeySelector; +import org.apache.flink.table.runtime.util.LRUMap; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; +import org.apache.flink.util.Collector; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +/** + * AppendRankFunction's input stream only contains append record. + */ +public class AppendRankFunction extends AbstractRankFunction { + + private static final long serialVersionUID = -4708453213104128010L; + + private static final Logger LOG = LoggerFactory.getLogger(AppendRankFunction.class); + + private final BaseRowTypeInfo sortKeyType; + private final TypeSerializer inputRowSer; + private final long cacheSize; + + // a map state stores mapping from sort key to records list which is in topN + private transient MapState> dataState; + + // the buffer stores mapping from sort key to records list, a heap mirror to dataState + private transient TopNBuffer buffer; + + // the kvSortedMap stores mapping from partition key to it's buffer + private transient Map kvSortedMap; + + public AppendRankFunction(long minRetentionTime, long maxRetentionTime, BaseRowTypeInfo inputRowType, + GeneratedRecordComparator sortKeyGeneratedRecordComparator, BaseRowKeySelector sortKeySelector, + RankType rankType, RankRange rankRange, GeneratedRecordEqualiser generatedEqualiser, + boolean generateRetraction, boolean outputRankNumber, long cacheSize) { + super(minRetentionTime, maxRetentionTime, inputRowType, sortKeyGeneratedRecordComparator, sortKeySelector, + rankType, rankRange, generatedEqualiser, generateRetraction, outputRankNumber); + this.sortKeyType = sortKeySelector.getProducedType(); + this.inputRowSer = inputRowType.createSerializer(new ExecutionConfig()); + this.cacheSize = cacheSize; + } + + public void open(Configuration parameters) throws Exception { + super.open(parameters); + int lruCacheSize = Math.max(1, (int) (cacheSize / getDefaultTopNSize())); + kvSortedMap = new LRUMap<>(lruCacheSize); + LOG.info("Top{} operator is using LRU caches key-size: {}", getDefaultTopNSize(), lruCacheSize); + + ListTypeInfo valueTypeInfo = new ListTypeInfo<>(inputRowType); + MapStateDescriptor> mapStateDescriptor = new MapStateDescriptor<>( + "data-state-with-append", sortKeyType, valueTypeInfo); + dataState = getRuntimeContext().getMapState(mapStateDescriptor); + + // metrics + registerMetric(kvSortedMap.size() * getDefaultTopNSize()); + } + + @Override + public void processElement(BaseRow input, Context context, Collector out) throws Exception { + long currentTime = context.timerService().currentProcessingTime(); + // register state-cleanup timer + registerProcessingCleanupTimer(context, currentTime); + + initHeapStates(); + initRankEnd(input); + + BaseRow sortKey = sortKeySelector.getKey(input); + // check whether the sortKey is in the topN range + if (checkSortKeyInBufferRange(sortKey, buffer)) { + // insert sort key into buffer + buffer.put(sortKey, inputRowSer.copy(input)); + Collection inputs = buffer.get(sortKey); + // update data state + dataState.put(sortKey, (List) inputs); + if (outputRankNumber || hasOffset()) { + // the without-number-algorithm can't handle topN with offset, + // so use the with-number-algorithm to handle offset + processElementWithRowNumber(sortKey, input, out); + } else { + processElementWithoutRowNumber(input, out); + } + } + } + + @Override + public void onTimer( + long timestamp, + OnTimerContext ctx, + Collector out) throws Exception { + if (stateCleaningEnabled) { + // cleanup cache + kvSortedMap.remove(keyContext.getCurrentKey()); + cleanupState(dataState); + } + } + + private void initHeapStates() throws Exception { + requestCount += 1; + BaseRow currentKey = (BaseRow) keyContext.getCurrentKey(); + buffer = kvSortedMap.get(currentKey); + if (buffer == null) { + buffer = new TopNBuffer(sortKeyComparator, ArrayList::new); + kvSortedMap.put(currentKey, buffer); + // restore buffer + Iterator>> iter = dataState.iterator(); + if (iter != null) { + while (iter.hasNext()) { + Map.Entry> entry = iter.next(); + BaseRow sortKey = entry.getKey(); + List values = entry.getValue(); + // the order is preserved + buffer.putAll(sortKey, values); + } + } + } else { + hitCount += 1; + } + } + + private void processElementWithRowNumber(BaseRow sortKey, BaseRow input, Collector out) throws Exception { + Iterator>> iterator = buffer.entrySet().iterator(); + long curRank = 0L; + boolean findsSortKey = false; + while (iterator.hasNext() && isInRankEnd(curRank)) { + Map.Entry> entry = iterator.next(); + Collection records = entry.getValue(); + // meet its own sort key + if (!findsSortKey && entry.getKey().equals(sortKey)) { + curRank += records.size(); + collect(out, input, curRank); + findsSortKey = true; + } else if (findsSortKey) { + Iterator recordsIter = records.iterator(); + while (recordsIter.hasNext() && isInRankEnd(curRank)) { + curRank += 1; + BaseRow prevRow = recordsIter.next(); + retract(out, prevRow, curRank - 1); + collect(out, prevRow, curRank); + } + } else { + curRank += records.size(); + } + } + + // remove the records associated to the sort key which is out of topN + List toDeleteSortKeys = new ArrayList<>(); + while (iterator.hasNext()) { + Map.Entry> entry = iterator.next(); + BaseRow key = entry.getKey(); + dataState.remove(key); + toDeleteSortKeys.add(key); + } + for (BaseRow toDeleteKey : toDeleteSortKeys) { + buffer.removeAll(toDeleteKey); + } + } + + private void processElementWithoutRowNumber(BaseRow input, Collector out) throws Exception { + // remove retired element + if (buffer.getCurrentTopNum() > rankEnd) { + Map.Entry> lastEntry = buffer.lastEntry(); + BaseRow lastKey = lastEntry.getKey(); + List lastList = (List) lastEntry.getValue(); + // remove last one + BaseRow lastElement = lastList.remove(lastList.size() - 1); + if (lastList.isEmpty()) { + buffer.removeAll(lastKey); + dataState.remove(lastKey); + } else { + dataState.put(lastKey, lastList); + } + if (input.equals(lastElement)) { + return; + } else { + // lastElement shouldn't be null + delete(out, lastElement); + } + } + collect(out, input); + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/ConstantRankRange.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/ConstantRankRange.java new file mode 100644 index 0000000000000..b4bb201308293 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/ConstantRankRange.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +import java.util.List; + +/** rankStart and rankEnd are inclusive, rankStart always start from one. */ +public class ConstantRankRange implements RankRange { + + private static final long serialVersionUID = 9062345289888078376L; + private long rankStart; + private long rankEnd; + + public ConstantRankRange(long rankStart, long rankEnd) { + this.rankStart = rankStart; + this.rankEnd = rankEnd; + } + + public long getRankStart() { + return rankStart; + } + + public long getRankEnd() { + return rankEnd; + } + + @Override + public String toString(List inputFieldNames) { + return toString(); + } + + @Override + public String toString() { + return "rankStart=" + rankStart + ", rankEnd=" + rankEnd; + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/ConstantRankRangeWithoutEnd.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/ConstantRankRangeWithoutEnd.java new file mode 100644 index 0000000000000..0d0ddd3dcc2fe --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/ConstantRankRangeWithoutEnd.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +import java.util.List; + +/** ConstantRankRangeWithoutEnd is a RankRange which not specify RankEnd. */ +public class ConstantRankRangeWithoutEnd implements RankRange { + + private static final long serialVersionUID = -1944057111062598696L; + + private final long rankStart; + + public ConstantRankRangeWithoutEnd(long rankStart) { + this.rankStart = rankStart; + } + + @Override + public String toString(List inputFieldNames) { + return toString(); + } + + @Override + public String toString() { + return "rankStart=" + rankStart; + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RankRange.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RankRange.java new file mode 100644 index 0000000000000..e14b4bfff9dd9 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RankRange.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +import java.io.Serializable; +import java.util.List; + +/** + * RankRange for Rank, including following 3 types : + * ConstantRankRange, ConstantRankRangeWithoutEnd, VariableRankRange. + */ +public interface RankRange extends Serializable { + + String toString(List inputFieldNames); + +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RankType.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RankType.java new file mode 100644 index 0000000000000..d6573dd96b24b --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RankType.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +/** + * An enumeration of rank type, usable to show how exactly generate rank number. + */ +public enum RankType { + + /** + * Returns a unique sequential number for each row within the partition based on the order, + * starting at 1 for the first row in each partition and without repeating or skipping + * numbers in the ranking result of each partition. If there are duplicate values within the + * row set, the ranking numbers will be assigned arbitrarily. + */ + ROW_NUMBER, + + /** + * Returns a unique rank number for each distinct row within the partition based on the order, + * starting at 1 for the first row in each partition, with the same rank for duplicate values + * and leaving gaps between the ranks; this gap appears in the sequence after the duplicate + * values. + */ + RANK, + + /** + * is similar to the RANK by generating a unique rank number for each distinct row + * within the partition based on the order, starting at 1 for the first row in each partition, + * ranking the rows with equal values with the same rank number, except that it does not skip + * any rank, leaving no gaps between the ranks. + */ + DENSE_RANK +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RetractRankFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RetractRankFunction.java new file mode 100644 index 0000000000000..eb0250d83de25 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/RetractRankFunction.java @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +import org.apache.flink.api.common.state.MapState; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.java.typeutils.ListTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.dataformat.util.BaseRowUtil; +import org.apache.flink.table.generated.GeneratedRecordComparator; +import org.apache.flink.table.generated.GeneratedRecordEqualiser; +import org.apache.flink.table.runtime.keyselector.BaseRowKeySelector; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; +import org.apache.flink.table.typeutils.SortedMapTypeInfo; +import org.apache.flink.util.Collector; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.SortedMap; +import java.util.TreeMap; + +/** + * RetractRankFunction's input stream could contain append record, update record, delete record. + */ +public class RetractRankFunction extends AbstractRankFunction { + + private static final long serialVersionUID = 1365312180599454479L; + + private static final Logger LOG = LoggerFactory.getLogger(RetractRankFunction.class); + + // Message to indicate the state is cleared because of ttl restriction. The message could be used to output to log. + private static final String STATE_CLEARED_WARN_MSG = "The state is cleared because of state ttl. " + + "This will result in incorrect result. You can increase the state ttl to avoid this."; + + private final BaseRowTypeInfo sortKeyType; + + // flag to skip records with non-exist error instead to fail, true by default. + private final boolean lenient = true; + + // a map state stores mapping from sort key to records list + private transient MapState> dataState; + + // a sorted map stores mapping from sort key to records count + private transient ValueState> treeMap; + + public RetractRankFunction(long minRetentionTime, long maxRetentionTime, BaseRowTypeInfo inputRowType, + GeneratedRecordComparator generatedRecordComparator, BaseRowKeySelector sortKeySelector, + RankType rankType, RankRange rankRange, GeneratedRecordEqualiser generatedEqualiser, + boolean generateRetraction, boolean outputRankNumber) { + super(minRetentionTime, maxRetentionTime, inputRowType, generatedRecordComparator, sortKeySelector, rankType, + rankRange, generatedEqualiser, generateRetraction, outputRankNumber); + this.sortKeyType = sortKeySelector.getProducedType(); + } + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + ListTypeInfo valueTypeInfo = new ListTypeInfo<>(inputRowType); + MapStateDescriptor> mapStateDescriptor = new MapStateDescriptor<>( + "data-state", sortKeyType, valueTypeInfo); + dataState = getRuntimeContext().getMapState(mapStateDescriptor); + + ValueStateDescriptor> valueStateDescriptor = new ValueStateDescriptor<>( + "sorted-map", + new SortedMapTypeInfo<>(sortKeyType, BasicTypeInfo.LONG_TYPE_INFO, sortKeyComparator)); + treeMap = getRuntimeContext().getState(valueStateDescriptor); + } + + @Override + public void processElement(BaseRow input, Context ctx, Collector out) throws Exception { + initRankEnd(input); + SortedMap sortedMap = treeMap.value(); + if (sortedMap == null) { + sortedMap = new TreeMap<>(sortKeyComparator); + } + BaseRow sortKey = sortKeySelector.getKey(input); + if (BaseRowUtil.isAccumulateMsg(input)) { + // update sortedMap + if (sortedMap.containsKey(sortKey)) { + sortedMap.put(sortKey, sortedMap.get(sortKey) + 1); + } else { + sortedMap.put(sortKey, 1L); + } + + // emit + emitRecordsWithRowNumber(sortedMap, sortKey, input, out); + + // update data state + List inputs = dataState.get(sortKey); + if (inputs == null) { + // the sort key is never seen + inputs = new ArrayList<>(); + } + inputs.add(input); + dataState.put(sortKey, inputs); + } else { + // emit updates first + retractRecordWithRowNumber(sortedMap, sortKey, input, out); + + // and then update sortedMap + if (sortedMap.containsKey(sortKey)) { + long count = sortedMap.get(sortKey) - 1; + if (count == 0) { + sortedMap.remove(sortKey); + } else { + sortedMap.put(sortKey, count); + } + } else { + if (sortedMap.isEmpty()) { + if (lenient) { + LOG.warn(STATE_CLEARED_WARN_MSG); + } else { + throw new RuntimeException(STATE_CLEARED_WARN_MSG); + } + } else { + throw new RuntimeException("Can not retract a non-existent record: ${inputBaseRow.toString}. " + + "This should never happen."); + } + } + + } + treeMap.update(sortedMap); + } + + // ------------- ROW_NUMBER------------------------------- + + private void retractRecordWithRowNumber( + SortedMap sortedMap, BaseRow sortKey, BaseRow inputRow, Collector out) + throws Exception { + Iterator> iterator = sortedMap.entrySet().iterator(); + long curRank = 0L; + boolean findsSortKey = false; + while (iterator.hasNext() && isInRankEnd(curRank)) { + Map.Entry entry = iterator.next(); + BaseRow key = entry.getKey(); + if (!findsSortKey && key.equals(sortKey)) { + List inputs = dataState.get(key); + if (inputs == null) { + // Skip the data if it's state is cleared because of state ttl. + if (lenient) { + LOG.warn(STATE_CLEARED_WARN_MSG); + } else { + throw new RuntimeException(STATE_CLEARED_WARN_MSG); + } + } else { + Iterator inputIter = inputs.iterator(); + while (inputIter.hasNext() && isInRankEnd(curRank)) { + curRank += 1; + BaseRow prevRow = inputIter.next(); + if (!findsSortKey && equaliser.equalsWithoutHeader(prevRow, inputRow)) { + delete(out, prevRow, curRank); + curRank -= 1; + findsSortKey = true; + inputIter.remove(); + } else if (findsSortKey) { + retract(out, prevRow, curRank + 1); + collect(out, prevRow, curRank); + } + } + if (inputs.isEmpty()) { + dataState.remove(key); + } else { + dataState.put(key, inputs); + } + } + } else if (findsSortKey) { + List inputs = dataState.get(key); + int i = 0; + while (i < inputs.size() && isInRankEnd(curRank)) { + curRank += 1; + BaseRow prevRow = inputs.get(i); + retract(out, prevRow, curRank + 1); + collect(out, prevRow, curRank); + i++; + } + } else { + curRank += entry.getValue(); + } + } + } + + private void emitRecordsWithRowNumber( + SortedMap sortedMap, BaseRow sortKey, BaseRow inputRow, Collector out) + throws Exception { + Iterator> iterator = sortedMap.entrySet().iterator(); + long curRank = 0L; + boolean findsSortKey = false; + while (iterator.hasNext() && isInRankEnd(curRank)) { + Map.Entry entry = iterator.next(); + BaseRow key = entry.getKey(); + if (!findsSortKey && key.equals(sortKey)) { + curRank += entry.getValue(); + collect(out, inputRow, curRank); + findsSortKey = true; + } else if (findsSortKey) { + List inputs = dataState.get(key); + if (inputs == null) { + // Skip the data if it's state is cleared because of state ttl. + if (lenient) { + LOG.warn(STATE_CLEARED_WARN_MSG); + } else { + throw new RuntimeException(STATE_CLEARED_WARN_MSG); + } + } else { + int i = 0; + while (i < inputs.size() && isInRankEnd(curRank)) { + curRank += 1; + BaseRow prevRow = inputs.get(i); + retract(out, prevRow, curRank - 1); + collect(out, prevRow, curRank); + i++; + } + } + } else { + curRank += entry.getValue(); + } + } + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/TopNBuffer.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/TopNBuffer.java new file mode 100644 index 0000000000000..8600d919c7e93 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/TopNBuffer.java @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +import org.apache.flink.table.dataformat.BaseRow; + +import java.io.Serializable; +import java.util.Collection; +import java.util.Comparator; +import java.util.Iterator; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; +import java.util.function.Supplier; + +/** + * TopNBuffer stores mapping from sort key to records list, sortKey is BaseRow type, each record is BaseRow type. + * TopNBuffer could also track rank number of each record. + */ +class TopNBuffer implements Serializable { + + private static final long serialVersionUID = 6824488508991990228L; + + private final Supplier> valueSupplier; + private final Comparator sortKeyComparator; + private int currentTopNum = 0; + private TreeMap> treeMap; + + TopNBuffer(Comparator sortKeyComparator, Supplier> valueSupplier) { + this.valueSupplier = valueSupplier; + this.sortKeyComparator = sortKeyComparator; + this.treeMap = new TreeMap(sortKeyComparator); + } + + /** + * Appends a record into the buffer. + * + * @param sortKey sort key with which the specified value is to be associated + * @param value record which is to be appended + * @return the size of the collection under the sortKey. + */ + public int put(BaseRow sortKey, BaseRow value) { + currentTopNum += 1; + // update treeMap + Collection collection = treeMap.get(sortKey); + if (collection == null) { + collection = valueSupplier.get(); + treeMap.put(sortKey, collection); + } + collection.add(value); + return collection.size(); + } + + /** + * Puts a record list into the buffer under the sortKey. + * Note: if buffer already contains sortKey, putAll will overwrite the previous value + * + * @param sortKey sort key with which the specified values are to be associated + * @param values record lists to be associated with the specified key + */ + void putAll(BaseRow sortKey, Collection values) { + treeMap.put(sortKey, values); + currentTopNum += values.size(); + } + + /** + * Gets the record list from the buffer under the sortKey. + * + * @param sortKey key to get + * @return the record list from the buffer under the sortKey + */ + public Collection get(BaseRow sortKey) { + return treeMap.get(sortKey); + } + + public void remove(BaseRow sortKey, BaseRow value) { + Collection list = treeMap.get(sortKey); + if (list != null) { + if (list.remove(value)) { + currentTopNum -= 1; + } + if (list.size() == 0) { + treeMap.remove(sortKey); + } + } + } + + /** + * Removes all record list from the buffer under the sortKey. + * + * @param sortKey key to remove + */ + void removeAll(BaseRow sortKey) { + Collection list = treeMap.get(sortKey); + if (list != null) { + currentTopNum -= list.size(); + treeMap.remove(sortKey); + } + } + + /** + * Removes the last record of the last Entry in the buffer. + * + * @return removed record + */ + BaseRow removeLast() { + Map.Entry> last = treeMap.lastEntry(); + BaseRow lastElement = null; + if (last != null) { + Collection list = last.getValue(); + lastElement = getLastElement(list); + if (lastElement != null) { + if (list.remove(lastElement)) { + currentTopNum -= 1; + } + if (list.size() == 0) { + treeMap.remove(last.getKey()); + } + } + } + return lastElement; + } + + /** + * Gets record which rank is given value. + * + * @param rank rank value to search + * @return the record which rank is given value + */ + BaseRow getElement(int rank) { + int curRank = 0; + Iterator>> iter = treeMap.entrySet().iterator(); + while (iter.hasNext()) { + Map.Entry> entry = iter.next(); + Collection list = entry.getValue(); + + Iterator listIter = list.iterator(); + while (listIter.hasNext()) { + BaseRow elem = listIter.next(); + curRank += 1; + if (curRank == rank) { + return elem; + } + } + } + return null; + } + + private BaseRow getLastElement(Collection list) { + BaseRow element = null; + if (list != null && !list.isEmpty()) { + Iterator iter = list.iterator(); + while (iter.hasNext()) { + element = iter.next(); + } + } + return element; + } + + /** + * Returns a {@link Set} view of the mappings contained in the buffer. + */ + Set>> entrySet() { + return treeMap.entrySet(); + } + + /** + * Returns the last Entry in the buffer. Returns null if the TreeMap is empty. + */ + Map.Entry> lastEntry() { + return treeMap.lastEntry(); + } + + /** + * Returns {@code true} if the buffer contains a mapping for the specified key. + * + * @param key key whose presence in the buffer is to be tested + * @return {@code true} if the buffer contains a mapping for the specified key + */ + boolean containsKey(BaseRow key) { + return treeMap.containsKey(key); + } + + /** + * Gets number of total records. + * + * @return the number of total records. + */ + int getCurrentTopNum() { + return currentTopNum; + } + + /** + * Gets sort key comparator used by buffer. + * + * @return sort key comparator used by buffer + */ + Comparator getSortKeyComparator() { + return sortKeyComparator; + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/UpdateRankFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/UpdateRankFunction.java new file mode 100644 index 0000000000000..a78b00d5e11ea --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/UpdateRankFunction.java @@ -0,0 +1,486 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.state.MapState; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.state.FunctionInitializationContext; +import org.apache.flink.runtime.state.FunctionSnapshotContext; +import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.generated.GeneratedRecordComparator; +import org.apache.flink.table.generated.GeneratedRecordEqualiser; +import org.apache.flink.table.runtime.keyselector.BaseRowKeySelector; +import org.apache.flink.table.runtime.util.LRUMap; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; +import org.apache.flink.util.Collector; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; + +/** + * A fast version of rank process function which only hold top n data in state, and keep sorted map in heap. + * This only works in some special scenarios: + * 1. sort field collation is ascending and its mono is decreasing, or sort field collation is descending and its mono + * is increasing + * 2. input data has unique keys + * 3. input stream could not contain delete record or retract record + */ +public class UpdateRankFunction extends AbstractRankFunction implements CheckpointedFunction { + + private static final long serialVersionUID = 6786508184355952780L; + + private static final Logger LOG = LoggerFactory.getLogger(UpdateRankFunction.class); + + private final BaseRowTypeInfo rowKeyType; + private final long cacheSize; + + // a map state stores mapping from row key to record which is in topN + // in tuple2, f0 is the record row, f1 is the index in the list of the same sort_key + // the f1 is used to preserve the record order in the same sort_key + private transient MapState> dataState; + + // a buffer stores mapping from sort key to rowKey list + private transient TopNBuffer buffer; + + // the kvSortedMap stores mapping from partition key to it's buffer + private transient Map kvSortedMap; + + // a HashMap stores mapping from rowKey to record, a heap mirror to dataState + private transient Map rowKeyMap; + + // the kvRowKeyMap store mapping from partitionKey to its rowKeyMap. + private transient LRUMap> kvRowKeyMap; + + private final TypeSerializer inputRowSer; + private final KeySelector rowKeySelector; + + public UpdateRankFunction(long minRetentionTime, long maxRetentionTime, BaseRowTypeInfo inputRowType, + BaseRowKeySelector rowKeySelector, GeneratedRecordComparator generatedRecordComparator, + BaseRowKeySelector sortKeySelector, RankType rankType, + RankRange rankRange, GeneratedRecordEqualiser generatedEqualiser, boolean generateRetraction, + boolean outputRankNumber, long cacheSize) { + super(minRetentionTime, maxRetentionTime, inputRowType, generatedRecordComparator, sortKeySelector, rankType, + rankRange, generatedEqualiser, generateRetraction, outputRankNumber); + this.rowKeyType = rowKeySelector.getProducedType(); + this.cacheSize = cacheSize; + this.inputRowSer = inputRowType.createSerializer(new ExecutionConfig()); + this.rowKeySelector = rowKeySelector; + } + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + int lruCacheSize = Math.max(1, (int) (cacheSize / getDefaultTopNSize())); + // make sure the cached map is in a fixed size, avoid OOM + kvSortedMap = new HashMap<>(lruCacheSize); + kvRowKeyMap = new LRUMap<>(lruCacheSize, new CacheRemovalListener()); + + LOG.info("Top{} operator is using LRU caches key-size: {}", getDefaultTopNSize(), lruCacheSize); + + TupleTypeInfo> valueTypeInfo = new TupleTypeInfo<>(inputRowType, Types.INT); + MapStateDescriptor> mapStateDescriptor = new MapStateDescriptor<>( + "data-state-with-update", rowKeyType, valueTypeInfo); + dataState = getRuntimeContext().getMapState(mapStateDescriptor); + + // metrics + registerMetric(kvSortedMap.size() * getDefaultTopNSize()); + } + + @Override + public void onTimer( + long timestamp, + OnTimerContext ctx, + Collector out) throws Exception { + if (stateCleaningEnabled) { + BaseRow partitionKey = (BaseRow) keyContext.getCurrentKey(); + // cleanup cache + kvRowKeyMap.remove(partitionKey); + kvSortedMap.remove(partitionKey); + cleanupState(dataState); + } + } + + @Override + public void initializeState(FunctionInitializationContext context) throws Exception { + // nothing to do + } + + @Override + public void processElement( + BaseRow input, Context context, Collector out) throws Exception { + long currentTime = context.timerService().currentProcessingTime(); + // register state-cleanup timer + registerProcessingCleanupTimer(context, currentTime); + + initHeapStates(); + initRankEnd(input); + if (outputRankNumber || hasOffset()) { + // the without-number-algorithm can't handle topN with offset, + // so use the with-number-algorithm to handle offset + processElementWithRowNumber(input, out); + } else { + processElementWithoutRowNumber(input, out); + } + } + + @Override + public void snapshotState(FunctionSnapshotContext context) throws Exception { + Iterator>> iter = kvRowKeyMap.entrySet().iterator(); + while (iter.hasNext()) { + Map.Entry> entry = iter.next(); + BaseRow partitionKey = entry.getKey(); + Map currentRowKeyMap = entry.getValue(); + keyContext.setCurrentKey(partitionKey); + flushBufferToState(currentRowKeyMap); + } + } + + private void initHeapStates() throws Exception { + requestCount += 1; + BaseRow partitionKey = (BaseRow) keyContext.getCurrentKey(); + buffer = kvSortedMap.get(partitionKey); + rowKeyMap = kvRowKeyMap.get(partitionKey); + if (buffer == null) { + buffer = new TopNBuffer(sortKeyComparator, LinkedHashSet::new); + rowKeyMap = new HashMap<>(); + kvSortedMap.put(partitionKey, buffer); + kvRowKeyMap.put(partitionKey, rowKeyMap); + + // restore sorted map + Iterator>> iter = dataState.iterator(); + if (iter != null) { + // a temp map associate sort key to tuple2 + Map> tempSortedMap = new HashMap<>(); + while (iter.hasNext()) { + Map.Entry> entry = iter.next(); + BaseRow rowKey = entry.getKey(); + Tuple2 recordAndInnerRank = entry.getValue(); + BaseRow record = recordAndInnerRank.f0; + Integer innerRank = recordAndInnerRank.f1; + rowKeyMap.put(rowKey, new RankRow(record, innerRank, false)); + + // insert into temp sort map to preserve the record order in the same sort key + BaseRow sortKey = sortKeySelector.getKey(record); + TreeMap treeMap = tempSortedMap.get(sortKey); + if (treeMap == null) { + treeMap = new TreeMap<>(); + tempSortedMap.put(sortKey, treeMap); + } + treeMap.put(innerRank, rowKey); + } + + // build sorted map from the temp map + Iterator>> tempIter = tempSortedMap.entrySet().iterator(); + while (tempIter.hasNext()) { + Map.Entry> entry = tempIter.next(); + BaseRow sortKey = entry.getKey(); + TreeMap treeMap = entry.getValue(); + Iterator> treeMapIter = treeMap.entrySet().iterator(); + while (treeMapIter.hasNext()) { + Map.Entry treeMapEntry = treeMapIter.next(); + Integer innerRank = treeMapEntry.getKey(); + BaseRow recordRowKey = treeMapEntry.getValue(); + int size = buffer.put(sortKey, recordRowKey); + if (innerRank != size) { + LOG.warn("Failed to build sorted map from state, this may result in wrong result. " + + "The sort key is {}, partition key is {}, " + + "treeMap is {}. The expected inner rank is {}, but current size is {}.", + sortKey, partitionKey, treeMap, innerRank, size); + } + } + } + } + } else { + hitCount += 1; + } + } + + private void processElementWithRowNumber(BaseRow inputRow, Collector out) throws Exception { + BaseRow sortKey = sortKeySelector.getKey(inputRow); + BaseRow rowKey = rowKeySelector.getKey(inputRow); + if (rowKeyMap.containsKey(rowKey)) { + // it is an updated record which is in the topN, in this scenario, + // the new sort key must be higher than old sort key, this is guaranteed by rules + RankRow oldRow = rowKeyMap.get(rowKey); + BaseRow oldSortKey = sortKeySelector.getKey(oldRow.row); + if (oldSortKey.equals(sortKey)) { + // sort key is not changed, so the rank is the same, only output the row + Tuple2 rankAndInnerRank = rowNumber(sortKey, rowKey, buffer); + int rank = rankAndInnerRank.f0; + int innerRank = rankAndInnerRank.f1; + rowKeyMap.put(rowKey, new RankRow(inputRowSer.copy(inputRow), innerRank, true)); + retract(out, oldRow.row, rank); // retract old record + collect(out, inputRow, rank); + return; + } + + Tuple2 oldRankAndInnerRank = rowNumber(oldSortKey, rowKey, buffer); + int oldRank = oldRankAndInnerRank.f0; + // remove old sort key + buffer.remove(oldSortKey, rowKey); + // add new sort key + int size = buffer.put(sortKey, rowKey); + rowKeyMap.put(rowKey, new RankRow(inputRowSer.copy(inputRow), size, true)); + // update inner rank of records under the old sort key + updateInnerRank(oldSortKey); + + // emit records + emitRecordsWithRowNumber(sortKey, inputRow, out, oldSortKey, oldRow, oldRank); + } else if (checkSortKeyInBufferRange(sortKey, buffer)) { + // it is an unique record but is in the topN, insert sort key into buffer + int size = buffer.put(sortKey, rowKey); + rowKeyMap.put(rowKey, new RankRow(inputRowSer.copy(inputRow), size, true)); + + // emit records + emitRecordsWithRowNumber(sortKey, inputRow, out); + } + } + + private Tuple2 rowNumber(BaseRow sortKey, BaseRow rowKey, TopNBuffer buffer) { + Iterator>> iterator = buffer.entrySet().iterator(); + int curRank = 1; + while (iterator.hasNext()) { + Map.Entry> entry = iterator.next(); + BaseRow curKey = entry.getKey(); + Collection rowKeys = entry.getValue(); + if (curKey.equals(sortKey)) { + Iterator rowKeysIter = rowKeys.iterator(); + int innerRank = 1; + while (rowKeysIter.hasNext()) { + if (rowKey.equals(rowKeysIter.next())) { + return Tuple2.of(curRank, innerRank); + } else { + innerRank += 1; + curRank += 1; + } + } + } else { + curRank += rowKeys.size(); + } + } + LOG.error("Failed to find the sortKey: {}, rowkey: {} in the buffer. This should never happen", sortKey, + rowKey); + throw new RuntimeException("Failed to find the sortKey, rowkey in the buffer. This should never happen"); + } + + private void emitRecordsWithRowNumber(BaseRow sortKey, BaseRow inputRow, Collector out) throws Exception { + emitRecordsWithRowNumber(sortKey, inputRow, out, null, null, -1); + } + + private void emitRecordsWithRowNumber(BaseRow sortKey, BaseRow inputRow, Collector out, BaseRow oldSortKey, + RankRow oldRow, int oldRank) throws Exception { + + int oldInnerRank = oldRow == null ? -1 : oldRow.innerRank; + Iterator>> iterator = buffer.entrySet().iterator(); + int curRank = 0; + // whether we have found the sort key in the buffer + boolean findsSortKey = false; + while (iterator.hasNext() && isInRankEnd(curRank)) { + Map.Entry> entry = iterator.next(); + BaseRow curSortKey = entry.getKey(); + Collection rowKeys = entry.getValue(); + // meet its own sort key + if (!findsSortKey && curSortKey.equals(sortKey)) { + curRank += rowKeys.size(); + if (oldRow != null) { + retract(out, oldRow.row, oldRank); + } + collect(out, inputRow, curRank); + findsSortKey = true; + } else if (findsSortKey) { + if (oldSortKey == null) { + // this is a new row, emit updates for all rows in the topn + Iterator rowKeyIter = rowKeys.iterator(); + while (rowKeyIter.hasNext() && isInRankEnd(curRank)) { + curRank += 1; + BaseRow rowKey = rowKeyIter.next(); + RankRow prevRow = rowKeyMap.get(rowKey); + retract(out, prevRow.row, curRank - 1); + collect(out, prevRow.row, curRank); + } + } else { + // current sort key is higher than old sort key, + // the rank of current record is changed, need to update the following rank + int compare = sortKeyComparator.compare(curSortKey, oldSortKey); + if (compare <= 0) { + Iterator rowKeyIter = rowKeys.iterator(); + int curInnerRank = 0; + while (rowKeyIter.hasNext() && isInRankEnd(curRank)) { + curRank += 1; + curInnerRank += 1; + if (compare == 0 && curInnerRank >= oldInnerRank) { + // match to the previous position + return; + } + + BaseRow rowKey = rowKeyIter.next(); + RankRow prevRow = rowKeyMap.get(rowKey); + retract(out, prevRow.row, curRank - 1); + collect(out, prevRow.row, curRank); + } + } else { + // current sort key is smaller than old sort key, the rank is not changed, so skip + return; + } + } + } else { + curRank += rowKeys.size(); + } + } + + // remove the records associated to the sort key which is out of topN + List toDeleteSortKeys = new ArrayList<>(); + while (iterator.hasNext()) { + Map.Entry> entry = iterator.next(); + Collection rowKeys = entry.getValue(); + Iterator rowKeyIter = rowKeys.iterator(); + while (rowKeyIter.hasNext()) { + BaseRow rowKey = rowKeyIter.next(); + rowKeyMap.remove(rowKey); + dataState.remove(rowKey); + } + toDeleteSortKeys.add(entry.getKey()); + } + for (BaseRow toDeleteKey : toDeleteSortKeys) { + buffer.removeAll(toDeleteKey); + } + } + + private void processElementWithoutRowNumber(BaseRow inputRow, Collector out) throws Exception { + BaseRow sortKey = sortKeySelector.getKey(inputRow); + BaseRow rowKey = rowKeySelector.getKey(inputRow); + if (rowKeyMap.containsKey(rowKey)) { + // it is an updated record which is in the topN, in this scenario, + // the new sort key must be higher than old sort key, this is guaranteed by rules + RankRow oldRow = rowKeyMap.get(rowKey); + BaseRow oldSortKey = sortKeySelector.getKey(oldRow.row); + if (!oldSortKey.equals(sortKey)) { + // remove old sort key + buffer.remove(oldSortKey, rowKey); + // add new sort key + int size = buffer.put(sortKey, rowKey); + rowKeyMap.put(rowKey, new RankRow(inputRowSer.copy(inputRow), size, true)); + // update inner rank of records under the old sort key + updateInnerRank(oldSortKey); + } else { + // row content may change, so we need to update row in map + rowKeyMap.put(rowKey, new RankRow(inputRowSer.copy(inputRow), oldRow.innerRank, true)); + } + // row content may change, so a retract is needed + retract(out, oldRow.row, oldRow.innerRank); + collect(out, inputRow); + } else if (checkSortKeyInBufferRange(sortKey, buffer)) { + // it is an unique record but is in the topN, insert sort key into buffer + int size = buffer.put(sortKey, rowKey); + rowKeyMap.put(rowKey, new RankRow(inputRowSer.copy(inputRow), size, true)); + collect(out, inputRow); + // remove retired element + if (buffer.getCurrentTopNum() > rankEnd) { + BaseRow lastRowKey = buffer.removeLast(); + if (lastRowKey != null) { + RankRow lastRow = rowKeyMap.remove(lastRowKey); + dataState.remove(lastRowKey); + // always send a retraction message + delete(out, lastRow.row); + } + } + } + } + + private void flushBufferToState(Map curRowKeyMap) throws Exception { + Iterator> iter = curRowKeyMap.entrySet().iterator(); + while (iter.hasNext()) { + Map.Entry entry = iter.next(); + BaseRow key = entry.getKey(); + RankRow rankRow = entry.getValue(); + if (rankRow.dirty) { + // should update state + dataState.put(key, Tuple2.of(rankRow.row, rankRow.innerRank)); + rankRow.dirty = false; + } + } + } + + private void updateInnerRank(BaseRow oldSortKey) { + Collection list = buffer.get(oldSortKey); + if (list != null) { + Iterator iter = list.iterator(); + int innerRank = 1; + while (iter.hasNext()) { + BaseRow rowKey = iter.next(); + RankRow row = rowKeyMap.get(rowKey); + if (row.innerRank != innerRank) { + row.innerRank = innerRank; + row.dirty = true; + } + innerRank += 1; + } + } + } + + private class CacheRemovalListener implements LRUMap.RemovalListener> { + + @Override + public void onRemoval(Map.Entry> eldest) { + BaseRow previousKey = (BaseRow) keyContext.getCurrentKey(); + BaseRow partitionKey = eldest.getKey(); + Map currentRowKeyMap = eldest.getValue(); + keyContext.setCurrentKey(partitionKey); + kvSortedMap.remove(partitionKey); + try { + flushBufferToState(currentRowKeyMap); + } catch (Throwable e) { + LOG.error("Fail to synchronize state!", e); + throw new RuntimeException(e); + } finally { + keyContext.setCurrentKey(previousKey); + } + } + } + + private class RankRow { + private final BaseRow row; + private int innerRank; + private boolean dirty; + + private RankRow(BaseRow row, int innerRank, boolean dirty) { + this.row = row; + this.innerRank = innerRank; + this.dirty = dirty; + } + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/VariableRankRange.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/VariableRankRange.java new file mode 100644 index 0000000000000..a065fe3ea1d2c --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/rank/VariableRankRange.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +import java.util.List; + +/** + * changing rank limit depends on input. + */ +public class VariableRankRange implements RankRange { + + private static final long serialVersionUID = 5579785886506433955L; + private int rankEndIndex; + + public VariableRankRange(int rankEndIndex) { + this.rankEndIndex = rankEndIndex; + } + + public int getRankEndIndex() { + return rankEndIndex; + } + + @Override + public String toString(List inputFieldNames) { + return "rankEnd=" + inputFieldNames.get(rankEndIndex); + } + + @Override + public String toString() { + return "rankEnd=$" + rankEndIndex; + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/values/ValuesInputFormat.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/values/ValuesInputFormat.java index 45ddd4ce33162..49cc0734dcf51 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/values/ValuesInputFormat.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/values/ValuesInputFormat.java @@ -41,7 +41,7 @@ public class ValuesInputFormat private static final Logger LOG = LoggerFactory.getLogger(ValuesInputFormat.class); private GeneratedInput> generatedInput; - private BaseRowTypeInfo returnType; + private final BaseRowTypeInfo returnType; private GenericInputFormat format; public ValuesInputFormat(GeneratedInput> generatedInput, BaseRowTypeInfo returnType) { @@ -54,7 +54,9 @@ public void open(GenericInputSplit split) { LOG.debug("Compiling GenericInputFormat: $name \n\n Code:\n$code", generatedInput.getClassName(), generatedInput.getCode()); LOG.debug("Instantiating GenericInputFormat."); + format = generatedInput.newInstance(getRuntimeContext().getUserCodeClassLoader()); + generatedInput = null; } @Override diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/TypeConverters.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/TypeConverters.java index a7b9199c962ca..12aaf6192b4bc 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/TypeConverters.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/TypeConverters.java @@ -93,6 +93,8 @@ public class TypeConverters { internalTypeToInfo.put(InternalTypes.CHAR, BasicTypeInfo.CHAR_TYPE_INFO); internalTypeToInfo.put(InternalTypes.DATE, BasicTypeInfo.INT_TYPE_INFO); internalTypeToInfo.put(InternalTypes.TIMESTAMP, BasicTypeInfo.LONG_TYPE_INFO); + internalTypeToInfo.put(InternalTypes.PROCTIME_INDICATOR, BasicTypeInfo.LONG_TYPE_INFO); + internalTypeToInfo.put(InternalTypes.ROWTIME_INDICATOR, BasicTypeInfo.LONG_TYPE_INFO); internalTypeToInfo.put(InternalTypes.TIME, BasicTypeInfo.INT_TYPE_INFO); internalTypeToInfo.put(InternalTypes.BINARY, PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO); internalTypeToInfo.put(InternalTypes.INTERVAL_MONTHS, BasicTypeInfo.INT_TYPE_INFO); @@ -111,6 +113,8 @@ public class TypeConverters { itToEti.put(InternalTypes.CHAR, BasicTypeInfo.CHAR_TYPE_INFO); itToEti.put(InternalTypes.DATE, SqlTimeTypeInfo.DATE); itToEti.put(InternalTypes.TIMESTAMP, SqlTimeTypeInfo.TIMESTAMP); + itToEti.put(InternalTypes.PROCTIME_INDICATOR, SqlTimeTypeInfo.TIMESTAMP); + itToEti.put(InternalTypes.ROWTIME_INDICATOR, SqlTimeTypeInfo.TIMESTAMP); itToEti.put(InternalTypes.TIME, SqlTimeTypeInfo.TIME); itToEti.put(InternalTypes.BINARY, PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO); INTERNAL_TYPE_TO_EXTERNAL_TYPE_INFO = Collections.unmodifiableMap(itToEti); diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/AbstractMapSerializer.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/AbstractMapSerializer.java new file mode 100644 index 0000000000000..344fb978c5fdc --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/AbstractMapSerializer.java @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.typeutils; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.Map; + +/** + * Base serializer for {@link Map}. The serializer relies on a key serializer + * and a value serializer for the serialization of the map's key-value pairs. + * + *

The serialization format for the map is as follows: four bytes for the + * length of the map, followed by the serialized representation of each + * key-value pair. To allow null values, each value is prefixed by a null flag. + * + * @param The type of the keys in the map. + * @param The type of the values in the map. + * @param The type of the map. + */ +abstract class AbstractMapSerializer> extends TypeSerializer { + + private static final long serialVersionUID = 1L; + + /** The serializer for the keys in the map. */ + final TypeSerializer keySerializer; + + /** The serializer for the values in the map. */ + final TypeSerializer valueSerializer; + + /** + * Creates a map serializer that uses the given serializers to serialize the key-value pairs in the map. + * + * @param keySerializer The serializer for the keys in the map + * @param valueSerializer The serializer for the values in the map + */ + AbstractMapSerializer(TypeSerializer keySerializer, TypeSerializer valueSerializer) { + Preconditions.checkNotNull(keySerializer, "The key serializer must not be null"); + Preconditions.checkNotNull(valueSerializer, "The value serializer must not be null."); + this.keySerializer = keySerializer; + this.valueSerializer = valueSerializer; + } + + // ------------------------------------------------------------------------ + + /** + * Returns the serializer for the keys in the map. + * + * @return The serializer for the keys in the map. + */ + public TypeSerializer getKeySerializer() { + return keySerializer; + } + + /** + * Returns the serializer for the values in the map. + * + * @return The serializer for the values in the map. + */ + public TypeSerializer getValueSerializer() { + return valueSerializer; + } + + // ------------------------------------------------------------------------ + // Type Serializer implementation + // ------------------------------------------------------------------------ + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public M copy(M from) { + M newMap = createInstance(); + + for (Map.Entry entry : from.entrySet()) { + K newKey = entry.getKey() == null ? null : keySerializer.copy(entry.getKey()); + V newValue = entry.getValue() == null ? null : valueSerializer.copy(entry.getValue()); + + newMap.put(newKey, newValue); + } + + return newMap; + } + + @Override + public M copy(M from, M reuse) { + return copy(from); + } + + @Override + public int getLength() { + return -1; // var length + } + + @Override + public void serialize(M map, DataOutputView target) throws IOException { + final int size = map.size(); + target.writeInt(size); + + for (Map.Entry entry : map.entrySet()) { + keySerializer.serialize(entry.getKey(), target); + + if (entry.getValue() == null) { + target.writeBoolean(true); + } else { + target.writeBoolean(false); + valueSerializer.serialize(entry.getValue(), target); + } + } + } + + @Override + public M deserialize(DataInputView source) throws IOException { + final int size = source.readInt(); + + final M map = createInstance(); + for (int i = 0; i < size; ++i) { + K key = keySerializer.deserialize(source); + + boolean isNull = source.readBoolean(); + V value = isNull ? null : valueSerializer.deserialize(source); + + map.put(key, value); + } + + return map; + } + + @Override + public M deserialize(M reuse, DataInputView source) throws IOException { + return deserialize(source); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + final int size = source.readInt(); + target.writeInt(size); + + for (int i = 0; i < size; ++i) { + keySerializer.copy(source, target); + + boolean isNull = source.readBoolean(); + target.writeBoolean(isNull); + + if (!isNull) { + valueSerializer.copy(source, target); + } + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + AbstractMapSerializer that = (AbstractMapSerializer) o; + + return keySerializer.equals(that.keySerializer) && valueSerializer.equals(that.valueSerializer); + } + + @Override + public int hashCode() { + int result = keySerializer.hashCode(); + result = 31 * result + valueSerializer.hashCode(); + return result; + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/AbstractMapTypeInfo.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/AbstractMapTypeInfo.java new file mode 100644 index 0000000000000..2321d28494cbe --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/AbstractMapTypeInfo.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.typeutils; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.util.Preconditions; + +import java.util.Map; + +/** + * Base type information for maps. + * + * @param The type of the keys in the map. + * @param The type of the values in the map. + */ +abstract class AbstractMapTypeInfo> extends TypeInformation { + + private static final long serialVersionUID = 1L; + + /* The type information for the keys in the map*/ + final TypeInformation keyTypeInfo; + + /* The type information for the values in the map */ + final TypeInformation valueTypeInfo; + + /** + * Constructor with given type information for the keys and the values in + * the map. + * + * @param keyTypeInfo The type information for the keys in the map. + * @param valueTypeInfo The type information for the values in th map. + */ + AbstractMapTypeInfo(TypeInformation keyTypeInfo, TypeInformation valueTypeInfo) { + Preconditions.checkNotNull(keyTypeInfo, "The type information for the keys cannot be null."); + Preconditions.checkNotNull(valueTypeInfo, "The type information for the values cannot be null."); + this.keyTypeInfo = keyTypeInfo; + this.valueTypeInfo = valueTypeInfo; + } + + /** + * Constructor with the classes of the keys and the values in the map. + * + * @param keyClass The class of the keys in the map. + * @param valueClass The class of the values in the map. + */ + AbstractMapTypeInfo(Class keyClass, Class valueClass) { + Preconditions.checkNotNull(keyClass, "The key class cannot be null."); + Preconditions.checkNotNull(valueClass, "The value class cannot be null."); + + this.keyTypeInfo = TypeInformation.of(keyClass); + this.valueTypeInfo = TypeInformation.of(valueClass); + } + + // ------------------------------------------------------------------------ + + /** + * Returns the type information for the keys in the map. + * + * @return The type information for the keys in the map. + */ + public TypeInformation getKeyTypeInfo() { + return keyTypeInfo; + } + + /** + * Returns the type information for the values in the map. + * + * @return The type information for the values in the map. + */ + public TypeInformation getValueTypeInfo() { + return valueTypeInfo; + } + + // ------------------------------------------------------------------------ + + @Override + public boolean isBasicType() { + return false; + } + + @Override + public boolean isTupleType() { + return false; + } + + @Override + public int getArity() { + return 0; + } + + @Override + public int getTotalFields() { + return 2; + } + + @Override + public boolean isKeyType() { + return false; + } + + // ------------------------------------------------------------------------ + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + AbstractMapTypeInfo that = (AbstractMapTypeInfo) o; + + return keyTypeInfo.equals(that.keyTypeInfo) && valueTypeInfo.equals(that.valueTypeInfo); + } + + @Override + public int hashCode() { + int result = keyTypeInfo.hashCode(); + result = 31 * result + valueTypeInfo.hashCode(); + return result; + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapSerializer.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapSerializer.java new file mode 100644 index 0000000000000..63f4f8f71eeac --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapSerializer.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.typeutils; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; + +import java.util.Comparator; +import java.util.SortedMap; +import java.util.TreeMap; + +/** + * A serializer for {@link SortedMap}. The serializer relies on a key serializer + * and a value serializer for the serialization of the map's key-value pairs. + * It also deploys a comparator to ensure the order of the keys. + * + *

The serialization format for the map is as follows: four bytes for the + * length of the map, followed by the serialized representation of each + * key-value pair. To allow null values, each value is prefixed by a null flag. + * + * @param The type of the keys in the map. + * @param The type of the values in the map. + */ +public final class SortedMapSerializer extends AbstractMapSerializer> { + + private static final long serialVersionUID = 1L; + + /** The comparator for the keys in the map. */ + private final Comparator comparator; + + /** + * Constructor with given comparator, and the serializers for the keys and + * values in the map. + * + * @param comparator The comparator for the keys in the map. + * @param keySerializer The serializer for the keys in the map. + * @param valueSerializer The serializer for the values in the map. + */ + public SortedMapSerializer( + Comparator comparator, + TypeSerializer keySerializer, + TypeSerializer valueSerializer) { + super(keySerializer, valueSerializer); + this.comparator = comparator; + } + + /** + * Returns the comparator for the keys in the map. + * + * @return The comparator for the keys in the map. + */ + public Comparator getComparator() { + return comparator; + } + + @Override + public TypeSerializer> duplicate() { + TypeSerializer keySerializer = getKeySerializer().duplicate(); + TypeSerializer valueSerializer = getValueSerializer().duplicate(); + + return new SortedMapSerializer<>(comparator, keySerializer, valueSerializer); + } + + @Override + public SortedMap createInstance() { + return new TreeMap<>(comparator); + } + + @Override + public boolean equals(Object o) { + if (!super.equals(o)) { + return false; + } + + SortedMapSerializer that = (SortedMapSerializer) o; + return comparator.equals(that.comparator); + } + + @Override + public int hashCode() { + int result = super.hashCode(); + result = result * 31 + comparator.hashCode(); + return result; + } + + @Override + public String toString() { + return "SortedMapSerializer{" + + "comparator = " + comparator + + ", keySerializer = " + keySerializer + + ", valueSerializer = " + valueSerializer + + "}"; + } + + // -------------------------------------------------------------------------------------------- + // Serializer configuration snapshot + // -------------------------------------------------------------------------------------------- + + @Override + public TypeSerializerSnapshot> snapshotConfiguration() { + return new SortedMapSerializerSnapshot<>(this); + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapSerializerSnapshot.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapSerializerSnapshot.java new file mode 100644 index 0000000000000..d7a1dcf68ce1d --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapSerializerSnapshot.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.typeutils; + +import org.apache.flink.api.common.typeutils.NestedSerializersSnapshotDelegate; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializerSchemaCompatibility; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.java.typeutils.runtime.DataInputViewStream; +import org.apache.flink.api.java.typeutils.runtime.DataOutputViewStream; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.util.InstantiationUtil; + +import java.io.IOException; +import java.util.Comparator; +import java.util.SortedMap; + +import static org.apache.flink.util.Preconditions.checkState; + +/** + * Snapshot class for the {@link SortedMapSerializer}. + */ +public class SortedMapSerializerSnapshot implements TypeSerializerSnapshot> { + + private Comparator comparator; + + private NestedSerializersSnapshotDelegate nestedSerializersSnapshotDelegate; + + private static final int CURRENT_VERSION = 3; + + @SuppressWarnings("unused") + public SortedMapSerializerSnapshot() { + // this constructor is used when restoring from a checkpoint/savepoint. + } + + SortedMapSerializerSnapshot(SortedMapSerializer sortedMapSerializer) { + this.comparator = sortedMapSerializer.getComparator(); + TypeSerializer[] typeSerializers = + new TypeSerializer[] { sortedMapSerializer.getKeySerializer(), sortedMapSerializer.getValueSerializer() }; + this.nestedSerializersSnapshotDelegate = new NestedSerializersSnapshotDelegate(typeSerializers); + } + + @Override + public int getCurrentVersion() { + return CURRENT_VERSION; + } + + @Override + public void writeSnapshot(DataOutputView out) throws IOException { + checkState(comparator != null, "Comparator cannot be null."); + InstantiationUtil.serializeObject(new DataOutputViewStream(out), comparator); + nestedSerializersSnapshotDelegate.writeNestedSerializerSnapshots(out); + } + + @Override + public void readSnapshot(int readVersion, DataInputView in, ClassLoader userCodeClassLoader) throws IOException { + try { + comparator = InstantiationUtil.deserializeObject( + new DataInputViewStream(in), userCodeClassLoader); + } catch (ClassNotFoundException e) { + throw new IOException(e); + } + this.nestedSerializersSnapshotDelegate = NestedSerializersSnapshotDelegate.readNestedSerializerSnapshots( + in, + userCodeClassLoader); + } + + @Override + public SortedMapSerializer restoreSerializer() { + TypeSerializer[] nestedSerializers = nestedSerializersSnapshotDelegate.getRestoredNestedSerializers(); + @SuppressWarnings("unchecked") + TypeSerializer keySerializer = (TypeSerializer) nestedSerializers[0]; + + @SuppressWarnings("unchecked") + TypeSerializer valueSerializer = (TypeSerializer) nestedSerializers[1]; + + return new SortedMapSerializer(comparator, keySerializer, valueSerializer); + } + + @Override + public TypeSerializerSchemaCompatibility> resolveSchemaCompatibility( + TypeSerializer> newSerializer) { + if (!(newSerializer instanceof SortedMapSerializer)) { + return TypeSerializerSchemaCompatibility.incompatible(); + } + SortedMapSerializer newSortedMapSerializer = (SortedMapSerializer) newSerializer; + if (!comparator.equals(newSortedMapSerializer.getComparator())) { + return TypeSerializerSchemaCompatibility.incompatible(); + } else { + return TypeSerializerSchemaCompatibility.compatibleAsIs(); + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapTypeInfo.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapTypeInfo.java new file mode 100644 index 0000000000000..0e3cef7297044 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/typeutils/SortedMapTypeInfo.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.typeutils; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.util.Preconditions; + +import java.io.Serializable; +import java.util.Comparator; +import java.util.SortedMap; + +/** + * The type information for sorted maps. + * + * @param The type of the keys in the map. + * @param The type of the values in the map. + */ +@PublicEvolving +public class SortedMapTypeInfo extends AbstractMapTypeInfo> { + + private static final long serialVersionUID = 1L; + + /** The comparator for the keys in the map. */ + private final Comparator comparator; + + public SortedMapTypeInfo( + TypeInformation keyTypeInfo, + TypeInformation valueTypeInfo, + Comparator comparator) { + super(keyTypeInfo, valueTypeInfo); + + Preconditions.checkNotNull(comparator, "The comparator cannot be null."); + this.comparator = comparator; + } + + public SortedMapTypeInfo( + Class keyClass, + Class valueClass, + Comparator comparator) { + super(keyClass, valueClass); + + Preconditions.checkNotNull(comparator, "The comparator cannot be null."); + this.comparator = comparator; + } + + public SortedMapTypeInfo(Class keyClass, Class valueClass) { + super(keyClass, valueClass); + + Preconditions.checkArgument(Comparable.class.isAssignableFrom(keyClass), + "The key class must be comparable when no comparator is given."); + this.comparator = new ComparableComparator<>(); + } + + // ------------------------------------------------------------------------ + + @SuppressWarnings("unchecked") + @Override + public Class> getTypeClass() { + return (Class>) (Class) SortedMap.class; + } + + @Override + public TypeSerializer> createSerializer(ExecutionConfig config) { + TypeSerializer keyTypeSerializer = keyTypeInfo.createSerializer(config); + TypeSerializer valueTypeSerializer = valueTypeInfo.createSerializer(config); + + return new SortedMapSerializer<>(comparator, keyTypeSerializer, valueTypeSerializer); + } + + @Override + public boolean canEqual(Object obj) { + return null != obj && getClass() == obj.getClass(); + } + + @Override + public boolean equals(Object o) { + if (!super.equals(o)) { + return false; + } + + SortedMapTypeInfo that = (SortedMapTypeInfo) o; + + return comparator.equals(that.comparator); + } + + @Override + public int hashCode() { + int result = super.hashCode(); + result = 31 * result + comparator.hashCode(); + return result; + } + + @Override + public String toString() { + return "SortedMapTypeInfo{" + + "comparator=" + comparator + + ", keyTypeInfo=" + getKeyTypeInfo() + + ", valueTypeInfo=" + getValueTypeInfo() + + "}"; + } + + //-------------------------------------------------------------------------- + + /** + * The default comparator for comparable types. + */ + private static class ComparableComparator implements Comparator, Serializable { + private static final long serialVersionUID = 1L; + + @SuppressWarnings("unchecked") + public int compare(K obj1, K obj2) { + return ((Comparable) obj1).compareTo(obj2); + } + + @Override + public boolean equals(Object o) { + return (o == this) || (o != null && o.getClass() == getClass()); + } + + @Override + public int hashCode() { + return "ComparableComparator".hashCode(); + } + + } +} diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/DeduplicateKeepFirstRowFunctionTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/DeduplicateKeepFirstRowFunctionTest.java new file mode 100644 index 0000000000000..f23acd727a00d --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/DeduplicateKeepFirstRowFunctionTest.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.deduplicate; + +import org.apache.flink.api.common.time.Time; +import org.apache.flink.streaming.api.operators.KeyedProcessOperator; +import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; +import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.runtime.util.BaseRowHarnessAssertor; +import org.apache.flink.table.runtime.util.BinaryRowKeySelector; +import org.apache.flink.table.runtime.util.GenericRowRecordSortComparator; +import org.apache.flink.table.type.InternalTypes; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; + +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.apache.flink.table.runtime.util.StreamRecordUtils.record; + +/** + * Tests for {@link DeduplicateKeepFirstRowFunction}. + */ +public class DeduplicateKeepFirstRowFunctionTest { + + private Time minTime = Time.milliseconds(10); + private Time maxTime = Time.milliseconds(20); + + private BaseRowTypeInfo inputRowType = new BaseRowTypeInfo(InternalTypes.STRING, InternalTypes.LONG, + InternalTypes.INT); + + private int rowKeyIdx = 1; + private BinaryRowKeySelector rowKeySelector = new BinaryRowKeySelector(new int[] { rowKeyIdx }, + inputRowType.getInternalTypes()); + + private BaseRowHarnessAssertor assertor = new BaseRowHarnessAssertor( + inputRowType.getFieldTypes(), + new GenericRowRecordSortComparator(rowKeyIdx, inputRowType.getInternalTypes()[rowKeyIdx])); + + private OneInputStreamOperatorTestHarness createTestHarness( + DeduplicateKeepFirstRowFunction func) + throws Exception { + KeyedProcessOperator operator = new KeyedProcessOperator<>(func); + return new KeyedOneInputStreamOperatorTestHarness<>(operator, rowKeySelector, rowKeySelector.getProducedType()); + } + + @Test + public void test() throws Exception { + DeduplicateKeepFirstRowFunction func = new DeduplicateKeepFirstRowFunction(minTime.toMilliseconds(), + maxTime.toMilliseconds()); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 1L, 12)); + testHarness.processElement(record("book", 2L, 11)); + testHarness.processElement(record("book", 1L, 13)); + testHarness.close(); + + // Keep FirstRow in deduplicate will not send retraction + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 1L, 12)); + expectedOutputOutput.add(record("book", 2L, 11)); + assertor.assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/DeduplicateKeepLastRowFunctionTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/DeduplicateKeepLastRowFunctionTest.java new file mode 100644 index 0000000000000..f6c3419185f43 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/DeduplicateKeepLastRowFunctionTest.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.deduplicate; + +import org.apache.flink.api.common.time.Time; +import org.apache.flink.streaming.api.operators.KeyedProcessOperator; +import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; +import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.runtime.util.BaseRowHarnessAssertor; +import org.apache.flink.table.runtime.util.BinaryRowKeySelector; +import org.apache.flink.table.runtime.util.GenericRowRecordSortComparator; +import org.apache.flink.table.type.InternalTypes; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; + +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.apache.flink.table.runtime.util.StreamRecordUtils.record; +import static org.apache.flink.table.runtime.util.StreamRecordUtils.retractRecord; + +/** + * Tests for {@link DeduplicateKeepLastRowFunction}. + */ +public class DeduplicateKeepLastRowFunctionTest { + + private Time minTime = Time.milliseconds(10); + private Time maxTime = Time.milliseconds(20); + + private BaseRowTypeInfo inputRowType = new BaseRowTypeInfo(InternalTypes.STRING, InternalTypes.LONG, + InternalTypes.INT); + + private int rowKeyIdx = 1; + private BinaryRowKeySelector rowKeySelector = new BinaryRowKeySelector(new int[] { rowKeyIdx }, + inputRowType.getInternalTypes()); + + private BaseRowHarnessAssertor assertor = new BaseRowHarnessAssertor( + inputRowType.getFieldTypes(), + new GenericRowRecordSortComparator(rowKeyIdx, inputRowType.getInternalTypes()[rowKeyIdx])); + + private DeduplicateKeepLastRowFunction createFunction(boolean generateRetraction) { + return new DeduplicateKeepLastRowFunction(minTime.toMilliseconds(), maxTime.toMilliseconds(), inputRowType, + generateRetraction); + } + + private OneInputStreamOperatorTestHarness createTestHarness( + DeduplicateKeepLastRowFunction func) + throws Exception { + KeyedProcessOperator operator = new KeyedProcessOperator<>(func); + return new KeyedOneInputStreamOperatorTestHarness<>(operator, rowKeySelector, rowKeySelector.getProducedType()); + } + + @Test + public void testWithoutGenerateRetraction() throws Exception { + DeduplicateKeepLastRowFunction func = createFunction(false); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 1L, 12)); + testHarness.processElement(record("book", 2L, 11)); + testHarness.processElement(record("book", 1L, 13)); + testHarness.close(); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 1L, 12)); + expectedOutputOutput.add(record("book", 2L, 11)); + expectedOutputOutput.add(record("book", 1L, 13)); + assertor.assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + + @Test + public void testWithGenerateRetraction() throws Exception { + DeduplicateKeepLastRowFunction func = createFunction(true); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 1L, 12)); + testHarness.processElement(record("book", 2L, 11)); + testHarness.processElement(record("book", 1L, 13)); + testHarness.close(); + + // Keep LastRow in deduplicate may send retraction + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 1L, 12)); + expectedOutputOutput.add(retractRecord("book", 1L, 12)); + expectedOutputOutput.add(record("book", 1L, 13)); + expectedOutputOutput.add(record("book", 2L, 11)); + assertor.assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } +} diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateKeepFirstRowFunctionTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateKeepFirstRowFunctionTest.java new file mode 100644 index 0000000000000..af3dbde80c9ce --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateKeepFirstRowFunctionTest.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.deduplicate; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; +import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.runtime.bundle.KeyedMapBundleOperator; +import org.apache.flink.table.runtime.bundle.trigger.CountBundleTrigger; +import org.apache.flink.table.runtime.util.BaseRowHarnessAssertor; +import org.apache.flink.table.runtime.util.BinaryRowKeySelector; +import org.apache.flink.table.runtime.util.GenericRowRecordSortComparator; +import org.apache.flink.table.type.InternalTypes; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; + +import org.junit.Assert; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.apache.flink.table.runtime.util.StreamRecordUtils.record; + +/** + * Tests for {@link MiniBatchDeduplicateKeepFirstRowFunction}. + */ +public class MiniBatchDeduplicateKeepFirstRowFunctionTest { + + private BaseRowTypeInfo inputRowType = new BaseRowTypeInfo(InternalTypes.STRING, InternalTypes.LONG, + InternalTypes.INT); + + private int rowKeyIdx = 1; + private BinaryRowKeySelector rowKeySelector = new BinaryRowKeySelector(new int[] { rowKeyIdx }, + inputRowType.getInternalTypes()); + + + private BaseRowHarnessAssertor assertor = new BaseRowHarnessAssertor( + inputRowType.getFieldTypes(), + new GenericRowRecordSortComparator(rowKeyIdx, inputRowType.getInternalTypes()[rowKeyIdx])); + + private TypeSerializer typeSerializer = inputRowType.createSerializer(new ExecutionConfig()); + + private OneInputStreamOperatorTestHarness createTestHarness( + MiniBatchDeduplicateKeepFirstRowFunction func) + throws Exception { + CountBundleTrigger> trigger = new CountBundleTrigger<>(3); + KeyedMapBundleOperator op = new KeyedMapBundleOperator(func, trigger); + return new KeyedOneInputStreamOperatorTestHarness<>(op, rowKeySelector, rowKeySelector.getProducedType()); + } + + @Test + public void testKeepFirstRowWithGenerateRetraction() throws Exception { + MiniBatchDeduplicateKeepFirstRowFunction func = new MiniBatchDeduplicateKeepFirstRowFunction(typeSerializer); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 1L, 12)); + testHarness.processElement(record("book", 2L, 11)); + + // output is empty because bundle not trigger yet. + Assert.assertTrue(testHarness.getOutput().isEmpty()); + + testHarness.processElement(record("book", 1L, 13)); + + // Keep FirstRow in deduplicate will not send retraction + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 1L, 12)); + expectedOutputOutput.add(record("book", 2L, 11)); + assertor.assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + testHarness.close(); + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateKeepLastRowFunctionTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateKeepLastRowFunctionTest.java new file mode 100644 index 0000000000000..efbf33552da8f --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/deduplicate/MiniBatchDeduplicateKeepLastRowFunctionTest.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.deduplicate; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; +import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.runtime.bundle.KeyedMapBundleOperator; +import org.apache.flink.table.runtime.bundle.trigger.CountBundleTrigger; +import org.apache.flink.table.runtime.util.BaseRowHarnessAssertor; +import org.apache.flink.table.runtime.util.BinaryRowKeySelector; +import org.apache.flink.table.runtime.util.GenericRowRecordSortComparator; +import org.apache.flink.table.type.InternalTypes; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; + +import org.junit.Assert; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.apache.flink.table.runtime.util.StreamRecordUtils.record; +import static org.apache.flink.table.runtime.util.StreamRecordUtils.retractRecord; + +/** + * Tests for {@link MiniBatchDeduplicateKeepLastRowFunction}. + */ +public class MiniBatchDeduplicateKeepLastRowFunctionTest { + + private BaseRowTypeInfo inputRowType = new BaseRowTypeInfo(InternalTypes.STRING, InternalTypes.LONG, + InternalTypes.INT); + + private int rowKeyIdx = 1; + private BinaryRowKeySelector rowKeySelector = new BinaryRowKeySelector(new int[] { rowKeyIdx }, + inputRowType.getInternalTypes()); + + + private BaseRowHarnessAssertor assertor = new BaseRowHarnessAssertor( + inputRowType.getFieldTypes(), + new GenericRowRecordSortComparator(rowKeyIdx, inputRowType.getInternalTypes()[rowKeyIdx])); + + private TypeSerializer typeSerializer = inputRowType.createSerializer(new ExecutionConfig()); + + private MiniBatchDeduplicateKeepLastRowFunction createFunction(boolean generateRetraction) { + return new MiniBatchDeduplicateKeepLastRowFunction(inputRowType, generateRetraction, typeSerializer); + } + + private OneInputStreamOperatorTestHarness createTestHarness( + MiniBatchDeduplicateKeepLastRowFunction func) + throws Exception { + CountBundleTrigger> trigger = new CountBundleTrigger<>(3); + KeyedMapBundleOperator op = new KeyedMapBundleOperator(func, trigger); + return new KeyedOneInputStreamOperatorTestHarness<>(op, rowKeySelector, rowKeySelector.getProducedType()); + } + + @Test + public void testWithoutGenerateRetraction() throws Exception { + MiniBatchDeduplicateKeepLastRowFunction func = createFunction(false); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 1L, 10)); + testHarness.processElement(record("book", 2L, 11)); + // output is empty because bundle not trigger yet. + Assert.assertTrue(testHarness.getOutput().isEmpty()); + + testHarness.processElement(record("book", 1L, 13)); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 2L, 11)); + expectedOutputOutput.add(record("book", 1L, 13)); + assertor.assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + + testHarness.processElement(record("book", 1L, 12)); + testHarness.processElement(record("book", 2L, 11)); + testHarness.processElement(record("book", 3L, 11)); + + expectedOutputOutput.add(record("book", 1L, 12)); + expectedOutputOutput.add(record("book", 2L, 11)); + expectedOutputOutput.add(record("book", 3L, 11)); + testHarness.close(); + assertor.assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + + @Test + public void testWithGenerateRetraction() throws Exception { + MiniBatchDeduplicateKeepLastRowFunction func = createFunction(true); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 1L, 10)); + testHarness.processElement(record("book", 2L, 11)); + // output is empty because bundle not trigger yet. + Assert.assertTrue(testHarness.getOutput().isEmpty()); + + testHarness.processElement(record("book", 1L, 13)); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 2L, 11)); + expectedOutputOutput.add(record("book", 1L, 13)); + assertor.assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + + testHarness.processElement(record("book", 1L, 12)); + testHarness.processElement(record("book", 2L, 11)); + testHarness.processElement(record("book", 3L, 11)); + + // this will send retract message to downstream + expectedOutputOutput.add(retractRecord("book", 1L, 13)); + expectedOutputOutput.add(record("book", 1L, 12)); + expectedOutputOutput.add(retractRecord("book", 2L, 11)); + expectedOutputOutput.add(record("book", 2L, 11)); + expectedOutputOutput.add(record("book", 3L, 11)); + testHarness.close(); + assertor.assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } +} diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/AppendRankFunctionTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/AppendRankFunctionTest.java new file mode 100644 index 0000000000000..49b9366c71f51 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/AppendRankFunctionTest.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; +import org.apache.flink.table.dataformat.BaseRow; + +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.apache.flink.table.runtime.util.StreamRecordUtils.record; +import static org.apache.flink.table.runtime.util.StreamRecordUtils.retractRecord; + +/** + * Tests for {@link AppendRankFunction}. + */ +public class AppendRankFunctionTest extends BaseRankFunctionTest { + + @Override + protected AbstractRankFunction createRankFunction(RankType rankType, RankRange rankRange, + boolean generateRetraction, boolean outputRankNumber) { + return new AppendRankFunction(minTime.toMilliseconds(), maxTime.toMilliseconds(), + inputRowType, sortKeyComparator, sortKeySelector, rankType, rankRange, generatedEqualiser, + generateRetraction, outputRankNumber, cacheSize); + } + + @Test + public void testVariableRankRange() throws Exception { + AbstractRankFunction func = createRankFunction(RankType.ROW_NUMBER, new VariableRankRange(1), true, false); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 2L, 12)); + testHarness.processElement(record("book", 2L, 19)); + testHarness.processElement(record("book", 2L, 11)); + testHarness.processElement(record("fruit", 1L, 33)); + testHarness.processElement(record("fruit", 1L, 44)); + testHarness.processElement(record("fruit", 1L, 22)); + testHarness.close(); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 2L, 12)); + expectedOutputOutput.add(record("book", 2L, 19)); + expectedOutputOutput.add(retractRecord("book", 2L, 19)); + expectedOutputOutput.add(record("book", 2L, 11)); + expectedOutputOutput.add(record("fruit", 1L, 33)); + expectedOutputOutput.add(retractRecord("fruit", 1L, 33)); + expectedOutputOutput.add(record("fruit", 1L, 22)); + assertorWithoutRowNumber + .assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/BaseRankFunctionTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/BaseRankFunctionTest.java new file mode 100644 index 0000000000000..deff2ac2ef98d --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/BaseRankFunctionTest.java @@ -0,0 +1,337 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +import org.apache.flink.api.common.time.Time; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; +import org.apache.flink.streaming.api.operators.KeyedProcessOperator; +import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; +import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.generated.GeneratedRecordComparator; +import org.apache.flink.table.generated.GeneratedRecordEqualiser; +import org.apache.flink.table.generated.RecordComparator; +import org.apache.flink.table.generated.RecordEqualiser; +import org.apache.flink.table.runtime.sort.IntRecordComparator; +import org.apache.flink.table.runtime.util.BaseRowHarnessAssertor; +import org.apache.flink.table.runtime.util.BaseRowRecordEqualiser; +import org.apache.flink.table.runtime.util.BinaryRowKeySelector; +import org.apache.flink.table.runtime.util.GenericRowRecordSortComparator; +import org.apache.flink.table.type.InternalTypes; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; + +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.apache.flink.table.runtime.util.StreamRecordUtils.deleteRecord; +import static org.apache.flink.table.runtime.util.StreamRecordUtils.record; +import static org.apache.flink.table.runtime.util.StreamRecordUtils.retractRecord; + +/** + * Base Tests for all subclass of {@link AbstractRankFunction}. + */ +abstract class BaseRankFunctionTest { + + Time minTime = Time.milliseconds(10); + Time maxTime = Time.milliseconds(20); + long cacheSize = 10000L; + + BaseRowTypeInfo inputRowType = new BaseRowTypeInfo( + InternalTypes.STRING, + InternalTypes.LONG, + InternalTypes.INT); + + GeneratedRecordComparator sortKeyComparator = new GeneratedRecordComparator("", "", new Object[0]) { + + private static final long serialVersionUID = 1434685115916728955L; + + @Override + public RecordComparator newInstance(ClassLoader classLoader) { + + return IntRecordComparator.INSTANCE; + } + }; + + private int sortKeyIdx = 2; + + BinaryRowKeySelector sortKeySelector = new BinaryRowKeySelector(new int[] { sortKeyIdx }, + inputRowType.getInternalTypes()); + + GeneratedRecordEqualiser generatedEqualiser = new GeneratedRecordEqualiser("", "", new Object[0]) { + + private static final long serialVersionUID = 8932460173848746733L; + + @Override + public RecordEqualiser newInstance(ClassLoader classLoader) { + return new BaseRowRecordEqualiser(); + } + }; + + private int partitionKeyIdx = 0; + + private BinaryRowKeySelector keySelector = new BinaryRowKeySelector(new int[] { partitionKeyIdx }, + inputRowType.getInternalTypes()); + + private BaseRowTypeInfo outputTypeWithoutRowNumber = inputRowType; + + private BaseRowTypeInfo outputTypeWithRowNumber = new BaseRowTypeInfo( + InternalTypes.STRING, + InternalTypes.LONG, + InternalTypes.INT, + InternalTypes.LONG); + + BaseRowHarnessAssertor assertorWithoutRowNumber = new BaseRowHarnessAssertor( + outputTypeWithoutRowNumber.getFieldTypes(), + new GenericRowRecordSortComparator(sortKeyIdx, outputTypeWithoutRowNumber.getInternalTypes()[sortKeyIdx])); + + BaseRowHarnessAssertor assertorWithRowNumber = new BaseRowHarnessAssertor( + outputTypeWithRowNumber.getFieldTypes(), + new GenericRowRecordSortComparator(sortKeyIdx, outputTypeWithRowNumber.getInternalTypes()[sortKeyIdx])); + + // rowKey only used in UpdateRankFunction + private int rowKeyIdx = 1; + BinaryRowKeySelector rowKeySelector = new BinaryRowKeySelector(new int[] { rowKeyIdx }, + inputRowType.getInternalTypes()); + + @Test(expected = UnsupportedOperationException.class) + public void testInvalidVariableRankRangeWithIntType() throws Exception { + createRankFunction(RankType.ROW_NUMBER, new VariableRankRange(2), true, false); + } + + @Test(expected = UnsupportedOperationException.class) + public void testNotSupportRank() throws Exception { + createRankFunction(RankType.RANK, new ConstantRankRange(1, 10), true, true); + } + + @Test(expected = UnsupportedOperationException.class) + public void testNotSupportDenseRank() throws Exception { + createRankFunction(RankType.DENSE_RANK, new ConstantRankRange(1, 10), true, true); + } + + @Test(expected = UnsupportedOperationException.class) + public void testNotSupportWithoutRankEnd() throws Exception { + createRankFunction(RankType.ROW_NUMBER, new ConstantRankRangeWithoutEnd(1), true, true); + } + + @Test + public void testDisableGenerateRetraction() throws Exception { + AbstractRankFunction func = createRankFunction(RankType.ROW_NUMBER, new ConstantRankRange(1, 2), false, + false); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 1L, 12)); + testHarness.processElement(record("book", 2L, 19)); + testHarness.processElement(record("book", 3L, 19)); + testHarness.processElement(record("book", 4L, 11)); + testHarness.processElement(record("book", 5L, 11)); + testHarness.processElement(record("fruit", 4L, 33)); + testHarness.processElement(record("fruit", 3L, 44)); + testHarness.processElement(record("fruit", 5L, 22)); + testHarness.close(); + + // Notes: Delete message will be sent even disable generate retraction when not output rankNumber. + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 1L, 12)); + expectedOutputOutput.add(record("book", 2L, 19)); + expectedOutputOutput.add(deleteRecord("book", 2L, 19)); + expectedOutputOutput.add(record("book", 4L, 11)); + expectedOutputOutput.add(deleteRecord("book", 1L, 12)); + expectedOutputOutput.add(record("book", 5L, 11)); + expectedOutputOutput.add(record("fruit", 4L, 33)); + expectedOutputOutput.add(record("fruit", 3L, 44)); + expectedOutputOutput.add(deleteRecord("fruit", 3L, 44)); + expectedOutputOutput.add(record("fruit", 5L, 22)); + assertorWithoutRowNumber + .assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + + @Test + public void testDisableGenerateRetractionAndOutputRankNumber() throws Exception { + AbstractRankFunction func = createRankFunction(RankType.ROW_NUMBER, new ConstantRankRange(1, 2), false, + true); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 1L, 12)); + testHarness.processElement(record("book", 2L, 19)); + testHarness.processElement(record("book", 4L, 11)); + testHarness.processElement(record("book", 5L, 11)); + testHarness.processElement(record("fruit", 4L, 33)); + testHarness.processElement(record("fruit", 3L, 44)); + testHarness.processElement(record("fruit", 5L, 22)); + testHarness.close(); + + // Notes: Retract message will not be sent if disable generate retraction and output rankNumber. + // Because partition key + rankNumber decomposes a uniqueKey. + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 1L, 12, 1L)); + expectedOutputOutput.add(record("book", 2L, 19, 2L)); + expectedOutputOutput.add(record("book", 4L, 11, 1L)); + expectedOutputOutput.add(record("book", 1L, 12, 2L)); + expectedOutputOutput.add(record("book", 5L, 11, 2L)); + expectedOutputOutput.add(record("fruit", 4L, 33, 1L)); + expectedOutputOutput.add(record("fruit", 3L, 44, 2L)); + expectedOutputOutput.add(record("fruit", 5L, 22, 1L)); + expectedOutputOutput.add(record("fruit", 4L, 33, 2L)); + assertorWithRowNumber.assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + + @Test + public void testOutputRankNumberWithConstantRankRange() throws Exception { + AbstractRankFunction func = createRankFunction(RankType.ROW_NUMBER, new ConstantRankRange(1, 2), true, + true); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 1L, 12)); + testHarness.processElement(record("book", 2L, 19)); + testHarness.processElement(record("book", 4L, 11)); + testHarness.processElement(record("book", 5L, 11)); + testHarness.processElement(record("fruit", 4L, 33)); + testHarness.processElement(record("fruit", 3L, 44)); + testHarness.processElement(record("fruit", 5L, 22)); + testHarness.close(); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 1L, 12, 1L)); + expectedOutputOutput.add(record("book", 2L, 19, 2L)); + expectedOutputOutput.add(retractRecord("book", 1L, 12, 1L)); + expectedOutputOutput.add(retractRecord("book", 2L, 19, 2L)); + expectedOutputOutput.add(record("book", 4L, 11, 1L)); + expectedOutputOutput.add(record("book", 1L, 12, 2L)); + expectedOutputOutput.add(retractRecord("book", 1L, 12, 2L)); + expectedOutputOutput.add(record("book", 5L, 11, 2L)); + expectedOutputOutput.add(record("fruit", 4L, 33, 1L)); + expectedOutputOutput.add(record("fruit", 3L, 44, 2L)); + expectedOutputOutput.add(retractRecord("fruit", 4L, 33, 1L)); + expectedOutputOutput.add(retractRecord("fruit", 3L, 44, 2L)); + expectedOutputOutput.add(record("fruit", 4L, 33, 2L)); + expectedOutputOutput.add(record("fruit", 5L, 22, 1L)); + assertorWithRowNumber + .assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + + @Test + public void testConstantRankRangeWithOffset() throws Exception { + AbstractRankFunction func = createRankFunction(RankType.ROW_NUMBER, new ConstantRankRange(2, 2), true, + false); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 1L, 12)); + testHarness.processElement(record("book", 2L, 19)); + testHarness.processElement(record("book", 4L, 11)); + testHarness.processElement(record("fruit", 4L, 33)); + testHarness.processElement(record("fruit", 3L, 44)); + testHarness.processElement(record("fruit", 5L, 22)); + testHarness.close(); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 2L, 19)); + expectedOutputOutput.add(retractRecord("book", 2L, 19)); + expectedOutputOutput.add(record("book", 1L, 12)); + expectedOutputOutput.add(record("fruit", 3L, 44)); + expectedOutputOutput.add(retractRecord("fruit", 3L, 44)); + expectedOutputOutput.add(record("fruit", 4L, 33)); + assertorWithoutRowNumber + .assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + + @Test + public void testOutputRankNumberWithVariableRankRange() throws Exception { + AbstractRankFunction func = createRankFunction(RankType.ROW_NUMBER, new VariableRankRange(1), true, true); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 2L, 12)); + testHarness.processElement(record("book", 2L, 19)); + testHarness.processElement(record("book", 2L, 11)); + testHarness.processElement(record("fruit", 1L, 33)); + testHarness.processElement(record("fruit", 1L, 44)); + testHarness.processElement(record("fruit", 1L, 22)); + testHarness.close(); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 2L, 12, 1L)); + expectedOutputOutput.add(record("book", 2L, 19, 2L)); + expectedOutputOutput.add(retractRecord("book", 2L, 12, 1L)); + expectedOutputOutput.add(retractRecord("book", 2L, 19, 2L)); + expectedOutputOutput.add(record("book", 2L, 11, 1L)); + expectedOutputOutput.add(record("book", 2L, 12, 2L)); + expectedOutputOutput.add(record("fruit", 1L, 33, 1L)); + expectedOutputOutput.add(retractRecord("fruit", 1L, 33, 1L)); + expectedOutputOutput.add(record("fruit", 1L, 22, 1L)); + assertorWithRowNumber + .assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + + @Test + public void testConstantRankRangeWithoutOffset() throws Exception { + AbstractRankFunction func = createRankFunction(RankType.ROW_NUMBER, new ConstantRankRange(1, 2), true, + false); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 1L, 12)); + testHarness.processElement(record("book", 2L, 19)); + testHarness.processElement(record("book", 4L, 11)); + testHarness.processElement(record("fruit", 4L, 33)); + testHarness.processElement(record("fruit", 3L, 44)); + testHarness.processElement(record("fruit", 5L, 22)); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 1L, 12)); + expectedOutputOutput.add(record("book", 2L, 19)); + expectedOutputOutput.add(retractRecord("book", 2L, 19)); + expectedOutputOutput.add(record("book", 4L, 11)); + expectedOutputOutput.add(record("fruit", 4L, 33)); + expectedOutputOutput.add(record("fruit", 3L, 44)); + expectedOutputOutput.add(retractRecord("fruit", 3L, 44)); + expectedOutputOutput.add(record("fruit", 5L, 22)); + assertorWithoutRowNumber + .assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + + // do a snapshot, data could be recovered from state + OperatorSubtaskState snapshot = testHarness.snapshot(0L, 0); + testHarness.close(); + expectedOutputOutput.clear(); + + func = createRankFunction(RankType.ROW_NUMBER, new ConstantRankRange(1, 2), true, false); + testHarness = createTestHarness(func); + testHarness.setup(); + testHarness.initializeState(snapshot); + testHarness.open(); + testHarness.processElement(record("book", 1L, 10)); + testHarness.close(); + + expectedOutputOutput.add(retractRecord("book", 1L, 12)); + expectedOutputOutput.add(record("book", 1L, 10)); + assertorWithoutRowNumber + .assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + + OneInputStreamOperatorTestHarness createTestHarness( + AbstractRankFunction rankFunction) + throws Exception { + KeyedProcessOperator operator = new KeyedProcessOperator<>(rankFunction); + rankFunction.setKeyContext(operator); + return new KeyedOneInputStreamOperatorTestHarness<>(operator, keySelector, keySelector.getProducedType()); + } + + protected abstract AbstractRankFunction createRankFunction(RankType rankType, RankRange rankRange, + boolean generateRetraction, boolean outputRankNumber); + +} diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/RetractRankFunctionTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/RetractRankFunctionTest.java new file mode 100644 index 0000000000000..ed38f631267ea --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/RetractRankFunctionTest.java @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; +import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; +import org.apache.flink.table.dataformat.BaseRow; + +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.apache.flink.table.runtime.util.StreamRecordUtils.deleteRecord; +import static org.apache.flink.table.runtime.util.StreamRecordUtils.record; +import static org.apache.flink.table.runtime.util.StreamRecordUtils.retractRecord; + +/** + * Tests for {@link RetractRankFunction}. + */ +public class RetractRankFunctionTest extends BaseRankFunctionTest { + + @Override + protected AbstractRankFunction createRankFunction(RankType rankType, RankRange rankRange, + boolean generateRetraction, boolean outputRankNumber) { + return new RetractRankFunction(minTime.toMilliseconds(), maxTime.toMilliseconds(), + inputRowType, sortKeyComparator, sortKeySelector, rankType, rankRange, generatedEqualiser, + generateRetraction, outputRankNumber); + } + + @Test + public void testProcessRetractMessageWithNotGenerateRetraction() throws Exception { + AbstractRankFunction func = createRankFunction(RankType.ROW_NUMBER, new ConstantRankRange(1, 2), false, + true); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 1L, 12)); + testHarness.processElement(record("book", 2L, 19)); + testHarness.processElement(record("book", 4L, 11)); + testHarness.processElement(retractRecord("book", 1L, 12)); + testHarness.processElement(record("book", 5L, 11)); + testHarness.processElement(record("fruit", 4L, 33)); + testHarness.processElement(record("fruit", 3L, 44)); + testHarness.processElement(record("fruit", 5L, 22)); + testHarness.close(); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 1L, 12, 1L)); + expectedOutputOutput.add(record("book", 2L, 19, 2L)); + expectedOutputOutput.add(record("book", 4L, 11, 1L)); + expectedOutputOutput.add(record("book", 1L, 12, 2L)); + expectedOutputOutput.add(deleteRecord("book", 1L, 12, 2L)); + expectedOutputOutput.add(record("book", 2L, 19, 2L)); + expectedOutputOutput.add(record("book", 5L, 11, 2L)); + expectedOutputOutput.add(record("fruit", 4L, 33, 1L)); + expectedOutputOutput.add(record("fruit", 3L, 44, 2L)); + expectedOutputOutput.add(record("fruit", 5L, 22, 1L)); + expectedOutputOutput.add(record("fruit", 4L, 33, 2L)); + assertorWithRowNumber.assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + + @Test + public void testProcessRetractMessageWithGenerateRetraction() throws Exception { + AbstractRankFunction func = createRankFunction(RankType.ROW_NUMBER, new ConstantRankRange(1, 2), true, + true); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 1L, 12)); + testHarness.processElement(record("book", 2L, 19)); + testHarness.processElement(record("book", 4L, 11)); + testHarness.processElement(retractRecord("book", 1L, 12)); + testHarness.processElement(record("book", 5L, 11)); + testHarness.processElement(record("fruit", 4L, 33)); + testHarness.processElement(record("fruit", 3L, 44)); + testHarness.processElement(record("fruit", 5L, 22)); + testHarness.close(); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 1L, 12, 1L)); + expectedOutputOutput.add(record("book", 2L, 19, 2L)); + expectedOutputOutput.add(retractRecord("book", 2L, 19, 2L)); + expectedOutputOutput.add(retractRecord("book", 1L, 12, 1L)); + expectedOutputOutput.add(record("book", 4L, 11, 1L)); + expectedOutputOutput.add(record("book", 1L, 12, 2L)); + expectedOutputOutput.add(retractRecord("book", 1L, 12, 2L)); + expectedOutputOutput.add(record("book", 2L, 19, 2L)); + expectedOutputOutput.add(retractRecord("book", 2L, 19, 2L)); + expectedOutputOutput.add(record("book", 5L, 11, 2L)); + expectedOutputOutput.add(record("fruit", 4L, 33, 1L)); + expectedOutputOutput.add(record("fruit", 3L, 44, 2L)); + expectedOutputOutput.add(retractRecord("fruit", 4L, 33, 1L)); + expectedOutputOutput.add(retractRecord("fruit", 3L, 44, 2L)); + expectedOutputOutput.add(record("fruit", 5L, 22, 1L)); + expectedOutputOutput.add(record("fruit", 4L, 33, 2L)); + assertorWithRowNumber.assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + + // TODO RetractRankFunction could be sent less retraction message when does not need to retract row_number + @Override + @Test + public void testConstantRankRangeWithoutOffset() throws Exception { + AbstractRankFunction func = createRankFunction(RankType.ROW_NUMBER, new ConstantRankRange(1, 2), true, + false); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 1L, 12)); + testHarness.processElement(record("book", 2L, 19)); + testHarness.processElement(record("book", 4L, 11)); + testHarness.processElement(record("fruit", 4L, 33)); + testHarness.processElement(record("fruit", 3L, 44)); + testHarness.processElement(record("fruit", 5L, 22)); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 1L, 12)); + expectedOutputOutput.add(record("book", 2L, 19)); + expectedOutputOutput.add(retractRecord("book", 1L, 12)); + expectedOutputOutput.add(record("book", 1L, 12)); + expectedOutputOutput.add(retractRecord("book", 2L, 19)); + expectedOutputOutput.add(record("book", 4L, 11)); + expectedOutputOutput.add(record("fruit", 4L, 33)); + expectedOutputOutput.add(record("fruit", 3L, 44)); + expectedOutputOutput.add(retractRecord("fruit", 4L, 33)); + expectedOutputOutput.add(retractRecord("fruit", 3L, 44)); + expectedOutputOutput.add(record("fruit", 4L, 33)); + expectedOutputOutput.add(record("fruit", 5L, 22)); + assertorWithoutRowNumber + .assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + + // do a snapshot, data could be recovered from state + OperatorSubtaskState snapshot = testHarness.snapshot(0L, 0); + testHarness.close(); + expectedOutputOutput.clear(); + + func = createRankFunction(RankType.ROW_NUMBER, new ConstantRankRange(1, 2), true, false); + testHarness = createTestHarness(func); + testHarness.setup(); + testHarness.initializeState(snapshot); + testHarness.open(); + testHarness.processElement(record("book", 1L, 10)); + + expectedOutputOutput.add(retractRecord("book", 1L, 12)); + expectedOutputOutput.add(retractRecord("book", 4L, 11)); + expectedOutputOutput.add(record("book", 4L, 11)); + expectedOutputOutput.add(record("book", 1L, 10)); + assertorWithoutRowNumber + .assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + testHarness.close(); + } + + // TODO RetractRankFunction could be sent less retraction message when does not need to retract row_number + @Test + public void testVariableRankRange() throws Exception { + AbstractRankFunction func = createRankFunction(RankType.ROW_NUMBER, new VariableRankRange(1), true, false); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 2L, 12)); + testHarness.processElement(record("book", 2L, 19)); + testHarness.processElement(record("book", 2L, 11)); + testHarness.processElement(record("fruit", 1L, 33)); + testHarness.processElement(record("fruit", 1L, 44)); + testHarness.processElement(record("fruit", 1L, 22)); + testHarness.close(); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 2L, 12)); + expectedOutputOutput.add(record("book", 2L, 19)); + expectedOutputOutput.add(retractRecord("book", 2L, 19)); + expectedOutputOutput.add(retractRecord("book", 2L, 12)); + expectedOutputOutput.add(record("book", 2L, 12)); + expectedOutputOutput.add(record("book", 2L, 11)); + expectedOutputOutput.add(record("fruit", 1L, 33)); + expectedOutputOutput.add(retractRecord("fruit", 1L, 33)); + expectedOutputOutput.add(record("fruit", 1L, 22)); + assertorWithoutRowNumber + .assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + + // TODO + @Test + public void testDisableGenerateRetraction() throws Exception { + AbstractRankFunction func = createRankFunction(RankType.ROW_NUMBER, new ConstantRankRange(1, 2), false, + false); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 1L, 12)); + testHarness.processElement(record("book", 2L, 19)); + testHarness.processElement(record("book", 4L, 11)); + testHarness.processElement(record("fruit", 4L, 33)); + testHarness.processElement(record("fruit", 3L, 44)); + testHarness.processElement(record("fruit", 5L, 22)); + testHarness.close(); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 1L, 12)); + expectedOutputOutput.add(record("book", 2L, 19)); + expectedOutputOutput.add(record("book", 1L, 12)); + expectedOutputOutput.add(record("book", 4L, 11)); + expectedOutputOutput.add(record("fruit", 4L, 33)); + expectedOutputOutput.add(record("fruit", 3L, 44)); + expectedOutputOutput.add(record("fruit", 4L, 33)); + expectedOutputOutput.add(record("fruit", 5L, 22)); + assertorWithoutRowNumber + .assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/UpdateRankFunctionTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/UpdateRankFunctionTest.java new file mode 100644 index 0000000000000..92b76cb8eb955 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/rank/UpdateRankFunctionTest.java @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.rank; + +import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; +import org.apache.flink.table.dataformat.BaseRow; + +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.apache.flink.table.runtime.util.StreamRecordUtils.deleteRecord; +import static org.apache.flink.table.runtime.util.StreamRecordUtils.record; +import static org.apache.flink.table.runtime.util.StreamRecordUtils.retractRecord; + +/** + * Tests for {@link UpdateRankFunction}. + */ +public class UpdateRankFunctionTest extends BaseRankFunctionTest { + + @Override + protected AbstractRankFunction createRankFunction(RankType rankType, RankRange rankRange, + boolean generateRetraction, boolean outputRankNumber) { + return new UpdateRankFunction(minTime.toMilliseconds(), maxTime.toMilliseconds(), + inputRowType, rowKeySelector, sortKeyComparator, sortKeySelector, rankType, rankRange, + generatedEqualiser, generateRetraction, outputRankNumber, cacheSize); + } + + @Test + public void testVariableRankRange() throws Exception { + AbstractRankFunction func = createRankFunction(RankType.ROW_NUMBER, + new VariableRankRange(1), true, false); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 2L, 19)); + testHarness.processElement(record("book", 2L, 18)); + testHarness.processElement(record("fruit", 1L, 44)); + testHarness.processElement(record("fruit", 1L, 33)); + testHarness.processElement(record("fruit", 1L, 22)); + testHarness.close(); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 2L, 19)); + expectedOutputOutput.add(retractRecord("book", 2L, 19)); + expectedOutputOutput.add(record("book", 2L, 18)); + expectedOutputOutput.add(record("fruit", 1L, 44)); + expectedOutputOutput.add(retractRecord("fruit", 1L, 44)); + expectedOutputOutput.add(record("fruit", 1L, 33)); + expectedOutputOutput.add(retractRecord("fruit", 1L, 33)); + expectedOutputOutput.add(record("fruit", 1L, 22)); + assertorWithoutRowNumber + .assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + + @Override + @Test + public void testOutputRankNumberWithVariableRankRange() throws Exception { + AbstractRankFunction func = createRankFunction(RankType.ROW_NUMBER, + new VariableRankRange(1), true, true); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 2L, 19)); + testHarness.processElement(record("book", 2L, 12)); + testHarness.processElement(record("book", 2L, 11)); + testHarness.processElement(record("fruit", 1L, 44)); + testHarness.processElement(record("fruit", 1L, 33)); + testHarness.processElement(record("fruit", 1L, 22)); + testHarness.close(); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 2L, 19, 1L)); + expectedOutputOutput.add(retractRecord("book", 2L, 19, 1L)); + expectedOutputOutput.add(record("book", 2L, 12, 1L)); + expectedOutputOutput.add(retractRecord("book", 2L, 12, 1L)); + expectedOutputOutput.add(record("book", 2L, 11, 1L)); + expectedOutputOutput.add(record("fruit", 1L, 44, 1L)); + expectedOutputOutput.add(retractRecord("fruit", 1L, 44, 1L)); + expectedOutputOutput.add(record("fruit", 1L, 33, 1L)); + expectedOutputOutput.add(retractRecord("fruit", 1L, 33, 1L)); + expectedOutputOutput.add(record("fruit", 1L, 22, 1L)); + assertorWithRowNumber + .assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + + @Test + public void testSortKeyChangesWhenOutputRankNumber() throws Exception { + AbstractRankFunction func = createRankFunction(RankType.ROW_NUMBER, + new ConstantRankRange(1, 2), true, true); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 2L, 19)); + testHarness.processElement(record("book", 3L, 16)); + testHarness.processElement(record("book", 2L, 11)); + testHarness.processElement(record("book", 3L, 15)); + testHarness.processElement(record("book", 4L, 2)); + testHarness.processElement(record("book", 2L, 1)); + testHarness.close(); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 2L, 19, 1L)); + expectedOutputOutput.add(retractRecord("book", 2L, 19, 1L)); + expectedOutputOutput.add(record("book", 2L, 19, 2L)); + expectedOutputOutput.add(record("book", 3L, 16, 1L)); + expectedOutputOutput.add(retractRecord("book", 3L, 16, 1L)); + expectedOutputOutput.add(record("book", 3L, 16, 2L)); + expectedOutputOutput.add(retractRecord("book", 2L, 19, 2L)); + expectedOutputOutput.add(record("book", 2L, 11, 1L)); + expectedOutputOutput.add(retractRecord("book", 3L, 16, 2L)); + expectedOutputOutput.add(record("book", 3L, 15, 2L)); + expectedOutputOutput.add(retractRecord("book", 3L, 15, 2L)); + expectedOutputOutput.add(retractRecord("book", 2L, 11, 1L)); + expectedOutputOutput.add(record("book", 2L, 11, 2L)); + expectedOutputOutput.add(record("book", 4L, 2, 1L)); + expectedOutputOutput.add(retractRecord("book", 2L, 11, 2L)); + expectedOutputOutput.add(retractRecord("book", 4L, 2, 1L)); + expectedOutputOutput.add(record("book", 2L, 1, 1L)); + expectedOutputOutput.add(record("book", 4L, 2, 2L)); + + assertorWithRowNumber + .assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + + @Test + public void testSortKeyChangesWhenOutputRankNumberAndNotGenerateRetraction() throws Exception { + AbstractRankFunction func = createRankFunction(RankType.ROW_NUMBER, + new ConstantRankRange(1, 2), false, true); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 2L, 19)); + testHarness.processElement(record("book", 3L, 16)); + testHarness.processElement(record("book", 2L, 11)); + testHarness.processElement(record("book", 3L, 15)); + testHarness.processElement(record("book", 4L, 2)); + testHarness.processElement(record("book", 2L, 1)); + testHarness.close(); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 2L, 19, 1L)); + expectedOutputOutput.add(record("book", 2L, 19, 2L)); + expectedOutputOutput.add(record("book", 3L, 16, 1L)); + expectedOutputOutput.add(record("book", 3L, 16, 2L)); + expectedOutputOutput.add(record("book", 2L, 11, 1L)); + expectedOutputOutput.add(record("book", 3L, 15, 2L)); + expectedOutputOutput.add(record("book", 2L, 11, 2L)); + expectedOutputOutput.add(record("book", 4L, 2, 1L)); + expectedOutputOutput.add(record("book", 2L, 1, 1L)); + expectedOutputOutput.add(record("book", 4L, 2, 2L)); + + assertorWithRowNumber + .assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + + @Test + public void testSortKeyChangesWhenNotOutputRankNumber() throws Exception { + AbstractRankFunction func = createRankFunction(RankType.ROW_NUMBER, + new ConstantRankRange(1, 2), true, false); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 2L, 19)); + testHarness.processElement(record("book", 3L, 16)); + testHarness.processElement(record("book", 2L, 11)); + testHarness.processElement(record("book", 3L, 15)); + testHarness.processElement(record("book", 4L, 2)); + testHarness.processElement(record("book", 2L, 1)); + testHarness.close(); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 2L, 19)); + expectedOutputOutput.add(record("book", 3L, 16)); + expectedOutputOutput.add(retractRecord("book", 2L, 19)); + expectedOutputOutput.add(record("book", 2L, 11)); + expectedOutputOutput.add(retractRecord("book", 3L, 16)); + expectedOutputOutput.add(record("book", 3L, 15)); + expectedOutputOutput.add(record("book", 4L, 2)); + expectedOutputOutput.add(retractRecord("book", 3L, 15)); + expectedOutputOutput.add(record("book", 2L, 1)); + expectedOutputOutput.add(retractRecord("book", 2L, 11)); + + assertorWithRowNumber + .assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } + + @Test + public void testSortKeyChangesWhenNotOutputRankNumberAndNotGenerateRetraction() throws Exception { + AbstractRankFunction func = createRankFunction(RankType.ROW_NUMBER, + new ConstantRankRange(1, 2), false, false); + OneInputStreamOperatorTestHarness testHarness = createTestHarness(func); + testHarness.open(); + testHarness.processElement(record("book", 2L, 19)); + testHarness.processElement(record("book", 3L, 16)); + testHarness.processElement(record("book", 2L, 11)); + testHarness.processElement(record("book", 3L, 15)); + testHarness.processElement(record("book", 4L, 2)); + testHarness.processElement(record("book", 2L, 1)); + testHarness.close(); + + List expectedOutputOutput = new ArrayList<>(); + expectedOutputOutput.add(record("book", 2L, 19)); + expectedOutputOutput.add(record("book", 3L, 16)); + expectedOutputOutput.add(record("book", 2L, 11)); + expectedOutputOutput.add(record("book", 3L, 15)); + expectedOutputOutput.add(record("book", 4L, 2)); + expectedOutputOutput.add(deleteRecord("book", 3L, 15)); + expectedOutputOutput.add(record("book", 2L, 1)); + + assertorWithRowNumber + .assertOutputEqualsSorted("output wrong.", expectedOutputOutput, testHarness.getOutput()); + } +} diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/sort/IntRecordComparator.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/sort/IntRecordComparator.java index 12ac7df525c97..6b5abcd2e4857 100644 --- a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/sort/IntRecordComparator.java +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/sort/IntRecordComparator.java @@ -41,4 +41,8 @@ public int compare(BaseRow o1, BaseRow o2) { return 0; } + @Override + public boolean equals(Object obj) { + return obj instanceof IntRecordComparator; + } } diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/util/BaseRowRecordEqualiser.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/util/BaseRowRecordEqualiser.java new file mode 100644 index 0000000000000..fd867f07bb274 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/util/BaseRowRecordEqualiser.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.util; + +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.dataformat.BinaryRow; +import org.apache.flink.table.dataformat.GenericRow; +import org.apache.flink.table.generated.RecordEqualiser; + +/** + * A utility class to check whether two BaseRow are equal. + * Note: Only support to compare two BinaryRows or two GenericRows. + */ +public class BaseRowRecordEqualiser implements RecordEqualiser { + + @Override + public boolean equals(BaseRow row1, BaseRow row2) { + if (row1 instanceof BinaryRow && row2 instanceof BinaryRow) { + return row1.equals(row2); + } else if (row1 instanceof GenericRow && row2 instanceof GenericRow) { + return row1.equals(row2); + } else { + throw new UnsupportedOperationException(); + } + + } + + @Override + public boolean equalsWithoutHeader(BaseRow row1, BaseRow row2) { + if (row1 instanceof BinaryRow && row2 instanceof BinaryRow) { + return ((BinaryRow) row1).equalsWithoutHeader(row2); + } else if (row1 instanceof GenericRow && row2 instanceof GenericRow) { + return ((GenericRow) row1).equalsWithoutHeader(row2); + } else { + throw new UnsupportedOperationException(); + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/util/BinaryRowKeySelector.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/util/BinaryRowKeySelector.java index 3b5b9e8822983..46475662a0b51 100644 --- a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/util/BinaryRowKeySelector.java +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/util/BinaryRowKeySelector.java @@ -18,21 +18,21 @@ package org.apache.flink.table.runtime.util; -import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.api.java.functions.KeySelector; -import org.apache.flink.api.java.typeutils.ResultTypeQueryable; import org.apache.flink.table.dataformat.BaseRow; import org.apache.flink.table.dataformat.BinaryRow; import org.apache.flink.table.dataformat.BinaryRowWriter; import org.apache.flink.table.dataformat.BinaryWriter; import org.apache.flink.table.dataformat.TypeGetterSetters; +import org.apache.flink.table.runtime.keyselector.BaseRowKeySelector; import org.apache.flink.table.type.InternalType; import org.apache.flink.table.typeutils.BaseRowTypeInfo; /** - * A utility class which will extract key from BaseRow. + * A utility class which extracts key from BaseRow. */ -public class BinaryRowKeySelector implements KeySelector, ResultTypeQueryable { +public class BinaryRowKeySelector implements BaseRowKeySelector { + + private static final long serialVersionUID = -2327761762415377059L; private final int[] keyFields; private final InternalType[] inputFieldTypes; @@ -67,7 +67,7 @@ public BaseRow getKey(BaseRow value) throws Exception { } @Override - public TypeInformation getProducedType() { + public BaseRowTypeInfo getProducedType() { return new BaseRowTypeInfo(keyFieldTypes); } } diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/util/GenericRowRecordSortComparator.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/util/GenericRowRecordSortComparator.java new file mode 100644 index 0000000000000..c572f8b9e2e38 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/util/GenericRowRecordSortComparator.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.util; + +import org.apache.flink.table.dataformat.GenericRow; +import org.apache.flink.table.dataformat.TypeGetterSetters; +import org.apache.flink.table.type.InternalType; + +import java.io.Serializable; +import java.util.Comparator; + +/** + * A utility class to compare two GenericRow based on sortKey value. + * Note: Only support sortKey is Comparable value. + */ +public class GenericRowRecordSortComparator implements Comparator, Serializable { + + private static final long serialVersionUID = -4988371592272863772L; + + private final int sortKeyIdx; + private final InternalType sortKeyType; + + public GenericRowRecordSortComparator(int sortKeyIdx, InternalType sortKeyType) { + this.sortKeyIdx = sortKeyIdx; + this.sortKeyType = sortKeyType; + } + + @Override + public int compare(GenericRow row1, GenericRow row2) { + byte header1 = row1.getHeader(); + byte header2 = row2.getHeader(); + if (header1 != header2) { + return header1 - header2; + } else { + Object key1 = TypeGetterSetters.get(row1, sortKeyIdx, sortKeyType); + Object key2 = TypeGetterSetters.get(row2, sortKeyIdx, sortKeyType); + if (key1 instanceof Comparable) { + return ((Comparable) key1).compareTo(key2); + } else { + throw new UnsupportedOperationException(); + } + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/util/StreamRecordUtils.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/util/StreamRecordUtils.java new file mode 100644 index 0000000000000..7cf7d5b5e90f5 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/util/StreamRecordUtils.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.util; + +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.dataformat.GenericRow; +import org.apache.flink.table.dataformat.util.BaseRowUtil; + +import static org.apache.flink.table.dataformat.BinaryString.fromString; + +/** + * Utilities to generate a StreamRecord which encapsulates value of BaseRow type. + */ +public class StreamRecordUtils { + + public static StreamRecord record(String key, Object... fields) { + return new StreamRecord<>(baserow(key, fields)); + } + + public static StreamRecord retractRecord(String key, Object... fields) { + BaseRow row = baserow(key, fields); + BaseRowUtil.setRetract(row); + return new StreamRecord<>(row); + } + + public static StreamRecord deleteRecord(String key, Object... fields) { + BaseRow row = baserow(key, fields); + BaseRowUtil.setRetract(row); + return new StreamRecord<>(row); + } + + public static BaseRow baserow(String key, Object... fields) { + Object[] objects = new Object[fields.length + 1]; + objects[0] = fromString(key); + System.arraycopy(fields, 0, objects, 1, fields.length); + return GenericRow.of(objects); + } +} diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/window/WindowOperatorContractTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/window/WindowOperatorContractTest.java index 54863a6d235d9..4961892c2eb8d 100644 --- a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/window/WindowOperatorContractTest.java +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/window/WindowOperatorContractTest.java @@ -44,8 +44,8 @@ import java.util.Arrays; import java.util.Collections; -import static org.apache.flink.table.runtime.window.WindowTestUtils.baserow; -import static org.apache.flink.table.runtime.window.WindowTestUtils.record; +import static org.apache.flink.table.runtime.util.StreamRecordUtils.baserow; +import static org.apache.flink.table.runtime.util.StreamRecordUtils.record; import static org.junit.Assert.assertEquals; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyLong; diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/window/WindowOperatorTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/window/WindowOperatorTest.java index 0004fc0ba418e..5f4deb0818d2e 100644 --- a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/window/WindowOperatorTest.java +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/window/WindowOperatorTest.java @@ -35,6 +35,7 @@ import org.apache.flink.table.runtime.context.ExecutionContext; import org.apache.flink.table.runtime.util.BaseRowHarnessAssertor; import org.apache.flink.table.runtime.util.BinaryRowKeySelector; +import org.apache.flink.table.runtime.util.GenericRowRecordSortComparator; import org.apache.flink.table.runtime.window.assigners.SessionWindowAssigner; import org.apache.flink.table.runtime.window.assigners.TumblingWindowAssigner; import org.apache.flink.table.runtime.window.assigners.WindowAssigner; @@ -47,17 +48,15 @@ import org.junit.Test; -import java.io.Serializable; import java.time.Duration; import java.util.Collection; import java.util.Collections; -import java.util.Comparator; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicInteger; import static org.apache.flink.table.dataformat.BinaryString.fromString; -import static org.apache.flink.table.runtime.window.WindowTestUtils.record; -import static org.apache.flink.table.runtime.window.WindowTestUtils.retractRecord; +import static org.apache.flink.table.runtime.util.StreamRecordUtils.record; +import static org.apache.flink.table.runtime.util.StreamRecordUtils.retractRecord; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -83,15 +82,15 @@ public class WindowOperatorTest { InternalTypes.LONG, InternalTypes.LONG); - private InternalType[] aggResultTypes = new InternalType[]{InternalTypes.LONG, InternalTypes.LONG}; - private InternalType[] accTypes = new InternalType[]{InternalTypes.LONG, InternalTypes.LONG}; - private InternalType[] windowTypes = new InternalType[]{InternalTypes.LONG, InternalTypes.LONG, InternalTypes.LONG}; + private InternalType[] aggResultTypes = new InternalType[] { InternalTypes.LONG, InternalTypes.LONG }; + private InternalType[] accTypes = new InternalType[] { InternalTypes.LONG, InternalTypes.LONG }; + private InternalType[] windowTypes = new InternalType[] { InternalTypes.LONG, InternalTypes.LONG, InternalTypes.LONG }; private GenericRowEqualiser equaliser = new GenericRowEqualiser(accTypes, windowTypes); - private BinaryRowKeySelector keySelector = new BinaryRowKeySelector(new int[]{0}, inputFieldTypes); + private BinaryRowKeySelector keySelector = new BinaryRowKeySelector(new int[] { 0 }, inputFieldTypes); private TypeInformation keyType = keySelector.getProducedType(); private BaseRowHarnessAssertor assertor = new BaseRowHarnessAssertor( outputType.getFieldTypes(), - new GenericRowResultSortComparator()); + new GenericRowRecordSortComparator(0, InternalTypes.STRING)); @Test public void testEventTimeSlidingWindows() throws Exception { @@ -709,7 +708,7 @@ public void testProcessingTimeSessionWindows() throws Throwable { OneInputStreamOperatorTestHarness testHarness = createTestHarness(operator); BaseRowHarnessAssertor assertor = new BaseRowHarnessAssertor( - outputType.getFieldTypes(), new GenericRowResultSortComparator()); + outputType.getFieldTypes(), new GenericRowRecordSortComparator(0, InternalTypes.STRING)); ConcurrentLinkedQueue expectedOutputOutput = new ConcurrentLinkedQueue<>(); @@ -953,7 +952,7 @@ public void testCleanupTimerWithEmptyReduceStateForTumblingWindows() throws Exce public void testTumblingCountWindow() throws Exception { closeCalled.set(0); final int windowSize = 3; - InternalType[] windowTypes = new InternalType[]{InternalTypes.LONG}; + InternalType[] windowTypes = new InternalType[] { InternalTypes.LONG }; WindowOperator operator = WindowOperatorBuilder.builder() .withInputFields(inputFieldTypes) @@ -1023,7 +1022,7 @@ public void testSlidingCountWindow() throws Exception { closeCalled.set(0); final int windowSize = 5; final int windowSlide = 3; - InternalType[] windowTypes = new InternalType[]{InternalTypes.LONG}; + InternalType[] windowTypes = new InternalType[] { InternalTypes.LONG }; WindowOperator operator = WindowOperatorBuilder.builder() .withInputFields(inputFieldTypes) @@ -1127,19 +1126,6 @@ public SessionWindowAssigner withProcessingTime() { } } - // schema: String, Long, Long, Long, Long - private static class GenericRowResultSortComparator implements Comparator, Serializable { - - private static final long serialVersionUID = -4988371592272863772L; - - @Override - public int compare(GenericRow row1, GenericRow row2) { - String key1 = row1.getString(0).toString(); - String key2 = row2.getString(0).toString(); - return key1.compareTo(key2); - } - } - // sum, count, window_start, window_end private static class SumAndCountAggTimeWindow extends SumAndCountAgg { diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/window/WindowTestUtils.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/window/WindowTestUtils.java index ac8bfe3ede364..4abdc5d4b13da 100644 --- a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/window/WindowTestUtils.java +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/window/WindowTestUtils.java @@ -18,17 +18,11 @@ package org.apache.flink.table.runtime.window; -import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.table.api.window.TimeWindow; -import org.apache.flink.table.dataformat.BaseRow; -import org.apache.flink.table.dataformat.GenericRow; -import org.apache.flink.table.dataformat.util.BaseRowUtil; import org.hamcrest.Matcher; import org.hamcrest.Matchers; -import static org.apache.flink.table.dataformat.BinaryString.fromString; - /** * Utilities that are useful for working with Window tests. */ @@ -38,20 +32,4 @@ static Matcher timeWindow(long start, long end) { return Matchers.equalTo(new TimeWindow(start, end)); } - static StreamRecord record(String key, Object... fields) { - return new StreamRecord<>(baserow(key, fields)); - } - - static StreamRecord retractRecord(String key, Object... fields) { - BaseRow row = baserow(key, fields); - BaseRowUtil.setRetract(row); - return new StreamRecord<>(row); - } - - static BaseRow baserow(String key, Object... fields) { - Object[] objects = new Object[fields.length + 1]; - objects[0] = fromString(key); - System.arraycopy(fields, 0, objects, 1, fields.length); - return GenericRow.of(objects); - } }