Skip to content

Commit

Permalink
[SPARK-7231] [SPARKR] Changes to make SparkR DataFrame dplyr friendly.
Browse files Browse the repository at this point in the history
Changes include
1. Rename sortDF to arrange
2. Add new aliases `group_by` and `sample_frac`, `summarize`
3. Add more user friendly column addition (mutate), rename
4. Support mean as an alias for avg in Scala and also support n_distinct, n as in dplyr

Using these changes we can pretty much run the examples as described in http://cran.rstudio.com/web/packages/dplyr/vignettes/introduction.html with the same syntax

The only thing missing in SparkR is auto resolving column names when used in an expression i.e. making something like `select(flights, delay)` works in dply but we right now need `select(flights, flights$delay)` or `select(flights, "delay")`. But this is a complicated change and I'll file a new issue for it

cc sun-rui rxin

Author: Shivaram Venkataraman <shivaram@cs.berkeley.edu>

Closes #6005 from shivaram/sparkr-df-api and squashes the following commits:

5e0716a [Shivaram Venkataraman] Fix some roxygen bugs
1254953 [Shivaram Venkataraman] Merge branch 'master' of https://github.com/apache/spark into sparkr-df-api
0521149 [Shivaram Venkataraman] Changes to make SparkR DataFrame dplyr friendly. Changes include 1. Rename sortDF to arrange 2. Add new aliases `group_by` and `sample_frac`, `summarize` 3. Add more user friendly column addition (mutate), rename 4. Support mean as an alias for avg in Scala and also support n_distinct, n as in dplyr
  • Loading branch information
shivaram committed May 9, 2015
1 parent b6c797b commit 0a901dd
Show file tree
Hide file tree
Showing 8 changed files with 249 additions and 29 deletions.
11 changes: 9 additions & 2 deletions R/pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ export("print.jobj")

exportClasses("DataFrame")

exportMethods("cache",
exportMethods("arrange",
"cache",
"collect",
"columns",
"count",
Expand All @@ -20,6 +21,7 @@ exportMethods("cache",
"explain",
"filter",
"first",
"group_by",
"groupBy",
"head",
"insertInto",
Expand All @@ -28,12 +30,15 @@ exportMethods("cache",
"join",
"limit",
"orderBy",
"mutate",
"names",
"persist",
"printSchema",
"registerTempTable",
"rename",
"repartition",
"sampleDF",
"sample_frac",
"saveAsParquetFile",
"saveAsTable",
"saveDF",
Expand All @@ -42,7 +47,7 @@ exportMethods("cache",
"selectExpr",
"show",
"showDF",
"sortDF",
"summarize",
"take",
"unionAll",
"unpersist",
Expand Down Expand Up @@ -72,6 +77,8 @@ exportMethods("abs",
"max",
"mean",
"min",
"n",
"n_distinct",
"rlike",
"sqrt",
"startsWith",
Expand Down
127 changes: 115 additions & 12 deletions R/pkg/R/DataFrame.R
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ setMethod("distinct",
#' @param withReplacement Sampling with replacement or not
#' @param fraction The (rough) sample target fraction
#' @rdname sampleDF
#' @aliases sample_frac
#' @export
#' @examples
#'\dontrun{
Expand All @@ -501,6 +502,15 @@ setMethod("sampleDF",
dataFrame(sdf)
})

#' @rdname sampleDF
#' @aliases sampleDF
setMethod("sample_frac",
signature(x = "DataFrame", withReplacement = "logical",
fraction = "numeric"),
function(x, withReplacement, fraction) {
sampleDF(x, withReplacement, fraction)
})

#' Count
#'
#' Returns the number of rows in a DataFrame
Expand Down Expand Up @@ -682,7 +692,8 @@ setMethod("toRDD",
#' @param x a DataFrame
#' @return a GroupedData
#' @seealso GroupedData
#' @rdname DataFrame
#' @aliases group_by
#' @rdname groupBy
#' @export
#' @examples
#' \dontrun{
Expand All @@ -705,19 +716,36 @@ setMethod("groupBy",
groupedData(sgd)
})

#' Agg
#' @rdname groupBy
#' @aliases group_by
setMethod("group_by",
signature(x = "DataFrame"),
function(x, ...) {
groupBy(x, ...)
})

#' Summarize data across columns
#'
#' Compute aggregates by specifying a list of columns
#'
#' @param x a DataFrame
#' @rdname DataFrame
#' @aliases summarize
#' @export
setMethod("agg",
signature(x = "DataFrame"),
function(x, ...) {
agg(groupBy(x), ...)
})

#' @rdname DataFrame
#' @aliases agg
setMethod("summarize",
signature(x = "DataFrame"),
function(x, ...) {
agg(x, ...)
})


############################## RDD Map Functions ##################################
# All of the following functions mirror the existing RDD map functions, #
Expand Down Expand Up @@ -886,7 +914,7 @@ setMethod("select",
signature(x = "DataFrame", col = "list"),
function(x, col) {
cols <- lapply(col, function(c) {
if (class(c)== "Column") {
if (class(c) == "Column") {
c@jc
} else {
col(c)@jc
Expand Down Expand Up @@ -946,6 +974,42 @@ setMethod("withColumn",
select(x, x$"*", alias(col, colName))
})

#' Mutate
#'
#' Return a new DataFrame with the specified columns added.
#'
#' @param x A DataFrame
#' @param col a named argument of the form name = col
#' @return A new DataFrame with the new columns added.
#' @rdname withColumn
#' @aliases withColumn
#' @export
#' @examples
#'\dontrun{
#' sc <- sparkR.init()
#' sqlCtx <- sparkRSQL.init(sc)
#' path <- "path/to/file.json"
#' df <- jsonFile(sqlCtx, path)
#' newDF <- mutate(df, newCol = df$col1 * 5, newCol2 = df$col1 * 2)
#' names(newDF) # Will contain newCol, newCol2
#' }
setMethod("mutate",
signature(x = "DataFrame"),
function(x, ...) {
cols <- list(...)
stopifnot(length(cols) > 0)
stopifnot(class(cols[[1]]) == "Column")
ns <- names(cols)
if (!is.null(ns)) {
for (n in ns) {
if (n != "") {
cols[[n]] <- alias(cols[[n]], n)
}
}
}
do.call(select, c(x, x$"*", cols))
})

#' WithColumnRenamed
#'
#' Rename an existing column in a DataFrame.
Expand Down Expand Up @@ -977,29 +1041,67 @@ setMethod("withColumnRenamed",
select(x, cols)
})

#' Rename
#'
#' Rename an existing column in a DataFrame.
#'
#' @param x A DataFrame
#' @param newCol A named pair of the form new_column_name = existing_column
#' @return A DataFrame with the column name changed.
#' @rdname withColumnRenamed
#' @aliases withColumnRenamed
#' @export
#' @examples
#'\dontrun{
#' sc <- sparkR.init()
#' sqlCtx <- sparkRSQL.init(sc)
#' path <- "path/to/file.json"
#' df <- jsonFile(sqlCtx, path)
#' newDF <- rename(df, col1 = df$newCol1)
#' }
setMethod("rename",
signature(x = "DataFrame"),
function(x, ...) {
renameCols <- list(...)
stopifnot(length(renameCols) > 0)
stopifnot(class(renameCols[[1]]) == "Column")
newNames <- names(renameCols)
oldNames <- lapply(renameCols, function(col) {
callJMethod(col@jc, "toString")
})
cols <- lapply(columns(x), function(c) {
if (c %in% oldNames) {
alias(col(c), newNames[[match(c, oldNames)]])
} else {
col(c)
}
})
select(x, cols)
})

setClassUnion("characterOrColumn", c("character", "Column"))

#' SortDF
#' Arrange
#'
#' Sort a DataFrame by the specified column(s).
#'
#' @param x A DataFrame to be sorted.
#' @param col Either a Column object or character vector indicating the field to sort on
#' @param ... Additional sorting fields
#' @return A DataFrame where all elements are sorted.
#' @rdname sortDF
#' @rdname arrange
#' @export
#' @examples
#'\dontrun{
#' sc <- sparkR.init()
#' sqlCtx <- sparkRSQL.init(sc)
#' path <- "path/to/file.json"
#' df <- jsonFile(sqlCtx, path)
#' sortDF(df, df$col1)
#' sortDF(df, "col1")
#' sortDF(df, asc(df$col1), desc(abs(df$col2)))
#' arrange(df, df$col1)
#' arrange(df, "col1")
#' arrange(df, asc(df$col1), desc(abs(df$col2)))
#' }
setMethod("sortDF",
setMethod("arrange",
signature(x = "DataFrame", col = "characterOrColumn"),
function(x, col, ...) {
if (class(col) == "character") {
Expand All @@ -1013,20 +1115,20 @@ setMethod("sortDF",
dataFrame(sdf)
})

#' @rdname sortDF
#' @rdname arrange
#' @aliases orderBy,DataFrame,function-method
setMethod("orderBy",
signature(x = "DataFrame", col = "characterOrColumn"),
function(x, col) {
sortDF(x, col)
arrange(x, col)
})

#' Filter
#'
#' Filter the rows of a DataFrame according to a given condition.
#'
#' @param x A DataFrame to be sorted.
#' @param condition The condition to sort on. This may either be a Column expression
#' @param condition The condition to filter on. This may either be a Column expression
#' or a string containing a SQL statement
#' @return A DataFrame containing only the rows that meet the condition.
#' @rdname filter
Expand Down Expand Up @@ -1106,6 +1208,7 @@ setMethod("join",
#'
#' Return a new DataFrame containing the union of rows in this DataFrame
#' and another DataFrame. This is equivalent to `UNION ALL` in SQL.
#' Note that this does not remove duplicate rows across the two DataFrames.
#'
#' @param x A Spark DataFrame
#' @param y A Spark DataFrame
Expand Down
32 changes: 28 additions & 4 deletions R/pkg/R/column.R
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ createMethods()
#' alias
#'
#' Set a new name for a column

#' @rdname column
setMethod("alias",
signature(object = "Column"),
function(object, data) {
Expand All @@ -141,8 +143,12 @@ setMethod("alias",
}
})

#' substr
#'
#' An expression that returns a substring.
#'
#' @rdname column
#'
#' @param start starting position
#' @param stop ending position
setMethod("substr", signature(x = "Column"),
Expand All @@ -152,6 +158,9 @@ setMethod("substr", signature(x = "Column"),
})

#' Casts the column to a different data type.
#'
#' @rdname column
#'
#' @examples
#' \dontrun{
#' cast(df$age, "string")
Expand All @@ -173,8 +182,8 @@ setMethod("cast",

#' Approx Count Distinct
#'
#' Returns the approximate number of distinct items in a group.
#'
#' @rdname column
#' @return the approximate number of distinct items in a group.
setMethod("approxCountDistinct",
signature(x = "Column"),
function(x, rsd = 0.95) {
Expand All @@ -184,8 +193,8 @@ setMethod("approxCountDistinct",

#' Count Distinct
#'
#' returns the number of distinct items in a group.
#'
#' @rdname column
#' @return the number of distinct items in a group.
setMethod("countDistinct",
signature(x = "Column"),
function(x, ...) {
Expand All @@ -197,3 +206,18 @@ setMethod("countDistinct",
column(jc)
})

#' @rdname column
#' @aliases countDistinct
setMethod("n_distinct",
signature(x = "Column"),
function(x, ...) {
countDistinct(x, ...)
})

#' @rdname column
#' @aliases count
setMethod("n",
signature(x = "Column"),
function(x) {
count(x)
})
Loading

0 comments on commit 0a901dd

Please sign in to comment.