Skip to content

Commit

Permalink
[SPARK-21072][SQL] TreeNode.mapChildren should only apply to the chil…
Browse files Browse the repository at this point in the history
…dren node.

## What changes were proposed in this pull request?

Just as the function name and comments of `TreeNode.mapChildren` mentioned, the function should be apply to all currently node children. So, the follow code should judge whether it is the children node.

https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala#L342

## How was this patch tested?

Existing tests.

Author: Xianyang Liu <xianyang.liu@intel.com>

Closes #18284 from ConeyLiu/treenode.
  • Loading branch information
ConeyLiu authored and cloud-fan committed Jun 16, 2017
1 parent 5d35d5c commit 87ab0ce
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
Expand Up @@ -340,8 +340,18 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
arg
}
case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) =>
val newChild1 = f(arg1.asInstanceOf[BaseType])
val newChild2 = f(arg2.asInstanceOf[BaseType])
val newChild1 = if (containsChild(arg1)) {
f(arg1.asInstanceOf[BaseType])
} else {
arg1.asInstanceOf[BaseType]
}

val newChild2 = if (containsChild(arg2)) {
f(arg2.asInstanceOf[BaseType])
} else {
arg2.asInstanceOf[BaseType]
}

if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) {
changed = true
(newChild1, newChild2)
Expand Down
Expand Up @@ -54,13 +54,21 @@ case class ComplexPlan(exprs: Seq[Seq[Expression]])
override def output: Seq[Attribute] = Nil
}

case class ExpressionInMap(map: Map[String, Expression]) extends Expression with Unevaluable {
case class ExpressionInMap(map: Map[String, Expression]) extends Unevaluable {
override def children: Seq[Expression] = map.values.toSeq
override def nullable: Boolean = true
override def dataType: NullType = NullType
override lazy val resolved = true
}

case class SeqTupleExpression(sons: Seq[(Expression, Expression)],
nonSons: Seq[(Expression, Expression)]) extends Unevaluable {
override def children: Seq[Expression] = sons.flatMap(t => Iterator(t._1, t._2))
override def nullable: Boolean = true
override def dataType: NullType = NullType
override lazy val resolved = true
}

case class JsonTestTreeNode(arg: Any) extends LeafNode {
override def output: Seq[Attribute] = Seq.empty[Attribute]
}
Expand Down Expand Up @@ -146,6 +154,17 @@ class TreeNodeSuite extends SparkFunSuite {
assert(actual === Dummy(None))
}

test("mapChildren should only works on children") {
val children = Seq((Literal(1), Literal(2)))
val nonChildren = Seq((Literal(3), Literal(4)))
val before = SeqTupleExpression(children, nonChildren)
val toZero: PartialFunction[Expression, Expression] = { case Literal(_, _) => Literal(0) }
val expect = SeqTupleExpression(Seq((Literal(0), Literal(0))), nonChildren)

val actual = before mapChildren toZero
assert(actual === expect)
}

test("preserves origin") {
CurrentOrigin.setPosition(1, 1)
val add = Add(Literal(1), Literal(1))
Expand Down

0 comments on commit 87ab0ce

Please sign in to comment.