Skip to content

Commit

Permalink
include grouping columns in agg()
Browse files Browse the repository at this point in the history
add docs for groupBy() and agg()
  • Loading branch information
davies committed Mar 12, 2015
1 parent 09ff163 commit 9a6be74
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 21 deletions.
22 changes: 21 additions & 1 deletion pkg/R/DataFrame.R
Original file line number Diff line number Diff line change
Expand Up @@ -697,8 +697,23 @@ setMethod("toRDD",
#'
#' Groups the DataFrame using the specified columns, so we can run aggregation on them.
#'
#' @param x a DataFrame
#' @return a GroupedData
#' @seealso GroupedData
#' @rdname DataFrame
#' @export
#' @examples
#' \dontrun {
#' # Compute the average for all numeric columns grouped by department.
#' avg(groupBy(df, "department"))
#'
#' // Compute the max age and average salary, grouped by department and gender.
#' agg(groupBy(df, "department", "gender"), salary="avg", "age" -> "max")
#' }
setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") })

#' @rdname DataFrame
#' @export
setMethod("groupBy",
signature(x = "DataFrame"),
function(x, ...) {
Expand All @@ -712,7 +727,12 @@ setMethod("groupBy",
groupedData(sgd)
})


#' Agg
#'
#' Compute aggregates by specifying a list of columns
#'
#' @rdname DataFrame
#' @export
setMethod("agg",
signature(x = "DataFrame"),
function(x, ...) {
Expand Down
38 changes: 32 additions & 6 deletions pkg/R/group.R
Original file line number Diff line number Diff line change
@@ -1,19 +1,39 @@
############################## GroupedData ########################################
# group.R - GroupedData class and methods implemented in S4 OO classes

setOldClass("jobj")

#' @title S4 class that represents a GroupedData
#' @description GroupedDatas can be created using groupBy() on a DataFrame
#' @rdname GroupedData
#' @seealso groupBy
#'
#' @param sgd A Java object reference to the backing Scala GroupedData
#' @export
setClass("GroupedData",
slots = list(env = "environment",
sgd = "jobj"))
slots = list(sgd = "jobj"))

setMethod("initialize", "GroupedData", function(.Object, sgd) {
.Object@env <- new.env()
.Object@sgd <- sgd
.Object
})

#' @rdname DataFrame
groupedData <- function(sgd) {
new("GroupedData", sgd)
}


#' Count
#'
#' Count the number of rows for each group.
#' The resulting DataFrame will also contain the grouping columns.
#'
#' @param x a GroupedData
#' @return a DataFrame
#' @export
#' @examples
#' \dontrun {
#' }
setMethod("count",
signature(x = "GroupedData"),
function(x) {
Expand All @@ -23,9 +43,13 @@ setMethod("count",
#' Agg
#'
#' Aggregates on the entire DataFrame without groups.
#' The resulting DataFrame will also contain the grouping columns.
#'
#' df2 <- agg(df, <column> = <aggFunction>)
#' df2 <- agg(df, newColName = aggFunction(column))
#'
#' @param x a GroupedData
#' @return a DataFrame
#' @examples
#' \dontrun{
#' df2 <- agg(df, age = "sum") # new column name will be created as 'SUM(age#0)'
Expand All @@ -51,15 +75,17 @@ setMethod("agg",
}
}
jcols <- lapply(cols, function(c) { c@jc })
sdf <- callJMethod(x@sgd, "agg", jcols[[1]], listToSeq(jcols[-1]))
# the GroupedData.agg(col, cols*) API does not contain grouping Column
sdf <- callJStatic("edu.berkeley.cs.amplab.sparkr.SQLUtils", "aggWithGrouping",
x@sgd, listToSeq(jcols))
} else {
stop("agg can only support Column or character")
}
dataFrame(sdf)
})

#' sum/mean/avg/min/max

# sum/mean/avg/min/max
methods <- c("sum", "mean", "avg", "min", "max")

createMethod <- function(name) {
Expand Down
7 changes: 6 additions & 1 deletion pkg/inst/tests/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ test_that("group by", {
expect_true(1 == count(df1))
df1 <- agg(df, age2 = max(df$age))
expect_true(1 == count(df1))
expect_true(columns(df1) == c("age2"))
expect_equal(columns(df1), c("age2"))

gd <- groupBy(df, "name")
expect_true(inherits(gd, "GroupedData"))
Expand All @@ -380,6 +380,11 @@ test_that("group by", {
expect_true(inherits(df3, "DataFrame"))
expect_true(3 == count(df3))

df3 <- agg(gd, age = sum(df$age))
expect_true(inherits(df3, "DataFrame"))
expect_true(3 == count(df3))
expect_equal(columns(df3), c("name", "age"))

df4 <- sum(gd, "age")
expect_true(inherits(df4, "DataFrame"))
expect_true(3 == count(df4))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
package edu.berkeley.cs.amplab.sparkr

import java.io.ByteArrayOutputStream
import java.io.DataOutputStream
import java.io.{ByteArrayOutputStream, DataOutputStream}

import org.apache.spark.rdd.RDD
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.sql.{SQLContext, DataFrame, Row, SaveMode}

import edu.berkeley.cs.amplab.sparkr.SerDe._
import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression}
import org.apache.spark.sql.{Column, DataFrame, GroupedData, Row, SQLContext, SaveMode}

object SQLUtils {
def createSQLContext(jsc: JavaSparkContext): SQLContext = {
Expand All @@ -18,6 +15,22 @@ object SQLUtils {
arr.toSeq
}

// A helper to include grouping columns in Agg()
def aggWithGrouping(gd: GroupedData, exprs: Column*): DataFrame = {
val aggExprs = exprs.map{ col =>
val f = col.getClass.getDeclaredField("expr")
f.setAccessible(true)
val expr = f.get(col).asInstanceOf[Expression]
expr match {
case expr: NamedExpression => expr
case expr: Expression => Alias(expr, expr.simpleString)()
}
}
val toDF = gd.getClass.getDeclaredMethods.filter(f => f.getName == "toDF").head
toDF.setAccessible(true)
toDF.invoke(gd, aggExprs).asInstanceOf[DataFrame]
}

def dfToRowRDD(df: DataFrame): JavaRDD[Array[Byte]] = {
df.map(r => rowToRBytes(r))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,29 +80,28 @@ class SparkRBackendHandler(server: SparkRBackend)
dis: DataInputStream,
dos: DataOutputStream) {
var obj: Object = null
var cls: Option[Class[_]] = None
try {
if (isStatic) {
cls = Some(Class.forName(objId))
val cls = if (isStatic) {
Class.forName(objId)
} else {
JVMObjectTracker.get(objId) match {
case None => throw new IllegalArgumentException("Object not found " + objId)
case Some(o) =>
cls = Some(o.getClass)
obj = o
o.getClass
}
}

val args = readArgs(numArgs, dis)

val methods = cls.get.getMethods
val methods = cls.getMethods
val selectedMethods = methods.filter(m => m.getName == methodName)
if (selectedMethods.length > 0) {
val methods = selectedMethods.filter { x =>
matchMethod(numArgs, args, x.getParameterTypes)
}
if (methods.isEmpty) {
System.err.println(s"cannot find matching method ${cls.get}.$methodName. "
System.err.println(s"cannot find matching method ${cls}.$methodName. "
+ s"Candidates are:")
selectedMethods.foreach { method =>
System.err.println(s"$methodName(${method.getParameterTypes.mkString(",")})")
Expand All @@ -116,7 +115,7 @@ class SparkRBackendHandler(server: SparkRBackend)
writeObject(dos, ret.asInstanceOf[AnyRef])
} else if (methodName == "<init>") {
// methodName should be "<init>" for constructor
val ctor = cls.get.getConstructors.filter { x =>
val ctor = cls.getConstructors.filter { x =>
matchMethod(numArgs, args, x.getParameterTypes)
}.head

Expand Down

0 comments on commit 9a6be74

Please sign in to comment.