From 085859ec19f59d2f8e73632e96fae59f7821c7d0 Mon Sep 17 00:00:00 2001 From: xuyang Date: Mon, 17 Jan 2022 12:27:02 +0800 Subject: [PATCH] [FLINK-25476][table-planner] Support CHAR type in built-in function MAX and MIN This closes #18375 --- .../plan/utils/AggFunctionFactory.scala | 12 +-- .../runtime/stream/sql/AggregateITCase.scala | 77 +++++++++++++++++++ 2 files changed, 83 insertions(+), 6 deletions(-) diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala index cbabddae03a44..861f537f6c24c 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala @@ -25,7 +25,7 @@ import org.apache.flink.table.planner.functions.aggfunctions.SumWithRetractAggFu import org.apache.flink.table.planner.functions.bridging.BridgingSqlAggFunction import org.apache.flink.table.planner.functions.sql.{SqlFirstLastValueAggFunction, SqlListAggFunction} import org.apache.flink.table.planner.functions.utils.AggSqlFunction -import org.apache.flink.table.runtime.functions.aggregate.{BuiltInAggregateFunction, CollectAggFunction, FirstValueAggFunction, FirstValueWithRetractAggFunction, JsonArrayAggFunction, JsonObjectAggFunction, LagAggFunction, LastValueAggFunction, LastValueWithRetractAggFunction, ListAggWithRetractAggFunction, ListAggWsWithRetractAggFunction, MaxWithRetractAggFunction, MinWithRetractAggFunction} +import org.apache.flink.table.runtime.functions.aggregate._ import org.apache.flink.table.runtime.functions.aggregate.BatchApproxCountDistinctAggFunctions._ import org.apache.flink.table.types.logical._ import org.apache.flink.table.types.logical.LogicalTypeRoot._ @@ -273,8 +273,8 @@ class AggFunctionFactory( val valueType = argTypes(0) if (aggCallNeedRetractions(index)) { valueType.getTypeRoot match { - case TINYINT | SMALLINT | INTEGER | BIGINT | FLOAT | DOUBLE | BOOLEAN | VARCHAR | DECIMAL | - TIME_WITHOUT_TIME_ZONE | DATE | TIMESTAMP_WITHOUT_TIME_ZONE | + case TINYINT | SMALLINT | INTEGER | BIGINT | FLOAT | DOUBLE | BOOLEAN | VARCHAR | CHAR | + DECIMAL | TIME_WITHOUT_TIME_ZONE | DATE | TIMESTAMP_WITHOUT_TIME_ZONE | TIMESTAMP_WITH_LOCAL_TIME_ZONE => new MinWithRetractAggFunction(argTypes(0)) case t => @@ -382,8 +382,8 @@ class AggFunctionFactory( val valueType = argTypes(0) if (aggCallNeedRetractions(index)) { valueType.getTypeRoot match { - case TINYINT | SMALLINT | INTEGER | BIGINT | FLOAT | DOUBLE | BOOLEAN | VARCHAR | DECIMAL | - TIME_WITHOUT_TIME_ZONE | DATE | TIMESTAMP_WITHOUT_TIME_ZONE | + case TINYINT | SMALLINT | INTEGER | BIGINT | FLOAT | DOUBLE | BOOLEAN | VARCHAR | CHAR | + DECIMAL | TIME_WITHOUT_TIME_ZONE | DATE | TIMESTAMP_WITHOUT_TIME_ZONE | TIMESTAMP_WITH_LOCAL_TIME_ZONE => new MaxWithRetractAggFunction(argTypes(0)) case t => @@ -407,7 +407,7 @@ class AggFunctionFactory( new MaxAggFunction.DoubleMaxAggFunction case BOOLEAN => new MaxAggFunction.BooleanMaxAggFunction - case VARCHAR => + case VARCHAR | CHAR => new MaxAggFunction.StringMaxAggFunction case DATE => new MaxAggFunction.DateMaxAggFunction diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala index 99d5e2a5f265b..0d11b69909ad3 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala @@ -1332,6 +1332,83 @@ class AggregateITCase(aggMode: AggMode, miniBatch: MiniBatchMode, backend: State assertThat(sink.getRetractResults.sorted).isEqualTo(expected.sorted) } + @TestTemplate + def testMinMaxWithChar(): Unit = { + val data = + List( + rowOf(1, "a", "gg"), + rowOf(1, "b", "hh"), + rowOf(2, "d", "j"), + rowOf(2, "c", "i") + ) + val dataId = TestValuesTableFactory.registerData(data) + tEnv.executeSql(s""" + |CREATE TABLE src( + | `id` INT, + | `char1` CHAR(1), + | `char2` CHAR(2) + |) WITH ( + | 'connector' = 'values', + | 'data-id' = '$dataId' + |) + |""".stripMargin) + + val sql = + """ + |select `id`, count(*), min(`char1`), max(`char1`), min(`char2`), max(`char2`) from src group by `id` + """.stripMargin + + val sink = new TestingRetractSink() + tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink) + env.execute() + + val expected = List("1,2,a,b,gg,hh", "2,2,c,d,i,j") + assertThat(sink.getRetractResults.sorted).isEqualTo(expected.sorted) + } + + @TestTemplate + def testRetractMinMaxWithChar(): Unit = { + val data = + List( + changelogRow("+I", Int.box(1), "a", "ee"), + changelogRow("+I", Int.box(1), "b", "ff"), + changelogRow("+I", Int.box(1), "c", "gg"), + changelogRow("-D", Int.box(1), "c", "gg"), + changelogRow("-D", Int.box(1), "a", "ee"), + changelogRow("+I", Int.box(2), "a", "e"), + changelogRow("+I", Int.box(2), "b", "f"), + changelogRow("+I", Int.box(2), "c", "g"), + changelogRow("-U", Int.box(2), "b", "f"), + changelogRow("+U", Int.box(2), "d", "h"), + changelogRow("-U", Int.box(2), "a", "e"), + changelogRow("+U", Int.box(2), "b", "f") + ) + val dataId = TestValuesTableFactory.registerData(data) + tEnv.executeSql(s""" + |CREATE TABLE src( + | `id` INT, + | `char1` CHAR(1), + | `char2` CHAR(2) + |) WITH ( + | 'connector' = 'values', + | 'data-id' = '$dataId', + | 'changelog-mode' = 'I,UA,UB,D' + |) + |""".stripMargin) + + val sql = + """ + |select `id`, count(*), min(`char1`), max(`char1`), min(`char2`), max(`char2`) from src group by `id` + """.stripMargin + + val sink = new TestingRetractSink() + tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink) + env.execute() + + val expected = List("1,1,b,b,ff,ff", "2,3,b,d,f,h") + assertThat(sink.getRetractResults.sorted).isEqualTo(expected.sorted) + } + @TestTemplate def testCollectOnClusteredFields(): Unit = { val data = List(