Skip to content

Commit

Permalink
cleanUp
Browse files Browse the repository at this point in the history
Addressed reviewer comments and added better documentation of code.
Added commons-math3 as a dependency of spark (okay’ed by Matei). “mvm
clean install” compiled. Recovered files that were reverted by accident
in the merge.
TODOs: figure out API for sampleByKeyExact and update Java, Python, and
the markdown file accordingly.
  • Loading branch information
dorx committed Jun 18, 2014
1 parent 90d94c0 commit 0214a76
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 20 deletions.
2 changes: 0 additions & 2 deletions core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<version>3.3</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.google.code.findbugs</groupId>
Expand Down
21 changes: 21 additions & 0 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,27 @@ abstract class RDD[T: ClassTag](
jobResult
}

/**
* A version of {@link #aggregate()} that passes the TaskContext to the function that does
* aggregation for each partition.
*/
def aggregateWithContext[U: ClassTag](zeroValue: U)(seqOp: ((TaskContext, U), T) => U,
combOp: (U, U) => U): U = {
// Clone the zero value since we will also be serializing it as part of tasks
var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())
// pad seqOp and combOp with taskContext to conform to aggregate's signature in TraversableOnce
val paddedSeqOp = (arg1: (TaskContext, U), item: T) => (arg1._1, seqOp(arg1, item))
val paddedcombOp = (arg1: (TaskContext, U), arg2: (TaskContext, U)) =>
(arg1._1, combOp(arg1._2, arg1._2))
val cleanSeqOp = sc.clean(paddedSeqOp)
val cleanCombOp = sc.clean(paddedcombOp)
val aggregatePartition = (tc: TaskContext, it: Iterator[T]) =>
(it.aggregate(tc, zeroValue)(cleanSeqOp, cleanCombOp))._2
val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult)
sc.runJob(this, aggregatePartition, mergeResult)
jobResult
}

/**
* Return the number of elements in the RDD.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.util.random

import org.apache.commons.math3.distribution.{PoissonDistribution, NormalDistribution}
import org.apache.commons.math3.distribution.PoissonDistribution

private[spark] object SamplingUtils {

Expand All @@ -43,7 +43,7 @@ private[spark] object SamplingUtils {
* @return a sampling rate that guarantees sufficient sample size with 99.99% success rate
*/
def computeFractionForSampleSize(sampleSizeLowerBound: Int, total: Long,
withReplacement: Boolean): Double = {
withReplacement: Boolean): Double = {
val fraction = sampleSizeLowerBound.toDouble / total
if (withReplacement) {
val numStDev = if (sampleSizeLowerBound < 12) 9 else 5
Expand All @@ -56,12 +56,29 @@ private[spark] object SamplingUtils {
}
}

/**
* Utility functions that help us determine bounds on adjusted sampling rate to guarantee exact
* sample sizes with high confidence when sampling with replacement.
*
* The algorithm for guaranteeing sample size instantly accepts items whose associated value drawn
* from Pois(s) is less than the lower bound and puts items whose value is between the lower and
* upper bound in a waitlist. The final sample is consisted of all items accepted on the fly and a
* portion of the waitlist needed to make the exact sample size.
*/
private[spark] object PoissonBounds {

val delta = 1e-4 / 3.0
val phi = new NormalDistribution().cumulativeProbability(1.0 - delta)

def getLambda1(s: Double): Double = {
/**
* Compute the threshold for accepting items on the fly. The threshold value is a fairly small
* number, which means if the item has an associated value < threshold, it is highly likely to
* be in the final sample. Hence we accept items with values less than the returned value of this
* function instantly.
*
* @param s sample size
* @return threshold for accepting items on the fly
*/
def getLowerBound(s: Double): Double = {
var lb = math.max(0.0, s - math.sqrt(s / delta)) // Chebyshev's inequality
var ub = s
while (lb < ub - 1.0) {
Expand All @@ -79,7 +96,16 @@ private[spark] object PoissonBounds {
poisson.inverseCumulativeProbability(delta)
}

def getLambda2(s: Double): Double = {
/**
* Compute the threshold for waitlisting items. An item is waitlisted if its associated value is
* greater than the lower bound determined above but below the upper bound computed here.
* The value is computed such that we only need to keep log(s) items in the waitlist and still be
* able to guarantee sample size with high confidence.
*
* @param s sample size
* @return threshold for waitlisting the item
*/
def getUpperBound(s: Double): Double = {
var lb = s
var ub = s + math.sqrt(s / delta) // Chebyshev's inequality
while (lb < ub - 1.0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ import scala.Some
import org.apache.spark.rdd.RDD

private[spark] object StratifiedSampler extends Logging {
/**
* Returns the function used by aggregate to collect sampling statistics for each partition.
*/
def getSeqOp[K, V](withReplacement: Boolean,
fractionByKey: (K => Double),
counts: Option[Map[K, Long]]): ((TaskContext, Result[K]),(K, V)) => Result[K] = {
Expand All @@ -43,9 +46,9 @@ private[spark] object StratifiedSampler extends Logging {
if (stratum.q1.isEmpty || stratum.q2.isEmpty) {
val n = counts.get(item._1)
val s = math.ceil(n * fraction).toLong
val lmbd1 = PB.getLambda1(s)
val lmbd1 = PB.getLowerBound(s)
val minCount = PB.getMinCount(lmbd1)
val lmbd2 = if (lmbd1 == 0) PB.getLambda2(s) else PB.getLambda2(s - minCount)
val lmbd2 = if (lmbd1 == 0) PB.getUpperBound(s) else PB.getUpperBound(s - minCount)
val q1 = lmbd1 / n
val q2 = lmbd2 / n
stratum.q1 = Some(q1)
Expand All @@ -60,6 +63,8 @@ private[spark] object StratifiedSampler extends Logging {
stratum.addToWaitList(ArrayBuffer.fill(x2)(rng.nextUniform(0.0, 1.0)))
}
} else {
// We use the streaming version of the algorithm for sampling without replacement.
// Hence, q1 and q2 change on every iteration.
val g1 = - math.log(delta) / stratum.numItems
val g2 = (2.0 / 3.0) * g1
val q1 = math.max(0, fraction + g2 - math.sqrt((g2 * g2 + 3 * g2 * fraction)))
Expand All @@ -79,7 +84,11 @@ private[spark] object StratifiedSampler extends Logging {
}
}

def getCombOp[K](): (Result[K], Result[K]) => Result[K] = {
/**
* Returns the function used by aggregate to combine results from different partitions, as
* returned by seqOp.
*/
def getCombOp[K](): (Result[K], Result[K]) => Result[K] = {
(r1: Result[K], r2: Result[K]) => {
// take union of both key sets in case one partition doesn't contain all keys
val keyUnion = r1.resultMap.keys.toSet.union(r2.resultMap.keys.toSet)
Expand All @@ -100,6 +109,10 @@ private[spark] object StratifiedSampler extends Logging {
}
}

/**
* Given the result returned by the aggregate function, we need to determine the threshold used
* to accept items to generate the exact sample size.
*/
def computeThresholdByKey[K](finalResult: Map[K, Stratum], fractionByKey: (K => Double)):
(K => Double) = {
val thresholdByKey = new mutable.HashMap[K, Double]()
Expand All @@ -122,11 +135,15 @@ private[spark] object StratifiedSampler extends Logging {
thresholdByKey
}

def computeThresholdByKey[K](finalResult: Map[K, String]): (K => String) = {
finalResult
}

def getBernoulliSamplingFunction[K, V](rdd:RDD[(K, V)],
/**
* Return the per partition sampling function used for sampling without replacement.
*
* When exact sample size is required, we make an additional pass over the RDD to determine the
* exact sampling rate that guarantees sample size with high confidence.
*
* The sampling function has a unique seed per partition.
*/
def getBernoulliSamplingFunction[K, V](rdd: RDD[(K, V)],
fractionByKey: K => Double,
exact: Boolean,
seed: Long): (Int, Iterator[(K, V)]) => Iterator[(K, V)] = {
Expand All @@ -146,6 +163,16 @@ private[spark] object StratifiedSampler extends Logging {
}
}

/**
* Return the per partition sampling function used for sampling with replacement.
*
* When exact sample size is required, we make two additional passed over the RDD to determine
* the exact sampling rate that guarantees sample size with high confidence. The first pass
* counts the number of items in each stratum (group of items with the same key) in the RDD, and
* the second pass uses the counts to determine exact sampling rates.
*
* The sampling function has a unique seed per partition.
*/
def getPoissonSamplingFunction[K, V](rdd:RDD[(K, V)],
fractionByKey: K => Double,
exact: Boolean,
Expand Down Expand Up @@ -191,6 +218,10 @@ private[spark] object StratifiedSampler extends Logging {
}
}

/**
* Object used by seqOp to keep track of the number of items accepted and items waitlisted per
* stratum, as well as the bounds for accepting and waitlisting items.
*/
private[random] class Stratum(var numItems: Long = 0L, var numAccepted: Long = 0L)
extends Serializable {

Expand All @@ -205,13 +236,14 @@ private[random] class Stratum(var numItems: Long = 0L, var numAccepted: Long = 0
def addToWaitList(elem: Double) = waitList += elem

def addToWaitList(elems: ArrayBuffer[Double]) = waitList ++= elems

override def toString() = {
"numItems: " + numItems + " numAccepted: " + numAccepted + " q1: " + q1 + " q2: " + q2 +
" waitListSize:" + waitList.size
}
}

/**
* Object used by seqOp and combOp to keep track of the sampling statistics for all strata.
*
* When used by seqOp for each partition, we also keep track of the partition ID in this object
* to make sure a single random number generator with a unique seed is used for each partition.
*/
private[random] class Result[K](var resultMap: Map[K, Stratum],
var cachedPartitionId: Option[Int] = None,
val seed: Long)
Expand Down
32 changes: 32 additions & 0 deletions core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,38 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5)))
}

test("aggregateWithContext") {
val data = Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3))
val numPartitions = 2
val pairs = sc.makeRDD(data, numPartitions)
//determine the partitionId for each pair
type StringMap = HashMap[String, Int]
val partitions = pairs.collectPartitions()
val offSets = new StringMap
for (i <- 0 to numPartitions - 1) {
partitions(i).foreach({ case (k, v) => offSets.put(k, offSets.getOrElse(k, 0) + i)})
}
val emptyMap = new StringMap {
override def default(key: String): Int = 0
}
val mergeElement: ((TaskContext, StringMap), (String, Int)) => StringMap = (arg1, pair) => {
val stringMap = arg1._2
val tc = arg1._1
stringMap(pair._1) += pair._2 + tc.partitionId
stringMap
}
val mergeMaps: (StringMap, StringMap) => StringMap = (map1, map2) => {
for ((key, value) <- map2) {
map1(key) += value
}
map1
}
val result = pairs.aggregateWithContext(emptyMap)(mergeElement, mergeMaps)
val expected = Set(("a", 6), ("b", 2), ("c", 5))
.map({ case (k, v) => (k -> (offSets.getOrElse(k, 0) + v))})
assert(result.toSet === expected)
}

test("basic caching") {
val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
assert(rdd.collect().toList === List(1, 2, 3, 4))
Expand Down

0 comments on commit 0214a76

Please sign in to comment.