From 95a18657233ba38bb794dac51ed1eaa31c367f66 Mon Sep 17 00:00:00 2001 From: Sebastian Liu Date: Thu, 28 Jan 2021 21:36:00 +0800 Subject: [PATCH 1/7] [FLINK-20895] support local aggregate push down in blink planner --- .../api/config/OptimizerConfigOptions.java | 10 + .../source/AggregatePushDownSpec.java | 94 ++++++ .../source/SourceAbilityContext.java | 1 + .../abilities/source/SourceAbilitySpec.java | 3 +- ...shLocalAggIntoTableSourceScanRuleBase.java | 274 ++++++++++++++++++ ...calAggWithSortIntoTableSourceScanRule.java | 87 ++++++ ...AggWithoutSortIntoTableSourceScanRule.java | 80 +++++ .../batch/BatchPhysicalTableSourceScan.scala | 6 + .../plan/rules/FlinkBatchRuleSets.scala | 4 +- .../plan/schema/TableSourceTable.scala | 19 ++ .../planner/plan/utils/AggregateUtil.scala | 20 +- .../factories/TestValuesTableFactory.java | 157 +++++++++- ...shLocalAggIntoTableSourceScanRuleTest.java | 132 +++++++++ ...ushLocalAggIntoTableSourceScanRuleTest.xml | 190 ++++++++++++ .../agg/LocalAggregatePushDownITCase.scala | 142 +++++++++ .../planner/runtime/utils/TestData.scala | 22 +- 16 files changed, 1215 insertions(+), 26 deletions(-) create mode 100644 flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/AggregatePushDownSpec.java create mode 100644 flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleBase.java create mode 100644 flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggWithSortIntoTableSourceScanRule.java create mode 100644 flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggWithoutSortIntoTableSourceScanRule.java create mode 100644 flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.java create mode 100644 flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.xml create mode 100644 flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/LocalAggregatePushDownITCase.scala diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/OptimizerConfigOptions.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/OptimizerConfigOptions.java index 685dd585df3df..7a1ad7a72615a 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/OptimizerConfigOptions.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/OptimizerConfigOptions.java @@ -93,6 +93,16 @@ public class OptimizerConfigOptions { + TABLE_OPTIMIZER_REUSE_SUB_PLAN_ENABLED.key() + " is true."); + @Documentation.TableOption(execMode = Documentation.ExecMode.BATCH) + public static final ConfigOption TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED = + key("table.optimizer.source.aggregate-pushdown-enabled") + .booleanType() + .defaultValue(false) + .withDescription( + "When it is true, the optimizer will push down the local aggregates into " + + "the TableSource which implements SupportsAggregatePushDown. " + + "Default value is false."); + @Documentation.TableOption(execMode = Documentation.ExecMode.BATCH_STREAMING) public static final ConfigOption TABLE_OPTIMIZER_SOURCE_PREDICATE_PUSHDOWN_ENABLED = key("table.optimizer.source.predicate-pushdown-enabled") diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/AggregatePushDownSpec.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/AggregatePushDownSpec.java new file mode 100644 index 0000000000000..d2085fb02e7dd --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/AggregatePushDownSpec.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.abilities.source; + +import org.apache.flink.table.api.TableException; +import org.apache.flink.table.connector.source.DynamicTableSource; +import org.apache.flink.table.connector.source.abilities.SupportsAggregatePushDown; +import org.apache.flink.table.expressions.AggregateExpression; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.table.types.utils.TypeConversions; + +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonTypeName; + +import java.util.ArrayList; +import java.util.List; + +import static org.apache.flink.util.Preconditions.checkArgument; +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * A sub-class of {@link SourceAbilitySpec} that can not only serialize/deserialize the aggregation + * to/from JSON, but also can push the filter into a {@link SupportsAggregatePushDown}. + */ +@JsonTypeName("AggregatePushDown") +public class AggregatePushDownSpec extends SourceAbilitySpecBase { + + public static final String FIELD_NAME_GROUPING_SETS = "groupingSets"; + + public static final String FIELD_NAME_AGGREGATE_EXPRESSIONS = "aggregateExpressions"; + + @JsonProperty(FIELD_NAME_GROUPING_SETS) + private final List groupingSets; + + @JsonProperty(FIELD_NAME_AGGREGATE_EXPRESSIONS) + private final List aggregateExpressions; + + @JsonCreator + public AggregatePushDownSpec( + @JsonProperty(FIELD_NAME_GROUPING_SETS) List groupingSets, + @JsonProperty(FIELD_NAME_AGGREGATE_EXPRESSIONS) + List aggregateExpressions, + @JsonProperty(FIELD_NAME_PRODUCED_TYPE) RowType producedType) { + super(producedType); + this.groupingSets = new ArrayList<>(checkNotNull(groupingSets)); + this.aggregateExpressions = new ArrayList<>(checkNotNull(aggregateExpressions)); + } + + @Override + public void apply(DynamicTableSource tableSource, SourceAbilityContext context) { + checkArgument(getProducedType().isPresent()); + apply(groupingSets, aggregateExpressions, getProducedType().get(), tableSource); + } + + @Override + public String getDigests(SourceAbilityContext context) { + return null; + } + + public static boolean apply( + List groupingSets, + List aggregateExpressions, + RowType producedType, + DynamicTableSource tableSource) { + if (tableSource instanceof SupportsAggregatePushDown) { + DataType producedDataType = TypeConversions.fromLogicalToDataType(producedType); + return ((SupportsAggregatePushDown) tableSource) + .applyAggregates(groupingSets, aggregateExpressions, producedDataType); + } else { + throw new TableException( + String.format( + "%s does not support SupportsAggregatePushDown.", + tableSource.getClass().getName())); + } + } +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/SourceAbilityContext.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/SourceAbilityContext.java index 1fbb61a468f88..e3431c706518b 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/SourceAbilityContext.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/SourceAbilityContext.java @@ -40,6 +40,7 @@ *
  • project push down (SupportsProjectionPushDown) *
  • partition push down (SupportsPartitionPushDown) *
  • watermark push down (SupportsWatermarkPushDown) + *
  • aggregate push down (SupportsAggregatePushDown) *
  • reading metadata (SupportsReadingMetadata) * */ diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/SourceAbilitySpec.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/SourceAbilitySpec.java index 92326f0ac52c4..453ee4cedfcbb 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/SourceAbilitySpec.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/SourceAbilitySpec.java @@ -40,7 +40,8 @@ @JsonSubTypes.Type(value = ProjectPushDownSpec.class), @JsonSubTypes.Type(value = ReadingMetadataSpec.class), @JsonSubTypes.Type(value = WatermarkPushDownSpec.class), - @JsonSubTypes.Type(value = SourceWatermarkSpec.class) + @JsonSubTypes.Type(value = SourceWatermarkSpec.class), + @JsonSubTypes.Type(value = AggregatePushDownSpec.class) }) @Internal public interface SourceAbilitySpec { diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleBase.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleBase.java new file mode 100644 index 0000000000000..49f2cc0b0532b --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleBase.java @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.rules.physical.batch; + +import org.apache.flink.table.api.TableConfig; +import org.apache.flink.table.api.config.OptimizerConfigOptions; +import org.apache.flink.table.connector.source.DynamicTableSource; +import org.apache.flink.table.connector.source.abilities.SupportsAggregatePushDown; +import org.apache.flink.table.expressions.AggregateExpression; +import org.apache.flink.table.expressions.FieldReferenceExpression; +import org.apache.flink.table.planner.calcite.FlinkTypeFactory; +import org.apache.flink.table.planner.functions.aggfunctions.AvgAggFunction; +import org.apache.flink.table.planner.functions.aggfunctions.CountAggFunction; +import org.apache.flink.table.planner.functions.aggfunctions.Sum0AggFunction; +import org.apache.flink.table.planner.plan.abilities.source.AggregatePushDownSpec; +import org.apache.flink.table.planner.plan.abilities.source.SourceAbilitySpec; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalExchange; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalGroupAggregateBase; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalTableSourceScan; +import org.apache.flink.table.planner.plan.schema.TableSourceTable; +import org.apache.flink.table.planner.plan.stats.FlinkStatistic; +import org.apache.flink.table.planner.plan.utils.AggregateInfo; +import org.apache.flink.table.planner.plan.utils.AggregateInfoList; +import org.apache.flink.table.planner.plan.utils.AggregateUtil; +import org.apache.flink.table.planner.utils.ShortcutUtils; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.table.types.utils.TypeConversions; + +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelOptRuleOperand; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.type.RelDataType; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import scala.Tuple2; +import scala.collection.JavaConverters; + +/** + * Planner rule that tries to push a local aggregator into an {@link BatchPhysicalTableSourceScan} + * which table is a {@link TableSourceTable}. And the table source in the table is a {@link + * SupportsAggregatePushDown}. + * + *

    The aggregate push down does not support a number of more complex statements at present: + * + *

      + *
    • complex grouping operations such as ROLLUP, CUBE, or GROUPING SETS. + *
    • expressions inside the aggregation function call: such as sum(a * b). + *
    • aggregations with ordering. + *
    • aggregations with filter. + *
    + */ +public abstract class PushLocalAggIntoTableSourceScanRuleBase extends RelOptRule { + + public PushLocalAggIntoTableSourceScanRuleBase(RelOptRuleOperand operand, String description) { + super(operand, description); + } + + protected boolean isMatch( + RelOptRuleCall call, + BatchPhysicalGroupAggregateBase aggregate, + BatchPhysicalTableSourceScan tableSourceScan) { + TableConfig tableConfig = ShortcutUtils.unwrapContext(call.getPlanner()).getTableConfig(); + if (!tableConfig + .getConfiguration() + .getBoolean( + OptimizerConfigOptions.TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED)) { + return false; + } + + if (aggregate.isFinal() || aggregate.getAggCallList().size() < 1) { + return false; + } + List aggCallList = + JavaConverters.seqAsJavaListConverter(aggregate.getAggCallList()).asJava(); + for (AggregateCall aggCall : aggCallList) { + if (aggCall.isDistinct() + || aggCall.isApproximate() + || aggCall.getArgList().size() > 1 + || aggCall.hasFilter() + || !aggCall.getCollation().getFieldCollations().isEmpty()) { + return false; + } + } + TableSourceTable tableSourceTable = tableSourceScan.tableSourceTable(); + // we can not push aggregates twice + return tableSourceTable != null + && tableSourceTable.tableSource() instanceof SupportsAggregatePushDown + && Arrays.stream(tableSourceTable.abilitySpecs()) + .noneMatch(spec -> spec instanceof AggregatePushDownSpec); + } + + protected void pushLocalAggregateIntoScan( + RelOptRuleCall call, + BatchPhysicalGroupAggregateBase localAgg, + BatchPhysicalTableSourceScan oldScan) { + RelDataType originalInputRowType = oldScan.deriveRowType(); + AggregateInfoList aggInfoList = + AggregateUtil.transformToBatchAggregateInfoList( + FlinkTypeFactory.toLogicalRowType(localAgg.getInput().getRowType()), + localAgg.getAggCallList(), + null, + null); + if (aggInfoList.aggInfos().length == 0) { + // no agg function need to be pushed down + return; + } + + List groupingSets = Collections.singletonList(localAgg.grouping()); + List aggExpressions = + buildAggregateExpressions(originalInputRowType, aggInfoList); + RelDataType relDataType = localAgg.deriveRowType(); + + TableSourceTable oldTableSourceTable = oldScan.tableSourceTable(); + DynamicTableSource newTableSource = oldScan.tableSource().copy(); + + boolean isPushDownSuccess = + AggregatePushDownSpec.apply( + groupingSets, + aggExpressions, + (RowType) FlinkTypeFactory.toLogicalType(relDataType), + newTableSource); + + if (!isPushDownSuccess) { + // aggregate push down failed, just return without changing any nodes. + return; + } + + FlinkStatistic newFlinkStatistic = getNewFlinkStatistic(oldTableSourceTable); + String[] newExtraDigests = + getNewExtraDigests(originalInputRowType, localAgg.grouping(), aggExpressions); + AggregatePushDownSpec aggregatePushDownSpec = + new AggregatePushDownSpec( + groupingSets, + aggExpressions, + (RowType) FlinkTypeFactory.toLogicalType(relDataType)); + TableSourceTable newTableSourceTable = + oldTableSourceTable + .copy( + newTableSource, + newFlinkStatistic, + new SourceAbilitySpec[] {aggregatePushDownSpec}) + .copy(relDataType); + BatchPhysicalTableSourceScan newScan = + oldScan.copy(oldScan.getTraitSet(), newTableSourceTable); + BatchPhysicalExchange oldExchange = call.rel(0); + BatchPhysicalExchange newExchange = + oldExchange.copy(oldExchange.getTraitSet(), newScan, oldExchange.getDistribution()); + call.transformTo(newExchange); + } + + private List buildAggregateExpressions( + RelDataType originalInputRowType, AggregateInfoList aggInfoList) { + List aggExpressions = new ArrayList<>(); + for (AggregateInfo aggInfo : aggInfoList.aggInfos()) { + List arguments = new ArrayList<>(1); + for (int argIndex : aggInfo.argIndexes()) { + DataType argType = + TypeConversions.fromLogicalToDataType( + FlinkTypeFactory.toLogicalType( + originalInputRowType + .getFieldList() + .get(argIndex) + .getType())); + FieldReferenceExpression field = + new FieldReferenceExpression( + originalInputRowType.getFieldNames().get(argIndex), + argType, + argIndex, + argIndex); + arguments.add(field); + } + if (aggInfo.function() instanceof AvgAggFunction) { + Tuple2 sum0AndCountFunction = + AggregateUtil.deriveSumAndCountFromAvg(aggInfo.function()); + AggregateExpression sum0Expression = + new AggregateExpression( + sum0AndCountFunction._1(), + arguments, + null, + aggInfo.externalResultType(), + aggInfo.agg().isDistinct(), + aggInfo.agg().isApproximate(), + aggInfo.agg().ignoreNulls()); + aggExpressions.add(sum0Expression); + AggregateExpression countExpression = + new AggregateExpression( + sum0AndCountFunction._2(), + arguments, + null, + aggInfo.externalResultType(), + aggInfo.agg().isDistinct(), + aggInfo.agg().isApproximate(), + aggInfo.agg().ignoreNulls()); + aggExpressions.add(countExpression); + } else { + AggregateExpression aggregateExpression = + new AggregateExpression( + aggInfo.function(), + arguments, + null, + aggInfo.externalResultType(), + aggInfo.agg().isDistinct(), + aggInfo.agg().isApproximate(), + aggInfo.agg().ignoreNulls()); + aggExpressions.add(aggregateExpression); + } + } + return aggExpressions; + } + + private FlinkStatistic getNewFlinkStatistic(TableSourceTable tableSourceTable) { + FlinkStatistic oldStatistic = tableSourceTable.getStatistic(); + FlinkStatistic newStatistic; + if (oldStatistic == FlinkStatistic.UNKNOWN()) { + newStatistic = oldStatistic; + } else { + // Remove tableStats after all of aggregate have been pushed down + newStatistic = + FlinkStatistic.builder().statistic(oldStatistic).tableStats(null).build(); + } + return newStatistic; + } + + private String[] getNewExtraDigests( + RelDataType originalInputRowType, + int[] grouping, + List aggregateExpressions) { + String extraDigest; + String groupingStr = "null"; + if (grouping.length > 0) { + groupingStr = + Arrays.stream(grouping) + .mapToObj(index -> originalInputRowType.getFieldNames().get(index)) + .collect(Collectors.joining(",")); + } + String aggFunctionsStr = "null"; + if (aggregateExpressions.size() > 0) { + aggFunctionsStr = + aggregateExpressions.stream() + .map(AggregateExpression::asSummaryString) + .collect(Collectors.joining(",")); + } + extraDigest = + "aggregates=[grouping=[" + + groupingStr + + "], aggFunctions=[" + + aggFunctionsStr + + "]]"; + return new String[] {extraDigest}; + } +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggWithSortIntoTableSourceScanRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggWithSortIntoTableSourceScanRule.java new file mode 100644 index 0000000000000..3f11585743e9d --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggWithSortIntoTableSourceScanRule.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.rules.physical.batch; + +import org.apache.flink.table.connector.source.abilities.SupportsAggregatePushDown; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalExchange; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalGroupAggregateBase; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalLocalSortAggregate; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalSort; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalTableSourceScan; +import org.apache.flink.table.planner.plan.schema.TableSourceTable; + +import org.apache.calcite.plan.RelOptRuleCall; + +/** + * Planner rule that tries to push a local sort aggregate which with sort into a {@link + * BatchPhysicalTableSourceScan} which table is a {@link TableSourceTable}. And the table source in + * the table is a {@link SupportsAggregatePushDown}. + * + *

    When the {@code OptimizerConfigOptions.TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED} is + * true, we have the original physical plan: + * + *

    {@code
    + * BatchPhysicalSortAggregate (global)
    + * +- Sort (exists if group keys are not empty)
    + *    +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton)
    + *       +- BatchPhysicalLocalSortAggregate (local)
    + *          +- Sort (exists if group keys are not empty)
    + *             +- BatchPhysicalTableSourceScan
    + * }
    + * + *

    This physical plan will be rewritten to: + * + *

    {@code
    + * BatchPhysicalSortAggregate (global)
    + * +- Sort (exists if group keys are not empty)
    + *    +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton)
    + *       +- BatchPhysicalTableSourceScan (with local aggregate pushed down)
    + * }
    + */ +public class PushLocalAggWithSortIntoTableSourceScanRule + extends PushLocalAggIntoTableSourceScanRuleBase { + public static final PushLocalAggWithSortIntoTableSourceScanRule INSTANCE = + new PushLocalAggWithSortIntoTableSourceScanRule(); + + public PushLocalAggWithSortIntoTableSourceScanRule() { + super( + operand( + BatchPhysicalExchange.class, + operand( + BatchPhysicalLocalSortAggregate.class, + operand( + BatchPhysicalSort.class, + operand(BatchPhysicalTableSourceScan.class, none())))), + "PushLocalAggWithSortIntoTableSourceScanRule"); + } + + @Override + public boolean matches(RelOptRuleCall call) { + BatchPhysicalGroupAggregateBase localAggregate = call.rel(1); + BatchPhysicalTableSourceScan tableSourceScan = call.rel(3); + return isMatch(call, localAggregate, tableSourceScan); + } + + @Override + public void onMatch(RelOptRuleCall call) { + BatchPhysicalGroupAggregateBase localSortAgg = call.rel(1); + BatchPhysicalTableSourceScan oldScan = call.rel(3); + pushLocalAggregateIntoScan(call, localSortAgg, oldScan); + } +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggWithoutSortIntoTableSourceScanRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggWithoutSortIntoTableSourceScanRule.java new file mode 100644 index 0000000000000..51f1e6f05c3d4 --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggWithoutSortIntoTableSourceScanRule.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.rules.physical.batch; + +import org.apache.flink.table.connector.source.abilities.SupportsAggregatePushDown; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalExchange; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalGroupAggregateBase; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalTableSourceScan; +import org.apache.flink.table.planner.plan.schema.TableSourceTable; + +import org.apache.calcite.plan.RelOptRuleCall; + +/** + * Planner rule that tries to push a local hash or sort aggregate which without sort into a {@link + * BatchPhysicalTableSourceScan} which table is a {@link TableSourceTable}. And the table source in + * the table is a {@link SupportsAggregatePushDown}. + * + *

    When the {@code OptimizerConfigOptions.TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED} is + * true, we have the original physical plan: + * + *

    {@code
    + * BatchPhysicalHashAggregate (global)
    + * +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton)
    + *    +- BatchPhysicalGroupAggregateBase (local)
    + *       +- BatchPhysicalTableSourceScan
    + * }
    + * + *

    This physical plan will be rewritten to: + * + *

    {@code
    + * BatchPhysicalHashAggregate (global)
    + * +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton)
    + *    +- BatchPhysicalTableSourceScan (with local aggregate pushed down)
    + * }
    + */ +public class PushLocalAggWithoutSortIntoTableSourceScanRule + extends PushLocalAggIntoTableSourceScanRuleBase { + public static final PushLocalAggWithoutSortIntoTableSourceScanRule INSTANCE = + new PushLocalAggWithoutSortIntoTableSourceScanRule(); + + public PushLocalAggWithoutSortIntoTableSourceScanRule() { + super( + operand( + BatchPhysicalExchange.class, + operand( + BatchPhysicalGroupAggregateBase.class, + operand(BatchPhysicalTableSourceScan.class, none()))), + "PushLocalAggWithoutSortIntoTableSourceScanRule"); + } + + @Override + public boolean matches(RelOptRuleCall call) { + BatchPhysicalGroupAggregateBase localAggregate = call.rel(1); + BatchPhysicalTableSourceScan tableSourceScan = call.rel(2); + return isMatch(call, localAggregate, tableSourceScan); + } + + @Override + public void onMatch(RelOptRuleCall call) { + BatchPhysicalGroupAggregateBase localHashAgg = call.rel(1); + BatchPhysicalTableSourceScan oldScan = call.rel(2); + pushLocalAggregateIntoScan(call, localHashAgg, oldScan); + } +} diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalTableSourceScan.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalTableSourceScan.scala index 3021002c83449..8082a452fcc71 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalTableSourceScan.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalTableSourceScan.scala @@ -49,6 +49,12 @@ class BatchPhysicalTableSourceScan( new BatchPhysicalTableSourceScan(cluster, traitSet, getHints, tableSourceTable) } + def copy( + traitSet: RelTraitSet, + tableSourceTable: TableSourceTable): BatchPhysicalTableSourceScan = { + new BatchPhysicalTableSourceScan(cluster, traitSet, getHints, tableSourceTable) + } + override def computeSelfCost(planner: RelOptPlanner, mq: RelMetadataQuery): RelOptCost = { val rowCnt = mq.getRowCount(this) if (rowCnt == null) { diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala index 83fa93b04b81b..269034fd11593 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala @@ -448,6 +448,8 @@ object FlinkBatchRuleSets { */ val PHYSICAL_REWRITE: RuleSet = RuleSets.ofList( EnforceLocalHashAggRule.INSTANCE, - EnforceLocalSortAggRule.INSTANCE + EnforceLocalSortAggRule.INSTANCE, + PushLocalAggWithoutSortIntoTableSourceScanRule.INSTANCE, + PushLocalAggWithSortIntoTableSourceScanRule.INSTANCE ) } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/schema/TableSourceTable.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/schema/TableSourceTable.scala index cc96c108e104b..58dcddb0fff58 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/schema/TableSourceTable.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/schema/TableSourceTable.scala @@ -131,4 +131,23 @@ class TableSourceTable( flinkContext, abilitySpecs ++ newAbilitySpecs) } + + /** + * Creates a copy of this table, changing the rowType + * + * @param newRowType new row type + * @return New TableSourceTable instance with new row type + */ + def copy(newRowType: RelDataType): TableSourceTable = { + new TableSourceTable( + relOptSchema, + tableIdentifier, + newRowType, + statistic, + tableSource, + isStreamingMode, + catalogTable, + flinkContext, + abilitySpecs) + } } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala index 7294d24d7f4d0..fd5e1c0486895 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala @@ -27,7 +27,9 @@ import org.apache.flink.table.planner.JLong import org.apache.flink.table.planner.calcite.{FlinkTypeFactory, FlinkTypeSystem} import org.apache.flink.table.planner.delegation.PlannerBase import org.apache.flink.table.planner.expressions._ -import org.apache.flink.table.planner.functions.aggfunctions.DeclarativeAggregateFunction +import org.apache.flink.table.planner.functions.aggfunctions.{CountAggFunction, DeclarativeAggregateFunction, Sum0AggFunction} +import org.apache.flink.table.planner.functions.aggfunctions.AvgAggFunction.{ByteAvgAggFunction, DoubleAvgAggFunction, FloatAvgAggFunction, IntAvgAggFunction, LongAvgAggFunction, ShortAvgAggFunction} +import org.apache.flink.table.planner.functions.aggfunctions.Sum0AggFunction.{ByteSum0AggFunction, DoubleSum0AggFunction, FloatSum0AggFunction, IntSum0AggFunction, LongSum0AggFunction, ShortSum0AggFunction} import org.apache.flink.table.planner.functions.bridging.BridgingSqlAggFunction import org.apache.flink.table.planner.functions.inference.OperatorBindingCallContext import org.apache.flink.table.planner.functions.sql.{FlinkSqlOperatorTable, SqlFirstLastValueAggFunction, SqlListAggFunction} @@ -278,6 +280,22 @@ object AggregateUtil extends Enumeration { isBounded = false) } + def deriveSumAndCountFromAvg( + avgAggFunction: UserDefinedFunction): (Sum0AggFunction, CountAggFunction) = { + avgAggFunction match { + case _: ByteAvgAggFunction => (new ByteSum0AggFunction, new CountAggFunction) + case _: ShortAvgAggFunction => (new ShortSum0AggFunction, new CountAggFunction) + case _: IntAvgAggFunction => (new IntSum0AggFunction, new CountAggFunction) + case _: LongAvgAggFunction => (new LongSum0AggFunction, new CountAggFunction) + case _: FloatAvgAggFunction => (new FloatSum0AggFunction, new CountAggFunction) + case _: DoubleAvgAggFunction => (new DoubleSum0AggFunction, new CountAggFunction) + case _ => { + throw new TableException(s"Avg aggregate function does not support: ''$avgAggFunction''" + + s"Please re-check the function or data type.") + } + } + } + def transformToBatchAggregateFunctions( inputRowType: RowType, aggregateCalls: Seq[AggregateCall], diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java index d71787621b1f8..de958ea2090ad 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java @@ -54,6 +54,7 @@ import org.apache.flink.table.connector.source.ScanTableSource; import org.apache.flink.table.connector.source.SourceFunctionProvider; import org.apache.flink.table.connector.source.TableFunctionProvider; +import org.apache.flink.table.connector.source.abilities.SupportsAggregatePushDown; import org.apache.flink.table.connector.source.abilities.SupportsFilterPushDown; import org.apache.flink.table.connector.source.abilities.SupportsLimitPushDown; import org.apache.flink.table.connector.source.abilities.SupportsPartitionPushDown; @@ -62,11 +63,14 @@ import org.apache.flink.table.connector.source.abilities.SupportsSourceWatermark; import org.apache.flink.table.connector.source.abilities.SupportsWatermarkPushDown; import org.apache.flink.table.data.RowData; +import org.apache.flink.table.expressions.AggregateExpression; +import org.apache.flink.table.expressions.FieldReferenceExpression; import org.apache.flink.table.expressions.ResolvedExpression; import org.apache.flink.table.factories.DynamicTableSinkFactory; import org.apache.flink.table.factories.DynamicTableSourceFactory; import org.apache.flink.table.factories.FactoryUtil; import org.apache.flink.table.functions.AsyncTableFunction; +import org.apache.flink.table.functions.FunctionDefinition; import org.apache.flink.table.functions.TableFunction; import org.apache.flink.table.planner.factories.TestValuesRuntimeFunctions.AppendingOutputFormat; import org.apache.flink.table.planner.factories.TestValuesRuntimeFunctions.AppendingSinkFunction; @@ -74,6 +78,12 @@ import org.apache.flink.table.planner.factories.TestValuesRuntimeFunctions.KeyedUpsertingSinkFunction; import org.apache.flink.table.planner.factories.TestValuesRuntimeFunctions.RetractingSinkFunction; import org.apache.flink.table.planner.factories.TestValuesRuntimeFunctions.TestValuesLookupFunction; +import org.apache.flink.table.planner.functions.aggfunctions.Count1AggFunction; +import org.apache.flink.table.planner.functions.aggfunctions.CountAggFunction; +import org.apache.flink.table.planner.functions.aggfunctions.MaxAggFunction; +import org.apache.flink.table.planner.functions.aggfunctions.MinAggFunction; +import org.apache.flink.table.planner.functions.aggfunctions.Sum0AggFunction; +import org.apache.flink.table.planner.functions.aggfunctions.SumAggFunction; import org.apache.flink.table.planner.runtime.utils.FailingCollectionSource; import org.apache.flink.table.planner.utils.FilterUtils; import org.apache.flink.table.planner.utils.JavaScalaConversionUtil; @@ -95,6 +105,7 @@ import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; @@ -686,7 +697,8 @@ private static class TestValuesScanTableSource SupportsFilterPushDown, SupportsLimitPushDown, SupportsPartitionPushDown, - SupportsReadingMetadata { + SupportsReadingMetadata, + SupportsAggregatePushDown { protected DataType producedDataType; protected final ChangelogMode changelogMode; @@ -705,6 +717,9 @@ private static class TestValuesScanTableSource protected final Map readableMetadata; protected @Nullable int[] projectedMetadataFields; + private @Nullable int[] groupingSet; + private List aggregateExpressions; + private TestValuesScanTableSource( DataType producedDataType, ChangelogMode changelogMode, @@ -736,6 +751,8 @@ private TestValuesScanTableSource( this.allPartitions = allPartitions; this.readableMetadata = readableMetadata; this.projectedMetadataFields = projectedMetadataFields; + this.groupingSet = null; + this.aggregateExpressions = Collections.emptyList(); } @Override @@ -862,32 +879,115 @@ public String asSummaryString() { } protected Collection convertToRowData(DataStructureConverter converter) { - List result = new ArrayList<>(); + List resultBuffer = new ArrayList<>(); List> keys = allPartitions.isEmpty() ? Collections.singletonList(Collections.emptyMap()) : allPartitions; int numRetained = 0; + boolean overLimit = false; for (Map partition : keys) { for (Row row : data.get(partition)) { - if (result.size() >= limit) { - return result; + if (resultBuffer.size() >= limit) { + overLimit = true; + break; } boolean isRetained = FilterUtils.isRetainedAfterApplyingFilterPredicates( filterPredicates, getValueGetter(row)); if (isRetained) { final Row projectedRow = projectRow(row); - final RowData rowData = (RowData) converter.toInternal(projectedRow); - if (rowData != null) { - if (numRetained >= numElementToSkip) { - rowData.setRowKind(row.getKind()); - result.add(rowData); - } - numRetained++; - } + resultBuffer.add(projectedRow); + } + } + if (overLimit) { + break; + } + } + // simulate aggregate operation + if (!aggregateExpressions.isEmpty()) { + resultBuffer = applyAggregatesToRows(resultBuffer); + } + List result = new ArrayList<>(); + for (Row row : resultBuffer) { + final RowData rowData = (RowData) converter.toInternal(row); + if (rowData != null) { + if (numRetained >= numElementToSkip) { + rowData.setRowKind(row.getKind()); + result.add(rowData); + } + numRetained++; + } + } + return result; + } + + private List applyAggregatesToRows(List rows) { + if (groupingSet != null && groupingSet.length > 0) { + // has group by, group firstly + Map> buffer = new HashMap<>(); + for (Row row : rows) { + Row bufferKey = new Row(groupingSet.length); + for (int i = 0; i < groupingSet.length; i++) { + bufferKey.setField(i, row.getField(groupingSet[i])); + } + if (buffer.containsKey(bufferKey)) { + buffer.get(bufferKey).add(row); + } else { + buffer.put(bufferKey, new ArrayList<>(Collections.singletonList(row))); } } + List result = new ArrayList<>(); + for (Map.Entry> entry : buffer.entrySet()) { + result.add(Row.join(entry.getKey(), accumulateRows(entry.getValue()))); + } + return result; + } else { + return Collections.singletonList(accumulateRows(rows)); + } + } + + // can only apply sum/sum0/avg function for long type fields for testing + private Row accumulateRows(List rows) { + Row result = new Row(aggregateExpressions.size()); + for (int i = 0; i < aggregateExpressions.size(); i++) { + FunctionDefinition aggFunction = + aggregateExpressions.get(i).getFunctionDefinition(); + List arguments = aggregateExpressions.get(i).getArgs(); + if (aggFunction instanceof MinAggFunction) { + int argIndex = arguments.get(0).getFieldIndex(); + Row minRow = + rows.stream() + .min(Comparator.comparing(row -> row.getFieldAs(argIndex))) + .get(); + result.setField(i, minRow.getField(argIndex)); + } else if (aggFunction instanceof MaxAggFunction) { + int argIndex = arguments.get(0).getFieldIndex(); + Row maxRow = + rows.stream() + .max(Comparator.comparing(row -> row.getFieldAs(argIndex))) + .get(); + result.setField(i, maxRow.getField(argIndex)); + } else if (aggFunction instanceof SumAggFunction) { + int argIndex = arguments.get(0).getFieldIndex(); + Object finalSum = + rows.stream() + .filter(row -> row.getField(argIndex) != null) + .mapToLong(row -> row.getFieldAs(argIndex)) + .sum(); + result.setField(i, finalSum); + } else if (aggFunction instanceof Sum0AggFunction) { + int argIndex = arguments.get(0).getFieldIndex(); + Object finalSum0 = + rows.stream() + .filter(row -> row.getField(argIndex) != null) + .mapToLong(row -> row.getFieldAs(argIndex)) + .sum(); + result.setField(i, finalSum0); + } else if (aggFunction instanceof CountAggFunction + || aggFunction instanceof Count1AggFunction) { + result.setField(i, (long) rows.size()); + } } return result; } @@ -953,6 +1053,39 @@ public void applyPartitions(List> remainingPartitions) { } } + @Override + public boolean applyAggregates( + List groupingSets, + List aggregateExpressions, + DataType producedDataType) { + // this TestValuesScanTableSource only support simple group type ar present. + if (groupingSets.size() > 1) { + return false; + } + List aggExpressions = new ArrayList<>(); + for (AggregateExpression aggExpression : aggregateExpressions) { + FunctionDefinition functionDefinition = aggExpression.getFunctionDefinition(); + if (!(functionDefinition instanceof MinAggFunction + || functionDefinition instanceof MaxAggFunction + || functionDefinition instanceof SumAggFunction + || functionDefinition instanceof Sum0AggFunction + || functionDefinition instanceof CountAggFunction + || functionDefinition instanceof Count1AggFunction)) { + return false; + } + if (aggExpression.getFilterExpression().isPresent() + || aggExpression.isApproximate() + || aggExpression.isDistinct()) { + return false; + } + aggExpressions.add(aggExpression); + } + this.groupingSet = groupingSets.get(0); + this.aggregateExpressions = aggExpressions; + this.producedDataType = producedDataType; + return true; + } + @Override public void applyLimit(long limit) { this.limit = limit; diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.java new file mode 100644 index 0000000000000..4ef1e888b1f0c --- /dev/null +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.rules.physical.batch; + +import org.apache.flink.table.api.TableConfig; +import org.apache.flink.table.api.config.OptimizerConfigOptions; +import org.apache.flink.table.planner.utils.BatchTableTestUtil; +import org.apache.flink.table.planner.utils.TableTestBase; + +import org.junit.Before; +import org.junit.Test; + +/** + * Test for {@link PushLocalAggWithoutSortIntoTableSourceScanRule} and {@link + * PushLocalAggWithSortIntoTableSourceScanRule}. + */ +public class PushLocalAggIntoTableSourceScanRuleTest extends TableTestBase { + protected BatchTableTestUtil util = batchTestUtil(new TableConfig()); + + @Before + public void setup() { + TableConfig tableConfig = util.tableEnv().getConfig(); + tableConfig + .getConfiguration() + .setBoolean( + OptimizerConfigOptions.TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED, + true); + String ddl = + "CREATE TABLE inventory (\n" + + " id INT,\n" + + " name STRING,\n" + + " amount INT,\n" + + " price DOUBLE,\n" + + " type STRING\n" + + ") WITH (\n" + + " 'connector' = 'values',\n" + + " 'filterable-fields' = 'id',\n" + + " 'bounded' = 'true'\n" + + ")"; + util.tableEnv().executeSql(ddl); + } + + @Test + public void testCanPushDownWithGroup() { + util.verifyRelPlan( + "SELECT\n" + + " sum(amount),\n" + + " name,\n" + + " type\n" + + "FROM inventory\n" + + " group by name, type"); + } + + @Test + public void testCanPushDownWithoutGroup() { + util.verifyRelPlan( + "SELECT\n" + + " min(id),\n" + + " max(amount),\n" + + " max(name),\n" + + " sum(price),\n" + + " avg(price),\n" + + " count(id)\n" + + "FROM inventory"); + } + + @Test + public void testCannotPushDownWithColumnExpression() { + util.verifyRelPlan( + "SELECT\n" + + " min(amount + price),\n" + + " max(amount),\n" + + " sum(price),\n" + + " count(id),\n" + + " name\n" + + "FROM inventory\n" + + " group by name"); + } + + @Test + public void testCannotPushDownWithUnsupportedAggFunction() { + util.verifyRelPlan( + "SELECT\n" + + " min(id),\n" + + " max(amount),\n" + + " sum(price),\n" + + " count(distinct id),\n" + + " name\n" + + "FROM inventory\n" + + " group by name"); + } + + @Test + public void testCannotPushDownWithWindowAggFunction() { + util.verifyRelPlan( + "SELECT\n" + + " id,\n" + + " amount,\n" + + " sum(price) over (partition by name),\n" + + " name\n" + + "FROM inventory"); + } + + @Test + public void testCannotPushDownWithFilter() { + util.verifyRelPlan( + "SELECT\n" + + " min(id),\n" + + " max(amount),\n" + + " sum(price),\n" + + " count(id) FILTER(WHERE id > 100),\n" + + " name\n" + + "FROM inventory\n" + + " group by name"); + } +} diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.xml new file mode 100644 index 0000000000000..2a5061a174ef4 --- /dev/null +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.xml @@ -0,0 +1,190 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + (COUNT($3) OVER (PARTITION BY $1), 0), $SUM0($3) OVER (PARTITION BY $1), null:DOUBLE)], name=[$1]) ++- LogicalTableScan(table=[[default_catalog, default_database, inventory]]) +]]> + + + (w0$o0, 0:BIGINT), w0$o1, null:DOUBLE) AS EXPR$2, name]) ++- OverAggregate(partitionBy=[name], window#0=[COUNT(price) AS w0$o0, $SUM0(price) AS w0$o1 RANG BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING], select=[id, name, amount, price, w0$o0, w0$o1]) + +- Sort(orderBy=[name ASC]) + +- Exchange(distribution=[hash[name]]) + +- TableSourceScan(table=[[default_catalog, default_database, inventory, project=[id, name, amount, price]]], fields=[id, name, amount, price]) +]]> + + + + + 100), + name +FROM inventory + group by name]]> + + + ($0, 100))]) + +- LogicalTableScan(table=[[default_catalog, default_database, inventory]]) +]]> + + + (id, 100)) AS $f4]) + +- TableSourceScan(table=[[default_catalog, default_database, inventory, project=[name, id, amount, price]]], fields=[name, id, amount, price]) +]]> + + + diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/LocalAggregatePushDownITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/LocalAggregatePushDownITCase.scala new file mode 100644 index 0000000000000..7f1c087be5498 --- /dev/null +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/LocalAggregatePushDownITCase.scala @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.runtime.batch.sql.agg + +import org.apache.flink.table.api.config.OptimizerConfigOptions +import org.apache.flink.table.planner.factories.TestValuesTableFactory +import org.apache.flink.table.planner.runtime.utils.BatchTestBase.row +import org.apache.flink.table.planner.runtime.utils.{BatchTestBase, TestData} +import org.junit.{Before, Test} + +class LocalAggregatePushDownITCase extends BatchTestBase { + + @Before + override def before(): Unit = { + super.before() + env.setParallelism(1) // set sink parallelism to 1 + conf.getConfiguration.setBoolean( + OptimizerConfigOptions.TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED, true) + val testDataId = TestValuesTableFactory.registerData(TestData.personData) + val ddl = + s""" + |CREATE TABLE AggregatableTable ( + | id int, + | age int, + | name string, + | height int, + | gender string, + | deposit bigint + |) WITH ( + | 'connector' = 'values', + | 'data-id' = '$testDataId', + | 'bounded' = 'true' + |) + """.stripMargin + tEnv.executeSql(ddl) + + } + + @Test + def testAggregateWithGroupBy(): Unit = { + checkResult( + """ + |SELECT + | min(age) as min_age, + | max(height), + | avg(deposit), + | sum(deposit), + | count(1), + | gender + |FROM + | AggregatableTable + |GROUP BY gender + |ORDER BY min_age + |""".stripMargin, + Seq( + row(19, 180, 126, 630, 5, "f"), + row(23, 182, 220, 1320, 6, "m")) + ) + } + + @Test + def testAggregateWithMultiGroupBy(): Unit = { + checkResult( + """ + |SELECT + | min(age), + | max(height), + | avg(deposit), + | sum(deposit), + | count(1), + | gender, + | age + |FROM + | AggregatableTable + |GROUP BY gender, age + |""".stripMargin, + Seq( + row(19, 172, 50, 50, 1, "f", 19), + row(20, 180, 200, 200, 1, "f", 20), + row(23, 182, 250, 750, 3, "m", 23), + row(25, 171, 126, 380, 3, "f", 25), + row(27, 175, 300, 300, 1, "m", 27), + row(28, 165, 170, 170, 1, "m", 28), + row(34, 170, 100, 100, 1, "m", 34)) + ) + } + + @Test + def testAggregateWithoutGroupBy(): Unit = { + checkResult( + """ + |SELECT + | min(age), + | max(height), + | avg(deposit), + | sum(deposit), + | count(*) + |FROM + | AggregatableTable + |""".stripMargin, + Seq( + row(19, 182, 177, 1950, 11)) + ) + } + + @Test + def testAggregateCanNotPushDown(): Unit = { + checkResult( + """ + |SELECT + | min(age), + | max(height), + | avg(deposit), + | sum(deposit), + | count(distinct age), + | gender + |FROM + | AggregatableTable + |GROUP BY gender + |""".stripMargin, + Seq( + row(19, 180, 126, 630, 3, "f"), + row(23, 182, 220, 1320, 4, "m")) + ) + } +} diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/utils/TestData.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/utils/TestData.scala index 538c5a82e4919..1e70ad1e92a35 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/utils/TestData.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/utils/TestData.scala @@ -414,17 +414,17 @@ object TestData { // person test data lazy val personData: Seq[Row] = Seq( - row(1, 23, "tom", 172, "m"), - row(2, 21, "mary", 161, "f"), - row(3, 18, "jack", 182, "m"), - row(4, 25, "rose", 165, "f"), - row(5, 27, "danny", 175, "m"), - row(6, 31, "tommas", 172, "m"), - row(7, 19, "olivia", 172, "f"), - row(8, 34, "stef", 170, "m"), - row(9, 32, "emma", 171, "f"), - row(10, 28, "benji", 165, "m"), - row(11, 20, "eva", 180, "f") + row(1, 23, "tom", 172, "m", 200L), + row(2, 25, "mary", 161, "f", 100L), + row(3, 23, "jack", 182, "m", 150L), + row(4, 25, "rose", 165, "f", 100L), + row(5, 27, "danny", 175, "m", 300L), + row(6, 23, "tommas", 172, "m", 400L), + row(7, 19, "olivia", 172, "f", 50L), + row(8, 34, "stef", 170, "m", 100L), + row(9, 25, "emma", 171, "f", 180L), + row(10, 28, "benji", 165, "m", 170L), + row(11, 20, "eva", 180, "f", 200L) ) val nullablesOfPersonData = Array(true, true, true, true, true) From f468ecaf5f8e7360b899ba927103bc634507c90d Mon Sep 17 00:00:00 2001 From: "Yu, Peng" Date: Thu, 23 Sep 2021 22:23:07 +0800 Subject: [PATCH 2/7] [FLINK-20895] make changes according to review comments --- .../abilities/SupportsAggregatePushDown.java | 2 - .../expressions/AggregateExpression.java | 5 - .../source/AggregatePushDownSpec.java | 133 ++++++- .../batch/PushLocalAggIntoScanRuleBase.java | 151 ++++++++ ...shLocalAggIntoTableSourceScanRuleBase.java | 274 -------------- ...java => PushLocalHashAggIntoScanRule.java} | 27 +- ...PushLocalSortAggWithSortIntoScanRule.java} | 18 +- ...shLocalSortAggWithoutSortIntoScanRule.java | 80 +++++ .../batch/BatchPhysicalTableSourceScan.scala | 17 +- .../plan/rules/FlinkBatchRuleSets.scala | 5 +- .../plan/schema/TableSourceTable.scala | 11 +- .../planner/plan/utils/AggregateUtil.scala | 7 +- .../factories/TestValuesTableFactory.java | 17 +- ...shLocalAggIntoTableSourceScanRuleTest.java | 166 ++++++++- ...ushLocalAggIntoTableSourceScanRuleTest.xml | 263 +++++++++++++- .../agg/LocalAggregatePushDownITCase.scala | 333 ++++++++++++++++-- .../planner/runtime/utils/TestData.scala | 22 +- 17 files changed, 1137 insertions(+), 394 deletions(-) create mode 100644 flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoScanRuleBase.java delete mode 100644 flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleBase.java rename flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/{PushLocalAggWithoutSortIntoTableSourceScanRule.java => PushLocalHashAggIntoScanRule.java} (75%) rename flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/{PushLocalAggWithSortIntoTableSourceScanRule.java => PushLocalSortAggWithSortIntoScanRule.java} (84%) create mode 100644 flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithoutSortIntoScanRule.java diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/connector/source/abilities/SupportsAggregatePushDown.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/connector/source/abilities/SupportsAggregatePushDown.java index 218c2a41af2ca..67f645ac15d50 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/connector/source/abilities/SupportsAggregatePushDown.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/connector/source/abilities/SupportsAggregatePushDown.java @@ -124,8 +124,6 @@ * *

    Regardless if this interface is implemented or not, a final aggregation is always applied in a * subsequent operation after the source. - * - *

    Note: currently, the {@link SupportsAggregatePushDown} is not supported by planner. */ @PublicEvolving public interface SupportsAggregatePushDown { diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/AggregateExpression.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/AggregateExpression.java index ce111796db771..897ab8aa985c6 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/AggregateExpression.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/AggregateExpression.java @@ -19,7 +19,6 @@ package org.apache.flink.table.expressions; import org.apache.flink.annotation.PublicEvolving; -import org.apache.flink.table.connector.source.abilities.SupportsAggregatePushDown; import org.apache.flink.table.functions.FunctionDefinition; import org.apache.flink.table.types.DataType; import org.apache.flink.util.Preconditions; @@ -47,9 +46,6 @@ *

  • {@code approximate} indicates whether this is a approximate aggregate function. *
  • {@code ignoreNulls} indicates whether this aggregate function ignore null value. * - * - *

    Note: currently, the {@link AggregateExpression} is only used in {@link - * SupportsAggregatePushDown}. */ @PublicEvolving public class AggregateExpression implements ResolvedExpression { @@ -107,7 +103,6 @@ public List getArgs() { return args; } - @Nullable public Optional getFilterExpression() { return Optional.ofNullable(filterExpression); } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/AggregatePushDownSpec.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/AggregatePushDownSpec.java index d2085fb02e7dd..24c7f061be09a 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/AggregatePushDownSpec.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/AggregatePushDownSpec.java @@ -22,6 +22,14 @@ import org.apache.flink.table.connector.source.DynamicTableSource; import org.apache.flink.table.connector.source.abilities.SupportsAggregatePushDown; import org.apache.flink.table.expressions.AggregateExpression; +import org.apache.flink.table.expressions.FieldReferenceExpression; +import org.apache.flink.table.planner.functions.aggfunctions.AvgAggFunction; +import org.apache.flink.table.planner.functions.aggfunctions.CountAggFunction; +import org.apache.flink.table.planner.functions.aggfunctions.Sum0AggFunction; +import org.apache.flink.table.planner.plan.utils.AggregateInfo; +import org.apache.flink.table.planner.plan.utils.AggregateInfoList; +import org.apache.flink.table.planner.plan.utils.AggregateUtil; +import org.apache.flink.table.planner.utils.JavaScalaConversionUtil; import org.apache.flink.table.types.DataType; import org.apache.flink.table.types.logical.RowType; import org.apache.flink.table.types.utils.TypeConversions; @@ -30,56 +38,100 @@ import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty; import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonTypeName; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.commons.lang3.ArrayUtils; + import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; import java.util.List; +import java.util.stream.Collectors; + +import scala.Tuple2; import static org.apache.flink.util.Preconditions.checkArgument; import static org.apache.flink.util.Preconditions.checkNotNull; /** * A sub-class of {@link SourceAbilitySpec} that can not only serialize/deserialize the aggregation - * to/from JSON, but also can push the filter into a {@link SupportsAggregatePushDown}. + * to/from JSON, but also can push the local aggregate into a {@link SupportsAggregatePushDown}. */ @JsonTypeName("AggregatePushDown") public class AggregatePushDownSpec extends SourceAbilitySpecBase { + public static final String FIELD_NAME_INPUT_TYPE = "inputType"; + public static final String FIELD_NAME_GROUPING_SETS = "groupingSets"; - public static final String FIELD_NAME_AGGREGATE_EXPRESSIONS = "aggregateExpressions"; + public static final String FIELD_NAME_AGGREGATE_CALLS = "aggregateCalls"; + + @JsonProperty(FIELD_NAME_INPUT_TYPE) + private final RowType inputType; @JsonProperty(FIELD_NAME_GROUPING_SETS) private final List groupingSets; - @JsonProperty(FIELD_NAME_AGGREGATE_EXPRESSIONS) - private final List aggregateExpressions; + @JsonProperty(FIELD_NAME_AGGREGATE_CALLS) + private final List aggregateCalls; @JsonCreator public AggregatePushDownSpec( + @JsonProperty(FIELD_NAME_INPUT_TYPE) RowType inputType, @JsonProperty(FIELD_NAME_GROUPING_SETS) List groupingSets, - @JsonProperty(FIELD_NAME_AGGREGATE_EXPRESSIONS) - List aggregateExpressions, + @JsonProperty(FIELD_NAME_AGGREGATE_CALLS) List aggregateCalls, @JsonProperty(FIELD_NAME_PRODUCED_TYPE) RowType producedType) { super(producedType); + + this.inputType = inputType; this.groupingSets = new ArrayList<>(checkNotNull(groupingSets)); - this.aggregateExpressions = new ArrayList<>(checkNotNull(aggregateExpressions)); + this.aggregateCalls = aggregateCalls; } @Override public void apply(DynamicTableSource tableSource, SourceAbilityContext context) { checkArgument(getProducedType().isPresent()); - apply(groupingSets, aggregateExpressions, getProducedType().get(), tableSource); + apply(inputType, groupingSets, aggregateCalls, getProducedType().get(), tableSource); } @Override public String getDigests(SourceAbilityContext context) { - return null; + String extraDigest; + String groupingStr = ""; + int[] grouping = ArrayUtils.addAll(groupingSets.get(0), groupingSets.get(1)); + if (grouping.length > 0) { + groupingStr = + Arrays.stream(grouping) + .mapToObj(index -> inputType.getFieldNames().get(index)) + .collect(Collectors.joining(",")); + } + String aggFunctionsStr = ""; + + List aggregateExpressions = + buildAggregateExpressions(inputType, aggregateCalls); + if (aggregateExpressions.size() > 0) { + aggFunctionsStr = + aggregateExpressions.stream() + .map(AggregateExpression::asSummaryString) + .collect(Collectors.joining(",")); + } + extraDigest = + "aggregates=[grouping=[" + + groupingStr + + "], aggFunctions=[" + + aggFunctionsStr + + "]]"; + return extraDigest; } public static boolean apply( + RowType inputType, List groupingSets, - List aggregateExpressions, + List aggregateCalls, RowType producedType, DynamicTableSource tableSource) { + List aggregateExpressions = + buildAggregateExpressions(inputType, aggregateCalls); + if (tableSource instanceof SupportsAggregatePushDown) { DataType producedDataType = TypeConversions.fromLogicalToDataType(producedType); return ((SupportsAggregatePushDown) tableSource) @@ -91,4 +143,65 @@ public static boolean apply( tableSource.getClass().getName())); } } + + private static List buildAggregateExpressions( + RowType inputType, List aggregateCalls) { + AggregateInfoList aggInfoList = + AggregateUtil.transformToBatchAggregateInfoList( + inputType, JavaScalaConversionUtil.toScala(aggregateCalls), null, null); + if (aggInfoList.aggInfos().length == 0) { + // no agg function need to be pushed down + return Collections.emptyList(); + } + + List aggExpressions = new ArrayList<>(); + for (AggregateInfo aggInfo : aggInfoList.aggInfos()) { + List arguments = new ArrayList<>(1); + for (int argIndex : aggInfo.argIndexes()) { + DataType argType = + TypeConversions.fromLogicalToDataType( + inputType.getFields().get(argIndex).getType()); + FieldReferenceExpression field = + new FieldReferenceExpression( + inputType.getFieldNames().get(argIndex), argType, 0, argIndex); + arguments.add(field); + } + if (aggInfo.function() instanceof AvgAggFunction) { + Tuple2 sum0AndCountFunction = + AggregateUtil.deriveSumAndCountFromAvg(aggInfo.function()); + AggregateExpression sum0Expression = + new AggregateExpression( + sum0AndCountFunction._1(), + arguments, + null, + aggInfo.externalResultType(), + aggInfo.agg().isDistinct(), + aggInfo.agg().isApproximate(), + aggInfo.agg().ignoreNulls()); + aggExpressions.add(sum0Expression); + AggregateExpression countExpression = + new AggregateExpression( + sum0AndCountFunction._2(), + arguments, + null, + aggInfo.externalResultType(), + aggInfo.agg().isDistinct(), + aggInfo.agg().isApproximate(), + aggInfo.agg().ignoreNulls()); + aggExpressions.add(countExpression); + } else { + AggregateExpression aggregateExpression = + new AggregateExpression( + aggInfo.function(), + arguments, + null, + aggInfo.externalResultType(), + aggInfo.agg().isDistinct(), + aggInfo.agg().isApproximate(), + aggInfo.agg().ignoreNulls()); + aggExpressions.add(aggregateExpression); + } + } + return aggExpressions; + } } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoScanRuleBase.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoScanRuleBase.java new file mode 100644 index 0000000000000..6837e612bd764 --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoScanRuleBase.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.rules.physical.batch; + +import org.apache.flink.table.api.TableConfig; +import org.apache.flink.table.api.config.OptimizerConfigOptions; +import org.apache.flink.table.connector.source.DynamicTableSource; +import org.apache.flink.table.connector.source.abilities.SupportsAggregatePushDown; +import org.apache.flink.table.planner.calcite.FlinkTypeFactory; +import org.apache.flink.table.planner.plan.abilities.source.AggregatePushDownSpec; +import org.apache.flink.table.planner.plan.abilities.source.SourceAbilitySpec; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalExchange; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalGroupAggregateBase; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalTableSourceScan; +import org.apache.flink.table.planner.plan.schema.TableSourceTable; +import org.apache.flink.table.planner.plan.stats.FlinkStatistic; +import org.apache.flink.table.planner.utils.JavaScalaConversionUtil; +import org.apache.flink.table.planner.utils.ShortcutUtils; +import org.apache.flink.table.types.logical.RowType; + +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelOptRuleOperand; +import org.apache.calcite.rel.core.AggregateCall; + +import java.util.Arrays; +import java.util.List; + +/** + * Planner rule that tries to push a local aggregator into an {@link BatchPhysicalTableSourceScan} + * which table is a {@link TableSourceTable}. And the table source in the table is a {@link + * SupportsAggregatePushDown}. + * + *

    The aggregate push down does not support a number of more complex statements at present: + * + *

      + *
    • complex grouping operations such as ROLLUP, CUBE, or GROUPING SETS. + *
    • expressions inside the aggregation function call: such as sum(a * b). + *
    • aggregations with ordering. + *
    • aggregations with filter. + *
    + */ +public abstract class PushLocalAggIntoScanRuleBase extends RelOptRule { + + public PushLocalAggIntoScanRuleBase(RelOptRuleOperand operand, String description) { + super(operand, description); + } + + protected boolean isMatch( + RelOptRuleCall call, + BatchPhysicalGroupAggregateBase aggregate, + BatchPhysicalTableSourceScan tableSourceScan) { + TableConfig tableConfig = ShortcutUtils.unwrapContext(call.getPlanner()).getTableConfig(); + if (!tableConfig + .getConfiguration() + .getBoolean( + OptimizerConfigOptions.TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED)) { + return false; + } + + if (aggregate.isFinal() || aggregate.getAggCallList().isEmpty()) { + return false; + } + List aggCallList = + JavaScalaConversionUtil.toJava(aggregate.getAggCallList()); + for (AggregateCall aggCall : aggCallList) { + if (aggCall.isDistinct() + || aggCall.isApproximate() + || aggCall.getArgList().size() > 1 + || aggCall.hasFilter() + || !aggCall.getCollation().getFieldCollations().isEmpty()) { + return false; + } + } + TableSourceTable tableSourceTable = tableSourceScan.tableSourceTable(); + // we can not push aggregates twice + return tableSourceTable != null + && tableSourceTable.tableSource() instanceof SupportsAggregatePushDown + && Arrays.stream(tableSourceTable.abilitySpecs()) + .noneMatch(spec -> spec instanceof AggregatePushDownSpec); + } + + protected void pushLocalAggregateIntoScan( + RelOptRuleCall call, + BatchPhysicalGroupAggregateBase localAgg, + BatchPhysicalTableSourceScan oldScan) { + RowType inputType = FlinkTypeFactory.toLogicalRowType(oldScan.getRowType()); + List groupingSets = Arrays.asList(localAgg.grouping(), localAgg.auxGrouping()); + List aggCallList = JavaScalaConversionUtil.toJava(localAgg.getAggCallList()); + RowType producedType = FlinkTypeFactory.toLogicalRowType(localAgg.getRowType()); + + TableSourceTable oldTableSourceTable = oldScan.tableSourceTable(); + DynamicTableSource newTableSource = oldScan.tableSource().copy(); + + boolean isPushDownSuccess = + AggregatePushDownSpec.apply( + inputType, groupingSets, aggCallList, producedType, newTableSource); + + if (!isPushDownSuccess) { + // aggregate push down failed, just return without changing any nodes. + return; + } + + FlinkStatistic newFlinkStatistic = getNewFlinkStatistic(oldTableSourceTable); + AggregatePushDownSpec aggregatePushDownSpec = + new AggregatePushDownSpec(inputType, groupingSets, aggCallList, producedType); + + TableSourceTable newTableSourceTable = + oldTableSourceTable + .copy( + newTableSource, + newFlinkStatistic, + new SourceAbilitySpec[] {aggregatePushDownSpec}) + .copy(localAgg.getRowType()); + BatchPhysicalTableSourceScan newScan = + oldScan.copy(oldScan.getTraitSet(), newTableSourceTable); + BatchPhysicalExchange oldExchange = call.rel(0); + BatchPhysicalExchange newExchange = + oldExchange.copy(oldExchange.getTraitSet(), newScan, oldExchange.getDistribution()); + call.transformTo(newExchange); + } + + private FlinkStatistic getNewFlinkStatistic(TableSourceTable tableSourceTable) { + FlinkStatistic oldStatistic = tableSourceTable.getStatistic(); + FlinkStatistic newStatistic; + if (oldStatistic == FlinkStatistic.UNKNOWN()) { + newStatistic = oldStatistic; + } else { + // Remove tableStats after all of aggregate have been pushed down + newStatistic = + FlinkStatistic.builder().statistic(oldStatistic).tableStats(null).build(); + } + return newStatistic; + } +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleBase.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleBase.java deleted file mode 100644 index 49f2cc0b0532b..0000000000000 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleBase.java +++ /dev/null @@ -1,274 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.planner.plan.rules.physical.batch; - -import org.apache.flink.table.api.TableConfig; -import org.apache.flink.table.api.config.OptimizerConfigOptions; -import org.apache.flink.table.connector.source.DynamicTableSource; -import org.apache.flink.table.connector.source.abilities.SupportsAggregatePushDown; -import org.apache.flink.table.expressions.AggregateExpression; -import org.apache.flink.table.expressions.FieldReferenceExpression; -import org.apache.flink.table.planner.calcite.FlinkTypeFactory; -import org.apache.flink.table.planner.functions.aggfunctions.AvgAggFunction; -import org.apache.flink.table.planner.functions.aggfunctions.CountAggFunction; -import org.apache.flink.table.planner.functions.aggfunctions.Sum0AggFunction; -import org.apache.flink.table.planner.plan.abilities.source.AggregatePushDownSpec; -import org.apache.flink.table.planner.plan.abilities.source.SourceAbilitySpec; -import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalExchange; -import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalGroupAggregateBase; -import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalTableSourceScan; -import org.apache.flink.table.planner.plan.schema.TableSourceTable; -import org.apache.flink.table.planner.plan.stats.FlinkStatistic; -import org.apache.flink.table.planner.plan.utils.AggregateInfo; -import org.apache.flink.table.planner.plan.utils.AggregateInfoList; -import org.apache.flink.table.planner.plan.utils.AggregateUtil; -import org.apache.flink.table.planner.utils.ShortcutUtils; -import org.apache.flink.table.types.DataType; -import org.apache.flink.table.types.logical.RowType; -import org.apache.flink.table.types.utils.TypeConversions; - -import org.apache.calcite.plan.RelOptRule; -import org.apache.calcite.plan.RelOptRuleCall; -import org.apache.calcite.plan.RelOptRuleOperand; -import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.type.RelDataType; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.stream.Collectors; - -import scala.Tuple2; -import scala.collection.JavaConverters; - -/** - * Planner rule that tries to push a local aggregator into an {@link BatchPhysicalTableSourceScan} - * which table is a {@link TableSourceTable}. And the table source in the table is a {@link - * SupportsAggregatePushDown}. - * - *

    The aggregate push down does not support a number of more complex statements at present: - * - *

      - *
    • complex grouping operations such as ROLLUP, CUBE, or GROUPING SETS. - *
    • expressions inside the aggregation function call: such as sum(a * b). - *
    • aggregations with ordering. - *
    • aggregations with filter. - *
    - */ -public abstract class PushLocalAggIntoTableSourceScanRuleBase extends RelOptRule { - - public PushLocalAggIntoTableSourceScanRuleBase(RelOptRuleOperand operand, String description) { - super(operand, description); - } - - protected boolean isMatch( - RelOptRuleCall call, - BatchPhysicalGroupAggregateBase aggregate, - BatchPhysicalTableSourceScan tableSourceScan) { - TableConfig tableConfig = ShortcutUtils.unwrapContext(call.getPlanner()).getTableConfig(); - if (!tableConfig - .getConfiguration() - .getBoolean( - OptimizerConfigOptions.TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED)) { - return false; - } - - if (aggregate.isFinal() || aggregate.getAggCallList().size() < 1) { - return false; - } - List aggCallList = - JavaConverters.seqAsJavaListConverter(aggregate.getAggCallList()).asJava(); - for (AggregateCall aggCall : aggCallList) { - if (aggCall.isDistinct() - || aggCall.isApproximate() - || aggCall.getArgList().size() > 1 - || aggCall.hasFilter() - || !aggCall.getCollation().getFieldCollations().isEmpty()) { - return false; - } - } - TableSourceTable tableSourceTable = tableSourceScan.tableSourceTable(); - // we can not push aggregates twice - return tableSourceTable != null - && tableSourceTable.tableSource() instanceof SupportsAggregatePushDown - && Arrays.stream(tableSourceTable.abilitySpecs()) - .noneMatch(spec -> spec instanceof AggregatePushDownSpec); - } - - protected void pushLocalAggregateIntoScan( - RelOptRuleCall call, - BatchPhysicalGroupAggregateBase localAgg, - BatchPhysicalTableSourceScan oldScan) { - RelDataType originalInputRowType = oldScan.deriveRowType(); - AggregateInfoList aggInfoList = - AggregateUtil.transformToBatchAggregateInfoList( - FlinkTypeFactory.toLogicalRowType(localAgg.getInput().getRowType()), - localAgg.getAggCallList(), - null, - null); - if (aggInfoList.aggInfos().length == 0) { - // no agg function need to be pushed down - return; - } - - List groupingSets = Collections.singletonList(localAgg.grouping()); - List aggExpressions = - buildAggregateExpressions(originalInputRowType, aggInfoList); - RelDataType relDataType = localAgg.deriveRowType(); - - TableSourceTable oldTableSourceTable = oldScan.tableSourceTable(); - DynamicTableSource newTableSource = oldScan.tableSource().copy(); - - boolean isPushDownSuccess = - AggregatePushDownSpec.apply( - groupingSets, - aggExpressions, - (RowType) FlinkTypeFactory.toLogicalType(relDataType), - newTableSource); - - if (!isPushDownSuccess) { - // aggregate push down failed, just return without changing any nodes. - return; - } - - FlinkStatistic newFlinkStatistic = getNewFlinkStatistic(oldTableSourceTable); - String[] newExtraDigests = - getNewExtraDigests(originalInputRowType, localAgg.grouping(), aggExpressions); - AggregatePushDownSpec aggregatePushDownSpec = - new AggregatePushDownSpec( - groupingSets, - aggExpressions, - (RowType) FlinkTypeFactory.toLogicalType(relDataType)); - TableSourceTable newTableSourceTable = - oldTableSourceTable - .copy( - newTableSource, - newFlinkStatistic, - new SourceAbilitySpec[] {aggregatePushDownSpec}) - .copy(relDataType); - BatchPhysicalTableSourceScan newScan = - oldScan.copy(oldScan.getTraitSet(), newTableSourceTable); - BatchPhysicalExchange oldExchange = call.rel(0); - BatchPhysicalExchange newExchange = - oldExchange.copy(oldExchange.getTraitSet(), newScan, oldExchange.getDistribution()); - call.transformTo(newExchange); - } - - private List buildAggregateExpressions( - RelDataType originalInputRowType, AggregateInfoList aggInfoList) { - List aggExpressions = new ArrayList<>(); - for (AggregateInfo aggInfo : aggInfoList.aggInfos()) { - List arguments = new ArrayList<>(1); - for (int argIndex : aggInfo.argIndexes()) { - DataType argType = - TypeConversions.fromLogicalToDataType( - FlinkTypeFactory.toLogicalType( - originalInputRowType - .getFieldList() - .get(argIndex) - .getType())); - FieldReferenceExpression field = - new FieldReferenceExpression( - originalInputRowType.getFieldNames().get(argIndex), - argType, - argIndex, - argIndex); - arguments.add(field); - } - if (aggInfo.function() instanceof AvgAggFunction) { - Tuple2 sum0AndCountFunction = - AggregateUtil.deriveSumAndCountFromAvg(aggInfo.function()); - AggregateExpression sum0Expression = - new AggregateExpression( - sum0AndCountFunction._1(), - arguments, - null, - aggInfo.externalResultType(), - aggInfo.agg().isDistinct(), - aggInfo.agg().isApproximate(), - aggInfo.agg().ignoreNulls()); - aggExpressions.add(sum0Expression); - AggregateExpression countExpression = - new AggregateExpression( - sum0AndCountFunction._2(), - arguments, - null, - aggInfo.externalResultType(), - aggInfo.agg().isDistinct(), - aggInfo.agg().isApproximate(), - aggInfo.agg().ignoreNulls()); - aggExpressions.add(countExpression); - } else { - AggregateExpression aggregateExpression = - new AggregateExpression( - aggInfo.function(), - arguments, - null, - aggInfo.externalResultType(), - aggInfo.agg().isDistinct(), - aggInfo.agg().isApproximate(), - aggInfo.agg().ignoreNulls()); - aggExpressions.add(aggregateExpression); - } - } - return aggExpressions; - } - - private FlinkStatistic getNewFlinkStatistic(TableSourceTable tableSourceTable) { - FlinkStatistic oldStatistic = tableSourceTable.getStatistic(); - FlinkStatistic newStatistic; - if (oldStatistic == FlinkStatistic.UNKNOWN()) { - newStatistic = oldStatistic; - } else { - // Remove tableStats after all of aggregate have been pushed down - newStatistic = - FlinkStatistic.builder().statistic(oldStatistic).tableStats(null).build(); - } - return newStatistic; - } - - private String[] getNewExtraDigests( - RelDataType originalInputRowType, - int[] grouping, - List aggregateExpressions) { - String extraDigest; - String groupingStr = "null"; - if (grouping.length > 0) { - groupingStr = - Arrays.stream(grouping) - .mapToObj(index -> originalInputRowType.getFieldNames().get(index)) - .collect(Collectors.joining(",")); - } - String aggFunctionsStr = "null"; - if (aggregateExpressions.size() > 0) { - aggFunctionsStr = - aggregateExpressions.stream() - .map(AggregateExpression::asSummaryString) - .collect(Collectors.joining(",")); - } - extraDigest = - "aggregates=[grouping=[" - + groupingStr - + "], aggFunctions=[" - + aggFunctionsStr - + "]]"; - return new String[] {extraDigest}; - } -} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggWithoutSortIntoTableSourceScanRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalHashAggIntoScanRule.java similarity index 75% rename from flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggWithoutSortIntoTableSourceScanRule.java rename to flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalHashAggIntoScanRule.java index 51f1e6f05c3d4..bfb3ee20e8375 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggWithoutSortIntoTableSourceScanRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalHashAggIntoScanRule.java @@ -18,9 +18,10 @@ package org.apache.flink.table.planner.plan.rules.physical.batch; +import org.apache.flink.table.api.config.OptimizerConfigOptions; import org.apache.flink.table.connector.source.abilities.SupportsAggregatePushDown; import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalExchange; -import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalGroupAggregateBase; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalLocalHashAggregate; import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalTableSourceScan; import org.apache.flink.table.planner.plan.schema.TableSourceTable; @@ -29,15 +30,15 @@ /** * Planner rule that tries to push a local hash or sort aggregate which without sort into a {@link * BatchPhysicalTableSourceScan} which table is a {@link TableSourceTable}. And the table source in - * the table is a {@link SupportsAggregatePushDown}. + * the table is a {@link SupportsAggregatePushDown}. The {@link + * OptimizerConfigOptions#TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED} need to be true. * - *

    When the {@code OptimizerConfigOptions.TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED} is - * true, we have the original physical plan: + *

    Suppose we have the original physical plan: * *

    {@code
      * BatchPhysicalHashAggregate (global)
      * +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton)
    - *    +- BatchPhysicalGroupAggregateBase (local)
    + *    +- BatchPhysicalLocalHashAggregate (local)
      *       +- BatchPhysicalTableSourceScan
      * }
    * @@ -49,31 +50,29 @@ * +- BatchPhysicalTableSourceScan (with local aggregate pushed down) * } */ -public class PushLocalAggWithoutSortIntoTableSourceScanRule - extends PushLocalAggIntoTableSourceScanRuleBase { - public static final PushLocalAggWithoutSortIntoTableSourceScanRule INSTANCE = - new PushLocalAggWithoutSortIntoTableSourceScanRule(); +public class PushLocalHashAggIntoScanRule extends PushLocalAggIntoScanRuleBase { + public static final PushLocalHashAggIntoScanRule INSTANCE = new PushLocalHashAggIntoScanRule(); - public PushLocalAggWithoutSortIntoTableSourceScanRule() { + public PushLocalHashAggIntoScanRule() { super( operand( BatchPhysicalExchange.class, operand( - BatchPhysicalGroupAggregateBase.class, + BatchPhysicalLocalHashAggregate.class, operand(BatchPhysicalTableSourceScan.class, none()))), - "PushLocalAggWithoutSortIntoTableSourceScanRule"); + "PushLocalHashAggIntoScanRule"); } @Override public boolean matches(RelOptRuleCall call) { - BatchPhysicalGroupAggregateBase localAggregate = call.rel(1); + BatchPhysicalLocalHashAggregate localAggregate = call.rel(1); BatchPhysicalTableSourceScan tableSourceScan = call.rel(2); return isMatch(call, localAggregate, tableSourceScan); } @Override public void onMatch(RelOptRuleCall call) { - BatchPhysicalGroupAggregateBase localHashAgg = call.rel(1); + BatchPhysicalLocalHashAggregate localHashAgg = call.rel(1); BatchPhysicalTableSourceScan oldScan = call.rel(2); pushLocalAggregateIntoScan(call, localHashAgg, oldScan); } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggWithSortIntoTableSourceScanRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithSortIntoScanRule.java similarity index 84% rename from flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggWithSortIntoTableSourceScanRule.java rename to flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithSortIntoScanRule.java index 3f11585743e9d..f012ce51c27ec 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggWithSortIntoTableSourceScanRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithSortIntoScanRule.java @@ -18,6 +18,7 @@ package org.apache.flink.table.planner.plan.rules.physical.batch; +import org.apache.flink.table.api.config.OptimizerConfigOptions; import org.apache.flink.table.connector.source.abilities.SupportsAggregatePushDown; import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalExchange; import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalGroupAggregateBase; @@ -31,10 +32,10 @@ /** * Planner rule that tries to push a local sort aggregate which with sort into a {@link * BatchPhysicalTableSourceScan} which table is a {@link TableSourceTable}. And the table source in - * the table is a {@link SupportsAggregatePushDown}. + * the table is a {@link SupportsAggregatePushDown}. The {@link + * OptimizerConfigOptions#TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED} need to be true. * - *

    When the {@code OptimizerConfigOptions.TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED} is - * true, we have the original physical plan: + *

    Suppose we have the original physical plan: * *

    {@code
      * BatchPhysicalSortAggregate (global)
    @@ -54,12 +55,11 @@
      *       +- BatchPhysicalTableSourceScan (with local aggregate pushed down)
      * }
    */ -public class PushLocalAggWithSortIntoTableSourceScanRule - extends PushLocalAggIntoTableSourceScanRuleBase { - public static final PushLocalAggWithSortIntoTableSourceScanRule INSTANCE = - new PushLocalAggWithSortIntoTableSourceScanRule(); +public class PushLocalSortAggWithSortIntoScanRule extends PushLocalAggIntoScanRuleBase { + public static final PushLocalSortAggWithSortIntoScanRule INSTANCE = + new PushLocalSortAggWithSortIntoScanRule(); - public PushLocalAggWithSortIntoTableSourceScanRule() { + public PushLocalSortAggWithSortIntoScanRule() { super( operand( BatchPhysicalExchange.class, @@ -68,7 +68,7 @@ public PushLocalAggWithSortIntoTableSourceScanRule() { operand( BatchPhysicalSort.class, operand(BatchPhysicalTableSourceScan.class, none())))), - "PushLocalAggWithSortIntoTableSourceScanRule"); + "PushLocalSortAggWithSortIntoScanRule"); } @Override diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithoutSortIntoScanRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithoutSortIntoScanRule.java new file mode 100644 index 0000000000000..ca1e442a66c40 --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithoutSortIntoScanRule.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.rules.physical.batch; + +import org.apache.flink.table.api.config.OptimizerConfigOptions; +import org.apache.flink.table.connector.source.abilities.SupportsAggregatePushDown; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalExchange; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalLocalSortAggregate; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalTableSourceScan; +import org.apache.flink.table.planner.plan.schema.TableSourceTable; + +import org.apache.calcite.plan.RelOptRuleCall; + +/** + * Planner rule that tries to push a local hash or sort aggregate which without sort into a {@link + * BatchPhysicalTableSourceScan} which table is a {@link TableSourceTable}. And the table source in + * the table is a {@link SupportsAggregatePushDown}. The {@link + * OptimizerConfigOptions#TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED} need to be true. + * + *

    Suppose we have the original physical plan: + * + *

    {@code
    + * BatchPhysicalHashAggregate (global)
    + * +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton)
    + *    +- BatchPhysicalLocalSortAggregate (local)
    + *       +- BatchPhysicalTableSourceScan
    + * }
    + * + *

    This physical plan will be rewritten to: + * + *

    {@code
    + * BatchPhysicalHashAggregate (global)
    + * +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton)
    + *    +- BatchPhysicalTableSourceScan (with local aggregate pushed down)
    + * }
    + */ +public class PushLocalSortAggWithoutSortIntoScanRule extends PushLocalAggIntoScanRuleBase { + public static final PushLocalSortAggWithoutSortIntoScanRule INSTANCE = + new PushLocalSortAggWithoutSortIntoScanRule(); + + public PushLocalSortAggWithoutSortIntoScanRule() { + super( + operand( + BatchPhysicalExchange.class, + operand( + BatchPhysicalLocalSortAggregate.class, + operand(BatchPhysicalTableSourceScan.class, none()))), + "PushLocalSortAggWithoutSortIntoScanRule"); + } + + @Override + public boolean matches(RelOptRuleCall call) { + BatchPhysicalLocalSortAggregate localAggregate = call.rel(1); + BatchPhysicalTableSourceScan tableSourceScan = call.rel(2); + return isMatch(call, localAggregate, tableSourceScan); + } + + @Override + public void onMatch(RelOptRuleCall call) { + BatchPhysicalLocalSortAggregate localHashAgg = call.rel(1); + BatchPhysicalTableSourceScan oldScan = call.rel(2); + pushLocalAggregateIntoScan(call, localHashAgg, oldScan); + } +} diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalTableSourceScan.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalTableSourceScan.scala index 8082a452fcc71..66bff2fe4dbf3 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalTableSourceScan.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalTableSourceScan.scala @@ -18,6 +18,12 @@ package org.apache.flink.table.planner.plan.nodes.physical.batch +import java.util + +import org.apache.calcite.plan._ +import org.apache.calcite.rel.RelNode +import org.apache.calcite.rel.hint.RelHint +import org.apache.calcite.rel.metadata.RelMetadataQuery import org.apache.flink.table.planner.calcite.FlinkTypeFactory import org.apache.flink.table.planner.plan.nodes.exec.ExecNode import org.apache.flink.table.planner.plan.nodes.exec.batch.BatchExecTableSourceScan @@ -26,13 +32,6 @@ import org.apache.flink.table.planner.plan.nodes.physical.common.CommonPhysicalT import org.apache.flink.table.planner.plan.schema.TableSourceTable import org.apache.flink.table.planner.plan.utils.FlinkRelOptUtil -import org.apache.calcite.plan._ -import org.apache.calcite.rel.RelNode -import org.apache.calcite.rel.hint.RelHint -import org.apache.calcite.rel.metadata.RelMetadataQuery - -import java.util - /** * Batch physical RelNode to read data from an external source defined by a * bounded [[org.apache.flink.table.connector.source.ScanTableSource]]. @@ -50,8 +49,8 @@ class BatchPhysicalTableSourceScan( } def copy( - traitSet: RelTraitSet, - tableSourceTable: TableSourceTable): BatchPhysicalTableSourceScan = { + traitSet: RelTraitSet, + tableSourceTable: TableSourceTable): BatchPhysicalTableSourceScan = { new BatchPhysicalTableSourceScan(cluster, traitSet, getHints, tableSourceTable) } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala index 269034fd11593..d7d8ea4b76797 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala @@ -449,7 +449,8 @@ object FlinkBatchRuleSets { val PHYSICAL_REWRITE: RuleSet = RuleSets.ofList( EnforceLocalHashAggRule.INSTANCE, EnforceLocalSortAggRule.INSTANCE, - PushLocalAggWithoutSortIntoTableSourceScanRule.INSTANCE, - PushLocalAggWithSortIntoTableSourceScanRule.INSTANCE + PushLocalHashAggIntoScanRule.INSTANCE, + PushLocalSortAggWithSortIntoScanRule.INSTANCE, + PushLocalSortAggWithoutSortIntoScanRule.INSTANCE ) } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/schema/TableSourceTable.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/schema/TableSourceTable.scala index 58dcddb0fff58..7f2db9c972c01 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/schema/TableSourceTable.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/schema/TableSourceTable.scala @@ -18,6 +18,11 @@ package org.apache.flink.table.planner.plan.schema +import java.util + +import com.google.common.collect.ImmutableList +import org.apache.calcite.plan.RelOptSchema +import org.apache.calcite.rel.`type`.RelDataType import org.apache.flink.table.catalog.{ObjectIdentifier, ResolvedCatalogTable} import org.apache.flink.table.connector.source.DynamicTableSource import org.apache.flink.table.planner.calcite.FlinkContext @@ -25,12 +30,6 @@ import org.apache.flink.table.planner.connectors.DynamicSourceUtils import org.apache.flink.table.planner.plan.abilities.source.{SourceAbilityContext, SourceAbilitySpec} import org.apache.flink.table.planner.plan.stats.FlinkStatistic -import com.google.common.collect.ImmutableList -import org.apache.calcite.plan.RelOptSchema -import org.apache.calcite.rel.`type`.RelDataType - -import java.util - /** * A [[FlinkPreparingTableBase]] implementation which defines the context variables * required to translate the Calcite [[org.apache.calcite.plan.RelOptTable]] to the Flink specific diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala index fd5e1c0486895..513d139505116 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala @@ -280,8 +280,8 @@ object AggregateUtil extends Enumeration { isBounded = false) } - def deriveSumAndCountFromAvg( - avgAggFunction: UserDefinedFunction): (Sum0AggFunction, CountAggFunction) = { + def deriveSumAndCountFromAvg(avgAggFunction: UserDefinedFunction + ): (Sum0AggFunction, CountAggFunction) = { avgAggFunction match { case _: ByteAvgAggFunction => (new ByteSum0AggFunction, new CountAggFunction) case _: ShortAvgAggFunction => (new ShortSum0AggFunction, new CountAggFunction) @@ -289,10 +289,9 @@ object AggregateUtil extends Enumeration { case _: LongAvgAggFunction => (new LongSum0AggFunction, new CountAggFunction) case _: FloatAvgAggFunction => (new FloatSum0AggFunction, new CountAggFunction) case _: DoubleAvgAggFunction => (new DoubleSum0AggFunction, new CountAggFunction) - case _ => { + case _ => throw new TableException(s"Avg aggregate function does not support: ''$avgAggFunction''" + s"Please re-check the function or data type.") - } } } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java index de958ea2090ad..55a1cde874cb5 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java @@ -88,6 +88,7 @@ import org.apache.flink.table.planner.utils.FilterUtils; import org.apache.flink.table.planner.utils.JavaScalaConversionUtil; import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.logical.BigIntType; import org.apache.flink.table.types.logical.LogicalType; import org.apache.flink.table.types.logical.utils.LogicalTypeParser; import org.apache.flink.table.types.utils.DataTypeUtils; @@ -1059,7 +1060,8 @@ public boolean applyAggregates( List aggregateExpressions, DataType producedDataType) { // this TestValuesScanTableSource only support simple group type ar present. - if (groupingSets.size() > 1) { + // auxGrouping is not supported. + if (groupingSets.size() > 1 && groupingSets.get(1).length > 0) { return false; } List aggExpressions = new ArrayList<>(); @@ -1078,6 +1080,19 @@ public boolean applyAggregates( || aggExpression.isDistinct()) { return false; } + + // only Long data type is supported in this unit test expect count() + if (aggExpression.getArgs().stream() + .anyMatch( + field -> + !(field.getOutputDataType().getLogicalType() + instanceof BigIntType) + && !(functionDefinition instanceof CountAggFunction + || functionDefinition + instanceof Count1AggFunction))) { + return false; + } + aggExpressions.add(aggExpression); } this.groupingSet = groupingSets.get(0); diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.java index 4ef1e888b1f0c..dcd1d636c0927 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.java @@ -18,8 +18,11 @@ package org.apache.flink.table.planner.plan.rules.physical.batch; +import org.apache.flink.table.api.DataTypes; import org.apache.flink.table.api.TableConfig; +import org.apache.flink.table.api.config.ExecutionConfigOptions; import org.apache.flink.table.api.config.OptimizerConfigOptions; +import org.apache.flink.table.planner.functions.aggfunctions.CollectAggFunction; import org.apache.flink.table.planner.utils.BatchTableTestUtil; import org.apache.flink.table.planner.utils.TableTestBase; @@ -27,8 +30,8 @@ import org.junit.Test; /** - * Test for {@link PushLocalAggWithoutSortIntoTableSourceScanRule} and {@link - * PushLocalAggWithSortIntoTableSourceScanRule}. + * Test for {@link PushLocalHashAggIntoScanRule}, {@link PushLocalSortAggWithSortIntoScanRule} and + * {@link PushLocalSortAggWithoutSortIntoScanRule}. */ public class PushLocalAggIntoTableSourceScanRuleTest extends TableTestBase { protected BatchTableTestUtil util = batchTestUtil(new TableConfig()); @@ -43,10 +46,10 @@ public void setup() { true); String ddl = "CREATE TABLE inventory (\n" - + " id INT,\n" + + " id BIGINT,\n" + " name STRING,\n" - + " amount INT,\n" - + " price DOUBLE,\n" + + " amount BIGINT,\n" + + " price BIGINT,\n" + " type STRING\n" + ") WITH (\n" + " 'connector' = 'values',\n" @@ -57,7 +60,7 @@ public void setup() { } @Test - public void testCanPushDownWithGroup() { + public void testCanPushDownLocalHashAggWithGroup() { util.verifyRelPlan( "SELECT\n" + " sum(amount),\n" @@ -68,18 +71,165 @@ public void testCanPushDownWithGroup() { } @Test - public void testCanPushDownWithoutGroup() { + public void testDisablePushDownLocalAgg() { + // disable push down local agg + util.getTableEnv() + .getConfig() + .getConfiguration() + .setBoolean( + OptimizerConfigOptions.TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED, + false); + + util.verifyRelPlan( + "SELECT\n" + + " sum(amount),\n" + + " name,\n" + + " type\n" + + "FROM inventory\n" + + " group by name, type"); + + // reset config + util.getTableEnv() + .getConfig() + .getConfiguration() + .setBoolean( + OptimizerConfigOptions.TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED, + true); + } + + @Test + public void testCanPushDownLocalHashAggWithoutGroup() { util.verifyRelPlan( "SELECT\n" + " min(id),\n" + " max(amount),\n" - + " max(name),\n" + " sum(price),\n" + " avg(price),\n" + " count(id)\n" + "FROM inventory"); } + @Test + public void testCanPushDownLocalSortAggWithoutSort() { + // enable sort agg + util.getTableEnv() + .getConfig() + .getConfiguration() + .setString(ExecutionConfigOptions.TABLE_EXEC_DISABLED_OPERATORS, "HashAgg"); + + util.verifyRelPlan( + "SELECT\n" + + " min(id),\n" + + " max(amount),\n" + + " sum(price),\n" + + " avg(price),\n" + + " count(id)\n" + + "FROM inventory"); + + // reset config + util.getTableEnv() + .getConfig() + .getConfiguration() + .setString(ExecutionConfigOptions.TABLE_EXEC_DISABLED_OPERATORS, ""); + } + + @Test + public void testCanPushDownLocalSortAggWithSort() { + // enable sort agg + util.getTableEnv() + .getConfig() + .getConfiguration() + .setString(ExecutionConfigOptions.TABLE_EXEC_DISABLED_OPERATORS, "HashAgg"); + + util.verifyRelPlan( + "SELECT\n" + + " sum(amount),\n" + + " name,\n" + + " type\n" + + "FROM inventory\n" + + " group by name, type"); + + // reset config + util.getTableEnv() + .getConfig() + .getConfiguration() + .setString(ExecutionConfigOptions.TABLE_EXEC_DISABLED_OPERATORS, ""); + } + + @Test + public void testCanPushDownLocalAggWithAuxGrouping() { + util.verifyRelPlan( + "SELECT\n" + + " name,\n" + + " a,\n" + + " p,\n" + + " count(*)\n" + + "FROM (\n" + + " SELECT\n" + + " name,\n" + + " sum(amount) as a,\n" + + " max(price) as p\n" + + " FROM inventory\n" + + " group by name\n" + + ") t\n" + + " group by name, a, p"); + } + + @Test + public void testCanPushDownLocalAggAfterFilterPushDown() { + + util.verifyRelPlan( + "SELECT\n" + + " sum(amount),\n" + + " name,\n" + + " type\n" + + "FROM inventory\n" + + " where id = 123\n" + + " group by name, type"); + } + + @Test + public void testCannotPushDownLocalAggAfterLimitPushDown() { + + util.verifyRelPlan( + "SELECT\n" + + " sum(amount),\n" + + " name,\n" + + " type\n" + + "FROM (\n" + + " SELECT\n" + + " *\n" + + " FROM inventory\n" + + " LIMIT 100\n" + + ") t\n" + + " group by name, type"); + } + + @Test + public void testCannotPushDownLocalAggWithUDAF() { + // add udf + util.addTemporarySystemFunction( + "udaf_collect", new CollectAggFunction<>(DataTypes.BIGINT().getLogicalType())); + + util.verifyRelPlan( + "SELECT\n" + + " udaf_collect(amount),\n" + + " name,\n" + + " type\n" + + "FROM inventory\n" + + " group by name, type"); + } + + @Test + public void testCannotPushDownLocalAggWithUnsupportedDataTypes() { + util.verifyRelPlan( + "SELECT\n" + + " max(name),\n" + + " type\n" + + "FROM inventory\n" + + " group by type"); + } + @Test public void testCannotPushDownWithColumnExpression() { util.verifyRelPlan( diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.xml index 2a5061a174ef4..7295012dd7da0 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.xml @@ -16,7 +16,7 @@ See the License for the specific language governing permissions and limitations under the License. --> - + - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -94,7 +329,7 @@ Calc(select=[EXPR$0, EXPR$1, EXPR$2, EXPR$3, name]) +- Exchange(distribution=[hash[name]]) +- LocalHashAggregate(groupBy=[name], select=[name, Partial_MIN($f1) AS min$0, Partial_MAX(amount) AS max$1, Partial_SUM(price) AS sum$2, Partial_COUNT(id) AS count$3]) +- Calc(select=[name, +(amount, price) AS $f1, amount, price, id]) - +- TableSourceScan(table=[[default_catalog, default_database, inventory, project=[name, amount, price, id]]], fields=[name, amount, price, id]) + +- TableSourceScan(table=[[default_catalog, default_database, inventory, project=[name, amount, price, id], metadata=[]]], fields=[name, amount, price, id]) ]]> @@ -128,7 +363,7 @@ Calc(select=[EXPR$0, EXPR$1, EXPR$2, EXPR$3, name]) +- Exchange(distribution=[hash[name, id, $e]]) +- LocalHashAggregate(groupBy=[name, id, $e], select=[name, id, $e, Partial_MIN(id_0) AS min$0, Partial_MAX(amount) AS max$1, Partial_SUM(price) AS sum$2]) +- Expand(projects=[{name, id, amount, price, 0 AS $e, id AS id_0}, {name, null AS id, amount, price, 1 AS $e, id AS id_0}]) - +- TableSourceScan(table=[[default_catalog, default_database, inventory, project=[name, id, amount, price]]], fields=[name, id, amount, price]) + +- TableSourceScan(table=[[default_catalog, default_database, inventory, project=[name, id, amount, price], metadata=[]]], fields=[name, id, amount, price]) ]]> @@ -143,17 +378,17 @@ FROM inventory]]> (COUNT($3) OVER (PARTITION BY $1), 0), $SUM0($3) OVER (PARTITION BY $1), null:DOUBLE)], name=[$1]) +LogicalProject(id=[$0], amount=[$2], EXPR$2=[CASE(>(COUNT($3) OVER (PARTITION BY $1), 0), $SUM0($3) OVER (PARTITION BY $1), null:BIGINT)], name=[$1]) +- LogicalTableScan(table=[[default_catalog, default_database, inventory]]) ]]> (w0$o0, 0:BIGINT), w0$o1, null:DOUBLE) AS EXPR$2, name]) +Calc(select=[id, amount, CASE(>(w0$o0, 0:BIGINT), w0$o1, null:BIGINT) AS EXPR$2, name]) +- OverAggregate(partitionBy=[name], window#0=[COUNT(price) AS w0$o0, $SUM0(price) AS w0$o1 RANG BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING], select=[id, name, amount, price, w0$o0, w0$o1]) +- Sort(orderBy=[name ASC]) +- Exchange(distribution=[hash[name]]) - +- TableSourceScan(table=[[default_catalog, default_database, inventory, project=[id, name, amount, price]]], fields=[id, name, amount, price]) + +- TableSourceScan(table=[[default_catalog, default_database, inventory, project=[id, name, amount, price], metadata=[]]], fields=[id, name, amount, price]) ]]> @@ -183,7 +418,7 @@ Calc(select=[EXPR$0, EXPR$1, EXPR$2, EXPR$3, name]) +- Exchange(distribution=[hash[name]]) +- LocalHashAggregate(groupBy=[name], select=[name, Partial_MIN(id) AS min$0, Partial_MAX(amount) AS max$1, Partial_SUM(price) AS sum$2, Partial_COUNT(id) FILTER $f4 AS count$3]) +- Calc(select=[name, id, amount, price, IS TRUE(>(id, 100)) AS $f4]) - +- TableSourceScan(table=[[default_catalog, default_database, inventory, project=[name, id, amount, price]]], fields=[name, id, amount, price]) + +- TableSourceScan(table=[[default_catalog, default_database, inventory, project=[name, id, amount, price], metadata=[]]], fields=[name, id, amount, price]) ]]> diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/LocalAggregatePushDownITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/LocalAggregatePushDownITCase.scala index 7f1c087be5498..7e20bc9af1f7c 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/LocalAggregatePushDownITCase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/LocalAggregatePushDownITCase.scala @@ -16,10 +16,12 @@ * limitations under the License. */ -package org.apache.flink.table.planner.runtime.batch.sql.agg +package org.apache.flink.table.planner.plan.batch.sql.agg -import org.apache.flink.table.api.config.OptimizerConfigOptions +import org.apache.flink.table.api.DataTypes +import org.apache.flink.table.api.config.{ExecutionConfigOptions, OptimizerConfigOptions} import org.apache.flink.table.planner.factories.TestValuesTableFactory +import org.apache.flink.table.planner.functions.aggfunctions.CollectAggFunction import org.apache.flink.table.planner.runtime.utils.BatchTestBase.row import org.apache.flink.table.planner.runtime.utils.{BatchTestBase, TestData} import org.junit.{Before, Test} @@ -41,10 +43,12 @@ class LocalAggregatePushDownITCase extends BatchTestBase { | name string, | height int, | gender string, - | deposit bigint + | deposit bigint, + | points bigint |) WITH ( | 'connector' = 'values', | 'data-id' = '$testDataId', + | 'filterable-fields' = 'id', | 'bounded' = 'true' |) """.stripMargin @@ -53,29 +57,240 @@ class LocalAggregatePushDownITCase extends BatchTestBase { } @Test - def testAggregateWithGroupBy(): Unit = { + def testPushDownLocalHashAggWithGroup(): Unit = { checkResult( """ |SELECT - | min(age) as min_age, - | max(height), - | avg(deposit), + | avg(deposit) as avg_dep, + | sum(deposit), + | count(1), + | gender + |FROM + | AggregatableTable + |GROUP BY gender + |ORDER BY avg_dep + |""".stripMargin, + Seq( + row(126, 630, 5, "f"), + row(220, 1320, 6, "m")) + ) + } + + @Test + def testDisablePushDownLocalAgg(): Unit = { + // disable push down local agg + tEnv.getConfig.getConfiguration.setBoolean( + OptimizerConfigOptions.TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED, + false) + + checkResult( + """ + |SELECT + | avg(deposit) as avg_dep, | sum(deposit), | count(1), | gender |FROM | AggregatableTable |GROUP BY gender - |ORDER BY min_age + |ORDER BY avg_dep + |""".stripMargin, + Seq( + row(126, 630, 5, "f"), + row(220, 1320, 6, "m")) + ) + + // reset config + tEnv.getConfig.getConfiguration.setBoolean( + OptimizerConfigOptions.TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED, + true) + } + + @Test + def testPushDownLocalHashAggWithOutGroup(): Unit = { + checkResult( + """ + |SELECT + | avg(deposit), + | sum(deposit), + | count(*) + |FROM + | AggregatableTable + |""".stripMargin, + Seq( + row(177, 1950, 11)) + ) + } + + @Test + def testPushDownLocalSortAggWithoutSort(): Unit = { + // enable sort agg + tEnv.getConfig.getConfiguration.setString( + ExecutionConfigOptions.TABLE_EXEC_DISABLED_OPERATORS, "HashAgg") + + checkResult( + """ + |SELECT + | avg(deposit), + | sum(deposit), + | count(*) + |FROM + | AggregatableTable + |""".stripMargin, + Seq( + row(177, 1950, 11)) + ) + + // reset config + tEnv.getConfig.getConfiguration.setString( + ExecutionConfigOptions.TABLE_EXEC_DISABLED_OPERATORS, "") + } + + @Test + def testPushDownLocalSortAggWithSort(): Unit = { + // enable sort agg + tEnv.getConfig.getConfiguration.setString( + ExecutionConfigOptions.TABLE_EXEC_DISABLED_OPERATORS, "HashAgg") + + checkResult( + """ + |SELECT + | avg(deposit), + | sum(deposit), + | count(1), + | gender, + | age + |FROM + | AggregatableTable + |GROUP BY gender, age + |""".stripMargin, + Seq( + row(50, 50, 1, "f", 19), + row(200, 200, 1, "f", 20), + row(250, 750, 3, "m", 23), + row(126, 380, 3, "f", 25), + row(300, 300, 1, "m", 27), + row(170, 170, 1, "m", 28), + row(100, 100, 1, "m", 34)) + ) + + // reset config + tEnv.getConfig.getConfiguration.setString( + ExecutionConfigOptions.TABLE_EXEC_DISABLED_OPERATORS, "") + } + + @Test + def testAggWithAuxGrouping(): Unit = { + checkResult( + """ + |SELECT + | name, + | d, + | p, + | count(*) + |FROM ( + | SELECT + | name, + | sum(deposit) as d, + | max(points) as p + | FROM AggregatableTable + | GROUP BY name + |) t + |GROUP BY name, d, p |""".stripMargin, Seq( - row(19, 180, 126, 630, 5, "f"), - row(23, 182, 220, 1320, 6, "m")) + row("tom", 200, 1000, 1), + row("mary", 100, 1000, 1), + row("jack", 150, 1300, 1), + row("rose", 100, 500, 1), + row("danny", 300, 300, 1), + row("tommas", 400, 4000, 1), + row("olivia", 50, 9000, 1), + row("stef", 100, 1900, 1), + row("emma", 180, 800, 1), + row("benji", 170, 11000, 1), + row("eva", 200, 1000, 1)) ) } @Test - def testAggregateWithMultiGroupBy(): Unit = { + def testPushDownLocalAggAfterFilterPushDown(): Unit = { + checkResult( + """ + |SELECT + | avg(deposit), + | sum(deposit), + | count(1), + | gender, + | age + |FROM + | AggregatableTable + |WHERE age <= 20 + |GROUP BY gender, age + |""".stripMargin, + Seq( + row(50, 50, 1, "f", 19), + row(200, 200, 1, "f", 20)) + ) + } + + @Test + def testLocalAggWithLimit(): Unit = { + checkResult( + """ + |SELECT + | avg(deposit) as avg_dep, + | sum(deposit), + | count(1), + | gender + |FROM + | ( + | SELECT * FROM AggregatableTable + | LIMIT 10 + | ) t + |GROUP BY gender + |ORDER BY avg_dep + |""".stripMargin, + Seq( + row(107, 430, 4, "f"), + row(220, 1320, 6, "m")) + ) + } + + @Test + def testLocalAggWithUDAF(): Unit = { + // add UDAF + tEnv.createTemporarySystemFunction( + "udaf_collect", + new CollectAggFunction(DataTypes.BIGINT().getLogicalType)) + + checkResult( + """ + |SELECT + | udaf_collect(deposit), + | count(1), + | gender, + | age + |FROM + | AggregatableTable + |GROUP BY gender, age + |""".stripMargin, + Seq( + row("{100=1}", 1, "m", 34), + row("{100=2, 180=1}", 3, "f", 25), + row("{170=1}", 1, "m", 28), + row("{200=1}", 1, "f", 20), + row("{300=1}", 1, "m", 27), + row("{400=1, 150=1, 200=1}", 3, "m", 23), + row("{50=1}", 1, "f", 19)) + ) + } + + @Test + def testLocalAggWithUnsupportedDataTypes(): Unit = { + // only agg on Long columns and count are supported to be pushed down + // in {@link TestValuesTableFactory} + checkResult( """ |SELECT @@ -102,41 +317,109 @@ class LocalAggregatePushDownITCase extends BatchTestBase { } @Test - def testAggregateWithoutGroupBy(): Unit = { + def testLocalAggWithColumnExpression1(): Unit = { checkResult( """ |SELECT - | min(age), - | max(height), | avg(deposit), + | sum(deposit + points), + | count(1), + | gender, + | age + |FROM + | AggregatableTable + |GROUP BY gender, age + |""".stripMargin, + Seq( + row(250, 7050, 3, "m", 23), + row(126, 2680, 3, "f", 25), + row(300, 600, 1, "m", 27), + row(50, 9050, 1, "f", 19), + row(100, 2000, 1, "m", 34), + row(170, 11170, 1, "m", 28), + row(200, 1200, 1, "f", 20)) + ) + } + + @Test + def testLocalAggWithColumnExpression2(): Unit = { + + checkResult( + """ + |SELECT + | avg(distinct deposit), | sum(deposit), - | count(*) + | count(1), + | gender, + | age |FROM | AggregatableTable + |GROUP BY gender, age |""".stripMargin, Seq( - row(19, 182, 177, 1950, 11)) + row(50, 50, 1, "f", 19), + row(200, 200, 1, "f", 20), + row(140, 380, 3, "f", 25), + row(250, 750, 3, "m", 23), + row(300, 300, 1, "m", 27), + row(170, 170, 1, "m", 28), + row(100, 100, 1, "m", 34)) ) } @Test - def testAggregateCanNotPushDown(): Unit = { + def testLocalAggWithWindow(): Unit = { + + checkResult( + """ + |SELECT + | avg(deposit) over (partition by gender), + | gender, + | age + |FROM + | AggregatableTable + |""".stripMargin, + Seq( + row(126, "f", 19), + row(126, "f", 20), + row(220, "m", 23), + row(220, "m", 23), + row(220, "m", 23), + row(126, "f", 25), + row(126, "f", 25), + row(126, "f", 25), + row(220, "m", 27), + row(220, "m", 28), + row(220, "m", 34) + ) + ) + } + + + @Test + def testLocalAggWithFilter(): Unit = { + checkResult( """ |SELECT - | min(age), - | max(height), | avg(deposit), - | sum(deposit), - | count(distinct age), - | gender + | sum(deposit) FILTER(WHERE deposit > 100), + | count(1), + | gender, + | age |FROM | AggregatableTable - |GROUP BY gender + |GROUP BY gender, age |""".stripMargin, Seq( - row(19, 180, 126, 630, 3, "f"), - row(23, 182, 220, 1320, 4, "m")) + row(100, null, 1, "m", 34), + row(126, 180, 3, "f", 25), + row(170, 170, 1, "m", 28), + row(200, 200, 1, "f", 20), + row(250, 750, 3, "m", 23), + row(300, 300, 1, "m", 27), + row(50, null, 1, "f", 19)) ) } + } diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/utils/TestData.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/utils/TestData.scala index 1e70ad1e92a35..e85f168d184c5 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/utils/TestData.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/utils/TestData.scala @@ -414,17 +414,17 @@ object TestData { // person test data lazy val personData: Seq[Row] = Seq( - row(1, 23, "tom", 172, "m", 200L), - row(2, 25, "mary", 161, "f", 100L), - row(3, 23, "jack", 182, "m", 150L), - row(4, 25, "rose", 165, "f", 100L), - row(5, 27, "danny", 175, "m", 300L), - row(6, 23, "tommas", 172, "m", 400L), - row(7, 19, "olivia", 172, "f", 50L), - row(8, 34, "stef", 170, "m", 100L), - row(9, 25, "emma", 171, "f", 180L), - row(10, 28, "benji", 165, "m", 170L), - row(11, 20, "eva", 180, "f", 200L) + row(1, 23, "tom", 172, "m", 200L, 1000L), + row(2, 25, "mary", 161, "f", 100L, 1000L), + row(3, 23, "jack", 182, "m", 150L, 1300L), + row(4, 25, "rose", 165, "f", 100L, 500L), + row(5, 27, "danny", 175, "m", 300L, 300L), + row(6, 23, "tommas", 172, "m", 400L, 4000L), + row(7, 19, "olivia", 172, "f", 50L, 9000L), + row(8, 34, "stef", 170, "m", 100L, 1900L), + row(9, 25, "emma", 171, "f", 180L, 800L), + row(10, 28, "benji", 165, "m", 170L, 11000L), + row(11, 20, "eva", 180, "f", 200L, 1000L) ) val nullablesOfPersonData = Array(true, true, true, true, true) From 4ceb389053dca74fe4097dd8c8721e6efe52eaa9 Mon Sep 17 00:00:00 2001 From: "Yu, Peng" Date: Fri, 24 Sep 2021 11:44:09 +0800 Subject: [PATCH 3/7] [FLINK-20895] code and comment optimizations --- .../optimizer_config_configuration.html | 6 ++++ .../api/config/OptimizerConfigOptions.java | 5 ++- .../source/AggregatePushDownSpec.java | 36 ++++++++----------- .../batch/PushLocalAggIntoScanRuleBase.java | 11 +++--- .../batch/PushLocalHashAggIntoScanRule.java | 4 +-- .../PushLocalSortAggWithSortIntoScanRule.java | 4 +-- ...shLocalSortAggWithoutSortIntoScanRule.java | 4 +-- 7 files changed, 32 insertions(+), 38 deletions(-) diff --git a/docs/layouts/shortcodes/generated/optimizer_config_configuration.html b/docs/layouts/shortcodes/generated/optimizer_config_configuration.html index f82b46a598890..61fc8ce678eea 100644 --- a/docs/layouts/shortcodes/generated/optimizer_config_configuration.html +++ b/docs/layouts/shortcodes/generated/optimizer_config_configuration.html @@ -59,6 +59,12 @@ Boolean When it is true, the optimizer will try to find out duplicated sub-plans and reuse them. + +
    table.optimizer.source.aggregate-pushdown-enabled

    Batch + true + Boolean + When it is true, the optimizer will push down the local aggregates into the TableSource which implements SupportsAggregatePushDown. +
    table.optimizer.source.predicate-pushdown-enabled

    Batch Streaming true diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/OptimizerConfigOptions.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/OptimizerConfigOptions.java index 7a1ad7a72615a..1e01256a18096 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/OptimizerConfigOptions.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/OptimizerConfigOptions.java @@ -97,11 +97,10 @@ public class OptimizerConfigOptions { public static final ConfigOption TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED = key("table.optimizer.source.aggregate-pushdown-enabled") .booleanType() - .defaultValue(false) + .defaultValue(true) .withDescription( "When it is true, the optimizer will push down the local aggregates into " - + "the TableSource which implements SupportsAggregatePushDown. " - + "Default value is false."); + + "the TableSource which implements SupportsAggregatePushDown."); @Documentation.TableOption(execMode = Documentation.ExecMode.BATCH_STREAMING) public static final ConfigOption TABLE_OPTIMIZER_SOURCE_PREDICATE_PUSHDOWN_ENABLED = diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/AggregatePushDownSpec.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/AggregatePushDownSpec.java index 24c7f061be09a..296523c8a5ddd 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/AggregatePushDownSpec.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/AggregatePushDownSpec.java @@ -95,32 +95,24 @@ public void apply(DynamicTableSource tableSource, SourceAbilityContext context) @Override public String getDigests(SourceAbilityContext context) { - String extraDigest; - String groupingStr = ""; int[] grouping = ArrayUtils.addAll(groupingSets.get(0), groupingSets.get(1)); - if (grouping.length > 0) { - groupingStr = - Arrays.stream(grouping) - .mapToObj(index -> inputType.getFieldNames().get(index)) - .collect(Collectors.joining(",")); - } - String aggFunctionsStr = ""; + String groupingStr = + Arrays.stream(grouping) + .mapToObj(index -> inputType.getFieldNames().get(index)) + .collect(Collectors.joining(",")); List aggregateExpressions = buildAggregateExpressions(inputType, aggregateCalls); - if (aggregateExpressions.size() > 0) { - aggFunctionsStr = - aggregateExpressions.stream() - .map(AggregateExpression::asSummaryString) - .collect(Collectors.joining(",")); - } - extraDigest = - "aggregates=[grouping=[" - + groupingStr - + "], aggFunctions=[" - + aggFunctionsStr - + "]]"; - return extraDigest; + String aggFunctionsStr = + aggregateExpressions.stream() + .map(AggregateExpression::asSummaryString) + .collect(Collectors.joining(",")); + + return "aggregates=[grouping=[" + + groupingStr + + "], aggFunctions=[" + + aggFunctionsStr + + "]]"; } public static boolean apply( diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoScanRuleBase.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoScanRuleBase.java index 6837e612bd764..9311b315c68c0 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoScanRuleBase.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoScanRuleBase.java @@ -44,7 +44,7 @@ /** * Planner rule that tries to push a local aggregator into an {@link BatchPhysicalTableSourceScan} - * which table is a {@link TableSourceTable}. And the table source in the table is a {@link + * whose table is a {@link TableSourceTable} with a source supporting {@link * SupportsAggregatePushDown}. * *

    The aggregate push down does not support a number of more complex statements at present: @@ -138,14 +138,11 @@ protected void pushLocalAggregateIntoScan( private FlinkStatistic getNewFlinkStatistic(TableSourceTable tableSourceTable) { FlinkStatistic oldStatistic = tableSourceTable.getStatistic(); - FlinkStatistic newStatistic; if (oldStatistic == FlinkStatistic.UNKNOWN()) { - newStatistic = oldStatistic; + return oldStatistic; } else { - // Remove tableStats after all of aggregate have been pushed down - newStatistic = - FlinkStatistic.builder().statistic(oldStatistic).tableStats(null).build(); + // Remove tableStats after all aggregates have been pushed down + return FlinkStatistic.builder().statistic(oldStatistic).tableStats(null).build(); } - return newStatistic; } } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalHashAggIntoScanRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalHashAggIntoScanRule.java index bfb3ee20e8375..61cdfc9ff17ed 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalHashAggIntoScanRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalHashAggIntoScanRule.java @@ -29,8 +29,8 @@ /** * Planner rule that tries to push a local hash or sort aggregate which without sort into a {@link - * BatchPhysicalTableSourceScan} which table is a {@link TableSourceTable}. And the table source in - * the table is a {@link SupportsAggregatePushDown}. The {@link + * BatchPhysicalTableSourceScan} whose table is a {@link TableSourceTable} with a source supporting + * {@link SupportsAggregatePushDown}. The {@link * OptimizerConfigOptions#TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED} need to be true. * *

    Suppose we have the original physical plan: diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithSortIntoScanRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithSortIntoScanRule.java index f012ce51c27ec..949d8242bc194 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithSortIntoScanRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithSortIntoScanRule.java @@ -31,8 +31,8 @@ /** * Planner rule that tries to push a local sort aggregate which with sort into a {@link - * BatchPhysicalTableSourceScan} which table is a {@link TableSourceTable}. And the table source in - * the table is a {@link SupportsAggregatePushDown}. The {@link + * BatchPhysicalTableSourceScan} whose table is a {@link TableSourceTable} with a source supporting + * {@link SupportsAggregatePushDown}. The {@link * OptimizerConfigOptions#TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED} need to be true. * *

    Suppose we have the original physical plan: diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithoutSortIntoScanRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithoutSortIntoScanRule.java index ca1e442a66c40..dfe8eca22b1e3 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithoutSortIntoScanRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithoutSortIntoScanRule.java @@ -29,8 +29,8 @@ /** * Planner rule that tries to push a local hash or sort aggregate which without sort into a {@link - * BatchPhysicalTableSourceScan} which table is a {@link TableSourceTable}. And the table source in - * the table is a {@link SupportsAggregatePushDown}. The {@link + * BatchPhysicalTableSourceScan} whose table is a {@link TableSourceTable} with a source supporting + * {@link SupportsAggregatePushDown}. The {@link * OptimizerConfigOptions#TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED} need to be true. * *

    Suppose we have the original physical plan: From 1dbeacec22c3c6635937db8e42009fe85abe34e9 Mon Sep 17 00:00:00 2001 From: "Yu, Peng" Date: Sat, 25 Sep 2021 10:26:50 +0800 Subject: [PATCH 4/7] [FLINK-20895] fix test errors in CI pipelines --- .../factories/TestValuesTableFactory.java | 70 +++++++++++-------- .../table/planner/utils/TableTestBase.scala | 9 ++- 2 files changed, 46 insertions(+), 33 deletions(-) diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java index 55a1cde874cb5..45381a449a9fb 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java @@ -880,46 +880,54 @@ public String asSummaryString() { } protected Collection convertToRowData(DataStructureConverter converter) { - List resultBuffer = new ArrayList<>(); + List result = new ArrayList<>(); List> keys = allPartitions.isEmpty() ? Collections.singletonList(Collections.emptyMap()) : allPartitions; + int numRetained = 0; - boolean overLimit = false; for (Map partition : keys) { - for (Row row : data.get(partition)) { - if (resultBuffer.size() >= limit) { - overLimit = true; - break; - } - boolean isRetained = - FilterUtils.isRetainedAfterApplyingFilterPredicates( - filterPredicates, getValueGetter(row)); - if (isRetained) { - final Row projectedRow = projectRow(row); - resultBuffer.add(projectedRow); - } - } - if (overLimit) { - break; + Collection rowsInPartition = data.get(partition); + + // handle predicates and projection + List rowsRetained = + rowsInPartition.stream() + .filter( + row -> + FilterUtils.isRetainedAfterApplyingFilterPredicates( + filterPredicates, getValueGetter(row))) + .map( + row -> { + Row projectedRow = projectRow(row); + projectedRow.setKind(row.getKind()); + return projectedRow; + }) + .collect(Collectors.toList()); + + // handle aggregates + if (!aggregateExpressions.isEmpty()) { + rowsRetained = applyAggregatesToRows(rowsRetained); } - } - // simulate aggregate operation - if (!aggregateExpressions.isEmpty()) { - resultBuffer = applyAggregatesToRows(resultBuffer); - } - List result = new ArrayList<>(); - for (Row row : resultBuffer) { - final RowData rowData = (RowData) converter.toInternal(row); - if (rowData != null) { - if (numRetained >= numElementToSkip) { - rowData.setRowKind(row.getKind()); - result.add(rowData); + + // handle row data + for (Row row : rowsRetained) { + final RowData rowData = (RowData) converter.toInternal(row); + if (rowData != null) { + if (numRetained >= numElementToSkip) { + rowData.setRowKind(row.getKind()); + result.add(rowData); + } + numRetained++; + } + + // handle limit. No aggregates will be pushed down when there is a limit. + if (result.size() >= limit) { + return result; } - numRetained++; } } + return result; } @@ -1059,7 +1067,7 @@ public boolean applyAggregates( List groupingSets, List aggregateExpressions, DataType producedDataType) { - // this TestValuesScanTableSource only support simple group type ar present. + // This TestValuesScanTableSource only supports simple group type ar present. // auxGrouping is not supported. if (groupingSets.size() > 1 && groupingSets.get(1).length > 0) { return false; diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala index d27af9c1c6697..c5204291a442b 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala @@ -22,7 +22,6 @@ import _root_.java.util import java.io.{File, IOException} import java.nio.file.{Files, Paths} import java.time.Duration - import org.apache.calcite.avatica.util.TimeUnit import org.apache.calcite.rel.RelNode import org.apache.calcite.sql.parser.SqlParserPos @@ -42,7 +41,7 @@ import org.apache.flink.table.api.bridge.java.internal.{StreamTableEnvironmentIm import org.apache.flink.table.api.bridge.java.{StreamTableEnvironment => JavaStreamTableEnv} import org.apache.flink.table.api.bridge.scala.internal.{StreamTableEnvironmentImpl => ScalaStreamTableEnvImpl} import org.apache.flink.table.api.bridge.scala.{StreamTableEnvironment => ScalaStreamTableEnv} -import org.apache.flink.table.api.config.ExecutionConfigOptions +import org.apache.flink.table.api.config.{ExecutionConfigOptions, OptimizerConfigOptions} import org.apache.flink.table.api.internal.{TableEnvironmentImpl, TableEnvironmentInternal, TableImpl} import org.apache.flink.table.catalog.{CatalogManager, FunctionCatalog, GenericInMemoryCatalog, ObjectIdentifier} import org.apache.flink.table.data.RowData @@ -1002,6 +1001,12 @@ abstract class TableTestUtil( .getConfiguration .set(ExecutionOptions.BATCH_SHUFFLE_MODE, BatchShuffleMode.ALL_EXCHANGES_PIPELINED) + // Disable push down aggregates to avoid conflicts with existing test cases that verify plans. + tableEnv.getConfig + .getConfiguration + .setBoolean( + OptimizerConfigOptions.TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED, false) + private val env: StreamExecutionEnvironment = getPlanner.getExecEnv override def getTableEnv: TableEnvironment = tableEnv From 289a257aef3351c1ffa07db5a435862cd98899d3 Mon Sep 17 00:00:00 2001 From: "Yu, Peng" Date: Mon, 25 Oct 2021 09:36:45 +0800 Subject: [PATCH 5/7] [FLINK-20895] add more rules and tests for local agg push down --- .../source/AggregatePushDownSpec.java | 18 +- .../batch/PushLocalAggIntoScanRuleBase.java | 139 +++++- .../batch/PushLocalHashAggIntoScanRule.java | 4 +- .../PushLocalHashAggWithCalcIntoScanRule.java | 92 ++++ ...java => PushLocalSortAggIntoScanRule.java} | 17 +- .../PushLocalSortAggWithCalcIntoScanRule.java | 94 ++++ ...calSortAggWithSortAndCalcIntoScanRule.java | 103 +++++ .../PushLocalSortAggWithSortIntoScanRule.java | 8 +- .../batch/BatchPhysicalTableSourceScan.scala | 13 +- .../plan/rules/FlinkBatchRuleSets.scala | 5 +- .../plan/schema/TableSourceTable.scala | 23 +- .../planner/plan/utils/AggregateUtil.scala | 8 +- .../factories/TestValuesTableFactory.java | 157 +++++-- ...shLocalAggIntoTableSourceScanRuleTest.java | 117 ++++- .../sql/agg/LocalAggregatePushDownITCase.java | 295 ++++++++++++ ...ushLocalAggIntoTableSourceScanRuleTest.xml | 124 ++++- .../agg/LocalAggregatePushDownITCase.scala | 425 ------------------ .../planner/runtime/utils/TestData.scala | 22 +- 18 files changed, 1099 insertions(+), 565 deletions(-) create mode 100755 flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalHashAggWithCalcIntoScanRule.java rename flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/{PushLocalSortAggWithoutSortIntoScanRule.java => PushLocalSortAggIntoScanRule.java} (83%) mode change 100644 => 100755 create mode 100755 flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithCalcIntoScanRule.java create mode 100755 flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithSortAndCalcIntoScanRule.java create mode 100755 flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/batch/sql/agg/LocalAggregatePushDownITCase.java delete mode 100644 flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/LocalAggregatePushDownITCase.scala diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/AggregatePushDownSpec.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/AggregatePushDownSpec.java index 296523c8a5ddd..fab8c21f416dd 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/AggregatePushDownSpec.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/AggregatePushDownSpec.java @@ -39,7 +39,6 @@ import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonTypeName; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.commons.lang3.ArrayUtils; import java.util.ArrayList; import java.util.Arrays; @@ -90,12 +89,18 @@ public AggregatePushDownSpec( @Override public void apply(DynamicTableSource tableSource, SourceAbilityContext context) { checkArgument(getProducedType().isPresent()); - apply(inputType, groupingSets, aggregateCalls, getProducedType().get(), tableSource); + apply( + inputType, + groupingSets, + aggregateCalls, + getProducedType().get(), + tableSource, + context); } @Override public String getDigests(SourceAbilityContext context) { - int[] grouping = ArrayUtils.addAll(groupingSets.get(0), groupingSets.get(1)); + int[] grouping = groupingSets.get(0); String groupingStr = Arrays.stream(grouping) .mapToObj(index -> inputType.getFieldNames().get(index)) @@ -120,7 +125,10 @@ public static boolean apply( List groupingSets, List aggregateCalls, RowType producedType, - DynamicTableSource tableSource) { + DynamicTableSource tableSource, + SourceAbilityContext context) { + assert context.isBatchMode(); + List aggregateExpressions = buildAggregateExpressions(inputType, aggregateCalls); @@ -160,7 +168,7 @@ private static List buildAggregateExpressions( } if (aggInfo.function() instanceof AvgAggFunction) { Tuple2 sum0AndCountFunction = - AggregateUtil.deriveSumAndCountFromAvg(aggInfo.function()); + AggregateUtil.deriveSumAndCountFromAvg((AvgAggFunction) aggInfo.function()); AggregateExpression sum0Expression = new AggregateExpression( sum0AndCountFunction._1(), diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoScanRuleBase.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoScanRuleBase.java index 9311b315c68c0..58461e42651f0 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoScanRuleBase.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoScanRuleBase.java @@ -24,7 +24,10 @@ import org.apache.flink.table.connector.source.abilities.SupportsAggregatePushDown; import org.apache.flink.table.planner.calcite.FlinkTypeFactory; import org.apache.flink.table.planner.plan.abilities.source.AggregatePushDownSpec; +import org.apache.flink.table.planner.plan.abilities.source.ProjectPushDownSpec; +import org.apache.flink.table.planner.plan.abilities.source.SourceAbilityContext; import org.apache.flink.table.planner.plan.abilities.source.SourceAbilitySpec; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalCalc; import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalExchange; import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalGroupAggregateBase; import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalTableSourceScan; @@ -38,9 +41,18 @@ import org.apache.calcite.plan.RelOptRuleCall; import org.apache.calcite.plan.RelOptRuleOperand; import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexProgram; +import org.apache.calcite.rex.RexSlot; +import org.apache.commons.lang3.ArrayUtils; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; /** * Planner rule that tries to push a local aggregator into an {@link BatchPhysicalTableSourceScan} @@ -62,7 +74,7 @@ public PushLocalAggIntoScanRuleBase(RelOptRuleOperand operand, String descriptio super(operand, description); } - protected boolean isMatch( + protected boolean checkMatchesAggregatePushDown( RelOptRuleCall call, BatchPhysicalGroupAggregateBase aggregate, BatchPhysicalTableSourceScan tableSourceScan) { @@ -100,9 +112,26 @@ protected void pushLocalAggregateIntoScan( RelOptRuleCall call, BatchPhysicalGroupAggregateBase localAgg, BatchPhysicalTableSourceScan oldScan) { + pushLocalAggregateIntoScan(call, localAgg, oldScan, null); + } + + protected void pushLocalAggregateIntoScan( + RelOptRuleCall call, + BatchPhysicalGroupAggregateBase localAgg, + BatchPhysicalTableSourceScan oldScan, + int[] calcRefFields) { RowType inputType = FlinkTypeFactory.toLogicalRowType(oldScan.getRowType()); - List groupingSets = Arrays.asList(localAgg.grouping(), localAgg.auxGrouping()); + List groupingSets = + Collections.singletonList( + ArrayUtils.addAll(localAgg.grouping(), localAgg.auxGrouping())); List aggCallList = JavaScalaConversionUtil.toJava(localAgg.getAggCallList()); + + // map arg index in aggregate to field index in scan through referred fields by calc. + if (calcRefFields != null) { + groupingSets = translateGroupingArgIndex(groupingSets, calcRefFields); + aggCallList = translateAggCallArgIndex(aggCallList, calcRefFields); + } + RowType producedType = FlinkTypeFactory.toLogicalRowType(localAgg.getRowType()); TableSourceTable oldTableSourceTable = oldScan.tableSourceTable(); @@ -110,24 +139,38 @@ protected void pushLocalAggregateIntoScan( boolean isPushDownSuccess = AggregatePushDownSpec.apply( - inputType, groupingSets, aggCallList, producedType, newTableSource); + inputType, + groupingSets, + aggCallList, + producedType, + newTableSource, + SourceAbilityContext.from(oldScan)); if (!isPushDownSuccess) { // aggregate push down failed, just return without changing any nodes. return; } - FlinkStatistic newFlinkStatistic = getNewFlinkStatistic(oldTableSourceTable); + // create new source table with new spec and statistic. AggregatePushDownSpec aggregatePushDownSpec = new AggregatePushDownSpec(inputType, groupingSets, aggCallList, producedType); + Set groupColumns = + Arrays.stream(groupingSets.get(0)) + .boxed() + .map(idx -> inputType.getFieldNames().get(idx)) + .collect(Collectors.toSet()); + FlinkStatistic newFlinkStatistic = getNewFlinkStatistic(oldTableSourceTable, groupColumns); + TableSourceTable newTableSourceTable = oldTableSourceTable .copy( newTableSource, - newFlinkStatistic, + localAgg.getRowType(), new SourceAbilitySpec[] {aggregatePushDownSpec}) - .copy(localAgg.getRowType()); + .copy(newFlinkStatistic); + + // transform to new nodes. BatchPhysicalTableSourceScan newScan = oldScan.copy(oldScan.getTraitSet(), newTableSourceTable); BatchPhysicalExchange oldExchange = call.rel(0); @@ -136,13 +179,85 @@ protected void pushLocalAggregateIntoScan( call.transformTo(newExchange); } - private FlinkStatistic getNewFlinkStatistic(TableSourceTable tableSourceTable) { + private FlinkStatistic getNewFlinkStatistic( + TableSourceTable tableSourceTable, Set groupColumns) { FlinkStatistic oldStatistic = tableSourceTable.getStatistic(); - if (oldStatistic == FlinkStatistic.UNKNOWN()) { - return oldStatistic; - } else { - // Remove tableStats after all aggregates have been pushed down - return FlinkStatistic.builder().statistic(oldStatistic).tableStats(null).build(); + + // Create new unique keys if there are group columns + Set> uniqueKeys = null; + if (!groupColumns.isEmpty()) { + uniqueKeys = new HashSet<>(); + uniqueKeys.add(groupColumns); } + + // Remove tableStats after all aggregates have been pushed down + return FlinkStatistic.builder() + .statistic(oldStatistic) + .uniqueKeys(uniqueKeys) + .tableStats(null) + .build(); + } + + protected boolean checkNoProjectionPushDown(BatchPhysicalTableSourceScan tableSourceScan) { + TableSourceTable tableSourceTable = tableSourceScan.tableSourceTable(); + return tableSourceTable != null + && Arrays.stream(tableSourceTable.abilitySpecs()) + .noneMatch(spec -> spec instanceof ProjectPushDownSpec); + } + + /** + * Currently, we only supports to push down aggregate above calc which has input ref only. + * + * @param calc BatchPhysicalCalc + * @return true if OK to be pushed down + */ + protected boolean checkCalcInputRefOnly(BatchPhysicalCalc calc) { + RexProgram program = calc.getProgram(); + + // check if condition exists. All filters should have been pushed down. + if (program.getCondition() != null) { + return false; + } + + return program.getExprList().stream().allMatch(RexInputRef.class::isInstance) + && !program.getProjectList().isEmpty(); + } + + protected int[] getRefFiledIndexFromCalc(BatchPhysicalCalc calc) { + return calc.getProgram().getProjectList().stream() + .map(RexSlot::getIndex) + .mapToInt(x -> x) + .toArray(); + } + + protected List translateGroupingArgIndex(List groupingSets, int[] refFields) { + List newGroupingSets = new ArrayList<>(); + groupingSets.forEach( + grouping -> { + int[] newGrouping = new int[grouping.length]; + for (int i = 0; i < grouping.length; i++) { + int argIndex = grouping[i]; + newGrouping[i] = refFields[argIndex]; + } + newGroupingSets.add(newGrouping); + }); + + return newGroupingSets; + } + + protected List translateAggCallArgIndex( + List aggCallList, int[] refFields) { + List newAggCallList = new ArrayList<>(); + aggCallList.forEach( + aggCall -> { + List argList = new ArrayList<>(); + for (int i = 0; i < aggCall.getArgList().size(); i++) { + int argIndex = aggCall.getArgList().get(i); + argList.add(refFields[argIndex]); + } + newAggCallList.add(aggCall.copy(argList, aggCall.filterArg, aggCall.collation)); + }); + + return newAggCallList; } } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalHashAggIntoScanRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalHashAggIntoScanRule.java index 61cdfc9ff17ed..efe3fb3547f39 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalHashAggIntoScanRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalHashAggIntoScanRule.java @@ -28,7 +28,7 @@ import org.apache.calcite.plan.RelOptRuleCall; /** - * Planner rule that tries to push a local hash or sort aggregate which without sort into a {@link + * Planner rule that tries to push a local hash aggregate which without sort into a {@link * BatchPhysicalTableSourceScan} whose table is a {@link TableSourceTable} with a source supporting * {@link SupportsAggregatePushDown}. The {@link * OptimizerConfigOptions#TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED} need to be true. @@ -67,7 +67,7 @@ public PushLocalHashAggIntoScanRule() { public boolean matches(RelOptRuleCall call) { BatchPhysicalLocalHashAggregate localAggregate = call.rel(1); BatchPhysicalTableSourceScan tableSourceScan = call.rel(2); - return isMatch(call, localAggregate, tableSourceScan); + return checkMatchesAggregatePushDown(call, localAggregate, tableSourceScan); } @Override diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalHashAggWithCalcIntoScanRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalHashAggWithCalcIntoScanRule.java new file mode 100755 index 0000000000000..318f49f77c65e --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalHashAggWithCalcIntoScanRule.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.rules.physical.batch; + +import org.apache.flink.table.api.config.OptimizerConfigOptions; +import org.apache.flink.table.connector.source.abilities.SupportsAggregatePushDown; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalCalc; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalExchange; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalLocalHashAggregate; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalTableSourceScan; +import org.apache.flink.table.planner.plan.schema.TableSourceTable; + +import org.apache.calcite.plan.RelOptRuleCall; + +/** + * Planner rule that tries to push a local hash aggregate which with calc into a {@link + * BatchPhysicalTableSourceScan} whose table is a {@link TableSourceTable} with a source supporting + * {@link SupportsAggregatePushDown}. The {@link + * OptimizerConfigOptions#TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED} need to be true. + * + *

    Suppose we have the original physical plan: + * + *

    {@code
    + * BatchPhysicalHashAggregate (global)
    + * +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton)
    + *    +- BatchPhysicalLocalHashAggregate (local)
    + *       +- BatchPhysicalCalc (filed projection only)
    + *          +- BatchPhysicalTableSourceScan
    + * }
    + * + *

    This physical plan will be rewritten to: + * + *

    {@code
    + * BatchPhysicalHashAggregate (global)
    + * +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton)
    + *    +- BatchPhysicalTableSourceScan (with local aggregate pushed down)
    + * }
    + */ +public class PushLocalHashAggWithCalcIntoScanRule extends PushLocalAggIntoScanRuleBase { + public static final PushLocalHashAggWithCalcIntoScanRule INSTANCE = + new PushLocalHashAggWithCalcIntoScanRule(); + + public PushLocalHashAggWithCalcIntoScanRule() { + super( + operand( + BatchPhysicalExchange.class, + operand( + BatchPhysicalLocalHashAggregate.class, + operand( + BatchPhysicalCalc.class, + operand(BatchPhysicalTableSourceScan.class, none())))), + "PushLocalHashAggWithCalcIntoScanRule"); + } + + @Override + public boolean matches(RelOptRuleCall call) { + BatchPhysicalLocalHashAggregate localHashAgg = call.rel(1); + BatchPhysicalCalc calc = call.rel(2); + BatchPhysicalTableSourceScan tableSourceScan = call.rel(3); + + return checkCalcInputRefOnly(calc) + && checkNoProjectionPushDown(tableSourceScan) + && checkMatchesAggregatePushDown(call, localHashAgg, tableSourceScan); + } + + @Override + public void onMatch(RelOptRuleCall call) { + BatchPhysicalLocalHashAggregate localHashAgg = call.rel(1); + BatchPhysicalCalc calc = call.rel(2); + BatchPhysicalTableSourceScan oldScan = call.rel(3); + + int[] calcRefFields = getRefFiledIndexFromCalc(calc); + + pushLocalAggregateIntoScan(call, localHashAgg, oldScan, calcRefFields); + } +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithoutSortIntoScanRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggIntoScanRule.java old mode 100644 new mode 100755 similarity index 83% rename from flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithoutSortIntoScanRule.java rename to flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggIntoScanRule.java index dfe8eca22b1e3..4ab6903dd34e2 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithoutSortIntoScanRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggIntoScanRule.java @@ -28,7 +28,7 @@ import org.apache.calcite.plan.RelOptRuleCall; /** - * Planner rule that tries to push a local hash or sort aggregate which without sort into a {@link + * Planner rule that tries to push a local sort aggregate which without sort into a {@link * BatchPhysicalTableSourceScan} whose table is a {@link TableSourceTable} with a source supporting * {@link SupportsAggregatePushDown}. The {@link * OptimizerConfigOptions#TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED} need to be true. @@ -36,7 +36,7 @@ *

    Suppose we have the original physical plan: * *

    {@code
    - * BatchPhysicalHashAggregate (global)
    + * BatchPhysicalSortAggregate (global)
      * +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton)
      *    +- BatchPhysicalLocalSortAggregate (local)
      *       +- BatchPhysicalTableSourceScan
    @@ -45,30 +45,29 @@
      * 

    This physical plan will be rewritten to: * *

    {@code
    - * BatchPhysicalHashAggregate (global)
    + * BatchPhysicalSortAggregate (global)
      * +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton)
      *    +- BatchPhysicalTableSourceScan (with local aggregate pushed down)
      * }
    */ -public class PushLocalSortAggWithoutSortIntoScanRule extends PushLocalAggIntoScanRuleBase { - public static final PushLocalSortAggWithoutSortIntoScanRule INSTANCE = - new PushLocalSortAggWithoutSortIntoScanRule(); +public class PushLocalSortAggIntoScanRule extends PushLocalAggIntoScanRuleBase { + public static final PushLocalSortAggIntoScanRule INSTANCE = new PushLocalSortAggIntoScanRule(); - public PushLocalSortAggWithoutSortIntoScanRule() { + public PushLocalSortAggIntoScanRule() { super( operand( BatchPhysicalExchange.class, operand( BatchPhysicalLocalSortAggregate.class, operand(BatchPhysicalTableSourceScan.class, none()))), - "PushLocalSortAggWithoutSortIntoScanRule"); + "PushLocalSortAggIntoScanRule"); } @Override public boolean matches(RelOptRuleCall call) { BatchPhysicalLocalSortAggregate localAggregate = call.rel(1); BatchPhysicalTableSourceScan tableSourceScan = call.rel(2); - return isMatch(call, localAggregate, tableSourceScan); + return checkMatchesAggregatePushDown(call, localAggregate, tableSourceScan); } @Override diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithCalcIntoScanRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithCalcIntoScanRule.java new file mode 100755 index 0000000000000..9ed72892bd6a2 --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithCalcIntoScanRule.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.rules.physical.batch; + +import org.apache.flink.table.api.config.OptimizerConfigOptions; +import org.apache.flink.table.connector.source.abilities.SupportsAggregatePushDown; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalCalc; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalExchange; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalLocalSortAggregate; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalTableSourceScan; +import org.apache.flink.table.planner.plan.schema.TableSourceTable; +import org.apache.flink.table.planner.plan.utils.RexNodeExtractor; + +import org.apache.calcite.plan.RelOptRuleCall; + +/** + * Planner rule that tries to push a local sort aggregate which without sort into a {@link + * BatchPhysicalTableSourceScan} whose table is a {@link TableSourceTable} with a source supporting + * {@link SupportsAggregatePushDown}. The {@link + * OptimizerConfigOptions#TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED} need to be true. + * + *

    Suppose we have the original physical plan: + * + *

    {@code
    + * BatchPhysicalSortAggregate (global)
    + * +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton)
    + *    +- BatchPhysicalLocalSortAggregate (local)
    + *       +- BatchPhysicalCalc (filed projection only)
    + *          +- BatchPhysicalTableSourceScan
    + * }
    + * + *

    This physical plan will be rewritten to: + * + *

    {@code
    + * BatchPhysicalSortAggregate (global)
    + * +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton)
    + *    +- BatchPhysicalTableSourceScan (with local aggregate pushed down)
    + * }
    + */ +public class PushLocalSortAggWithCalcIntoScanRule extends PushLocalAggIntoScanRuleBase { + public static final PushLocalSortAggWithCalcIntoScanRule INSTANCE = + new PushLocalSortAggWithCalcIntoScanRule(); + + public PushLocalSortAggWithCalcIntoScanRule() { + super( + operand( + BatchPhysicalExchange.class, + operand( + BatchPhysicalLocalSortAggregate.class, + operand( + BatchPhysicalCalc.class, + operand(BatchPhysicalTableSourceScan.class, none())))), + "PushLocalSortAggWithCalcIntoScanRule"); + } + + @Override + public boolean matches(RelOptRuleCall call) { + BatchPhysicalLocalSortAggregate localAggregate = call.rel(1); + BatchPhysicalCalc calc = call.rel(2); + BatchPhysicalTableSourceScan tableSourceScan = call.rel(3); + + return checkCalcInputRefOnly(calc) + && checkNoProjectionPushDown(tableSourceScan) + && checkMatchesAggregatePushDown(call, localAggregate, tableSourceScan); + } + + @Override + public void onMatch(RelOptRuleCall call) { + BatchPhysicalLocalSortAggregate localHashAgg = call.rel(1); + BatchPhysicalCalc calc = call.rel(2); + BatchPhysicalTableSourceScan oldScan = call.rel(3); + + int[] calcRefFields = + RexNodeExtractor.extractRefInputFields(calc.getProgram().getExprList()); + + pushLocalAggregateIntoScan(call, localHashAgg, oldScan, calcRefFields); + } +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithSortAndCalcIntoScanRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithSortAndCalcIntoScanRule.java new file mode 100755 index 0000000000000..68c31ea104167 --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithSortAndCalcIntoScanRule.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.rules.physical.batch; + +import org.apache.flink.table.api.config.OptimizerConfigOptions; +import org.apache.flink.table.connector.source.abilities.SupportsAggregatePushDown; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalCalc; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalExchange; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalGroupAggregateBase; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalLocalSortAggregate; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalSort; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalTableSourceScan; +import org.apache.flink.table.planner.plan.schema.TableSourceTable; +import org.apache.flink.table.planner.plan.utils.RexNodeExtractor; + +import org.apache.calcite.plan.RelOptRuleCall; + +/** + * Planner rule that tries to push a local sort aggregate which with sort into a {@link + * BatchPhysicalTableSourceScan} whose table is a {@link TableSourceTable} with a source supporting + * {@link SupportsAggregatePushDown}. The {@link + * OptimizerConfigOptions#TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED} need to be true. + * + *

    Suppose we have the original physical plan: + * + *

    {@code
    + * BatchPhysicalSortAggregate (global)
    + * +- BatchPhysicalSort (exists if group keys are not empty)
    + *    +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton)
    + *       +- BatchPhysicalLocalSortAggregate (local)
    + *          +- BatchPhysicalSort (exists if group keys are not empty)
    + *             +- BatchPhysicalCalc (filed projection only)
    + *                +- BatchPhysicalTableSourceScan
    + * }
    + * + *

    This physical plan will be rewritten to: + * + *

    {@code
    + * BatchPhysicalSortAggregate (global)
    + * +- BatchPhysicalSort (exists if group keys are not empty)
    + *    +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton)
    + *       +- BatchPhysicalTableSourceScan (with local aggregate pushed down)
    + * }
    + */ +public class PushLocalSortAggWithSortAndCalcIntoScanRule extends PushLocalAggIntoScanRuleBase { + public static final PushLocalSortAggWithSortAndCalcIntoScanRule INSTANCE = + new PushLocalSortAggWithSortAndCalcIntoScanRule(); + + public PushLocalSortAggWithSortAndCalcIntoScanRule() { + super( + operand( + BatchPhysicalExchange.class, + operand( + BatchPhysicalLocalSortAggregate.class, + operand( + BatchPhysicalSort.class, + operand( + BatchPhysicalCalc.class, + operand( + BatchPhysicalTableSourceScan.class, + none()))))), + "PushLocalSortAggWithSortAndCalcIntoScanRule"); + } + + @Override + public boolean matches(RelOptRuleCall call) { + BatchPhysicalGroupAggregateBase localAggregate = call.rel(1); + BatchPhysicalCalc calc = call.rel(3); + BatchPhysicalTableSourceScan tableSourceScan = call.rel(4); + + return checkCalcInputRefOnly(calc) + && checkNoProjectionPushDown(tableSourceScan) + && checkMatchesAggregatePushDown(call, localAggregate, tableSourceScan); + } + + @Override + public void onMatch(RelOptRuleCall call) { + BatchPhysicalGroupAggregateBase localSortAgg = call.rel(1); + BatchPhysicalCalc calc = call.rel(3); + BatchPhysicalTableSourceScan oldScan = call.rel(4); + + int[] calcRefFields = + RexNodeExtractor.extractRefInputFields(calc.getProgram().getExprList()); + + pushLocalAggregateIntoScan(call, localSortAgg, oldScan, calcRefFields); + } +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithSortIntoScanRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithSortIntoScanRule.java index 949d8242bc194..bfcc353f50bfe 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithSortIntoScanRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithSortIntoScanRule.java @@ -39,10 +39,10 @@ * *
    {@code
      * BatchPhysicalSortAggregate (global)
    - * +- Sort (exists if group keys are not empty)
    + * +- BatchPhysicalSort (exists if group keys are not empty)
      *    +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton)
      *       +- BatchPhysicalLocalSortAggregate (local)
    - *          +- Sort (exists if group keys are not empty)
    + *          +- BatchPhysicalSort (exists if group keys are not empty)
      *             +- BatchPhysicalTableSourceScan
      * }
    * @@ -50,7 +50,7 @@ * *
    {@code
      * BatchPhysicalSortAggregate (global)
    - * +- Sort (exists if group keys are not empty)
    + * +- BatchPhysicalSort (exists if group keys are not empty)
      *    +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton)
      *       +- BatchPhysicalTableSourceScan (with local aggregate pushed down)
      * }
    @@ -75,7 +75,7 @@ public PushLocalSortAggWithSortIntoScanRule() { public boolean matches(RelOptRuleCall call) { BatchPhysicalGroupAggregateBase localAggregate = call.rel(1); BatchPhysicalTableSourceScan tableSourceScan = call.rel(3); - return isMatch(call, localAggregate, tableSourceScan); + return checkMatchesAggregatePushDown(call, localAggregate, tableSourceScan); } @Override diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalTableSourceScan.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalTableSourceScan.scala index 66bff2fe4dbf3..0a3563d397e76 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalTableSourceScan.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalTableSourceScan.scala @@ -18,12 +18,6 @@ package org.apache.flink.table.planner.plan.nodes.physical.batch -import java.util - -import org.apache.calcite.plan._ -import org.apache.calcite.rel.RelNode -import org.apache.calcite.rel.hint.RelHint -import org.apache.calcite.rel.metadata.RelMetadataQuery import org.apache.flink.table.planner.calcite.FlinkTypeFactory import org.apache.flink.table.planner.plan.nodes.exec.ExecNode import org.apache.flink.table.planner.plan.nodes.exec.batch.BatchExecTableSourceScan @@ -32,6 +26,13 @@ import org.apache.flink.table.planner.plan.nodes.physical.common.CommonPhysicalT import org.apache.flink.table.planner.plan.schema.TableSourceTable import org.apache.flink.table.planner.plan.utils.FlinkRelOptUtil +import org.apache.calcite.plan._ +import org.apache.calcite.rel.RelNode +import org.apache.calcite.rel.hint.RelHint +import org.apache.calcite.rel.metadata.RelMetadataQuery + +import java.util + /** * Batch physical RelNode to read data from an external source defined by a * bounded [[org.apache.flink.table.connector.source.ScanTableSource]]. diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala index d7d8ea4b76797..28db5c8143eb3 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala @@ -450,7 +450,10 @@ object FlinkBatchRuleSets { EnforceLocalHashAggRule.INSTANCE, EnforceLocalSortAggRule.INSTANCE, PushLocalHashAggIntoScanRule.INSTANCE, + PushLocalHashAggWithCalcIntoScanRule.INSTANCE, + PushLocalSortAggIntoScanRule.INSTANCE, PushLocalSortAggWithSortIntoScanRule.INSTANCE, - PushLocalSortAggWithoutSortIntoScanRule.INSTANCE + PushLocalSortAggWithCalcIntoScanRule.INSTANCE, + PushLocalSortAggWithSortAndCalcIntoScanRule.INSTANCE ) } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/schema/TableSourceTable.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/schema/TableSourceTable.scala index 7f2db9c972c01..2ef50b4114d4a 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/schema/TableSourceTable.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/schema/TableSourceTable.scala @@ -18,11 +18,6 @@ package org.apache.flink.table.planner.plan.schema -import java.util - -import com.google.common.collect.ImmutableList -import org.apache.calcite.plan.RelOptSchema -import org.apache.calcite.rel.`type`.RelDataType import org.apache.flink.table.catalog.{ObjectIdentifier, ResolvedCatalogTable} import org.apache.flink.table.connector.source.DynamicTableSource import org.apache.flink.table.planner.calcite.FlinkContext @@ -30,6 +25,12 @@ import org.apache.flink.table.planner.connectors.DynamicSourceUtils import org.apache.flink.table.planner.plan.abilities.source.{SourceAbilityContext, SourceAbilitySpec} import org.apache.flink.table.planner.plan.stats.FlinkStatistic +import com.google.common.collect.ImmutableList +import org.apache.calcite.plan.RelOptSchema +import org.apache.calcite.rel.`type`.RelDataType + +import java.util + /** * A [[FlinkPreparingTableBase]] implementation which defines the context variables * required to translate the Calcite [[org.apache.calcite.plan.RelOptTable]] to the Flink specific @@ -132,17 +133,17 @@ class TableSourceTable( } /** - * Creates a copy of this table, changing the rowType + * Creates a copy of this table, changing the statistic * - * @param newRowType new row type - * @return New TableSourceTable instance with new row type + * @param newStatistic new table statistic + * @return New TableSourceTable instance with new statistic */ - def copy(newRowType: RelDataType): TableSourceTable = { + def copy(newStatistic: FlinkStatistic): TableSourceTable = { new TableSourceTable( relOptSchema, tableIdentifier, - newRowType, - statistic, + rowType, + newStatistic, tableSource, isStreamingMode, catalogTable, diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala index 513d139505116..8c820930fbfaf 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala @@ -27,7 +27,7 @@ import org.apache.flink.table.planner.JLong import org.apache.flink.table.planner.calcite.{FlinkTypeFactory, FlinkTypeSystem} import org.apache.flink.table.planner.delegation.PlannerBase import org.apache.flink.table.planner.expressions._ -import org.apache.flink.table.planner.functions.aggfunctions.{CountAggFunction, DeclarativeAggregateFunction, Sum0AggFunction} +import org.apache.flink.table.planner.functions.aggfunctions.{AvgAggFunction, CountAggFunction, DeclarativeAggregateFunction, Sum0AggFunction} import org.apache.flink.table.planner.functions.aggfunctions.AvgAggFunction.{ByteAvgAggFunction, DoubleAvgAggFunction, FloatAvgAggFunction, IntAvgAggFunction, LongAvgAggFunction, ShortAvgAggFunction} import org.apache.flink.table.planner.functions.aggfunctions.Sum0AggFunction.{ByteSum0AggFunction, DoubleSum0AggFunction, FloatSum0AggFunction, IntSum0AggFunction, LongSum0AggFunction, ShortSum0AggFunction} import org.apache.flink.table.planner.functions.bridging.BridgingSqlAggFunction @@ -52,7 +52,6 @@ import org.apache.flink.table.types.logical.LogicalTypeRoot._ import org.apache.flink.table.types.logical.utils.LogicalTypeChecks import org.apache.flink.table.types.logical.{LogicalTypeRoot, _} import org.apache.flink.table.types.utils.DataTypeUtils - import org.apache.calcite.rel.`type`._ import org.apache.calcite.rel.core.Aggregate.AggCallBinding import org.apache.calcite.rel.core.{Aggregate, AggregateCall} @@ -64,7 +63,6 @@ import org.apache.calcite.tools.RelBuilder import java.time.Duration import java.util - import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ import scala.collection.mutable @@ -280,8 +278,8 @@ object AggregateUtil extends Enumeration { isBounded = false) } - def deriveSumAndCountFromAvg(avgAggFunction: UserDefinedFunction - ): (Sum0AggFunction, CountAggFunction) = { + def deriveSumAndCountFromAvg( + avgAggFunction: AvgAggFunction): (Sum0AggFunction, CountAggFunction) = { avgAggFunction match { case _: ByteAvgAggFunction => (new ByteSum0AggFunction, new CountAggFunction) case _: ShortAvgAggFunction => (new ShortSum0AggFunction, new CountAggFunction) diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java index 45381a449a9fb..fcf8ea9929a73 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java @@ -298,6 +298,9 @@ private static RowKind parseRowKind(String rowKindShortString) { private static final ConfigOption SINK_EXPECTED_MESSAGES_NUM = ConfigOptions.key("sink-expected-messages-num").intType().defaultValue(-1); + private static final ConfigOption DISABLE_PROJECTION_PUSH_DOWN = + ConfigOptions.key("disable-projection-push-down").booleanType().defaultValue(false); + private static final ConfigOption NESTED_PROJECTION_SUPPORTED = ConfigOptions.key("nested-projection-supported").booleanType().defaultValue(false); @@ -373,6 +376,7 @@ public DynamicTableSource createDynamicTableSource(Context context) { boolean isAsync = helper.getOptions().get(ASYNC_ENABLED); String lookupFunctionClass = helper.getOptions().get(LOOKUP_FUNCTION_CLASS); boolean disableLookup = helper.getOptions().get(DISABLE_LOOKUP); + boolean disableProjectionPushDown = helper.getOptions().get(DISABLE_PROJECTION_PUSH_DOWN); boolean nestedProjectionSupported = helper.getOptions().get(NESTED_PROJECTION_SUPPORTED); boolean enableWatermarkPushDown = helper.getOptions().get(ENABLE_WATERMARK_PUSH_DOWN); boolean failingSource = helper.getOptions().get(FAILING_SOURCE); @@ -410,6 +414,25 @@ public DynamicTableSource createDynamicTableSource(Context context) { partition2Rows.put(Collections.emptyMap(), data); } + if (disableProjectionPushDown) { + return new TestValuesScanTableSourceWithoutProjectionPushDown( + producedDataType, + changelogMode, + isBounded, + runtimeSource, + failingSource, + partition2Rows, + nestedProjectionSupported, + null, + Collections.emptyList(), + filterableFieldsSet, + numElementToSkip, + Long.MAX_VALUE, + partitions, + readableMetadata, + null); + } + if (disableLookup) { if (enableWatermarkPushDown) { return new TestValuesScanTableSourceWithWatermarkPushDown( @@ -553,6 +576,7 @@ public Set> optionalOptions() { SINK_INSERT_ONLY, RUNTIME_SINK, SINK_EXPECTED_MESSAGES_NUM, + DISABLE_PROJECTION_PUSH_DOWN, NESTED_PROJECTION_SUPPORTED, FILTERABLE_FIELDS, PARTITION_LIST, @@ -691,10 +715,9 @@ private static Map convertToMetadataMap( // Table sources // -------------------------------------------------------------------------------------------- - /** Values {@link ScanTableSource} for testing. */ - private static class TestValuesScanTableSource + /** Values {@link ScanTableSource} for testing that disables projection push down. */ + private static class TestValuesScanTableSourceWithoutProjectionPushDown implements ScanTableSource, - SupportsProjectionPushDown, SupportsFilterPushDown, SupportsLimitPushDown, SupportsPartitionPushDown, @@ -721,7 +744,7 @@ private static class TestValuesScanTableSource private @Nullable int[] groupingSet; private List aggregateExpressions; - private TestValuesScanTableSource( + private TestValuesScanTableSourceWithoutProjectionPushDown( DataType producedDataType, ChangelogMode changelogMode, boolean bounded, @@ -820,17 +843,6 @@ public boolean isBounded() { } } - @Override - public boolean supportsNestedProjection() { - return nestedProjectionSupported; - } - - @Override - public void applyProjection(int[][] projectedFields) { - this.producedDataType = DataTypeUtils.projectRow(producedDataType, projectedFields); - this.projectedPhysicalFields = projectedFields; - } - @Override public Result applyFilters(List filters) { List acceptedFilters = new ArrayList<>(); @@ -856,7 +868,7 @@ private Function> getValueGetter(Row row) { @Override public DynamicTableSource copy() { - return new TestValuesScanTableSource( + return new TestValuesScanTableSourceWithoutProjectionPushDown( producedDataType, changelogMode, bounded, @@ -886,13 +898,22 @@ protected Collection convertToRowData(DataStructureConverter converter) ? Collections.singletonList(Collections.emptyMap()) : allPartitions; - int numRetained = 0; + int numSkipped = 0; for (Map partition : keys) { Collection rowsInPartition = data.get(partition); + // handle element skipping + int numToSkipInPartition = 0; + if (numSkipped < numElementToSkip) { + numToSkipInPartition = + Math.min(rowsInPartition.size(), numElementToSkip - numSkipped); + } + numSkipped += numToSkipInPartition; + // handle predicates and projection List rowsRetained = rowsInPartition.stream() + .skip(numToSkipInPartition) .filter( row -> FilterUtils.isRetainedAfterApplyingFilterPredicates( @@ -914,11 +935,8 @@ filterPredicates, getValueGetter(row))) for (Row row : rowsRetained) { final RowData rowData = (RowData) converter.toInternal(row); if (rowData != null) { - if (numRetained >= numElementToSkip) { - rowData.setRowKind(row.getKind()); - result.add(rowData); - } - numRetained++; + rowData.setRowKind(row.getKind()); + result.add(rowData); } // handle limit. No aggregates will be pushed down when there is a limit. @@ -968,15 +986,15 @@ private Row accumulateRows(List rows) { Row minRow = rows.stream() .min(Comparator.comparing(row -> row.getFieldAs(argIndex))) - .get(); - result.setField(i, minRow.getField(argIndex)); + .orElse(null); + result.setField(i, minRow != null ? minRow.getField(argIndex) : null); } else if (aggFunction instanceof MaxAggFunction) { int argIndex = arguments.get(0).getFieldIndex(); Row maxRow = rows.stream() .max(Comparator.comparing(row -> row.getFieldAs(argIndex))) - .get(); - result.setField(i, maxRow.getField(argIndex)); + .orElse(null); + result.setField(i, maxRow != null ? maxRow.getField(argIndex) : null); } else if (aggFunction instanceof SumAggFunction) { int argIndex = arguments.get(0).getFieldIndex(); Object finalSum = @@ -984,7 +1002,9 @@ private Row accumulateRows(List rows) { .filter(row -> row.getField(argIndex) != null) .mapToLong(row -> row.getFieldAs(argIndex)) .sum(); - result.setField(i, finalSum); + + boolean allNull = rows.stream().noneMatch(r -> r.getField(argIndex) != null); + result.setField(i, allNull ? null : finalSum); } else if (aggFunction instanceof Sum0AggFunction) { int argIndex = arguments.get(0).getFieldIndex(); Object finalSum0 = @@ -993,8 +1013,11 @@ private Row accumulateRows(List rows) { .mapToLong(row -> row.getFieldAs(argIndex)) .sum(); result.setField(i, finalSum0); - } else if (aggFunction instanceof CountAggFunction - || aggFunction instanceof Count1AggFunction) { + } else if (aggFunction instanceof CountAggFunction) { + int argIndex = arguments.get(0).getFieldIndex(); + long count = rows.stream().filter(r -> r.getField(argIndex) != null).count(); + result.setField(i, count); + } else if (aggFunction instanceof Count1AggFunction) { result.setField(i, (long) rows.size()); } } @@ -1067,9 +1090,8 @@ public boolean applyAggregates( List groupingSets, List aggregateExpressions, DataType producedDataType) { - // This TestValuesScanTableSource only supports simple group type ar present. - // auxGrouping is not supported. - if (groupingSets.size() > 1 && groupingSets.get(1).length > 0) { + // This TestValuesScanTableSource only supports single group aggregate ar present. + if (groupingSets.size() > 1) { return false; } List aggExpressions = new ArrayList<>(); @@ -1129,6 +1151,77 @@ public void applyReadableMetadata( } } + /** Values {@link ScanTableSource} for testing that supports projection push down. */ + private static class TestValuesScanTableSource + extends TestValuesScanTableSourceWithoutProjectionPushDown + implements SupportsProjectionPushDown { + + private TestValuesScanTableSource( + DataType producedDataType, + ChangelogMode changelogMode, + boolean bounded, + String runtimeSource, + boolean failingSource, + Map, Collection> data, + boolean nestedProjectionSupported, + @Nullable int[][] projectedPhysicalFields, + List filterPredicates, + Set filterableFields, + int numElementToSkip, + long limit, + List> allPartitions, + Map readableMetadata, + @Nullable int[] projectedMetadataFields) { + super( + producedDataType, + changelogMode, + bounded, + runtimeSource, + failingSource, + data, + nestedProjectionSupported, + projectedPhysicalFields, + filterPredicates, + filterableFields, + numElementToSkip, + limit, + allPartitions, + readableMetadata, + projectedMetadataFields); + } + + @Override + public DynamicTableSource copy() { + return new TestValuesScanTableSource( + producedDataType, + changelogMode, + bounded, + runtimeSource, + failingSource, + data, + nestedProjectionSupported, + projectedPhysicalFields, + filterPredicates, + filterableFields, + numElementToSkip, + limit, + allPartitions, + readableMetadata, + projectedMetadataFields); + } + + @Override + public boolean supportsNestedProjection() { + return nestedProjectionSupported; + } + + @Override + public void applyProjection(int[][] projectedFields) { + this.producedDataType = DataTypeUtils.projectRow(producedDataType, projectedFields); + this.projectedPhysicalFields = projectedFields; + } + } + /** Values {@link ScanTableSource} for testing that supports watermark push down. */ private static class TestValuesScanTableSourceWithWatermarkPushDown extends TestValuesScanTableSource diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.java index dcd1d636c0927..f01725b07dd46 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.java @@ -30,8 +30,8 @@ import org.junit.Test; /** - * Test for {@link PushLocalHashAggIntoScanRule}, {@link PushLocalSortAggWithSortIntoScanRule} and - * {@link PushLocalSortAggWithoutSortIntoScanRule}. + * Test for rules that extend {@link PushLocalAggIntoScanRuleBase} to push down local aggregates + * into table source. */ public class PushLocalAggIntoTableSourceScanRuleTest extends TableTestBase { protected BatchTableTestUtil util = batchTestUtil(new TableConfig()); @@ -53,10 +53,62 @@ public void setup() { + " type STRING\n" + ") WITH (\n" + " 'connector' = 'values',\n" - + " 'filterable-fields' = 'id',\n" + + " 'filterable-fields' = 'id;type',\n" + " 'bounded' = 'true'\n" + ")"; util.tableEnv().executeSql(ddl); + + String ddl2 = + "CREATE TABLE inventory_meta (\n" + + " id BIGINT,\n" + + " name STRING,\n" + + " amount BIGINT,\n" + + " price BIGINT,\n" + + " type STRING,\n" + + " metadata_1 BIGINT METADATA,\n" + + " metadata_2 STRING METADATA,\n" + + " PRIMARY KEY (`id`) NOT ENFORCED\n" + + ") WITH (\n" + + " 'connector' = 'values',\n" + + " 'filterable-fields' = 'id;type',\n" + + " 'readable-metadata' = 'metadata_1:BIGINT, metadata_2:STRING',\n" + + " 'bounded' = 'true'\n" + + ")"; + util.tableEnv().executeSql(ddl2); + + // partitioned table + String ddl3 = + "CREATE TABLE inventory_part (\n" + + " id BIGINT,\n" + + " name STRING,\n" + + " amount BIGINT,\n" + + " price BIGINT,\n" + + " type STRING\n" + + ") PARTITIONED BY (type)\n" + + "WITH (\n" + + " 'connector' = 'values',\n" + + " 'filterable-fields' = 'id;type',\n" + + " 'partition-list' = 'type:a;type:b',\n" + + " 'bounded' = 'true'\n" + + ")"; + util.tableEnv().executeSql(ddl3); + + // disable projection push down + String ddl4 = + "CREATE TABLE inventory_no_proj (\n" + + " id BIGINT,\n" + + " name STRING,\n" + + " amount BIGINT,\n" + + " price BIGINT,\n" + + " type STRING\n" + + ")\n" + + "WITH (\n" + + " 'connector' = 'values',\n" + + " 'filterable-fields' = 'id;type',\n" + + " 'disable-projection-push-down' = 'true',\n" + + " 'bounded' = 'true'\n" + + ")"; + util.tableEnv().executeSql(ddl4); } @Test @@ -157,37 +209,64 @@ public void testCanPushDownLocalSortAggWithSort() { } @Test - public void testCanPushDownLocalAggWithAuxGrouping() { + public void testCanPushDownLocalAggAfterFilterPushDown() { + util.verifyRelPlan( "SELECT\n" + + " sum(amount),\n" + " name,\n" - + " a,\n" - + " p,\n" - + " count(*)\n" - + "FROM (\n" - + " SELECT\n" - + " name,\n" - + " sum(amount) as a,\n" - + " max(price) as p\n" - + " FROM inventory\n" - + " group by name\n" - + ") t\n" - + " group by name, a, p"); + + " type\n" + + "FROM inventory\n" + + " where id = 123\n" + + " group by name, type"); } @Test - public void testCanPushDownLocalAggAfterFilterPushDown() { + public void testCanPushDownLocalAggWithMetadata() { + util.verifyRelPlan( + "SELECT\n" + + " sum(amount),\n" + + " max(metadata_1),\n" + + " name,\n" + + " type\n" + + "FROM inventory_meta\n" + + " where id = 123\n" + + " group by name, type"); + } + @Test + public void testCanPushDownLocalAggWithPartition() { + util.verifyRelPlan( + "SELECT\n" + + " sum(amount),\n" + + " type,\n" + + " name\n" + + "FROM inventory_part\n" + + " where type in ('a', 'b') and id = 123\n" + + " group by type, name"); + } + + @Test + public void testCanPushDownLocalAggWithoutProjectionPushDown() { util.verifyRelPlan( "SELECT\n" + " sum(amount),\n" + " name,\n" + " type\n" - + "FROM inventory\n" + + "FROM inventory_no_proj\n" + " where id = 123\n" + " group by name, type"); } + @Test + public void testCannotPushDownLocalAggWithAuxGrouping() { + util.verifyRelPlan( + "SELECT\n" + + " id, name, count(*)\n" + + "FROM inventory_meta\n" + + " group by id, name, abs(amount)"); + } + @Test public void testCannotPushDownLocalAggAfterLimitPushDown() { @@ -268,7 +347,7 @@ public void testCannotPushDownWithWindowAggFunction() { } @Test - public void testCannotPushDownWithFilter() { + public void testCannotPushDownWithArgFilter() { util.verifyRelPlan( "SELECT\n" + " min(id),\n" diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/batch/sql/agg/LocalAggregatePushDownITCase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/batch/sql/agg/LocalAggregatePushDownITCase.java new file mode 100755 index 0000000000000..763dbd79577bc --- /dev/null +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/batch/sql/agg/LocalAggregatePushDownITCase.java @@ -0,0 +1,295 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.runtime.batch.sql.agg; + +import org.apache.flink.table.api.config.ExecutionConfigOptions; +import org.apache.flink.table.api.config.OptimizerConfigOptions; +import org.apache.flink.table.planner.factories.TestValuesTableFactory; +import org.apache.flink.table.planner.runtime.utils.BatchTestBase; +import org.apache.flink.table.planner.runtime.utils.TestData; +import org.apache.flink.table.planner.utils.JavaScalaConversionUtil; +import org.apache.flink.types.Row; + +import org.junit.Before; +import org.junit.Test; + +import java.util.Arrays; +import java.util.Collections; + +/** Test for local aggregate push down. */ +public class LocalAggregatePushDownITCase extends BatchTestBase { + + @Before + public void before() { + super.before(); + env().setParallelism(1); // set sink parallelism to 1 + + String testDataId = TestValuesTableFactory.registerData(TestData.personData()); + String ddl = + "CREATE TABLE AggregatableTable (\n" + + " id int,\n" + + " age int,\n" + + " name string,\n" + + " height int,\n" + + " gender string,\n" + + " deposit bigint,\n" + + " points bigint,\n" + + " metadata_1 BIGINT METADATA,\n" + + " metadata_2 STRING METADATA\n" + + ") WITH (\n" + + " 'connector' = 'values',\n" + + " 'data-id' = '" + + testDataId + + "',\n" + + " 'filterable-fields' = 'id;age',\n" + + " 'readable-metadata' = 'metadata_1:BIGINT, metadata_2:STRING',\n" + + " 'bounded' = 'true'\n" + + ")"; + tEnv().executeSql(ddl); + + // partitioned table + String testDataId2 = TestValuesTableFactory.registerData(TestData.personData()); + String ddl2 = + "CREATE TABLE AggregatableTable_Part (\n" + + " id int,\n" + + " age int,\n" + + " name string,\n" + + " height int,\n" + + " gender string,\n" + + " deposit bigint,\n" + + " points bigint,\n" + + " distance BIGINT,\n" + + " type STRING\n" + + ") PARTITIONED BY (type)\n" + + "WITH (\n" + + " 'connector' = 'values',\n" + + " 'data-id' = '" + + testDataId + + "',\n" + + " 'filterable-fields' = 'id;age',\n" + + " 'partition-list' = 'type:A;type:B;type:C;type:D',\n" + + " 'bounded' = 'true'\n" + + ")"; + tEnv().executeSql(ddl2); + + // partitioned table + String testDataId3 = TestValuesTableFactory.registerData(TestData.personData()); + String ddl3 = + "CREATE TABLE AggregatableTable_No_Proj (\n" + + " id int,\n" + + " age int,\n" + + " name string,\n" + + " height int,\n" + + " gender string,\n" + + " deposit bigint,\n" + + " points bigint,\n" + + " distance BIGINT,\n" + + " type STRING\n" + + ")\n" + + "WITH (\n" + + " 'connector' = 'values',\n" + + " 'data-id' = '" + + testDataId + + "',\n" + + " 'filterable-fields' = 'id;age',\n" + + " 'disable-projection-push-down' = 'true',\n" + + " 'bounded' = 'true'\n" + + ")"; + tEnv().executeSql(ddl3); + } + + @Test + public void testPushDownLocalHashAggWithGroup() { + checkResult( + "SELECT\n" + + " avg(deposit) as avg_dep,\n" + + " sum(deposit),\n" + + " count(1),\n" + + " gender\n" + + "FROM\n" + + " AggregatableTable\n" + + "GROUP BY gender\n" + + "ORDER BY avg_dep", + JavaScalaConversionUtil.toScala( + Arrays.asList(Row.of(126, 630, 5, "f"), Row.of(220, 1320, 6, "m"))), + false); + } + + @Test + public void testDisablePushDownLocalAgg() { + // disable push down local agg + tEnv().getConfig() + .getConfiguration() + .setBoolean( + OptimizerConfigOptions.TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED, + false); + + checkResult( + "SELECT\n" + + " avg(deposit) as avg_dep,\n" + + " sum(deposit),\n" + + " count(1),\n" + + " gender\n" + + "FROM\n" + + " AggregatableTable\n" + + "GROUP BY gender\n" + + "ORDER BY avg_dep", + JavaScalaConversionUtil.toScala( + Arrays.asList(Row.of(126, 630, 5, "f"), Row.of(220, 1320, 6, "m"))), + false); + } + + @Test + public void testPushDownLocalHashAggWithoutGroup() { + checkResult( + "SELECT\n" + + " avg(deposit),\n" + + " sum(deposit),\n" + + " count(*)\n" + + "FROM\n" + + " AggregatableTable", + JavaScalaConversionUtil.toScala(Collections.singletonList(Row.of(177, 1950, 11))), + false); + } + + @Test + public void testPushDownLocalSortAggWithoutSort() { + // enable sort agg + tEnv().getConfig() + .getConfiguration() + .setString(ExecutionConfigOptions.TABLE_EXEC_DISABLED_OPERATORS, "HashAgg"); + + checkResult( + "SELECT\n" + + " avg(deposit),\n" + + " sum(deposit),\n" + + " count(*)\n" + + "FROM\n" + + " AggregatableTable", + JavaScalaConversionUtil.toScala(Collections.singletonList(Row.of(177, 1950, 11))), + false); + } + + @Test + public void testPushDownLocalSortAggWithSort() { + // enable sort agg + tEnv().getConfig() + .getConfiguration() + .setString(ExecutionConfigOptions.TABLE_EXEC_DISABLED_OPERATORS, "HashAgg"); + + checkResult( + "SELECT\n" + + " avg(deposit),\n" + + " sum(deposit),\n" + + " count(1),\n" + + " gender,\n" + + " age\n" + + "FROM\n" + + " AggregatableTable\n" + + "GROUP BY gender, age", + JavaScalaConversionUtil.toScala( + Arrays.asList( + Row.of(50, 50, 1, "f", 19), + Row.of(200, 200, 1, "f", 20), + Row.of(250, 750, 3, "m", 23), + Row.of(126, 380, 3, "f", 25), + Row.of(300, 300, 1, "m", 27), + Row.of(170, 170, 1, "m", 28), + Row.of(100, 100, 1, "m", 34))), + false); + } + + @Test + public void testPushDownLocalAggAfterFilterPushDown() { + checkResult( + "SELECT\n" + + " avg(deposit),\n" + + " sum(deposit),\n" + + " count(1),\n" + + " gender,\n" + + " age\n" + + "FROM\n" + + " AggregatableTable\n" + + "WHERE age <= 20\n" + + "GROUP BY gender, age", + JavaScalaConversionUtil.toScala( + Arrays.asList(Row.of(50, 50, 1, "f", 19), Row.of(200, 200, 1, "f", 20))), + false); + } + + @Test + public void testPushDownLocalAggWithMetadata() { + checkResult( + "SELECT\n" + + " sum(metadata_1),\n" + + " metadata_2\n" + + "FROM\n" + + " AggregatableTable\n" + + "GROUP BY metadata_2", + JavaScalaConversionUtil.toScala( + Arrays.asList( + Row.of(156, 'C'), + Row.of(183, 'A'), + Row.of(51, 'D'), + Row.of(70, 'B'))), + false); + } + + @Test + public void testPushDownLocalAggWithPartition() { + checkResult( + "SELECT\n" + + " sum(deposit),\n" + + " count(1),\n" + + " type,\n" + + " name\n" + + "FROM\n" + + " AggregatableTable_Part\n" + + "WHERE type in ('A', 'C')" + + "GROUP BY type, name", + JavaScalaConversionUtil.toScala( + Arrays.asList( + Row.of(150, 1, "C", "jack"), + Row.of(180, 1, "A", "emma"), + Row.of(200, 1, "A", "tom"), + Row.of(200, 1, "C", "eva"), + Row.of(300, 1, "C", "danny"), + Row.of(400, 1, "A", "tommas"), + Row.of(50, 1, "C", "olivia"))), + false); + } + + @Test + public void testPushDownLocalAggWithoutProjectionPushDown() { + checkResult( + "SELECT\n" + + " avg(deposit),\n" + + " sum(deposit),\n" + + " count(1),\n" + + " gender,\n" + + " age\n" + + "FROM\n" + + " AggregatableTable_No_Proj\n" + + "WHERE age <= 20\n" + + "GROUP BY gender, age", + JavaScalaConversionUtil.toScala( + Arrays.asList(Row.of(50, 50, 1, "f", 19), Row.of(200, 200, 1, "f", 20))), + false); + } +} diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.xml index 7295012dd7da0..f379f0426df6b 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.xml @@ -146,37 +146,29 @@ Calc(select=[EXPR$0, name, type]) ]]> - + + id, name, count(*) +FROM inventory_meta + group by id, name, abs(amount)]]> @@ -205,6 +197,92 @@ Calc(select=[EXPR$0, name, type]) +- HashAggregate(isMerge=[true], groupBy=[name, type], select=[name, type, Final_SUM(sum$0) AS EXPR$0]) +- Exchange(distribution=[hash[name, type]]) +- TableSourceScan(table=[[default_catalog, default_database, inventory, filter=[=(id, 123:BIGINT)], project=[name, type, amount], metadata=[], aggregates=[grouping=[name,type], aggFunctions=[LongSumAggFunction(amount)]]]], fields=[name, type, sum$0]) +]]> + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -392,7 +470,7 @@ Calc(select=[id, amount, CASE(>(w0$o0, 0:BIGINT), w0$o1, null:BIGINT) AS EXPR$2, ]]> - + 100), - | count(1), - | gender, - | age - |FROM - | AggregatableTable - |GROUP BY gender, age - |""".stripMargin, - Seq( - row(100, null, 1, "m", 34), - row(126, 180, 3, "f", 25), - row(170, 170, 1, "m", 28), - row(200, 200, 1, "f", 20), - row(250, 750, 3, "m", 23), - row(300, 300, 1, "m", 27), - row(50, null, 1, "f", 19)) - ) - } - -} diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/utils/TestData.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/utils/TestData.scala index e85f168d184c5..9fe60baf2656c 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/utils/TestData.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/utils/TestData.scala @@ -414,17 +414,17 @@ object TestData { // person test data lazy val personData: Seq[Row] = Seq( - row(1, 23, "tom", 172, "m", 200L, 1000L), - row(2, 25, "mary", 161, "f", 100L, 1000L), - row(3, 23, "jack", 182, "m", 150L, 1300L), - row(4, 25, "rose", 165, "f", 100L, 500L), - row(5, 27, "danny", 175, "m", 300L, 300L), - row(6, 23, "tommas", 172, "m", 400L, 4000L), - row(7, 19, "olivia", 172, "f", 50L, 9000L), - row(8, 34, "stef", 170, "m", 100L, 1900L), - row(9, 25, "emma", 171, "f", 180L, 800L), - row(10, 28, "benji", 165, "m", 170L, 11000L), - row(11, 20, "eva", 180, "f", 200L, 1000L) + row(1, 23, "tom", 172, "m", 200L, 1000L, 15L, "A"), + row(2, 25, "mary", 161, "f", 100L, 1000L, 25L, "B"), + row(3, 23, "jack", 182, "m", 150L, 1300L, 35L, "C"), + row(4, 25, "rose", 165, "f", 100L, 500L, 45L, "B"), + row(5, 27, "danny", 175, "m", 300L, 300L, 54L, "C"), + row(6, 23, "tommas", 172, "m", 400L, 4000L, 53L, "A"), + row(7, 19, "olivia", 172, "f", 50L, 9000L, 52L, "C"), + row(8, 34, "stef", 170, "m", 100L, 1900L, 51L, "D"), + row(9, 25, "emma", 171, "f", 180L, 800L, 115L, "A"), + row(10, 28, "benji", 165, "m", 170L, 11000L, 0L, "B"), + row(11, 20, "eva", 180, "f", 200L, 1000L, 15L, "C") ) val nullablesOfPersonData = Array(true, true, true, true, true) From 0b0584f2260c1706c02eef5966302bf824021047 Mon Sep 17 00:00:00 2001 From: "Yu, Peng" Date: Thu, 28 Oct 2021 20:18:36 +0800 Subject: [PATCH 6/7] [FLINK-20895] optimize code and add test cases for auxGrouping --- .../source/AggregatePushDownSpec.java | 2 +- .../batch/PushLocalAggIntoScanRuleBase.java | 57 ++++++------------- .../batch/PushLocalHashAggIntoScanRule.java | 2 +- .../PushLocalHashAggWithCalcIntoScanRule.java | 8 +-- .../batch/PushLocalSortAggIntoScanRule.java | 2 +- .../PushLocalSortAggWithCalcIntoScanRule.java | 10 ++-- ...calSortAggWithSortAndCalcIntoScanRule.java | 12 ++-- .../PushLocalSortAggWithSortIntoScanRule.java | 2 +- .../factories/TestValuesTableFactory.java | 10 ++-- ...shLocalAggIntoTableSourceScanRuleTest.java | 12 +++- .../sql/agg/LocalAggregatePushDownITCase.java | 31 ++++++++-- ...ushLocalAggIntoTableSourceScanRuleTest.xml | 20 +++---- 12 files changed, 83 insertions(+), 85 deletions(-) diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/AggregatePushDownSpec.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/AggregatePushDownSpec.java index fab8c21f416dd..f8fe88720914f 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/AggregatePushDownSpec.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/AggregatePushDownSpec.java @@ -127,7 +127,7 @@ public static boolean apply( RowType producedType, DynamicTableSource tableSource, SourceAbilityContext context) { - assert context.isBatchMode(); + assert context.isBatchMode() && groupingSets.size() == 1; List aggregateExpressions = buildAggregateExpressions(inputType, aggregateCalls); diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoScanRuleBase.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoScanRuleBase.java index 58461e42651f0..fe9a8584b02aa 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoScanRuleBase.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoScanRuleBase.java @@ -33,6 +33,7 @@ import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalTableSourceScan; import org.apache.flink.table.planner.plan.schema.TableSourceTable; import org.apache.flink.table.planner.plan.stats.FlinkStatistic; +import org.apache.flink.table.planner.plan.utils.RexNodeExtractor; import org.apache.flink.table.planner.utils.JavaScalaConversionUtil; import org.apache.flink.table.planner.utils.ShortcutUtils; import org.apache.flink.table.types.logical.RowType; @@ -42,16 +43,14 @@ import org.apache.calcite.plan.RelOptRuleOperand; import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexProgram; -import org.apache.calcite.rex.RexSlot; import org.apache.commons.lang3.ArrayUtils; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; -import java.util.HashSet; import java.util.List; -import java.util.Set; import java.util.stream.Collectors; /** @@ -74,7 +73,7 @@ public PushLocalAggIntoScanRuleBase(RelOptRuleOperand operand, String descriptio super(operand, description); } - protected boolean checkMatchesAggregatePushDown( + protected boolean canPushDown( RelOptRuleCall call, BatchPhysicalGroupAggregateBase aggregate, BatchPhysicalTableSourceScan tableSourceScan) { @@ -155,20 +154,13 @@ protected void pushLocalAggregateIntoScan( AggregatePushDownSpec aggregatePushDownSpec = new AggregatePushDownSpec(inputType, groupingSets, aggCallList, producedType); - Set groupColumns = - Arrays.stream(groupingSets.get(0)) - .boxed() - .map(idx -> inputType.getFieldNames().get(idx)) - .collect(Collectors.toSet()); - FlinkStatistic newFlinkStatistic = getNewFlinkStatistic(oldTableSourceTable, groupColumns); - TableSourceTable newTableSourceTable = oldTableSourceTable .copy( newTableSource, localAgg.getRowType(), new SourceAbilitySpec[] {aggregatePushDownSpec}) - .copy(newFlinkStatistic); + .copy(FlinkStatistic.UNKNOWN()); // transform to new nodes. BatchPhysicalTableSourceScan newScan = @@ -179,26 +171,7 @@ protected void pushLocalAggregateIntoScan( call.transformTo(newExchange); } - private FlinkStatistic getNewFlinkStatistic( - TableSourceTable tableSourceTable, Set groupColumns) { - FlinkStatistic oldStatistic = tableSourceTable.getStatistic(); - - // Create new unique keys if there are group columns - Set> uniqueKeys = null; - if (!groupColumns.isEmpty()) { - uniqueKeys = new HashSet<>(); - uniqueKeys.add(groupColumns); - } - - // Remove tableStats after all aggregates have been pushed down - return FlinkStatistic.builder() - .statistic(oldStatistic) - .uniqueKeys(uniqueKeys) - .tableStats(null) - .build(); - } - - protected boolean checkNoProjectionPushDown(BatchPhysicalTableSourceScan tableSourceScan) { + protected boolean isProjectionNotPushedDown(BatchPhysicalTableSourceScan tableSourceScan) { TableSourceTable tableSourceTable = tableSourceScan.tableSourceTable(); return tableSourceTable != null && Arrays.stream(tableSourceTable.abilitySpecs()) @@ -211,7 +184,7 @@ protected boolean checkNoProjectionPushDown(BatchPhysicalTableSourceScan tableSo * @param calc BatchPhysicalCalc * @return true if OK to be pushed down */ - protected boolean checkCalcInputRefOnly(BatchPhysicalCalc calc) { + protected boolean isInputRefOnly(BatchPhysicalCalc calc) { RexProgram program = calc.getProgram(); // check if condition exists. All filters should have been pushed down. @@ -219,15 +192,19 @@ protected boolean checkCalcInputRefOnly(BatchPhysicalCalc calc) { return false; } - return program.getExprList().stream().allMatch(RexInputRef.class::isInstance) - && !program.getProjectList().isEmpty(); + return !program.getProjectList().isEmpty() + && program.getProjectList().stream() + .map(calc.getProgram()::expandLocalRef) + .allMatch(RexInputRef.class::isInstance); } - protected int[] getRefFiledIndexFromCalc(BatchPhysicalCalc calc) { - return calc.getProgram().getProjectList().stream() - .map(RexSlot::getIndex) - .mapToInt(x -> x) - .toArray(); + protected int[] getRefFiledIndex(BatchPhysicalCalc calc) { + List projects = + calc.getProgram().getProjectList().stream() + .map(calc.getProgram()::expandLocalRef) + .collect(Collectors.toList()); + + return RexNodeExtractor.extractRefInputFields(projects); } protected List translateGroupingArgIndex(List groupingSets, int[] refFields) { diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalHashAggIntoScanRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalHashAggIntoScanRule.java index efe3fb3547f39..0678162929ed7 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalHashAggIntoScanRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalHashAggIntoScanRule.java @@ -67,7 +67,7 @@ public PushLocalHashAggIntoScanRule() { public boolean matches(RelOptRuleCall call) { BatchPhysicalLocalHashAggregate localAggregate = call.rel(1); BatchPhysicalTableSourceScan tableSourceScan = call.rel(2); - return checkMatchesAggregatePushDown(call, localAggregate, tableSourceScan); + return canPushDown(call, localAggregate, tableSourceScan); } @Override diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalHashAggWithCalcIntoScanRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalHashAggWithCalcIntoScanRule.java index 318f49f77c65e..87f47c5154181 100755 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalHashAggWithCalcIntoScanRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalHashAggWithCalcIntoScanRule.java @@ -74,9 +74,9 @@ public boolean matches(RelOptRuleCall call) { BatchPhysicalCalc calc = call.rel(2); BatchPhysicalTableSourceScan tableSourceScan = call.rel(3); - return checkCalcInputRefOnly(calc) - && checkNoProjectionPushDown(tableSourceScan) - && checkMatchesAggregatePushDown(call, localHashAgg, tableSourceScan); + return isInputRefOnly(calc) + && isProjectionNotPushedDown(tableSourceScan) + && canPushDown(call, localHashAgg, tableSourceScan); } @Override @@ -85,7 +85,7 @@ public void onMatch(RelOptRuleCall call) { BatchPhysicalCalc calc = call.rel(2); BatchPhysicalTableSourceScan oldScan = call.rel(3); - int[] calcRefFields = getRefFiledIndexFromCalc(calc); + int[] calcRefFields = getRefFiledIndex(calc); pushLocalAggregateIntoScan(call, localHashAgg, oldScan, calcRefFields); } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggIntoScanRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggIntoScanRule.java index 4ab6903dd34e2..ca101ca3cd8cb 100755 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggIntoScanRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggIntoScanRule.java @@ -67,7 +67,7 @@ public PushLocalSortAggIntoScanRule() { public boolean matches(RelOptRuleCall call) { BatchPhysicalLocalSortAggregate localAggregate = call.rel(1); BatchPhysicalTableSourceScan tableSourceScan = call.rel(2); - return checkMatchesAggregatePushDown(call, localAggregate, tableSourceScan); + return canPushDown(call, localAggregate, tableSourceScan); } @Override diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithCalcIntoScanRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithCalcIntoScanRule.java index 9ed72892bd6a2..e56e3aa9dd866 100755 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithCalcIntoScanRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithCalcIntoScanRule.java @@ -25,7 +25,6 @@ import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalLocalSortAggregate; import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalTableSourceScan; import org.apache.flink.table.planner.plan.schema.TableSourceTable; -import org.apache.flink.table.planner.plan.utils.RexNodeExtractor; import org.apache.calcite.plan.RelOptRuleCall; @@ -75,9 +74,9 @@ public boolean matches(RelOptRuleCall call) { BatchPhysicalCalc calc = call.rel(2); BatchPhysicalTableSourceScan tableSourceScan = call.rel(3); - return checkCalcInputRefOnly(calc) - && checkNoProjectionPushDown(tableSourceScan) - && checkMatchesAggregatePushDown(call, localAggregate, tableSourceScan); + return isInputRefOnly(calc) + && isProjectionNotPushedDown(tableSourceScan) + && canPushDown(call, localAggregate, tableSourceScan); } @Override @@ -86,8 +85,7 @@ public void onMatch(RelOptRuleCall call) { BatchPhysicalCalc calc = call.rel(2); BatchPhysicalTableSourceScan oldScan = call.rel(3); - int[] calcRefFields = - RexNodeExtractor.extractRefInputFields(calc.getProgram().getExprList()); + int[] calcRefFields = getRefFiledIndex(calc); pushLocalAggregateIntoScan(call, localHashAgg, oldScan, calcRefFields); } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithSortAndCalcIntoScanRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithSortAndCalcIntoScanRule.java index 68c31ea104167..d9c340a00edad 100755 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithSortAndCalcIntoScanRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithSortAndCalcIntoScanRule.java @@ -27,12 +27,11 @@ import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalSort; import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalTableSourceScan; import org.apache.flink.table.planner.plan.schema.TableSourceTable; -import org.apache.flink.table.planner.plan.utils.RexNodeExtractor; import org.apache.calcite.plan.RelOptRuleCall; /** - * Planner rule that tries to push a local sort aggregate which with sort into a {@link + * Planner rule that tries to push a local sort aggregate which with sort and calc into a {@link * BatchPhysicalTableSourceScan} whose table is a {@link TableSourceTable} with a source supporting * {@link SupportsAggregatePushDown}. The {@link * OptimizerConfigOptions#TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED} need to be true. @@ -84,9 +83,9 @@ public boolean matches(RelOptRuleCall call) { BatchPhysicalCalc calc = call.rel(3); BatchPhysicalTableSourceScan tableSourceScan = call.rel(4); - return checkCalcInputRefOnly(calc) - && checkNoProjectionPushDown(tableSourceScan) - && checkMatchesAggregatePushDown(call, localAggregate, tableSourceScan); + return isInputRefOnly(calc) + && isProjectionNotPushedDown(tableSourceScan) + && canPushDown(call, localAggregate, tableSourceScan); } @Override @@ -95,8 +94,7 @@ public void onMatch(RelOptRuleCall call) { BatchPhysicalCalc calc = call.rel(3); BatchPhysicalTableSourceScan oldScan = call.rel(4); - int[] calcRefFields = - RexNodeExtractor.extractRefInputFields(calc.getProgram().getExprList()); + int[] calcRefFields = getRefFiledIndex(calc); pushLocalAggregateIntoScan(call, localSortAgg, oldScan, calcRefFields); } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithSortIntoScanRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithSortIntoScanRule.java index bfcc353f50bfe..9d952b253afd7 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithSortIntoScanRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithSortIntoScanRule.java @@ -75,7 +75,7 @@ public PushLocalSortAggWithSortIntoScanRule() { public boolean matches(RelOptRuleCall call) { BatchPhysicalGroupAggregateBase localAggregate = call.rel(1); BatchPhysicalTableSourceScan tableSourceScan = call.rel(3); - return checkMatchesAggregatePushDown(call, localAggregate, tableSourceScan); + return canPushDown(call, localAggregate, tableSourceScan); } @Override diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java index fcf8ea9929a73..404d05a3c3e19 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java @@ -298,8 +298,8 @@ private static RowKind parseRowKind(String rowKindShortString) { private static final ConfigOption SINK_EXPECTED_MESSAGES_NUM = ConfigOptions.key("sink-expected-messages-num").intType().defaultValue(-1); - private static final ConfigOption DISABLE_PROJECTION_PUSH_DOWN = - ConfigOptions.key("disable-projection-push-down").booleanType().defaultValue(false); + private static final ConfigOption ENABLE_PROJECTION_PUSH_DOWN = + ConfigOptions.key("enable-projection-push-down").booleanType().defaultValue(true); private static final ConfigOption NESTED_PROJECTION_SUPPORTED = ConfigOptions.key("nested-projection-supported").booleanType().defaultValue(false); @@ -376,7 +376,7 @@ public DynamicTableSource createDynamicTableSource(Context context) { boolean isAsync = helper.getOptions().get(ASYNC_ENABLED); String lookupFunctionClass = helper.getOptions().get(LOOKUP_FUNCTION_CLASS); boolean disableLookup = helper.getOptions().get(DISABLE_LOOKUP); - boolean disableProjectionPushDown = helper.getOptions().get(DISABLE_PROJECTION_PUSH_DOWN); + boolean enableProjectionPushDown = helper.getOptions().get(ENABLE_PROJECTION_PUSH_DOWN); boolean nestedProjectionSupported = helper.getOptions().get(NESTED_PROJECTION_SUPPORTED); boolean enableWatermarkPushDown = helper.getOptions().get(ENABLE_WATERMARK_PUSH_DOWN); boolean failingSource = helper.getOptions().get(FAILING_SOURCE); @@ -414,7 +414,7 @@ public DynamicTableSource createDynamicTableSource(Context context) { partition2Rows.put(Collections.emptyMap(), data); } - if (disableProjectionPushDown) { + if (!enableProjectionPushDown) { return new TestValuesScanTableSourceWithoutProjectionPushDown( producedDataType, changelogMode, @@ -576,7 +576,7 @@ public Set> optionalOptions() { SINK_INSERT_ONLY, RUNTIME_SINK, SINK_EXPECTED_MESSAGES_NUM, - DISABLE_PROJECTION_PUSH_DOWN, + ENABLE_PROJECTION_PUSH_DOWN, NESTED_PROJECTION_SUPPORTED, FILTERABLE_FIELDS, PARTITION_LIST, diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.java index f01725b07dd46..1312f8b6d2568 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.java @@ -105,7 +105,7 @@ public void setup() { + "WITH (\n" + " 'connector' = 'values',\n" + " 'filterable-fields' = 'id;type',\n" - + " 'disable-projection-push-down' = 'true',\n" + + " 'enable-projection-push-down' = 'false',\n" + " 'bounded' = 'true'\n" + ")"; util.tableEnv().executeSql(ddl4); @@ -259,12 +259,18 @@ public void testCanPushDownLocalAggWithoutProjectionPushDown() { } @Test - public void testCannotPushDownLocalAggWithAuxGrouping() { + public void testCanPushDownLocalAggWithAuxGrouping() { + // enable two-phase aggregate, otherwise there is no local aggregate + util.getTableEnv() + .getConfig() + .getConfiguration() + .setString(OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY, "TWO_PHASE"); + util.verifyRelPlan( "SELECT\n" + " id, name, count(*)\n" + "FROM inventory_meta\n" - + " group by id, name, abs(amount)"); + + " group by id, name"); } @Test diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/batch/sql/agg/LocalAggregatePushDownITCase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/batch/sql/agg/LocalAggregatePushDownITCase.java index 763dbd79577bc..b6522c0d99db1 100755 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/batch/sql/agg/LocalAggregatePushDownITCase.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/batch/sql/agg/LocalAggregatePushDownITCase.java @@ -51,7 +51,8 @@ public void before() { + " deposit bigint,\n" + " points bigint,\n" + " metadata_1 BIGINT METADATA,\n" - + " metadata_2 STRING METADATA\n" + + " metadata_2 STRING METADATA,\n" + + " PRIMARY KEY (`id`) NOT ENFORCED\n" + ") WITH (\n" + " 'connector' = 'values',\n" + " 'data-id' = '" @@ -64,7 +65,6 @@ public void before() { tEnv().executeSql(ddl); // partitioned table - String testDataId2 = TestValuesTableFactory.registerData(TestData.personData()); String ddl2 = "CREATE TABLE AggregatableTable_Part (\n" + " id int,\n" @@ -89,7 +89,6 @@ public void before() { tEnv().executeSql(ddl2); // partitioned table - String testDataId3 = TestValuesTableFactory.registerData(TestData.personData()); String ddl3 = "CREATE TABLE AggregatableTable_No_Proj (\n" + " id int,\n" @@ -108,7 +107,7 @@ public void before() { + testDataId + "',\n" + " 'filterable-fields' = 'id;age',\n" - + " 'disable-projection-push-down' = 'true',\n" + + " 'enable-projection-push-down' = 'false',\n" + " 'bounded' = 'true'\n" + ")"; tEnv().executeSql(ddl3); @@ -292,4 +291,28 @@ public void testPushDownLocalAggWithoutProjectionPushDown() { Arrays.asList(Row.of(50, 50, 1, "f", 19), Row.of(200, 200, 1, "f", 20))), false); } + + @Test + public void testPushDownLocalAggWithoutAuxGrouping() { + // enable two-phase aggregate, otherwise there is no local aggregate + tEnv().getConfig() + .getConfiguration() + .setString(OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY, "TWO_PHASE"); + + checkResult( + "SELECT\n" + + " id,\n" + + " name,\n" + + " count(*)\n" + + "FROM\n" + + " AggregatableTable\n" + + "WHERE id > 8\n" + + "GROUP BY id, name", + JavaScalaConversionUtil.toScala( + Arrays.asList( + Row.of(9, "emma", 1), + Row.of(10, "benji", 1), + Row.of(11, "eva", 1))), + false); + } } diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.xml index f379f0426df6b..fc4e0d9ed0e87 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.xml @@ -146,29 +146,25 @@ Calc(select=[EXPR$0, name, type]) ]]> - + + group by id, name]]> From aec0bc1492f607606f53ce208dc5c46aa07d2706 Mon Sep 17 00:00:00 2001 From: "Yu, Peng" Date: Fri, 29 Oct 2021 10:46:36 +0800 Subject: [PATCH 7/7] [FLINK-20895] enable push down local aggregate for existing unit tests --- .../flink/table/planner/plan/batch/sql/RankTest.xml | 3 +-- .../table/planner/plan/batch/sql/TableSourceTest.xml | 3 +-- .../apache/flink/table/planner/utils/TableTestBase.scala | 9 ++------- 3 files changed, 4 insertions(+), 11 deletions(-) diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/RankTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/RankTest.xml index c25de396b3c30..b587212b88a55 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/RankTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/RankTest.xml @@ -40,8 +40,7 @@ Sink(table=[default_catalog.default_database.sink], fields=[name, eat, cnt]) +- Exchange(distribution=[hash[name]]) +- HashAggregate(isMerge=[true], groupBy=[name, eat], select=[name, eat, Final_SUM(sum$0) AS cnt]) +- Exchange(distribution=[hash[name, eat]]) - +- LocalHashAggregate(groupBy=[name, eat], select=[name, eat, Partial_SUM(age) AS sum$0]) - +- TableSourceScan(table=[[default_catalog, default_database, test_source]], fields=[name, eat, age]) + +- TableSourceScan(table=[[default_catalog, default_database, test_source, aggregates=[grouping=[name,eat], aggFunctions=[LongSumAggFunction(age)]]]], fields=[name, eat, sum$0]) ]]> diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/TableSourceTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/TableSourceTest.xml index f28e54631f59b..794c35401447c 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/TableSourceTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/TableSourceTest.xml @@ -133,8 +133,7 @@ LogicalAggregate(group=[{}], EXPR$0=[COUNT()]) diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala index c5204291a442b..d27af9c1c6697 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala @@ -22,6 +22,7 @@ import _root_.java.util import java.io.{File, IOException} import java.nio.file.{Files, Paths} import java.time.Duration + import org.apache.calcite.avatica.util.TimeUnit import org.apache.calcite.rel.RelNode import org.apache.calcite.sql.parser.SqlParserPos @@ -41,7 +42,7 @@ import org.apache.flink.table.api.bridge.java.internal.{StreamTableEnvironmentIm import org.apache.flink.table.api.bridge.java.{StreamTableEnvironment => JavaStreamTableEnv} import org.apache.flink.table.api.bridge.scala.internal.{StreamTableEnvironmentImpl => ScalaStreamTableEnvImpl} import org.apache.flink.table.api.bridge.scala.{StreamTableEnvironment => ScalaStreamTableEnv} -import org.apache.flink.table.api.config.{ExecutionConfigOptions, OptimizerConfigOptions} +import org.apache.flink.table.api.config.ExecutionConfigOptions import org.apache.flink.table.api.internal.{TableEnvironmentImpl, TableEnvironmentInternal, TableImpl} import org.apache.flink.table.catalog.{CatalogManager, FunctionCatalog, GenericInMemoryCatalog, ObjectIdentifier} import org.apache.flink.table.data.RowData @@ -1001,12 +1002,6 @@ abstract class TableTestUtil( .getConfiguration .set(ExecutionOptions.BATCH_SHUFFLE_MODE, BatchShuffleMode.ALL_EXCHANGES_PIPELINED) - // Disable push down aggregates to avoid conflicts with existing test cases that verify plans. - tableEnv.getConfig - .getConfiguration - .setBoolean( - OptimizerConfigOptions.TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED, false) - private val env: StreamExecutionEnvironment = getPlanner.getExecEnv override def getTableEnv: TableEnvironment = tableEnv