Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.functions.hive;

import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.expressions.Expression;
import org.apache.flink.table.expressions.UnresolvedCallExpression;
import org.apache.flink.table.expressions.UnresolvedReferenceExpression;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.inference.CallContext;
import org.apache.flink.table.types.logical.LogicalType;

import java.math.BigDecimal;

import static org.apache.flink.connectors.hive.HiveOptions.TABLE_EXEC_HIVE_NATIVE_AGG_FUNCTION_ENABLED;
import static org.apache.flink.table.expressions.ApiExpressionUtils.unresolvedRef;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.cast;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.div;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.equalTo;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.hiveAggDecimalPlus;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.ifThenElse;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.isNull;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.literal;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.nullOf;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.plus;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.tryCast;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.typeLiteral;
import static org.apache.flink.table.types.logical.DecimalType.MAX_PRECISION;
import static org.apache.flink.table.types.logical.LogicalTypeRoot.DECIMAL;
import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getPrecision;
import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getScale;

/** built-in hive avg aggregate function. */
public class HiveAvgAggFunction extends HiveDeclarativeAggregateFunction {

private final UnresolvedReferenceExpression sum = unresolvedRef("sum");
private final UnresolvedReferenceExpression count = unresolvedRef("count");
private DataType resultType;
private DataType sumType;

@Override
public int operandCount() {
return 1;
}

@Override
public UnresolvedReferenceExpression[] aggBufferAttributes() {
return new UnresolvedReferenceExpression[] {sum, count};
}

@Override
public DataType[] getAggBufferTypes() {
return new DataType[] {getResultType(), DataTypes.BIGINT()};
}

@Override
public DataType getResultType() {
return resultType;
}

@Override
public Expression[] initialValuesExpressions() {
switch (resultType.getLogicalType().getTypeRoot()) {
case DECIMAL:
return new Expression[] {literal(BigDecimal.ZERO, sumType.notNull()), literal(0L)};
case DOUBLE:
case FLOAT:
return new Expression[] {literal(0D), literal(0L)};
default:
return new Expression[] {
/* sum = */ literal(0L, sumType.notNull()), /* count = */ literal(0L)
};
}
}

@Override
public Expression[] accumulateExpressions() {
Expression tryCastOperand = tryCast(operand(0), typeLiteral(getResultType()));
return new Expression[] {
/* sum = */ ifThenElse(isNull(tryCastOperand), sum, adjustedPlus(sum, tryCastOperand)),
/* count = */ ifThenElse(isNull(tryCastOperand), count, plus(count, literal(1L))),
};
}

@Override
public Expression[] retractExpressions() {
throw new TableException("Avg aggregate function does not support retraction.");
}

@Override
public Expression[] mergeExpressions() {
return new Expression[] {
/* sum = */ adjustedPlus(sum, mergeOperand(sum)),
/* count = */ plus(count, mergeOperand(count))
};
}

@Override
public Expression getValueExpression() {
Expression ifTrue = nullOf(getResultType());
Expression ifFalse = cast(div(sum, count), typeLiteral(getResultType()));
return ifThenElse(equalTo(count, literal(0L)), ifTrue, ifFalse);
}

@Override
public void setArguments(CallContext callContext) {
if (resultType == null) {
// check argument type firstly
checkArgumentType(callContext.getArgumentDataTypes().get(0).getLogicalType());
resultType = initResultType(callContext.getArgumentDataTypes().get(0));
sumType = resultType;
}
}

private DataType initResultType(DataType argsType) {
switch (argsType.getLogicalType().getTypeRoot()) {
case TINYINT:
case SMALLINT:
case INTEGER:
case BIGINT:
return DataTypes.BIGINT();
case FLOAT:
case DOUBLE:
case CHAR:
case VARCHAR:
return DataTypes.DOUBLE();
case DECIMAL:
int precision =
Math.min(MAX_PRECISION, getPrecision(argsType.getLogicalType()) + 10);
return DataTypes.DECIMAL(precision, getScale(argsType.getLogicalType()));
case TIMESTAMP_WITHOUT_TIME_ZONE:
throw new TableException(
String.format(
"Native hive avg aggregate function does not support type: %s. Please set option '%s' to false.",
argsType, TABLE_EXEC_HIVE_NATIVE_AGG_FUNCTION_ENABLED.key()));
default:
throw new TableException(
String.format(
"Only numeric type arguments are accepted but %s is passed.",
argsType));
}
}

private void checkArgumentType(LogicalType logicalType) {
switch (logicalType.getTypeRoot()) {
case TINYINT:
case SMALLINT:
case INTEGER:
case BIGINT:
case FLOAT:
case DOUBLE:
case DECIMAL:
return;
default:
throw new TableException(
String.format(
"Only numeric type arguments are accepted but %s is passed.",
logicalType.getTypeRoot()));
}
}

private UnresolvedCallExpression adjustedPlus(Expression arg1, Expression arg2) {
if (getResultType().getLogicalType().is(DECIMAL)) {
return hiveAggDecimalPlus(arg1, arg2);
} else {
return plus(arg1, arg2);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.flink.table.catalog.hive.factories.HiveFunctionDefinitionFactory;
import org.apache.flink.table.factories.FunctionDefinitionFactory;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.functions.hive.HiveAvgAggFunction;
import org.apache.flink.table.functions.hive.HiveMinAggFunction;
import org.apache.flink.table.functions.hive.HiveSumAggFunction;
import org.apache.flink.table.module.Module;
Expand Down Expand Up @@ -86,7 +87,7 @@ public class HiveModule implements Module {
"tumble_start")));

static final Set<String> BUILTIN_NATIVE_AGG_FUNC =
Collections.unmodifiableSet(new HashSet<>(Arrays.asList("sum", "min")));
Collections.unmodifiableSet(new HashSet<>(Arrays.asList("sum", "min", "avg")));

private final HiveFunctionDefinitionFactory factory;
private final String hiveVersion;
Expand Down Expand Up @@ -209,6 +210,8 @@ private Optional<FunctionDefinition> getBuiltInNativeAggFunction(String name) {
case "min":
// We override Hive's min function by native implementation to supports hash-agg
return Optional.of(new HiveMinAggFunction());
case "avg":
return Optional.of(new HiveAvgAggFunction());
default:
throw new UnsupportedOperationException(
String.format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,18 +220,6 @@ public void testMinAggFunction() throws Exception {
tableEnv.executeSql("select min(b) from test_min").collect());
assertThat(result7.toString()).isEqualTo("[+I[false]]");

// test min with timestamp type
List<Row> result8 =
CollectionUtil.iteratorToList(
tableEnv.executeSql("select min(ts) from test_min").collect());
assertThat(result8.toString()).isEqualTo("[+I[2021-08-04T16:26:33.400]]");

// test min with date type
List<Row> result9 =
CollectionUtil.iteratorToList(
tableEnv.executeSql("select min(dt) from test_min").collect());
assertThat(result9.toString()).isEqualTo("[+I[2021-08-01]]");

// test min with binary type
List<Row> result10 =
CollectionUtil.iteratorToList(
Expand Down Expand Up @@ -270,6 +258,116 @@ public void testMinAggFunction() throws Exception {
tableEnv.executeSql("drop table test_min_not_support_type");
}

@Test
public void testAvgAggFunction() throws Exception {
tableEnv.executeSql(
"create table test_avg(a bigint, b boolean, x string, y string, z int, d decimal(10,5), e float, f double, ts timestamp, dt date, bar binary)");
tableEnv.executeSql(
"insert into test_avg values "
+ "(1, true, NULL, '2', 1, 1.11, 1.2, 1.3, '2021-08-04 16:26:33.4','2021-08-04', 'data1'), "
+ "(3, false, NULL, 'b', 2, 2.22, 2.3, 2.4, '2021-08-06 16:26:33.4','2021-08-07', 'data2'), "
+ "(2, false, NULL, '4', 1, 3.33, 3.5, 3.6, '2021-08-08 16:26:33.4','2021-08-08', 'data3'), "
+ "(2, true, NULL, NULL, 4, 4.45, 4.7, 4.8, '2021-08-10 16:26:33.4','2021-08-01', 'data4')")
.await();

// test avg with bigint type
List<Row> result1 =
CollectionUtil.iteratorToList(
tableEnv.executeSql("select avg(a) from test_avg").collect());
assertThat(result1.toString()).isEqualTo("[+I[2]]");

// test avg with int type
List<Row> result2 =
CollectionUtil.iteratorToList(
tableEnv.executeSql("select avg(z) from test_avg").collect());
assertThat(result2.toString()).isEqualTo("[+I[2]]");

// test avg with decimal type
List<Row> result3 =
CollectionUtil.iteratorToList(
tableEnv.executeSql("select avg(d) from test_avg").collect());
assertThat(result3.toString()).isEqualTo("[+I[2.77750]]");

// test avg with float type
List<Row> result4 =
CollectionUtil.iteratorToList(
tableEnv.executeSql("select avg(e) from test_avg").collect());
assertThat(result4.toString()).isEqualTo("[+I[2.924999952316284]]");

// test avg with double type
List<Row> result5 =
CollectionUtil.iteratorToList(
tableEnv.executeSql("select avg(f) from test_avg").collect());
assertThat(result5.toString()).isEqualTo("[+I[3.0250000000000004]]");

// test avg with unsupported data type
// test avg with string type
String expectedStringMessage =
"Only numeric type arguments are accepted but VARCHAR is passed.";
assertSqlException(
"select avg(x) from test_avg", TableException.class, expectedStringMessage);

// test avg with timestamp type
String expectedTimestampMessage =
"Only numeric type arguments are accepted but TIMESTAMP_WITHOUT_TIME_ZONE is passed.";
assertSqlException(
"select avg(ts) from test_avg", TableException.class, expectedTimestampMessage);

// test avg with date type
String expectedDateMessage =
"Only numeric or string type arguments are accepted but date is passed.";
assertSqlException(
"select avg(dt) from test_avg",
UDFArgumentTypeException.class,
expectedDateMessage);

// test avg with boolean type
String expectedBooleanMessage =
"Only numeric or string type arguments are accepted but boolean is passed.";
assertSqlException(
"select avg(b) from test_avg",
UDFArgumentTypeException.class,
expectedBooleanMessage);

// test avg with binary type
String expectedBinaryMessage =
"Only numeric or string type arguments are accepted but binary is passed.";
assertSqlException(
"select avg(bar) from test_avg",
UDFArgumentTypeException.class,
expectedBinaryMessage);

// test avg with unsupported complex data type
tableEnv.executeSql(
"create table test_avg_not_support_type(a array<int>,m map<int, string>,s struct<f1:int,f2:string>)");
// test avg with row type
String expectedRowMessage =
"Only primitive type arguments are accepted but struct<f1:int,f2:string> is passed.";
assertSqlException(
"select avg(s) from test_avg_not_support_type",
UDFArgumentTypeException.class,
expectedRowMessage);

// test avg with array type
String expectedArrayMessage =
"Only primitive type arguments are accepted but array<int> is passed.";
assertSqlException(
"select avg(a) from test_avg_not_support_type",
UDFArgumentTypeException.class,
expectedArrayMessage);

// test avg with map type
String expectedMapMessage =
"Only primitive type arguments are accepted but map<int,string> is passed.";
assertSqlException(
"select avg(m) from test_avg_not_support_type",
UDFArgumentTypeException.class,
expectedMapMessage);

tableEnv.executeSql("drop table test_avg");
tableEnv.executeSql("drop table test_avg_not_support_type");
}

private void assertSqlException(
String sql, Class<?> expectedExceptionClz, String expectedMessage) {
assertThatThrownBy(() -> tableEnv.executeSql(sql))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,19 @@ public void testMinAggFunctionPlan() {
.isEqualTo(readFromResource("/explain/testMinAggFunctionFallbackPlan.out"));
}

@Test
public void testAvgAggFunctionPlan() {
// test explain
String actualPlan = explainSql("select x, avg(y) from foo group by x");
assertThat(actualPlan).isEqualTo(readFromResource("/explain/testAvgAggFunctionPlan.out"));

// test fallback to hive avg udaf
tableEnv.getConfig().set(TABLE_EXEC_HIVE_NATIVE_AGG_FUNCTION_ENABLED, false);
String actualSortAggPlan = explainSql("select x, avg(y) from foo group by x");
assertThat(actualSortAggPlan)
.isEqualTo(readFromResource("/explain/testAvgAggFunctionFallbackPlan.out"));
}

private String explainSql(String sql) {
return (String)
CollectionUtil.iteratorToList(tableEnv.executeSql("explain " + sql).collect())
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
== Abstract Syntax Tree ==
LogicalProject(x=[$0], _o__c1=[$1])
+- LogicalAggregate(group=[{0}], agg#0=[avg($1)])
+- LogicalProject($f0=[$0], $f1=[$1])
+- LogicalTableScan(table=[[test-catalog, default, foo]])

== Optimized Physical Plan ==
SortAggregate(isMerge=[true], groupBy=[x], select=[x, Final_avg($f1) AS $f1])
+- Sort(orderBy=[x ASC])
+- Exchange(distribution=[hash[x]])
+- LocalSortAggregate(groupBy=[x], select=[x, Partial_avg(y) AS $f1])
+- Sort(orderBy=[x ASC])
+- TableSourceScan(table=[[test-catalog, default, foo]], fields=[x, y])

== Optimized Execution Plan ==
SortAggregate(isMerge=[true], groupBy=[x], select=[x, Final_avg($f1) AS $f1])
+- Sort(orderBy=[x ASC])
+- Exchange(distribution=[hash[x]])
+- LocalSortAggregate(groupBy=[x], select=[x, Partial_avg(y) AS $f1])
+- Sort(orderBy=[x ASC])
+- TableSourceScan(table=[[test-catalog, default, foo]], fields=[x, y])
Loading