Skip to content

Commit

Permalink
[FLINK-25476][table-planner] support CHAR type in function MAX and MIN
Browse files Browse the repository at this point in the history
  • Loading branch information
xuyangzhong committed Dec 18, 2023
1 parent 3d4d396 commit 7a43b00
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.flink.table.planner.functions.aggfunctions.SumWithRetractAggFu
import org.apache.flink.table.planner.functions.bridging.BridgingSqlAggFunction
import org.apache.flink.table.planner.functions.sql.{SqlFirstLastValueAggFunction, SqlListAggFunction}
import org.apache.flink.table.planner.functions.utils.AggSqlFunction
import org.apache.flink.table.runtime.functions.aggregate.{BuiltInAggregateFunction, CollectAggFunction, FirstValueAggFunction, FirstValueWithRetractAggFunction, JsonArrayAggFunction, JsonObjectAggFunction, LagAggFunction, LastValueAggFunction, LastValueWithRetractAggFunction, ListAggWithRetractAggFunction, ListAggWsWithRetractAggFunction, MaxWithRetractAggFunction, MinWithRetractAggFunction}
import org.apache.flink.table.runtime.functions.aggregate._
import org.apache.flink.table.runtime.functions.aggregate.BatchApproxCountDistinctAggFunctions._
import org.apache.flink.table.types.logical._
import org.apache.flink.table.types.logical.LogicalTypeRoot._
Expand Down Expand Up @@ -273,8 +273,8 @@ class AggFunctionFactory(
val valueType = argTypes(0)
if (aggCallNeedRetractions(index)) {
valueType.getTypeRoot match {
case TINYINT | SMALLINT | INTEGER | BIGINT | FLOAT | DOUBLE | BOOLEAN | VARCHAR | DECIMAL |
TIME_WITHOUT_TIME_ZONE | DATE | TIMESTAMP_WITHOUT_TIME_ZONE |
case TINYINT | SMALLINT | INTEGER | BIGINT | FLOAT | DOUBLE | BOOLEAN | VARCHAR | CHAR |
DECIMAL | TIME_WITHOUT_TIME_ZONE | DATE | TIMESTAMP_WITHOUT_TIME_ZONE |
TIMESTAMP_WITH_LOCAL_TIME_ZONE =>
new MinWithRetractAggFunction(argTypes(0))
case t =>
Expand Down Expand Up @@ -382,8 +382,8 @@ class AggFunctionFactory(
val valueType = argTypes(0)
if (aggCallNeedRetractions(index)) {
valueType.getTypeRoot match {
case TINYINT | SMALLINT | INTEGER | BIGINT | FLOAT | DOUBLE | BOOLEAN | VARCHAR | DECIMAL |
TIME_WITHOUT_TIME_ZONE | DATE | TIMESTAMP_WITHOUT_TIME_ZONE |
case TINYINT | SMALLINT | INTEGER | BIGINT | FLOAT | DOUBLE | BOOLEAN | VARCHAR | CHAR |
DECIMAL | TIME_WITHOUT_TIME_ZONE | DATE | TIMESTAMP_WITHOUT_TIME_ZONE |
TIMESTAMP_WITH_LOCAL_TIME_ZONE =>
new MaxWithRetractAggFunction(argTypes(0))
case t =>
Expand All @@ -407,7 +407,7 @@ class AggFunctionFactory(
new MaxAggFunction.DoubleMaxAggFunction
case BOOLEAN =>
new MaxAggFunction.BooleanMaxAggFunction
case VARCHAR =>
case VARCHAR | CHAR =>
new MaxAggFunction.StringMaxAggFunction
case DATE =>
new MaxAggFunction.DateMaxAggFunction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.flink.table.runtime.functions.aggregate.MaxWithRetractAggFunction.MaxWithRetractAccumulator;
import org.apache.flink.table.types.logical.BigIntType;
import org.apache.flink.table.types.logical.BooleanType;
import org.apache.flink.table.types.logical.CharType;
import org.apache.flink.table.types.logical.DateType;
import org.apache.flink.table.types.logical.DecimalType;
import org.apache.flink.table.types.logical.DoubleType;
Expand Down Expand Up @@ -317,6 +318,31 @@ protected List<StringData> getExpectedResults() {
}
}

/** Test for {@link CharType}. */
public static final class CharMaxWithRetractAggFunctionTest
extends MaxWithRetractAggFunctionTestBase<Character> {

@Override
protected List<List<Character>> getInputValueSets() {
return Arrays.asList(
Arrays.asList('b', 'c', null, 'd', 'e', null),
Arrays.asList(null, null),
Arrays.asList(null, 'w'),
Arrays.asList('d', 'a'));
}

@Override
protected List<Character> getExpectedResults() {
return Arrays.asList('e', null, 'w', 'd');
}

@Override
protected AggregateFunction<Character, MaxWithRetractAccumulator<Character>>
getAggregator() {
return new MaxWithRetractAggFunction<>(DataTypes.CHAR(1).getLogicalType());
}
}

/** Test for {@link TimestampType}. */
@Nested
final class TimestampMaxWithRetractAggFunctionTest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.flink.table.runtime.functions.aggregate.MinWithRetractAggFunction.MinWithRetractAccumulator;
import org.apache.flink.table.types.logical.BigIntType;
import org.apache.flink.table.types.logical.BooleanType;
import org.apache.flink.table.types.logical.CharType;
import org.apache.flink.table.types.logical.DateType;
import org.apache.flink.table.types.logical.DecimalType;
import org.apache.flink.table.types.logical.DoubleType;
Expand Down Expand Up @@ -317,6 +318,31 @@ protected List<StringData> getExpectedResults() {
}
}

/** Test for {@link CharType}. */
public static final class CharMinWithRetractAggFunctionTest
extends MinWithRetractAggFunctionTestBase<Character> {

@Override
protected List<List<Character>> getInputValueSets() {
return Arrays.asList(
Arrays.asList('b', 'c', null, 'd', 'e', null),
Arrays.asList(null, null),
Arrays.asList(null, 'w'),
Arrays.asList('d', 'a'));
}

@Override
protected List<Character> getExpectedResults() {
return Arrays.asList('b', null, 'w', 'a');
}

@Override
protected AggregateFunction<Character, MinWithRetractAccumulator<Character>>
getAggregator() {
return new MinWithRetractAggFunction<>(DataTypes.CHAR(1).getLogicalType());
}
}

/** Test for {@link TimestampType}. */
@Nested
final class TimestampMinWithRetractAggFunctionTest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1331,6 +1331,39 @@ class AggregateITCase(aggMode: AggMode, miniBatch: MiniBatchMode, backend: State
assertThat(sink.getRetractResults.sorted).isEqualTo(expected.sorted)
}

@TestTemplate
def testMinMaxWithChar(): Unit = {
val data =
List(
rowOf(1, "a"),
rowOf(1, "b"),
rowOf(2, "d"),
rowOf(2, "c")
)
val dataId = TestValuesTableFactory.registerData(data)
tEnv.executeSql(s"""
|CREATE TABLE src(
| `id` INT,
| `char` CHAR(1)
|) WITH (
| 'connector' = 'values',
| 'data-id' = '$dataId'
|)
|""".stripMargin)

val sql =
"""
|select `id`, count(*), min(`char`), max(`char`) from src group by `id`
""".stripMargin

val sink = new TestingRetractSink()
tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink)
env.execute()

val expected = List("1,2,a,b", "2,2,c,d")
assertThat(sink.getRetractResults.sorted).isEqualTo(expected.sorted)
}

@TestTemplate
def testCollectOnClusteredFields(): Unit = {
val data = List(
Expand Down

0 comments on commit 7a43b00

Please sign in to comment.