From dac2f49b5dc66b67d632632907cc654f4d319434 Mon Sep 17 00:00:00 2001 From: Tejas Patil Date: Sun, 23 Oct 2016 21:23:13 -0700 Subject: [PATCH] [SPARK-18067] [SQL] SortMergeJoin adds shuffle if join predicates have non partitioned columns --- .../spark/sql/catalyst/plans/physical/partitioning.scala | 1 + .../spark/sql/execution/exchange/EnsureRequirements.scala | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 51d78dd1233fe..82c79e5203b6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -255,6 +255,7 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def guarantees(other: Partitioning): Boolean = other match { case o: HashPartitioning => this.semanticEquals(o) + case o: PartitioningCollection => o.partitionings.exists(this.guarantees) case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index f17049949aa47..dc477cecb25b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -50,7 +50,11 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { numPartitions: Int): Partitioning = { requiredDistribution match { case AllTuples => SinglePartition - case ClusteredDistribution(clustering) => HashPartitioning(clustering, numPartitions) + case ClusteredDistribution(clustering) if clustering.size == 1 => + HashPartitioning(clustering, numPartitions) + case ClusteredDistribution(clustering) => + PartitioningCollection( + clustering.map(expression => HashPartitioning(Seq(expression), numPartitions))) case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions) case dist => sys.error(s"Do not know how to satisfy distribution $dist") }