Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-40971][CONNECT][DSL] Do not need to use proto. to refer generated classes in Connect DSL. #38445

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.connect
import scala.collection.JavaConverters._
import scala.language.implicitConversions

import org.apache.spark.connect.proto
import org.apache.spark.connect.proto._
import org.apache.spark.connect.proto.Join.JoinType
import org.apache.spark.connect.proto.SetOperation.SetOpType
import org.apache.spark.sql.SaveMode
Expand All @@ -36,55 +36,53 @@ package object dsl {

object expressions { // scalastyle:ignore
implicit class DslString(val s: String) {
def protoAttr: proto.Expression =
proto.Expression
def protoAttr: Expression =
Expression
.newBuilder()
.setUnresolvedAttribute(
proto.Expression.UnresolvedAttribute
Expression.UnresolvedAttribute
.newBuilder()
.setUnparsedIdentifier(s))
.build()

def struct(
attrs: proto.Expression.QualifiedAttribute*): proto.Expression.QualifiedAttribute = {
val structExpr = proto.DataType.Struct.newBuilder()
def struct(attrs: Expression.QualifiedAttribute*): Expression.QualifiedAttribute = {
val structExpr = DataType.Struct.newBuilder()
for (attr <- attrs) {
val structField = proto.DataType.StructField.newBuilder()
val structField = DataType.StructField.newBuilder()
structField.setName(attr.getName)
structField.setType(attr.getType)
structExpr.addFields(structField)
}
proto.Expression.QualifiedAttribute
Expression.QualifiedAttribute
.newBuilder()
.setName(s)
.setType(proto.DataType.newBuilder().setStruct(structExpr))
.setType(DataType.newBuilder().setStruct(structExpr))
.build()
}

/** Creates a new AttributeReference of type int */
def int: proto.Expression.QualifiedAttribute = protoQualifiedAttrWithType(
proto.DataType.newBuilder().setI32(proto.DataType.I32.newBuilder()).build())
def int: Expression.QualifiedAttribute = protoQualifiedAttrWithType(
DataType.newBuilder().setI32(DataType.I32.newBuilder()).build())

private def protoQualifiedAttrWithType(
dataType: proto.DataType): proto.Expression.QualifiedAttribute =
proto.Expression.QualifiedAttribute
private def protoQualifiedAttrWithType(dataType: DataType): Expression.QualifiedAttribute =
Expression.QualifiedAttribute
.newBuilder()
.setName(s)
.setType(dataType)
.build()
}

implicit class DslExpression(val expr: proto.Expression) {
def as(alias: String): proto.Expression = proto.Expression
implicit class DslExpression(val expr: Expression) {
def as(alias: String): Expression = Expression
.newBuilder()
.setAlias(proto.Expression.Alias.newBuilder().setName(alias).setExpr(expr))
.setAlias(Expression.Alias.newBuilder().setName(alias).setExpr(expr))
.build()

def <(other: proto.Expression): proto.Expression =
proto.Expression
def <(other: Expression): Expression =
Expression
.newBuilder()
.setUnresolvedFunction(
proto.Expression.UnresolvedFunction
Expression.UnresolvedFunction
.newBuilder()
.addParts("<")
.addArguments(expr)
Expand All @@ -100,11 +98,11 @@ package object dsl {
* @return
* Expression wrapping the unresolved function.
*/
def callFunction(nameParts: Seq[String], args: Seq[proto.Expression]): proto.Expression = {
proto.Expression
def callFunction(nameParts: Seq[String], args: Seq[Expression]): Expression = {
Expression
.newBuilder()
.setUnresolvedFunction(
proto.Expression.UnresolvedFunction
Expression.UnresolvedFunction
.newBuilder()
.addAllParts(nameParts.asJava)
.addAllArguments(args.asJava))
Expand All @@ -119,26 +117,26 @@ package object dsl {
* @return
* Expression wrapping the unresolved function.
*/
def callFunction(name: String, args: Seq[proto.Expression]): proto.Expression = {
proto.Expression
def callFunction(name: String, args: Seq[Expression]): Expression = {
Expression
.newBuilder()
.setUnresolvedFunction(
proto.Expression.UnresolvedFunction
Expression.UnresolvedFunction
.newBuilder()
.addParts(name)
.addAllArguments(args.asJava))
.build()
}

implicit def intToLiteral(i: Int): proto.Expression =
proto.Expression
implicit def intToLiteral(i: Int): Expression =
Expression
.newBuilder()
.setLiteral(proto.Expression.Literal.newBuilder().setI32(i))
.setLiteral(Expression.Literal.newBuilder().setI32(i))
.build()
}

object commands { // scalastyle:ignore
implicit class DslCommands(val logicalPlan: proto.Relation) {
implicit class DslCommands(val logicalPlan: Relation) {
def write(
format: Option[String] = None,
path: Option[String] = None,
Expand All @@ -147,8 +145,8 @@ package object dsl {
sortByColumns: Seq[String] = Seq.empty,
partitionByCols: Seq[String] = Seq.empty,
bucketByCols: Seq[String] = Seq.empty,
numBuckets: Option[Int] = None): proto.Command = {
val writeOp = proto.WriteOperation.newBuilder()
numBuckets: Option[Int] = None): Command = {
val writeOp = WriteOperation.newBuilder()
format.foreach(writeOp.setSource(_))

mode
Expand All @@ -165,113 +163,110 @@ package object dsl {
partitionByCols.foreach(writeOp.addPartitioningColumns(_))

if (numBuckets.nonEmpty && bucketByCols.nonEmpty) {
val op = proto.WriteOperation.BucketBy.newBuilder()
val op = WriteOperation.BucketBy.newBuilder()
numBuckets.foreach(op.setNumBuckets(_))
bucketByCols.foreach(op.addBucketColumnNames(_))
writeOp.setBucketBy(op.build())
}
writeOp.setInput(logicalPlan)
proto.Command.newBuilder().setWriteOperation(writeOp.build()).build()
Command.newBuilder().setWriteOperation(writeOp.build()).build()
}
}
}

object plans { // scalastyle:ignore
implicit class DslLogicalPlan(val logicalPlan: proto.Relation) {
def select(exprs: proto.Expression*): proto.Relation = {
proto.Relation
implicit class DslLogicalPlan(val logicalPlan: Relation) {
def select(exprs: Expression*): Relation = {
Relation
.newBuilder()
.setProject(
proto.Project
Project
.newBuilder()
.setInput(logicalPlan)
.addAllExpressions(exprs.toIterable.asJava)
.build())
.build()
}

def limit(limit: Int): proto.Relation = {
proto.Relation
def limit(limit: Int): Relation = {
Relation
.newBuilder()
.setLimit(
proto.Limit
Limit
.newBuilder()
.setInput(logicalPlan)
.setLimit(limit))
.build()
}

def offset(offset: Int): proto.Relation = {
proto.Relation
def offset(offset: Int): Relation = {
Relation
.newBuilder()
.setOffset(
proto.Offset
Offset
.newBuilder()
.setInput(logicalPlan)
.setOffset(offset))
.build()
}

def where(condition: proto.Expression): proto.Relation = {
proto.Relation
def where(condition: Expression): Relation = {
Relation
.newBuilder()
.setFilter(proto.Filter.newBuilder().setInput(logicalPlan).setCondition(condition))
.setFilter(Filter.newBuilder().setInput(logicalPlan).setCondition(condition))
.build()
}

def deduplicate(colNames: Seq[String]): proto.Relation =
proto.Relation
def deduplicate(colNames: Seq[String]): Relation =
Relation
.newBuilder()
.setDeduplicate(
proto.Deduplicate
Deduplicate
.newBuilder()
.setInput(logicalPlan)
.addAllColumnNames(colNames.asJava))
.build()

def distinct(): proto.Relation =
proto.Relation
def distinct(): Relation =
Relation
.newBuilder()
.setDeduplicate(
proto.Deduplicate
Deduplicate
.newBuilder()
.setInput(logicalPlan)
.setAllColumnsAsKeys(true))
.build()

def join(
otherPlan: proto.Relation,
otherPlan: Relation,
joinType: JoinType,
condition: Option[proto.Expression]): proto.Relation = {
condition: Option[Expression]): Relation = {
join(otherPlan, joinType, Seq(), condition)
}

def join(otherPlan: proto.Relation, condition: Option[proto.Expression]): proto.Relation = {
def join(otherPlan: Relation, condition: Option[Expression]): Relation = {
join(otherPlan, JoinType.JOIN_TYPE_INNER, Seq(), condition)
}

def join(otherPlan: proto.Relation): proto.Relation = {
def join(otherPlan: Relation): Relation = {
join(otherPlan, JoinType.JOIN_TYPE_INNER, Seq(), None)
}

def join(otherPlan: proto.Relation, joinType: JoinType): proto.Relation = {
def join(otherPlan: Relation, joinType: JoinType): Relation = {
join(otherPlan, joinType, Seq(), None)
}

def join(
otherPlan: proto.Relation,
joinType: JoinType,
usingColumns: Seq[String]): proto.Relation = {
def join(otherPlan: Relation, joinType: JoinType, usingColumns: Seq[String]): Relation = {
join(otherPlan, joinType, usingColumns, None)
}

private def join(
otherPlan: proto.Relation,
otherPlan: Relation,
joinType: JoinType = JoinType.JOIN_TYPE_INNER,
usingColumns: Seq[String],
condition: Option[proto.Expression]): proto.Relation = {
val relation = proto.Relation.newBuilder()
val join = proto.Join.newBuilder()
condition: Option[Expression]): Relation = {
val relation = Relation.newBuilder()
val join = Join.newBuilder()
join
.setLeft(logicalPlan)
.setRight(otherPlan)
Expand All @@ -285,65 +280,61 @@ package object dsl {
relation.setJoin(join).build()
}

def as(alias: String): proto.Relation = {
proto.Relation
def as(alias: String): Relation = {
Relation
.newBuilder(logicalPlan)
.setCommon(proto.RelationCommon.newBuilder().setAlias(alias))
.setCommon(RelationCommon.newBuilder().setAlias(alias))
.build()
}

def sample(
lowerBound: Double,
upperBound: Double,
withReplacement: Boolean,
seed: Long): proto.Relation = {
proto.Relation
seed: Long): Relation = {
Relation
.newBuilder()
.setSample(
proto.Sample
Sample
.newBuilder()
.setInput(logicalPlan)
.setUpperBound(upperBound)
.setLowerBound(lowerBound)
.setWithReplacement(withReplacement)
.setSeed(proto.Sample.Seed.newBuilder().setSeed(seed).build())
.setSeed(Sample.Seed.newBuilder().setSeed(seed).build())
.build())
.build()
}

def groupBy(groupingExprs: proto.Expression*)(
aggregateExprs: proto.Expression*): proto.Relation = {
val agg = proto.Aggregate.newBuilder()
def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): Relation = {
val agg = Aggregate.newBuilder()
agg.setInput(logicalPlan)

for (groupingExpr <- groupingExprs) {
agg.addGroupingExpressions(groupingExpr)
}
// TODO: support aggregateExprs, which is blocked by supporting any builtin function
// resolution only by name in the analyzer.
proto.Relation.newBuilder().setAggregate(agg.build()).build()
Relation.newBuilder().setAggregate(agg.build()).build()
}

def except(otherPlan: proto.Relation, isAll: Boolean): proto.Relation = {
proto.Relation
def except(otherPlan: Relation, isAll: Boolean): Relation = {
Relation
.newBuilder()
.setSetOp(
createSetOperation(logicalPlan, otherPlan, SetOpType.SET_OP_TYPE_EXCEPT, isAll))
.build()
}

def intersect(otherPlan: proto.Relation, isAll: Boolean): proto.Relation =
proto.Relation
def intersect(otherPlan: Relation, isAll: Boolean): Relation =
Relation
.newBuilder()
.setSetOp(
createSetOperation(logicalPlan, otherPlan, SetOpType.SET_OP_TYPE_INTERSECT, isAll))
.build()

def union(
otherPlan: proto.Relation,
isAll: Boolean = true,
byName: Boolean = false): proto.Relation =
proto.Relation
def union(otherPlan: Relation, isAll: Boolean = true, byName: Boolean = false): Relation =
Relation
.newBuilder()
.setSetOp(
createSetOperation(
Expand All @@ -355,12 +346,12 @@ package object dsl {
.build()

private def createSetOperation(
left: proto.Relation,
right: proto.Relation,
left: Relation,
right: Relation,
t: SetOpType,
isAll: Boolean = true,
byName: Boolean = false): proto.SetOperation.Builder = {
val setOp = proto.SetOperation
byName: Boolean = false): SetOperation.Builder = {
val setOp = SetOperation
.newBuilder()
.setLeftInput(left)
.setRightInput(right)
Expand Down