Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/apache/spark into multiIt…
Browse files Browse the repository at this point in the history
…ems_2
  • Loading branch information
zhangjiajin committed Jul 31, 2015
2 parents ae8c02d + 83670fc commit 09dc409
Show file tree
Hide file tree
Showing 526 changed files with 19,735 additions and 8,959 deletions.
4 changes: 3 additions & 1 deletion R/pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ export("print.jobj")

# MLlib integration
exportMethods("glm",
"predict")
"predict",
"summary")

# Job group lifecycle management methods
export("setJobGroup",
Expand All @@ -26,6 +27,7 @@ exportMethods("arrange",
"collect",
"columns",
"count",
"crosstab",
"describe",
"distinct",
"dropna",
Expand Down
28 changes: 28 additions & 0 deletions R/pkg/R/DataFrame.R
Original file line number Diff line number Diff line change
Expand Up @@ -1554,3 +1554,31 @@ setMethod("fillna",
}
dataFrame(sdf)
})

#' crosstab
#'
#' Computes a pair-wise frequency table of the given columns. Also known as a contingency
#' table. The number of distinct values for each column should be less than 1e4. At most 1e6
#' non-zero pair frequencies will be returned.
#'
#' @param col1 name of the first column. Distinct items will make the first item of each row.
#' @param col2 name of the second column. Distinct items will make the column names of the output.
#' @return a local R data.frame representing the contingency table. The first column of each row
#' will be the distinct values of `col1` and the column names will be the distinct values
#' of `col2`. The name of the first column will be `$col1_$col2`. Pairs that have no
#' occurrences will have zero as their counts.
#'
#' @rdname statfunctions
#' @export
#' @examples
#' \dontrun{
#' df <- jsonFile(sqlCtx, "/path/to/file.json")
#' ct = crosstab(df, "title", "gender")
#' }
setMethod("crosstab",
signature(x = "DataFrame", col1 = "character", col2 = "character"),
function(x, col1, col2) {
statFunctions <- callJMethod(x@sdf, "stat")
sct <- callJMethod(statFunctions, "crosstab", col1, col2)
collect(dataFrame(sct))
})
4 changes: 3 additions & 1 deletion R/pkg/R/backend.R
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ invokeJava <- function(isStatic, objId, methodName, ...) {

# TODO: check the status code to output error information
returnStatus <- readInt(conn)
stopifnot(returnStatus == 0)
if (returnStatus != 0) {
stop(readString(conn))
}
readObject(conn)
}
2 changes: 1 addition & 1 deletion R/pkg/R/client.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ generateSparkSubmitArgs <- function(args, sparkHome, jars, sparkSubmitOpts, pack
jars <- paste("--jars", jars)
}

if (packages != "") {
if (!identical(packages, "")) {
packages <- paste("--packages", packages)
}

Expand Down
4 changes: 2 additions & 2 deletions R/pkg/R/deserialize.R
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,11 @@ readList <- function(con) {

readRaw <- function(con) {
dataLen <- readInt(con)
data <- readBin(con, raw(), as.integer(dataLen), endian = "big")
readBin(con, raw(), as.integer(dataLen), endian = "big")
}

readRawLen <- function(con, dataLen) {
data <- readBin(con, raw(), as.integer(dataLen), endian = "big")
readBin(con, raw(), as.integer(dataLen), endian = "big")
}

readDeserialize <- function(con) {
Expand Down
18 changes: 11 additions & 7 deletions R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ setGeneric("count", function(x) { standardGeneric("count") })
# @export
setGeneric("countByValue", function(x) { standardGeneric("countByValue") })

# @rdname statfunctions
# @export
setGeneric("crosstab", function(x, col1, col2) { standardGeneric("crosstab") })

# @rdname distinct
# @export
setGeneric("distinct", function(x, numPartitions = 1) { standardGeneric("distinct") })
Expand Down Expand Up @@ -250,8 +254,10 @@ setGeneric("flatMapValues", function(X, FUN) { standardGeneric("flatMapValues")

# @rdname intersection
# @export
setGeneric("intersection", function(x, other, numPartitions = 1) {
standardGeneric("intersection") })
setGeneric("intersection",
function(x, other, numPartitions = 1) {
standardGeneric("intersection")
})

# @rdname keys
# @export
Expand Down Expand Up @@ -485,9 +491,7 @@ setGeneric("sample",
#' @rdname sample
#' @export
setGeneric("sample_frac",
function(x, withReplacement, fraction, seed) {
standardGeneric("sample_frac")
})
function(x, withReplacement, fraction, seed) { standardGeneric("sample_frac") })

#' @rdname saveAsParquetFile
#' @export
Expand Down Expand Up @@ -549,8 +553,8 @@ setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn

#' @rdname withColumnRenamed
#' @export
setGeneric("withColumnRenamed", function(x, existingCol, newCol) {
standardGeneric("withColumnRenamed") })
setGeneric("withColumnRenamed",
function(x, existingCol, newCol) { standardGeneric("withColumnRenamed") })


###################### Column Methods ##########################
Expand Down
28 changes: 27 additions & 1 deletion R/pkg/R/mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ setClass("PipelineModel", representation(model = "jobj"))
#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
#'
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~' and '+'.
#' operators are supported, including '~', '+', '-', and '.'.
#' @param data DataFrame for training
#' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg.
#' @param lambda Regularization parameter
Expand Down Expand Up @@ -71,3 +71,29 @@ setMethod("predict", signature(object = "PipelineModel"),
function(object, newData) {
return(dataFrame(callJMethod(object@model, "transform", newData@sdf)))
})

#' Get the summary of a model
#'
#' Returns the summary of a model produced by glm(), similarly to R's summary().
#'
#' @param model A fitted MLlib model
#' @return a list with a 'coefficient' component, which is the matrix of coefficients. See
#' summary.glm for more information.
#' @rdname glm
#' @export
#' @examples
#'\dontrun{
#' model <- glm(y ~ x, trainingData)
#' summary(model)
#'}
setMethod("summary", signature(object = "PipelineModel"),
function(object) {
features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getModelFeatures", object@model)
weights <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getModelWeights", object@model)
coefficients <- as.matrix(unlist(weights))
colnames(coefficients) <- c("Estimate")
rownames(coefficients) <- unlist(features)
return(list(coefficients = coefficients))
})
4 changes: 2 additions & 2 deletions R/pkg/R/pairRDD.R
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ setMethod("partitionBy",

packageNamesArr <- serialize(.sparkREnv$.packages,
connection = NULL)
broadcastArr <- lapply(ls(.broadcastNames), function(name) {
get(name, .broadcastNames) })
broadcastArr <- lapply(ls(.broadcastNames),
function(name) { get(name, .broadcastNames) })
jrdd <- getJRDD(x)

# We create a PairwiseRRDD that extends RDD[(Int, Array[Byte])],
Expand Down
12 changes: 6 additions & 6 deletions R/pkg/R/sparkR.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
connExists <- function(env) {
tryCatch({
exists(".sparkRCon", envir = env) && isOpen(env[[".sparkRCon"]])
}, error = function(err) {
},
error = function(err) {
return(FALSE)
})
}
Expand Down Expand Up @@ -104,16 +105,13 @@ sparkR.init <- function(
return(get(".sparkRjsc", envir = .sparkREnv))
}

sparkMem <- Sys.getenv("SPARK_MEM", "1024m")
jars <- suppressWarnings(normalizePath(as.character(sparkJars)))

# Classpath separator is ";" on Windows
# URI needs four /// as from http://stackoverflow.com/a/18522792
if (.Platform$OS.type == "unix") {
collapseChar <- ":"
uriSep <- "//"
} else {
collapseChar <- ";"
uriSep <- "////"
}

Expand Down Expand Up @@ -156,7 +154,8 @@ sparkR.init <- function(
.sparkREnv$backendPort <- backendPort
tryCatch({
connectBackend("localhost", backendPort)
}, error = function(err) {
},
error = function(err) {
stop("Failed to connect JVM\n")
})

Expand Down Expand Up @@ -267,7 +266,8 @@ sparkRHive.init <- function(jsc = NULL) {
ssc <- callJMethod(sc, "sc")
hiveCtx <- tryCatch({
newJObject("org.apache.spark.sql.hive.HiveContext", ssc)
}, error = function(err) {
},
error = function(err) {
stop("Spark SQL is not built with Hive support")
})

Expand Down
4 changes: 4 additions & 0 deletions R/pkg/inst/tests/test_client.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,7 @@ test_that("no package specified doesn't add packages flag", {
expect_equal(gsub("[[:space:]]", "", args),
"")
})

test_that("multiple packages don't produce a warning", {
expect_that(generateSparkSubmitArgs("", "", "", "", c("A", "B")), not(gives_warning()))
})
25 changes: 22 additions & 3 deletions R/pkg/inst/tests/test_mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,27 @@ test_that("glm and predict", {

test_that("predictions match with native glm", {
training <- createDataFrame(sqlContext, iris)
model <- glm(Sepal_Width ~ Sepal_Length, data = training)
model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training)
vals <- collect(select(predict(model, training), "prediction"))
rVals <- predict(glm(Sepal.Width ~ Sepal.Length, data = iris), iris)
expect_true(all(abs(rVals - vals) < 1e-9), rVals - vals)
rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris)
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
})

test_that("dot minus and intercept vs native glm", {
training <- createDataFrame(sqlContext, iris)
model <- glm(Sepal_Width ~ . - Species + 0, data = training)
vals <- collect(select(predict(model, training), "prediction"))
rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris)
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
})

test_that("summary coefficients match with native glm", {
training <- createDataFrame(sqlContext, iris)
stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training))
coefs <- as.vector(stats$coefficients)
rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)))
expect_true(all(abs(rCoefs - coefs) < 1e-6))
expect_true(all(
as.character(stats$features) ==
c("(Intercept)", "Sepal_Length", "Species__versicolor", "Species__virginica")))
})
24 changes: 22 additions & 2 deletions R/pkg/inst/tests/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ test_that("create DataFrame from RDD", {
df <- jsonFile(sqlContext, jsonPathNa)
hiveCtx <- tryCatch({
newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc)
}, error = function(err) {
},
error = function(err) {
skip("Hive is not build with SparkSQL, skipped")
})
sql(hiveCtx, "CREATE TABLE people (name string, age double, height float)")
Expand Down Expand Up @@ -602,7 +603,8 @@ test_that("write.df() as parquet file", {
test_that("test HiveContext", {
hiveCtx <- tryCatch({
newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc)
}, error = function(err) {
},
error = function(err) {
skip("Hive is not build with SparkSQL, skipped")
})
df <- createExternalTable(hiveCtx, "json", jsonPath, "json")
Expand Down Expand Up @@ -987,6 +989,24 @@ test_that("fillna() on a DataFrame", {
expect_identical(expected, actual)
})

test_that("crosstab() on a DataFrame", {
rdd <- lapply(parallelize(sc, 0:3), function(x) {
list(paste0("a", x %% 3), paste0("b", x %% 2))
})
df <- toDF(rdd, list("a", "b"))
ct <- crosstab(df, "a", "b")
ordered <- ct[order(ct$a_b),]
row.names(ordered) <- NULL
expected <- data.frame("a_b" = c("a0", "a1", "a2"), "b0" = c(1, 0, 1), "b1" = c(1, 1, 0),
stringsAsFactors = FALSE, row.names = NULL)
expect_identical(expected, ordered)
})

test_that("SQL error message is returned from JVM", {
retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e)
expect_equal(grepl("Table Not Found: blah", retError), TRUE)
})

unlink(parquetPath)
unlink(jsonPath)
unlink(jsonPathNa)
2 changes: 1 addition & 1 deletion R/run-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ FAILED=0
LOGFILE=$FWDIR/unit-tests.out
rm -f $LOGFILE

SPARK_TESTING=1 $FWDIR/../bin/sparkR --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE
SPARK_TESTING=1 $FWDIR/../bin/sparkR --conf spark.buffer.pageSize=4m --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE
FAILED=$((PIPESTATUS[0]||$FAILED))

if [[ $FAILED != 0 ]]; then
Expand Down
2 changes: 1 addition & 1 deletion bin/pyspark
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,4 @@ fi

export PYSPARK_DRIVER_PYTHON
export PYSPARK_DRIVER_PYTHON_OPTS
exec "$SPARK_HOME"/bin/spark-submit pyspark-shell-main "$@"
exec "$SPARK_HOME"/bin/spark-submit pyspark-shell-main --name "PySparkShell" "$@"
2 changes: 1 addition & 1 deletion bin/pyspark2.cmd
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,4 @@ set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.8.2.1-src.zip;%PYTHONPATH%
set OLD_PYTHONSTARTUP=%PYTHONSTARTUP%
set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py

call %SPARK_HOME%\bin\spark-submit2.cmd pyspark-shell-main %*
call %SPARK_HOME%\bin\spark-submit2.cmd pyspark-shell-main --name "PySparkShell" %*
4 changes: 2 additions & 2 deletions bin/spark-shell
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ function main() {
# (see https://github.com/sbt/sbt/issues/562).
stty -icanon min 1 -echo > /dev/null 2>&1
export SPARK_SUBMIT_OPTS="$SPARK_SUBMIT_OPTS -Djline.terminal=unix"
"$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main "$@"
"$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main --name "Spark shell" "$@"
stty icanon echo > /dev/null 2>&1
else
export SPARK_SUBMIT_OPTS
"$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main "$@"
"$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main --name "Spark shell" "$@"
fi
}

Expand Down
2 changes: 1 addition & 1 deletion bin/spark-shell2.cmd
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ if "x%SPARK_SUBMIT_OPTS%"=="x" (
set SPARK_SUBMIT_OPTS="%SPARK_SUBMIT_OPTS% -Dscala.usejavacp=true"

:run_shell
%SPARK_HOME%\bin\spark-submit2.cmd --class org.apache.spark.repl.Main %*
%SPARK_HOME%\bin\spark-submit2.cmd --class org.apache.spark.repl.Main --name "Spark shell" %*
8 changes: 6 additions & 2 deletions build/sbt-launch-lib.bash
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,13 @@ acquire_sbt_jar () {
printf "Attempting to fetch sbt\n"
JAR_DL="${JAR}.part"
if [ $(command -v curl) ]; then
(curl --silent ${URL1} > "${JAR_DL}" || curl --silent ${URL2} > "${JAR_DL}") && mv "${JAR_DL}" "${JAR}"
(curl --fail --location --silent ${URL1} > "${JAR_DL}" ||\
(rm -f "${JAR_DL}" && curl --fail --location --silent ${URL2} > "${JAR_DL}")) &&\
mv "${JAR_DL}" "${JAR}"
elif [ $(command -v wget) ]; then
(wget --quiet ${URL1} -O "${JAR_DL}" || wget --quiet ${URL2} -O "${JAR_DL}") && mv "${JAR_DL}" "${JAR}"
(wget --quiet ${URL1} -O "${JAR_DL}" ||\
(rm -f "${JAR_DL}" && wget --quiet ${URL2} -O "${JAR_DL}")) &&\
mv "${JAR_DL}" "${JAR}"
else
printf "You do not have curl or wget installed, please install sbt manually from http://www.scala-sbt.org/\n"
exit -1
Expand Down
4 changes: 4 additions & 0 deletions conf/log4j.properties.template
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,7 @@ log4j.logger.org.spark-project.jetty=WARN
log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR
log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO
log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO

# SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support
log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL
log4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR
Loading

0 comments on commit 09dc409

Please sign in to comment.