Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-2568] RangePartitioner should run only one job if data is balanced #1562

Closed
wants to merge 13 commits into from
121 changes: 106 additions & 15 deletions core/src/main/scala/org/apache/spark/Partitioner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@ package org.apache.spark

import java.io.{IOException, ObjectInputStream, ObjectOutputStream}

import scala.reflect.ClassTag
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.reflect.{ClassTag, classTag}
import scala.util.hashing.byteswap32

import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.{PartitionPruningRDD, RDD}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.util.{CollectionsUtils, Utils}
import org.apache.spark.util.random.{XORShiftRandom, SamplingUtils}

/**
* An object that defines how the elements in a key-value pair RDD are partitioned by key.
Expand Down Expand Up @@ -103,26 +107,49 @@ class RangePartitioner[K : Ordering : ClassTag, V](
private var ascending: Boolean = true)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be great to update the documentation on when this results in two passes vs one pass. We should probably update the documentation for sortByKey and various other sorts that use this too. Let's do that in another PR.

extends Partitioner {

// We allow partitions = 0, which happens when sorting an empty RDD under the default settings.
require(partitions >= 0, s"Number of partitions cannot be negative but found $partitions.")

private var ordering = implicitly[Ordering[K]]

// An array of upper bounds for the first (partitions - 1) partitions
private var rangeBounds: Array[K] = {
if (partitions == 1) {
Array()
if (partitions <= 1) {
Array.empty
} else {
val rddSize = rdd.count()
val maxSampleSize = partitions * 20.0
val frac = math.min(maxSampleSize / math.max(rddSize, 1), 1.0)
val rddSample = rdd.sample(false, frac, 1).map(_._1).collect().sorted
if (rddSample.length == 0) {
Array()
// This is the sample size we need to have roughly balanced output partitions, capped at 1M.
val sampleSize = math.min(20.0 * partitions, 1e6)
// Assume the input partitions are roughly balanced and over-sample a little bit.
val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.size).toInt
val (numItems, sketched) = RangePartitioner.sketch(rdd.map(_._1), sampleSizePerPartition)
if (numItems == 0L) {
Array.empty
} else {
val bounds = new Array[K](partitions - 1)
for (i <- 0 until partitions - 1) {
val index = (rddSample.length - 1) * (i + 1) / partitions
bounds(i) = rddSample(index)
// If a partition contains much more than the average number of items, we re-sample from it
// to ensure that enough items are collected from that partition.
val fraction = math.min(sampleSize / math.max(numItems, 1L), 1.0)
val candidates = ArrayBuffer.empty[(K, Float)]
val imbalancedPartitions = mutable.Set.empty[Int]
sketched.foreach { case (idx, n, sample) =>
if (fraction * n > sampleSizePerPartition) {
imbalancedPartitions += idx
} else {
// The weight is 1 over the sampling probability.
val weight = (n.toDouble / sample.size).toFloat
for (key <- sample) {
candidates += ((key, weight))
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can probably just write

for (key <- samples) {
  candidates += ((key, weight))
}

Same with the foreach above. It will be slightly more readable, but no big deal.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think foreach should be faster than for here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe they translate to the same thing in this case. Did you see the code being different?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, they are the same. Changed to for.

}
}
if (imbalancedPartitions.nonEmpty) {
// Re-sample imbalanced partitions with the desired sampling probability.
val imbalanced = new PartitionPruningRDD(rdd.map(_._1), imbalancedPartitions.contains)
val seed = byteswap32(-rdd.id - 1)
val reSampled = imbalanced.sample(withReplacement = false, fraction, seed).collect()
val weight = (1.0 / fraction).toFloat
candidates ++= reSampled.map(x => (x, weight))
}
bounds
RangePartitioner.determineBounds(candidates, partitions)
}
}
}
Expand Down Expand Up @@ -212,3 +239,67 @@ class RangePartitioner[K : Ordering : ClassTag, V](
}
}
}

private[spark] object RangePartitioner {

/**
* Sketches the input RDD via reservoir sampling on each partition.
*
* @param rdd the input RDD to sketch
* @param sampleSizePerPartition max sample size per partition
* @return (total number of items, an array of (partitionId, number of items, sample))
*/
def sketch[K:ClassTag](
rdd: RDD[K],
sampleSizePerPartition: Int): (Long, Array[(Int, Int, Array[K])]) = {
val shift = rdd.id
// val classTagK = classTag[K] // to avoid serializing the entire partitioner object
val sketched = rdd.mapPartitionsWithIndex { (idx, iter) =>
val seed = byteswap32(idx ^ (shift << 16))
val (sample, n) = SamplingUtils.reservoirSampleAndCount(
iter, sampleSizePerPartition, seed)
Iterator((idx, n, sample))
}.collect()
val numItems = sketched.map(_._2.toLong).sum
(numItems, sketched)
}

/**
* Determines the bounds for range partitioning from candidates with weights indicating how many
* items each represents. Usually this is 1 over the probability used to sample this candidate.
*
* @param candidates unordered candidates with weights
* @param partitions number of partitions
* @return selected bounds
*/
def determineBounds[K:Ordering:ClassTag](
candidates: ArrayBuffer[(K, Float)],
partitions: Int): Array[K] = {
val ordering = implicitly[Ordering[K]]
val ordered = candidates.sortBy(_._1)
val numCandidates = ordered.size
val sumWeights = ordered.map(_._2.toDouble).sum
val step = sumWeights / partitions
var cumWeight = 0.0
var target = step
val bounds = ArrayBuffer.empty[K]
var i = 0
var j = 0
var previousBound = Option.empty[K]
while ((i < numCandidates) && (j < partitions - 1)) {
val (key, weight) = ordered(i)
cumWeight += weight
if (cumWeight > target) {
// Skip duplicate values.
if (previousBound.isEmpty || ordering.gt(key, previousBound.get)) {
bounds += key
target += step
j += 1
previousBound = Some(key)
}
}
i += 1
}
bounds.toArray
}
}
64 changes: 60 additions & 4 deletions core/src/test/scala/org/apache/spark/PartitioningSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark

import scala.collection.mutable.ArrayBuffer
import scala.math.abs

import org.scalatest.{FunSuite, PrivateMethodTester}
Expand Down Expand Up @@ -52,14 +53,12 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet

assert(p2 === p2)
assert(p4 === p4)
assert(p2 != p4)
assert(p4 != p2)
assert(p2 === p4)
assert(p4 === anotherP4)
assert(anotherP4 === p4)
assert(descendingP2 === descendingP2)
assert(descendingP4 === descendingP4)
assert(descendingP2 != descendingP4)
assert(descendingP4 != descendingP2)
assert(descendingP2 === descendingP4)
assert(p2 != descendingP2)
assert(p4 != descendingP4)
assert(descendingP2 != p2)
Expand Down Expand Up @@ -102,6 +101,63 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet
partitioner.getPartition(Row(100))
}

test("RangPartitioner.sketch") {
val rdd = sc.makeRDD(0 until 20, 20).flatMap { i =>
val random = new java.util.Random(i)
Iterator.fill(i)(random.nextDouble())
}.cache()
val sampleSizePerPartition = 10
val (count, sketched) = RangePartitioner.sketch(rdd, sampleSizePerPartition)
assert(count === rdd.count())
sketched.foreach { case (idx, n, sample) =>
assert(n === idx)
assert(sample.size === math.min(n, sampleSizePerPartition))
}
}

test("RangePartitioner.determineBounds") {
assert(RangePartitioner.determineBounds(ArrayBuffer.empty[(Int, Float)], 10).isEmpty,
"Bounds on an empty candidates set should be empty.")
val candidates = ArrayBuffer(
(0.7, 2.0f), (0.1, 1.0f), (0.4, 1.0f), (0.3, 1.0f), (0.2, 1.0f), (0.5, 1.0f), (1.0, 3.0f))
assert(RangePartitioner.determineBounds(candidates, 3) === Array(0.4, 0.7))
}

test("RangePartitioner should run only one job if data is roughly balanced") {
val rdd = sc.makeRDD(0 until 20, 20).flatMap { i =>
val random = new java.util.Random(i)
Iterator.fill(5000 * i)((random.nextDouble() + i, i))
}.cache()
for (numPartitions <- Seq(10, 20, 40)) {
val partitioner = new RangePartitioner(numPartitions, rdd)
assert(partitioner.numPartitions === numPartitions)
val counts = rdd.keys.map(key => partitioner.getPartition(key)).countByValue().values
assert(counts.max < 3.0 * counts.min)
}
}

test("RangePartitioner should work well on unbalanced data") {
val rdd = sc.makeRDD(0 until 20, 20).flatMap { i =>
val random = new java.util.Random(i)
Iterator.fill(20 * i * i * i)((random.nextDouble() + i, i))
}.cache()
for (numPartitions <- Seq(2, 4, 8)) {
val partitioner = new RangePartitioner(numPartitions, rdd)
assert(partitioner.numPartitions === numPartitions)
val counts = rdd.keys.map(key => partitioner.getPartition(key)).countByValue().values
assert(counts.max < 3.0 * counts.min)
}
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some tests where the whole RDD has 0 elements, and some tests where individual partitions have 0 elements and others have more? That's where divide by zero errors can happen.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first partition in this test contains 0 elements. I will add a test where the whole RDD has 0 elements.

test("RangePartitioner should return a single partition for empty RDDs") {
val empty1 = sc.emptyRDD[(Int, Double)]
val partitioner1 = new RangePartitioner(0, empty1)
assert(partitioner1.numPartitions === 1)
val empty2 = sc.makeRDD(0 until 2, 2).flatMap(i => Seq.empty[(Int, Double)])
val partitioner2 = new RangePartitioner(2, empty2)
assert(partitioner2.numPartitions === 1)
}

test("HashPartitioner not equal to RangePartitioner") {
val rdd = sc.parallelize(1 to 10).map(x => (x, x))
val rangeP2 = new RangePartitioner(2, rdd)
Expand Down
5 changes: 5 additions & 0 deletions core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,11 @@ class RDDSuite extends FunSuite with SharedSparkContext {
}
}

test("sort an empty RDD") {
val data = sc.emptyRDD[Int]
assert(data.sortBy(x => x).collect() === Array.empty)
}

test("sortByKey") {
val data = sc.parallelize(Seq("5|50|A","4|60|C", "6|40|B"))

Expand Down