From c6a25c5cda92431d8a98b458914530fecc6caa5a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 10 Jul 2015 00:59:59 +0800 Subject: [PATCH] Add between function. --- R/pkg/NAMESPACE | 1 + R/pkg/R/column.R | 17 +++++++++++++++++ R/pkg/R/generics.R | 4 ++++ R/pkg/inst/tests/test_sparkSQL.R | 7 +++++++ 4 files changed, 29 insertions(+) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 7f857222452d4..331307c2077a5 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -77,6 +77,7 @@ exportMethods("abs", "atan", "atan2", "avg", + "between", "cast", "cbrt", "ceiling", diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 8e4b0f5bf1c4d..2892e1416cc65 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -187,6 +187,23 @@ setMethod("substr", signature(x = "Column"), column(jc) }) +#' between +#' +#' Test if the column is between the lower bound and upper bound, inclusive. +#' +#' @rdname column +#' +#' @param bounds lower and upper bounds +setMethod("between", signature(x = "Column"), + function(x, bounds) { + if (is.vector(bounds) && length(bounds) == 2) { + jc <- callJMethod(x@jc, "between", bounds[1], bounds[2]) + column(jc) + } else { + stop("bounds should be a vector of lower and upper bounds") + } + }) + #' Casts the column to a different data type. #' #' @rdname column diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index fad9d71158c51..ebe6fbd97ce86 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -567,6 +567,10 @@ setGeneric("asc", function(x) { standardGeneric("asc") }) #' @export setGeneric("avg", function(x, ...) { standardGeneric("avg") }) +#' @rdname column +#' @export +setGeneric("between", function(x, bounds) { standardGeneric("between") }) + #' @rdname column #' @export setGeneric("cast", function(x, dataType) { standardGeneric("cast") }) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index b0ea38854304e..a949a4f010b2c 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -612,6 +612,13 @@ test_that("column functions", { c7 <- floor(c) + log(c) + log10(c) + log1p(c) + rint(c) c8 <- sign(c) + sin(c) + sinh(c) + tan(c) + tanh(c) c9 <- toDegrees(c) + toRadians(c) + + df <- jsonFile(sqlContext, jsonPath) + df2 <- select(df, between(df$age, c(20, 30)), between(df$age, c(10, 20))) + expect_equal(collect(df2)[[2, 1]], TRUE) + expect_equal(collect(df2)[[2, 2]], FALSE) + expect_equal(collect(df2)[[3, 1]], FALSE) + expect_equal(collect(df2)[[3, 2]], TRUE) }) test_that("column binary mathfunctions", {