Skip to content

Commit

Permalink
[SPARK-28057][SQL] Add method clone in catalyst TreeNode
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Implemented the `clone` method for `TreeNode` based on `mapChildren`.

## How was this patch tested?

Added new UT.

Closes #24876 from maryannxue/treenode-clone.

Authored-by: maryannxue <maryannxue@apache.org>
Signed-off-by: herman <herman@databricks.com>
  • Loading branch information
maryannxue authored and hvanhovell committed Jun 14, 2019
1 parent b508eab commit d1951aa
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 63 deletions.
Expand Up @@ -316,80 +316,92 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
}

/**
* Returns a copy of this node where `f` has been applied to all the nodes children.
* Returns a copy of this node where `f` has been applied to all the nodes in `children`.
*/
def mapChildren(f: BaseType => BaseType): BaseType = {
if (children.nonEmpty) {
var changed = false
def mapChild(child: Any): Any = child match {
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = f(arg.asInstanceOf[BaseType])
if (!(newChild fastEquals arg)) {
changed = true
newChild
} else {
arg
}
case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) =>
val newChild1 = if (containsChild(arg1)) {
f(arg1.asInstanceOf[BaseType])
} else {
arg1.asInstanceOf[BaseType]
}
mapChildren(f, forceCopy = false)
} else {
this
}
}

val newChild2 = if (containsChild(arg2)) {
f(arg2.asInstanceOf[BaseType])
} else {
arg2.asInstanceOf[BaseType]
}
/**
* Returns a copy of this node where `f` has been applied to all the nodes in `children`.
* @param f The transform function to be applied on applicable `TreeNode` elements.
* @param forceCopy Whether to force making a copy of the nodes even if no child has been changed.
*/
private def mapChildren(
f: BaseType => BaseType,
forceCopy: Boolean): BaseType = {
var changed = false

if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) {
changed = true
(newChild1, newChild2)
} else {
tuple
}
case other => other
}
def mapChild(child: Any): Any = child match {
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = f(arg.asInstanceOf[BaseType])
if (forceCopy || !(newChild fastEquals arg)) {
changed = true
newChild
} else {
arg
}
case tuple @ (arg1: TreeNode[_], arg2: TreeNode[_]) =>
val newChild1 = if (containsChild(arg1)) {
f(arg1.asInstanceOf[BaseType])
} else {
arg1.asInstanceOf[BaseType]
}

val newChild2 = if (containsChild(arg2)) {
f(arg2.asInstanceOf[BaseType])
} else {
arg2.asInstanceOf[BaseType]
}

if (forceCopy || !(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) {
changed = true
(newChild1, newChild2)
} else {
tuple
}
case other => other
}

val newArgs = mapProductIterator {
val newArgs = mapProductIterator {
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = f(arg.asInstanceOf[BaseType])
if (forceCopy || !(newChild fastEquals arg)) {
changed = true
newChild
} else {
arg
}
case Some(arg: TreeNode[_]) if containsChild(arg) =>
val newChild = f(arg.asInstanceOf[BaseType])
if (forceCopy || !(newChild fastEquals arg)) {
changed = true
Some(newChild)
} else {
Some(arg)
}
case m: Map[_, _] => m.mapValues {
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = f(arg.asInstanceOf[BaseType])
if (!(newChild fastEquals arg)) {
if (forceCopy || !(newChild fastEquals arg)) {
changed = true
newChild
} else {
arg
}
case Some(arg: TreeNode[_]) if containsChild(arg) =>
val newChild = f(arg.asInstanceOf[BaseType])
if (!(newChild fastEquals arg)) {
changed = true
Some(newChild)
} else {
Some(arg)
}
case m: Map[_, _] => m.mapValues {
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = f(arg.asInstanceOf[BaseType])
if (!(newChild fastEquals arg)) {
changed = true
newChild
} else {
arg
}
case other => other
}.view.force // `mapValues` is lazy and we need to force it to materialize
case d: DataType => d // Avoid unpacking Structs
case args: Stream[_] => args.map(mapChild).force // Force materialization on stream
case args: Iterable[_] => args.map(mapChild)
case nonChild: AnyRef => nonChild
case null => null
}
if (changed) makeCopy(newArgs) else this
} else {
this
case other => other
}.view.force // `mapValues` is lazy and we need to force it to materialize
case d: DataType => d // Avoid unpacking Structs
case args: Stream[_] => args.map(mapChild).force // Force materialization on stream
case args: Iterable[_] => args.map(mapChild)
case nonChild: AnyRef => nonChild
case null => null
}
if (forceCopy || changed) makeCopy(newArgs, forceCopy) else this
}

/**
Expand All @@ -405,9 +417,20 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
* that are not present in the productIterator.
* @param newArgs the new product arguments.
*/
def makeCopy(newArgs: Array[AnyRef]): BaseType = attachTree(this, "makeCopy") {
def makeCopy(newArgs: Array[AnyRef]): BaseType = makeCopy(newArgs, allowEmptyArgs = false)

/**
* Creates a copy of this type of tree node after a transformation.
* Must be overridden by child classes that have constructor arguments
* that are not present in the productIterator.
* @param newArgs the new product arguments.
* @param allowEmptyArgs whether to allow argument list to be empty.
*/
private def makeCopy(
newArgs: Array[AnyRef],
allowEmptyArgs: Boolean): BaseType = attachTree(this, "makeCopy") {
// Skip no-arg constructors that are just there for kryo.
val ctors = getClass.getConstructors.filter(_.getParameterTypes.size != 0)
val ctors = getClass.getConstructors.filter(allowEmptyArgs || _.getParameterTypes.size != 0)
if (ctors.isEmpty) {
sys.error(s"No valid constructor for $nodeName")
}
Expand Down Expand Up @@ -450,6 +473,10 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
}
}

override def clone(): BaseType = {
mapChildren(_.clone(), forceCopy = true)
}

/**
* Returns the name of this type of TreeNode. Defaults to the class name.
* Note that we remove the "Exec" suffix for physical operators here.
Expand Down
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions.DslString
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.plans.{LeftOuter, NaturalJoin, SQLHelper}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Union}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.{IdentityBroadcastMode, RoundRobinPartitioning, SinglePartition}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -82,6 +82,11 @@ case class SelfReferenceUDF(
def apply(key: String): Boolean = config.contains(key)
}

case class FakeLeafPlan(child: LogicalPlan)
extends org.apache.spark.sql.catalyst.plans.logical.LeafNode {
override def output: Seq[Attribute] = child.output
}

class TreeNodeSuite extends SparkFunSuite with SQLHelper {
test("top node changed") {
val after = Literal(1) transform { case Literal(1, _) => Literal(2) }
Expand Down Expand Up @@ -673,4 +678,34 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
})
}
}

test("clone") {
def assertDifferentInstance(before: AnyRef, after: AnyRef): Unit = {
assert(before.ne(after) && before == after)
before.asInstanceOf[TreeNode[_]].children.zip(
after.asInstanceOf[TreeNode[_]].children).foreach {
case (beforeChild: AnyRef, afterChild: AnyRef) =>
assertDifferentInstance(beforeChild, afterChild)
}
}

// Empty constructor
val rowNumber = RowNumber()
assertDifferentInstance(rowNumber, rowNumber.clone())

// Overridden `makeCopy`
val oneRowRelation = OneRowRelation()
assertDifferentInstance(oneRowRelation, oneRowRelation.clone())

// Multi-way operators
val intersect =
Intersect(oneRowRelation, Union(Seq(oneRowRelation, oneRowRelation)), isAll = false)
assertDifferentInstance(intersect, intersect.clone())

// Leaf node with an inner child
val leaf = FakeLeafPlan(intersect)
val leafCloned = leaf.clone()
assertDifferentInstance(leaf, leafCloned)
assert(leaf.child.eq(leafCloned.asInstanceOf[FakeLeafPlan].child))
}
}

0 comments on commit d1951aa

Please sign in to comment.