diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index a1c9495b0795e..70eb7a874b75c 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2297,6 +2297,8 @@ setMethod("rename", setClassUnion("characterOrColumn", c("character", "Column")) +setClassUnion("numericOrColumn", c("numeric", "Column")) + #' Arrange Rows by Variables #' #' Sort a SparkDataFrame by the specified column(s). diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 50137130ab98e..920c43c4916b8 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -3052,14 +3052,19 @@ setMethod("array_position", #' \code{array_repeat}: Creates an array containing the left argument repeated the number of times #' given by the right argument. #' -#' @param n Column determining the number of repetitions. +#' @param count Column or constant determining the number of repetitions. #' @rdname column_collection_functions -#' @aliases array_repeat array_repeat,Column-method +#' @aliases array_repeat array_repeat,Column,numericOrColumn-method #' @note array_repeat since 2.4.0 setMethod("array_repeat", - signature(x = "Column", n = "Column"), - function(x, n) { - jc <- callJStatic("org.apache.spark.sql.functions", "array_repeat", x@jc, n@jc) + signature(x = "Column", count = "numericOrColumn"), + function(x, count) { + if (class(count) == "Column") { + count <- count@jc + } else { + count <- as.integer(count) + } + jc <- callJStatic("org.apache.spark.sql.functions", "array_repeat", x@jc, count) column(jc) }) @@ -3086,9 +3091,9 @@ setMethod("array_sort", #' @aliases arrays_overlap arrays_overlap,Column-method #' @note arrays_overlap since 2.4.0 setMethod("arrays_overlap", - signature(y = "Column", x = "Column"), - function(y, x) { - jc <- callJStatic("org.apache.spark.sql.functions", "arrays_overlap", y@jc, x@jc) + signature(x = "Column", y = "Column"), + function(x, y) { + jc <- callJStatic("org.apache.spark.sql.functions", "arrays_overlap", x@jc, y@jc) column(jc) }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 12b3b799cd718..8894cb1c5b92f 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -771,7 +771,7 @@ setGeneric("array_position", function(x, value) { standardGeneric("array_positio #' @rdname column_collection_functions #' @name NULL -setGeneric("array_repeat", function(x, n) { standardGeneric("array_repeat") }) +setGeneric("array_repeat", function(x, count) { standardGeneric("array_repeat") }) #' @rdname column_collection_functions #' @name NULL @@ -779,7 +779,7 @@ setGeneric("array_sort", function(x) { standardGeneric("array_sort") }) #' @rdname column_collection_functions #' @name NULL -setGeneric("arrays_overlap", function(y, x) { standardGeneric("arrays_overlap") }) +setGeneric("arrays_overlap", function(x, y) { standardGeneric("arrays_overlap") }) #' @rdname column_string_functions #' @name NULL diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 6c0c010e691f0..16c1fd5a065eb 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1508,6 +1508,9 @@ test_that("column functions", { result <- collect(select(df, array_repeat(df[[1]], df[[2]])))[[1]] expect_equal(result, list(list("a", "a", "a"), list("b", "b"))) + result <- collect(select(df, array_repeat(df[[1]], 2L)))[[1]] + expect_equal(result, list(list("a", "a"), list("b", "b"))) + # Test arrays_overlap() df <- createDataFrame(list(list(list(1L, 2L), list(3L, 1L)), list(list(1L, 2L), list(3L, 4L)),