Skip to content

Commit

Permalink
bug fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
dorx committed Jul 28, 2014
1 parent 17a381b commit eaf5771
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,11 @@ private[spark] object SamplingUtils {
*/
def computeFractionForSampleSize(sampleSizeLowerBound: Int, total: Long,
withReplacement: Boolean): Double = {
val fraction = sampleSizeLowerBound.toDouble / total
if (withReplacement) {
PoissonBounds.getUpperBound(sampleSizeLowerBound)
PoissonBounds.getUpperBound(sampleSizeLowerBound) / total
} else {
BernoulliBounds.getLowerBound(1e-4, total, fraction)
val fraction = sampleSizeLowerBound.toDouble / total
BinomialBounds.getUpperBound(1e-4, total, fraction)
}
}
}
Expand Down Expand Up @@ -138,25 +138,25 @@ private[spark] object PoissonBounds {
* Utility functions that help us determine bounds on adjusted sampling rate to guarantee exact
* sample size with high confidence when sampling without replacement.
*/
private[spark] object BernoulliBounds {
private[spark] object BinomialBounds {

val minSamplingRate = 1e-10

/**
* Returns a threshold such that if we apply Bernoulli sampling with that threshold, it is very
* unlikely to sample less than `fraction * n` items out of `n` items.
* Returns a threshold `p` such that if we conduct n Bernoulli trials with success rate = `p`,
* it is very unlikely to have more than `fraction * n` successes.
*/
def getUpperBound(delta: Double, n: Long, fraction: Double): Double = {
def getLowerBound(delta: Double, n: Long, fraction: Double): Double = {
val gamma = - math.log(delta) / n * (2.0 / 3.0)
math.max(minSamplingRate,
fraction + gamma - math.sqrt(gamma * gamma + 3 * gamma * fraction))
}

/**
* Returns a threshold such that if we apply Bernoulli sampling with that threshold, it is very
* unlikely to sample more than `fraction * n` items out of `n` items.
* Returns a threshold `p` such that if we conduct n Bernoulli trials with success rate = `p`,
* it is very unlikely to have less than `fraction * n` successes.
*/
def getLowerBound(delta: Double, n: Long, fraction: Double): Double = {
def getUpperBound(delta: Double, n: Long, fraction: Double): Double = {
val gamma = - math.log(delta) / n
math.min(1, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,18 @@

package org.apache.spark.util.random

import cern.jet.random.Poisson
import cern.jet.random.engine.DRand

import scala.collection.Map
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

import cern.jet.random.Poisson
import cern.jet.random.engine.DRand

import org.apache.spark.Logging
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD

import scala.reflect.ClassTag

/**
* Auxiliary functions and data structures for the sampleByKey method in PairRDDFunctions.
*
Expand Down Expand Up @@ -119,9 +118,9 @@ private[spark] object StratifiedSamplingUtils extends Logging {
// using an extra pass over the RDD for computing the count.
// Hence, acceptBound and waitListBound change on every iteration.
acceptResult.acceptBound =
BernoulliBounds.getUpperBound(delta, acceptResult.numItems, fraction)
BinomialBounds.getLowerBound(delta, acceptResult.numItems, fraction)
acceptResult.waitListBound =
BernoulliBounds.getLowerBound(delta, acceptResult.numItems, fraction)
BinomialBounds.getUpperBound(delta, acceptResult.numItems, fraction)

val x = rng.nextUniform()
if (x < acceptResult.acceptBound) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
val stratifiedData = data.keyBy(stratifier(fractionPositive))

val samplingRate = 0.1
val seed = defaultSeed
checkAllCombos(stratifiedData, samplingRate, seed, n)
checkAllCombos(stratifiedData, samplingRate, defaultSeed, n)
}

// vary fractionPositive
Expand All @@ -179,8 +178,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
val stratifiedData = data.keyBy(stratifier(fractionPositive))

val samplingRate = 0.1
val seed = defaultSeed
checkAllCombos(stratifiedData, samplingRate, seed, n)
checkAllCombos(stratifiedData, samplingRate, defaultSeed, n)
}

// Use the same data for the rest of the tests
Expand All @@ -197,8 +195,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {

// vary sampling rate
for (samplingRate <- List(0.01, 0.05, 0.1, 0.5)) {
val seed = defaultSeed
checkAllCombos(stratifiedData, samplingRate, seed, n)
checkAllCombos(stratifiedData, samplingRate, defaultSeed, n)
}
}

Expand Down

0 comments on commit eaf5771

Please sign in to comment.