From 4e4a848c2759577464f4c11c4ea938c7d931f214 Mon Sep 17 00:00:00 2001 From: Rui Wang Date: Tue, 11 Oct 2022 12:35:08 +0800 Subject: [PATCH] [SPARK-40707][CONNECT] Add groupby to connect DSL and test more than one grouping expressions ### What changes were proposed in this pull request? 1. Add `groupby` to connect DSL and test more than one grouping expressions 2. Pass limited data types through connect proto for LocalRelation's attributes. 3. Cleanup unused `Trait` in the testing code. ### Why are the changes needed? Enhance connect's support for GROUP BY. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT Closes #38155 from amaliujia/support_more_than_one_grouping_set. Authored-by: Rui Wang Signed-off-by: Wenchen Fan --- .../protobuf/spark/connect/commands.proto | 4 +- .../protobuf/spark/connect/expressions.proto | 7 +-- .../protobuf/spark/connect/relations.proto | 18 +------- .../main/protobuf/spark/connect/types.proto | 12 ++--- .../spark/sql/connect/dsl/package.scala | 13 ++++++ .../planner/DataTypeProtoConverter.scala | 46 +++++++++++++++++++ .../connect/planner/SparkConnectPlanner.scala | 19 ++++---- .../planner/SparkConnectPlannerSuite.scala | 20 ++------ .../planner/SparkConnectProtoSuite.scala | 19 ++++++-- 9 files changed, 101 insertions(+), 57 deletions(-) create mode 100644 connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala diff --git a/connector/connect/src/main/protobuf/spark/connect/commands.proto b/connector/connect/src/main/protobuf/spark/connect/commands.proto index 425857b842e56..0a83e4543f5ec 100644 --- a/connector/connect/src/main/protobuf/spark/connect/commands.proto +++ b/connector/connect/src/main/protobuf/spark/connect/commands.proto @@ -44,8 +44,8 @@ message CreateScalarFunction { repeated string parts = 1; FunctionLanguage language = 2; bool temporary = 3; - repeated Type argument_types = 4; - Type return_type = 5; + repeated DataType argument_types = 4; + DataType return_type = 5; // How the function body is defined: oneof function_definition { diff --git a/connector/connect/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/src/main/protobuf/spark/connect/expressions.proto index 9b3029a32b0a7..791b1b5887b74 100644 --- a/connector/connect/src/main/protobuf/spark/connect/expressions.proto +++ b/connector/connect/src/main/protobuf/spark/connect/expressions.proto @@ -65,10 +65,10 @@ message Expression { // Timestamp in units of microseconds since the UNIX epoch. int64 timestamp_tz = 27; bytes uuid = 28; - Type null = 29; // a typed null literal + DataType null = 29; // a typed null literal List list = 30; - Type.List empty_list = 31; - Type.Map empty_map = 32; + DataType.List empty_list = 31; + DataType.Map empty_map = 32; UserDefined user_defined = 33; } @@ -164,5 +164,6 @@ message Expression { // by the analyzer. message QualifiedAttribute { string name = 1; + DataType type = 2; } } diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto b/connector/connect/src/main/protobuf/spark/connect/relations.proto index 25bc4e8a16b18..30f36fa6ceb52 100644 --- a/connector/connect/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto @@ -130,22 +130,8 @@ message Fetch { // Relation of type [[Aggregate]]. message Aggregate { Relation input = 1; - - // Grouping sets are used in rollups - repeated GroupingSet grouping_sets = 2; - - // Measures - repeated Measure measures = 3; - - message GroupingSet { - repeated Expression aggregate_expressions = 1; - } - - message Measure { - AggregateFunction function = 1; - // Conditional filter for SUM(x FILTER WHERE x < 10) - Expression filter = 2; - } + repeated Expression grouping_expressions = 2; + repeated AggregateFunction result_expressions = 3; message AggregateFunction { string name = 1; diff --git a/connector/connect/src/main/protobuf/spark/connect/types.proto b/connector/connect/src/main/protobuf/spark/connect/types.proto index c46afa2afc651..98b0c48b1e016 100644 --- a/connector/connect/src/main/protobuf/spark/connect/types.proto +++ b/connector/connect/src/main/protobuf/spark/connect/types.proto @@ -22,9 +22,9 @@ package spark.connect; option java_multiple_files = true; option java_package = "org.apache.spark.connect.proto"; -// This message describes the logical [[Type]] of something. It does not carry the value +// This message describes the logical [[DataType]] of something. It does not carry the value // itself but only describes it. -message Type { +message DataType { oneof kind { Boolean bool = 1; I8 i8 = 2; @@ -168,20 +168,20 @@ message Type { } message Struct { - repeated Type types = 1; + repeated DataType types = 1; uint32 type_variation_reference = 2; Nullability nullability = 3; } message List { - Type type = 1; + DataType DataType = 1; uint32 type_variation_reference = 2; Nullability nullability = 3; } message Map { - Type key = 1; - Type value = 2; + DataType key = 1; + DataType value = 2; uint32 type_variation_reference = 3; Nullability nullability = 4; } diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 234b423a80316..3ccf71c26b744 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -67,6 +67,19 @@ package object dsl { } relation.setJoin(join).build() } + + def groupBy( + groupingExprs: proto.Expression*)(aggregateExprs: proto.Expression*): proto.Relation = { + val agg = proto.Aggregate.newBuilder() + agg.setInput(logicalPlan) + + for (groupingExpr <- groupingExprs) { + agg.addGroupingExpressions(groupingExpr) + } + // TODO: support aggregateExprs, which is blocked by supporting any builtin function + // resolution only by name in the analyzer. + proto.Relation.newBuilder().setAggregate(agg.build()).build() + } } } } diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala new file mode 100644 index 0000000000000..b31855bfca993 --- /dev/null +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.planner + +import org.apache.spark.connect.proto +import org.apache.spark.sql.types.{DataType, IntegerType, StringType} + +/** + * This object offers methods to convert to/from connect proto to catalyst types. + */ +object DataTypeProtoConverter { + def toCatalystType(t: proto.DataType): DataType = { + t.getKindCase match { + case proto.DataType.KindCase.I32 => IntegerType + case proto.DataType.KindCase.STRING => StringType + case _ => + throw InvalidPlanInput(s"Does not support convert ${t.getKindCase} to catalyst types.") + } + } + + def toConnectProtoType(t: DataType): proto.DataType = { + t match { + case IntegerType => + proto.DataType.newBuilder().setI32(proto.DataType.I32.getDefaultInstance).build() + case StringType => + proto.DataType.newBuilder().setString(proto.DataType.String.getDefaultInstance).build() + case _ => + throw InvalidPlanInput(s"Does not support convert ${t.typeName} to connect proto types.") + } + } +} diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index e3bb7e2932273..66560f5e62f6f 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -77,8 +77,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { } private def transformAttribute(exp: proto.Expression.QualifiedAttribute): Attribute = { - // TODO: use data type from the proto. - AttributeReference(exp.getName, IntegerType)() + AttributeReference(exp.getName, DataTypeProtoConverter.toCatalystType(exp.getType))() } private def transformReadRel( @@ -271,11 +270,9 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { private def transformAggregate(rel: proto.Aggregate): LogicalPlan = { assert(rel.hasInput) - assert(rel.getGroupingSetsCount == 1, "Only one grouping set is supported") - val groupingSet = rel.getGroupingSetsList.asScala.take(1) - val ge = groupingSet - .flatMap(f => f.getAggregateExpressionsList.asScala) + val groupingExprs = + rel.getGroupingExpressionsList.asScala .map(transformExpression) .map { case x @ UnresolvedAttribute(_) => x @@ -284,18 +281,18 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { logical.Aggregate( child = transformRelation(rel.getInput), - groupingExpressions = ge.toSeq, + groupingExpressions = groupingExprs.toSeq, aggregateExpressions = - (rel.getMeasuresList.asScala.map(transformAggregateExpression) ++ ge).toSeq) + rel.getResultExpressionsList.asScala.map(transformAggregateExpression).toSeq) } private def transformAggregateExpression( - exp: proto.Aggregate.Measure): expressions.NamedExpression = { - val fun = exp.getFunction.getName + exp: proto.Aggregate.AggregateFunction): expressions.NamedExpression = { + val fun = exp.getName UnresolvedAlias( UnresolvedFunction( name = fun, - arguments = exp.getFunction.getArgumentsList.asScala.map(transformExpression).toSeq, + arguments = exp.getArgumentsList.asScala.map(transformExpression).toSeq, isDistinct = false)) } diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index 37d80e01f72b4..10e17f121f0e5 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -45,11 +45,6 @@ trait SparkConnectPlanTest { .build() } -trait SparkConnectSessionTest { - protected var spark: SparkSession - -} - /** * This is a rudimentary test class for SparkConnect. The main goal of these basic tests is to * ensure that the transformation from Proto to LogicalPlan works and that the right nodes are @@ -222,16 +217,11 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { val agg = proto.Aggregate.newBuilder .setInput(readRel) - .addAllMeasures( - Seq( - proto.Aggregate.Measure.newBuilder - .setFunction(proto.Aggregate.AggregateFunction.newBuilder - .setName("sum") - .addArguments(unresolvedAttribute)) - .build()).asJava) - .addGroupingSets(proto.Aggregate.GroupingSet.newBuilder - .addAggregateExpressions(unresolvedAttribute) - .build()) + .addResultExpressions( + proto.Aggregate.AggregateFunction.newBuilder + .setName("sum") + .addArguments(unresolvedAttribute)) + .addGroupingExpressions(unresolvedAttribute) .build() val res = transform(proto.Relation.newBuilder.setAggregate(agg).build()) diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 4f3f0fea387e0..441a3a9f1e41f 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -31,11 +31,11 @@ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation */ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { - lazy val connectTestRelation = createLocalRelationProto(Seq($"id".int)) + lazy val connectTestRelation = createLocalRelationProto(Seq($"id".int, $"name".string)) lazy val connectTestRelation2 = createLocalRelationProto(Seq($"key".int, $"value".int)) - lazy val sparkTestRelation: LocalRelation = LocalRelation($"id".int) + lazy val sparkTestRelation: LocalRelation = LocalRelation($"id".int, $"name".string) lazy val sparkTestRelation2: LocalRelation = LocalRelation($"key".int, $"value".int) @@ -81,12 +81,23 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { } } + test("Aggregate with more than 1 grouping expressions") { + val connectPlan = { + import org.apache.spark.sql.connect.dsl.expressions._ + import org.apache.spark.sql.connect.dsl.plans._ + transform(connectTestRelation.groupBy("id".protoAttr, "name".protoAttr)()) + } + val sparkPlan = sparkTestRelation.groupBy($"id", $"name")() + comparePlans(connectPlan.analyze, sparkPlan.analyze, false) + } + private def createLocalRelationProto(attrs: Seq[AttributeReference]): proto.Relation = { val localRelationBuilder = proto.LocalRelation.newBuilder() - // TODO: set data types for each local relation attribute one proto supports data type. for (attr <- attrs) { localRelationBuilder.addAttributes( - proto.Expression.QualifiedAttribute.newBuilder().setName(attr.name).build() + proto.Expression.QualifiedAttribute.newBuilder() + .setName(attr.name) + .setType(DataTypeProtoConverter.toConnectProtoType(attr.dataType)) ) } proto.Relation.newBuilder().setLocalRelation(localRelationBuilder.build()).build()