Skip to content

Commit

Permalink
[SPARK-36976][R] Add max_by/min_by API to SparkR
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Add max_by/min_by to SparkR

### Why are the changes needed?

for sparkr users' convenience

### Does this PR introduce _any_ user-facing change?

yes, new methods are added

### How was this patch tested?

unit test

Closes #34258 from yoda-mon/max-by-min-by-r.

Authored-by: Leona Yoda <yodal@oss.nttdata.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
yoda-mon authored and HyukjinKwon committed Oct 13, 2021
1 parent bc7e4f5 commit 5982162
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 0 deletions.
2 changes: 2 additions & 0 deletions R/pkg/NAMESPACE
Expand Up @@ -352,9 +352,11 @@ exportMethods("%<=>%",
"map_values",
"map_zip_with",
"max",
"max_by",
"md5",
"mean",
"min",
"min_by",
"minute",
"monotonically_increasing_id",
"month",
Expand Down
46 changes: 46 additions & 0 deletions R/pkg/R/functions.R
Expand Up @@ -1479,6 +1479,29 @@ setMethod("max",
column(jc)
})

#' @details
#' \code{max_by}: Returns the value associated with the maximum value of ord.
#'
#' @rdname column_aggregate_functions
#' @aliases max_by max_by,Column-method
#' @note max_by since 3.3.0
#' @examples
#'
#' \dontrun{
#' df <- createDataFrame(
#' list(list("Java", 2012, 20000), list("dotNET", 2012, 5000),
#' list("dotNET", 2013, 48000), list("Java", 2013, 30000)),
#' list("course", "year", "earnings")
#' )
#' tmp <- agg(groupBy(df, df$"course"), "max_by" = max_by(df$"year", df$"earnings"))
#' head(tmp)}
setMethod("max_by",
signature(x = "Column", y = "Column"),
function(x, y) {
jc <- callJStatic("org.apache.spark.sql.functions", "max_by", x@jc, y@jc)
column(jc)
})

#' @details
#' \code{md5}: Calculates the MD5 digest of a binary column and returns the value
#' as a 32 character hex string.
Expand Down Expand Up @@ -1531,6 +1554,29 @@ setMethod("min",
column(jc)
})

#' @details
#' \code{min_by}: Returns the value associated with the minimum value of ord.
#'
#' @rdname column_aggregate_functions
#' @aliases min_by min_by,Column-method
#' @note min_by since 3.3.0
#' @examples
#'
#' \dontrun{
#' df <- createDataFrame(
#' list(list("Java", 2012, 20000), list("dotNET", 2012, 5000),
#' list("dotNET", 2013, 48000), list("Java", 2013, 30000)),
#' list("course", "year", "earnings")
#' )
#' tmp <- agg(groupBy(df, df$"course"), "min_by" = min_by(df$"year", df$"earnings"))
#' head(tmp)}
setMethod("min_by",
signature(x = "Column", y = "Column"),
function(x, y) {
jc <- callJStatic("org.apache.spark.sql.functions", "min_by", x@jc, y@jc)
column(jc)
})

#' @details
#' \code{minute}: Extracts the minute as an integer from a given date/timestamp/string.
#'
Expand Down
8 changes: 8 additions & 0 deletions R/pkg/R/generics.R
Expand Up @@ -1190,10 +1190,18 @@ setGeneric("map_values", function(x) { standardGeneric("map_values") })
#' @name NULL
setGeneric("map_zip_with", function(x, y, f) { standardGeneric("map_zip_with") })

#' @rdname column_aggregate_functions
#' @name NULL
setGeneric("max_by", function(x, y) { standardGeneric("max_by") })

#' @rdname column_misc_functions
#' @name NULL
setGeneric("md5", function(x) { standardGeneric("md5") })

#' @rdname column_aggregate_functions
#' @name NULL
setGeneric("min_by", function(x, y) { standardGeneric("min_by") })

#' @rdname column_datetime_functions
#' @name NULL
setGeneric("minute", function(x) { standardGeneric("minute") })
Expand Down
16 changes: 16 additions & 0 deletions R/pkg/tests/fulltests/test_sparkSQL.R
Expand Up @@ -2292,6 +2292,22 @@ test_that("group by, agg functions", {
unlink(jsonPath3)
})

test_that("SPARK-36976: Add max_by/min_by API to SparkR", {
df <- createDataFrame(
list(list("Java", 2012, 20000), list("dotNET", 2012, 5000),
list("dotNET", 2013, 48000), list("Java", 2013, 30000))
)
gd <- groupBy(df, df$"_1")

actual1 <- agg(gd, "_2" = max_by(df$"_2", df$"_3"))
expect1 <- createDataFrame(list(list("dotNET", 2013), list("Java", 2013)))
expect_equal(collect(actual1), collect(expect1))

actual2 <- agg(gd, "_2" = min_by(df$"_2", df$"_3"))
expect2 <- createDataFrame(list(list("dotNET", 2012), list("Java", 2012)))
expect_equal(collect(actual2), collect(expect2))
})

test_that("pivot GroupedData column", {
df <- createDataFrame(data.frame(
earnings = c(10000, 10000, 11000, 15000, 12000, 20000, 21000, 22000),
Expand Down

0 comments on commit 5982162

Please sign in to comment.