Skip to content

Commit

Permalink
[SPARK-10048] [SPARKR] Support arbitrary nested Java array in serde.
Browse files Browse the repository at this point in the history
This PR:
1. supports transferring arbitrary nested array from JVM to R side in SerDe;
2. based on 1, collect() implemenation is improved. Now it can support collecting data of complex types
   from a DataFrame.

Author: Sun Rui <rui.sun@intel.com>

Closes #8276 from sun-rui/SPARK-10048.
  • Loading branch information
Sun Rui authored and shivaram committed Aug 25, 2015
1 parent 16a2be1 commit 71a138c
Show file tree
Hide file tree
Showing 8 changed files with 216 additions and 127 deletions.
55 changes: 43 additions & 12 deletions R/pkg/R/DataFrame.R
Original file line number Diff line number Diff line change
Expand Up @@ -652,18 +652,49 @@ setMethod("dim",
setMethod("collect",
signature(x = "DataFrame"),
function(x, stringsAsFactors = FALSE) {
# listCols is a list of raw vectors, one per column
listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToCols", x@sdf)
cols <- lapply(listCols, function(col) {
objRaw <- rawConnection(col)
numRows <- readInt(objRaw)
col <- readCol(objRaw, numRows)
close(objRaw)
col
})
names(cols) <- columns(x)
do.call(cbind.data.frame, list(cols, stringsAsFactors = stringsAsFactors))
})
names <- columns(x)
ncol <- length(names)
if (ncol <= 0) {
# empty data.frame with 0 columns and 0 rows
data.frame()
} else {
# listCols is a list of columns
listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToCols", x@sdf)
stopifnot(length(listCols) == ncol)

# An empty data.frame with 0 columns and number of rows as collected
nrow <- length(listCols[[1]])
if (nrow <= 0) {
df <- data.frame()
} else {
df <- data.frame(row.names = 1 : nrow)
}

# Append columns one by one
for (colIndex in 1 : ncol) {
# Note: appending a column of list type into a data.frame so that
# data of complex type can be held. But getting a cell from a column
# of list type returns a list instead of a vector. So for columns of
# non-complex type, append them as vector.
col <- listCols[[colIndex]]
if (length(col) <= 0) {
df[[names[colIndex]]] <- col
} else {
# TODO: more robust check on column of primitive types
vec <- do.call(c, col)
if (class(vec) != "list") {
df[[names[colIndex]]] <- vec
} else {
# For columns of complex type, be careful to access them.
# Get a column of complex type returns a list.
# Get a cell from a column of complex type returns a list instead of a vector.
df[[names[colIndex]]] <- col
}
}
}
df
}
})

#' Limit
#'
Expand Down
72 changes: 31 additions & 41 deletions R/pkg/R/deserialize.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ readTypedObject <- function(con, type) {
"r" = readRaw(con),
"D" = readDate(con),
"t" = readTime(con),
"a" = readArray(con),
"l" = readList(con),
"n" = NULL,
"j" = getJobj(readString(con)),
Expand Down Expand Up @@ -85,8 +86,7 @@ readTime <- function(con) {
as.POSIXct(t, origin = "1970-01-01")
}

# We only support lists where all elements are of same type
readList <- function(con) {
readArray <- function(con) {
type <- readType(con)
len <- readInt(con)
if (len > 0) {
Expand All @@ -100,6 +100,25 @@ readList <- function(con) {
}
}

# Read a list. Types of each element may be different.
# Null objects are read as NA.
readList <- function(con) {
len <- readInt(con)
if (len > 0) {
l <- vector("list", len)
for (i in 1:len) {
elem <- readObject(con)
if (is.null(elem)) {
elem <- NA
}
l[[i]] <- elem
}
l
} else {
list()
}
}

readRaw <- function(con) {
dataLen <- readInt(con)
readBin(con, raw(), as.integer(dataLen), endian = "big")
Expand Down Expand Up @@ -132,18 +151,19 @@ readDeserialize <- function(con) {
}
}

readDeserializeRows <- function(inputCon) {
# readDeserializeRows will deserialize a DataOutputStream composed of
# a list of lists. Since the DOS is one continuous stream and
# the number of rows varies, we put the readRow function in a while loop
# that termintates when the next row is empty.
readMultipleObjects <- function(inputCon) {
# readMultipleObjects will read multiple continuous objects from
# a DataOutputStream. There is no preceding field telling the count
# of the objects, so the number of objects varies, we try to read
# all objects in a loop until the end of the stream.
data <- list()
while(TRUE) {
row <- readRow(inputCon)
if (length(row) == 0) {
# If reaching the end of the stream, type returned should be "".
type <- readType(inputCon)
if (type == "") {
break
}
data[[length(data) + 1L]] <- row
data[[length(data) + 1L]] <- readTypedObject(inputCon, type)
}
data # this is a list of named lists now
}
Expand All @@ -155,35 +175,5 @@ readRowList <- function(obj) {
# deserialize the row.
rawObj <- rawConnection(obj, "r+")
on.exit(close(rawObj))
readRow(rawObj)
}

readRow <- function(inputCon) {
numCols <- readInt(inputCon)
if (length(numCols) > 0 && numCols > 0) {
lapply(1:numCols, function(x) {
obj <- readObject(inputCon)
if (is.null(obj)) {
NA
} else {
obj
}
}) # each row is a list now
} else {
list()
}
}

# Take a single column as Array[Byte] and deserialize it into an atomic vector
readCol <- function(inputCon, numRows) {
if (numRows > 0) {
# sapply can not work with POSIXlt
do.call(c, lapply(1:numRows, function(x) {
value <- readObject(inputCon)
# Replace NULL with NA so we can coerce to vectors
if (is.null(value)) NA else value
}))
} else {
vector()
}
readObject(rawObj)
}
10 changes: 1 addition & 9 deletions R/pkg/R/serialize.R
Original file line number Diff line number Diff line change
Expand Up @@ -110,18 +110,10 @@ writeRowSerialize <- function(outputCon, rows) {
serializeRow <- function(row) {
rawObj <- rawConnection(raw(0), "wb")
on.exit(close(rawObj))
writeRow(rawObj, row)
writeGenericList(rawObj, row)
rawConnectionValue(rawObj)
}

writeRow <- function(con, row) {
numCols <- length(row)
writeInt(con, numCols)
for (i in 1:numCols) {
writeObject(con, row[[i]])
}
}

writeRaw <- function(con, batch) {
writeInt(con, length(batch))
writeBin(batch, con, endian = "big")
Expand Down
77 changes: 77 additions & 0 deletions R/pkg/inst/tests/test_Serde.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

context("SerDe functionality")

sc <- sparkR.init()

test_that("SerDe of primitive types", {
x <- callJStatic("SparkRHandler", "echo", 1L)
expect_equal(x, 1L)
expect_equal(class(x), "integer")

x <- callJStatic("SparkRHandler", "echo", 1)
expect_equal(x, 1)
expect_equal(class(x), "numeric")

x <- callJStatic("SparkRHandler", "echo", TRUE)
expect_true(x)
expect_equal(class(x), "logical")

x <- callJStatic("SparkRHandler", "echo", "abc")
expect_equal(x, "abc")
expect_equal(class(x), "character")
})

test_that("SerDe of list of primitive types", {
x <- list(1L, 2L, 3L)
y <- callJStatic("SparkRHandler", "echo", x)
expect_equal(x, y)
expect_equal(class(y[[1]]), "integer")

x <- list(1, 2, 3)
y <- callJStatic("SparkRHandler", "echo", x)
expect_equal(x, y)
expect_equal(class(y[[1]]), "numeric")

x <- list(TRUE, FALSE)
y <- callJStatic("SparkRHandler", "echo", x)
expect_equal(x, y)
expect_equal(class(y[[1]]), "logical")

x <- list("a", "b", "c")
y <- callJStatic("SparkRHandler", "echo", x)
expect_equal(x, y)
expect_equal(class(y[[1]]), "character")

# Empty list
x <- list()
y <- callJStatic("SparkRHandler", "echo", x)
expect_equal(x, y)
})

test_that("SerDe of list of lists", {
x <- list(list(1L, 2L, 3L), list(1, 2, 3),
list(TRUE, FALSE), list("a", "b", "c"))
y <- callJStatic("SparkRHandler", "echo", x)
expect_equal(x, y)

# List of empty lists
x <- list(list(), list())
y <- callJStatic("SparkRHandler", "echo", x)
expect_equal(x, y)
})
4 changes: 2 additions & 2 deletions R/pkg/inst/worker/worker.R
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ if (isEmpty != 0) {
} else if (deserializer == "string") {
data <- as.list(readLines(inputCon))
} else if (deserializer == "row") {
data <- SparkR:::readDeserializeRows(inputCon)
data <- SparkR:::readMultipleObjects(inputCon)
}
# Timing reading input data for execution
inputElap <- elapsedSecs()
Expand All @@ -120,7 +120,7 @@ if (isEmpty != 0) {
} else if (deserializer == "string") {
data <- readLines(inputCon)
} else if (deserializer == "row") {
data <- SparkR:::readDeserializeRows(inputCon)
data <- SparkR:::readMultipleObjects(inputCon)
}
# Timing reading input data for execution
inputElap <- elapsedSecs()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ private[r] class RBackendHandler(server: RBackend)

if (objId == "SparkRHandler") {
methodName match {
// This function is for test-purpose only
case "echo" =>
val args = readArgs(numArgs, dis)
assert(numArgs == 1)

writeInt(dos, 0)
writeObject(dos, args(0))
case "stopBackend" =>
writeInt(dos, 0)
writeType(dos, "void")
Expand Down
Loading

0 comments on commit 71a138c

Please sign in to comment.