diff --git a/flink-python/pyflink/table/tests/test_pandas_udaf.py b/flink-python/pyflink/table/tests/test_pandas_udaf.py index 7db721b176052..044c1ae29d498 100644 --- a/flink-python/pyflink/table/tests/test_pandas_udaf.py +++ b/flink-python/pyflink/table/tests/test_pandas_udaf.py @@ -17,6 +17,7 @@ ################################################################################ import unittest +from pyflink.table import expressions as expr from pyflink.table.types import DataTypes from pyflink.table.udf import udaf, udf, AggregateFunction from pyflink.testing import source_sink_utils @@ -94,6 +95,100 @@ def test_group_aggregate_with_aux_group(self): actual = source_sink_utils.results() self.assert_equals(actual, ["1,2,2.0,6", "2,3,3.0,8", "3,4,4.0,10"]) + def test_tumble_group_window_aggregate_function(self): + import datetime + from pyflink.table.window import Tumble + t = self.t_env.from_elements( + [ + (1, 2, 3, datetime.datetime(2018, 3, 11, 3, 10, 0, 0)), + (3, 2, 4, datetime.datetime(2018, 3, 11, 3, 10, 0, 0)), + (2, 1, 2, datetime.datetime(2018, 3, 11, 3, 10, 0, 0)), + (1, 3, 1, datetime.datetime(2018, 3, 11, 3, 40, 0, 0)), + (1, 8, 5, datetime.datetime(2018, 3, 11, 4, 20, 0, 0)), + (2, 3, 6, datetime.datetime(2018, 3, 11, 3, 30, 0, 0)) + ], + DataTypes.ROW( + [DataTypes.FIELD("a", DataTypes.TINYINT()), + DataTypes.FIELD("b", DataTypes.SMALLINT()), + DataTypes.FIELD("c", DataTypes.INT()), + DataTypes.FIELD("rowtime", DataTypes.TIMESTAMP(3))])) + + table_sink = source_sink_utils.TestAppendSink( + ['a', 'b', 'c'], + [ + DataTypes.TIMESTAMP(3), + DataTypes.TIMESTAMP(3), + DataTypes.FLOAT() + ]) + self.t_env.register_table_sink("Results", table_sink) + self.t_env.create_temporary_system_function("mean_udaf", mean_udaf) + tumble_window = Tumble.over(expr.lit(1).hours) \ + .on(expr.col("rowtime")) \ + .alias("w") + t.window(tumble_window) \ + .group_by("w") \ + .select("w.start, w.end, mean_udaf(b)") \ + .execute_insert("Results") \ + .wait() + + actual = source_sink_utils.results() + self.assert_equals(actual, + ["2018-03-11 03:00:00.0,2018-03-11 04:00:00.0,2.2", + "2018-03-11 04:00:00.0,2018-03-11 05:00:00.0,8.0"]) + + def test_slide_group_window_aggregate_function(self): + import datetime + from pyflink.table.window import Slide + t = self.t_env.from_elements( + [ + (1, 2, 3, datetime.datetime(2018, 3, 11, 3, 10, 0, 0)), + (3, 2, 4, datetime.datetime(2018, 3, 11, 3, 10, 0, 0)), + (2, 1, 2, datetime.datetime(2018, 3, 11, 3, 10, 0, 0)), + (1, 3, 1, datetime.datetime(2018, 3, 11, 3, 40, 0, 0)), + (1, 8, 5, datetime.datetime(2018, 3, 11, 4, 20, 0, 0)), + (2, 3, 6, datetime.datetime(2018, 3, 11, 3, 30, 0, 0)) + ], + DataTypes.ROW( + [DataTypes.FIELD("a", DataTypes.TINYINT()), + DataTypes.FIELD("b", DataTypes.SMALLINT()), + DataTypes.FIELD("c", DataTypes.INT()), + DataTypes.FIELD("rowtime", DataTypes.TIMESTAMP(3))])) + + table_sink = source_sink_utils.TestAppendSink( + ['a', 'b', 'c', 'd', 'e'], + [ + DataTypes.TINYINT(), + DataTypes.TIMESTAMP(3), + DataTypes.TIMESTAMP(3), + DataTypes.FLOAT(), + DataTypes.INT() + ]) + self.t_env.register_table_sink("Results", table_sink) + self.t_env.register_function("max_add", udaf(MaxAdd(), + result_type=DataTypes.INT(), + func_type="pandas")) + self.t_env.create_temporary_system_function("mean_udaf", mean_udaf) + slide_window = Slide.over(expr.lit(1).hours) \ + .every(expr.lit(30).minutes) \ + .on(expr.col("rowtime")) \ + .alias("w") + t.window(slide_window) \ + .group_by("a, w") \ + .select("a, w.start, w.end, mean_udaf(b), max_add(b, c, 1)") \ + .execute_insert("Results") \ + .wait() + actual = source_sink_utils.results() + self.assert_equals(actual, + ["1,2018-03-11 02:30:00.0,2018-03-11 03:30:00.0,2.0,6", + "1,2018-03-11 03:00:00.0,2018-03-11 04:00:00.0,2.5,7", + "1,2018-03-11 03:30:00.0,2018-03-11 04:30:00.0,5.5,14", + "1,2018-03-11 04:00:00.0,2018-03-11 05:00:00.0,8.0,14", + "2,2018-03-11 02:30:00.0,2018-03-11 03:30:00.0,1.0,4", + "2,2018-03-11 03:00:00.0,2018-03-11 04:00:00.0,2.0,10", + "2,2018-03-11 03:30:00.0,2018-03-11 04:30:00.0,3.0,10", + "3,2018-03-11 03:00:00.0,2018-03-11 04:00:00.0,2.0,7", + "3,2018-03-11 02:30:00.0,2018-03-11 03:30:00.0,2.0,7"]) + @udaf(result_type=DataTypes.FLOAT(), func_type="pandas") def mean_udaf(v): diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/BatchArrowPythonGroupWindowAggregateFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/BatchArrowPythonGroupWindowAggregateFunctionOperator.java index 8d6f02e472cb3..f8ad1c3cbd4d4 100644 --- a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/BatchArrowPythonGroupWindowAggregateFunctionOperator.java +++ b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/BatchArrowPythonGroupWindowAggregateFunctionOperator.java @@ -31,6 +31,7 @@ import org.apache.flink.table.runtime.operators.window.TimeWindow; import org.apache.flink.table.runtime.operators.window.grouping.HeapWindowsGrouping; import org.apache.flink.table.runtime.operators.window.grouping.WindowsGrouping; +import org.apache.flink.table.runtime.typeutils.RowDataSerializer; import org.apache.flink.table.runtime.util.RowIterator; import org.apache.flink.table.types.logical.RowType; @@ -92,6 +93,11 @@ public class BatchArrowPythonGroupWindowAggregateFunctionOperator */ private transient LinkedList> inputKeyAndWindow; + /** + * The type serializer for the forwarded fields. + */ + private transient RowDataSerializer forwardedInputSerializer; + public BatchArrowPythonGroupWindowAggregateFunctionOperator( Configuration config, PythonFunctionInfo[] pandasAggFunctions, @@ -122,6 +128,7 @@ public void open() throws Exception { windowAggResult = new JoinedRowData(); windowsGrouping = new HeapWindowsGrouping( maxLimitSize, windowSize, slideSize, inputTimeFieldIndex, false); + forwardedInputSerializer = new RowDataSerializer(inputType); super.open(); } @@ -152,7 +159,7 @@ protected void invokeCurrentBatch() throws Exception { @Override public void processElementInternal(RowData value) throws Exception { - windowsGrouping.addInputToBuffer((BinaryRowData) value); + windowsGrouping.addInputToBuffer(forwardedInputSerializer.toBinaryRow(value).copy()); triggerWindowProcess(); } diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecPythonWindowAggregateRule.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecPythonWindowAggregateRule.java new file mode 100644 index 0000000000000..7120dfc8fae03 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecPythonWindowAggregateRule.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.rules.physical.batch; + +import org.apache.flink.table.api.TableException; +import org.apache.flink.table.functions.UserDefinedFunction; +import org.apache.flink.table.functions.python.PythonFunctionKind; +import org.apache.flink.table.planner.calcite.FlinkRelFactories; +import org.apache.flink.table.planner.plan.logical.LogicalWindow; +import org.apache.flink.table.planner.plan.logical.SessionGroupWindow; +import org.apache.flink.table.planner.plan.logical.SlidingGroupWindow; +import org.apache.flink.table.planner.plan.logical.TumblingGroupWindow; +import org.apache.flink.table.planner.plan.nodes.FlinkConventions; +import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalWindowAggregate; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchExecPythonGroupWindowAggregate; +import org.apache.flink.table.planner.plan.trait.FlinkRelDistribution; +import org.apache.flink.table.planner.plan.utils.AggregateUtil; +import org.apache.flink.table.planner.plan.utils.FlinkRelOptUtil; +import org.apache.flink.table.planner.plan.utils.PythonUtil; +import org.apache.flink.table.types.DataType; + +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.RelCollation; +import org.apache.calcite.rel.RelCollations; +import org.apache.calcite.rel.RelFieldCollation; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.sql.type.SqlTypeName; + +import java.util.LinkedList; +import java.util.List; + +import scala.Tuple2; +import scala.Tuple3; +import scala.collection.Seq; + +/** + * The physical rule is responsible for convert {@link FlinkLogicalWindowAggregate} to + * {@link BatchExecPythonGroupWindowAggregate}. + */ +public class BatchExecPythonWindowAggregateRule extends RelOptRule { + + public static final RelOptRule INSTANCE = new BatchExecPythonWindowAggregateRule(); + + private BatchExecPythonWindowAggregateRule() { + super(operand(FlinkLogicalWindowAggregate.class, + operand(RelNode.class, any())), + FlinkRelFactories.LOGICAL_BUILDER_WITHOUT_AGG_INPUT_PRUNE(), + "BatchExecPythonWindowAggregateRule"); + } + + @Override + public boolean matches(RelOptRuleCall call) { + FlinkLogicalWindowAggregate agg = call.rel(0); + List aggCalls = agg.getAggCallList(); + + boolean existGeneralPythonFunction = + aggCalls.stream().anyMatch(x -> PythonUtil.isPythonAggregate(x, PythonFunctionKind.GENERAL)); + boolean existPandasFunction = + aggCalls.stream().anyMatch(x -> PythonUtil.isPythonAggregate(x, PythonFunctionKind.PANDAS)); + boolean existJavaFunction = + aggCalls.stream().anyMatch(x -> !PythonUtil.isPythonAggregate(x, null)); + if (existPandasFunction || existGeneralPythonFunction) { + if (existJavaFunction) { + throw new TableException("Python UDAF and Java/Scala UDAF cannot be used together."); + } + if (existPandasFunction && existGeneralPythonFunction) { + throw new TableException("Pandas UDAF and non-Pandas UDAF cannot be used together."); + } + return true; + } else { + return false; + } + + } + + @Override + public void onMatch(RelOptRuleCall call) { + FlinkLogicalWindowAggregate agg = call.rel(0); + RelNode input = agg.getInput(); + LogicalWindow window = agg.getWindow(); + + if (!(window instanceof TumblingGroupWindow && + AggregateUtil.hasTimeIntervalType(((TumblingGroupWindow) window).size()) + || window instanceof SlidingGroupWindow && + AggregateUtil.hasTimeIntervalType(((SlidingGroupWindow) window).size()) + || window instanceof SessionGroupWindow)) { + // sliding & tumbling count window and session window not supported + throw new TableException("Window " + window + " is not supported right now."); + } + + int[] groupSet = agg.getGroupSet().toArray(); + RelTraitSet traitSet = agg.getTraitSet().replace(FlinkConventions.BATCH_PHYSICAL()); + + Tuple2> auxGroupSetAndCallsTuple = AggregateUtil.checkAndSplitAggCalls(agg); + int[] auxGroupSet = auxGroupSetAndCallsTuple._1; + Seq aggCallsWithoutAuxGroupCalls = auxGroupSetAndCallsTuple._2; + + Tuple3 aggBufferTypesAndFunctions = + AggregateUtil.transformToBatchAggregateFunctions( + aggCallsWithoutAuxGroupCalls, input.getRowType(), null); + UserDefinedFunction[] aggFunctions = aggBufferTypesAndFunctions._3(); + + int inputTimeFieldIndex = AggregateUtil.timeFieldIndex( + input.getRowType(), call.builder(), window.timeAttribute()); + RelDataType inputTimeFieldType = input.getRowType().getFieldList().get(inputTimeFieldIndex).getType(); + boolean inputTimeIsDate = inputTimeFieldType.getSqlTypeName() == SqlTypeName.DATE; + + RelTraitSet requiredTraitSet = agg.getTraitSet().replace(FlinkConventions.BATCH_PHYSICAL()); + if (groupSet.length != 0) { + FlinkRelDistribution requiredDistribution = + FlinkRelDistribution.hash(groupSet, false); + requiredTraitSet = requiredTraitSet.replace(requiredDistribution); + } else { + requiredTraitSet = requiredTraitSet.replace(FlinkRelDistribution.SINGLETON()); + } + + RelCollation sortCollation = createRelCollation(groupSet, inputTimeFieldIndex); + requiredTraitSet = requiredTraitSet.replace(sortCollation); + + RelNode newInput = RelOptRule.convert(input, requiredTraitSet); + BatchExecPythonGroupWindowAggregate windowAgg = new BatchExecPythonGroupWindowAggregate( + agg.getCluster(), + call.builder(), + traitSet, + newInput, + agg.getRowType(), + newInput.getRowType(), + groupSet, + auxGroupSet, + aggCallsWithoutAuxGroupCalls, + aggFunctions, + window, + inputTimeFieldIndex, + inputTimeIsDate, + agg.getNamedProperties()); + call.transformTo(windowAgg); + } + + private RelCollation createRelCollation(int[] groupSet, int timeIndex) { + List fields = new LinkedList<>(); + for (int value : groupSet) { + fields.add(FlinkRelOptUtil.ofRelFieldCollation(value)); + } + fields.add(FlinkRelOptUtil.ofRelFieldCollation(timeIndex)); + return RelCollations.of(fields); + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonGroupWindowAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonGroupWindowAggregate.scala new file mode 100644 index 0000000000000..f51ae66322064 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonGroupWindowAggregate.scala @@ -0,0 +1,252 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.nodes.physical.batch + +import java.util + +import org.apache.flink.api.dag.Transformation +import org.apache.flink.configuration.Configuration +import org.apache.flink.runtime.operators.DamBehavior +import org.apache.flink.streaming.api.operators.OneInputStreamOperator +import org.apache.flink.streaming.api.transformations.OneInputTransformation +import org.apache.flink.table.api.config.ExecutionConfigOptions +import org.apache.flink.table.data.RowData +import org.apache.flink.table.functions.UserDefinedFunction +import org.apache.flink.table.functions.python.PythonFunctionInfo +import org.apache.flink.table.planner.calcite.FlinkRelBuilder.PlannerNamedWindowProperty +import org.apache.flink.table.planner.calcite.FlinkTypeFactory +import org.apache.flink.table.planner.codegen.agg.batch.WindowCodeGenerator +import org.apache.flink.table.planner.delegation.BatchPlanner +import org.apache.flink.table.planner.expressions.{PlannerRowtimeAttribute, PlannerWindowEnd, PlannerWindowStart} +import org.apache.flink.table.planner.plan.cost.{FlinkCost, FlinkCostFactory} +import org.apache.flink.table.planner.plan.logical.LogicalWindow +import org.apache.flink.table.planner.plan.nodes.common.CommonPythonAggregate +import org.apache.flink.table.planner.plan.nodes.exec.{BatchExecNode, ExecNode} +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchExecPythonGroupWindowAggregate.ARROW_PYTHON_GROUP_WINDOW_AGGREGATE_FUNCTION_OPERATOR_NAME +import org.apache.flink.table.runtime.typeutils.InternalTypeInfo +import org.apache.flink.table.types.logical.RowType + +import org.apache.calcite.plan.{RelOptCluster, RelOptCost, RelOptPlanner, RelTraitSet} +import org.apache.calcite.rel.RelNode +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rel.core.AggregateCall +import org.apache.calcite.rel.metadata.RelMetadataQuery +import org.apache.calcite.tools.RelBuilder + +import scala.collection.JavaConversions._ + +/** + * Batch physical RelNode for group widow aggregate (Python user defined aggregate function). + */ +class BatchExecPythonGroupWindowAggregate( + cluster: RelOptCluster, + relBuilder: RelBuilder, + traitSet: RelTraitSet, + inputRel: RelNode, + outputRowType: RelDataType, + inputRowType: RelDataType, + grouping: Array[Int], + auxGrouping: Array[Int], + aggCalls: Seq[AggregateCall], + aggFunctions: Array[UserDefinedFunction], + window: LogicalWindow, + inputTimeFieldIndex: Int, + inputTimeIsDate: Boolean, + namedProperties: Seq[PlannerNamedWindowProperty]) + extends BatchExecWindowAggregateBase( + cluster, + traitSet, + inputRel, + outputRowType, + inputRowType, + grouping, + auxGrouping, + aggCalls.zip(aggFunctions), + window, + namedProperties, + false, + false, + true) + with BatchExecNode[RowData] + with CommonPythonAggregate { + + override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = { + new BatchExecPythonGroupWindowAggregate( + cluster, + relBuilder, + traitSet, + inputs.get(0), + outputRowType, + inputRowType, + grouping, + auxGrouping, + aggCalls, + aggFunctions, + window, + inputTimeFieldIndex, + inputTimeIsDate, + namedProperties) + } + + override def computeSelfCost(planner: RelOptPlanner, mq: RelMetadataQuery): RelOptCost = { + val inputRowCnt = mq.getRowCount(getInput) + if (inputRowCnt == null) { + return null + } + // does not take pane optimization into consideration here + // sort is not done here + val aggCallToAggFunction = aggCalls.zip(aggFunctions) + val cpu = FlinkCost.FUNC_CPU_COST * inputRowCnt * aggCallToAggFunction.size + val averageRowSize: Double = mq.getAverageRowSize(this) + val memCost = averageRowSize + val costFactory = planner.getCostFactory.asInstanceOf[FlinkCostFactory] + costFactory.makeCost(mq.getRowCount(this), cpu, 0, 0, memCost) + } + + //~ ExecNode methods ----------------------------------------------------------- + + override def getInputNodes: util.List[ExecNode[BatchPlanner, _]] = + List(getInput.asInstanceOf[ExecNode[BatchPlanner, _]]) + + override def replaceInputNode( + ordinalInParent: Int, + newInputNode: ExecNode[BatchPlanner, _]): Unit = { + replaceInput(ordinalInParent, newInputNode.asInstanceOf[RelNode]) + } + + override def getDamBehavior: DamBehavior = DamBehavior.PIPELINED + + override protected def translateToPlanInternal( + planner: BatchPlanner): Transformation[RowData] = { + val input = getInputNodes.get(0).translateToPlan(planner) + .asInstanceOf[Transformation[RowData]] + val outputType = FlinkTypeFactory.toLogicalRowType(getRowType) + val inputType = FlinkTypeFactory.toLogicalRowType(inputRowType) + + val (windowSize: Long, slideSize: Long) = WindowCodeGenerator.getWindowDef(window) + + val groupBufferLimitSize = planner.getTableConfig.getConfiguration.getInteger( + ExecutionConfigOptions.TABLE_EXEC_WINDOW_AGG_BUFFER_SIZE_LIMIT) + + val ret = createPythonOneInputTransformation( + input, + inputType, + outputType, + inputTimeFieldIndex, + groupBufferLimitSize, + windowSize, + slideSize, + getConfig(planner.getExecEnv, planner.getTableConfig)) + + if (isPythonWorkerUsingManagedMemory(planner.getTableConfig.getConfiguration)) { + ExecNode.setManagedMemoryWeight( + ret, getPythonWorkerMemory(planner.getTableConfig.getConfiguration).getBytes) + } + ret + } + + private[this] def createPythonOneInputTransformation( + inputTransform: Transformation[RowData], + inputRowType: RowType, + outputRowType: RowType, + inputTimeFieldIndex: Int, + maxLimitSize: Int, + windowSize: Long, + slideSize: Long, + config: Configuration): OneInputTransformation[RowData, RowData] = { + val namePropertyTypeArray = namedProperties.map { + case PlannerNamedWindowProperty(_, p) => p match { + case PlannerWindowStart(_) => 0 + case PlannerWindowEnd(_) => 1 + case PlannerRowtimeAttribute(_) => 2 + } + }.toArray + + val (pythonUdafInputOffsets, pythonFunctionInfos) = + extractPythonAggregateFunctionInfos(aggCalls) + + val pythonOperator = getPythonGroupWindowAggregateFunctionOperator( + config, + inputRowType, + outputRowType, + inputTimeFieldIndex, + maxLimitSize, + windowSize, + slideSize, + namePropertyTypeArray, + pythonUdafInputOffsets, + pythonFunctionInfos) + + new OneInputTransformation( + inputTransform, + "BatchExecPythonGroupWindowAggregate", + pythonOperator, + InternalTypeInfo.of(outputRowType), + inputTransform.getParallelism) + } + + private[this] def getPythonGroupWindowAggregateFunctionOperator( + config: Configuration, + inputRowType: RowType, + outputRowType: RowType, + inputTimeFieldIndex: Int, + maxLimitSize: Int, + windowSize: Long, + slideSize: Long, + namedProperties: Array[Int], + udafInputOffsets: Array[Int], + pythonFunctionInfos: Array[PythonFunctionInfo]): OneInputStreamOperator[RowData, RowData] = { + val clazz = loadClass(ARROW_PYTHON_GROUP_WINDOW_AGGREGATE_FUNCTION_OPERATOR_NAME) + + val ctor = clazz.getConstructor( + classOf[Configuration], + classOf[Array[PythonFunctionInfo]], + classOf[RowType], + classOf[RowType], + classOf[Int], + classOf[Int], + classOf[Long], + classOf[Long], + classOf[Array[Int]], + classOf[Array[Int]], + classOf[Array[Int]], + classOf[Array[Int]]) + + ctor.newInstance( + config, + pythonFunctionInfos, + inputRowType, + outputRowType, + Integer.valueOf(inputTimeFieldIndex), + Integer.valueOf(maxLimitSize), + java.lang.Long.valueOf(windowSize), + java.lang.Long.valueOf(slideSize), + namedProperties, + grouping, + grouping ++ auxGrouping, + udafInputOffsets) + .asInstanceOf[OneInputStreamOperator[RowData, RowData]] + } +} + +object BatchExecPythonGroupWindowAggregate { + val ARROW_PYTHON_GROUP_WINDOW_AGGREGATE_FUNCTION_OPERATOR_NAME: String = + "org.apache.flink.table.runtime.operators.python.aggregate.arrow.batch." + + "BatchArrowPythonGroupWindowAggregateFunctionOperator" +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala index 3414aea7e5a9c..02769354ee3fa 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala @@ -419,6 +419,7 @@ object FlinkBatchRuleSets { BatchExecOverAggregateRule.INSTANCE, // window agg BatchExecWindowAggregateRule.INSTANCE, + BatchExecPythonWindowAggregateRule.INSTANCE, // join BatchExecHashJoinRule.INSTANCE, BatchExecSortMergeJoinRule.INSTANCE, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecWindowAggregateRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecWindowAggregateRule.scala index 847f29b7b1713..06c25bd479d03 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecWindowAggregateRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecWindowAggregateRule.scala @@ -30,6 +30,7 @@ import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalWindowAggre import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecHashWindowAggregate, BatchExecLocalHashWindowAggregate, BatchExecLocalSortWindowAggregate, BatchExecSortWindowAggregate} import org.apache.flink.table.planner.plan.utils.AggregateUtil import org.apache.flink.table.planner.plan.utils.AggregateUtil.hasTimeIntervalType +import org.apache.flink.table.planner.plan.utils.PythonUtil.isPythonAggregate import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType import org.apache.flink.table.types.logical.{BigIntType, IntType, LogicalType} @@ -89,7 +90,7 @@ class BatchExecWindowAggregateRule throw new TableException("GROUPING SETS are currently not supported.") } - true + !agg.getAggCallList.exists(isPythonAggregate(_)) } override def onMatch(call: RelOptRuleCall): Unit = { diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/batch/table/PythonGroupWindowAggregateTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/batch/table/PythonGroupWindowAggregateTest.xml new file mode 100644 index 0000000000000..6fa63d57590e0 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/batch/table/PythonGroupWindowAggregateTest.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/batch/table/PythonGroupWindowAggregateTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/batch/table/PythonGroupWindowAggregateTest.scala new file mode 100644 index 0000000000000..53ef7c25a3906 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/batch/table/PythonGroupWindowAggregateTest.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.batch.table + +import org.apache.flink.api.scala._ +import org.apache.flink.table.api._ +import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedAggFunctions.PandasAggregateFunction +import org.apache.flink.table.planner.utils.TableTestBase + +import org.junit.Test + +class PythonGroupWindowAggregateTest extends TableTestBase { + + @Test + def testPandasEventTimeTumblingGroupWindowOverTime(): Unit = { + val util = batchTestUtil() + val sourceTable = util.addTableSource[(Int, Long, Int, Long)]( + "MyTable", 'a, 'b, 'c, 'rowtime.rowtime) + val func = new PandasAggregateFunction + + val resultTable = sourceTable + .window(Tumble over 5.millis on 'rowtime as 'w) + .groupBy('w, 'b) + .select('b, 'w.start,'w.end, func('a, 'c)) + + util.verifyPlan(resultTable) + } + + @Test(expected = classOf[TableException]) + def testPandasEventTimeTumblingGroupWindowOverCount(): Unit = { + val util = batchTestUtil() + val sourceTable = util.addTableSource[(Int, Long, Int, Long)]( + "MyTable", 'a, 'b, 'c, 'rowtime.rowtime) + val func = new PandasAggregateFunction + + val resultTable = sourceTable + .window(Tumble over 2.rows on 'rowtime as 'w) + .groupBy('w, 'b) + .select('b, func('a, 'c)) + + util.verifyPlan(resultTable) + } + + @Test + def testPandasEventTimeSlidingGroupWindowOverTime(): Unit = { + val util = batchTestUtil() + val sourceTable = util.addTableSource[(Int, Long, Int, Long)]( + "MyTable", 'a, 'b, 'c, 'rowtime.rowtime) + val func = new PandasAggregateFunction + + val resultTable = sourceTable + .window(Slide over 5.millis every 2.millis on 'rowtime as 'w) + .groupBy('w, 'b) + .select('b, 'w.start,'w.end, func('a, 'c)) + + util.verifyPlan(resultTable) + } + + @Test(expected = classOf[TableException]) + def testPandasEventTimeSlidingGroupWindowOverCount(): Unit = { + val util = batchTestUtil() + val sourceTable = util.addTableSource[(Int, Long, Int, Long)]( + "MyTable", 'a, 'b, 'c, 'rowtime.rowtime) + val func = new PandasAggregateFunction + + val resultTable = sourceTable + .window(Slide over 5.rows every 2.rows on 'rowtime as 'w) + .groupBy('w, 'b) + .select('b, func('a, 'c)) + + util.verifyPlan(resultTable) + } + + @Test + def testPandasGroupWindowAggregateWithoutKeys(): Unit = { + val util = batchTestUtil() + val sourceTable = util.addTableSource[(Int, Long, Int, Long)]( + "MyTable", 'a, 'b, 'c, 'rowtime.rowtime) + val func = new PandasAggregateFunction + + val resultTable = sourceTable + .window(Slide over 5.millis every 2.millis on 'rowtime as 'w) + .groupBy('w) + .select('w.start,'w.end, func('a, 'c)) + + util.verifyPlan(resultTable) + } +}