Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,14 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with RemoteSparkSessi
(5, "hello"))
}

test("SPARK-50789: reduceGroups on unresolved plan") {
val ds = Seq("abc", "xyz", "hello").toDS().select("*").as[String]
checkDatasetUnorderly(
ds.groupByKey(_.length).reduceGroups(_ + _),
(3, "abcxyz"),
(5, "hello"))
}

test("groupby") {
val ds = Seq(("a", 1, 10), ("a", 2, 20), ("b", 2, 1), ("b", 1, 2), ("c", 1, 1))
.toDF("key", "seq", "value")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,18 +401,38 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest with RemoteSparkSession
assert(ds.select(aggCol).head() == 135) // 45 + 90
}

test("SPARK-50789: UDAF custom Aggregator - toColumn on unresolved plan") {
val encoder = Encoders.product[UdafTestInput]
val aggCol = new CompleteUdafTestInputAggregator().toColumn
val ds = spark.range(10).withColumn("extra", col("id") * 2).select("*").as(encoder)
assert(ds.select(aggCol).head() == 135) // 45 + 90
}

test("UDAF custom Aggregator - multiple extends - toColumn") {
val encoder = Encoders.product[UdafTestInput]
val aggCol = new CompleteGrandChildUdafTestInputAggregator().toColumn
val ds = spark.range(10).withColumn("extra", col("id") * 2).as(encoder)
assert(ds.select(aggCol).head() == 540) // (45 + 90) * 4
}

test("UDAF custom aggregator - with rows - toColumn") {
test("SPARK-50789: UDAF custom Aggregator - multiple extends - toColumn on unresolved plan") {
val encoder = Encoders.product[UdafTestInput]
val aggCol = new CompleteGrandChildUdafTestInputAggregator().toColumn
val ds = spark.range(10).withColumn("extra", col("id") * 2).select("*").as(encoder)
assert(ds.select(aggCol).head() == 540) // (45 + 90) * 4
}

test("UDAF custom Aggregator - with rows - toColumn") {
val ds = spark.range(10).withColumn("extra", col("id") * 2)
assert(ds.select(RowAggregator.toColumn).head() == 405)
assert(ds.agg(RowAggregator.toColumn).head().getLong(0) == 405)
}

test("SPARK-50789: UDAF custom Aggregator - with rows - toColumn on unresolved plan") {
val ds = spark.range(10).withColumn("extra", col("id") * 2).select("*")
assert(ds.select(RowAggregator.toColumn).head() == 405)
assert(ds.agg(RowAggregator.toColumn).head().getLong(0) == 405)
}
}

case class UdafTestInput(id: Long, extra: Long)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -845,9 +845,10 @@ class SparkConnectPlanner(
kEncoder: ExpressionEncoder[_],
vEncoder: ExpressionEncoder[_],
analyzed: LogicalPlan,
dataAttributes: Seq[Attribute],
analyzedData: LogicalPlan,
groupingAttributes: Seq[Attribute],
sortOrder: Seq[SortOrder]) {
val dataAttributes: Seq[Attribute] = analyzedData.output
val valueDeserializer: Expression =
UnresolvedDeserializer(vEncoder.deserializer, dataAttributes)
}
Expand Down Expand Up @@ -900,7 +901,7 @@ class SparkConnectPlanner(
dummyFunc.outEnc,
dummyFunc.inEnc,
qe.analyzed,
analyzed.output,
analyzed,
aliasedGroupings,
sortOrder)
}
Expand All @@ -924,7 +925,7 @@ class SparkConnectPlanner(
kEnc,
vEnc,
withGroupingKeyAnalyzed,
analyzed.output,
analyzed,
withGroupingKey.newColumns,
sortOrder)
}
Expand Down Expand Up @@ -1489,11 +1490,19 @@ class SparkConnectPlanner(
logical.OneRowRelation()
}

val logicalPlan =
if (rel.getExpressionsList.asScala.toSeq.exists(
_.getExprTypeCase == proto.Expression.ExprTypeCase.TYPED_AGGREGATE_EXPRESSION)) {
session.sessionState.executePlan(baseRel).analyzed
} else {
baseRel
}

val projection = rel.getExpressionsList.asScala.toSeq
.map(transformExpression(_, Some(baseRel)))
.map(transformExpression(_, Some(logicalPlan)))
.map(toNamedExpression)

logical.Project(projectList = projection, child = baseRel)
logical.Project(projectList = projection, child = logicalPlan)
}

/**
Expand Down Expand Up @@ -2241,7 +2250,7 @@ class SparkConnectPlanner(

val keyColumn = TypedAggUtils.aggKeyColumn(ds.kEncoder, ds.groupingAttributes)
val namedColumns = rel.getAggregateExpressionsList.asScala.toSeq
.map(expr => transformExpressionWithTypedReduceExpression(expr, input))
.map(expr => transformExpressionWithTypedReduceExpression(expr, ds.analyzedData))
.map(toNamedExpression)
logical.Aggregate(ds.groupingAttributes, keyColumn +: namedColumns, ds.analyzed)
}
Expand All @@ -2252,29 +2261,37 @@ class SparkConnectPlanner(
}
val input = transformRelation(rel.getInput)

val logicalPlan =
if (rel.getAggregateExpressionsList.asScala.toSeq.exists(
_.getExprTypeCase == proto.Expression.ExprTypeCase.TYPED_AGGREGATE_EXPRESSION)) {
session.sessionState.executePlan(input).analyzed
} else {
input
}

val groupingExprs = rel.getGroupingExpressionsList.asScala.toSeq.map(transformExpression)
val aggExprs = rel.getAggregateExpressionsList.asScala.toSeq
.map(expr => transformExpressionWithTypedReduceExpression(expr, input))
.map(expr => transformExpressionWithTypedReduceExpression(expr, logicalPlan))
val aliasedAgg = (groupingExprs ++ aggExprs).map(toNamedExpression)

rel.getGroupType match {
case proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY =>
logical.Aggregate(
groupingExpressions = groupingExprs,
aggregateExpressions = aliasedAgg,
child = input)
child = logicalPlan)

case proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP =>
logical.Aggregate(
groupingExpressions = Seq(Rollup(groupingExprs.map(Seq(_)))),
aggregateExpressions = aliasedAgg,
child = input)
child = logicalPlan)

case proto.Aggregate.GroupType.GROUP_TYPE_CUBE =>
logical.Aggregate(
groupingExpressions = Seq(Cube(groupingExprs.map(Seq(_)))),
aggregateExpressions = aliasedAgg,
child = input)
child = logicalPlan)

case proto.Aggregate.GroupType.GROUP_TYPE_PIVOT =>
if (!rel.hasPivot) {
Expand All @@ -2286,15 +2303,15 @@ class SparkConnectPlanner(
rel.getPivot.getValuesList.asScala.toSeq.map(transformLiteral)
} else {
RelationalGroupedDataset
.collectPivotValues(Dataset.ofRows(session, input), Column(pivotExpr))
.collectPivotValues(Dataset.ofRows(session, logicalPlan), Column(pivotExpr))
.map(expressions.Literal.apply)
}
logical.Pivot(
groupByExprsOpt = Some(groupingExprs.map(toNamedExpression)),
pivotColumn = pivotExpr,
pivotValues = valueExprs,
aggregates = aggExprs,
child = input)
child = logicalPlan)

case proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS =>
val groupingSetsExprs = rel.getGroupingSetsList.asScala.toSeq.map { getGroupingSets =>
Expand All @@ -2306,7 +2323,7 @@ class SparkConnectPlanner(
groupingSets = groupingSetsExprs,
userGivenGroupByExprs = groupingExprs)),
aggregateExpressions = aliasedAgg,
child = input)
child = logicalPlan)

case other => throw InvalidPlanInput(s"Unknown Group Type $other")
}
Expand Down