Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-26830][SQL][R] Vectorized R dapply() implementation #23787

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
23 changes: 23 additions & 0 deletions R/pkg/R/DataFrame.R
Expand Up @@ -1437,6 +1437,29 @@ dapplyInternal <- function(x, func, schema) {
schema <- structType(schema)
}

arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]] == "true"
HyukjinKwon marked this conversation as resolved.
Show resolved Hide resolved
if (arrowEnabled) {
requireNamespace1 <- requireNamespace
if (!requireNamespace1("arrow", quietly = TRUE)) {
stop("'arrow' package should be installed.")
}
# Currenty Arrow optimization does not support raw for now.
# Also, it does not support explicit float type set by users.
if (inherits(schema, "structType")) {
if (any(sapply(schema$fields(), function(x) x$dataType.toString() == "FloatType"))) {
stop("Arrow optimization with dapply do not support FloatType yet.")
}
if (any(sapply(schema$fields(), function(x) x$dataType.toString() == "BinaryType"))) {
stop("Arrow optimization with dapply do not support BinaryType yet.")
}
} else if (is.null(schema)) {
stop(paste0("Arrow optimization does not support 'dapplyCollect' yet. Please disable ",
"Arrow optimization or use 'collect' and 'dapply' APIs instead."))
} else {
stop("'schema' should be DDL-formatted string or structType.")
}
}

packageNamesArr <- serialize(.sparkREnv[[".packages"]],
connection = NULL)

Expand Down
16 changes: 10 additions & 6 deletions R/pkg/R/deserialize.R
Expand Up @@ -247,17 +247,21 @@ readDeserializeInArrow <- function(inputCon) {
batches <- RecordBatchStreamReader(arrowData)$batches()

# Read all groupped batches. Tibble -> data.frame is cheap.
data <- lapply(batches, function(batch) as.data.frame(as_tibble(batch)))

# Read keys to map with each groupped batch.
keys <- readMultipleObjects(inputCon)

list(keys = keys, data = data)
lapply(batches, function(batch) as.data.frame(as_tibble(batch)))
} else {
stop("'arrow' package should be installed.")
}
}

readDeserializeWithKeysInArrow <- function(inputCon) {
data <- readDeserializeInArrow(inputCon)

keys <- readMultipleObjects(inputCon)

# Read keys to map with each groupped batch later.
list(keys = keys, data = data)
}

readRowList <- function(obj) {
# readRowList is meant for use inside an lapply. As a result, it is
# necessary to open a standalone connection for the row and consume
Expand Down
15 changes: 15 additions & 0 deletions R/pkg/R/serialize.R
Expand Up @@ -220,3 +220,18 @@ writeArgs <- function(con, args) {
}
}
}

writeSerializeInArrow <- function(conn, df) {
# This is a hack to avoid CRAN check. Arrow is not uploaded into CRAN now. See ARROW-3204.
requireNamespace1 <- requireNamespace
if (requireNamespace1("arrow", quietly = TRUE)) {
write_arrow <- get("write_arrow", envir = asNamespace("arrow"), inherits = FALSE)

# There looks no way to send each batch in streaming format via socket
# connection. See ARROW-4512.
# So, it writes the whole Arrow streaming-formatted binary at once for now.
writeRaw(conn, write_arrow(df, raw()))
} else {
stop("'arrow' package should be installed.")
}
}
33 changes: 15 additions & 18 deletions R/pkg/inst/worker/worker.R
Expand Up @@ -76,6 +76,8 @@ outputResult <- function(serializer, output, outputCon) {
SparkR:::writeRawSerialize(outputCon, output)
} else if (serializer == "row") {
SparkR:::writeRowSerialize(outputCon, output)
} else if (serializer == "arrow") {
SparkR:::writeSerializeInArrow(outputCon, output)
} else {
# write lines one-by-one with flag
lapply(output, function(line) SparkR:::writeString(outputCon, line))
Expand Down Expand Up @@ -172,9 +174,15 @@ if (isEmpty != 0) {
} else if (deserializer == "row") {
data <- SparkR:::readMultipleObjects(inputCon)
} else if (deserializer == "arrow" && mode == 2) {
dataWithKeys <- SparkR:::readDeserializeInArrow(inputCon)
dataWithKeys <- SparkR:::readDeserializeWithKeysInArrow(inputCon)
keys <- dataWithKeys$keys
data <- dataWithKeys$data
} else if (deserializer == "arrow" && mode == 1) {
data <- SparkR:::readDeserializeInArrow(inputCon)
# See https://stat.ethz.ch/pipermail/r-help/2010-September/252046.html
# rbind.fill might be an anternative to make it faster if plyr is installed.
# Also, note that, 'dapply' applies a function to each partition.
data <- do.call("rbind", data)
}

# Timing reading input data for execution
Expand All @@ -192,7 +200,7 @@ if (isEmpty != 0) {
output <- compute(mode, partition, serializer, deserializer, keys[[i]],
colNames, computeFunc, data[[i]])
computeElap <- elapsedSecs()
if (deserializer == "arrow") {
if (serializer == "arrow") {
outputs[[length(outputs) + 1L]] <- output
} else {
outputResult(serializer, output, outputCon)
Expand All @@ -202,22 +210,11 @@ if (isEmpty != 0) {
outputComputeElapsDiff <- outputComputeElapsDiff + (outputElap - computeElap)
}

if (deserializer == "arrow") {
# This is a hack to avoid CRAN check. Arrow is not uploaded into CRAN now. See ARROW-3204.
requireNamespace1 <- requireNamespace
if (requireNamespace1("arrow", quietly = TRUE)) {
write_arrow <- get("write_arrow", envir = asNamespace("arrow"), inherits = FALSE)
# See https://stat.ethz.ch/pipermail/r-help/2010-September/252046.html
# rbind.fill might be an anternative to make it faster if plyr is installed.
combined <- do.call("rbind", outputs)

# Likewise, there looks no way to send each batch in streaming format via socket
# connection. See ARROW-4512.
# So, it writes the whole Arrow streaming-formatted binary at once for now.
SparkR:::writeRaw(outputCon, write_arrow(combined, raw()))
} else {
stop("'arrow' package should be installed.")
}
if (serializer == "arrow") {
# See https://stat.ethz.ch/pipermail/r-help/2010-September/252046.html
# rbind.fill might be an anternative to make it faster if plyr is installed.
combined <- do.call("rbind", outputs)
SparkR:::writeSerializeInArrow(outputCon, combined)
}
}
} else {
Expand Down
99 changes: 99 additions & 0 deletions R/pkg/tests/fulltests/test_sparkSQL.R
Expand Up @@ -3300,6 +3300,105 @@ test_that("dapplyCollect() on DataFrame with a binary column", {

})

test_that("dapply() Arrow optimization", {
skip_if_not_installed("arrow")
df <- createDataFrame(mtcars)

conf <- callJMethod(sparkSession, "conf")
arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]]

callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", "false")
tryCatch({
ret <- dapply(df,
function(rdf) {
stopifnot(class(rdf) == "data.frame")
rdf
},
schema(df))
expected <- collect(ret)
},
finally = {
# Resetting the conf back to default value
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", arrowEnabled)
})

callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", "true")
tryCatch({
ret <- dapply(df,
function(rdf) {
stopifnot(class(rdf) == "data.frame")
# mtcars' hp is more then 50.
stopifnot(all(rdf$hp > 50))
rdf
},
schema(df))
actual <- collect(ret)
expect_equal(actual, expected)
expect_equal(count(ret), nrow(mtcars))
},
finally = {
# Resetting the conf back to default value
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", arrowEnabled)
})
})

test_that("dapply() Arrow optimization - type specification", {
skip_if_not_installed("arrow")
# Note that regular dapply() seems not supporting date and timestamps
# whereas Arrow-optimized dapply() does.
rdf <- data.frame(list(list(a = 1,
b = "a",
c = TRUE,
d = 1.1,
e = 1L)))
# numPartitions are set to 8 intentionally to test empty partitions as well.
df <- createDataFrame(rdf, numPartitions = 8)

conf <- callJMethod(sparkSession, "conf")
arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]]

callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", "false")
tryCatch({
ret <- dapply(df, function(rdf) { rdf }, schema(df))
expected <- collect(ret)
},
finally = {
# Resetting the conf back to default value
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", arrowEnabled)
})

callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", "true")
tryCatch({
ret <- dapply(df, function(rdf) { rdf }, schema(df))
actual <- collect(ret)
expect_equal(actual, expected)
},
finally = {
# Resetting the conf back to default value
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", arrowEnabled)
})
})

test_that("dapply() Arrow optimization - type specification (date and timestamp)", {
skip_if_not_installed("arrow")
rdf <- data.frame(list(list(a = as.Date("1990-02-24"),
b = as.POSIXct("1990-02-24 12:34:56"))))
HyukjinKwon marked this conversation as resolved.
Show resolved Hide resolved
df <- createDataFrame(rdf)

conf <- callJMethod(sparkSession, "conf")
arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]]

callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", "true")
tryCatch({
ret <- dapply(df, function(rdf) { rdf }, schema(df))
expect_equal(collect(ret), rdf)
},
finally = {
# Resetting the conf back to default value
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", arrowEnabled)
})
})

test_that("repartition by columns on DataFrame", {
# The tasks here launch R workers with shuffles. So, we decrease the number of shuffle
# partitions to reduce the number of the tasks to speed up the test. This is particularly
Expand Down
Expand Up @@ -123,16 +123,25 @@ object MapPartitionsInR {
schema: StructType,
encoder: ExpressionEncoder[Row],
child: LogicalPlan): LogicalPlan = {
val deserialized = CatalystSerde.deserialize(child)(encoder)
val mapped = MapPartitionsInR(
func,
packageNames,
broadcastVars,
encoder.schema,
schema,
CatalystSerde.generateObjAttr(RowEncoder(schema)),
deserialized)
CatalystSerde.serialize(mapped)(RowEncoder(schema))
if (SQLConf.get.arrowEnabled) {
MapPartitionsInRWithArrow(
func,
packageNames,
broadcastVars,
encoder.schema,
schema.toAttributes,
child)
} else {
val deserialized = CatalystSerde.deserialize(child)(encoder)
CatalystSerde.serialize(MapPartitionsInR(
func,
packageNames,
broadcastVars,
encoder.schema,
schema,
CatalystSerde.generateObjAttr(RowEncoder(schema)),
deserialized))(RowEncoder(schema))
}
}
}

Expand All @@ -154,6 +163,28 @@ case class MapPartitionsInR(
outputObjAttr, child)
}

/**
* Similar with `MapPartitionsInR` but serializes and deserializes input/output in
* Arrow format.
*
* This is somewhat similar with `org.apache.spark.sql.execution.python.ArrowEvalPython`
*/
case class MapPartitionsInRWithArrow(
func: Array[Byte],
packageNames: Array[Byte],
broadcastVars: Array[Broadcast[Object]],
inputSchema: StructType,
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
// This operator always need all columns of its child, even it doesn't reference to.
override def references: AttributeSet = child.outputSet

override protected def stringArgs: Iterator[Any] = Iterator(
inputSchema, StructType.fromAttributes(output), child)

override val producedAttributes = AttributeSet(output)
}

object MapElements {
def apply[T : Encoder, U : Encoder](
func: AnyRef,
Expand Down
Expand Up @@ -599,6 +599,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.FlatMapGroupsInRWithArrow(f, p, b, is, ot, key, grouping, child) =>
execution.FlatMapGroupsInRWithArrowExec(
f, p, b, is, ot, key, grouping, planLater(child)) :: Nil
case logical.MapPartitionsInRWithArrow(f, p, b, is, ot, child) =>
execution.MapPartitionsInRWithArrowExec(
f, p, b, is, ot, planLater(child)) :: Nil
case logical.FlatMapGroupsInPandas(grouping, func, output, child) =>
execution.python.FlatMapGroupsInPandasExec(grouping, func, output, planLater(child)) :: Nil
case logical.MapElements(f, _, _, objAttr, child) =>
Expand Down