Skip to content

Commit

Permalink
[SPARK-24054][R] Add array_position function / element_at functions
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

This PR proposes to add array_position and element_at in R side too.

array_position:

```r
df <- createDataFrame(cbind(model = rownames(mtcars), mtcars))
mutated <- mutate(df, v1 = create_array(df$gear, df$am, df$carb))
head(select(mutated, array_position(mutated$v1, 1)))
```

```
  array_position(v1, 1.0)
1                       2
2                       2
3                       2
4                       3
5                       0
6                       3
```

element_at:

```r
df <- createDataFrame(cbind(model = rownames(mtcars), mtcars))
mutated <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp))
head(select(mutated, element_at(mutated$v1, 1)))
```

```
  element_at(v1, 1.0)
1                21.0
2                21.0
3                22.8
4                21.4
5                18.7
6                18.1
```

```r
df <- createDataFrame(cbind(model = rownames(mtcars), mtcars))
mutated <- mutate(df, v1 = create_map(df$model, df$cyl))
head(select(mutated, element_at(mutated$v1, "Valiant")))
```

```
  element_at(v3, Valiant)
1                      NA
2                      NA
3                      NA
4                      NA
5                      NA
6                       6
```

## How was this patch tested?

Unit tests were added in `R/pkg/tests/fulltests/test_sparkSQL.R` and manually tested. Documentation was manually built and verified.

Author: hyukjinkwon <gurwls223@apache.org>

Closes #21130 from HyukjinKwon/sparkr_array_position_element_at.
  • Loading branch information
HyukjinKwon committed Apr 24, 2018
1 parent c303b1b commit 87e8a57
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 4 deletions.
2 changes: 2 additions & 0 deletions R/pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ exportMethods("%<=>%",
"approxCountDistinct",
"approxQuantile",
"array_contains",
"array_position",
"asc",
"ascii",
"asin",
Expand Down Expand Up @@ -245,6 +246,7 @@ exportMethods("%<=>%",
"decode",
"dense_rank",
"desc",
"element_at",
"encode",
"endsWith",
"exp",
Expand Down
42 changes: 40 additions & 2 deletions R/pkg/R/functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,11 @@ NULL
#' the map or array of maps.
#' \item \code{from_json}: it is the column containing the JSON string.
#' }
#' @param value A value to compute on.
#' \itemize{
#' \item \code{array_contains}: a value to be checked if contained in the column.
#' \item \code{array_position}: a value to locate in the given array.
#' }
#' @param ... additional argument(s). In \code{to_json} and \code{from_json}, this contains
#' additional named properties to control how it is converted, accepts the same
#' options as the JSON data source.
Expand All @@ -201,14 +206,16 @@ NULL
#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars))
#' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp))
#' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1)))
#' head(select(tmp, array_position(tmp$v1, 21)))
#' tmp2 <- mutate(tmp, v2 = explode(tmp$v1))
#' head(tmp2)
#' head(select(tmp, posexplode(tmp$v1)))
#' head(select(tmp, sort_array(tmp$v1)))
#' head(select(tmp, sort_array(tmp$v1, asc = FALSE)))
#' tmp3 <- mutate(df, v3 = create_map(df$model, df$cyl))
#' head(select(tmp3, map_keys(tmp3$v3)))
#' head(select(tmp3, map_values(tmp3$v3)))}
#' head(select(tmp3, map_values(tmp3$v3)))
#' head(select(tmp3, element_at(tmp3$v3, "Valiant")))}
NULL

#' Window functions for Column operations
Expand Down Expand Up @@ -2975,7 +2982,6 @@ setMethod("row_number",
#' \code{array_contains}: Returns null if the array is null, true if the array contains
#' the value, and false otherwise.
#'
#' @param value a value to be checked if contained in the column
#' @rdname column_collection_functions
#' @aliases array_contains array_contains,Column-method
#' @note array_contains since 1.6.0
Expand All @@ -2986,6 +2992,22 @@ setMethod("array_contains",
column(jc)
})

#' @details
#' \code{array_position}: Locates the position of the first occurrence of the given value
#' in the given array. Returns NA if either of the arguments are NA.
#' Note: The position is not zero based, but 1 based index. Returns 0 if the given
#' value could not be found in the array.
#'
#' @rdname column_collection_functions
#' @aliases array_position array_position,Column-method
#' @note array_position since 2.4.0
setMethod("array_position",
signature(x = "Column", value = "ANY"),
function(x, value) {
jc <- callJStatic("org.apache.spark.sql.functions", "array_position", x@jc, value)
column(jc)
})

#' @details
#' \code{map_keys}: Returns an unordered array containing the keys of the map.
#'
Expand All @@ -3012,6 +3034,22 @@ setMethod("map_values",
column(jc)
})

#' @details
#' \code{element_at}: Returns element of array at given index in \code{extraction} if
#' \code{x} is array. Returns value for the given key in \code{extraction} if \code{x} is map.
#' Note: The position is not zero based, but 1 based index.
#'
#' @param extraction index to check for in array or key to check for in map
#' @rdname column_collection_functions
#' @aliases element_at element_at,Column-method
#' @note element_at since 2.4.0
setMethod("element_at",
signature(x = "Column", extraction = "ANY"),
function(x, extraction) {
jc <- callJStatic("org.apache.spark.sql.functions", "element_at", x@jc, extraction)
column(jc)
})

#' @details
#' \code{explode}: Creates a new row for each element in the given array or map column.
#'
Expand Down
8 changes: 8 additions & 0 deletions R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,10 @@ setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCoun
#' @name NULL
setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") })

#' @rdname column_collection_functions
#' @name NULL
setGeneric("array_position", function(x, value) { standardGeneric("array_position") })

#' @rdname column_string_functions
#' @name NULL
setGeneric("ascii", function(x) { standardGeneric("ascii") })
Expand Down Expand Up @@ -886,6 +890,10 @@ setGeneric("decode", function(x, charset) { standardGeneric("decode") })
#' @name NULL
setGeneric("dense_rank", function(x = "missing") { standardGeneric("dense_rank") })

#' @rdname column_collection_functions
#' @name NULL
setGeneric("element_at", function(x, extraction) { standardGeneric("element_at") })

#' @rdname column_string_functions
#' @name NULL
setGeneric("encode", function(x, charset) { standardGeneric("encode") })
Expand Down
13 changes: 11 additions & 2 deletions R/pkg/tests/fulltests/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -1479,24 +1479,33 @@ test_that("column functions", {
df5 <- createDataFrame(list(list(a = "010101")))
expect_equal(collect(select(df5, conv(df5$a, 2, 16)))[1, 1], "15")

# Test array_contains() and sort_array()
# Test array_contains(), array_position(), element_at() and sort_array()
df <- createDataFrame(list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L))))
result <- collect(select(df, array_contains(df[[1]], 1L)))[[1]]
expect_equal(result, c(TRUE, FALSE))

result <- collect(select(df, array_position(df[[1]], 1L)))[[1]]
expect_equal(result, c(1, 0))

result <- collect(select(df, element_at(df[[1]], 1L)))[[1]]
expect_equal(result, c(1, 6))

result <- collect(select(df, sort_array(df[[1]], FALSE)))[[1]]
expect_equal(result, list(list(3L, 2L, 1L), list(6L, 5L, 4L)))
result <- collect(select(df, sort_array(df[[1]])))[[1]]
expect_equal(result, list(list(1L, 2L, 3L), list(4L, 5L, 6L)))

# Test map_keys() and map_values()
# Test map_keys(), map_values() and element_at()
df <- createDataFrame(list(list(map = as.environment(list(x = 1, y = 2)))))
result <- collect(select(df, map_keys(df$map)))[[1]]
expect_equal(result, list(list("x", "y")))

result <- collect(select(df, map_values(df$map)))[[1]]
expect_equal(result, list(list(1, 2)))

result <- collect(select(df, element_at(df$map, "y")))[[1]]
expect_equal(result, 2)

# Test that stats::lag is working
expect_equal(length(lag(ldeaths, 12)), 72)

Expand Down

0 comments on commit 87e8a57

Please sign in to comment.