Skip to content

Commit

Permalink
[FLINK-13523][table-planner-blink] address feedbacks for 1941bc6
Browse files Browse the repository at this point in the history
  • Loading branch information
docete committed Aug 8, 2019
1 parent 48c66a9 commit 7d17148
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 27 deletions.
Expand Up @@ -28,6 +28,7 @@
import java.math.BigDecimal;

import static org.apache.flink.table.expressions.utils.ApiExpressionUtils.unresolvedRef;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.cast;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.div;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.equalTo;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.ifThenElse;
Expand All @@ -36,6 +37,7 @@
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.minus;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.nullOf;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.plus;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.typeLiteral;

/**
* built-in avg aggregate function.
Expand Down Expand Up @@ -103,7 +105,9 @@ public Expression[] mergeExpressions() {
*/
@Override
public Expression getValueExpression() {
return ifThenElse(equalTo(count, literal(0L)), nullOf(getResultType()), div(sum, count));
Expression ifTrue = nullOf(getResultType());
Expression ifFalse = cast(div(sum, count), typeLiteral(getResultType()));
return ifThenElse(equalTo(count, literal(0L)), ifTrue, ifFalse);
}

/**
Expand Down
Expand Up @@ -45,6 +45,9 @@
* defining {@link #initialValuesExpressions}, {@link #accumulateExpressions},
* {@link #mergeExpressions} and {@link #getValueExpression}.
*
* <p> Note: Developer of DeclarativeAggregateFunction should guarantee that the inferred type
* of {@link #getValueExpression} is the same as {@link #getResultType()}
*
* <p>See an full example: {@link AvgAggFunction}.
*/
public abstract class DeclarativeAggregateFunction extends UserDefinedFunction {
Expand Down
Expand Up @@ -204,13 +204,8 @@ class DeclarativeAggCodeGen(
}

def getValue(generator: ExprCodeGenerator): GeneratedExpression = {
val expr = function.getValueExpression
val resolvedGetValueExpression = function.getValueExpression
.accept(ResolveReference())
val resolvedGetValueExpression = ApiExpressionUtils.unresolvedCall(
BuiltInFunctionDefinitions.CAST,
expr,
ApiExpressionUtils.typeLiteral(aggInfo.externalResultType)
)
generator.generateExpression(resolvedGetValueExpression.accept(rexNodeGen))
}

Expand Down
Expand Up @@ -24,7 +24,7 @@ import org.apache.flink.streaming.api.operators.OneInputStreamOperator
import org.apache.flink.table.dataformat.{BaseRow, GenericRow}
import org.apache.flink.table.expressions.utils.ApiExpressionUtils
import org.apache.flink.table.expressions.{Expression, ExpressionVisitor, FieldReferenceExpression, TypeLiteralExpression, UnresolvedCallExpression, UnresolvedReferenceExpression, ValueLiteralExpression, _}
import org.apache.flink.table.functions.{AggregateFunction, BuiltInFunctionDefinitions, UserDefinedFunction}
import org.apache.flink.table.functions.{AggregateFunction, UserDefinedFunction}
import org.apache.flink.table.planner.codegen.CodeGenUtils._
import org.apache.flink.table.planner.codegen.OperatorCodeGenerator.STREAM_RECORD
import org.apache.flink.table.planner.codegen._
Expand Down Expand Up @@ -524,12 +524,8 @@ object AggCodeGenHelper {
val aggExprs = aggregates.zipWithIndex.map {
case (agg: DeclarativeAggregateFunction, aggIndex) =>
val idx = auxGrouping.length + aggIndex
val expr = agg.getValueExpression.accept(ResolveReference(
agg.getValueExpression.accept(ResolveReference(
ctx, isMerge, agg, idx, argsMapping, aggBufferTypes))
ApiExpressionUtils.unresolvedCall(
BuiltInFunctionDefinitions.CAST,
expr,
ApiExpressionUtils.typeLiteral(agg.getResultType))
case (agg: AggregateFunction[_, _], aggIndex) =>
val idx = auxGrouping.length + aggIndex
(agg, idx)
Expand Down
Expand Up @@ -23,7 +23,7 @@ import org.apache.flink.metrics.Gauge
import org.apache.flink.table.dataformat.{BaseRow, BinaryRow, GenericRow, JoinedRow}
import org.apache.flink.table.expressions.utils.ApiExpressionUtils
import org.apache.flink.table.expressions.{Expression, ExpressionVisitor, FieldReferenceExpression, TypeLiteralExpression, UnresolvedCallExpression, UnresolvedReferenceExpression, ValueLiteralExpression, _}
import org.apache.flink.table.functions.{AggregateFunction, BuiltInFunctionDefinitions, UserDefinedFunction}
import org.apache.flink.table.functions.{AggregateFunction, UserDefinedFunction}
import org.apache.flink.table.planner.codegen.CodeGenUtils.{binaryRowFieldSetAccess, binaryRowSetNull}
import org.apache.flink.table.planner.codegen._
import org.apache.flink.table.planner.codegen.agg.batch.AggCodeGenHelper.buildAggregateArgsMapping
Expand Down Expand Up @@ -292,13 +292,8 @@ object HashAggCodeGenHelper {
val getAggValueExprs = aggregates.zipWithIndex.map {
case (agg: DeclarativeAggregateFunction, aggIndex) =>
val idx = auxGrouping.length + aggIndex
val expr = agg.getValueExpression.accept(
agg.getValueExpression.accept(
ResolveReference(ctx, isMerge, bindRefOffset, agg, idx, argsMapping, aggBuffMapping))
ApiExpressionUtils.unresolvedCall(
BuiltInFunctionDefinitions.CAST,
expr,
ApiExpressionUtils.typeLiteral(agg.getResultType)
)
}.map(_.accept(new RexNodeConverter(builder))).map(exprCodegen.generateExpression)

val getValueExprs = getAuxGroupingExprs ++ getAggValueExprs
Expand Down
Expand Up @@ -39,8 +39,6 @@ import org.apache.calcite.util.{ImmutableBitSet, ImmutableIntList}
import java.math.{BigDecimal => JBigDecimal}
import java.util

import org.apache.calcite.sql.`type`.SqlTypeName

import scala.collection.JavaConversions._

/**
Expand Down Expand Up @@ -304,19 +302,22 @@ class SplitAggregateRule extends RelOptRule(
aggGroupCount + index + avgAggCount + 1,
finalAggregate.getRowType)
avgAggCount += 1
// TODO
// Make a guarantee that the final aggregation returns NULL if underlying count is ZERO.
// We use SUM0 for underlying sum, which may run into ZERO / ZERO,
// and division by zero exception occurs.
// @see Glossary#SQL2011 SQL:2011 Part 2 Section 6.27
val equals = relBuilder.call(
FlinkSqlOperatorTable.EQUALS,
countInputRef,
relBuilder.getRexBuilder.makeBigintLiteral(JBigDecimal.valueOf(0)))
val falseT = relBuilder.call(FlinkSqlOperatorTable.DIVIDE, sumInputRef, countInputRef)
val trueT = relBuilder.cast(
val ifTrue = relBuilder.cast(
relBuilder.getRexBuilder.constantNull(), aggCall.`type`.getSqlTypeName)
val ifFalse = relBuilder.call(FlinkSqlOperatorTable.DIVIDE, sumInputRef, countInputRef)
relBuilder.call(
FlinkSqlOperatorTable.IF,
equals,
trueT,
falseT)
ifTrue,
ifFalse)
} else {
RexInputRef.of(aggGroupCount + index + avgAggCount, finalAggregate.getRowType)
}
Expand Down

0 comments on commit 7d17148

Please sign in to comment.