From cbd1dbbff5cd625ba7749d3b90cc3a385d6538f8 Mon Sep 17 00:00:00 2001 From: Rui Wang Date: Mon, 31 Oct 2022 00:11:53 -0700 Subject: [PATCH] [SPARK-40971][CONNECT][DSL] Imports more from connect proto package to avoid calling `proto.` for Connect DSL. --- .../spark/sql/connect/dsl/package.scala | 173 +++++++++--------- 1 file changed, 82 insertions(+), 91 deletions(-) diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 9ffc4c4a1feea..3ba773e4c04fd 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.connect import scala.collection.JavaConverters._ import scala.language.implicitConversions -import org.apache.spark.connect.proto +import org.apache.spark.connect.proto._ import org.apache.spark.connect.proto.Join.JoinType import org.apache.spark.connect.proto.SetOperation.SetOpType import org.apache.spark.sql.SaveMode @@ -36,55 +36,53 @@ package object dsl { object expressions { // scalastyle:ignore implicit class DslString(val s: String) { - def protoAttr: proto.Expression = - proto.Expression + def protoAttr: Expression = + Expression .newBuilder() .setUnresolvedAttribute( - proto.Expression.UnresolvedAttribute + Expression.UnresolvedAttribute .newBuilder() .setUnparsedIdentifier(s)) .build() - def struct( - attrs: proto.Expression.QualifiedAttribute*): proto.Expression.QualifiedAttribute = { - val structExpr = proto.DataType.Struct.newBuilder() + def struct(attrs: Expression.QualifiedAttribute*): Expression.QualifiedAttribute = { + val structExpr = DataType.Struct.newBuilder() for (attr <- attrs) { - val structField = proto.DataType.StructField.newBuilder() + val structField = DataType.StructField.newBuilder() structField.setName(attr.getName) structField.setType(attr.getType) structExpr.addFields(structField) } - proto.Expression.QualifiedAttribute + Expression.QualifiedAttribute .newBuilder() .setName(s) - .setType(proto.DataType.newBuilder().setStruct(structExpr)) + .setType(DataType.newBuilder().setStruct(structExpr)) .build() } /** Creates a new AttributeReference of type int */ - def int: proto.Expression.QualifiedAttribute = protoQualifiedAttrWithType( - proto.DataType.newBuilder().setI32(proto.DataType.I32.newBuilder()).build()) + def int: Expression.QualifiedAttribute = protoQualifiedAttrWithType( + DataType.newBuilder().setI32(DataType.I32.newBuilder()).build()) - private def protoQualifiedAttrWithType( - dataType: proto.DataType): proto.Expression.QualifiedAttribute = - proto.Expression.QualifiedAttribute + private def protoQualifiedAttrWithType(dataType: DataType): Expression.QualifiedAttribute = + Expression.QualifiedAttribute .newBuilder() .setName(s) .setType(dataType) .build() } - implicit class DslExpression(val expr: proto.Expression) { - def as(alias: String): proto.Expression = proto.Expression + implicit class DslExpression(val expr: Expression) { + def as(alias: String): Expression = Expression .newBuilder() - .setAlias(proto.Expression.Alias.newBuilder().setName(alias).setExpr(expr)) + .setAlias(Expression.Alias.newBuilder().setName(alias).setExpr(expr)) .build() - def <(other: proto.Expression): proto.Expression = - proto.Expression + def <(other: Expression): Expression = + Expression .newBuilder() .setUnresolvedFunction( - proto.Expression.UnresolvedFunction + Expression.UnresolvedFunction .newBuilder() .addParts("<") .addArguments(expr) @@ -100,11 +98,11 @@ package object dsl { * @return * Expression wrapping the unresolved function. */ - def callFunction(nameParts: Seq[String], args: Seq[proto.Expression]): proto.Expression = { - proto.Expression + def callFunction(nameParts: Seq[String], args: Seq[Expression]): Expression = { + Expression .newBuilder() .setUnresolvedFunction( - proto.Expression.UnresolvedFunction + Expression.UnresolvedFunction .newBuilder() .addAllParts(nameParts.asJava) .addAllArguments(args.asJava)) @@ -119,26 +117,26 @@ package object dsl { * @return * Expression wrapping the unresolved function. */ - def callFunction(name: String, args: Seq[proto.Expression]): proto.Expression = { - proto.Expression + def callFunction(name: String, args: Seq[Expression]): Expression = { + Expression .newBuilder() .setUnresolvedFunction( - proto.Expression.UnresolvedFunction + Expression.UnresolvedFunction .newBuilder() .addParts(name) .addAllArguments(args.asJava)) .build() } - implicit def intToLiteral(i: Int): proto.Expression = - proto.Expression + implicit def intToLiteral(i: Int): Expression = + Expression .newBuilder() - .setLiteral(proto.Expression.Literal.newBuilder().setI32(i)) + .setLiteral(Expression.Literal.newBuilder().setI32(i)) .build() } object commands { // scalastyle:ignore - implicit class DslCommands(val logicalPlan: proto.Relation) { + implicit class DslCommands(val logicalPlan: Relation) { def write( format: Option[String] = None, path: Option[String] = None, @@ -147,8 +145,8 @@ package object dsl { sortByColumns: Seq[String] = Seq.empty, partitionByCols: Seq[String] = Seq.empty, bucketByCols: Seq[String] = Seq.empty, - numBuckets: Option[Int] = None): proto.Command = { - val writeOp = proto.WriteOperation.newBuilder() + numBuckets: Option[Int] = None): Command = { + val writeOp = WriteOperation.newBuilder() format.foreach(writeOp.setSource(_)) mode @@ -165,24 +163,24 @@ package object dsl { partitionByCols.foreach(writeOp.addPartitioningColumns(_)) if (numBuckets.nonEmpty && bucketByCols.nonEmpty) { - val op = proto.WriteOperation.BucketBy.newBuilder() + val op = WriteOperation.BucketBy.newBuilder() numBuckets.foreach(op.setNumBuckets(_)) bucketByCols.foreach(op.addBucketColumnNames(_)) writeOp.setBucketBy(op.build()) } writeOp.setInput(logicalPlan) - proto.Command.newBuilder().setWriteOperation(writeOp.build()).build() + Command.newBuilder().setWriteOperation(writeOp.build()).build() } } } object plans { // scalastyle:ignore - implicit class DslLogicalPlan(val logicalPlan: proto.Relation) { - def select(exprs: proto.Expression*): proto.Relation = { - proto.Relation + implicit class DslLogicalPlan(val logicalPlan: Relation) { + def select(exprs: Expression*): Relation = { + Relation .newBuilder() .setProject( - proto.Project + Project .newBuilder() .setInput(logicalPlan) .addAllExpressions(exprs.toIterable.asJava) @@ -190,88 +188,85 @@ package object dsl { .build() } - def limit(limit: Int): proto.Relation = { - proto.Relation + def limit(limit: Int): Relation = { + Relation .newBuilder() .setLimit( - proto.Limit + Limit .newBuilder() .setInput(logicalPlan) .setLimit(limit)) .build() } - def offset(offset: Int): proto.Relation = { - proto.Relation + def offset(offset: Int): Relation = { + Relation .newBuilder() .setOffset( - proto.Offset + Offset .newBuilder() .setInput(logicalPlan) .setOffset(offset)) .build() } - def where(condition: proto.Expression): proto.Relation = { - proto.Relation + def where(condition: Expression): Relation = { + Relation .newBuilder() - .setFilter(proto.Filter.newBuilder().setInput(logicalPlan).setCondition(condition)) + .setFilter(Filter.newBuilder().setInput(logicalPlan).setCondition(condition)) .build() } - def deduplicate(colNames: Seq[String]): proto.Relation = - proto.Relation + def deduplicate(colNames: Seq[String]): Relation = + Relation .newBuilder() .setDeduplicate( - proto.Deduplicate + Deduplicate .newBuilder() .setInput(logicalPlan) .addAllColumnNames(colNames.asJava)) .build() - def distinct(): proto.Relation = - proto.Relation + def distinct(): Relation = + Relation .newBuilder() .setDeduplicate( - proto.Deduplicate + Deduplicate .newBuilder() .setInput(logicalPlan) .setAllColumnsAsKeys(true)) .build() def join( - otherPlan: proto.Relation, + otherPlan: Relation, joinType: JoinType, - condition: Option[proto.Expression]): proto.Relation = { + condition: Option[Expression]): Relation = { join(otherPlan, joinType, Seq(), condition) } - def join(otherPlan: proto.Relation, condition: Option[proto.Expression]): proto.Relation = { + def join(otherPlan: Relation, condition: Option[Expression]): Relation = { join(otherPlan, JoinType.JOIN_TYPE_INNER, Seq(), condition) } - def join(otherPlan: proto.Relation): proto.Relation = { + def join(otherPlan: Relation): Relation = { join(otherPlan, JoinType.JOIN_TYPE_INNER, Seq(), None) } - def join(otherPlan: proto.Relation, joinType: JoinType): proto.Relation = { + def join(otherPlan: Relation, joinType: JoinType): Relation = { join(otherPlan, joinType, Seq(), None) } - def join( - otherPlan: proto.Relation, - joinType: JoinType, - usingColumns: Seq[String]): proto.Relation = { + def join(otherPlan: Relation, joinType: JoinType, usingColumns: Seq[String]): Relation = { join(otherPlan, joinType, usingColumns, None) } private def join( - otherPlan: proto.Relation, + otherPlan: Relation, joinType: JoinType = JoinType.JOIN_TYPE_INNER, usingColumns: Seq[String], - condition: Option[proto.Expression]): proto.Relation = { - val relation = proto.Relation.newBuilder() - val join = proto.Join.newBuilder() + condition: Option[Expression]): Relation = { + val relation = Relation.newBuilder() + val join = Join.newBuilder() join .setLeft(logicalPlan) .setRight(otherPlan) @@ -285,10 +280,10 @@ package object dsl { relation.setJoin(join).build() } - def as(alias: String): proto.Relation = { - proto.Relation + def as(alias: String): Relation = { + Relation .newBuilder(logicalPlan) - .setCommon(proto.RelationCommon.newBuilder().setAlias(alias)) + .setCommon(RelationCommon.newBuilder().setAlias(alias)) .build() } @@ -296,24 +291,23 @@ package object dsl { lowerBound: Double, upperBound: Double, withReplacement: Boolean, - seed: Long): proto.Relation = { - proto.Relation + seed: Long): Relation = { + Relation .newBuilder() .setSample( - proto.Sample + Sample .newBuilder() .setInput(logicalPlan) .setUpperBound(upperBound) .setLowerBound(lowerBound) .setWithReplacement(withReplacement) - .setSeed(proto.Sample.Seed.newBuilder().setSeed(seed).build()) + .setSeed(Sample.Seed.newBuilder().setSeed(seed).build()) .build()) .build() } - def groupBy(groupingExprs: proto.Expression*)( - aggregateExprs: proto.Expression*): proto.Relation = { - val agg = proto.Aggregate.newBuilder() + def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): Relation = { + val agg = Aggregate.newBuilder() agg.setInput(logicalPlan) for (groupingExpr <- groupingExprs) { @@ -321,29 +315,26 @@ package object dsl { } // TODO: support aggregateExprs, which is blocked by supporting any builtin function // resolution only by name in the analyzer. - proto.Relation.newBuilder().setAggregate(agg.build()).build() + Relation.newBuilder().setAggregate(agg.build()).build() } - def except(otherPlan: proto.Relation, isAll: Boolean): proto.Relation = { - proto.Relation + def except(otherPlan: Relation, isAll: Boolean): Relation = { + Relation .newBuilder() .setSetOp( createSetOperation(logicalPlan, otherPlan, SetOpType.SET_OP_TYPE_EXCEPT, isAll)) .build() } - def intersect(otherPlan: proto.Relation, isAll: Boolean): proto.Relation = - proto.Relation + def intersect(otherPlan: Relation, isAll: Boolean): Relation = + Relation .newBuilder() .setSetOp( createSetOperation(logicalPlan, otherPlan, SetOpType.SET_OP_TYPE_INTERSECT, isAll)) .build() - def union( - otherPlan: proto.Relation, - isAll: Boolean = true, - byName: Boolean = false): proto.Relation = - proto.Relation + def union(otherPlan: Relation, isAll: Boolean = true, byName: Boolean = false): Relation = + Relation .newBuilder() .setSetOp( createSetOperation( @@ -355,12 +346,12 @@ package object dsl { .build() private def createSetOperation( - left: proto.Relation, - right: proto.Relation, + left: Relation, + right: Relation, t: SetOpType, isAll: Boolean = true, - byName: Boolean = false): proto.SetOperation.Builder = { - val setOp = proto.SetOperation + byName: Boolean = false): SetOperation.Builder = { + val setOp = SetOperation .newBuilder() .setLeftInput(left) .setRightInput(right)