Skip to content

Commit

Permalink
[SPARK-9803] [SPARKR] Add subset and transform + tests
Browse files Browse the repository at this point in the history
Add subset and transform
Also reorganize `[` & `[[` to subset instead of select

Note: for transform, transform is very similar to mutate. Spark doesn't seem to replace existing column with the name in mutate (ie. `mutate(df, age = df$age + 2)` - returned DataFrame has 2 columns with the same name 'age'), so therefore not doing that for now in transform.
Though it is clearly stated it should replace column with matching name (should I open a JIRA for mutate/transform?)

Author: felixcheung <felixcheung_m@hotmail.com>

Closes #8503 from felixcheung/rsubset_transform.

(cherry picked from commit 2a4e00c)
Signed-off-by: Shivaram Venkataraman <shivaram@cs.berkeley.edu>
  • Loading branch information
felixcheung authored and shivaram committed Aug 29, 2015
1 parent df4a2e6 commit b7aab1d
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 17 deletions.
2 changes: 2 additions & 0 deletions R/pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,11 @@ exportMethods("arrange",
"selectExpr",
"show",
"showDF",
"subset",
"summarize",
"summary",
"take",
"transform",
"unionAll",
"unique",
"unpersist",
Expand Down
70 changes: 55 additions & 15 deletions R/pkg/R/DataFrame.R
Original file line number Diff line number Diff line change
Expand Up @@ -956,7 +956,7 @@ setMethod("$<-", signature(x = "DataFrame"),

setClassUnion("numericOrcharacter", c("numeric", "character"))

#' @rdname select
#' @rdname subset
#' @name [[
setMethod("[[", signature(x = "DataFrame", i = "numericOrcharacter"),
function(x, i) {
Expand All @@ -967,7 +967,7 @@ setMethod("[[", signature(x = "DataFrame", i = "numericOrcharacter"),
getColumn(x, i)
})

#' @rdname select
#' @rdname subset
#' @name [
setMethod("[", signature(x = "DataFrame", i = "missing"),
function(x, i, j, ...) {
Expand All @@ -981,20 +981,51 @@ setMethod("[", signature(x = "DataFrame", i = "missing"),
select(x, j)
})

#' @rdname select
#' @rdname subset
#' @name [
setMethod("[", signature(x = "DataFrame", i = "Column"),
function(x, i, j, ...) {
# It could handle i as "character" but it seems confusing and not required
# https://stat.ethz.ch/R-manual/R-devel/library/base/html/Extract.data.frame.html
filtered <- filter(x, i)
if (!missing(j)) {
filtered[, j]
filtered[, j, ...]
} else {
filtered
}
})

#' Subset
#'
#' Return subsets of DataFrame according to given conditions
#' @param x A DataFrame
#' @param subset A logical expression to filter on rows
#' @param select expression for the single Column or a list of columns to select from the DataFrame
#' @return A new DataFrame containing only the rows that meet the condition with selected columns
#' @export
#' @rdname subset
#' @name subset
#' @aliases [
#' @family subsetting functions
#' @examples
#' \dontrun{
#' # Columns can be selected using `[[` and `[`
#' df[[2]] == df[["age"]]
#' df[,2] == df[,"age"]
#' df[,c("name", "age")]
#' # Or to filter rows
#' df[df$age > 20,]
#' # DataFrame can be subset on both rows and Columns
#' df[df$name == "Smith", c(1,2)]
#' df[df$age %in% c(19, 30), 1:2]
#' subset(df, df$age %in% c(19, 30), 1:2)
#' subset(df, df$age %in% c(19), select = c(1,2))
#' }
setMethod("subset", signature(x = "DataFrame"),
function(x, subset, select, ...) {
x[subset, select, ...]
})

#' Select
#'
#' Selects a set of columns with names or Column expressions.
Expand All @@ -1003,22 +1034,17 @@ setMethod("[", signature(x = "DataFrame", i = "Column"),
#' @return A new DataFrame with selected columns
#' @export
#' @rdname select
#' @name select
#' @family subsetting functions
#' @examples
#' \dontrun{
#' select(df, "*")
#' select(df, "col1", "col2")
#' select(df, df$name, df$age + 1)
#' select(df, c("col1", "col2"))
#' select(df, list(df$name, df$age + 1))
#' # Columns can also be selected using `[[` and `[`
#' df[[2]] == df[["age"]]
#' df[,2] == df[,"age"]
#' df[,c("name", "age")]
#' # Similar to R data frames columns can also be selected using `$`
#' df$age
#' # It can also be subset on rows and Columns
#' df[df$name == "Smith", c(1,2)]
#' df[df$age %in% c(19, 30), 1:2]
#' }
setMethod("select", signature(x = "DataFrame", col = "character"),
function(x, col, ...) {
Expand Down Expand Up @@ -1090,7 +1116,7 @@ setMethod("selectExpr",
#' @return A DataFrame with the new column added.
#' @rdname withColumn
#' @name withColumn
#' @aliases mutate
#' @aliases mutate transform
#' @export
#' @examples
#'\dontrun{
Expand All @@ -1110,11 +1136,12 @@ setMethod("withColumn",
#'
#' Return a new DataFrame with the specified columns added.
#'
#' @param x A DataFrame
#' @param .data A DataFrame
#' @param col a named argument of the form name = col
#' @return A new DataFrame with the new columns added.
#' @rdname withColumn
#' @name mutate
#' @aliases withColumn transform
#' @export
#' @examples
#'\dontrun{
Expand All @@ -1124,10 +1151,12 @@ setMethod("withColumn",
#' df <- jsonFile(sqlContext, path)
#' newDF <- mutate(df, newCol = df$col1 * 5, newCol2 = df$col1 * 2)
#' names(newDF) # Will contain newCol, newCol2
#' newDF2 <- transform(df, newCol = df$col1 / 5, newCol2 = df$col1 * 2)
#' }
setMethod("mutate",
signature(x = "DataFrame"),
function(x, ...) {
signature(.data = "DataFrame"),
function(.data, ...) {
x <- .data
cols <- list(...)
stopifnot(length(cols) > 0)
stopifnot(class(cols[[1]]) == "Column")
Expand All @@ -1142,6 +1171,16 @@ setMethod("mutate",
do.call(select, c(x, x$"*", cols))
})

#' @export
#' @rdname withColumn
#' @name transform
#' @aliases withColumn mutate
setMethod("transform",
signature(`_data` = "DataFrame"),
function(`_data`, ...) {
mutate(`_data`, ...)
})

#' WithColumnRenamed
#'
#' Rename an existing column in a DataFrame.
Expand Down Expand Up @@ -1269,6 +1308,7 @@ setMethod("orderBy",
#' @return A DataFrame containing only the rows that meet the condition.
#' @rdname filter
#' @name filter
#' @family subsetting functions
#' @export
#' @examples
#'\dontrun{
Expand Down
10 changes: 9 additions & 1 deletion R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ setGeneric("merge")

#' @rdname withColumn
#' @export
setGeneric("mutate", function(x, ...) {standardGeneric("mutate") })
setGeneric("mutate", function(.data, ...) {standardGeneric("mutate") })

#' @rdname arrange
#' @export
Expand Down Expand Up @@ -507,6 +507,10 @@ setGeneric("saveAsTable", function(df, tableName, source, mode, ...) {
standardGeneric("saveAsTable")
})

#' @rdname withColumn
#' @export
setGeneric("transform", function(`_data`, ...) {standardGeneric("transform") })

#' @rdname write.df
#' @export
setGeneric("write.df", function(df, path, ...) { standardGeneric("write.df") })
Expand All @@ -531,6 +535,10 @@ setGeneric("selectExpr", function(x, expr, ...) { standardGeneric("selectExpr")
#' @export
setGeneric("showDF", function(x,...) { standardGeneric("showDF") })

# @rdname subset
# @export
setGeneric("subset", function(x, subset, select, ...) { standardGeneric("subset") })

#' @rdname agg
#' @export
setGeneric("summarize", function(x,...) { standardGeneric("summarize") })
Expand Down
20 changes: 19 additions & 1 deletion R/pkg/inst/tests/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,10 @@ test_that("subsetting", {
df5 <- df[df$age %in% c(19), c(1,2)]
expect_equal(count(df5), 1)
expect_equal(columns(df5), c("name", "age"))

df6 <- subset(df, df$age %in% c(30), c(1,2))
expect_equal(count(df6), 1)
expect_equal(columns(df6), c("name", "age"))
})

test_that("selectExpr() on a DataFrame", {
Expand Down Expand Up @@ -1028,7 +1032,7 @@ test_that("withColumn() and withColumnRenamed()", {
expect_equal(columns(newDF2)[1], "newerAge")
})

test_that("mutate(), rename() and names()", {
test_that("mutate(), transform(), rename() and names()", {
df <- jsonFile(sqlContext, jsonPath)
newDF <- mutate(df, newAge = df$age + 2)
expect_equal(length(columns(newDF)), 3)
Expand All @@ -1042,6 +1046,20 @@ test_that("mutate(), rename() and names()", {
names(newDF2) <- c("newerName", "evenNewerAge")
expect_equal(length(names(newDF2)), 2)
expect_equal(names(newDF2)[1], "newerName")

transformedDF <- transform(df, newAge = -df$age, newAge2 = df$age / 2)
expect_equal(length(columns(transformedDF)), 4)
expect_equal(columns(transformedDF)[3], "newAge")
expect_equal(columns(transformedDF)[4], "newAge2")
expect_equal(first(filter(transformedDF, transformedDF$name == "Andy"))$newAge, -30)

# test if transform on local data frames works
# ensure the proper signature is used - otherwise this will fail to run
attach(airquality)
result <- transform(Ozone, logOzone = log(Ozone))
expect_equal(nrow(result), 153)
expect_equal(ncol(result), 2)
detach(airquality)
})

test_that("write.df() on DataFrame and works with parquetFile", {
Expand Down

0 comments on commit b7aab1d

Please sign in to comment.