From 5982162853932067f0df585d39b69ca8c95b49a5 Mon Sep 17 00:00:00 2001 From: Leona Yoda Date: Wed, 13 Oct 2021 15:24:13 +0900 Subject: [PATCH] [SPARK-36976][R] Add max_by/min_by API to SparkR ### What changes were proposed in this pull request? Add max_by/min_by to SparkR ### Why are the changes needed? for sparkr users' convenience ### Does this PR introduce _any_ user-facing change? yes, new methods are added ### How was this patch tested? unit test Closes #34258 from yoda-mon/max-by-min-by-r. Authored-by: Leona Yoda Signed-off-by: Hyukjin Kwon --- R/pkg/NAMESPACE | 2 ++ R/pkg/R/functions.R | 46 +++++++++++++++++++++++++++ R/pkg/R/generics.R | 8 +++++ R/pkg/tests/fulltests/test_sparkSQL.R | 16 ++++++++++ 4 files changed, 72 insertions(+) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 11403f6346f47..10bb02a9c0551 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -352,9 +352,11 @@ exportMethods("%<=>%", "map_values", "map_zip_with", "max", + "max_by", "md5", "mean", "min", + "min_by", "minute", "monotonically_increasing_id", "month", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 987d11087d42a..fdbf48ba870fb 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -1479,6 +1479,29 @@ setMethod("max", column(jc) }) +#' @details +#' \code{max_by}: Returns the value associated with the maximum value of ord. +#' +#' @rdname column_aggregate_functions +#' @aliases max_by max_by,Column-method +#' @note max_by since 3.3.0 +#' @examples +#' +#' \dontrun{ +#' df <- createDataFrame( +#' list(list("Java", 2012, 20000), list("dotNET", 2012, 5000), +#' list("dotNET", 2013, 48000), list("Java", 2013, 30000)), +#' list("course", "year", "earnings") +#' ) +#' tmp <- agg(groupBy(df, df$"course"), "max_by" = max_by(df$"year", df$"earnings")) +#' head(tmp)} +setMethod("max_by", + signature(x = "Column", y = "Column"), + function(x, y) { + jc <- callJStatic("org.apache.spark.sql.functions", "max_by", x@jc, y@jc) + column(jc) + }) + #' @details #' \code{md5}: Calculates the MD5 digest of a binary column and returns the value #' as a 32 character hex string. @@ -1531,6 +1554,29 @@ setMethod("min", column(jc) }) +#' @details +#' \code{min_by}: Returns the value associated with the minimum value of ord. +#' +#' @rdname column_aggregate_functions +#' @aliases min_by min_by,Column-method +#' @note min_by since 3.3.0 +#' @examples +#' +#' \dontrun{ +#' df <- createDataFrame( +#' list(list("Java", 2012, 20000), list("dotNET", 2012, 5000), +#' list("dotNET", 2013, 48000), list("Java", 2013, 30000)), +#' list("course", "year", "earnings") +#' ) +#' tmp <- agg(groupBy(df, df$"course"), "min_by" = min_by(df$"year", df$"earnings")) +#' head(tmp)} +setMethod("min_by", + signature(x = "Column", y = "Column"), + function(x, y) { + jc <- callJStatic("org.apache.spark.sql.functions", "min_by", x@jc, y@jc) + column(jc) + }) + #' @details #' \code{minute}: Extracts the minute as an integer from a given date/timestamp/string. #' diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index ad29a7019ee18..af19e7268e710 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1190,10 +1190,18 @@ setGeneric("map_values", function(x) { standardGeneric("map_values") }) #' @name NULL setGeneric("map_zip_with", function(x, y, f) { standardGeneric("map_zip_with") }) +#' @rdname column_aggregate_functions +#' @name NULL +setGeneric("max_by", function(x, y) { standardGeneric("max_by") }) + #' @rdname column_misc_functions #' @name NULL setGeneric("md5", function(x) { standardGeneric("md5") }) +#' @rdname column_aggregate_functions +#' @name NULL +setGeneric("min_by", function(x, y) { standardGeneric("min_by") }) + #' @rdname column_datetime_functions #' @name NULL setGeneric("minute", function(x) { standardGeneric("minute") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 1d8ac2b6bf0c5..b6e02bb0a01c9 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -2292,6 +2292,22 @@ test_that("group by, agg functions", { unlink(jsonPath3) }) +test_that("SPARK-36976: Add max_by/min_by API to SparkR", { + df <- createDataFrame( + list(list("Java", 2012, 20000), list("dotNET", 2012, 5000), + list("dotNET", 2013, 48000), list("Java", 2013, 30000)) + ) + gd <- groupBy(df, df$"_1") + + actual1 <- agg(gd, "_2" = max_by(df$"_2", df$"_3")) + expect1 <- createDataFrame(list(list("dotNET", 2013), list("Java", 2013))) + expect_equal(collect(actual1), collect(expect1)) + + actual2 <- agg(gd, "_2" = min_by(df$"_2", df$"_3")) + expect2 <- createDataFrame(list(list("dotNET", 2012), list("Java", 2012))) + expect_equal(collect(actual2), collect(expect2)) +}) + test_that("pivot GroupedData column", { df <- createDataFrame(data.frame( earnings = c(10000, 10000, 11000, 15000, 12000, 20000, 21000, 22000),