Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-10049][SPARKR] Support collecting data of ArraryType in DataFrame. #8458

Closed
wants to merge 10 commits into from
34 changes: 17 additions & 17 deletions R/pkg/R/DataFrame.R
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ setMethod("names<-",
signature(x = "DataFrame"),
function(x, value) {
if (!is.null(value)) {
sdf <- callJMethod(x@sdf, "toDF", listToSeq(as.list(value)))
sdf <- callJMethod(x@sdf, "toDF", as.list(value))
dataFrame(sdf)
}
})
Expand Down Expand Up @@ -661,15 +661,15 @@ setMethod("collect",
# listCols is a list of columns
listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToCols", x@sdf)
stopifnot(length(listCols) == ncol)

# An empty data.frame with 0 columns and number of rows as collected
nrow <- length(listCols[[1]])
if (nrow <= 0) {
df <- data.frame()
} else {
df <- data.frame(row.names = 1 : nrow)
df <- data.frame(row.names = 1 : nrow)
}

# Append columns one by one
for (colIndex in 1 : ncol) {
# Note: appending a column of list type into a data.frame so that
Expand All @@ -683,7 +683,7 @@ setMethod("collect",
# TODO: more robust check on column of primitive types
vec <- do.call(c, col)
if (class(vec) != "list") {
df[[names[colIndex]]] <- vec
df[[names[colIndex]]] <- vec
} else {
# For columns of complex type, be careful to access them.
# Get a column of complex type returns a list.
Expand Down Expand Up @@ -843,10 +843,10 @@ setMethod("groupBy",
function(x, ...) {
cols <- list(...)
if (length(cols) >= 1 && class(cols[[1]]) == "character") {
sgd <- callJMethod(x@sdf, "groupBy", cols[[1]], listToSeq(cols[-1]))
sgd <- callJMethod(x@sdf, "groupBy", cols[[1]], cols[-1])
} else {
jcol <- lapply(cols, function(c) { c@jc })
sgd <- callJMethod(x@sdf, "groupBy", listToSeq(jcol))
sgd <- callJMethod(x@sdf, "groupBy", jcol)
}
groupedData(sgd)
})
Expand Down Expand Up @@ -1053,7 +1053,7 @@ setMethod("[", signature(x = "DataFrame", i = "Column"),
#' }
setMethod("select", signature(x = "DataFrame", col = "character"),
function(x, col, ...) {
sdf <- callJMethod(x@sdf, "select", col, toSeq(...))
sdf <- callJMethod(x@sdf, "select", col, list(...))
dataFrame(sdf)
})

Expand All @@ -1064,7 +1064,7 @@ setMethod("select", signature(x = "DataFrame", col = "Column"),
jcols <- lapply(list(col, ...), function(c) {
c@jc
})
sdf <- callJMethod(x@sdf, "select", listToSeq(jcols))
sdf <- callJMethod(x@sdf, "select", jcols)
dataFrame(sdf)
})

Expand All @@ -1080,7 +1080,7 @@ setMethod("select",
col(c)@jc
}
})
sdf <- callJMethod(x@sdf, "select", listToSeq(cols))
sdf <- callJMethod(x@sdf, "select", cols)
dataFrame(sdf)
})

Expand All @@ -1107,7 +1107,7 @@ setMethod("selectExpr",
signature(x = "DataFrame", expr = "character"),
function(x, expr, ...) {
exprList <- list(expr, ...)
sdf <- callJMethod(x@sdf, "selectExpr", listToSeq(exprList))
sdf <- callJMethod(x@sdf, "selectExpr", exprList)
dataFrame(sdf)
})

Expand Down Expand Up @@ -1272,12 +1272,12 @@ setMethod("arrange",
signature(x = "DataFrame", col = "characterOrColumn"),
function(x, col, ...) {
if (class(col) == "character") {
sdf <- callJMethod(x@sdf, "sort", col, toSeq(...))
sdf <- callJMethod(x@sdf, "sort", col, list(...))
} else if (class(col) == "Column") {
jcols <- lapply(list(col, ...), function(c) {
c@jc
})
sdf <- callJMethod(x@sdf, "sort", listToSeq(jcols))
sdf <- callJMethod(x@sdf, "sort", jcols)
}
dataFrame(sdf)
})
Expand Down Expand Up @@ -1624,7 +1624,7 @@ setMethod("describe",
signature(x = "DataFrame", col = "character"),
function(x, col, ...) {
colList <- list(col, ...)
sdf <- callJMethod(x@sdf, "describe", listToSeq(colList))
sdf <- callJMethod(x@sdf, "describe", colList)
dataFrame(sdf)
})

Expand All @@ -1634,7 +1634,7 @@ setMethod("describe",
signature(x = "DataFrame"),
function(x) {
colList <- as.list(c(columns(x)))
sdf <- callJMethod(x@sdf, "describe", listToSeq(colList))
sdf <- callJMethod(x@sdf, "describe", colList)
dataFrame(sdf)
})

Expand Down Expand Up @@ -1691,7 +1691,7 @@ setMethod("dropna",

naFunctions <- callJMethod(x@sdf, "na")
sdf <- callJMethod(naFunctions, "drop",
as.integer(minNonNulls), listToSeq(as.list(cols)))
as.integer(minNonNulls), as.list(cols))
dataFrame(sdf)
})

Expand Down Expand Up @@ -1775,7 +1775,7 @@ setMethod("fillna",
sdf <- if (length(cols) == 0) {
callJMethod(naFunctions, "fill", value)
} else {
callJMethod(naFunctions, "fill", value, listToSeq(as.list(cols)))
callJMethod(naFunctions, "fill", value, as.list(cols))
}
dataFrame(sdf)
})
Expand Down
4 changes: 2 additions & 2 deletions R/pkg/R/SQLContext.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ infer_type <- function(x) {
stopifnot(length(x) > 0)
names <- names(x)
if (is.null(names)) {
list(type = "array", elementType = infer_type(x[[1]]), containsNull = TRUE)
paste0("array<", infer_type(x[[1]]), ">")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify this is to support vectors of the form c(1, 2, 3) etc. ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for list. Next changed one is for vector.

} else {
# StructType
types <- lapply(x, infer_type)
Expand All @@ -59,7 +59,7 @@ infer_type <- function(x) {
do.call(structType, fields)
}
} else if (length(x) > 1) {
list(type = "array", elementType = type, containsNull = TRUE)
paste0("array<", infer_type(x[[1]]), ">")
} else {
type
}
Expand Down
3 changes: 1 addition & 2 deletions R/pkg/R/column.R
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,7 @@ setMethod("cast",
setMethod("%in%",
signature(x = "Column"),
function(x, table) {
table <- listToSeq(as.list(table))
jc <- callJMethod(x@jc, "in", table)
jc <- callJMethod(x@jc, "in", as.list(table))
return(column(jc))
})

Expand Down
12 changes: 6 additions & 6 deletions R/pkg/R/functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -1331,7 +1331,7 @@ setMethod("countDistinct",
x@jc
})
jc <- callJStatic("org.apache.spark.sql.functions", "countDistinct", x@jc,
listToSeq(jcol))
jcol)
column(jc)
})

Expand All @@ -1348,7 +1348,7 @@ setMethod("concat",
signature(x = "Column"),
function(x, ...) {
jcols <- lapply(list(x, ...), function(x) { x@jc })
jc <- callJStatic("org.apache.spark.sql.functions", "concat", listToSeq(jcols))
jc <- callJStatic("org.apache.spark.sql.functions", "concat", jcols)
column(jc)
})

Expand All @@ -1366,7 +1366,7 @@ setMethod("greatest",
function(x, ...) {
stopifnot(length(list(...)) > 0)
jcols <- lapply(list(x, ...), function(x) { x@jc })
jc <- callJStatic("org.apache.spark.sql.functions", "greatest", listToSeq(jcols))
jc <- callJStatic("org.apache.spark.sql.functions", "greatest", jcols)
column(jc)
})

Expand All @@ -1384,7 +1384,7 @@ setMethod("least",
function(x, ...) {
stopifnot(length(list(...)) > 0)
jcols <- lapply(list(x, ...), function(x) { x@jc })
jc <- callJStatic("org.apache.spark.sql.functions", "least", listToSeq(jcols))
jc <- callJStatic("org.apache.spark.sql.functions", "least", jcols)
column(jc)
})

Expand Down Expand Up @@ -1675,7 +1675,7 @@ setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"),
#' @export
setMethod("concat_ws", signature(sep = "character", x = "Column"),
function(sep, x, ...) {
jcols <- listToSeq(lapply(list(x, ...), function(x) { x@jc }))
jcols <- lapply(list(x, ...), function(x) { x@jc })
jc <- callJStatic("org.apache.spark.sql.functions", "concat_ws", sep, jcols)
column(jc)
})
Expand Down Expand Up @@ -1723,7 +1723,7 @@ setMethod("expr", signature(x = "character"),
#' @export
setMethod("format_string", signature(format = "character", x = "Column"),
function(format, x, ...) {
jcols <- listToSeq(lapply(list(x, ...), function(arg) { arg@jc }))
jcols <- lapply(list(x, ...), function(arg) { arg@jc })
jc <- callJStatic("org.apache.spark.sql.functions",
"format_string",
format, jcols)
Expand Down
4 changes: 2 additions & 2 deletions R/pkg/R/group.R
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ setMethod("agg",
}
}
jcols <- lapply(cols, function(c) { c@jc })
sdf <- callJMethod(x@sgd, "agg", jcols[[1]], listToSeq(jcols[-1]))
sdf <- callJMethod(x@sgd, "agg", jcols[[1]], jcols[-1])
} else {
stop("agg can only support Column or character")
}
Expand All @@ -124,7 +124,7 @@ createMethod <- function(name) {
setMethod(name,
signature(x = "GroupedData"),
function(x, ...) {
sdf <- callJMethod(x@sgd, name, toSeq(...))
sdf <- callJMethod(x@sgd, name, list(...))
dataFrame(sdf)
})
}
Expand Down
54 changes: 34 additions & 20 deletions R/pkg/R/schema.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ structType.structField <- function(x, ...) {
})
stObj <- callJStatic("org.apache.spark.sql.api.r.SQLUtils",
"createStructType",
listToSeq(sfObjList))
sfObjList)
structType(stObj)
}

Expand Down Expand Up @@ -114,6 +114,35 @@ structField.jobj <- function(x) {
obj
}

checkType <- function(type) {
primtiveTypes <- c("byte",
"integer",
"float",
"double",
"numeric",
"character",
"string",
"binary",
"raw",
"logical",
"boolean",
"timestamp",
"date")
if (type %in% primtiveTypes) {
return()
} else {
m <- regexec("^array<(.*)>$", type)
matchedStrings <- regmatches(type, m)
if (length(matchedStrings[[1]]) >= 2) {
elemType <- matchedStrings[[1]][2]
checkType(elemType)
return()
}
}

stop(paste("Unsupported type for Dataframe:", type))
}

structField.character <- function(x, type, nullable = TRUE) {
if (class(x) != "character") {
stop("Field name must be a string.")
Expand All @@ -124,28 +153,13 @@ structField.character <- function(x, type, nullable = TRUE) {
if (class(nullable) != "logical") {
stop("nullable must be either TRUE or FALSE")
}
options <- c("byte",
"integer",
"float",
"double",
"numeric",
"character",
"string",
"binary",
"raw",
"logical",
"boolean",
"timestamp",
"date")
dataType <- if (type %in% options) {
type
} else {
stop(paste("Unsupported type for Dataframe:", type))
}

checkType(type)

sfObj <- callJStatic("org.apache.spark.sql.api.r.SQLUtils",
"createStructField",
x,
dataType,
type,
nullable)
structField(sfObj)
}
Expand Down
10 changes: 0 additions & 10 deletions R/pkg/R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -360,16 +360,6 @@ numToInt <- function(num) {
as.integer(num)
}

# create a Seq in JVM
toSeq <- function(...) {
callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", list(...))
}

# create a Seq in JVM from a list
listToSeq <- function(l) {
callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", l)
}

# Utility function to recursively traverse the Abstract Syntax Tree (AST) of a
# user defined function (UDF), and to examine variables in the UDF to decide
# if their values should be included in the new function environment.
Expand Down
12 changes: 6 additions & 6 deletions R/pkg/inst/tests/test_Serde.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,18 @@ test_that("SerDe of primitive types", {
x <- callJStatic("SparkRHandler", "echo", 1L)
expect_equal(x, 1L)
expect_equal(class(x), "integer")

x <- callJStatic("SparkRHandler", "echo", 1)
expect_equal(x, 1)
expect_equal(class(x), "numeric")

x <- callJStatic("SparkRHandler", "echo", TRUE)
expect_true(x)
expect_equal(class(x), "logical")

x <- callJStatic("SparkRHandler", "echo", "abc")
expect_equal(x, "abc")
expect_equal(class(x), "character")
expect_equal(class(x), "character")
})

test_that("SerDe of list of primitive types", {
Expand All @@ -47,17 +47,17 @@ test_that("SerDe of list of primitive types", {
y <- callJStatic("SparkRHandler", "echo", x)
expect_equal(x, y)
expect_equal(class(y[[1]]), "numeric")

x <- list(TRUE, FALSE)
y <- callJStatic("SparkRHandler", "echo", x)
expect_equal(x, y)
expect_equal(class(y[[1]]), "logical")

x <- list("a", "b", "c")
y <- callJStatic("SparkRHandler", "echo", x)
expect_equal(x, y)
expect_equal(class(y[[1]]), "character")

# Empty list
x <- list()
y <- callJStatic("SparkRHandler", "echo", x)
Expand Down
Loading