From 517567348b0ec0c23ef0c1dcc05c54a91d5c5671 Mon Sep 17 00:00:00 2001 From: Fabian Hueske Date: Thu, 15 Mar 2018 21:04:00 +0100 Subject: [PATCH 1/4] [FLINK-8903] [table] Fix VAR_SAMP, VAR_POP, STDEV_SAMP, STDEV_POP functions on GROUP BY windows. --- .../rules/AggregateReduceFunctionsRule.java | 590 ++++++++++++++++++ .../nodes/logical/FlinkLogicalAggregate.scala | 4 +- .../logical/FlinkLogicalWindowAggregate.scala | 16 + .../table/plan/rules/FlinkRuleSets.scala | 1 + .../WindowAggregateReduceFunctionsRule.scala | 75 +++ .../runtime/aggregate/AggregateUtil.scala | 4 +- .../table/api/batch/sql/GroupWindowTest.scala | 49 ++ .../api/batch/table/GroupWindowTest.scala | 45 ++ .../api/stream/sql/GroupWindowTest.scala | 46 ++ .../api/stream/table/GroupWindowTest.scala | 45 ++ 10 files changed, 871 insertions(+), 4 deletions(-) create mode 100644 flink-libraries/flink-table/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/common/WindowAggregateReduceFunctionsRule.scala diff --git a/flink-libraries/flink-table/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java b/flink-libraries/flink-table/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java new file mode 100644 index 0000000000000..76b340382d581 --- /dev/null +++ b/flink-libraries/flink-table/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java @@ -0,0 +1,590 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.rel.rules; + +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelOptRuleOperand; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.RelFactories; +import org.apache.calcite.rel.logical.LogicalAggregate; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.SqlTypeUtil; +import org.apache.calcite.tools.RelBuilder; +import org.apache.calcite.tools.RelBuilderFactory; +import org.apache.calcite.util.CompositeList; +import org.apache.calcite.util.ImmutableIntList; +import org.apache.calcite.util.Util; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; + +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +/* + * THIS FILE HAS BEEN COPIED FROM THE APACHE CALCITE PROJECT TO MAKE IT MORE EXTENSIBLE. + * + * Modification: + * - Added newCalcRel() method to be able to add fields to the projection. + */ + +/** + * Planner rule that reduces aggregate functions in + * {@link org.apache.calcite.rel.core.Aggregate}s to simpler forms. + * + *

Rewrites: + *

+ * + *

Since many of these rewrites introduce multiple occurrences of simpler + * forms like {@code COUNT(x)}, the rule gathers common sub-expressions as it + * goes. + */ +public class AggregateReduceFunctionsRule extends RelOptRule { + //~ Static fields/initializers --------------------------------------------- + + /** The singleton. */ + public static final AggregateReduceFunctionsRule INSTANCE = + new AggregateReduceFunctionsRule(operand(LogicalAggregate.class, any()), + RelFactories.LOGICAL_BUILDER); + + //~ Constructors ----------------------------------------------------------- + + /** Creates an AggregateReduceFunctionsRule. */ + public AggregateReduceFunctionsRule(RelOptRuleOperand operand, + RelBuilderFactory relBuilderFactory) { + super(operand, relBuilderFactory, null); + } + + //~ Methods ---------------------------------------------------------------- + + @Override public boolean matches(RelOptRuleCall call) { + if (!super.matches(call)) { + return false; + } + Aggregate oldAggRel = (Aggregate) call.rels[0]; + return containsAvgStddevVarCall(oldAggRel.getAggCallList()); + } + + public void onMatch(RelOptRuleCall ruleCall) { + Aggregate oldAggRel = (Aggregate) ruleCall.rels[0]; + reduceAggs(ruleCall, oldAggRel); + } + + /** + * Returns whether any of the aggregates are calls to AVG, STDDEV_*, VAR_*. + * + * @param aggCallList List of aggregate calls + */ + private boolean containsAvgStddevVarCall(List aggCallList) { + for (AggregateCall call : aggCallList) { + if (isReducible(call.getAggregation().getKind())) { + return true; + } + } + return false; + } + + /** + * Returns whether the aggregate call is a reducible function + */ + private boolean isReducible(final SqlKind kind) { + if (SqlKind.AVG_AGG_FUNCTIONS.contains(kind)) { + return true; + } + switch (kind) { + case SUM: + return true; + } + return false; + } + + /** + * Reduces all calls to AVG, STDDEV_POP, STDDEV_SAMP, VAR_POP, VAR_SAMP in + * the aggregates list to. + * + *

It handles newly generated common subexpressions since this was done + * at the sql2rel stage. + */ + private void reduceAggs( + RelOptRuleCall ruleCall, + Aggregate oldAggRel) { + RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder(); + + List oldCalls = oldAggRel.getAggCallList(); + final int groupCount = oldAggRel.getGroupCount(); + final int indicatorCount = oldAggRel.getIndicatorCount(); + + final List newCalls = Lists.newArrayList(); + final Map aggCallMapping = Maps.newHashMap(); + + final List projList = Lists.newArrayList(); + + // pass through group key (+ indicators if present) + for (int i = 0; i < groupCount + indicatorCount; ++i) { + projList.add( + rexBuilder.makeInputRef( + getFieldType(oldAggRel, i), + i)); + } + + // List of input expressions. If a particular aggregate needs more, it + // will add an expression to the end, and we will create an extra + // project. + final RelBuilder relBuilder = ruleCall.builder(); + relBuilder.push(oldAggRel.getInput()); + final List inputExprs = new ArrayList<>(relBuilder.fields()); + + // create new agg function calls and rest of project list together + for (AggregateCall oldCall : oldCalls) { + projList.add( + reduceAgg( + oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs)); + } + + final int extraArgCount = + inputExprs.size() - relBuilder.peek().getRowType().getFieldCount(); + if (extraArgCount > 0) { + relBuilder.project(inputExprs, + CompositeList.of( + relBuilder.peek().getRowType().getFieldNames(), + Collections.nCopies(extraArgCount, null))); + } + newAggregateRel(relBuilder, oldAggRel, newCalls); + newCalcRel(relBuilder, oldAggRel, projList); + ruleCall.transformTo(relBuilder.build()); + } + + private RexNode reduceAgg( + Aggregate oldAggRel, + AggregateCall oldCall, + List newCalls, + Map aggCallMapping, + List inputExprs) { + final SqlKind kind = oldCall.getAggregation().getKind(); + if (isReducible(kind)) { + switch (kind) { + case SUM: + // replace original SUM(x) with + // case COUNT(x) when 0 then null else SUM0(x) end + return reduceSum(oldAggRel, oldCall, newCalls, aggCallMapping); + case AVG: + // replace original AVG(x) with SUM(x) / COUNT(x) + return reduceAvg(oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs); + case STDDEV_POP: + // replace original STDDEV_POP(x) with + // SQRT( + // (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x)) + // / COUNT(x)) + return reduceStddev(oldAggRel, oldCall, true, true, newCalls, + aggCallMapping, inputExprs); + case STDDEV_SAMP: + // replace original STDDEV_POP(x) with + // SQRT( + // (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x)) + // / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END) + return reduceStddev(oldAggRel, oldCall, false, true, newCalls, + aggCallMapping, inputExprs); + case VAR_POP: + // replace original VAR_POP(x) with + // (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x)) + // / COUNT(x) + return reduceStddev(oldAggRel, oldCall, true, false, newCalls, + aggCallMapping, inputExprs); + case VAR_SAMP: + // replace original VAR_POP(x) with + // (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x)) + // / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END + return reduceStddev(oldAggRel, oldCall, false, false, newCalls, + aggCallMapping, inputExprs); + default: + throw Util.unexpected(kind); + } + } else { + // anything else: preserve original call + RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder(); + final int nGroups = oldAggRel.getGroupCount(); + List oldArgTypes = + SqlTypeUtil.projectTypes( + oldAggRel.getInput().getRowType(), oldCall.getArgList()); + return rexBuilder.addAggCall(oldCall, + nGroups, + oldAggRel.indicator, + newCalls, + aggCallMapping, + oldArgTypes); + } + } + + private AggregateCall createAggregateCallWithBinding( + RelDataTypeFactory typeFactory, + SqlAggFunction aggFunction, + RelDataType operandType, + Aggregate oldAggRel, + AggregateCall oldCall, + int argOrdinal) { + final Aggregate.AggCallBinding binding = + new Aggregate.AggCallBinding(typeFactory, aggFunction, + ImmutableList.of(operandType), oldAggRel.getGroupCount(), + oldCall.filterArg >= 0); + return AggregateCall.create(aggFunction, + oldCall.isDistinct(), + oldCall.isApproximate(), + ImmutableIntList.of(argOrdinal), + oldCall.filterArg, + aggFunction.inferReturnType(binding), + null); + } + + private RexNode reduceAvg( + Aggregate oldAggRel, + AggregateCall oldCall, + List newCalls, + Map aggCallMapping, + List inputExprs) { + final int nGroups = oldAggRel.getGroupCount(); + final RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder(); + final int iAvgInput = oldCall.getArgList().get(0); + final RelDataType avgInputType = + getFieldType( + oldAggRel.getInput(), + iAvgInput); + final AggregateCall sumCall = + AggregateCall.create(SqlStdOperatorTable.SUM, + oldCall.isDistinct(), + oldCall.isApproximate(), + oldCall.getArgList(), + oldCall.filterArg, + oldAggRel.getGroupCount(), + oldAggRel.getInput(), + null, + null); + final AggregateCall countCall = + AggregateCall.create(SqlStdOperatorTable.COUNT, + oldCall.isDistinct(), + oldCall.isApproximate(), + oldCall.getArgList(), + oldCall.filterArg, + oldAggRel.getGroupCount(), + oldAggRel.getInput(), + null, + null); + + // NOTE: these references are with respect to the output + // of newAggRel + RexNode numeratorRef = + rexBuilder.addAggCall(sumCall, + nGroups, + oldAggRel.indicator, + newCalls, + aggCallMapping, + ImmutableList.of(avgInputType)); + final RexNode denominatorRef = + rexBuilder.addAggCall(countCall, + nGroups, + oldAggRel.indicator, + newCalls, + aggCallMapping, + ImmutableList.of(avgInputType)); + + final RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory(); + final RelDataType avgType = typeFactory.createTypeWithNullability( + oldCall.getType(), numeratorRef.getType().isNullable()); + numeratorRef = rexBuilder.ensureType(avgType, numeratorRef, true); + final RexNode divideRef = + rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, numeratorRef, denominatorRef); + return rexBuilder.makeCast(oldCall.getType(), divideRef); + } + + private RexNode reduceSum( + Aggregate oldAggRel, + AggregateCall oldCall, + List newCalls, + Map aggCallMapping) { + final int nGroups = oldAggRel.getGroupCount(); + RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder(); + int arg = oldCall.getArgList().get(0); + RelDataType argType = + getFieldType( + oldAggRel.getInput(), + arg); + final AggregateCall sumZeroCall = + AggregateCall.create(SqlStdOperatorTable.SUM0, oldCall.isDistinct(), + oldCall.isApproximate(), oldCall.getArgList(), oldCall.filterArg, + oldAggRel.getGroupCount(), oldAggRel.getInput(), null, + oldCall.name); + final AggregateCall countCall = + AggregateCall.create(SqlStdOperatorTable.COUNT, + oldCall.isDistinct(), + oldCall.isApproximate(), + oldCall.getArgList(), + oldCall.filterArg, + oldAggRel.getGroupCount(), + oldAggRel, + null, + null); + + // NOTE: these references are with respect to the output + // of newAggRel + RexNode sumZeroRef = + rexBuilder.addAggCall(sumZeroCall, + nGroups, + oldAggRel.indicator, + newCalls, + aggCallMapping, + ImmutableList.of(argType)); + if (!oldCall.getType().isNullable()) { + // If SUM(x) is not nullable, the validator must have determined that + // nulls are impossible (because the group is never empty and x is never + // null). Therefore we translate to SUM0(x). + return sumZeroRef; + } + RexNode countRef = + rexBuilder.addAggCall(countCall, + nGroups, + oldAggRel.indicator, + newCalls, + aggCallMapping, + ImmutableList.of(argType)); + return rexBuilder.makeCall(SqlStdOperatorTable.CASE, + rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, + countRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO)), + rexBuilder.makeCast(sumZeroRef.getType(), rexBuilder.constantNull()), + sumZeroRef); + } + + private RexNode reduceStddev( + Aggregate oldAggRel, + AggregateCall oldCall, + boolean biased, + boolean sqrt, + List newCalls, + Map aggCallMapping, + List inputExprs) { + // stddev_pop(x) ==> + // power( + // (sum(x * x) - sum(x) * sum(x) / count(x)) + // / count(x), + // .5) + // + // stddev_samp(x) ==> + // power( + // (sum(x * x) - sum(x) * sum(x) / count(x)) + // / nullif(count(x) - 1, 0), + // .5) + final int nGroups = oldAggRel.getGroupCount(); + final RelOptCluster cluster = oldAggRel.getCluster(); + final RexBuilder rexBuilder = cluster.getRexBuilder(); + final RelDataTypeFactory typeFactory = cluster.getTypeFactory(); + + assert oldCall.getArgList().size() == 1 : oldCall.getArgList(); + final int argOrdinal = oldCall.getArgList().get(0); + final RelDataType argOrdinalType = getFieldType(oldAggRel.getInput(), argOrdinal); + final RelDataType oldCallType = + typeFactory.createTypeWithNullability(oldCall.getType(), + argOrdinalType.isNullable()); + + final RexNode argRef = + rexBuilder.ensureType(oldCallType, inputExprs.get(argOrdinal), true); + final int argRefOrdinal = lookupOrAdd(inputExprs, argRef); + + final RexNode argSquared = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, + argRef, argRef); + final int argSquaredOrdinal = lookupOrAdd(inputExprs, argSquared); + + final AggregateCall sumArgSquaredAggCall = + createAggregateCallWithBinding(typeFactory, SqlStdOperatorTable.SUM, + argSquared.getType(), oldAggRel, oldCall, argSquaredOrdinal); + + final RexNode sumArgSquared = + rexBuilder.addAggCall(sumArgSquaredAggCall, + nGroups, + oldAggRel.indicator, + newCalls, + aggCallMapping, + ImmutableList.of(sumArgSquaredAggCall.getType())); + + final AggregateCall sumArgAggCall = + AggregateCall.create(SqlStdOperatorTable.SUM, + oldCall.isDistinct(), + oldCall.isApproximate(), + ImmutableIntList.of(argOrdinal), + oldCall.filterArg, + oldAggRel.getGroupCount(), + oldAggRel.getInput(), + null, + null); + + final RexNode sumArg = + rexBuilder.addAggCall(sumArgAggCall, + nGroups, + oldAggRel.indicator, + newCalls, + aggCallMapping, + ImmutableList.of(sumArgAggCall.getType())); + final RexNode sumArgCast = rexBuilder.ensureType(oldCallType, sumArg, true); + final RexNode sumSquaredArg = + rexBuilder.makeCall( + SqlStdOperatorTable.MULTIPLY, sumArgCast, sumArgCast); + + final AggregateCall countArgAggCall = + AggregateCall.create(SqlStdOperatorTable.COUNT, + oldCall.isDistinct(), + oldCall.isApproximate(), + oldCall.getArgList(), + oldCall.filterArg, + oldAggRel.getGroupCount(), + oldAggRel, + null, + null); + + final RexNode countArg = + rexBuilder.addAggCall(countArgAggCall, + nGroups, + oldAggRel.indicator, + newCalls, + aggCallMapping, + ImmutableList.of(argOrdinalType)); + + final RexNode avgSumSquaredArg = + rexBuilder.makeCall( + SqlStdOperatorTable.DIVIDE, sumSquaredArg, countArg); + + final RexNode diff = + rexBuilder.makeCall( + SqlStdOperatorTable.MINUS, + sumArgSquared, avgSumSquaredArg); + + final RexNode denominator; + if (biased) { + denominator = countArg; + } else { + final RexLiteral one = + rexBuilder.makeExactLiteral(BigDecimal.ONE); + final RexNode nul = + rexBuilder.makeCast(countArg.getType(), rexBuilder.constantNull()); + final RexNode countMinusOne = + rexBuilder.makeCall( + SqlStdOperatorTable.MINUS, countArg, one); + final RexNode countEqOne = + rexBuilder.makeCall( + SqlStdOperatorTable.EQUALS, countArg, one); + denominator = + rexBuilder.makeCall( + SqlStdOperatorTable.CASE, + countEqOne, nul, countMinusOne); + } + + final RexNode div = + rexBuilder.makeCall( + SqlStdOperatorTable.DIVIDE, diff, denominator); + + RexNode result = div; + if (sqrt) { + final RexNode half = + rexBuilder.makeExactLiteral(new BigDecimal("0.5")); + result = + rexBuilder.makeCall( + SqlStdOperatorTable.POWER, div, half); + } + + return rexBuilder.makeCast( + oldCall.getType(), result); + } + + /** + * Finds the ordinal of an element in a list, or adds it. + * + * @param list List + * @param element Element to lookup or add + * @param Element type + * @return Ordinal of element in list + */ + private static int lookupOrAdd(List list, T element) { + int ordinal = list.indexOf(element); + if (ordinal == -1) { + ordinal = list.size(); + list.add(element); + } + return ordinal; + } + + /** + * Do a shallow clone of oldAggRel and update aggCalls. Could be refactored + * into Aggregate and subclasses - but it's only needed for some + * subclasses. + * + * @param relBuilder Builder of relational expressions; at the top of its + * stack is its input + * @param oldAggregate LogicalAggregate to clone. + * @param newCalls New list of AggregateCalls + */ + protected void newAggregateRel(RelBuilder relBuilder, + Aggregate oldAggregate, + List newCalls) { + relBuilder.aggregate( + relBuilder.groupKey(oldAggregate.getGroupSet(), + oldAggregate.getGroupSets()), + newCalls); + } + + protected void newCalcRel(RelBuilder relBuilder, + Aggregate oldAggregate, + List exprs) { + relBuilder.project(exprs, oldAggregate.getRowType().getFieldNames()); + } + + private RelDataType getFieldType(RelNode relNode, int i) { + final RelDataTypeField inputField = + relNode.getRowType().getFieldList().get(i); + return inputField.getType(); + } +} + +// End AggregateReduceFunctionsRule.java diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalAggregate.scala index e1e93c7c583b9..03e4e1f2e5e74 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalAggregate.scala @@ -30,7 +30,7 @@ import org.apache.calcite.sql.SqlKind import org.apache.calcite.util.ImmutableBitSet import org.apache.flink.table.plan.nodes.FlinkConventions -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ class FlinkLogicalAggregate( cluster: RelOptCluster, @@ -74,7 +74,7 @@ private class FlinkLogicalAggregateConverter // we do not support these functions natively // they have to be converted using the AggregateReduceFunctionsRule - val supported = agg.getAggCallList.map(_.getAggregation.getKind).forall { + val supported = agg.getAggCallList.asScala.map(_.getAggregation.getKind).forall { case SqlKind.STDDEV_POP | SqlKind.STDDEV_SAMP | SqlKind.VAR_POP | SqlKind.VAR_SAMP => false case _ => true } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalWindowAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalWindowAggregate.scala index 3e605e895dceb..d87d4113815c8 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalWindowAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalWindowAggregate.scala @@ -26,6 +26,7 @@ import org.apache.calcite.rel.convert.ConverterRule import org.apache.calcite.rel.core.{Aggregate, AggregateCall} import org.apache.calcite.rel.metadata.RelMetadataQuery import org.apache.calcite.rel.{RelNode, RelShuttle} +import org.apache.calcite.sql.SqlKind import org.apache.calcite.util.ImmutableBitSet import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty import org.apache.flink.table.calcite.FlinkTypeFactory @@ -33,6 +34,8 @@ import org.apache.flink.table.plan.logical.LogicalWindow import org.apache.flink.table.plan.logical.rel.LogicalWindowAggregate import org.apache.flink.table.plan.nodes.FlinkConventions +import scala.collection.JavaConverters._ + class FlinkLogicalWindowAggregate( window: LogicalWindow, namedProperties: Seq[NamedWindowProperty], @@ -103,6 +106,19 @@ class FlinkLogicalWindowAggregateConverter FlinkConventions.LOGICAL, "FlinkLogicalWindowAggregateConverter") { + override def matches(call: RelOptRuleCall): Boolean = { + val agg = call.rel(0).asInstanceOf[LogicalWindowAggregate] + + // we do not support these functions natively + // they have to be converted using the WindowAggregateReduceFunctionsRule + val supported = agg.getAggCallList.asScala.map(_.getAggregation.getKind).forall { + case SqlKind.STDDEV_POP | SqlKind.STDDEV_SAMP | SqlKind.VAR_POP | SqlKind.VAR_SAMP => false + case _ => true + } + + !agg.containsDistinctCall() && supported + } + override def convert(rel: RelNode): RelNode = { val agg = rel.asInstanceOf[LogicalWindowAggregate] val traitSet = rel.getTraitSet.replace(FlinkConventions.LOGICAL) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala index d3ad2ac5654dc..9f3b8e99ece61 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala @@ -93,6 +93,7 @@ object FlinkRuleSets { // reduce aggregate functions like AVG, STDDEV_POP etc. AggregateReduceFunctionsRule.INSTANCE, + WindowAggregateReduceFunctionsRule.INSTANCE, // remove unnecessary sort rule SortRemoveRule.INSTANCE, diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/common/WindowAggregateReduceFunctionsRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/common/WindowAggregateReduceFunctionsRule.scala new file mode 100644 index 0000000000000..4ca2b335478d6 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/common/WindowAggregateReduceFunctionsRule.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.rules.common + +import java.util + +import org.apache.calcite.plan.RelOptRule +import org.apache.calcite.rel.core.{Aggregate, AggregateCall, RelFactories} +import org.apache.calcite.rel.logical.LogicalAggregate +import org.apache.calcite.rel.rules.AggregateReduceFunctionsRule +import org.apache.calcite.rex.RexNode +import org.apache.calcite.tools.RelBuilder +import org.apache.flink.table.plan.logical.rel.LogicalWindowAggregate + +/** + * Rule to convert complex aggregation functions into simpler ones. + * Have a look at [[AggregateReduceFunctionsRule]] for details. + */ +class WindowAggregateReduceFunctionsRule extends AggregateReduceFunctionsRule( + RelOptRule.operand(classOf[LogicalWindowAggregate], RelOptRule.any()), + RelFactories.LOGICAL_BUILDER) { + + override def newAggregateRel( + relBuilder: RelBuilder, + oldAgg: Aggregate, + newCalls: util.List[AggregateCall]): Unit = { + + // create a LogicalAggregate with simpler aggregation functions + super.newAggregateRel(relBuilder, oldAgg, newCalls) + // pop LogicalAggregate from RelBuilder + val newAgg = relBuilder.build().asInstanceOf[LogicalAggregate] + + // create a new LogicalWindowAggregate (based on the new LogicalAggregate) and push it on the + // RelBuilder + val oldWindowAgg = oldAgg.asInstanceOf[LogicalWindowAggregate] + relBuilder.push(LogicalWindowAggregate.create( + oldWindowAgg.getWindow, + oldWindowAgg.getNamedProperties, + newAgg)) + } + + override def newCalcRel( + relBuilder: RelBuilder, + oldAgg: Aggregate, + exprs: util.List[RexNode]): Unit = { + + // add all named properties of the window to the selection + val oldWindowAgg = oldAgg.asInstanceOf[LogicalWindowAggregate] + oldWindowAgg.getNamedProperties.foreach(np => exprs.add(relBuilder.field(np.name))) + + // create a LogicalCalc that computes the complex aggregates and forwards the window properties + relBuilder.project(exprs, oldAgg.getRowType.getFieldNames) + } + +} + +object WindowAggregateReduceFunctionsRule { + val INSTANCE = new WindowAggregateReduceFunctionsRule +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala index df9b1c5520467..ce0a9c96e336e 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala @@ -1259,7 +1259,7 @@ object AggregateUtil { } } - case _: SqlAvgAggFunction => + case a: SqlAvgAggFunction if a.kind == SqlKind.AVG => aggregates(index) = sqlTypeName match { case TINYINT => new ByteAvgAggFunction @@ -1413,7 +1413,7 @@ object AggregateUtil { accTypes(index) = udagg.accType case unSupported: SqlAggFunction => - throw new TableException(s"unsupported Function: '${unSupported.getName}'") + throw new TableException(s"Unsupported Function: '${unSupported.getName}'") } } } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/GroupWindowTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/GroupWindowTest.scala index 8d06bcd84db39..b1369e2a9289a 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/GroupWindowTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/GroupWindowTest.scala @@ -304,4 +304,53 @@ class GroupWindowTest extends TableTestBase { util.verifySql(sql, expected) } + + @Test + def testDecomposableAggFunctions() = { + val util = batchTestUtil() + util.addTable[(Int, String, Long, Timestamp)]("MyTable", 'a, 'b, 'c, 'rowtime) + + val sql = + "SELECT " + + " VAR_POP(c), VAR_SAMP(c), STDDEV_POP(c), STDDEV_SAMP(c), " + + " TUMBLE_START(rowtime, INTERVAL '15' MINUTE), " + + " TUMBLE_END(rowtime, INTERVAL '15' MINUTE)" + + "FROM MyTable " + + "GROUP BY TUMBLE(rowtime, INTERVAL '15' MINUTE)" + + val expected = + unaryNode( + "DataSetCalc", + unaryNode( + "DataSetWindowAggregate", + unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select", "rowtime", "c", + "*(c, c) AS $f2", "*(c, c) AS $f3", "*(c, c) AS $f4", "*(c, c) AS $f5") + ), + term("window", TumblingGroupWindow('w$, 'rowtime, 900000.millis)), + term("select", + "SUM($f2) AS $f0", + "SUM(c) AS $f1", + "COUNT(c) AS $f2", + "SUM($f3) AS $f3", + "SUM($f4) AS $f4", + "SUM($f5) AS $f5", + "start('w$) AS w$start", + "end('w$) AS w$end", + "rowtime('w$) AS w$rowtime") + ), + term("select", + "CAST(/(-($f0, /(*($f1, $f1), $f2)), $f2)) AS EXPR$0", + "CAST(/(-($f3, /(*($f1, $f1), $f2)), CASE(=($f2, 1), null, -($f2, 1)))) AS EXPR$1", + "CAST(POWER(/(-($f4, /(*($f1, $f1), $f2)), $f2), 0.5)) AS EXPR$2", + "CAST(POWER(/(-($f5, /(*($f1, $f1), $f2)), CASE(=($f2, 1), null, -($f2, 1))), 0.5)) " + + "AS EXPR$3", + "CAST(w$start) AS EXPR$4", + "CAST(w$end) AS EXPR$5") + ) + + util.verifySql(sql, expected) + } } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/GroupWindowTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/GroupWindowTest.scala index ad44e09c68aad..27c1d7f6c324a 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/GroupWindowTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/GroupWindowTest.scala @@ -449,4 +449,49 @@ class GroupWindowTest extends TableTestBase { util.verifyTable(windowedTable, expected) } + + @Test + def testDecomposableAggFunctions(): Unit = { + val util = batchTestUtil() + val table = util.addTable[(Long, Int, String, Long)]('rowtime, 'a, 'b, 'c) + + val windowedTable = table + .window(Tumble over 15.minutes on 'rowtime as 'w) + .groupBy('w) + .select('c.varPop, 'c.varSamp, 'c.stddevPop, 'c.stddevSamp, 'w.start, 'w.end) + + val expected = + unaryNode( + "DataSetCalc", + unaryNode( + "DataSetWindowAggregate", + unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select", "c", "rowtime", + "*(c, c) AS $f2", "*(c, c) AS $f3", "*(c, c) AS $f4", "*(c, c) AS $f5") + ), + term("window", TumblingGroupWindow('w, 'rowtime, 900000.millis)), + term("select", + "SUM($f2) AS $f0", + "SUM(c) AS $f1", + "COUNT(c) AS $f2", + "SUM($f3) AS $f3", + "SUM($f4) AS $f4", + "SUM($f5) AS $f5", + "start('w) AS TMP_4", + "end('w) AS TMP_5") + ), + term("select", + "CAST(/(-($f0, /(*($f1, $f1), $f2)), $f2)) AS TMP_0", + "CAST(/(-($f3, /(*($f1, $f1), $f2)), CASE(=($f2, 1), null, -($f2, 1)))) AS TMP_1", + "CAST(POWER(/(-($f4, /(*($f1, $f1), $f2)), $f2), 0.5)) AS TMP_2", + "CAST(POWER(/(-($f5, /(*($f1, $f1), $f2)), CASE(=($f2, 1), null, -($f2, 1))), 0.5)) " + + "AS TMP_3", + "TMP_4", + "TMP_5") + ) + + util.verifyTable(windowedTable, expected) + } } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/GroupWindowTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/GroupWindowTest.scala index d7d5f1e07b947..d29283456d9b5 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/GroupWindowTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/GroupWindowTest.scala @@ -260,4 +260,50 @@ class GroupWindowTest extends TableTestBase { streamUtil.verifySql(sql, expected) } + + @Test + def testDecomposableAggFunctions() = { + + val sql = + "SELECT " + + " VAR_POP(c), VAR_SAMP(c), STDDEV_POP(c), STDDEV_SAMP(c), " + + " TUMBLE_START(rowtime, INTERVAL '15' MINUTE), " + + " TUMBLE_END(rowtime, INTERVAL '15' MINUTE)" + + "FROM MyTable " + + "GROUP BY TUMBLE(rowtime, INTERVAL '15' MINUTE)" + val expected = + unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamGroupWindowAggregate", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "rowtime", "c", + "*(c, c) AS $f2", "*(c, c) AS $f3", "*(c, c) AS $f4", "*(c, c) AS $f5") + ), + term("window", TumblingGroupWindow('w$, 'rowtime, 900000.millis)), + term("select", + "SUM($f2) AS $f0", + "SUM(c) AS $f1", + "COUNT(c) AS $f2", + "SUM($f3) AS $f3", + "SUM($f4) AS $f4", + "SUM($f5) AS $f5", + "start('w$) AS w$start", + "end('w$) AS w$end", + "rowtime('w$) AS w$rowtime", + "proctime('w$) AS w$proctime") + ), + term("select", + "CAST(/(-($f0, /(*($f1, $f1), $f2)), $f2)) AS EXPR$0", + "CAST(/(-($f3, /(*($f1, $f1), $f2)), CASE(=($f2, 1), null, -($f2, 1)))) AS EXPR$1", + "CAST(POWER(/(-($f4, /(*($f1, $f1), $f2)), $f2), 0.5)) AS EXPR$2", + "CAST(POWER(/(-($f5, /(*($f1, $f1), $f2)), CASE(=($f2, 1), null, -($f2, 1))), 0.5)) " + + "AS EXPR$3", + "w$start AS EXPR$4", + "w$end AS EXPR$5") + ) + streamUtil.verifySql(sql, expected) + } } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/GroupWindowTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/GroupWindowTest.scala index 260726ba495b9..a59ad8382a0ed 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/GroupWindowTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/GroupWindowTest.scala @@ -782,4 +782,49 @@ class GroupWindowTest extends TableTestBase { util.verifyTable(windowedTable, expected) } + + @Test + def testDecomposableAggFunctions(): Unit = { + val util = streamTestUtil() + val table = util.addTable[(Long, Int, String, Long)]('rowtime.rowtime, 'a, 'b, 'c) + + val windowedTable = table + .window(Tumble over 15.minutes on 'rowtime as 'w) + .groupBy('w) + .select('c.varPop, 'c.varSamp, 'c.stddevPop, 'c.stddevSamp, 'w.start, 'w.end) + + val expected = + unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamGroupWindowAggregate", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "c", "rowtime", + "*(c, c) AS $f2", "*(c, c) AS $f3", "*(c, c) AS $f4", "*(c, c) AS $f5") + ), + term("window", TumblingGroupWindow('w, 'rowtime, 900000.millis)), + term("select", + "SUM($f2) AS $f0", + "SUM(c) AS $f1", + "COUNT(c) AS $f2", + "SUM($f3) AS $f3", + "SUM($f4) AS $f4", + "SUM($f5) AS $f5", + "start('w) AS TMP_4", + "end('w) AS TMP_5") + ), + term("select", + "CAST(/(-($f0, /(*($f1, $f1), $f2)), $f2)) AS TMP_0", + "CAST(/(-($f3, /(*($f1, $f1), $f2)), CASE(=($f2, 1), null, -($f2, 1)))) AS TMP_1", + "CAST(POWER(/(-($f4, /(*($f1, $f1), $f2)), $f2), 0.5)) AS TMP_2", + "CAST(POWER(/(-($f5, /(*($f1, $f1), $f2)), CASE(=($f2, 1), null, -($f2, 1))), 0.5)) " + + "AS TMP_3", + "TMP_4", + "TMP_5") + ) + + util.verifyTable(windowedTable, expected) + } } From 50e234f5ede702b0d6da45673b1a91e609b1d81f Mon Sep 17 00:00:00 2001 From: Fabian Hueske Date: Fri, 16 Mar 2018 19:36:10 +0100 Subject: [PATCH 2/4] addressed feedback --- .../rel/rules/AggregateReduceFunctionsRule.java | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/flink-libraries/flink-table/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java b/flink-libraries/flink-table/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java index 76b340382d581..ce466e199c262 100644 --- a/flink-libraries/flink-table/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java +++ b/flink-libraries/flink-table/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java @@ -54,6 +54,9 @@ /* * THIS FILE HAS BEEN COPIED FROM THE APACHE CALCITE PROJECT TO MAKE IT MORE EXTENSIBLE. * + * We have opened an issue to port this change to Apache Calcite (CALCITE-2216). + * Once CALCITE-2216 is fixed and included in a release, we can remove the copied class. + * * Modification: * - Added newCalcRel() method to be able to add fields to the projection. */ @@ -574,6 +577,15 @@ protected void newAggregateRel(RelBuilder relBuilder, newCalls); } + /** + * Add a calc with the expressions to compute the original agg calls from the + * decomposed ones. + * + * @param relBuilder Builder of relational expressions; at the top of its + * stack is its input + * @param oldAggregate The original LogicalAggregate that is replaced. + * @param exprs The expressions to compute the original agg calls. + */ protected void newCalcRel(RelBuilder relBuilder, Aggregate oldAggregate, List exprs) { From 9925a18a2d160197c0be9bb74d185d2284b28d68 Mon Sep 17 00:00:00 2001 From: Fabian Hueske Date: Sat, 17 Mar 2018 00:08:39 +0100 Subject: [PATCH 3/4] addressed feedback --- .../table/plan/nodes/logical/FlinkLogicalAggregate.scala | 5 ++++- .../plan/nodes/logical/FlinkLogicalWindowAggregate.scala | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalAggregate.scala index 03e4e1f2e5e74..17b6f1b6f8756 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalAggregate.scala @@ -75,7 +75,10 @@ private class FlinkLogicalAggregateConverter // we do not support these functions natively // they have to be converted using the AggregateReduceFunctionsRule val supported = agg.getAggCallList.asScala.map(_.getAggregation.getKind).forall { - case SqlKind.STDDEV_POP | SqlKind.STDDEV_SAMP | SqlKind.VAR_POP | SqlKind.VAR_SAMP => false + // we support AVG + case SqlKind.AVG => true + // but none of the other AVG agg functions + case k if SqlKind.AVG_AGG_FUNCTIONS.contains(k) => false case _ => true } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalWindowAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalWindowAggregate.scala index d87d4113815c8..5fb716ce288a9 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalWindowAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalWindowAggregate.scala @@ -112,7 +112,10 @@ class FlinkLogicalWindowAggregateConverter // we do not support these functions natively // they have to be converted using the WindowAggregateReduceFunctionsRule val supported = agg.getAggCallList.asScala.map(_.getAggregation.getKind).forall { - case SqlKind.STDDEV_POP | SqlKind.STDDEV_SAMP | SqlKind.VAR_POP | SqlKind.VAR_SAMP => false + // we support AVG + case SqlKind.AVG => true + // but none of the other AVG agg functions + case k if SqlKind.AVG_AGG_FUNCTIONS.contains(k) => false case _ => true } From dfcb89021d23c8db77784f21484a112eb22749ff Mon Sep 17 00:00:00 2001 From: Fabian Hueske Date: Tue, 20 Mar 2018 11:31:15 +0100 Subject: [PATCH 4/4] addressed feedback --- .../plan/nodes/logical/FlinkLogicalWindowAggregate.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalWindowAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalWindowAggregate.scala index 5fb716ce288a9..f2576f4c85378 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalWindowAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalWindowAggregate.scala @@ -111,15 +111,13 @@ class FlinkLogicalWindowAggregateConverter // we do not support these functions natively // they have to be converted using the WindowAggregateReduceFunctionsRule - val supported = agg.getAggCallList.asScala.map(_.getAggregation.getKind).forall { + agg.getAggCallList.asScala.map(_.getAggregation.getKind).forall { // we support AVG case SqlKind.AVG => true // but none of the other AVG agg functions case k if SqlKind.AVG_AGG_FUNCTIONS.contains(k) => false case _ => true } - - !agg.containsDistinctCall() && supported } override def convert(rel: RelNode): RelNode = {