Skip to content

Commit

Permalink
[SPARK-48035][SQL] Fix try_add/try_multiply being semantic equal to a…
Browse files Browse the repository at this point in the history
…dd/multiply

### What changes were proposed in this pull request?
- This PR fixes a correctness bug in commutative operator canonicalization where we currently do not take into account the evaluation mode during operand reordering.
- As a result, the following condition will be incorrectly true:
```
val l1 = Literal(1)
val l2 = Literal(2)
val l3 = Literal(3)
val expr1 = Add(Add(l1, l2), l3)
val expr2 = Add(Add(l2, l1, EvalMode.TRY), l3)
expr1.semanticEquals(expr2)
```
- To fix the issue, we now reorder commutative operands only if all operators have the same evaluation mode.

### Why are the changes needed?
- To fix a correctness bug.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
- Added unit tests

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #46307 from db-scnakandala/db-scnakandala/master.

Authored-by: Supun Nakandala <supun.nakandala@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
db-scnakandala authored and HyukjinKwon committed May 7, 2024
1 parent c94fe37 commit 7290000
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1378,6 +1378,20 @@ trait CommutativeExpression extends Expression {
}
reorderResult
}

/**
* Helper method to collect the evaluation mode of the commutative expressions. This is
* used by the canonicalized methods of [[Add]] and [[Multiply]] operators to ensure that
* all operands have the same evaluation mode before reordering the operands.
*/
protected def collectEvalModes(
e: Expression,
f: PartialFunction[CommutativeExpression, Seq[EvalMode.Value]]
): Seq[EvalMode.Value] = e match {
case c: CommutativeExpression if f.isDefinedAt(c) =>
f(c) ++ c.children.flatMap(collectEvalModes(_, f))
case _ => Nil
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -452,13 +452,14 @@ case class Add(
copy(left = newLeft, right = newRight)

override lazy val canonicalized: Expression = {
// TODO: do not reorder consecutive `Add`s with different `evalMode`
val reorderResult = buildCanonicalizedPlan(
val evalModes = collectEvalModes(this, {case Add(_, _, evalMode) => Seq(evalMode)})
lazy val reorderResult = buildCanonicalizedPlan(
{ case Add(l, r, _) => Seq(l, r) },
{ case (l: Expression, r: Expression) => Add(l, r, evalMode)},
Some(evalMode)
)
if (resolved && reorderResult.resolved && reorderResult.dataType == dataType) {
if (resolved && evalModes.forall(_ == evalMode) && reorderResult.resolved &&
reorderResult.dataType == dataType) {
reorderResult
} else {
// SPARK-40903: Avoid reordering decimal Add for canonicalization if the result data type is
Expand Down Expand Up @@ -608,12 +609,16 @@ case class Multiply(
newLeft: Expression, newRight: Expression): Multiply = copy(left = newLeft, right = newRight)

override lazy val canonicalized: Expression = {
// TODO: do not reorder consecutive `Multiply`s with different `evalMode`
buildCanonicalizedPlan(
{ case Multiply(l, r, _) => Seq(l, r) },
{ case (l: Expression, r: Expression) => Multiply(l, r, evalMode)},
Some(evalMode)
)
val evalModes = collectEvalModes(this, {case Multiply(_, _, evalMode) => Seq(evalMode)})
if (evalModes.forall(_ == evalMode)) {
buildCanonicalizedPlan(
{ case Multiply(l, r, _) => Seq(l, r) },
{ case (l: Expression, r: Expression) => Multiply(l, r, evalMode)},
Some(evalMode)
)
} else {
withCanonicalizedChildren
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -454,4 +454,29 @@ class CanonicalizeSuite extends SparkFunSuite {
// different.
assert(common3.canonicalized != common4.canonicalized)
}

test("SPARK-48035: Add/Multiply operator canonicalization should take into account the" +
"evaluation mode of the operands before operand reordering") {
Seq(1, 10) map { multiCommutativeOpOptThreshold =>
val default = SQLConf.get.getConf(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD)
SQLConf.get.setConfString(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD.key,
multiCommutativeOpOptThreshold.toString)
try {
val l1 = Literal(1)
val l2 = Literal(2)
val l3 = Literal(3)

val expr1 = Add(Add(l1, l2), l3)
val expr2 = Add(Add(l2, l1, EvalMode.TRY), l3)
assert(!expr1.semanticEquals(expr2))

val expr3 = Multiply(Multiply(l1, l2), l3)
val expr4 = Multiply(Multiply(l2, l1, EvalMode.TRY), l3)
assert(!expr3.semanticEquals(expr4))
} finally {
SQLConf.get.setConfString(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD.key,
default.toString)
}
}
}
}

0 comments on commit 7290000

Please sign in to comment.