From cf6c7e91e2eb9331cf9123a446a14b71f625200a Mon Sep 17 00:00:00 2001 From: "tanel.kiis@gmail.com" Date: Thu, 17 Sep 2020 22:46:53 +0300 Subject: [PATCH] Bitwise operations are commutative --- .../sql/catalyst/expressions/Canonicalize.scala | 7 +++++++ .../catalyst/expressions/CanonicalizeSuite.scala | 16 ++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala index a8031086d82f7..1ecf4372cfb58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala @@ -80,6 +80,13 @@ object Canonicalize { orderCommutative(a, { case And(l, r) if l.deterministic && r.deterministic => Seq(l, r)}) .reduce(And) + case o: BitwiseOr => + orderCommutative(o, { case BitwiseOr(l, r) => Seq(l, r) }).reduce(BitwiseOr) + case a: BitwiseAnd => + orderCommutative(a, { case BitwiseAnd(l, r) => Seq(l, r) }).reduce(BitwiseAnd) + case x: BitwiseXor => + orderCommutative(x, { case BitwiseXor(l, r) => Seq(l, r) }).reduce(BitwiseXor) + case EqualTo(l, r) if l.hashCode() > r.hashCode() => EqualTo(r, l) case EqualNullSafe(l, r) if l.hashCode() > r.hashCode() => EqualNullSafe(r, l) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala index a043b4cbed1f1..d822fe736ef89 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.TimeZone import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans.logical.Range import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType} @@ -95,4 +96,19 @@ class CanonicalizeSuite extends SparkFunSuite { val castWithTimeZoneId = Cast(literal, LongType, Some(TimeZone.getDefault.getID)) assert(castWithTimeZoneId.semanticEquals(cast)) } + + test("SPARK-32927: Bitwise operations are commutative") { + Seq( + (l: Expression, r: Expression) => BitwiseOr(l, r), + (l: Expression, r: Expression) => BitwiseAnd(l, r), + (l: Expression, r: Expression) => BitwiseXor(l, r) + ).foreach(f => { + val e1 = f('a, f('b, 'c)) + val e2 = f(f('a, 'b), 'c) + val e3 = f('a, f('b, 'a)) + + assert(e1.canonicalized == e2.canonicalized) + assert(e1.canonicalized != e3.canonicalized) + }) + } }