Skip to content

Commit

Permalink
[SPARK-33848][SQL][FOLLOWUP] Introduce allowList for push into (if / …
Browse files Browse the repository at this point in the history
…case) branches

### What changes were proposed in this pull request?

Introduce allowList push into (if / case) branches to fix potential bug.

### Why are the changes needed?

 Fix potential bug.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Existing test.

Closes #30955 from wangyum/SPARK-33848-2.

Authored-by: Yuming Wang <yumwang@ebay.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
wangyum authored and cloud-fan committed Dec 29, 2020
1 parent 3b1b209 commit 872107f
Showing 1 changed file with 34 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -553,41 +553,68 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper {
foldables.nonEmpty && others.length < 2
}

// Not all UnaryExpression can be pushed into (if / case) branches, e.g. Alias.
private def supportedUnaryExpression(e: UnaryExpression): Boolean = e match {
case _: IsNull | _: IsNotNull => true
case _: UnaryMathExpression | _: Abs | _: Bin | _: Factorial | _: Hex => true
case _: String2StringExpression | _: Ascii | _: Base64 | _: BitLength | _: Chr | _: Length =>
true
case _: CastBase => true
case _: GetDateField | _: LastDay => true
case _: ExtractIntervalPart => true
case _: ArraySetLike => true
case _: ExtractValue => true
case _ => false
}

// Not all BinaryExpression can be pushed into (if / case) branches.
private def supportedBinaryExpression(e: BinaryExpression): Boolean = e match {
case _: BinaryComparison | _: StringPredicate | _: StringRegexExpression => true
case _: BinaryArithmetic => true
case _: BinaryMathExpression => true
case _: AddMonths | _: DateAdd | _: DateAddInterval | _: DateDiff | _: DateSub => true
case _: FindInSet | _: RoundBase => true
case _ => false
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
case a: Alias => a // Skip an alias.
case u @ UnaryExpression(i @ If(_, trueValue, falseValue))
if atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
if supportedUnaryExpression(u) && atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
i.copy(
trueValue = u.withNewChildren(Array(trueValue)),
falseValue = u.withNewChildren(Array(falseValue)))

case u @ UnaryExpression(c @ CaseWhen(branches, elseValue))
if atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
if supportedUnaryExpression(u) && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
c.copy(
branches.map(e => e.copy(_2 = u.withNewChildren(Array(e._2)))),
elseValue.map(e => u.withNewChildren(Array(e))))

case b @ BinaryExpression(i @ If(_, trueValue, falseValue), right)
if right.foldable && atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
if supportedBinaryExpression(b) && right.foldable &&
atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
i.copy(
trueValue = b.withNewChildren(Array(trueValue, right)),
falseValue = b.withNewChildren(Array(falseValue, right)))

case b @ BinaryExpression(left, i @ If(_, trueValue, falseValue))
if left.foldable && atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
if supportedBinaryExpression(b) && left.foldable &&
atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
i.copy(
trueValue = b.withNewChildren(Array(left, trueValue)),
falseValue = b.withNewChildren(Array(left, falseValue)))

case b @ BinaryExpression(c @ CaseWhen(branches, elseValue), right)
if right.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
if supportedBinaryExpression(b) && right.foldable &&
atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
c.copy(
branches.map(e => e.copy(_2 = b.withNewChildren(Array(e._2, right)))),
elseValue.map(e => b.withNewChildren(Array(e, right))))

case b @ BinaryExpression(left, c @ CaseWhen(branches, elseValue))
if left.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
if supportedBinaryExpression(b) && left.foldable &&
atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
c.copy(
branches.map(e => e.copy(_2 = b.withNewChildren(Array(left, e._2)))),
elseValue.map(e => b.withNewChildren(Array(left, e))))
Expand Down

0 comments on commit 872107f

Please sign in to comment.