Skip to content

Commit

Permalink
SPARK-1240: handle the case of empty RDD when takeSample
Browse files Browse the repository at this point in the history
https://spark-project.atlassian.net/browse/SPARK-1240

It seems that the current implementation does not handle the empty RDD case when run takeSample

In this patch, before calling sample() inside takeSample API, I add a checker for this case and returns an empty Array when it's a empty RDD; also in sample(), I add a checker for the invalid fraction value

In the test case, I also add several lines for this case

Author: CodingCat <zhunansjtu@gmail.com>

Closes apache#135 from CodingCat/SPARK-1240 and squashes the following commits:

fef57d4 [CodingCat] fix the same problem in PySpark
36db06b [CodingCat] create new test cases for takeSample from an empty red
810948d [CodingCat] further fix
a40e8fb [CodingCat] replace if with require
ad483fd [CodingCat] handle the case with empty RDD when take sample

Conflicts:
	core/src/main/scala/org/apache/spark/rdd/RDD.scala
	core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
  • Loading branch information
CodingCat authored and James Z.M. Gao committed Mar 18, 2014
1 parent 9f95130 commit 9a061b7
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 2 deletions.
10 changes: 8 additions & 2 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,10 @@ abstract class RDD[T: ClassTag](
/**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] =
def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] = {
require(fraction >= 0.0, "Invalid fraction value: " + fraction)
new SampledRDD(this, withReplacement, fraction, seed)
}

def takeSample(withReplacement: Boolean, num: Int, seed: Int): Array[T] = {
var fraction = 0.0
Expand All @@ -333,6 +335,10 @@ abstract class RDD[T: ClassTag](
throw new IllegalArgumentException("Negative number of elements requested")
}

if (initialCount == 0) {
return new Array[T](0)
}

if (initialCount > Integer.MAX_VALUE - 1) {
maxSelected = Integer.MAX_VALUE - 1
} else {
Expand All @@ -351,7 +357,7 @@ abstract class RDD[T: ClassTag](
var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()

// If the first sample didn't turn out large enough, keep trying to take samples;
// this shouldn't happen often because we use a big multiplier for thei initial size
// this shouldn't happen often because we use a big multiplier for the initial size
while (samples.length < total) {
samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
}
Expand Down
7 changes: 7 additions & 0 deletions core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ class RDDSuite extends FunSuite with SharedSparkContext {

test("takeSample") {
val data = sc.parallelize(1 to 100, 2)

for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=false, 20, seed)
assert(sample.size === 20) // Got exactly 20 elements
Expand Down Expand Up @@ -486,6 +487,12 @@ class RDDSuite extends FunSuite with SharedSparkContext {
}
}

test("takeSample from an empty rdd") {
val emptySet = sc.parallelize(Seq.empty[Int], 2)
val sample = emptySet.takeSample(false, 20, 1)
assert(sample.length === 0)
}

test("runJob on an invalid partition") {
intercept[IllegalArgumentException] {
sc.runJob(sc.parallelize(1 to 10, 2), {iter: Iterator[Int] => iter.size}, Seq(0, 1, 2), false)
Expand Down
4 changes: 4 additions & 0 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def sample(self, withReplacement, fraction, seed):
>>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest: +SKIP
[2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98]
"""
assert fraction >= 0.0, "Invalid fraction value: %s" % fraction
return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True)

# this is ported from scala/spark/RDD.scala
Expand All @@ -272,6 +273,9 @@ def takeSample(self, withReplacement, num, seed):
if (num < 0):
raise ValueError

if (initialCount == 0):
return list()

if initialCount > sys.maxint - 1:
maxSelected = sys.maxint - 1
else:
Expand Down

0 comments on commit 9a061b7

Please sign in to comment.