Skip to content

Commit

Permalink
Separated out most of the logic in sampleByKey
Browse files Browse the repository at this point in the history
into StratifiedSampler in util.random
  • Loading branch information
dorx committed Jun 17, 2014
1 parent 7327611 commit 9e74ab5
Show file tree
Hide file tree
Showing 3 changed files with 309 additions and 190 deletions.
183 changes: 19 additions & 164 deletions core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ import org.apache.spark.SparkContext._
import org.apache.spark.partial.{BoundedDouble, PartialResult}
import org.apache.spark.serializer.Serializer
import org.apache.spark.util.Utils
import org.apache.spark.util.random.{PoissonBounds => PB}
import org.apache.spark.util.random.{Stratum, Result, StratifiedSampler, PoissonBounds => PB}

/**
* Extra functions available on RDDs of (key, value) pairs through an implicit conversion.
Expand Down Expand Up @@ -210,177 +210,32 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])

/**
* Return a subset of this RDD sampled by key (via stratified sampling).
* We guarantee a sample size = math.ceil(fraction * S_i), where S_i is the size of the ith
* stratum.
*
* If exact set to true, we guarantee, with high probability, a sample size =
* math.ceil(fraction * S_i), where S_i is the size of the ith stratum (collection of entries
* that share the same key). When sampling without replacement, we need one additional pass over
* the RDD to guarantee sample size with a 99.99% confidence; when sampling with replacement, we
* need two additional passes over the RDD to guarantee sample size with a 99.99% confidence.
*
* @param withReplacement whether to sample with or without replacement
* @param fraction sampling rate
* @param fractionByKey function mapping key to sampling rate
* @param seed seed for the random number generator
* @param exact whether sample size needs to be exactly math.ceil(fraction * size) per stratum
* @return RDD containing the sampled subset
*/
def sampleByKey(withReplacement: Boolean,
fraction: Double,
seed: Long = Utils.random.nextLong): RDD[(K, V)]= {

class Stratum(var numItems: Long = 0L, var numAccepted: Long = 0L) extends Serializable {
var waitList: ArrayBuffer[Double] = new ArrayBuffer[Double]
var q1: Option[Double] = None
var q2: Option[Double] = None

def incrNumItems(by: Long = 1L) = numItems += by

def incrNumAccepted(by: Long = 1L) = numAccepted += by

def addToWaitList(elem: Double) = waitList += elem

def addToWaitList(elems: ArrayBuffer[Double]) = waitList ++= elems

override def toString() = {
"numItems: " + numItems + " numAccepted: " + numAccepted + " q1: " + q1 + " q2: " + q2 +
" waitListSize:" + waitList.size
}
}

class Result(var resultMap: Map[K, Stratum], var cachedPartitionId: Option[Int] = None)
extends Serializable {
var rand: RandomDataGenerator = new RandomDataGenerator

def getEntry(key: K, numItems: Long = 0L): Stratum = {
if (resultMap.get(key).isEmpty) {
resultMap += (key -> new Stratum(numItems))
}
resultMap.get(key).get
}

def getRand(partitionId: Int): RandomDataGenerator = {
if (cachedPartitionId.isEmpty || cachedPartitionId.get != partitionId) {
cachedPartitionId = Some(partitionId)
rand.reSeed(seed + partitionId)
}
rand
}
}

// TODO implement the streaming version of sampling w/ replacement that doesn't require counts
// in order to save one pass over the RDD
val counts = if (withReplacement) Some(this.countByKey()) else None

val seqOp = (U: (TaskContext, Result), item: (K, V)) => {
val delta = 5e-5
val result = U._2
val tc = U._1
val rng = result.getRand(tc.partitionId)
val stratum = result.getEntry(item._1)
if (withReplacement) {
// compute q1 and q2 only if they haven't been computed already
// since they don't change from iteration to iteration.
// TODO change this to the streaming version
if (stratum.q1.isEmpty || stratum.q2.isEmpty) {
val n = counts.get(item._1)
val s = math.ceil(n * fraction).toLong
val lmbd1 = PB.getLambda1(s)
val minCount = PB.getMinCount(lmbd1)
val lmbd2 = if (lmbd1 == 0) PB.getLambda2(s) else PB.getLambda2(s - minCount)
val q1 = lmbd1 / n
val q2 = lmbd2 / n
stratum.q1 = Some(q1)
stratum.q2 = Some(q2)
}
val x1 = if (stratum.q1.get == 0) 0L else rng.nextPoisson(stratum.q1.get)
if (x1 > 0) {
stratum.incrNumAccepted(x1)
}
val x2 = rng.nextPoisson(stratum.q2.get).toInt
if (x2 > 0) {
stratum.addToWaitList(ArrayBuffer.fill(x2)(rng.nextUniform(0.0, 1.0)))
}
} else {
val g1 = - math.log(delta) / stratum.numItems
val g2 = (2.0 / 3.0) * g1
val q1 = math.max(0, fraction + g2 - math.sqrt((g2 * g2 + 3 * g2 * fraction)))
val q2 = math.min(1, fraction + g1 + math.sqrt(g1 * g1 + 2 * g1 * fraction))

val x = rng.nextUniform(0.0, 1.0)
if (x < q1) {
stratum.incrNumAccepted()
} else if ( x < q2) {
stratum.addToWaitList(x)
}
stratum.q1 = Some(q1)
stratum.q2 = Some(q2)
}
stratum.incrNumItems()
result
}

val combOp = (r1: Result, r2: Result) => {
// take union of both key sets in case one partion doesn't contain all keys
val keyUnion = r1.resultMap.keys.toSet.union(r2.resultMap.keys.toSet)

// Use r2 to keep the combined result since r1 is usual empty
for (key <- keyUnion) {
val entry1 = r1.resultMap.get(key)
val entry2 = r2.resultMap.get(key)
if (entry2.isEmpty && entry1.isDefined) {
r2.resultMap += (key -> entry1.get)
} else if (entry1.isDefined && entry2.isDefined) {
entry2.get.addToWaitList(entry1.get.waitList)
entry2.get.incrNumAccepted(entry1.get.numAccepted)
entry2.get.incrNumItems(entry1.get.numItems)
}
}
r2
}

val zeroU = new Result(Map[K, Stratum]())

// determine threshold for each stratum and resample
val finalResult = self.aggregateWithContext(zeroU)(seqOp, combOp).resultMap
val thresholdByKey = new mutable.HashMap[K, Double]()
for ((key, stratum) <- finalResult) {
val s = math.ceil(stratum.numItems * fraction).toLong
breakable {
if (stratum.numAccepted > s) {
logWarning("Pre-accepted too many")
thresholdByKey += (key -> stratum.q1.get)
break
}
val numWaitListAccepted = (s - stratum.numAccepted).toInt
if (numWaitListAccepted >= stratum.waitList.size) {
logWarning("WaitList too short")
thresholdByKey += (key -> stratum.q2.get)
} else {
thresholdByKey += (key -> stratum.waitList.sorted.apply(numWaitListAccepted))
}
}
}

fractionByKey: K => Double,
seed: Long = Utils.random.nextLong,
exact: Boolean = true): RDD[(K, V)]= {
if (withReplacement) {
// Poisson sampler
self.mapPartitionsWithIndex((idx: Int, iter: Iterator[(K, V)]) => {
val random = new RandomDataGenerator()
random.reSeed(seed + idx)
iter.flatMap { t =>
val q1 = finalResult.get(t._1).get.q1.get
val q2 = finalResult.get(t._1).get.q2.get
val x1 = if (q1 == 0) 0L else random.nextPoisson(q1)
val x2 = random.nextPoisson(q2).toInt
val x = x1 + (0 until x2).filter(i => random.nextUniform(0.0, 1.0) <
thresholdByKey.get(t._1).get).size
if (x > 0) {
Iterator.fill(x.toInt)(t)
} else {
Iterator.empty
}
}
}, preservesPartitioning = true)
val counts = if (exact) Some(this.countByKey()) else None
val samplingFunc =
StratifiedSampler.getPoissonSamplingFunction(self, fractionByKey, exact, counts, seed)
self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true)
} else {
// Bernoulli sampler
self.mapPartitionsWithIndex((idx: Int, iter: Iterator[(K, V)]) => {
val random = new RandomDataGenerator
random.reSeed(seed + idx)
iter.filter(t => random.nextUniform(0.0, 1.0) < thresholdByKey.get(t._1).get)
}, preservesPartitioning = true)
val samplingFunc =
StratifiedSampler.getBernoulliSamplingFunction(self, fractionByKey, exact, seed)
self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true)
}
}

Expand Down
Loading

0 comments on commit 9e74ab5

Please sign in to comment.