Skip to content

Commit

Permalink
Changed fractionByKey to a map to enable arg check
Browse files Browse the repository at this point in the history
  • Loading branch information
dorx committed Jun 19, 2014
1 parent 944a10c commit 1fe1cff
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 13 deletions.
12 changes: 3 additions & 9 deletions core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
* math.ceil(fraction * S_i), where S_i is the size of the ith stratum (collection of entries
* that share the same key). When sampling without replacement, we need one additional pass over
* the RDD to guarantee sample size with a 99.99% confidence; when sampling with replacement, we
* need two additional passes over the RDD to guarantee sample size with a 99.99% confidence.
*
* Note that if the sampling rate for any stratum is < 1e-10, we will throw an exception to
* avoid not being able to ever create the sample as an artifact of the RNG's quality.
* need two additional passes.
*
* @param withReplacement whether to sample with or without replacement
* @param fractionByKey function mapping key to sampling rate
Expand All @@ -227,14 +224,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
* @return RDD containing the sampled subset
*/
def sampleByKey(withReplacement: Boolean,
fractionByKey: K => Double,
fractionByKey: Map[K, Double],
seed: Long = Utils.random.nextLong,
exact: Boolean = true): RDD[(K, V)]= {

require(fractionByKey.asInstanceOf[Map[K, Double]].forall({case(k, v) => v >= 1e-10}),
"Unable to support sampling rates < 1e-10.")

if (withReplacement) {
require(fractionByKey.forall({case(k, v) => v >= 0.0}), "Invalid sampling rates.")
val counts = if (exact) Some(this.countByKey()) else None
val samplingFunc =
StratifiedSampler.getPoissonSamplingFunction(self, fractionByKey, exact, counts, seed)
Expand Down
3 changes: 1 addition & 2 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -351,12 +351,11 @@ abstract class RDD[T: ClassTag](
/**
* Return a sampled subset of this RDD.
*
* fraction < 1e-10 not supported.
*/
def sample(withReplacement: Boolean,
fraction: Double,
seed: Long = Utils.random.nextLong): RDD[T] = {
require(fraction >= 1e-10, "Invalid fraction value: " + fraction)
require(fraction >= 0.0, "Invalid fraction value: " + fraction)
if (withReplacement) {
new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), seed)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ private[spark] object SamplingUtils {
* ~ Binomial(total, fraction) and our choice of q guarantees 1-delta, or 0.9999 success
* rate, where success rate is defined the same as in sampling with replacement.
*
* The smallest sampling rate supported is 1e-10 (in order to avoid running into the limit of the
* RNG's resolution).
*
* @param sampleSizeLowerBound sample size
* @param total size of RDD
* @param withReplacement whether sampling with replacement
Expand All @@ -47,11 +50,11 @@ private[spark] object SamplingUtils {
val fraction = sampleSizeLowerBound.toDouble / total
if (withReplacement) {
val numStDev = if (sampleSizeLowerBound < 12) 9 else 5
fraction + numStDev * math.sqrt(fraction / total)
math.max(1e-10, fraction + numStDev * math.sqrt(fraction / total))
} else {
val delta = 1e-4
val gamma = - math.log(delta) / total
math.min(1, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction))
math.min(1, math.max(1e-10, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction)))
}
}
}
Expand Down

0 comments on commit 1fe1cff

Please sign in to comment.