Skip to content

Commit

Permalink
Reservoir sampling implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
rxin committed Jul 18, 2014
1 parent 72e9021 commit 6940010
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,48 @@

package org.apache.spark.util.random

import scala.reflect.ClassTag

private[spark] object SamplingUtils {

/**
* Reservoir Sampling implementation.
*
* @param input input size
* @param k reservoir size
* @return (samples, input size)
*/
def reservoirSample[T: ClassTag](input: Iterator[T], k: Int): (Array[T], Int) = {
val reservoir = new Array[T](k)
// Put the first k elements in the reservoir.
var i = 0
while (i < k && input.hasNext) {
val item = input.next()
reservoir(i) = item
i += 1
}

// If we have consumed all the elements, return them. Otherwise do the replacement.
if (i < k) {
// If input size < k, trim the array to return only an array of input size.
val trimReservoir = new Array[T](i)
System.arraycopy(reservoir, 0, trimReservoir, 0, i)
(trimReservoir, i)
} else {
// If input size > k, continue the sampling process.
val rand = new XORShiftRandom
while (input.hasNext) {
val item = input.next()
val replacementIndex = rand.nextInt(i)
if (replacementIndex < k) {
reservoir(replacementIndex) = item
}
i += 1
}
(reservoir, i)
}
}

/**
* Returns a sampling rate that guarantees a sample of size >= sampleSizeLowerBound 99.99% of
* the time.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,32 @@

package org.apache.spark.util.random

import scala.util.Random

import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution}
import org.scalatest.FunSuite

class SamplingUtilsSuite extends FunSuite {

test("reservoirSample") {
val input = Seq.fill(100)(Random.nextInt())

// input size < k
val (sample1, count1) = SamplingUtils.reservoirSample(input.iterator, 150)
assert(count1 === 100)
assert(input === sample1.toSeq)

// input size == k
val (sample2, count2) = SamplingUtils.reservoirSample(input.iterator, 100)
assert(count2 === 100)
assert(input === sample2.toSeq)

// input size > k
val (sample3, count3) = SamplingUtils.reservoirSample(input.iterator, 10)
assert(count3 === 100)
assert(sample3.length === 10)
}

test("computeFraction") {
// test that the computed fraction guarantees enough data points
// in the sample with a failure rate <= 0.0001
Expand Down

0 comments on commit 6940010

Please sign in to comment.