From a0366f29046b2eeac572bc3a298d9ee99eddd27c Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Sat, 11 Jul 2020 19:04:16 -0700 Subject: [PATCH 1/8] initial commit --- .../exchange/EnsureRequirements.scala | 60 +++++++++++--- .../spark/sql/execution/PlannerSuite.scala | 82 +++++++++++++++++++ .../spark/sql/sources/BucketedReadSuite.scala | 31 +++++++ 3 files changed, 161 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 3242ac21ab324..720c022c05d18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -130,9 +130,14 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { leftKeys: IndexedSeq[Expression], rightKeys: IndexedSeq[Expression], expectedOrderOfKeys: Seq[Expression], - currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { + currentOrderOfKeys: Seq[Expression]): Option[(Seq[Expression], Seq[Expression])] = { if (expectedOrderOfKeys.size != currentOrderOfKeys.size) { - return (leftKeys, rightKeys) + return None + } + + // Check if the current order already satisfies the expected order. + if (expectedOrderOfKeys.zip(currentOrderOfKeys).forall(p => p._1.semanticEquals(p._2))) { + return Some(leftKeys, rightKeys) } // Build a lookup between an expression and the positions its holds in the current key seq. @@ -159,10 +164,10 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { rightKeysBuffer += rightKeys(index) case _ => // The expression cannot be found, or we have exhausted all indices for that expression. - return (leftKeys, rightKeys) + return None } } - (leftKeysBuffer, rightKeysBuffer) + Some(leftKeysBuffer, rightKeysBuffer) } private def reorderJoinKeys( @@ -171,19 +176,50 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { leftPartitioning: Partitioning, rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = { if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) { - (leftPartitioning, rightPartitioning) match { - case (HashPartitioning(leftExpressions, _), _) => - reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leftExpressions, leftKeys) - case (_, HashPartitioning(rightExpressions, _)) => - reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys) - case _ => - (leftKeys, rightKeys) - } + reorderJoinKeysRecursively(leftKeys, rightKeys, leftPartitioning, rightPartitioning) + .getOrElse((leftKeys, rightKeys)) } else { (leftKeys, rightKeys) } } + /** + * Recursively reorders the join keys based on the partitioning. It starts reordering + * keys to match HashPartitioning on either side, followed by PartitioningCollection. + */ + private def reorderJoinKeysRecursively( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + leftPartitioning: Partitioning, + rightPartitioning: Partitioning): Option[(Seq[Expression], Seq[Expression])] = { + (leftPartitioning, rightPartitioning) match { + case (HashPartitioning(leftExpressions, _), _) => + reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leftExpressions, leftKeys) + .orElse(reorderJoinKeysRecursively( + leftKeys, rightKeys, UnknownPartitioning(0), rightPartitioning)) + case (_, HashPartitioning(rightExpressions, _)) => + reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys) + .orElse(reorderJoinKeysRecursively( + leftKeys, rightKeys, leftPartitioning, UnknownPartitioning(0))) + case (PartitioningCollection(partitionings), _) => + partitionings.foreach { p => + reorderJoinKeysRecursively(leftKeys, rightKeys, p, rightPartitioning).map { k => + return Some(k) + } + } + reorderJoinKeysRecursively(leftKeys, rightKeys, UnknownPartitioning(0), rightPartitioning) + case (_, PartitioningCollection(partitionings)) => + partitionings.foreach { p => + reorderJoinKeysRecursively(leftKeys, rightKeys, leftPartitioning, p).map { k => + return Some(k) + } + } + None + case _ => + None + } + } + /** * When the physical operators are created for JOIN, the ordering of join keys is based on order * in which the join keys appear in the user query. That might not match with the output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index d428b7ebc0e91..0bbd97eb0ab34 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -994,6 +994,88 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { } } } + + test("EnsureRequirements.reorder should fallback to the right side HashPartitioning") { + val plan1 = DummySparkPlan( + outputPartitioning = HashPartitioning(exprA :: exprB :: exprC :: Nil, 5)) + val plan2 = DummySparkPlan( + outputPartitioning = HashPartitioning(exprB :: exprC :: Nil, 5)) + // The left keys cannot be reordered to match the left partitioning, and it should + // fall back to reorder the right side. + val smjExec = SortMergeJoinExec( + exprA :: exprB :: Nil, exprC :: exprB :: Nil, Inner, None, plan1, plan2) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(smjExec) + outputPlan match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, + SortExec(_, _, + ShuffleExchangeExec(HashPartitioning(leftPartitioningExpressions, _), _, _), _), + SortExec(_, _, + DummySparkPlan(_, _, HashPartitioning(rightPartitioningExpressions, _), _, _), _), _) => + assert(leftKeys !== smjExec.leftKeys) + assert(rightKeys !== smjExec.rightKeys) + assert(leftKeys === leftPartitioningExpressions) + assert(rightKeys === rightPartitioningExpressions) + case _ => fail(outputPlan.toString) + } + } + + test("EnsureRequirements.reorder should handle PartitioningCollection") { + // PartitioningCollection on the left side of join. + val plan1 = DummySparkPlan( + outputPartitioning = PartitioningCollection(Seq( + HashPartitioning(exprA :: exprB :: Nil, 5), + HashPartitioning(exprA :: Nil, 5)))) + val plan2 = DummySparkPlan() + val smjExec1 = SortMergeJoinExec( + exprB :: exprA :: Nil, exprA :: exprB :: Nil, Inner, None, plan1, plan2) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(smjExec1) + outputPlan match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, + SortExec(_, _, + DummySparkPlan(_, _, PartitioningCollection(leftPartitionings), _, _), _), + SortExec(_, _, + ShuffleExchangeExec(HashPartitioning(rightPartitioningExpressions, _), _, _), _), _) => + assert(leftKeys !== smjExec1.leftKeys) + assert(rightKeys !== smjExec1.rightKeys) + assert(leftKeys === leftPartitionings(0).asInstanceOf[HashPartitioning].expressions) + assert(rightKeys === rightPartitioningExpressions) + case _ => fail(outputPlan.toString) + } + + // PartitioningCollection on the right side of join. + val smjExec2 = SortMergeJoinExec( + exprA :: exprB :: Nil, exprB :: exprA :: Nil, Inner, None, plan2, plan1) + val outputPlan2 = EnsureRequirements(spark.sessionState.conf).apply(smjExec2) + outputPlan2 match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, + SortExec(_, _, + ShuffleExchangeExec(HashPartitioning(leftPartitioningExpressions, _), _, _), _), + SortExec(_, _, + DummySparkPlan(_, _, PartitioningCollection(rightPartitionings), _, _), _), _) => + assert(leftKeys !== smjExec2.leftKeys) + assert(rightKeys !== smjExec2.rightKeys) + assert(leftKeys === leftPartitioningExpressions) + assert(rightKeys === rightPartitionings(0).asInstanceOf[HashPartitioning].expressions) + case _ => fail(outputPlan2.toString) + } + + // Both sides are PartitioningCollection and falls back to the right side. + val smjExec3 = SortMergeJoinExec( + exprA :: exprC :: Nil, exprB :: exprA :: Nil, Inner, None, plan1, plan1) + val outputPlan3 = EnsureRequirements(spark.sessionState.conf).apply(smjExec2) + outputPlan3 match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, + SortExec(_, _, + ShuffleExchangeExec(HashPartitioning(leftPartitioningExpressions, _), _, _), _), + SortExec(_, _, + DummySparkPlan(_, _, PartitioningCollection(rightPartitionings), _, _), _), _) => + assert(leftKeys !== smjExec2.leftKeys) + assert(rightKeys !== smjExec2.rightKeys) + assert(leftKeys === leftPartitioningExpressions) + assert(rightKeys === rightPartitionings(0).asInstanceOf[HashPartitioning].expressions) + case _ => fail(outputPlan3.toString) + } + } } // Used for unit-testing EnsureRequirements diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index b6767eb3132ea..a53238aaa6176 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -943,6 +943,37 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { } } + test("terry - hashpartitioning") { + withTable("t1", "t2") { + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0", + SQLConf.SHUFFLE_PARTITIONS.key -> "4") { + val df1 = (0 until 10).map(i => (i % 5, i % 13)).toDF("i1", "j1") + val df2 = (0 until 10).map(i => (i % 7, i % 11)).toDF("i2", "j2") + + df1.write.format("parquet").bucketBy(4, "i1", "j1").saveAsTable("t1") + df2.write.format("parquet").bucketBy(4, "i2", "j2").saveAsTable("t2") + + val t1 = spark.table("t1") + val t2 = spark.table("t2") + val join = t1.join(t2, t1("i1") === t2("j2") && t1("i1") === t2("i2")) + join.explain + } + } + } + + + test("terry - collectionpartition") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { + val df1 = (0 until 100).map(i => (i % 5, i % 13, i.toString)).toDF("i1", "j1", "k1") + val df2 = (0 until 100).map(i => (i % 7, i % 11, i.toString)).toDF("i2", "j2", "k2") + val df3 = (0 until 100).map(i => (i % 5, i % 13, i.toString)).toDF("i3", "j3", "k3") + val join = df1.join(df2, df1("i1") === df2("i2") && df1("j1") === df2("j2")) + val join2 = join.join(df3, join("j1") === df3("j3") && join("i1") === df3("i3")) + join2.explain + } + } + test("bucket coalescing is applied when join expressions match with partitioning expressions") { withTable("t1", "t2") { df1.write.format("parquet").bucketBy(8, "i", "j").saveAsTable("t1") From 99493e422f6d3540cf342a4ff23686d1aee4ac14 Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Sat, 11 Jul 2020 19:11:27 -0700 Subject: [PATCH 2/8] update comments --- .../exchange/EnsureRequirements.scala | 4 +-- .../spark/sql/sources/BucketedReadSuite.scala | 31 ------------------- 2 files changed, 2 insertions(+), 33 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 720c022c05d18..c6d441b80d45d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -184,8 +184,8 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { } /** - * Recursively reorders the join keys based on the partitioning. It starts reordering - * keys to match HashPartitioning on either side, followed by PartitioningCollection. + * Recursively reorders the join keys based on partitioning. It starts reordering the + * join keys to match HashPartitioning on either side, followed by PartitioningCollection. */ private def reorderJoinKeysRecursively( leftKeys: Seq[Expression], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index a53238aaa6176..b6767eb3132ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -943,37 +943,6 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { } } - test("terry - hashpartitioning") { - withTable("t1", "t2") { - withSQLConf( - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0", - SQLConf.SHUFFLE_PARTITIONS.key -> "4") { - val df1 = (0 until 10).map(i => (i % 5, i % 13)).toDF("i1", "j1") - val df2 = (0 until 10).map(i => (i % 7, i % 11)).toDF("i2", "j2") - - df1.write.format("parquet").bucketBy(4, "i1", "j1").saveAsTable("t1") - df2.write.format("parquet").bucketBy(4, "i2", "j2").saveAsTable("t2") - - val t1 = spark.table("t1") - val t2 = spark.table("t2") - val join = t1.join(t2, t1("i1") === t2("j2") && t1("i1") === t2("i2")) - join.explain - } - } - } - - - test("terry - collectionpartition") { - withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { - val df1 = (0 until 100).map(i => (i % 5, i % 13, i.toString)).toDF("i1", "j1", "k1") - val df2 = (0 until 100).map(i => (i % 7, i % 11, i.toString)).toDF("i2", "j2", "k2") - val df3 = (0 until 100).map(i => (i % 5, i % 13, i.toString)).toDF("i3", "j3", "k3") - val join = df1.join(df2, df1("i1") === df2("i2") && df1("j1") === df2("j2")) - val join2 = join.join(df3, join("j1") === df3("j3") && join("i1") === df3("i3")) - join2.explain - } - } - test("bucket coalescing is applied when join expressions match with partitioning expressions") { withTable("t1", "t2") { df1.write.format("parquet").bucketBy(8, "i", "j").saveAsTable("t1") From 83086494f14f16fabc762921545de4c73e1d1743 Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Wed, 15 Jul 2020 12:17:17 -0700 Subject: [PATCH 3/8] Fix merge error --- .../spark/sql/execution/exchange/EnsureRequirements.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 0871dbf3093d9..68b2a8d9619b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -167,7 +167,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { return None } } - (leftKeysBuffer.toSeq, rightKeysBuffer.toSeq) + Some(leftKeysBuffer.toSeq, rightKeysBuffer.toSeq) } private def reorderJoinKeys( From e5b078f066058e41d5fa861ec42a3a8bf16ca66e Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Thu, 6 Aug 2020 18:47:01 -0700 Subject: [PATCH 4/8] address PR comments --- .../spark/sql/execution/PlannerSuite.scala | 82 ---------- .../exchange/EnsureRequirementsSuite.scala | 146 ++++++++++++++++++ 2 files changed, 146 insertions(+), 82 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 0bbd97eb0ab34..d428b7ebc0e91 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -994,88 +994,6 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { } } } - - test("EnsureRequirements.reorder should fallback to the right side HashPartitioning") { - val plan1 = DummySparkPlan( - outputPartitioning = HashPartitioning(exprA :: exprB :: exprC :: Nil, 5)) - val plan2 = DummySparkPlan( - outputPartitioning = HashPartitioning(exprB :: exprC :: Nil, 5)) - // The left keys cannot be reordered to match the left partitioning, and it should - // fall back to reorder the right side. - val smjExec = SortMergeJoinExec( - exprA :: exprB :: Nil, exprC :: exprB :: Nil, Inner, None, plan1, plan2) - val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(smjExec) - outputPlan match { - case SortMergeJoinExec(leftKeys, rightKeys, _, _, - SortExec(_, _, - ShuffleExchangeExec(HashPartitioning(leftPartitioningExpressions, _), _, _), _), - SortExec(_, _, - DummySparkPlan(_, _, HashPartitioning(rightPartitioningExpressions, _), _, _), _), _) => - assert(leftKeys !== smjExec.leftKeys) - assert(rightKeys !== smjExec.rightKeys) - assert(leftKeys === leftPartitioningExpressions) - assert(rightKeys === rightPartitioningExpressions) - case _ => fail(outputPlan.toString) - } - } - - test("EnsureRequirements.reorder should handle PartitioningCollection") { - // PartitioningCollection on the left side of join. - val plan1 = DummySparkPlan( - outputPartitioning = PartitioningCollection(Seq( - HashPartitioning(exprA :: exprB :: Nil, 5), - HashPartitioning(exprA :: Nil, 5)))) - val plan2 = DummySparkPlan() - val smjExec1 = SortMergeJoinExec( - exprB :: exprA :: Nil, exprA :: exprB :: Nil, Inner, None, plan1, plan2) - val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(smjExec1) - outputPlan match { - case SortMergeJoinExec(leftKeys, rightKeys, _, _, - SortExec(_, _, - DummySparkPlan(_, _, PartitioningCollection(leftPartitionings), _, _), _), - SortExec(_, _, - ShuffleExchangeExec(HashPartitioning(rightPartitioningExpressions, _), _, _), _), _) => - assert(leftKeys !== smjExec1.leftKeys) - assert(rightKeys !== smjExec1.rightKeys) - assert(leftKeys === leftPartitionings(0).asInstanceOf[HashPartitioning].expressions) - assert(rightKeys === rightPartitioningExpressions) - case _ => fail(outputPlan.toString) - } - - // PartitioningCollection on the right side of join. - val smjExec2 = SortMergeJoinExec( - exprA :: exprB :: Nil, exprB :: exprA :: Nil, Inner, None, plan2, plan1) - val outputPlan2 = EnsureRequirements(spark.sessionState.conf).apply(smjExec2) - outputPlan2 match { - case SortMergeJoinExec(leftKeys, rightKeys, _, _, - SortExec(_, _, - ShuffleExchangeExec(HashPartitioning(leftPartitioningExpressions, _), _, _), _), - SortExec(_, _, - DummySparkPlan(_, _, PartitioningCollection(rightPartitionings), _, _), _), _) => - assert(leftKeys !== smjExec2.leftKeys) - assert(rightKeys !== smjExec2.rightKeys) - assert(leftKeys === leftPartitioningExpressions) - assert(rightKeys === rightPartitionings(0).asInstanceOf[HashPartitioning].expressions) - case _ => fail(outputPlan2.toString) - } - - // Both sides are PartitioningCollection and falls back to the right side. - val smjExec3 = SortMergeJoinExec( - exprA :: exprC :: Nil, exprB :: exprA :: Nil, Inner, None, plan1, plan1) - val outputPlan3 = EnsureRequirements(spark.sessionState.conf).apply(smjExec2) - outputPlan3 match { - case SortMergeJoinExec(leftKeys, rightKeys, _, _, - SortExec(_, _, - ShuffleExchangeExec(HashPartitioning(leftPartitioningExpressions, _), _, _), _), - SortExec(_, _, - DummySparkPlan(_, _, PartitioningCollection(rightPartitionings), _, _), _), _) => - assert(leftKeys !== smjExec2.leftKeys) - assert(rightKeys !== smjExec2.rightKeys) - assert(leftKeys === leftPartitioningExpressions) - assert(rightKeys === rightPartitionings(0).asInstanceOf[HashPartitioning].expressions) - case _ => fail(outputPlan3.toString) - } - } } // Used for unit-testing EnsureRequirements diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala new file mode 100644 index 0000000000000..dfe5a5a26d9a2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.exchange + +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, PartitioningCollection} +import org.apache.spark.sql.execution.{DummySparkPlan, SortExec} +import org.apache.spark.sql.execution.joins.SortMergeJoinExec +import org.apache.spark.sql.test.SharedSparkSession + +class EnsureRequirementsSuite extends SharedSparkSession { + private val exprA = Literal(1) + private val exprB = Literal(2) + private val exprC = Literal(3) + + test("EnsureRequirements.reorder should handle PartitioningCollection") { + val plan1 = DummySparkPlan( + outputPartitioning = PartitioningCollection(Seq( + HashPartitioning(exprA :: exprB :: Nil, 5), + HashPartitioning(exprA :: Nil, 5)))) + val plan2 = DummySparkPlan() + + // Test PartitioningCollection on the left side of join. + val smjExec1 = SortMergeJoinExec( + exprB :: exprA :: Nil, exprA :: exprB :: Nil, Inner, None, plan1, plan2) + EnsureRequirements(spark.sessionState.conf).apply(smjExec1) match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, + SortExec(_, _, + DummySparkPlan(_, _, PartitioningCollection(leftPartitionings), _, _), _), + SortExec(_, _, + ShuffleExchangeExec(HashPartitioning(rightPartitioningExpressions, _), _, _), _), _) => + assert(leftKeys !== smjExec1.leftKeys) + assert(rightKeys !== smjExec1.rightKeys) + assert(leftKeys === leftPartitionings.head.asInstanceOf[HashPartitioning].expressions) + assert(rightKeys === rightPartitioningExpressions) + case other => fail(other.toString) + } + + // Test PartitioningCollection on the right side of join. + val smjExec2 = SortMergeJoinExec( + exprA :: exprB :: Nil, exprB :: exprA :: Nil, Inner, None, plan2, plan1) + EnsureRequirements(spark.sessionState.conf).apply(smjExec2) match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, + SortExec(_, _, + ShuffleExchangeExec(HashPartitioning(leftPartitioningExpressions, _), _, _), _), + SortExec(_, _, + DummySparkPlan(_, _, PartitioningCollection(rightPartitionings), _, _), _), _) => + assert(leftKeys !== smjExec2.leftKeys) + assert(rightKeys !== smjExec2.rightKeys) + assert(leftKeys === leftPartitioningExpressions) + assert(rightKeys === rightPartitionings.head.asInstanceOf[HashPartitioning].expressions) + case other => fail(other.toString) + } + + // Both sides are PartitioningCollection, but left side cannot be reorderd to match + // and it should fall back to the right side. + val smjExec3 = SortMergeJoinExec( + exprA :: exprC :: Nil, exprB :: exprA :: Nil, Inner, None, plan1, plan1) + EnsureRequirements(spark.sessionState.conf).apply(smjExec3) match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, + SortExec(_, _, + ShuffleExchangeExec(HashPartitioning(leftPartitioningExpressions, _), _, _), _), + SortExec(_, _, + DummySparkPlan(_, _, PartitioningCollection(rightPartitionings), _, _), _), _) => + assert(leftKeys !== smjExec3.leftKeys) + assert(rightKeys !== smjExec3.rightKeys) + assert(leftKeys === leftPartitioningExpressions) + assert(rightKeys === rightPartitionings.head.asInstanceOf[HashPartitioning].expressions) + case other => fail(other.toString) + } + } + + test("EnsureRequirements.reorder should fallback to the other side partitioning") { + val plan1 = DummySparkPlan( + outputPartitioning = HashPartitioning(exprA :: exprB :: exprC :: Nil, 5)) + val plan2 = DummySparkPlan( + outputPartitioning = HashPartitioning(exprB :: exprC :: Nil, 5)) + + // Test fallback to the right side, which has PartitioningCollection. + val smjExec1 = SortMergeJoinExec( + exprA :: exprB :: Nil, exprC :: exprB :: Nil, Inner, None, plan1, plan2) + EnsureRequirements(spark.sessionState.conf).apply(smjExec1) match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, + SortExec(_, _, + ShuffleExchangeExec(HashPartitioning(leftPartitioningExpressions, _), _, _), _), + SortExec(_, _, + DummySparkPlan(_, _, HashPartitioning(rightPartitioningExpressions, _), _, _), _), _) => + assert(leftKeys !== smjExec1.leftKeys) + assert(rightKeys !== smjExec1.rightKeys) + assert(leftKeys === leftPartitioningExpressions) + assert(rightKeys === rightPartitioningExpressions) + case other => fail(other.toString) + } + + // Test fallback to the right side, which has PartitioningCollection. + val plan3 = DummySparkPlan( + outputPartitioning = PartitioningCollection(Seq(HashPartitioning(exprB :: exprC :: Nil, 5)))) + val smjExec2 = SortMergeJoinExec( + exprA :: exprB :: Nil, exprC :: exprB :: Nil, Inner, None, plan1, plan3) + EnsureRequirements(spark.sessionState.conf).apply(smjExec2) match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, + SortExec(_, _, + ShuffleExchangeExec(HashPartitioning(leftPartitioningExpressions, _), _, _), _), + SortExec(_, _, + DummySparkPlan(_, _, PartitioningCollection(rightPartitionings), _, _), _), _) => + assert(leftKeys !== smjExec2.leftKeys) + assert(rightKeys !== smjExec2.rightKeys) + assert(leftKeys === leftPartitioningExpressions) + assert(rightKeys === rightPartitionings.head.asInstanceOf[HashPartitioning].expressions) + case other => fail(other.toString) + } + + // The right side has HashPartitioning, so it is matched first, but no reordering match is + // found, and it should fall back to the left side, which has a PartitioningCollection. + val smjExec3 = SortMergeJoinExec( + exprC :: exprB :: Nil, exprA :: exprB :: Nil, Inner, None, plan3, plan1) + EnsureRequirements(spark.sessionState.conf).apply(smjExec3) match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, + SortExec(_, _, + DummySparkPlan(_, _, PartitioningCollection(leftPartitionings), _, _), _), + SortExec(_, _, + ShuffleExchangeExec(HashPartitioning(rightPartitioningExpressions, _), _, _), _), _) => + assert(leftKeys !== smjExec3.leftKeys) + assert(rightKeys !== smjExec3.rightKeys) + assert(leftKeys === leftPartitionings.head.asInstanceOf[HashPartitioning].expressions) + assert(rightKeys === rightPartitioningExpressions) + case other => fail(other.toString) + } + } +} From 89ad6ef4d9ca67fb4dc06bd66b7db9d7430f9fd0 Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Fri, 7 Aug 2020 11:14:57 -0700 Subject: [PATCH 5/8] address comments --- .../exchange/EnsureRequirements.scala | 28 +++++++++++-------- .../exchange/EnsureRequirementsSuite.scala | 4 +-- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 68b2a8d9619b9..47cf38dd53cf2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -176,7 +176,11 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { leftPartitioning: Partitioning, rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = { if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) { - reorderJoinKeysRecursively(leftKeys, rightKeys, leftPartitioning, rightPartitioning) + reorderJoinKeysRecursively( + leftKeys, + rightKeys, + Some(leftPartitioning), + Some(rightPartitioning)) .getOrElse((leftKeys, rightKeys)) } else { (leftKeys, rightKeys) @@ -190,27 +194,27 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { private def reorderJoinKeysRecursively( leftKeys: Seq[Expression], rightKeys: Seq[Expression], - leftPartitioning: Partitioning, - rightPartitioning: Partitioning): Option[(Seq[Expression], Seq[Expression])] = { + leftPartitioning: Option[Partitioning], + rightPartitioning: Option[Partitioning]): Option[(Seq[Expression], Seq[Expression])] = { (leftPartitioning, rightPartitioning) match { - case (HashPartitioning(leftExpressions, _), _) => + case (Some(HashPartitioning(leftExpressions, _)), _) => reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leftExpressions, leftKeys) .orElse(reorderJoinKeysRecursively( - leftKeys, rightKeys, UnknownPartitioning(0), rightPartitioning)) - case (_, HashPartitioning(rightExpressions, _)) => + leftKeys, rightKeys, None, rightPartitioning)) + case (_, Some(HashPartitioning(rightExpressions, _))) => reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys) .orElse(reorderJoinKeysRecursively( - leftKeys, rightKeys, leftPartitioning, UnknownPartitioning(0))) - case (PartitioningCollection(partitionings), _) => + leftKeys, rightKeys, leftPartitioning, None)) + case (Some(PartitioningCollection(partitionings)), _) => partitionings.foreach { p => - reorderJoinKeysRecursively(leftKeys, rightKeys, p, rightPartitioning).map { k => + reorderJoinKeysRecursively(leftKeys, rightKeys, Some(p), rightPartitioning).map { k => return Some(k) } } - reorderJoinKeysRecursively(leftKeys, rightKeys, UnknownPartitioning(0), rightPartitioning) - case (_, PartitioningCollection(partitionings)) => + reorderJoinKeysRecursively(leftKeys, rightKeys, None, rightPartitioning) + case (_, Some(PartitioningCollection(partitionings))) => partitionings.foreach { p => - reorderJoinKeysRecursively(leftKeys, rightKeys, leftPartitioning, p).map { k => + reorderJoinKeysRecursively(leftKeys, rightKeys, leftPartitioning, Some(p)).map { k => return Some(k) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index dfe5a5a26d9a2..5231b63375fc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -29,7 +29,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { private val exprB = Literal(2) private val exprC = Literal(3) - test("EnsureRequirements.reorder should handle PartitioningCollection") { + test("reorder should handle PartitioningCollection") { val plan1 = DummySparkPlan( outputPartitioning = PartitioningCollection(Seq( HashPartitioning(exprA :: exprB :: Nil, 5), @@ -86,7 +86,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { } } - test("EnsureRequirements.reorder should fallback to the other side partitioning") { + test("reorder should fallback to the other side partitioning") { val plan1 = DummySparkPlan( outputPartitioning = HashPartitioning(exprA :: exprB :: exprC :: Nil, 5)) val plan2 = DummySparkPlan( From 1729c8ba83178663e98f546312b7506d2852913c Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Mon, 24 Aug 2020 10:16:27 -0700 Subject: [PATCH 6/8] Address PR comment --- .../spark/sql/execution/exchange/EnsureRequirementsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index 5231b63375fc1..6027f6521c1a6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -92,7 +92,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { val plan2 = DummySparkPlan( outputPartitioning = HashPartitioning(exprB :: exprC :: Nil, 5)) - // Test fallback to the right side, which has PartitioningCollection. + // Test fallback to the right side, which has HashPartitioning. val smjExec1 = SortMergeJoinExec( exprA :: exprB :: Nil, exprC :: exprB :: Nil, Inner, None, plan1, plan2) EnsureRequirements(spark.sessionState.conf).apply(smjExec1) match { From 10b4d5a664ad685b2eb7ae09b0a6706e8bdbe812 Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Tue, 6 Oct 2020 18:40:19 -0700 Subject: [PATCH 7/8] Address PR comments --- .../exchange/EnsureRequirements.scala | 9 +-- .../exchange/EnsureRequirementsSuite.scala | 60 ++++++++----------- 2 files changed, 27 insertions(+), 42 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index ddd3168942026..3adb667718731 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -211,12 +211,9 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { .orElse(reorderJoinKeysRecursively( leftKeys, rightKeys, leftPartitioning, None)) case (Some(PartitioningCollection(partitionings)), _) => - partitionings.foreach { p => - reorderJoinKeysRecursively(leftKeys, rightKeys, Some(p), rightPartitioning).map { k => - return Some(k) - } - } - reorderJoinKeysRecursively(leftKeys, rightKeys, None, rightPartitioning) + partitionings.foldLeft(Option.empty[(Seq[Expression], Seq[Expression])]) { (res, p) => + res.orElse(reorderJoinKeysRecursively(leftKeys, rightKeys, Some(p), rightPartitioning)) + }.orElse(reorderJoinKeysRecursively(leftKeys, rightKeys, None, rightPartitioning)) case (_, Some(PartitioningCollection(partitionings))) => partitionings.foreach { p => reorderJoinKeysRecursively(leftKeys, rightKeys, leftPartitioning, Some(p)).map { k => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index 6027f6521c1a6..6081365f61290 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -41,14 +41,12 @@ class EnsureRequirementsSuite extends SharedSparkSession { exprB :: exprA :: Nil, exprA :: exprB :: Nil, Inner, None, plan1, plan2) EnsureRequirements(spark.sessionState.conf).apply(smjExec1) match { case SortMergeJoinExec(leftKeys, rightKeys, _, _, - SortExec(_, _, - DummySparkPlan(_, _, PartitioningCollection(leftPartitionings), _, _), _), - SortExec(_, _, - ShuffleExchangeExec(HashPartitioning(rightPartitioningExpressions, _), _, _), _), _) => + SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _), + SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _), _) => assert(leftKeys !== smjExec1.leftKeys) assert(rightKeys !== smjExec1.rightKeys) - assert(leftKeys === leftPartitionings.head.asInstanceOf[HashPartitioning].expressions) - assert(rightKeys === rightPartitioningExpressions) + assert(leftKeys === Seq(exprA, exprB)) + assert(rightKeys === Seq(exprB, exprA)) case other => fail(other.toString) } @@ -57,14 +55,12 @@ class EnsureRequirementsSuite extends SharedSparkSession { exprA :: exprB :: Nil, exprB :: exprA :: Nil, Inner, None, plan2, plan1) EnsureRequirements(spark.sessionState.conf).apply(smjExec2) match { case SortMergeJoinExec(leftKeys, rightKeys, _, _, - SortExec(_, _, - ShuffleExchangeExec(HashPartitioning(leftPartitioningExpressions, _), _, _), _), - SortExec(_, _, - DummySparkPlan(_, _, PartitioningCollection(rightPartitionings), _, _), _), _) => + SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _), _) => assert(leftKeys !== smjExec2.leftKeys) assert(rightKeys !== smjExec2.rightKeys) - assert(leftKeys === leftPartitioningExpressions) - assert(rightKeys === rightPartitionings.head.asInstanceOf[HashPartitioning].expressions) + assert(leftKeys === Seq(exprB, exprA)) + assert(rightKeys === Seq(exprA, exprB)) case other => fail(other.toString) } @@ -74,14 +70,12 @@ class EnsureRequirementsSuite extends SharedSparkSession { exprA :: exprC :: Nil, exprB :: exprA :: Nil, Inner, None, plan1, plan1) EnsureRequirements(spark.sessionState.conf).apply(smjExec3) match { case SortMergeJoinExec(leftKeys, rightKeys, _, _, - SortExec(_, _, - ShuffleExchangeExec(HashPartitioning(leftPartitioningExpressions, _), _, _), _), - SortExec(_, _, - DummySparkPlan(_, _, PartitioningCollection(rightPartitionings), _, _), _), _) => + SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _), _) => assert(leftKeys !== smjExec3.leftKeys) assert(rightKeys !== smjExec3.rightKeys) - assert(leftKeys === leftPartitioningExpressions) - assert(rightKeys === rightPartitionings.head.asInstanceOf[HashPartitioning].expressions) + assert(leftKeys === Seq(exprC, exprA)) + assert(rightKeys === Seq(exprA, exprB)) case other => fail(other.toString) } } @@ -97,14 +91,12 @@ class EnsureRequirementsSuite extends SharedSparkSession { exprA :: exprB :: Nil, exprC :: exprB :: Nil, Inner, None, plan1, plan2) EnsureRequirements(spark.sessionState.conf).apply(smjExec1) match { case SortMergeJoinExec(leftKeys, rightKeys, _, _, - SortExec(_, _, - ShuffleExchangeExec(HashPartitioning(leftPartitioningExpressions, _), _, _), _), - SortExec(_, _, - DummySparkPlan(_, _, HashPartitioning(rightPartitioningExpressions, _), _, _), _), _) => + SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, _: HashPartitioning, _, _), _), _) => assert(leftKeys !== smjExec1.leftKeys) assert(rightKeys !== smjExec1.rightKeys) - assert(leftKeys === leftPartitioningExpressions) - assert(rightKeys === rightPartitioningExpressions) + assert(leftKeys === Seq(exprB, exprA)) + assert(rightKeys === Seq(exprB, exprC)) case other => fail(other.toString) } @@ -115,14 +107,12 @@ class EnsureRequirementsSuite extends SharedSparkSession { exprA :: exprB :: Nil, exprC :: exprB :: Nil, Inner, None, plan1, plan3) EnsureRequirements(spark.sessionState.conf).apply(smjExec2) match { case SortMergeJoinExec(leftKeys, rightKeys, _, _, - SortExec(_, _, - ShuffleExchangeExec(HashPartitioning(leftPartitioningExpressions, _), _, _), _), - SortExec(_, _, - DummySparkPlan(_, _, PartitioningCollection(rightPartitionings), _, _), _), _) => + SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _), _) => assert(leftKeys !== smjExec2.leftKeys) assert(rightKeys !== smjExec2.rightKeys) - assert(leftKeys === leftPartitioningExpressions) - assert(rightKeys === rightPartitionings.head.asInstanceOf[HashPartitioning].expressions) + assert(leftKeys === Seq(exprB, exprA)) + assert(rightKeys === Seq(exprB, exprC)) case other => fail(other.toString) } @@ -132,14 +122,12 @@ class EnsureRequirementsSuite extends SharedSparkSession { exprC :: exprB :: Nil, exprA :: exprB :: Nil, Inner, None, plan3, plan1) EnsureRequirements(spark.sessionState.conf).apply(smjExec3) match { case SortMergeJoinExec(leftKeys, rightKeys, _, _, - SortExec(_, _, - DummySparkPlan(_, _, PartitioningCollection(leftPartitionings), _, _), _), - SortExec(_, _, - ShuffleExchangeExec(HashPartitioning(rightPartitioningExpressions, _), _, _), _), _) => + SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _), + SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _), _) => assert(leftKeys !== smjExec3.leftKeys) assert(rightKeys !== smjExec3.rightKeys) - assert(leftKeys === leftPartitionings.head.asInstanceOf[HashPartitioning].expressions) - assert(rightKeys === rightPartitioningExpressions) + assert(leftKeys === Seq(exprB, exprC)) + assert(rightKeys === Seq(exprB, exprA)) case other => fail(other.toString) } } From 3cd6df91b74fc98487ffb1f9920bada5b4040e96 Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Wed, 7 Oct 2020 09:28:19 -0700 Subject: [PATCH 8/8] Address PR comments --- .../sql/execution/exchange/EnsureRequirements.scala | 9 +++------ .../execution/exchange/EnsureRequirementsSuite.scala | 12 ------------ 2 files changed, 3 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 3adb667718731..3641654b89b76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -215,12 +215,9 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { res.orElse(reorderJoinKeysRecursively(leftKeys, rightKeys, Some(p), rightPartitioning)) }.orElse(reorderJoinKeysRecursively(leftKeys, rightKeys, None, rightPartitioning)) case (_, Some(PartitioningCollection(partitionings))) => - partitionings.foreach { p => - reorderJoinKeysRecursively(leftKeys, rightKeys, leftPartitioning, Some(p)).map { k => - return Some(k) - } - } - None + partitionings.foldLeft(Option.empty[(Seq[Expression], Seq[Expression])]) { (res, p) => + res.orElse(reorderJoinKeysRecursively(leftKeys, rightKeys, leftPartitioning, Some(p))) + }.orElse(None) case _ => None } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index 6081365f61290..38e68cd2512e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -43,8 +43,6 @@ class EnsureRequirementsSuite extends SharedSparkSession { case SortMergeJoinExec(leftKeys, rightKeys, _, _, SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _), SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _), _) => - assert(leftKeys !== smjExec1.leftKeys) - assert(rightKeys !== smjExec1.rightKeys) assert(leftKeys === Seq(exprA, exprB)) assert(rightKeys === Seq(exprB, exprA)) case other => fail(other.toString) @@ -57,8 +55,6 @@ class EnsureRequirementsSuite extends SharedSparkSession { case SortMergeJoinExec(leftKeys, rightKeys, _, _, SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _), SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _), _) => - assert(leftKeys !== smjExec2.leftKeys) - assert(rightKeys !== smjExec2.rightKeys) assert(leftKeys === Seq(exprB, exprA)) assert(rightKeys === Seq(exprA, exprB)) case other => fail(other.toString) @@ -72,8 +68,6 @@ class EnsureRequirementsSuite extends SharedSparkSession { case SortMergeJoinExec(leftKeys, rightKeys, _, _, SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _), SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _), _) => - assert(leftKeys !== smjExec3.leftKeys) - assert(rightKeys !== smjExec3.rightKeys) assert(leftKeys === Seq(exprC, exprA)) assert(rightKeys === Seq(exprA, exprB)) case other => fail(other.toString) @@ -93,8 +87,6 @@ class EnsureRequirementsSuite extends SharedSparkSession { case SortMergeJoinExec(leftKeys, rightKeys, _, _, SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _), SortExec(_, _, DummySparkPlan(_, _, _: HashPartitioning, _, _), _), _) => - assert(leftKeys !== smjExec1.leftKeys) - assert(rightKeys !== smjExec1.rightKeys) assert(leftKeys === Seq(exprB, exprA)) assert(rightKeys === Seq(exprB, exprC)) case other => fail(other.toString) @@ -109,8 +101,6 @@ class EnsureRequirementsSuite extends SharedSparkSession { case SortMergeJoinExec(leftKeys, rightKeys, _, _, SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _), SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _), _) => - assert(leftKeys !== smjExec2.leftKeys) - assert(rightKeys !== smjExec2.rightKeys) assert(leftKeys === Seq(exprB, exprA)) assert(rightKeys === Seq(exprB, exprC)) case other => fail(other.toString) @@ -124,8 +114,6 @@ class EnsureRequirementsSuite extends SharedSparkSession { case SortMergeJoinExec(leftKeys, rightKeys, _, _, SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _), SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _), _) => - assert(leftKeys !== smjExec3.leftKeys) - assert(rightKeys !== smjExec3.rightKeys) assert(leftKeys === Seq(exprB, exprC)) assert(rightKeys === Seq(exprB, exprA)) case other => fail(other.toString)