Skip to content

Commit

Permalink
[SPARK-2082] stratified sampling in PairRDDFunctions that guarantees …
Browse files Browse the repository at this point in the history
…exact sample size
  • Loading branch information
dorx committed Jun 9, 2014
1 parent e3fd6a6 commit 9ee94ee
Show file tree
Hide file tree
Showing 7 changed files with 379 additions and 4 deletions.
1 change: 0 additions & 1 deletion core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.google.code.findbugs</groupId>
Expand Down
186 changes: 185 additions & 1 deletion core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@ import scala.collection.Map
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag
import scala.util.control.Breaks._

import com.clearspring.analytics.stream.cardinality.HyperLogLog

import org.apache.commons.math3.random.RandomDataGenerator

import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.io.SequenceFile.CompressionType
Expand All @@ -46,7 +50,8 @@ import org.apache.spark.Partitioner.defaultPartitioner
import org.apache.spark.SparkContext._
import org.apache.spark.partial.{BoundedDouble, PartialResult}
import org.apache.spark.serializer.Serializer
import org.apache.spark.util.SerializableHyperLogLog
import org.apache.spark.util.{Utils, SerializableHyperLogLog}
import org.apache.spark.util.random.{PoissonBounds => PB}

/**
* Extra functions available on RDDs of (key, value) pairs through an implicit conversion.
Expand Down Expand Up @@ -155,6 +160,182 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
foldByKey(zeroValue, defaultPartitioner(self))(func)
}

/**
* Return a subset of this RDD sampled by key (via stratified sampling).
* We guarantee a sample size = math.ceil(fraction * S_i), where S_i is the size of the ith
* stratum.
*
* @param withReplacement whether to sample with or without replacement
* @param fraction sampling rate
* @param seed seed for the random number generator
* @return RDD containing the sampled subset
*/
def sampleByKey(withReplacement: Boolean,
fraction: Double,
seed: Long = Utils.random.nextLong): RDD[(K, V)]= {

class Stratum(var numItems: Long = 0L, var numAccepted: Long = 0L) extends Serializable {
var waitList: ArrayBuffer[Double] = new ArrayBuffer[Double]
var q1: Option[Double] = None
var q2: Option[Double] = None

def incrNumItems(by: Long = 1L) = numItems += by

def incrNumAccepted(by: Long = 1L) = numAccepted += by

def addToWaitList(elem: Double) = waitList += elem

def addToWaitList(elems: ArrayBuffer[Double]) = waitList ++= elems

override def toString() = {
"numItems: " + numItems + " numAccepted: " + numAccepted + " q1: " + q1 + " q2: " + q2 +
" waitListSize:" + waitList.size
}
}

class Result(var resultMap: Map[K, Stratum], var cachedPartitionId: Option[Int] = None)
extends Serializable {
var rand: RandomDataGenerator = new RandomDataGenerator

def getEntry(key: K, numItems: Long = 0L): Stratum = {
if (resultMap.get(key).isEmpty) {
resultMap += (key -> new Stratum(numItems))
}
resultMap.get(key).get
}

def getRand(partitionId: Int): RandomDataGenerator = {
if (cachedPartitionId.isEmpty || cachedPartitionId.get != partitionId) {
cachedPartitionId = Some(partitionId)
rand.reSeed(seed + partitionId)
}
rand
}
}

// TODO implement the streaming version of sampling w/ replacement that doesn't require counts
// in order to save one pass over the RDD
val counts = if (withReplacement) Some(this.countByKey()) else None

val seqOp = (U: (TaskContext, Result), item: (K, V)) => {
val delta = 5e-5
val result = U._2
val tc = U._1
val rng = result.getRand(tc.partitionId)
val stratum = result.getEntry(item._1)
if (withReplacement) {
// compute q1 and q2 only if they haven't been computed already
// since they don't change from iteration to iteration.
// TODO change this to the streaming version
if (stratum.q1.isEmpty || stratum.q2.isEmpty) {
val n = counts.get(item._1)
val s = math.ceil(n * fraction).toLong
val lmbd1 = PB.getLambda1(s)
val minCount = PB.getMinCount(lmbd1)
val lmbd2 = if (lmbd1 == 0) PB.getLambda2(s) else PB.getLambda2(s - minCount)
val q1 = lmbd1 / n
val q2 = lmbd2 / n
stratum.q1 = Some(q1)
stratum.q2 = Some(q2)
}
val x1 = if (stratum.q1.get == 0) 0L else rng.nextPoisson(stratum.q1.get)
if (x1 > 0) {
stratum.incrNumAccepted(x1)
}
val x2 = rng.nextPoisson(stratum.q2.get).toInt
if (x2 > 0) {
stratum.addToWaitList(ArrayBuffer.fill(x2)(rng.nextUniform(0.0, 1.0)))
}
} else {
val g1 = - math.log(delta) / stratum.numItems
val g2 = (2.0 / 3.0) * g1
val q1 = math.max(0, fraction + g2 - math.sqrt((g2 * g2 + 3 * g2 * fraction)))
val q2 = math.min(1, fraction + g1 + math.sqrt(g1 * g1 + 2 * g1 * fraction))

val x = rng.nextUniform(0.0, 1.0)
if (x < q1) {
stratum.incrNumAccepted()
} else if ( x < q2) {
stratum.addToWaitList(x)
}
stratum.q1 = Some(q1)
stratum.q2 = Some(q2)
}
stratum.incrNumItems()
result
}

val combOp = (r1: Result, r2: Result) => {
//take union of both key sets in case one partion doesn't contain all keys
val keyUnion = r1.resultMap.keys.toSet.union(r2.resultMap.keys.toSet)

//Use r2 to keep the combined result since r1 is usual empty
for (key <- keyUnion) {
val entry1 = r1.resultMap.get(key)
val entry2 = r2.resultMap.get(key)
if (entry2.isEmpty && entry1.isDefined) {
r2.resultMap += (key -> entry1.get)
} else if (entry1.isDefined && entry2.isDefined) {
entry2.get.addToWaitList(entry1.get.waitList)
entry2.get.incrNumAccepted(entry1.get.numAccepted)
entry2.get.incrNumItems(entry1.get.numItems)
}
}
r2
}

val zeroU = new Result(Map[K, Stratum]())

//determine threshold for each stratum and resample
val finalResult = self.aggregateWithContext(zeroU)(seqOp, combOp).resultMap
val thresholdByKey = new mutable.HashMap[K, Double]()
for ((key, stratum) <- finalResult) {
val s = math.ceil(stratum.numItems * fraction).toLong
breakable {
if (stratum.numAccepted > s) {
logWarning("Pre-accepted too many")
thresholdByKey += (key -> stratum.q1.get)
break
}
val numWaitListAccepted = (s - stratum.numAccepted).toInt
if (numWaitListAccepted >= stratum.waitList.size) {
logWarning("WaitList too short")
thresholdByKey += (key -> stratum.q2.get)
} else {
thresholdByKey += (key -> stratum.waitList.sorted.apply(numWaitListAccepted))
}
}
}

if (withReplacement) {
// Poisson sampler
self.mapPartitionsWithIndex((idx: Int, iter: Iterator[(K, V)]) => {
val random = new RandomDataGenerator()
random.reSeed(seed + idx)
iter.flatMap { t =>
val q1 = finalResult.get(t._1).get.q1.get
val q2 = finalResult.get(t._1).get.q2.get
val x1 = if (q1 == 0) 0L else random.nextPoisson(q1)
val x2 = random.nextPoisson(q2).toInt
val x = x1 + (0 until x2).filter(i => random.nextUniform(0.0, 1.0) <
thresholdByKey.get(t._1).get).size
if (x > 0) {
Iterator.fill(x.toInt)(t)
} else {
Iterator.empty
}
}
}, preservesPartitioning = true)
} else {
// Bernoulli sampler
self.mapPartitionsWithIndex((idx: Int, iter: Iterator[(K, V)]) => {
val random = new RandomDataGenerator
random.reSeed(seed+idx)
iter.filter(t => random.nextUniform(0.0, 1.0) < thresholdByKey.get(t._1).get)
}, preservesPartitioning = true)
}
}

/**
* Merge the values for each key using an associative reduce function. This will also perform
* the merging locally on each mapper before sending results to a reducer, similarly to a
Expand Down Expand Up @@ -442,6 +623,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])

/**
* Return the key-value pairs in this RDD to the master as a Map.
*
* Warning: this doesn't return a multimap (so if you have multiple values to the same key, only
* one value per key is preserved in the map returned)
*/
def collectAsMap(): Map[K, V] = {
val data = self.collect()
Expand Down
18 changes: 18 additions & 0 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,24 @@ abstract class RDD[T: ClassTag](
jobResult
}

/**
* A version of {@link #aggregate()} that passes the TaskContext to the function that does
* aggregation for each partition.
*/
def aggregateWithContext[U: ClassTag](zeroValue: U)(seqOp: ((TaskContext, U), T) => U, combOp: (U, U) => U): U = {
// Clone the zero value since we will also be serializing it as part of tasks
var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())
//pad seqOp and combOp with taskContext to conform to aggregate's signature in TraversableOnce
val paddedSeqOp = (arg1: (TaskContext, U), item: T) => (arg1._1, seqOp(arg1, item))
val paddedcombOp = (arg1 : (TaskContext, U), arg2: (TaskContext, U)) => (arg1._1, combOp(arg1._2, arg1._2))
val cleanSeqOp = sc.clean(paddedSeqOp)
val cleanCombOp = sc.clean(paddedcombOp)
val aggregatePartition = (tc: TaskContext, it: Iterator[T]) => (it.aggregate(tc, zeroValue)(cleanSeqOp, cleanCombOp))._2
val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult)
sc.runJob(this, aggregatePartition, mergeResult)
jobResult
}

/**
* Return the number of elements in the RDD.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util.random

import org.apache.commons.math3.distribution.{PoissonDistribution, NormalDistribution}

private[random] object PoissonBounds {

val delta = 1e-4 / 3.0
val phi = new NormalDistribution().cumulativeProbability(1.0 - delta)

def getLambda1(s: Double): Double = {
var lb = math.max(0.0, s - math.sqrt(s / delta)) // Chebyshev's inequality
var ub = s
while (lb < ub - 1.0) {
val m = (lb + ub) / 2.0
val poisson = new PoissonDistribution(m, 1e-15)
val y = poisson.inverseCumulativeProbability(1 - delta)
if (y > s) ub = m else lb = m
}
lb
}

def getMinCount(lmbd: Double): Double = {
if(lmbd == 0) return 0
val poisson = new PoissonDistribution(lmbd, 1e-15)
poisson.inverseCumulativeProbability(delta)
}

def getLambda2(s: Double): Double = {
var lb = s
var ub = s + math.sqrt(s / delta) // Chebyshev's inequality
while (lb < ub - 1.0) {
val m = (lb + ub) / 2.0
val poisson = new PoissonDistribution(m, 1e-15)
val y = poisson.inverseCumulativeProbability(delta)
if (y >= s) ub = m else lb = m
}
ub
}
}
Loading

0 comments on commit 9ee94ee

Please sign in to comment.