Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
maryannxue committed Jun 14, 2019
1 parent 8570ec0 commit 237c067
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 25 deletions.
Expand Up @@ -313,53 +313,49 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
}

/**
* Returns a copy of this node where `f` has been applied to all the nodes children.
* Returns a copy of this node where `f` has been applied to all the nodes in `children`.
*/
def mapChildren(f: BaseType => BaseType): BaseType = {
if (children.nonEmpty) {
mapProductElements(f, applyToAll = false)
mapChildren(f, forceCopy = false)
} else {
this
}
}

/**
* Returns a copy of this node where `f` has been applied to all applicable `TreeNode` elements
* in the productIterator.
* @param f the transform function to be applied on applicable `TreeNode` elements.
* @param applyToAll If true, the transform function will be applied to all `TreeNode` elements
* even for non-child elements; otherwise, the function will only be applied
* on children nodes. Also, when this is true, a copy of this node will be
* returned even if no elements have been changed.
* Returns a copy of this node where `f` has been applied to all the nodes in `children`.
* @param f The transform function to be applied on applicable `TreeNode` elements.
* @param forceCopy Whether to force making a copy of the nodes even if no child has been changed.
*/
private def mapProductElements(
private def mapChildren(
f: BaseType => BaseType,
applyToAll: Boolean): BaseType = {
forceCopy: Boolean): BaseType = {
var changed = false

def mapChild(child: Any): Any = child match {
case arg: TreeNode[_] if applyToAll || containsChild(arg) =>
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = f(arg.asInstanceOf[BaseType])
if (applyToAll || !(newChild fastEquals arg)) {
if (forceCopy || !(newChild fastEquals arg)) {
changed = true
newChild
} else {
arg
}
case tuple @ (arg1: TreeNode[_], arg2: TreeNode[_]) =>
val newChild1 = if (applyToAll || containsChild(arg1)) {
val newChild1 = if (containsChild(arg1)) {
f(arg1.asInstanceOf[BaseType])
} else {
arg1.asInstanceOf[BaseType]
}

val newChild2 = if (applyToAll || containsChild(arg2)) {
val newChild2 = if (containsChild(arg2)) {
f(arg2.asInstanceOf[BaseType])
} else {
arg2.asInstanceOf[BaseType]
}

if (applyToAll || !(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) {
if (forceCopy || !(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) {
changed = true
(newChild1, newChild2)
} else {
Expand All @@ -369,26 +365,26 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
}

val newArgs = mapProductIterator {
case arg: TreeNode[_] if applyToAll || containsChild(arg) =>
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = f(arg.asInstanceOf[BaseType])
if (applyToAll || !(newChild fastEquals arg)) {
if (forceCopy || !(newChild fastEquals arg)) {
changed = true
newChild
} else {
arg
}
case Some(arg: TreeNode[_]) if applyToAll || containsChild(arg) =>
case Some(arg: TreeNode[_]) if containsChild(arg) =>
val newChild = f(arg.asInstanceOf[BaseType])
if (applyToAll || !(newChild fastEquals arg)) {
if (forceCopy || !(newChild fastEquals arg)) {
changed = true
Some(newChild)
} else {
Some(arg)
}
case m: Map[_, _] => m.mapValues {
case arg: TreeNode[_] if applyToAll || containsChild(arg) =>
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = f(arg.asInstanceOf[BaseType])
if (applyToAll || !(newChild fastEquals arg)) {
if (forceCopy || !(newChild fastEquals arg)) {
changed = true
newChild
} else {
Expand All @@ -402,7 +398,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
case nonChild: AnyRef => nonChild
case null => null
}
if (applyToAll || changed) makeCopy(newArgs, applyToAll) else this
if (forceCopy || changed) makeCopy(newArgs, forceCopy) else this
}

/**
Expand Down Expand Up @@ -475,7 +471,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
}

override def clone(): BaseType = {
mapProductElements(_.clone(), applyToAll = true)
mapChildren(_.clone(), forceCopy = true)
}

/**
Expand Down
Expand Up @@ -706,6 +706,6 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
val leaf = FakeLeafPlan(intersect)
val leafCloned = leaf.clone()
assertDifferentInstance(leaf, leafCloned)
assertDifferentInstance(leaf.child, leafCloned.asInstanceOf[FakeLeafPlan].child)
assert(leaf.child.eq(leafCloned.asInstanceOf[FakeLeafPlan].child))
}
}

0 comments on commit 237c067

Please sign in to comment.