Skip to content

Commit

Permalink
[SPARK-12158][SPARKR][SQL] Fix 'sample' functions that break R unit t…
Browse files Browse the repository at this point in the history
…est cases

The existing sample functions miss the parameter `seed`, however, the corresponding function interface in `generics` has such a parameter. Thus, although the function caller can call the function with the 'seed', we are not using the value.

This could cause SparkR unit tests failed. For example, I hit it in another PR:
https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/47213/consoleFull

Author: gatorsmile <gatorsmile@gmail.com>

Closes #10160 from gatorsmile/sampleR.
  • Loading branch information
gatorsmile authored and shivaram committed Dec 12, 2015
1 parent 1e799d6 commit 1e3526c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
17 changes: 11 additions & 6 deletions R/pkg/R/DataFrame.R
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,7 @@ setMethod("unique",
#' @param x A SparkSQL DataFrame
#' @param withReplacement Sampling with replacement or not
#' @param fraction The (rough) sample target fraction
#' @param seed Randomness seed value
#'
#' @family DataFrame functions
#' @rdname sample
Expand All @@ -677,13 +678,17 @@ setMethod("unique",
#' collect(sample(df, TRUE, 0.5))
#'}
setMethod("sample",
# TODO : Figure out how to send integer as java.lang.Long to JVM so
# we can send seed as an argument through callJMethod
signature(x = "DataFrame", withReplacement = "logical",
fraction = "numeric"),
function(x, withReplacement, fraction) {
function(x, withReplacement, fraction, seed) {
if (fraction < 0.0) stop(cat("Negative fraction value:", fraction))
sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction)
if (!missing(seed)) {
# TODO : Figure out how to send integer as java.lang.Long to JVM so
# we can send seed as an argument through callJMethod
sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction, as.integer(seed))
} else {
sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction)
}
dataFrame(sdf)
})

Expand All @@ -692,8 +697,8 @@ setMethod("sample",
setMethod("sample_frac",
signature(x = "DataFrame", withReplacement = "logical",
fraction = "numeric"),
function(x, withReplacement, fraction) {
sample(x, withReplacement, fraction)
function(x, withReplacement, fraction, seed) {
sample(x, withReplacement, fraction, seed)
})

#' nrow
Expand Down
4 changes: 4 additions & 0 deletions R/pkg/inst/tests/testthat/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,10 @@ test_that("sample on a DataFrame", {
sampled2 <- sample(df, FALSE, 0.1, 0) # set seed for predictable result
expect_true(count(sampled2) < 3)

count1 <- count(sample(df, FALSE, 0.1, 0))
count2 <- count(sample(df, FALSE, 0.1, 0))
expect_equal(count1, count2)

# Also test sample_frac
sampled3 <- sample_frac(df, FALSE, 0.1, 0) # set seed for predictable result
expect_true(count(sampled3) < 3)
Expand Down

0 comments on commit 1e3526c

Please sign in to comment.