Skip to content

Commit

Permalink
addressed reviewer comments.
Browse files Browse the repository at this point in the history
Note that logging isn’t added to rdd.py because it seemed to be
clobbering with the log4j logs
  • Loading branch information
dorx committed Jun 12, 2014
1 parent ecab508 commit eff89e2
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 36 deletions.
25 changes: 13 additions & 12 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -389,8 +389,6 @@ abstract class RDD[T: ClassTag](
def takeSample(withReplacement: Boolean,
num: Int,
seed: Long = Utils.random.nextLong): Array[T] = {
var fraction = 0.0
var total = 0
val numStDev = 10.0
val initialCount = this.count()

Expand All @@ -407,27 +405,30 @@ abstract class RDD[T: ClassTag](
"sampling without replacement")
}

if (initialCount > Int.MaxValue - 1) {
val maxSelected = Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt
if (num > maxSelected) {
throw new IllegalArgumentException("Cannot support a sample size > Int.MaxValue - " +
s"$numStDev * math.sqrt(Int.MaxValue)")
}
val maxSampleSize = Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt
if (num > maxSampleSize) {
throw new IllegalArgumentException("Cannot support a sample size > Int.MaxValue - " +
s"$numStDev * math.sqrt(Int.MaxValue)")
}

fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount, withReplacement)
total = num
val fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount,
withReplacement)

val rand = new Random(seed)
var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()

// If the first sample didn't turn out large enough, keep trying to take samples;
// this shouldn't happen often because we use a big multiplier for the initial size
while (samples.length < total) {
var numIters = 0
while (samples.length < num) {
if (numIters > 0) {
logWarning(s"Needed to re-sample due to insufficient sample size. Repeat #$numIters")
}
samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
numIters += 1
}

Utils.randomizeInPlace(samples, rand).take(total)
Utils.randomizeInPlace(samples, rand).take(num)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ package org.apache.spark.util.random
private[spark] object SamplingUtils {

/**
* Returns a sampling rate that guarantees a sample of size >= sampleSizeLowerBound 99.99% of
* the time.
*
* How the sampling rate is determined:
* Let p = num / total, where num is the sample size and total is the total number of
* datapoints in the RDD. We're trying to compute q > p such that
* - when sampling with replacement, we're drawing each datapoint with prob_i ~ Pois(q),
Expand Down
4 changes: 2 additions & 2 deletions core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -547,8 +547,8 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements")
}
for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=true, 2*n, seed)
assert(sample.size === 2*n) // Got exactly 200 elements
val sample = data.takeSample(withReplacement=true, 2 * n, seed)
assert(sample.size === 2 * n) // Got exactly 200 elements
// Chance of getting all distinct elements is still quite low, so test we got < 100
assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ package org.apache.spark.util.random
import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution}
import org.scalatest.FunSuite

class SamplingUtilsSuite extends FunSuite{
class SamplingUtilsSuite extends FunSuite {

test("computeFraction") {
// test that the computed fraction guarantees enough datapoints
// test that the computed fraction guarantees enough data points
// in the sample with a failure rate <= 0.0001
val n = 100000

Expand Down
53 changes: 33 additions & 20 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
# limitations under the License.
#

from base64 import standard_b64encode as b64enc
import copy
from collections import defaultdict
from collections import namedtuple
from itertools import chain, ifilter, imap
Expand Down Expand Up @@ -364,8 +362,8 @@ def takeSample(self, withReplacement, num, seed=None):
[4, 2, 1, 8, 2, 7, 0, 4, 1, 4]
"""

fraction = 0.0
total = 0
#TODO remove
logging.basicConfig(level=logging.INFO)
numStDev = 10.0
initialCount = self.count()

Expand All @@ -378,38 +376,53 @@ def takeSample(self, withReplacement, num, seed=None):
if (not withReplacement) and num > initialCount:
raise ValueError

if initialCount > sys.maxint - 1:
maxSelected = sys.maxint - int(numStDev * sqrt(sys.maxint))
if num > maxSelected:
raise ValueError

fraction = self._computeFraction(num, initialCount, withReplacement)
total = num
maxSampleSize = sys.maxint - int(numStDev * sqrt(sys.maxint))
if num > maxSampleSize:
raise ValueError

fraction = self._computeFractionForSampleSize(num, initialCount, withReplacement)

samples = self.sample(withReplacement, fraction, seed).collect()

# If the first sample didn't turn out large enough, keep trying to take samples;
# this shouldn't happen often because we use a big multiplier for their initial size.
# See: scala/spark/RDD.scala
rand = Random(seed)
while len(samples) < total:
while len(samples) < num:
samples = self.sample(withReplacement, fraction, rand.randint(0, sys.maxint)).collect()

sampler = RDDSampler(withReplacement, fraction, rand.randint(0, sys.maxint))
sampler.shuffle(samples)
return samples[0:total]

def _computeFraction(self, num, total, withReplacement):
fraction = float(num)/total
return samples[0:num]

@staticmethod
def _computeFractionForSampleSize(sampleSizeLowerBound, total, withReplacement):
"""
Returns a sampling rate that guarantees a sample of size >= sampleSizeLowerBound 99.99% of
the time.
How the sampling rate is determined:
Let p = num / total, where num is the sample size and total is the total number of
datapoints in the RDD. We're trying to compute q > p such that
- when sampling with replacement, we're drawing each datapoint with prob_i ~ Pois(q),
where we want to guarantee Pr[s < num] < 0.0001 for s = sum(prob_i for i from 0 to
total), i.e. the failure rate of not having a sufficiently large sample < 0.0001.
Setting q = p + 5 * sqrt(p/total) is sufficient to guarantee 0.9999 success rate for
num > 12, but we need a slightly larger q (9 empirically determined).
- when sampling without replacement, we're drawing each datapoint with prob_i
~ Binomial(total, fraction) and our choice of q guarantees 1-delta, or 0.9999 success
rate, where success rate is defined the same as in sampling with replacement.
"""
fraction = float(sampleSizeLowerBound) / total
if withReplacement:
numStDev = 5
if (num < 12):
if (sampleSizeLowerBound < 12):
numStDev = 9
return fraction + numStDev * sqrt(fraction/total)
return fraction + numStDev * sqrt(fraction / total)
else:
delta = 0.00005
gamma = - log(delta)/total
return min(1, fraction + gamma + sqrt(gamma * gamma + 2* gamma * fraction))
gamma = - log(delta) / total
return min(1, fraction + gamma + sqrt(gamma * gamma + 2 * gamma * fraction))

def union(self, other):
"""
Expand Down

0 comments on commit eff89e2

Please sign in to comment.