From 0a795336df20c7ec969366e613286f0c060a4eeb Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 15 Jul 2015 23:36:57 -0700 Subject: [PATCH] [SPARK-8807] [SPARKR] Add between operator in SparkR JIRA: https://issues.apache.org/jira/browse/SPARK-8807 Add between operator in SparkR. Author: Liang-Chi Hsieh Closes #7356 from viirya/add_r_between and squashes the following commits: 7f51b44 [Liang-Chi Hsieh] Add test for non-numeric column. c6a25c5 [Liang-Chi Hsieh] Add between function. --- R/pkg/NAMESPACE | 1 + R/pkg/R/column.R | 17 +++++++++++++++++ R/pkg/R/generics.R | 4 ++++ R/pkg/inst/tests/test_sparkSQL.R | 12 ++++++++++++ 4 files changed, 34 insertions(+) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 7f857222452d4..331307c2077a5 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -77,6 +77,7 @@ exportMethods("abs", "atan", "atan2", "avg", + "between", "cast", "cbrt", "ceiling", diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 8e4b0f5bf1c4d..2892e1416cc65 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -187,6 +187,23 @@ setMethod("substr", signature(x = "Column"), column(jc) }) +#' between +#' +#' Test if the column is between the lower bound and upper bound, inclusive. +#' +#' @rdname column +#' +#' @param bounds lower and upper bounds +setMethod("between", signature(x = "Column"), + function(x, bounds) { + if (is.vector(bounds) && length(bounds) == 2) { + jc <- callJMethod(x@jc, "between", bounds[1], bounds[2]) + column(jc) + } else { + stop("bounds should be a vector of lower and upper bounds") + } + }) + #' Casts the column to a different data type. #' #' @rdname column diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index fad9d71158c51..ebe6fbd97ce86 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -567,6 +567,10 @@ setGeneric("asc", function(x) { standardGeneric("asc") }) #' @export setGeneric("avg", function(x, ...) { standardGeneric("avg") }) +#' @rdname column +#' @export +setGeneric("between", function(x, bounds) { standardGeneric("between") }) + #' @rdname column #' @export setGeneric("cast", function(x, dataType) { standardGeneric("cast") }) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 76f74f80834a9..cdfe6481f60ea 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -638,6 +638,18 @@ test_that("column functions", { c7 <- floor(c) + log(c) + log10(c) + log1p(c) + rint(c) c8 <- sign(c) + sin(c) + sinh(c) + tan(c) + tanh(c) c9 <- toDegrees(c) + toRadians(c) + + df <- jsonFile(sqlContext, jsonPath) + df2 <- select(df, between(df$age, c(20, 30)), between(df$age, c(10, 20))) + expect_equal(collect(df2)[[2, 1]], TRUE) + expect_equal(collect(df2)[[2, 2]], FALSE) + expect_equal(collect(df2)[[3, 1]], FALSE) + expect_equal(collect(df2)[[3, 2]], TRUE) + + df3 <- select(df, between(df$name, c("Apache", "Spark"))) + expect_equal(collect(df3)[[1, 1]], TRUE) + expect_equal(collect(df3)[[2, 1]], FALSE) + expect_equal(collect(df3)[[3, 1]], TRUE) }) test_that("column binary mathfunctions", {