Permalink
| /* | |
| * Licensed to the Apache Software Foundation (ASF) under one or more | |
| * contributor license agreements. See the NOTICE file distributed with | |
| * this work for additional information regarding copyright ownership. | |
| * The ASF licenses this file to You under the Apache License, Version 2.0 | |
| * (the "License"); you may not use this file except in compliance with | |
| * the License. You may obtain a copy of the License at | |
| * | |
| * http://www.apache.org/licenses/LICENSE-2.0 | |
| * | |
| * Unless required by applicable law or agreed to in writing, software | |
| * distributed under the License is distributed on an "AS IS" BASIS, | |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| * See the License for the specific language governing permissions and | |
| * limitations under the License. | |
| */ | |
| package org.apache.spark | |
| import java.io.{IOException, ObjectInputStream, ObjectOutputStream} | |
| import scala.collection.mutable | |
| import scala.collection.mutable.ArrayBuffer | |
| import scala.reflect.ClassTag | |
| import scala.util.hashing.byteswap32 | |
| import org.apache.spark.rdd.{PartitionPruningRDD, RDD} | |
| import org.apache.spark.serializer.JavaSerializer | |
| import org.apache.spark.util.{CollectionsUtils, Utils} | |
| import org.apache.spark.util.random.SamplingUtils | |
| /** | |
| * An object that defines how the elements in a key-value pair RDD are partitioned by key. | |
| * Maps each key to a partition ID, from 0 to `numPartitions - 1`. | |
| */ | |
| abstract class Partitioner extends Serializable { | |
| def numPartitions: Int | |
| def getPartition(key: Any): Int | |
| } | |
| object Partitioner { | |
| /** | |
| * Choose a partitioner to use for a cogroup-like operation between a number of RDDs. | |
| * | |
| * If any of the RDDs already has a partitioner, choose that one. | |
| * | |
| * Otherwise, we use a default HashPartitioner. For the number of partitions, if | |
| * spark.default.parallelism is set, then we'll use the value from SparkContext | |
| * defaultParallelism, otherwise we'll use the max number of upstream partitions. | |
| * | |
| * Unless spark.default.parallelism is set, the number of partitions will be the | |
| * same as the number of partitions in the largest upstream RDD, as this should | |
| * be least likely to cause out-of-memory errors. | |
| * | |
| * We use two method parameters (rdd, others) to enforce callers passing at least 1 RDD. | |
| */ | |
| def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = { | |
| val rdds = (Seq(rdd) ++ others) | |
| val hasPartitioner = rdds.filter(_.partitioner.exists(_.numPartitions > 0)) | |
| if (hasPartitioner.nonEmpty) { | |
| hasPartitioner.maxBy(_.partitions.length).partitioner.get | |
| } else { | |
| if (rdd.context.conf.contains("spark.default.parallelism")) { | |
| new HashPartitioner(rdd.context.defaultParallelism) | |
| } else { | |
| new HashPartitioner(rdds.map(_.partitions.length).max) | |
| } | |
| } | |
| } | |
| } | |
| /** | |
| * A [[org.apache.spark.Partitioner]] that implements hash-based partitioning using | |
| * Java's `Object.hashCode`. | |
| * | |
| * Java arrays have hashCodes that are based on the arrays' identities rather than their contents, | |
| * so attempting to partition an RDD[Array[_]] or RDD[(Array[_], _)] using a HashPartitioner will | |
| * produce an unexpected or incorrect result. | |
| */ | |
| class HashPartitioner(partitions: Int) extends Partitioner { | |
| require(partitions >= 0, s"Number of partitions ($partitions) cannot be negative.") | |
| def numPartitions: Int = partitions | |
| def getPartition(key: Any): Int = key match { | |
| case null => 0 | |
| case _ => Utils.nonNegativeMod(key.hashCode, numPartitions) | |
| } | |
| override def equals(other: Any): Boolean = other match { | |
| case h: HashPartitioner => | |
| h.numPartitions == numPartitions | |
| case _ => | |
| false | |
| } | |
| override def hashCode: Int = numPartitions | |
| } | |
| /** | |
| * A [[org.apache.spark.Partitioner]] that partitions sortable records by range into roughly | |
| * equal ranges. The ranges are determined by sampling the content of the RDD passed in. | |
| * | |
| * @note The actual number of partitions created by the RangePartitioner might not be the same | |
| * as the `partitions` parameter, in the case where the number of sampled records is less than | |
| * the value of `partitions`. | |
| */ | |
| class RangePartitioner[K : Ordering : ClassTag, V]( | |
| partitions: Int, | |
| rdd: RDD[_ <: Product2[K, V]], | |
| private var ascending: Boolean = true) | |
| extends Partitioner { | |
| // We allow partitions = 0, which happens when sorting an empty RDD under the default settings. | |
| require(partitions >= 0, s"Number of partitions cannot be negative but found $partitions.") | |
| private var ordering = implicitly[Ordering[K]] | |
| // An array of upper bounds for the first (partitions - 1) partitions | |
| private var rangeBounds: Array[K] = { | |
| if (partitions <= 1) { | |
| Array.empty | |
| } else { | |
| // This is the sample size we need to have roughly balanced output partitions, capped at 1M. | |
| val sampleSize = math.min(20.0 * partitions, 1e6) | |
| // Assume the input partitions are roughly balanced and over-sample a little bit. | |
| val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.length).toInt | |
| val (numItems, sketched) = RangePartitioner.sketch(rdd.map(_._1), sampleSizePerPartition) | |
| if (numItems == 0L) { | |
| Array.empty | |
| } else { | |
| // If a partition contains much more than the average number of items, we re-sample from it | |
| // to ensure that enough items are collected from that partition. | |
| val fraction = math.min(sampleSize / math.max(numItems, 1L), 1.0) | |
| val candidates = ArrayBuffer.empty[(K, Float)] | |
| val imbalancedPartitions = mutable.Set.empty[Int] | |
| sketched.foreach { case (idx, n, sample) => | |
| if (fraction * n > sampleSizePerPartition) { | |
| imbalancedPartitions += idx | |
| } else { | |
| // The weight is 1 over the sampling probability. | |
| val weight = (n.toDouble / sample.length).toFloat | |
| for (key <- sample) { | |
| candidates += ((key, weight)) | |
| } | |
| } | |
| } | |
| if (imbalancedPartitions.nonEmpty) { | |
| // Re-sample imbalanced partitions with the desired sampling probability. | |
| val imbalanced = new PartitionPruningRDD(rdd.map(_._1), imbalancedPartitions.contains) | |
| val seed = byteswap32(-rdd.id - 1) | |
| val reSampled = imbalanced.sample(withReplacement = false, fraction, seed).collect() | |
| val weight = (1.0 / fraction).toFloat | |
| candidates ++= reSampled.map(x => (x, weight)) | |
| } | |
| RangePartitioner.determineBounds(candidates, partitions) | |
| } | |
| } | |
| } | |
| def numPartitions: Int = rangeBounds.length + 1 | |
| private var binarySearch: ((Array[K], K) => Int) = CollectionsUtils.makeBinarySearch[K] | |
| def getPartition(key: Any): Int = { | |
| val k = key.asInstanceOf[K] | |
| var partition = 0 | |
| if (rangeBounds.length <= 128) { | |
| // If we have less than 128 partitions naive search | |
| while (partition < rangeBounds.length && ordering.gt(k, rangeBounds(partition))) { | |
| partition += 1 | |
| } | |
| } else { | |
| // Determine which binary search method to use only once. | |
| partition = binarySearch(rangeBounds, k) | |
| // binarySearch either returns the match location or -[insertion point]-1 | |
| if (partition < 0) { | |
| partition = -partition-1 | |
| } | |
| if (partition > rangeBounds.length) { | |
| partition = rangeBounds.length | |
| } | |
| } | |
| if (ascending) { | |
| partition | |
| } else { | |
| rangeBounds.length - partition | |
| } | |
| } | |
| override def equals(other: Any): Boolean = other match { | |
| case r: RangePartitioner[_, _] => | |
| r.rangeBounds.sameElements(rangeBounds) && r.ascending == ascending | |
| case _ => | |
| false | |
| } | |
| override def hashCode(): Int = { | |
| val prime = 31 | |
| var result = 1 | |
| var i = 0 | |
| while (i < rangeBounds.length) { | |
| result = prime * result + rangeBounds(i).hashCode | |
| i += 1 | |
| } | |
| result = prime * result + ascending.hashCode | |
| result | |
| } | |
| @throws(classOf[IOException]) | |
| private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { | |
| val sfactory = SparkEnv.get.serializer | |
| sfactory match { | |
| case js: JavaSerializer => out.defaultWriteObject() | |
| case _ => | |
| out.writeBoolean(ascending) | |
| out.writeObject(ordering) | |
| out.writeObject(binarySearch) | |
| val ser = sfactory.newInstance() | |
| Utils.serializeViaNestedStream(out, ser) { stream => | |
| stream.writeObject(scala.reflect.classTag[Array[K]]) | |
| stream.writeObject(rangeBounds) | |
| } | |
| } | |
| } | |
| @throws(classOf[IOException]) | |
| private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { | |
| val sfactory = SparkEnv.get.serializer | |
| sfactory match { | |
| case js: JavaSerializer => in.defaultReadObject() | |
| case _ => | |
| ascending = in.readBoolean() | |
| ordering = in.readObject().asInstanceOf[Ordering[K]] | |
| binarySearch = in.readObject().asInstanceOf[(Array[K], K) => Int] | |
| val ser = sfactory.newInstance() | |
| Utils.deserializeViaNestedStream(in, ser) { ds => | |
| implicit val classTag = ds.readObject[ClassTag[Array[K]]]() | |
| rangeBounds = ds.readObject[Array[K]]() | |
| } | |
| } | |
| } | |
| } | |
| private[spark] object RangePartitioner { | |
| /** | |
| * Sketches the input RDD via reservoir sampling on each partition. | |
| * | |
| * @param rdd the input RDD to sketch | |
| * @param sampleSizePerPartition max sample size per partition | |
| * @return (total number of items, an array of (partitionId, number of items, sample)) | |
| */ | |
| def sketch[K : ClassTag]( | |
| rdd: RDD[K], | |
| sampleSizePerPartition: Int): (Long, Array[(Int, Long, Array[K])]) = { | |
| val shift = rdd.id | |
| // val classTagK = classTag[K] // to avoid serializing the entire partitioner object | |
| val sketched = rdd.mapPartitionsWithIndex { (idx, iter) => | |
| val seed = byteswap32(idx ^ (shift << 16)) | |
| val (sample, n) = SamplingUtils.reservoirSampleAndCount( | |
| iter, sampleSizePerPartition, seed) | |
| Iterator((idx, n, sample)) | |
| }.collect() | |
| val numItems = sketched.map(_._2).sum | |
| (numItems, sketched) | |
| } | |
| /** | |
| * Determines the bounds for range partitioning from candidates with weights indicating how many | |
| * items each represents. Usually this is 1 over the probability used to sample this candidate. | |
| * | |
| * @param candidates unordered candidates with weights | |
| * @param partitions number of partitions | |
| * @return selected bounds | |
| */ | |
| def determineBounds[K : Ordering : ClassTag]( | |
| candidates: ArrayBuffer[(K, Float)], | |
| partitions: Int): Array[K] = { | |
| val ordering = implicitly[Ordering[K]] | |
| val ordered = candidates.sortBy(_._1) | |
| val numCandidates = ordered.size | |
| val sumWeights = ordered.map(_._2.toDouble).sum | |
| val step = sumWeights / partitions | |
| var cumWeight = 0.0 | |
| var target = step | |
| val bounds = ArrayBuffer.empty[K] | |
| var i = 0 | |
| var j = 0 | |
| var previousBound = Option.empty[K] | |
| while ((i < numCandidates) && (j < partitions - 1)) { | |
| val (key, weight) = ordered(i) | |
| cumWeight += weight | |
| if (cumWeight >= target) { | |
| // Skip duplicate values. | |
| if (previousBound.isEmpty || ordering.gt(key, previousBound.get)) { | |
| bounds += key | |
| target += step | |
| j += 1 | |
| previousBound = Some(key) | |
| } | |
| } | |
| i += 1 | |
| } | |
| bounds.toArray | |
| } | |
| } |