Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -260,19 +260,32 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] {
q.transformExpressionsDownWithPruning(_.containsPattern(BINARY_ARITHMETIC)) {
case a @ Add(_, _, f) if a.deterministic && a.dataType.isInstanceOf[IntegralType] =>
val (foldables, others) = flattenAdd(a, groupingExpressionSet).partition(_.foldable)
if (foldables.size > 1) {
if (foldables.nonEmpty) {
val foldableExpr = foldables.reduce((x, y) => Add(x, y, f))
val c = Literal.create(foldableExpr.eval(EmptyRow), a.dataType)
if (others.isEmpty) c else Add(others.reduce((x, y) => Add(x, y, f)), c, f)
val foldableValue = foldableExpr.eval(EmptyRow)
if (others.isEmpty) {
Literal.create(foldableValue, a.dataType)
} else if (foldableValue == 0) {
others.reduce((x, y) => Add(x, y, f))
} else {
Add(others.reduce((x, y) => Add(x, y, f)), Literal.create(foldableValue, a.dataType), f)
}
} else {
a
}
case m @ Multiply(_, _, f) if m.deterministic && m.dataType.isInstanceOf[IntegralType] =>
val (foldables, others) = flattenMultiply(m, groupingExpressionSet).partition(_.foldable)
if (foldables.size > 1) {
if (foldables.nonEmpty) {
val foldableExpr = foldables.reduce((x, y) => Multiply(x, y, f))
val c = Literal.create(foldableExpr.eval(EmptyRow), m.dataType)
if (others.isEmpty) c else Multiply(others.reduce((x, y) => Multiply(x, y, f)), c, f)
val foldableValue = foldableExpr.eval(EmptyRow)
if (others.isEmpty || (foldableValue == 0 && !m.nullable)) {
Literal.create(foldableValue, m.dataType)
} else if (foldableValue == 1) {
others.reduce((x, y) => Multiply(x, y, f))
} else {
Multiply(others.reduce((x, y) => Multiply(x, y, f)),
Literal.create(foldableValue, m.dataType), f)
}
} else {
m
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
Expand Down Expand Up @@ -74,4 +75,35 @@ class ReorderAssociativeOperatorSuite extends PlanTest {

comparePlans(optimized, correctAnswer)
}

test("SPARK-49915: Handle zero and one in associative operators") {
val originalQuery =
testRelation.select(
$"a" + 0,
Literal(-3) + $"a" + 3,
$"b" * 0 * 1 * 2 * 3,
Copy link
Contributor

@cloud-fan cloud-fan Oct 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we test non-nullable b multiply 0?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed

Count($"b") * 0,
$"b" * 1 * 1,
($"b" + 0) * 1 * 2 * 3 * 4,
$"a" + 0 + $"b" + 0 + $"c" + 0,
$"a" + 0 + $"b" * 1 + $"c" + 0
)

val optimized = Optimize.execute(originalQuery.analyze)

val correctAnswer =
testRelation
.select(
$"a".as("(a + 0)"),
$"a".as("((-3 + a) + 3)"),
($"b" * 0).as("((((b * 0) * 1) * 2) * 3)"),
Literal(0L).as("(count(b) * 0)"),
$"b".as("((b * 1) * 1)"),
($"b" * 24).as("(((((b + 0) * 1) * 2) * 3) * 4)"),
($"a" + $"b" + $"c").as("""(((((a + 0) + b) + 0) + c) + 0)"""),
($"a" + $"b" + $"c").as("((((a + 0) + (b * 1)) + c) + 0)")
).analyze

comparePlans(optimized, correctAnswer)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,24 @@ Project [a#x, (b#x + c#x) AS (b + c)#x]
+- Relation spark_catalog.default.t1[a#x,b#x,c#x] parquet


-- !query
select b + 0 from t1 where a = 5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should test b * 0, where the optimization should be skipped to respect the NULL semantic.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are already such test cases in ReorderAssociativeOperatorSuite.scala

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, then do we need to add additional tests in this golden file?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ReorderAssociativeOperatorSuite is more about the plan correctness.

If we want the test to be more intuitive and specific for NULL semantics, I can add some more null-relevant cases here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-- !query analysis
Project [(b#x + 0) AS (b + 0)#x]
+- Filter (a#x = 5)
+- SubqueryAlias spark_catalog.default.t1
+- Relation spark_catalog.default.t1[a#x,b#x,c#x] parquet


-- !query
select -100 + b + 100 from t1 where a = 5
-- !query analysis
Project [((-100 + b#x) + 100) AS ((-100 + b) + 100)#x]
+- Filter (a#x = 5)
+- SubqueryAlias spark_catalog.default.t1
+- Relation spark_catalog.default.t1[a#x,b#x,c#x] parquet


-- !query
select a+10, b*0 from t1
-- !query analysis
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ insert into t1 values(7,null,null);

-- Adding anything to null gives null
select a, b+c from t1;
select b + 0 from t1 where a = 5;
select -100 + b + 100 from t1 where a = 5;

-- Multiplying null by zero gives null
select a+10, b*0 from t1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,22 @@ struct<a:int,(b + c):int>
7 NULL


-- !query
select b + 0 from t1 where a = 5
-- !query schema
struct<(b + 0):int>
-- !query output
NULL


-- !query
select -100 + b + 100 from t1 where a = 5
-- !query schema
struct<((-100 + b) + 100):int>
-- !query output
NULL


-- !query
select a+10, b*0 from t1
-- !query schema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2688,7 +2688,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
val df = sql("SELECT SUM(2147483647 + DEPT) FROM h2.test.employee")
checkAggregateRemoved(df, ansiMode)
val expectedPlanFragment = if (ansiMode) {
"PushedAggregates: [SUM(2147483647 + DEPT)], " +
"PushedAggregates: [SUM(DEPT + 2147483647)], " +
Copy link
Member Author

@yaooqinn yaooqinn Oct 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test changes as the reorder rule is applied when containing only one foldable expression

"PushedFilters: [], " +
"PushedGroupByExpressions: []"
} else {
Expand Down