Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -226,12 +226,14 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
}
}

@scala.annotation.nowarn("cat=deprecation")
def recursiveTransform(arg: Any): AnyRef = arg match {
case e: Expression => transformExpression(e)
case Some(value) => Some(recursiveTransform(value))
case m: Map[_, _] => m
case d: DataType => d // Avoid unpacking Structs
case stream: LazyList[_] => stream.map(recursiveTransform).force
case stream: Stream[_] => stream.map(recursiveTransform).force
case lazyList: LazyList[_] => lazyList.map(recursiveTransform).force
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @JoshRosen Did I understand your suggestion correctly? Thanks ~

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LuciferYang, yes, this is exactly what I had in mind.

case seq: Iterable[_] => seq.map(recursiveTransform)
case other: AnyRef => other
case null => null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.trees

import java.util.UUID

import scala.annotation.nowarn
import scala.collection.{mutable, Map}
import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag
Expand Down Expand Up @@ -378,12 +379,16 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]]
case nonChild: AnyRef => nonChild
case null => null
}
@nowarn("cat=deprecation")
val newArgs = mapProductIterator {
case s: StructType => s // Don't convert struct types to some other type of Seq[StructField]
// Handle Seq[TreeNode] in TreeNode parameters.
case s: LazyList[_] =>
// LazyList is lazy so we need to force materialization
case s: Stream[_] =>
// Stream is lazy so we need to force materialization
s.map(mapChild).force
case l: LazyList[_] =>
// LazyList is lazy so we need to force materialization
l.map(mapChild).force
case s: Seq[_] =>
s.map(mapChild)
case m: Map[_, _] =>
Expand Down Expand Up @@ -801,6 +806,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]]
case other => other
}

@nowarn("cat=deprecation")
val newArgs = mapProductIterator {
case arg: TreeNode[_] if containsChild(arg) =>
arg.asInstanceOf[BaseType].clone()
Expand All @@ -813,7 +819,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]]
case (_, other) => other
}
case d: DataType => d // Avoid unpacking Structs
case args: LazyList[_] => args.map(mapChild).force // Force materialization on stream
case args: Stream[_] => args.map(mapChild).force // Force materialization on stream
case args: LazyList[_] => args.map(mapChild).force // Force materialization on LazyList
case args: Iterable[_] => args.map(mapChild)
case nonChild: AnyRef => nonChild
case null => null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.plans

import scala.annotation.nowarn

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.dsl.expressions._
Expand Down Expand Up @@ -83,6 +85,26 @@ class LogicalPlanSuite extends SparkFunSuite {
}

test("transformExpressions works with a Stream") {
val id1 = NamedExpression.newExprId
val id2 = NamedExpression.newExprId
@nowarn("cat=deprecation")
val plan = Project(Stream(
Alias(Literal(1), "a")(exprId = id1),
Alias(Literal(2), "b")(exprId = id2)),
OneRowRelation())
val result = plan.transformExpressions {
case Literal(v: Int, IntegerType) if v != 1 =>
Literal(v + 1, IntegerType)
}
@nowarn("cat=deprecation")
val expected = Project(Stream(
Alias(Literal(1), "a")(exprId = id1),
Alias(Literal(3), "b")(exprId = id2)),
OneRowRelation())
assert(result.sameResult(expected))
}

test("SPARK-45685: transformExpressions works with a LazyList") {
val id1 = NamedExpression.newExprId
val id2 = NamedExpression.newExprId
val plan = Project(LazyList(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.trees
import java.math.BigInteger
import java.util.UUID

import scala.annotation.nowarn
import scala.collection.mutable.ArrayBuffer

import org.json4s.JsonAST._
Expand Down Expand Up @@ -693,6 +694,22 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
}

test("transform works on stream of children") {
@nowarn("cat=deprecation")
val before = Coalesce(Stream(Literal(1), Literal(2)))
// Note it is a bit tricky to exhibit the broken behavior. Basically we want to create the
// situation in which the TreeNode.mapChildren function's change detection is not triggered. A
// stream's first element is typically materialized, so in order to not trip the TreeNode change
// detection logic, we should not change the first element in the sequence.
val result = before.transform {
case Literal(v: Int, IntegerType) if v != 1 =>
Literal(v + 1, IntegerType)
}
@nowarn("cat=deprecation")
val expected = Coalesce(Stream(Literal(1), Literal(3)))
assert(result === expected)
}

test("SPARK-45685: transform works on LazyList of children") {
val before = Coalesce(LazyList(Literal(1), Literal(2)))
// Note it is a bit tricky to exhibit the broken behavior. Basically we want to create the
// situation in which the TreeNode.mapChildren function's change detection is not triggered. A
Expand All @@ -707,6 +724,16 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
}

test("withNewChildren on stream of children") {
@nowarn("cat=deprecation")
val before = Coalesce(Stream(Literal(1), Literal(2)))
@nowarn("cat=deprecation")
val result = before.withNewChildren(Stream(Literal(1), Literal(3)))
@nowarn("cat=deprecation")
val expected = Coalesce(Stream(Literal(1), Literal(3)))
assert(result === expected)
}

test("SPARK-45685: withNewChildren on LazyList of children") {
val before = Coalesce(LazyList(Literal(1), Literal(2)))
val result = before.withNewChildren(LazyList(Literal(1), Literal(3)))
val expected = Coalesce(LazyList(Literal(1), Literal(3)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,10 @@ trait CodegenSupport extends SparkPlan {
}
}

@scala.annotation.nowarn("cat=deprecation")
val inputVars = inputVarsCandidate match {
case stream: LazyList[ExprCode] => stream.force
case stream: Stream[ExprCode] => stream.force
case lazyList: LazyList[ExprCode] => lazyList.force
case other => other
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,14 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
}

test("SPARK-24500: create union with stream of children") {
@scala.annotation.nowarn("cat=deprecation")
val df = Union(Stream(
Range(1, 1, 1, 1),
Range(1, 2, 1, 1)))
df.queryExecution.executedPlan.execute()
}

test("SPARK-45685: create union with LazyList of children") {
val df = Union(LazyList(
Range(1, 1, 1, 1),
Range(1, 2, 1, 1)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,16 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
}

test("SPARK-26680: Stream in groupBy does not cause StackOverflowError") {
@scala.annotation.nowarn("cat=deprecation")
val groupByCols = Stream(col("key"))
val df = Seq((1, 2), (2, 3), (1, 3)).toDF("key", "value")
.groupBy(groupByCols: _*)
.max("value")

checkAnswer(df, Seq(Row(1, 3), Row(2, 3)))
}

test("SPARK-45685: LazyList in groupBy does not cause StackOverflowError") {
val groupByCols = LazyList(col("key"))
val df = Seq((1, 2), (2, 3), (1, 3)).toDF("key", "value")
.groupBy(groupByCols: _*)
Expand Down