diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 52c018baa5f7b..37053bb6f37ad 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -19,11 +19,15 @@ package org.apache.spark import java.io.{IOException, ObjectInputStream, ObjectOutputStream} -import scala.reflect.ClassTag +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.reflect.{ClassTag, classTag} +import scala.util.hashing.byteswap32 -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{PartitionPruningRDD, RDD} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.util.{CollectionsUtils, Utils} +import org.apache.spark.util.random.{XORShiftRandom, SamplingUtils} /** * An object that defines how the elements in a key-value pair RDD are partitioned by key. @@ -103,26 +107,49 @@ class RangePartitioner[K : Ordering : ClassTag, V]( private var ascending: Boolean = true) extends Partitioner { + // We allow partitions = 0, which happens when sorting an empty RDD under the default settings. + require(partitions >= 0, s"Number of partitions cannot be negative but found $partitions.") + private var ordering = implicitly[Ordering[K]] // An array of upper bounds for the first (partitions - 1) partitions private var rangeBounds: Array[K] = { - if (partitions == 1) { - Array() + if (partitions <= 1) { + Array.empty } else { - val rddSize = rdd.count() - val maxSampleSize = partitions * 20.0 - val frac = math.min(maxSampleSize / math.max(rddSize, 1), 1.0) - val rddSample = rdd.sample(false, frac, 1).map(_._1).collect().sorted - if (rddSample.length == 0) { - Array() + // This is the sample size we need to have roughly balanced output partitions, capped at 1M. + val sampleSize = math.min(20.0 * partitions, 1e6) + // Assume the input partitions are roughly balanced and over-sample a little bit. + val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.size).toInt + val (numItems, sketched) = RangePartitioner.sketch(rdd.map(_._1), sampleSizePerPartition) + if (numItems == 0L) { + Array.empty } else { - val bounds = new Array[K](partitions - 1) - for (i <- 0 until partitions - 1) { - val index = (rddSample.length - 1) * (i + 1) / partitions - bounds(i) = rddSample(index) + // If a partition contains much more than the average number of items, we re-sample from it + // to ensure that enough items are collected from that partition. + val fraction = math.min(sampleSize / math.max(numItems, 1L), 1.0) + val candidates = ArrayBuffer.empty[(K, Float)] + val imbalancedPartitions = mutable.Set.empty[Int] + sketched.foreach { case (idx, n, sample) => + if (fraction * n > sampleSizePerPartition) { + imbalancedPartitions += idx + } else { + // The weight is 1 over the sampling probability. + val weight = (n.toDouble / sample.size).toFloat + for (key <- sample) { + candidates += ((key, weight)) + } + } + } + if (imbalancedPartitions.nonEmpty) { + // Re-sample imbalanced partitions with the desired sampling probability. + val imbalanced = new PartitionPruningRDD(rdd.map(_._1), imbalancedPartitions.contains) + val seed = byteswap32(-rdd.id - 1) + val reSampled = imbalanced.sample(withReplacement = false, fraction, seed).collect() + val weight = (1.0 / fraction).toFloat + candidates ++= reSampled.map(x => (x, weight)) } - bounds + RangePartitioner.determineBounds(candidates, partitions) } } } @@ -212,3 +239,67 @@ class RangePartitioner[K : Ordering : ClassTag, V]( } } } + +private[spark] object RangePartitioner { + + /** + * Sketches the input RDD via reservoir sampling on each partition. + * + * @param rdd the input RDD to sketch + * @param sampleSizePerPartition max sample size per partition + * @return (total number of items, an array of (partitionId, number of items, sample)) + */ + def sketch[K:ClassTag]( + rdd: RDD[K], + sampleSizePerPartition: Int): (Long, Array[(Int, Int, Array[K])]) = { + val shift = rdd.id + // val classTagK = classTag[K] // to avoid serializing the entire partitioner object + val sketched = rdd.mapPartitionsWithIndex { (idx, iter) => + val seed = byteswap32(idx ^ (shift << 16)) + val (sample, n) = SamplingUtils.reservoirSampleAndCount( + iter, sampleSizePerPartition, seed) + Iterator((idx, n, sample)) + }.collect() + val numItems = sketched.map(_._2.toLong).sum + (numItems, sketched) + } + + /** + * Determines the bounds for range partitioning from candidates with weights indicating how many + * items each represents. Usually this is 1 over the probability used to sample this candidate. + * + * @param candidates unordered candidates with weights + * @param partitions number of partitions + * @return selected bounds + */ + def determineBounds[K:Ordering:ClassTag]( + candidates: ArrayBuffer[(K, Float)], + partitions: Int): Array[K] = { + val ordering = implicitly[Ordering[K]] + val ordered = candidates.sortBy(_._1) + val numCandidates = ordered.size + val sumWeights = ordered.map(_._2.toDouble).sum + val step = sumWeights / partitions + var cumWeight = 0.0 + var target = step + val bounds = ArrayBuffer.empty[K] + var i = 0 + var j = 0 + var previousBound = Option.empty[K] + while ((i < numCandidates) && (j < partitions - 1)) { + val (key, weight) = ordered(i) + cumWeight += weight + if (cumWeight > target) { + // Skip duplicate values. + if (previousBound.isEmpty || ordering.gt(key, previousBound.get)) { + bounds += key + target += step + j += 1 + previousBound = Some(key) + } + } + i += 1 + } + bounds.toArray + } +} diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index 4658a08064280..fc0cee3e8749d 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark +import scala.collection.mutable.ArrayBuffer import scala.math.abs import org.scalatest.{FunSuite, PrivateMethodTester} @@ -52,14 +53,12 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet assert(p2 === p2) assert(p4 === p4) - assert(p2 != p4) - assert(p4 != p2) + assert(p2 === p4) assert(p4 === anotherP4) assert(anotherP4 === p4) assert(descendingP2 === descendingP2) assert(descendingP4 === descendingP4) - assert(descendingP2 != descendingP4) - assert(descendingP4 != descendingP2) + assert(descendingP2 === descendingP4) assert(p2 != descendingP2) assert(p4 != descendingP4) assert(descendingP2 != p2) @@ -102,6 +101,63 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet partitioner.getPartition(Row(100)) } + test("RangPartitioner.sketch") { + val rdd = sc.makeRDD(0 until 20, 20).flatMap { i => + val random = new java.util.Random(i) + Iterator.fill(i)(random.nextDouble()) + }.cache() + val sampleSizePerPartition = 10 + val (count, sketched) = RangePartitioner.sketch(rdd, sampleSizePerPartition) + assert(count === rdd.count()) + sketched.foreach { case (idx, n, sample) => + assert(n === idx) + assert(sample.size === math.min(n, sampleSizePerPartition)) + } + } + + test("RangePartitioner.determineBounds") { + assert(RangePartitioner.determineBounds(ArrayBuffer.empty[(Int, Float)], 10).isEmpty, + "Bounds on an empty candidates set should be empty.") + val candidates = ArrayBuffer( + (0.7, 2.0f), (0.1, 1.0f), (0.4, 1.0f), (0.3, 1.0f), (0.2, 1.0f), (0.5, 1.0f), (1.0, 3.0f)) + assert(RangePartitioner.determineBounds(candidates, 3) === Array(0.4, 0.7)) + } + + test("RangePartitioner should run only one job if data is roughly balanced") { + val rdd = sc.makeRDD(0 until 20, 20).flatMap { i => + val random = new java.util.Random(i) + Iterator.fill(5000 * i)((random.nextDouble() + i, i)) + }.cache() + for (numPartitions <- Seq(10, 20, 40)) { + val partitioner = new RangePartitioner(numPartitions, rdd) + assert(partitioner.numPartitions === numPartitions) + val counts = rdd.keys.map(key => partitioner.getPartition(key)).countByValue().values + assert(counts.max < 3.0 * counts.min) + } + } + + test("RangePartitioner should work well on unbalanced data") { + val rdd = sc.makeRDD(0 until 20, 20).flatMap { i => + val random = new java.util.Random(i) + Iterator.fill(20 * i * i * i)((random.nextDouble() + i, i)) + }.cache() + for (numPartitions <- Seq(2, 4, 8)) { + val partitioner = new RangePartitioner(numPartitions, rdd) + assert(partitioner.numPartitions === numPartitions) + val counts = rdd.keys.map(key => partitioner.getPartition(key)).countByValue().values + assert(counts.max < 3.0 * counts.min) + } + } + + test("RangePartitioner should return a single partition for empty RDDs") { + val empty1 = sc.emptyRDD[(Int, Double)] + val partitioner1 = new RangePartitioner(0, empty1) + assert(partitioner1.numPartitions === 1) + val empty2 = sc.makeRDD(0 until 2, 2).flatMap(i => Seq.empty[(Int, Double)]) + val partitioner2 = new RangePartitioner(2, empty2) + assert(partitioner2.numPartitions === 1) + } + test("HashPartitioner not equal to RangePartitioner") { val rdd = sc.parallelize(1 to 10).map(x => (x, x)) val rangeP2 = new RangePartitioner(2, rdd) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 6654ec2d7c656..fdc83bc0a5f8e 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -613,6 +613,11 @@ class RDDSuite extends FunSuite with SharedSparkContext { } } + test("sort an empty RDD") { + val data = sc.emptyRDD[Int] + assert(data.sortBy(x => x).collect() === Array.empty) + } + test("sortByKey") { val data = sc.parallelize(Seq("5|50|A","4|60|C", "6|40|B"))