Skip to content

Commit

Permalink
[SPARK-40707][CONNECT] Add groupby to connect DSL and test more than …
Browse files Browse the repository at this point in the history
…one grouping expressions

### What changes were proposed in this pull request?

1. Add `groupby` to connect DSL and test more than one grouping expressions
2. Pass limited data types through connect proto for LocalRelation's attributes.
3. Cleanup unused `Trait` in the testing code.

### Why are the changes needed?

Enhance connect's support for GROUP BY.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

UT

Closes #38155 from amaliujia/support_more_than_one_grouping_set.

Authored-by: Rui Wang <rui.wang@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
amaliujia authored and cloud-fan committed Oct 11, 2022
1 parent d59f71c commit 4e4a848
Show file tree
Hide file tree
Showing 9 changed files with 101 additions and 57 deletions.
Expand Up @@ -44,8 +44,8 @@ message CreateScalarFunction {
repeated string parts = 1;
FunctionLanguage language = 2;
bool temporary = 3;
repeated Type argument_types = 4;
Type return_type = 5;
repeated DataType argument_types = 4;
DataType return_type = 5;

// How the function body is defined:
oneof function_definition {
Expand Down
Expand Up @@ -65,10 +65,10 @@ message Expression {
// Timestamp in units of microseconds since the UNIX epoch.
int64 timestamp_tz = 27;
bytes uuid = 28;
Type null = 29; // a typed null literal
DataType null = 29; // a typed null literal
List list = 30;
Type.List empty_list = 31;
Type.Map empty_map = 32;
DataType.List empty_list = 31;
DataType.Map empty_map = 32;
UserDefined user_defined = 33;
}

Expand Down Expand Up @@ -164,5 +164,6 @@ message Expression {
// by the analyzer.
message QualifiedAttribute {
string name = 1;
DataType type = 2;
}
}
18 changes: 2 additions & 16 deletions connector/connect/src/main/protobuf/spark/connect/relations.proto
Expand Up @@ -130,22 +130,8 @@ message Fetch {
// Relation of type [[Aggregate]].
message Aggregate {
Relation input = 1;

// Grouping sets are used in rollups
repeated GroupingSet grouping_sets = 2;

// Measures
repeated Measure measures = 3;

message GroupingSet {
repeated Expression aggregate_expressions = 1;
}

message Measure {
AggregateFunction function = 1;
// Conditional filter for SUM(x FILTER WHERE x < 10)
Expression filter = 2;
}
repeated Expression grouping_expressions = 2;
repeated AggregateFunction result_expressions = 3;

message AggregateFunction {
string name = 1;
Expand Down
12 changes: 6 additions & 6 deletions connector/connect/src/main/protobuf/spark/connect/types.proto
Expand Up @@ -22,9 +22,9 @@ package spark.connect;
option java_multiple_files = true;
option java_package = "org.apache.spark.connect.proto";

// This message describes the logical [[Type]] of something. It does not carry the value
// This message describes the logical [[DataType]] of something. It does not carry the value
// itself but only describes it.
message Type {
message DataType {
oneof kind {
Boolean bool = 1;
I8 i8 = 2;
Expand Down Expand Up @@ -168,20 +168,20 @@ message Type {
}

message Struct {
repeated Type types = 1;
repeated DataType types = 1;
uint32 type_variation_reference = 2;
Nullability nullability = 3;
}

message List {
Type type = 1;
DataType DataType = 1;
uint32 type_variation_reference = 2;
Nullability nullability = 3;
}

message Map {
Type key = 1;
Type value = 2;
DataType key = 1;
DataType value = 2;
uint32 type_variation_reference = 3;
Nullability nullability = 4;
}
Expand Down
Expand Up @@ -67,6 +67,19 @@ package object dsl {
}
relation.setJoin(join).build()
}

def groupBy(
groupingExprs: proto.Expression*)(aggregateExprs: proto.Expression*): proto.Relation = {
val agg = proto.Aggregate.newBuilder()
agg.setInput(logicalPlan)

for (groupingExpr <- groupingExprs) {
agg.addGroupingExpressions(groupingExpr)
}
// TODO: support aggregateExprs, which is blocked by supporting any builtin function
// resolution only by name in the analyzer.
proto.Relation.newBuilder().setAggregate(agg.build()).build()
}
}
}
}
@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connect.planner

import org.apache.spark.connect.proto
import org.apache.spark.sql.types.{DataType, IntegerType, StringType}

/**
* This object offers methods to convert to/from connect proto to catalyst types.
*/
object DataTypeProtoConverter {
def toCatalystType(t: proto.DataType): DataType = {
t.getKindCase match {
case proto.DataType.KindCase.I32 => IntegerType
case proto.DataType.KindCase.STRING => StringType
case _ =>
throw InvalidPlanInput(s"Does not support convert ${t.getKindCase} to catalyst types.")
}
}

def toConnectProtoType(t: DataType): proto.DataType = {
t match {
case IntegerType =>
proto.DataType.newBuilder().setI32(proto.DataType.I32.getDefaultInstance).build()
case StringType =>
proto.DataType.newBuilder().setString(proto.DataType.String.getDefaultInstance).build()
case _ =>
throw InvalidPlanInput(s"Does not support convert ${t.typeName} to connect proto types.")
}
}
}
Expand Up @@ -77,8 +77,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
}

private def transformAttribute(exp: proto.Expression.QualifiedAttribute): Attribute = {
// TODO: use data type from the proto.
AttributeReference(exp.getName, IntegerType)()
AttributeReference(exp.getName, DataTypeProtoConverter.toCatalystType(exp.getType))()
}

private def transformReadRel(
Expand Down Expand Up @@ -271,11 +270,9 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {

private def transformAggregate(rel: proto.Aggregate): LogicalPlan = {
assert(rel.hasInput)
assert(rel.getGroupingSetsCount == 1, "Only one grouping set is supported")

val groupingSet = rel.getGroupingSetsList.asScala.take(1)
val ge = groupingSet
.flatMap(f => f.getAggregateExpressionsList.asScala)
val groupingExprs =
rel.getGroupingExpressionsList.asScala
.map(transformExpression)
.map {
case x @ UnresolvedAttribute(_) => x
Expand All @@ -284,18 +281,18 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {

logical.Aggregate(
child = transformRelation(rel.getInput),
groupingExpressions = ge.toSeq,
groupingExpressions = groupingExprs.toSeq,
aggregateExpressions =
(rel.getMeasuresList.asScala.map(transformAggregateExpression) ++ ge).toSeq)
rel.getResultExpressionsList.asScala.map(transformAggregateExpression).toSeq)
}

private def transformAggregateExpression(
exp: proto.Aggregate.Measure): expressions.NamedExpression = {
val fun = exp.getFunction.getName
exp: proto.Aggregate.AggregateFunction): expressions.NamedExpression = {
val fun = exp.getName
UnresolvedAlias(
UnresolvedFunction(
name = fun,
arguments = exp.getFunction.getArgumentsList.asScala.map(transformExpression).toSeq,
arguments = exp.getArgumentsList.asScala.map(transformExpression).toSeq,
isDistinct = false))
}

Expand Down
Expand Up @@ -45,11 +45,6 @@ trait SparkConnectPlanTest {
.build()
}

trait SparkConnectSessionTest {
protected var spark: SparkSession

}

/**
* This is a rudimentary test class for SparkConnect. The main goal of these basic tests is to
* ensure that the transformation from Proto to LogicalPlan works and that the right nodes are
Expand Down Expand Up @@ -222,16 +217,11 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {

val agg = proto.Aggregate.newBuilder
.setInput(readRel)
.addAllMeasures(
Seq(
proto.Aggregate.Measure.newBuilder
.setFunction(proto.Aggregate.AggregateFunction.newBuilder
.setName("sum")
.addArguments(unresolvedAttribute))
.build()).asJava)
.addGroupingSets(proto.Aggregate.GroupingSet.newBuilder
.addAggregateExpressions(unresolvedAttribute)
.build())
.addResultExpressions(
proto.Aggregate.AggregateFunction.newBuilder
.setName("sum")
.addArguments(unresolvedAttribute))
.addGroupingExpressions(unresolvedAttribute)
.build()

val res = transform(proto.Relation.newBuilder.setAggregate(agg).build())
Expand Down
Expand Up @@ -31,11 +31,11 @@ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
*/
class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {

lazy val connectTestRelation = createLocalRelationProto(Seq($"id".int))
lazy val connectTestRelation = createLocalRelationProto(Seq($"id".int, $"name".string))

lazy val connectTestRelation2 = createLocalRelationProto(Seq($"key".int, $"value".int))

lazy val sparkTestRelation: LocalRelation = LocalRelation($"id".int)
lazy val sparkTestRelation: LocalRelation = LocalRelation($"id".int, $"name".string)

lazy val sparkTestRelation2: LocalRelation = LocalRelation($"key".int, $"value".int)

Expand Down Expand Up @@ -81,12 +81,23 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
}
}

test("Aggregate with more than 1 grouping expressions") {
val connectPlan = {
import org.apache.spark.sql.connect.dsl.expressions._
import org.apache.spark.sql.connect.dsl.plans._
transform(connectTestRelation.groupBy("id".protoAttr, "name".protoAttr)())
}
val sparkPlan = sparkTestRelation.groupBy($"id", $"name")()
comparePlans(connectPlan.analyze, sparkPlan.analyze, false)
}

private def createLocalRelationProto(attrs: Seq[AttributeReference]): proto.Relation = {
val localRelationBuilder = proto.LocalRelation.newBuilder()
// TODO: set data types for each local relation attribute one proto supports data type.
for (attr <- attrs) {
localRelationBuilder.addAttributes(
proto.Expression.QualifiedAttribute.newBuilder().setName(attr.name).build()
proto.Expression.QualifiedAttribute.newBuilder()
.setName(attr.name)
.setType(DataTypeProtoConverter.toConnectProtoType(attr.dataType))
)
}
proto.Relation.newBuilder().setLocalRelation(localRelationBuilder.build()).build()
Expand Down

0 comments on commit 4e4a848

Please sign in to comment.