Skip to content

Commit

Permalink
[FLINK-19364][python] Add Batch Physical Pandas Group Window Aggregat…
Browse files Browse the repository at this point in the history
…e Rule and RelNode

This closes #13460.
  • Loading branch information
HuangXingBo authored and dianfu committed Sep 24, 2020
1 parent da80206 commit 51368b8
Show file tree
Hide file tree
Showing 8 changed files with 702 additions and 2 deletions.
95 changes: 95 additions & 0 deletions flink-python/pyflink/table/tests/test_pandas_udaf.py
Expand Up @@ -17,6 +17,7 @@
################################################################################
import unittest

from pyflink.table import expressions as expr
from pyflink.table.types import DataTypes
from pyflink.table.udf import udaf, udf, AggregateFunction
from pyflink.testing import source_sink_utils
Expand Down Expand Up @@ -94,6 +95,100 @@ def test_group_aggregate_with_aux_group(self):
actual = source_sink_utils.results()
self.assert_equals(actual, ["1,2,2.0,6", "2,3,3.0,8", "3,4,4.0,10"])

def test_tumble_group_window_aggregate_function(self):
import datetime
from pyflink.table.window import Tumble
t = self.t_env.from_elements(
[
(1, 2, 3, datetime.datetime(2018, 3, 11, 3, 10, 0, 0)),
(3, 2, 4, datetime.datetime(2018, 3, 11, 3, 10, 0, 0)),
(2, 1, 2, datetime.datetime(2018, 3, 11, 3, 10, 0, 0)),
(1, 3, 1, datetime.datetime(2018, 3, 11, 3, 40, 0, 0)),
(1, 8, 5, datetime.datetime(2018, 3, 11, 4, 20, 0, 0)),
(2, 3, 6, datetime.datetime(2018, 3, 11, 3, 30, 0, 0))
],
DataTypes.ROW(
[DataTypes.FIELD("a", DataTypes.TINYINT()),
DataTypes.FIELD("b", DataTypes.SMALLINT()),
DataTypes.FIELD("c", DataTypes.INT()),
DataTypes.FIELD("rowtime", DataTypes.TIMESTAMP(3))]))

table_sink = source_sink_utils.TestAppendSink(
['a', 'b', 'c'],
[
DataTypes.TIMESTAMP(3),
DataTypes.TIMESTAMP(3),
DataTypes.FLOAT()
])
self.t_env.register_table_sink("Results", table_sink)
self.t_env.create_temporary_system_function("mean_udaf", mean_udaf)
tumble_window = Tumble.over(expr.lit(1).hours) \
.on(expr.col("rowtime")) \
.alias("w")
t.window(tumble_window) \
.group_by("w") \
.select("w.start, w.end, mean_udaf(b)") \
.execute_insert("Results") \
.wait()

actual = source_sink_utils.results()
self.assert_equals(actual,
["2018-03-11 03:00:00.0,2018-03-11 04:00:00.0,2.2",
"2018-03-11 04:00:00.0,2018-03-11 05:00:00.0,8.0"])

def test_slide_group_window_aggregate_function(self):
import datetime
from pyflink.table.window import Slide
t = self.t_env.from_elements(
[
(1, 2, 3, datetime.datetime(2018, 3, 11, 3, 10, 0, 0)),
(3, 2, 4, datetime.datetime(2018, 3, 11, 3, 10, 0, 0)),
(2, 1, 2, datetime.datetime(2018, 3, 11, 3, 10, 0, 0)),
(1, 3, 1, datetime.datetime(2018, 3, 11, 3, 40, 0, 0)),
(1, 8, 5, datetime.datetime(2018, 3, 11, 4, 20, 0, 0)),
(2, 3, 6, datetime.datetime(2018, 3, 11, 3, 30, 0, 0))
],
DataTypes.ROW(
[DataTypes.FIELD("a", DataTypes.TINYINT()),
DataTypes.FIELD("b", DataTypes.SMALLINT()),
DataTypes.FIELD("c", DataTypes.INT()),
DataTypes.FIELD("rowtime", DataTypes.TIMESTAMP(3))]))

table_sink = source_sink_utils.TestAppendSink(
['a', 'b', 'c', 'd', 'e'],
[
DataTypes.TINYINT(),
DataTypes.TIMESTAMP(3),
DataTypes.TIMESTAMP(3),
DataTypes.FLOAT(),
DataTypes.INT()
])
self.t_env.register_table_sink("Results", table_sink)
self.t_env.register_function("max_add", udaf(MaxAdd(),
result_type=DataTypes.INT(),
func_type="pandas"))
self.t_env.create_temporary_system_function("mean_udaf", mean_udaf)
slide_window = Slide.over(expr.lit(1).hours) \
.every(expr.lit(30).minutes) \
.on(expr.col("rowtime")) \
.alias("w")
t.window(slide_window) \
.group_by("a, w") \
.select("a, w.start, w.end, mean_udaf(b), max_add(b, c, 1)") \
.execute_insert("Results") \
.wait()
actual = source_sink_utils.results()
self.assert_equals(actual,
["1,2018-03-11 02:30:00.0,2018-03-11 03:30:00.0,2.0,6",
"1,2018-03-11 03:00:00.0,2018-03-11 04:00:00.0,2.5,7",
"1,2018-03-11 03:30:00.0,2018-03-11 04:30:00.0,5.5,14",
"1,2018-03-11 04:00:00.0,2018-03-11 05:00:00.0,8.0,14",
"2,2018-03-11 02:30:00.0,2018-03-11 03:30:00.0,1.0,4",
"2,2018-03-11 03:00:00.0,2018-03-11 04:00:00.0,2.0,10",
"2,2018-03-11 03:30:00.0,2018-03-11 04:30:00.0,3.0,10",
"3,2018-03-11 03:00:00.0,2018-03-11 04:00:00.0,2.0,7",
"3,2018-03-11 02:30:00.0,2018-03-11 03:30:00.0,2.0,7"])


@udaf(result_type=DataTypes.FLOAT(), func_type="pandas")
def mean_udaf(v):
Expand Down
Expand Up @@ -31,6 +31,7 @@
import org.apache.flink.table.runtime.operators.window.TimeWindow;
import org.apache.flink.table.runtime.operators.window.grouping.HeapWindowsGrouping;
import org.apache.flink.table.runtime.operators.window.grouping.WindowsGrouping;
import org.apache.flink.table.runtime.typeutils.RowDataSerializer;
import org.apache.flink.table.runtime.util.RowIterator;
import org.apache.flink.table.types.logical.RowType;

Expand Down Expand Up @@ -92,6 +93,11 @@ public class BatchArrowPythonGroupWindowAggregateFunctionOperator
*/
private transient LinkedList<Tuple2<RowData, TimeWindow>> inputKeyAndWindow;

/**
* The type serializer for the forwarded fields.
*/
private transient RowDataSerializer forwardedInputSerializer;

public BatchArrowPythonGroupWindowAggregateFunctionOperator(
Configuration config,
PythonFunctionInfo[] pandasAggFunctions,
Expand Down Expand Up @@ -122,6 +128,7 @@ public void open() throws Exception {
windowAggResult = new JoinedRowData();
windowsGrouping = new HeapWindowsGrouping(
maxLimitSize, windowSize, slideSize, inputTimeFieldIndex, false);
forwardedInputSerializer = new RowDataSerializer(inputType);
super.open();
}

Expand Down Expand Up @@ -152,7 +159,7 @@ protected void invokeCurrentBatch() throws Exception {

@Override
public void processElementInternal(RowData value) throws Exception {
windowsGrouping.addInputToBuffer((BinaryRowData) value);
windowsGrouping.addInputToBuffer(forwardedInputSerializer.toBinaryRow(value).copy());
triggerWindowProcess();
}

Expand Down
@@ -0,0 +1,167 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.planner.plan.rules.physical.batch;

import org.apache.flink.table.api.TableException;
import org.apache.flink.table.functions.UserDefinedFunction;
import org.apache.flink.table.functions.python.PythonFunctionKind;
import org.apache.flink.table.planner.calcite.FlinkRelFactories;
import org.apache.flink.table.planner.plan.logical.LogicalWindow;
import org.apache.flink.table.planner.plan.logical.SessionGroupWindow;
import org.apache.flink.table.planner.plan.logical.SlidingGroupWindow;
import org.apache.flink.table.planner.plan.logical.TumblingGroupWindow;
import org.apache.flink.table.planner.plan.nodes.FlinkConventions;
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalWindowAggregate;
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchExecPythonGroupWindowAggregate;
import org.apache.flink.table.planner.plan.trait.FlinkRelDistribution;
import org.apache.flink.table.planner.plan.utils.AggregateUtil;
import org.apache.flink.table.planner.plan.utils.FlinkRelOptUtil;
import org.apache.flink.table.planner.plan.utils.PythonUtil;
import org.apache.flink.table.types.DataType;

import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.sql.type.SqlTypeName;

import java.util.LinkedList;
import java.util.List;

import scala.Tuple2;
import scala.Tuple3;
import scala.collection.Seq;

/**
* The physical rule is responsible for convert {@link FlinkLogicalWindowAggregate} to
* {@link BatchExecPythonGroupWindowAggregate}.
*/
public class BatchExecPythonWindowAggregateRule extends RelOptRule {

public static final RelOptRule INSTANCE = new BatchExecPythonWindowAggregateRule();

private BatchExecPythonWindowAggregateRule() {
super(operand(FlinkLogicalWindowAggregate.class,
operand(RelNode.class, any())),
FlinkRelFactories.LOGICAL_BUILDER_WITHOUT_AGG_INPUT_PRUNE(),
"BatchExecPythonWindowAggregateRule");
}

@Override
public boolean matches(RelOptRuleCall call) {
FlinkLogicalWindowAggregate agg = call.rel(0);
List<AggregateCall> aggCalls = agg.getAggCallList();

boolean existGeneralPythonFunction =
aggCalls.stream().anyMatch(x -> PythonUtil.isPythonAggregate(x, PythonFunctionKind.GENERAL));
boolean existPandasFunction =
aggCalls.stream().anyMatch(x -> PythonUtil.isPythonAggregate(x, PythonFunctionKind.PANDAS));
boolean existJavaFunction =
aggCalls.stream().anyMatch(x -> !PythonUtil.isPythonAggregate(x, null));
if (existPandasFunction || existGeneralPythonFunction) {
if (existJavaFunction) {
throw new TableException("Python UDAF and Java/Scala UDAF cannot be used together.");
}
if (existPandasFunction && existGeneralPythonFunction) {
throw new TableException("Pandas UDAF and non-Pandas UDAF cannot be used together.");
}
return true;
} else {
return false;
}

}

@Override
public void onMatch(RelOptRuleCall call) {
FlinkLogicalWindowAggregate agg = call.rel(0);
RelNode input = agg.getInput();
LogicalWindow window = agg.getWindow();

if (!(window instanceof TumblingGroupWindow &&
AggregateUtil.hasTimeIntervalType(((TumblingGroupWindow) window).size())
|| window instanceof SlidingGroupWindow &&
AggregateUtil.hasTimeIntervalType(((SlidingGroupWindow) window).size())
|| window instanceof SessionGroupWindow)) {
// sliding & tumbling count window and session window not supported
throw new TableException("Window " + window + " is not supported right now.");
}

int[] groupSet = agg.getGroupSet().toArray();
RelTraitSet traitSet = agg.getTraitSet().replace(FlinkConventions.BATCH_PHYSICAL());

Tuple2<int[], Seq<AggregateCall>> auxGroupSetAndCallsTuple = AggregateUtil.checkAndSplitAggCalls(agg);
int[] auxGroupSet = auxGroupSetAndCallsTuple._1;
Seq<AggregateCall> aggCallsWithoutAuxGroupCalls = auxGroupSetAndCallsTuple._2;

Tuple3<int[][], DataType[][], UserDefinedFunction[]> aggBufferTypesAndFunctions =
AggregateUtil.transformToBatchAggregateFunctions(
aggCallsWithoutAuxGroupCalls, input.getRowType(), null);
UserDefinedFunction[] aggFunctions = aggBufferTypesAndFunctions._3();

int inputTimeFieldIndex = AggregateUtil.timeFieldIndex(
input.getRowType(), call.builder(), window.timeAttribute());
RelDataType inputTimeFieldType = input.getRowType().getFieldList().get(inputTimeFieldIndex).getType();
boolean inputTimeIsDate = inputTimeFieldType.getSqlTypeName() == SqlTypeName.DATE;

RelTraitSet requiredTraitSet = agg.getTraitSet().replace(FlinkConventions.BATCH_PHYSICAL());
if (groupSet.length != 0) {
FlinkRelDistribution requiredDistribution =
FlinkRelDistribution.hash(groupSet, false);
requiredTraitSet = requiredTraitSet.replace(requiredDistribution);
} else {
requiredTraitSet = requiredTraitSet.replace(FlinkRelDistribution.SINGLETON());
}

RelCollation sortCollation = createRelCollation(groupSet, inputTimeFieldIndex);
requiredTraitSet = requiredTraitSet.replace(sortCollation);

RelNode newInput = RelOptRule.convert(input, requiredTraitSet);
BatchExecPythonGroupWindowAggregate windowAgg = new BatchExecPythonGroupWindowAggregate(
agg.getCluster(),
call.builder(),
traitSet,
newInput,
agg.getRowType(),
newInput.getRowType(),
groupSet,
auxGroupSet,
aggCallsWithoutAuxGroupCalls,
aggFunctions,
window,
inputTimeFieldIndex,
inputTimeIsDate,
agg.getNamedProperties());
call.transformTo(windowAgg);
}

private RelCollation createRelCollation(int[] groupSet, int timeIndex) {
List<RelFieldCollation> fields = new LinkedList<>();
for (int value : groupSet) {
fields.add(FlinkRelOptUtil.ofRelFieldCollation(value));
}
fields.add(FlinkRelOptUtil.ofRelFieldCollation(timeIndex));
return RelCollations.of(fields);
}
}

0 comments on commit 51368b8

Please sign in to comment.