diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala index 6fd664d905408..021b4fea26e2a 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala @@ -460,6 +460,14 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with RemoteSparkSessi (5, "hello")) } + test("SPARK-50789: reduceGroups on unresolved plan") { + val ds = Seq("abc", "xyz", "hello").toDS().select("*").as[String] + checkDatasetUnorderly( + ds.groupByKey(_.length).reduceGroups(_ + _), + (3, "abcxyz"), + (5, "hello")) + } + test("groupby") { val ds = Seq(("a", 1, 10), ("a", 2, 20), ("b", 2, 1), ("b", 1, 2), ("c", 1, 1)) .toDF("key", "seq", "value") diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala index 8415444c10aac..19275326d6421 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala @@ -401,6 +401,13 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest with RemoteSparkSession assert(ds.select(aggCol).head() == 135) // 45 + 90 } + test("SPARK-50789: UDAF custom Aggregator - toColumn on unresolved plan") { + val encoder = Encoders.product[UdafTestInput] + val aggCol = new CompleteUdafTestInputAggregator().toColumn + val ds = spark.range(10).withColumn("extra", col("id") * 2).select("*").as(encoder) + assert(ds.select(aggCol).head() == 135) // 45 + 90 + } + test("UDAF custom Aggregator - multiple extends - toColumn") { val encoder = Encoders.product[UdafTestInput] val aggCol = new CompleteGrandChildUdafTestInputAggregator().toColumn @@ -408,11 +415,24 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest with RemoteSparkSession assert(ds.select(aggCol).head() == 540) // (45 + 90) * 4 } - test("UDAF custom aggregator - with rows - toColumn") { + test("SPARK-50789: UDAF custom Aggregator - multiple extends - toColumn on unresolved plan") { + val encoder = Encoders.product[UdafTestInput] + val aggCol = new CompleteGrandChildUdafTestInputAggregator().toColumn + val ds = spark.range(10).withColumn("extra", col("id") * 2).select("*").as(encoder) + assert(ds.select(aggCol).head() == 540) // (45 + 90) * 4 + } + + test("UDAF custom Aggregator - with rows - toColumn") { val ds = spark.range(10).withColumn("extra", col("id") * 2) assert(ds.select(RowAggregator.toColumn).head() == 405) assert(ds.agg(RowAggregator.toColumn).head().getLong(0) == 405) } + + test("SPARK-50789: UDAF custom Aggregator - with rows - toColumn on unresolved plan") { + val ds = spark.range(10).withColumn("extra", col("id") * 2).select("*") + assert(ds.select(RowAggregator.toColumn).head() == 405) + assert(ds.agg(RowAggregator.toColumn).head().getLong(0) == 405) + } } case class UdafTestInput(id: Long, extra: Long) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index c0b4384af8b6d..6ab69aea12e5d 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -845,9 +845,10 @@ class SparkConnectPlanner( kEncoder: ExpressionEncoder[_], vEncoder: ExpressionEncoder[_], analyzed: LogicalPlan, - dataAttributes: Seq[Attribute], + analyzedData: LogicalPlan, groupingAttributes: Seq[Attribute], sortOrder: Seq[SortOrder]) { + val dataAttributes: Seq[Attribute] = analyzedData.output val valueDeserializer: Expression = UnresolvedDeserializer(vEncoder.deserializer, dataAttributes) } @@ -900,7 +901,7 @@ class SparkConnectPlanner( dummyFunc.outEnc, dummyFunc.inEnc, qe.analyzed, - analyzed.output, + analyzed, aliasedGroupings, sortOrder) } @@ -924,7 +925,7 @@ class SparkConnectPlanner( kEnc, vEnc, withGroupingKeyAnalyzed, - analyzed.output, + analyzed, withGroupingKey.newColumns, sortOrder) } @@ -1489,11 +1490,19 @@ class SparkConnectPlanner( logical.OneRowRelation() } + val logicalPlan = + if (rel.getExpressionsList.asScala.toSeq.exists( + _.getExprTypeCase == proto.Expression.ExprTypeCase.TYPED_AGGREGATE_EXPRESSION)) { + session.sessionState.executePlan(baseRel).analyzed + } else { + baseRel + } + val projection = rel.getExpressionsList.asScala.toSeq - .map(transformExpression(_, Some(baseRel))) + .map(transformExpression(_, Some(logicalPlan))) .map(toNamedExpression) - logical.Project(projectList = projection, child = baseRel) + logical.Project(projectList = projection, child = logicalPlan) } /** @@ -2241,7 +2250,7 @@ class SparkConnectPlanner( val keyColumn = TypedAggUtils.aggKeyColumn(ds.kEncoder, ds.groupingAttributes) val namedColumns = rel.getAggregateExpressionsList.asScala.toSeq - .map(expr => transformExpressionWithTypedReduceExpression(expr, input)) + .map(expr => transformExpressionWithTypedReduceExpression(expr, ds.analyzedData)) .map(toNamedExpression) logical.Aggregate(ds.groupingAttributes, keyColumn +: namedColumns, ds.analyzed) } @@ -2252,9 +2261,17 @@ class SparkConnectPlanner( } val input = transformRelation(rel.getInput) + val logicalPlan = + if (rel.getAggregateExpressionsList.asScala.toSeq.exists( + _.getExprTypeCase == proto.Expression.ExprTypeCase.TYPED_AGGREGATE_EXPRESSION)) { + session.sessionState.executePlan(input).analyzed + } else { + input + } + val groupingExprs = rel.getGroupingExpressionsList.asScala.toSeq.map(transformExpression) val aggExprs = rel.getAggregateExpressionsList.asScala.toSeq - .map(expr => transformExpressionWithTypedReduceExpression(expr, input)) + .map(expr => transformExpressionWithTypedReduceExpression(expr, logicalPlan)) val aliasedAgg = (groupingExprs ++ aggExprs).map(toNamedExpression) rel.getGroupType match { @@ -2262,19 +2279,19 @@ class SparkConnectPlanner( logical.Aggregate( groupingExpressions = groupingExprs, aggregateExpressions = aliasedAgg, - child = input) + child = logicalPlan) case proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP => logical.Aggregate( groupingExpressions = Seq(Rollup(groupingExprs.map(Seq(_)))), aggregateExpressions = aliasedAgg, - child = input) + child = logicalPlan) case proto.Aggregate.GroupType.GROUP_TYPE_CUBE => logical.Aggregate( groupingExpressions = Seq(Cube(groupingExprs.map(Seq(_)))), aggregateExpressions = aliasedAgg, - child = input) + child = logicalPlan) case proto.Aggregate.GroupType.GROUP_TYPE_PIVOT => if (!rel.hasPivot) { @@ -2286,7 +2303,7 @@ class SparkConnectPlanner( rel.getPivot.getValuesList.asScala.toSeq.map(transformLiteral) } else { RelationalGroupedDataset - .collectPivotValues(Dataset.ofRows(session, input), Column(pivotExpr)) + .collectPivotValues(Dataset.ofRows(session, logicalPlan), Column(pivotExpr)) .map(expressions.Literal.apply) } logical.Pivot( @@ -2294,7 +2311,7 @@ class SparkConnectPlanner( pivotColumn = pivotExpr, pivotValues = valueExprs, aggregates = aggExprs, - child = input) + child = logicalPlan) case proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS => val groupingSetsExprs = rel.getGroupingSetsList.asScala.toSeq.map { getGroupingSets => @@ -2306,7 +2323,7 @@ class SparkConnectPlanner( groupingSets = groupingSetsExprs, userGivenGroupByExprs = groupingExprs)), aggregateExpressions = aliasedAgg, - child = input) + child = logicalPlan) case other => throw InvalidPlanInput(s"Unknown Group Type $other") }