Skip to content

Commit

Permalink
Fix sort based shuffle for spark sql.
Browse files Browse the repository at this point in the history
  • Loading branch information
marmbrus committed Aug 20, 2014
1 parent 0ea46ac commit fcd7bb2
Showing 1 changed file with 23 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
package org.apache.spark.sql.execution

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.{HashPartitioner, RangePartitioner, SparkConf}
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner, SparkConf}
import org.apache.spark.rdd.ShuffledRDD
import org.apache.spark.sql.{SQLContext, Row}
import org.apache.spark.sql.catalyst.errors.attachTree
Expand All @@ -37,6 +38,9 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una

def output = child.output

/** We must copy rows when sort based shuffle is on */
protected def sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager]

def execute() = attachTree(this , "execute") {
newPartitioning match {
case HashPartitioning(expressions, numPartitions) =>
Expand All @@ -45,8 +49,12 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
@transient val hashExpressions =
newMutableProjection(expressions, child.output)()

val mutablePair = new MutablePair[Row, Row]()
iter.map(r => mutablePair.update(hashExpressions(r), r))
if (sortBasedShuffleOn) {
iter.map(r => (hashExpressions(r), r.copy()))
} else {
val mutablePair = new MutablePair[Row, Row]()
iter.map(r => mutablePair.update(hashExpressions(r), r))
}
}
val part = new HashPartitioner(numPartitions)
val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part)
Expand All @@ -58,8 +66,12 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
implicit val ordering = new RowOrdering(sortingExpressions, child.output)

val rdd = child.execute().mapPartitions { iter =>
val mutablePair = new MutablePair[Row, Null](null, null)
iter.map(row => mutablePair.update(row, null))
if (sortBasedShuffleOn) {
iter.map(row => (row.copy(), null))
} else {
val mutablePair = new MutablePair[Row, Null](null, null)
iter.map(row => mutablePair.update(row, null))
}
}
val part = new RangePartitioner(numPartitions, rdd, ascending = true)
val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part)
Expand All @@ -69,8 +81,12 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una

case SinglePartition =>
val rdd = child.execute().mapPartitions { iter =>
val mutablePair = new MutablePair[Null, Row]()
iter.map(r => mutablePair.update(null, r))
if (sortBasedShuffleOn) {
iter.map(r => (null, r.copy()))
} else {
val mutablePair = new MutablePair[Null, Row]()
iter.map(r => mutablePair.update(null, r))
}
}
val partitioner = new HashPartitioner(1)
val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner)
Expand Down

0 comments on commit fcd7bb2

Please sign in to comment.