Skip to content

Commit

Permalink
[SPARK-20726][SPARKR] wrapper for SQL broadcast
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

- Adds R wrapper for `o.a.s.sql.functions.broadcast`.
- Renames `broadcast` to `broadcast_`.

## How was this patch tested?

Unit tests, check `check-cran.sh`.

Author: zero323 <zero323@users.noreply.github.com>

Closes #17965 from zero323/SPARK-20726.
  • Loading branch information
zero323 authored and Felix Cheung committed May 14, 2017
1 parent aa3df15 commit 5a799fd
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 4 deletions.
1 change: 1 addition & 0 deletions R/pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ exportClasses("SparkDataFrame")
exportMethods("arrange",
"as.data.frame",
"attach",
"broadcast",
"cache",
"checkpoint",
"coalesce",
Expand Down
29 changes: 29 additions & 0 deletions R/pkg/R/DataFrame.R
Original file line number Diff line number Diff line change
Expand Up @@ -3769,3 +3769,32 @@ setMethod("alias",
sdf <- callJMethod(object@sdf, "alias", data)
dataFrame(sdf)
})

#' broadcast
#'
#' Return a new SparkDataFrame marked as small enough for use in broadcast joins.
#'
#' Equivalent to \code{hint(x, "broadcast")}.
#'
#' @param x a SparkDataFrame.
#' @return a SparkDataFrame.
#'
#' @aliases broadcast,SparkDataFrame-method
#' @family SparkDataFrame functions
#' @rdname broadcast
#' @name broadcast
#' @export
#' @examples
#' \dontrun{
#' df <- createDataFrame(mtcars)
#' avg_mpg <- mean(groupBy(createDataFrame(mtcars), "cyl"), "mpg")
#'
#' head(join(df, broadcast(avg_mpg), df$cyl == avg_mpg$cyl))
#' }
#' @note broadcast since 2.3.0
setMethod("broadcast",
signature(x = "SparkDataFrame"),
function(x) {
sdf <- callJStatic("org.apache.spark.sql.functions", "broadcast", x@sdf)
dataFrame(sdf)
})
4 changes: 2 additions & 2 deletions R/pkg/R/context.R
Original file line number Diff line number Diff line change
Expand Up @@ -258,15 +258,15 @@ includePackage <- function(sc, pkg) {
#'
#' # Large Matrix object that we want to broadcast
#' randomMat <- matrix(nrow=100, ncol=10, data=rnorm(1000))
#' randomMatBr <- broadcast(sc, randomMat)
#' randomMatBr <- broadcastRDD(sc, randomMat)
#'
#' # Use the broadcast variable inside the function
#' useBroadcast <- function(x) {
#' sum(value(randomMatBr) * x)
#' }
#' sumRDD <- lapply(rdd, useBroadcast)
#'}
broadcast <- function(sc, object) {
broadcastRDD <- function(sc, object) {
objName <- as.character(substitute(object))
serializedObj <- serialize(object, connection = NULL)

Expand Down
4 changes: 4 additions & 0 deletions R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,10 @@ setGeneric("write.df", function(df, path = NULL, ...) { standardGeneric("write.d
#' @export
setGeneric("randomSplit", function(x, weights, seed) { standardGeneric("randomSplit") })

#' @rdname broadcast
#' @export
setGeneric("broadcast", function(x) { standardGeneric("broadcast") })

###################### Column Methods ##########################

#' @rdname columnfunctions
Expand Down
2 changes: 1 addition & 1 deletion R/pkg/inst/tests/testthat/test_broadcast.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ test_that("using broadcast variable", {
skip_on_cran()

randomMat <- matrix(nrow = 10, ncol = 10, data = rnorm(100))
randomMatBr <- broadcast(sc, randomMat)
randomMatBr <- broadcastRDD(sc, randomMat)

useBroadcast <- function(x) {
sum(SparkR:::value(randomMatBr) * x)
Expand Down
5 changes: 5 additions & 0 deletions R/pkg/inst/tests/testthat/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -2216,6 +2216,11 @@ test_that("join(), crossJoin() and merge() on a DataFrame", {
explain(join(df1, hint(df2, "broadcast"), df1$id == df2$id))
)
expect_true(any(grepl("BroadcastHashJoin", execution_plan_hint)))

execution_plan_broadcast <- capture.output(
explain(join(df1, broadcast(df2), df1$id == df2$id))
)
expect_true(any(grepl("BroadcastHashJoin", execution_plan_broadcast)))
})

test_that("toJSON() on DataFrame", {
Expand Down
2 changes: 1 addition & 1 deletion R/pkg/inst/tests/testthat/test_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ test_that("cleanClosure on R functions", {

# Test for broadcast variables.
a <- matrix(nrow = 10, ncol = 10, data = rnorm(100))
aBroadcast <- broadcast(sc, a)
aBroadcast <- broadcastRDD(sc, a)
normMultiply <- function(x) { norm(aBroadcast$value) * x }
newnormMultiply <- SparkR:::cleanClosure(normMultiply)
env <- environment(newnormMultiply)
Expand Down

0 comments on commit 5a799fd

Please sign in to comment.