Permalink
Browse files

include grouping columns in agg()

add docs for groupBy() and agg()
  • Loading branch information...
davies committed Mar 12, 2015
1 parent 09ff163 commit 9a6be746efc9fafad88122fa2267862ef87aa0e1
View
@@ -697,8 +697,23 @@ setMethod("toRDD",
#'
#' Groups the DataFrame using the specified columns, so we can run aggregation on them.
#'
#' @param x a DataFrame
#' @return a GroupedData
#' @seealso GroupedData
#' @rdname DataFrame
#' @export
#' @examples
#' \dontrun {
#' # Compute the average for all numeric columns grouped by department.
#' avg(groupBy(df, "department"))
#'
#' // Compute the max age and average salary, grouped by department and gender.
#' agg(groupBy(df, "department", "gender"), salary="avg", "age" -> "max")
#' }
setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") })
#' @rdname DataFrame
#' @export
setMethod("groupBy",
signature(x = "DataFrame"),
function(x, ...) {
@@ -712,7 +727,12 @@ setMethod("groupBy",
groupedData(sgd)
})
#' Agg
#'
#' Compute aggregates by specifying a list of columns
#'
#' @rdname DataFrame
#' @export
setMethod("agg",
signature(x = "DataFrame"),
function(x, ...) {
View
@@ -1,19 +1,39 @@
############################## GroupedData ########################################
# group.R - GroupedData class and methods implemented in S4 OO classes
setOldClass("jobj")
#' @title S4 class that represents a GroupedData
#' @description GroupedDatas can be created using groupBy() on a DataFrame
#' @rdname GroupedData
#' @seealso groupBy
#'
#' @param sgd A Java object reference to the backing Scala GroupedData
#' @export
setClass("GroupedData",
slots = list(env = "environment",
sgd = "jobj"))
slots = list(sgd = "jobj"))
setMethod("initialize", "GroupedData", function(.Object, sgd) {
.Object@env <- new.env()
.Object@sgd <- sgd
.Object
})
#' @rdname DataFrame
groupedData <- function(sgd) {
new("GroupedData", sgd)
}
#' Count
#'
#' Count the number of rows for each group.
#' The resulting DataFrame will also contain the grouping columns.
#'
#' @param x a GroupedData
#' @return a DataFrame
#' @export
#' @examples
#' \dontrun {
#' }
setMethod("count",
signature(x = "GroupedData"),
function(x) {
@@ -23,9 +43,13 @@ setMethod("count",
#' Agg
#'
#' Aggregates on the entire DataFrame without groups.
#' The resulting DataFrame will also contain the grouping columns.
#'
#' df2 <- agg(df, <column> = <aggFunction>)
#' df2 <- agg(df, newColName = aggFunction(column))
#'
#' @param x a GroupedData
#' @return a DataFrame
#' @examples
#' \dontrun{
#' df2 <- agg(df, age = "sum") # new column name will be created as 'SUM(age#0)'
@@ -51,15 +75,17 @@ setMethod("agg",
}
}
jcols <- lapply(cols, function(c) { c@jc })
sdf <- callJMethod(x@sgd, "agg", jcols[[1]], listToSeq(jcols[-1]))
# the GroupedData.agg(col, cols*) API does not contain grouping Column
sdf <- callJStatic("edu.berkeley.cs.amplab.sparkr.SQLUtils", "aggWithGrouping",
x@sgd, listToSeq(jcols))
} else {
stop("agg can only support Column or character")
}
dataFrame(sdf)
})
#' sum/mean/avg/min/max
# sum/mean/avg/min/max
methods <- c("sum", "mean", "avg", "min", "max")
createMethod <- function(name) {
@@ -368,7 +368,7 @@ test_that("group by", {
expect_true(1 == count(df1))
df1 <- agg(df, age2 = max(df$age))
expect_true(1 == count(df1))
expect_true(columns(df1) == c("age2"))
expect_equal(columns(df1), c("age2"))
gd <- groupBy(df, "name")
expect_true(inherits(gd, "GroupedData"))
@@ -380,6 +380,11 @@ test_that("group by", {
expect_true(inherits(df3, "DataFrame"))
expect_true(3 == count(df3))
df3 <- agg(gd, age = sum(df$age))
expect_true(inherits(df3, "DataFrame"))
expect_true(3 == count(df3))
expect_equal(columns(df3), c("name", "age"))
df4 <- sum(gd, "age")
expect_true(inherits(df4, "DataFrame"))
expect_true(3 == count(df4))
@@ -1,13 +1,10 @@
package edu.berkeley.cs.amplab.sparkr
import java.io.ByteArrayOutputStream
import java.io.DataOutputStream
import java.io.{ByteArrayOutputStream, DataOutputStream}
import org.apache.spark.rdd.RDD
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.sql.{SQLContext, DataFrame, Row, SaveMode}
import edu.berkeley.cs.amplab.sparkr.SerDe._
import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression}
import org.apache.spark.sql.{Column, DataFrame, GroupedData, Row, SQLContext, SaveMode}
object SQLUtils {
def createSQLContext(jsc: JavaSparkContext): SQLContext = {
@@ -18,6 +15,22 @@ object SQLUtils {
arr.toSeq
}
// A helper to include grouping columns in Agg()
def aggWithGrouping(gd: GroupedData, exprs: Column*): DataFrame = {
val aggExprs = exprs.map{ col =>
val f = col.getClass.getDeclaredField("expr")
f.setAccessible(true)
val expr = f.get(col).asInstanceOf[Expression]
expr match {
case expr: NamedExpression => expr
case expr: Expression => Alias(expr, expr.simpleString)()
}
}
val toDF = gd.getClass.getDeclaredMethods.filter(f => f.getName == "toDF").head
toDF.setAccessible(true)
toDF.invoke(gd, aggExprs).asInstanceOf[DataFrame]
}
def dfToRowRDD(df: DataFrame): JavaRDD[Array[Byte]] = {
df.map(r => rowToRBytes(r))
}
@@ -80,29 +80,28 @@ class SparkRBackendHandler(server: SparkRBackend)
dis: DataInputStream,
dos: DataOutputStream) {
var obj: Object = null
var cls: Option[Class[_]] = None
try {
if (isStatic) {
cls = Some(Class.forName(objId))
val cls = if (isStatic) {
Class.forName(objId)
} else {
JVMObjectTracker.get(objId) match {
case None => throw new IllegalArgumentException("Object not found " + objId)
case Some(o) =>
cls = Some(o.getClass)
obj = o
o.getClass
}
}
val args = readArgs(numArgs, dis)
val methods = cls.get.getMethods
val methods = cls.getMethods
val selectedMethods = methods.filter(m => m.getName == methodName)
if (selectedMethods.length > 0) {
val methods = selectedMethods.filter { x =>
matchMethod(numArgs, args, x.getParameterTypes)
}
if (methods.isEmpty) {
System.err.println(s"cannot find matching method ${cls.get}.$methodName. "
System.err.println(s"cannot find matching method ${cls}.$methodName. "
+ s"Candidates are:")
selectedMethods.foreach { method =>
System.err.println(s"$methodName(${method.getParameterTypes.mkString(",")})")
@@ -116,7 +115,7 @@ class SparkRBackendHandler(server: SparkRBackend)
writeObject(dos, ret.asInstanceOf[AnyRef])
} else if (methodName == "<init>") {
// methodName should be "<init>" for constructor
val ctor = cls.get.getConstructors.filter { x =>
val ctor = cls.getConstructors.filter { x =>
matchMethod(numArgs, args, x.getParameterTypes)
}.head

0 comments on commit 9a6be74

Please sign in to comment.