diff --git a/.rat-excludes b/.rat-excludes index 0240e81c45ea2..236c2db05367c 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -91,3 +91,5 @@ help/* html/* INDEX .lintr +gen-java.* +.*avpr diff --git a/R/README.md b/R/README.md index d7d65b4f0eca5..005f56da1670c 100644 --- a/R/README.md +++ b/R/README.md @@ -6,7 +6,7 @@ SparkR is an R package that provides a light-weight frontend to use Spark from R #### Build Spark -Build Spark with [Maven](http://spark.apache.org/docs/latest/building-spark.html#building-with-buildmvn) and include the `-PsparkR` profile to build the R package. For example to use the default Hadoop versions you can run +Build Spark with [Maven](http://spark.apache.org/docs/latest/building-spark.html#building-with-buildmvn) and include the `-Psparkr` profile to build the R package. For example to use the default Hadoop versions you can run ``` build/mvn -DskipTests -Psparkr package ``` diff --git a/R/install-dev.bat b/R/install-dev.bat index 008a5c668bc45..f32670b67de96 100644 --- a/R/install-dev.bat +++ b/R/install-dev.bat @@ -25,3 +25,8 @@ set SPARK_HOME=%~dp0.. MKDIR %SPARK_HOME%\R\lib R.exe CMD INSTALL --library="%SPARK_HOME%\R\lib" %SPARK_HOME%\R\pkg\ + +rem Zip the SparkR package so that it can be distributed to worker nodes on YARN +pushd %SPARK_HOME%\R\lib +%JAVA_HOME%\bin\jar.exe cfM "%SPARK_HOME%\R\lib\sparkr.zip" SparkR +popd diff --git a/R/install-dev.sh b/R/install-dev.sh index 1edd551f8d243..4972bb9217072 100755 --- a/R/install-dev.sh +++ b/R/install-dev.sh @@ -34,7 +34,7 @@ LIB_DIR="$FWDIR/lib" mkdir -p $LIB_DIR -pushd $FWDIR +pushd $FWDIR > /dev/null # Generate Rd files if devtools is installed Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtools); devtools::document(pkg="./pkg", roclets=c("rd")) }' @@ -42,4 +42,8 @@ Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtoo # Install SparkR to $LIB_DIR R CMD INSTALL --library=$LIB_DIR $FWDIR/pkg/ -popd +# Zip the SparkR package so that it can be distributed to worker nodes on YARN +cd $LIB_DIR +jar cfM "$LIB_DIR/sparkr.zip" SparkR + +popd > /dev/null diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index efc85bbc4b316..4949d86d20c91 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -29,7 +29,7 @@ Collate: 'client.R' 'context.R' 'deserialize.R' + 'mllib.R' 'serialize.R' 'sparkR.R' 'utils.R' - 'zzz.R' diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 7f857222452d4..a329e14f25aeb 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -10,6 +10,11 @@ export("sparkR.init") export("sparkR.stop") export("print.jobj") +# MLlib integration +exportMethods("glm", + "predict", + "summary") + # Job group lifecycle management methods export("setJobGroup", "clearJobGroup", @@ -22,6 +27,7 @@ exportMethods("arrange", "collect", "columns", "count", + "crosstab", "describe", "distinct", "dropna", @@ -77,6 +83,7 @@ exportMethods("abs", "atan", "atan2", "avg", + "between", "cast", "cbrt", "ceiling", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 60702824acb46..f4c93d3c7dd67 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1314,7 +1314,7 @@ setMethod("except", #' write.df(df, "myfile", "parquet", "overwrite") #' } setMethod("write.df", - signature(df = "DataFrame", path = 'character'), + signature(df = "DataFrame", path = "character"), function(df, path, source = NULL, mode = "append", ...){ if (is.null(source)) { sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) @@ -1328,7 +1328,7 @@ setMethod("write.df", jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) options <- varargsToEnv(...) if (!is.null(path)) { - options[['path']] = path + options[["path"]] <- path } callJMethod(df@sdf, "save", source, jmode, options) }) @@ -1337,7 +1337,7 @@ setMethod("write.df", #' @aliases saveDF #' @export setMethod("saveDF", - signature(df = "DataFrame", path = 'character'), + signature(df = "DataFrame", path = "character"), function(df, path, source = NULL, mode = "append", ...){ write.df(df, path, source, mode, ...) }) @@ -1375,8 +1375,8 @@ setMethod("saveDF", #' saveAsTable(df, "myfile") #' } setMethod("saveAsTable", - signature(df = "DataFrame", tableName = 'character', source = 'character', - mode = 'character'), + signature(df = "DataFrame", tableName = "character", source = "character", + mode = "character"), function(df, tableName, source = NULL, mode="append", ...){ if (is.null(source)) { sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) @@ -1554,3 +1554,31 @@ setMethod("fillna", } dataFrame(sdf) }) + +#' crosstab +#' +#' Computes a pair-wise frequency table of the given columns. Also known as a contingency +#' table. The number of distinct values for each column should be less than 1e4. At most 1e6 +#' non-zero pair frequencies will be returned. +#' +#' @param col1 name of the first column. Distinct items will make the first item of each row. +#' @param col2 name of the second column. Distinct items will make the column names of the output. +#' @return a local R data.frame representing the contingency table. The first column of each row +#' will be the distinct values of `col1` and the column names will be the distinct values +#' of `col2`. The name of the first column will be `$col1_$col2`. Pairs that have no +#' occurrences will have zero as their counts. +#' +#' @rdname statfunctions +#' @export +#' @examples +#' \dontrun{ +#' df <- jsonFile(sqlCtx, "/path/to/file.json") +#' ct = crosstab(df, "title", "gender") +#' } +setMethod("crosstab", + signature(x = "DataFrame", col1 = "character", col2 = "character"), + function(x, col1, col2) { + statFunctions <- callJMethod(x@sdf, "stat") + sct <- callJMethod(statFunctions, "crosstab", col1, col2) + collect(dataFrame(sct)) + }) diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 89511141d3ef7..d2d096709245d 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -165,7 +165,6 @@ setMethod("getJRDD", signature(rdd = "PipelinedRDD"), serializedFuncArr, rdd@env$prev_serializedMode, packageNamesArr, - as.character(.sparkREnv[["libname"]]), broadcastArr, callJMethod(prev_jrdd, "classTag")) } else { @@ -175,7 +174,6 @@ setMethod("getJRDD", signature(rdd = "PipelinedRDD"), rdd@env$prev_serializedMode, serializedMode, packageNamesArr, - as.character(.sparkREnv[["libname"]]), broadcastArr, callJMethod(prev_jrdd, "classTag")) } diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 9a743a3411533..110117a18ccbc 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -86,7 +86,9 @@ infer_type <- function(x) { createDataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0) { if (is.data.frame(data)) { # get the names of columns, they will be put into RDD - schema <- names(data) + if (is.null(schema)) { + schema <- names(data) + } n <- nrow(data) m <- ncol(data) # get rid of factor type @@ -455,7 +457,7 @@ dropTempTable <- function(sqlContext, tableName) { read.df <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { options <- varargsToEnv(...) if (!is.null(path)) { - options[['path']] <- path + options[["path"]] <- path } if (is.null(source)) { sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) @@ -504,7 +506,7 @@ loadDF <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { createExternalTable <- function(sqlContext, tableName, path = NULL, source = NULL, ...) { options <- varargsToEnv(...) if (!is.null(path)) { - options[['path']] <- path + options[["path"]] <- path } sdf <- callJMethod(sqlContext, "createExternalTable", tableName, source, options) dataFrame(sdf) diff --git a/R/pkg/R/backend.R b/R/pkg/R/backend.R index 2fb6fae55f28c..49162838b8d1a 100644 --- a/R/pkg/R/backend.R +++ b/R/pkg/R/backend.R @@ -110,6 +110,8 @@ invokeJava <- function(isStatic, objId, methodName, ...) { # TODO: check the status code to output error information returnStatus <- readInt(conn) - stopifnot(returnStatus == 0) + if (returnStatus != 0) { + stop(readString(conn)) + } readObject(conn) } diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index 78c7a3037ffac..c811d1dac3bd5 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -36,9 +36,9 @@ connectBackend <- function(hostname, port, timeout = 6000) { determineSparkSubmitBin <- function() { if (.Platform$OS.type == "unix") { - sparkSubmitBinName = "spark-submit" + sparkSubmitBinName <- "spark-submit" } else { - sparkSubmitBinName = "spark-submit.cmd" + sparkSubmitBinName <- "spark-submit.cmd" } sparkSubmitBinName } @@ -48,7 +48,7 @@ generateSparkSubmitArgs <- function(args, sparkHome, jars, sparkSubmitOpts, pack jars <- paste("--jars", jars) } - if (packages != "") { + if (!identical(packages, "")) { packages <- paste("--packages", packages) } diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 8e4b0f5bf1c4d..2892e1416cc65 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -187,6 +187,23 @@ setMethod("substr", signature(x = "Column"), column(jc) }) +#' between +#' +#' Test if the column is between the lower bound and upper bound, inclusive. +#' +#' @rdname column +#' +#' @param bounds lower and upper bounds +setMethod("between", signature(x = "Column"), + function(x, bounds) { + if (is.vector(bounds) && length(bounds) == 2) { + jc <- callJMethod(x@jc, "between", bounds[1], bounds[2]) + column(jc) + } else { + stop("bounds should be a vector of lower and upper bounds") + } + }) + #' Casts the column to a different data type. #' #' @rdname column diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index d961bbc383688..6d364f77be7ee 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -23,6 +23,7 @@ # Int -> integer # String -> character # Boolean -> logical +# Float -> double # Double -> double # Long -> double # Array[Byte] -> raw @@ -101,11 +102,11 @@ readList <- function(con) { readRaw <- function(con) { dataLen <- readInt(con) - data <- readBin(con, raw(), as.integer(dataLen), endian = "big") + readBin(con, raw(), as.integer(dataLen), endian = "big") } readRawLen <- function(con, dataLen) { - data <- readBin(con, raw(), as.integer(dataLen), endian = "big") + readBin(con, raw(), as.integer(dataLen), endian = "big") } readDeserialize <- function(con) { diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 79055b7f18558..a3a121058e165 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -20,7 +20,8 @@ # @rdname aggregateRDD # @seealso reduce # @export -setGeneric("aggregateRDD", function(x, zeroValue, seqOp, combOp) { standardGeneric("aggregateRDD") }) +setGeneric("aggregateRDD", + function(x, zeroValue, seqOp, combOp) { standardGeneric("aggregateRDD") }) # @rdname cache-methods # @export @@ -58,6 +59,10 @@ setGeneric("count", function(x) { standardGeneric("count") }) # @export setGeneric("countByValue", function(x) { standardGeneric("countByValue") }) +# @rdname statfunctions +# @export +setGeneric("crosstab", function(x, col1, col2) { standardGeneric("crosstab") }) + # @rdname distinct # @export setGeneric("distinct", function(x, numPartitions = 1) { standardGeneric("distinct") }) @@ -249,8 +254,10 @@ setGeneric("flatMapValues", function(X, FUN) { standardGeneric("flatMapValues") # @rdname intersection # @export -setGeneric("intersection", function(x, other, numPartitions = 1) { - standardGeneric("intersection") }) +setGeneric("intersection", + function(x, other, numPartitions = 1) { + standardGeneric("intersection") + }) # @rdname keys # @export @@ -484,9 +491,7 @@ setGeneric("sample", #' @rdname sample #' @export setGeneric("sample_frac", - function(x, withReplacement, fraction, seed) { - standardGeneric("sample_frac") - }) + function(x, withReplacement, fraction, seed) { standardGeneric("sample_frac") }) #' @rdname saveAsParquetFile #' @export @@ -548,8 +553,8 @@ setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn #' @rdname withColumnRenamed #' @export -setGeneric("withColumnRenamed", function(x, existingCol, newCol) { - standardGeneric("withColumnRenamed") }) +setGeneric("withColumnRenamed", + function(x, existingCol, newCol) { standardGeneric("withColumnRenamed") }) ###################### Column Methods ########################## @@ -566,6 +571,10 @@ setGeneric("asc", function(x) { standardGeneric("asc") }) #' @export setGeneric("avg", function(x, ...) { standardGeneric("avg") }) +#' @rdname column +#' @export +setGeneric("between", function(x, bounds) { standardGeneric("between") }) + #' @rdname column #' @export setGeneric("cast", function(x, dataType) { standardGeneric("cast") }) @@ -656,3 +665,7 @@ setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) #' @rdname column #' @export setGeneric("upper", function(x) { standardGeneric("upper") }) + +#' @rdname glm +#' @export +setGeneric("glm") diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 8f1c68f7c4d28..576ac72f40fc0 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -87,7 +87,7 @@ setMethod("count", setMethod("agg", signature(x = "GroupedData"), function(x, ...) { - cols = list(...) + cols <- list(...) stopifnot(length(cols) > 0) if (is.character(cols[[1]])) { cols <- varargsToEnv(...) @@ -97,7 +97,7 @@ setMethod("agg", if (!is.null(ns)) { for (n in ns) { if (n != "") { - cols[[n]] = alias(cols[[n]], n) + cols[[n]] <- alias(cols[[n]], n) } } } diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R new file mode 100644 index 0000000000000..efddcc1d8d71c --- /dev/null +++ b/R/pkg/R/mllib.R @@ -0,0 +1,99 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# mllib.R: Provides methods for MLlib integration + +#' @title S4 class that represents a PipelineModel +#' @param model A Java object reference to the backing Scala PipelineModel +#' @export +setClass("PipelineModel", representation(model = "jobj")) + +#' Fits a generalized linear model +#' +#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package. +#' +#' @param formula A symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', '+', '-', and '.'. +#' @param data DataFrame for training +#' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg. +#' @param lambda Regularization parameter +#' @param alpha Elastic-net mixing parameter (see glmnet's documentation for details) +#' @return a fitted MLlib model +#' @rdname glm +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' data(iris) +#' df <- createDataFrame(sqlContext, iris) +#' model <- glm(Sepal_Length ~ Sepal_Width, df) +#'} +setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"), + function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0) { + family <- match.arg(family) + model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "fitRModelFormula", deparse(formula), data@sdf, family, lambda, + alpha) + return(new("PipelineModel", model = model)) + }) + +#' Make predictions from a model +#' +#' Makes predictions from a model produced by glm(), similarly to R's predict(). +#' +#' @param model A fitted MLlib model +#' @param newData DataFrame for testing +#' @return DataFrame containing predicted values +#' @rdname glm +#' @export +#' @examples +#'\dontrun{ +#' model <- glm(y ~ x, trainingData) +#' predicted <- predict(model, testData) +#' showDF(predicted) +#'} +setMethod("predict", signature(object = "PipelineModel"), + function(object, newData) { + return(dataFrame(callJMethod(object@model, "transform", newData@sdf))) + }) + +#' Get the summary of a model +#' +#' Returns the summary of a model produced by glm(), similarly to R's summary(). +#' +#' @param model A fitted MLlib model +#' @return a list with a 'coefficient' component, which is the matrix of coefficients. See +#' summary.glm for more information. +#' @rdname glm +#' @export +#' @examples +#'\dontrun{ +#' model <- glm(y ~ x, trainingData) +#' summary(model) +#'} +setMethod("summary", signature(object = "PipelineModel"), + function(object) { + features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getModelFeatures", object@model) + weights <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getModelWeights", object@model) + coefficients <- as.matrix(unlist(weights)) + colnames(coefficients) <- c("Estimate") + rownames(coefficients) <- unlist(features) + return(list(coefficients = coefficients)) + }) diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 7f902ba8e683e..83801d3209700 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -202,8 +202,8 @@ setMethod("partitionBy", packageNamesArr <- serialize(.sparkREnv$.packages, connection = NULL) - broadcastArr <- lapply(ls(.broadcastNames), function(name) { - get(name, .broadcastNames) }) + broadcastArr <- lapply(ls(.broadcastNames), + function(name) { get(name, .broadcastNames) }) jrdd <- getJRDD(x) # We create a PairwiseRRDD that extends RDD[(Int, Array[Byte])], @@ -215,7 +215,6 @@ setMethod("partitionBy", serializedHashFuncBytes, getSerializedMode(x), packageNamesArr, - as.character(.sparkREnv$libname), broadcastArr, callJMethod(jrdd, "classTag")) @@ -560,8 +559,8 @@ setMethod("join", # Left outer join two RDDs # # @description -# \code{leftouterjoin} This function left-outer-joins two RDDs where every element is of the form list(K, V). -# The key types of the two RDDs should be the same. +# \code{leftouterjoin} This function left-outer-joins two RDDs where every element is of +# the form list(K, V). The key types of the two RDDs should be the same. # # @param x An RDD to be joined. Should be an RDD where each element is # list(K, V). @@ -597,8 +596,8 @@ setMethod("leftOuterJoin", # Right outer join two RDDs # # @description -# \code{rightouterjoin} This function right-outer-joins two RDDs where every element is of the form list(K, V). -# The key types of the two RDDs should be the same. +# \code{rightouterjoin} This function right-outer-joins two RDDs where every element is of +# the form list(K, V). The key types of the two RDDs should be the same. # # @param x An RDD to be joined. Should be an RDD where each element is # list(K, V). @@ -634,8 +633,8 @@ setMethod("rightOuterJoin", # Full outer join two RDDs # # @description -# \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of the form list(K, V). -# The key types of the two RDDs should be the same. +# \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of +# the form list(K, V). The key types of the two RDDs should be the same. # # @param x An RDD to be joined. Should be an RDD where each element is # list(K, V). diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index 15e2bdbd55d79..79c744ef29c23 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -69,11 +69,14 @@ structType.structField <- function(x, ...) { #' @param ... further arguments passed to or from other methods print.structType <- function(x, ...) { cat("StructType\n", - sapply(x$fields(), function(field) { paste("|-", "name = \"", field$name(), - "\", type = \"", field$dataType.toString(), - "\", nullable = ", field$nullable(), "\n", - sep = "") }) - , sep = "") + sapply(x$fields(), + function(field) { + paste("|-", "name = \"", field$name(), + "\", type = \"", field$dataType.toString(), + "\", nullable = ", field$nullable(), "\n", + sep = "") + }), + sep = "") } #' structField @@ -123,6 +126,7 @@ structField.character <- function(x, type, nullable = TRUE) { } options <- c("byte", "integer", + "float", "double", "numeric", "character", diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R index 78535eff0d2f6..311021e5d8473 100644 --- a/R/pkg/R/serialize.R +++ b/R/pkg/R/serialize.R @@ -140,8 +140,8 @@ writeType <- function(con, class) { jobj = "j", environment = "e", Date = "D", - POSIXlt = 't', - POSIXct = 't', + POSIXlt = "t", + POSIXct = "t", stop(paste("Unsupported type for serialization", class))) writeBin(charToRaw(type), con) } diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 86233e01db365..e83104f116422 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -17,16 +17,13 @@ .sparkREnv <- new.env() -sparkR.onLoad <- function(libname, pkgname) { - .sparkREnv$libname <- libname -} - # Utility function that returns TRUE if we have an active connection to the # backend and FALSE otherwise connExists <- function(env) { tryCatch({ exists(".sparkRCon", envir = env) && isOpen(env[[".sparkRCon"]]) - }, error = function(err) { + }, + error = function(err) { return(FALSE) }) } @@ -80,7 +77,6 @@ sparkR.stop <- function() { #' @param sparkEnvir Named list of environment variables to set on worker nodes. #' @param sparkExecutorEnv Named list of environment variables to be used when launching executors. #' @param sparkJars Character string vector of jar files to pass to the worker nodes. -#' @param sparkRLibDir The path where R is installed on the worker nodes. #' @param sparkPackages Character string vector of packages from spark-packages.org #' @export #' @examples @@ -101,24 +97,21 @@ sparkR.init <- function( sparkEnvir = list(), sparkExecutorEnv = list(), sparkJars = "", - sparkRLibDir = "", sparkPackages = "") { if (exists(".sparkRjsc", envir = .sparkREnv)) { - cat("Re-using existing Spark Context. Please stop SparkR with sparkR.stop() or restart R to create a new Spark Context\n") + cat(paste("Re-using existing Spark Context.", + "Please stop SparkR with sparkR.stop() or restart R to create a new Spark Context\n")) return(get(".sparkRjsc", envir = .sparkREnv)) } - sparkMem <- Sys.getenv("SPARK_MEM", "1024m") jars <- suppressWarnings(normalizePath(as.character(sparkJars))) # Classpath separator is ";" on Windows # URI needs four /// as from http://stackoverflow.com/a/18522792 if (.Platform$OS.type == "unix") { - collapseChar <- ":" uriSep <- "//" } else { - collapseChar <- ";" uriSep <- "////" } @@ -145,7 +138,7 @@ sparkR.init <- function( if (!file.exists(path)) { stop("JVM is not ready after 10 seconds") } - f <- file(path, open='rb') + f <- file(path, open="rb") backendPort <- readInt(f) monitorPort <- readInt(f) close(f) @@ -161,7 +154,8 @@ sparkR.init <- function( .sparkREnv$backendPort <- backendPort tryCatch({ connectBackend("localhost", backendPort) - }, error = function(err) { + }, + error = function(err) { stop("Failed to connect JVM\n") }) @@ -169,10 +163,6 @@ sparkR.init <- function( sparkHome <- normalizePath(sparkHome) } - if (nchar(sparkRLibDir) != 0) { - .sparkREnv$libname <- sparkRLibDir - } - sparkEnvirMap <- new.env() for (varname in names(sparkEnvir)) { sparkEnvirMap[[varname]] <- sparkEnvir[[varname]] @@ -180,14 +170,16 @@ sparkR.init <- function( sparkExecutorEnvMap <- new.env() if (!any(names(sparkExecutorEnv) == "LD_LIBRARY_PATH")) { - sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <- paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH")) + sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <- + paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH")) } for (varname in names(sparkExecutorEnv)) { sparkExecutorEnvMap[[varname]] <- sparkExecutorEnv[[varname]] } nonEmptyJars <- Filter(function(x) { x != "" }, jars) - localJarPaths <- sapply(nonEmptyJars, function(j) { utils::URLencode(paste("file:", uriSep, j, sep = "")) }) + localJarPaths <- sapply(nonEmptyJars, + function(j) { utils::URLencode(paste("file:", uriSep, j, sep = "")) }) # Set the start time to identify jobjs # Seconds resolution is good enough for this purpose, so use ints @@ -274,7 +266,8 @@ sparkRHive.init <- function(jsc = NULL) { ssc <- callJMethod(sc, "sc") hiveCtx <- tryCatch({ newJObject("org.apache.spark.sql.hive.HiveContext", ssc) - }, error = function(err) { + }, + error = function(err) { stop("Spark SQL is not built with Hive support") }) diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 13cec0f712fb4..3f45589a50443 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -41,8 +41,8 @@ convertJListToRList <- function(jList, flatten, logicalUpperBound = NULL, if (isInstanceOf(obj, "scala.Tuple2")) { # JavaPairRDD[Array[Byte], Array[Byte]]. - keyBytes = callJMethod(obj, "_1") - valBytes = callJMethod(obj, "_2") + keyBytes <- callJMethod(obj, "_1") + valBytes <- callJMethod(obj, "_2") res <- list(unserialize(keyBytes), unserialize(valBytes)) } else { @@ -334,18 +334,21 @@ getStorageLevel <- function(newLevel = c("DISK_ONLY", "MEMORY_ONLY_SER_2", "OFF_HEAP")) { match.arg(newLevel) + storageLevelClass <- "org.apache.spark.storage.StorageLevel" storageLevel <- switch(newLevel, - "DISK_ONLY" = callJStatic("org.apache.spark.storage.StorageLevel", "DISK_ONLY"), - "DISK_ONLY_2" = callJStatic("org.apache.spark.storage.StorageLevel", "DISK_ONLY_2"), - "MEMORY_AND_DISK" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK"), - "MEMORY_AND_DISK_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_2"), - "MEMORY_AND_DISK_SER" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_SER"), - "MEMORY_AND_DISK_SER_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_SER_2"), - "MEMORY_ONLY" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY"), - "MEMORY_ONLY_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_2"), - "MEMORY_ONLY_SER" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_SER"), - "MEMORY_ONLY_SER_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_SER_2"), - "OFF_HEAP" = callJStatic("org.apache.spark.storage.StorageLevel", "OFF_HEAP")) + "DISK_ONLY" = callJStatic(storageLevelClass, "DISK_ONLY"), + "DISK_ONLY_2" = callJStatic(storageLevelClass, "DISK_ONLY_2"), + "MEMORY_AND_DISK" = callJStatic(storageLevelClass, "MEMORY_AND_DISK"), + "MEMORY_AND_DISK_2" = callJStatic(storageLevelClass, "MEMORY_AND_DISK_2"), + "MEMORY_AND_DISK_SER" = callJStatic(storageLevelClass, + "MEMORY_AND_DISK_SER"), + "MEMORY_AND_DISK_SER_2" = callJStatic(storageLevelClass, + "MEMORY_AND_DISK_SER_2"), + "MEMORY_ONLY" = callJStatic(storageLevelClass, "MEMORY_ONLY"), + "MEMORY_ONLY_2" = callJStatic(storageLevelClass, "MEMORY_ONLY_2"), + "MEMORY_ONLY_SER" = callJStatic(storageLevelClass, "MEMORY_ONLY_SER"), + "MEMORY_ONLY_SER_2" = callJStatic(storageLevelClass, "MEMORY_ONLY_SER_2"), + "OFF_HEAP" = callJStatic(storageLevelClass, "OFF_HEAP")) } # Utility function for functions where an argument needs to be integer but we want to allow @@ -387,14 +390,17 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { for (i in 1:nodeLen) { processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) } - } else { # if node[[1]] is length of 1, check for some R special functions. + } else { + # if node[[1]] is length of 1, check for some R special functions. nodeChar <- as.character(node[[1]]) - if (nodeChar == "{" || nodeChar == "(") { # Skip start symbol. + if (nodeChar == "{" || nodeChar == "(") { + # Skip start symbol. for (i in 2:nodeLen) { processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) } } else if (nodeChar == "<-" || nodeChar == "=" || - nodeChar == "<<-") { # Assignment Ops. + nodeChar == "<<-") { + # Assignment Ops. defVar <- node[[2]] if (length(defVar) == 1 && typeof(defVar) == "symbol") { # Add the defined variable name into defVars. @@ -405,14 +411,16 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { for (i in 3:nodeLen) { processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) } - } else if (nodeChar == "function") { # Function definition. + } else if (nodeChar == "function") { + # Function definition. # Add parameter names. newArgs <- names(node[[2]]) lapply(newArgs, function(arg) { addItemToAccumulator(defVars, arg) }) for (i in 3:nodeLen) { processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) } - } else if (nodeChar == "$") { # Skip the field. + } else if (nodeChar == "$") { + # Skip the field. processClosure(node[[2]], oldEnv, defVars, checkedFuncs, newEnv) } else if (nodeChar == "::" || nodeChar == ":::") { processClosure(node[[3]], oldEnv, defVars, checkedFuncs, newEnv) @@ -426,7 +434,8 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { (typeof(node) == "symbol" || typeof(node) == "language")) { # Base case: current AST node is a leaf node and a symbol or a function call. nodeChar <- as.character(node) - if (!nodeChar %in% defVars$data) { # Not a function parameter or local variable. + if (!nodeChar %in% defVars$data) { + # Not a function parameter or local variable. func.env <- oldEnv topEnv <- parent.env(.GlobalEnv) # Search in function environment, and function's enclosing environments @@ -436,20 +445,24 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { while (!identical(func.env, topEnv)) { # Namespaces other than "SparkR" will not be searched. if (!isNamespace(func.env) || - (getNamespaceName(func.env) == "SparkR" && - !(nodeChar %in% getNamespaceExports("SparkR")))) { # Only include SparkR internals. + (getNamespaceName(func.env) == "SparkR" && + !(nodeChar %in% getNamespaceExports("SparkR")))) { + # Only include SparkR internals. + # Set parameter 'inherits' to FALSE since we do not need to search in # attached package environments. if (tryCatch(exists(nodeChar, envir = func.env, inherits = FALSE), error = function(e) { FALSE })) { obj <- get(nodeChar, envir = func.env, inherits = FALSE) - if (is.function(obj)) { # If the node is a function call. + if (is.function(obj)) { + # If the node is a function call. funcList <- mget(nodeChar, envir = checkedFuncs, inherits = F, ifnotfound = list(list(NULL)))[[1]] found <- sapply(funcList, function(func) { ifelse(identical(func, obj), TRUE, FALSE) }) - if (sum(found) > 0) { # If function has been examined, ignore. + if (sum(found) > 0) { + # If function has been examined, ignore. break } # Function has not been examined, record it and recursively clean its closure. @@ -492,7 +505,8 @@ cleanClosure <- function(func, checkedFuncs = new.env()) { # environment. First, function's arguments are added to defVars. defVars <- initAccumulator() argNames <- names(as.list(args(func))) - for (i in 1:(length(argNames) - 1)) { # Remove the ending NULL in pairlist. + for (i in 1:(length(argNames) - 1)) { + # Remove the ending NULL in pairlist. addItemToAccumulator(defVars, argNames[i]) } # Recursively examine variables in the function body. @@ -545,9 +559,11 @@ mergePartitions <- function(rdd, zip) { lengthOfKeys <- part[[len - lengthOfValues]] stopifnot(len == lengthOfKeys + lengthOfValues) - # For zip operation, check if corresponding partitions of both RDDs have the same number of elements. + # For zip operation, check if corresponding partitions + # of both RDDs have the same number of elements. if (zip && lengthOfKeys != lengthOfValues) { - stop("Can only zip RDDs with same number of elements in each pair of corresponding partitions.") + stop(paste("Can only zip RDDs with same number of elements", + "in each pair of corresponding partitions.")) } if (lengthOfKeys > 1) { diff --git a/R/pkg/inst/profile/general.R b/R/pkg/inst/profile/general.R index 8fe711b622086..2a8a8213d0849 100644 --- a/R/pkg/inst/profile/general.R +++ b/R/pkg/inst/profile/general.R @@ -16,7 +16,7 @@ # .First <- function() { - home <- Sys.getenv("SPARK_HOME") - .libPaths(c(file.path(home, "R", "lib"), .libPaths())) + packageDir <- Sys.getenv("SPARKR_PACKAGE_DIR") + .libPaths(c(packageDir, .libPaths())) Sys.setenv(NOAWT=1) } diff --git a/R/pkg/inst/tests/test_binaryFile.R b/R/pkg/inst/tests/test_binaryFile.R index ccaea18ecab2a..f2452ed97d2ea 100644 --- a/R/pkg/inst/tests/test_binaryFile.R +++ b/R/pkg/inst/tests/test_binaryFile.R @@ -20,7 +20,7 @@ context("functions on binary files") # JavaSparkContext handle sc <- sparkR.init() -mockFile = c("Spark is pretty.", "Spark is awesome.") +mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("saveAsObjectFile()/objectFile() following textFile() works", { fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/test_binary_function.R index 3be8c65a6c1a0..dca0657c57e0d 100644 --- a/R/pkg/inst/tests/test_binary_function.R +++ b/R/pkg/inst/tests/test_binary_function.R @@ -76,7 +76,7 @@ test_that("zipPartitions() on RDDs", { expect_equal(actual, list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6)))) - mockFile = c("Spark is pretty.", "Spark is awesome.") + mockFile <- c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern="spark-test", fileext=".tmp") writeLines(mockFile, fileName) diff --git a/R/pkg/inst/tests/test_client.R b/R/pkg/inst/tests/test_client.R index 30b05c1a2afcd..8a20991f89af8 100644 --- a/R/pkg/inst/tests/test_client.R +++ b/R/pkg/inst/tests/test_client.R @@ -30,3 +30,7 @@ test_that("no package specified doesn't add packages flag", { expect_equal(gsub("[[:space:]]", "", args), "") }) + +test_that("multiple packages don't produce a warning", { + expect_that(generateSparkSubmitArgs("", "", "", "", c("A", "B")), not(gives_warning())) +}) diff --git a/R/pkg/inst/tests/test_includeJAR.R b/R/pkg/inst/tests/test_includeJAR.R index 844d86f3cc97f..cc1faeabffe30 100644 --- a/R/pkg/inst/tests/test_includeJAR.R +++ b/R/pkg/inst/tests/test_includeJAR.R @@ -18,8 +18,8 @@ context("include an external JAR in SparkContext") runScript <- function() { sparkHome <- Sys.getenv("SPARK_HOME") - jarPath <- paste("--jars", - shQuote(file.path(sparkHome, "R/lib/SparkR/test_support/sparktestjar_2.10-1.0.jar"))) + sparkTestJarPath <- "R/lib/SparkR/test_support/sparktestjar_2.10-1.0.jar" + jarPath <- paste("--jars", shQuote(file.path(sparkHome, sparkTestJarPath))) scriptPath <- file.path(sparkHome, "R/lib/SparkR/tests/jarTest.R") submitPath <- file.path(sparkHome, "bin/spark-submit") res <- system2(command = submitPath, diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R new file mode 100644 index 0000000000000..f272de78ad4a6 --- /dev/null +++ b/R/pkg/inst/tests/test_mllib.R @@ -0,0 +1,61 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("MLlib functions") + +# Tests for MLlib functions in SparkR + +sc <- sparkR.init() + +sqlContext <- sparkRSQL.init(sc) + +test_that("glm and predict", { + training <- createDataFrame(sqlContext, iris) + test <- select(training, "Sepal_Length") + model <- glm(Sepal_Width ~ Sepal_Length, training, family = "gaussian") + prediction <- predict(model, test) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") +}) + +test_that("predictions match with native glm", { + training <- createDataFrame(sqlContext, iris) + model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) +}) + +test_that("dot minus and intercept vs native glm", { + training <- createDataFrame(sqlContext, iris) + model <- glm(Sepal_Width ~ . - Species + 0, data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) +}) + +test_that("summary coefficients match with native glm", { + training <- createDataFrame(sqlContext, iris) + stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training)) + coefs <- as.vector(stats$coefficients) + rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris))) + expect_true(all(abs(rCoefs - coefs) < 1e-6)) + expect_true(all( + as.character(stats$features) == + c("(Intercept)", "Sepal_Length", "Species__versicolor", "Species__virginica"))) +}) diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R index fc3c01d837de4..6c3aaab8c711e 100644 --- a/R/pkg/inst/tests/test_rdd.R +++ b/R/pkg/inst/tests/test_rdd.R @@ -447,7 +447,7 @@ test_that("zipRDD() on RDDs", { expect_equal(actual, list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004))) - mockFile = c("Spark is pretty.", "Spark is awesome.") + mockFile <- c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern="spark-test", fileext=".tmp") writeLines(mockFile, fileName) @@ -483,7 +483,7 @@ test_that("cartesian() on RDDs", { actual <- collect(cartesian(rdd, emptyRdd)) expect_equal(actual, list()) - mockFile = c("Spark is pretty.", "Spark is awesome.") + mockFile <- c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern="spark-test", fileext=".tmp") writeLines(mockFile, fileName) @@ -669,13 +669,15 @@ test_that("fullOuterJoin() on pairwise RDDs", { rdd1 <- parallelize(sc, list(list(1,2), list(1,3), list(3,3))) rdd2 <- parallelize(sc, list(list(1,1), list(2,4))) actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) - expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4)), list(3, list(3, NULL))) + expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), + list(2, list(NULL, 4)), list(3, list(3, NULL))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) rdd1 <- parallelize(sc, list(list("a",2), list("a",3), list("c", 1))) rdd2 <- parallelize(sc, list(list("a",1), list("b",4))) actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) - expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), list("a", list(3, 1)), list("c", list(1, NULL))) + expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), + list("a", list(3, 1)), list("c", list(1, NULL))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) @@ -683,13 +685,15 @@ test_that("fullOuterJoin() on pairwise RDDs", { rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), - sortKeyValueList(list(list(1, list(1, NULL)), list(2, list(2, NULL)), list(3, list(NULL, 3)), list(4, list(NULL, 4))))) + sortKeyValueList(list(list(1, list(1, NULL)), list(2, list(2, NULL)), + list(3, list(NULL, 3)), list(4, list(NULL, 4))))) rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), - sortKeyValueList(list(list("a", list(1, NULL)), list("b", list(2, NULL)), list("d", list(NULL, 4)), list("c", list(NULL, 3))))) + sortKeyValueList(list(list("a", list(1, NULL)), list("b", list(2, NULL)), + list("d", list(NULL, 4)), list("c", list(NULL, 3))))) }) test_that("sortByKey() on pairwise RDDs", { diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 0e4235ea8b4b3..61c8a7ec7d837 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -57,9 +57,9 @@ test_that("infer types", { expect_equal(infer_type(as.Date("2015-03-11")), "date") expect_equal(infer_type(as.POSIXlt("2015-03-11 12:13:04.043")), "timestamp") expect_equal(infer_type(c(1L, 2L)), - list(type = 'array', elementType = "integer", containsNull = TRUE)) + list(type = "array", elementType = "integer", containsNull = TRUE)) expect_equal(infer_type(list(1L, 2L)), - list(type = 'array', elementType = "integer", containsNull = TRUE)) + list(type = "array", elementType = "integer", containsNull = TRUE)) testStruct <- infer_type(list(a = 1L, b = "2")) expect_equal(class(testStruct), "structType") checkStructField(testStruct$fields()[[1]], "a", "IntegerType", TRUE) @@ -108,6 +108,33 @@ test_that("create DataFrame from RDD", { expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + df <- jsonFile(sqlContext, jsonPathNa) + hiveCtx <- tryCatch({ + newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) + }, + error = function(err) { + skip("Hive is not build with SparkSQL, skipped") + }) + sql(hiveCtx, "CREATE TABLE people (name string, age double, height float)") + insertInto(df, "people") + expect_equal(sql(hiveCtx, "SELECT age from people WHERE name = 'Bob'"), c(16)) + expect_equal(sql(hiveCtx, "SELECT height from people WHERE name ='Bob'"), c(176.5)) + + schema <- structType(structField("name", "string"), structField("age", "integer"), + structField("height", "float")) + df2 <- createDataFrame(sqlContext, df.toRDD, schema) + expect_equal(columns(df2), c("name", "age", "height")) + expect_equal(dtypes(df2), list(c("name", "string"), c("age", "int"), c("height", "float"))) + expect_equal(collect(where(df2, df2$name == "Bob")), c("Bob", 16, 176.5)) + + localDF <- data.frame(name=c("John", "Smith", "Sarah"), age=c(19, 23, 18), height=c(164.10, 181.4, 173.7)) + df <- createDataFrame(sqlContext, localDF, schema) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) + expect_equal(columns(df), c("name", "age", "height")) + expect_equal(dtypes(df), list(c("name", "string"), c("age", "int"), c("height", "float"))) + expect_equal(collect(where(df, df$name == "John")), c("John", 19, 164.10)) }) test_that("convert NAs to null type in DataFrames", { @@ -391,7 +418,7 @@ test_that("collect() and take() on a DataFrame return the same number of rows an expect_equal(ncol(collect(df)), ncol(take(df, 10))) }) -test_that("multiple pipeline transformations starting with a DataFrame result in an RDD with the correct values", { +test_that("multiple pipeline transformations result in an RDD with the correct values", { df <- jsonFile(sqlContext, jsonPath) first <- lapply(df, function(row) { row$age <- row$age + 5 @@ -576,7 +603,8 @@ test_that("write.df() as parquet file", { test_that("test HiveContext", { hiveCtx <- tryCatch({ newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) - }, error = function(err) { + }, + error = function(err) { skip("Hive is not build with SparkSQL, skipped") }) df <- createExternalTable(hiveCtx, "json", jsonPath, "json") @@ -612,6 +640,18 @@ test_that("column functions", { c7 <- floor(c) + log(c) + log10(c) + log1p(c) + rint(c) c8 <- sign(c) + sin(c) + sinh(c) + tan(c) + tanh(c) c9 <- toDegrees(c) + toRadians(c) + + df <- jsonFile(sqlContext, jsonPath) + df2 <- select(df, between(df$age, c(20, 30)), between(df$age, c(10, 20))) + expect_equal(collect(df2)[[2, 1]], TRUE) + expect_equal(collect(df2)[[2, 2]], FALSE) + expect_equal(collect(df2)[[3, 1]], FALSE) + expect_equal(collect(df2)[[3, 2]], TRUE) + + df3 <- select(df, between(df$name, c("Apache", "Spark"))) + expect_equal(collect(df3)[[1, 1]], TRUE) + expect_equal(collect(df3)[[2, 1]], FALSE) + expect_equal(collect(df3)[[3, 1]], TRUE) }) test_that("column binary mathfunctions", { @@ -756,7 +796,14 @@ test_that("toJSON() returns an RDD of the correct values", { test_that("showDF()", { df <- jsonFile(sqlContext, jsonPath) s <- capture.output(showDF(df)) - expect_output(s , "+----+-------+\n| age| name|\n+----+-------+\n|null|Michael|\n| 30| Andy|\n| 19| Justin|\n+----+-------+\n") + expected <- paste("+----+-------+\n", + "| age| name|\n", + "+----+-------+\n", + "|null|Michael|\n", + "| 30| Andy|\n", + "| 19| Justin|\n", + "+----+-------+\n", sep="") + expect_output(s , expected) }) test_that("isLocal()", { @@ -942,6 +989,24 @@ test_that("fillna() on a DataFrame", { expect_identical(expected, actual) }) +test_that("crosstab() on a DataFrame", { + rdd <- lapply(parallelize(sc, 0:3), function(x) { + list(paste0("a", x %% 3), paste0("b", x %% 2)) + }) + df <- toDF(rdd, list("a", "b")) + ct <- crosstab(df, "a", "b") + ordered <- ct[order(ct$a_b),] + row.names(ordered) <- NULL + expected <- data.frame("a_b" = c("a0", "a1", "a2"), "b0" = c(1, 0, 1), "b1" = c(1, 1, 0), + stringsAsFactors = FALSE, row.names = NULL) + expect_identical(expected, ordered) +}) + +test_that("SQL error message is returned from JVM", { + retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e) + expect_equal(grepl("Table Not Found: blah", retError), TRUE) +}) + unlink(parquetPath) unlink(jsonPath) unlink(jsonPathNa) diff --git a/R/pkg/inst/tests/test_textFile.R b/R/pkg/inst/tests/test_textFile.R index 58318dfef71ab..a9cf83dbdbdb1 100644 --- a/R/pkg/inst/tests/test_textFile.R +++ b/R/pkg/inst/tests/test_textFile.R @@ -20,7 +20,7 @@ context("the textFile() function") # JavaSparkContext handle sc <- sparkR.init() -mockFile = c("Spark is pretty.", "Spark is awesome.") +mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("textFile() on a local file returns an RDD", { fileName <- tempfile(pattern="spark-test", fileext=".tmp") diff --git a/R/pkg/inst/tests/test_utils.R b/R/pkg/inst/tests/test_utils.R index aa0d2a66b9082..12df4cf4f65b7 100644 --- a/R/pkg/inst/tests/test_utils.R +++ b/R/pkg/inst/tests/test_utils.R @@ -119,7 +119,7 @@ test_that("cleanClosure on R functions", { # Test for overriding variables in base namespace (Issue: SparkR-196). nums <- as.list(1:10) rdd <- parallelize(sc, nums, 2L) - t = 4 # Override base::t in .GlobalEnv. + t <- 4 # Override base::t in .GlobalEnv. f <- function(x) { x > t } newF <- cleanClosure(f) env <- environment(newF) diff --git a/R/run-tests.sh b/R/run-tests.sh index e82ad0ba2cd06..18a1e13bdc655 100755 --- a/R/run-tests.sh +++ b/R/run-tests.sh @@ -23,7 +23,7 @@ FAILED=0 LOGFILE=$FWDIR/unit-tests.out rm -f $LOGFILE -SPARK_TESTING=1 $FWDIR/../bin/sparkR --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE +SPARK_TESTING=1 $FWDIR/../bin/sparkR --conf spark.buffer.pageSize=4m --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE FAILED=$((PIPESTATUS[0]||$FAILED)) if [[ $FAILED != 0 ]]; then diff --git a/bin/pyspark b/bin/pyspark index f9dbddfa53560..8f2a3b5a7717b 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -82,4 +82,4 @@ fi export PYSPARK_DRIVER_PYTHON export PYSPARK_DRIVER_PYTHON_OPTS -exec "$SPARK_HOME"/bin/spark-submit pyspark-shell-main "$@" +exec "$SPARK_HOME"/bin/spark-submit pyspark-shell-main --name "PySparkShell" "$@" diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index 45e9e3def5121..3c6169983e76b 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -35,4 +35,4 @@ set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.8.2.1-src.zip;%PYTHONPATH% set OLD_PYTHONSTARTUP=%PYTHONSTARTUP% set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py -call %SPARK_HOME%\bin\spark-submit2.cmd pyspark-shell-main %* +call %SPARK_HOME%\bin\spark-submit2.cmd pyspark-shell-main --name "PySparkShell" %* diff --git a/bin/spark-shell b/bin/spark-shell index a6dc863d83fc6..00ab7afd118b5 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -47,11 +47,11 @@ function main() { # (see https://github.com/sbt/sbt/issues/562). stty -icanon min 1 -echo > /dev/null 2>&1 export SPARK_SUBMIT_OPTS="$SPARK_SUBMIT_OPTS -Djline.terminal=unix" - "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main "$@" + "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main --name "Spark shell" "$@" stty icanon echo > /dev/null 2>&1 else export SPARK_SUBMIT_OPTS - "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main "$@" + "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main --name "Spark shell" "$@" fi } diff --git a/bin/spark-shell2.cmd b/bin/spark-shell2.cmd index 251309d67f860..b9b0f510d7f5d 100644 --- a/bin/spark-shell2.cmd +++ b/bin/spark-shell2.cmd @@ -32,4 +32,4 @@ if "x%SPARK_SUBMIT_OPTS%"=="x" ( set SPARK_SUBMIT_OPTS="%SPARK_SUBMIT_OPTS% -Dscala.usejavacp=true" :run_shell -%SPARK_HOME%\bin\spark-submit2.cmd --class org.apache.spark.repl.Main %* +%SPARK_HOME%\bin\spark-submit2.cmd --class org.apache.spark.repl.Main --name "Spark shell" %* diff --git a/build/mvn b/build/mvn index e8364181e8230..f62f61ee1c416 100755 --- a/build/mvn +++ b/build/mvn @@ -112,10 +112,17 @@ install_scala() { # the environment ZINC_PORT=${ZINC_PORT:-"3030"} +# Check for the `--force` flag dictating that `mvn` should be downloaded +# regardless of whether the system already has a `mvn` install +if [ "$1" == "--force" ]; then + FORCE_MVN=1 + shift +fi + # Install Maven if necessary MVN_BIN="$(command -v mvn)" -if [ ! "$MVN_BIN" ]; then +if [ ! "$MVN_BIN" -o -n "$FORCE_MVN" ]; then install_mvn fi @@ -139,5 +146,7 @@ fi # Set any `mvn` options if not already present export MAVEN_OPTS=${MAVEN_OPTS:-"$_COMPILE_JVM_OPTS"} +echo "Using \`mvn\` from path: $MVN_BIN" + # Last, call the `mvn` command as usual ${MVN_BIN} "$@" diff --git a/build/sbt-launch-lib.bash b/build/sbt-launch-lib.bash index 504be48b358fa..7930a38b9674a 100755 --- a/build/sbt-launch-lib.bash +++ b/build/sbt-launch-lib.bash @@ -51,9 +51,13 @@ acquire_sbt_jar () { printf "Attempting to fetch sbt\n" JAR_DL="${JAR}.part" if [ $(command -v curl) ]; then - (curl --silent ${URL1} > "${JAR_DL}" || curl --silent ${URL2} > "${JAR_DL}") && mv "${JAR_DL}" "${JAR}" + (curl --fail --location --silent ${URL1} > "${JAR_DL}" ||\ + (rm -f "${JAR_DL}" && curl --fail --location --silent ${URL2} > "${JAR_DL}")) &&\ + mv "${JAR_DL}" "${JAR}" elif [ $(command -v wget) ]; then - (wget --quiet ${URL1} -O "${JAR_DL}" || wget --quiet ${URL2} -O "${JAR_DL}") && mv "${JAR_DL}" "${JAR}" + (wget --quiet ${URL1} -O "${JAR_DL}" ||\ + (rm -f "${JAR_DL}" && wget --quiet ${URL2} -O "${JAR_DL}")) &&\ + mv "${JAR_DL}" "${JAR}" else printf "You do not have curl or wget installed, please install sbt manually from http://www.scala-sbt.org/\n" exit -1 diff --git a/conf/log4j.properties.template b/conf/log4j.properties.template index 3a2a88219818f..27006e45e932b 100644 --- a/conf/log4j.properties.template +++ b/conf/log4j.properties.template @@ -10,3 +10,7 @@ log4j.logger.org.spark-project.jetty=WARN log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO + +# SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support +log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL +log4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR diff --git a/core/pom.xml b/core/pom.xml index aee0d92620606..202678779150b 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -34,6 +34,11 @@ Spark Project Core http://spark.apache.org/ + + org.apache.avro + avro-mapred + ${avro.mapred.classifier} + com.google.guava guava @@ -261,7 +266,7 @@ com.fasterxml.jackson.module - jackson-module-scala_2.10 + jackson-module-scala_${scala.binary.version} org.apache.derby @@ -281,7 +286,7 @@ org.tachyonproject tachyon-client - 0.6.4 + 0.7.0 org.apache.hadoop @@ -292,36 +297,12 @@ curator-recipes - org.eclipse.jetty - jetty-jsp - - - org.eclipse.jetty - jetty-webapp - - - org.eclipse.jetty - jetty-server - - - org.eclipse.jetty - jetty-servlet - - - junit - junit + org.tachyonproject + tachyon-underfs-glusterfs - org.powermock - powermock-module-junit4 - - - org.powermock - powermock-api-mockito - - - org.apache.curator - curator-test + org.tachyonproject + tachyon-underfs-s3 @@ -342,6 +323,16 @@ xml-apis test + + org.hamcrest + hamcrest-core + test + + + org.hamcrest + hamcrest-library + test + org.mockito mockito-core @@ -358,18 +349,13 @@ test - org.hamcrest - hamcrest-core - test - - - org.hamcrest - hamcrest-library + com.novocode + junit-interface test - com.novocode - junit-interface + org.apache.curator + curator-test test diff --git a/core/src/main/java/org/apache/spark/JavaSparkListener.java b/core/src/main/java/org/apache/spark/JavaSparkListener.java index 646496f313507..fa9acf0a15b88 100644 --- a/core/src/main/java/org/apache/spark/JavaSparkListener.java +++ b/core/src/main/java/org/apache/spark/JavaSparkListener.java @@ -17,23 +17,7 @@ package org.apache.spark; -import org.apache.spark.scheduler.SparkListener; -import org.apache.spark.scheduler.SparkListenerApplicationEnd; -import org.apache.spark.scheduler.SparkListenerApplicationStart; -import org.apache.spark.scheduler.SparkListenerBlockManagerAdded; -import org.apache.spark.scheduler.SparkListenerBlockManagerRemoved; -import org.apache.spark.scheduler.SparkListenerEnvironmentUpdate; -import org.apache.spark.scheduler.SparkListenerExecutorAdded; -import org.apache.spark.scheduler.SparkListenerExecutorMetricsUpdate; -import org.apache.spark.scheduler.SparkListenerExecutorRemoved; -import org.apache.spark.scheduler.SparkListenerJobEnd; -import org.apache.spark.scheduler.SparkListenerJobStart; -import org.apache.spark.scheduler.SparkListenerStageCompleted; -import org.apache.spark.scheduler.SparkListenerStageSubmitted; -import org.apache.spark.scheduler.SparkListenerTaskEnd; -import org.apache.spark.scheduler.SparkListenerTaskGettingResult; -import org.apache.spark.scheduler.SparkListenerTaskStart; -import org.apache.spark.scheduler.SparkListenerUnpersistRDD; +import org.apache.spark.scheduler.*; /** * Java clients should extend this class instead of implementing @@ -94,4 +78,8 @@ public void onExecutorAdded(SparkListenerExecutorAdded executorAdded) { } @Override public void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { } + + @Override + public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { } + } diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java index fbc5666959055..1214d05ba6063 100644 --- a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java +++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java @@ -112,4 +112,10 @@ public final void onExecutorAdded(SparkListenerExecutorAdded executorAdded) { public final void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { onEvent(executorRemoved); } + + @Override + public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { + onEvent(blockUpdated); + } + } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java b/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java similarity index 91% rename from core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java rename to core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java index 3f746b886bc9b..0399abc63c235 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java +++ b/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.serializer; import java.io.IOException; import java.io.InputStream; @@ -24,9 +24,7 @@ import scala.reflect.ClassTag; -import org.apache.spark.serializer.DeserializationStream; -import org.apache.spark.serializer.SerializationStream; -import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.annotation.Private; import org.apache.spark.unsafe.PlatformDependent; /** @@ -35,7 +33,8 @@ * `write() OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work * around this, we pass a dummy no-op serializer. */ -final class DummySerializerInstance extends SerializerInstance { +@Private +public final class DummySerializerInstance extends SerializerInstance { public static final DummySerializerInstance INSTANCE = new DummySerializerInstance(); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index d3d6280284beb..0b8b604e18494 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -75,7 +75,7 @@ final class BypassMergeSortShuffleWriter implements SortShuffleFileWriter< private final Serializer serializer; /** Array of file writers, one for each partition */ - private BlockObjectWriter[] partitionWriters; + private DiskBlockObjectWriter[] partitionWriters; public BypassMergeSortShuffleWriter( SparkConf conf, @@ -101,7 +101,7 @@ public void insertAll(Iterator> records) throws IOException { } final SerializerInstance serInstance = serializer.newInstance(); final long openStartTime = System.nanoTime(); - partitionWriters = new BlockObjectWriter[numPartitions]; + partitionWriters = new DiskBlockObjectWriter[numPartitions]; for (int i = 0; i < numPartitions; i++) { final Tuple2 tempShuffleBlockIdPlusFile = blockManager.diskBlockManager().createTempShuffleBlock(); @@ -121,7 +121,7 @@ public void insertAll(Iterator> records) throws IOException { partitionWriters[partitioner.getPartition(key)].write(key, record._2()); } - for (BlockObjectWriter writer : partitionWriters) { + for (DiskBlockObjectWriter writer : partitionWriters) { writer.commitAndClose(); } } @@ -169,7 +169,7 @@ public void stop() throws IOException { if (partitionWriters != null) { try { final DiskBlockManager diskBlockManager = blockManager.diskBlockManager(); - for (BlockObjectWriter writer : partitionWriters) { + for (DiskBlockObjectWriter writer : partitionWriters) { // This method explicitly does _not_ throw exceptions: writer.revertPartialWritesAndClose(); if (!diskBlockManager.getFile(writer.blockId()).delete()) { diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index 9e9ed94b7890c..1aa6ba4201261 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -30,6 +30,7 @@ import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.serializer.DummySerializerInstance; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.storage.*; @@ -58,14 +59,14 @@ final class UnsafeShuffleExternalSorter { private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class); - private static final int PAGE_SIZE = PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES; @VisibleForTesting static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024; - @VisibleForTesting - static final int MAX_RECORD_SIZE = PAGE_SIZE - 4; private final int initialSize; private final int numPartitions; + private final int pageSizeBytes; + @VisibleForTesting + final int maxRecordSizeBytes; private final TaskMemoryManager memoryManager; private final ShuffleMemoryManager shuffleMemoryManager; private final BlockManager blockManager; @@ -108,7 +109,10 @@ public UnsafeShuffleExternalSorter( this.numPartitions = numPartitions; // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; - + this.pageSizeBytes = (int) Math.min( + PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, + conf.getSizeAsBytes("spark.buffer.pageSize", "64m")); + this.maxRecordSizeBytes = pageSizeBytes - 4; this.writeMetrics = writeMetrics; initializeForWriting(); } @@ -156,7 +160,7 @@ private void writeSortedFile(boolean isLastFile) throws IOException { // Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this // after SPARK-5581 is fixed. - BlockObjectWriter writer; + DiskBlockObjectWriter writer; // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to // be an API to directly transfer bytes from managed memory to the disk writer, we buffer @@ -271,7 +275,11 @@ void spill() throws IOException { } private long getMemoryUsage() { - return sorter.getMemoryUsage() + (allocatedPages.size() * (long) PAGE_SIZE); + long totalPageSize = 0; + for (MemoryBlock page : allocatedPages) { + totalPageSize += page.size(); + } + return sorter.getMemoryUsage() + totalPageSize; } private long freeMemory() { @@ -345,23 +353,23 @@ private void allocateSpaceForRecord(int requiredSpace) throws IOException { // TODO: we should track metrics on the amount of space wasted when we roll over to a new page // without using the free space at the end of the current page. We should also do this for // BytesToBytesMap. - if (requiredSpace > PAGE_SIZE) { + if (requiredSpace > pageSizeBytes) { throw new IOException("Required space " + requiredSpace + " is greater than page size (" + - PAGE_SIZE + ")"); + pageSizeBytes + ")"); } else { - final long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); - if (memoryAcquired < PAGE_SIZE) { + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes); + if (memoryAcquired < pageSizeBytes) { shuffleMemoryManager.release(memoryAcquired); spill(); - final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); - if (memoryAcquiredAfterSpilling != PAGE_SIZE) { + final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes); + if (memoryAcquiredAfterSpilling != pageSizeBytes) { shuffleMemoryManager.release(memoryAcquiredAfterSpilling); - throw new IOException("Unable to acquire " + PAGE_SIZE + " bytes of memory"); + throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory"); } } - currentPage = memoryManager.allocatePage(PAGE_SIZE); + currentPage = memoryManager.allocatePage(pageSizeBytes); currentPagePosition = currentPage.getBaseOffset(); - freeSpaceInCurrentPage = PAGE_SIZE; + freeSpaceInCurrentPage = pageSizeBytes; allocatedPages.add(currentPage); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 764578b181422..d47d6fc9c2ac4 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -129,6 +129,11 @@ public UnsafeShuffleWriter( open(); } + @VisibleForTesting + public int maxRecordSizeBytes() { + return sorter.maxRecordSizeBytes; + } + /** * This convenience method should only be called in test code. */ diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java new file mode 100644 index 0000000000000..45b78829e4cf7 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import org.apache.spark.annotation.Private; + +/** + * Compares 8-byte key prefixes in prefix sort. Subclasses may implement type-specific + * comparisons, such as lexicographic comparison for strings. + */ +@Private +public abstract class PrefixComparator { + public abstract int compare(long prefix1, long prefix2); +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java new file mode 100644 index 0000000000000..4d7e5b3dfba6e --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import com.google.common.primitives.UnsignedLongs; + +import org.apache.spark.annotation.Private; +import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.util.Utils; + +@Private +public class PrefixComparators { + private PrefixComparators() {} + + public static final StringPrefixComparator STRING = new StringPrefixComparator(); + public static final StringPrefixComparatorDesc STRING_DESC = new StringPrefixComparatorDesc(); + public static final LongPrefixComparator LONG = new LongPrefixComparator(); + public static final LongPrefixComparatorDesc LONG_DESC = new LongPrefixComparatorDesc(); + public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator(); + public static final DoublePrefixComparatorDesc DOUBLE_DESC = new DoublePrefixComparatorDesc(); + + public static final class StringPrefixComparator extends PrefixComparator { + @Override + public int compare(long aPrefix, long bPrefix) { + return UnsignedLongs.compare(aPrefix, bPrefix); + } + + public static long computePrefix(UTF8String value) { + return value == null ? 0L : value.getPrefix(); + } + } + + public static final class StringPrefixComparatorDesc extends PrefixComparator { + @Override + public int compare(long bPrefix, long aPrefix) { + return UnsignedLongs.compare(aPrefix, bPrefix); + } + } + + public static final class LongPrefixComparator extends PrefixComparator { + @Override + public int compare(long a, long b) { + return (a < b) ? -1 : (a > b) ? 1 : 0; + } + } + + public static final class LongPrefixComparatorDesc extends PrefixComparator { + @Override + public int compare(long b, long a) { + return (a < b) ? -1 : (a > b) ? 1 : 0; + } + } + + public static final class DoublePrefixComparator extends PrefixComparator { + @Override + public int compare(long aPrefix, long bPrefix) { + double a = Double.longBitsToDouble(aPrefix); + double b = Double.longBitsToDouble(bPrefix); + return Utils.nanSafeCompareDoubles(a, b); + } + + public static long computePrefix(double value) { + return Double.doubleToLongBits(value); + } + } + + public static final class DoublePrefixComparatorDesc extends PrefixComparator { + @Override + public int compare(long bPrefix, long aPrefix) { + double a = Double.longBitsToDouble(aPrefix); + double b = Double.longBitsToDouble(bPrefix); + return Utils.nanSafeCompareDoubles(a, b); + } + + public static long computePrefix(double value) { + return Double.doubleToLongBits(value); + } + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java new file mode 100644 index 0000000000000..09e4258792204 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +/** + * Compares records for ordering. In cases where the entire sorting key can fit in the 8-byte + * prefix, this may simply return 0. + */ +public abstract class RecordComparator { + + /** + * Compare two records for order. + * + * @return a negative integer, zero, or a positive integer as the first record is less than, + * equal to, or greater than the second. + */ + public abstract int compare( + Object leftBaseObject, + long leftBaseOffset, + Object rightBaseObject, + long rightBaseOffset); +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java new file mode 100644 index 0000000000000..0c4ebde407cfc --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +final class RecordPointerAndKeyPrefix { + /** + * A pointer to a record; see {@link org.apache.spark.unsafe.memory.TaskMemoryManager} for a + * description of how these addresses are encoded. + */ + public long recordPointer; + + /** + * A key prefix, for use in comparisons. + */ + public long keyPrefix; +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java new file mode 100644 index 0000000000000..866e0b4151577 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -0,0 +1,303 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.IOException; +import java.util.LinkedList; + +import scala.runtime.AbstractFunction0; +import scala.runtime.BoxedUnit; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.util.Utils; + +/** + * External sorter based on {@link UnsafeInMemorySorter}. + */ +public final class UnsafeExternalSorter { + + private final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class); + + private final long pageSizeBytes; + private final PrefixComparator prefixComparator; + private final RecordComparator recordComparator; + private final int initialSize; + private final TaskMemoryManager memoryManager; + private final ShuffleMemoryManager shuffleMemoryManager; + private final BlockManager blockManager; + private final TaskContext taskContext; + private ShuffleWriteMetrics writeMetrics; + + /** The buffer size to use when writing spills using DiskBlockObjectWriter */ + private final int fileBufferSizeBytes; + + /** + * Memory pages that hold the records being sorted. The pages in this list are freed when + * spilling, although in principle we could recycle these pages across spills (on the other hand, + * this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager + * itself). + */ + private final LinkedList allocatedPages = new LinkedList(); + + // These variables are reset after spilling: + private UnsafeInMemorySorter sorter; + private MemoryBlock currentPage = null; + private long currentPagePosition = -1; + private long freeSpaceInCurrentPage = 0; + + private final LinkedList spillWriters = new LinkedList<>(); + + public UnsafeExternalSorter( + TaskMemoryManager memoryManager, + ShuffleMemoryManager shuffleMemoryManager, + BlockManager blockManager, + TaskContext taskContext, + RecordComparator recordComparator, + PrefixComparator prefixComparator, + int initialSize, + SparkConf conf) throws IOException { + this.memoryManager = memoryManager; + this.shuffleMemoryManager = shuffleMemoryManager; + this.blockManager = blockManager; + this.taskContext = taskContext; + this.recordComparator = recordComparator; + this.prefixComparator = prefixComparator; + this.initialSize = initialSize; + // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units + this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + this.pageSizeBytes = conf.getSizeAsBytes("spark.buffer.pageSize", "64m"); + initializeForWriting(); + + // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at + // the end of the task. This is necessary to avoid memory leaks in when the downstream operator + // does not fully consume the sorter's output (e.g. sort followed by limit). + taskContext.addOnCompleteCallback(new AbstractFunction0() { + @Override + public BoxedUnit apply() { + freeMemory(); + return null; + } + }); + } + + // TODO: metrics tracking + integration with shuffle write metrics + // need to connect the write metrics to task metrics so we count the spill IO somewhere. + + /** + * Allocates new sort data structures. Called when creating the sorter and after each spill. + */ + private void initializeForWriting() throws IOException { + this.writeMetrics = new ShuffleWriteMetrics(); + // TODO: move this sizing calculation logic into a static method of sorter: + final long memoryRequested = initialSize * 8L * 2; + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested); + if (memoryAcquired != memoryRequested) { + shuffleMemoryManager.release(memoryAcquired); + throw new IOException("Could not acquire " + memoryRequested + " bytes of memory"); + } + + this.sorter = + new UnsafeInMemorySorter(memoryManager, recordComparator, prefixComparator, initialSize); + } + + /** + * Sort and spill the current records in response to memory pressure. + */ + @VisibleForTesting + public void spill() throws IOException { + logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", + Thread.currentThread().getId(), + Utils.bytesToString(getMemoryUsage()), + spillWriters.size(), + spillWriters.size() > 1 ? " times" : " time"); + + final UnsafeSorterSpillWriter spillWriter = + new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, + sorter.numRecords()); + spillWriters.add(spillWriter); + final UnsafeSorterIterator sortedRecords = sorter.getSortedIterator(); + while (sortedRecords.hasNext()) { + sortedRecords.loadNext(); + final Object baseObject = sortedRecords.getBaseObject(); + final long baseOffset = sortedRecords.getBaseOffset(); + final int recordLength = sortedRecords.getRecordLength(); + spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix()); + } + spillWriter.close(); + final long sorterMemoryUsage = sorter.getMemoryUsage(); + sorter = null; + shuffleMemoryManager.release(sorterMemoryUsage); + final long spillSize = freeMemory(); + taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); + initializeForWriting(); + } + + private long getMemoryUsage() { + long totalPageSize = 0; + for (MemoryBlock page : allocatedPages) { + totalPageSize += page.size(); + } + return sorter.getMemoryUsage() + totalPageSize; + } + + @VisibleForTesting + public int getNumberOfAllocatedPages() { + return allocatedPages.size(); + } + + public long freeMemory() { + long memoryFreed = 0; + for (MemoryBlock block : allocatedPages) { + memoryManager.freePage(block); + shuffleMemoryManager.release(block.size()); + memoryFreed += block.size(); + } + allocatedPages.clear(); + currentPage = null; + currentPagePosition = -1; + freeSpaceInCurrentPage = 0; + return memoryFreed; + } + + /** + * Checks whether there is enough space to insert a new record into the sorter. + * + * @param requiredSpace the required space in the data page, in bytes, including space for storing + * the record size. + + * @return true if the record can be inserted without requiring more allocations, false otherwise. + */ + private boolean haveSpaceForRecord(int requiredSpace) { + assert (requiredSpace > 0); + return (sorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage)); + } + + /** + * Allocates more memory in order to insert an additional record. This will request additional + * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be + * obtained. + * + * @param requiredSpace the required space in the data page, in bytes, including space for storing + * the record size. + */ + private void allocateSpaceForRecord(int requiredSpace) throws IOException { + // TODO: merge these steps to first calculate total memory requirements for this insert, + // then try to acquire; no point in acquiring sort buffer only to spill due to no space in the + // data page. + if (!sorter.hasSpaceForAnotherRecord()) { + logger.debug("Attempting to expand sort pointer array"); + final long oldPointerArrayMemoryUsage = sorter.getMemoryUsage(); + final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2; + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray); + if (memoryAcquired < memoryToGrowPointerArray) { + shuffleMemoryManager.release(memoryAcquired); + spill(); + } else { + sorter.expandPointerArray(); + shuffleMemoryManager.release(oldPointerArrayMemoryUsage); + } + } + + if (requiredSpace > freeSpaceInCurrentPage) { + logger.trace("Required space {} is less than free space in current page ({})", requiredSpace, + freeSpaceInCurrentPage); + // TODO: we should track metrics on the amount of space wasted when we roll over to a new page + // without using the free space at the end of the current page. We should also do this for + // BytesToBytesMap. + if (requiredSpace > pageSizeBytes) { + throw new IOException("Required space " + requiredSpace + " is greater than page size (" + + pageSizeBytes + ")"); + } else { + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes); + if (memoryAcquired < pageSizeBytes) { + shuffleMemoryManager.release(memoryAcquired); + spill(); + final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes); + if (memoryAcquiredAfterSpilling != pageSizeBytes) { + shuffleMemoryManager.release(memoryAcquiredAfterSpilling); + throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory"); + } + } + currentPage = memoryManager.allocatePage(pageSizeBytes); + currentPagePosition = currentPage.getBaseOffset(); + freeSpaceInCurrentPage = pageSizeBytes; + allocatedPages.add(currentPage); + } + } + } + + /** + * Write a record to the sorter. + */ + public void insertRecord( + Object recordBaseObject, + long recordBaseOffset, + int lengthInBytes, + long prefix) throws IOException { + // Need 4 bytes to store the record length. + final int totalSpaceRequired = lengthInBytes + 4; + if (!haveSpaceForRecord(totalSpaceRequired)) { + allocateSpaceForRecord(totalSpaceRequired); + } + + final long recordAddress = + memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition); + final Object dataPageBaseObject = currentPage.getBaseObject(); + PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, lengthInBytes); + currentPagePosition += 4; + PlatformDependent.copyMemory( + recordBaseObject, + recordBaseOffset, + dataPageBaseObject, + currentPagePosition, + lengthInBytes); + currentPagePosition += lengthInBytes; + freeSpaceInCurrentPage -= totalSpaceRequired; + sorter.insertRecord(recordAddress, prefix); + } + + public UnsafeSorterIterator getSortedIterator() throws IOException { + final UnsafeSorterIterator inMemoryIterator = sorter.getSortedIterator(); + int numIteratorsToMerge = spillWriters.size() + (inMemoryIterator.hasNext() ? 1 : 0); + if (spillWriters.isEmpty()) { + return inMemoryIterator; + } else { + final UnsafeSorterSpillMerger spillMerger = + new UnsafeSorterSpillMerger(recordComparator, prefixComparator, numIteratorsToMerge); + for (UnsafeSorterSpillWriter spillWriter : spillWriters) { + spillMerger.addSpill(spillWriter.getReader(blockManager)); + } + spillWriters.clear(); + if (inMemoryIterator.hasNext()) { + spillMerger.addSpill(inMemoryIterator); + } + return spillMerger.getSortedIterator(); + } + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java new file mode 100644 index 0000000000000..fc34ad9cff369 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.util.Comparator; + +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.util.collection.Sorter; +import org.apache.spark.unsafe.memory.TaskMemoryManager; + +/** + * Sorts records using an AlphaSort-style key-prefix sort. This sort stores pointers to records + * alongside a user-defined prefix of the record's sorting key. When the underlying sort algorithm + * compares records, it will first compare the stored key prefixes; if the prefixes are not equal, + * then we do not need to traverse the record pointers to compare the actual records. Avoiding these + * random memory accesses improves cache hit rates. + */ +public final class UnsafeInMemorySorter { + + private static final class SortComparator implements Comparator { + + private final RecordComparator recordComparator; + private final PrefixComparator prefixComparator; + private final TaskMemoryManager memoryManager; + + SortComparator( + RecordComparator recordComparator, + PrefixComparator prefixComparator, + TaskMemoryManager memoryManager) { + this.recordComparator = recordComparator; + this.prefixComparator = prefixComparator; + this.memoryManager = memoryManager; + } + + @Override + public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) { + final int prefixComparisonResult = prefixComparator.compare(r1.keyPrefix, r2.keyPrefix); + if (prefixComparisonResult == 0) { + final Object baseObject1 = memoryManager.getPage(r1.recordPointer); + final long baseOffset1 = memoryManager.getOffsetInPage(r1.recordPointer) + 4; // skip length + final Object baseObject2 = memoryManager.getPage(r2.recordPointer); + final long baseOffset2 = memoryManager.getOffsetInPage(r2.recordPointer) + 4; // skip length + return recordComparator.compare(baseObject1, baseOffset1, baseObject2, baseOffset2); + } else { + return prefixComparisonResult; + } + } + } + + private final TaskMemoryManager memoryManager; + private final Sorter sorter; + private final Comparator sortComparator; + + /** + * Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at + * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. + */ + private long[] pointerArray; + + /** + * The position in the sort buffer where new records can be inserted. + */ + private int pointerArrayInsertPosition = 0; + + public UnsafeInMemorySorter( + final TaskMemoryManager memoryManager, + final RecordComparator recordComparator, + final PrefixComparator prefixComparator, + int initialSize) { + assert (initialSize > 0); + this.pointerArray = new long[initialSize * 2]; + this.memoryManager = memoryManager; + this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE); + this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager); + } + + /** + * @return the number of records that have been inserted into this sorter. + */ + public int numRecords() { + return pointerArrayInsertPosition / 2; + } + + public long getMemoryUsage() { + return pointerArray.length * 8L; + } + + public boolean hasSpaceForAnotherRecord() { + return pointerArrayInsertPosition + 2 < pointerArray.length; + } + + public void expandPointerArray() { + final long[] oldArray = pointerArray; + // Guard against overflow: + final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE; + pointerArray = new long[newLength]; + System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length); + } + + /** + * Inserts a record to be sorted. Assumes that the record pointer points to a record length + * stored as a 4-byte integer, followed by the record's bytes. + * + * @param recordPointer pointer to a record in a data page, encoded by {@link TaskMemoryManager}. + * @param keyPrefix a user-defined key prefix + */ + public void insertRecord(long recordPointer, long keyPrefix) { + if (!hasSpaceForAnotherRecord()) { + expandPointerArray(); + } + pointerArray[pointerArrayInsertPosition] = recordPointer; + pointerArrayInsertPosition++; + pointerArray[pointerArrayInsertPosition] = keyPrefix; + pointerArrayInsertPosition++; + } + + private static final class SortedIterator extends UnsafeSorterIterator { + + private final TaskMemoryManager memoryManager; + private final int sortBufferInsertPosition; + private final long[] sortBuffer; + private int position = 0; + private Object baseObject; + private long baseOffset; + private long keyPrefix; + private int recordLength; + + SortedIterator( + TaskMemoryManager memoryManager, + int sortBufferInsertPosition, + long[] sortBuffer) { + this.memoryManager = memoryManager; + this.sortBufferInsertPosition = sortBufferInsertPosition; + this.sortBuffer = sortBuffer; + } + + @Override + public boolean hasNext() { + return position < sortBufferInsertPosition; + } + + @Override + public void loadNext() { + // This pointer points to a 4-byte record length, followed by the record's bytes + final long recordPointer = sortBuffer[position]; + baseObject = memoryManager.getPage(recordPointer); + baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length + recordLength = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset - 4); + keyPrefix = sortBuffer[position + 1]; + position += 2; + } + + @Override + public Object getBaseObject() { return baseObject; } + + @Override + public long getBaseOffset() { return baseOffset; } + + @Override + public int getRecordLength() { return recordLength; } + + @Override + public long getKeyPrefix() { return keyPrefix; } + } + + /** + * Return an iterator over record pointers in sorted order. For efficiency, all calls to + * {@code next()} will return the same mutable object. + */ + public UnsafeSorterIterator getSortedIterator() { + sorter.sort(pointerArray, 0, pointerArrayInsertPosition / 2, sortComparator); + return new SortedIterator(memoryManager, pointerArrayInsertPosition, pointerArray); + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java new file mode 100644 index 0000000000000..d09c728a7a638 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import org.apache.spark.util.collection.SortDataFormat; + +/** + * Supports sorting an array of (record pointer, key prefix) pairs. + * Used in {@link UnsafeInMemorySorter}. + *

+ * Within each long[] buffer, position {@code 2 * i} holds a pointer pointer to the record at + * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. + */ +final class UnsafeSortDataFormat extends SortDataFormat { + + public static final UnsafeSortDataFormat INSTANCE = new UnsafeSortDataFormat(); + + private UnsafeSortDataFormat() { } + + @Override + public RecordPointerAndKeyPrefix getKey(long[] data, int pos) { + // Since we re-use keys, this method shouldn't be called. + throw new UnsupportedOperationException(); + } + + @Override + public RecordPointerAndKeyPrefix newKey() { + return new RecordPointerAndKeyPrefix(); + } + + @Override + public RecordPointerAndKeyPrefix getKey(long[] data, int pos, RecordPointerAndKeyPrefix reuse) { + reuse.recordPointer = data[pos * 2]; + reuse.keyPrefix = data[pos * 2 + 1]; + return reuse; + } + + @Override + public void swap(long[] data, int pos0, int pos1) { + long tempPointer = data[pos0 * 2]; + long tempKeyPrefix = data[pos0 * 2 + 1]; + data[pos0 * 2] = data[pos1 * 2]; + data[pos0 * 2 + 1] = data[pos1 * 2 + 1]; + data[pos1 * 2] = tempPointer; + data[pos1 * 2 + 1] = tempKeyPrefix; + } + + @Override + public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) { + dst[dstPos * 2] = src[srcPos * 2]; + dst[dstPos * 2 + 1] = src[srcPos * 2 + 1]; + } + + @Override + public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) { + System.arraycopy(src, srcPos * 2, dst, dstPos * 2, length * 2); + } + + @Override + public long[] allocate(int length) { + assert (length < Integer.MAX_VALUE / 2) : "Length " + length + " is too large"; + return new long[length * 2]; + } + +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java new file mode 100644 index 0000000000000..16ac2e8d821ba --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.IOException; + +public abstract class UnsafeSorterIterator { + + public abstract boolean hasNext(); + + public abstract void loadNext() throws IOException; + + public abstract Object getBaseObject(); + + public abstract long getBaseOffset(); + + public abstract int getRecordLength(); + + public abstract long getKeyPrefix(); +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java new file mode 100644 index 0000000000000..8272c2a5be0d1 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.IOException; +import java.util.Comparator; +import java.util.PriorityQueue; + +final class UnsafeSorterSpillMerger { + + private final PriorityQueue priorityQueue; + + public UnsafeSorterSpillMerger( + final RecordComparator recordComparator, + final PrefixComparator prefixComparator, + final int numSpills) { + final Comparator comparator = new Comparator() { + + @Override + public int compare(UnsafeSorterIterator left, UnsafeSorterIterator right) { + final int prefixComparisonResult = + prefixComparator.compare(left.getKeyPrefix(), right.getKeyPrefix()); + if (prefixComparisonResult == 0) { + return recordComparator.compare( + left.getBaseObject(), left.getBaseOffset(), + right.getBaseObject(), right.getBaseOffset()); + } else { + return prefixComparisonResult; + } + } + }; + priorityQueue = new PriorityQueue(numSpills, comparator); + } + + public void addSpill(UnsafeSorterIterator spillReader) throws IOException { + if (spillReader.hasNext()) { + spillReader.loadNext(); + } + priorityQueue.add(spillReader); + } + + public UnsafeSorterIterator getSortedIterator() throws IOException { + return new UnsafeSorterIterator() { + + private UnsafeSorterIterator spillReader; + + @Override + public boolean hasNext() { + return !priorityQueue.isEmpty() || (spillReader != null && spillReader.hasNext()); + } + + @Override + public void loadNext() throws IOException { + if (spillReader != null) { + if (spillReader.hasNext()) { + spillReader.loadNext(); + priorityQueue.add(spillReader); + } + } + spillReader = priorityQueue.remove(); + } + + @Override + public Object getBaseObject() { return spillReader.getBaseObject(); } + + @Override + public long getBaseOffset() { return spillReader.getBaseOffset(); } + + @Override + public int getRecordLength() { return spillReader.getRecordLength(); } + + @Override + public long getKeyPrefix() { return spillReader.getKeyPrefix(); } + }; + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java new file mode 100644 index 0000000000000..29e9e0f30f934 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.*; + +import com.google.common.io.ByteStreams; + +import org.apache.spark.storage.BlockId; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.unsafe.PlatformDependent; + +/** + * Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description + * of the file format). + */ +final class UnsafeSorterSpillReader extends UnsafeSorterIterator { + + private InputStream in; + private DataInputStream din; + + // Variables that change with every record read: + private int recordLength; + private long keyPrefix; + private int numRecordsRemaining; + + private byte[] arr = new byte[1024 * 1024]; + private Object baseObject = arr; + private final long baseOffset = PlatformDependent.BYTE_ARRAY_OFFSET; + + public UnsafeSorterSpillReader( + BlockManager blockManager, + File file, + BlockId blockId) throws IOException { + assert (file.length() > 0); + final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file)); + this.in = blockManager.wrapForCompression(blockId, bs); + this.din = new DataInputStream(this.in); + numRecordsRemaining = din.readInt(); + } + + @Override + public boolean hasNext() { + return (numRecordsRemaining > 0); + } + + @Override + public void loadNext() throws IOException { + recordLength = din.readInt(); + keyPrefix = din.readLong(); + if (recordLength > arr.length) { + arr = new byte[recordLength]; + baseObject = arr; + } + ByteStreams.readFully(in, arr, 0, recordLength); + numRecordsRemaining--; + if (numRecordsRemaining == 0) { + in.close(); + in = null; + din = null; + } + } + + @Override + public Object getBaseObject() { + return baseObject; + } + + @Override + public long getBaseOffset() { + return baseOffset; + } + + @Override + public int getRecordLength() { + return recordLength; + } + + @Override + public long getKeyPrefix() { + return keyPrefix; + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java new file mode 100644 index 0000000000000..71eed29563d4a --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.File; +import java.io.IOException; + +import scala.Tuple2; + +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.serializer.DummySerializerInstance; +import org.apache.spark.storage.BlockId; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.DiskBlockObjectWriter; +import org.apache.spark.storage.TempLocalBlockId; +import org.apache.spark.unsafe.PlatformDependent; + +/** + * Spills a list of sorted records to disk. Spill files have the following format: + * + * [# of records (int)] [[len (int)][prefix (long)][data (bytes)]...] + */ +final class UnsafeSorterSpillWriter { + + static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024; + + // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to + // be an API to directly transfer bytes from managed memory to the disk writer, we buffer + // data through a byte array. + private byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE]; + + private final File file; + private final BlockId blockId; + private final int numRecordsToWrite; + private DiskBlockObjectWriter writer; + private int numRecordsSpilled = 0; + + public UnsafeSorterSpillWriter( + BlockManager blockManager, + int fileBufferSize, + ShuffleWriteMetrics writeMetrics, + int numRecordsToWrite) throws IOException { + final Tuple2 spilledFileInfo = + blockManager.diskBlockManager().createTempLocalBlock(); + this.file = spilledFileInfo._2(); + this.blockId = spilledFileInfo._1(); + this.numRecordsToWrite = numRecordsToWrite; + // Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter. + // Our write path doesn't actually use this serializer (since we end up calling the `write()` + // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work + // around this, we pass a dummy no-op serializer. + writer = blockManager.getDiskWriter( + blockId, file, DummySerializerInstance.INSTANCE, fileBufferSize, writeMetrics); + // Write the number of records + writeIntToBuffer(numRecordsToWrite, 0); + writer.write(writeBuffer, 0, 4); + } + + // Based on DataOutputStream.writeLong. + private void writeLongToBuffer(long v, int offset) throws IOException { + writeBuffer[offset + 0] = (byte)(v >>> 56); + writeBuffer[offset + 1] = (byte)(v >>> 48); + writeBuffer[offset + 2] = (byte)(v >>> 40); + writeBuffer[offset + 3] = (byte)(v >>> 32); + writeBuffer[offset + 4] = (byte)(v >>> 24); + writeBuffer[offset + 5] = (byte)(v >>> 16); + writeBuffer[offset + 6] = (byte)(v >>> 8); + writeBuffer[offset + 7] = (byte)(v >>> 0); + } + + // Based on DataOutputStream.writeInt. + private void writeIntToBuffer(int v, int offset) throws IOException { + writeBuffer[offset + 0] = (byte)(v >>> 24); + writeBuffer[offset + 1] = (byte)(v >>> 16); + writeBuffer[offset + 2] = (byte)(v >>> 8); + writeBuffer[offset + 3] = (byte)(v >>> 0); + } + + /** + * Write a record to a spill file. + * + * @param baseObject the base object / memory page containing the record + * @param baseOffset the base offset which points directly to the record data. + * @param recordLength the length of the record. + * @param keyPrefix a sort key prefix + */ + public void write( + Object baseObject, + long baseOffset, + int recordLength, + long keyPrefix) throws IOException { + if (numRecordsSpilled == numRecordsToWrite) { + throw new IllegalStateException( + "Number of records written exceeded numRecordsToWrite = " + numRecordsToWrite); + } else { + numRecordsSpilled++; + } + writeIntToBuffer(recordLength, 0); + writeLongToBuffer(keyPrefix, 4); + int dataRemaining = recordLength; + int freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE - 4 - 8; // space used by prefix + len + long recordReadPosition = baseOffset; + while (dataRemaining > 0) { + final int toTransfer = Math.min(freeSpaceInWriteBuffer, dataRemaining); + PlatformDependent.copyMemory( + baseObject, + recordReadPosition, + writeBuffer, + PlatformDependent.BYTE_ARRAY_OFFSET + (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer), + toTransfer); + writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer) + toTransfer); + recordReadPosition += toTransfer; + dataRemaining -= toTransfer; + freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE; + } + if (freeSpaceInWriteBuffer < DISK_WRITE_BUFFER_SIZE) { + writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer)); + } + writer.recordWritten(); + } + + public void close() throws IOException { + writer.commitAndClose(); + writer = null; + writeBuffer = null; + } + + public UnsafeSorterSpillReader getReader(BlockManager blockManager) throws IOException { + return new UnsafeSorterSpillReader(blockManager, file, blockId); + } +} diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties index b146f8a784127..689afea64f8db 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties @@ -10,3 +10,7 @@ log4j.logger.org.spark-project.jetty=WARN log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO + +# SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support +log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL +log4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties index 3a2a88219818f..27006e45e932b 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties @@ -10,3 +10,7 @@ log4j.logger.org.spark-project.jetty=WARN log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO + +# SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support +log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL +log4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR diff --git a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js index 0b450dc76bc38..3c8ddddf07b1e 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js +++ b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js @@ -19,6 +19,9 @@ * to be registered after the page loads. */ $(function() { $("span.expand-additional-metrics").click(function(){ + var status = window.localStorage.getItem("expand-additional-metrics") == "true"; + status = !status; + // Expand the list of additional metrics. var additionalMetricsDiv = $(this).parent().find('.additional-metrics'); $(additionalMetricsDiv).toggleClass('collapsed'); @@ -26,17 +29,31 @@ $(function() { // Switch the class of the arrow from open to closed. $(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-open'); $(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-closed'); + + window.localStorage.setItem("expand-additional-metrics", "" + status); }); + if (window.localStorage.getItem("expand-additional-metrics") == "true") { + // Set it to false so that the click function can revert it + window.localStorage.setItem("expand-additional-metrics", "false"); + $("span.expand-additional-metrics").trigger("click"); + } + stripeSummaryTable(); $('input[type="checkbox"]').click(function() { - var column = "table ." + $(this).attr("name"); + var name = $(this).attr("name") + var column = "table ." + name; + var status = window.localStorage.getItem(name) == "true"; + status = !status; $(column).toggle(); stripeSummaryTable(); + window.localStorage.setItem(name, "" + status); }); $("#select-all-metrics").click(function() { + var status = window.localStorage.getItem("select-all-metrics") == "true"; + status = !status; if (this.checked) { // Toggle all un-checked options. $('input[type="checkbox"]:not(:checked)').trigger('click'); @@ -44,6 +61,21 @@ $(function() { // Toggle all checked options. $('input[type="checkbox"]:checked').trigger('click'); } + window.localStorage.setItem("select-all-metrics", "" + status); + }); + + if (window.localStorage.getItem("select-all-metrics") == "true") { + $("#select-all-metrics").attr('checked', status); + } + + $("span.additional-metric-title").parent().find('input[type="checkbox"]').each(function() { + var name = $(this).attr("name") + // If name is undefined, then skip it because it's the "select-all-metrics" checkbox + if (name && window.localStorage.getItem(name) == "true") { + // Set it to false so that the click function can revert it + window.localStorage.setItem(name, "false"); + $(this).trigger("click") + } }); // Trigger a click on the checkbox if a user clicks the label next to it. diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js index 9fa53baaf4212..4a893bc0189aa 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js +++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js @@ -72,6 +72,14 @@ var StagePageVizConstants = { rankSep: 40 }; +/* + * Return "expand-dag-viz-arrow-job" if forJob is true. + * Otherwise, return "expand-dag-viz-arrow-stage". + */ +function expandDagVizArrowKey(forJob) { + return forJob ? "expand-dag-viz-arrow-job" : "expand-dag-viz-arrow-stage"; +} + /* * Show or hide the RDD DAG visualization. * @@ -79,6 +87,9 @@ var StagePageVizConstants = { * This is the narrow interface called from the Scala UI code. */ function toggleDagViz(forJob) { + var status = window.localStorage.getItem(expandDagVizArrowKey(forJob)) == "true"; + status = !status; + var arrowSelector = ".expand-dag-viz-arrow"; $(arrowSelector).toggleClass('arrow-closed'); $(arrowSelector).toggleClass('arrow-open'); @@ -93,8 +104,24 @@ function toggleDagViz(forJob) { // Save the graph for later so we don't have to render it again graphContainer().style("display", "none"); } + + window.localStorage.setItem(expandDagVizArrowKey(forJob), "" + status); } +$(function (){ + if (window.localStorage.getItem(expandDagVizArrowKey(false)) == "true") { + // Set it to false so that the click function can revert it + window.localStorage.setItem(expandDagVizArrowKey(false), "false"); + toggleDagViz(false); + } + + if (window.localStorage.getItem(expandDagVizArrowKey(true)) == "true") { + // Set it to false so that the click function can revert it + window.localStorage.setItem(expandDagVizArrowKey(true), "false"); + toggleDagViz(true); + } +}); + /* * Render the RDD DAG visualization. * diff --git a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js index ca74ef9d7e94e..f4453c71df1ea 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js +++ b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js @@ -66,14 +66,27 @@ function drawApplicationTimeline(groupArray, eventObjArray, startTime) { setupJobEventAction(); $("span.expand-application-timeline").click(function() { + var status = window.localStorage.getItem("expand-application-timeline") == "true"; + status = !status; + $("#application-timeline").toggleClass('collapsed'); // Switch the class of the arrow from open to closed. $(this).find('.expand-application-timeline-arrow').toggleClass('arrow-open'); $(this).find('.expand-application-timeline-arrow').toggleClass('arrow-closed'); + + window.localStorage.setItem("expand-application-timeline", "" + status); }); } +$(function (){ + if (window.localStorage.getItem("expand-application-timeline") == "true") { + // Set it to false so that the click function can revert it + window.localStorage.setItem("expand-application-timeline", "false"); + $("span.expand-application-timeline").trigger('click'); + } +}); + function drawJobTimeline(groupArray, eventObjArray, startTime) { var groups = new vis.DataSet(groupArray); var items = new vis.DataSet(eventObjArray); @@ -125,14 +138,27 @@ function drawJobTimeline(groupArray, eventObjArray, startTime) { setupStageEventAction(); $("span.expand-job-timeline").click(function() { + var status = window.localStorage.getItem("expand-job-timeline") == "true"; + status = !status; + $("#job-timeline").toggleClass('collapsed'); // Switch the class of the arrow from open to closed. $(this).find('.expand-job-timeline-arrow').toggleClass('arrow-open'); $(this).find('.expand-job-timeline-arrow').toggleClass('arrow-closed'); + + window.localStorage.setItem("expand-job-timeline", "" + status); }); } +$(function (){ + if (window.localStorage.getItem("expand-job-timeline") == "true") { + // Set it to false so that the click function can revert it + window.localStorage.setItem("expand-job-timeline", "false"); + $("span.expand-job-timeline").trigger('click'); + } +}); + function drawTaskAssignmentTimeline(groupArray, eventObjArray, minLaunchTime, maxFinishTime) { var groups = new vis.DataSet(groupArray); var items = new vis.DataSet(eventObjArray); @@ -176,14 +202,27 @@ function drawTaskAssignmentTimeline(groupArray, eventObjArray, minLaunchTime, ma setupZoomable("#task-assignment-timeline-zoom-lock", taskTimeline); $("span.expand-task-assignment-timeline").click(function() { + var status = window.localStorage.getItem("expand-task-assignment-timeline") == "true"; + status = !status; + $("#task-assignment-timeline").toggleClass("collapsed"); // Switch the class of the arrow from open to closed. $(this).find(".expand-task-assignment-timeline-arrow").toggleClass("arrow-open"); $(this).find(".expand-task-assignment-timeline-arrow").toggleClass("arrow-closed"); + + window.localStorage.setItem("expand-task-assignment-timeline", "" + status); }); } +$(function (){ + if (window.localStorage.getItem("expand-task-assignment-timeline") == "true") { + // Set it to false so that the click function can revert it + window.localStorage.setItem("expand-task-assignment-timeline", "false"); + $("span.expand-task-assignment-timeline").trigger('click'); + } +}); + function setupExecutorEventAction() { $(".item.box.executor").each(function () { $(this).hover( diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 5a8d17bd99933..eb75f26718e19 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -20,7 +20,8 @@ package org.apache.spark import java.io.{ObjectInputStream, Serializable} import scala.collection.generic.Growable -import scala.collection.mutable.Map +import scala.collection.Map +import scala.collection.mutable import scala.ref.WeakReference import scala.reflect.ClassTag @@ -39,25 +40,44 @@ import org.apache.spark.util.Utils * @param initialValue initial value of accumulator * @param param helper object defining how to add elements of type `R` and `T` * @param name human-readable name for use in Spark's web UI + * @param internal if this [[Accumulable]] is internal. Internal [[Accumulable]]s will be reported + * to the driver via heartbeats. For internal [[Accumulable]]s, `R` must be + * thread safe so that they can be reported correctly. * @tparam R the full accumulated data (result type) * @tparam T partial data that can be added in */ -class Accumulable[R, T] ( +class Accumulable[R, T] private[spark] ( @transient initialValue: R, param: AccumulableParam[R, T], - val name: Option[String]) + val name: Option[String], + internal: Boolean) extends Serializable { + private[spark] def this( + @transient initialValue: R, param: AccumulableParam[R, T], internal: Boolean) = { + this(initialValue, param, None, internal) + } + + def this(@transient initialValue: R, param: AccumulableParam[R, T], name: Option[String]) = + this(initialValue, param, name, false) + def this(@transient initialValue: R, param: AccumulableParam[R, T]) = this(initialValue, param, None) val id: Long = Accumulators.newId - @transient private var value_ = initialValue // Current value on master + @volatile @transient private var value_ : R = initialValue // Current value on master val zero = param.zero(initialValue) // Zero value to be passed to workers private var deserialized = false - Accumulators.register(this, true) + Accumulators.register(this) + + /** + * If this [[Accumulable]] is internal. Internal [[Accumulable]]s will be reported to the driver + * via heartbeats. For internal [[Accumulable]]s, `R` must be thread safe so that they can be + * reported correctly. + */ + private[spark] def isInternal: Boolean = internal /** * Add more data to this accumulator / accumulable @@ -132,7 +152,8 @@ class Accumulable[R, T] ( in.defaultReadObject() value_ = zero deserialized = true - Accumulators.register(this, false) + val taskContext = TaskContext.get() + taskContext.registerAccumulator(this) } override def toString: String = if (value_ == null) "null" else value_.toString @@ -284,16 +305,7 @@ private[spark] object Accumulators extends Logging { * It keeps weak references to these objects so that accumulators can be garbage-collected * once the RDDs and user-code that reference them are cleaned up. */ - val originals = Map[Long, WeakReference[Accumulable[_, _]]]() - - /** - * This thread-local map holds per-task copies of accumulators; it is used to collect the set - * of accumulator updates to send back to the driver when tasks complete. After tasks complete, - * this map is cleared by `Accumulators.clear()` (see Executor.scala). - */ - private val localAccums = new ThreadLocal[Map[Long, Accumulable[_, _]]]() { - override protected def initialValue() = Map[Long, Accumulable[_, _]]() - } + val originals = mutable.Map[Long, WeakReference[Accumulable[_, _]]]() private var lastId: Long = 0 @@ -302,19 +314,8 @@ private[spark] object Accumulators extends Logging { lastId } - def register(a: Accumulable[_, _], original: Boolean): Unit = synchronized { - if (original) { - originals(a.id) = new WeakReference[Accumulable[_, _]](a) - } else { - localAccums.get()(a.id) = a - } - } - - // Clear the local (non-original) accumulators for the current thread - def clear() { - synchronized { - localAccums.get.clear() - } + def register(a: Accumulable[_, _]): Unit = synchronized { + originals(a.id) = new WeakReference[Accumulable[_, _]](a) } def remove(accId: Long) { @@ -323,15 +324,6 @@ private[spark] object Accumulators extends Logging { } } - // Get the values of the local accumulators for the current thread (by ID) - def values: Map[Long, Any] = synchronized { - val ret = Map[Long, Any]() - for ((id, accum) <- localAccums.get) { - ret(id) = accum.localValue - } - return ret - } - // Add values to the original accumulators with some given IDs def add(values: Map[Long, Any]): Unit = synchronized { for ((id, value) <- values) { @@ -349,7 +341,4 @@ private[spark] object Accumulators extends Logging { } } - def stringifyPartialValue(partialValue: Any): String = "%s".format(partialValue) - - def stringifyValue(value: Any): String = "%s".format(value) } diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala index 443830f8d03b6..842bfdbadc948 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala @@ -24,11 +24,23 @@ package org.apache.spark private[spark] trait ExecutorAllocationClient { /** - * Express a preference to the cluster manager for a given total number of executors. - * This can result in canceling pending requests or filing additional requests. + * Update the cluster manager on our scheduling needs. Three bits of information are included + * to help it make decisions. + * @param numExecutors The total number of executors we'd like to have. The cluster manager + * shouldn't kill any running executor to reach this number, but, + * if all existing executors were to die, this is the number of executors + * we'd want to be allocated. + * @param localityAwareTasks The number of tasks in all active stages that have a locality + * preferences. This includes running, pending, and completed tasks. + * @param hostToLocalTaskCount A map of hosts to the number of tasks from all active stages + * that would like to like to run on that host. + * This includes running, pending, and completed tasks. * @return whether the request is acknowledged by the cluster manager. */ - private[spark] def requestTotalExecutors(numExecutors: Int): Boolean + private[spark] def requestTotalExecutors( + numExecutors: Int, + localityAwareTasks: Int, + hostToLocalTaskCount: Map[String, Int]): Boolean /** * Request an additional number of executors from the cluster manager. diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 49329423dca76..1877aaf2cac55 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -20,6 +20,7 @@ package org.apache.spark import java.util.concurrent.TimeUnit import scala.collection.mutable +import scala.util.control.ControlThrowable import com.codahale.metrics.{Gauge, MetricRegistry} @@ -102,7 +103,7 @@ private[spark] class ExecutorAllocationManager( "spark.dynamicAllocation.executorIdleTimeout", "60s") private val cachedExecutorIdleTimeoutS = conf.getTimeAsSeconds( - "spark.dynamicAllocation.cachedExecutorIdleTimeout", s"${2 * executorIdleTimeoutS}s") + "spark.dynamicAllocation.cachedExecutorIdleTimeout", s"${Integer.MAX_VALUE}s") // During testing, the methods to actually kill and add executors are mocked out private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false) @@ -160,6 +161,12 @@ private[spark] class ExecutorAllocationManager( // (2) an executor idle timeout has elapsed. @volatile private var initializing: Boolean = true + // Number of locality aware tasks, used for executor placement. + private var localityAwareTasks = 0 + + // Host to possible task running on it, used for executor placement. + private var hostToLocalTaskCount: Map[String, Int] = Map.empty + /** * Verify that the settings specified through the config are valid. * If not, throw an appropriate exception. @@ -211,7 +218,16 @@ private[spark] class ExecutorAllocationManager( listenerBus.addListener(listener) val scheduleTask = new Runnable() { - override def run(): Unit = Utils.logUncaughtExceptions(schedule()) + override def run(): Unit = { + try { + schedule() + } catch { + case ct: ControlThrowable => + throw ct + case t: Throwable => + logWarning(s"Uncaught exception in thread ${Thread.currentThread().getName}", t) + } + } } executor.scheduleAtFixedRate(scheduleTask, 0, intervalMillis, TimeUnit.MILLISECONDS) } @@ -285,7 +301,7 @@ private[spark] class ExecutorAllocationManager( // If the new target has not changed, avoid sending a message to the cluster manager if (numExecutorsTarget < oldNumExecutorsTarget) { - client.requestTotalExecutors(numExecutorsTarget) + client.requestTotalExecutors(numExecutorsTarget, localityAwareTasks, hostToLocalTaskCount) logDebug(s"Lowering target number of executors to $numExecutorsTarget (previously " + s"$oldNumExecutorsTarget) because not all requested executors are actually needed") } @@ -339,7 +355,8 @@ private[spark] class ExecutorAllocationManager( return 0 } - val addRequestAcknowledged = testing || client.requestTotalExecutors(numExecutorsTarget) + val addRequestAcknowledged = testing || + client.requestTotalExecutors(numExecutorsTarget, localityAwareTasks, hostToLocalTaskCount) if (addRequestAcknowledged) { val executorsString = "executor" + { if (delta > 1) "s" else "" } logInfo(s"Requesting $delta new $executorsString because tasks are backlogged" + @@ -509,6 +526,12 @@ private[spark] class ExecutorAllocationManager( // Number of tasks currently running on the cluster. Should be 0 when no stages are active. private var numRunningTasks: Int = _ + // stageId to tuple (the number of task with locality preferences, a map where each pair is a + // node and the number of tasks that would like to be scheduled on that node) map, + // maintain the executor placement hints for each stage Id used by resource framework to better + // place the executors. + private val stageIdToExecutorPlacementHints = new mutable.HashMap[Int, (Int, Map[String, Int])] + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = { initializing = false val stageId = stageSubmitted.stageInfo.stageId @@ -516,6 +539,24 @@ private[spark] class ExecutorAllocationManager( allocationManager.synchronized { stageIdToNumTasks(stageId) = numTasks allocationManager.onSchedulerBacklogged() + + // Compute the number of tasks requested by the stage on each host + var numTasksPending = 0 + val hostToLocalTaskCountPerStage = new mutable.HashMap[String, Int]() + stageSubmitted.stageInfo.taskLocalityPreferences.foreach { locality => + if (!locality.isEmpty) { + numTasksPending += 1 + locality.foreach { location => + val count = hostToLocalTaskCountPerStage.getOrElse(location.host, 0) + 1 + hostToLocalTaskCountPerStage(location.host) = count + } + } + } + stageIdToExecutorPlacementHints.put(stageId, + (numTasksPending, hostToLocalTaskCountPerStage.toMap)) + + // Update the executor placement hints + updateExecutorPlacementHints() } } @@ -524,6 +565,10 @@ private[spark] class ExecutorAllocationManager( allocationManager.synchronized { stageIdToNumTasks -= stageId stageIdToTaskIndices -= stageId + stageIdToExecutorPlacementHints -= stageId + + // Update the executor placement hints + updateExecutorPlacementHints() // If this is the last stage with pending tasks, mark the scheduler queue as empty // This is needed in case the stage is aborted for any reason @@ -627,6 +672,29 @@ private[spark] class ExecutorAllocationManager( def isExecutorIdle(executorId: String): Boolean = { !executorIdToTaskIds.contains(executorId) } + + /** + * Update the Executor placement hints (the number of tasks with locality preferences, + * a map where each pair is a node and the number of tasks that would like to be scheduled + * on that node). + * + * These hints are updated when stages arrive and complete, so are not up-to-date at task + * granularity within stages. + */ + def updateExecutorPlacementHints(): Unit = { + var localityAwareTasks = 0 + val localityToCount = new mutable.HashMap[String, Int]() + stageIdToExecutorPlacementHints.values.foreach { case (numTasksPending, localities) => + localityAwareTasks += numTasksPending + localities.foreach { case (hostname, count) => + val updatedCount = localityToCount.getOrElse(hostname, 0) + count + localityToCount(hostname) = updatedCount + } + } + + allocationManager.localityAwareTasks = localityAwareTasks + allocationManager.hostToLocalTaskCount = localityToCount.toMap + } } /** diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 221b1dab43278..43dd4a170731d 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -181,7 +181,9 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock) // Asynchronously kill the executor to avoid blocking the current thread killExecutorThread.submit(new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { - sc.killExecutor(executorId) + // Note: we want to get an executor back after expiring this one, + // so do not simply call `sc.killExecutor` here (SPARK-8119) + sc.killAndReplaceExecutor(executorId) } }) } diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index 7fcb7830e7b0b..f0598816d6c07 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -121,6 +121,7 @@ trait Logging { if (usingLog4j12) { val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements if (!log4j12Initialized) { + // scalastyle:off println if (Utils.isInInterpreter) { val replDefaultLogProps = "org/apache/spark/log4j-defaults-repl.properties" Option(Utils.getSparkClassLoader.getResource(replDefaultLogProps)) match { @@ -141,6 +142,7 @@ trait Logging { System.err.println(s"Spark was unable to load $defaultLogProps") } } + // scalastyle:on println } } Logging.initialized = true @@ -157,7 +159,7 @@ private object Logging { try { // We use reflection here to handle the case where users remove the // slf4j-to-jul bridge order to route their logs to JUL. - val bridgeClass = Class.forName("org.slf4j.bridge.SLF4JBridgeHandler") + val bridgeClass = Utils.classForName("org.slf4j.bridge.SLF4JBridgeHandler") bridgeClass.getMethod("removeHandlersForRootLogger").invoke(null) val installed = bridgeClass.getMethod("isInstalled").invoke(null).asInstanceOf[Boolean] if (!installed) { diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 862ffe868f58f..92218832d256f 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -21,14 +21,14 @@ import java.io._ import java.util.concurrent.ConcurrentHashMap import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import scala.collection.mutable.{HashMap, HashSet, Map} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} import scala.collection.JavaConversions._ import scala.reflect.ClassTag import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, RpcEndpoint} import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.MetadataFetchFailedException -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} import org.apache.spark.util._ private[spark] sealed trait MapOutputTrackerMessage @@ -124,10 +124,18 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } /** - * Called from executors to get the server URIs and output sizes of the map outputs of - * a given shuffle. + * Called from executors to get the server URIs and output sizes for each shuffle block that + * needs to be read from a given reduce task. + * + * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, + * and the second item is a sequence of (shuffle block id, shuffle block size) tuples + * describing the shuffle blocks that are stored at that block manager. */ - def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = { + def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int) + : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + logDebug(s"Fetching outputs for shuffle $shuffleId, reduce $reduceId") + val startTime = System.currentTimeMillis + val statuses = mapStatuses.get(shuffleId).orNull if (statuses == null) { logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") @@ -167,6 +175,9 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } } } + logDebug(s"Fetching map output location for shuffle $shuffleId, reduce $reduceId took " + + s"${System.currentTimeMillis - startTime} ms") + if (fetchedStatuses != null) { fetchedStatuses.synchronized { return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) @@ -421,23 +432,38 @@ private[spark] object MapOutputTracker extends Logging { } } - // Convert an array of MapStatuses to locations and sizes for a given reduce ID. If - // any of the statuses is null (indicating a missing location due to a failed mapper), - // throw a FetchFailedException. + /** + * Converts an array of MapStatuses for a given reduce ID to a sequence that, for each block + * manager ID, lists the shuffle block ids and corresponding shuffle block sizes stored at that + * block manager. + * + * If any of the statuses is null (indicating a missing location due to a failed mapper), + * throws a FetchFailedException. + * + * @param shuffleId Identifier for the shuffle + * @param reduceId Identifier for the reduce task + * @param statuses List of map statuses, indexed by map ID. + * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, + * and the second item is a sequence of (shuffle block id, shuffle block size) tuples + * describing the shuffle blocks that are stored at that block manager. + */ private def convertMapStatuses( shuffleId: Int, reduceId: Int, - statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = { + statuses: Array[MapStatus]): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { assert (statuses != null) - statuses.map { - status => - if (status == null) { - logError("Missing an output location for shuffle " + shuffleId) - throw new MetadataFetchFailedException( - shuffleId, reduceId, "Missing an output location for shuffle " + shuffleId) - } else { - (status.location, status.getSizeForBlock(reduceId)) - } + val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(BlockId, Long)]] + for ((status, mapId) <- statuses.zipWithIndex) { + if (status == null) { + val errorMessage = s"Missing an output location for shuffle $shuffleId" + logError(errorMessage) + throw new MetadataFetchFailedException(shuffleId, reduceId, errorMessage) + } else { + splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) += + ((ShuffleBlockId(shuffleId, mapId, reduceId), status.getSizeForBlock(reduceId))) + } } + + splitsByAddress.toSeq } } diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 82889bcd30988..4b9d59975bdc2 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -56,7 +56,7 @@ object Partitioner { */ def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = { val bySize = (Seq(rdd) ++ others).sortBy(_.partitions.size).reverse - for (r <- bySize if r.partitioner.isDefined) { + for (r <- bySize if r.partitioner.isDefined && r.partitioner.get.numPartitions > 0) { return r.partitioner.get } if (rdd.context.conf.contains("spark.default.parallelism")) { @@ -76,6 +76,8 @@ object Partitioner { * produce an unexpected or incorrect result. */ class HashPartitioner(partitions: Int) extends Partitioner { + require(partitions >= 0, s"Number of partitions ($partitions) cannot be negative.") + def numPartitions: Int = partitions def getPartition(key: Any): Int = key match { diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 6cf36fbbd6254..4161792976c7b 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -18,11 +18,12 @@ package org.apache.spark import java.util.concurrent.ConcurrentHashMap -import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ import scala.collection.mutable.LinkedHashSet +import org.apache.avro.{SchemaNormalization, Schema} + import org.apache.spark.serializer.KryoSerializer import org.apache.spark.util.Utils @@ -161,6 +162,26 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { this } + private final val avroNamespace = "avro.schema." + + /** + * Use Kryo serialization and register the given set of Avro schemas so that the generic + * record serializer can decrease network IO + */ + def registerAvroSchemas(schemas: Schema*): SparkConf = { + for (schema <- schemas) { + set(avroNamespace + SchemaNormalization.parsingFingerprint64(schema), schema.toString) + } + this + } + + /** Gets all the avro schemas in the configuration used in the generic Avro record serializer */ + def getAvroSchema: Map[Long, String] = { + getAll.filter { case (k, v) => k.startsWith(avroNamespace) } + .map { case (k, v) => (k.substring(avroNamespace.length).toLong, v) } + .toMap + } + /** Remove a parameter from the configuration */ def remove(key: String): SparkConf = { settings.remove(key) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index d2547eeff2b4e..ac6ac6c216767 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -471,7 +471,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli .orElse(Option(System.getenv("SPARK_MEM")) .map(warnSparkMem)) .map(Utils.memoryStringToMb) - .getOrElse(512) + .getOrElse(1024) // Convert java options to env vars as a work around // since we can't set env vars directly in sbt. @@ -532,7 +532,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _executorAllocationManager = if (dynamicAllocationEnabled) { assert(supportDynamicAllocation, - "Dynamic allocation of executors is currently only supported in YARN mode") + "Dynamic allocation of executors is currently only supported in YARN and Mesos mode") Some(new ExecutorAllocationManager(this, listenerBus, _conf)) } else { None @@ -853,7 +853,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli minPartitions).setName(path) } - /** * :: Experimental :: * @@ -1364,10 +1363,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** * Return whether dynamically adjusting the amount of resources allocated to - * this application is supported. This is currently only available for YARN. + * this application is supported. This is currently only available for YARN + * and Mesos coarse-grained mode. */ - private[spark] def supportDynamicAllocation = - master.contains("yarn") || _conf.getBoolean("spark.dynamicAllocation.testing", false) + private[spark] def supportDynamicAllocation: Boolean = { + (master.contains("yarn") + || master.contains("mesos") + || _conf.getBoolean("spark.dynamicAllocation.testing", false)) + } /** * :: DeveloperApi :: @@ -1379,16 +1382,29 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } /** - * Express a preference to the cluster manager for a given total number of executors. - * This can result in canceling pending requests or filing additional requests. - * This is currently only supported in YARN mode. Return whether the request is received. - */ - private[spark] override def requestTotalExecutors(numExecutors: Int): Boolean = { + * Update the cluster manager on our scheduling needs. Three bits of information are included + * to help it make decisions. + * @param numExecutors The total number of executors we'd like to have. The cluster manager + * shouldn't kill any running executor to reach this number, but, + * if all existing executors were to die, this is the number of executors + * we'd want to be allocated. + * @param localityAwareTasks The number of tasks in all active stages that have a locality + * preferences. This includes running, pending, and completed tasks. + * @param hostToLocalTaskCount A map of hosts to the number of tasks from all active stages + * that would like to like to run on that host. + * This includes running, pending, and completed tasks. + * @return whether the request is acknowledged by the cluster manager. + */ + private[spark] override def requestTotalExecutors( + numExecutors: Int, + localityAwareTasks: Int, + hostToLocalTaskCount: scala.collection.immutable.Map[String, Int] + ): Boolean = { assert(supportDynamicAllocation, - "Requesting executors is currently only supported in YARN mode") + "Requesting executors is currently only supported in YARN and Mesos modes") schedulerBackend match { case b: CoarseGrainedSchedulerBackend => - b.requestTotalExecutors(numExecutors) + b.requestTotalExecutors(numExecutors, localityAwareTasks, hostToLocalTaskCount) case _ => logWarning("Requesting executors is only supported in coarse-grained mode") false @@ -1403,7 +1419,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli @DeveloperApi override def requestExecutors(numAdditionalExecutors: Int): Boolean = { assert(supportDynamicAllocation, - "Requesting executors is currently only supported in YARN mode") + "Requesting executors is currently only supported in YARN and Mesos modes") schedulerBackend match { case b: CoarseGrainedSchedulerBackend => b.requestExecutors(numAdditionalExecutors) @@ -1416,12 +1432,18 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** * :: DeveloperApi :: * Request that the cluster manager kill the specified executors. + * + * Note: This is an indication to the cluster manager that the application wishes to adjust + * its resource usage downwards. If the application wishes to replace the executors it kills + * through this method with new ones, it should follow up explicitly with a call to + * {{SparkContext#requestExecutors}}. + * * This is currently only supported in YARN mode. Return whether the request is received. */ @DeveloperApi override def killExecutors(executorIds: Seq[String]): Boolean = { assert(supportDynamicAllocation, - "Killing executors is currently only supported in YARN mode") + "Killing executors is currently only supported in YARN and Mesos modes") schedulerBackend match { case b: CoarseGrainedSchedulerBackend => b.killExecutors(executorIds) @@ -1433,12 +1455,42 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** * :: DeveloperApi :: - * Request that cluster manager the kill the specified executor. - * This is currently only supported in Yarn mode. Return whether the request is received. + * Request that the cluster manager kill the specified executor. + * + * Note: This is an indication to the cluster manager that the application wishes to adjust + * its resource usage downwards. If the application wishes to replace the executor it kills + * through this method with a new one, it should follow up explicitly with a call to + * {{SparkContext#requestExecutors}}. + * + * This is currently only supported in YARN mode. Return whether the request is received. */ @DeveloperApi override def killExecutor(executorId: String): Boolean = super.killExecutor(executorId) + /** + * Request that the cluster manager kill the specified executor without adjusting the + * application resource requirements. + * + * The effect is that a new executor will be launched in place of the one killed by + * this request. This assumes the cluster manager will automatically and eventually + * fulfill all missing application resource requests. + * + * Note: The replace is by no means guaranteed; another application on the same cluster + * can steal the window of opportunity and acquire this application's resources in the + * mean time. + * + * This is currently only supported in YARN mode. Return whether the request is received. + */ + private[spark] def killAndReplaceExecutor(executorId: String): Boolean = { + schedulerBackend match { + case b: CoarseGrainedSchedulerBackend => + b.killExecutors(Seq(executorId), replace = true) + case _ => + logWarning("Killing executors is only supported in coarse-grained mode") + false + } + } + /** The version of Spark on which this application is running. */ def version: String = SPARK_VERSION @@ -1719,16 +1771,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** * Run a function on a given set of partitions in an RDD and pass the results to the given - * handler function. This is the main entry point for all actions in Spark. The allowLocal - * flag specifies whether the scheduler can run the computation on the driver rather than - * shipping it out to the cluster, for short actions like first(). + * handler function. This is the main entry point for all actions in Spark. */ def runJob[T, U: ClassTag]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], - allowLocal: Boolean, - resultHandler: (Int, U) => Unit) { + resultHandler: (Int, U) => Unit): Unit = { if (stopped.get()) { throw new IllegalStateException("SparkContext has been shutdown") } @@ -1738,54 +1787,104 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli if (conf.getBoolean("spark.logLineage", false)) { logInfo("RDD's recursive dependencies:\n" + rdd.toDebugString) } - dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal, - resultHandler, localProperties.get) + dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, resultHandler, localProperties.get) progressBar.foreach(_.finishAll()) rdd.doCheckpoint() } /** - * Run a function on a given set of partitions in an RDD and return the results as an array. The - * allowLocal flag specifies whether the scheduler can run the computation on the driver rather - * than shipping it out to the cluster, for short actions like first(). + * Run a function on a given set of partitions in an RDD and return the results as an array. + */ + def runJob[T, U: ClassTag]( + rdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + partitions: Seq[Int]): Array[U] = { + val results = new Array[U](partitions.size) + runJob[T, U](rdd, func, partitions, (index, res) => results(index) = res) + results + } + + /** + * Run a job on a given set of partitions of an RDD, but take a function of type + * `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`. + */ + def runJob[T, U: ClassTag]( + rdd: RDD[T], + func: Iterator[T] => U, + partitions: Seq[Int]): Array[U] = { + val cleanedFunc = clean(func) + runJob(rdd, (ctx: TaskContext, it: Iterator[T]) => cleanedFunc(it), partitions) + } + + + /** + * Run a function on a given set of partitions in an RDD and pass the results to the given + * handler function. This is the main entry point for all actions in Spark. + * + * The allowLocal flag is deprecated as of Spark 1.5.0+. + */ + @deprecated("use the version of runJob without the allowLocal parameter", "1.5.0") + def runJob[T, U: ClassTag]( + rdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + partitions: Seq[Int], + allowLocal: Boolean, + resultHandler: (Int, U) => Unit): Unit = { + if (allowLocal) { + logWarning("sc.runJob with allowLocal=true is deprecated in Spark 1.5.0+") + } + runJob(rdd, func, partitions, resultHandler) + } + + /** + * Run a function on a given set of partitions in an RDD and return the results as an array. + * + * The allowLocal flag is deprecated as of Spark 1.5.0+. */ + @deprecated("use the version of runJob without the allowLocal parameter", "1.5.0") def runJob[T, U: ClassTag]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], allowLocal: Boolean ): Array[U] = { - val results = new Array[U](partitions.size) - runJob[T, U](rdd, func, partitions, allowLocal, (index, res) => results(index) = res) - results + if (allowLocal) { + logWarning("sc.runJob with allowLocal=true is deprecated in Spark 1.5.0+") + } + runJob(rdd, func, partitions) } /** * Run a job on a given set of partitions of an RDD, but take a function of type * `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`. + * + * The allowLocal argument is deprecated as of Spark 1.5.0+. */ + @deprecated("use the version of runJob without the allowLocal parameter", "1.5.0") def runJob[T, U: ClassTag]( rdd: RDD[T], func: Iterator[T] => U, partitions: Seq[Int], allowLocal: Boolean ): Array[U] = { - val cleanedFunc = clean(func) - runJob(rdd, (ctx: TaskContext, it: Iterator[T]) => cleanedFunc(it), partitions, allowLocal) + if (allowLocal) { + logWarning("sc.runJob with allowLocal=true is deprecated in Spark 1.5.0+") + } + runJob(rdd, func, partitions) } /** * Run a job on all partitions in an RDD and return the results in an array. */ def runJob[T, U: ClassTag](rdd: RDD[T], func: (TaskContext, Iterator[T]) => U): Array[U] = { - runJob(rdd, func, 0 until rdd.partitions.size, false) + runJob(rdd, func, 0 until rdd.partitions.length) } /** * Run a job on all partitions in an RDD and return the results in an array. */ def runJob[T, U: ClassTag](rdd: RDD[T], func: Iterator[T] => U): Array[U] = { - runJob(rdd, func, 0 until rdd.partitions.size, false) + runJob(rdd, func, 0 until rdd.partitions.length) } /** @@ -1796,7 +1895,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli processPartition: (TaskContext, Iterator[T]) => U, resultHandler: (Int, U) => Unit) { - runJob[T, U](rdd, processPartition, 0 until rdd.partitions.size, false, resultHandler) + runJob[T, U](rdd, processPartition, 0 until rdd.partitions.length, resultHandler) } /** @@ -1808,7 +1907,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli resultHandler: (Int, U) => Unit) { val processFunc = (context: TaskContext, iter: Iterator[T]) => processPartition(iter) - runJob[T, U](rdd, processFunc, 0 until rdd.partitions.size, false, resultHandler) + runJob[T, U](rdd, processFunc, 0 until rdd.partitions.length, resultHandler) } /** @@ -1853,7 +1952,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli (context: TaskContext, iter: Iterator[T]) => cleanF(iter), partitions, callSite, - allowLocal = false, resultHandler, localProperties.get) new SimpleFutureAction(waiter, resultFunc) @@ -1965,7 +2063,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli for (className <- listenerClassNames) { // Use reflection to find the right constructor val constructors = { - val listenerClass = Class.forName(className) + val listenerClass = Utils.classForName(className) listenerClass.getConstructors.asInstanceOf[Array[Constructor[_ <: SparkListener]]] } val constructorTakingSparkConf = constructors.find { c => @@ -2500,7 +2598,7 @@ object SparkContext extends Logging { "\"yarn-standalone\" is deprecated as of Spark 1.0. Use \"yarn-cluster\" instead.") } val scheduler = try { - val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler") + val clazz = Utils.classForName("org.apache.spark.scheduler.cluster.YarnClusterScheduler") val cons = clazz.getConstructor(classOf[SparkContext]) cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl] } catch { @@ -2512,7 +2610,7 @@ object SparkContext extends Logging { } val backend = try { val clazz = - Class.forName("org.apache.spark.scheduler.cluster.YarnClusterSchedulerBackend") + Utils.classForName("org.apache.spark.scheduler.cluster.YarnClusterSchedulerBackend") val cons = clazz.getConstructor(classOf[TaskSchedulerImpl], classOf[SparkContext]) cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend] } catch { @@ -2525,8 +2623,7 @@ object SparkContext extends Logging { case "yarn-client" => val scheduler = try { - val clazz = - Class.forName("org.apache.spark.scheduler.cluster.YarnScheduler") + val clazz = Utils.classForName("org.apache.spark.scheduler.cluster.YarnScheduler") val cons = clazz.getConstructor(classOf[SparkContext]) cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl] @@ -2538,7 +2635,7 @@ object SparkContext extends Logging { val backend = try { val clazz = - Class.forName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend") + Utils.classForName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend") val cons = clazz.getConstructor(classOf[TaskSchedulerImpl], classOf[SparkContext]) cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend] } catch { diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index d18fc599e9890..adfece4d6e7c0 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -261,7 +261,7 @@ object SparkEnv extends Logging { // Create an instance of the class with the given name, possibly initializing it with our conf def instantiateClass[T](className: String): T = { - val cls = Class.forName(className, true, Utils.getContextOrSparkClassLoader) + val cls = Utils.classForName(className) // Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just // SparkConf, then one taking no arguments try { diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index d09e17dea0911..b48836d5c8897 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -21,6 +21,7 @@ import java.io.Serializable import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics +import org.apache.spark.metrics.source.Source import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util.TaskCompletionListener @@ -32,7 +33,20 @@ object TaskContext { */ def get(): TaskContext = taskContext.get - private val taskContext: ThreadLocal[TaskContext] = new ThreadLocal[TaskContext] + /** + * Returns the partition id of currently active TaskContext. It will return 0 + * if there is no active TaskContext for cases like local execution. + */ + def getPartitionId(): Int = { + val tc = taskContext.get() + if (tc eq null) { + 0 + } else { + tc.partitionId() + } + } + + private[this] val taskContext: ThreadLocal[TaskContext] = new ThreadLocal[TaskContext] // Note: protected[spark] instead of private[spark] to prevent the following two from // showing up in JavaDoc. @@ -135,8 +149,34 @@ abstract class TaskContext extends Serializable { @DeveloperApi def taskMetrics(): TaskMetrics + /** + * ::DeveloperApi:: + * Returns all metrics sources with the given name which are associated with the instance + * which runs the task. For more information see [[org.apache.spark.metrics.MetricsSystem!]]. + */ + @DeveloperApi + def getMetricsSources(sourceName: String): Seq[Source] + /** * Returns the manager for this task's managed memory. */ private[spark] def taskMemoryManager(): TaskMemoryManager + + /** + * Register an accumulator that belongs to this task. Accumulators must call this method when + * deserializing in executors. + */ + private[spark] def registerAccumulator(a: Accumulable[_, _]): Unit + + /** + * Return the local values of internal accumulators that belong to this task. The key of the Map + * is the accumulator id and the value of the Map is the latest accumulator local value. + */ + private[spark] def collectInternalAccumulators(): Map[Long, Any] + + /** + * Return the local values of accumulators that belong to this task. The key of the Map is the + * accumulator id and the value of the Map is the latest accumulator local value. + */ + private[spark] def collectAccumulators(): Map[Long, Any] } diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index b4d572cb52313..9ee168ae016f8 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -17,18 +17,21 @@ package org.apache.spark +import scala.collection.mutable.{ArrayBuffer, HashMap} + import org.apache.spark.executor.TaskMetrics +import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.metrics.source.Source import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} -import scala.collection.mutable.ArrayBuffer - private[spark] class TaskContextImpl( val stageId: Int, val partitionId: Int, override val taskAttemptId: Long, override val attemptNumber: Int, override val taskMemoryManager: TaskMemoryManager, + @transient private val metricsSystem: MetricsSystem, val runningLocally: Boolean = false, val taskMetrics: TaskMetrics = TaskMetrics.empty) extends TaskContext @@ -94,5 +97,21 @@ private[spark] class TaskContextImpl( override def isRunningLocally(): Boolean = runningLocally override def isInterrupted(): Boolean = interrupted -} + override def getMetricsSources(sourceName: String): Seq[Source] = + metricsSystem.getSourcesByName(sourceName) + + @transient private val accumulators = new HashMap[Long, Accumulable[_, _]] + + private[spark] override def registerAccumulator(a: Accumulable[_, _]): Unit = synchronized { + accumulators(a.id) = a + } + + private[spark] override def collectInternalAccumulators(): Map[Long, Any] = synchronized { + accumulators.filter(_._2.isInternal).mapValues(_.localValue).toMap + } + + private[spark] override def collectAccumulators(): Map[Long, Any] = synchronized { + accumulators.mapValues(_.localValue).toMap + } +} diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index c95615a5a9307..829fae1d1d9bf 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -364,7 +364,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { // This is useful for implementing `take` from other language frontends // like Python where the data is serialized. import scala.collection.JavaConversions._ - val res = context.runJob(rdd, (it: Iterator[T]) => it.toArray, partitionIds, true) + val res = context.runJob(rdd, (it: Iterator[T]) => it.toArray, partitionIds) res.map(x => new java.util.ArrayList(x.toSeq)).toArray } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index dc9f62f39e6d5..55e563ee968be 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -207,6 +207,7 @@ private[spark] class PythonRDD( override def run(): Unit = Utils.logUncaughtExceptions { try { + TaskContext.setTaskContext(context) val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) // Partition index @@ -263,11 +264,6 @@ private[spark] class PythonRDD( if (!worker.isClosed) { Utils.tryLog(worker.shutdownOutput()) } - } finally { - // Release memory used by this thread for shuffles - env.shuffleMemoryManager.releaseMemoryForThisThread() - // Release memory used by this thread for unrolling blocks - env.blockManager.memoryStore.releaseUnrollMemoryForThisThread() } } } @@ -358,12 +354,11 @@ private[spark] object PythonRDD extends Logging { def runJob( sc: SparkContext, rdd: JavaRDD[Array[Byte]], - partitions: JArrayList[Int], - allowLocal: Boolean): Int = { + partitions: JArrayList[Int]): Int = { type ByteArray = Array[Byte] type UnrolledPartition = Array[ByteArray] val allPartitions: Array[UnrolledPartition] = - sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions, allowLocal) + sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions) val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*) serveIterator(flattenedPartition.iterator, s"serve RDD ${rdd.id} with partitions ${partitions.mkString(",")}") diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index 1a5f2bca26c2b..b7e72d4d0ed0b 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -95,7 +95,9 @@ private[spark] class RBackend { private[spark] object RBackend extends Logging { def main(args: Array[String]): Unit = { if (args.length < 1) { + // scalastyle:off println System.err.println("Usage: RBackend ") + // scalastyle:on println System.exit(-1) } val sparkRBackend = new RBackend() diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index 4b8f7fe9242e0..14dac4ed28ce3 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -20,12 +20,14 @@ package org.apache.spark.api.r import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import scala.collection.mutable.HashMap +import scala.language.existentials import io.netty.channel.ChannelHandler.Sharable import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} import org.apache.spark.Logging import org.apache.spark.api.r.SerDe._ +import org.apache.spark.util.Utils /** * Handler for RBackend @@ -67,8 +69,11 @@ private[r] class RBackendHandler(server: RBackend) case e: Exception => logError(s"Removing $objId failed", e) writeInt(dos, -1) + writeString(dos, s"Removing $objId failed: ${e.getMessage}") } - case _ => dos.writeInt(-1) + case _ => + dos.writeInt(-1) + writeString(dos, s"Error: unknown method $methodName") } } else { handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos) @@ -88,21 +93,6 @@ private[r] class RBackendHandler(server: RBackend) ctx.close() } - // Looks up a class given a class name. This function first checks the - // current class loader and if a class is not found, it looks up the class - // in the context class loader. Address [SPARK-5185] - def getStaticClass(objId: String): Class[_] = { - try { - val clsCurrent = Class.forName(objId) - clsCurrent - } catch { - // Use contextLoader if we can't find the JAR in the system class loader - case e: ClassNotFoundException => - val clsContext = Class.forName(objId, true, Thread.currentThread().getContextClassLoader) - clsContext - } - } - def handleMethodCall( isStatic: Boolean, objId: String, @@ -113,7 +103,7 @@ private[r] class RBackendHandler(server: RBackend) var obj: Object = null try { val cls = if (isStatic) { - getStaticClass(objId) + Utils.classForName(objId) } else { JVMObjectTracker.get(objId) match { case None => throw new IllegalArgumentException("Object not found " + objId) @@ -159,8 +149,11 @@ private[r] class RBackendHandler(server: RBackend) } } catch { case e: Exception => - logError(s"$methodName on $objId failed", e) + logError(s"$methodName on $objId failed") writeInt(dos, -1) + // Writing the error message of the cause for the exception. This will be returned + // to user in the R process. + writeString(dos, Utils.exceptionString(e.getCause)) } } diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 524676544d6f5..1cf2824f862ee 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -39,7 +39,6 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( deserializer: String, serializer: String, packageNames: Array[Byte], - rLibDir: String, broadcastVars: Array[Broadcast[Object]]) extends RDD[U](parent) with Logging { protected var dataStream: DataInputStream = _ @@ -60,7 +59,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( // The stdout/stderr is shared by multiple tasks, because we use one daemon // to launch child process as worker. - val errThread = RRDD.createRWorker(rLibDir, listenPort) + val errThread = RRDD.createRWorker(listenPort) // We use two sockets to separate input and output, then it's easy to manage // the lifecycle of them to avoid deadlock. @@ -113,6 +112,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( partition: Int): Unit = { val env = SparkEnv.get + val taskContext = TaskContext.get() val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt val stream = new BufferedOutputStream(output, bufferSize) @@ -120,6 +120,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( override def run(): Unit = { try { SparkEnv.set(env) + TaskContext.setTaskContext(taskContext) val dataOut = new DataOutputStream(stream) dataOut.writeInt(partition) @@ -161,7 +162,9 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( dataOut.write(elem.asInstanceOf[Array[Byte]]) } else if (deserializer == SerializationFormats.STRING) { // write string(for StringRRDD) + // scalastyle:off println printOut.println(elem) + // scalastyle:on println } } @@ -233,11 +236,10 @@ private class PairwiseRRDD[T: ClassTag]( hashFunc: Array[Byte], deserializer: String, packageNames: Array[Byte], - rLibDir: String, broadcastVars: Array[Object]) extends BaseRRDD[T, (Int, Array[Byte])]( parent, numPartitions, hashFunc, deserializer, - SerializationFormats.BYTE, packageNames, rLibDir, + SerializationFormats.BYTE, packageNames, broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { override protected def readData(length: Int): (Int, Array[Byte]) = { @@ -264,10 +266,9 @@ private class RRDD[T: ClassTag]( deserializer: String, serializer: String, packageNames: Array[Byte], - rLibDir: String, broadcastVars: Array[Object]) extends BaseRRDD[T, Array[Byte]]( - parent, -1, func, deserializer, serializer, packageNames, rLibDir, + parent, -1, func, deserializer, serializer, packageNames, broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { override protected def readData(length: Int): Array[Byte] = { @@ -291,10 +292,9 @@ private class StringRRDD[T: ClassTag]( func: Array[Byte], deserializer: String, packageNames: Array[Byte], - rLibDir: String, broadcastVars: Array[Object]) extends BaseRRDD[T, String]( - parent, -1, func, deserializer, SerializationFormats.STRING, packageNames, rLibDir, + parent, -1, func, deserializer, SerializationFormats.STRING, packageNames, broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { override protected def readData(length: Int): String = { @@ -390,9 +390,10 @@ private[r] object RRDD { thread } - private def createRProcess(rLibDir: String, port: Int, script: String): BufferedStreamThread = { + private def createRProcess(port: Int, script: String): BufferedStreamThread = { val rCommand = SparkEnv.get.conf.get("spark.sparkr.r.command", "Rscript") val rOptions = "--vanilla" + val rLibDir = RUtils.sparkRPackagePath(isDriver = false) val rExecScript = rLibDir + "/SparkR/worker/" + script val pb = new ProcessBuilder(List(rCommand, rOptions, rExecScript)) // Unset the R_TESTS environment variable for workers. @@ -411,7 +412,7 @@ private[r] object RRDD { /** * ProcessBuilder used to launch worker R processes. */ - def createRWorker(rLibDir: String, port: Int): BufferedStreamThread = { + def createRWorker(port: Int): BufferedStreamThread = { val useDaemon = SparkEnv.get.conf.getBoolean("spark.sparkr.use.daemon", true) if (!Utils.isWindows && useDaemon) { synchronized { @@ -419,7 +420,7 @@ private[r] object RRDD { // we expect one connections val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) val daemonPort = serverSocket.getLocalPort - errThread = createRProcess(rLibDir, daemonPort, "daemon.R") + errThread = createRProcess(daemonPort, "daemon.R") // the socket used to send out the input of task serverSocket.setSoTimeout(10000) val sock = serverSocket.accept() @@ -441,7 +442,7 @@ private[r] object RRDD { errThread } } else { - createRProcess(rLibDir, port, "worker.R") + createRProcess(port, "worker.R") } } diff --git a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala new file mode 100644 index 0000000000000..d53abd3408c55 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.File + +import org.apache.spark.{SparkEnv, SparkException} + +private[spark] object RUtils { + /** + * Get the SparkR package path in the local spark distribution. + */ + def localSparkRPackagePath: Option[String] = { + val sparkHome = sys.env.get("SPARK_HOME") + sparkHome.map( + Seq(_, "R", "lib").mkString(File.separator) + ) + } + + /** + * Get the SparkR package path in various deployment modes. + * This assumes that Spark properties `spark.master` and `spark.submit.deployMode` + * and environment variable `SPARK_HOME` are set. + */ + def sparkRPackagePath(isDriver: Boolean): String = { + val (master, deployMode) = + if (isDriver) { + (sys.props("spark.master"), sys.props("spark.submit.deployMode")) + } else { + val sparkConf = SparkEnv.get.conf + (sparkConf.get("spark.master"), sparkConf.get("spark.submit.deployMode")) + } + + val isYarnCluster = master.contains("yarn") && deployMode == "cluster" + val isYarnClient = master.contains("yarn") && deployMode == "client" + + // In YARN mode, the SparkR package is distributed as an archive symbolically + // linked to the "sparkr" file in the current directory. Note that this does not apply + // to the driver in client mode because it is run outside of the cluster. + if (isYarnCluster || (isYarnClient && !isDriver)) { + new File("sparkr").getAbsolutePath + } else { + // Otherwise, assume the package is local + // TODO: support this for Mesos + localSparkRPackagePath.getOrElse { + throw new SparkException("SPARK_HOME not set. Can't locate SparkR package.") + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 56adc857d4ce0..d5b4260bf4529 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -179,6 +179,7 @@ private[spark] object SerDe { // Int -> integer // String -> character // Boolean -> logical + // Float -> double // Double -> double // Long -> double // Array[Byte] -> raw @@ -215,6 +216,9 @@ private[spark] object SerDe { case "long" | "java.lang.Long" => writeType(dos, "double") writeDouble(dos, value.asInstanceOf[Long].toDouble) + case "float" | "java.lang.Float" => + writeType(dos, "double") + writeDouble(dos, value.asInstanceOf[Float].toDouble) case "double" | "java.lang.Double" => writeType(dos, "double") writeDouble(dos, value.asInstanceOf[Double]) diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index 685313ac009ba..fac6666bb3410 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -22,6 +22,7 @@ import java.util.concurrent.atomic.AtomicLong import scala.reflect.ClassTag import org.apache.spark._ +import org.apache.spark.util.Utils private[spark] class BroadcastManager( val isDriver: Boolean, @@ -42,7 +43,7 @@ private[spark] class BroadcastManager( conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") broadcastFactory = - Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] + Utils.classForName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] // Initialize appropriate BroadcastFactory and BroadcastObject broadcastFactory.initialize(isDriver, conf, securityManager) diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 71f7e2129116f..f03875a3e8c89 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -118,26 +118,26 @@ private class ClientEndpoint( def pollAndReportStatus(driverId: String) { // Since ClientEndpoint is the only RpcEndpoint in the process, blocking the event loop thread // is fine. - println("... waiting before polling master for driver state") + logInfo("... waiting before polling master for driver state") Thread.sleep(5000) - println("... polling master for driver state") + logInfo("... polling master for driver state") val statusResponse = activeMasterEndpoint.askWithRetry[DriverStatusResponse](RequestDriverStatus(driverId)) statusResponse.found match { case false => - println(s"ERROR: Cluster master did not recognize $driverId") + logError(s"ERROR: Cluster master did not recognize $driverId") System.exit(-1) case true => - println(s"State of $driverId is ${statusResponse.state.get}") + logInfo(s"State of $driverId is ${statusResponse.state.get}") // Worker node, if present (statusResponse.workerId, statusResponse.workerHostPort, statusResponse.state) match { case (Some(id), Some(hostPort), Some(DriverState.RUNNING)) => - println(s"Driver running on $hostPort ($id)") + logInfo(s"Driver running on $hostPort ($id)") case _ => } // Exception, if present statusResponse.exception.map { e => - println(s"Exception from cluster was: $e") + logError(s"Exception from cluster was: $e") e.printStackTrace() System.exit(-1) } @@ -148,7 +148,7 @@ private class ClientEndpoint( override def receive: PartialFunction[Any, Unit] = { case SubmitDriverResponse(master, success, driverId, message) => - println(message) + logInfo(message) if (success) { activeMasterEndpoint = master pollAndReportStatus(driverId.get) @@ -158,7 +158,7 @@ private class ClientEndpoint( case KillDriverResponse(master, driverId, success, message) => - println(message) + logInfo(message) if (success) { activeMasterEndpoint = master pollAndReportStatus(driverId) @@ -169,13 +169,13 @@ private class ClientEndpoint( override def onDisconnected(remoteAddress: RpcAddress): Unit = { if (!lostMasters.contains(remoteAddress)) { - println(s"Error connecting to master $remoteAddress.") + logError(s"Error connecting to master $remoteAddress.") lostMasters += remoteAddress // Note that this heuristic does not account for the fact that a Master can recover within // the lifetime of this client. Thus, once a Master is lost it is lost to us forever. This // is not currently a concern, however, because this client does not retry submissions. if (lostMasters.size >= masterEndpoints.size) { - println("No master is available, exiting.") + logError("No master is available, exiting.") System.exit(-1) } } @@ -183,18 +183,18 @@ private class ClientEndpoint( override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { if (!lostMasters.contains(remoteAddress)) { - println(s"Error connecting to master ($remoteAddress).") - println(s"Cause was: $cause") + logError(s"Error connecting to master ($remoteAddress).") + logError(s"Cause was: $cause") lostMasters += remoteAddress if (lostMasters.size >= masterEndpoints.size) { - println("No master is available, exiting.") + logError("No master is available, exiting.") System.exit(-1) } } } override def onError(cause: Throwable): Unit = { - println(s"Error processing messages, exiting.") + logError(s"Error processing messages, exiting.") cause.printStackTrace() System.exit(-1) } @@ -209,10 +209,12 @@ private class ClientEndpoint( */ object Client { def main(args: Array[String]) { + // scalastyle:off println if (!sys.props.contains("SPARK_SUBMIT")) { println("WARNING: This client is deprecated and will be removed in a future version of Spark") println("Use ./bin/spark-submit with \"--master spark://host:port\"") } + // scalastyle:on println val conf = new SparkConf() val driverArgs = new ClientArguments(args) diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala index 42d3296062e6d..72cc330a398da 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala @@ -72,9 +72,11 @@ private[deploy] class ClientArguments(args: Array[String]) { cmd = "launch" if (!ClientArguments.isValidJarUrl(_jarUrl)) { + // scalastyle:off println println(s"Jar url '${_jarUrl}' is not in valid format.") println(s"Must be a jar file path in URL format " + "(e.g. hdfs://host:port/XX.jar, file:///XX.jar)") + // scalastyle:on println printUsageAndExit(-1) } @@ -110,7 +112,9 @@ private[deploy] class ClientArguments(args: Array[String]) { | (default: $DEFAULT_SUPERVISE) | -v, --verbose Print more debugging output """.stripMargin + // scalastyle:off println System.err.println(usage) + // scalastyle:on println System.exit(exitCode) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala index 2954f932b4f41..ccffb36652988 100644 --- a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala @@ -76,12 +76,13 @@ private[deploy] object JsonProtocol { } def writeMasterState(obj: MasterStateResponse): JObject = { + val aliveWorkers = obj.workers.filter(_.isAlive()) ("url" -> obj.uri) ~ ("workers" -> obj.workers.toList.map(writeWorkerInfo)) ~ - ("cores" -> obj.workers.map(_.cores).sum) ~ - ("coresused" -> obj.workers.map(_.coresUsed).sum) ~ - ("memory" -> obj.workers.map(_.memory).sum) ~ - ("memoryused" -> obj.workers.map(_.memoryUsed).sum) ~ + ("cores" -> aliveWorkers.map(_.cores).sum) ~ + ("coresused" -> aliveWorkers.map(_.coresUsed).sum) ~ + ("memory" -> aliveWorkers.map(_.memory).sum) ~ + ("memoryused" -> aliveWorkers.map(_.memoryUsed).sum) ~ ("activeapps" -> obj.activeApps.toList.map(writeApplicationInfo)) ~ ("completedapps" -> obj.completedApps.toList.map(writeApplicationInfo)) ~ ("activedrivers" -> obj.activeDrivers.toList.map(writeDriverInfo)) ~ diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala index e99779f299785..c0cab22fa8252 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConversions._ import org.apache.hadoop.fs.Path -import org.apache.spark.api.r.RBackend +import org.apache.spark.api.r.{RBackend, RUtils} import org.apache.spark.util.RedirectThread /** @@ -71,9 +71,10 @@ object RRunner { val builder = new ProcessBuilder(Seq(rCommand, rFileNormalized) ++ otherArgs) val env = builder.environment() env.put("EXISTING_SPARKR_BACKEND_PORT", sparkRBackendPort.toString) - val sparkHome = System.getenv("SPARK_HOME") + val rPackageDir = RUtils.sparkRPackagePath(isDriver = true) + env.put("SPARKR_PACKAGE_DIR", rPackageDir) env.put("R_PROFILE_USER", - Seq(sparkHome, "R", "lib", "SparkR", "profile", "general.R").mkString(File.separator)) + Seq(rPackageDir, "SparkR", "profile", "general.R").mkString(File.separator)) builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize val process = builder.start() @@ -85,7 +86,9 @@ object RRunner { } System.exit(returnCode) } else { + // scalastyle:off println System.err.println("SparkR backend did not initialize in " + backendTimeout + " seconds") + // scalastyle:on println System.exit(-1) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 6d14590a1d192..e06b06e06fb4a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -25,6 +25,7 @@ import java.util.{Arrays, Comparator} import scala.collection.JavaConversions._ import scala.concurrent.duration._ import scala.language.postfixOps +import scala.util.control.NonFatal import com.google.common.primitives.Longs import org.apache.hadoop.conf.Configuration @@ -178,7 +179,7 @@ class SparkHadoopUtil extends Logging { private def getFileSystemThreadStatisticsMethod(methodName: String): Method = { val statisticsDataClass = - Class.forName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData") + Utils.classForName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData") statisticsDataClass.getDeclaredMethod(methodName) } @@ -238,6 +239,14 @@ class SparkHadoopUtil extends Logging { }.getOrElse(Seq.empty[Path]) } + def globPathIfNecessary(pattern: Path): Seq[Path] = { + if (pattern.toString.exists("{}[]*?\\".toSet.contains)) { + globPath(pattern) + } else { + Seq(pattern) + } + } + /** * Lists all the files in a directory with the specified prefix, and does not end with the * given suffix. The returned {{FileStatus}} instances are sorted by the modification times of @@ -248,19 +257,25 @@ class SparkHadoopUtil extends Logging { dir: Path, prefix: String, exclusionSuffix: String): Array[FileStatus] = { - val fileStatuses = remoteFs.listStatus(dir, - new PathFilter { - override def accept(path: Path): Boolean = { - val name = path.getName - name.startsWith(prefix) && !name.endsWith(exclusionSuffix) + try { + val fileStatuses = remoteFs.listStatus(dir, + new PathFilter { + override def accept(path: Path): Boolean = { + val name = path.getName + name.startsWith(prefix) && !name.endsWith(exclusionSuffix) + } + }) + Arrays.sort(fileStatuses, new Comparator[FileStatus] { + override def compare(o1: FileStatus, o2: FileStatus): Int = { + Longs.compare(o1.getModificationTime, o2.getModificationTime) } }) - Arrays.sort(fileStatuses, new Comparator[FileStatus] { - override def compare(o1: FileStatus, o2: FileStatus): Int = { - Longs.compare(o1.getModificationTime, o2.getModificationTime) - } - }) - fileStatuses + fileStatuses + } catch { + case NonFatal(e) => + logWarning("Error while attempting to list files from application staging dir", e) + Array.empty + } } /** @@ -356,7 +371,7 @@ object SparkHadoopUtil { System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE"))) if (yarnMode) { try { - Class.forName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil") + Utils.classForName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil") .newInstance() .asInstanceOf[SparkHadoopUtil] } catch { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index b1d6ec209d62b..0b39ee8fe3ba0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -37,6 +37,7 @@ import org.apache.ivy.core.settings.IvySettings import org.apache.ivy.plugins.matcher.GlobPatternMatcher import org.apache.ivy.plugins.repository.file.FileRepository import org.apache.ivy.plugins.resolver.{FileSystemResolver, ChainResolver, IBiblioResolver} +import org.apache.spark.api.r.RUtils import org.apache.spark.SPARK_VERSION import org.apache.spark.deploy.rest._ import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} @@ -79,9 +80,11 @@ object SparkSubmit { private val SPARK_SHELL = "spark-shell" private val PYSPARK_SHELL = "pyspark-shell" private val SPARKR_SHELL = "sparkr-shell" + private val SPARKR_PACKAGE_ARCHIVE = "sparkr.zip" private val CLASS_NOT_FOUND_EXIT_STATUS = 101 + // scalastyle:off println // Exposed for testing private[spark] var exitFn: Int => Unit = (exitCode: Int) => System.exit(exitCode) private[spark] var printStream: PrintStream = System.err @@ -102,11 +105,14 @@ object SparkSubmit { printStream.println("Type --help for more information.") exitFn(0) } + // scalastyle:on println def main(args: Array[String]): Unit = { val appArgs = new SparkSubmitArguments(args) if (appArgs.verbose) { + // scalastyle:off println printStream.println(appArgs) + // scalastyle:on println } appArgs.action match { case SparkSubmitAction.SUBMIT => submit(appArgs) @@ -160,7 +166,9 @@ object SparkSubmit { // makes the message printed to the output by the JVM not very helpful. Instead, // detect exceptions with empty stack traces here, and treat them differently. if (e.getStackTrace().length == 0) { + // scalastyle:off println printStream.println(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}") + // scalastyle:on println exitFn(1) } else { throw e @@ -178,7 +186,9 @@ object SparkSubmit { // to use the legacy gateway if the master endpoint turns out to be not a REST server. if (args.isStandaloneCluster && args.useRest) { try { + // scalastyle:off println printStream.println("Running Spark using the REST application submission protocol.") + // scalastyle:on println doRunMain() } catch { // Fail over to use the legacy submission gateway @@ -254,6 +264,12 @@ object SparkSubmit { } } + // Update args.deployMode if it is null. It will be passed down as a Spark property later. + (args.deployMode, deployMode) match { + case (null, CLIENT) => args.deployMode = "client" + case (null, CLUSTER) => args.deployMode = "cluster" + case _ => + } val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER val isMesosCluster = clusterManager == MESOS && deployMode == CLUSTER @@ -339,6 +355,23 @@ object SparkSubmit { } } + // In YARN mode for an R app, add the SparkR package archive to archives + // that can be distributed with the job + if (args.isR && clusterManager == YARN) { + val rPackagePath = RUtils.localSparkRPackagePath + if (rPackagePath.isEmpty) { + printErrorAndExit("SPARK_HOME does not exist for R application in YARN mode.") + } + val rPackageFile = new File(rPackagePath.get, SPARKR_PACKAGE_ARCHIVE) + if (!rPackageFile.exists()) { + printErrorAndExit(s"$SPARKR_PACKAGE_ARCHIVE does not exist for R application in YARN mode.") + } + val localURI = Utils.resolveURI(rPackageFile.getAbsolutePath) + + // Assigns a symbol link name "sparkr" to the shipped package. + args.archives = mergeFileLists(args.archives, localURI.toString + "#sparkr") + } + // If we're running a R app, set the main class to our specific R runner if (args.isR && deployMode == CLIENT) { if (args.primaryResource == SPARKR_SHELL) { @@ -367,6 +400,8 @@ object SparkSubmit { // All cluster managers OptionAssigner(args.master, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.master"), + OptionAssigner(args.deployMode, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, + sysProp = "spark.submit.deployMode"), OptionAssigner(args.name, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.app.name"), OptionAssigner(args.jars, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars"), OptionAssigner(args.ivyRepoPath, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars.ivy"), @@ -473,8 +508,14 @@ object SparkSubmit { } // Let YARN know it's a pyspark app, so it distributes needed libraries. - if (clusterManager == YARN && args.isPython) { - sysProps.put("spark.yarn.isPython", "true") + if (clusterManager == YARN) { + if (args.isPython) { + sysProps.put("spark.yarn.isPython", "true") + } + if (args.principal != null) { + require(args.keytab != null, "Keytab must be specified when the keytab is specified") + UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab) + } } // In yarn-cluster mode, use yarn.Client as a wrapper around the user class @@ -558,6 +599,7 @@ object SparkSubmit { sysProps: Map[String, String], childMainClass: String, verbose: Boolean): Unit = { + // scalastyle:off println if (verbose) { printStream.println(s"Main class:\n$childMainClass") printStream.println(s"Arguments:\n${childArgs.mkString("\n")}") @@ -565,6 +607,7 @@ object SparkSubmit { printStream.println(s"Classpath elements:\n${childClasspath.mkString("\n")}") printStream.println("\n") } + // scalastyle:on println val loader = if (sysProps.getOrElse("spark.driver.userClassPathFirst", "false").toBoolean) { @@ -587,13 +630,15 @@ object SparkSubmit { var mainClass: Class[_] = null try { - mainClass = Class.forName(childMainClass, true, loader) + mainClass = Utils.classForName(childMainClass) } catch { case e: ClassNotFoundException => e.printStackTrace(printStream) if (childMainClass.contains("thriftserver")) { + // scalastyle:off println printStream.println(s"Failed to load main class $childMainClass.") printStream.println("You need to build Spark with -Phive and -Phive-thriftserver.") + // scalastyle:on println } System.exit(CLASS_NOT_FOUND_EXIT_STATUS) } @@ -766,7 +811,9 @@ private[spark] object SparkSubmitUtils { brr.setRoot(repo) brr.setName(s"repo-${i + 1}") cr.add(brr) + // scalastyle:off println printStream.println(s"$repo added as a remote repository with the name: ${brr.getName}") + // scalastyle:on println } } @@ -829,7 +876,9 @@ private[spark] object SparkSubmitUtils { val ri = ModuleRevisionId.newInstance(mvn.groupId, mvn.artifactId, mvn.version) val dd = new DefaultDependencyDescriptor(ri, false, false) dd.addDependencyConfiguration(ivyConfName, ivyConfName) + // scalastyle:off println printStream.println(s"${dd.getDependencyId} added as a dependency") + // scalastyle:on println md.addDependency(dd) } } @@ -896,9 +945,11 @@ private[spark] object SparkSubmitUtils { ivySettings.setDefaultCache(new File(alternateIvyCache, "cache")) new File(alternateIvyCache, "jars") } + // scalastyle:off println printStream.println( s"Ivy Default Cache set to: ${ivySettings.getDefaultCache.getAbsolutePath}") printStream.println(s"The jars for the packages stored in: $packagesDirectory") + // scalastyle:on println // create a pattern matcher ivySettings.addMatcher(new GlobPatternMatcher) // create the dependency resolvers diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 73ab18332feb4..b3710073e330c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -79,6 +79,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S /** Default properties present in the currently defined defaults file. */ lazy val defaultSparkProperties: HashMap[String, String] = { val defaultProperties = new HashMap[String, String]() + // scalastyle:off println if (verbose) SparkSubmit.printStream.println(s"Using properties file: $propertiesFile") Option(propertiesFile).foreach { filename => Utils.getPropertiesFromFile(filename).foreach { case (k, v) => @@ -86,6 +87,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S if (verbose) SparkSubmit.printStream.println(s"Adding default property: $k=$v") } } + // scalastyle:on println defaultProperties } @@ -162,6 +164,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S .orNull executorCores = Option(executorCores) .orElse(sparkProperties.get("spark.executor.cores")) + .orElse(env.get("SPARK_EXECUTOR_CORES")) .orNull totalExecutorCores = Option(totalExecutorCores) .orElse(sparkProperties.get("spark.cores.max")) @@ -451,6 +454,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S } private def printUsageAndExit(exitCode: Int, unknownParam: Any = null): Unit = { + // scalastyle:off println val outStream = SparkSubmit.printStream if (unknownParam != null) { outStream.println("Unknown/unsupported param " + unknownParam) @@ -540,6 +544,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S outStream.println("CLI options:") outStream.println(getSqlShellOptions()) } + // scalastyle:on println SparkSubmit.exitFn(exitCode) } @@ -571,7 +576,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S System.setSecurityManager(sm) try { - Class.forName(mainClass).getMethod("main", classOf[Array[String]]) + Utils.classForName(mainClass).getMethod("main", classOf[Array[String]]) .invoke(null, Array(HELP)) } catch { case e: InvocationTargetException => diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala index c5ac45c6730d3..a98b1fa8f83a1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala @@ -19,7 +19,9 @@ package org.apache.spark.deploy.client private[spark] object TestExecutor { def main(args: Array[String]) { + // scalastyle:off println println("Hello world!") + // scalastyle:on println while (true) { Thread.sleep(1000) } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 2cc465e55fceb..e3060ac3fa1a9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -407,8 +407,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) /** * Comparison function that defines the sort order for application attempts within the same - * application. Order is: running attempts before complete attempts, running attempts sorted - * by start time, completed attempts sorted by end time. + * application. Order is: attempts are sorted by descending start time. + * Most recent attempt state matches with current state of the app. * * Normally applications should have a single running attempt; but failure to call sc.stop() * may cause multiple running attempts to show up. @@ -418,11 +418,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) private def compareAttemptInfo( a1: FsApplicationAttemptInfo, a2: FsApplicationAttemptInfo): Boolean = { - if (a1.completed == a2.completed) { - if (a1.completed) a1.endTime >= a2.endTime else a1.startTime >= a2.startTime - } else { - !a1.completed - } + a1.startTime >= a2.startTime } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 10638afb74900..a076a9c3f984d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -228,7 +228,7 @@ object HistoryServer extends Logging { val providerName = conf.getOption("spark.history.provider") .getOrElse(classOf[FsHistoryProvider].getName()) - val provider = Class.forName(providerName) + val provider = Utils.classForName(providerName) .getConstructor(classOf[SparkConf]) .newInstance(conf) .asInstanceOf[ApplicationHistoryProvider] diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala index 4692d22651c93..18265df9faa2c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala @@ -56,6 +56,7 @@ private[history] class HistoryServerArguments(conf: SparkConf, args: Array[Strin Utils.loadDefaultSparkProperties(conf, propertiesFile) private def printUsageAndExit(exitCode: Int) { + // scalastyle:off println System.err.println( """ |Usage: HistoryServer [options] @@ -84,6 +85,7 @@ private[history] class HistoryServerArguments(conf: SparkConf, args: Array[Strin | spark.history.fs.updateInterval How often to reload log data from storage | (in seconds, default: 10) |""".stripMargin) + // scalastyle:on println System.exit(exitCode) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala index f459ed5b3a1a1..aa379d4cd61e7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala @@ -21,9 +21,8 @@ import java.io._ import scala.reflect.ClassTag -import akka.serialization.Serialization - import org.apache.spark.Logging +import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer} import org.apache.spark.util.Utils @@ -32,11 +31,11 @@ import org.apache.spark.util.Utils * Files are deleted when applications and workers are removed. * * @param dir Directory to store files. Created if non-existent (but not recursively). - * @param serialization Used to serialize our objects. + * @param serializer Used to serialize our objects. */ private[master] class FileSystemPersistenceEngine( val dir: String, - val serialization: Serialization) + val serializer: Serializer) extends PersistenceEngine with Logging { new File(dir).mkdir() @@ -57,27 +56,31 @@ private[master] class FileSystemPersistenceEngine( private def serializeIntoFile(file: File, value: AnyRef) { val created = file.createNewFile() if (!created) { throw new IllegalStateException("Could not create file: " + file) } - val serializer = serialization.findSerializerFor(value) - val serialized = serializer.toBinary(value) - val out = new FileOutputStream(file) + val fileOut = new FileOutputStream(file) + var out: SerializationStream = null Utils.tryWithSafeFinally { - out.write(serialized) + out = serializer.newInstance().serializeStream(fileOut) + out.writeObject(value) } { - out.close() + fileOut.close() + if (out != null) { + out.close() + } } } private def deserializeFromFile[T](file: File)(implicit m: ClassTag[T]): T = { - val fileData = new Array[Byte](file.length().asInstanceOf[Int]) - val dis = new DataInputStream(new FileInputStream(file)) + val fileIn = new FileInputStream(file) + var in: DeserializationStream = null try { - dis.readFully(fileData) + in = serializer.newInstance().deserializeStream(fileIn) + in.readObject[T]() } finally { - dis.close() + fileIn.close() + if (in != null) { + in.close() + } } - val clazz = m.runtimeClass.asInstanceOf[Class[T]] - val serializer = serialization.serializerFor(clazz) - serializer.fromBinary(fileData).asInstanceOf[T] } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 48070768f6edb..51b3f0dead73e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -27,11 +27,8 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.language.postfixOps import scala.util.Random -import akka.serialization.Serialization -import akka.serialization.SerializationExtension import org.apache.hadoop.fs.Path -import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.rpc._ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, @@ -44,6 +41,7 @@ import org.apache.spark.deploy.master.ui.MasterWebUI import org.apache.spark.deploy.rest.StandaloneRestServer import org.apache.spark.metrics.MetricsSystem import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus} +import org.apache.spark.serializer.{JavaSerializer, Serializer} import org.apache.spark.ui.SparkUI import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils} @@ -58,9 +56,6 @@ private[master] class Master( private val forwardMessageThread = ThreadUtils.newDaemonSingleThreadScheduledExecutor("master-forward-message-thread") - // TODO Remove it once we don't use akka.serialization.Serialization - private val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem - private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs @@ -161,20 +156,21 @@ private[master] class Master( masterMetricsSystem.getServletHandlers.foreach(webUi.attachHandler) applicationMetricsSystem.getServletHandlers.foreach(webUi.attachHandler) + val serializer = new JavaSerializer(conf) val (persistenceEngine_, leaderElectionAgent_) = RECOVERY_MODE match { case "ZOOKEEPER" => logInfo("Persisting recovery state to ZooKeeper") val zkFactory = - new ZooKeeperRecoveryModeFactory(conf, SerializationExtension(actorSystem)) + new ZooKeeperRecoveryModeFactory(conf, serializer) (zkFactory.createPersistenceEngine(), zkFactory.createLeaderElectionAgent(this)) case "FILESYSTEM" => val fsFactory = - new FileSystemRecoveryModeFactory(conf, SerializationExtension(actorSystem)) + new FileSystemRecoveryModeFactory(conf, serializer) (fsFactory.createPersistenceEngine(), fsFactory.createLeaderElectionAgent(this)) case "CUSTOM" => - val clazz = Class.forName(conf.get("spark.deploy.recoveryMode.factory")) - val factory = clazz.getConstructor(classOf[SparkConf], classOf[Serialization]) - .newInstance(conf, SerializationExtension(actorSystem)) + val clazz = Utils.classForName(conf.get("spark.deploy.recoveryMode.factory")) + val factory = clazz.getConstructor(classOf[SparkConf], classOf[Serializer]) + .newInstance(conf, serializer) .asInstanceOf[StandaloneRecoveryModeFactory] (factory.createPersistenceEngine(), factory.createLeaderElectionAgent(this)) case _ => @@ -213,7 +209,7 @@ private[master] class Master( override def receive: PartialFunction[Any, Unit] = { case ElectedLeader => { - val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData() + val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData(rpcEnv) state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty) { RecoveryState.ALIVE } else { @@ -545,6 +541,7 @@ private[master] class Master( /** * Schedule executors to be launched on the workers. + * Returns an array containing number of cores assigned to each worker. * * There are two modes of launching executors. The first attempts to spread out an application's * executors on as many workers as possible, while the second does the opposite (i.e. launch them @@ -555,39 +552,77 @@ private[master] class Master( * multiple executors from the same application may be launched on the same worker if the worker * has enough cores and memory. Otherwise, each executor grabs all the cores available on the * worker by default, in which case only one executor may be launched on each worker. + * + * It is important to allocate coresPerExecutor on each worker at a time (instead of 1 core + * at a time). Consider the following example: cluster has 4 workers with 16 cores each. + * User requests 3 executors (spark.cores.max = 48, spark.executor.cores = 16). If 1 core is + * allocated at a time, 12 cores from each worker would be assigned to each executor. + * Since 12 < 16, no executors would launch [SPARK-8881]. */ - private def startExecutorsOnWorkers(): Unit = { - // Right now this is a very simple FIFO scheduler. We keep trying to fit in the first app - // in the queue, then the second app, etc. - if (spreadOutApps) { - // Try to spread out each app among all the workers, until it has all its cores - for (app <- waitingApps if app.coresLeft > 0) { - val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE) - .filter(worker => worker.memoryFree >= app.desc.memoryPerExecutorMB && - worker.coresFree >= app.desc.coresPerExecutor.getOrElse(1)) - .sortBy(_.coresFree).reverse - val numUsable = usableWorkers.length - val assigned = new Array[Int](numUsable) // Number of cores to give on each node - var toAssign = math.min(app.coresLeft, usableWorkers.map(_.coresFree).sum) - var pos = 0 - while (toAssign > 0) { - if (usableWorkers(pos).coresFree - assigned(pos) > 0) { - toAssign -= 1 - assigned(pos) += 1 + private def scheduleExecutorsOnWorkers( + app: ApplicationInfo, + usableWorkers: Array[WorkerInfo], + spreadOutApps: Boolean): Array[Int] = { + // If the number of cores per executor is not specified, then we can just schedule + // 1 core at a time since we expect a single executor to be launched on each worker + val coresPerExecutor = app.desc.coresPerExecutor.getOrElse(1) + val memoryPerExecutor = app.desc.memoryPerExecutorMB + val numUsable = usableWorkers.length + val assignedCores = new Array[Int](numUsable) // Number of cores to give to each worker + val assignedMemory = new Array[Int](numUsable) // Amount of memory to give to each worker + var coresToAssign = math.min(app.coresLeft, usableWorkers.map(_.coresFree).sum) + var freeWorkers = (0 until numUsable).toIndexedSeq + + def canLaunchExecutor(pos: Int): Boolean = { + usableWorkers(pos).coresFree - assignedCores(pos) >= coresPerExecutor && + usableWorkers(pos).memoryFree - assignedMemory(pos) >= memoryPerExecutor + } + + while (coresToAssign >= coresPerExecutor && freeWorkers.nonEmpty) { + freeWorkers = freeWorkers.filter(canLaunchExecutor) + freeWorkers.foreach { pos => + var keepScheduling = true + while (keepScheduling && canLaunchExecutor(pos) && coresToAssign >= coresPerExecutor) { + coresToAssign -= coresPerExecutor + assignedCores(pos) += coresPerExecutor + // If cores per executor is not set, we are assigning 1 core at a time + // without actually meaning to launch 1 executor for each core assigned + if (app.desc.coresPerExecutor.isDefined) { + assignedMemory(pos) += memoryPerExecutor + } + + // Spreading out an application means spreading out its executors across as + // many workers as possible. If we are not spreading out, then we should keep + // scheduling executors on this worker until we use all of its resources. + // Otherwise, just move on to the next worker. + if (spreadOutApps) { + keepScheduling = false } - pos = (pos + 1) % numUsable - } - // Now that we've decided how many cores to give on each node, let's actually give them - for (pos <- 0 until numUsable if assigned(pos) > 0) { - allocateWorkerResourceToExecutors(app, assigned(pos), usableWorkers(pos)) } } - } else { - // Pack each app into as few workers as possible until we've assigned all its cores - for (worker <- workers if worker.coresFree > 0 && worker.state == WorkerState.ALIVE) { - for (app <- waitingApps if app.coresLeft > 0) { - allocateWorkerResourceToExecutors(app, app.coresLeft, worker) - } + } + assignedCores + } + + /** + * Schedule and launch executors on workers + */ + private def startExecutorsOnWorkers(): Unit = { + // Right now this is a very simple FIFO scheduler. We keep trying to fit in the first app + // in the queue, then the second app, etc. + for (app <- waitingApps if app.coresLeft > 0) { + val coresPerExecutor: Option[Int] = app.desc.coresPerExecutor + // Filter out workers that don't have enough resources to launch an executor + val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE) + .filter(worker => worker.memoryFree >= app.desc.memoryPerExecutorMB && + worker.coresFree >= coresPerExecutor.getOrElse(1)) + .sortBy(_.coresFree).reverse + val assignedCores = scheduleExecutorsOnWorkers(app, usableWorkers, spreadOutApps) + + // Now that we've decided how many cores to allocate on each worker, let's allocate them + for (pos <- 0 until usableWorkers.length if assignedCores(pos) > 0) { + allocateWorkerResourceToExecutors( + app, assignedCores(pos), coresPerExecutor, usableWorkers(pos)) } } } @@ -595,19 +630,22 @@ private[master] class Master( /** * Allocate a worker's resources to one or more executors. * @param app the info of the application which the executors belong to - * @param coresToAllocate cores on this worker to be allocated to this application + * @param assignedCores number of cores on this worker for this application + * @param coresPerExecutor number of cores per executor * @param worker the worker info */ private def allocateWorkerResourceToExecutors( app: ApplicationInfo, - coresToAllocate: Int, + assignedCores: Int, + coresPerExecutor: Option[Int], worker: WorkerInfo): Unit = { - val memoryPerExecutor = app.desc.memoryPerExecutorMB - val coresPerExecutor = app.desc.coresPerExecutor.getOrElse(coresToAllocate) - var coresLeft = coresToAllocate - while (coresLeft >= coresPerExecutor && worker.memoryFree >= memoryPerExecutor) { - val exec = app.addExecutor(worker, coresPerExecutor) - coresLeft -= coresPerExecutor + // If the number of cores per executor is specified, we divide the cores assigned + // to this worker evenly among the executors with no remainder. + // Otherwise, we launch a single executor that grabs all the assignedCores on this worker. + val numExecutors = coresPerExecutor.map { assignedCores / _ }.getOrElse(1) + val coresToAssign = coresPerExecutor.getOrElse(assignedCores) + for (i <- 1 to numExecutors) { + val exec = app.addExecutor(worker, coresToAssign) launchExecutor(worker, exec) app.state = ApplicationState.RUNNING } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala index 435b9b12f83b8..44cefbc77f08e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala @@ -85,6 +85,7 @@ private[master] class MasterArguments(args: Array[String], conf: SparkConf) { * Print usage and exit JVM with the given exit code. */ private def printUsageAndExit(exitCode: Int) { + // scalastyle:off println System.err.println( "Usage: Master [options]\n" + "\n" + @@ -95,6 +96,7 @@ private[master] class MasterArguments(args: Array[String], conf: SparkConf) { " --webui-port PORT Port for web UI (default: 8080)\n" + " --properties-file FILE Path to a custom Spark properties file.\n" + " Default is conf/spark-defaults.conf.") + // scalastyle:on println System.exit(exitCode) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala index a03d460509e03..58a00bceee6af 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.master import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rpc.RpcEnv import scala.reflect.ClassTag @@ -80,8 +81,11 @@ abstract class PersistenceEngine { * Returns the persisted data sorted by their respective ids (which implies that they're * sorted by time of creation). */ - final def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = { - (read[ApplicationInfo]("app_"), read[DriverInfo]("driver_"), read[WorkerInfo]("worker_")) + final def readPersistedData( + rpcEnv: RpcEnv): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = { + rpcEnv.deserialize { () => + (read[ApplicationInfo]("app_"), read[DriverInfo]("driver_"), read[WorkerInfo]("worker_")) + } } def close() {} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala index 351db8fab2041..c4c3283fb73f7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala @@ -17,10 +17,9 @@ package org.apache.spark.deploy.master -import akka.serialization.Serialization - import org.apache.spark.{Logging, SparkConf} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.serializer.Serializer /** * ::DeveloperApi:: @@ -30,7 +29,7 @@ import org.apache.spark.annotation.DeveloperApi * */ @DeveloperApi -abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serialization) { +abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serializer) { /** * PersistenceEngine defines how the persistent data(Information about worker, driver etc..) @@ -49,7 +48,7 @@ abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serial * LeaderAgent in this case is a no-op. Since leader is forever leader as the actual * recovery is made by restoring from filesystem. */ -private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: Serialization) +private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: Serializer) extends StandaloneRecoveryModeFactory(conf, serializer) with Logging { val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "") @@ -64,7 +63,7 @@ private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: } } -private[master] class ZooKeeperRecoveryModeFactory(conf: SparkConf, serializer: Serialization) +private[master] class ZooKeeperRecoveryModeFactory(conf: SparkConf, serializer: Serializer) extends StandaloneRecoveryModeFactory(conf, serializer) { def createPersistenceEngine(): PersistenceEngine = { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala index 471811037e5e2..f751966605206 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala @@ -105,4 +105,6 @@ private[spark] class WorkerInfo( def setState(state: WorkerState.Value): Unit = { this.state = state } + + def isAlive(): Boolean = this.state == WorkerState.ALIVE } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index 328d95a7a0c68..563831cc6b8dd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.master -import akka.serialization.Serialization +import java.nio.ByteBuffer import scala.collection.JavaConversions._ import scala.reflect.ClassTag @@ -27,9 +27,10 @@ import org.apache.zookeeper.CreateMode import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.SparkCuratorUtil +import org.apache.spark.serializer.Serializer -private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serialization: Serialization) +private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializer: Serializer) extends PersistenceEngine with Logging { @@ -57,17 +58,16 @@ private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializat } private def serializeIntoFile(path: String, value: AnyRef) { - val serializer = serialization.findSerializerFor(value) - val serialized = serializer.toBinary(value) - zk.create().withMode(CreateMode.PERSISTENT).forPath(path, serialized) + val serialized = serializer.newInstance().serialize(value) + val bytes = new Array[Byte](serialized.remaining()) + serialized.get(bytes) + zk.create().withMode(CreateMode.PERSISTENT).forPath(path, bytes) } private def deserializeFromFile[T](filename: String)(implicit m: ClassTag[T]): Option[T] = { val fileData = zk.getData().forPath(WORKING_DIR + "/" + filename) - val clazz = m.runtimeClass.asInstanceOf[Class[T]] - val serializer = serialization.serializerFor(clazz) try { - Some(serializer.fromBinary(fileData).asInstanceOf[T]) + Some(serializer.newInstance().deserialize[T](ByteBuffer.wrap(fileData))) } catch { case e: Exception => { logWarning("Exception while reading persisted file, deleting", e) diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala index 894cb78d8591a..5accaf78d0a51 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala @@ -54,7 +54,9 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: case ("--master" | "-m") :: value :: tail => if (!value.startsWith("mesos://")) { + // scalastyle:off println System.err.println("Cluster dispatcher only supports mesos (uri begins with mesos://)") + // scalastyle:on println System.exit(1) } masterUrl = value.stripPrefix("mesos://") @@ -73,7 +75,9 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: case Nil => { if (masterUrl == null) { + // scalastyle:off println System.err.println("--master is required") + // scalastyle:on println printUsageAndExit(1) } } @@ -83,6 +87,7 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: } private def printUsageAndExit(exitCode: Int): Unit = { + // scalastyle:off println System.err.println( "Usage: MesosClusterDispatcher [options]\n" + "\n" + @@ -96,6 +101,7 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: " Zookeeper for persistence\n" + " --properties-file FILE Path to a custom Spark properties file.\n" + " Default is conf/spark-defaults.conf.") + // scalastyle:on println System.exit(exitCode) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala index e6615a3174ce1..ef5a7e35ad562 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala @@ -128,7 +128,7 @@ private[spark] object SubmitRestProtocolMessage { */ def fromJson(json: String): SubmitRestProtocolMessage = { val className = parseAction(json) - val clazz = Class.forName(packagePrefix + "." + className) + val clazz = Utils.classForName(packagePrefix + "." + className) .asSubclass[SubmitRestProtocolMessage](classOf[SubmitRestProtocolMessage]) fromJson(json, clazz) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index d1a12b01e78f7..6799f78ec0c19 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -53,14 +53,16 @@ object DriverWrapper { Thread.currentThread.setContextClassLoader(loader) // Delegate to supplied main class - val clazz = Class.forName(mainClass, true, loader) + val clazz = Utils.classForName(mainClass) val mainMethod = clazz.getMethod("main", classOf[Array[String]]) mainMethod.invoke(null, extraArgs.toArray[String]) rpcEnv.shutdown() case _ => + // scalastyle:off println System.err.println("Usage: DriverWrapper [options]") + // scalastyle:on println System.exit(-1) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index 1d2ecab517613..5181142c5f80e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -121,6 +121,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { * Print usage and exit JVM with the given exit code. */ def printUsageAndExit(exitCode: Int) { + // scalastyle:off println System.err.println( "Usage: Worker [options] \n" + "\n" + @@ -136,6 +137,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { " --webui-port PORT Port for web UI (default: 8081)\n" + " --properties-file FILE Path to a custom Spark properties file.\n" + " Default is conf/spark-defaults.conf.") + // scalastyle:on println System.exit(exitCode) } @@ -147,6 +149,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { val ibmVendor = System.getProperty("java.vendor").contains("IBM") var totalMb = 0 try { + // scalastyle:off classforname val bean = ManagementFactory.getOperatingSystemMXBean() if (ibmVendor) { val beanClass = Class.forName("com.ibm.lang.management.OperatingSystemMXBean") @@ -157,10 +160,13 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { val method = beanClass.getDeclaredMethod("getTotalPhysicalMemorySize") totalMb = (method.invoke(bean).asInstanceOf[Long] / 1024 / 1024).toInt } + // scalastyle:on classforname } catch { case e: Exception => { totalMb = 2*1024 + // scalastyle:off println System.out.println("Failed to get total physical memory. Using " + totalMb + " MB") + // scalastyle:on println } } // Leave out 1 GB for the operating system, but don't return a negative memory size diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 34d4cfdca7732..fcd76ec52742a 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -235,7 +235,9 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { argv = tail case Nil => case tail => + // scalastyle:off println System.err.println(s"Unrecognized options: ${tail.mkString(" ")}") + // scalastyle:on println printUsageAndExit() } } @@ -249,6 +251,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { } private def printUsageAndExit() = { + // scalastyle:off println System.err.println( """ |"Usage: CoarseGrainedExecutorBackend [options] @@ -262,6 +265,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { | --worker-url | --user-class-path |""".stripMargin) + // scalastyle:on println System.exit(1) } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 8f916e0502ecb..7bc7fce7ae8dd 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -209,15 +209,19 @@ private[spark] class Executor( // Run the actual task and measure its runtime. taskStart = System.currentTimeMillis() - val value = try { - task.run(taskAttemptId = taskId, attemptNumber = attemptNumber) + var threwException = true + val (value, accumUpdates) = try { + val res = task.run( + taskAttemptId = taskId, + attemptNumber = attemptNumber, + metricsSystem = env.metricsSystem) + threwException = false + res } finally { - // Note: this memory freeing logic is duplicated in DAGScheduler.runLocallyWithinThread; - // when changing this, make sure to update both copies. val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory() if (freedMemory > 0) { val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId" - if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) { + if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false) && !threwException) { throw new SparkException(errMsg) } else { logError(errMsg) @@ -247,7 +251,6 @@ private[spark] class Executor( m.setResultSerializationTime(afterSerialization - beforeSerialization) } - val accumUpdates = Accumulators.values val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull) val serializedDirectResult = ser.serialize(directResult) val resultSize = serializedDirectResult.limit @@ -310,12 +313,6 @@ private[spark] class Executor( } } finally { - // Release memory used by this thread for shuffles - env.shuffleMemoryManager.releaseMemoryForThisThread() - // Release memory used by this thread for unrolling blocks - env.blockManager.memoryStore.releaseUnrollMemoryForThisThread() - // Release memory used by this thread for accumulators - Accumulators.clear() runningTasks.remove(taskId) } } @@ -356,7 +353,7 @@ private[spark] class Executor( logInfo("Using REPL class URI: " + classUri) try { val _userClassPathFirst: java.lang.Boolean = userClassPathFirst - val klass = Class.forName("org.apache.spark.repl.ExecutorClassLoader") + val klass = Utils.classForName("org.apache.spark.repl.ExecutorClassLoader") .asInstanceOf[Class[_ <: ClassLoader]] val constructor = klass.getConstructor(classOf[SparkConf], classOf[String], classOf[ClassLoader], classOf[Boolean]) @@ -424,6 +421,7 @@ private[spark] class Executor( metrics.updateShuffleReadMetrics() metrics.updateInputMetrics() metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime) + metrics.updateAccumulators() if (isLocal) { // JobProgressListener will hold an reference of it during @@ -443,7 +441,7 @@ private[spark] class Executor( try { val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse](message) if (response.reregisterBlockManager) { - logWarning("Told to re-register on heartbeat") + logInfo("Told to re-register on heartbeat") env.blockManager.reregister() } } catch { diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index a3b4561b07e7f..42207a9553592 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -17,11 +17,15 @@ package org.apache.spark.executor +import java.io.{IOException, ObjectInputStream} +import java.util.concurrent.ConcurrentHashMap + import scala.collection.mutable.ArrayBuffer import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.DataReadMethod.DataReadMethod import org.apache.spark.storage.{BlockId, BlockStatus} +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -210,10 +214,42 @@ class TaskMetrics extends Serializable { private[spark] def updateInputMetrics(): Unit = synchronized { inputMetrics.foreach(_.updateBytesRead()) } + + @throws(classOf[IOException]) + private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { + in.defaultReadObject() + // Get the hostname from cached data, since hostname is the order of number of nodes in + // cluster, so using cached hostname will decrease the object number and alleviate the GC + // overhead. + _hostname = TaskMetrics.getCachedHostName(_hostname) + } + + private var _accumulatorUpdates: Map[Long, Any] = Map.empty + @transient private var _accumulatorsUpdater: () => Map[Long, Any] = null + + private[spark] def updateAccumulators(): Unit = synchronized { + _accumulatorUpdates = _accumulatorsUpdater() + } + + /** + * Return the latest updates of accumulators in this task. + */ + def accumulatorUpdates(): Map[Long, Any] = _accumulatorUpdates + + private[spark] def setAccumulatorsUpdater(accumulatorsUpdater: () => Map[Long, Any]): Unit = { + _accumulatorsUpdater = accumulatorsUpdater + } } private[spark] object TaskMetrics { + private val hostNameCache = new ConcurrentHashMap[String, String]() + def empty: TaskMetrics = new TaskMetrics + + def getCachedHostName(host: String): String = { + val canonicalHost = hostNameCache.putIfAbsent(host, host) + if (canonicalHost != null) canonicalHost else host + } } /** diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala index c219d21fbefa9..532850dd57716 100644 --- a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala @@ -21,6 +21,8 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{BytesWritable, LongWritable} import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext} + +import org.apache.spark.Logging import org.apache.spark.deploy.SparkHadoopUtil /** @@ -39,7 +41,8 @@ private[spark] object FixedLengthBinaryInputFormat { } private[spark] class FixedLengthBinaryInputFormat - extends FileInputFormat[LongWritable, BytesWritable] { + extends FileInputFormat[LongWritable, BytesWritable] + with Logging { private var recordLength = -1 @@ -51,7 +54,7 @@ private[spark] class FixedLengthBinaryInputFormat recordLength = FixedLengthBinaryInputFormat.getRecordLength(context) } if (recordLength <= 0) { - println("record length is less than 0, file cannot be split") + logDebug("record length is less than 0, file cannot be split") false } else { true diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 0d8ac1f80a9f4..607d5a321efca 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -63,8 +63,7 @@ private[spark] object CompressionCodec { def createCodec(conf: SparkConf, codecName: String): CompressionCodec = { val codecClass = shortCompressionCodecNames.getOrElse(codecName.toLowerCase, codecName) val codec = try { - val ctor = Class.forName(codecClass, true, Utils.getContextOrSparkClassLoader) - .getConstructor(classOf[SparkConf]) + val ctor = Utils.classForName(codecClass).getConstructor(classOf[SparkConf]) Some(ctor.newInstance(conf).asInstanceOf[CompressionCodec]) } catch { case e: ClassNotFoundException => None diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index 818f7a4c8d422..87df42748be44 100644 --- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -26,6 +26,7 @@ import org.apache.hadoop.mapreduce.{OutputCommitter => MapReduceOutputCommitter} import org.apache.spark.executor.CommitDeniedException import org.apache.spark.{Logging, SparkEnv, TaskContext} +import org.apache.spark.util.{Utils => SparkUtils} private[spark] trait SparkHadoopMapRedUtil { @@ -64,10 +65,10 @@ trait SparkHadoopMapRedUtil { private def firstAvailableClass(first: String, second: String): Class[_] = { try { - Class.forName(first) + SparkUtils.classForName(first) } catch { case e: ClassNotFoundException => - Class.forName(second) + SparkUtils.classForName(second) } } } diff --git a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala index 390d148bc97f9..943ebcb7bd0a1 100644 --- a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala @@ -21,6 +21,7 @@ import java.lang.{Boolean => JBoolean, Integer => JInteger} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapreduce.{JobContext, JobID, TaskAttemptContext, TaskAttemptID} +import org.apache.spark.util.Utils private[spark] trait SparkHadoopMapReduceUtil { @@ -46,7 +47,7 @@ trait SparkHadoopMapReduceUtil { isMap: Boolean, taskId: Int, attemptId: Int): TaskAttemptID = { - val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID") + val klass = Utils.classForName("org.apache.hadoop.mapreduce.TaskAttemptID") try { // First, attempt to use the old-style constructor that takes a boolean isMap // (not available in YARN) @@ -57,7 +58,7 @@ trait SparkHadoopMapReduceUtil { } catch { case exc: NoSuchMethodException => { // If that failed, look for the new constructor that takes a TaskType (not available in 1.x) - val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType") + val taskTypeClass = Utils.classForName("org.apache.hadoop.mapreduce.TaskType") .asInstanceOf[Class[Enum[_]]] val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke( taskTypeClass, if (isMap) "MAP" else "REDUCE") @@ -71,10 +72,10 @@ trait SparkHadoopMapReduceUtil { private def firstAvailableClass(first: String, second: String): Class[_] = { try { - Class.forName(first) + Utils.classForName(first) } catch { case e: ClassNotFoundException => - Class.forName(second) + Utils.classForName(second) } } } diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index ed5131c79fdc5..4517f465ebd3b 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -20,6 +20,8 @@ package org.apache.spark.metrics import java.util.Properties import java.util.concurrent.TimeUnit +import org.apache.spark.util.Utils + import scala.collection.mutable import com.codahale.metrics.{Metric, MetricFilter, MetricRegistry} @@ -140,6 +142,9 @@ private[spark] class MetricsSystem private ( } else { defaultName } } + def getSourcesByName(sourceName: String): Seq[Source] = + sources.filter(_.sourceName == sourceName) + def registerSource(source: Source) { sources += source try { @@ -166,7 +171,7 @@ private[spark] class MetricsSystem private ( sourceConfigs.foreach { kv => val classPath = kv._2.getProperty("class") try { - val source = Class.forName(classPath).newInstance() + val source = Utils.classForName(classPath).newInstance() registerSource(source.asInstanceOf[Source]) } catch { case e: Exception => logError("Source class " + classPath + " cannot be instantiated", e) @@ -182,7 +187,7 @@ private[spark] class MetricsSystem private ( val classPath = kv._2.getProperty("class") if (null != classPath) { try { - val sink = Class.forName(classPath) + val sink = Utils.classForName(classPath) .getConstructor(classOf[Properties], classOf[MetricRegistry], classOf[SecurityManager]) .newInstance(kv._2, registry, securityMgr) if (kv._1 == "servlet") { diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala index 67a376102994c..79cb0640c8672 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala @@ -57,16 +57,6 @@ private[nio] class BlockMessage() { } def set(buffer: ByteBuffer) { - /* - println() - println("BlockMessage: ") - while(buffer.remaining > 0) { - print(buffer.get()) - } - buffer.rewind() - println() - println() - */ typ = buffer.getInt() val idLength = buffer.getInt() val idBuilder = new StringBuilder(idLength) @@ -138,18 +128,6 @@ private[nio] class BlockMessage() { buffers += data } - /* - println() - println("BlockMessage: ") - buffers.foreach(b => { - while(b.remaining > 0) { - print(b.get()) - } - b.rewind() - }) - println() - println() - */ Message.createBufferMessage(buffers) } diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala index 7d0806f0c2580..f1c9ea8b64ca3 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala @@ -43,16 +43,6 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage]) val newBlockMessages = new ArrayBuffer[BlockMessage]() val buffer = bufferMessage.buffers(0) buffer.clear() - /* - println() - println("BlockMessageArray: ") - while(buffer.remaining > 0) { - print(buffer.get()) - } - buffer.rewind() - println() - println() - */ while (buffer.remaining() > 0) { val size = buffer.getInt() logDebug("Creating block message of size " + size + " bytes") @@ -86,23 +76,11 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage]) logDebug("Buffer list:") buffers.foreach((x: ByteBuffer) => logDebug("" + x)) - /* - println() - println("BlockMessageArray: ") - buffers.foreach(b => { - while(b.remaining > 0) { - print(b.get()) - } - b.rewind() - }) - println() - println() - */ Message.createBufferMessage(buffers) } } -private[nio] object BlockMessageArray { +private[nio] object BlockMessageArray extends Logging { def fromBufferMessage(bufferMessage: BufferMessage): BlockMessageArray = { val newBlockMessageArray = new BlockMessageArray() @@ -123,10 +101,10 @@ private[nio] object BlockMessageArray { } } val blockMessageArray = new BlockMessageArray(blockMessages) - println("Block message array created") + logDebug("Block message array created") val bufferMessage = blockMessageArray.toBufferMessage - println("Converted to buffer message") + logDebug("Converted to buffer message") val totalSize = bufferMessage.size val newBuffer = ByteBuffer.allocate(totalSize) @@ -138,10 +116,11 @@ private[nio] object BlockMessageArray { }) newBuffer.flip val newBufferMessage = Message.createBufferMessage(newBuffer) - println("Copied to new buffer message, size = " + newBufferMessage.size) + logDebug("Copied to new buffer message, size = " + newBufferMessage.size) val newBlockMessageArray = BlockMessageArray.fromBufferMessage(newBufferMessage) - println("Converted back to block message array") + logDebug("Converted back to block message array") + // scalastyle:off println newBlockMessageArray.foreach(blockMessage => { blockMessage.getType match { case BlockMessage.TYPE_PUT_BLOCK => { @@ -154,6 +133,7 @@ private[nio] object BlockMessageArray { } } }) + // scalastyle:on println } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index c0bca2c4bc994..9143918790381 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -1016,7 +1016,9 @@ private[spark] object ConnectionManager { val conf = new SparkConf val manager = new ConnectionManager(9999, conf, new SecurityManager(conf)) manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + // scalastyle:off println println("Received [" + msg + "] from [" + id + "]") + // scalastyle:on println None }) @@ -1033,6 +1035,7 @@ private[spark] object ConnectionManager { System.gc() } + // scalastyle:off println def testSequentialSending(manager: ConnectionManager) { println("--------------------------") println("Sequential Sending") @@ -1150,4 +1153,5 @@ private[spark] object ConnectionManager { println() } } + // scalastyle:on println } diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 658e8c8b89318..130b58882d8ee 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -94,13 +94,14 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: } override def getDependencies: Seq[Dependency[_]] = { - rdds.map { rdd: RDD[_ <: Product2[K, _]] => + rdds.map { rdd: RDD[_] => if (rdd.partitioner == Some(part)) { logDebug("Adding one-to-one dependency with " + rdd) new OneToOneDependency(rdd) } else { logDebug("Adding shuffle dependency with " + rdd) - new ShuffleDependency[K, Any, CoGroupCombiner](rdd, part, serializer) + new ShuffleDependency[K, Any, CoGroupCombiner]( + rdd.asInstanceOf[RDD[_ <: Product2[K, _]]], part, serializer) } } } @@ -133,7 +134,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: // A list of (rdd iterator, dependency number) pairs val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)] for ((dep, depNum) <- dependencies.zipWithIndex) dep match { - case oneToOneDependency: OneToOneDependency[Product2[K, Any]] => + case oneToOneDependency: OneToOneDependency[Product2[K, Any]] @unchecked => val dependencyPartition = split.narrowDeps(depNum).get.split // Read them from the parent val it = oneToOneDependency.rdd.iterator(dependencyPartition, context) diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index 663eebb8e4191..90d9735cb3f69 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -69,7 +69,7 @@ private[spark] case class CoalescedRDDPartition( * the preferred location of each new partition overlaps with as many preferred locations of its * parent partitions * @param prev RDD to be coalesced - * @param maxPartitions number of desired partitions in the coalesced RDD + * @param maxPartitions number of desired partitions in the coalesced RDD (must be positive) * @param balanceSlack used to trade-off balance and locality. 1.0 is all locality, 0 is all balance */ private[spark] class CoalescedRDD[T: ClassTag]( @@ -78,6 +78,9 @@ private[spark] class CoalescedRDD[T: ClassTag]( balanceSlack: Double = 0.10) extends RDD[T](prev.context, Nil) { // Nil since we implement getDependencies + require(maxPartitions > 0 || maxPartitions == prev.partitions.length, + s"Number of partitions ($maxPartitions) must be positive.") + override def getPartitions: Array[Partition] = { val pc = new PartitionCoalescer(maxPartitions, prev, balanceSlack) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index bee59a437f120..f1c17369cb48c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -383,11 +383,11 @@ private[spark] object HadoopRDD extends Logging { private[spark] class SplitInfoReflections { val inputSplitWithLocationInfo = - Class.forName("org.apache.hadoop.mapred.InputSplitWithLocationInfo") + Utils.classForName("org.apache.hadoop.mapred.InputSplitWithLocationInfo") val getLocationInfo = inputSplitWithLocationInfo.getMethod("getLocationInfo") - val newInputSplit = Class.forName("org.apache.hadoop.mapreduce.InputSplit") + val newInputSplit = Utils.classForName("org.apache.hadoop.mapreduce.InputSplit") val newGetLocationInfo = newInputSplit.getMethod("getLocationInfo") - val splitLocationInfo = Class.forName("org.apache.hadoop.mapred.SplitLocationInfo") + val splitLocationInfo = Utils.classForName("org.apache.hadoop.mapred.SplitLocationInfo") val isInMemory = splitLocationInfo.getMethod("isInMemory") val getLocation = splitLocationInfo.getMethod("getLocation") } diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index f827270ee6a44..f83a051f5da11 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -128,7 +128,7 @@ class NewHadoopRDD[K, V]( configurable.setConf(conf) case _ => } - val reader = format.createRecordReader( + private var reader = format.createRecordReader( split.serializableHadoopSplit.value, hadoopAttemptContext) reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) @@ -141,6 +141,12 @@ class NewHadoopRDD[K, V]( override def hasNext: Boolean = { if (!finished && !havePair) { finished = !reader.nextKeyValue + if (finished) { + // Close and release the reader here; close() will also be called when the task + // completes, but for tasks that read from many files, it helps to release the + // resources early. + close() + } havePair = !finished } !finished @@ -159,18 +165,23 @@ class NewHadoopRDD[K, V]( private def close() { try { - reader.close() - if (bytesReadCallback.isDefined) { - inputMetrics.updateBytesRead() - } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || - split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { - // If we can't get the bytes read from the FS stats, fall back to the split size, - // which may be inaccurate. - try { - inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) - } catch { - case e: java.io.IOException => - logWarning("Unable to get input size to set InputMetrics for task", e) + if (reader != null) { + // Close reader and release it + reader.close() + reader = null + + if (bytesReadCallback.isDefined) { + inputMetrics.updateBytesRead() + } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || + split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { + // If we can't get the bytes read from the FS stats, fall back to the split size, + // which may be inaccurate. + try { + inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) + } catch { + case e: java.io.IOException => + logWarning("Unable to get input size to set InputMetrics for task", e) + } } } } catch { diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 91a6a2d039852..326fafb230a40 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -881,7 +881,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } buf } : Seq[V] - val res = self.context.runJob(self, process, Array(index), false) + val res = self.context.runJob(self, process, Array(index)) res(0) case None => self.filter(_._1 == key).map(_._2).collect() diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index dc60d48927624..3bb9998e1db44 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -123,7 +123,9 @@ private[spark] class PipedRDD[T: ClassTag]( new Thread("stderr reader for " + command) { override def run() { for (line <- Source.fromInputStream(proc.getErrorStream).getLines) { + // scalastyle:off println System.err.println(line) + // scalastyle:on println } } }.start() @@ -131,8 +133,10 @@ private[spark] class PipedRDD[T: ClassTag]( // Start a thread to feed the process input from our parent's iterator new Thread("stdin writer for " + command) { override def run() { + TaskContext.setTaskContext(context) val out = new PrintWriter(proc.getOutputStream) + // scalastyle:off println // input the pipe context firstly if (printPipeContext != null) { printPipeContext(out.println(_)) @@ -144,6 +148,7 @@ private[spark] class PipedRDD[T: ClassTag]( out.println(elem) } } + // scalastyle:on println out.close() } }.start() diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 9f7ebae3e9af3..6d61d227382d7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -897,7 +897,7 @@ abstract class RDD[T: ClassTag]( */ def toLocalIterator: Iterator[T] = withScope { def collectPartition(p: Int): Array[T] = { - sc.runJob(this, (iter: Iterator[T]) => iter.toArray, Seq(p), allowLocal = false).head + sc.runJob(this, (iter: Iterator[T]) => iter.toArray, Seq(p)).head } (0 until partitions.length).iterator.flatMap(i => collectPartition(i)) } @@ -1082,7 +1082,9 @@ abstract class RDD[T: ClassTag]( val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2) // If creating an extra level doesn't help reduce // the wall-clock time, we stop tree aggregation. - while (numPartitions > scale + numPartitions / scale) { + + // Don't trigger TreeAggregation when it doesn't save wall-clock time + while (numPartitions > scale + math.ceil(numPartitions.toDouble / scale)) { numPartitions /= scale val curNumPartitions = numPartitions partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { @@ -1273,7 +1275,7 @@ abstract class RDD[T: ClassTag]( val left = num - buf.size val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) - val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p, allowLocal = true) + val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p) res.foreach(buf ++= _.take(num - buf.size)) partsScanned += numPartsToTry diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala similarity index 81% rename from sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala rename to core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala index 2bdc341021256..35e44cb59c1be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala @@ -15,28 +15,28 @@ * limitations under the License. */ -package org.apache.spark.sql.sources +package org.apache.spark.rdd import java.text.SimpleDateFormat import java.util.Date +import scala.reflect.ClassTag + import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} -import org.apache.spark.broadcast.Broadcast - -import org.apache.spark.{Partition => SparkPartition, _} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil -import org.apache.spark.rdd.{RDD, HadoopRDD} +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.{Partition => SparkPartition, _} import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{SerializableConfiguration, Utils} -import scala.reflect.ClassTag private[spark] class SqlNewHadoopPartition( rddId: Int, @@ -63,7 +63,7 @@ private[spark] class SqlNewHadoopPartition( * changes based on [[org.apache.spark.rdd.HadoopRDD]]. In future, this functionality will be * folded into core. */ -private[sql] class SqlNewHadoopRDD[K, V]( +private[spark] class SqlNewHadoopRDD[K, V]( @transient sc : SparkContext, broadcastedConf: Broadcast[SerializableConfiguration], @transient initDriverSideJobFuncOpt: Option[Job => Unit], @@ -129,6 +129,12 @@ private[sql] class SqlNewHadoopRDD[K, V]( val inputMetrics = context.taskMetrics .getInputMetricsForReadMethod(DataReadMethod.Hadoop) + // Sets the thread local variable for the file's name + split.serializableHadoopSplit.value match { + case fs: FileSplit => SqlNewHadoopRDD.setInputFileName(fs.getPath.toString) + case _ => SqlNewHadoopRDD.unsetInputFileName() + } + // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { @@ -148,7 +154,7 @@ private[sql] class SqlNewHadoopRDD[K, V]( configurable.setConf(conf) case _ => } - val reader = format.createRecordReader( + private var reader = format.createRecordReader( split.serializableHadoopSplit.value, hadoopAttemptContext) reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) @@ -161,6 +167,12 @@ private[sql] class SqlNewHadoopRDD[K, V]( override def hasNext: Boolean = { if (!finished && !havePair) { finished = !reader.nextKeyValue + if (finished) { + // Close and release the reader here; close() will also be called when the task + // completes, but for tasks that read from many files, it helps to release the + // resources early. + close() + } havePair = !finished } !finished @@ -179,18 +191,24 @@ private[sql] class SqlNewHadoopRDD[K, V]( private def close() { try { - reader.close() - if (bytesReadCallback.isDefined) { - inputMetrics.updateBytesRead() - } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || - split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { - // If we can't get the bytes read from the FS stats, fall back to the split size, - // which may be inaccurate. - try { - inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) - } catch { - case e: java.io.IOException => - logWarning("Unable to get input size to set InputMetrics for task", e) + if (reader != null) { + reader.close() + reader = null + + SqlNewHadoopRDD.unsetInputFileName() + + if (bytesReadCallback.isDefined) { + inputMetrics.updateBytesRead() + } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || + split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { + // If we can't get the bytes read from the FS stats, fall back to the split size, + // which may be inaccurate. + try { + inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) + } catch { + case e: java.io.IOException => + logWarning("Unable to get input size to set InputMetrics for task", e) + } } } } catch { @@ -241,6 +259,21 @@ private[sql] class SqlNewHadoopRDD[K, V]( } private[spark] object SqlNewHadoopRDD { + + /** + * The thread variable for the name of the current file being read. This is used by + * the InputFileName function in Spark SQL. + */ + private[this] val inputFileName: ThreadLocal[UTF8String] = new ThreadLocal[UTF8String] { + override protected def initialValue(): UTF8String = UTF8String.fromString("") + } + + def getInputFileName(): UTF8String = inputFileName.get() + + private[spark] def setInputFileName(file: String) = inputFileName.set(UTF8String.fromString(file)) + + private[spark] def unsetInputFileName(): Unit = inputFileName.remove() + /** * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to * the given function rather than the index of the partition. diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala index 523aaf2b860b5..e277ae28d588f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala @@ -50,8 +50,7 @@ class ZippedWithIndexRDD[T: ClassTag](@transient prev: RDD[T]) extends RDD[(T, L prev.context.runJob( prev, Utils.getIteratorSize _, - 0 until n - 1, // do not need to count the last partition - allowLocal = false + 0 until n - 1 // do not need to count the last partition ).scanLeft(0L)(_ + _) } } diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 1709bdf560b6f..29debe8081308 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -39,8 +39,7 @@ private[spark] object RpcEnv { val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory") val rpcEnvName = conf.get("spark.rpc", "akka") val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName) - Class.forName(rpcEnvFactoryClassName, true, Utils.getContextOrSparkClassLoader). - newInstance().asInstanceOf[RpcEnvFactory] + Utils.classForName(rpcEnvFactoryClassName).newInstance().asInstanceOf[RpcEnvFactory] } def create( @@ -140,6 +139,12 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { * creating it manually because different [[RpcEnv]] may have different formats. */ def uriOf(systemName: String, address: RpcAddress, endpointName: String): String + + /** + * [[RpcEndpointRef]] cannot be deserialized without [[RpcEnv]]. So when deserializing any object + * that contains [[RpcEndpointRef]]s, the deserialization codes should be wrapped by this method. + */ + def deserialize[T](deserializationAction: () => T): T } diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index f2d87f68341af..fc17542abf81d 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -28,7 +28,7 @@ import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Add import akka.event.Logging.Error import akka.pattern.{ask => akkaAsk} import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent} -import com.google.common.util.concurrent.MoreExecutors +import akka.serialization.JavaSerializer import org.apache.spark.{SparkException, Logging, SparkConf} import org.apache.spark.rpc._ @@ -239,6 +239,12 @@ private[spark] class AkkaRpcEnv private[akka] ( } override def toString: String = s"${getClass.getSimpleName}($actorSystem)" + + override def deserialize[T](deserializationAction: () => T): T = { + JavaSerializer.currentSystem.withValue(actorSystem.asInstanceOf[ExtendedActorSystem]) { + deserializationAction() + } + } } private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory { @@ -315,6 +321,12 @@ private[akka] class AkkaRpcEndpointRef( override def toString: String = s"${getClass.getSimpleName}($actorRef)" + final override def equals(that: Any): Boolean = that match { + case other: AkkaRpcEndpointRef => actorRef == other.actorRef + case _ => false + } + + final override def hashCode(): Int = if (actorRef == null) 0 else actorRef.hashCode() } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 6841fa835747f..c4fa277c21254 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -22,7 +22,8 @@ import java.util.Properties import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicInteger -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack} +import scala.collection.Map +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Stack} import scala.concurrent.duration._ import scala.language.existentials import scala.language.postfixOps @@ -37,7 +38,6 @@ import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator import org.apache.spark.rdd.RDD import org.apache.spark.rpc.RpcTimeout import org.apache.spark.storage._ -import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util._ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat @@ -127,10 +127,6 @@ class DAGScheduler( // This is only safe because DAGScheduler runs in a single thread. private val closureSerializer = SparkEnv.get.closureSerializer.newInstance() - - /** If enabled, we may run certain actions like take() and first() locally. */ - private val localExecutionEnabled = sc.getConf.getBoolean("spark.localExecution.enabled", false) - /** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */ private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false) @@ -514,7 +510,6 @@ class DAGScheduler( func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], callSite: CallSite, - allowLocal: Boolean, resultHandler: (Int, U) => Unit, properties: Properties): JobWaiter[U] = { // Check to make sure we are not launching a task on a partition that does not exist. @@ -534,7 +529,7 @@ class DAGScheduler( val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler) eventProcessLoop.post(JobSubmitted( - jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, + jobId, rdd, func2, partitions.toArray, callSite, waiter, SerializationUtils.clone(properties))) waiter } @@ -544,11 +539,10 @@ class DAGScheduler( func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], callSite: CallSite, - allowLocal: Boolean, resultHandler: (Int, U) => Unit, properties: Properties): Unit = { val start = System.nanoTime - val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties) + val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties) waiter.awaitResult() match { case JobSucceeded => logInfo("Job %d finished: %s, took %f s".format @@ -556,6 +550,9 @@ class DAGScheduler( case JobFailed(exception: Exception) => logInfo("Job %d failed: %s, took %f s".format (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) + // SPARK-8644: Include user stack trace in exceptions coming from DAGScheduler. + val callerStackTrace = Thread.currentThread().getStackTrace.tail + exception.setStackTrace(exception.getStackTrace ++ callerStackTrace) throw exception } } @@ -572,8 +569,7 @@ class DAGScheduler( val partitions = (0 until rdd.partitions.size).toArray val jobId = nextJobId.getAndIncrement() eventProcessLoop.post(JobSubmitted( - jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, - SerializationUtils.clone(properties))) + jobId, rdd, func2, partitions, callSite, listener, SerializationUtils.clone(properties))) listener.awaitResult() // Will throw an exception if the job fails } @@ -650,73 +646,6 @@ class DAGScheduler( } } - /** - * Run a job on an RDD locally, assuming it has only a single partition and no dependencies. - * We run the operation in a separate thread just in case it takes a bunch of time, so that we - * don't block the DAGScheduler event loop or other concurrent jobs. - */ - protected def runLocally(job: ActiveJob) { - logInfo("Computing the requested partition locally") - new Thread("Local computation of job " + job.jobId) { - override def run() { - runLocallyWithinThread(job) - } - }.start() - } - - // Broken out for easier testing in DAGSchedulerSuite. - protected def runLocallyWithinThread(job: ActiveJob) { - var jobResult: JobResult = JobSucceeded - try { - val rdd = job.finalStage.rdd - val split = rdd.partitions(job.partitions(0)) - val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager) - val taskContext = - new TaskContextImpl( - job.finalStage.id, - job.partitions(0), - taskAttemptId = 0, - attemptNumber = 0, - taskMemoryManager = taskMemoryManager, - runningLocally = true) - TaskContext.setTaskContext(taskContext) - try { - val result = job.func(taskContext, rdd.iterator(split, taskContext)) - job.listener.taskSucceeded(0, result) - } finally { - taskContext.markTaskCompleted() - TaskContext.unset() - // Note: this memory freeing logic is duplicated in Executor.run(); when changing this, - // make sure to update both copies. - val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory() - if (freedMemory > 0) { - if (sc.getConf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) { - throw new SparkException(s"Managed memory leak detected; size = $freedMemory bytes") - } else { - logError(s"Managed memory leak detected; size = $freedMemory bytes") - } - } - } - } catch { - case e: Exception => - val exception = new SparkDriverExecutionException(e) - jobResult = JobFailed(exception) - job.listener.jobFailed(exception) - case oom: OutOfMemoryError => - val exception = new SparkException("Local job aborted due to out of memory error", oom) - jobResult = JobFailed(exception) - job.listener.jobFailed(exception) - } finally { - val s = job.finalStage - // clean up data structures that were populated for a local job, - // but that won't get cleaned up via the normal paths through - // completion events or stage abort - stageIdToStage -= s.id - jobIdToStageIds -= job.jobId - listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), jobResult)) - } - } - /** Finds the earliest-created active job that needs the stage */ // TODO: Probably should actually find among the active jobs that need this // stage the one with the highest priority (highest-priority pool, earliest created). @@ -779,7 +708,6 @@ class DAGScheduler( finalRDD: RDD[_], func: (TaskContext, Iterator[_]) => _, partitions: Array[Int], - allowLocal: Boolean, callSite: CallSite, listener: JobListener, properties: Properties) { @@ -797,29 +725,20 @@ class DAGScheduler( if (finalStage != null) { val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties) clearCacheLocs() - logInfo("Got job %s (%s) with %d output partitions (allowLocal=%s)".format( - job.jobId, callSite.shortForm, partitions.length, allowLocal)) + logInfo("Got job %s (%s) with %d output partitions".format( + job.jobId, callSite.shortForm, partitions.length)) logInfo("Final stage: " + finalStage + "(" + finalStage.name + ")") logInfo("Parents of final stage: " + finalStage.parents) logInfo("Missing parents: " + getMissingParentStages(finalStage)) - val shouldRunLocally = - localExecutionEnabled && allowLocal && finalStage.parents.isEmpty && partitions.length == 1 val jobSubmissionTime = clock.getTimeMillis() - if (shouldRunLocally) { - // Compute very short actions like first() or take() with no parent stages locally. - listenerBus.post( - SparkListenerJobStart(job.jobId, jobSubmissionTime, Seq.empty, properties)) - runLocally(job) - } else { - jobIdToActiveJob(jobId) = job - activeJobs += job - finalStage.resultOfJob = Some(job) - val stageIds = jobIdToStageIds(jobId).toArray - val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) - listenerBus.post( - SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties)) - submitStage(finalStage) - } + jobIdToActiveJob(jobId) = job + activeJobs += job + finalStage.resultOfJob = Some(job) + val stageIds = jobIdToStageIds(jobId).toArray + val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) + listenerBus.post( + SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties)) + submitStage(finalStage) } submitWaitingStages() } @@ -853,7 +772,6 @@ class DAGScheduler( // Get our pending tasks and remember them in our pendingTasks entry stage.pendingTasks.clear() - // First figure out the indexes of partition ids to compute. val partitionsToCompute: Seq[Int] = { stage match { @@ -872,8 +790,28 @@ class DAGScheduler( // serializable. If tasks are not serializable, a SparkListenerStageCompleted event // will be posted, which should always come after a corresponding SparkListenerStageSubmitted // event. - stage.latestInfo = StageInfo.fromStage(stage, Some(partitionsToCompute.size)) outputCommitCoordinator.stageStart(stage.id) + val taskIdToLocations = try { + stage match { + case s: ShuffleMapStage => + partitionsToCompute.map { id => (id, getPreferredLocs(stage.rdd, id))}.toMap + case s: ResultStage => + val job = s.resultOfJob.get + partitionsToCompute.map { id => + val p = job.partitions(id) + (id, getPreferredLocs(stage.rdd, p)) + }.toMap + } + } catch { + case NonFatal(e) => + stage.makeNewStageAttempt(partitionsToCompute.size) + listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties)) + abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}") + runningStages -= stage + return + } + + stage.makeNewStageAttempt(partitionsToCompute.size, taskIdToLocations.values.toSeq) listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties)) // TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times. @@ -912,9 +850,9 @@ class DAGScheduler( stage match { case stage: ShuffleMapStage => partitionsToCompute.map { id => - val locs = getPreferredLocs(stage.rdd, id) + val locs = taskIdToLocations(id) val part = stage.rdd.partitions(id) - new ShuffleMapTask(stage.id, taskBinary, part, locs) + new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs) } case stage: ResultStage => @@ -922,8 +860,8 @@ class DAGScheduler( partitionsToCompute.map { id => val p: Int = job.partitions(id) val part = stage.rdd.partitions(p) - val locs = getPreferredLocs(stage.rdd, p) - new ResultTask(stage.id, taskBinary, part, locs, id) + val locs = taskIdToLocations(id) + new ResultTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs, id) } } } catch { @@ -937,8 +875,8 @@ class DAGScheduler( logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")") stage.pendingTasks ++= tasks logDebug("New pending tasks: " + stage.pendingTasks) - taskScheduler.submitTasks( - new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.firstJobId, properties)) + taskScheduler.submitTasks(new TaskSet( + tasks.toArray, stage.id, stage.latestInfo.attemptId, stage.firstJobId, properties)) stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) } else { // Because we posted SparkListenerStageSubmitted earlier, we should mark @@ -978,11 +916,9 @@ class DAGScheduler( // To avoid UI cruft, ignore cases where value wasn't updated if (acc.name.isDefined && partialValue != acc.zero) { val name = acc.name.get - val stringPartialValue = Accumulators.stringifyPartialValue(partialValue) - val stringValue = Accumulators.stringifyValue(acc.value) - stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, stringValue) + stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, s"${acc.value}") event.taskInfo.accumulables += - AccumulableInfo(id, name, Some(stringPartialValue), stringValue) + AccumulableInfo(id, name, Some(s"$partialValue"), s"${acc.value}") } } } catch { @@ -1009,7 +945,7 @@ class DAGScheduler( // The success case is dealt with separately below, since we need to compute accumulator // updates before posting. if (event.reason != Success) { - val attemptId = stageIdToStage.get(task.stageId).map(_.latestInfo.attemptId).getOrElse(-1) + val attemptId = task.stageAttemptId listenerBus.post(SparkListenerTaskEnd(stageId, attemptId, taskType, event.reason, event.taskInfo, event.taskMetrics)) } @@ -1065,10 +1001,11 @@ class DAGScheduler( val execId = status.location.executorId logDebug("ShuffleMapTask finished on " + execId) if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) { - logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId) + logInfo(s"Ignoring possibly bogus $smt completion from executor $execId") } else { shuffleStage.addOutputLoc(smt.partitionId, status) } + if (runningStages.contains(shuffleStage) && shuffleStage.pendingTasks.isEmpty) { markStageAsFinished(shuffleStage) logInfo("looking for newly runnable stages") @@ -1128,38 +1065,48 @@ class DAGScheduler( val failedStage = stageIdToStage(task.stageId) val mapStage = shuffleToMapStage(shuffleId) - // It is likely that we receive multiple FetchFailed for a single stage (because we have - // multiple tasks running concurrently on different executors). In that case, it is possible - // the fetch failure has already been handled by the scheduler. - if (runningStages.contains(failedStage)) { - logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + - s"due to a fetch failure from $mapStage (${mapStage.name})") - markStageAsFinished(failedStage, Some(failureMessage)) - } + if (failedStage.latestInfo.attemptId != task.stageAttemptId) { + logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" + + s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + + s"(attempt ID ${failedStage.latestInfo.attemptId}) running") + } else { - if (disallowStageRetryForTest) { - abortStage(failedStage, "Fetch failure will not retry stage due to testing config") - } else if (failedStages.isEmpty) { - // Don't schedule an event to resubmit failed stages if failed isn't empty, because - // in that case the event will already have been scheduled. - // TODO: Cancel running tasks in the stage - logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " + - s"$failedStage (${failedStage.name}) due to fetch failure") - messageScheduler.schedule(new Runnable { - override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) - }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS) - } - failedStages += failedStage - failedStages += mapStage - // Mark the map whose fetch failed as broken in the map stage - if (mapId != -1) { - mapStage.removeOutputLoc(mapId, bmAddress) - mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) - } + // It is likely that we receive multiple FetchFailed for a single stage (because we have + // multiple tasks running concurrently on different executors). In that case, it is + // possible the fetch failure has already been handled by the scheduler. + if (runningStages.contains(failedStage)) { + logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + + s"due to a fetch failure from $mapStage (${mapStage.name})") + markStageAsFinished(failedStage, Some(failureMessage)) + } else { + logDebug(s"Received fetch failure from $task, but its from $failedStage which is no " + + s"longer running") + } - // TODO: mark the executor as failed only if there were lots of fetch failures on it - if (bmAddress != null) { - handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch)) + if (disallowStageRetryForTest) { + abortStage(failedStage, "Fetch failure will not retry stage due to testing config") + } else if (failedStages.isEmpty) { + // Don't schedule an event to resubmit failed stages if failed isn't empty, because + // in that case the event will already have been scheduled. + // TODO: Cancel running tasks in the stage + logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " + + s"$failedStage (${failedStage.name}) due to fetch failure") + messageScheduler.schedule(new Runnable { + override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) + }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS) + } + failedStages += failedStage + failedStages += mapStage + // Mark the map whose fetch failed as broken in the map stage + if (mapId != -1) { + mapStage.removeOutputLoc(mapId, bmAddress) + mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) + } + + // TODO: mark the executor as failed only if there were lots of fetch failures on it + if (bmAddress != null) { + handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch)) + } } case commitDenied: TaskCommitDenied => @@ -1471,9 +1418,8 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler } private def doOnReceive(event: DAGSchedulerEvent): Unit = event match { - case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) => - dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, - listener, properties) + case JobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) => + dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) case StageCancelled(stageId) => dagScheduler.handleStageCancellation(stageId) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index 2b6f7e4205c32..a213d419cf033 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler import java.util.Properties -import scala.collection.mutable.Map +import scala.collection.Map import scala.language.existentials import org.apache.spark._ @@ -40,7 +40,6 @@ private[scheduler] case class JobSubmitted( finalRDD: RDD[_], func: (TaskContext, Iterator[_]) => _, partitions: Array[Int], - allowLocal: Boolean, callSite: CallSite, listener: JobListener, properties: Properties = null) diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 529a5b2bf1a0d..5a06ef02f5c57 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -140,7 +140,9 @@ private[spark] class EventLoggingListener( /** Log the event as JSON. */ private def logEvent(event: SparkListenerEvent, flushLogger: Boolean = false) { val eventJson = JsonProtocol.sparkEventToJson(event) + // scalastyle:off println writer.foreach(_.println(compact(render(eventJson)))) + // scalastyle:on println if (flushLogger) { writer.foreach(_.flush()) hadoopDataStream.foreach(hadoopFlushMethod.invoke(_)) @@ -197,6 +199,9 @@ private[spark] class EventLoggingListener( logEvent(event, flushLogger = true) } + // No-op because logging every update would be overkill + override def onBlockUpdated(event: SparkListenerBlockUpdated): Unit = {} + // No-op because logging every update would be overkill override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { } diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index e55b76c36cc5f..f96eb8ca0ae00 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -125,7 +125,9 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener val date = new Date(System.currentTimeMillis()) writeInfo = dateFormat.get.format(date) + ": " + info } + // scalastyle:off println jobIdToPrintWriter.get(jobId).foreach(_.println(writeInfo)) + // scalastyle:on println } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index c9a124113961f..9c2606e278c54 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -41,11 +41,12 @@ import org.apache.spark.rdd.RDD */ private[spark] class ResultTask[T, U]( stageId: Int, + stageAttemptId: Int, taskBinary: Broadcast[Array[Byte]], partition: Partition, @transient locs: Seq[TaskLocation], val outputId: Int) - extends Task[U](stageId, partition.index) with Serializable { + extends Task[U](stageId, stageAttemptId, partition.index) with Serializable { @transient private[this] val preferredLocs: Seq[TaskLocation] = { if (locs == null) Nil else locs.toSet.toSeq diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index bd3dd23dfe1ac..14c8c00961487 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -40,14 +40,15 @@ import org.apache.spark.shuffle.ShuffleWriter */ private[spark] class ShuffleMapTask( stageId: Int, + stageAttemptId: Int, taskBinary: Broadcast[Array[Byte]], partition: Partition, @transient private var locs: Seq[TaskLocation]) - extends Task[MapStatus](stageId, partition.index) with Logging { + extends Task[MapStatus](stageId, stageAttemptId, partition.index) with Logging { /** A constructor used only in test suites. This does not require passing in an RDD. */ def this(partitionId: Int) { - this(0, null, new Partition { override def index: Int = 0 }, null) + this(0, 0, null, new Partition { override def index: Int = 0 }, null) } @transient private val preferredLocs: Seq[TaskLocation] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 9620915f495ab..896f1743332f1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -26,7 +26,7 @@ import org.apache.spark.{Logging, TaskEndReason} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.cluster.ExecutorInfo -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.storage.{BlockManagerId, BlockUpdatedInfo} import org.apache.spark.util.{Distribution, Utils} @DeveloperApi @@ -98,6 +98,9 @@ case class SparkListenerExecutorAdded(time: Long, executorId: String, executorIn case class SparkListenerExecutorRemoved(time: Long, executorId: String, reason: String) extends SparkListenerEvent +@DeveloperApi +case class SparkListenerBlockUpdated(blockUpdatedInfo: BlockUpdatedInfo) extends SparkListenerEvent + /** * Periodic updates from executors. * @param execId executor id @@ -215,6 +218,11 @@ trait SparkListener { * Called when the driver removes an executor. */ def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved) { } + + /** + * Called when the driver receives a block update info. + */ + def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated) { } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 61e69ecc08387..04afde33f5aad 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -58,6 +58,8 @@ private[spark] trait SparkListenerBus extends ListenerBus[SparkListener, SparkLi listener.onExecutorAdded(executorAdded) case executorRemoved: SparkListenerExecutorRemoved => listener.onExecutorRemoved(executorRemoved) + case blockUpdated: SparkListenerBlockUpdated => + listener.onBlockUpdated(blockUpdated) case logStart: SparkListenerLogStart => // ignore event log metadata } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index c59d6e4f5bc04..40a333a3e06b2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -62,22 +62,31 @@ private[spark] abstract class Stage( var pendingTasks = new HashSet[Task[_]] + /** The ID to use for the next new attempt for this stage. */ private var nextAttemptId: Int = 0 val name = callSite.shortForm val details = callSite.longForm - /** Pointer to the latest [StageInfo] object, set by DAGScheduler. */ - var latestInfo: StageInfo = StageInfo.fromStage(this) + /** + * Pointer to the [StageInfo] object for the most recent attempt. This needs to be initialized + * here, before any attempts have actually been created, because the DAGScheduler uses this + * StageInfo to tell SparkListeners when a job starts (which happens before any stage attempts + * have been created). + */ + private var _latestInfo: StageInfo = StageInfo.fromStage(this, nextAttemptId) - /** Return a new attempt id, starting with 0. */ - def newAttemptId(): Int = { - val id = nextAttemptId + /** Creates a new attempt for this stage by creating a new StageInfo with a new attempt ID. */ + def makeNewStageAttempt( + numPartitionsToCompute: Int, + taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty): Unit = { + _latestInfo = StageInfo.fromStage( + this, nextAttemptId, Some(numPartitionsToCompute), taskLocalityPreferences) nextAttemptId += 1 - id } - def attemptId: Int = nextAttemptId + /** Returns the StageInfo for the most recent attempt for this stage. */ + def latestInfo: StageInfo = _latestInfo override final def hashCode(): Int = id override final def equals(other: Any): Boolean = other match { diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index e439d2a7e1229..24796c14300b1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -34,7 +34,8 @@ class StageInfo( val numTasks: Int, val rddInfos: Seq[RDDInfo], val parentIds: Seq[Int], - val details: String) { + val details: String, + private[spark] val taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty) { /** When this stage was submitted from the DAGScheduler to a TaskScheduler. */ var submissionTime: Option[Long] = None /** Time when all tasks in the stage completed or when the stage was cancelled. */ @@ -70,16 +71,22 @@ private[spark] object StageInfo { * shuffle dependencies. Therefore, all ancestor RDDs related to this Stage's RDD through a * sequence of narrow dependencies should also be associated with this Stage. */ - def fromStage(stage: Stage, numTasks: Option[Int] = None): StageInfo = { + def fromStage( + stage: Stage, + attemptId: Int, + numTasks: Option[Int] = None, + taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty + ): StageInfo = { val ancestorRddInfos = stage.rdd.getNarrowAncestors.map(RDDInfo.fromRdd) val rddInfos = Seq(RDDInfo.fromRdd(stage.rdd)) ++ ancestorRddInfos new StageInfo( stage.id, - stage.attemptId, + attemptId, stage.name, numTasks.getOrElse(stage.numTasks), rddInfos, stage.parents.map(_.id), - stage.details) + stage.details, + taskLocalityPreferences) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 15101c64f0503..1978305cfefbd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -22,7 +22,8 @@ import java.nio.ByteBuffer import scala.collection.mutable.HashMap -import org.apache.spark.{TaskContextImpl, TaskContext} +import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.{SparkEnv, TaskContextImpl, TaskContext} import org.apache.spark.executor.TaskMetrics import org.apache.spark.serializer.SerializerInstance import org.apache.spark.unsafe.memory.TaskMemoryManager @@ -43,34 +44,60 @@ import org.apache.spark.util.Utils * @param stageId id of the stage this task belongs to * @param partitionId index of the number in the RDD */ -private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable { +private[spark] abstract class Task[T]( + val stageId: Int, + val stageAttemptId: Int, + var partitionId: Int) extends Serializable { + + /** + * The key of the Map is the accumulator id and the value of the Map is the latest accumulator + * local value. + */ + type AccumulatorUpdates = Map[Long, Any] /** * Called by [[Executor]] to run this task. * * @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext. * @param attemptNumber how many times this task has been attempted (0 for the first attempt) - * @return the result of the task + * @return the result of the task along with updates of Accumulators. */ - final def run(taskAttemptId: Long, attemptNumber: Int): T = { + final def run( + taskAttemptId: Long, + attemptNumber: Int, + metricsSystem: MetricsSystem) + : (T, AccumulatorUpdates) = { context = new TaskContextImpl( stageId = stageId, partitionId = partitionId, taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, taskMemoryManager = taskMemoryManager, + metricsSystem = metricsSystem, runningLocally = false) TaskContext.setTaskContext(context) context.taskMetrics.setHostname(Utils.localHostName()) + context.taskMetrics.setAccumulatorsUpdater(context.collectInternalAccumulators) taskThread = Thread.currentThread() if (_killed) { kill(interruptThread = false) } try { - runTask(context) + (runTask(context), context.collectAccumulators()) } finally { context.markTaskCompleted() - TaskContext.unset() + try { + Utils.tryLogNonFatalError { + // Release memory used by this thread for shuffles + SparkEnv.get.shuffleMemoryManager.releaseMemoryForThisTask() + } + Utils.tryLogNonFatalError { + // Release memory used by this thread for unrolling blocks + SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask() + } + } finally { + TaskContext.unset() + } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala index 8b2a742b96988..b82c7f3fa54f8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala @@ -20,7 +20,8 @@ package org.apache.spark.scheduler import java.io._ import java.nio.ByteBuffer -import scala.collection.mutable.Map +import scala.collection.Map +import scala.collection.mutable import org.apache.spark.SparkEnv import org.apache.spark.executor.TaskMetrics @@ -69,10 +70,11 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long if (numUpdates == 0) { accumUpdates = null } else { - accumUpdates = Map() + val _accumUpdates = mutable.Map[Long, Any]() for (i <- 0 until numUpdates) { - accumUpdates(in.readLong()) = in.readObject() + _accumUpdates(in.readLong()) = in.readObject() } + accumUpdates = _accumUpdates } metrics = in.readObject().asInstanceOf[TaskMetrics] valueObjectDeserialized = false diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index ed3dde0fc3055..1705e7f962de2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -75,9 +75,9 @@ private[spark] class TaskSchedulerImpl( // TaskSetManagers are not thread safe, so any access to one should be synchronized // on this class. - val activeTaskSets = new HashMap[String, TaskSetManager] + private val taskSetsByStageIdAndAttempt = new HashMap[Int, HashMap[Int, TaskSetManager]] - val taskIdToTaskSetId = new HashMap[Long, String] + private[scheduler] val taskIdToTaskSetManager = new HashMap[Long, TaskSetManager] val taskIdToExecutorId = new HashMap[Long, String] @volatile private var hasReceivedTask = false @@ -162,7 +162,17 @@ private[spark] class TaskSchedulerImpl( logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") this.synchronized { val manager = createTaskSetManager(taskSet, maxTaskFailures) - activeTaskSets(taskSet.id) = manager + val stage = taskSet.stageId + val stageTaskSets = + taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager]) + stageTaskSets(taskSet.stageAttemptId) = manager + val conflictingTaskSet = stageTaskSets.exists { case (_, ts) => + ts.taskSet != taskSet && !ts.isZombie + } + if (conflictingTaskSet) { + throw new IllegalStateException(s"more than one active taskSet for stage $stage:" + + s" ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(",")}") + } schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) if (!isLocal && !hasReceivedTask) { @@ -192,19 +202,21 @@ private[spark] class TaskSchedulerImpl( override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized { logInfo("Cancelling stage " + stageId) - activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) => - // There are two possible cases here: - // 1. The task set manager has been created and some tasks have been scheduled. - // In this case, send a kill signal to the executors to kill the task and then abort - // the stage. - // 2. The task set manager has been created but no tasks has been scheduled. In this case, - // simply abort the stage. - tsm.runningTasksSet.foreach { tid => - val execId = taskIdToExecutorId(tid) - backend.killTask(tid, execId, interruptThread) + taskSetsByStageIdAndAttempt.get(stageId).foreach { attempts => + attempts.foreach { case (_, tsm) => + // There are two possible cases here: + // 1. The task set manager has been created and some tasks have been scheduled. + // In this case, send a kill signal to the executors to kill the task and then abort + // the stage. + // 2. The task set manager has been created but no tasks has been scheduled. In this case, + // simply abort the stage. + tsm.runningTasksSet.foreach { tid => + val execId = taskIdToExecutorId(tid) + backend.killTask(tid, execId, interruptThread) + } + tsm.abort("Stage %s cancelled".format(stageId)) + logInfo("Stage %d was cancelled".format(stageId)) } - tsm.abort("Stage %s cancelled".format(stageId)) - logInfo("Stage %d was cancelled".format(stageId)) } } @@ -214,7 +226,12 @@ private[spark] class TaskSchedulerImpl( * cleaned up. */ def taskSetFinished(manager: TaskSetManager): Unit = synchronized { - activeTaskSets -= manager.taskSet.id + taskSetsByStageIdAndAttempt.get(manager.taskSet.stageId).foreach { taskSetsForStage => + taskSetsForStage -= manager.taskSet.stageAttemptId + if (taskSetsForStage.isEmpty) { + taskSetsByStageIdAndAttempt -= manager.taskSet.stageId + } + } manager.parent.removeSchedulable(manager) logInfo("Removed TaskSet %s, whose tasks have all completed, from pool %s" .format(manager.taskSet.id, manager.parent.name)) @@ -235,7 +252,7 @@ private[spark] class TaskSchedulerImpl( for (task <- taskSet.resourceOffer(execId, host, maxLocality)) { tasks(i) += task val tid = task.taskId - taskIdToTaskSetId(tid) = taskSet.taskSet.id + taskIdToTaskSetManager(tid) = taskSet taskIdToExecutorId(tid) = execId executorsByHost(host) += execId availableCpus(i) -= CPUS_PER_TASK @@ -319,26 +336,24 @@ private[spark] class TaskSchedulerImpl( failedExecutor = Some(execId) } } - taskIdToTaskSetId.get(tid) match { - case Some(taskSetId) => + taskIdToTaskSetManager.get(tid) match { + case Some(taskSet) => if (TaskState.isFinished(state)) { - taskIdToTaskSetId.remove(tid) + taskIdToTaskSetManager.remove(tid) taskIdToExecutorId.remove(tid) } - activeTaskSets.get(taskSetId).foreach { taskSet => - if (state == TaskState.FINISHED) { - taskSet.removeRunningTask(tid) - taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData) - } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) { - taskSet.removeRunningTask(tid) - taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData) - } + if (state == TaskState.FINISHED) { + taskSet.removeRunningTask(tid) + taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData) + } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) { + taskSet.removeRunningTask(tid) + taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData) } case None => logError( ("Ignoring update with state %s for TID %s because its task set is gone (this is " + - "likely the result of receiving duplicate task finished status updates)") - .format(state, tid)) + "likely the result of receiving duplicate task finished status updates)") + .format(state, tid)) } } catch { case e: Exception => logError("Exception in statusUpdate", e) @@ -363,9 +378,9 @@ private[spark] class TaskSchedulerImpl( val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized { taskMetrics.flatMap { case (id, metrics) => - taskIdToTaskSetId.get(id) - .flatMap(activeTaskSets.get) - .map(taskSetMgr => (id, taskSetMgr.stageId, taskSetMgr.taskSet.attempt, metrics)) + taskIdToTaskSetManager.get(id).map { taskSetMgr => + (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics) + } } } dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId) @@ -397,9 +412,12 @@ private[spark] class TaskSchedulerImpl( def error(message: String) { synchronized { - if (activeTaskSets.nonEmpty) { + if (taskSetsByStageIdAndAttempt.nonEmpty) { // Have each task set throw a SparkException with the error - for ((taskSetId, manager) <- activeTaskSets) { + for { + attempts <- taskSetsByStageIdAndAttempt.values + manager <- attempts.values + } { try { manager.abort(message) } catch { @@ -520,6 +538,17 @@ private[spark] class TaskSchedulerImpl( override def applicationAttemptId(): Option[String] = backend.applicationAttemptId() + private[scheduler] def taskSetManagerForAttempt( + stageId: Int, + stageAttemptId: Int): Option[TaskSetManager] = { + for { + attempts <- taskSetsByStageIdAndAttempt.get(stageId) + manager <- attempts.get(stageAttemptId) + } yield { + manager + } + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala index c3ad325156f53..be8526ba9b94f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala @@ -26,10 +26,10 @@ import java.util.Properties private[spark] class TaskSet( val tasks: Array[Task[_]], val stageId: Int, - val attempt: Int, + val stageAttemptId: Int, val priority: Int, val properties: Properties) { - val id: String = stageId + "." + attempt + val id: String = stageId + "." + stageAttemptId override def toString: String = "TaskSet " + id } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 4be1eda2e9291..06f5438433b6e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -86,7 +86,11 @@ private[spark] object CoarseGrainedClusterMessages { // Request executors by specifying the new total number of executors desired // This includes executors already pending or running - case class RequestExecutors(requestedTotal: Int) extends CoarseGrainedClusterMessage + case class RequestExecutors( + requestedTotal: Int, + localityAwareTasks: Int, + hostToLocalTaskCount: Map[String, Int]) + extends CoarseGrainedClusterMessage case class KillExecutors(executorIds: Seq[String]) extends CoarseGrainedClusterMessage diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 7c7f70d8a193b..bd89160af4ffa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -66,6 +66,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Executors we have requested the cluster manager to kill that have not died yet private val executorsPendingToRemove = new HashSet[String] + // A map to store hostname with its possible task number running on it + protected var hostToLocalTaskCount: Map[String, Int] = Map.empty + + // The number of pending tasks which is locality required + protected var localityAwareTasks = 0 + class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) extends ThreadSafeRpcEndpoint with Logging { @@ -169,9 +175,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Make fake resource offers on all executors private def makeOffers() { - launchTasks(scheduler.resourceOffers(executorDataMap.map { case (id, executorData) => + // Filter out executors under killing + val activeExecutors = executorDataMap.filterKeys(!executorsPendingToRemove.contains(_)) + val workOffers = activeExecutors.map { case (id, executorData) => new WorkerOffer(id, executorData.executorHost, executorData.freeCores) - }.toSeq)) + }.toSeq + launchTasks(scheduler.resourceOffers(workOffers)) } override def onDisconnected(remoteAddress: RpcAddress): Unit = { @@ -181,9 +190,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Make fake resource offers on just one executor private def makeOffers(executorId: String) { - val executorData = executorDataMap(executorId) - launchTasks(scheduler.resourceOffers( - Seq(new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores)))) + // Filter out executors under killing + if (!executorsPendingToRemove.contains(executorId)) { + val executorData = executorDataMap(executorId) + val workOffers = Seq( + new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores)) + launchTasks(scheduler.resourceOffers(workOffers)) + } } // Launch tasks returned by a set of resource offers @@ -191,15 +204,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp for (task <- tasks.flatten) { val serializedTask = ser.serialize(task) if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { - val taskSetId = scheduler.taskIdToTaskSetId(task.taskId) - scheduler.activeTaskSets.get(taskSetId).foreach { taskSet => + scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr => try { var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " + "spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " + "spark.akka.frameSize or using broadcast variables for large values." msg = msg.format(task.taskId, task.index, serializedTask.limit, akkaFrameSize, AkkaUtils.reservedSizeBytes) - taskSet.abort(msg) + taskSetMgr.abort(msg) } catch { case e: Exception => logError("Exception in error callback", e) } @@ -229,7 +241,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp scheduler.executorLost(executorId, SlaveLost(reason)) listenerBus.post( SparkListenerExecutorRemoved(System.currentTimeMillis(), executorId, reason)) - case None => logError(s"Asked to remove non-existent executor $executorId") + case None => logInfo(s"Asked to remove non-existent executor $executorId") } } @@ -333,6 +345,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } logInfo(s"Requesting $numAdditionalExecutors additional executor(s) from the cluster manager") logDebug(s"Number of pending executors is now $numPendingExecutors") + numPendingExecutors += numAdditionalExecutors // Account for executors pending to be added or removed val newTotal = numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size @@ -340,16 +353,33 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } /** - * Express a preference to the cluster manager for a given total number of executors. This can - * result in canceling pending requests or filing additional requests. - * @return whether the request is acknowledged. + * Update the cluster manager on our scheduling needs. Three bits of information are included + * to help it make decisions. + * @param numExecutors The total number of executors we'd like to have. The cluster manager + * shouldn't kill any running executor to reach this number, but, + * if all existing executors were to die, this is the number of executors + * we'd want to be allocated. + * @param localityAwareTasks The number of tasks in all active stages that have a locality + * preferences. This includes running, pending, and completed tasks. + * @param hostToLocalTaskCount A map of hosts to the number of tasks from all active stages + * that would like to like to run on that host. + * This includes running, pending, and completed tasks. + * @return whether the request is acknowledged by the cluster manager. */ - final override def requestTotalExecutors(numExecutors: Int): Boolean = synchronized { + final override def requestTotalExecutors( + numExecutors: Int, + localityAwareTasks: Int, + hostToLocalTaskCount: Map[String, Int] + ): Boolean = synchronized { if (numExecutors < 0) { throw new IllegalArgumentException( "Attempted to request a negative number of executor(s) " + s"$numExecutors from the cluster manager. Please specify a positive number!") } + + this.localityAwareTasks = localityAwareTasks + this.hostToLocalTaskCount = hostToLocalTaskCount + numPendingExecutors = math.max(numExecutors - numExistingExecutors + executorsPendingToRemove.size, 0) doRequestTotalExecutors(numExecutors) @@ -371,26 +401,36 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp /** * Request that the cluster manager kill the specified executors. - * Return whether the kill request is acknowledged. + * @return whether the kill request is acknowledged. */ final override def killExecutors(executorIds: Seq[String]): Boolean = synchronized { + killExecutors(executorIds, replace = false) + } + + /** + * Request that the cluster manager kill the specified executors. + * + * @param executorIds identifiers of executors to kill + * @param replace whether to replace the killed executors with new ones + * @return whether the kill request is acknowledged. + */ + final def killExecutors(executorIds: Seq[String], replace: Boolean): Boolean = synchronized { logInfo(s"Requesting to kill executor(s) ${executorIds.mkString(", ")}") - val filteredExecutorIds = new ArrayBuffer[String] - executorIds.foreach { id => - if (executorDataMap.contains(id)) { - filteredExecutorIds += id - } else { - logWarning(s"Executor to kill $id does not exist!") - } + val (knownExecutors, unknownExecutors) = executorIds.partition(executorDataMap.contains) + unknownExecutors.foreach { id => + logWarning(s"Executor to kill $id does not exist!") + } + + // If we do not wish to replace the executors we kill, sync the target number of executors + // with the cluster manager to avoid allocating new ones. When computing the new target, + // take into account executors that are pending to be added or removed. + if (!replace) { + doRequestTotalExecutors(numExistingExecutors + numPendingExecutors + - executorsPendingToRemove.size - knownExecutors.size) } - // Killing executors means effectively that we want less executors than before, so also update - // the target number of executors to avoid having the backend allocate new ones. - val newTotal = (numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size - - filteredExecutorIds.size) - doRequestTotalExecutors(newTotal) - executorsPendingToRemove ++= filteredExecutorIds - doKillExecutors(filteredExecutorIds) + executorsPendingToRemove ++= knownExecutors + doKillExecutors(knownExecutors) } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index bc67abb5df446..044f6288fabdd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -53,7 +53,8 @@ private[spark] abstract class YarnSchedulerBackend( * This includes executors already pending or running. */ override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { - yarnSchedulerEndpoint.askWithRetry[Boolean](RequestExecutors(requestedTotal)) + yarnSchedulerEndpoint.askWithRetry[Boolean]( + RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount)) } /** @@ -108,6 +109,8 @@ private[spark] abstract class YarnSchedulerBackend( case AddWebUIFilter(filterName, filterParams, proxyBase) => addWebUIFilter(filterName, filterParams, proxyBase) + case RemoveExecutor(executorId, reason) => + removeExecutor(executorId, reason) } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 6b8edca5aa485..b7fde0d9b3265 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -18,11 +18,13 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File +import java.util.concurrent.locks.ReentrantLock import java.util.{Collections, List => JList} import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet} +import com.google.common.collect.HashBiMap import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} import org.apache.mesos.{Scheduler => MScheduler, _} import org.apache.spark.rpc.RpcAddress @@ -60,12 +62,34 @@ private[spark] class CoarseMesosSchedulerBackend( val slaveIdsWithExecutors = new HashSet[String] - val taskIdToSlaveId = new HashMap[Int, String] - val failuresBySlaveId = new HashMap[String, Int] // How many times tasks on each slave failed + val taskIdToSlaveId: HashBiMap[Int, String] = HashBiMap.create[Int, String] + // How many times tasks on each slave failed + val failuresBySlaveId: HashMap[String, Int] = new HashMap[String, Int] + /** + * The total number of executors we aim to have. Undefined when not using dynamic allocation + * and before the ExecutorAllocatorManager calls [[doRequestTotalExecutors]]. + */ + private var executorLimitOption: Option[Int] = None + + /** + * Return the current executor limit, which may be [[Int.MaxValue]] + * before properly initialized. + */ + private[mesos] def executorLimit: Int = executorLimitOption.getOrElse(Int.MaxValue) + + private val pendingRemovedSlaveIds = new HashSet[String] + + // private lock object protecting mutable state above. Using the intrinsic lock + // may lead to deadlocks since the superclass might also try to lock + private val stateLock = new ReentrantLock val extraCoresPerSlave = conf.getInt("spark.mesos.extra.cores", 0) + // Offer constraints + private val slaveOfferConstraints = + parseConstraintString(sc.conf.get("spark.mesos.constraints", "")) + var nextMesosTaskId = 0 @volatile var appId: String = _ @@ -78,11 +102,12 @@ private[spark] class CoarseMesosSchedulerBackend( override def start() { super.start() - val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() - startScheduler(master, CoarseMesosSchedulerBackend.this, fwInfo) + val driver = createSchedulerDriver( + master, CoarseMesosSchedulerBackend.this, sc.sparkUser, sc.appName, sc.conf) + startScheduler(driver) } - def createCommand(offer: Offer, numCores: Int): CommandInfo = { + def createCommand(offer: Offer, numCores: Int, taskId: Int): CommandInfo = { val executorSparkHome = conf.getOption("spark.mesos.executor.home") .orElse(sc.getSparkHome()) .getOrElse { @@ -116,10 +141,6 @@ private[spark] class CoarseMesosSchedulerBackend( } val command = CommandInfo.newBuilder() .setEnvironment(environment) - val driverUrl = sc.env.rpcEnv.uriOf( - SparkEnv.driverActorSystemName, - RpcAddress(conf.get("spark.driver.host"), conf.get("spark.driver.port").toInt), - CoarseGrainedSchedulerBackend.ENDPOINT_NAME) val uri = conf.getOption("spark.executor.uri") .orElse(Option(System.getenv("SPARK_EXECUTOR_URI"))) @@ -129,7 +150,7 @@ private[spark] class CoarseMesosSchedulerBackend( command.setValue( "%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend" .format(prefixEnv, runScript) + - s" --driver-url $driverUrl" + + s" --driver-url $driverURL" + s" --executor-id ${offer.getSlaveId.getValue}" + s" --hostname ${offer.getHostname}" + s" --cores $numCores" + @@ -138,11 +159,12 @@ private[spark] class CoarseMesosSchedulerBackend( // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". val basename = uri.get.split('/').last.split('.').head + val executorId = sparkExecutorId(offer.getSlaveId.getValue, taskId.toString) command.setValue( s"cd $basename*; $prefixEnv " + "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend" + - s" --driver-url $driverUrl" + - s" --executor-id ${offer.getSlaveId.getValue}" + + s" --driver-url $driverURL" + + s" --executor-id $executorId" + s" --hostname ${offer.getHostname}" + s" --cores $numCores" + s" --app-id $appId") @@ -151,6 +173,17 @@ private[spark] class CoarseMesosSchedulerBackend( command.build() } + protected def driverURL: String = { + if (conf.contains("spark.testing")) { + "driverURL" + } else { + sc.env.rpcEnv.uriOf( + SparkEnv.driverActorSystemName, + RpcAddress(conf.get("spark.driver.host"), conf.get("spark.driver.port").toInt), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) + } + } + override def offerRescinded(d: SchedulerDriver, o: OfferID) {} override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { @@ -168,15 +201,19 @@ private[spark] class CoarseMesosSchedulerBackend( * unless we've already launched more than we wanted to. */ override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { - synchronized { + stateLock.synchronized { val filters = Filters.newBuilder().setRefuseSeconds(5).build() - for (offer <- offers) { - val slaveId = offer.getSlaveId.toString + val offerAttributes = toAttributeMap(offer.getAttributesList) + val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) + val slaveId = offer.getSlaveId.getValue val mem = getResource(offer.getResourcesList, "mem") val cpus = getResource(offer.getResourcesList, "cpus").toInt - if (totalCoresAcquired < maxCores && - mem >= MemoryUtils.calculateTotalMemory(sc) && + val id = offer.getId.getValue + if (taskIdToSlaveId.size < executorLimit && + totalCoresAcquired < maxCores && + meetsConstraints && + mem >= calculateTotalMemory(sc) && cpus >= 1 && failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES && !slaveIdsWithExecutors.contains(slaveId)) { @@ -187,45 +224,44 @@ private[spark] class CoarseMesosSchedulerBackend( taskIdToSlaveId(taskId) = slaveId slaveIdsWithExecutors += slaveId coresByTaskId(taskId) = cpusToUse - val task = MesosTaskInfo.newBuilder() + // Gather cpu resources from the available resources and use them in the task. + val (remainingResources, cpuResourcesToUse) = + partitionResources(offer.getResourcesList, "cpus", cpusToUse) + val (_, memResourcesToUse) = + partitionResources(remainingResources, "mem", calculateTotalMemory(sc)) + val taskBuilder = MesosTaskInfo.newBuilder() .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) .setSlaveId(offer.getSlaveId) - .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave)) + .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave, taskId)) .setName("Task " + taskId) - .addResources(createResource("cpus", cpusToUse)) - .addResources(createResource("mem", - MemoryUtils.calculateTotalMemory(sc))) + .addAllResources(cpuResourcesToUse) + .addAllResources(memResourcesToUse) sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image => MesosSchedulerBackendUtil - .setupContainerBuilderDockerInfo(image, sc.conf, task.getContainerBuilder()) + .setupContainerBuilderDockerInfo(image, sc.conf, taskBuilder.getContainerBuilder()) } + // accept the offer and launch the task + logDebug(s"Accepting offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") d.launchTasks( - Collections.singleton(offer.getId), Collections.singletonList(task.build()), filters) + Collections.singleton(offer.getId), + Collections.singleton(taskBuilder.build()), filters) } else { - // Filter it out - d.launchTasks( - Collections.singleton(offer.getId), Collections.emptyList[MesosTaskInfo](), filters) + // Decline the offer + logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") + d.declineOffer(offer.getId) } } } } - /** Build a Mesos resource protobuf object */ - private def createResource(resourceName: String, quantity: Double): Protos.Resource = { - Resource.newBuilder() - .setName(resourceName) - .setType(Value.Type.SCALAR) - .setScalar(Value.Scalar.newBuilder().setValue(quantity).build()) - .build() - } override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { val taskId = status.getTaskId.getValue.toInt val state = status.getState - logInfo("Mesos task " + taskId + " is now " + state) - synchronized { + logInfo(s"Mesos task $taskId is now $state") + stateLock.synchronized { if (TaskState.isFinished(TaskState.fromMesos(state))) { val slaveId = taskIdToSlaveId(taskId) slaveIdsWithExecutors -= slaveId @@ -239,18 +275,19 @@ private[spark] class CoarseMesosSchedulerBackend( if (TaskState.isFailed(TaskState.fromMesos(state))) { failuresBySlaveId(slaveId) = failuresBySlaveId.getOrElse(slaveId, 0) + 1 if (failuresBySlaveId(slaveId) >= MAX_SLAVE_FAILURES) { - logInfo("Blacklisting Mesos slave " + slaveId + " due to too many failures; " + + logInfo(s"Blacklisting Mesos slave $slaveId due to too many failures; " + "is Spark installed on it?") } } + executorTerminated(d, slaveId, s"Executor finished with state $state") // In case we'd rejected everything before but have now lost a node - mesosDriver.reviveOffers() + d.reviveOffers() } } } override def error(d: SchedulerDriver, message: String) { - logError("Mesos error: " + message) + logError(s"Mesos error: $message") scheduler.error(message) } @@ -263,18 +300,39 @@ private[spark] class CoarseMesosSchedulerBackend( override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {} - override def slaveLost(d: SchedulerDriver, slaveId: SlaveID) { - logInfo("Mesos slave lost: " + slaveId.getValue) - synchronized { - if (slaveIdsWithExecutors.contains(slaveId.getValue)) { - // Note that the slave ID corresponds to the executor ID on that slave - slaveIdsWithExecutors -= slaveId.getValue - removeExecutor(slaveId.getValue, "Mesos slave lost") + /** + * Called when a slave is lost or a Mesos task finished. Update local view on + * what tasks are running and remove the terminated slave from the list of pending + * slave IDs that we might have asked to be killed. It also notifies the driver + * that an executor was removed. + */ + private def executorTerminated(d: SchedulerDriver, slaveId: String, reason: String): Unit = { + stateLock.synchronized { + if (slaveIdsWithExecutors.contains(slaveId)) { + val slaveIdToTaskId = taskIdToSlaveId.inverse() + if (slaveIdToTaskId.contains(slaveId)) { + val taskId: Int = slaveIdToTaskId.get(slaveId) + taskIdToSlaveId.remove(taskId) + removeExecutor(sparkExecutorId(slaveId, taskId.toString), reason) + } + // TODO: This assumes one Spark executor per Mesos slave, + // which may no longer be true after SPARK-5095 + pendingRemovedSlaveIds -= slaveId + slaveIdsWithExecutors -= slaveId } } } - override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int) { + private def sparkExecutorId(slaveId: String, taskId: String): String = { + s"$slaveId/$taskId" + } + + override def slaveLost(d: SchedulerDriver, slaveId: SlaveID): Unit = { + logInfo(s"Mesos slave lost: ${slaveId.getValue}") + executorTerminated(d, slaveId.getValue, "Mesos slave lost: " + slaveId.getValue) + } + + override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int): Unit = { logInfo("Executor lost: %s, marking slave %s as lost".format(e.getValue, s.getValue)) slaveLost(d, s) } @@ -285,4 +343,34 @@ private[spark] class CoarseMesosSchedulerBackend( super.applicationId } + override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { + // We don't truly know if we can fulfill the full amount of executors + // since at coarse grain it depends on the amount of slaves available. + logInfo("Capping the total amount of executors to " + requestedTotal) + executorLimitOption = Some(requestedTotal) + true + } + + override def doKillExecutors(executorIds: Seq[String]): Boolean = { + if (mesosDriver == null) { + logWarning("Asked to kill executors before the Mesos driver was started.") + return false + } + + val slaveIdToTaskId = taskIdToSlaveId.inverse() + for (executorId <- executorIds) { + val slaveId = executorId.split("/")(0) + if (slaveIdToTaskId.contains(slaveId)) { + mesosDriver.killTask( + TaskID.newBuilder().setValue(slaveIdToTaskId.get(slaveId).toString).build()) + pendingRemovedSlaveIds += slaveId + } else { + logWarning("Unable to find executor Id '" + executorId + "' in Mesos scheduler") + } + } + // no need to adjust `executorLimitOption` since the AllocationManager already communicated + // the desired limit through a call to `doRequestTotalExecutors`. + // See [[o.a.s.scheduler.cluster.CoarseGrainedSchedulerBackend.killExecutors]] + true + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index 1067a7f1caf4c..f078547e71352 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -29,6 +29,7 @@ import org.apache.mesos.Protos.Environment.Variable import org.apache.mesos.Protos.TaskStatus.Reason import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} import org.apache.mesos.{Scheduler, SchedulerDriver} + import org.apache.spark.deploy.mesos.MesosDriverDescription import org.apache.spark.deploy.rest.{CreateSubmissionResponse, KillSubmissionResponse, SubmissionStatusResponse} import org.apache.spark.metrics.MetricsSystem @@ -294,20 +295,24 @@ private[spark] class MesosClusterScheduler( def start(): Unit = { // TODO: Implement leader election to make sure only one framework running in the cluster. val fwId = schedulerState.fetch[String]("frameworkId") - val builder = FrameworkInfo.newBuilder() - .setUser(Utils.getCurrentUserName()) - .setName(appName) - .setWebuiUrl(frameworkUrl) - .setCheckpoint(true) - .setFailoverTimeout(Integer.MAX_VALUE) // Setting to max so tasks keep running on crash fwId.foreach { id => - builder.setId(FrameworkID.newBuilder().setValue(id).build()) frameworkId = id } recoverState() metricsSystem.registerSource(new MesosClusterSchedulerSource(this)) metricsSystem.start() - startScheduler(master, MesosClusterScheduler.this, builder.build()) + val driver = createSchedulerDriver( + master, + MesosClusterScheduler.this, + Utils.getCurrentUserName(), + appName, + conf, + Some(frameworkUrl), + Some(true), + Some(Integer.MAX_VALUE), + fwId) + + startScheduler(driver) ready = true } @@ -448,12 +453,8 @@ private[spark] class MesosClusterScheduler( offer.cpu -= driverCpu offer.mem -= driverMem val taskId = TaskID.newBuilder().setValue(submission.submissionId).build() - val cpuResource = Resource.newBuilder() - .setName("cpus").setType(Value.Type.SCALAR) - .setScalar(Value.Scalar.newBuilder().setValue(driverCpu)).build() - val memResource = Resource.newBuilder() - .setName("mem").setType(Value.Type.SCALAR) - .setScalar(Value.Scalar.newBuilder().setValue(driverMem)).build() + val cpuResource = createResource("cpus", driverCpu) + val memResource = createResource("mem", driverMem) val commandInfo = buildDriverCommand(submission) val appName = submission.schedulerProperties("spark.app.name") val taskInfo = TaskInfo.newBuilder() diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 49de85ef48ada..3f63ec1c5832f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -23,14 +23,15 @@ import java.util.{ArrayList => JArrayList, Collections, List => JList} import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet} +import org.apache.mesos.{Scheduler => MScheduler, _} import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, TaskInfo => MesosTaskInfo, _} import org.apache.mesos.protobuf.ByteString -import org.apache.mesos.{Scheduler => MScheduler, _} +import org.apache.spark.{SparkContext, SparkException, TaskState} import org.apache.spark.executor.MesosExecutorBackend import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.util.Utils -import org.apache.spark.{SparkContext, SparkException, TaskState} + /** * A SchedulerBackend for running fine-grained tasks on Mesos. Each Spark task is mapped to a @@ -45,8 +46,8 @@ private[spark] class MesosSchedulerBackend( with MScheduler with MesosSchedulerUtils { - // Which slave IDs we have executors on - val slaveIdsWithExecutors = new HashSet[String] + // Stores the slave ids that has launched a Mesos executor. + val slaveIdToExecutorInfo = new HashMap[String, MesosExecutorInfo] val taskIdToSlaveId = new HashMap[Long, String] // An ExecutorInfo for our tasks @@ -59,20 +60,33 @@ private[spark] class MesosSchedulerBackend( private[mesos] val mesosExecutorCores = sc.conf.getDouble("spark.mesos.mesosExecutor.cores", 1) + // Offer constraints + private[this] val slaveOfferConstraints = + parseConstraintString(sc.conf.get("spark.mesos.constraints", "")) + @volatile var appId: String = _ override def start() { - val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() classLoader = Thread.currentThread.getContextClassLoader - startScheduler(master, MesosSchedulerBackend.this, fwInfo) + val driver = createSchedulerDriver( + master, MesosSchedulerBackend.this, sc.sparkUser, sc.appName, sc.conf) + startScheduler(driver) } - def createExecutorInfo(execId: String): MesosExecutorInfo = { + /** + * Creates a MesosExecutorInfo that is used to launch a Mesos executor. + * @param availableResources Available resources that is offered by Mesos + * @param execId The executor id to assign to this new executor. + * @return A tuple of the new mesos executor info and the remaining available resources. + */ + def createExecutorInfo( + availableResources: JList[Resource], + execId: String): (MesosExecutorInfo, JList[Resource]) = { val executorSparkHome = sc.conf.getOption("spark.mesos.executor.home") .orElse(sc.getSparkHome()) // Fall back to driver Spark home for backward compatibility .getOrElse { - throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") - } + throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") + } val environment = Environment.newBuilder() sc.conf.getOption("spark.executor.extraClassPath").foreach { cp => environment.addVariables( @@ -111,32 +125,25 @@ private[spark] class MesosSchedulerBackend( command.setValue(s"cd ${basename}*; $prefixEnv ./bin/spark-class $executorBackendName") command.addUris(CommandInfo.URI.newBuilder().setValue(uri.get)) } - val cpus = Resource.newBuilder() - .setName("cpus") - .setType(Value.Type.SCALAR) - .setScalar(Value.Scalar.newBuilder() - .setValue(mesosExecutorCores).build()) - .build() - val memory = Resource.newBuilder() - .setName("mem") - .setType(Value.Type.SCALAR) - .setScalar( - Value.Scalar.newBuilder() - .setValue(MemoryUtils.calculateTotalMemory(sc)).build()) - .build() - val executorInfo = MesosExecutorInfo.newBuilder() + val builder = MesosExecutorInfo.newBuilder() + val (resourcesAfterCpu, usedCpuResources) = + partitionResources(availableResources, "cpus", scheduler.CPUS_PER_TASK) + val (resourcesAfterMem, usedMemResources) = + partitionResources(resourcesAfterCpu, "mem", calculateTotalMemory(sc)) + + builder.addAllResources(usedCpuResources) + builder.addAllResources(usedMemResources) + val executorInfo = builder .setExecutorId(ExecutorID.newBuilder().setValue(execId).build()) .setCommand(command) .setData(ByteString.copyFrom(createExecArg())) - .addResources(cpus) - .addResources(memory) sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image => MesosSchedulerBackendUtil .setupContainerBuilderDockerInfo(image, sc.conf, executorInfo.getContainerBuilder()) } - executorInfo.build() + (executorInfo.build(), resourcesAfterMem) } /** @@ -179,6 +186,18 @@ private[spark] class MesosSchedulerBackend( override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {} + private def getTasksSummary(tasks: JArrayList[MesosTaskInfo]): String = { + val builder = new StringBuilder + tasks.foreach { t => + builder.append("Task id: ").append(t.getTaskId.getValue).append("\n") + .append("Slave id: ").append(t.getSlaveId.getValue).append("\n") + .append("Task resources: ").append(t.getResourcesList).append("\n") + .append("Executor resources: ").append(t.getExecutor.getResourcesList) + .append("---------------------------------------------\n") + } + builder.toString() + } + /** * Method called by Mesos to offer resources on slaves. We respond by asking our active task sets * for tasks in order of priority. We fill each node with tasks in a round-robin manner so that @@ -191,15 +210,33 @@ private[spark] class MesosSchedulerBackend( val mem = getResource(o.getResourcesList, "mem") val cpus = getResource(o.getResourcesList, "cpus") val slaveId = o.getSlaveId.getValue - (mem >= MemoryUtils.calculateTotalMemory(sc) && - // need at least 1 for executor, 1 for task - cpus >= (mesosExecutorCores + scheduler.CPUS_PER_TASK)) || - (slaveIdsWithExecutors.contains(slaveId) && - cpus >= scheduler.CPUS_PER_TASK) + val offerAttributes = toAttributeMap(o.getAttributesList) + + // check if all constraints are satisfield + // 1. Attribute constraints + // 2. Memory requirements + // 3. CPU requirements - need at least 1 for executor, 1 for task + val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) + val meetsMemoryRequirements = mem >= calculateTotalMemory(sc) + val meetsCPURequirements = cpus >= (mesosExecutorCores + scheduler.CPUS_PER_TASK) + + val meetsRequirements = + (meetsConstraints && meetsMemoryRequirements && meetsCPURequirements) || + (slaveIdToExecutorInfo.contains(slaveId) && cpus >= scheduler.CPUS_PER_TASK) + + // add some debug messaging + val debugstr = if (meetsRequirements) "Accepting" else "Declining" + val id = o.getId.getValue + logDebug(s"$debugstr offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") + + meetsRequirements } + // Decline offers we ruled out immediately + unUsableOffers.foreach(o => d.declineOffer(o.getId)) + val workerOffers = usableOffers.map { o => - val cpus = if (slaveIdsWithExecutors.contains(o.getSlaveId.getValue)) { + val cpus = if (slaveIdToExecutorInfo.contains(o.getSlaveId.getValue)) { getResource(o.getResourcesList, "cpus").toInt } else { // If the Mesos executor has not been started on this slave yet, set aside a few @@ -214,6 +251,10 @@ private[spark] class MesosSchedulerBackend( val slaveIdToOffer = usableOffers.map(o => o.getSlaveId.getValue -> o).toMap val slaveIdToWorkerOffer = workerOffers.map(o => o.executorId -> o).toMap + val slaveIdToResources = new HashMap[String, JList[Resource]]() + usableOffers.foreach { o => + slaveIdToResources(o.getSlaveId.getValue) = o.getResourcesList + } val mesosTasks = new HashMap[String, JArrayList[MesosTaskInfo]] @@ -225,11 +266,15 @@ private[spark] class MesosSchedulerBackend( .foreach { offer => offer.foreach { taskDesc => val slaveId = taskDesc.executorId - slaveIdsWithExecutors += slaveId slavesIdsOfAcceptedOffers += slaveId taskIdToSlaveId(taskDesc.taskId) = slaveId + val (mesosTask, remainingResources) = createMesosTask( + taskDesc, + slaveIdToResources(slaveId), + slaveId) mesosTasks.getOrElseUpdate(slaveId, new JArrayList[MesosTaskInfo]) - .add(createMesosTask(taskDesc, slaveId)) + .add(mesosTask) + slaveIdToResources(slaveId) = remainingResources } } @@ -242,6 +287,7 @@ private[spark] class MesosSchedulerBackend( // TODO: Add support for log urls for Mesos new ExecutorInfo(o.host, o.cores, Map.empty))) ) + logTrace(s"Launching Mesos tasks on slave '$slaveId', tasks:\n${getTasksSummary(tasks)}") d.launchTasks(Collections.singleton(slaveIdToOffer(slaveId).getId), tasks, filters) } @@ -250,28 +296,32 @@ private[spark] class MesosSchedulerBackend( for (o <- usableOffers if !slavesIdsOfAcceptedOffers.contains(o.getSlaveId.getValue)) { d.declineOffer(o.getId) } - - // Decline offers we ruled out immediately - unUsableOffers.foreach(o => d.declineOffer(o.getId)) } } - /** Turn a Spark TaskDescription into a Mesos task */ - def createMesosTask(task: TaskDescription, slaveId: String): MesosTaskInfo = { + /** Turn a Spark TaskDescription into a Mesos task and also resources unused by the task */ + def createMesosTask( + task: TaskDescription, + resources: JList[Resource], + slaveId: String): (MesosTaskInfo, JList[Resource]) = { val taskId = TaskID.newBuilder().setValue(task.taskId.toString).build() - val cpuResource = Resource.newBuilder() - .setName("cpus") - .setType(Value.Type.SCALAR) - .setScalar(Value.Scalar.newBuilder().setValue(scheduler.CPUS_PER_TASK).build()) - .build() - MesosTaskInfo.newBuilder() + val (executorInfo, remainingResources) = if (slaveIdToExecutorInfo.contains(slaveId)) { + (slaveIdToExecutorInfo(slaveId), resources) + } else { + createExecutorInfo(resources, slaveId) + } + slaveIdToExecutorInfo(slaveId) = executorInfo + val (finalResources, cpuResources) = + partitionResources(remainingResources, "cpus", scheduler.CPUS_PER_TASK) + val taskInfo = MesosTaskInfo.newBuilder() .setTaskId(taskId) .setSlaveId(SlaveID.newBuilder().setValue(slaveId).build()) - .setExecutor(createExecutorInfo(slaveId)) + .setExecutor(executorInfo) .setName(task.name) - .addResources(cpuResource) + .addAllResources(cpuResources) .setData(MesosTaskLaunchData(task.serializedTask, task.attemptNumber).toByteString) .build() + (taskInfo, finalResources) } override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { @@ -317,7 +367,7 @@ private[spark] class MesosSchedulerBackend( private def removeExecutor(slaveId: String, reason: String) = { synchronized { listenerBus.post(SparkListenerExecutorRemoved(System.currentTimeMillis(), slaveId, reason)) - slaveIdsWithExecutors -= slaveId + slaveIdToExecutorInfo -= slaveId } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index d11228f3d016a..c04920e4f5873 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -17,16 +17,21 @@ package org.apache.spark.scheduler.cluster.mesos -import java.util.List +import java.util.{List => JList} import java.util.concurrent.CountDownLatch import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer +import scala.util.control.NonFatal -import org.apache.mesos.Protos.{FrameworkInfo, Resource, Status} -import org.apache.mesos.{MesosSchedulerDriver, Scheduler} -import org.apache.spark.Logging +import com.google.common.base.Splitter +import org.apache.mesos.{MesosSchedulerDriver, SchedulerDriver, Scheduler, Protos} +import org.apache.mesos.Protos._ +import org.apache.mesos.protobuf.{ByteString, GeneratedMessage} +import org.apache.spark.{SparkException, SparkConf, Logging, SparkContext} import org.apache.spark.util.Utils + /** * Shared trait for implementing a Mesos Scheduler. This holds common state and helper * methods and Mesos scheduler will use. @@ -36,16 +41,66 @@ private[mesos] trait MesosSchedulerUtils extends Logging { private final val registerLatch = new CountDownLatch(1) // Driver for talking to Mesos - protected var mesosDriver: MesosSchedulerDriver = null + protected var mesosDriver: SchedulerDriver = null /** - * Starts the MesosSchedulerDriver with the provided information. This method returns - * only after the scheduler has registered with Mesos. - * @param masterUrl Mesos master connection URL - * @param scheduler Scheduler object - * @param fwInfo FrameworkInfo to pass to the Mesos master + * Creates a new MesosSchedulerDriver that communicates to the Mesos master. + * @param masterUrl The url to connect to Mesos master + * @param scheduler the scheduler class to receive scheduler callbacks + * @param sparkUser User to impersonate with when running tasks + * @param appName The framework name to display on the Mesos UI + * @param conf Spark configuration + * @param webuiUrl The WebUI url to link from Mesos UI + * @param checkpoint Option to checkpoint tasks for failover + * @param failoverTimeout Duration Mesos master expect scheduler to reconnect on disconnect + * @param frameworkId The id of the new framework */ - def startScheduler(masterUrl: String, scheduler: Scheduler, fwInfo: FrameworkInfo): Unit = { + protected def createSchedulerDriver( + masterUrl: String, + scheduler: Scheduler, + sparkUser: String, + appName: String, + conf: SparkConf, + webuiUrl: Option[String] = None, + checkpoint: Option[Boolean] = None, + failoverTimeout: Option[Double] = None, + frameworkId: Option[String] = None): SchedulerDriver = { + val fwInfoBuilder = FrameworkInfo.newBuilder().setUser(sparkUser).setName(appName) + val credBuilder = Credential.newBuilder() + webuiUrl.foreach { url => fwInfoBuilder.setWebuiUrl(url) } + checkpoint.foreach { checkpoint => fwInfoBuilder.setCheckpoint(checkpoint) } + failoverTimeout.foreach { timeout => fwInfoBuilder.setFailoverTimeout(timeout) } + frameworkId.foreach { id => + fwInfoBuilder.setId(FrameworkID.newBuilder().setValue(id).build()) + } + conf.getOption("spark.mesos.principal").foreach { principal => + fwInfoBuilder.setPrincipal(principal) + credBuilder.setPrincipal(principal) + } + conf.getOption("spark.mesos.secret").foreach { secret => + credBuilder.setSecret(ByteString.copyFromUtf8(secret)) + } + if (credBuilder.hasSecret && !fwInfoBuilder.hasPrincipal) { + throw new SparkException( + "spark.mesos.principal must be configured when spark.mesos.secret is set") + } + conf.getOption("spark.mesos.role").foreach { role => + fwInfoBuilder.setRole(role) + } + if (credBuilder.hasPrincipal) { + new MesosSchedulerDriver( + scheduler, fwInfoBuilder.build(), masterUrl, credBuilder.build()) + } else { + new MesosSchedulerDriver(scheduler, fwInfoBuilder.build(), masterUrl) + } + } + + /** + * Starts the MesosSchedulerDriver and stores the current running driver to this new instance. + * This driver is expected to not be running. + * This method returns only after the scheduler has registered with Mesos. + */ + def startScheduler(newDriver: SchedulerDriver): Unit = { synchronized { if (mesosDriver != null) { registerLatch.await() @@ -56,11 +111,11 @@ private[mesos] trait MesosSchedulerUtils extends Logging { setDaemon(true) override def run() { - mesosDriver = new MesosSchedulerDriver(scheduler, fwInfo, masterUrl) + mesosDriver = newDriver try { val ret = mesosDriver.run() logInfo("driver.run() returned with code " + ret) - if (ret.equals(Status.DRIVER_ABORTED)) { + if (ret != null && ret.equals(Status.DRIVER_ABORTED)) { System.exit(1) } } catch { @@ -79,17 +134,201 @@ private[mesos] trait MesosSchedulerUtils extends Logging { /** * Signal that the scheduler has registered with Mesos. */ + protected def getResource(res: JList[Resource], name: String): Double = { + // A resource can have multiple values in the offer since it can either be from + // a specific role or wildcard. + res.filter(_.getName == name).map(_.getScalar.getValue).sum + } + protected def markRegistered(): Unit = { registerLatch.countDown() } + def createResource(name: String, amount: Double, role: Option[String] = None): Resource = { + val builder = Resource.newBuilder() + .setName(name) + .setType(Value.Type.SCALAR) + .setScalar(Value.Scalar.newBuilder().setValue(amount).build()) + + role.foreach { r => builder.setRole(r) } + + builder.build() + } + + /** + * Partition the existing set of resources into two groups, those remaining to be + * scheduled and those requested to be used for a new task. + * @param resources The full list of available resources + * @param resourceName The name of the resource to take from the available resources + * @param amountToUse The amount of resources to take from the available resources + * @return The remaining resources list and the used resources list. + */ + def partitionResources( + resources: JList[Resource], + resourceName: String, + amountToUse: Double): (List[Resource], List[Resource]) = { + var remain = amountToUse + var requestedResources = new ArrayBuffer[Resource] + val remainingResources = resources.map { + case r => { + if (remain > 0 && + r.getType == Value.Type.SCALAR && + r.getScalar.getValue > 0.0 && + r.getName == resourceName) { + val usage = Math.min(remain, r.getScalar.getValue) + requestedResources += createResource(resourceName, usage, Some(r.getRole)) + remain -= usage + createResource(resourceName, r.getScalar.getValue - usage, Some(r.getRole)) + } else { + r + } + } + } + + // Filter any resource that has depleted. + val filteredResources = + remainingResources.filter(r => r.getType != Value.Type.SCALAR || r.getScalar.getValue > 0.0) + + (filteredResources.toList, requestedResources.toList) + } + + /** Helper method to get the key,value-set pair for a Mesos Attribute protobuf */ + protected def getAttribute(attr: Attribute): (String, Set[String]) = { + (attr.getName, attr.getText.getValue.split(',').toSet) + } + + + /** Build a Mesos resource protobuf object */ + protected def createResource(resourceName: String, quantity: Double): Protos.Resource = { + Resource.newBuilder() + .setName(resourceName) + .setType(Value.Type.SCALAR) + .setScalar(Value.Scalar.newBuilder().setValue(quantity).build()) + .build() + } + + /** + * Converts the attributes from the resource offer into a Map of name -> Attribute Value + * The attribute values are the mesos attribute types and they are + * @param offerAttributes + * @return + */ + protected def toAttributeMap(offerAttributes: JList[Attribute]): Map[String, GeneratedMessage] = { + offerAttributes.map(attr => { + val attrValue = attr.getType match { + case Value.Type.SCALAR => attr.getScalar + case Value.Type.RANGES => attr.getRanges + case Value.Type.SET => attr.getSet + case Value.Type.TEXT => attr.getText + } + (attr.getName, attrValue) + }).toMap + } + + + /** + * Match the requirements (if any) to the offer attributes. + * if attribute requirements are not specified - return true + * else if attribute is defined and no values are given, simple attribute presence is performed + * else if attribute name and value is specified, subset match is performed on slave attributes + */ + def matchesAttributeRequirements( + slaveOfferConstraints: Map[String, Set[String]], + offerAttributes: Map[String, GeneratedMessage]): Boolean = { + slaveOfferConstraints.forall { + // offer has the required attribute and subsumes the required values for that attribute + case (name, requiredValues) => + offerAttributes.get(name) match { + case None => false + case Some(_) if requiredValues.isEmpty => true // empty value matches presence + case Some(scalarValue: Value.Scalar) => + // check if provided values is less than equal to the offered values + requiredValues.map(_.toDouble).exists(_ <= scalarValue.getValue) + case Some(rangeValue: Value.Range) => + val offerRange = rangeValue.getBegin to rangeValue.getEnd + // Check if there is some required value that is between the ranges specified + // Note: We only support the ability to specify discrete values, in the future + // we may expand it to subsume ranges specified with a XX..YY value or something + // similar to that. + requiredValues.map(_.toLong).exists(offerRange.contains(_)) + case Some(offeredValue: Value.Set) => + // check if the specified required values is a subset of offered set + requiredValues.subsetOf(offeredValue.getItemList.toSet) + case Some(textValue: Value.Text) => + // check if the specified value is equal, if multiple values are specified + // we succeed if any of them match. + requiredValues.contains(textValue.getValue) + } + } + } + /** - * Get the amount of resources for the specified type from the resource list + * Parses the attributes constraints provided to spark and build a matching data struct: + * Map[, Set[values-to-match]] + * The constraints are specified as ';' separated key-value pairs where keys and values + * are separated by ':'. The ':' implies equality (for singular values) and "is one of" for + * multiple values (comma separated). For example: + * {{{ + * parseConstraintString("tachyon:true;zone:us-east-1a,us-east-1b") + * // would result in + * + * Map( + * "tachyon" -> Set("true"), + * "zone": -> Set("us-east-1a", "us-east-1b") + * ) + * }}} + * + * Mesos documentation: http://mesos.apache.org/documentation/attributes-resources/ + * https://github.com/apache/mesos/blob/master/src/common/values.cpp + * https://github.com/apache/mesos/blob/master/src/common/attributes.cpp + * + * @param constraintsVal constaints string consisting of ';' separated key-value pairs (separated + * by ':') + * @return Map of constraints to match resources offers. */ - protected def getResource(res: List[Resource], name: String): Double = { - for (r <- res if r.getName == name) { - return r.getScalar.getValue + def parseConstraintString(constraintsVal: String): Map[String, Set[String]] = { + /* + Based on mesos docs: + attributes : attribute ( ";" attribute )* + attribute : labelString ":" ( labelString | "," )+ + labelString : [a-zA-Z0-9_/.-] + */ + val splitter = Splitter.on(';').trimResults().withKeyValueSeparator(':') + // kv splitter + if (constraintsVal.isEmpty) { + Map() + } else { + try { + Map() ++ mapAsScalaMap(splitter.split(constraintsVal)).map { + case (k, v) => + if (v == null || v.isEmpty) { + (k, Set[String]()) + } else { + (k, v.split(',').toSet) + } + } + } catch { + case NonFatal(e) => + throw new IllegalArgumentException(s"Bad constraint string: $constraintsVal", e) + } } - 0.0 } + + // These defaults copied from YARN + private val MEMORY_OVERHEAD_FRACTION = 0.10 + private val MEMORY_OVERHEAD_MINIMUM = 384 + + /** + * Return the amount of memory to allocate to each executor, taking into account + * container overheads. + * @param sc SparkContext to use to get `spark.mesos.executor.memoryOverhead` value + * @return memory requirement as (0.1 * ) or MEMORY_OVERHEAD_MINIMUM + * (whichever is larger) + */ + def calculateTotalMemory(sc: SparkContext): Int = { + sc.conf.getInt("spark.mesos.executor.memoryOverhead", + math.max(MEMORY_OVERHEAD_FRACTION * sc.executorMemory, MEMORY_OVERHEAD_MINIMUM).toInt) + + sc.executorMemory + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 3078a1b10be8b..4d48fcfea44e7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -17,13 +17,16 @@ package org.apache.spark.scheduler.local +import java.io.File +import java.net.URL import java.nio.ByteBuffer import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} -import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} +import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.ExecutorInfo private case class ReviveOffers() @@ -40,6 +43,7 @@ private case class StopExecutor() */ private[spark] class LocalEndpoint( override val rpcEnv: RpcEnv, + userClassPath: Seq[URL], scheduler: TaskSchedulerImpl, executorBackend: LocalBackend, private val totalCores: Int) @@ -47,11 +51,11 @@ private[spark] class LocalEndpoint( private var freeCores = totalCores - private val localExecutorId = SparkContext.DRIVER_IDENTIFIER - private val localExecutorHostname = "localhost" + val localExecutorId = SparkContext.DRIVER_IDENTIFIER + val localExecutorHostname = "localhost" private val executor = new Executor( - localExecutorId, localExecutorHostname, SparkEnv.get, isLocal = true) + localExecutorId, localExecutorHostname, SparkEnv.get, userClassPath, isLocal = true) override def receive: PartialFunction[Any, Unit] = { case ReviveOffers => @@ -96,11 +100,28 @@ private[spark] class LocalBackend( extends SchedulerBackend with ExecutorBackend with Logging { private val appId = "local-" + System.currentTimeMillis - var localEndpoint: RpcEndpointRef = null + private var localEndpoint: RpcEndpointRef = null + private val userClassPath = getUserClasspath(conf) + private val listenerBus = scheduler.sc.listenerBus + + /** + * Returns a list of URLs representing the user classpath. + * + * @param conf Spark configuration. + */ + def getUserClasspath(conf: SparkConf): Seq[URL] = { + val userClassPathStr = conf.getOption("spark.executor.extraClassPath") + userClassPathStr.map(_.split(File.pathSeparator)).toSeq.flatten.map(new File(_).toURI.toURL) + } override def start() { - localEndpoint = SparkEnv.get.rpcEnv.setupEndpoint( - "LocalBackendEndpoint", new LocalEndpoint(SparkEnv.get.rpcEnv, scheduler, this, totalCores)) + val rpcEnv = SparkEnv.get.rpcEnv + val executorEndpoint = new LocalEndpoint(rpcEnv, userClassPath, scheduler, this, totalCores) + localEndpoint = rpcEnv.setupEndpoint("LocalBackendEndpoint", executorEndpoint) + listenerBus.post(SparkListenerExecutorAdded( + System.currentTimeMillis, + executorEndpoint.localExecutorId, + new ExecutorInfo(executorEndpoint.localExecutorHostname, totalCores, Map.empty))) } override def stop() { diff --git a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala new file mode 100644 index 0000000000000..62f8aae7f2126 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.nio.ByteBuffer + +import scala.collection.mutable + +import com.esotericsoftware.kryo.{Kryo, Serializer => KSerializer} +import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} +import org.apache.avro.{Schema, SchemaNormalization} +import org.apache.avro.generic.{GenericData, GenericRecord} +import org.apache.avro.io._ +import org.apache.commons.io.IOUtils + +import org.apache.spark.{SparkException, SparkEnv} +import org.apache.spark.io.CompressionCodec + +/** + * Custom serializer used for generic Avro records. If the user registers the schemas + * ahead of time, then the schema's fingerprint will be sent with each message instead of the actual + * schema, as to reduce network IO. + * Actions like parsing or compressing schemas are computationally expensive so the serializer + * caches all previously seen values as to reduce the amount of work needed to do. + * @param schemas a map where the keys are unique IDs for Avro schemas and the values are the + * string representation of the Avro schema, used to decrease the amount of data + * that needs to be serialized. + */ +private[serializer] class GenericAvroSerializer(schemas: Map[Long, String]) + extends KSerializer[GenericRecord] { + + /** Used to reduce the amount of effort to compress the schema */ + private val compressCache = new mutable.HashMap[Schema, Array[Byte]]() + private val decompressCache = new mutable.HashMap[ByteBuffer, Schema]() + + /** Reuses the same datum reader/writer since the same schema will be used many times */ + private val writerCache = new mutable.HashMap[Schema, DatumWriter[_]]() + private val readerCache = new mutable.HashMap[Schema, DatumReader[_]]() + + /** Fingerprinting is very expensive so this alleviates most of the work */ + private val fingerprintCache = new mutable.HashMap[Schema, Long]() + private val schemaCache = new mutable.HashMap[Long, Schema]() + + // GenericAvroSerializer can't take a SparkConf in the constructor b/c then it would become + // a member of KryoSerializer, which would make KryoSerializer not Serializable. We make + // the codec lazy here just b/c in some unit tests, we use a KryoSerializer w/out having + // the SparkEnv set (note those tests would fail if they tried to serialize avro data). + private lazy val codec = CompressionCodec.createCodec(SparkEnv.get.conf) + + /** + * Used to compress Schemas when they are being sent over the wire. + * The compression results are memoized to reduce the compression time since the + * same schema is compressed many times over + */ + def compress(schema: Schema): Array[Byte] = compressCache.getOrElseUpdate(schema, { + val bos = new ByteArrayOutputStream() + val out = codec.compressedOutputStream(bos) + out.write(schema.toString.getBytes("UTF-8")) + out.close() + bos.toByteArray + }) + + /** + * Decompresses the schema into the actual in-memory object. Keeps an internal cache of already + * seen values so to limit the number of times that decompression has to be done. + */ + def decompress(schemaBytes: ByteBuffer): Schema = decompressCache.getOrElseUpdate(schemaBytes, { + val bis = new ByteArrayInputStream(schemaBytes.array()) + val bytes = IOUtils.toByteArray(codec.compressedInputStream(bis)) + new Schema.Parser().parse(new String(bytes, "UTF-8")) + }) + + /** + * Serializes a record to the given output stream. It caches a lot of the internal data as + * to not redo work + */ + def serializeDatum[R <: GenericRecord](datum: R, output: KryoOutput): Unit = { + val encoder = EncoderFactory.get.binaryEncoder(output, null) + val schema = datum.getSchema + val fingerprint = fingerprintCache.getOrElseUpdate(schema, { + SchemaNormalization.parsingFingerprint64(schema) + }) + schemas.get(fingerprint) match { + case Some(_) => + output.writeBoolean(true) + output.writeLong(fingerprint) + case None => + output.writeBoolean(false) + val compressedSchema = compress(schema) + output.writeInt(compressedSchema.length) + output.writeBytes(compressedSchema) + } + + writerCache.getOrElseUpdate(schema, GenericData.get.createDatumWriter(schema)) + .asInstanceOf[DatumWriter[R]] + .write(datum, encoder) + encoder.flush() + } + + /** + * Deserializes generic records into their in-memory form. There is internal + * state to keep a cache of already seen schemas and datum readers. + */ + def deserializeDatum(input: KryoInput): GenericRecord = { + val schema = { + if (input.readBoolean()) { + val fingerprint = input.readLong() + schemaCache.getOrElseUpdate(fingerprint, { + schemas.get(fingerprint) match { + case Some(s) => new Schema.Parser().parse(s) + case None => + throw new SparkException( + "Error reading attempting to read avro data -- encountered an unknown " + + s"fingerprint: $fingerprint, not sure what schema to use. This could happen " + + "if you registered additional schemas after starting your spark context.") + } + }) + } else { + val length = input.readInt() + decompress(ByteBuffer.wrap(input.readBytes(length))) + } + } + val decoder = DecoderFactory.get.directBinaryDecoder(input, null) + readerCache.getOrElseUpdate(schema, GenericData.get.createDatumReader(schema)) + .asInstanceOf[DatumReader[GenericRecord]] + .read(null, decoder) + } + + override def write(kryo: Kryo, output: KryoOutput, datum: GenericRecord): Unit = + serializeDatum(datum, output) + + override def read(kryo: Kryo, input: KryoInput, datumClass: Class[GenericRecord]): GenericRecord = + deserializeDatum(input) +} diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index 698d1384d580d..4a5274b46b7a0 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -62,8 +62,11 @@ private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoa extends DeserializationStream { private val objIn = new ObjectInputStream(in) { - override def resolveClass(desc: ObjectStreamClass): Class[_] = + override def resolveClass(desc: ObjectStreamClass): Class[_] = { + // scalastyle:off classforname Class.forName(desc.getName, false, loader) + // scalastyle:on classforname + } } def readObject[T: ClassTag](): T = objIn.readObject().asInstanceOf[T] diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index ed35cffe968f8..0ff7562e912ca 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -27,6 +27,7 @@ import com.esotericsoftware.kryo.{Kryo, KryoException} import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer} import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator} +import org.apache.avro.generic.{GenericData, GenericRecord} import org.roaringbitmap.{ArrayContainer, BitmapContainer, RoaringArray, RoaringBitmap} import org.apache.spark._ @@ -73,6 +74,8 @@ class KryoSerializer(conf: SparkConf) .split(',') .filter(!_.isEmpty) + private val avroSchemas = conf.getAvroSchema + def newKryoOutput(): KryoOutput = new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize)) def newKryo(): Kryo = { @@ -101,7 +104,11 @@ class KryoSerializer(conf: SparkConf) kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer()) + kryo.register(classOf[GenericRecord], new GenericAvroSerializer(avroSchemas)) + kryo.register(classOf[GenericData.Record], new GenericAvroSerializer(avroSchemas)) + try { + // scalastyle:off classforname // Use the default classloader when calling the user registrator. Thread.currentThread.setContextClassLoader(classLoader) // Register classes given through spark.kryo.classesToRegister. @@ -111,6 +118,7 @@ class KryoSerializer(conf: SparkConf) userRegistrator .map(Class.forName(_, true, classLoader).newInstance().asInstanceOf[KryoRegistrator]) .foreach { reg => reg.registerClasses(kryo) } + // scalastyle:on classforname } catch { case e: Exception => throw new SparkException(s"Failed to register classes with Kryo", e) diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala index cc2f0506817d3..a1b1e1631eafb 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala @@ -407,7 +407,9 @@ private[spark] object SerializationDebugger extends Logging { /** ObjectStreamClass$ClassDataSlot.desc field */ val DescField: Field = { + // scalastyle:off classforname val f = Class.forName("java.io.ObjectStreamClass$ClassDataSlot").getDeclaredField("desc") + // scalastyle:on classforname f.setAccessible(true) f } diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala index 6c3b3080d2605..f6a96d81e7aa9 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -35,7 +35,7 @@ import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVecto /** A group of writers for a ShuffleMapTask, one writer per reducer. */ private[spark] trait ShuffleWriterGroup { - val writers: Array[BlockObjectWriter] + val writers: Array[DiskBlockObjectWriter] /** @param success Indicates all writes were successful. If false, no blocks will be recorded. */ def releaseWriters(success: Boolean) @@ -113,15 +113,15 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) val openStartTime = System.nanoTime val serializerInstance = serializer.newInstance() - val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) { + val writers: Array[DiskBlockObjectWriter] = if (consolidateShuffleFiles) { fileGroup = getUnusedFileGroup() - Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => + Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId => val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializerInstance, bufferSize, writeMetrics) } } else { - Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => + Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId => val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) val blockFile = blockManager.diskBlockManager.getFile(blockId) // Because of previous failures, the shuffle file may already exist on this machine. diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index d9c63b6e7bbb9..fae69551e7330 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -114,7 +114,7 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB } private[spark] object IndexShuffleBlockResolver { - // No-op reduce ID used in interactions with disk store and BlockObjectWriter. + // No-op reduce ID used in interactions with disk store and DiskBlockObjectWriter. // The disk store currently expects puts to relate to a (map, reduce) pair, but in the sort // shuffle outputs for several reduces are glommed into a single file. // TODO: Avoid this entirely by having the DiskBlockObjectWriter not require a BlockId. diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala index 3bcc7178a3d8b..f038b722957b8 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala @@ -19,95 +19,101 @@ package org.apache.spark.shuffle import scala.collection.mutable -import org.apache.spark.{Logging, SparkException, SparkConf} +import org.apache.spark.{Logging, SparkException, SparkConf, TaskContext} /** - * Allocates a pool of memory to task threads for use in shuffle operations. Each disk-spilling + * Allocates a pool of memory to tasks for use in shuffle operations. Each disk-spilling * collection (ExternalAppendOnlyMap or ExternalSorter) used by these tasks can acquire memory * from this pool and release it as it spills data out. When a task ends, all its memory will be * released by the Executor. * - * This class tries to ensure that each thread gets a reasonable share of memory, instead of some - * thread ramping up to a large amount first and then causing others to spill to disk repeatedly. - * If there are N threads, it ensures that each thread can acquire at least 1 / 2N of the memory + * This class tries to ensure that each task gets a reasonable share of memory, instead of some + * task ramping up to a large amount first and then causing others to spill to disk repeatedly. + * If there are N tasks, it ensures that each tasks can acquire at least 1 / 2N of the memory * before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the - * set of active threads and redo the calculations of 1 / 2N and 1 / N in waiting threads whenever + * set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever * this set changes. This is all done by synchronizing access on "this" to mutate state and using * wait() and notifyAll() to signal changes. */ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging { - private val threadMemory = new mutable.HashMap[Long, Long]() // threadId -> memory bytes + private val taskMemory = new mutable.HashMap[Long, Long]() // taskAttemptId -> memory bytes def this(conf: SparkConf) = this(ShuffleMemoryManager.getMaxMemory(conf)) + private def currentTaskAttemptId(): Long = { + // In case this is called on the driver, return an invalid task attempt id. + Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L) + } + /** - * Try to acquire up to numBytes memory for the current thread, and return the number of bytes + * Try to acquire up to numBytes memory for the current task, and return the number of bytes * obtained, or 0 if none can be allocated. This call may block until there is enough free memory - * in some situations, to make sure each thread has a chance to ramp up to at least 1 / 2N of the - * total memory pool (where N is the # of active threads) before it is forced to spill. This can - * happen if the number of threads increases but an older thread had a lot of memory already. + * in some situations, to make sure each task has a chance to ramp up to at least 1 / 2N of the + * total memory pool (where N is the # of active tasks) before it is forced to spill. This can + * happen if the number of tasks increases but an older task had a lot of memory already. */ def tryToAcquire(numBytes: Long): Long = synchronized { - val threadId = Thread.currentThread().getId + val taskAttemptId = currentTaskAttemptId() assert(numBytes > 0, "invalid number of bytes requested: " + numBytes) - // Add this thread to the threadMemory map just so we can keep an accurate count of the number - // of active threads, to let other threads ramp down their memory in calls to tryToAcquire - if (!threadMemory.contains(threadId)) { - threadMemory(threadId) = 0L - notifyAll() // Will later cause waiting threads to wake up and check numThreads again + // Add this task to the taskMemory map just so we can keep an accurate count of the number + // of active tasks, to let other tasks ramp down their memory in calls to tryToAcquire + if (!taskMemory.contains(taskAttemptId)) { + taskMemory(taskAttemptId) = 0L + notifyAll() // Will later cause waiting tasks to wake up and check numThreads again } // Keep looping until we're either sure that we don't want to grant this request (because this - // thread would have more than 1 / numActiveThreads of the memory) or we have enough free - // memory to give it (we always let each thread get at least 1 / (2 * numActiveThreads)). + // task would have more than 1 / numActiveTasks of the memory) or we have enough free + // memory to give it (we always let each task get at least 1 / (2 * numActiveTasks)). while (true) { - val numActiveThreads = threadMemory.keys.size - val curMem = threadMemory(threadId) - val freeMemory = maxMemory - threadMemory.values.sum + val numActiveTasks = taskMemory.keys.size + val curMem = taskMemory(taskAttemptId) + val freeMemory = maxMemory - taskMemory.values.sum - // How much we can grant this thread; don't let it grow to more than 1 / numActiveThreads; + // How much we can grant this task; don't let it grow to more than 1 / numActiveTasks; // don't let it be negative - val maxToGrant = math.min(numBytes, math.max(0, (maxMemory / numActiveThreads) - curMem)) + val maxToGrant = math.min(numBytes, math.max(0, (maxMemory / numActiveTasks) - curMem)) - if (curMem < maxMemory / (2 * numActiveThreads)) { - // We want to let each thread get at least 1 / (2 * numActiveThreads) before blocking; - // if we can't give it this much now, wait for other threads to free up memory - // (this happens if older threads allocated lots of memory before N grew) - if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveThreads) - curMem)) { + if (curMem < maxMemory / (2 * numActiveTasks)) { + // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking; + // if we can't give it this much now, wait for other tasks to free up memory + // (this happens if older tasks allocated lots of memory before N grew) + if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveTasks) - curMem)) { val toGrant = math.min(maxToGrant, freeMemory) - threadMemory(threadId) += toGrant + taskMemory(taskAttemptId) += toGrant return toGrant } else { - logInfo(s"Thread $threadId waiting for at least 1/2N of shuffle memory pool to be free") + logInfo( + s"Thread $taskAttemptId waiting for at least 1/2N of shuffle memory pool to be free") wait() } } else { // Only give it as much memory as is free, which might be none if it reached 1 / numThreads val toGrant = math.min(maxToGrant, freeMemory) - threadMemory(threadId) += toGrant + taskMemory(taskAttemptId) += toGrant return toGrant } } 0L // Never reached } - /** Release numBytes bytes for the current thread. */ + /** Release numBytes bytes for the current task. */ def release(numBytes: Long): Unit = synchronized { - val threadId = Thread.currentThread().getId - val curMem = threadMemory.getOrElse(threadId, 0L) + val taskAttemptId = currentTaskAttemptId() + val curMem = taskMemory.getOrElse(taskAttemptId, 0L) if (curMem < numBytes) { throw new SparkException( - s"Internal error: release called on ${numBytes} bytes but thread only has ${curMem}") + s"Internal error: release called on ${numBytes} bytes but task only has ${curMem}") } - threadMemory(threadId) -= numBytes + taskMemory(taskAttemptId) -= numBytes notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed } - /** Release all memory for the current thread and mark it as inactive (e.g. when a task ends). */ - def releaseMemoryForThisThread(): Unit = synchronized { - val threadId = Thread.currentThread().getId - threadMemory.remove(threadId) + /** Release all memory for the current task and mark it as inactive (e.g. when a task ends). */ + def releaseMemoryForThisTask(): Unit = synchronized { + val taskAttemptId = currentTaskAttemptId() + taskMemory.remove(taskAttemptId) notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala deleted file mode 100644 index 9d8e7e9f03aea..0000000000000 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.hash - -import java.io.InputStream - -import scala.collection.mutable.{ArrayBuffer, HashMap} -import scala.util.{Failure, Success} - -import org.apache.spark._ -import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator, - ShuffleBlockId} - -private[hash] object BlockStoreShuffleFetcher extends Logging { - def fetchBlockStreams( - shuffleId: Int, - reduceId: Int, - context: TaskContext, - blockManager: BlockManager, - mapOutputTracker: MapOutputTracker) - : Iterator[(BlockId, InputStream)] = - { - logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) - - val startTime = System.currentTimeMillis - val statuses = mapOutputTracker.getServerStatuses(shuffleId, reduceId) - logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format( - shuffleId, reduceId, System.currentTimeMillis - startTime)) - - val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]] - for (((address, size), index) <- statuses.zipWithIndex) { - splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size)) - } - - val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map { - case (address, splits) => - (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2))) - } - - val blockFetcherItr = new ShuffleBlockFetcherIterator( - context, - blockManager.shuffleClient, - blockManager, - blocksByAddress, - // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility - SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) - - // Make sure that fetch failures are wrapped inside a FetchFailedException for the scheduler - blockFetcherItr.map { blockPair => - val blockId = blockPair._1 - val blockOption = blockPair._2 - blockOption match { - case Success(inputStream) => { - (blockId, inputStream) - } - case Failure(e) => { - blockId match { - case ShuffleBlockId(shufId, mapId, _) => - val address = statuses(mapId.toInt)._1 - throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e) - case _ => - throw new SparkException( - "Failed to get block " + blockId + ", which is not a shuffle block", e) - } - } - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index d5c9880659dd3..de79fa56f017b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -17,10 +17,10 @@ package org.apache.spark.shuffle.hash -import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext} +import org.apache.spark.{InterruptibleIterator, Logging, MapOutputTracker, SparkEnv, TaskContext} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} -import org.apache.spark.storage.BlockManager +import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter @@ -31,8 +31,8 @@ private[spark] class HashShuffleReader[K, C]( context: TaskContext, blockManager: BlockManager = SparkEnv.get.blockManager, mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) - extends ShuffleReader[K, C] -{ + extends ShuffleReader[K, C] with Logging { + require(endPartition == startPartition + 1, "Hash shuffle currently only supports fetching one partition") @@ -40,11 +40,16 @@ private[spark] class HashShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { - val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams( - handle.shuffleId, startPartition, context, blockManager, mapOutputTracker) + val blockFetcherItr = new ShuffleBlockFetcherIterator( + context, + blockManager.shuffleClient, + blockManager, + mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition), + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) // Wrap the streams for compression based on configuration - val wrappedStreams = blockStreams.map { case (blockId, inputStream) => + val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) => blockManager.wrapForCompression(blockId, inputStream) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala index eb87cee15903c..41df70c602c30 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -22,7 +22,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.scheduler.MapStatus import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle._ -import org.apache.spark.storage.BlockObjectWriter +import org.apache.spark.storage.DiskBlockObjectWriter private[spark] class HashShuffleWriter[K, V]( shuffleBlockResolver: FileShuffleBlockResolver, @@ -102,7 +102,7 @@ private[spark] class HashShuffleWriter[K, V]( private def commitWritesAndBuildStatus(): MapStatus = { // Commit the writes. Get the size of each bucket block (total block size). - val sizes: Array[Long] = shuffle.writers.map { writer: BlockObjectWriter => + val sizes: Array[Long] = shuffle.writers.map { writer: DiskBlockObjectWriter => writer.commitAndClose() writer.fileSegment().length } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 1beafa1771448..86493673d958d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -648,7 +648,7 @@ private[spark] class BlockManager( file: File, serializerInstance: SerializerInstance, bufferSize: Int, - writeMetrics: ShuffleWriteMetrics): BlockObjectWriter = { + writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = { val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _) val syncWrites = conf.getBoolean("spark.shuffle.sync", false) new DiskBlockObjectWriter(blockId, file, serializerInstance, bufferSize, compressStream, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 68ed9096731c5..5dc0c537cbb62 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -60,10 +60,11 @@ class BlockManagerMasterEndpoint( register(blockManagerId, maxMemSize, slaveEndpoint) context.reply(true) - case UpdateBlockInfo( + case _updateBlockInfo @ UpdateBlockInfo( blockManagerId, blockId, storageLevel, deserializedSize, size, externalBlockStoreSize) => context.reply(updateBlockInfo( blockManagerId, blockId, storageLevel, deserializedSize, size, externalBlockStoreSize)) + listenerBus.post(SparkListenerBlockUpdated(BlockUpdatedInfo(_updateBlockInfo))) case GetLocations(blockId) => context.reply(getLocations(blockId)) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala new file mode 100644 index 0000000000000..2789e25b8d3ab --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import scala.collection.mutable + +import org.apache.spark.scheduler._ + +private[spark] case class BlockUIData( + blockId: BlockId, + location: String, + storageLevel: StorageLevel, + memSize: Long, + diskSize: Long, + externalBlockStoreSize: Long) + +/** + * The aggregated status of stream blocks in an executor + */ +private[spark] case class ExecutorStreamBlockStatus( + executorId: String, + location: String, + blocks: Seq[BlockUIData]) { + + def totalMemSize: Long = blocks.map(_.memSize).sum + + def totalDiskSize: Long = blocks.map(_.diskSize).sum + + def totalExternalBlockStoreSize: Long = blocks.map(_.externalBlockStoreSize).sum + + def numStreamBlocks: Int = blocks.size + +} + +private[spark] class BlockStatusListener extends SparkListener { + + private val blockManagers = + new mutable.HashMap[BlockManagerId, mutable.HashMap[BlockId, BlockUIData]] + + override def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit = { + val blockId = blockUpdated.blockUpdatedInfo.blockId + if (!blockId.isInstanceOf[StreamBlockId]) { + // Now we only monitor StreamBlocks + return + } + val blockManagerId = blockUpdated.blockUpdatedInfo.blockManagerId + val storageLevel = blockUpdated.blockUpdatedInfo.storageLevel + val memSize = blockUpdated.blockUpdatedInfo.memSize + val diskSize = blockUpdated.blockUpdatedInfo.diskSize + val externalBlockStoreSize = blockUpdated.blockUpdatedInfo.externalBlockStoreSize + + synchronized { + // Drop the update info if the block manager is not registered + blockManagers.get(blockManagerId).foreach { blocksInBlockManager => + if (storageLevel.isValid) { + blocksInBlockManager.put(blockId, + BlockUIData( + blockId, + blockManagerId.hostPort, + storageLevel, + memSize, + diskSize, + externalBlockStoreSize) + ) + } else { + // If isValid is not true, it means we should drop the block. + blocksInBlockManager -= blockId + } + } + } + } + + override def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded): Unit = { + synchronized { + blockManagers.put(blockManagerAdded.blockManagerId, mutable.HashMap()) + } + } + + override def onBlockManagerRemoved( + blockManagerRemoved: SparkListenerBlockManagerRemoved): Unit = synchronized { + blockManagers -= blockManagerRemoved.blockManagerId + } + + def allExecutorStreamBlockStatus: Seq[ExecutorStreamBlockStatus] = synchronized { + blockManagers.map { case (blockManagerId, blocks) => + ExecutorStreamBlockStatus( + blockManagerId.executorId, blockManagerId.hostPort, blocks.values.toSeq) + }.toSeq + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockUpdatedInfo.scala b/core/src/main/scala/org/apache/spark/storage/BlockUpdatedInfo.scala new file mode 100644 index 0000000000000..a5790e4454a89 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockUpdatedInfo.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.storage.BlockManagerMessages.UpdateBlockInfo + +/** + * :: DeveloperApi :: + * Stores information about a block status in a block manager. + */ +@DeveloperApi +case class BlockUpdatedInfo( + blockManagerId: BlockManagerId, + blockId: BlockId, + storageLevel: StorageLevel, + memSize: Long, + diskSize: Long, + externalBlockStoreSize: Long) + +private[spark] object BlockUpdatedInfo { + + private[spark] def apply(updateBlockInfo: UpdateBlockInfo): BlockUpdatedInfo = { + BlockUpdatedInfo( + updateBlockInfo.blockManagerId, + updateBlockInfo.blockId, + updateBlockInfo.storageLevel, + updateBlockInfo.memSize, + updateBlockInfo.diskSize, + updateBlockInfo.externalBlockStoreSize) + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 91ef86389a0c3..5f537692a16c5 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -124,10 +124,16 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon (blockId, getFile(blockId)) } + /** + * Create local directories for storing block data. These directories are + * located inside configured local directories and won't + * be deleted on JVM exit when using the external shuffle service. + */ private def createLocalDirs(conf: SparkConf): Array[File] = { - Utils.getOrCreateLocalRootDirs(conf).flatMap { rootDir => + Utils.getConfiguredLocalDirs(conf).flatMap { rootDir => try { val localDir = Utils.createDirectory(rootDir, "blockmgr") + Utils.chmod700(localDir) logInfo(s"Created local directory at $localDir") Some(localDir) } catch { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala similarity index 83% rename from core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala rename to core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index 7eeabd1e0489c..49d9154f95a5b 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -26,66 +26,25 @@ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.util.Utils /** - * An interface for writing JVM objects to some underlying storage. This interface allows - * appending data to an existing block, and can guarantee atomicity in the case of faults - * as it allows the caller to revert partial writes. + * A class for writing JVM objects directly to a file on disk. This class allows data to be appended + * to an existing block and can guarantee atomicity in the case of faults as it allows the caller to + * revert partial writes. * - * This interface does not support concurrent writes. Also, once the writer has - * been opened, it cannot be reopened again. - */ -private[spark] abstract class BlockObjectWriter(val blockId: BlockId) extends OutputStream { - - def open(): BlockObjectWriter - - def close() - - def isOpen: Boolean - - /** - * Flush the partial writes and commit them as a single atomic block. - */ - def commitAndClose(): Unit - - /** - * Reverts writes that haven't been flushed yet. Callers should invoke this function - * when there are runtime exceptions. This method will not throw, though it may be - * unsuccessful in truncating written data. - */ - def revertPartialWritesAndClose() - - /** - * Writes a key-value pair. - */ - def write(key: Any, value: Any) - - /** - * Notify the writer that a record worth of bytes has been written with OutputStream#write. - */ - def recordWritten() - - /** - * Returns the file segment of committed data that this Writer has written. - * This is only valid after commitAndClose() has been called. - */ - def fileSegment(): FileSegment -} - -/** - * BlockObjectWriter which writes directly to a file on disk. Appends to the given file. + * This class does not support concurrent writes. Also, once the writer has been opened it cannot be + * reopened again. */ private[spark] class DiskBlockObjectWriter( - blockId: BlockId, + val blockId: BlockId, file: File, serializerInstance: SerializerInstance, bufferSize: Int, compressStream: OutputStream => OutputStream, syncWrites: Boolean, - // These write metrics concurrently shared with other active BlockObjectWriter's who + // These write metrics concurrently shared with other active DiskBlockObjectWriters who // are themselves performing writes. All updates must be relative. writeMetrics: ShuffleWriteMetrics) - extends BlockObjectWriter(blockId) - with Logging -{ + extends OutputStream + with Logging { /** The file channel, used for repositioning / truncating the file. */ private var channel: FileChannel = null @@ -122,7 +81,7 @@ private[spark] class DiskBlockObjectWriter( */ private var numRecordsWritten = 0 - override def open(): BlockObjectWriter = { + def open(): DiskBlockObjectWriter = { if (hasBeenClosed) { throw new IllegalStateException("Writer already closed. Cannot be reopened.") } @@ -159,9 +118,12 @@ private[spark] class DiskBlockObjectWriter( } } - override def isOpen: Boolean = objOut != null + def isOpen: Boolean = objOut != null - override def commitAndClose(): Unit = { + /** + * Flush the partial writes and commit them as a single atomic block. + */ + def commitAndClose(): Unit = { if (initialized) { // NOTE: Because Kryo doesn't flush the underlying stream we explicitly flush both the // serializer stream and the lower level stream. @@ -177,9 +139,15 @@ private[spark] class DiskBlockObjectWriter( commitAndCloseHasBeenCalled = true } - // Discard current writes. We do this by flushing the outstanding writes and then - // truncating the file to its initial position. - override def revertPartialWritesAndClose() { + + /** + * Reverts writes that haven't been flushed yet. Callers should invoke this function + * when there are runtime exceptions. This method will not throw, though it may be + * unsuccessful in truncating written data. + */ + def revertPartialWritesAndClose() { + // Discard current writes. We do this by flushing the outstanding writes and then + // truncating the file to its initial position. try { if (initialized) { writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition) @@ -201,7 +169,10 @@ private[spark] class DiskBlockObjectWriter( } } - override def write(key: Any, value: Any) { + /** + * Writes a key-value pair. + */ + def write(key: Any, value: Any) { if (!initialized) { open() } @@ -221,7 +192,10 @@ private[spark] class DiskBlockObjectWriter( bs.write(kvBytes, offs, len) } - override def recordWritten(): Unit = { + /** + * Notify the writer that a record worth of bytes has been written with OutputStream#write. + */ + def recordWritten(): Unit = { numRecordsWritten += 1 writeMetrics.incShuffleRecordsWritten(1) @@ -230,7 +204,11 @@ private[spark] class DiskBlockObjectWriter( } } - override def fileSegment(): FileSegment = { + /** + * Returns the file segment of committed data that this Writer has written. + * This is only valid after commitAndClose() has been called. + */ + def fileSegment(): FileSegment = { if (!commitAndCloseHasBeenCalled) { throw new IllegalStateException( "fileSegment() is only valid after commitAndClose() has been called") diff --git a/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala b/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala index 291394ed34816..db965d54bafd6 100644 --- a/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala @@ -192,7 +192,7 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: .getOrElse(ExternalBlockStore.DEFAULT_BLOCK_MANAGER_NAME) try { - val instance = Class.forName(clsName) + val instance = Utils.classForName(clsName) .newInstance() .asInstanceOf[ExternalBlockManager] instance.init(blockManager, executorId) diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index ed609772e6979..6f27f00307f8c 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -23,6 +23,7 @@ import java.util.LinkedHashMap import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import org.apache.spark.TaskContext import org.apache.spark.util.{SizeEstimator, Utils} import org.apache.spark.util.collection.SizeTrackingVector @@ -43,11 +44,11 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // Ensure only one thread is putting, and if necessary, dropping blocks at any given time private val accountingLock = new Object - // A mapping from thread ID to amount of memory used for unrolling a block (in bytes) + // A mapping from taskAttemptId to amount of memory used for unrolling a block (in bytes) // All accesses of this map are assumed to have manually synchronized on `accountingLock` private val unrollMemoryMap = mutable.HashMap[Long, Long]() // Same as `unrollMemoryMap`, but for pending unroll memory as defined below. - // Pending unroll memory refers to the intermediate memory occupied by a thread + // Pending unroll memory refers to the intermediate memory occupied by a task // after the unroll but before the actual putting of the block in the cache. // This chunk of memory is expected to be released *as soon as* we finish // caching the corresponding block as opposed to until after the task finishes. @@ -250,21 +251,21 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) var elementsUnrolled = 0 // Whether there is still enough memory for us to continue unrolling this block var keepUnrolling = true - // Initial per-thread memory to request for unrolling blocks (bytes). Exposed for testing. + // Initial per-task memory to request for unrolling blocks (bytes). Exposed for testing. val initialMemoryThreshold = unrollMemoryThreshold // How often to check whether we need to request more memory val memoryCheckPeriod = 16 - // Memory currently reserved by this thread for this particular unrolling operation + // Memory currently reserved by this task for this particular unrolling operation var memoryThreshold = initialMemoryThreshold // Memory to request as a multiple of current vector size val memoryGrowthFactor = 1.5 - // Previous unroll memory held by this thread, for releasing later (only at the very end) - val previousMemoryReserved = currentUnrollMemoryForThisThread + // Previous unroll memory held by this task, for releasing later (only at the very end) + val previousMemoryReserved = currentUnrollMemoryForThisTask // Underlying vector for unrolling the block var vector = new SizeTrackingVector[Any] // Request enough memory to begin unrolling - keepUnrolling = reserveUnrollMemoryForThisThread(initialMemoryThreshold) + keepUnrolling = reserveUnrollMemoryForThisTask(initialMemoryThreshold) if (!keepUnrolling) { logWarning(s"Failed to reserve initial memory threshold of " + @@ -283,7 +284,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // Hold the accounting lock, in case another thread concurrently puts a block that // takes up the unrolling space we just ensured here accountingLock.synchronized { - if (!reserveUnrollMemoryForThisThread(amountToRequest)) { + if (!reserveUnrollMemoryForThisTask(amountToRequest)) { // If the first request is not granted, try again after ensuring free space // If there is still not enough space, give up and drop the partition val spaceToEnsure = maxUnrollMemory - currentUnrollMemory @@ -291,7 +292,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) val result = ensureFreeSpace(blockId, spaceToEnsure) droppedBlocks ++= result.droppedBlocks } - keepUnrolling = reserveUnrollMemoryForThisThread(amountToRequest) + keepUnrolling = reserveUnrollMemoryForThisTask(amountToRequest) } } // New threshold is currentSize * memoryGrowthFactor @@ -317,9 +318,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // later when the task finishes. if (keepUnrolling) { accountingLock.synchronized { - val amountToRelease = currentUnrollMemoryForThisThread - previousMemoryReserved - releaseUnrollMemoryForThisThread(amountToRelease) - reservePendingUnrollMemoryForThisThread(amountToRelease) + val amountToRelease = currentUnrollMemoryForThisTask - previousMemoryReserved + releaseUnrollMemoryForThisTask(amountToRelease) + reservePendingUnrollMemoryForThisTask(amountToRelease) } } } @@ -397,7 +398,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) } } // Release the unroll memory used because we no longer need the underlying Array - releasePendingUnrollMemoryForThisThread() + releasePendingUnrollMemoryForThisTask() } ResultWithDroppedBlocks(putSuccess, droppedBlocks) } @@ -427,9 +428,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // Take into account the amount of memory currently occupied by unrolling blocks // and minus the pending unroll memory for that block on current thread. - val threadId = Thread.currentThread().getId + val taskAttemptId = currentTaskAttemptId() val actualFreeMemory = freeMemory - currentUnrollMemory + - pendingUnrollMemoryMap.getOrElse(threadId, 0L) + pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) if (actualFreeMemory < space) { val rddToAdd = getRddId(blockIdToAdd) @@ -455,7 +456,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) logInfo(s"${selectedBlocks.size} blocks selected for dropping") for (blockId <- selectedBlocks) { val entry = entries.synchronized { entries.get(blockId) } - // This should never be null as only one thread should be dropping + // This should never be null as only one task should be dropping // blocks and removing entries. However the check is still here for // future safety. if (entry != null) { @@ -482,79 +483,85 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) entries.synchronized { entries.containsKey(blockId) } } + private def currentTaskAttemptId(): Long = { + // In case this is called on the driver, return an invalid task attempt id. + Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L) + } + /** - * Reserve additional memory for unrolling blocks used by this thread. + * Reserve additional memory for unrolling blocks used by this task. * Return whether the request is granted. */ - def reserveUnrollMemoryForThisThread(memory: Long): Boolean = { + def reserveUnrollMemoryForThisTask(memory: Long): Boolean = { accountingLock.synchronized { val granted = freeMemory > currentUnrollMemory + memory if (granted) { - val threadId = Thread.currentThread().getId - unrollMemoryMap(threadId) = unrollMemoryMap.getOrElse(threadId, 0L) + memory + val taskAttemptId = currentTaskAttemptId() + unrollMemoryMap(taskAttemptId) = unrollMemoryMap.getOrElse(taskAttemptId, 0L) + memory } granted } } /** - * Release memory used by this thread for unrolling blocks. - * If the amount is not specified, remove the current thread's allocation altogether. + * Release memory used by this task for unrolling blocks. + * If the amount is not specified, remove the current task's allocation altogether. */ - def releaseUnrollMemoryForThisThread(memory: Long = -1L): Unit = { - val threadId = Thread.currentThread().getId + def releaseUnrollMemoryForThisTask(memory: Long = -1L): Unit = { + val taskAttemptId = currentTaskAttemptId() accountingLock.synchronized { if (memory < 0) { - unrollMemoryMap.remove(threadId) + unrollMemoryMap.remove(taskAttemptId) } else { - unrollMemoryMap(threadId) = unrollMemoryMap.getOrElse(threadId, memory) - memory - // If this thread claims no more unroll memory, release it completely - if (unrollMemoryMap(threadId) <= 0) { - unrollMemoryMap.remove(threadId) + unrollMemoryMap(taskAttemptId) = unrollMemoryMap.getOrElse(taskAttemptId, memory) - memory + // If this task claims no more unroll memory, release it completely + if (unrollMemoryMap(taskAttemptId) <= 0) { + unrollMemoryMap.remove(taskAttemptId) } } } } /** - * Reserve the unroll memory of current unroll successful block used by this thread + * Reserve the unroll memory of current unroll successful block used by this task * until actually put the block into memory entry. */ - def reservePendingUnrollMemoryForThisThread(memory: Long): Unit = { - val threadId = Thread.currentThread().getId + def reservePendingUnrollMemoryForThisTask(memory: Long): Unit = { + val taskAttemptId = currentTaskAttemptId() accountingLock.synchronized { - pendingUnrollMemoryMap(threadId) = pendingUnrollMemoryMap.getOrElse(threadId, 0L) + memory + pendingUnrollMemoryMap(taskAttemptId) = + pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) + memory } } /** - * Release pending unroll memory of current unroll successful block used by this thread + * Release pending unroll memory of current unroll successful block used by this task */ - def releasePendingUnrollMemoryForThisThread(): Unit = { - val threadId = Thread.currentThread().getId + def releasePendingUnrollMemoryForThisTask(): Unit = { + val taskAttemptId = currentTaskAttemptId() accountingLock.synchronized { - pendingUnrollMemoryMap.remove(threadId) + pendingUnrollMemoryMap.remove(taskAttemptId) } } /** - * Return the amount of memory currently occupied for unrolling blocks across all threads. + * Return the amount of memory currently occupied for unrolling blocks across all tasks. */ def currentUnrollMemory: Long = accountingLock.synchronized { unrollMemoryMap.values.sum + pendingUnrollMemoryMap.values.sum } /** - * Return the amount of memory currently occupied for unrolling blocks by this thread. + * Return the amount of memory currently occupied for unrolling blocks by this task. */ - def currentUnrollMemoryForThisThread: Long = accountingLock.synchronized { - unrollMemoryMap.getOrElse(Thread.currentThread().getId, 0L) + def currentUnrollMemoryForThisTask: Long = accountingLock.synchronized { + unrollMemoryMap.getOrElse(currentTaskAttemptId(), 0L) } /** - * Return the number of threads currently unrolling blocks. + * Return the number of tasks currently unrolling blocks. */ - def numThreadsUnrolling: Int = accountingLock.synchronized { unrollMemoryMap.keys.size } + def numTasksUnrolling: Int = accountingLock.synchronized { unrollMemoryMap.keys.size } /** * Log information about current memory usage. @@ -566,7 +573,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) logInfo( s"Memory use = ${Utils.bytesToString(blocksMemory)} (blocks) + " + s"${Utils.bytesToString(unrollMemory)} (scratch space shared across " + - s"$numThreadsUnrolling thread(s)) = ${Utils.bytesToString(totalMemory)}. " + + s"$numTasksUnrolling tasks(s)) = ${Utils.bytesToString(totalMemory)}. " + s"Storage limit = ${Utils.bytesToString(maxMemory)}." ) } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index e49e39679e940..a759ceb96ec1e 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -21,18 +21,19 @@ import java.io.InputStream import java.util.concurrent.LinkedBlockingQueue import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} -import scala.util.{Failure, Try} +import scala.util.control.NonFatal -import org.apache.spark.{Logging, TaskContext} +import org.apache.spark.{Logging, SparkException, TaskContext} import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.Utils /** * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block * manager. For remote blocks, it fetches them using the provided BlockTransferService. * - * This creates an iterator of (BlockID, Try[InputStream]) tuples so the caller can handle blocks + * This creates an iterator of (BlockID, InputStream) tuples so the caller can handle blocks * in a pipelined fashion as they are received. * * The implementation throttles the remote fetches to they don't exceed maxBytesInFlight to avoid @@ -53,7 +54,7 @@ final class ShuffleBlockFetcherIterator( blockManager: BlockManager, blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], maxBytesInFlight: Long) - extends Iterator[(BlockId, Try[InputStream])] with Logging { + extends Iterator[(BlockId, InputStream)] with Logging { import ShuffleBlockFetcherIterator._ @@ -115,7 +116,7 @@ final class ShuffleBlockFetcherIterator( private[storage] def releaseCurrentResultBuffer(): Unit = { // Release the current buffer if necessary currentResult match { - case SuccessFetchResult(_, _, buf) => buf.release() + case SuccessFetchResult(_, _, _, buf) => buf.release() case _ => } currentResult = null @@ -132,7 +133,7 @@ final class ShuffleBlockFetcherIterator( while (iter.hasNext) { val result = iter.next() result match { - case SuccessFetchResult(_, _, buf) => buf.release() + case SuccessFetchResult(_, _, _, buf) => buf.release() case _ => } } @@ -157,7 +158,7 @@ final class ShuffleBlockFetcherIterator( // Increment the ref count because we need to pass this to a different thread. // This needs to be released after use. buf.retain() - results.put(new SuccessFetchResult(BlockId(blockId), sizeMap(blockId), buf)) + results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf)) shuffleMetrics.incRemoteBytesRead(buf.size) shuffleMetrics.incRemoteBlocksFetched(1) } @@ -166,7 +167,7 @@ final class ShuffleBlockFetcherIterator( override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) - results.put(new FailureFetchResult(BlockId(blockId), e)) + results.put(new FailureFetchResult(BlockId(blockId), address, e)) } } ) @@ -238,12 +239,12 @@ final class ShuffleBlockFetcherIterator( shuffleMetrics.incLocalBlocksFetched(1) shuffleMetrics.incLocalBytesRead(buf.size) buf.retain() - results.put(new SuccessFetchResult(blockId, 0, buf)) + results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf)) } catch { case e: Exception => // If we see an exception, stop immediately. logError(s"Error occurred while fetching local blocks", e) - results.put(new FailureFetchResult(blockId, e)) + results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e)) return } } @@ -275,12 +276,14 @@ final class ShuffleBlockFetcherIterator( override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch /** - * Fetches the next (BlockId, Try[InputStream]). If a task fails, the ManagedBuffers + * Fetches the next (BlockId, InputStream). If a task fails, the ManagedBuffers * underlying each InputStream will be freed by the cleanup() method registered with the * TaskCompletionListener. However, callers should close() these InputStreams * as soon as they are no longer needed, in order to release memory as early as possible. + * + * Throws a FetchFailedException if the next block could not be fetched. */ - override def next(): (BlockId, Try[InputStream]) = { + override def next(): (BlockId, InputStream) = { numBlocksProcessed += 1 val startFetchWait = System.currentTimeMillis() currentResult = results.take() @@ -289,7 +292,7 @@ final class ShuffleBlockFetcherIterator( shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait) result match { - case SuccessFetchResult(_, size, _) => bytesInFlight -= size + case SuccessFetchResult(_, _, size, _) => bytesInFlight -= size case _ => } // Send fetch requests up to maxBytesInFlight @@ -298,19 +301,28 @@ final class ShuffleBlockFetcherIterator( sendRequest(fetchRequests.dequeue()) } - val iteratorTry: Try[InputStream] = result match { - case FailureFetchResult(_, e) => - Failure(e) - case SuccessFetchResult(blockId, _, buf) => - // There is a chance that createInputStream can fail (e.g. fetching a local file that does - // not exist, SPARK-4085). In that case, we should propagate the right exception so - // the scheduler gets a FetchFailedException. - Try(buf.createInputStream()).map { inputStream => - new BufferReleasingInputStream(inputStream, this) + result match { + case FailureFetchResult(blockId, address, e) => + throwFetchFailedException(blockId, address, e) + + case SuccessFetchResult(blockId, address, _, buf) => + try { + (result.blockId, new BufferReleasingInputStream(buf.createInputStream(), this)) + } catch { + case NonFatal(t) => + throwFetchFailedException(blockId, address, t) } } + } - (result.blockId, iteratorTry) + private def throwFetchFailedException(blockId: BlockId, address: BlockManagerId, e: Throwable) = { + blockId match { + case ShuffleBlockId(shufId, mapId, reduceId) => + throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e) + case _ => + throw new SparkException( + "Failed to get block " + blockId + ", which is not a shuffle block", e) + } } } @@ -366,16 +378,22 @@ object ShuffleBlockFetcherIterator { */ private[storage] sealed trait FetchResult { val blockId: BlockId + val address: BlockManagerId } /** * Result of a fetch from a remote block successfully. * @param blockId block id + * @param address BlockManager that the block was fetched from. * @param size estimated size of the block, used to calculate bytesInFlight. * Note that this is NOT the exact bytes. * @param buf [[ManagedBuffer]] for the content. */ - private[storage] case class SuccessFetchResult(blockId: BlockId, size: Long, buf: ManagedBuffer) + private[storage] case class SuccessFetchResult( + blockId: BlockId, + address: BlockManagerId, + size: Long, + buf: ManagedBuffer) extends FetchResult { require(buf != null) require(size >= 0) @@ -384,8 +402,12 @@ object ShuffleBlockFetcherIterator { /** * Result of a fetch from a remote block unsuccessfully. * @param blockId block id + * @param address BlockManager that the block was attempted to be fetched from * @param e the failure exception */ - private[storage] case class FailureFetchResult(blockId: BlockId, e: Throwable) + private[storage] case class FailureFetchResult( + blockId: BlockId, + address: BlockManagerId, + e: Throwable) extends FetchResult } diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 06e616220c706..c8356467fab87 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -68,7 +68,9 @@ private[spark] object JettyUtils extends Logging { response.setStatus(HttpServletResponse.SC_OK) val result = servletParams.responder(request) response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") + // scalastyle:off println response.getWriter.println(servletParams.extractFn(result)) + // scalastyle:on println } else { response.setStatus(HttpServletResponse.SC_UNAUTHORIZED) response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") @@ -210,10 +212,16 @@ private[spark] object JettyUtils extends Logging { conf: SparkConf, serverName: String = ""): ServerInfo = { - val collection = new ContextHandlerCollection - collection.setHandlers(handlers.toArray) addFilters(handlers, conf) + val collection = new ContextHandlerCollection + val gzipHandlers = handlers.map { h => + val gzipHandler = new GzipHandler + gzipHandler.setHandler(h) + gzipHandler + } + collection.setHandlers(gzipHandlers.toArray) + // Bind to the given port, or throw a java.net.BindException if the port is occupied def connect(currentPort: Int): (Server, Int) = { val server = new Server(new InetSocketAddress(hostName, currentPort)) diff --git a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala new file mode 100644 index 0000000000000..17d7b39c2d951 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui + +import scala.xml.{Node, Unparsed} + +/** + * A data source that provides data for a page. + * + * @param pageSize the number of rows in a page + */ +private[ui] abstract class PagedDataSource[T](val pageSize: Int) { + + if (pageSize <= 0) { + throw new IllegalArgumentException("Page size must be positive") + } + + /** + * Return the size of all data. + */ + protected def dataSize: Int + + /** + * Slice a range of data. + */ + protected def sliceData(from: Int, to: Int): Seq[T] + + /** + * Slice the data for this page + */ + def pageData(page: Int): PageData[T] = { + val totalPages = (dataSize + pageSize - 1) / pageSize + if (page <= 0 || page > totalPages) { + throw new IndexOutOfBoundsException( + s"Page $page is out of range. Please select a page number between 1 and $totalPages.") + } + val from = (page - 1) * pageSize + val to = dataSize.min(page * pageSize) + PageData(totalPages, sliceData(from, to)) + } + +} + +/** + * The data returned by `PagedDataSource.pageData`, including the page number, the number of total + * pages and the data in this page. + */ +private[ui] case class PageData[T](totalPage: Int, data: Seq[T]) + +/** + * A paged table that will generate a HTML table for a specified page and also the page navigation. + */ +private[ui] trait PagedTable[T] { + + def tableId: String + + def tableCssClass: String + + def dataSource: PagedDataSource[T] + + def headers: Seq[Node] + + def row(t: T): Seq[Node] + + def table(page: Int): Seq[Node] = { + val _dataSource = dataSource + try { + val PageData(totalPages, data) = _dataSource.pageData(page) +

+ {pageNavigation(page, _dataSource.pageSize, totalPages)} + + {headers} + + {data.map(row)} + +
+
+ } catch { + case e: IndexOutOfBoundsException => + val PageData(totalPages, _) = _dataSource.pageData(1) +
+ {pageNavigation(1, _dataSource.pageSize, totalPages)} +
{e.getMessage}
+
+ } + } + + /** + * Return a page navigation. + *
    + *
  • If the totalPages is 1, the page navigation will be empty
  • + *
  • + * If the totalPages is more than 1, it will create a page navigation including a group of + * page numbers and a form to submit the page number. + *
  • + *
+ * + * Here are some examples of the page navigation: + * {{{ + * << < 11 12 13* 14 15 16 17 18 19 20 > >> + * + * This is the first group, so "<<" is hidden. + * < 1 2* 3 4 5 6 7 8 9 10 > >> + * + * This is the first group and the first page, so "<<" and "<" are hidden. + * 1* 2 3 4 5 6 7 8 9 10 > >> + * + * Assume totalPages is 19. This is the last group, so ">>" is hidden. + * << < 11 12 13* 14 15 16 17 18 19 > + * + * Assume totalPages is 19. This is the last group and the last page, so ">>" and ">" are hidden. + * << < 11 12 13 14 15 16 17 18 19* + * + * * means the current page number + * << means jumping to the first page of the previous group. + * < means jumping to the previous page. + * >> means jumping to the first page of the next group. + * > means jumping to the next page. + * }}} + */ + private[ui] def pageNavigation(page: Int, pageSize: Int, totalPages: Int): Seq[Node] = { + if (totalPages == 1) { + Nil + } else { + // A group includes all page numbers will be shown in the page navigation. + // The size of group is 10 means there are 10 page numbers will be shown. + // The first group is 1 to 10, the second is 2 to 20, and so on + val groupSize = 10 + val firstGroup = 0 + val lastGroup = (totalPages - 1) / groupSize + val currentGroup = (page - 1) / groupSize + val startPage = currentGroup * groupSize + 1 + val endPage = totalPages.min(startPage + groupSize - 1) + val pageTags = (startPage to endPage).map { p => + if (p == page) { + // The current page should be disabled so that it cannot be clicked. +
  • {p}
  • + } else { +
  • {p}
  • + } + } + val (goButtonJsFuncName, goButtonJsFunc) = goButtonJavascriptFunction + // When clicking the "Go" button, it will call this javascript method and then call + // "goButtonJsFuncName" + val formJs = + s"""$$(function(){ + | $$( "#form-task-page" ).submit(function(event) { + | var page = $$("#form-task-page-no").val() + | var pageSize = $$("#form-task-page-size").val() + | pageSize = pageSize ? pageSize: 100; + | if (page != "") { + | ${goButtonJsFuncName}(page, pageSize); + | } + | event.preventDefault(); + | }); + |}); + """.stripMargin + +
    +
    +
    + + + + + + +
    +
    + + +
    + } + } + + /** + * Return a link to jump to a page. + */ + def pageLink(page: Int): String + + /** + * Only the implementation knows how to create the url with a page number and the page size, so we + * leave this one to the implementation. The implementation should create a JavaScript method that + * accepts a page number along with the page size and jumps to the page. The return value is this + * method name and its JavaScript codes. + */ + def goButtonJavascriptFunction: (String, String) +} diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 7898039519201..718aea7e1dc22 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -27,7 +27,7 @@ import org.apache.spark.ui.scope.RDDOperationGraph /** Utility functions for generating XML pages with spark content. */ private[spark] object UIUtils extends Logging { - val TABLE_CLASS_NOT_STRIPED = "table table-bordered table-condensed sortable" + val TABLE_CLASS_NOT_STRIPED = "table table-bordered table-condensed" val TABLE_CLASS_STRIPED = TABLE_CLASS_NOT_STRIPED + " table-striped" // SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use. @@ -267,9 +267,17 @@ private[spark] object UIUtils extends Logging { fixedWidth: Boolean = false, id: Option[String] = None, headerClasses: Seq[String] = Seq.empty, - stripeRowsWithCss: Boolean = true): Seq[Node] = { + stripeRowsWithCss: Boolean = true, + sortable: Boolean = true): Seq[Node] = { - val listingTableClass = if (stripeRowsWithCss) TABLE_CLASS_STRIPED else TABLE_CLASS_NOT_STRIPED + val listingTableClass = { + val _tableClass = if (stripeRowsWithCss) TABLE_CLASS_STRIPED else TABLE_CLASS_NOT_STRIPED + if (sortable) { + _tableClass + " sortable" + } else { + _tableClass + } + } val colWidth = 100.toDouble / headers.size val colWidthAttr = if (fixedWidth) colWidth + "%" else "" diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala index ba03acdb38cc5..5a8c2914314c2 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala @@ -38,9 +38,11 @@ private[spark] object UIWorkloadGenerator { def main(args: Array[String]) { if (args.length < 3) { + // scalastyle:off println println( - "usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator " + + "Usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator " + "[master] [FIFO|FAIR] [#job set (4 jobs per set)]") + // scalastyle:on println System.exit(1) } @@ -96,6 +98,7 @@ private[spark] object UIWorkloadGenerator { for ((desc, job) <- jobs) { new Thread { override def run() { + // scalastyle:off println try { setProperties(desc) job() @@ -106,6 +109,7 @@ private[spark] object UIWorkloadGenerator { } finally { barrier.release() } + // scalastyle:on println } }.start Thread.sleep(INTER_JOB_WAIT_MS) diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 2c84e4485996e..61449847add3d 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -107,6 +107,25 @@ private[spark] abstract class WebUI( } } + /** + * Add a handler for static content. + * + * @param resourceBase Root of where to find resources to serve. + * @param path Path in UI where to mount the resources. + */ + def addStaticHandler(resourceBase: String, path: String): Unit = { + attachHandler(JettyUtils.createStaticHandler(resourceBase, path)) + } + + /** + * Remove a static content handler. + * + * @param path Path in UI to unmount. + */ + def removeStaticHandler(path: String): Unit = { + handlers.find(_.getContextPath() == path).foreach(detachHandler) + } + /** Initialize all components of the server. */ def initialize() diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index 2ce670ad02e97..e72547df7254b 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -79,6 +79,7 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { case JobExecutionStatus.SUCCEEDED => "succeeded" case JobExecutionStatus.FAILED => "failed" case JobExecutionStatus.RUNNING => "running" + case JobExecutionStatus.UNKNOWN => "unknown" } // The timeline library treats contents as HTML, so we have to escape them; for the diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 60e3c6343122c..cf04b5e59239b 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -17,6 +17,7 @@ package org.apache.spark.ui.jobs +import java.net.URLEncoder import java.util.Date import javax.servlet.http.HttpServletRequest @@ -27,13 +28,14 @@ import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo} -import org.apache.spark.ui.{ToolTips, WebUIPage, UIUtils} +import org.apache.spark.ui._ import org.apache.spark.ui.jobs.UIData._ -import org.apache.spark.ui.scope.RDDOperationGraph import org.apache.spark.util.{Utils, Distribution} /** Page showing statistics and task list for a given stage */ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { + import StagePage._ + private val progressListener = parent.progressListener private val operationGraphListener = parent.operationGraphListener @@ -74,6 +76,16 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val parameterAttempt = request.getParameter("attempt") require(parameterAttempt != null && parameterAttempt.nonEmpty, "Missing attempt parameter") + val parameterTaskPage = request.getParameter("task.page") + val parameterTaskSortColumn = request.getParameter("task.sort") + val parameterTaskSortDesc = request.getParameter("task.desc") + val parameterTaskPageSize = request.getParameter("task.pageSize") + + val taskPage = Option(parameterTaskPage).map(_.toInt).getOrElse(1) + val taskSortColumn = Option(parameterTaskSortColumn).getOrElse("Index") + val taskSortDesc = Option(parameterTaskSortDesc).map(_.toBoolean).getOrElse(false) + val taskPageSize = Option(parameterTaskPageSize).map(_.toInt).getOrElse(100) + // If this is set, expand the dag visualization by default val expandDagVizParam = request.getParameter("expandDagViz") val expandDagViz = expandDagVizParam != null && expandDagVizParam.toBoolean @@ -231,52 +243,47 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { accumulableRow, accumulables.values.toSeq) - val taskHeadersAndCssClasses: Seq[(String, String)] = - Seq( - ("Index", ""), ("ID", ""), ("Attempt", ""), ("Status", ""), ("Locality Level", ""), - ("Executor ID / Host", ""), ("Launch Time", ""), ("Duration", ""), - ("Scheduler Delay", TaskDetailsClassNames.SCHEDULER_DELAY), - ("Task Deserialization Time", TaskDetailsClassNames.TASK_DESERIALIZATION_TIME), - ("GC Time", ""), - ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME), - ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME)) ++ - {if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++ - {if (stageData.hasInput) Seq(("Input Size / Records", "")) else Nil} ++ - {if (stageData.hasOutput) Seq(("Output Size / Records", "")) else Nil} ++ - {if (stageData.hasShuffleRead) { - Seq(("Shuffle Read Blocked Time", TaskDetailsClassNames.SHUFFLE_READ_BLOCKED_TIME), - ("Shuffle Read Size / Records", ""), - ("Shuffle Remote Reads", TaskDetailsClassNames.SHUFFLE_READ_REMOTE_SIZE)) - } else { - Nil - }} ++ - {if (stageData.hasShuffleWrite) { - Seq(("Write Time", ""), ("Shuffle Write Size / Records", "")) - } else { - Nil - }} ++ - {if (stageData.hasBytesSpilled) { - Seq(("Shuffle Spill (Memory)", ""), ("Shuffle Spill (Disk)", "")) - } else { - Nil - }} ++ - Seq(("Errors", "")) - - val unzipped = taskHeadersAndCssClasses.unzip - val currentTime = System.currentTimeMillis() - val taskTable = UIUtils.listingTable( - unzipped._1, - taskRow( + val (taskTable, taskTableHTML) = try { + val _taskTable = new TaskPagedTable( + UIUtils.prependBaseUri(parent.basePath) + + s"/stages/stage?id=${stageId}&attempt=${stageAttemptId}", + tasks, hasAccumulators, stageData.hasInput, stageData.hasOutput, stageData.hasShuffleRead, stageData.hasShuffleWrite, stageData.hasBytesSpilled, - currentTime), - tasks, - headerClasses = unzipped._2) + currentTime, + pageSize = taskPageSize, + sortColumn = taskSortColumn, + desc = taskSortDesc + ) + (_taskTable, _taskTable.table(taskPage)) + } catch { + case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) => + (null,
    {e.getMessage}
    ) + } + + val jsForScrollingDownToTaskTable = + + + val taskIdsInPage = if (taskTable == null) Set.empty[Long] + else taskTable.dataSource.slicedTaskIds + // Excludes tasks which failed and have incomplete metrics val validTasks = tasks.filter(t => t.taskInfo.status == "SUCCESS" && t.taskMetrics.isDefined) @@ -332,7 +339,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { +: getFormattedTimeQuantiles(serializationTimes) val gettingResultTimes = validTasks.map { case TaskUIData(info, _, _) => - getGettingResultTime(info).toDouble + getGettingResultTime(info, currentTime).toDouble } val gettingResultQuantiles = @@ -346,7 +353,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { // machine and to send back the result (but not the time to fetch the task result, // if it needed to be fetched from the block manager on the worker). val schedulerDelays = validTasks.map { case TaskUIData(info, metrics, _) => - getSchedulerDelay(info, metrics.get).toDouble + getSchedulerDelay(info, metrics.get, currentTime).toDouble } val schedulerDelayTitle = Scheduler Delay @@ -499,12 +506,15 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { dagViz ++ maybeExpandDagViz ++ showAdditionalMetrics ++ - makeTimeline(stageData.taskData.values.toSeq, currentTime) ++ + makeTimeline( + // Only show the tasks in the table + stageData.taskData.values.toSeq.filter(t => taskIdsInPage.contains(t.taskInfo.taskId)), + currentTime) ++

    Summary Metrics for {numCompleted} Completed Tasks

    ++
    {summaryTable.getOrElse("No tasks have reported metrics yet.")}
    ++

    Aggregated Metrics by Executor

    ++ executorTable.toNodeSeq ++ maybeAccumulableTable ++ -

    Tasks

    ++ taskTable +

    Tasks

    ++ taskTableHTML ++ jsForScrollingDownToTaskTable UIUtils.headerSparkPage(stageHeader, content, parent, showVisualization = true) } } @@ -537,20 +547,27 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { (metricsOpt.flatMap(_.shuffleWriteMetrics .map(_.shuffleWriteTime)).getOrElse(0L) / 1e6).toLong val shuffleWriteTimeProportion = toProportion(shuffleWriteTime) - val executorComputingTime = metricsOpt.map(_.executorRunTime).getOrElse(0L) - - shuffleReadTime - shuffleWriteTime - val executorComputingTimeProportion = toProportion(executorComputingTime) + val serializationTime = metricsOpt.map(_.resultSerializationTime).getOrElse(0L) val serializationTimeProportion = toProportion(serializationTime) val deserializationTime = metricsOpt.map(_.executorDeserializeTime).getOrElse(0L) val deserializationTimeProportion = toProportion(deserializationTime) - val gettingResultTime = getGettingResultTime(taskUIData.taskInfo) + val gettingResultTime = getGettingResultTime(taskUIData.taskInfo, currentTime) val gettingResultTimeProportion = toProportion(gettingResultTime) - val schedulerDelay = totalExecutionTime - - (executorComputingTime + shuffleReadTime + shuffleWriteTime + - serializationTime + deserializationTime + gettingResultTime) - val schedulerDelayProportion = - (100 - executorComputingTimeProportion - shuffleReadTimeProportion - + val schedulerDelay = + metricsOpt.map(getSchedulerDelay(taskInfo, _, currentTime)).getOrElse(0L) + val schedulerDelayProportion = toProportion(schedulerDelay) + + val executorOverhead = serializationTime + deserializationTime + val executorRunTime = if (taskInfo.running) { + totalExecutionTime - executorOverhead - gettingResultTime + } else { + metricsOpt.map(_.executorRunTime).getOrElse( + totalExecutionTime - executorOverhead - gettingResultTime) + } + val executorComputingTime = executorRunTime - shuffleReadTime - shuffleWriteTime + val executorComputingTimeProportion = + (100 - schedulerDelayProportion - shuffleReadTimeProportion - shuffleWriteTimeProportion - serializationTimeProportion - deserializationTimeProportion - gettingResultTimeProportion) @@ -672,162 +689,619 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { } - def taskRow( - hasAccumulators: Boolean, - hasInput: Boolean, - hasOutput: Boolean, - hasShuffleRead: Boolean, - hasShuffleWrite: Boolean, - hasBytesSpilled: Boolean, - currentTime: Long)(taskData: TaskUIData): Seq[Node] = { - taskData match { case TaskUIData(info, metrics, errorMessage) => - val duration = if (info.status == "RUNNING") info.timeRunning(currentTime) - else metrics.map(_.executorRunTime).getOrElse(1L) - val formatDuration = if (info.status == "RUNNING") UIUtils.formatDuration(duration) - else metrics.map(m => UIUtils.formatDuration(m.executorRunTime)).getOrElse("") - val schedulerDelay = metrics.map(getSchedulerDelay(info, _)).getOrElse(0L) - val gcTime = metrics.map(_.jvmGCTime).getOrElse(0L) - val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L) - val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L) - val gettingResultTime = getGettingResultTime(info) - - val maybeAccumulators = info.accumulables - val accumulatorsReadable = maybeAccumulators.map{acc => s"${acc.name}: ${acc.update.get}"} - - val maybeInput = metrics.flatMap(_.inputMetrics) - val inputSortable = maybeInput.map(_.bytesRead.toString).getOrElse("") - val inputReadable = maybeInput - .map(m => s"${Utils.bytesToString(m.bytesRead)} (${m.readMethod.toString.toLowerCase()})") - .getOrElse("") - val inputRecords = maybeInput.map(_.recordsRead.toString).getOrElse("") - - val maybeOutput = metrics.flatMap(_.outputMetrics) - val outputSortable = maybeOutput.map(_.bytesWritten.toString).getOrElse("") - val outputReadable = maybeOutput - .map(m => s"${Utils.bytesToString(m.bytesWritten)}") - .getOrElse("") - val outputRecords = maybeOutput.map(_.recordsWritten.toString).getOrElse("") - - val maybeShuffleRead = metrics.flatMap(_.shuffleReadMetrics) - val shuffleReadBlockedTimeSortable = maybeShuffleRead - .map(_.fetchWaitTime.toString).getOrElse("") - val shuffleReadBlockedTimeReadable = - maybeShuffleRead.map(ms => UIUtils.formatDuration(ms.fetchWaitTime)).getOrElse("") - - val totalShuffleBytes = maybeShuffleRead.map(_.totalBytesRead) - val shuffleReadSortable = totalShuffleBytes.map(_.toString).getOrElse("") - val shuffleReadReadable = totalShuffleBytes.map(Utils.bytesToString).getOrElse("") - val shuffleReadRecords = maybeShuffleRead.map(_.recordsRead.toString).getOrElse("") - - val remoteShuffleBytes = maybeShuffleRead.map(_.remoteBytesRead) - val shuffleReadRemoteSortable = remoteShuffleBytes.map(_.toString).getOrElse("") - val shuffleReadRemoteReadable = remoteShuffleBytes.map(Utils.bytesToString).getOrElse("") - - val maybeShuffleWrite = metrics.flatMap(_.shuffleWriteMetrics) - val shuffleWriteSortable = maybeShuffleWrite.map(_.shuffleBytesWritten.toString).getOrElse("") - val shuffleWriteReadable = maybeShuffleWrite - .map(m => s"${Utils.bytesToString(m.shuffleBytesWritten)}").getOrElse("") - val shuffleWriteRecords = maybeShuffleWrite - .map(_.shuffleRecordsWritten.toString).getOrElse("") - - val maybeWriteTime = metrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleWriteTime) - val writeTimeSortable = maybeWriteTime.map(_.toString).getOrElse("") - val writeTimeReadable = maybeWriteTime.map(t => t / (1000 * 1000)).map { ms => - if (ms == 0) "" else UIUtils.formatDuration(ms) - }.getOrElse("") - - val maybeMemoryBytesSpilled = metrics.map(_.memoryBytesSpilled) - val memoryBytesSpilledSortable = maybeMemoryBytesSpilled.map(_.toString).getOrElse("") - val memoryBytesSpilledReadable = - maybeMemoryBytesSpilled.map(Utils.bytesToString).getOrElse("") - - val maybeDiskBytesSpilled = metrics.map(_.diskBytesSpilled) - val diskBytesSpilledSortable = maybeDiskBytesSpilled.map(_.toString).getOrElse("") - val diskBytesSpilledReadable = maybeDiskBytesSpilled.map(Utils.bytesToString).getOrElse("") - - - {info.index} - {info.taskId} - { - if (info.speculative) s"${info.attempt} (speculative)" else info.attempt.toString - } - {info.status} - {info.taskLocality} - {info.executorId} / {info.host} - {UIUtils.formatDate(new Date(info.launchTime))} - - {formatDuration} - - - {UIUtils.formatDuration(schedulerDelay.toLong)} - - - {UIUtils.formatDuration(taskDeserializationTime.toLong)} - - - {if (gcTime > 0) UIUtils.formatDuration(gcTime) else ""} - - - {UIUtils.formatDuration(serializationTime)} - - - {UIUtils.formatDuration(gettingResultTime)} - - {if (hasAccumulators) { - - {Unparsed(accumulatorsReadable.mkString("
    "))} - - }} - {if (hasInput) { - - {s"$inputReadable / $inputRecords"} - - }} - {if (hasOutput) { - - {s"$outputReadable / $outputRecords"} - - }} +} + +private[ui] object StagePage { + private[ui] def getGettingResultTime(info: TaskInfo, currentTime: Long): Long = { + if (info.gettingResult) { + if (info.finished) { + info.finishTime - info.gettingResultTime + } else { + // The task is still fetching the result. + currentTime - info.gettingResultTime + } + } else { + 0L + } + } + + private[ui] def getSchedulerDelay( + info: TaskInfo, metrics: TaskMetrics, currentTime: Long): Long = { + if (info.finished) { + val totalExecutionTime = info.finishTime - info.launchTime + val executorOverhead = (metrics.executorDeserializeTime + + metrics.resultSerializationTime) + math.max( + 0, + totalExecutionTime - metrics.executorRunTime - executorOverhead - + getGettingResultTime(info, currentTime)) + } else { + // The task is still running and the metrics like executorRunTime are not available. + 0L + } + } +} + +private[ui] case class TaskTableRowInputData(inputSortable: Long, inputReadable: String) + +private[ui] case class TaskTableRowOutputData(outputSortable: Long, outputReadable: String) + +private[ui] case class TaskTableRowShuffleReadData( + shuffleReadBlockedTimeSortable: Long, + shuffleReadBlockedTimeReadable: String, + shuffleReadSortable: Long, + shuffleReadReadable: String, + shuffleReadRemoteSortable: Long, + shuffleReadRemoteReadable: String) + +private[ui] case class TaskTableRowShuffleWriteData( + writeTimeSortable: Long, + writeTimeReadable: String, + shuffleWriteSortable: Long, + shuffleWriteReadable: String) + +private[ui] case class TaskTableRowBytesSpilledData( + memoryBytesSpilledSortable: Long, + memoryBytesSpilledReadable: String, + diskBytesSpilledSortable: Long, + diskBytesSpilledReadable: String) + +/** + * Contains all data that needs for sorting and generating HTML. Using this one rather than + * TaskUIData to avoid creating duplicate contents during sorting the data. + */ +private[ui] case class TaskTableRowData( + index: Int, + taskId: Long, + attempt: Int, + speculative: Boolean, + status: String, + taskLocality: String, + executorIdAndHost: String, + launchTime: Long, + duration: Long, + formatDuration: String, + schedulerDelay: Long, + taskDeserializationTime: Long, + gcTime: Long, + serializationTime: Long, + gettingResultTime: Long, + accumulators: Option[String], // HTML + input: Option[TaskTableRowInputData], + output: Option[TaskTableRowOutputData], + shuffleRead: Option[TaskTableRowShuffleReadData], + shuffleWrite: Option[TaskTableRowShuffleWriteData], + bytesSpilled: Option[TaskTableRowBytesSpilledData], + error: String) + +private[ui] class TaskDataSource( + tasks: Seq[TaskUIData], + hasAccumulators: Boolean, + hasInput: Boolean, + hasOutput: Boolean, + hasShuffleRead: Boolean, + hasShuffleWrite: Boolean, + hasBytesSpilled: Boolean, + currentTime: Long, + pageSize: Int, + sortColumn: String, + desc: Boolean) extends PagedDataSource[TaskTableRowData](pageSize) { + import StagePage._ + + // Convert TaskUIData to TaskTableRowData which contains the final contents to show in the table + // so that we can avoid creating duplicate contents during sorting the data + private val data = tasks.map(taskRow).sorted(ordering(sortColumn, desc)) + + private var _slicedTaskIds: Set[Long] = null + + override def dataSize: Int = data.size + + override def sliceData(from: Int, to: Int): Seq[TaskTableRowData] = { + val r = data.slice(from, to) + _slicedTaskIds = r.map(_.taskId).toSet + r + } + + def slicedTaskIds: Set[Long] = _slicedTaskIds + + private def taskRow(taskData: TaskUIData): TaskTableRowData = { + val TaskUIData(info, metrics, errorMessage) = taskData + val duration = if (info.status == "RUNNING") info.timeRunning(currentTime) + else metrics.map(_.executorRunTime).getOrElse(1L) + val formatDuration = if (info.status == "RUNNING") UIUtils.formatDuration(duration) + else metrics.map(m => UIUtils.formatDuration(m.executorRunTime)).getOrElse("") + val schedulerDelay = metrics.map(getSchedulerDelay(info, _, currentTime)).getOrElse(0L) + val gcTime = metrics.map(_.jvmGCTime).getOrElse(0L) + val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L) + val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L) + val gettingResultTime = getGettingResultTime(info, currentTime) + + val maybeAccumulators = info.accumulables + val accumulatorsReadable = maybeAccumulators.map { acc => + StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update.get}") + } + + val maybeInput = metrics.flatMap(_.inputMetrics) + val inputSortable = maybeInput.map(_.bytesRead).getOrElse(0L) + val inputReadable = maybeInput + .map(m => s"${Utils.bytesToString(m.bytesRead)} (${m.readMethod.toString.toLowerCase()})") + .getOrElse("") + val inputRecords = maybeInput.map(_.recordsRead.toString).getOrElse("") + + val maybeOutput = metrics.flatMap(_.outputMetrics) + val outputSortable = maybeOutput.map(_.bytesWritten).getOrElse(0L) + val outputReadable = maybeOutput + .map(m => s"${Utils.bytesToString(m.bytesWritten)}") + .getOrElse("") + val outputRecords = maybeOutput.map(_.recordsWritten.toString).getOrElse("") + + val maybeShuffleRead = metrics.flatMap(_.shuffleReadMetrics) + val shuffleReadBlockedTimeSortable = maybeShuffleRead.map(_.fetchWaitTime).getOrElse(0L) + val shuffleReadBlockedTimeReadable = + maybeShuffleRead.map(ms => UIUtils.formatDuration(ms.fetchWaitTime)).getOrElse("") + + val totalShuffleBytes = maybeShuffleRead.map(_.totalBytesRead) + val shuffleReadSortable = totalShuffleBytes.getOrElse(0L) + val shuffleReadReadable = totalShuffleBytes.map(Utils.bytesToString).getOrElse("") + val shuffleReadRecords = maybeShuffleRead.map(_.recordsRead.toString).getOrElse("") + + val remoteShuffleBytes = maybeShuffleRead.map(_.remoteBytesRead) + val shuffleReadRemoteSortable = remoteShuffleBytes.getOrElse(0L) + val shuffleReadRemoteReadable = remoteShuffleBytes.map(Utils.bytesToString).getOrElse("") + + val maybeShuffleWrite = metrics.flatMap(_.shuffleWriteMetrics) + val shuffleWriteSortable = maybeShuffleWrite.map(_.shuffleBytesWritten).getOrElse(0L) + val shuffleWriteReadable = maybeShuffleWrite + .map(m => s"${Utils.bytesToString(m.shuffleBytesWritten)}").getOrElse("") + val shuffleWriteRecords = maybeShuffleWrite + .map(_.shuffleRecordsWritten.toString).getOrElse("") + + val maybeWriteTime = metrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleWriteTime) + val writeTimeSortable = maybeWriteTime.getOrElse(0L) + val writeTimeReadable = maybeWriteTime.map(t => t / (1000 * 1000)).map { ms => + if (ms == 0) "" else UIUtils.formatDuration(ms) + }.getOrElse("") + + val maybeMemoryBytesSpilled = metrics.map(_.memoryBytesSpilled) + val memoryBytesSpilledSortable = maybeMemoryBytesSpilled.getOrElse(0L) + val memoryBytesSpilledReadable = + maybeMemoryBytesSpilled.map(Utils.bytesToString).getOrElse("") + + val maybeDiskBytesSpilled = metrics.map(_.diskBytesSpilled) + val diskBytesSpilledSortable = maybeDiskBytesSpilled.getOrElse(0L) + val diskBytesSpilledReadable = maybeDiskBytesSpilled.map(Utils.bytesToString).getOrElse("") + + val input = + if (hasInput) { + Some(TaskTableRowInputData(inputSortable, s"$inputReadable / $inputRecords")) + } else { + None + } + + val output = + if (hasOutput) { + Some(TaskTableRowOutputData(outputSortable, s"$outputReadable / $outputRecords")) + } else { + None + } + + val shuffleRead = + if (hasShuffleRead) { + Some(TaskTableRowShuffleReadData( + shuffleReadBlockedTimeSortable, + shuffleReadBlockedTimeReadable, + shuffleReadSortable, + s"$shuffleReadReadable / $shuffleReadRecords", + shuffleReadRemoteSortable, + shuffleReadRemoteReadable + )) + } else { + None + } + + val shuffleWrite = + if (hasShuffleWrite) { + Some(TaskTableRowShuffleWriteData( + writeTimeSortable, + writeTimeReadable, + shuffleWriteSortable, + s"$shuffleWriteReadable / $shuffleWriteRecords" + )) + } else { + None + } + + val bytesSpilled = + if (hasBytesSpilled) { + Some(TaskTableRowBytesSpilledData( + memoryBytesSpilledSortable, + memoryBytesSpilledReadable, + diskBytesSpilledSortable, + diskBytesSpilledReadable + )) + } else { + None + } + + TaskTableRowData( + info.index, + info.taskId, + info.attempt, + info.speculative, + info.status, + info.taskLocality.toString, + s"${info.executorId} / ${info.host}", + info.launchTime, + duration, + formatDuration, + schedulerDelay, + taskDeserializationTime, + gcTime, + serializationTime, + gettingResultTime, + if (hasAccumulators) Some(accumulatorsReadable.mkString("
    ")) else None, + input, + output, + shuffleRead, + shuffleWrite, + bytesSpilled, + errorMessage.getOrElse("") + ) + } + + /** + * Return Ordering according to sortColumn and desc + */ + private def ordering(sortColumn: String, desc: Boolean): Ordering[TaskTableRowData] = { + val ordering = sortColumn match { + case "Index" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Int.compare(x.index, y.index) + } + case "ID" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.taskId, y.taskId) + } + case "Attempt" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Int.compare(x.attempt, y.attempt) + } + case "Status" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.String.compare(x.status, y.status) + } + case "Locality Level" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.String.compare(x.taskLocality, y.taskLocality) + } + case "Executor ID / Host" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.String.compare(x.executorIdAndHost, y.executorIdAndHost) + } + case "Launch Time" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.launchTime, y.launchTime) + } + case "Duration" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.duration, y.duration) + } + case "Scheduler Delay" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.schedulerDelay, y.schedulerDelay) + } + case "Task Deserialization Time" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.taskDeserializationTime, y.taskDeserializationTime) + } + case "GC Time" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.gcTime, y.gcTime) + } + case "Result Serialization Time" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.serializationTime, y.serializationTime) + } + case "Getting Result Time" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.gettingResultTime, y.gettingResultTime) + } + case "Accumulators" => + if (hasAccumulators) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.String.compare(x.accumulators.get, y.accumulators.get) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Accumulators because of no accumulators") + } + case "Input Size / Records" => + if (hasInput) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.input.get.inputSortable, y.input.get.inputSortable) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Input Size / Records because of no inputs") + } + case "Output Size / Records" => + if (hasOutput) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.output.get.outputSortable, y.output.get.outputSortable) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Output Size / Records because of no outputs") + } + // ShuffleRead + case "Shuffle Read Blocked Time" => + if (hasShuffleRead) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.shuffleRead.get.shuffleReadBlockedTimeSortable, + y.shuffleRead.get.shuffleReadBlockedTimeSortable) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Shuffle Read Blocked Time because of no shuffle reads") + } + case "Shuffle Read Size / Records" => + if (hasShuffleRead) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.shuffleRead.get.shuffleReadSortable, + y.shuffleRead.get.shuffleReadSortable) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Shuffle Read Size / Records because of no shuffle reads") + } + case "Shuffle Remote Reads" => + if (hasShuffleRead) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.shuffleRead.get.shuffleReadRemoteSortable, + y.shuffleRead.get.shuffleReadRemoteSortable) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Shuffle Remote Reads because of no shuffle reads") + } + // ShuffleWrite + case "Write Time" => + if (hasShuffleWrite) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.shuffleWrite.get.writeTimeSortable, + y.shuffleWrite.get.writeTimeSortable) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Write Time because of no shuffle writes") + } + case "Shuffle Write Size / Records" => + if (hasShuffleWrite) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.shuffleWrite.get.shuffleWriteSortable, + y.shuffleWrite.get.shuffleWriteSortable) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Shuffle Write Size / Records because of no shuffle writes") + } + // BytesSpilled + case "Shuffle Spill (Memory)" => + if (hasBytesSpilled) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.bytesSpilled.get.memoryBytesSpilledSortable, + y.bytesSpilled.get.memoryBytesSpilledSortable) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Shuffle Spill (Memory) because of no spills") + } + case "Shuffle Spill (Disk)" => + if (hasBytesSpilled) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.bytesSpilled.get.diskBytesSpilledSortable, + y.bytesSpilled.get.diskBytesSpilledSortable) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Shuffle Spill (Disk) because of no spills") + } + case "Errors" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.String.compare(x.error, y.error) + } + case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn") + } + if (desc) { + ordering.reverse + } else { + ordering + } + } + +} + +private[ui] class TaskPagedTable( + basePath: String, + data: Seq[TaskUIData], + hasAccumulators: Boolean, + hasInput: Boolean, + hasOutput: Boolean, + hasShuffleRead: Boolean, + hasShuffleWrite: Boolean, + hasBytesSpilled: Boolean, + currentTime: Long, + pageSize: Int, + sortColumn: String, + desc: Boolean) extends PagedTable[TaskTableRowData]{ + + override def tableId: String = "" + + override def tableCssClass: String = "table table-bordered table-condensed table-striped" + + override val dataSource: TaskDataSource = new TaskDataSource( + data, + hasAccumulators, + hasInput, + hasOutput, + hasShuffleRead, + hasShuffleWrite, + hasBytesSpilled, + currentTime, + pageSize, + sortColumn, + desc + ) + + override def pageLink(page: Int): String = { + val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") + s"${basePath}&task.page=$page&task.sort=${encodedSortColumn}&task.desc=${desc}" + + s"&task.pageSize=${pageSize}" + } + + override def goButtonJavascriptFunction: (String, String) = { + val jsFuncName = "goToTaskPage" + val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") + val jsFunc = s""" + |currentTaskPageSize = ${pageSize} + |function goToTaskPage(page, pageSize) { + | // Set page to 1 if the page size changes + | page = pageSize == currentTaskPageSize ? page : 1; + | var url = "${basePath}&task.sort=${encodedSortColumn}&task.desc=${desc}" + + | "&task.page=" + page + "&task.pageSize=" + pageSize; + | window.location.href = url; + |} + """.stripMargin + (jsFuncName, jsFunc) + } + + def headers: Seq[Node] = { + val taskHeadersAndCssClasses: Seq[(String, String)] = + Seq( + ("Index", ""), ("ID", ""), ("Attempt", ""), ("Status", ""), ("Locality Level", ""), + ("Executor ID / Host", ""), ("Launch Time", ""), ("Duration", ""), + ("Scheduler Delay", TaskDetailsClassNames.SCHEDULER_DELAY), + ("Task Deserialization Time", TaskDetailsClassNames.TASK_DESERIALIZATION_TIME), + ("GC Time", ""), + ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME), + ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME)) ++ + {if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++ + {if (hasInput) Seq(("Input Size / Records", "")) else Nil} ++ + {if (hasOutput) Seq(("Output Size / Records", "")) else Nil} ++ {if (hasShuffleRead) { - - {shuffleReadBlockedTimeReadable} - - - {s"$shuffleReadReadable / $shuffleReadRecords"} - - - {shuffleReadRemoteReadable} - - }} + Seq(("Shuffle Read Blocked Time", TaskDetailsClassNames.SHUFFLE_READ_BLOCKED_TIME), + ("Shuffle Read Size / Records", ""), + ("Shuffle Remote Reads", TaskDetailsClassNames.SHUFFLE_READ_REMOTE_SIZE)) + } else { + Nil + }} ++ {if (hasShuffleWrite) { - - {writeTimeReadable} - - - {s"$shuffleWriteReadable / $shuffleWriteRecords"} - - }} + Seq(("Write Time", ""), ("Shuffle Write Size / Records", "")) + } else { + Nil + }} ++ {if (hasBytesSpilled) { - - {memoryBytesSpilledReadable} - - - {diskBytesSpilledReadable} - - }} - {errorMessageCell(errorMessage)} - + Seq(("Shuffle Spill (Memory)", ""), ("Shuffle Spill (Disk)", "")) + } else { + Nil + }} ++ + Seq(("Errors", "")) + + if (!taskHeadersAndCssClasses.map(_._1).contains(sortColumn)) { + new IllegalArgumentException(s"Unknown column: $sortColumn") + } + + val headerRow: Seq[Node] = { + taskHeadersAndCssClasses.map { case (header, cssClass) => + if (header == sortColumn) { + val headerLink = + s"$basePath&task.sort=${URLEncoder.encode(header, "UTF-8")}&task.desc=${!desc}" + + s"&task.pageSize=${pageSize}" + val js = Unparsed(s"window.location.href='${headerLink}'") + val arrow = if (desc) "▾" else "▴" // UP or DOWN + + {header} +  {Unparsed(arrow)} + + } else { + val headerLink = + s"$basePath&task.sort=${URLEncoder.encode(header, "UTF-8")}&task.pageSize=${pageSize}" + val js = Unparsed(s"window.location.href='${headerLink}'") + + {header} + + } + } } + {headerRow} } - private def errorMessageCell(errorMessage: Option[String]): Seq[Node] = { - val error = errorMessage.getOrElse("") + def row(task: TaskTableRowData): Seq[Node] = { + + {task.index} + {task.taskId} + {if (task.speculative) s"${task.attempt} (speculative)" else task.attempt.toString} + {task.status} + {task.taskLocality} + {task.executorIdAndHost} + {UIUtils.formatDate(new Date(task.launchTime))} + {task.formatDuration} + + {UIUtils.formatDuration(task.schedulerDelay)} + + + {UIUtils.formatDuration(task.taskDeserializationTime)} + + + {if (task.gcTime > 0) UIUtils.formatDuration(task.gcTime) else ""} + + + {UIUtils.formatDuration(task.serializationTime)} + + + {UIUtils.formatDuration(task.gettingResultTime)} + + {if (task.accumulators.nonEmpty) { + {Unparsed(task.accumulators.get)} + }} + {if (task.input.nonEmpty) { + {task.input.get.inputReadable} + }} + {if (task.output.nonEmpty) { + {task.output.get.outputReadable} + }} + {if (task.shuffleRead.nonEmpty) { + + {task.shuffleRead.get.shuffleReadBlockedTimeReadable} + + {task.shuffleRead.get.shuffleReadReadable} + + {task.shuffleRead.get.shuffleReadRemoteReadable} + + }} + {if (task.shuffleWrite.nonEmpty) { + {task.shuffleWrite.get.writeTimeReadable} + {task.shuffleWrite.get.shuffleWriteReadable} + }} + {if (task.bytesSpilled.nonEmpty) { + {task.bytesSpilled.get.memoryBytesSpilledReadable} + {task.bytesSpilled.get.diskBytesSpilledReadable} + }} + {errorMessageCell(task.error)} + + } + + private def errorMessageCell(error: String): Seq[Node] = { val isMultiline = error.indexOf('\n') >= 0 // Display the first line by default val errorSummary = StringEscapeUtils.escapeHtml4( @@ -851,33 +1325,4 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { } {errorSummary}{details} } - - private def getGettingResultTime(info: TaskInfo): Long = { - if (info.gettingResultTime > 0) { - if (info.finishTime > 0) { - info.finishTime - info.gettingResultTime - } else { - // The task is still fetching the result. - System.currentTimeMillis - info.gettingResultTime - } - } else { - 0L - } - } - - private def getSchedulerDelay(info: TaskInfo, metrics: TaskMetrics): Long = { - val totalExecutionTime = - if (info.gettingResult) { - info.gettingResultTime - info.launchTime - } else if (info.finished) { - info.finishTime - info.launchTime - } else { - 0 - } - val executorOverhead = (metrics.executorDeserializeTime + - metrics.resultSerializationTime) - math.max( - 0, - totalExecutionTime - metrics.executorRunTime - executorOverhead - getGettingResultTime(info)) - } } diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index 07db783c572cf..04f584621e71e 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -21,7 +21,7 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node -import org.apache.spark.storage.RDDInfo +import org.apache.spark.storage._ import org.apache.spark.ui.{UIUtils, WebUIPage} import org.apache.spark.util.Utils @@ -30,13 +30,25 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { - val rdds = listener.rddInfoList - val content = UIUtils.listingTable(rddHeader, rddRow, rdds, id = Some("storage-by-rdd-table")) + val content = rddTable(listener.rddInfoList) ++ + receiverBlockTables(listener.allExecutorStreamBlockStatus.sortBy(_.executorId)) UIUtils.headerSparkPage("Storage", content, parent) } + private[storage] def rddTable(rdds: Seq[RDDInfo]): Seq[Node] = { + if (rdds.isEmpty) { + // Don't show the rdd table if there is no RDD persisted. + Nil + } else { +
    +

    RDDs

    + {UIUtils.listingTable(rddHeader, rddRow, rdds, id = Some("storage-by-rdd-table"))} +
    + } + } + /** Header fields for the RDD table */ - private def rddHeader = Seq( + private val rddHeader = Seq( "RDD Name", "Storage Level", "Cached Partitions", @@ -56,7 +68,7 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { {rdd.storageLevel.description} - {rdd.numCachedPartitions} + {rdd.numCachedPartitions.toString} {"%.0f%%".format(rdd.numCachedPartitions * 100.0 / rdd.numPartitions)} {Utils.bytesToString(rdd.memSize)} {Utils.bytesToString(rdd.externalBlockStoreSize)} @@ -64,4 +76,130 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { // scalastyle:on } + + private[storage] def receiverBlockTables(statuses: Seq[ExecutorStreamBlockStatus]): Seq[Node] = { + if (statuses.map(_.numStreamBlocks).sum == 0) { + // Don't show the tables if there is no stream block + Nil + } else { + val blocks = statuses.flatMap(_.blocks).groupBy(_.blockId).toSeq.sortBy(_._1.toString) + +
    +

    Receiver Blocks

    + {executorMetricsTable(statuses)} + {streamBlockTable(blocks)} +
    + } + } + + private def executorMetricsTable(statuses: Seq[ExecutorStreamBlockStatus]): Seq[Node] = { +
    +
    Aggregated Block Metrics by Executor
    + {UIUtils.listingTable(executorMetricsTableHeader, executorMetricsTableRow, statuses, + id = Some("storage-by-executor-stream-blocks"))} +
    + } + + private val executorMetricsTableHeader = Seq( + "Executor ID", + "Address", + "Total Size in Memory", + "Total Size in ExternalBlockStore", + "Total Size on Disk", + "Stream Blocks") + + private def executorMetricsTableRow(status: ExecutorStreamBlockStatus): Seq[Node] = { + + + {status.executorId} + + + {status.location} + + + {Utils.bytesToString(status.totalMemSize)} + + + {Utils.bytesToString(status.totalExternalBlockStoreSize)} + + + {Utils.bytesToString(status.totalDiskSize)} + + + {status.numStreamBlocks.toString} + + + } + + private def streamBlockTable(blocks: Seq[(BlockId, Seq[BlockUIData])]): Seq[Node] = { + if (blocks.isEmpty) { + Nil + } else { +
    +
    Blocks
    + {UIUtils.listingTable( + streamBlockTableHeader, + streamBlockTableRow, + blocks, + id = Some("storage-by-block-table"), + sortable = false)} +
    + } + } + + private val streamBlockTableHeader = Seq( + "Block ID", + "Replication Level", + "Location", + "Storage Level", + "Size") + + /** Render a stream block */ + private def streamBlockTableRow(block: (BlockId, Seq[BlockUIData])): Seq[Node] = { + val replications = block._2 + assert(replications.size > 0) // This must be true because it's the result of "groupBy" + if (replications.size == 1) { + streamBlockTableSubrow(block._1, replications.head, replications.size, true) + } else { + streamBlockTableSubrow(block._1, replications.head, replications.size, true) ++ + replications.tail.map(streamBlockTableSubrow(block._1, _, replications.size, false)).flatten + } + } + + private def streamBlockTableSubrow( + blockId: BlockId, block: BlockUIData, replication: Int, firstSubrow: Boolean): Seq[Node] = { + val (storageLevel, size) = streamBlockStorageLevelDescriptionAndSize(block) + + + { + if (firstSubrow) { + + {block.blockId.toString} + + + {replication.toString} + + } + } + {block.location} + {storageLevel} + {Utils.bytesToString(size)} + + } + + private[storage] def streamBlockStorageLevelDescriptionAndSize( + block: BlockUIData): (String, Long) = { + if (block.storageLevel.useDisk) { + ("Disk", block.diskSize) + } else if (block.storageLevel.useMemory && block.storageLevel.deserialized) { + ("Memory", block.memSize) + } else if (block.storageLevel.useMemory && !block.storageLevel.deserialized) { + ("Memory Serialized", block.memSize) + } else if (block.storageLevel.useOffHeap) { + ("External", block.externalBlockStoreSize) + } else { + throw new IllegalStateException(s"Invalid Storage Level: ${block.storageLevel}") + } + } + } diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala index 0351749700962..22e2993b3b5bd 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala @@ -39,7 +39,8 @@ private[ui] class StorageTab(parent: SparkUI) extends SparkUITab(parent, "storag * This class is thread-safe (unlike JobProgressListener) */ @DeveloperApi -class StorageListener(storageStatusListener: StorageStatusListener) extends SparkListener { +class StorageListener(storageStatusListener: StorageStatusListener) extends BlockStatusListener { + private[ui] val _rddInfoMap = mutable.Map[Int, RDDInfo]() // exposed for testing def storageStatusList: Seq[StorageStatus] = storageStatusListener.storageStatusList diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index c179833e5b06a..78e7ddc27d1c7 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -128,7 +128,7 @@ private[spark] object AkkaUtils extends Logging { /** Returns the configured max frame size for Akka messages in bytes. */ def maxFrameSizeBytes(conf: SparkConf): Int = { - val frameSizeInMB = conf.getInt("spark.akka.frameSize", 10) + val frameSizeInMB = conf.getInt("spark.akka.frameSize", 128) if (frameSizeInMB > AKKA_MAX_FRAME_SIZE_IN_MB) { throw new IllegalArgumentException( s"spark.akka.frameSize should not be greater than $AKKA_MAX_FRAME_SIZE_IN_MB MB") diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 305de4c75539d..ebead830c6466 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -49,45 +49,28 @@ private[spark] object ClosureCleaner extends Logging { cls.getName.contains("$anonfun$") } - // Get a list of the classes of the outer objects of a given closure object, obj; + // Get a list of the outer objects and their classes of a given closure object, obj; // the outer objects are defined as any closures that obj is nested within, plus // possibly the class that the outermost closure is in, if any. We stop searching // for outer objects beyond that because cloning the user's object is probably // not a good idea (whereas we can clone closure objects just fine since we // understand how all their fields are used). - private def getOuterClasses(obj: AnyRef): List[Class[_]] = { + private def getOuterClassesAndObjects(obj: AnyRef): (List[Class[_]], List[AnyRef]) = { for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") { f.setAccessible(true) val outer = f.get(obj) // The outer pointer may be null if we have cleaned this closure before if (outer != null) { if (isClosure(f.getType)) { - return f.getType :: getOuterClasses(outer) + val recurRet = getOuterClassesAndObjects(outer) + return (f.getType :: recurRet._1, outer :: recurRet._2) } else { - return f.getType :: Nil // Stop at the first $outer that is not a closure + return (f.getType :: Nil, outer :: Nil) // Stop at the first $outer that is not a closure } } } - Nil + (Nil, Nil) } - - // Get a list of the outer objects for a given closure object. - private def getOuterObjects(obj: AnyRef): List[AnyRef] = { - for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") { - f.setAccessible(true) - val outer = f.get(obj) - // The outer pointer may be null if we have cleaned this closure before - if (outer != null) { - if (isClosure(f.getType)) { - return outer :: getOuterObjects(outer) - } else { - return outer :: Nil // Stop at the first $outer that is not a closure - } - } - } - Nil - } - /** * Return a list of classes that represent closures enclosed in the given closure object. */ @@ -205,8 +188,7 @@ private[spark] object ClosureCleaner extends Logging { // A list of enclosing objects and their respective classes, from innermost to outermost // An outer object at a given index is of type outer class at the same index - val outerClasses = getOuterClasses(func) - val outerObjects = getOuterObjects(func) + val (outerClasses, outerObjects) = getOuterClassesAndObjects(func) // For logging purposes only val declaredFields = func.getClass.getDeclaredFields @@ -448,10 +430,12 @@ private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM if (op == INVOKESPECIAL && name == "" && argTypes.length > 0 && argTypes(0).toString.startsWith("L") // is it an object? && argTypes(0).getInternalName == myName) { + // scalastyle:off classforname output += Class.forName( owner.replace('/', '.'), false, Thread.currentThread.getContextClassLoader) + // scalastyle:on classforname } } } diff --git a/core/src/main/scala/org/apache/spark/util/Distribution.scala b/core/src/main/scala/org/apache/spark/util/Distribution.scala index 1bab707235b89..950b69f7db641 100644 --- a/core/src/main/scala/org/apache/spark/util/Distribution.scala +++ b/core/src/main/scala/org/apache/spark/util/Distribution.scala @@ -52,9 +52,11 @@ private[spark] class Distribution(val data: Array[Double], val startIdx: Int, va } def showQuantiles(out: PrintStream = System.out): Unit = { + // scalastyle:off println out.println("min\t25%\t50%\t75%\tmax") getQuantiles(defaultProbabilities).foreach{q => out.print(q + "\t")} out.println + // scalastyle:on println } def statCounter: StatCounter = StatCounter(data.slice(startIdx, endIdx)) @@ -64,8 +66,10 @@ private[spark] class Distribution(val data: Array[Double], val startIdx: Int, va * @param out */ def summary(out: PrintStream = System.out) { + // scalastyle:off println out.println(statCounter) showQuantiles(out) + // scalastyle:on println } } @@ -80,8 +84,10 @@ private[spark] object Distribution { } def showQuantiles(out: PrintStream = System.out, quantiles: Traversable[Double]) { + // scalastyle:off println out.println("min\t25%\t50%\t75%\tmax") quantiles.foreach{q => out.print(q + "\t")} out.println + // scalastyle:on println } } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index adf69a4e78e71..c600319d9ddb4 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -92,8 +92,10 @@ private[spark] object JsonProtocol { executorRemovedToJson(executorRemoved) case logStart: SparkListenerLogStart => logStartToJson(logStart) - // These aren't used, but keeps compiler happy - case SparkListenerExecutorMetricsUpdate(_, _) => JNothing + case metricsUpdate: SparkListenerExecutorMetricsUpdate => + executorMetricsUpdateToJson(metricsUpdate) + case blockUpdated: SparkListenerBlockUpdated => + throw new MatchError(blockUpdated) // TODO(ekl) implement this } } @@ -224,6 +226,19 @@ private[spark] object JsonProtocol { ("Spark Version" -> SPARK_VERSION) } + def executorMetricsUpdateToJson(metricsUpdate: SparkListenerExecutorMetricsUpdate): JValue = { + val execId = metricsUpdate.execId + val taskMetrics = metricsUpdate.taskMetrics + ("Event" -> Utils.getFormattedClassName(metricsUpdate)) ~ + ("Executor ID" -> execId) ~ + ("Metrics Updated" -> taskMetrics.map { case (taskId, stageId, stageAttemptId, metrics) => + ("Task ID" -> taskId) ~ + ("Stage ID" -> stageId) ~ + ("Stage Attempt ID" -> stageAttemptId) ~ + ("Task Metrics" -> taskMetricsToJson(metrics)) + }) + } + /** ------------------------------------------------------------------- * * JSON serialization methods for classes SparkListenerEvents depend on | * -------------------------------------------------------------------- */ @@ -463,6 +478,7 @@ private[spark] object JsonProtocol { val executorAdded = Utils.getFormattedClassName(SparkListenerExecutorAdded) val executorRemoved = Utils.getFormattedClassName(SparkListenerExecutorRemoved) val logStart = Utils.getFormattedClassName(SparkListenerLogStart) + val metricsUpdate = Utils.getFormattedClassName(SparkListenerExecutorMetricsUpdate) (json \ "Event").extract[String] match { case `stageSubmitted` => stageSubmittedFromJson(json) @@ -481,6 +497,7 @@ private[spark] object JsonProtocol { case `executorAdded` => executorAddedFromJson(json) case `executorRemoved` => executorRemovedFromJson(json) case `logStart` => logStartFromJson(json) + case `metricsUpdate` => executorMetricsUpdateFromJson(json) } } @@ -598,6 +615,18 @@ private[spark] object JsonProtocol { SparkListenerLogStart(sparkVersion) } + def executorMetricsUpdateFromJson(json: JValue): SparkListenerExecutorMetricsUpdate = { + val execInfo = (json \ "Executor ID").extract[String] + val taskMetrics = (json \ "Metrics Updated").extract[List[JValue]].map { json => + val taskId = (json \ "Task ID").extract[Long] + val stageId = (json \ "Stage ID").extract[Int] + val stageAttemptId = (json \ "Stage Attempt ID").extract[Int] + val metrics = taskMetricsFromJson(json \ "Task Metrics") + (taskId, stageId, stageAttemptId, metrics) + } + SparkListenerExecutorMetricsUpdate(execInfo, taskMetrics) + } + /** --------------------------------------------------------------------- * * JSON deserialization methods for classes SparkListenerEvents depend on | * ---------------------------------------------------------------------- */ diff --git a/core/src/main/scala/org/apache/spark/util/SerializableConfiguration.scala b/core/src/main/scala/org/apache/spark/util/SerializableConfiguration.scala index 30bcf1d2f24d5..3354a923273ff 100644 --- a/core/src/main/scala/org/apache/spark/util/SerializableConfiguration.scala +++ b/core/src/main/scala/org/apache/spark/util/SerializableConfiguration.scala @@ -20,8 +20,6 @@ import java.io.{ObjectInputStream, ObjectOutputStream} import org.apache.hadoop.conf.Configuration -import org.apache.spark.util.Utils - private[spark] class SerializableConfiguration(@transient var value: Configuration) extends Serializable { private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { diff --git a/core/src/main/scala/org/apache/spark/util/SerializableJobConf.scala b/core/src/main/scala/org/apache/spark/util/SerializableJobConf.scala index afbcc6efc850c..cadae472b3f85 100644 --- a/core/src/main/scala/org/apache/spark/util/SerializableJobConf.scala +++ b/core/src/main/scala/org/apache/spark/util/SerializableJobConf.scala @@ -21,8 +21,6 @@ import java.io.{ObjectInputStream, ObjectOutputStream} import org.apache.hadoop.mapred.JobConf -import org.apache.spark.util.Utils - private[spark] class SerializableJobConf(@transient var value: JobConf) extends Serializable { private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index 0180399c9dad5..14b1f2a17e707 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -124,9 +124,11 @@ object SizeEstimator extends Logging { val server = ManagementFactory.getPlatformMBeanServer() // NOTE: This should throw an exception in non-Sun JVMs + // scalastyle:off classforname val hotSpotMBeanClass = Class.forName("com.sun.management.HotSpotDiagnosticMXBean") val getVMMethod = hotSpotMBeanClass.getDeclaredMethod("getVMOption", Class.forName("java.lang.String")) + // scalastyle:on classforname val bean = ManagementFactory.newPlatformMXBeanProxy(server, hotSpotMBeanName, hotSpotMBeanClass) @@ -215,10 +217,10 @@ object SizeEstimator extends Logging { var arrSize: Long = alignSize(objectSize + INT_SIZE) if (elementClass.isPrimitive) { - arrSize += alignSize(length * primitiveSize(elementClass)) + arrSize += alignSize(length.toLong * primitiveSize(elementClass)) state.size += arrSize } else { - arrSize += alignSize(length * pointerSize) + arrSize += alignSize(length.toLong * pointerSize) state.size += arrSize if (length <= ARRAY_SIZE_FOR_SAMPLING) { @@ -334,7 +336,7 @@ object SizeEstimator extends Logging { // hg.openjdk.java.net/jdk8/jdk8/hotspot/file/tip/src/share/vm/classfile/classFileParser.cpp var alignedSize = shellSize for (size <- fieldSizes if sizeCount(size) > 0) { - val count = sizeCount(size) + val count = sizeCount(size).toLong // If there are internal gaps, smaller field can fit in. alignedSize = math.max(alignedSize, alignSizeUp(shellSize, size) + size * count) shellSize += size * count diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 944560a91354a..c4012d0e83f7d 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -113,8 +113,11 @@ private[spark] object Utils extends Logging { def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = { val bis = new ByteArrayInputStream(bytes) val ois = new ObjectInputStream(bis) { - override def resolveClass(desc: ObjectStreamClass): Class[_] = + override def resolveClass(desc: ObjectStreamClass): Class[_] = { + // scalastyle:off classforname Class.forName(desc.getName, false, loader) + // scalastyle:on classforname + } } ois.readObject.asInstanceOf[T] } @@ -177,12 +180,16 @@ private[spark] object Utils extends Logging { /** Determines whether the provided class is loadable in the current thread. */ def classIsLoadable(clazz: String): Boolean = { + // scalastyle:off classforname Try { Class.forName(clazz, false, getContextOrSparkClassLoader) }.isSuccess + // scalastyle:on classforname } + // scalastyle:off classforname /** Preferred alternative to Class.forName(className) */ def classForName(className: String): Class[_] = { Class.forName(className, true, getContextOrSparkClassLoader) + // scalastyle:on classforname } /** @@ -436,11 +443,11 @@ private[spark] object Utils extends Logging { val lockFileName = s"${url.hashCode}${timestamp}_lock" val localDir = new File(getLocalDir(conf)) val lockFile = new File(localDir, lockFileName) - val raf = new RandomAccessFile(lockFile, "rw") + val lockFileChannel = new RandomAccessFile(lockFile, "rw").getChannel() // Only one executor entry. // The FileLock is only used to control synchronization for executors download file, // it's always safe regardless of lock type (mandatory or advisory). - val lock = raf.getChannel().lock() + val lock = lockFileChannel.lock() val cachedFile = new File(localDir, cachedFileName) try { if (!cachedFile.exists()) { @@ -448,6 +455,7 @@ private[spark] object Utils extends Logging { } } finally { lock.release() + lockFileChannel.close() } copyFile( url, @@ -733,7 +741,12 @@ private[spark] object Utils extends Logging { localRootDirs } - private def getOrCreateLocalRootDirsImpl(conf: SparkConf): Array[String] = { + /** + * Return the configured local directories where Spark can write files. This + * method does not create any directories on its own, it only encapsulates the + * logic of locating the local directories according to deployment mode. + */ + def getConfiguredLocalDirs(conf: SparkConf): Array[String] = { if (isRunningInYarnContainer(conf)) { // If we are in yarn mode, systems can have different disk layouts so we must set it // to what Yarn on this system said was available. Note this assumes that Yarn has @@ -749,27 +762,29 @@ private[spark] object Utils extends Logging { Option(conf.getenv("SPARK_LOCAL_DIRS")) .getOrElse(conf.get("spark.local.dir", System.getProperty("java.io.tmpdir"))) .split(",") - .flatMap { root => - try { - val rootDir = new File(root) - if (rootDir.exists || rootDir.mkdirs()) { - val dir = createTempDir(root) - chmod700(dir) - Some(dir.getAbsolutePath) - } else { - logError(s"Failed to create dir in $root. Ignoring this directory.") - None - } - } catch { - case e: IOException => - logError(s"Failed to create local root dir in $root. Ignoring this directory.") - None - } - } - .toArray } } + private def getOrCreateLocalRootDirsImpl(conf: SparkConf): Array[String] = { + getConfiguredLocalDirs(conf).flatMap { root => + try { + val rootDir = new File(root) + if (rootDir.exists || rootDir.mkdirs()) { + val dir = createTempDir(root) + chmod700(dir) + Some(dir.getAbsolutePath) + } else { + logError(s"Failed to create dir in $root. Ignoring this directory.") + None + } + } catch { + case e: IOException => + logError(s"Failed to create local root dir in $root. Ignoring this directory.") + None + } + }.toArray + } + /** Get the Yarn approved local directories. */ private def getYarnLocalDirs(conf: SparkConf): String = { // Hadoop 0.23 and 2.x have different Environment variable names for the @@ -1572,6 +1587,34 @@ private[spark] object Utils extends Logging { hashAbs } + /** + * NaN-safe version of [[java.lang.Double.compare()]] which allows NaN values to be compared + * according to semantics where NaN == NaN and NaN > any non-NaN double. + */ + def nanSafeCompareDoubles(x: Double, y: Double): Int = { + val xIsNan: Boolean = java.lang.Double.isNaN(x) + val yIsNan: Boolean = java.lang.Double.isNaN(y) + if ((xIsNan && yIsNan) || (x == y)) 0 + else if (xIsNan) 1 + else if (yIsNan) -1 + else if (x > y) 1 + else -1 + } + + /** + * NaN-safe version of [[java.lang.Float.compare()]] which allows NaN values to be compared + * according to semantics where NaN == NaN and NaN > any non-NaN float. + */ + def nanSafeCompareFloats(x: Float, y: Float): Int = { + val xIsNan: Boolean = java.lang.Float.isNaN(x) + val yIsNan: Boolean = java.lang.Float.isNaN(y) + if ((xIsNan && yIsNan) || (x == y)) 0 + else if (xIsNan) 1 + else if (yIsNan) -1 + else if (x > y) 1 + else -1 + } + /** Returns the system properties map that is thread-safe to iterator over. It gets the * properties which have been set explicitly, as well as those for which only a default value * has been defined. */ @@ -2259,7 +2302,7 @@ private [util] class SparkShutdownHookManager { val hookTask = new Runnable() { override def run(): Unit = runAll() } - Try(Class.forName("org.apache.hadoop.util.ShutdownHookManager")) match { + Try(Utils.classForName("org.apache.hadoop.util.ShutdownHookManager")) match { case Success(shmClass) => val fsPriority = classOf[FileSystem].getField("SHUTDOWN_HOOK_PRIORITY").get() .asInstanceOf[Int] diff --git a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala index 516aaa44d03fc..ae60f3b0cb555 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala @@ -37,7 +37,7 @@ private[spark] class ChainedBuffer(chunkSize: Int) { private var _size: Long = 0 /** - * Feed bytes from this buffer into a BlockObjectWriter. + * Feed bytes from this buffer into a DiskBlockObjectWriter. * * @param pos Offset in the buffer to read from. * @param os OutputStream to read into. diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 1e4531ef395ae..d166037351c31 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -26,7 +26,7 @@ import scala.collection.mutable.ArrayBuffer import com.google.common.io.ByteStreams -import org.apache.spark.{Logging, SparkEnv} +import org.apache.spark.{Logging, SparkEnv, TaskContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.serializer.{DeserializationStream, Serializer} import org.apache.spark.storage.{BlockId, BlockManager} @@ -470,14 +470,27 @@ class ExternalAppendOnlyMap[K, V, C]( item } - // TODO: Ensure this gets called even if the iterator isn't drained. private def cleanup() { batchIndex = batchOffsets.length // Prevent reading any other batch val ds = deserializeStream - deserializeStream = null - fileStream = null - ds.close() - file.delete() + if (ds != null) { + ds.close() + deserializeStream = null + } + if (fileStream != null) { + fileStream.close() + fileStream = null + } + if (file.exists()) { + file.delete() + } + } + + val context = TaskContext.get() + // context is null in some tests of ExternalAppendOnlyMapSuite because these tests don't run in + // a TaskContext. + if (context != null) { + context.addTaskCompletionListener(context => cleanup()) } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 757dec66c203b..ba7ec834d622d 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -30,7 +30,7 @@ import org.apache.spark._ import org.apache.spark.serializer._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.shuffle.sort.{SortShuffleFileWriter, SortShuffleWriter} -import org.apache.spark.storage.{BlockId, BlockObjectWriter} +import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter} /** * Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner @@ -250,7 +250,7 @@ private[spark] class ExternalSorter[K, V, C]( // These variables are reset after each flush var objectsWritten: Long = 0 var spillMetrics: ShuffleWriteMetrics = null - var writer: BlockObjectWriter = null + var writer: DiskBlockObjectWriter = null def openWriter(): Unit = { assert (writer == null && spillMetrics == null) spillMetrics = new ShuffleWriteMetrics diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala index 04bb7fc78c13b..f5844d5353be7 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala @@ -19,7 +19,6 @@ package org.apache.spark.util.collection import java.util.Comparator -import org.apache.spark.storage.BlockObjectWriter import org.apache.spark.util.collection.WritablePartitionedPairCollection._ /** diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala index ae9a48729e201..87a786b02d651 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala @@ -21,9 +21,8 @@ import java.io.InputStream import java.nio.IntBuffer import java.util.Comparator -import org.apache.spark.SparkEnv import org.apache.spark.serializer.{JavaSerializerInstance, SerializerInstance} -import org.apache.spark.storage.BlockObjectWriter +import org.apache.spark.storage.DiskBlockObjectWriter import org.apache.spark.util.collection.PartitionedSerializedPairBuffer._ /** @@ -136,7 +135,7 @@ private[spark] class PartitionedSerializedPairBuffer[K, V]( // current position in the meta buffer in ints var pos = 0 - def writeNext(writer: BlockObjectWriter): Unit = { + def writeNext(writer: DiskBlockObjectWriter): Unit = { val keyStart = getKeyStartPos(metaBuffer, pos) val keyValLen = metaBuffer.get(pos + KEY_VAL_LEN) pos += RECORD_SIZE diff --git a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala index 7bc59898658e4..38848e9018c6c 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala @@ -19,7 +19,7 @@ package org.apache.spark.util.collection import java.util.Comparator -import org.apache.spark.storage.BlockObjectWriter +import org.apache.spark.storage.DiskBlockObjectWriter /** * A common interface for size-tracking collections of key-value pairs that @@ -51,7 +51,7 @@ private[spark] trait WritablePartitionedPairCollection[K, V] { new WritablePartitionedIterator { private[this] var cur = if (it.hasNext) it.next() else null - def writeNext(writer: BlockObjectWriter): Unit = { + def writeNext(writer: DiskBlockObjectWriter): Unit = { writer.write(cur._1._2, cur._2) cur = if (it.hasNext) it.next() else null } @@ -91,11 +91,11 @@ private[spark] object WritablePartitionedPairCollection { } /** - * Iterator that writes elements to a BlockObjectWriter instead of returning them. Each element + * Iterator that writes elements to a DiskBlockObjectWriter instead of returning them. Each element * has an associated partition. */ private[spark] trait WritablePartitionedIterator { - def writeNext(writer: BlockObjectWriter): Unit + def writeNext(writer: DiskBlockObjectWriter): Unit def hasNext(): Boolean diff --git a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala index c4a7b4441c85c..85fb923cd9bc7 100644 --- a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala +++ b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala @@ -70,12 +70,14 @@ private[spark] object XORShiftRandom { * @param args takes one argument - the number of random numbers to generate */ def main(args: Array[String]): Unit = { + // scalastyle:off println if (args.length != 1) { println("Benchmark of XORShiftRandom vis-a-vis java.util.Random") println("Usage: XORShiftRandom number_of_random_numbers_to_generate") System.exit(1) } println(benchmark(args(0).toInt)) + // scalastyle:on println } /** diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index dfd86d3e51e7d..e948ca33471a4 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -1011,7 +1011,7 @@ public void persist() { @Test public void iterator() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); - TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, false, new TaskMetrics()); + TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, null, false, new TaskMetrics()); Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue()); } @@ -1783,7 +1783,7 @@ public void testGuavaOptional() { // Stop the context created in setUp() and start a local-cluster one, to force usage of the // assembly. sc.stop(); - JavaSparkContext localCluster = new JavaSparkContext("local-cluster[1,1,512]", "JavaAPISuite"); + JavaSparkContext localCluster = new JavaSparkContext("local-cluster[1,1,1024]", "JavaAPISuite"); try { JavaRDD rdd1 = localCluster.parallelize(Arrays.asList(1, 2, null), 3); JavaRDD> rdd2 = rdd1.map( diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index 10c3eedbf4b46..04fc09b323dbb 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -111,7 +111,7 @@ public void setUp() throws IOException { mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir); partitionSizesInMergedFile = null; spillFilesCreated.clear(); - conf = new SparkConf(); + conf = new SparkConf().set("spark.buffer.pageSize", "128m"); taskMetrics = new TaskMetrics(); when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg()); @@ -512,12 +512,12 @@ public void close() { } writer.insertRecordIntoSorter(new Tuple2(new byte[1], new byte[1])); writer.forceSorterToSpill(); // We should be able to write a record that's right _at_ the max record size - final byte[] atMaxRecordSize = new byte[UnsafeShuffleExternalSorter.MAX_RECORD_SIZE]; + final byte[] atMaxRecordSize = new byte[writer.maxRecordSizeBytes()]; new Random(42).nextBytes(atMaxRecordSize); writer.insertRecordIntoSorter(new Tuple2(new byte[0], atMaxRecordSize)); writer.forceSorterToSpill(); // Inserting a record that's larger than the max record size should fail: - final byte[] exceedsMaxRecordSize = new byte[UnsafeShuffleExternalSorter.MAX_RECORD_SIZE + 1]; + final byte[] exceedsMaxRecordSize = new byte[writer.maxRecordSizeBytes() + 1]; new Random(42).nextBytes(exceedsMaxRecordSize); Product2 hugeRecord = new Tuple2(new byte[0], exceedsMaxRecordSize); diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java new file mode 100644 index 0000000000000..0e391b751226d --- /dev/null +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.File; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.UUID; + +import scala.Tuple2; +import scala.Tuple2$; +import scala.runtime.AbstractFunction1; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import static org.junit.Assert.*; +import static org.mockito.AdditionalAnswers.returnsFirstArg; +import static org.mockito.AdditionalAnswers.returnsSecondArg; +import static org.mockito.Answers.RETURNS_SMART_NULLS; +import static org.mockito.Mockito.*; + +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.storage.*; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.ExecutorMemoryManager; +import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.util.Utils; + +public class UnsafeExternalSorterSuite { + + final TaskMemoryManager memoryManager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + // Use integer comparison for comparing prefixes (which are partition ids, in this case) + final PrefixComparator prefixComparator = new PrefixComparator() { + @Override + public int compare(long prefix1, long prefix2) { + return (int) prefix1 - (int) prefix2; + } + }; + // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so + // use a dummy comparator + final RecordComparator recordComparator = new RecordComparator() { + @Override + public int compare( + Object leftBaseObject, + long leftBaseOffset, + Object rightBaseObject, + long rightBaseOffset) { + return 0; + } + }; + + @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager; + @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; + @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager; + @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext; + + File tempDir; + + private static final class CompressStream extends AbstractFunction1 { + @Override + public OutputStream apply(OutputStream stream) { + return stream; + } + } + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + tempDir = new File(Utils.createTempDir$default$1()); + taskContext = mock(TaskContext.class); + when(taskContext.taskMetrics()).thenReturn(new TaskMetrics()); + when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg()); + when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); + when(diskBlockManager.createTempLocalBlock()).thenAnswer(new Answer>() { + @Override + public Tuple2 answer(InvocationOnMock invocationOnMock) throws Throwable { + TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID()); + File file = File.createTempFile("spillFile", ".spill", tempDir); + return Tuple2$.MODULE$.apply(blockId, file); + } + }); + when(blockManager.getDiskWriter( + any(BlockId.class), + any(File.class), + any(SerializerInstance.class), + anyInt(), + any(ShuffleWriteMetrics.class))).thenAnswer(new Answer() { + @Override + public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable { + Object[] args = invocationOnMock.getArguments(); + + return new DiskBlockObjectWriter( + (BlockId) args[0], + (File) args[1], + (SerializerInstance) args[2], + (Integer) args[3], + new CompressStream(), + false, + (ShuffleWriteMetrics) args[4] + ); + } + }); + when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))) + .then(returnsSecondArg()); + } + + private static void insertNumber(UnsafeExternalSorter sorter, int value) throws Exception { + final int[] arr = new int[] { value }; + sorter.insertRecord(arr, PlatformDependent.INT_ARRAY_OFFSET, 4, value); + } + + @Test + public void testSortingOnlyByPrefix() throws Exception { + + final UnsafeExternalSorter sorter = new UnsafeExternalSorter( + memoryManager, + shuffleMemoryManager, + blockManager, + taskContext, + recordComparator, + prefixComparator, + 1024, + new SparkConf()); + + insertNumber(sorter, 5); + insertNumber(sorter, 1); + insertNumber(sorter, 3); + sorter.spill(); + insertNumber(sorter, 4); + sorter.spill(); + insertNumber(sorter, 2); + + UnsafeSorterIterator iter = sorter.getSortedIterator(); + + for (int i = 1; i <= 5; i++) { + iter.loadNext(); + assertEquals(i, iter.getKeyPrefix()); + assertEquals(4, iter.getRecordLength()); + // TODO: read rest of value. + } + + // TODO: test for cleanup: + // assert(tempDir.isEmpty) + } + + @Test + public void testSortingEmptyArrays() throws Exception { + + final UnsafeExternalSorter sorter = new UnsafeExternalSorter( + memoryManager, + shuffleMemoryManager, + blockManager, + taskContext, + recordComparator, + prefixComparator, + 1024, + new SparkConf()); + + sorter.insertRecord(null, 0, 0, 0); + sorter.insertRecord(null, 0, 0, 0); + sorter.spill(); + sorter.insertRecord(null, 0, 0, 0); + sorter.spill(); + sorter.insertRecord(null, 0, 0, 0); + sorter.insertRecord(null, 0, 0, 0); + + UnsafeSorterIterator iter = sorter.getSortedIterator(); + + for (int i = 1; i <= 5; i++) { + iter.loadNext(); + assertEquals(0, iter.getKeyPrefix()); + assertEquals(0, iter.getRecordLength()); + } + } + + @Test + public void testFillingPage() throws Exception { + final UnsafeExternalSorter sorter = new UnsafeExternalSorter( + memoryManager, + shuffleMemoryManager, + blockManager, + taskContext, + recordComparator, + prefixComparator, + 1024, + new SparkConf()); + + byte[] record = new byte[16]; + while (sorter.getNumberOfAllocatedPages() < 2) { + sorter.insertRecord(record, PlatformDependent.BYTE_ARRAY_OFFSET, record.length, 0); + } + sorter.freeMemory(); + } + +} diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java new file mode 100644 index 0000000000000..909500930539c --- /dev/null +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.util.Arrays; + +import org.junit.Test; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.*; +import static org.junit.Assert.*; +import static org.mockito.Mockito.mock; + +import org.apache.spark.HashPartitioner; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.ExecutorMemoryManager; +import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.TaskMemoryManager; + +public class UnsafeInMemorySorterSuite { + + private static String getStringFromDataPage(Object baseObject, long baseOffset, int length) { + final byte[] strBytes = new byte[length]; + PlatformDependent.copyMemory( + baseObject, + baseOffset, + strBytes, + PlatformDependent.BYTE_ARRAY_OFFSET, length); + return new String(strBytes); + } + + @Test + public void testSortingEmptyInput() { + final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter( + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)), + mock(RecordComparator.class), + mock(PrefixComparator.class), + 100); + final UnsafeSorterIterator iter = sorter.getSortedIterator(); + assert(!iter.hasNext()); + } + + @Test + public void testSortingOnlyByIntegerPrefix() throws Exception { + final String[] dataToSort = new String[] { + "Boba", + "Pearls", + "Tapioca", + "Taho", + "Condensed Milk", + "Jasmine", + "Milk Tea", + "Lychee", + "Mango" + }; + final TaskMemoryManager memoryManager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + final MemoryBlock dataPage = memoryManager.allocatePage(2048); + final Object baseObject = dataPage.getBaseObject(); + // Write the records into the data page: + long position = dataPage.getBaseOffset(); + for (String str : dataToSort) { + final byte[] strBytes = str.getBytes("utf-8"); + PlatformDependent.UNSAFE.putInt(baseObject, position, strBytes.length); + position += 4; + PlatformDependent.copyMemory( + strBytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + baseObject, + position, + strBytes.length); + position += strBytes.length; + } + // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so + // use a dummy comparator + final RecordComparator recordComparator = new RecordComparator() { + @Override + public int compare( + Object leftBaseObject, + long leftBaseOffset, + Object rightBaseObject, + long rightBaseOffset) { + return 0; + } + }; + // Compute key prefixes based on the records' partition ids + final HashPartitioner hashPartitioner = new HashPartitioner(4); + // Use integer comparison for comparing prefixes (which are partition ids, in this case) + final PrefixComparator prefixComparator = new PrefixComparator() { + @Override + public int compare(long prefix1, long prefix2) { + return (int) prefix1 - (int) prefix2; + } + }; + UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(memoryManager, recordComparator, + prefixComparator, dataToSort.length); + // Given a page of records, insert those records into the sorter one-by-one: + position = dataPage.getBaseOffset(); + for (int i = 0; i < dataToSort.length; i++) { + // position now points to the start of a record (which holds its length). + final int recordLength = PlatformDependent.UNSAFE.getInt(baseObject, position); + final long address = memoryManager.encodePageNumberAndOffset(dataPage, position); + final String str = getStringFromDataPage(baseObject, position + 4, recordLength); + final int partitionId = hashPartitioner.getPartition(str); + sorter.insertRecord(address, partitionId); + position += 4 + recordLength; + } + final UnsafeSorterIterator iter = sorter.getSortedIterator(); + int iterLength = 0; + long prevPrefix = -1; + Arrays.sort(dataToSort); + while (iter.hasNext()) { + iter.loadNext(); + final String str = + getStringFromDataPage(iter.getBaseObject(), iter.getBaseOffset(), iter.getRecordLength()); + final long keyPrefix = iter.getKeyPrefix(); + assertThat(str, isIn(Arrays.asList(dataToSort))); + assertThat(keyPrefix, greaterThanOrEqualTo(prevPrefix)); + prevPrefix = keyPrefix; + iterLength++; + } + assertEquals(dataToSort.length, iterLength); + } +} diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index af81e46a657d3..618a5fb24710f 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -65,7 +65,7 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before // in blockManager.put is a losing battle. You have been warned. blockManager = sc.env.blockManager cacheManager = sc.env.cacheManager - val context = new TaskContextImpl(0, 0, 0, 0, null) + val context = new TaskContextImpl(0, 0, 0, 0, null, null) val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) val getValue = blockManager.get(RDDBlockId(rdd.id, split.index)) assert(computeValue.toList === List(1, 2, 3, 4)) @@ -77,7 +77,7 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before val result = new BlockResult(Array(5, 6, 7).iterator, DataReadMethod.Memory, 12) when(blockManager.get(RDDBlockId(0, 0))).thenReturn(Some(result)) - val context = new TaskContextImpl(0, 0, 0, 0, null) + val context = new TaskContextImpl(0, 0, 0, 0, null, null) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(5, 6, 7)) } @@ -86,14 +86,14 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before // Local computation should not persist the resulting value, so don't expect a put(). when(blockManager.get(RDDBlockId(0, 0))).thenReturn(None) - val context = new TaskContextImpl(0, 0, 0, 0, null, true) + val context = new TaskContextImpl(0, 0, 0, 0, null, null, true) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(1, 2, 3, 4)) } test("verify task metrics updated correctly") { cacheManager = sc.env.cacheManager - val context = new TaskContextImpl(0, 0, 0, 0, null) + val context = new TaskContextImpl(0, 0, 0, 0, null, null) cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY) assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2) } diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 501fe186bfd7c..26858ef2774fc 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -292,7 +292,7 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { sc.stop() val conf2 = new SparkConf() - .setMaster("local-cluster[2, 1, 512]") + .setMaster("local-cluster[2, 1, 1024]") .setAppName("ContextCleanerSuite") .set("spark.cleaner.referenceTracking.blocking", "true") .set("spark.cleaner.referenceTracking.blocking.shuffle", "true") @@ -370,7 +370,7 @@ class SortShuffleContextCleanerSuite extends ContextCleanerSuiteBase(classOf[Sor sc.stop() val conf2 = new SparkConf() - .setMaster("local-cluster[2, 1, 512]") + .setMaster("local-cluster[2, 1, 1024]") .setAppName("ContextCleanerSuite") .set("spark.cleaner.referenceTracking.blocking", "true") .set("spark.cleaner.referenceTracking.blocking.shuffle", "true") diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 9c191ed52206d..600c1403b0344 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -29,7 +29,7 @@ class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() { class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContext { - val clusterUrl = "local-cluster[2,1,512]" + val clusterUrl = "local-cluster[2,1,1024]" test("task throws not serializable exception") { // Ensures that executors do not crash when an exn is not serializable. If executors crash, @@ -40,7 +40,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex val numSlaves = 3 val numPartitions = 10 - sc = new SparkContext("local-cluster[%s,1,512]".format(numSlaves), "test") + sc = new SparkContext("local-cluster[%s,1,1024]".format(numSlaves), "test") val data = sc.parallelize(1 to 100, numPartitions). map(x => throw new NotSerializableExn(new NotSerializableClass)) intercept[SparkException] { @@ -50,16 +50,16 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex } test("local-cluster format") { - sc = new SparkContext("local-cluster[2,1,512]", "test") + sc = new SparkContext("local-cluster[2,1,1024]", "test") assert(sc.parallelize(1 to 2, 2).count() == 2) resetSparkContext() - sc = new SparkContext("local-cluster[2 , 1 , 512]", "test") + sc = new SparkContext("local-cluster[2 , 1 , 1024]", "test") assert(sc.parallelize(1 to 2, 2).count() == 2) resetSparkContext() - sc = new SparkContext("local-cluster[2, 1, 512]", "test") + sc = new SparkContext("local-cluster[2, 1, 1024]", "test") assert(sc.parallelize(1 to 2, 2).count() == 2) resetSparkContext() - sc = new SparkContext("local-cluster[ 2, 1, 512 ]", "test") + sc = new SparkContext("local-cluster[ 2, 1, 1024 ]", "test") assert(sc.parallelize(1 to 2, 2).count() == 2) resetSparkContext() } @@ -107,7 +107,9 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex sc = new SparkContext(clusterUrl, "test") val accum = sc.accumulator(0) val thrown = intercept[SparkException] { + // scalastyle:off println sc.parallelize(1 to 10, 10).foreach(x => println(x / 0)) + // scalastyle:on println } assert(thrown.getClass === classOf[SparkException]) assert(thrown.getMessage.contains("failed 4 times")) @@ -274,7 +276,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex DistributedSuite.amMaster = true // Using more than two nodes so we don't have a symmetric communication pattern and might // cache a partially correct list of peers. - sc = new SparkContext("local-cluster[3,1,512]", "test") + sc = new SparkContext("local-cluster[3,1,1024]", "test") for (i <- 1 to 3) { val data = sc.parallelize(Seq(true, false, false, false), 4) data.persist(StorageLevel.MEMORY_ONLY_2) @@ -292,7 +294,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex test("unpersist RDDs") { DistributedSuite.amMaster = true - sc = new SparkContext("local-cluster[3,1,512]", "test") + sc = new SparkContext("local-cluster[3,1,1024]", "test") val data = sc.parallelize(Seq(true, false, false, false), 4) data.persist(StorageLevel.MEMORY_ONLY_2) data.count diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index b2262033ca238..454b7e607a51b 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -29,7 +29,7 @@ class DriverSuite extends SparkFunSuite with Timeouts { ignore("driver should exit after finishing without cleanup (SPARK-530)") { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) - val masters = Table("master", "local", "local-cluster[2,1,512]") + val masters = Table("master", "local", "local-cluster[2,1,1024]") forAll(masters) { (master: String) => val process = Utils.executeCommand( Seq(s"$sparkHome/bin/spark-class", "org.apache.spark.DriverWithoutCleanup", master), diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 803e1831bb269..34caca892891c 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -751,6 +751,42 @@ class ExecutorAllocationManagerSuite assert(numExecutorsTarget(manager) === 2) } + test("get pending task number and related locality preference") { + sc = createSparkContext(2, 5, 3) + val manager = sc.executorAllocationManager.get + + val localityPreferences1 = Seq( + Seq(TaskLocation("host1"), TaskLocation("host2"), TaskLocation("host3")), + Seq(TaskLocation("host1"), TaskLocation("host2"), TaskLocation("host4")), + Seq(TaskLocation("host2"), TaskLocation("host3"), TaskLocation("host4")), + Seq.empty, + Seq.empty + ) + val stageInfo1 = createStageInfo(1, 5, localityPreferences1) + sc.listenerBus.postToAll(SparkListenerStageSubmitted(stageInfo1)) + + assert(localityAwareTasks(manager) === 3) + assert(hostToLocalTaskCount(manager) === + Map("host1" -> 2, "host2" -> 3, "host3" -> 2, "host4" -> 2)) + + val localityPreferences2 = Seq( + Seq(TaskLocation("host2"), TaskLocation("host3"), TaskLocation("host5")), + Seq(TaskLocation("host3"), TaskLocation("host4"), TaskLocation("host5")), + Seq.empty + ) + val stageInfo2 = createStageInfo(2, 3, localityPreferences2) + sc.listenerBus.postToAll(SparkListenerStageSubmitted(stageInfo2)) + + assert(localityAwareTasks(manager) === 5) + assert(hostToLocalTaskCount(manager) === + Map("host1" -> 2, "host2" -> 4, "host3" -> 4, "host4" -> 3, "host5" -> 2)) + + sc.listenerBus.postToAll(SparkListenerStageCompleted(stageInfo1)) + assert(localityAwareTasks(manager) === 2) + assert(hostToLocalTaskCount(manager) === + Map("host2" -> 1, "host3" -> 2, "host4" -> 1, "host5" -> 2)) + } + private def createSparkContext( minExecutors: Int = 1, maxExecutors: Int = 5, @@ -784,8 +820,13 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { private val sustainedSchedulerBacklogTimeout = 2L private val executorIdleTimeout = 3L - private def createStageInfo(stageId: Int, numTasks: Int): StageInfo = { - new StageInfo(stageId, 0, "name", numTasks, Seq.empty, Seq.empty, "no details") + private def createStageInfo( + stageId: Int, + numTasks: Int, + taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty + ): StageInfo = { + new StageInfo( + stageId, 0, "name", numTasks, Seq.empty, Seq.empty, "no details", taskLocalityPreferences) } private def createTaskInfo(taskId: Int, taskIndex: Int, executorId: String): TaskInfo = { @@ -815,6 +856,8 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { private val _onSchedulerQueueEmpty = PrivateMethod[Unit]('onSchedulerQueueEmpty) private val _onExecutorIdle = PrivateMethod[Unit]('onExecutorIdle) private val _onExecutorBusy = PrivateMethod[Unit]('onExecutorBusy) + private val _localityAwareTasks = PrivateMethod[Int]('localityAwareTasks) + private val _hostToLocalTaskCount = PrivateMethod[Map[String, Int]]('hostToLocalTaskCount) private def numExecutorsToAdd(manager: ExecutorAllocationManager): Int = { manager invokePrivate _numExecutorsToAdd() @@ -885,4 +928,12 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { private def onExecutorBusy(manager: ExecutorAllocationManager, id: String): Unit = { manager invokePrivate _onExecutorBusy(id) } + + private def localityAwareTasks(manager: ExecutorAllocationManager): Int = { + manager invokePrivate _localityAwareTasks() + } + + private def hostToLocalTaskCount(manager: ExecutorAllocationManager): Map[String, Int] = { + manager invokePrivate _hostToLocalTaskCount() + } } diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index 140012226fdbb..c38d70252add1 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -51,7 +51,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { // This test ensures that the external shuffle service is actually in use for the other tests. test("using external shuffle service") { - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) sc.env.blockManager.externalShuffleServiceEnabled should equal(true) sc.env.blockManager.shuffleClient.getClass should equal(classOf[ExternalShuffleClient]) diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index a8c8c6f73fb5a..69cb4b44cf7ef 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -130,7 +130,9 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext { // Non-serializable closure in foreach function val thrown2 = intercept[SparkException] { + // scalastyle:off println sc.parallelize(1 to 10, 2).foreach(x => println(a)) + // scalastyle:on println } assert(thrown2.getClass === classOf[SparkException]) assert(thrown2.getMessage.contains("NotSerializableException") || @@ -139,5 +141,30 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext { FailureSuiteState.clear() } + test("managed memory leak error should not mask other failures (SPARK-9266") { + val conf = new SparkConf().set("spark.unsafe.exceptionOnMemoryLeak", "true") + sc = new SparkContext("local[1,1]", "test", conf) + + // If a task leaks memory but fails due to some other cause, then make sure that the original + // cause is preserved + val thrownDueToTaskFailure = intercept[SparkException] { + sc.parallelize(Seq(0)).mapPartitions { iter => + TaskContext.get().taskMemoryManager().allocate(128) + throw new Exception("intentional task failure") + iter + }.count() + } + assert(thrownDueToTaskFailure.getMessage.contains("intentional task failure")) + + // If the task succeeded but memory was leaked, then the task should fail due to that leak + val thrownDueToMemoryLeak = intercept[SparkException] { + sc.parallelize(Seq(0)).mapPartitions { iter => + TaskContext.get().taskMemoryManager().allocate(128) + iter + }.count() + } + assert(thrownDueToMemoryLeak.getMessage.contains("memory leak")) + } + // TODO: Need to add tests with shuffle fetch failures. } diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index 6e65b0a8f6c76..1255e71af6c0b 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -51,7 +51,9 @@ class FileServerSuite extends SparkFunSuite with LocalSparkContext { val textFile = new File(testTempDir, "FileServerSuite.txt") val pw = new PrintWriter(textFile) + // scalastyle:off println pw.println("100") + // scalastyle:on println pw.close() val jarFile = new File(testTempDir, "test.jar") @@ -137,7 +139,7 @@ class FileServerSuite extends SparkFunSuite with LocalSparkContext { } test("Distributing files on a standalone cluster") { - sc = new SparkContext("local-cluster[1,1,512]", "test", newConf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", newConf) sc.addFile(tmpFile.toString) val testData = Array((1, 1), (1, 1), (2, 1), (3, 5), (2, 2), (3, 0)) val result = sc.parallelize(testData).reduceByKey { @@ -151,7 +153,7 @@ class FileServerSuite extends SparkFunSuite with LocalSparkContext { } test ("Dynamically adding JARS on a standalone cluster") { - sc = new SparkContext("local-cluster[1,1,512]", "test", newConf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", newConf) sc.addJar(tmpJarUrl) val testData = Array((1, 1)) sc.parallelize(testData).foreach { x => @@ -162,7 +164,7 @@ class FileServerSuite extends SparkFunSuite with LocalSparkContext { } test ("Dynamically adding JARS on a standalone cluster using local: URL") { - sc = new SparkContext("local-cluster[1,1,512]", "test", newConf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", newConf) sc.addJar(tmpJarUrl.replace("file", "local")) val testData = Array((1, 1)) sc.parallelize(testData).foreach { x => diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 1d8fade90f398..418763f4e5ffa 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -179,6 +179,7 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { } test("object files of classes from a JAR") { + // scalastyle:off classforname val original = Thread.currentThread().getContextClassLoader val className = "FileSuiteObjectFileTest" val jar = TestUtils.createJarWithClasses(Seq(className)) @@ -201,6 +202,7 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { finally { Thread.currentThread().setContextClassLoader(original) } + // scalastyle:on classforname } test("write SequenceFile using new Hadoop API") { diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index b31b09196608f..139b8dc25f4b4 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark +import java.util.concurrent.{ExecutorService, TimeUnit} + +import scala.collection.mutable import scala.language.postfixOps import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester} @@ -25,11 +28,16 @@ import org.mockito.Matchers import org.mockito.Matchers._ import org.apache.spark.executor.TaskMetrics -import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv, RpcEndpointRef} import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.ManualClock +/** + * A test suite for the heartbeating behavior between the driver and the executors. + */ class HeartbeatReceiverSuite extends SparkFunSuite with BeforeAndAfterEach @@ -40,23 +48,40 @@ class HeartbeatReceiverSuite private val executorId2 = "executor-2" // Shared state that must be reset before and after each test - private var scheduler: TaskScheduler = null + private var scheduler: TaskSchedulerImpl = null private var heartbeatReceiver: HeartbeatReceiver = null private var heartbeatReceiverRef: RpcEndpointRef = null private var heartbeatReceiverClock: ManualClock = null + // Helper private method accessors for HeartbeatReceiver + private val _executorLastSeen = PrivateMethod[collection.Map[String, Long]]('executorLastSeen) + private val _executorTimeoutMs = PrivateMethod[Long]('executorTimeoutMs) + private val _killExecutorThread = PrivateMethod[ExecutorService]('killExecutorThread) + + /** + * Before each test, set up the SparkContext and a custom [[HeartbeatReceiver]] + * that uses a manual clock. + */ override def beforeEach(): Unit = { - sc = spy(new SparkContext("local[2]", "test")) - scheduler = mock(classOf[TaskScheduler]) + val conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set("spark.dynamicAllocation.testing", "true") + sc = spy(new SparkContext(conf)) + scheduler = mock(classOf[TaskSchedulerImpl]) when(sc.taskScheduler).thenReturn(scheduler) + when(scheduler.sc).thenReturn(sc) heartbeatReceiverClock = new ManualClock heartbeatReceiver = new HeartbeatReceiver(sc, heartbeatReceiverClock) heartbeatReceiverRef = sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver) when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(true) } + /** + * After each test, clean up all state and stop the [[SparkContext]]. + */ override def afterEach(): Unit = { - resetSparkContext() + super.afterEach() scheduler = null heartbeatReceiver = null heartbeatReceiverRef = null @@ -75,7 +100,7 @@ class HeartbeatReceiverSuite heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null)) triggerHeartbeat(executorId1, executorShouldReregister = false) triggerHeartbeat(executorId2, executorShouldReregister = false) - val trackedExecutors = executorLastSeen(heartbeatReceiver) + val trackedExecutors = heartbeatReceiver.invokePrivate(_executorLastSeen()) assert(trackedExecutors.size === 2) assert(trackedExecutors.contains(executorId1)) assert(trackedExecutors.contains(executorId2)) @@ -83,15 +108,15 @@ class HeartbeatReceiverSuite test("reregister if scheduler is not ready yet") { heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null)) - // Task scheduler not set in HeartbeatReceiver + // Task scheduler is not set yet in HeartbeatReceiver, so executors should reregister triggerHeartbeat(executorId1, executorShouldReregister = true) } test("reregister if heartbeat from unregistered executor") { heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) - // Received heartbeat from unknown receiver, so we ask it to re-register + // Received heartbeat from unknown executor, so we ask it to re-register triggerHeartbeat(executorId1, executorShouldReregister = true) - assert(executorLastSeen(heartbeatReceiver).isEmpty) + assert(heartbeatReceiver.invokePrivate(_executorLastSeen()).isEmpty) } test("reregister if heartbeat from removed executor") { @@ -104,14 +129,14 @@ class HeartbeatReceiverSuite // A heartbeat from the second executor should require reregistering triggerHeartbeat(executorId1, executorShouldReregister = false) triggerHeartbeat(executorId2, executorShouldReregister = true) - val trackedExecutors = executorLastSeen(heartbeatReceiver) + val trackedExecutors = heartbeatReceiver.invokePrivate(_executorLastSeen()) assert(trackedExecutors.size === 1) assert(trackedExecutors.contains(executorId1)) assert(!trackedExecutors.contains(executorId2)) } test("expire dead hosts") { - val executorTimeout = executorTimeoutMs(heartbeatReceiver) + val executorTimeout = heartbeatReceiver.invokePrivate(_executorTimeoutMs()) heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null)) heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null)) @@ -124,12 +149,61 @@ class HeartbeatReceiverSuite heartbeatReceiverRef.askWithRetry[Boolean](ExpireDeadHosts) // Only the second executor should be expired as a dead host verify(scheduler).executorLost(Matchers.eq(executorId2), any()) - val trackedExecutors = executorLastSeen(heartbeatReceiver) + val trackedExecutors = heartbeatReceiver.invokePrivate(_executorLastSeen()) assert(trackedExecutors.size === 1) assert(trackedExecutors.contains(executorId1)) assert(!trackedExecutors.contains(executorId2)) } + test("expire dead hosts should kill executors with replacement (SPARK-8119)") { + // Set up a fake backend and cluster manager to simulate killing executors + val rpcEnv = sc.env.rpcEnv + val fakeClusterManager = new FakeClusterManager(rpcEnv) + val fakeClusterManagerRef = rpcEnv.setupEndpoint("fake-cm", fakeClusterManager) + val fakeSchedulerBackend = new FakeSchedulerBackend(scheduler, rpcEnv, fakeClusterManagerRef) + when(sc.schedulerBackend).thenReturn(fakeSchedulerBackend) + + // Register fake executors with our fake scheduler backend + // This is necessary because the backend refuses to kill executors it does not know about + fakeSchedulerBackend.start() + val dummyExecutorEndpoint1 = new FakeExecutorEndpoint(rpcEnv) + val dummyExecutorEndpoint2 = new FakeExecutorEndpoint(rpcEnv) + val dummyExecutorEndpointRef1 = rpcEnv.setupEndpoint("fake-executor-1", dummyExecutorEndpoint1) + val dummyExecutorEndpointRef2 = rpcEnv.setupEndpoint("fake-executor-2", dummyExecutorEndpoint2) + fakeSchedulerBackend.driverEndpoint.askWithRetry[RegisteredExecutor.type]( + RegisterExecutor(executorId1, dummyExecutorEndpointRef1, "dummy:4040", 0, Map.empty)) + fakeSchedulerBackend.driverEndpoint.askWithRetry[RegisteredExecutor.type]( + RegisterExecutor(executorId2, dummyExecutorEndpointRef2, "dummy:4040", 0, Map.empty)) + heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) + heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null)) + heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null)) + triggerHeartbeat(executorId1, executorShouldReregister = false) + triggerHeartbeat(executorId2, executorShouldReregister = false) + + // Adjust the target number of executors on the cluster manager side + assert(fakeClusterManager.getTargetNumExecutors === 0) + sc.requestTotalExecutors(2, 0, Map.empty) + assert(fakeClusterManager.getTargetNumExecutors === 2) + assert(fakeClusterManager.getExecutorIdsToKill.isEmpty) + + // Expire the executors. This should trigger our fake backend to kill the executors. + // Since the kill request is sent to the cluster manager asynchronously, we need to block + // on the kill thread to ensure that the cluster manager actually received our requests. + // Here we use a timeout of O(seconds), but in practice this whole test takes O(10ms). + val executorTimeout = heartbeatReceiver.invokePrivate(_executorTimeoutMs()) + heartbeatReceiverClock.advance(executorTimeout * 2) + heartbeatReceiverRef.askWithRetry[Boolean](ExpireDeadHosts) + val killThread = heartbeatReceiver.invokePrivate(_killExecutorThread()) + killThread.shutdown() // needed for awaitTermination + killThread.awaitTermination(10L, TimeUnit.SECONDS) + + // The target number of executors should not change! Otherwise, having an expired + // executor means we permanently adjust the target number downwards until we + // explicitly request new executors. For more detail, see SPARK-8119. + assert(fakeClusterManager.getTargetNumExecutors === 2) + assert(fakeClusterManager.getExecutorIdsToKill === Set(executorId1, executorId2)) + } + /** Manually send a heartbeat and return the response. */ private def triggerHeartbeat( executorId: String, @@ -148,14 +222,50 @@ class HeartbeatReceiverSuite } } - // Helper methods to access private fields in HeartbeatReceiver - private val _executorLastSeen = PrivateMethod[collection.Map[String, Long]]('executorLastSeen) - private val _executorTimeoutMs = PrivateMethod[Long]('executorTimeoutMs) - private def executorLastSeen(receiver: HeartbeatReceiver): collection.Map[String, Long] = { - receiver invokePrivate _executorLastSeen() +} + +// TODO: use these classes to add end-to-end tests for dynamic allocation! + +/** + * Dummy RPC endpoint to simulate executors. + */ +private class FakeExecutorEndpoint(override val rpcEnv: RpcEnv) extends RpcEndpoint + +/** + * Dummy scheduler backend to simulate executor allocation requests to the cluster manager. + */ +private class FakeSchedulerBackend( + scheduler: TaskSchedulerImpl, + rpcEnv: RpcEnv, + clusterManagerEndpoint: RpcEndpointRef) + extends CoarseGrainedSchedulerBackend(scheduler, rpcEnv) { + + protected override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { + clusterManagerEndpoint.askWithRetry[Boolean]( + RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount)) } - private def executorTimeoutMs(receiver: HeartbeatReceiver): Long = { - receiver invokePrivate _executorTimeoutMs() + + protected override def doKillExecutors(executorIds: Seq[String]): Boolean = { + clusterManagerEndpoint.askWithRetry[Boolean](KillExecutors(executorIds)) } +} +/** + * Dummy cluster manager to simulate responses to executor allocation requests. + */ +private class FakeClusterManager(override val rpcEnv: RpcEnv) extends RpcEndpoint { + private var targetNumExecutors = 0 + private val executorIdsToKill = new mutable.HashSet[String] + + def getTargetNumExecutors: Int = targetNumExecutors + def getExecutorIdsToKill: Set[String] = executorIdsToKill.toSet + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RequestExecutors(requestedTotal, _, _) => + targetNumExecutors = requestedTotal + context.reply(true) + case KillExecutors(executorIds) => + executorIdsToKill ++= executorIds + context.reply(true) + } } diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 340a9e327107e..1168eb0b802f2 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -64,7 +64,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft test("cluster mode, FIFO scheduler") { val conf = new SparkConf().set("spark.scheduler.mode", "FIFO") - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) testCount() testTake() // Make sure we can still launch tasks. @@ -75,7 +75,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft val conf = new SparkConf().set("spark.scheduler.mode", "FAIR") val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() conf.set("spark.scheduler.allocation.file", xmlPath) - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) testCount() testTake() // Make sure we can still launch tasks. diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 7a1961137cce5..af4e68950f75a 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -17,13 +17,15 @@ package org.apache.spark +import scala.collection.mutable.ArrayBuffer + import org.mockito.Mockito._ import org.mockito.Matchers.{any, isA} import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcCallContext, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} class MapOutputTrackerSuite extends SparkFunSuite { private val conf = new SparkConf @@ -55,9 +57,11 @@ class MapOutputTrackerSuite extends SparkFunSuite { Array(1000L, 10000L))) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), Array(10000L, 1000L))) - val statuses = tracker.getServerStatuses(10, 0) - assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000), - (BlockManagerId("b", "hostB", 1000), size10000))) + val statuses = tracker.getMapSizesByExecutorId(10, 0) + assert(statuses.toSet === + Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))), + (BlockManagerId("b", "hostB", 1000), ArrayBuffer((ShuffleBlockId(10, 1, 0), size10000)))) + .toSet) tracker.stop() rpcEnv.shutdown() } @@ -75,10 +79,10 @@ class MapOutputTrackerSuite extends SparkFunSuite { tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), Array(compressedSize10000, compressedSize1000))) assert(tracker.containsShuffle(10)) - assert(tracker.getServerStatuses(10, 0).nonEmpty) + assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty) tracker.unregisterShuffle(10) assert(!tracker.containsShuffle(10)) - assert(tracker.getServerStatuses(10, 0).isEmpty) + assert(tracker.getMapSizesByExecutorId(10, 0).isEmpty) tracker.stop() rpcEnv.shutdown() @@ -104,7 +108,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { // The remaining reduce task might try to grab the output despite the shuffle failure; // this should cause it to fail, and the scheduler will ignore the failure due to the // stage already being aborted. - intercept[FetchFailedException] { tracker.getServerStatuses(10, 1) } + intercept[FetchFailedException] { tracker.getMapSizesByExecutorId(10, 1) } tracker.stop() rpcEnv.shutdown() @@ -126,23 +130,23 @@ class MapOutputTrackerSuite extends SparkFunSuite { masterTracker.registerShuffle(10, 1) masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) - intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("a", "hostA", 1000), Array(1000L))) masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) - assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000), size1000))) + assert(slaveTracker.getMapSizesByExecutorId(10, 0) === + Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) - intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } // failure should be cached - intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } masterTracker.stop() slaveTracker.stop() diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index 3316f561a4949..aa8028792cb41 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -91,13 +91,13 @@ class PartitioningSuite extends SparkFunSuite with SharedSparkContext with Priva test("RangePartitioner for keys that are not Comparable (but with Ordering)") { // Row does not extend Comparable, but has an implicit Ordering defined. - implicit object RowOrdering extends Ordering[Row] { - override def compare(x: Row, y: Row): Int = x.value - y.value + implicit object RowOrdering extends Ordering[Item] { + override def compare(x: Item, y: Item): Int = x.value - y.value } - val rdd = sc.parallelize(1 to 4500).map(x => (Row(x), Row(x))) + val rdd = sc.parallelize(1 to 4500).map(x => (Item(x), Item(x))) val partitioner = new RangePartitioner(1500, rdd) - partitioner.getPartition(Row(100)) + partitioner.getPartition(Item(100)) } test("RangPartitioner.sketch") { @@ -252,4 +252,4 @@ class PartitioningSuite extends SparkFunSuite with SharedSparkContext with Priva } -private sealed case class Row(value: Int) +private sealed case class Item(value: Int) diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index c3c2b1ffc1efa..d91b799ecfc08 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -47,7 +47,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC } test("shuffle non-zero block size") { - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) val NUM_BLOCKS = 3 val a = sc.parallelize(1 to 10, 2) @@ -66,14 +66,14 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // All blocks must have non-zero size (0 until NUM_BLOCKS).foreach { id => - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) - assert(statuses.forall(s => s._2 > 0)) + val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(shuffleId, id) + assert(statuses.forall(_._2.forall(blockIdSizePair => blockIdSizePair._2 > 0))) } } test("shuffle serializer") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) val a = sc.parallelize(1 to 10, 2) val b = a.map { x => (x, new NonJavaSerializableClass(x * 2)) @@ -89,7 +89,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC test("zero sized blocks") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) // 201 partitions (greater than "spark.shuffle.sort.bypassMergeThreshold") from 4 keys val NUM_BLOCKS = 201 @@ -105,8 +105,8 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC assert(c.count === 4) val blockSizes = (0 until NUM_BLOCKS).flatMap { id => - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) - statuses.map(x => x._2) + val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(shuffleId, id) + statuses.flatMap(_._2.map(_._2)) } val nonEmptyBlocks = blockSizes.filter(x => x > 0) @@ -116,7 +116,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC test("zero sized blocks without kryo") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) // 201 partitions (greater than "spark.shuffle.sort.bypassMergeThreshold") from 4 keys val NUM_BLOCKS = 201 @@ -130,8 +130,8 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC assert(c.count === 4) val blockSizes = (0 until NUM_BLOCKS).flatMap { id => - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) - statuses.map(x => x._2) + val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(shuffleId, id) + statuses.flatMap(_._2.map(_._2)) } val nonEmptyBlocks = blockSizes.filter(x => x > 0) @@ -141,7 +141,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC test("shuffle on mutable pairs") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2) val data = Array(p(1, 1), p(1, 2), p(1, 3), p(2, 1)) val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2) @@ -154,7 +154,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC test("sorting on mutable pairs") { // This is not in SortingSuite because of the local cluster setup. // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2) val data = Array(p(1, 11), p(3, 33), p(100, 100), p(2, 22)) val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2) @@ -168,7 +168,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC test("cogroup using mutable pairs") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2) val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1)) val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22"), p(3, "3")) @@ -195,7 +195,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC test("subtract mutable pairs") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2) val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1), p(3, 33)) val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22")) @@ -210,7 +210,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC test("sort with Java non serializable class - Kryo") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks val myConf = conf.clone().set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - sc = new SparkContext("local-cluster[2,1,512]", "test", myConf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", myConf) val a = sc.parallelize(1 to 10, 2) val b = a.map { x => (new NonJavaSerializableClass(x), x) @@ -223,7 +223,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC test("sort with Java non serializable class - Java") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) val a = sc.parallelize(1 to 10, 2) val b = a.map { x => (new NonJavaSerializableClass(x), x) diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index f89e3d0a49920..e5a14a69ef05f 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark import org.scalatest.PrivateMethodTester +import org.apache.spark.util.Utils import org.apache.spark.scheduler.{SchedulerBackend, TaskScheduler, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.{SimrSchedulerBackend, SparkDeploySchedulerBackend} import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} @@ -122,7 +123,7 @@ class SparkContextSchedulerCreationSuite } test("local-cluster") { - createTaskScheduler("local-cluster[3, 14, 512]").backend match { + createTaskScheduler("local-cluster[3, 14, 1024]").backend match { case s: SparkDeploySchedulerBackend => // OK case _ => fail() } @@ -131,7 +132,7 @@ class SparkContextSchedulerCreationSuite def testYarn(master: String, expectedClassName: String) { try { val sched = createTaskScheduler(master) - assert(sched.getClass === Class.forName(expectedClassName)) + assert(sched.getClass === Utils.classForName(expectedClassName)) } catch { case e: SparkException => assert(e.getMessage.contains("YARN mode not available")) diff --git a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala index 6580139df6c60..48509f0759a3b 100644 --- a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala +++ b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala @@ -36,7 +36,7 @@ object ThreadingSuiteState { } } -class ThreadingSuite extends SparkFunSuite with LocalSparkContext { +class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { test("accessing SparkContext form a different thread") { sc = new SparkContext("local", "test") @@ -130,8 +130,6 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext { Thread.sleep(100) } if (running.get() != 4) { - println("Waited 1 second without seeing runningThreads = 4 (it was " + - running.get() + "); failing test") ThreadingSuiteState.failed.set(true) } number @@ -143,6 +141,8 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext { } sem.acquire(2) if (ThreadingSuiteState.failed.get()) { + logError("Waited 1 second without seeing runningThreads = 4 (it was " + + ThreadingSuiteState.runningThreads.get() + "); failing test") fail("One or more threads didn't see runningThreads = 4") } } diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index c054c718075f8..48e74f06f79b1 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -69,7 +69,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { val conf = httpConf.clone conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.broadcast.compress", "true") - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", conf) + sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf) val list = List[Int](1, 2, 3, 4) val broadcast = sc.broadcast(list) val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum)) @@ -97,7 +97,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { val conf = torrentConf.clone conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.broadcast.compress", "true") - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", conf) + sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf) val list = List[Int](1, 2, 3, 4) val broadcast = sc.broadcast(list) val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum)) @@ -125,7 +125,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { test("Test Lazy Broadcast variables with TorrentBroadcast") { val numSlaves = 2 val conf = torrentConf.clone - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", conf) + sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf) val rdd = sc.parallelize(1 to numSlaves) val results = new DummyBroadcastClass(rdd).doSomething() @@ -308,7 +308,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { sc = if (distributed) { val _sc = - new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", broadcastConf) + new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", broadcastConf) // Wait until all salves are up _sc.jobProgressListener.waitUntilExecutorsUp(numSlaves, 10000) _sc diff --git a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala index ddc92814c0acf..cbd2aee10c0e2 100644 --- a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala @@ -33,7 +33,7 @@ class LogUrlsStandaloneSuite extends SparkFunSuite with LocalSparkContext { private val WAIT_TIMEOUT_MILLIS = 10000 test("verify that correct log urls get propagated from workers") { - sc = new SparkContext("local-cluster[2,1,512]", "test") + sc = new SparkContext("local-cluster[2,1,1024]", "test") val listener = new SaveExecutorInfo sc.addSparkListener(listener) @@ -66,7 +66,7 @@ class LogUrlsStandaloneSuite extends SparkFunSuite with LocalSparkContext { } val conf = new MySparkConf().set( "spark.extraListeners", classOf[SaveExecutorInfo].getName) - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) // Trigger a job so that executors get added sc.parallelize(1 to 100, 4).map(_.toString).count() diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 2e05dec99b6bf..aa78bfe30974c 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -51,9 +51,11 @@ class SparkSubmitSuite /** Simple PrintStream that reads data into a buffer */ private class BufferPrintStream extends PrintStream(noOpOutputStream) { var lineBuffer = ArrayBuffer[String]() + // scalastyle:off println override def println(line: String) { lineBuffer += line } + // scalastyle:on println } /** Returns true if the script exits and the given search string is printed. */ @@ -81,6 +83,7 @@ class SparkSubmitSuite } } + // scalastyle:off println test("prints usage on empty input") { testPrematureExit(Array[String](), "Usage: spark-submit") } @@ -243,7 +246,7 @@ class SparkSubmitSuite mainClass should be ("org.apache.spark.deploy.Client") } classpath should have size 0 - sysProps should have size 8 + sysProps should have size 9 sysProps.keys should contain ("SPARK_SUBMIT") sysProps.keys should contain ("spark.master") sysProps.keys should contain ("spark.app.name") @@ -252,6 +255,7 @@ class SparkSubmitSuite sysProps.keys should contain ("spark.driver.cores") sysProps.keys should contain ("spark.driver.supervise") sysProps.keys should contain ("spark.shuffle.spill") + sysProps.keys should contain ("spark.submit.deployMode") sysProps("spark.shuffle.spill") should be ("false") } @@ -333,7 +337,7 @@ class SparkSubmitSuite val args = Seq( "--class", JarCreationTest.getClass.getName.stripSuffix("$"), "--name", "testApp", - "--master", "local-cluster[2,1,512]", + "--master", "local-cluster[2,1,1024]", "--jars", jarsString, unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB") runSparkSubmit(args) @@ -348,7 +352,7 @@ class SparkSubmitSuite val args = Seq( "--class", JarCreationTest.getClass.getName.stripSuffix("$"), "--name", "testApp", - "--master", "local-cluster[2,1,512]", + "--master", "local-cluster[2,1,1024]", "--packages", Seq(main, dep).mkString(","), "--repositories", repo, "--conf", "spark.ui.enabled=false", @@ -491,6 +495,7 @@ class SparkSubmitSuite appArgs.executorMemory should be ("2.3g") } } + // scalastyle:on println // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. private def runSparkSubmit(args: Seq[String]): Unit = { @@ -536,8 +541,8 @@ object JarCreationTest extends Logging { val result = sc.makeRDD(1 to 100, 10).mapPartitions { x => var exception: String = null try { - Class.forName(args(0), true, Thread.currentThread().getContextClassLoader) - Class.forName(args(1), true, Thread.currentThread().getContextClassLoader) + Utils.classForName(args(0)) + Utils.classForName(args(1)) } catch { case t: Throwable => exception = t + "\n" + t.getStackTraceString diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index c9b435a9228d3..01ece1a10f46d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -41,9 +41,11 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { /** Simple PrintStream that reads data into a buffer */ private class BufferPrintStream extends PrintStream(noOpOutputStream) { var lineBuffer = ArrayBuffer[String]() + // scalastyle:off println override def println(line: String) { lineBuffer += line } + // scalastyle:on println } override def beforeAll() { diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 2a62450bcdbad..73cff89544dc3 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -243,13 +243,12 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc appListAfterRename.size should be (1) } - test("apps with multiple attempts") { + test("apps with multiple attempts with order") { val provider = new FsHistoryProvider(createTestConf()) - val attempt1 = newLogFile("app1", Some("attempt1"), inProgress = false) + val attempt1 = newLogFile("app1", Some("attempt1"), inProgress = true) writeFile(attempt1, true, None, - SparkListenerApplicationStart("app1", Some("app1"), 1L, "test", Some("attempt1")), - SparkListenerApplicationEnd(2L) + SparkListenerApplicationStart("app1", Some("app1"), 1L, "test", Some("attempt1")) ) updateAndCheck(provider) { list => @@ -259,7 +258,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc val attempt2 = newLogFile("app1", Some("attempt2"), inProgress = true) writeFile(attempt2, true, None, - SparkListenerApplicationStart("app1", Some("app1"), 3L, "test", Some("attempt2")) + SparkListenerApplicationStart("app1", Some("app1"), 2L, "test", Some("attempt2")) ) updateAndCheck(provider) { list => @@ -268,22 +267,21 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc list.head.attempts.head.attemptId should be (Some("attempt2")) } - val completedAttempt2 = newLogFile("app1", Some("attempt2"), inProgress = false) - attempt2.delete() - writeFile(attempt2, true, None, - SparkListenerApplicationStart("app1", Some("app1"), 3L, "test", Some("attempt2")), + val attempt3 = newLogFile("app1", Some("attempt3"), inProgress = false) + writeFile(attempt3, true, None, + SparkListenerApplicationStart("app1", Some("app1"), 3L, "test", Some("attempt3")), SparkListenerApplicationEnd(4L) ) updateAndCheck(provider) { list => list should not be (null) list.size should be (1) - list.head.attempts.size should be (2) - list.head.attempts.head.attemptId should be (Some("attempt2")) + list.head.attempts.size should be (3) + list.head.attempts.head.attemptId should be (Some("attempt3")) } val app2Attempt1 = newLogFile("app2", Some("attempt1"), inProgress = false) - writeFile(attempt2, true, None, + writeFile(attempt1, true, None, SparkListenerApplicationStart("app2", Some("app2"), 5L, "test", Some("attempt1")), SparkListenerApplicationEnd(6L) ) @@ -291,7 +289,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc updateAndCheck(provider) { list => list.size should be (2) list.head.attempts.size should be (1) - list.last.attempts.size should be (2) + list.last.attempts.size should be (3) list.head.attempts.head.attemptId should be (Some("attempt1")) list.foreach { case app => diff --git a/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala b/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala index f4e56632e426a..8c96b0e71dfdd 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala @@ -19,18 +19,19 @@ // when they are outside of org.apache.spark. package other.supplier +import java.nio.ByteBuffer + import scala.collection.mutable import scala.reflect.ClassTag -import akka.serialization.Serialization - import org.apache.spark.SparkConf import org.apache.spark.deploy.master._ +import org.apache.spark.serializer.Serializer class CustomRecoveryModeFactory( conf: SparkConf, - serialization: Serialization -) extends StandaloneRecoveryModeFactory(conf, serialization) { + serializer: Serializer +) extends StandaloneRecoveryModeFactory(conf, serializer) { CustomRecoveryModeFactory.instantiationAttempts += 1 @@ -40,7 +41,7 @@ class CustomRecoveryModeFactory( * */ override def createPersistenceEngine(): PersistenceEngine = - new CustomPersistenceEngine(serialization) + new CustomPersistenceEngine(serializer) /** * Create an instance of LeaderAgent that decides who gets elected as master. @@ -53,7 +54,7 @@ object CustomRecoveryModeFactory { @volatile var instantiationAttempts = 0 } -class CustomPersistenceEngine(serialization: Serialization) extends PersistenceEngine { +class CustomPersistenceEngine(serializer: Serializer) extends PersistenceEngine { val data = mutable.HashMap[String, Array[Byte]]() CustomPersistenceEngine.lastInstance = Some(this) @@ -64,10 +65,10 @@ class CustomPersistenceEngine(serialization: Serialization) extends PersistenceE */ override def persist(name: String, obj: Object): Unit = { CustomPersistenceEngine.persistAttempts += 1 - serialization.serialize(obj) match { - case util.Success(bytes) => data += name -> bytes - case util.Failure(cause) => throw new RuntimeException(cause) - } + val serialized = serializer.newInstance().serialize(obj) + val bytes = new Array[Byte](serialized.remaining()) + serialized.get(bytes) + data += name -> bytes } /** @@ -84,15 +85,9 @@ class CustomPersistenceEngine(serialization: Serialization) extends PersistenceE */ override def read[T: ClassTag](prefix: String): Seq[T] = { CustomPersistenceEngine.readAttempts += 1 - val clazz = implicitly[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]] val results = for ((name, bytes) <- data; if name.startsWith(prefix)) - yield serialization.deserialize(bytes, clazz) - - results.find(_.isFailure).foreach { - case util.Failure(cause) => throw new RuntimeException(cause) - } - - results.flatMap(_.toOption).toSeq + yield serializer.newInstance().deserialize[T](ByteBuffer.wrap(bytes)) + results.toSeq } } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 9cb6dd43bac47..4d7016d1e594b 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -25,14 +25,15 @@ import scala.language.postfixOps import org.json4s._ import org.json4s.jackson.JsonMethods._ -import org.scalatest.Matchers +import org.scalatest.{Matchers, PrivateMethodTester} import org.scalatest.concurrent.Eventually import other.supplier.{CustomPersistenceEngine, CustomRecoveryModeFactory} -import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy._ +import org.apache.spark.rpc.RpcEnv -class MasterSuite extends SparkFunSuite with Matchers with Eventually { +class MasterSuite extends SparkFunSuite with Matchers with Eventually with PrivateMethodTester { test("can use a custom recovery mode factory") { val conf = new SparkConf(loadDefaults = false) @@ -105,7 +106,7 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually { persistenceEngine.addDriver(driverToPersist) persistenceEngine.addWorker(workerToPersist) - val (apps, drivers, workers) = persistenceEngine.readPersistedData() + val (apps, drivers, workers) = persistenceEngine.readPersistedData(rpcEnv) apps.map(_.id) should contain(appToPersist.id) drivers.map(_.id) should contain(driverToPersist.id) @@ -142,4 +143,196 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually { } } + test("basic scheduling - spread out") { + testBasicScheduling(spreadOut = true) + } + + test("basic scheduling - no spread out") { + testBasicScheduling(spreadOut = false) + } + + test("scheduling with max cores - spread out") { + testSchedulingWithMaxCores(spreadOut = true) + } + + test("scheduling with max cores - no spread out") { + testSchedulingWithMaxCores(spreadOut = false) + } + + test("scheduling with cores per executor - spread out") { + testSchedulingWithCoresPerExecutor(spreadOut = true) + } + + test("scheduling with cores per executor - no spread out") { + testSchedulingWithCoresPerExecutor(spreadOut = false) + } + + test("scheduling with cores per executor AND max cores - spread out") { + testSchedulingWithCoresPerExecutorAndMaxCores(spreadOut = true) + } + + test("scheduling with cores per executor AND max cores - no spread out") { + testSchedulingWithCoresPerExecutorAndMaxCores(spreadOut = false) + } + + private def testBasicScheduling(spreadOut: Boolean): Unit = { + val master = makeMaster() + val appInfo = makeAppInfo(1024) + val workerInfo = makeWorkerInfo(4096, 10) + val workerInfos = Array(workerInfo, workerInfo, workerInfo) + val scheduledCores = master.invokePrivate( + _scheduleExecutorsOnWorkers(appInfo, workerInfos, spreadOut)) + assert(scheduledCores.length === 3) + assert(scheduledCores(0) === 10) + assert(scheduledCores(1) === 10) + assert(scheduledCores(2) === 10) + } + + private def testSchedulingWithMaxCores(spreadOut: Boolean): Unit = { + val master = makeMaster() + val appInfo1 = makeAppInfo(1024, maxCores = Some(8)) + val appInfo2 = makeAppInfo(1024, maxCores = Some(16)) + val workerInfo = makeWorkerInfo(4096, 10) + val workerInfos = Array(workerInfo, workerInfo, workerInfo) + var scheduledCores = master.invokePrivate( + _scheduleExecutorsOnWorkers(appInfo1, workerInfos, spreadOut)) + assert(scheduledCores.length === 3) + // With spreading out, each worker should be assigned a few cores + if (spreadOut) { + assert(scheduledCores(0) === 3) + assert(scheduledCores(1) === 3) + assert(scheduledCores(2) === 2) + } else { + // Without spreading out, the cores should be concentrated on the first worker + assert(scheduledCores(0) === 8) + assert(scheduledCores(1) === 0) + assert(scheduledCores(2) === 0) + } + // Now test the same thing with max cores > cores per worker + scheduledCores = master.invokePrivate( + _scheduleExecutorsOnWorkers(appInfo2, workerInfos, spreadOut)) + assert(scheduledCores.length === 3) + if (spreadOut) { + assert(scheduledCores(0) === 6) + assert(scheduledCores(1) === 5) + assert(scheduledCores(2) === 5) + } else { + // Without spreading out, the first worker should be fully booked, + // and the leftover cores should spill over to the second worker only. + assert(scheduledCores(0) === 10) + assert(scheduledCores(1) === 6) + assert(scheduledCores(2) === 0) + } + } + + private def testSchedulingWithCoresPerExecutor(spreadOut: Boolean): Unit = { + val master = makeMaster() + val appInfo1 = makeAppInfo(1024, coresPerExecutor = Some(2)) + val appInfo2 = makeAppInfo(256, coresPerExecutor = Some(2)) + val appInfo3 = makeAppInfo(256, coresPerExecutor = Some(3)) + val workerInfo = makeWorkerInfo(4096, 10) + val workerInfos = Array(workerInfo, workerInfo, workerInfo) + // Each worker should end up with 4 executors with 2 cores each + // This should be 4 because of the memory restriction on each worker + var scheduledCores = master.invokePrivate( + _scheduleExecutorsOnWorkers(appInfo1, workerInfos, spreadOut)) + assert(scheduledCores.length === 3) + assert(scheduledCores(0) === 8) + assert(scheduledCores(1) === 8) + assert(scheduledCores(2) === 8) + // Now test the same thing without running into the worker memory limit + // Each worker should now end up with 5 executors with 2 cores each + scheduledCores = master.invokePrivate( + _scheduleExecutorsOnWorkers(appInfo2, workerInfos, spreadOut)) + assert(scheduledCores.length === 3) + assert(scheduledCores(0) === 10) + assert(scheduledCores(1) === 10) + assert(scheduledCores(2) === 10) + // Now test the same thing with a cores per executor that 10 is not divisible by + scheduledCores = master.invokePrivate( + _scheduleExecutorsOnWorkers(appInfo3, workerInfos, spreadOut)) + assert(scheduledCores.length === 3) + assert(scheduledCores(0) === 9) + assert(scheduledCores(1) === 9) + assert(scheduledCores(2) === 9) + } + + // Sorry for the long method name! + private def testSchedulingWithCoresPerExecutorAndMaxCores(spreadOut: Boolean): Unit = { + val master = makeMaster() + val appInfo1 = makeAppInfo(256, coresPerExecutor = Some(2), maxCores = Some(4)) + val appInfo2 = makeAppInfo(256, coresPerExecutor = Some(2), maxCores = Some(20)) + val appInfo3 = makeAppInfo(256, coresPerExecutor = Some(3), maxCores = Some(20)) + val workerInfo = makeWorkerInfo(4096, 10) + val workerInfos = Array(workerInfo, workerInfo, workerInfo) + // We should only launch two executors, each with exactly 2 cores + var scheduledCores = master.invokePrivate( + _scheduleExecutorsOnWorkers(appInfo1, workerInfos, spreadOut)) + assert(scheduledCores.length === 3) + if (spreadOut) { + assert(scheduledCores(0) === 2) + assert(scheduledCores(1) === 2) + assert(scheduledCores(2) === 0) + } else { + assert(scheduledCores(0) === 4) + assert(scheduledCores(1) === 0) + assert(scheduledCores(2) === 0) + } + // Test max cores > number of cores per worker + scheduledCores = master.invokePrivate( + _scheduleExecutorsOnWorkers(appInfo2, workerInfos, spreadOut)) + assert(scheduledCores.length === 3) + if (spreadOut) { + assert(scheduledCores(0) === 8) + assert(scheduledCores(1) === 6) + assert(scheduledCores(2) === 6) + } else { + assert(scheduledCores(0) === 10) + assert(scheduledCores(1) === 10) + assert(scheduledCores(2) === 0) + } + // Test max cores > number of cores per worker AND + // a cores per executor that is 10 is not divisible by + scheduledCores = master.invokePrivate( + _scheduleExecutorsOnWorkers(appInfo3, workerInfos, spreadOut)) + assert(scheduledCores.length === 3) + if (spreadOut) { + assert(scheduledCores(0) === 6) + assert(scheduledCores(1) === 6) + assert(scheduledCores(2) === 6) + } else { + assert(scheduledCores(0) === 9) + assert(scheduledCores(1) === 9) + assert(scheduledCores(2) === 0) + } + } + + // =============================== + // | Utility methods for testing | + // =============================== + + private val _scheduleExecutorsOnWorkers = PrivateMethod[Array[Int]]('scheduleExecutorsOnWorkers) + + private def makeMaster(conf: SparkConf = new SparkConf): Master = { + val securityMgr = new SecurityManager(conf) + val rpcEnv = RpcEnv.create(Master.SYSTEM_NAME, "localhost", 7077, conf, securityMgr) + val master = new Master(rpcEnv, rpcEnv.address, 8080, securityMgr, conf) + master + } + + private def makeAppInfo( + memoryPerExecutorMb: Int, + coresPerExecutor: Option[Int] = None, + maxCores: Option[Int] = None): ApplicationInfo = { + val desc = new ApplicationDescription( + "test", maxCores, memoryPerExecutorMb, null, "", None, None, coresPerExecutor) + val appId = System.currentTimeMillis.toString + new ApplicationInfo(0, appId, desc, new Date, null, Int.MaxValue) + } + + private def makeWorkerInfo(memoryMb: Int, cores: Int): WorkerInfo = { + val workerId = System.currentTimeMillis.toString + new WorkerInfo(workerId, "host", 100, cores, memoryMb, null, 101, "address") + } + } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala new file mode 100644 index 0000000000000..11e87bd1dd8eb --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.deploy.master + +import java.net.ServerSocket + +import org.apache.commons.lang3.RandomUtils +import org.apache.curator.test.TestingServer + +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.rpc.{RpcEndpoint, RpcEnv} +import org.apache.spark.serializer.{Serializer, JavaSerializer} +import org.apache.spark.util.Utils + +class PersistenceEngineSuite extends SparkFunSuite { + + test("FileSystemPersistenceEngine") { + val dir = Utils.createTempDir() + try { + val conf = new SparkConf() + testPersistenceEngine(conf, serializer => + new FileSystemPersistenceEngine(dir.getAbsolutePath, serializer) + ) + } finally { + Utils.deleteRecursively(dir) + } + } + + test("ZooKeeperPersistenceEngine") { + val conf = new SparkConf() + // TestingServer logs the port conflict exception rather than throwing an exception. + // So we have to find a free port by ourselves. This approach cannot guarantee always starting + // zkTestServer successfully because there is a time gap between finding a free port and + // starting zkTestServer. But the failure possibility should be very low. + val zkTestServer = new TestingServer(findFreePort(conf)) + try { + testPersistenceEngine(conf, serializer => { + conf.set("spark.deploy.zookeeper.url", zkTestServer.getConnectString) + new ZooKeeperPersistenceEngine(conf, serializer) + }) + } finally { + zkTestServer.stop() + } + } + + private def testPersistenceEngine( + conf: SparkConf, persistenceEngineCreator: Serializer => PersistenceEngine): Unit = { + val serializer = new JavaSerializer(conf) + val persistenceEngine = persistenceEngineCreator(serializer) + persistenceEngine.persist("test_1", "test_1_value") + assert(Seq("test_1_value") === persistenceEngine.read[String]("test_")) + persistenceEngine.persist("test_2", "test_2_value") + assert(Set("test_1_value", "test_2_value") === persistenceEngine.read[String]("test_").toSet) + persistenceEngine.unpersist("test_1") + assert(Seq("test_2_value") === persistenceEngine.read[String]("test_")) + persistenceEngine.unpersist("test_2") + assert(persistenceEngine.read[String]("test_").isEmpty) + + // Test deserializing objects that contain RpcEndpointRef + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + try { + // Create a real endpoint so that we can test RpcEndpointRef deserialization + val workerEndpoint = rpcEnv.setupEndpoint("worker", new RpcEndpoint { + override val rpcEnv: RpcEnv = rpcEnv + }) + + val workerToPersist = new WorkerInfo( + id = "test_worker", + host = "127.0.0.1", + port = 10000, + cores = 0, + memory = 0, + endpoint = workerEndpoint, + webUiPort = 0, + publicAddress = "" + ) + + persistenceEngine.addWorker(workerToPersist) + + val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData(rpcEnv) + + assert(storedApps.isEmpty) + assert(storedDrivers.isEmpty) + + // Check deserializing WorkerInfo + assert(storedWorkers.size == 1) + val recoveryWorkerInfo = storedWorkers.head + assert(workerToPersist.id === recoveryWorkerInfo.id) + assert(workerToPersist.host === recoveryWorkerInfo.host) + assert(workerToPersist.port === recoveryWorkerInfo.port) + assert(workerToPersist.cores === recoveryWorkerInfo.cores) + assert(workerToPersist.memory === recoveryWorkerInfo.memory) + assert(workerToPersist.endpoint === recoveryWorkerInfo.endpoint) + assert(workerToPersist.webUiPort === recoveryWorkerInfo.webUiPort) + assert(workerToPersist.publicAddress === recoveryWorkerInfo.publicAddress) + } finally { + rpcEnv.shutdown() + rpcEnv.awaitTermination() + } + } + + private def findFreePort(conf: SparkConf): Int = { + val candidatePort = RandomUtils.nextInt(1024, 65536) + Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { + val socket = new ServerSocket(trialPort) + socket.close() + (null, trialPort) + }, conf)._2 + } +} diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala index 63947df3d43a2..8a199459c1ddf 100644 --- a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala @@ -27,7 +27,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.hadoop.io.Text -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{Logging, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.util.Utils import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodecFactory, GzipCodec} @@ -36,7 +36,7 @@ import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodecFactory, Gzi * [[org.apache.spark.input.WholeTextFileRecordReader WholeTextFileRecordReader]]. A temporary * directory is created as fake input. Temporal storage would be deleted in the end. */ -class WholeTextFileRecordReaderSuite extends SparkFunSuite with BeforeAndAfterAll { +class WholeTextFileRecordReaderSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { private var sc: SparkContext = _ private var factory: CompressionCodecFactory = _ @@ -85,7 +85,7 @@ class WholeTextFileRecordReaderSuite extends SparkFunSuite with BeforeAndAfterAl */ test("Correctness of WholeTextFileRecordReader.") { val dir = Utils.createTempDir() - println(s"Local disk address is ${dir.toString}.") + logInfo(s"Local disk address is ${dir.toString}.") WholeTextFileRecordReaderSuite.files.foreach { case (filename, contents) => createNativeFile(dir, filename, contents, false) @@ -109,7 +109,7 @@ class WholeTextFileRecordReaderSuite extends SparkFunSuite with BeforeAndAfterAl test("Correctness of WholeTextFileRecordReader with GzipCodec.") { val dir = Utils.createTempDir() - println(s"Local disk address is ${dir.toString}.") + logInfo(s"Local disk address is ${dir.toString}.") WholeTextFileRecordReaderSuite.files.foreach { case (filename, contents) => createNativeFile(dir, filename, contents, true) diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index 9e4d34fb7d382..d3218a548efc7 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -60,7 +60,9 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext tmpFile = new File(testTempDir, getClass.getSimpleName + ".txt") val pw = new PrintWriter(new FileWriter(tmpFile)) for (x <- 1 to numRecords) { + // scalastyle:off println pw.println(RandomUtils.nextInt(0, numBuckets)) + // scalastyle:on println } pw.close() diff --git a/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala index 08215a2bafc09..05013fbc49b8e 100644 --- a/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala @@ -22,11 +22,12 @@ import java.sql._ import org.scalatest.BeforeAndAfter import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite} +import org.apache.spark.util.Utils class JdbcRDDSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { before { - Class.forName("org.apache.derby.jdbc.EmbeddedDriver") + Utils.classForName("org.apache.derby.jdbc.EmbeddedDriver") val conn = DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;create=true") try { diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index dfa102f432a02..1321ec84735b5 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -282,6 +282,29 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { )) } + // See SPARK-9326 + test("cogroup with empty RDD") { + import scala.reflect.classTag + val intPairCT = classTag[(Int, Int)] + + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.emptyRDD[(Int, Int)](intPairCT) + + val joined = rdd1.cogroup(rdd2).collect() + assert(joined.size > 0) + } + + // See SPARK-9326 + test("cogroup with groupByed RDD having 0 partitions") { + import scala.reflect.classTag + val intCT = classTag[Int] + + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.emptyRDD[Int](intCT).groupBy((x) => 5) + val joined = rdd1.cogroup(rdd2).collect() + assert(joined.size > 0) + } + test("rightOuterJoin") { val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala index 32f04d54eff94..3e8816a4c65be 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala @@ -175,7 +175,7 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext { } val hadoopPart1 = generateFakeHadoopPartition() val pipedRdd = new PipedRDD(nums, "printenv " + varName) - val tContext = new TaskContextImpl(0, 0, 0, 0, null) + val tContext = new TaskContextImpl(0, 0, 0, 0, null, null) val rddIter = pipedRdd.compute(hadoopPart1, tContext) val arr = rddIter.toArray assert(arr(0) == "/some/path") diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index f6da9f98ad253..5f718ea9f7be1 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -679,7 +679,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { test("runJob on an invalid partition") { intercept[IllegalArgumentException] { - sc.runJob(sc.parallelize(1 to 10, 2), {iter: Iterator[Int] => iter.size}, Seq(0, 1, 2), false) + sc.runJob(sc.parallelize(1 to 10, 2), {iter: Iterator[Int] => iter.size}, Seq(0, 1, 2)) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala index 34145691153ce..eef6aafa624ee 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala @@ -26,7 +26,7 @@ class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkCo val conf = new SparkConf conf.set("spark.akka.frameSize", "1") conf.set("spark.default.parallelism", "1") - sc = new SparkContext("local-cluster[2 , 1 , 512]", "test", conf) + sc = new SparkContext("local-cluster[2, 1, 1024]", "test", conf) val frameSize = AkkaUtils.maxFrameSizeBytes(sc.conf) val buffer = new SerializableBuffer(java.nio.ByteBuffer.allocate(2 * frameSize)) val larger = sc.parallelize(Seq(buffer)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 6bc45f249f975..86dff8fb577d5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -101,9 +101,15 @@ class DAGSchedulerSuite /** Length of time to wait while draining listener events. */ val WAIT_TIMEOUT_MILLIS = 10000 val sparkListener = new SparkListener() { + val submittedStageInfos = new HashSet[StageInfo] val successfulStages = new HashSet[Int] val failedStages = new ArrayBuffer[Int] val stageByOrderOfExecution = new ArrayBuffer[Int] + + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { + submittedStageInfos += stageSubmitted.stageInfo + } + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { val stageInfo = stageCompleted.stageInfo stageByOrderOfExecution += stageInfo.stageId @@ -147,9 +153,8 @@ class DAGSchedulerSuite } before { - // Enable local execution for this test - val conf = new SparkConf().set("spark.localExecution.enabled", "true") - sc = new SparkContext("local", "DAGSchedulerSuite", conf) + sc = new SparkContext("local", "DAGSchedulerSuite") + sparkListener.submittedStageInfos.clear() sparkListener.successfulStages.clear() sparkListener.failedStages.clear() failure = null @@ -165,12 +170,7 @@ class DAGSchedulerSuite sc.listenerBus, mapOutputTracker, blockManagerMaster, - sc.env) { - override def runLocally(job: ActiveJob) { - // don't bother with the thread while unit testing - runLocallyWithinThread(job) - } - } + sc.env) dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(scheduler) } @@ -234,10 +234,9 @@ class DAGSchedulerSuite rdd: RDD[_], partitions: Array[Int], func: (TaskContext, Iterator[_]) => _ = jobComputeFunc, - allowLocal: Boolean = false, listener: JobListener = jobListener): Int = { val jobId = scheduler.nextJobId.getAndIncrement() - runEvent(JobSubmitted(jobId, rdd, func, partitions, allowLocal, CallSite("", ""), listener)) + runEvent(JobSubmitted(jobId, rdd, func, partitions, CallSite("", ""), listener)) jobId } @@ -277,37 +276,6 @@ class DAGSchedulerSuite assertDataStructuresEmpty() } - test("local job") { - val rdd = new PairOfIntsRDD(sc, Nil) { - override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = - Array(42 -> 0).iterator - override def getPartitions: Array[Partition] = - Array( new Partition { override def index: Int = 0 } ) - override def getPreferredLocations(split: Partition): List[String] = Nil - override def toString: String = "DAGSchedulerSuite Local RDD" - } - val jobId = scheduler.nextJobId.getAndIncrement() - runEvent( - JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, CallSite("", ""), jobListener)) - assert(results === Map(0 -> 42)) - assertDataStructuresEmpty() - } - - test("local job oom") { - val rdd = new PairOfIntsRDD(sc, Nil) { - override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = - throw new java.lang.OutOfMemoryError("test local job oom") - override def getPartitions = Array( new Partition { override def index = 0 } ) - override def getPreferredLocations(split: Partition) = Nil - override def toString = "DAGSchedulerSuite Local RDD" - } - val jobId = scheduler.nextJobId.getAndIncrement() - runEvent( - JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, CallSite("", ""), jobListener)) - assert(results.size == 0) - assertDataStructuresEmpty() - } - test("run trivial job w/ dependency") { val baseRdd = new MyRDD(sc, 1, Nil) val finalRdd = new MyRDD(sc, 1, List(new OneToOneDependency(baseRdd))) @@ -445,12 +413,7 @@ class DAGSchedulerSuite sc.listenerBus, mapOutputTracker, blockManagerMaster, - sc.env) { - override def runLocally(job: ActiveJob) { - // don't bother with the thread while unit testing - runLocallyWithinThread(job) - } - } + sc.env) dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(noKillScheduler) val jobId = submit(new MyRDD(sc, 1, Nil), Array(0)) cancel(jobId) @@ -476,8 +439,8 @@ class DAGSchedulerSuite complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)))) - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === - Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) complete(taskSets(1), Seq((Success, 42))) assert(results === Map(0 -> 42)) assertDataStructuresEmpty() @@ -503,8 +466,8 @@ class DAGSchedulerSuite // have the 2nd attempt pass complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.size)))) // we can see both result blocks now - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === - Array("hostA", "hostB")) + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === + HashSet("hostA", "hostB")) complete(taskSets(3), Seq((Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) assertDataStructuresEmpty() @@ -520,8 +483,8 @@ class DAGSchedulerSuite (Success, makeMapStatus("hostA", reduceRdd.partitions.size)), (Success, makeMapStatus("hostB", reduceRdd.partitions.size)))) // The MapOutputTracker should know about both map output locations. - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === - Array("hostA", "hostB")) + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === + HashSet("hostA", "hostB")) // The first result task fails, with a fetch failure for the output from the first mapper. runEvent(CompletionEvent( @@ -547,6 +510,140 @@ class DAGSchedulerSuite assert(sparkListener.failedStages.size == 1) } + /** + * This tests the case where another FetchFailed comes in while the map stage is getting + * re-run. + */ + test("late fetch failures don't cause multiple concurrent attempts for the same map stage") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) + submit(reduceRdd, Array(0, 1)) + + val mapStageId = 0 + def countSubmittedMapStageAttempts(): Int = { + sparkListener.submittedStageInfos.count(_.stageId == mapStageId) + } + + // The map stage should have been submitted. + assert(countSubmittedMapStageAttempts() === 1) + + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostB", 2)))) + // The MapOutputTracker should know about both map output locations. + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === + HashSet("hostA", "hostB")) + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 1).map(_._1.host).toSet === + HashSet("hostA", "hostB")) + + // The first result task fails, with a fetch failure for the output from the first mapper. + runEvent(CompletionEvent( + taskSets(1).tasks(0), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), + null, + Map[Long, Any](), + createFakeTaskInfo(), + null)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(sparkListener.failedStages.contains(1)) + + // Trigger resubmission of the failed map stage. + runEvent(ResubmitFailedStages) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + + // Another attempt for the map stage should have been submitted, resulting in 2 total attempts. + assert(countSubmittedMapStageAttempts() === 2) + + // The second ResultTask fails, with a fetch failure for the output from the second mapper. + runEvent(CompletionEvent( + taskSets(1).tasks(1), + FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1, 1, "ignored"), + null, + Map[Long, Any](), + createFakeTaskInfo(), + null)) + + // Another ResubmitFailedStages event should not result in another attempt for the map + // stage being run concurrently. + // NOTE: the actual ResubmitFailedStages may get called at any time during this, but it + // shouldn't effect anything -- our calling it just makes *SURE* it gets called between the + // desired event and our check. + runEvent(ResubmitFailedStages) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(countSubmittedMapStageAttempts() === 2) + + } + + /** + * This tests the case where a late FetchFailed comes in after the map stage has finished getting + * retried and a new reduce stage starts running. + */ + test("extremely late fetch failures don't cause multiple concurrent attempts for " + + "the same stage") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) + submit(reduceRdd, Array(0, 1)) + + def countSubmittedReduceStageAttempts(): Int = { + sparkListener.submittedStageInfos.count(_.stageId == 1) + } + def countSubmittedMapStageAttempts(): Int = { + sparkListener.submittedStageInfos.count(_.stageId == 0) + } + + // The map stage should have been submitted. + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(countSubmittedMapStageAttempts() === 1) + + // Complete the map stage. + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostB", 2)))) + + // The reduce stage should have been submitted. + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(countSubmittedReduceStageAttempts() === 1) + + // The first result task fails, with a fetch failure for the output from the first mapper. + runEvent(CompletionEvent( + taskSets(1).tasks(0), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), + null, + Map[Long, Any](), + createFakeTaskInfo(), + null)) + + // Trigger resubmission of the failed map stage and finish the re-started map task. + runEvent(ResubmitFailedStages) + complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1)))) + + // Because the map stage finished, another attempt for the reduce stage should have been + // submitted, resulting in 2 total attempts for each the map and the reduce stage. + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(countSubmittedMapStageAttempts() === 2) + assert(countSubmittedReduceStageAttempts() === 2) + + // A late FetchFailed arrives from the second task in the original reduce stage. + runEvent(CompletionEvent( + taskSets(1).tasks(1), + FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1, 1, "ignored"), + null, + Map[Long, Any](), + createFakeTaskInfo(), + null)) + + // Running ResubmitFailedStages shouldn't result in any more attempts for the map stage, because + // the FetchFailed should have been ignored + runEvent(ResubmitFailedStages) + + // The FetchFailed from the original reduce stage should be ignored. + assert(countSubmittedMapStageAttempts() === 2) + } + test("ignore late map task completions") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) @@ -572,8 +669,8 @@ class DAGSchedulerSuite taskSet.tasks(1).epoch = newEpoch runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", reduceRdd.partitions.size), null, createFakeTaskInfo(), null)) - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === - Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) complete(taskSets(1), Seq((Success, 42), (Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) assertDataStructuresEmpty() @@ -668,8 +765,8 @@ class DAGSchedulerSuite (Success, makeMapStatus("hostB", 1)))) // have hostC complete the resubmitted task complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1)))) - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === - Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) complete(taskSets(2), Seq((Success, 42))) assert(results === Map(0 -> 42)) assertDataStructuresEmpty() @@ -748,40 +845,23 @@ class DAGSchedulerSuite // Run this on executors sc.parallelize(1 to 10, 2).foreach { item => acc.add(1) } - // Run this within a local thread - sc.parallelize(1 to 10, 2).map { item => acc.add(1) }.take(1) - - // Make sure we can still run local commands as well as cluster commands. + // Make sure we can still run commands assert(sc.parallelize(1 to 10, 2).count() === 10) - assert(sc.parallelize(1 to 10, 2).first() === 1) } test("misbehaved resultHandler should not crash DAGScheduler and SparkContext") { - val e1 = intercept[SparkDriverExecutionException] { - val rdd = sc.parallelize(1 to 10, 2) - sc.runJob[Int, Int]( - rdd, - (context: TaskContext, iter: Iterator[Int]) => iter.size, - Seq(0), - allowLocal = true, - (part: Int, result: Int) => throw new DAGSchedulerSuiteDummyException) - } - assert(e1.getCause.isInstanceOf[DAGSchedulerSuiteDummyException]) - - val e2 = intercept[SparkDriverExecutionException] { + val e = intercept[SparkDriverExecutionException] { val rdd = sc.parallelize(1 to 10, 2) sc.runJob[Int, Int]( rdd, (context: TaskContext, iter: Iterator[Int]) => iter.size, Seq(0, 1), - allowLocal = false, (part: Int, result: Int) => throw new DAGSchedulerSuiteDummyException) } - assert(e2.getCause.isInstanceOf[DAGSchedulerSuiteDummyException]) + assert(e.getCause.isInstanceOf[DAGSchedulerSuiteDummyException]) - // Make sure we can still run local commands as well as cluster commands. + // Make sure we can still run commands assert(sc.parallelize(1 to 10, 2).count() === 10) - assert(sc.parallelize(1 to 10, 2).first() === 1) } test("getPartitions exceptions should not crash DAGScheduler and SparkContext (SPARK-8606)") { @@ -794,9 +874,8 @@ class DAGSchedulerSuite rdd.reduceByKey(_ + _, 1).count() } - // Make sure we can still run local commands as well as cluster commands. + // Make sure we can still run commands assert(sc.parallelize(1 to 10, 2).count() === 10) - assert(sc.parallelize(1 to 10, 2).first() === 1) } test("getPreferredLocations errors should not crash DAGScheduler and SparkContext (SPARK-8606)") { @@ -810,9 +889,8 @@ class DAGSchedulerSuite } assert(e1.getMessage.contains(classOf[DAGSchedulerSuiteDummyException].getName)) - // Make sure we can still run local commands as well as cluster commands. + // Make sure we can still run commands assert(sc.parallelize(1 to 10, 2).count() === 10) - assert(sc.parallelize(1 to 10, 2).first() === 1) } test("accumulator not calculated for resubmitted result stage") { @@ -840,8 +918,8 @@ class DAGSchedulerSuite submit(reduceRdd, Array(0)) complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)))) - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === - Array(makeBlockManagerId("hostA"))) + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostA"))) // Reducer should run on the same host that map task ran val reduceTaskSet = taskSets(1) @@ -875,6 +953,21 @@ class DAGSchedulerSuite assertDataStructuresEmpty } + test("Spark exceptions should include call site in stack trace") { + val e = intercept[SparkException] { + sc.parallelize(1 to 10, 2).map { _ => throw new RuntimeException("uh-oh!") }.count() + } + + // Does not include message, ONLY stack trace. + val stackTraceString = e.getStackTraceString + + // should actually include the RDD operation that invoked the method: + assert(stackTraceString.contains("org.apache.spark.rdd.RDD.count")) + + // should include the FunSuite setup: + assert(stackTraceString.contains("org.scalatest.FunSuite")) + } + /** * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. * Note that this checks only the host and not the executor ID. diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index f681f21b6205e..5cb2d4225d281 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -180,7 +180,7 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit // into SPARK-6688. val conf = getLoggingConf(testDirPath, compressionCodec) .set("spark.hadoop.fs.defaultFS", "unsupported://example.com") - val sc = new SparkContext("local-cluster[2,2,512]", "test", conf) + val sc = new SparkContext("local-cluster[2,2,1024]", "test", conf) assert(sc.eventLogger.isDefined) val eventLogger = sc.eventLogger.get val eventLogPath = eventLogger.logPath diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index 0a7cb69416a08..b3ca150195a5f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler import org.apache.spark.TaskContext -class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0) { +class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0, 0) { override def runTask(context: TaskContext): Int = 0 override def preferredLocations: Seq[TaskLocation] = prefLocs @@ -31,12 +31,16 @@ object FakeTask { * locations for each task (given as varargs) if this sequence is not empty. */ def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = { + createTaskSet(numTasks, 0, prefLocs: _*) + } + + def createTaskSet(numTasks: Int, stageAttemptId: Int, prefLocs: Seq[TaskLocation]*): TaskSet = { if (prefLocs.size != 0 && prefLocs.size != numTasks) { throw new IllegalArgumentException("Wrong number of task locations") } val tasks = Array.tabulate[Task[_]](numTasks) { i => new FakeTask(i, if (prefLocs.size != 0) prefLocs(i) else Nil) } - new TaskSet(tasks, 0, 0, 0, null) + new TaskSet(tasks, 0, stageAttemptId, 0, null) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala index 9b92f8de56759..383855caefa2f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala @@ -25,7 +25,7 @@ import org.apache.spark.TaskContext * A Task implementation that fails to serialize. */ private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int) - extends Task[Array[Byte]](stageId, 0) { + extends Task[Array[Byte]](stageId, 0, 0) { override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte] override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]() diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index a9036da9cc93d..e5ecd4b7c2610 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -134,14 +134,14 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { test("Only one of two duplicate commit tasks should commit") { val rdd = sc.parallelize(Seq(1), 1) sc.runJob(rdd, OutputCommitFunctions(tempDir.getAbsolutePath).commitSuccessfully _, - 0 until rdd.partitions.size, allowLocal = false) + 0 until rdd.partitions.size) assert(tempDir.list().size === 1) } test("If commit fails, if task is retried it should not be locked, and will succeed.") { val rdd = sc.parallelize(Seq(1), 1) sc.runJob(rdd, OutputCommitFunctions(tempDir.getAbsolutePath).failFirstCommitAttempt _, - 0 until rdd.partitions.size, allowLocal = false) + 0 until rdd.partitions.size) assert(tempDir.list().size === 1) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index ff3fa95ec32ae..103fc19369c97 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -52,8 +52,10 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter { val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", None, 125L, "Mickey", None) val applicationEnd = SparkListenerApplicationEnd(1000L) + // scalastyle:off println writer.println(compact(render(JsonProtocol.sparkEventToJson(applicationStart)))) writer.println(compact(render(JsonProtocol.sparkEventToJson(applicationEnd)))) + // scalastyle:on println writer.close() val conf = EventLoggingListenerSuite.getLoggingConf(logFilePath) @@ -100,7 +102,7 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter { fileSystem.mkdirs(logDirPath) val conf = EventLoggingListenerSuite.getLoggingConf(logDirPath, codecName) - val sc = new SparkContext("local-cluster[2,1,512]", "Test replay", conf) + val sc = new SparkContext("local-cluster[2,1,1024]", "Test replay", conf) // Run a few jobs sc.parallelize(1 to 100, 1).count() diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 651295b7344c5..730535ece7878 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -188,7 +188,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match sc.addSparkListener(listener) val rdd1 = sc.parallelize(1 to 100, 4) val rdd2 = rdd1.map(_.toString) - sc.runJob(rdd2, (items: Iterator[String]) => items.size, Seq(0, 1), true) + sc.runJob(rdd2, (items: Iterator[String]) => items.size, Seq(0, 1)) sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala index d97fba00976d2..d1e23ed527ff1 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala @@ -34,7 +34,7 @@ class SparkListenerWithClusterSuite extends SparkFunSuite with LocalSparkContext val WAIT_TIMEOUT_MILLIS = 10000 before { - sc = new SparkContext("local-cluster[2,1,512]", "SparkListenerSuite") + sc = new SparkContext("local-cluster[2,1,1024]", "SparkListenerSuite") } test("SparkListener sends executor added message") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 7c1adc1aef1b6..9201d1e1f328b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -24,11 +24,27 @@ import org.scalatest.BeforeAndAfter import org.apache.spark._ import org.apache.spark.rdd.RDD -import org.apache.spark.util.{TaskCompletionListenerException, TaskCompletionListener} +import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} +import org.apache.spark.metrics.source.JvmSource class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { + test("provide metrics sources") { + val filePath = getClass.getClassLoader.getResource("test_metrics_config.properties").getFile + val conf = new SparkConf(loadDefaults = false) + .set("spark.metrics.conf", filePath) + sc = new SparkContext("local", "test", conf) + val rdd = sc.makeRDD(1 to 1) + val result = sc.runJob(rdd, (tc: TaskContext, it: Iterator[Int]) => { + tc.getMetricsSources("jvm").count { + case source: JvmSource => true + case _ => false + } + }).sum + assert(result > 0) + } + test("calls TaskCompletionListener after failure") { TaskContextSuite.completed = false sc = new SparkContext("local", "test") @@ -41,16 +57,16 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark } val closureSerializer = SparkEnv.get.closureSerializer.newInstance() val func = (c: TaskContext, i: Iterator[String]) => i.next() - val task = new ResultTask[String, String]( - 0, sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0) + val task = new ResultTask[String, String](0, 0, + sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0) intercept[RuntimeException] { - task.run(0, 0) + task.run(0, 0, null) } assert(TaskContextSuite.completed === true) } test("all TaskCompletionListeners should be called even if some fail") { - val context = new TaskContextImpl(0, 0, 0, 0, null) + val context = new TaskContextImpl(0, 0, 0, 0, null, null) val listener = mock(classOf[TaskCompletionListener]) context.addTaskCompletionListener(_ => throw new Exception("blah")) context.addTaskCompletionListener(listener) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index a6d5232feb8de..c2edd4c317d6e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -33,7 +33,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L val taskScheduler = new TaskSchedulerImpl(sc) taskScheduler.initialize(new FakeSchedulerBackend) // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. - val dagScheduler = new DAGScheduler(sc, taskScheduler) { + new DAGScheduler(sc, taskScheduler) { override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} override def executorAdded(execId: String, host: String) {} } @@ -67,7 +67,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L val taskScheduler = new TaskSchedulerImpl(sc) taskScheduler.initialize(new FakeSchedulerBackend) // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. - val dagScheduler = new DAGScheduler(sc, taskScheduler) { + new DAGScheduler(sc, taskScheduler) { override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} override def executorAdded(execId: String, host: String) {} } @@ -128,4 +128,113 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L assert(taskDescriptions.map(_.executorId) === Seq("executor0")) } + test("refuse to schedule concurrent attempts for the same stage (SPARK-8103)") { + sc = new SparkContext("local", "TaskSchedulerImplSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + taskScheduler.initialize(new FakeSchedulerBackend) + // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. + val dagScheduler = new DAGScheduler(sc, taskScheduler) { + override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} + override def executorAdded(execId: String, host: String) {} + } + taskScheduler.setDAGScheduler(dagScheduler) + val attempt1 = FakeTask.createTaskSet(1, 0) + val attempt2 = FakeTask.createTaskSet(1, 1) + taskScheduler.submitTasks(attempt1) + intercept[IllegalStateException] { taskScheduler.submitTasks(attempt2) } + + // OK to submit multiple if previous attempts are all zombie + taskScheduler.taskSetManagerForAttempt(attempt1.stageId, attempt1.stageAttemptId) + .get.isZombie = true + taskScheduler.submitTasks(attempt2) + val attempt3 = FakeTask.createTaskSet(1, 2) + intercept[IllegalStateException] { taskScheduler.submitTasks(attempt3) } + taskScheduler.taskSetManagerForAttempt(attempt2.stageId, attempt2.stageAttemptId) + .get.isZombie = true + taskScheduler.submitTasks(attempt3) + } + + test("don't schedule more tasks after a taskset is zombie") { + sc = new SparkContext("local", "TaskSchedulerImplSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + taskScheduler.initialize(new FakeSchedulerBackend) + // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. + new DAGScheduler(sc, taskScheduler) { + override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} + override def executorAdded(execId: String, host: String) {} + } + + val numFreeCores = 1 + val workerOffers = Seq(new WorkerOffer("executor0", "host0", numFreeCores)) + val attempt1 = FakeTask.createTaskSet(10) + + // submit attempt 1, offer some resources, some tasks get scheduled + taskScheduler.submitTasks(attempt1) + val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten + assert(1 === taskDescriptions.length) + + // now mark attempt 1 as a zombie + taskScheduler.taskSetManagerForAttempt(attempt1.stageId, attempt1.stageAttemptId) + .get.isZombie = true + + // don't schedule anything on another resource offer + val taskDescriptions2 = taskScheduler.resourceOffers(workerOffers).flatten + assert(0 === taskDescriptions2.length) + + // if we schedule another attempt for the same stage, it should get scheduled + val attempt2 = FakeTask.createTaskSet(10, 1) + + // submit attempt 2, offer some resources, some tasks get scheduled + taskScheduler.submitTasks(attempt2) + val taskDescriptions3 = taskScheduler.resourceOffers(workerOffers).flatten + assert(1 === taskDescriptions3.length) + val mgr = taskScheduler.taskIdToTaskSetManager.get(taskDescriptions3(0).taskId).get + assert(mgr.taskSet.stageAttemptId === 1) + } + + test("if a zombie attempt finishes, continue scheduling tasks for non-zombie attempts") { + sc = new SparkContext("local", "TaskSchedulerImplSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + taskScheduler.initialize(new FakeSchedulerBackend) + // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. + new DAGScheduler(sc, taskScheduler) { + override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} + override def executorAdded(execId: String, host: String) {} + } + + val numFreeCores = 10 + val workerOffers = Seq(new WorkerOffer("executor0", "host0", numFreeCores)) + val attempt1 = FakeTask.createTaskSet(10) + + // submit attempt 1, offer some resources, some tasks get scheduled + taskScheduler.submitTasks(attempt1) + val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten + assert(10 === taskDescriptions.length) + + // now mark attempt 1 as a zombie + val mgr1 = taskScheduler.taskSetManagerForAttempt(attempt1.stageId, attempt1.stageAttemptId).get + mgr1.isZombie = true + + // don't schedule anything on another resource offer + val taskDescriptions2 = taskScheduler.resourceOffers(workerOffers).flatten + assert(0 === taskDescriptions2.length) + + // submit attempt 2 + val attempt2 = FakeTask.createTaskSet(10, 1) + taskScheduler.submitTasks(attempt2) + + // attempt 1 finished (this can happen even if it was marked zombie earlier -- all tasks were + // already submitted, and then they finish) + taskScheduler.taskSetFinished(mgr1) + + // now with another resource offer, we should still schedule all the tasks in attempt2 + val taskDescriptions3 = taskScheduler.resourceOffers(workerOffers).flatten + assert(10 === taskDescriptions3.length) + + taskDescriptions3.foreach { task => + val mgr = taskScheduler.taskIdToTaskSetManager.get(task.taskId).get + assert(mgr.taskSet.stageAttemptId === 1) + } + } + } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 0060f3396dcde..3abb99c4b2b54 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -19,12 +19,13 @@ package org.apache.spark.scheduler import java.util.Random -import scala.collection.mutable.ArrayBuffer +import scala.collection.Map import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import org.apache.spark._ import org.apache.spark.executor.TaskMetrics -import org.apache.spark.util.{ManualClock, Utils} +import org.apache.spark.util.ManualClock class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) extends DAGScheduler(sc) { @@ -37,7 +38,7 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) task: Task[_], reason: TaskEndReason, result: Any, - accumUpdates: mutable.Map[Long, Any], + accumUpdates: Map[Long, Any], taskInfo: TaskInfo, taskMetrics: TaskMetrics) { taskScheduler.endedTasks(taskInfo.index) = reason @@ -135,7 +136,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex /** * A Task implementation that results in a large serialized task. */ -class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0) { +class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0) { val randomBuffer = new Array[Byte](TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024) val random = new Random(0) random.nextBytes(randomBuffer) diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala new file mode 100644 index 0000000000000..4b504df7b8851 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import java.util +import java.util.Collections + +import org.apache.mesos.Protos.Value.Scalar +import org.apache.mesos.Protos._ +import org.apache.mesos.{Protos, Scheduler, SchedulerDriver} +import org.mockito.Matchers._ +import org.mockito.Mockito._ +import org.mockito.Matchers +import org.scalatest.mock.MockitoSugar +import org.scalatest.BeforeAndAfter + +import org.apache.spark.scheduler.TaskSchedulerImpl +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} + +class CoarseMesosSchedulerBackendSuite extends SparkFunSuite + with LocalSparkContext + with MockitoSugar + with BeforeAndAfter { + + private def createOffer(offerId: String, slaveId: String, mem: Int, cpu: Int): Offer = { + val builder = Offer.newBuilder() + builder.addResourcesBuilder() + .setName("mem") + .setType(Value.Type.SCALAR) + .setScalar(Scalar.newBuilder().setValue(mem)) + builder.addResourcesBuilder() + .setName("cpus") + .setType(Value.Type.SCALAR) + .setScalar(Scalar.newBuilder().setValue(cpu)) + builder.setId(OfferID.newBuilder() + .setValue(offerId).build()) + .setFrameworkId(FrameworkID.newBuilder() + .setValue("f1")) + .setSlaveId(SlaveID.newBuilder().setValue(slaveId)) + .setHostname(s"host${slaveId}") + .build() + } + + private def createSchedulerBackend( + taskScheduler: TaskSchedulerImpl, + driver: SchedulerDriver): CoarseMesosSchedulerBackend = { + val backend = new CoarseMesosSchedulerBackend(taskScheduler, sc, "master") { + override protected def createSchedulerDriver( + masterUrl: String, + scheduler: Scheduler, + sparkUser: String, + appName: String, + conf: SparkConf, + webuiUrl: Option[String] = None, + checkpoint: Option[Boolean] = None, + failoverTimeout: Option[Double] = None, + frameworkId: Option[String] = None): SchedulerDriver = driver + markRegistered() + } + backend.start() + backend + } + + var sparkConf: SparkConf = _ + + before { + sparkConf = (new SparkConf) + .setMaster("local[*]") + .setAppName("test-mesos-dynamic-alloc") + .setSparkHome("/path") + + sc = new SparkContext(sparkConf) + } + + test("mesos supports killing and limiting executors") { + val driver = mock[SchedulerDriver] + when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) + val taskScheduler = mock[TaskSchedulerImpl] + when(taskScheduler.sc).thenReturn(sc) + + sparkConf.set("spark.driver.host", "driverHost") + sparkConf.set("spark.driver.port", "1234") + + val backend = createSchedulerBackend(taskScheduler, driver) + val minMem = backend.calculateTotalMemory(sc) + val minCpu = 4 + + val mesosOffers = new java.util.ArrayList[Offer] + mesosOffers.add(createOffer("o1", "s1", minMem, minCpu)) + + val taskID0 = TaskID.newBuilder().setValue("0").build() + + backend.resourceOffers(driver, mesosOffers) + verify(driver, times(1)).launchTasks( + Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)), + any[util.Collection[TaskInfo]], + any[Filters]) + + // simulate the allocation manager down-scaling executors + backend.doRequestTotalExecutors(0) + assert(backend.doKillExecutors(Seq("s1/0"))) + verify(driver, times(1)).killTask(taskID0) + + val mesosOffers2 = new java.util.ArrayList[Offer] + mesosOffers2.add(createOffer("o2", "s2", minMem, minCpu)) + backend.resourceOffers(driver, mesosOffers2) + + verify(driver, times(1)) + .declineOffer(OfferID.newBuilder().setValue("o2").build()) + + // Verify we didn't launch any new executor + assert(backend.slaveIdsWithExecutors.size === 1) + + backend.doRequestTotalExecutors(2) + backend.resourceOffers(driver, mesosOffers2) + verify(driver, times(1)).launchTasks( + Matchers.eq(Collections.singleton(mesosOffers2.get(0).getId)), + any[util.Collection[TaskInfo]], + any[Filters]) + + assert(backend.slaveIdsWithExecutors.size === 2) + backend.slaveLost(driver, SlaveID.newBuilder().setValue("s1").build()) + assert(backend.slaveIdsWithExecutors.size === 1) + } + + test("mesos supports killing and relaunching tasks with executors") { + val driver = mock[SchedulerDriver] + when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) + val taskScheduler = mock[TaskSchedulerImpl] + when(taskScheduler.sc).thenReturn(sc) + + val backend = createSchedulerBackend(taskScheduler, driver) + val minMem = backend.calculateTotalMemory(sc) + 1024 + val minCpu = 4 + + val mesosOffers = new java.util.ArrayList[Offer] + val offer1 = createOffer("o1", "s1", minMem, minCpu) + mesosOffers.add(offer1) + + val offer2 = createOffer("o2", "s1", minMem, 1); + + backend.resourceOffers(driver, mesosOffers) + + verify(driver, times(1)).launchTasks( + Matchers.eq(Collections.singleton(offer1.getId)), + anyObject(), + anyObject[Filters]) + + // Simulate task killed, executor no longer running + val status = TaskStatus.newBuilder() + .setTaskId(TaskID.newBuilder().setValue("0").build()) + .setSlaveId(SlaveID.newBuilder().setValue("s1").build()) + .setState(TaskState.TASK_KILLED) + .build + + backend.statusUpdate(driver, status) + assert(!backend.slaveIdsWithExecutors.contains("s1")) + + mesosOffers.clear() + mesosOffers.add(offer2) + backend.resourceOffers(driver, mesosOffers) + assert(backend.slaveIdsWithExecutors.contains("s1")) + + verify(driver, times(1)).launchTasks( + Matchers.eq(Collections.singleton(offer2.getId)), + anyObject(), + anyObject[Filters]) + + verify(driver, times(1)).reviveOffers() + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala deleted file mode 100644 index e72285d03d3ee..0000000000000 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.cluster.mesos - -import org.mockito.Mockito._ -import org.scalatest.mock.MockitoSugar - -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} - -class MemoryUtilsSuite extends SparkFunSuite with MockitoSugar { - test("MesosMemoryUtils should always override memoryOverhead when it's set") { - val sparkConf = new SparkConf - - val sc = mock[SparkContext] - when(sc.conf).thenReturn(sparkConf) - - // 384 > sc.executorMemory * 0.1 => 512 + 384 = 896 - when(sc.executorMemory).thenReturn(512) - assert(MemoryUtils.calculateTotalMemory(sc) === 896) - - // 384 < sc.executorMemory * 0.1 => 4096 + (4096 * 0.1) = 4505.6 - when(sc.executorMemory).thenReturn(4096) - assert(MemoryUtils.calculateTotalMemory(sc) === 4505) - - // set memoryOverhead - sparkConf.set("spark.mesos.executor.memoryOverhead", "100") - assert(MemoryUtils.calculateTotalMemory(sc) === 4196) - sparkConf.set("spark.mesos.executor.memoryOverhead", "400") - assert(MemoryUtils.calculateTotalMemory(sc) === 4496) - } -} diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala index 68df46a41ddc8..5ed30f64d705f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala @@ -21,6 +21,7 @@ import java.nio.ByteBuffer import java.util import java.util.Collections +import scala.collection.JavaConversions._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -60,14 +61,17 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi val mesosSchedulerBackend = new MesosSchedulerBackend(taskScheduler, sc, "master") + val resources = List( + mesosSchedulerBackend.createResource("cpus", 4), + mesosSchedulerBackend.createResource("mem", 1024)) // uri is null. - val executorInfo = mesosSchedulerBackend.createExecutorInfo("test-id") + val (executorInfo, _) = mesosSchedulerBackend.createExecutorInfo(resources, "test-id") assert(executorInfo.getCommand.getValue === s" /mesos-home/bin/spark-class ${classOf[MesosExecutorBackend].getName}") // uri exists. conf.set("spark.executor.uri", "hdfs:///test-app-1.0.0.tgz") - val executorInfo1 = mesosSchedulerBackend.createExecutorInfo("test-id") + val (executorInfo1, _) = mesosSchedulerBackend.createExecutorInfo(resources, "test-id") assert(executorInfo1.getCommand.getValue === s"cd test-app-1*; ./bin/spark-class ${classOf[MesosExecutorBackend].getName}") } @@ -93,7 +97,8 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi val backend = new MesosSchedulerBackend(taskScheduler, sc, "master") - val execInfo = backend.createExecutorInfo("mockExecutor") + val (execInfo, _) = backend.createExecutorInfo( + List(backend.createResource("cpus", 4)), "mockExecutor") assert(execInfo.getContainer.getDocker.getImage.equals("spark/mock")) val portmaps = execInfo.getContainer.getDocker.getPortMappingsList assert(portmaps.get(0).getHostPort.equals(80)) @@ -149,7 +154,9 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi when(sc.conf).thenReturn(new SparkConf) when(sc.listenerBus).thenReturn(listenerBus) - val minMem = MemoryUtils.calculateTotalMemory(sc).toInt + val backend = new MesosSchedulerBackend(taskScheduler, sc, "master") + + val minMem = backend.calculateTotalMemory(sc) val minCpu = 4 val mesosOffers = new java.util.ArrayList[Offer] @@ -157,8 +164,6 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi mesosOffers.add(createOffer(2, minMem - 1, minCpu)) mesosOffers.add(createOffer(3, minMem, minCpu)) - val backend = new MesosSchedulerBackend(taskScheduler, sc, "master") - val expectedWorkerOffers = new ArrayBuffer[WorkerOffer](2) expectedWorkerOffers.append(new WorkerOffer( mesosOffers.get(0).getSlaveId.getValue, @@ -194,7 +199,7 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi ) verify(driver, times(1)).declineOffer(mesosOffers.get(1).getId) verify(driver, times(1)).declineOffer(mesosOffers.get(2).getId) - assert(capture.getValue.size() == 1) + assert(capture.getValue.size() === 1) val taskInfo = capture.getValue.iterator().next() assert(taskInfo.getName.equals("n1")) val cpus = taskInfo.getResourcesList.get(0) @@ -214,4 +219,97 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi backend.resourceOffers(driver, mesosOffers2) verify(driver, times(1)).declineOffer(mesosOffers2.get(0).getId) } + + test("can handle multiple roles") { + val driver = mock[SchedulerDriver] + val taskScheduler = mock[TaskSchedulerImpl] + + val listenerBus = mock[LiveListenerBus] + listenerBus.post( + SparkListenerExecutorAdded(anyLong, "s1", new ExecutorInfo("host1", 2, Map.empty))) + + val sc = mock[SparkContext] + when(sc.executorMemory).thenReturn(100) + when(sc.getSparkHome()).thenReturn(Option("/path")) + when(sc.executorEnvs).thenReturn(new mutable.HashMap[String, String]) + when(sc.conf).thenReturn(new SparkConf) + when(sc.listenerBus).thenReturn(listenerBus) + + val id = 1 + val builder = Offer.newBuilder() + builder.addResourcesBuilder() + .setName("mem") + .setType(Value.Type.SCALAR) + .setRole("prod") + .setScalar(Scalar.newBuilder().setValue(500)) + builder.addResourcesBuilder() + .setName("cpus") + .setRole("prod") + .setType(Value.Type.SCALAR) + .setScalar(Scalar.newBuilder().setValue(1)) + builder.addResourcesBuilder() + .setName("mem") + .setRole("dev") + .setType(Value.Type.SCALAR) + .setScalar(Scalar.newBuilder().setValue(600)) + builder.addResourcesBuilder() + .setName("cpus") + .setRole("dev") + .setType(Value.Type.SCALAR) + .setScalar(Scalar.newBuilder().setValue(2)) + val offer = builder.setId(OfferID.newBuilder().setValue(s"o${id.toString}").build()) + .setFrameworkId(FrameworkID.newBuilder().setValue("f1")) + .setSlaveId(SlaveID.newBuilder().setValue(s"s${id.toString}")) + .setHostname(s"host${id.toString}").build() + + + val mesosOffers = new java.util.ArrayList[Offer] + mesosOffers.add(offer) + + val backend = new MesosSchedulerBackend(taskScheduler, sc, "master") + + val expectedWorkerOffers = new ArrayBuffer[WorkerOffer](1) + expectedWorkerOffers.append(new WorkerOffer( + mesosOffers.get(0).getSlaveId.getValue, + mesosOffers.get(0).getHostname, + 2 // Deducting 1 for executor + )) + + val taskDesc = new TaskDescription(1L, 0, "s1", "n1", 0, ByteBuffer.wrap(new Array[Byte](0))) + when(taskScheduler.resourceOffers(expectedWorkerOffers)).thenReturn(Seq(Seq(taskDesc))) + when(taskScheduler.CPUS_PER_TASK).thenReturn(1) + + val capture = ArgumentCaptor.forClass(classOf[util.Collection[TaskInfo]]) + when( + driver.launchTasks( + Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)), + capture.capture(), + any(classOf[Filters]) + ) + ).thenReturn(Status.valueOf(1)) + + backend.resourceOffers(driver, mesosOffers) + + verify(driver, times(1)).launchTasks( + Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)), + capture.capture(), + any(classOf[Filters]) + ) + + assert(capture.getValue.size() === 1) + val taskInfo = capture.getValue.iterator().next() + assert(taskInfo.getName.equals("n1")) + assert(taskInfo.getResourcesCount === 1) + val cpusDev = taskInfo.getResourcesList.get(0) + assert(cpusDev.getName.equals("cpus")) + assert(cpusDev.getScalar.getValue.equals(1.0)) + assert(cpusDev.getRole.equals("dev")) + val executorResources = taskInfo.getExecutor.getResourcesList + assert(executorResources.exists { r => + r.getName.equals("mem") && r.getScalar.getValue.equals(484.0) && r.getRole.equals("prod") + }) + assert(executorResources.exists { r => + r.getName.equals("cpus") && r.getScalar.getValue.equals(1.0) && r.getRole.equals("prod") + }) + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala new file mode 100644 index 0000000000000..b354914b6ffd0 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import org.apache.mesos.Protos.Value +import org.mockito.Mockito._ +import org.scalatest._ +import org.scalatest.mock.MockitoSugar +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} + +class MesosSchedulerUtilsSuite extends SparkFunSuite with Matchers with MockitoSugar { + + // scalastyle:off structural.type + // this is the documented way of generating fixtures in scalatest + def fixture: Object {val sc: SparkContext; val sparkConf: SparkConf} = new { + val sparkConf = new SparkConf + val sc = mock[SparkContext] + when(sc.conf).thenReturn(sparkConf) + } + val utils = new MesosSchedulerUtils { } + // scalastyle:on structural.type + + test("use at-least minimum overhead") { + val f = fixture + when(f.sc.executorMemory).thenReturn(512) + utils.calculateTotalMemory(f.sc) shouldBe 896 + } + + test("use overhead if it is greater than minimum value") { + val f = fixture + when(f.sc.executorMemory).thenReturn(4096) + utils.calculateTotalMemory(f.sc) shouldBe 4505 + } + + test("use spark.mesos.executor.memoryOverhead (if set)") { + val f = fixture + when(f.sc.executorMemory).thenReturn(1024) + f.sparkConf.set("spark.mesos.executor.memoryOverhead", "512") + utils.calculateTotalMemory(f.sc) shouldBe 1536 + } + + test("parse a non-empty constraint string correctly") { + val expectedMap = Map( + "tachyon" -> Set("true"), + "zone" -> Set("us-east-1a", "us-east-1b") + ) + utils.parseConstraintString("tachyon:true;zone:us-east-1a,us-east-1b") should be (expectedMap) + } + + test("parse an empty constraint string correctly") { + utils.parseConstraintString("") shouldBe Map() + } + + test("throw an exception when the input is malformed") { + an[IllegalArgumentException] should be thrownBy + utils.parseConstraintString("tachyon;zone:us-east") + } + + test("empty values for attributes' constraints matches all values") { + val constraintsStr = "tachyon:" + val parsedConstraints = utils.parseConstraintString(constraintsStr) + + parsedConstraints shouldBe Map("tachyon" -> Set()) + + val zoneSet = Value.Set.newBuilder().addItem("us-east-1a").addItem("us-east-1b").build() + val noTachyonOffer = Map("zone" -> zoneSet) + val tachyonTrueOffer = Map("tachyon" -> Value.Text.newBuilder().setValue("true").build()) + val tachyonFalseOffer = Map("tachyon" -> Value.Text.newBuilder().setValue("false").build()) + + utils.matchesAttributeRequirements(parsedConstraints, noTachyonOffer) shouldBe false + utils.matchesAttributeRequirements(parsedConstraints, tachyonTrueOffer) shouldBe true + utils.matchesAttributeRequirements(parsedConstraints, tachyonFalseOffer) shouldBe true + } + + test("subset match is performed for set attributes") { + val supersetConstraint = Map( + "tachyon" -> Value.Text.newBuilder().setValue("true").build(), + "zone" -> Value.Set.newBuilder() + .addItem("us-east-1a") + .addItem("us-east-1b") + .addItem("us-east-1c") + .build()) + + val zoneConstraintStr = "tachyon:;zone:us-east-1a,us-east-1c" + val parsedConstraints = utils.parseConstraintString(zoneConstraintStr) + + utils.matchesAttributeRequirements(parsedConstraints, supersetConstraint) shouldBe true + } + + test("less than equal match is performed on scalar attributes") { + val offerAttribs = Map("gpus" -> Value.Scalar.newBuilder().setValue(3).build()) + + val ltConstraint = utils.parseConstraintString("gpus:2") + val eqConstraint = utils.parseConstraintString("gpus:3") + val gtConstraint = utils.parseConstraintString("gpus:4") + + utils.matchesAttributeRequirements(ltConstraint, offerAttribs) shouldBe true + utils.matchesAttributeRequirements(eqConstraint, offerAttribs) shouldBe true + utils.matchesAttributeRequirements(gtConstraint, offerAttribs) shouldBe false + } + + test("contains match is performed for range attributes") { + val offerAttribs = Map("ports" -> Value.Range.newBuilder().setBegin(7000).setEnd(8000).build()) + val ltConstraint = utils.parseConstraintString("ports:6000") + val eqConstraint = utils.parseConstraintString("ports:7500") + val gtConstraint = utils.parseConstraintString("ports:8002") + val multiConstraint = utils.parseConstraintString("ports:5000,7500,8300") + + utils.matchesAttributeRequirements(ltConstraint, offerAttribs) shouldBe false + utils.matchesAttributeRequirements(eqConstraint, offerAttribs) shouldBe true + utils.matchesAttributeRequirements(gtConstraint, offerAttribs) shouldBe false + utils.matchesAttributeRequirements(multiConstraint, offerAttribs) shouldBe true + } + + test("equality match is performed for text attributes") { + val offerAttribs = Map("tachyon" -> Value.Text.newBuilder().setValue("true").build()) + + val trueConstraint = utils.parseConstraintString("tachyon:true") + val falseConstraint = utils.parseConstraintString("tachyon:false") + + utils.matchesAttributeRequirements(trueConstraint, offerAttribs) shouldBe true + utils.matchesAttributeRequirements(falseConstraint, offerAttribs) shouldBe false + } + +} diff --git a/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala new file mode 100644 index 0000000000000..bc9f3708ed69d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.nio.ByteBuffer + +import com.esotericsoftware.kryo.io.{Output, Input} +import org.apache.avro.{SchemaBuilder, Schema} +import org.apache.avro.generic.GenericData.Record + +import org.apache.spark.{SparkFunSuite, SharedSparkContext} + +class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext { + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + + val schema : Schema = SchemaBuilder + .record("testRecord").fields() + .requiredString("data") + .endRecord() + val record = new Record(schema) + record.put("data", "test data") + + test("schema compression and decompression") { + val genericSer = new GenericAvroSerializer(conf.getAvroSchema) + assert(schema === genericSer.decompress(ByteBuffer.wrap(genericSer.compress(schema)))) + } + + test("record serialization and deserialization") { + val genericSer = new GenericAvroSerializer(conf.getAvroSchema) + + val outputStream = new ByteArrayOutputStream() + val output = new Output(outputStream) + genericSer.serializeDatum(record, output) + output.flush() + output.close() + + val input = new Input(new ByteArrayInputStream(outputStream.toByteArray)) + assert(genericSer.deserializeDatum(input) === record) + } + + test("uses schema fingerprint to decrease message size") { + val genericSerFull = new GenericAvroSerializer(conf.getAvroSchema) + + val output = new Output(new ByteArrayOutputStream()) + + val beginningNormalPosition = output.total() + genericSerFull.serializeDatum(record, output) + output.flush() + val normalLength = output.total - beginningNormalPosition + + conf.registerAvroSchemas(schema) + val genericSerFinger = new GenericAvroSerializer(conf.getAvroSchema) + val beginningFingerprintPosition = output.total() + genericSerFinger.serializeDatum(record, output) + val fingerprintLength = output.total - beginningFingerprintPosition + + assert(fingerprintLength < normalLength) + } + + test("caches previously seen schemas") { + val genericSer = new GenericAvroSerializer(conf.getAvroSchema) + val compressedSchema = genericSer.compress(schema) + val decompressedScheam = genericSer.decompress(ByteBuffer.wrap(compressedSchema)) + + assert(compressedSchema.eq(genericSer.compress(schema))) + assert(decompressedScheam.eq(genericSer.decompress(ByteBuffer.wrap(compressedSchema)))) + } +} diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala index 63a8480c9b57b..935a091f14f9b 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala @@ -35,7 +35,7 @@ class KryoSerializerDistributedSuite extends SparkFunSuite { val jar = TestUtils.createJarWithClasses(List(AppJarRegistrator.customClassName)) conf.setJars(List(jar.getPath)) - val sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + val sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) val original = Thread.currentThread.getContextClassLoader val loader = new java.net.URLClassLoader(Array(jar), Utils.getContextOrSparkClassLoader) SparkEnv.get.serializer.setDefaultClassLoader(loader) @@ -59,7 +59,9 @@ object KryoDistributedTest { class AppJarRegistrator extends KryoRegistrator { override def registerClasses(k: Kryo) { val classLoader = Thread.currentThread.getContextClassLoader + // scalastyle:off classforname k.register(Class.forName(AppJarRegistrator.customClassName, true, classLoader)) + // scalastyle:on classforname } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala index 96778c9ebafb1..f495b6a037958 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala @@ -17,26 +17,39 @@ package org.apache.spark.shuffle +import java.util.concurrent.CountDownLatch +import java.util.concurrent.atomic.AtomicInteger + +import org.mockito.Mockito._ import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ -import java.util.concurrent.atomic.AtomicBoolean -import java.util.concurrent.CountDownLatch -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkFunSuite, TaskContext} class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { + + val nextTaskAttemptId = new AtomicInteger() + /** Launch a thread with the given body block and return it. */ private def startThread(name: String)(body: => Unit): Thread = { val thread = new Thread("ShuffleMemorySuite " + name) { override def run() { - body + try { + val taskAttemptId = nextTaskAttemptId.getAndIncrement + val mockTaskContext = mock(classOf[TaskContext], RETURNS_SMART_NULLS) + when(mockTaskContext.taskAttemptId()).thenReturn(taskAttemptId) + TaskContext.setTaskContext(mockTaskContext) + body + } finally { + TaskContext.unset() + } } } thread.start() thread } - test("single thread requesting memory") { + test("single task requesting memory") { val manager = new ShuffleMemoryManager(1000L) assert(manager.tryToAcquire(100L) === 100L) @@ -50,7 +63,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { assert(manager.tryToAcquire(300L) === 300L) assert(manager.tryToAcquire(300L) === 200L) - manager.releaseMemoryForThisThread() + manager.releaseMemoryForThisTask() assert(manager.tryToAcquire(1000L) === 1000L) assert(manager.tryToAcquire(100L) === 0L) } @@ -107,8 +120,8 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { } - test("threads cannot grow past 1 / N") { - // Two threads request 250 bytes first, wait for each other to get it, and then request + test("tasks cannot grow past 1 / N") { + // Two tasks request 250 bytes first, wait for each other to get it, and then request // 500 more; we should only grant 250 bytes to each of them on this second request val manager = new ShuffleMemoryManager(1000L) @@ -158,7 +171,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { assert(state.t2Result2 === 250L) } - test("threads can block to get at least 1 / 2N memory") { + test("tasks can block to get at least 1 / 2N memory") { // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps // for a bit and releases 250 bytes, which should then be granted to t2. Further requests // by t2 will return false right away because it now has 1 / 2N of the memory. @@ -224,7 +237,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { } } - test("releaseMemoryForThisThread") { + test("releaseMemoryForThisTask") { // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps // for a bit and releases all its memory. t2 should now be able to grab all the memory. @@ -251,9 +264,9 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { } } // Sleep a bit before releasing our memory; this is hacky but it would be difficult to make - // sure the other thread blocks for some time otherwise + // sure the other task blocks for some time otherwise Thread.sleep(300) - manager.releaseMemoryForThisThread() + manager.releaseMemoryForThisTask() } val t2 = startThread("t2") { @@ -282,7 +295,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { t2.join() } - // Both threads should've been able to acquire their memory; the second one will have waited + // Both tasks should've been able to acquire their memory; the second one will have waited // until the first one acquired 1000 bytes and then released all of it state.synchronized { assert(state.t1Result === 1000L, "t1 could not allocate memory") @@ -293,7 +306,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { } } - test("threads should not be granted a negative size") { + test("tasks should not be granted a negative size") { val manager = new ShuffleMemoryManager(1000L) manager.tryToAcquire(700L) diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala index 28ca68698e3dc..db718ecabbdb9 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala @@ -115,11 +115,15 @@ class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { // Make a mocked MapOutputTracker for the shuffle reader to use to determine what // shuffle data to read. val mapOutputTracker = mock(classOf[MapOutputTracker]) - // Test a scenario where all data is local, just to avoid creating a bunch of additional mocks - // for the code to read data over the network. - val statuses: Array[(BlockManagerId, Long)] = - Array.fill(numMaps)((localBlockManagerId, byteOutputStream.size().toLong)) - when(mapOutputTracker.getServerStatuses(shuffleId, reduceId)).thenReturn(statuses) + when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId)).thenReturn { + // Test a scenario where all data is local, to avoid creating a bunch of additional mocks + // for the code to read data over the network. + val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId => + val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) + (shuffleBlockId, byteOutputStream.size().toLong) + } + Seq((localBlockManagerId, shuffleBlockIdsAndSizes)) + } // Create a mocked shuffle handle to pass into HashShuffleReader. val shuffleHandle = { @@ -134,7 +138,7 @@ class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { shuffleHandle, reduceId, reduceId + 1, - new TaskContextImpl(0, 0, 0, 0, null), + new TaskContextImpl(0, 0, 0, 0, null, null), blockManager, mapOutputTracker) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 542f8f45125a4..cc7342f1ecd78 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -68,8 +68,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte any[SerializerInstance], anyInt(), any[ShuffleWriteMetrics] - )).thenAnswer(new Answer[BlockObjectWriter] { - override def answer(invocation: InvocationOnMock): BlockObjectWriter = { + )).thenAnswer(new Answer[DiskBlockObjectWriter] { + override def answer(invocation: InvocationOnMock): DiskBlockObjectWriter = { val args = invocation.getArguments new DiskBlockObjectWriter( args(0).asInstanceOf[BlockId], diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index bcee901f5dd5f..f480fd107a0c2 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -1004,32 +1004,32 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE store = makeBlockManager(12000) val memoryStore = store.memoryStore assert(memoryStore.currentUnrollMemory === 0) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Reserve - memoryStore.reserveUnrollMemoryForThisThread(100) - assert(memoryStore.currentUnrollMemoryForThisThread === 100) - memoryStore.reserveUnrollMemoryForThisThread(200) - assert(memoryStore.currentUnrollMemoryForThisThread === 300) - memoryStore.reserveUnrollMemoryForThisThread(500) - assert(memoryStore.currentUnrollMemoryForThisThread === 800) - memoryStore.reserveUnrollMemoryForThisThread(1000000) - assert(memoryStore.currentUnrollMemoryForThisThread === 800) // not granted + memoryStore.reserveUnrollMemoryForThisTask(100) + assert(memoryStore.currentUnrollMemoryForThisTask === 100) + memoryStore.reserveUnrollMemoryForThisTask(200) + assert(memoryStore.currentUnrollMemoryForThisTask === 300) + memoryStore.reserveUnrollMemoryForThisTask(500) + assert(memoryStore.currentUnrollMemoryForThisTask === 800) + memoryStore.reserveUnrollMemoryForThisTask(1000000) + assert(memoryStore.currentUnrollMemoryForThisTask === 800) // not granted // Release - memoryStore.releaseUnrollMemoryForThisThread(100) - assert(memoryStore.currentUnrollMemoryForThisThread === 700) - memoryStore.releaseUnrollMemoryForThisThread(100) - assert(memoryStore.currentUnrollMemoryForThisThread === 600) + memoryStore.releaseUnrollMemoryForThisTask(100) + assert(memoryStore.currentUnrollMemoryForThisTask === 700) + memoryStore.releaseUnrollMemoryForThisTask(100) + assert(memoryStore.currentUnrollMemoryForThisTask === 600) // Reserve again - memoryStore.reserveUnrollMemoryForThisThread(4400) - assert(memoryStore.currentUnrollMemoryForThisThread === 5000) - memoryStore.reserveUnrollMemoryForThisThread(20000) - assert(memoryStore.currentUnrollMemoryForThisThread === 5000) // not granted + memoryStore.reserveUnrollMemoryForThisTask(4400) + assert(memoryStore.currentUnrollMemoryForThisTask === 5000) + memoryStore.reserveUnrollMemoryForThisTask(20000) + assert(memoryStore.currentUnrollMemoryForThisTask === 5000) // not granted // Release again - memoryStore.releaseUnrollMemoryForThisThread(1000) - assert(memoryStore.currentUnrollMemoryForThisThread === 4000) - memoryStore.releaseUnrollMemoryForThisThread() // release all - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + memoryStore.releaseUnrollMemoryForThisTask(1000) + assert(memoryStore.currentUnrollMemoryForThisTask === 4000) + memoryStore.releaseUnrollMemoryForThisTask() // release all + assert(memoryStore.currentUnrollMemoryForThisTask === 0) } /** @@ -1060,24 +1060,24 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val bigList = List.fill(40)(new Array[Byte](1000)) val memoryStore = store.memoryStore val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Unroll with all the space in the world. This should succeed and return an array. var unrollResult = memoryStore.unrollSafely("unroll", smallList.iterator, droppedBlocks) verifyUnroll(smallList.iterator, unrollResult, shouldBeArray = true) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) - memoryStore.releasePendingUnrollMemoryForThisThread() + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + memoryStore.releasePendingUnrollMemoryForThisTask() // Unroll with not enough space. This should succeed after kicking out someBlock1. store.putIterator("someBlock1", smallList.iterator, StorageLevel.MEMORY_ONLY) store.putIterator("someBlock2", smallList.iterator, StorageLevel.MEMORY_ONLY) unrollResult = memoryStore.unrollSafely("unroll", smallList.iterator, droppedBlocks) verifyUnroll(smallList.iterator, unrollResult, shouldBeArray = true) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) assert(droppedBlocks.size === 1) assert(droppedBlocks.head._1 === TestBlockId("someBlock1")) droppedBlocks.clear() - memoryStore.releasePendingUnrollMemoryForThisThread() + memoryStore.releasePendingUnrollMemoryForThisTask() // Unroll huge block with not enough space. Even after ensuring free space of 12000 * 0.4 = // 4800 bytes, there is still not enough room to unroll this block. This returns an iterator. @@ -1085,7 +1085,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE store.putIterator("someBlock3", smallList.iterator, StorageLevel.MEMORY_ONLY) unrollResult = memoryStore.unrollSafely("unroll", bigList.iterator, droppedBlocks) verifyUnroll(bigList.iterator, unrollResult, shouldBeArray = false) - assert(memoryStore.currentUnrollMemoryForThisThread > 0) // we returned an iterator + assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator assert(droppedBlocks.size === 1) assert(droppedBlocks.head._1 === TestBlockId("someBlock2")) droppedBlocks.clear() @@ -1099,7 +1099,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val bigList = List.fill(40)(new Array[Byte](1000)) def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]] - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Unroll with plenty of space. This should succeed and cache both blocks. val result1 = memoryStore.putIterator("b1", smallIterator, memOnly, returnValues = true) @@ -1110,7 +1110,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(result2.size > 0) assert(result1.data.isLeft) // unroll did not drop this block to disk assert(result2.data.isLeft) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Re-put these two blocks so block manager knows about them too. Otherwise, block manager // would not know how to drop them from memory later. @@ -1126,7 +1126,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(!memoryStore.contains("b1")) assert(memoryStore.contains("b2")) assert(memoryStore.contains("b3")) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) memoryStore.remove("b3") store.putIterator("b3", smallIterator, memOnly) @@ -1138,7 +1138,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(!memoryStore.contains("b2")) assert(memoryStore.contains("b3")) assert(!memoryStore.contains("b4")) - assert(memoryStore.currentUnrollMemoryForThisThread > 0) // we returned an iterator + assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator } /** @@ -1153,7 +1153,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val bigList = List.fill(40)(new Array[Byte](1000)) def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]] - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) store.putIterator("b1", smallIterator, memAndDisk) store.putIterator("b2", smallIterator, memAndDisk) @@ -1170,7 +1170,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(!diskStore.contains("b3")) memoryStore.remove("b3") store.putIterator("b3", smallIterator, StorageLevel.MEMORY_ONLY) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Unroll huge block with not enough space. This should fail and drop the new block to disk // directly in addition to kicking out b2 in the process. Memory store should contain only @@ -1186,7 +1186,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(diskStore.contains("b2")) assert(!diskStore.contains("b3")) assert(diskStore.contains("b4")) - assert(memoryStore.currentUnrollMemoryForThisThread > 0) // we returned an iterator + assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator } test("multiple unrolls by the same thread") { @@ -1195,32 +1195,32 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val memoryStore = store.memoryStore val smallList = List.fill(40)(new Array[Byte](100)) def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // All unroll memory used is released because unrollSafely returned an array memoryStore.putIterator("b1", smallIterator, memOnly, returnValues = true) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) memoryStore.putIterator("b2", smallIterator, memOnly, returnValues = true) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Unroll memory is not released because unrollSafely returned an iterator // that still depends on the underlying vector used in the process memoryStore.putIterator("b3", smallIterator, memOnly, returnValues = true) - val unrollMemoryAfterB3 = memoryStore.currentUnrollMemoryForThisThread + val unrollMemoryAfterB3 = memoryStore.currentUnrollMemoryForThisTask assert(unrollMemoryAfterB3 > 0) // The unroll memory owned by this thread builds on top of its value after the previous unrolls memoryStore.putIterator("b4", smallIterator, memOnly, returnValues = true) - val unrollMemoryAfterB4 = memoryStore.currentUnrollMemoryForThisThread + val unrollMemoryAfterB4 = memoryStore.currentUnrollMemoryForThisTask assert(unrollMemoryAfterB4 > unrollMemoryAfterB3) // ... but only to a certain extent (until we run out of free space to grant new unroll memory) memoryStore.putIterator("b5", smallIterator, memOnly, returnValues = true) - val unrollMemoryAfterB5 = memoryStore.currentUnrollMemoryForThisThread + val unrollMemoryAfterB5 = memoryStore.currentUnrollMemoryForThisTask memoryStore.putIterator("b6", smallIterator, memOnly, returnValues = true) - val unrollMemoryAfterB6 = memoryStore.currentUnrollMemoryForThisThread + val unrollMemoryAfterB6 = memoryStore.currentUnrollMemoryForThisTask memoryStore.putIterator("b7", smallIterator, memOnly, returnValues = true) - val unrollMemoryAfterB7 = memoryStore.currentUnrollMemoryForThisThread + val unrollMemoryAfterB7 = memoryStore.currentUnrollMemoryForThisTask assert(unrollMemoryAfterB5 === unrollMemoryAfterB4) assert(unrollMemoryAfterB6 === unrollMemoryAfterB4) assert(unrollMemoryAfterB7 === unrollMemoryAfterB4) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockStatusListenerSuite.scala new file mode 100644 index 0000000000000..d7ffde1e7864e --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/BlockStatusListenerSuite.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import org.apache.spark.SparkFunSuite +import org.apache.spark.scheduler._ + +class BlockStatusListenerSuite extends SparkFunSuite { + + test("basic functions") { + val blockManagerId = BlockManagerId("0", "localhost", 10000) + val listener = new BlockStatusListener() + + // Add a block manager and a new block status + listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(0, blockManagerId, 0)) + listener.onBlockUpdated(SparkListenerBlockUpdated( + BlockUpdatedInfo( + blockManagerId, + StreamBlockId(0, 100), + StorageLevel.MEMORY_AND_DISK, + memSize = 100, + diskSize = 100, + externalBlockStoreSize = 0))) + // The new block status should be added to the listener + val expectedBlock = BlockUIData( + StreamBlockId(0, 100), + "localhost:10000", + StorageLevel.MEMORY_AND_DISK, + memSize = 100, + diskSize = 100, + externalBlockStoreSize = 0 + ) + val expectedExecutorStreamBlockStatus = Seq( + ExecutorStreamBlockStatus("0", "localhost:10000", Seq(expectedBlock)) + ) + assert(listener.allExecutorStreamBlockStatus === expectedExecutorStreamBlockStatus) + + // Add the second block manager + val blockManagerId2 = BlockManagerId("1", "localhost", 10001) + listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(0, blockManagerId2, 0)) + // Add a new replication of the same block id from the second manager + listener.onBlockUpdated(SparkListenerBlockUpdated( + BlockUpdatedInfo( + blockManagerId2, + StreamBlockId(0, 100), + StorageLevel.MEMORY_AND_DISK, + memSize = 100, + diskSize = 100, + externalBlockStoreSize = 0))) + val expectedBlock2 = BlockUIData( + StreamBlockId(0, 100), + "localhost:10001", + StorageLevel.MEMORY_AND_DISK, + memSize = 100, + diskSize = 100, + externalBlockStoreSize = 0 + ) + // Each block manager should contain one block + val expectedExecutorStreamBlockStatus2 = Set( + ExecutorStreamBlockStatus("0", "localhost:10000", Seq(expectedBlock)), + ExecutorStreamBlockStatus("1", "localhost:10001", Seq(expectedBlock2)) + ) + assert(listener.allExecutorStreamBlockStatus.toSet === expectedExecutorStreamBlockStatus2) + + // Remove a replication of the same block + listener.onBlockUpdated(SparkListenerBlockUpdated( + BlockUpdatedInfo( + blockManagerId2, + StreamBlockId(0, 100), + StorageLevel.NONE, // StorageLevel.NONE means removing it + memSize = 0, + diskSize = 0, + externalBlockStoreSize = 0))) + // Only the first block manager contains a block + val expectedExecutorStreamBlockStatus3 = Set( + ExecutorStreamBlockStatus("0", "localhost:10000", Seq(expectedBlock)), + ExecutorStreamBlockStatus("1", "localhost:10001", Seq.empty) + ) + assert(listener.allExecutorStreamBlockStatus.toSet === expectedExecutorStreamBlockStatus3) + + // Remove the second block manager at first but add a new block status + // from this removed block manager + listener.onBlockManagerRemoved(SparkListenerBlockManagerRemoved(0, blockManagerId2)) + listener.onBlockUpdated(SparkListenerBlockUpdated( + BlockUpdatedInfo( + blockManagerId2, + StreamBlockId(0, 100), + StorageLevel.MEMORY_AND_DISK, + memSize = 100, + diskSize = 100, + externalBlockStoreSize = 0))) + // The second block manager is removed so we should not see the new block + val expectedExecutorStreamBlockStatus4 = Seq( + ExecutorStreamBlockStatus("0", "localhost:10000", Seq(expectedBlock)) + ) + assert(listener.allExecutorStreamBlockStatus === expectedExecutorStreamBlockStatus4) + + // Remove the last block manager + listener.onBlockManagerRemoved(SparkListenerBlockManagerRemoved(0, blockManagerId)) + // No block manager now so we should dop all block managers + assert(listener.allExecutorStreamBlockStatus.isEmpty) + } + +} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala similarity index 98% rename from core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala rename to core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala index 7bdea724fea58..66af6e1a79740 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.serializer.JavaSerializer import org.apache.spark.util.Utils -class BlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { +class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { var tempDir: File = _ diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 9ced4148d7206..cf8bd8ae69625 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.{SparkFunSuite, TaskContextImpl} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.shuffle.BlockFetchingListener +import org.apache.spark.shuffle.FetchFailedException class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester { @@ -94,7 +95,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ) val iterator = new ShuffleBlockFetcherIterator( - new TaskContextImpl(0, 0, 0, 0, null), + new TaskContextImpl(0, 0, 0, 0, null, null), transfer, blockManager, blocksByAddress, @@ -106,13 +107,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT for (i <- 0 until 5) { assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements") val (blockId, inputStream) = iterator.next() - assert(inputStream.isSuccess, - s"iterator should have 5 elements defined but actually has $i elements") // Make sure we release buffers when a wrapped input stream is closed. val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId)) // Note: ShuffleBlockFetcherIterator wraps input streams in a BufferReleasingInputStream - val wrappedInputStream = inputStream.get.asInstanceOf[BufferReleasingInputStream] + val wrappedInputStream = inputStream.asInstanceOf[BufferReleasingInputStream] verify(mockBuf, times(0)).release() val delegateAccess = PrivateMethod[InputStream]('delegate) @@ -166,7 +165,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) - val taskContext = new TaskContextImpl(0, 0, 0, 0, null) + val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null) val iterator = new ShuffleBlockFetcherIterator( taskContext, transfer, @@ -175,11 +174,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024) verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() - iterator.next()._2.get.close() // close() first block's input stream + iterator.next()._2.close() // close() first block's input stream verify(blocks(ShuffleBlockId(0, 0, 0)), times(1)).release() // Get the 2nd block but do not exhaust the iterator - val subIter = iterator.next()._2.get + val subIter = iterator.next()._2 // Complete the task; then the 2nd block buffer should be exhausted verify(blocks(ShuffleBlockId(0, 1, 0)), times(0)).release() @@ -228,7 +227,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) - val taskContext = new TaskContextImpl(0, 0, 0, 0, null) + val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null) val iterator = new ShuffleBlockFetcherIterator( taskContext, transfer, @@ -239,9 +238,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Continue only after the mock calls onBlockFetchFailure sem.acquire() - // The first block should be defined, and the last two are not defined (due to failure) - assert(iterator.next()._2.isSuccess) - assert(iterator.next()._2.isFailure) - assert(iterator.next()._2.isFailure) + // The first block should be returned without an exception, and the last two should throw + // FetchFailedExceptions (due to failure) + iterator.next() + intercept[FetchFailedException] { iterator.next() } + intercept[FetchFailedException] { iterator.next() } } } diff --git a/core/src/test/scala/org/apache/spark/ui/PagedTableSuite.scala b/core/src/test/scala/org/apache/spark/ui/PagedTableSuite.scala new file mode 100644 index 0000000000000..cc76c141c53cc --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ui/PagedTableSuite.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui + +import scala.xml.Node + +import org.apache.spark.SparkFunSuite + +class PagedDataSourceSuite extends SparkFunSuite { + + test("basic") { + val dataSource1 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2) + assert(dataSource1.pageData(1) === PageData(3, (1 to 2))) + + val dataSource2 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2) + assert(dataSource2.pageData(2) === PageData(3, (3 to 4))) + + val dataSource3 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2) + assert(dataSource3.pageData(3) === PageData(3, Seq(5))) + + val dataSource4 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2) + val e1 = intercept[IndexOutOfBoundsException] { + dataSource4.pageData(4) + } + assert(e1.getMessage === "Page 4 is out of range. Please select a page number between 1 and 3.") + + val dataSource5 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2) + val e2 = intercept[IndexOutOfBoundsException] { + dataSource5.pageData(0) + } + assert(e2.getMessage === "Page 0 is out of range. Please select a page number between 1 and 3.") + + } +} + +class PagedTableSuite extends SparkFunSuite { + test("pageNavigation") { + // Create a fake PagedTable to test pageNavigation + val pagedTable = new PagedTable[Int] { + override def tableId: String = "" + + override def tableCssClass: String = "" + + override def dataSource: PagedDataSource[Int] = null + + override def pageLink(page: Int): String = page.toString + + override def headers: Seq[Node] = Nil + + override def row(t: Int): Seq[Node] = Nil + + override def goButtonJavascriptFunction: (String, String) = ("", "") + } + + assert(pagedTable.pageNavigation(1, 10, 1) === Nil) + assert( + (pagedTable.pageNavigation(1, 10, 2).head \\ "li").map(_.text.trim) === Seq("1", "2", ">")) + assert( + (pagedTable.pageNavigation(2, 10, 2).head \\ "li").map(_.text.trim) === Seq("<", "1", "2")) + + assert((pagedTable.pageNavigation(1, 10, 100).head \\ "li").map(_.text.trim) === + (1 to 10).map(_.toString) ++ Seq(">", ">>")) + assert((pagedTable.pageNavigation(2, 10, 100).head \\ "li").map(_.text.trim) === + Seq("<") ++ (1 to 10).map(_.toString) ++ Seq(">", ">>")) + + assert((pagedTable.pageNavigation(100, 10, 100).head \\ "li").map(_.text.trim) === + Seq("<<", "<") ++ (91 to 100).map(_.toString)) + assert((pagedTable.pageNavigation(99, 10, 100).head \\ "li").map(_.text.trim) === + Seq("<<", "<") ++ (91 to 100).map(_.toString) ++ Seq(">")) + + assert((pagedTable.pageNavigation(11, 10, 100).head \\ "li").map(_.text.trim) === + Seq("<<", "<") ++ (11 to 20).map(_.toString) ++ Seq(">", ">>")) + assert((pagedTable.pageNavigation(93, 10, 97).head \\ "li").map(_.text.trim) === + Seq("<<", "<") ++ (91 to 97).map(_.toString) ++ Seq(">")) + } +} + +private[spark] class SeqPagedDataSource[T](seq: Seq[T], pageSize: Int) + extends PagedDataSource[T](pageSize) { + + override protected def dataSize: Int = seq.size + + override protected def sliceData(from: Int, to: Int): Seq[T] = seq.slice(from, to) +} diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala new file mode 100644 index 0000000000000..3dab15a9d4691 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.storage + +import scala.xml.Utility + +import org.mockito.Mockito._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.storage._ + +class StoragePageSuite extends SparkFunSuite { + + val storageTab = mock(classOf[StorageTab]) + when(storageTab.basePath).thenReturn("http://localhost:4040") + val storagePage = new StoragePage(storageTab) + + test("rddTable") { + val rdd1 = new RDDInfo(1, + "rdd1", + 10, + StorageLevel.MEMORY_ONLY, + Seq.empty) + rdd1.memSize = 100 + rdd1.numCachedPartitions = 10 + + val rdd2 = new RDDInfo(2, + "rdd2", + 10, + StorageLevel.DISK_ONLY, + Seq.empty) + rdd2.diskSize = 200 + rdd2.numCachedPartitions = 5 + + val rdd3 = new RDDInfo(3, + "rdd3", + 10, + StorageLevel.MEMORY_AND_DISK_SER, + Seq.empty) + rdd3.memSize = 400 + rdd3.diskSize = 500 + rdd3.numCachedPartitions = 10 + + val xmlNodes = storagePage.rddTable(Seq(rdd1, rdd2, rdd3)) + + val headers = Seq( + "RDD Name", + "Storage Level", + "Cached Partitions", + "Fraction Cached", + "Size in Memory", + "Size in ExternalBlockStore", + "Size on Disk") + assert((xmlNodes \\ "th").map(_.text) === headers) + + assert((xmlNodes \\ "tr").size === 3) + assert(((xmlNodes \\ "tr")(0) \\ "td").map(_.text.trim) === + Seq("rdd1", "Memory Deserialized 1x Replicated", "10", "100%", "100.0 B", "0.0 B", "0.0 B")) + // Check the url + assert(((xmlNodes \\ "tr")(0) \\ "td" \ "a")(0).attribute("href").map(_.text) === + Some("http://localhost:4040/storage/rdd?id=1")) + + assert(((xmlNodes \\ "tr")(1) \\ "td").map(_.text.trim) === + Seq("rdd2", "Disk Serialized 1x Replicated", "5", "50%", "0.0 B", "0.0 B", "200.0 B")) + // Check the url + assert(((xmlNodes \\ "tr")(1) \\ "td" \ "a")(0).attribute("href").map(_.text) === + Some("http://localhost:4040/storage/rdd?id=2")) + + assert(((xmlNodes \\ "tr")(2) \\ "td").map(_.text.trim) === + Seq("rdd3", "Disk Memory Serialized 1x Replicated", "10", "100%", "400.0 B", "0.0 B", + "500.0 B")) + // Check the url + assert(((xmlNodes \\ "tr")(2) \\ "td" \ "a")(0).attribute("href").map(_.text) === + Some("http://localhost:4040/storage/rdd?id=3")) + } + + test("empty rddTable") { + assert(storagePage.rddTable(Seq.empty).isEmpty) + } + + test("streamBlockStorageLevelDescriptionAndSize") { + val memoryBlock = BlockUIData(StreamBlockId(0, 0), + "localhost:1111", + StorageLevel.MEMORY_ONLY, + memSize = 100, + diskSize = 0, + externalBlockStoreSize = 0) + assert(("Memory", 100) === storagePage.streamBlockStorageLevelDescriptionAndSize(memoryBlock)) + + val memorySerializedBlock = BlockUIData(StreamBlockId(0, 0), + "localhost:1111", + StorageLevel.MEMORY_ONLY_SER, + memSize = 100, + diskSize = 0, + externalBlockStoreSize = 0) + assert(("Memory Serialized", 100) === + storagePage.streamBlockStorageLevelDescriptionAndSize(memorySerializedBlock)) + + val diskBlock = BlockUIData(StreamBlockId(0, 0), + "localhost:1111", + StorageLevel.DISK_ONLY, + memSize = 0, + diskSize = 100, + externalBlockStoreSize = 0) + assert(("Disk", 100) === storagePage.streamBlockStorageLevelDescriptionAndSize(diskBlock)) + + val externalBlock = BlockUIData(StreamBlockId(0, 0), + "localhost:1111", + StorageLevel.OFF_HEAP, + memSize = 0, + diskSize = 0, + externalBlockStoreSize = 100) + assert(("External", 100) === + storagePage.streamBlockStorageLevelDescriptionAndSize(externalBlock)) + } + + test("receiverBlockTables") { + val blocksForExecutor0 = Seq( + BlockUIData(StreamBlockId(0, 0), + "localhost:10000", + StorageLevel.MEMORY_ONLY, + memSize = 100, + diskSize = 0, + externalBlockStoreSize = 0), + BlockUIData(StreamBlockId(1, 1), + "localhost:10000", + StorageLevel.DISK_ONLY, + memSize = 0, + diskSize = 100, + externalBlockStoreSize = 0) + ) + val executor0 = ExecutorStreamBlockStatus("0", "localhost:10000", blocksForExecutor0) + + val blocksForExecutor1 = Seq( + BlockUIData(StreamBlockId(0, 0), + "localhost:10001", + StorageLevel.MEMORY_ONLY, + memSize = 100, + diskSize = 0, + externalBlockStoreSize = 0), + BlockUIData(StreamBlockId(2, 2), + "localhost:10001", + StorageLevel.OFF_HEAP, + memSize = 0, + diskSize = 0, + externalBlockStoreSize = 200), + BlockUIData(StreamBlockId(1, 1), + "localhost:10001", + StorageLevel.MEMORY_ONLY_SER, + memSize = 100, + diskSize = 0, + externalBlockStoreSize = 0) + ) + val executor1 = ExecutorStreamBlockStatus("1", "localhost:10001", blocksForExecutor1) + val xmlNodes = storagePage.receiverBlockTables(Seq(executor0, executor1)) + + val executorTable = (xmlNodes \\ "table")(0) + val executorHeaders = Seq( + "Executor ID", + "Address", + "Total Size in Memory", + "Total Size in ExternalBlockStore", + "Total Size on Disk", + "Stream Blocks") + assert((executorTable \\ "th").map(_.text) === executorHeaders) + + assert((executorTable \\ "tr").size === 2) + assert(((executorTable \\ "tr")(0) \\ "td").map(_.text.trim) === + Seq("0", "localhost:10000", "100.0 B", "0.0 B", "100.0 B", "2")) + assert(((executorTable \\ "tr")(1) \\ "td").map(_.text.trim) === + Seq("1", "localhost:10001", "200.0 B", "200.0 B", "0.0 B", "3")) + + val blockTable = (xmlNodes \\ "table")(1) + val blockHeaders = Seq( + "Block ID", + "Replication Level", + "Location", + "Storage Level", + "Size") + assert((blockTable \\ "th").map(_.text) === blockHeaders) + + assert((blockTable \\ "tr").size === 5) + assert(((blockTable \\ "tr")(0) \\ "td").map(_.text.trim) === + Seq("input-0-0", "2", "localhost:10000", "Memory", "100.0 B")) + // Check "rowspan=2" for the first 2 columns + assert(((blockTable \\ "tr")(0) \\ "td")(0).attribute("rowspan").map(_.text) === Some("2")) + assert(((blockTable \\ "tr")(0) \\ "td")(1).attribute("rowspan").map(_.text) === Some("2")) + + assert(((blockTable \\ "tr")(1) \\ "td").map(_.text.trim) === + Seq("localhost:10001", "Memory", "100.0 B")) + + assert(((blockTable \\ "tr")(2) \\ "td").map(_.text.trim) === + Seq("input-1-1", "2", "localhost:10000", "Disk", "100.0 B")) + // Check "rowspan=2" for the first 2 columns + assert(((blockTable \\ "tr")(2) \\ "td")(0).attribute("rowspan").map(_.text) === Some("2")) + assert(((blockTable \\ "tr")(2) \\ "td")(1).attribute("rowspan").map(_.text) === Some("2")) + + assert(((blockTable \\ "tr")(3) \\ "td").map(_.text.trim) === + Seq("localhost:10001", "Memory Serialized", "100.0 B")) + + assert(((blockTable \\ "tr")(4) \\ "td").map(_.text.trim) === + Seq("input-2-2", "1", "localhost:10001", "External", "200.0 B")) + // Check "rowspan=1" for the first 2 columns + assert(((blockTable \\ "tr")(4) \\ "td")(0).attribute("rowspan").map(_.text) === Some("1")) + assert(((blockTable \\ "tr")(4) \\ "td")(1).attribute("rowspan").map(_.text) === Some("1")) + } + + test("empty receiverBlockTables") { + assert(storagePage.receiverBlockTables(Seq.empty).isEmpty) + + val executor0 = ExecutorStreamBlockStatus("0", "localhost:10000", Seq.empty) + val executor1 = ExecutorStreamBlockStatus("1", "localhost:10001", Seq.empty) + assert(storagePage.receiverBlockTables(Seq(executor0, executor1)).isEmpty) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala index 6c40685484ed4..61601016e005e 100644 --- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.util +import scala.collection.mutable.ArrayBuffer + import java.util.concurrent.TimeoutException import akka.actor.ActorNotFound @@ -24,7 +26,7 @@ import akka.actor.ActorNotFound import org.apache.spark._ import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.MapStatus -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} import org.apache.spark.SSLSampleConfigs._ @@ -107,8 +109,9 @@ class AkkaUtilsSuite extends SparkFunSuite with LocalSparkContext with ResetSyst slaveTracker.updateEpoch(masterTracker.getEpoch) // this should succeed since security off - assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000), size1000))) + assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq === + Seq((BlockManagerId("a", "hostA", 1000), + ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) rpcEnv.shutdown() slaveRpcEnv.shutdown() @@ -153,8 +156,9 @@ class AkkaUtilsSuite extends SparkFunSuite with LocalSparkContext with ResetSyst slaveTracker.updateEpoch(masterTracker.getEpoch) // this should succeed since security on and passwords match - assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000), size1000))) + assert(slaveTracker.getMapSizesByExecutorId(10, 0) === + Seq((BlockManagerId("a", "hostA", 1000), + ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) rpcEnv.shutdown() slaveRpcEnv.shutdown() @@ -232,8 +236,8 @@ class AkkaUtilsSuite extends SparkFunSuite with LocalSparkContext with ResetSyst slaveTracker.updateEpoch(masterTracker.getEpoch) // this should succeed since security off - assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000), size1000))) + assert(slaveTracker.getMapSizesByExecutorId(10, 0) === + Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) rpcEnv.shutdown() slaveRpcEnv.shutdown() @@ -278,8 +282,8 @@ class AkkaUtilsSuite extends SparkFunSuite with LocalSparkContext with ResetSyst masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) - assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000), size1000))) + assert(slaveTracker.getMapSizesByExecutorId(10, 0) === + Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) rpcEnv.shutdown() slaveRpcEnv.shutdown() diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 1053c6caf7718..480722a5ac182 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -375,6 +375,7 @@ class TestCreateNullValue { // parameters of the closure constructor. This allows us to test whether // null values are created correctly for each type. val nestedClosure = () => { + // scalastyle:off println if (s.toString == "123") { // Don't really output them to avoid noisy println(bo) println(c) @@ -389,6 +390,7 @@ class TestCreateNullValue { val closure = () => { println(getX) } + // scalastyle:on println ClosureCleaner.clean(closure) } nestedClosure() diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala index 3147c937769d2..a829b099025e9 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala @@ -120,8 +120,8 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri // Accessors for private methods private val _isClosure = PrivateMethod[Boolean]('isClosure) private val _getInnerClosureClasses = PrivateMethod[List[Class[_]]]('getInnerClosureClasses) - private val _getOuterClasses = PrivateMethod[List[Class[_]]]('getOuterClasses) - private val _getOuterObjects = PrivateMethod[List[AnyRef]]('getOuterObjects) + private val _getOuterClassesAndObjects = + PrivateMethod[(List[Class[_]], List[AnyRef])]('getOuterClassesAndObjects) private def isClosure(obj: AnyRef): Boolean = { ClosureCleaner invokePrivate _isClosure(obj) @@ -131,12 +131,8 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri ClosureCleaner invokePrivate _getInnerClosureClasses(closure) } - private def getOuterClasses(closure: AnyRef): List[Class[_]] = { - ClosureCleaner invokePrivate _getOuterClasses(closure) - } - - private def getOuterObjects(closure: AnyRef): List[AnyRef] = { - ClosureCleaner invokePrivate _getOuterObjects(closure) + private def getOuterClassesAndObjects(closure: AnyRef): (List[Class[_]], List[AnyRef]) = { + ClosureCleaner invokePrivate _getOuterClassesAndObjects(closure) } test("get inner closure classes") { @@ -171,14 +167,11 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri val closure2 = () => localValue val closure3 = () => someSerializableValue val closure4 = () => someSerializableMethod() - val outerClasses1 = getOuterClasses(closure1) - val outerClasses2 = getOuterClasses(closure2) - val outerClasses3 = getOuterClasses(closure3) - val outerClasses4 = getOuterClasses(closure4) - val outerObjects1 = getOuterObjects(closure1) - val outerObjects2 = getOuterObjects(closure2) - val outerObjects3 = getOuterObjects(closure3) - val outerObjects4 = getOuterObjects(closure4) + + val (outerClasses1, outerObjects1) = getOuterClassesAndObjects(closure1) + val (outerClasses2, outerObjects2) = getOuterClassesAndObjects(closure2) + val (outerClasses3, outerObjects3) = getOuterClassesAndObjects(closure3) + val (outerClasses4, outerObjects4) = getOuterClassesAndObjects(closure4) // The classes and objects should have the same size assert(outerClasses1.size === outerObjects1.size) @@ -211,10 +204,8 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri val x = 1 val closure1 = () => 1 val closure2 = () => x - val outerClasses1 = getOuterClasses(closure1) - val outerClasses2 = getOuterClasses(closure2) - val outerObjects1 = getOuterObjects(closure1) - val outerObjects2 = getOuterObjects(closure2) + val (outerClasses1, outerObjects1) = getOuterClassesAndObjects(closure1) + val (outerClasses2, outerObjects2) = getOuterClassesAndObjects(closure2) assert(outerClasses1.size === outerObjects1.size) assert(outerClasses2.size === outerObjects2.size) // These inner closures only reference local variables, and so do not have $outer pointers @@ -227,12 +218,9 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri val closure1 = () => 1 val closure2 = () => y val closure3 = () => localValue - val outerClasses1 = getOuterClasses(closure1) - val outerClasses2 = getOuterClasses(closure2) - val outerClasses3 = getOuterClasses(closure3) - val outerObjects1 = getOuterObjects(closure1) - val outerObjects2 = getOuterObjects(closure2) - val outerObjects3 = getOuterObjects(closure3) + val (outerClasses1, outerObjects1) = getOuterClassesAndObjects(closure1) + val (outerClasses2, outerObjects2) = getOuterClassesAndObjects(closure2) + val (outerClasses3, outerObjects3) = getOuterClassesAndObjects(closure3) assert(outerClasses1.size === outerObjects1.size) assert(outerClasses2.size === outerObjects2.size) assert(outerClasses3.size === outerObjects3.size) @@ -265,9 +253,9 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri val closure1 = () => 1 val closure2 = () => localValue val closure3 = () => someSerializableValue - val outerClasses1 = getOuterClasses(closure1) - val outerClasses2 = getOuterClasses(closure2) - val outerClasses3 = getOuterClasses(closure3) + val (outerClasses1, _) = getOuterClassesAndObjects(closure1) + val (outerClasses2, _) = getOuterClassesAndObjects(closure2) + val (outerClasses3, _) = getOuterClassesAndObjects(closure3) val fields1 = findAccessedFields(closure1, outerClasses1, findTransitively = false) val fields2 = findAccessedFields(closure2, outerClasses2, findTransitively = false) @@ -307,10 +295,10 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri val closure2 = () => a val closure3 = () => localValue val closure4 = () => someSerializableValue - val outerClasses1 = getOuterClasses(closure1) - val outerClasses2 = getOuterClasses(closure2) - val outerClasses3 = getOuterClasses(closure3) - val outerClasses4 = getOuterClasses(closure4) + val (outerClasses1, _) = getOuterClassesAndObjects(closure1) + val (outerClasses2, _) = getOuterClassesAndObjects(closure2) + val (outerClasses3, _) = getOuterClassesAndObjects(closure3) + val (outerClasses4, _) = getOuterClassesAndObjects(closure4) // First, find only fields accessed directly, not transitively, by these closures val fields1 = findAccessedFields(closure1, outerClasses1, findTransitively = false) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index e0ef9c70a5fc3..dde95f3778434 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -83,6 +83,9 @@ class JsonProtocolSuite extends SparkFunSuite { val executorAdded = SparkListenerExecutorAdded(executorAddedTime, "exec1", new ExecutorInfo("Hostee.awesome.com", 11, logUrlMap)) val executorRemoved = SparkListenerExecutorRemoved(executorRemovedTime, "exec2", "test reason") + val executorMetricsUpdate = SparkListenerExecutorMetricsUpdate("exec3", Seq( + (1L, 2, 3, makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800, + hasHadoopInput = true, hasOutput = true)))) testEvent(stageSubmitted, stageSubmittedJsonString) testEvent(stageCompleted, stageCompletedJsonString) @@ -102,6 +105,7 @@ class JsonProtocolSuite extends SparkFunSuite { testEvent(applicationEnd, applicationEndJsonString) testEvent(executorAdded, executorAddedJsonString) testEvent(executorRemoved, executorRemovedJsonString) + testEvent(executorMetricsUpdate, executorMetricsUpdateJsonString) } test("Dependent Classes") { @@ -440,10 +444,20 @@ class JsonProtocolSuite extends SparkFunSuite { case (e1: SparkListenerEnvironmentUpdate, e2: SparkListenerEnvironmentUpdate) => assertEquals(e1.environmentDetails, e2.environmentDetails) case (e1: SparkListenerExecutorAdded, e2: SparkListenerExecutorAdded) => - assert(e1.executorId == e1.executorId) + assert(e1.executorId === e1.executorId) assertEquals(e1.executorInfo, e2.executorInfo) case (e1: SparkListenerExecutorRemoved, e2: SparkListenerExecutorRemoved) => - assert(e1.executorId == e1.executorId) + assert(e1.executorId === e1.executorId) + case (e1: SparkListenerExecutorMetricsUpdate, e2: SparkListenerExecutorMetricsUpdate) => + assert(e1.execId === e2.execId) + assertSeqEquals[(Long, Int, Int, TaskMetrics)](e1.taskMetrics, e2.taskMetrics, (a, b) => { + val (taskId1, stageId1, stageAttemptId1, metrics1) = a + val (taskId2, stageId2, stageAttemptId2, metrics2) = b + assert(taskId1 === taskId2) + assert(stageId1 === stageId2) + assert(stageAttemptId1 === stageAttemptId2) + assertEquals(metrics1, metrics2) + }) case (e1, e2) => assert(e1 === e2) case _ => fail("Events don't match in types!") @@ -1598,4 +1612,55 @@ class JsonProtocolSuite extends SparkFunSuite { | "Removed Reason": "test reason" |} """ + + private val executorMetricsUpdateJsonString = + s""" + |{ + | "Event": "SparkListenerExecutorMetricsUpdate", + | "Executor ID": "exec3", + | "Metrics Updated": [ + | { + | "Task ID": 1, + | "Stage ID": 2, + | "Stage Attempt ID": 3, + | "Task Metrics": { + | "Host Name": "localhost", + | "Executor Deserialize Time": 300, + | "Executor Run Time": 400, + | "Result Size": 500, + | "JVM GC Time": 600, + | "Result Serialization Time": 700, + | "Memory Bytes Spilled": 800, + | "Disk Bytes Spilled": 0, + | "Input Metrics": { + | "Data Read Method": "Hadoop", + | "Bytes Read": 2100, + | "Records Read": 21 + | }, + | "Output Metrics": { + | "Data Write Method": "Hadoop", + | "Bytes Written": 1200, + | "Records Written": 12 + | }, + | "Updated Blocks": [ + | { + | "Block ID": "rdd_0_0", + | "Status": { + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use ExternalBlockStore": false, + | "Deserialized": false, + | "Replication": 2 + | }, + | "Memory Size": 0, + | "ExternalBlockStore Size": 0, + | "Disk Size": 0 + | } + | } + | ] + | } + | }] + |} + """.stripMargin } diff --git a/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala index 42125547436cb..d3d464e84ffd7 100644 --- a/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala @@ -84,7 +84,9 @@ class MutableURLClassLoaderSuite extends SparkFunSuite { try { sc.makeRDD(1 to 5, 2).mapPartitions { x => val loader = Thread.currentThread().getContextClassLoader + // scalastyle:off classforname Class.forName(className, true, loader).newInstance() + // scalastyle:on classforname Seq().iterator }.count() } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 251a797dc28a2..8f7e402d5f2a6 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.util import java.io.{File, ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream} +import java.lang.{Double => JDouble, Float => JFloat} import java.net.{BindException, ServerSocket, URI} import java.nio.{ByteBuffer, ByteOrder} import java.text.DecimalFormatSymbols @@ -684,7 +685,39 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { val buffer = new CircularBuffer(25) val stream = new java.io.PrintStream(buffer, true, "UTF-8") + // scalastyle:off println stream.println("test circular test circular test circular test circular test circular") + // scalastyle:on println assert(buffer.toString === "t circular test circular\n") } + + test("nanSafeCompareDoubles") { + def shouldMatchDefaultOrder(a: Double, b: Double): Unit = { + assert(Utils.nanSafeCompareDoubles(a, b) === JDouble.compare(a, b)) + assert(Utils.nanSafeCompareDoubles(b, a) === JDouble.compare(b, a)) + } + shouldMatchDefaultOrder(0d, 0d) + shouldMatchDefaultOrder(0d, 1d) + shouldMatchDefaultOrder(Double.MinValue, Double.MaxValue) + assert(Utils.nanSafeCompareDoubles(Double.NaN, Double.NaN) === 0) + assert(Utils.nanSafeCompareDoubles(Double.NaN, Double.PositiveInfinity) === 1) + assert(Utils.nanSafeCompareDoubles(Double.NaN, Double.NegativeInfinity) === 1) + assert(Utils.nanSafeCompareDoubles(Double.PositiveInfinity, Double.NaN) === -1) + assert(Utils.nanSafeCompareDoubles(Double.NegativeInfinity, Double.NaN) === -1) + } + + test("nanSafeCompareFloats") { + def shouldMatchDefaultOrder(a: Float, b: Float): Unit = { + assert(Utils.nanSafeCompareFloats(a, b) === JFloat.compare(a, b)) + assert(Utils.nanSafeCompareFloats(b, a) === JFloat.compare(b, a)) + } + shouldMatchDefaultOrder(0f, 0f) + shouldMatchDefaultOrder(1f, 1f) + shouldMatchDefaultOrder(Float.MinValue, Float.MaxValue) + assert(Utils.nanSafeCompareFloats(Float.NaN, Float.NaN) === 0) + assert(Utils.nanSafeCompareFloats(Float.NaN, Float.PositiveInfinity) === 1) + assert(Utils.nanSafeCompareFloats(Float.NaN, Float.NegativeInfinity) === 1) + assert(Utils.nanSafeCompareFloats(Float.PositiveInfinity, Float.NaN) === -1) + assert(Utils.nanSafeCompareFloats(Float.NegativeInfinity, Float.NaN) === -1) + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 79eba61a87251..9c362f0de7076 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -244,7 +244,7 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { private def testSimpleSpilling(codec: Option[String] = None): Unit = { val conf = createSparkConf(loadDefaults = true, codec) // Load defaults for Spark home conf.set("spark.shuffle.memoryFraction", "0.001") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) // reduceByKey - should spill ~8 times val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i)) @@ -292,7 +292,7 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { test("spilling with hash collisions") { val conf = createSparkConf(loadDefaults = true) conf.set("spark.shuffle.memoryFraction", "0.001") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) val map = createExternalMap[String] val collisionPairs = Seq( @@ -341,7 +341,7 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { test("spilling with many hash collisions") { val conf = createSparkConf(loadDefaults = true) conf.set("spark.shuffle.memoryFraction", "0.0001") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) val map = new ExternalAppendOnlyMap[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _) // Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes @@ -366,7 +366,7 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { test("spilling with hash collisions using the Int.MaxValue key") { val conf = createSparkConf(loadDefaults = true) conf.set("spark.shuffle.memoryFraction", "0.001") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) val map = createExternalMap[Int] (1 to 100000).foreach { i => map.insert(i, i) } @@ -383,7 +383,7 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { test("spilling with null keys and values") { val conf = createSparkConf(loadDefaults = true) conf.set("spark.shuffle.memoryFraction", "0.001") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) val map = createExternalMap[Int] map.insertAll((1 to 100000).iterator.map(i => (i, i))) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 9cefa612f5491..986cd8623d145 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -176,7 +176,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { def testSpillingInLocalCluster(conf: SparkConf) { conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) // reduceByKey - should spill ~8 times val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i)) @@ -254,7 +254,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { def spillingInLocalClusterWithManyReduceTasks(conf: SparkConf) { conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) // reduceByKey - should spill ~4 times per executor val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i)) @@ -554,7 +554,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { test("spilling with hash collisions") { val conf = createSparkConf(true, false) conf.set("spark.shuffle.memoryFraction", "0.001") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer[String](i) def mergeValue(buffer: ArrayBuffer[String], i: String): ArrayBuffer[String] = buffer += i @@ -611,7 +611,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { test("spilling with many hash collisions") { val conf = createSparkConf(true, false) conf.set("spark.shuffle.memoryFraction", "0.0001") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) val agg = new Aggregator[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _) val sorter = new ExternalSorter[FixedHashObject, Int, Int](Some(agg), None, None, None) @@ -634,7 +634,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { test("spilling with hash collisions using the Int.MaxValue key") { val conf = createSparkConf(true, false) conf.set("spark.shuffle.memoryFraction", "0.001") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) def createCombiner(i: Int): ArrayBuffer[Int] = ArrayBuffer[Int](i) def mergeValue(buffer: ArrayBuffer[Int], i: Int): ArrayBuffer[Int] = buffer += i @@ -658,7 +658,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { test("spilling with null keys and values") { val conf = createSparkConf(true, false) conf.set("spark.shuffle.memoryFraction", "0.001") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer[String](i) def mergeValue(buffer: ArrayBuffer[String], i: String): ArrayBuffer[String] = buffer += i @@ -695,7 +695,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { def sortWithoutBreakingSortingContracts(conf: SparkConf) { conf.set("spark.shuffle.memoryFraction", "0.01") conf.set("spark.shuffle.manager", "sort") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) // Using wrongOrdering to show integer overflow introduced exception. val rand = new Random(100L) diff --git a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala index 6d2459d48d326..3b67f6206495a 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala @@ -17,15 +17,20 @@ package org.apache.spark.util.collection -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, InputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import com.google.common.io.ByteStreams +import org.mockito.Matchers.any +import org.mockito.Mockito._ +import org.mockito.Mockito.RETURNS_SMART_NULLS +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer import org.scalatest.Matchers._ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.KryoSerializer -import org.apache.spark.storage.{FileSegment, BlockObjectWriter} +import org.apache.spark.storage.DiskBlockObjectWriter class PartitionedSerializedPairBufferSuite extends SparkFunSuite { test("OrderedInputStream single record") { @@ -79,13 +84,13 @@ class PartitionedSerializedPairBufferSuite extends SparkFunSuite { val struct = SomeStruct("something", 5) buffer.insert(4, 10, struct) val it = buffer.destructiveSortedWritablePartitionedIterator(None) - val writer = new SimpleBlockObjectWriter + val (writer, baos) = createMockWriter() assert(it.hasNext) it.nextPartition should be (4) it.writeNext(writer) assert(!it.hasNext) - val stream = serializerInstance.deserializeStream(writer.getInputStream) + val stream = serializerInstance.deserializeStream(new ByteArrayInputStream(baos.toByteArray)) stream.readObject[AnyRef]() should be (10) stream.readObject[AnyRef]() should be (struct) } @@ -101,7 +106,7 @@ class PartitionedSerializedPairBufferSuite extends SparkFunSuite { buffer.insert(5, 3, struct3) val it = buffer.destructiveSortedWritablePartitionedIterator(None) - val writer = new SimpleBlockObjectWriter + val (writer, baos) = createMockWriter() assert(it.hasNext) it.nextPartition should be (4) it.writeNext(writer) @@ -113,7 +118,7 @@ class PartitionedSerializedPairBufferSuite extends SparkFunSuite { it.writeNext(writer) assert(!it.hasNext) - val stream = serializerInstance.deserializeStream(writer.getInputStream) + val stream = serializerInstance.deserializeStream(new ByteArrayInputStream(baos.toByteArray)) val iter = stream.asIterator iter.next() should be (2) iter.next() should be (struct2) @@ -123,26 +128,21 @@ class PartitionedSerializedPairBufferSuite extends SparkFunSuite { iter.next() should be (struct1) assert(!iter.hasNext) } -} - -case class SomeStruct(val str: String, val num: Int) - -class SimpleBlockObjectWriter extends BlockObjectWriter(null) { - val baos = new ByteArrayOutputStream() - override def write(bytes: Array[Byte], offs: Int, len: Int): Unit = { - baos.write(bytes, offs, len) + def createMockWriter(): (DiskBlockObjectWriter, ByteArrayOutputStream) = { + val writer = mock(classOf[DiskBlockObjectWriter], RETURNS_SMART_NULLS) + val baos = new ByteArrayOutputStream() + when(writer.write(any(), any(), any())).thenAnswer(new Answer[Unit] { + override def answer(invocationOnMock: InvocationOnMock): Unit = { + val args = invocationOnMock.getArguments + val bytes = args(0).asInstanceOf[Array[Byte]] + val offset = args(1).asInstanceOf[Int] + val length = args(2).asInstanceOf[Int] + baos.write(bytes, offset, length) + } + }) + (writer, baos) } - - def getInputStream(): InputStream = new ByteArrayInputStream(baos.toByteArray) - - override def open(): BlockObjectWriter = this - override def close(): Unit = { } - override def isOpen: Boolean = true - override def commitAndClose(): Unit = { } - override def revertPartialWritesAndClose(): Unit = { } - override def fileSegment(): FileSegment = null - override def write(key: Any, value: Any): Unit = { } - override def recordWritten(): Unit = { } - override def write(b: Int): Unit = { } } + +case class SomeStruct(str: String, num: Int) diff --git a/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala index 5a5919fca2469..4f382414a8dd7 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala @@ -103,7 +103,9 @@ private object SizeTrackerSuite { */ def main(args: Array[String]): Unit = { if (args.size < 1) { + // scalastyle:off println println("Usage: SizeTrackerSuite [num elements]") + // scalastyle:on println System.exit(1) } val numElements = args(0).toInt @@ -180,11 +182,13 @@ private object SizeTrackerSuite { baseTimes: Seq[Long], sampledTimes: Seq[Long], unsampledTimes: Seq[Long]): Unit = { + // scalastyle:off println println(s"Average times for $testName (ms):") println(" Base - " + averageTime(baseTimes)) println(" SizeTracker (sampled) - " + averageTime(sampledTimes)) println(" SizeEstimator (unsampled) - " + averageTime(unsampledTimes)) println() + // scalastyle:on println } def time(f: => Unit): Long = { diff --git a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala index b2f5d9009ee5d..fefa5165db197 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala @@ -20,10 +20,10 @@ package org.apache.spark.util.collection import java.lang.{Float => JFloat, Integer => JInteger} import java.util.{Arrays, Comparator} -import org.apache.spark.SparkFunSuite +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.util.random.XORShiftRandom -class SorterSuite extends SparkFunSuite { +class SorterSuite extends SparkFunSuite with Logging { test("equivalent to Arrays.sort") { val rand = new XORShiftRandom(123) @@ -74,7 +74,7 @@ class SorterSuite extends SparkFunSuite { /** Runs an experiment several times. */ def runExperiment(name: String, skip: Boolean = false)(f: => Unit, prepare: () => Unit): Unit = { if (skip) { - println(s"Skipped experiment $name.") + logInfo(s"Skipped experiment $name.") return } @@ -86,11 +86,11 @@ class SorterSuite extends SparkFunSuite { while (i < 10) { val time = org.apache.spark.util.Utils.timeIt(1)(f, Some(prepare)) next10 += time - println(s"$name: Took $time ms") + logInfo(s"$name: Took $time ms") i += 1 } - println(s"$name: ($firstTry ms first try, ${next10 / 10} ms average)") + logInfo(s"$name: ($firstTry ms first try, ${next10 / 10} ms average)") } /** diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala new file mode 100644 index 0000000000000..26a2e96edaaa2 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort + +import com.google.common.primitives.UnsignedBytes +import org.scalatest.prop.PropertyChecks +import org.apache.spark.SparkFunSuite +import org.apache.spark.unsafe.types.UTF8String + +class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { + + test("String prefix comparator") { + + def testPrefixComparison(s1: String, s2: String): Unit = { + val utf8string1 = UTF8String.fromString(s1) + val utf8string2 = UTF8String.fromString(s2) + val s1Prefix = PrefixComparators.StringPrefixComparator.computePrefix(utf8string1) + val s2Prefix = PrefixComparators.StringPrefixComparator.computePrefix(utf8string2) + val prefixComparisonResult = PrefixComparators.STRING.compare(s1Prefix, s2Prefix) + + val cmp = UnsignedBytes.lexicographicalComparator().compare( + utf8string1.getBytes.take(8), utf8string2.getBytes.take(8)) + + assert( + (prefixComparisonResult == 0 && cmp == 0) || + (prefixComparisonResult < 0 && s1.compareTo(s2) < 0) || + (prefixComparisonResult > 0 && s1.compareTo(s2) > 0)) + } + + // scalastyle:off + val regressionTests = Table( + ("s1", "s2"), + ("abc", "世界"), + ("你好", "世界"), + ("你好123", "你好122") + ) + // scalastyle:on + + forAll (regressionTests) { (s1: String, s2: String) => testPrefixComparison(s1, s2) } + forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) } + } + + test("double prefix comparator handles NaNs properly") { + val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L) + val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL) + assert(nan1.isNaN) + assert(nan2.isNaN) + val nan1Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan1) + val nan2Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan2) + assert(nan1Prefix === nan2Prefix) + val doubleMaxPrefix = PrefixComparators.DoublePrefixComparator.computePrefix(Double.MaxValue) + assert(PrefixComparators.DOUBLE.compare(nan1Prefix, doubleMaxPrefix) === 1) + } + +} diff --git a/data/mllib/sample_naive_bayes_data.txt b/data/mllib/sample_naive_bayes_data.txt index 981da382d6ac8..bd22bea3a59d6 100644 --- a/data/mllib/sample_naive_bayes_data.txt +++ b/data/mllib/sample_naive_bayes_data.txt @@ -1,6 +1,12 @@ 0,1 0 0 0,2 0 0 +0,3 0 0 +0,4 0 0 1,0 1 0 1,0 2 0 +1,0 3 0 +1,0 4 0 2,0 0 1 2,0 0 2 +2,0 0 3 +2,0 0 4 \ No newline at end of file diff --git a/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala b/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala index fc03fec9866a6..61d91c70e9709 100644 --- a/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala +++ b/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package main.scala import scala.util.Try @@ -59,3 +60,4 @@ object SimpleApp { } } } +// scalastyle:on println diff --git a/dev/audit-release/sbt_app_ganglia/src/main/scala/SparkApp.scala b/dev/audit-release/sbt_app_ganglia/src/main/scala/SparkApp.scala index 0be8e64fbfabd..9f7ae75d0b477 100644 --- a/dev/audit-release/sbt_app_ganglia/src/main/scala/SparkApp.scala +++ b/dev/audit-release/sbt_app_ganglia/src/main/scala/SparkApp.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package main.scala import scala.util.Try @@ -37,3 +38,4 @@ object SimpleApp { } } } +// scalastyle:on println diff --git a/dev/audit-release/sbt_app_graphx/src/main/scala/GraphxApp.scala b/dev/audit-release/sbt_app_graphx/src/main/scala/GraphxApp.scala index 24c7f8d667296..2f0b6ef9a5672 100644 --- a/dev/audit-release/sbt_app_graphx/src/main/scala/GraphxApp.scala +++ b/dev/audit-release/sbt_app_graphx/src/main/scala/GraphxApp.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package main.scala import org.apache.spark.{SparkContext, SparkConf} @@ -51,3 +52,4 @@ object GraphXApp { println("Test succeeded") } } +// scalastyle:on println diff --git a/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala b/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala index 5111bc0adb772..4a980ec071ae4 100644 --- a/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala +++ b/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package main.scala import scala.collection.mutable.{ListBuffer, Queue} @@ -55,3 +56,4 @@ object SparkSqlExample { sc.stop() } } +// scalastyle:on println diff --git a/dev/audit-release/sbt_app_kinesis/src/main/scala/SparkApp.scala b/dev/audit-release/sbt_app_kinesis/src/main/scala/SparkApp.scala index 9f85066501472..adc25b57d6aa5 100644 --- a/dev/audit-release/sbt_app_kinesis/src/main/scala/SparkApp.scala +++ b/dev/audit-release/sbt_app_kinesis/src/main/scala/SparkApp.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package main.scala import scala.util.Try @@ -31,3 +32,4 @@ object SimpleApp { } } } +// scalastyle:on println diff --git a/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala b/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala index cc86ef45858c9..69c1154dc0955 100644 --- a/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala +++ b/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package main.scala import scala.collection.mutable.{ListBuffer, Queue} @@ -57,3 +58,4 @@ object SparkSqlExample { sc.stop() } } +// scalastyle:on println diff --git a/dev/audit-release/sbt_app_streaming/src/main/scala/StreamingApp.scala b/dev/audit-release/sbt_app_streaming/src/main/scala/StreamingApp.scala index 58a662bd9b2e8..d6a074687f4a1 100644 --- a/dev/audit-release/sbt_app_streaming/src/main/scala/StreamingApp.scala +++ b/dev/audit-release/sbt_app_streaming/src/main/scala/StreamingApp.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package main.scala import scala.collection.mutable.{ListBuffer, Queue} @@ -61,3 +62,4 @@ object SparkStreamingExample { ssc.stop() } } +// scalastyle:on println diff --git a/dev/change-scala-version.sh b/dev/change-scala-version.sh new file mode 100755 index 0000000000000..d7975dfb6475c --- /dev/null +++ b/dev/change-scala-version.sh @@ -0,0 +1,70 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +set -e + +VALID_VERSIONS=( 2.10 2.11 ) + +usage() { + echo "Usage: $(basename $0) [-h|--help] +where : + -h| --help Display this help text + valid version values : ${VALID_VERSIONS[*]} +" 1>&2 + exit 1 +} + +if [[ ($# -ne 1) || ( $1 == "--help") || $1 == "-h" ]]; then + usage +fi + +TO_VERSION=$1 + +check_scala_version() { + for i in ${VALID_VERSIONS[*]}; do [ $i = "$1" ] && return 0; done + echo "Invalid Scala version: $1. Valid versions: ${VALID_VERSIONS[*]}" 1>&2 + exit 1 +} + +check_scala_version "$TO_VERSION" + +if [ $TO_VERSION = "2.11" ]; then + FROM_VERSION="2.10" +else + FROM_VERSION="2.11" +fi + +sed_i() { + sed -e "$1" "$2" > "$2.tmp" && mv "$2.tmp" "$2" +} + +export -f sed_i + +BASEDIR=$(dirname $0)/.. +find "$BASEDIR" -name 'pom.xml' -not -path '*target*' -print \ + -exec bash -c "sed_i 's/\(artifactId.*\)_'$FROM_VERSION'/\1_'$TO_VERSION'/g' {}" \; + +# Also update in parent POM +# Match any scala binary version to ensure idempotency +sed_i '1,/[0-9]*\.[0-9]*[0-9]*\.[0-9]*'$TO_VERSION' in parent POM -sed -i -e '0,/2.112.10 in parent POM -sed -i -e '0,/2.102.11 "$PYTHON_LINT_REPORT_PATH" +python -B -m compileall -q -l $PATHS_TO_CHECK > "$PEP8_REPORT_PATH" compile_status="${PIPESTATUS[0]}" # Get pep8 at runtime so that we don't rely on it being installed on the build server. @@ -47,11 +49,36 @@ if [ ! -e "$PEP8_SCRIPT_PATH" ]; then fi fi +# Easy install pylint in /dev/pylint. To easy_install into a directory, the PYTHONPATH should +# be set to the directory. +# dev/pylint should be appended to the PATH variable as well. +# Jenkins by default installs the pylint3 version, so for now this just checks the code quality +# of python3. +export "PYTHONPATH=$SPARK_ROOT_DIR/dev/pylint" +export "PYLINT_HOME=$PYTHONPATH" +export "PATH=$PYTHONPATH:$PATH" + +# if [ ! -d "$PYLINT_HOME" ]; then +# mkdir "$PYLINT_HOME" +# # Redirect the annoying pylint installation output. +# easy_install -d "$PYLINT_HOME" pylint==1.4.4 &>> "$PYLINT_INSTALL_INFO" +# easy_install_status="$?" +# +# if [ "$easy_install_status" -ne 0 ]; then +# echo "Unable to install pylint locally in \"$PYTHONPATH\"." +# cat "$PYLINT_INSTALL_INFO" +# exit "$easy_install_status" +# fi +# +# rm "$PYLINT_INSTALL_INFO" +# +# fi + # There is no need to write this output to a file #+ first, but we do so so that the check status can #+ be output before the report, like with the #+ scalastyle and RAT checks. -python "$PEP8_SCRIPT_PATH" --ignore=E402,E731,E241,W503,E226 $PATHS_TO_CHECK >> "$PYTHON_LINT_REPORT_PATH" +python "$PEP8_SCRIPT_PATH" --ignore=E402,E731,E241,W503,E226 $PATHS_TO_CHECK >> "$PEP8_REPORT_PATH" pep8_status="${PIPESTATUS[0]}" if [ "$compile_status" -eq 0 -a "$pep8_status" -eq 0 ]; then @@ -61,13 +88,27 @@ else fi if [ "$lint_status" -ne 0 ]; then - echo "Python lint checks failed." - cat "$PYTHON_LINT_REPORT_PATH" + echo "PEP8 checks failed." + cat "$PEP8_REPORT_PATH" else - echo "Python lint checks passed." + echo "PEP8 checks passed." fi -# rm "$PEP8_SCRIPT_PATH" -rm "$PYTHON_LINT_REPORT_PATH" +rm "$PEP8_REPORT_PATH" + +# for to_be_checked in "$PATHS_TO_CHECK" +# do +# pylint --rcfile="$SPARK_ROOT_DIR/pylintrc" $to_be_checked >> "$PYLINT_REPORT_PATH" +# done + +# if [ "${PIPESTATUS[0]}" -ne 0 ]; then +# lint_status=1 +# echo "Pylint checks failed." +# cat "$PYLINT_REPORT_PATH" +# else +# echo "Pylint checks passed." +# fi + +# rm "$PYLINT_REPORT_PATH" exit "$lint_status" diff --git a/dev/lint-r.R b/dev/lint-r.R index dcb1a184291e1..48bd6246096ae 100644 --- a/dev/lint-r.R +++ b/dev/lint-r.R @@ -15,15 +15,21 @@ # limitations under the License. # +argv <- commandArgs(TRUE) +SPARK_ROOT_DIR <- as.character(argv[1]) + # Installs lintr from Github. # NOTE: The CRAN's version is too old to adapt to our rules. if ("lintr" %in% row.names(installed.packages()) == FALSE) { devtools::install_github("jimhester/lintr") } -library(lintr) -argv <- commandArgs(TRUE) -SPARK_ROOT_DIR <- as.character(argv[1]) +library(lintr) +library(methods) +library(testthat) +if (! library(SparkR, lib.loc = file.path(SPARK_ROOT_DIR, "R", "lib"), logical.return = TRUE)) { + stop("You should install SparkR in a local directory with `R/install-dev.sh`.") +} path.to.package <- file.path(SPARK_ROOT_DIR, "R", "pkg") lint_package(path.to.package, cache = FALSE) diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 4a17d48d8171d..ad4b76695c9ff 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -130,7 +130,12 @@ def merge_pr(pr_num, target_ref, title, body, pr_repo_desc): '--pretty=format:%an <%ae>']).split("\n") distinct_authors = sorted(set(commit_authors), key=lambda x: commit_authors.count(x), reverse=True) - primary_author = distinct_authors[0] + primary_author = raw_input( + "Enter primary author in the format of \"name \" [%s]: " % + distinct_authors[0]) + if primary_author == "": + primary_author = distinct_authors[0] + commits = run_cmd(['git', 'log', 'HEAD..%s' % pr_branch_name, '--pretty=format:%h [%an] %s']).split("\n\n") @@ -281,7 +286,7 @@ def get_version_json(version_str): resolve = filter(lambda a: a['name'] == "Resolve Issue", asf_jira.transitions(jira_id))[0] resolution = filter(lambda r: r.raw['name'] == "Fixed", asf_jira.resolutions())[0] asf_jira.transition_issue( - jira_id, resolve["id"], fixVersions = jira_fix_versions, + jira_id, resolve["id"], fixVersions = jira_fix_versions, comment = comment, resolution = {'id': resolution.raw['id']}) print "Successfully resolved %s with fixVersions=%s!" % (jira_id, fix_versions) @@ -300,7 +305,7 @@ def standardize_jira_ref(text): """ Standardize the [SPARK-XXXXX] [MODULE] prefix Converts "[SPARK-XXX][mllib] Issue", "[MLLib] SPARK-XXX. Issue" or "SPARK XXX [MLLIB]: Issue" to "[SPARK-XXX] [MLLIB] Issue" - + >>> standardize_jira_ref("[SPARK-5821] [SQL] ParquetRelation2 CTAS should check if delete is successful") '[SPARK-5821] [SQL] ParquetRelation2 CTAS should check if delete is successful' >>> standardize_jira_ref("[SPARK-4123][Project Infra][WIP]: Show new dependencies added in pull requests") @@ -322,11 +327,11 @@ def standardize_jira_ref(text): """ jira_refs = [] components = [] - + # If the string is compliant, no need to process any further if (re.search(r'^\[SPARK-[0-9]{3,6}\] (\[[A-Z0-9_\s,]+\] )+\S+', text)): return text - + # Extract JIRA ref(s): pattern = re.compile(r'(SPARK[-\s]*[0-9]{3,6})+', re.IGNORECASE) for ref in pattern.findall(text): @@ -348,18 +353,18 @@ def standardize_jira_ref(text): # Assemble full text (JIRA ref(s), module(s), remaining text) clean_text = ' '.join(jira_refs).strip() + " " + ' '.join(components).strip() + " " + text.strip() - + # Replace multiple spaces with a single space, e.g. if no jira refs and/or components were included clean_text = re.sub(r'\s+', ' ', clean_text.strip()) - + return clean_text def main(): global original_head - + os.chdir(SPARK_HOME) original_head = run_cmd("git rev-parse HEAD")[:8] - + branches = get_json("%s/branches" % GITHUB_API_BASE) branch_names = filter(lambda x: x.startswith("branch-"), [x['name'] for x in branches]) # Assumes branch names can be sorted lexicographically @@ -448,5 +453,5 @@ def main(): (failure_count, test_count) = doctest.testmod() if failure_count: exit(-1) - + main() diff --git a/dev/run-tests.py b/dev/run-tests.py index 237fb76c9b3d9..1eff2b4d5c071 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -85,6 +85,13 @@ def identify_changed_files_from_git_commits(patch_sha, target_branch=None, targe return [f for f in raw_output.split('\n') if f] +def setup_test_environ(environ): + print("[info] Setup the following environment variables for tests: ") + for (k, v) in environ.items(): + print("%s=%s" % (k, v)) + os.environ[k] = v + + def determine_modules_to_test(changed_modules): """ Given a set of modules that have changed, compute the transitive closure of those modules' @@ -457,6 +464,15 @@ def main(): print("[info] Found the following changed modules:", ", ".join(x.name for x in changed_modules)) + # setup environment variables + # note - the 'root' module doesn't collect environment variables for all modules. Because the + # environment variables should not be set if a module is not changed, even if running the 'root' + # module. So here we should use changed_modules rather than test_modules. + test_environ = {} + for m in changed_modules: + test_environ.update(m.environ) + setup_test_environ(test_environ) + test_modules = determine_modules_to_test(changed_modules) # license checks diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index b283753f2dfd7..3e273af10ffae 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -29,7 +29,7 @@ class Module(object): changed. """ - def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), + def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), environ={}, sbt_test_goals=(), python_test_goals=(), blacklisted_python_implementations=(), should_run_r_tests=False): """ @@ -43,6 +43,8 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= filename strings. :param build_profile_flags: A set of profile flags that should be passed to Maven or SBT in order to build and test this module (e.g. '-PprofileName'). + :param environ: A dict of environment variables that should be set when files in this + module are changed. :param sbt_test_goals: A set of SBT test goals for testing this module. :param python_test_goals: A set of Python test goals for testing this module. :param blacklisted_python_implementations: A set of Python implementations that are not @@ -55,6 +57,7 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= self.source_file_prefixes = source_file_regexes self.sbt_test_goals = sbt_test_goals self.build_profile_flags = build_profile_flags + self.environ = environ self.python_test_goals = python_test_goals self.blacklisted_python_implementations = blacklisted_python_implementations self.should_run_r_tests = should_run_r_tests @@ -126,15 +129,22 @@ def contains_file(self, filename): ) +# Don't set the dependencies because changes in other modules should not trigger Kinesis tests. +# Kinesis tests depends on external Amazon kinesis service. We should run these tests only when +# files in streaming_kinesis_asl are changed, so that if Kinesis experiences an outage, we don't +# fail other PRs. streaming_kinesis_asl = Module( name="kinesis-asl", - dependencies=[streaming], + dependencies=[], source_file_regexes=[ "extras/kinesis-asl/", ], build_profile_flags=[ "-Pkinesis-asl", ], + environ={ + "ENABLE_KINESIS_TESTS": "1" + }, sbt_test_goals=[ "kinesis-asl/test", ] @@ -320,7 +330,7 @@ def contains_file(self, filename): "pyspark.mllib.evaluation", "pyspark.mllib.feature", "pyspark.mllib.fpm", - "pyspark.mllib.linalg", + "pyspark.mllib.linalg.__init__", "pyspark.mllib.random", "pyspark.mllib.recommendation", "pyspark.mllib.regression", @@ -345,6 +355,7 @@ def contains_file(self, filename): python_test_goals=[ "pyspark.ml.feature", "pyspark.ml.classification", + "pyspark.ml.clustering", "pyspark.ml.recommendation", "pyspark.ml.regression", "pyspark.ml.tuning", diff --git a/docker/spark-test/base/Dockerfile b/docker/spark-test/base/Dockerfile index 5956d59130fbf..5dbdb8b22a44f 100644 --- a/docker/spark-test/base/Dockerfile +++ b/docker/spark-test/base/Dockerfile @@ -17,13 +17,13 @@ FROM ubuntu:precise -RUN echo "deb http://archive.ubuntu.com/ubuntu precise main universe" > /etc/apt/sources.list - # Upgrade package index -RUN apt-get update - # install a few other useful packages plus Open Jdk 7 -RUN apt-get install -y less openjdk-7-jre-headless net-tools vim-tiny sudo openssh-server +# Remove unneeded /var/lib/apt/lists/* after install to reduce the +# docker image size (by ~30MB) +RUN apt-get update && \ + apt-get install -y less openjdk-7-jre-headless net-tools vim-tiny sudo openssh-server && \ + rm -rf /var/lib/apt/lists/* ENV SCALA_VERSION 2.10.4 ENV CDH_VERSION cdh4 diff --git a/docs/building-spark.md b/docs/building-spark.md index 2128fdffecc05..a5da3b39502e2 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -124,7 +124,7 @@ mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -Phive-thriftserver -Dskip # Building for Scala 2.11 To produce a Spark package compiled with Scala 2.11, use the `-Dscala-2.11` property: - dev/change-version-to-2.11.sh + dev/change-scala-version.sh 2.11 mvn -Pyarn -Phadoop-2.4 -Dscala-2.11 -DskipTests clean package Spark does not yet support its JDBC component for Scala 2.11. diff --git a/docs/configuration.md b/docs/configuration.md index bebaf6f62e90a..fd236137cb96e 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -31,7 +31,6 @@ which can help detect bugs that only exist when we run in a distributed context. val conf = new SparkConf() .setMaster("local[2]") .setAppName("CountingSheep") - .set("spark.executor.memory", "1g") val sc = new SparkContext(conf) {% endhighlight %} @@ -84,7 +83,7 @@ Running `./bin/spark-submit --help` will show the entire list of these options. each line consists of a key and a value separated by whitespace. For example: spark.master spark://5.6.7.8:7077 - spark.executor.memory 512m + spark.executor.memory 4g spark.eventLog.enabled true spark.serializer org.apache.spark.serializer.KryoSerializer @@ -150,10 +149,9 @@ of the most common options to set are: spark.executor.memory - 512m + 1g - Amount of memory to use per executor process, in the same format as JVM memory strings - (e.g. 512m, 2g). + Amount of memory to use per executor process (e.g. 2g, 8g). @@ -205,7 +203,7 @@ Apart from these, the following properties are also available, and may be useful spark.driver.extraClassPath (none) - Extra classpath entries to append to the classpath of the driver. + Extra classpath entries to prepend to the classpath of the driver.
    Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. @@ -252,7 +250,7 @@ Apart from these, the following properties are also available, and may be useful spark.executor.extraClassPath (none) - Extra classpath entries to append to the classpath of executors. This exists primarily for + Extra classpath entries to prepend to the classpath of executors. This exists primarily for backwards-compatibility with older versions of Spark. Users typically should not need to set this option. @@ -665,7 +663,7 @@ Apart from these, the following properties are also available, and may be useful Initial size of Kryo's serialization buffer. Note that there will be one buffer per core on each worker. This buffer will grow up to - spark.kryoserializer.buffer.max.mb if needed. + spark.kryoserializer.buffer.max if needed. @@ -886,11 +884,11 @@ Apart from these, the following properties are also available, and may be useful spark.akka.frameSize - 10 + 128 - Maximum message size to allow in "control plane" communication (for serialized tasks and task - results), in MB. Increase this if your tasks need to send back large results to the driver - (e.g. using collect() on a large dataset). + Maximum message size to allow in "control plane" communication; generally only applies to map + output size information sent between executors and the driver. Increase this if you are running + jobs with many thousands of map and reduce tasks and see messages about the frame size. @@ -1007,9 +1005,9 @@ Apart from these, the following properties are also available, and may be useful spark.rpc.numRetries 3 + Number of times to retry before an RPC task gives up. An RPC task will run at most times of this number. - @@ -1029,8 +1027,8 @@ Apart from these, the following properties are also available, and may be useful spark.rpc.lookupTimeout 120s - Duration for an RPC remote endpoint lookup operation to wait before timing out. + Duration for an RPC remote endpoint lookup operation to wait before timing out. @@ -1050,15 +1048,6 @@ Apart from these, the following properties are also available, and may be useful infinite (all available cores) on Mesos. - - spark.localExecution.enabled - false - - Enables Spark to run certain jobs, such as first() or take() on the driver, without sending - tasks to the cluster. This can make certain jobs execute very quickly, but may require - shipping a whole partition of data to the driver. - - spark.locality.wait 3s @@ -1206,7 +1195,7 @@ Apart from these, the following properties are also available, and may be useful spark.dynamicAllocation.cachedExecutorIdleTimeout - 2 * executorIdleTimeout + infinity If dynamic allocation is enabled and an executor which has cached data blocks has been idle for more than this duration, the executor will be removed. For more details, see this @@ -1222,7 +1211,7 @@ Apart from these, the following properties are also available, and may be useful spark.dynamicAllocation.maxExecutors - Integer.MAX_VALUE + infinity Upper bound for the number of executors if dynamic allocation is enabled. diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index 3f10cb2dc3d2a..99f8c827f767f 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -800,7 +800,7 @@ import org.apache.spark.graphx._ // Import random graph generation library import org.apache.spark.graphx.util.GraphGenerators // A graph with edge attributes containing distances -val graph: Graph[Int, Double] = +val graph: Graph[Long, Double] = GraphGenerators.logNormalGraph(sc, numVertices = 100).mapEdges(e => e.attr.toDouble) val sourceId: VertexId = 42 // The ultimate source // Initialize the graph such that all vertices except the root have distance infinity. diff --git a/docs/ml-features.md b/docs/ml-features.md index f88c0248c1a8a..54068debe2159 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -288,6 +288,94 @@ for words_label in wordsDataFrame.select("words", "label").take(3): +## $n$-gram + +An [n-gram](https://en.wikipedia.org/wiki/N-gram) is a sequence of $n$ tokens (typically words) for some integer $n$. The `NGram` class can be used to transform input features into $n$-grams. + +`NGram` takes as input a sequence of strings (e.g. the output of a [Tokenizer](ml-features.html#tokenizer). The parameter `n` is used to determine the number of terms in each $n$-gram. The output will consist of a sequence of $n$-grams where each $n$-gram is represented by a space-delimited string of $n$ consecutive words. If the input sequence contains fewer than `n` strings, no output is produced. + +
    +
    +
    + +
    + +[`NGram`](api/scala/index.html#org.apache.spark.ml.feature.NGram) takes an input column name, an output column name, and an optional length parameter n (n=2 by default). + +{% highlight scala %} +import org.apache.spark.ml.feature.NGram + +val wordDataFrame = sqlContext.createDataFrame(Seq( + (0, Array("Hi", "I", "heard", "about", "Spark")), + (1, Array("I", "wish", "Java", "could", "use", "case", "classes")), + (2, Array("Logistic", "regression", "models", "are", "neat")) +)).toDF("label", "words") + +val ngram = new NGram().setInputCol("words").setOutputCol("ngrams") +val ngramDataFrame = ngram.transform(wordDataFrame) +ngramDataFrame.take(3).map(_.getAs[Stream[String]]("ngrams").toList).foreach(println) +{% endhighlight %} +
    + +
    + +[`NGram`](api/java/org/apache/spark/ml/feature/NGram.html) takes an input column name, an output column name, and an optional length parameter n (n=2 by default). + +{% highlight java %} +import com.google.common.collect.Lists; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.NGram; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + RowFactory.create(0D, Lists.newArrayList("Hi", "I", "heard", "about", "Spark")), + RowFactory.create(1D, Lists.newArrayList("I", "wish", "Java", "could", "use", "case", "classes")), + RowFactory.create(2D, Lists.newArrayList("Logistic", "regression", "models", "are", "neat")) +)); +StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("words", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) +}); +DataFrame wordDataFrame = sqlContext.createDataFrame(jrdd, schema); +NGram ngramTransformer = new NGram().setInputCol("words").setOutputCol("ngrams"); +DataFrame ngramDataFrame = ngramTransformer.transform(wordDataFrame); +for (Row r : ngramDataFrame.select("ngrams", "label").take(3)) { + java.util.List ngrams = r.getList(0); + for (String ngram : ngrams) System.out.print(ngram + " --- "); + System.out.println(); +} +{% endhighlight %} +
    + +
    + +[`NGram`](api/python/pyspark.ml.html#pyspark.ml.feature.NGram) takes an input column name, an output column name, and an optional length parameter n (n=2 by default). + +{% highlight python %} +from pyspark.ml.feature import NGram + +wordDataFrame = sqlContext.createDataFrame([ + (0, ["Hi", "I", "heard", "about", "Spark"]), + (1, ["I", "wish", "Java", "could", "use", "case", "classes"]), + (2, ["Logistic", "regression", "models", "are", "neat"]) +], ["label", "words"]) +ngram = NGram(inputCol="words", outputCol="ngrams") +ngramDataFrame = ngram.transform(wordDataFrame) +for ngrams_label in ngramDataFrame.select("ngrams", "label").take(3): + print(ngrams_label) +{% endhighlight %} +
    +
    + + ## Binarizer Binarization is the process of thresholding numerical features to binary features. As some probabilistic estimators make assumption that the input data is distributed according to [Bernoulli distribution](http://en.wikipedia.org/wiki/Bernoulli_distribution), a binarizer is useful for pre-processing the input data with continuous numerical features. diff --git a/docs/ml-guide.md b/docs/ml-guide.md index c74cb1f1ef8ea..8c46adf256a9a 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -3,6 +3,24 @@ layout: global title: Spark ML Programming Guide --- +`\[ +\newcommand{\R}{\mathbb{R}} +\newcommand{\E}{\mathbb{E}} +\newcommand{\x}{\mathbf{x}} +\newcommand{\y}{\mathbf{y}} +\newcommand{\wv}{\mathbf{w}} +\newcommand{\av}{\mathbf{\alpha}} +\newcommand{\bv}{\mathbf{b}} +\newcommand{\N}{\mathbb{N}} +\newcommand{\id}{\mathbf{I}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} +\newcommand{\zero}{\mathbf{0}} +\]` + + Spark 1.2 introduced a new package called `spark.ml`, which aims to provide a uniform set of high-level APIs that help users create and tune practical machine learning pipelines. @@ -154,6 +172,19 @@ Parameters belong to specific instances of `Estimator`s and `Transformer`s. For example, if we have two `LogisticRegression` instances `lr1` and `lr2`, then we can build a `ParamMap` with both `maxIter` parameters specified: `ParamMap(lr1.maxIter -> 10, lr2.maxIter -> 20)`. This is useful if there are two algorithms with the `maxIter` parameter in a `Pipeline`. +# Algorithm Guides + +There are now several algorithms in the Pipelines API which are not in the lower-level MLlib API, so we link to documentation for them here. These algorithms are mostly feature transformers, which fit naturally into the `Transformer` abstraction in Pipelines, and ensembles, which fit naturally into the `Estimator` abstraction in the Pipelines. + +**Pipelines API Algorithm Guides** + +* [Feature Extraction, Transformation, and Selection](ml-features.html) +* [Ensembles](ml-ensembles.html) + +**Algorithms in `spark.ml`** + +* [Linear methods with elastic net regularization](ml-linear-methods.html) + # Code Examples This section gives code examples illustrating the functionality discussed above. diff --git a/docs/ml-linear-methods.md b/docs/ml-linear-methods.md new file mode 100644 index 0000000000000..1ac83d94c9e81 --- /dev/null +++ b/docs/ml-linear-methods.md @@ -0,0 +1,129 @@ +--- +layout: global +title: Linear Methods - ML +displayTitle: ML - Linear Methods +--- + + +`\[ +\newcommand{\R}{\mathbb{R}} +\newcommand{\E}{\mathbb{E}} +\newcommand{\x}{\mathbf{x}} +\newcommand{\y}{\mathbf{y}} +\newcommand{\wv}{\mathbf{w}} +\newcommand{\av}{\mathbf{\alpha}} +\newcommand{\bv}{\mathbf{b}} +\newcommand{\N}{\mathbb{N}} +\newcommand{\id}{\mathbf{I}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} +\newcommand{\zero}{\mathbf{0}} +\]` + + +In MLlib, we implement popular linear methods such as logistic regression and linear least squares with L1 or L2 regularization. Refer to [the linear methods in mllib](mllib-linear-methods.html) for details. In `spark.ml`, we also include Pipelines API for [Elastic net](http://en.wikipedia.org/wiki/Elastic_net_regularization), a hybrid of L1 and L2 regularization proposed in [this paper](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf). Mathematically it is defined as a linear combination of the L1-norm and the L2-norm: +`\[ +\alpha \|\wv\|_1 + (1-\alpha) \frac{1}{2}\|\wv\|_2^2, \alpha \in [0, 1]. +\]` +By setting $\alpha$ properly, it contains both L1 and L2 regularization as special cases. For example, if a [linear regression](https://en.wikipedia.org/wiki/Linear_regression) model is trained with the elastic net parameter $\alpha$ set to $1$, it is equivalent to a [Lasso](http://en.wikipedia.org/wiki/Least_squares#Lasso_method) model. On the other hand, if $\alpha$ is set to $0$, the trained model reduces to a [ridge regression](http://en.wikipedia.org/wiki/Tikhonov_regularization) model. We implement Pipelines API for both linear regression and logistic regression with elastic net regularization. + +**Examples** + +
    + +
    + +{% highlight scala %} + +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.mllib.util.MLUtils + +// Load training data +val training = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + +val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8) + +// Fit the model +val lrModel = lr.fit(training) + +// Print the weights and intercept for logistic regression +println(s"Weights: ${lrModel.weights} Intercept: ${lrModel.intercept}") + +{% endhighlight %} + +
    + +
    + +{% highlight java %} + +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.classification.LogisticRegressionModel; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + +public class LogisticRegressionWithElasticNetExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf() + .setAppName("Logistic Regression with Elastic Net Example"); + + SparkContext sc = new SparkContext(conf); + SQLContext sql = new SQLContext(sc); + String path = "sample_libsvm_data.txt"; + + // Load training data + DataFrame training = sql.createDataFrame(MLUtils.loadLibSVMFile(sc, path).toJavaRDD(), LabeledPoint.class); + + LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8) + + // Fit the model + LogisticRegressionModel lrModel = lr.fit(training); + + // Print the weights and intercept for logistic regression + System.out.println("Weights: " + lrModel.weights() + " Intercept: " + lrModel.intercept()); + } +} +{% endhighlight %} +
    + +
    + +{% highlight python %} + +from pyspark.ml.classification import LogisticRegression +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.util import MLUtils + +# Load training data +training = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + +lr = LogisticRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) + +# Fit the model +lrModel = lr.fit(training) + +# Print the weights and intercept for logistic regression +print("Weights: " + str(lrModel.weights)) +print("Intercept: " + str(lrModel.intercept)) +{% endhighlight %} + +
    + +
    + +### Optimization + +The optimization algorithm underlies the implementation is called [Orthant-Wise Limited-memory QuasiNewton](http://research-srv.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf) +(OWL-QN). It is an extension of L-BFGS that can effectively handle L1 regularization and elastic net. diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index 3aad4149f99db..bb875ae2ae6cb 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -33,6 +33,7 @@ guaranteed to find a globally optimal solution, and when run multiple times on a given dataset, the algorithm returns the best clustering result). * *initializationSteps* determines the number of steps in the k-means\|\| algorithm. * *epsilon* determines the distance threshold within which we consider k-means to have converged. +* *initialModel* is an optional set of cluster centers used for initialization. If this parameter is supplied, only one run is performed. **Examples** @@ -447,7 +448,7 @@ It supports different inference algorithms via `setOptimizer` function. EMLDAOpt on the likelihood function and yields comprehensive results, while OnlineLDAOptimizer uses iterative mini-batch sampling for [online variational inference](https://www.cs.princeton.edu/~blei/papers/HoffmanBleiBach2010b.pdf) and is generally memory friendly. After fitting on the documents, LDA provides: * Topics: Inferred topics, each of which is a probability distribution over terms (words). -* Topic distributions for documents: For each document in the training set, LDA gives a probability distribution over topics. (EM only) +* Topic distributions for documents: For each non empty document in the training set, LDA gives a probability distribution over topics. (EM only). Note that for empty documents, we don't create the topic distributions. (EM only) LDA takes the following parameters: @@ -471,7 +472,7 @@ to the algorithm. We then output the topics, represented as probability distribu
    {% highlight scala %} -import org.apache.spark.mllib.clustering.LDA +import org.apache.spark.mllib.clustering.{LDA, DistributedLDAModel} import org.apache.spark.mllib.linalg.Vectors // Load and parse the data @@ -491,6 +492,11 @@ for (topic <- Range(0, 3)) { for (word <- Range(0, ldaModel.vocabSize)) { print(" " + topics(word, topic)); } println() } + +// Save and load model. +ldaModel.save(sc, "myLDAModel") +val sameModel = DistributedLDAModel.load(sc, "myLDAModel") + {% endhighlight %}
    @@ -550,6 +556,9 @@ public class JavaLDAExample { } System.out.println(); } + + ldaModel.save(sc.sc(), "myLDAModel"); + DistributedLDAModel sameModel = DistributedLDAModel.load(sc.sc(), "myLDAModel"); } } {% endhighlight %} diff --git a/docs/mllib-data-types.md b/docs/mllib-data-types.md index d824dab1d7f7b..3aa040046fca5 100644 --- a/docs/mllib-data-types.md +++ b/docs/mllib-data-types.md @@ -226,7 +226,8 @@ examples = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") A local matrix has integer-typed row and column indices and double-typed values, stored on a single machine. MLlib supports dense matrices, whose entry values are stored in a single double array in -column major. For example, the following matrix `\[ \begin{pmatrix} +column-major order, and sparse matrices, whose non-zero entry values are stored in the Compressed Sparse +Column (CSC) format in column-major order. For example, the following dense matrix `\[ \begin{pmatrix} 1.0 & 2.0 \\ 3.0 & 4.0 \\ 5.0 & 6.0 @@ -238,28 +239,33 @@ is stored in a one-dimensional array `[1.0, 3.0, 5.0, 2.0, 4.0, 6.0]` with the m
    The base class of local matrices is -[`Matrix`](api/scala/index.html#org.apache.spark.mllib.linalg.Matrix), and we provide one -implementation: [`DenseMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.DenseMatrix). +[`Matrix`](api/scala/index.html#org.apache.spark.mllib.linalg.Matrix), and we provide two +implementations: [`DenseMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.DenseMatrix), +and [`SparseMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.SparseMatrix). We recommend using the factory methods implemented in [`Matrices`](api/scala/index.html#org.apache.spark.mllib.linalg.Matrices$) to create local -matrices. +matrices. Remember, local matrices in MLlib are stored in column-major order. {% highlight scala %} import org.apache.spark.mllib.linalg.{Matrix, Matrices} // Create a dense matrix ((1.0, 2.0), (3.0, 4.0), (5.0, 6.0)) val dm: Matrix = Matrices.dense(3, 2, Array(1.0, 3.0, 5.0, 2.0, 4.0, 6.0)) + +// Create a sparse matrix ((9.0, 0.0), (0.0, 8.0), (0.0, 6.0)) +val sm: Matrix = Matrices.sparse(3, 2, Array(0, 1, 3), Array(0, 2, 1), Array(9, 6, 8)) {% endhighlight %}
    The base class of local matrices is -[`Matrix`](api/java/org/apache/spark/mllib/linalg/Matrix.html), and we provide one -implementation: [`DenseMatrix`](api/java/org/apache/spark/mllib/linalg/DenseMatrix.html). +[`Matrix`](api/java/org/apache/spark/mllib/linalg/Matrix.html), and we provide two +implementations: [`DenseMatrix`](api/java/org/apache/spark/mllib/linalg/DenseMatrix.html), +and [`SparseMatrix`](api/java/org/apache/spark/mllib/linalg/SparseMatrix.html). We recommend using the factory methods implemented in [`Matrices`](api/java/org/apache/spark/mllib/linalg/Matrices.html) to create local -matrices. +matrices. Remember, local matrices in MLlib are stored in column-major order. {% highlight java %} import org.apache.spark.mllib.linalg.Matrix; @@ -267,6 +273,30 @@ import org.apache.spark.mllib.linalg.Matrices; // Create a dense matrix ((1.0, 2.0), (3.0, 4.0), (5.0, 6.0)) Matrix dm = Matrices.dense(3, 2, new double[] {1.0, 3.0, 5.0, 2.0, 4.0, 6.0}); + +// Create a sparse matrix ((9.0, 0.0), (0.0, 8.0), (0.0, 6.0)) +Matrix sm = Matrices.sparse(3, 2, new int[] {0, 1, 3}, new int[] {0, 2, 1}, new double[] {9, 6, 8}); +{% endhighlight %} +
    + +
    + +The base class of local matrices is +[`Matrix`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.Matrix), and we provide two +implementations: [`DenseMatrix`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.DenseMatrix), +and [`SparseMatrix`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.SparseMatrix). +We recommend using the factory methods implemented +in [`Matrices`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.Matrices) to create local +matrices. Remember, local matrices in MLlib are stored in column-major order. + +{% highlight python %} +import org.apache.spark.mllib.linalg.{Matrix, Matrices} + +// Create a dense matrix ((1.0, 2.0), (3.0, 4.0), (5.0, 6.0)) +dm2 = Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6]) + +// Create a sparse matrix ((9.0, 0.0), (0.0, 8.0), (0.0, 6.0)) +sm = Matrices.sparse(3, 2, [0, 1, 3], [0, 2, 1], [9, 6, 8]) {% endhighlight %}
    diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md new file mode 100644 index 0000000000000..4ca0bb06b26a6 --- /dev/null +++ b/docs/mllib-evaluation-metrics.md @@ -0,0 +1,1497 @@ +--- +layout: global +title: Evaluation Metrics - MLlib +displayTitle: MLlib - Evaluation Metrics +--- + +* Table of contents +{:toc} + +Spark's MLlib comes with a number of machine learning algorithms that can be used to learn from and make predictions +on data. When these algorithms are applied to build machine learning models, there is a need to evaluate the performance +of the model on some criteria, which depends on the application and its requirements. Spark's MLlib also provides a +suite of metrics for the purpose of evaluating the performance of machine learning models. + +Specific machine learning algorithms fall under broader types of machine learning applications like classification, +regression, clustering, etc. Each of these types have well established metrics for performance evaluation and those +metrics that are currently available in Spark's MLlib are detailed in this section. + +## Classification model evaluation + +While there are many different types of classification algorithms, the evaluation of classification models all share +similar principles. In a [supervised classification problem](https://en.wikipedia.org/wiki/Statistical_classification), +there exists a true output and a model-generated predicted output for each data point. For this reason, the results for +each data point can be assigned to one of four categories: + +* True Positive (TP) - label is positive and prediction is also positive +* True Negative (TN) - label is negative and prediction is also negative +* False Positive (FP) - label is negative but prediction is positive +* False Negative (FN) - label is positive but prediction is negative + +These four numbers are the building blocks for most classifier evaluation metrics. A fundamental point when considering +classifier evaluation is that pure accuracy (i.e. was the prediction correct or incorrect) is not generally a good metric. The +reason for this is because a dataset may be highly unbalanced. For example, if a model is designed to predict fraud from +a dataset where 95% of the data points are _not fraud_ and 5% of the data points are _fraud_, then a naive classifier +that predicts _not fraud_, regardless of input, will be 95% accurate. For this reason, metrics like +[precision and recall](https://en.wikipedia.org/wiki/Precision_and_recall) are typically used because they take into +account the *type* of error. In most applications there is some desired balance between precision and recall, which can +be captured by combining the two into a single metric, called the [F-measure](https://en.wikipedia.org/wiki/F1_score). + +### Binary classification + +[Binary classifiers](https://en.wikipedia.org/wiki/Binary_classification) are used to separate the elements of a given +dataset into one of two possible groups (e.g. fraud or not fraud) and is a special case of multiclass classification. +Most binary classification metrics can be generalized to multiclass classification metrics. + +#### Threshold tuning + +It is import to understand that many classification models actually output a "score" (often times a probability) for +each class, where a higher score indicates higher likelihood. In the binary case, the model may output a probability for +each class: $P(Y=1|X)$ and $P(Y=0|X)$. Instead of simply taking the higher probability, there may be some cases where +the model might need to be tuned so that it only predicts a class when the probability is very high (e.g. only block a +credit card transaction if the model predicts fraud with >90% probability). Therefore, there is a prediction *threshold* +which determines what the predicted class will be based on the probabilities that the model outputs. + +Tuning the prediction threshold will change the precision and recall of the model and is an important part of model +optimization. In order to visualize how precision, recall, and other metrics change as a function of the threshold it is +common practice to plot competing metrics against one another, parameterized by threshold. A P-R curve plots (precision, +recall) points for different threshold values, while a +[receiver operating characteristic](https://en.wikipedia.org/wiki/Receiver_operating_characteristic), or ROC, curve +plots (recall, false positive rate) points. + +**Available metrics** + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    MetricDefinition
    Precision (Postive Predictive Value)$PPV=\frac{TP}{TP + FP}$
    Recall (True Positive Rate)$TPR=\frac{TP}{P}=\frac{TP}{TP + FN}$
    F-measure$F(\beta) = \left(1 + \beta^2\right) \cdot \left(\frac{PPV \cdot TPR} + {\beta^2 \cdot PPV + TPR}\right)$
    Receiver Operating Characteristic (ROC)$FPR(T)=\int^\infty_{T} P_0(T)\,dT \\ TPR(T)=\int^\infty_{T} P_1(T)\,dT$
    Area Under ROC Curve$AUROC=\int^1_{0} \frac{TP}{P} d\left(\frac{FP}{N}\right)$
    Area Under Precision-Recall Curve$AUPRC=\int^1_{0} \frac{TP}{TP+FP} d\left(\frac{TP}{P}\right)$
    + + +**Examples** + +
    +The following code snippets illustrate how to load a sample dataset, train a binary classification algorithm on the +data, and evaluate the performance of the algorithm by several binary evaluation metrics. + +
    + +{% highlight scala %} +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils + +// Load training data in LIBSVM format +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") + +// Split data into training (60%) and test (40%) +val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) +training.cache() + +// Run training algorithm to build the model +val model = new LogisticRegressionWithLBFGS() + .setNumClasses(2) + .run(training) + +// Clear the prediction threshold so the model will return probabilities +model.clearThreshold + +// Compute raw scores on the test set +val predictionAndLabels = test.map { case LabeledPoint(label, features) => + val prediction = model.predict(features) + (prediction, label) +} + +// Instantiate metrics object +val metrics = new BinaryClassificationMetrics(predictionAndLabels) + +// Precision by threshold +val precision = metrics.precisionByThreshold +precision.foreach { case (t, p) => + println(s"Threshold: $t, Precision: $p") +} + +// Recall by threshold +val recall = metrics.precisionByThreshold +recall.foreach { case (t, r) => + println(s"Threshold: $t, Recall: $r") +} + +// Precision-Recall Curve +val PRC = metrics.pr + +// F-measure +val f1Score = metrics.fMeasureByThreshold +f1Score.foreach { case (t, f) => + println(s"Threshold: $t, F-score: $f, Beta = 1") +} + +val beta = 0.5 +val fScore = metrics.fMeasureByThreshold(beta) +f1Score.foreach { case (t, f) => + println(s"Threshold: $t, F-score: $f, Beta = 0.5") +} + +// AUPRC +val auPRC = metrics.areaUnderPR +println("Area under precision-recall curve = " + auPRC) + +// Compute thresholds used in ROC and PR curves +val thresholds = precision.map(_._1) + +// ROC Curve +val roc = metrics.roc + +// AUROC +val auROC = metrics.areaUnderROC +println("Area under ROC = " + auROC) + +{% endhighlight %} + +
    + +
    + +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.rdd.RDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +public class BinaryClassification { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Binary Classification Metrics"); + SparkContext sc = new SparkContext(conf); + String path = "data/mllib/sample_binary_classification_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L); + JavaRDD training = splits[0].cache(); + JavaRDD test = splits[1]; + + // Run training algorithm to build the model. + final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() + .setNumClasses(2) + .run(training.rdd()); + + // Clear the prediction threshold so the model will return probabilities + model.clearThreshold(); + + // Compute raw scores on the test set. + JavaRDD> predictionAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double prediction = model.predict(p.features()); + return new Tuple2(prediction, p.label()); + } + } + ); + + // Get evaluation metrics. + BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(predictionAndLabels.rdd()); + + // Precision by threshold + JavaRDD> precision = metrics.precisionByThreshold().toJavaRDD(); + System.out.println("Precision by threshold: " + precision.toArray()); + + // Recall by threshold + JavaRDD> recall = metrics.recallByThreshold().toJavaRDD(); + System.out.println("Recall by threshold: " + recall.toArray()); + + // F Score by threshold + JavaRDD> f1Score = metrics.fMeasureByThreshold().toJavaRDD(); + System.out.println("F1 Score by threshold: " + f1Score.toArray()); + + JavaRDD> f2Score = metrics.fMeasureByThreshold(2.0).toJavaRDD(); + System.out.println("F2 Score by threshold: " + f2Score.toArray()); + + // Precision-recall curve + JavaRDD> prc = metrics.pr().toJavaRDD(); + System.out.println("Precision-recall curve: " + prc.toArray()); + + // Thresholds + JavaRDD thresholds = precision.map( + new Function, Double>() { + public Double call (Tuple2 t) { + return new Double(t._1().toString()); + } + } + ); + + // ROC Curve + JavaRDD> roc = metrics.roc().toJavaRDD(); + System.out.println("ROC curve: " + roc.toArray()); + + // AUPRC + System.out.println("Area under precision-recall curve = " + metrics.areaUnderPR()); + + // AUROC + System.out.println("Area under ROC = " + metrics.areaUnderROC()); + + // Save and load model + model.save(sc, "myModelPath"); + LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath"); + } +} + +{% endhighlight %} + +
    + +
    + +{% highlight python %} +from pyspark.mllib.classification import LogisticRegressionWithLBFGS +from pyspark.mllib.evaluation import BinaryClassificationMetrics +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.util import MLUtils + +# Several of the methods available in scala are currently missing from pyspark + +# Load training data in LIBSVM format +data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") + +# Split data into training (60%) and test (40%) +training, test = data.randomSplit([0.6, 0.4], seed = 11L) +training.cache() + +# Run training algorithm to build the model +model = LogisticRegressionWithLBFGS.train(training) + +# Compute raw scores on the test set +predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) + +# Instantiate metrics object +metrics = BinaryClassificationMetrics(predictionAndLabels) + +# Area under precision-recall curve +print "Area under PR = %s" % metrics.areaUnderPR + +# Area under ROC curve +print "Area under ROC = %s" % metrics.areaUnderROC + +{% endhighlight %} + +
    +
    + + +### Multiclass classification + +A [multiclass classification](https://en.wikipedia.org/wiki/Multiclass_classification) describes a classification +problem where there are $M \gt 2$ possible labels for each data point (the case where $M=2$ is the binary +classification problem). For example, classifying handwriting samples to the digits 0 to 9, having 10 possible classes. + +For multiclass metrics, the notion of positives and negatives is slightly different. Predictions and labels can still +be positive or negative, but they must be considered under the context of a particular class. Each label and prediction +take on the value of one of the multiple classes and so they are said to be positive for their particular class and negative +for all other classes. So, a true positive occurs whenever the prediction and the label match, while a true negative +occurs when neither the prediction nor the label take on the value of a given class. By this convention, there can be +multiple true negatives for a given data sample. The extension of false negatives and false positives from the former +definitions of positive and negative labels is straightforward. + +#### Label based metrics + +Opposed to binary classification where there are only two possible labels, multiclass classification problems have many +possible labels and so the concept of label-based metrics is introduced. Overall precision measures precision across all +labels - the number of times any class was predicted correctly (true positives) normalized by the number of data +points. Precision by label considers only one class, and measures the number of time a specific label was predicted +correctly normalized by the number of times that label appears in the output. + +**Available metrics** + +Define the class, or label, set as + +$$L = \{\ell_0, \ell_1, \ldots, \ell_{M-1} \} $$ + +The true output vector $\mathbf{y}$ consists of $N$ elements + +$$\mathbf{y}_0, \mathbf{y}_1, \ldots, \mathbf{y}_{N-1} \in L $$ + +A multiclass prediction algorithm generates a prediction vector $\hat{\mathbf{y}}$ of $N$ elements + +$$\hat{\mathbf{y}}_0, \hat{\mathbf{y}}_1, \ldots, \hat{\mathbf{y}}_{N-1} \in L $$ + +For this section, a modified delta function $\hat{\delta}(x)$ will prove useful + +$$\hat{\delta}(x) = \begin{cases}1 & \text{if $x = 0$}, \\ 0 & \text{otherwise}.\end{cases}$$ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    MetricDefinition
    Confusion Matrix + $C_{ij} = \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_i) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_j)\\ \\ + \left( \begin{array}{ccc} + \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_1) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_1) & \ldots & + \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_1) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_N) \\ + \vdots & \ddots & \vdots \\ + \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_N) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_1) & \ldots & + \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_N) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_N) + \end{array} \right)$ +
    Overall Precision$PPV = \frac{TP}{TP + FP} = \frac{1}{N}\sum_{i=0}^{N-1} \hat{\delta}\left(\hat{\mathbf{y}}_i - + \mathbf{y}_i\right)$
    Overall Recall$TPR = \frac{TP}{TP + FN} = \frac{1}{N}\sum_{i=0}^{N-1} \hat{\delta}\left(\hat{\mathbf{y}}_i - + \mathbf{y}_i\right)$
    Overall F1-measure$F1 = 2 \cdot \left(\frac{PPV \cdot TPR} + {PPV + TPR}\right)$
    Precision by label$PPV(\ell) = \frac{TP}{TP + FP} = + \frac{\sum_{i=0}^{N-1} \hat{\delta}(\hat{\mathbf{y}}_i - \ell) \cdot \hat{\delta}(\mathbf{y}_i - \ell)} + {\sum_{i=0}^{N-1} \hat{\delta}(\hat{\mathbf{y}}_i - \ell)}$
    Recall by label$TPR(\ell)=\frac{TP}{P} = + \frac{\sum_{i=0}^{N-1} \hat{\delta}(\hat{\mathbf{y}}_i - \ell) \cdot \hat{\delta}(\mathbf{y}_i - \ell)} + {\sum_{i=0}^{N-1} \hat{\delta}(\mathbf{y}_i - \ell)}$
    F-measure by label$F(\beta, \ell) = \left(1 + \beta^2\right) \cdot \left(\frac{PPV(\ell) \cdot TPR(\ell)} + {\beta^2 \cdot PPV(\ell) + TPR(\ell)}\right)$
    Weighted precision$PPV_{w}= \frac{1}{N} \sum\nolimits_{\ell \in L} PPV(\ell) + \cdot \sum_{i=0}^{N-1} \hat{\delta}(\mathbf{y}_i-\ell)$
    Weighted recall$TPR_{w}= \frac{1}{N} \sum\nolimits_{\ell \in L} TPR(\ell) + \cdot \sum_{i=0}^{N-1} \hat{\delta}(\mathbf{y}_i-\ell)$
    Weighted F-measure$F_{w}(\beta)= \frac{1}{N} \sum\nolimits_{\ell \in L} F(\beta, \ell) + \cdot \sum_{i=0}^{N-1} \hat{\delta}(\mathbf{y}_i-\ell)$
    + +**Examples** + +
    +The following code snippets illustrate how to load a sample dataset, train a multiclass classification algorithm on +the data, and evaluate the performance of the algorithm by several multiclass classification evaluation metrics. + +
    + +{% highlight scala %} +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils + +// Load training data in LIBSVM format +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") + +// Split data into training (60%) and test (40%) +val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) +training.cache() + +// Run training algorithm to build the model +val model = new LogisticRegressionWithLBFGS() + .setNumClasses(3) + .run(training) + +// Compute raw scores on the test set +val predictionAndLabels = test.map { case LabeledPoint(label, features) => + val prediction = model.predict(features) + (prediction, label) +} + +// Instantiate metrics object +val metrics = new MulticlassMetrics(predictionAndLabels) + +// Confusion matrix +println("Confusion matrix:") +println(metrics.confusionMatrix) + +// Overall Statistics +val precision = metrics.precision +val recall = metrics.recall // same as true positive rate +val f1Score = metrics.fMeasure +println("Summary Statistics") +println(s"Precision = $precision") +println(s"Recall = $recall") +println(s"F1 Score = $f1Score") + +// Precision by label +val labels = metrics.labels +labels.foreach { l => + println(s"Precision($l) = " + metrics.precision(l)) +} + +// Recall by label +labels.foreach { l => + println(s"Recall($l) = " + metrics.recall(l)) +} + +// False positive rate by label +labels.foreach { l => + println(s"FPR($l) = " + metrics.falsePositiveRate(l)) +} + +// F-measure by label +labels.foreach { l => + println(s"F1-Score($l) = " + metrics.fMeasure(l)) +} + +// Weighted stats +println(s"Weighted precision: ${metrics.weightedPrecision}") +println(s"Weighted recall: ${metrics.weightedRecall}") +println(s"Weighted F1 score: ${metrics.weightedFMeasure}") +println(s"Weighted false positive rate: ${metrics.weightedFalsePositiveRate}") + +{% endhighlight %} + +
    + +
    + +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.rdd.RDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; +import org.apache.spark.mllib.evaluation.MulticlassMetrics; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +public class MulticlassClassification { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Multiclass Classification Metrics"); + SparkContext sc = new SparkContext(conf); + String path = "data/mllib/sample_multiclass_classification_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L); + JavaRDD training = splits[0].cache(); + JavaRDD test = splits[1]; + + // Run training algorithm to build the model. + final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() + .setNumClasses(3) + .run(training.rdd()); + + // Compute raw scores on the test set. + JavaRDD> predictionAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double prediction = model.predict(p.features()); + return new Tuple2(prediction, p.label()); + } + } + ); + + // Get evaluation metrics. + MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); + + // Confusion matrix + Matrix confusion = metrics.confusionMatrix(); + System.out.println("Confusion matrix: \n" + confusion); + + // Overall statistics + System.out.println("Precision = " + metrics.precision()); + System.out.println("Recall = " + metrics.recall()); + System.out.println("F1 Score = " + metrics.fMeasure()); + + // Stats by labels + for (int i = 0; i < metrics.labels().length; i++) { + System.out.format("Class %f precision = %f\n", metrics.labels()[i], metrics.precision(metrics.labels()[i])); + System.out.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall(metrics.labels()[i])); + System.out.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure(metrics.labels()[i])); + } + + //Weighted stats + System.out.format("Weighted precision = %f\n", metrics.weightedPrecision()); + System.out.format("Weighted recall = %f\n", metrics.weightedRecall()); + System.out.format("Weighted F1 score = %f\n", metrics.weightedFMeasure()); + System.out.format("Weighted false positive rate = %f\n", metrics.weightedFalsePositiveRate()); + + // Save and load model + model.save(sc, "myModelPath"); + LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath"); + } +} + +{% endhighlight %} + +
    + +
    + +{% highlight python %} +from pyspark.mllib.classification import LogisticRegressionWithLBFGS +from pyspark.mllib.util import MLUtils +from pyspark.mllib.evaluation import MulticlassMetrics + +# Load training data in LIBSVM format +data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") + +# Split data into training (60%) and test (40%) +training, test = data.randomSplit([0.6, 0.4], seed = 11L) +training.cache() + +# Run training algorithm to build the model +model = LogisticRegressionWithLBFGS.train(training, numClasses=3) + +# Compute raw scores on the test set +predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) + +# Instantiate metrics object +metrics = MulticlassMetrics(predictionAndLabels) + +# Overall statistics +precision = metrics.precision() +recall = metrics.recall() +f1Score = metrics.fMeasure() +print "Summary Stats" +print "Precision = %s" % precision +print "Recall = %s" % recall +print "F1 Score = %s" % f1Score + +# Statistics by class +labels = data.map(lambda lp: lp.label).distinct().collect() +for label in sorted(labels): + print "Class %s precision = %s" % (label, metrics.precision(label)) + print "Class %s recall = %s" % (label, metrics.recall(label)) + print "Class %s F1 Measure = %s" % (label, metrics.fMeasure(label, beta=1.0)) + +# Weighted stats +print "Weighted recall = %s" % metrics.weightedRecall +print "Weighted precision = %s" % metrics.weightedPrecision +print "Weighted F(1) Score = %s" % metrics.weightedFMeasure() +print "Weighted F(0.5) Score = %s" % metrics.weightedFMeasure(beta=0.5) +print "Weighted false positive rate = %s" % metrics.weightedFalsePositiveRate +{% endhighlight %} + +
    +
    + +### Multilabel classification + +A [multilabel classification](https://en.wikipedia.org/wiki/Multi-label_classification) problem involves mapping +each sample in a dataset to a set of class labels. In this type of classification problem, the labels are not +mutually exclusive. For example, when classifying a set of news articles into topics, a single article might be both +science and politics. + +Because the labels are not mutually exclusive, the predictions and true labels are now vectors of label *sets*, rather +than vectors of labels. Multilabel metrics, therefore, extend the fundamental ideas of precision, recall, etc. to +operations on sets. For example, a true positive for a given class now occurs when that class exists in the predicted +set and it exists in the true label set, for a specific data point. + +**Available metrics** + +Here we define a set $D$ of $N$ documents + +$$D = \left\{d_0, d_1, ..., d_{N-1}\right\}$$ + +Define $L_0, L_1, ..., L_{N-1}$ to be a family of label sets and $P_0, P_1, ..., P_{N-1}$ +to be a family of prediction sets where $L_i$ and $P_i$ are the label set and prediction set, respectively, that +correspond to document $d_i$. + +The set of all unique labels is given by + +$$L = \bigcup_{k=0}^{N-1} L_k$$ + +The following definition of indicator function $I_A(x)$ on a set $A$ will be necessary + +$$I_A(x) = \begin{cases}1 & \text{if $x \in A$}, \\ 0 & \text{otherwise}.\end{cases}$$ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    MetricDefinition
    Precision$\frac{1}{N} \sum_{i=0}^{N-1} \frac{\left|P_i \cap L_i\right|}{\left|P_i\right|}$
    Recall$\frac{1}{N} \sum_{i=0}^{N-1} \frac{\left|L_i \cap P_i\right|}{\left|L_i\right|}$
    Accuracy + $\frac{1}{N} \sum_{i=0}^{N - 1} \frac{\left|L_i \cap P_i \right|} + {\left|L_i\right| + \left|P_i\right| - \left|L_i \cap P_i \right|}$ +
    Precision by label$PPV(\ell)=\frac{TP}{TP + FP}= + \frac{\sum_{i=0}^{N-1} I_{P_i}(\ell) \cdot I_{L_i}(\ell)} + {\sum_{i=0}^{N-1} I_{P_i}(\ell)}$
    Recall by label$TPR(\ell)=\frac{TP}{P}= + \frac{\sum_{i=0}^{N-1} I_{P_i}(\ell) \cdot I_{L_i}(\ell)} + {\sum_{i=0}^{N-1} I_{L_i}(\ell)}$
    F1-measure by label$F1(\ell) = 2 + \cdot \left(\frac{PPV(\ell) \cdot TPR(\ell)} + {PPV(\ell) + TPR(\ell)}\right)$
    Hamming Loss + $\frac{1}{N \cdot \left|L\right|} \sum_{i=0}^{N - 1} \left|L_i\right| + \left|P_i\right| - 2\left|L_i + \cap P_i\right|$ +
    Subset Accuracy$\frac{1}{N} \sum_{i=0}^{N-1} I_{\{L_i\}}(P_i)$
    F1 Measure$\frac{1}{N} \sum_{i=0}^{N-1} 2 \frac{\left|P_i \cap L_i\right|}{\left|P_i\right| \cdot \left|L_i\right|}$
    Micro precision$\frac{TP}{TP + FP}=\frac{\sum_{i=0}^{N-1} \left|P_i \cap L_i\right|} + {\sum_{i=0}^{N-1} \left|P_i \cap L_i\right| + \sum_{i=0}^{N-1} \left|P_i - L_i\right|}$
    Micro recall$\frac{TP}{TP + FN}=\frac{\sum_{i=0}^{N-1} \left|P_i \cap L_i\right|} + {\sum_{i=0}^{N-1} \left|P_i \cap L_i\right| + \sum_{i=0}^{N-1} \left|L_i - P_i\right|}$
    Micro F1 Measure + $2 \cdot \frac{TP}{2 \cdot TP + FP + FN}=2 \cdot \frac{\sum_{i=0}^{N-1} \left|P_i \cap L_i\right|}{2 \cdot + \sum_{i=0}^{N-1} \left|P_i \cap L_i\right| + \sum_{i=0}^{N-1} \left|L_i - P_i\right| + \sum_{i=0}^{N-1} + \left|P_i - L_i\right|}$ +
    + +**Examples** + +The following code snippets illustrate how to evaluate the performance of a multilabel classifer. The examples +use the fake prediction and label data for multilabel classification that is shown below. + +Document predictions: + +* doc 0 - predict 0, 1 - class 0, 2 +* doc 1 - predict 0, 2 - class 0, 1 +* doc 2 - predict none - class 0 +* doc 3 - predict 2 - class 2 +* doc 4 - predict 2, 0 - class 2, 0 +* doc 5 - predict 0, 1, 2 - class 0, 1 +* doc 6 - predict 1 - class 1, 2 + +Predicted classes: + +* class 0 - doc 0, 1, 4, 5 (total 4) +* class 1 - doc 0, 5, 6 (total 3) +* class 2 - doc 1, 3, 4, 5 (total 4) + +True classes: + +* class 0 - doc 0, 1, 2, 4, 5 (total 5) +* class 1 - doc 1, 5, 6 (total 3) +* class 2 - doc 0, 3, 4, 6 (total 4) + +
    + +
    + +{% highlight scala %} +import org.apache.spark.mllib.evaluation.MultilabelMetrics +import org.apache.spark.rdd.RDD; + +val scoreAndLabels: RDD[(Array[Double], Array[Double])] = sc.parallelize( + Seq((Array(0.0, 1.0), Array(0.0, 2.0)), + (Array(0.0, 2.0), Array(0.0, 1.0)), + (Array(), Array(0.0)), + (Array(2.0), Array(2.0)), + (Array(2.0, 0.0), Array(2.0, 0.0)), + (Array(0.0, 1.0, 2.0), Array(0.0, 1.0)), + (Array(1.0), Array(1.0, 2.0))), 2) + +// Instantiate metrics object +val metrics = new MultilabelMetrics(scoreAndLabels) + +// Summary stats +println(s"Recall = ${metrics.recall}") +println(s"Precision = ${metrics.precision}") +println(s"F1 measure = ${metrics.f1Measure}") +println(s"Accuracy = ${metrics.accuracy}") + +// Individual label stats +metrics.labels.foreach(label => println(s"Class $label precision = ${metrics.precision(label)}")) +metrics.labels.foreach(label => println(s"Class $label recall = ${metrics.recall(label)}")) +metrics.labels.foreach(label => println(s"Class $label F1-score = ${metrics.f1Measure(label)}")) + +// Micro stats +println(s"Micro recall = ${metrics.microRecall}") +println(s"Micro precision = ${metrics.microPrecision}") +println(s"Micro F1 measure = ${metrics.microF1Measure}") + +// Hamming loss +println(s"Hamming loss = ${metrics.hammingLoss}") + +// Subset accuracy +println(s"Subset accuracy = ${metrics.subsetAccuracy}") + +{% endhighlight %} + +
    + +
    + +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.rdd.RDD; +import org.apache.spark.mllib.evaluation.MultilabelMetrics; +import org.apache.spark.SparkConf; +import java.util.Arrays; +import java.util.List; + +public class MultilabelClassification { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Multilabel Classification Metrics"); + JavaSparkContext sc = new JavaSparkContext(conf); + + List> data = Arrays.asList( + new Tuple2(new double[]{0.0, 1.0}, new double[]{0.0, 2.0}), + new Tuple2(new double[]{0.0, 2.0}, new double[]{0.0, 1.0}), + new Tuple2(new double[]{}, new double[]{0.0}), + new Tuple2(new double[]{2.0}, new double[]{2.0}), + new Tuple2(new double[]{2.0, 0.0}, new double[]{2.0, 0.0}), + new Tuple2(new double[]{0.0, 1.0, 2.0}, new double[]{0.0, 1.0}), + new Tuple2(new double[]{1.0}, new double[]{1.0, 2.0}) + ); + JavaRDD> scoreAndLabels = sc.parallelize(data); + + // Instantiate metrics object + MultilabelMetrics metrics = new MultilabelMetrics(scoreAndLabels.rdd()); + + // Summary stats + System.out.format("Recall = %f\n", metrics.recall()); + System.out.format("Precision = %f\n", metrics.precision()); + System.out.format("F1 measure = %f\n", metrics.f1Measure()); + System.out.format("Accuracy = %f\n", metrics.accuracy()); + + // Stats by labels + for (int i = 0; i < metrics.labels().length - 1; i++) { + System.out.format("Class %1.1f precision = %f\n", metrics.labels()[i], metrics.precision(metrics.labels()[i])); + System.out.format("Class %1.1f recall = %f\n", metrics.labels()[i], metrics.recall(metrics.labels()[i])); + System.out.format("Class %1.1f F1 score = %f\n", metrics.labels()[i], metrics.f1Measure(metrics.labels()[i])); + } + + // Micro stats + System.out.format("Micro recall = %f\n", metrics.microRecall()); + System.out.format("Micro precision = %f\n", metrics.microPrecision()); + System.out.format("Micro F1 measure = %f\n", metrics.microF1Measure()); + + // Hamming loss + System.out.format("Hamming loss = %f\n", metrics.hammingLoss()); + + // Subset accuracy + System.out.format("Subset accuracy = %f\n", metrics.subsetAccuracy()); + + } +} + +{% endhighlight %} + +
    + +
    + +{% highlight python %} +from pyspark.mllib.evaluation import MultilabelMetrics + +scoreAndLabels = sc.parallelize([ + ([0.0, 1.0], [0.0, 2.0]), + ([0.0, 2.0], [0.0, 1.0]), + ([], [0.0]), + ([2.0], [2.0]), + ([2.0, 0.0], [2.0, 0.0]), + ([0.0, 1.0, 2.0], [0.0, 1.0]), + ([1.0], [1.0, 2.0])]) + +# Instantiate metrics object +metrics = MultilabelMetrics(scoreAndLabels) + +# Summary stats +print "Recall = %s" % metrics.recall() +print "Precision = %s" % metrics.precision() +print "F1 measure = %s" % metrics.f1Measure() +print "Accuracy = %s" % metrics.accuracy + +# Individual label stats +labels = scoreAndLabels.flatMap(lambda x: x[1]).distinct().collect() +for label in labels: + print "Class %s precision = %s" % (label, metrics.precision(label)) + print "Class %s recall = %s" % (label, metrics.recall(label)) + print "Class %s F1 Measure = %s" % (label, metrics.f1Measure(label)) + +# Micro stats +print "Micro precision = %s" % metrics.microPrecision +print "Micro recall = %s" % metrics.microRecall +print "Micro F1 measure = %s" % metrics.microF1Measure + +# Hamming loss +print "Hamming loss = %s" % metrics.hammingLoss + +# Subset accuracy +print "Subset accuracy = %s" % metrics.subsetAccuracy + +{% endhighlight %} + +
    +
    + +### Ranking systems + +The role of a ranking algorithm (often thought of as a [recommender system](https://en.wikipedia.org/wiki/Recommender_system)) +is to return to the user a set of relevant items or documents based on some training data. The definition of relevance +may vary and is usually application specific. Ranking system metrics aim to quantify the effectiveness of these +rankings or recommendations in various contexts. Some metrics compare a set of recommended documents to a ground truth +set of relevant documents, while other metrics may incorporate numerical ratings explicitly. + +**Available metrics** + +A ranking system usually deals with a set of $M$ users + +$$U = \left\{u_0, u_1, ..., u_{M-1}\right\}$$ + +Each user ($u_i$) having a set of $N$ ground truth relevant documents + +$$D_i = \left\{d_0, d_1, ..., d_{N-1}\right\}$$ + +And a list of $Q$ recommended documents, in order of decreasing relevance + +$$R_i = \left[r_0, r_1, ..., r_{Q-1}\right]$$ + +The goal of the ranking system is to produce the most relevant set of documents for each user. The relevance of the +sets and the effectiveness of the algorithms can be measured using the metrics listed below. + +It is necessary to define a function which, provided a recommended document and a set of ground truth relevant +documents, returns a relevance score for the recommended document. + +$$rel_D(r) = \begin{cases}1 & \text{if $r \in D$}, \\ 0 & \text{otherwise}.\end{cases}$$ + + + + + + + + + + + + + + + + + + + + + + +
    MetricDefinitionNotes
    + Precision at k + + $p(k)=\frac{1}{M} \sum_{i=0}^{M-1} {\frac{1}{k} \sum_{j=0}^{\text{min}(\left|D\right|, k) - 1} rel_{D_i}(R_i(j))}$ + + Precision at k is a measure of + how many of the first k recommended documents are in the set of true relevant documents averaged across all + users. In this metric, the order of the recommendations is not taken into account. +
    Mean Average Precision + $MAP=\frac{1}{M} \sum_{i=0}^{M-1} {\frac{1}{\left|D_i\right|} \sum_{j=0}^{Q-1} \frac{rel_{D_i}(R_i(j))}{j + 1}}$ + + MAP is a measure of how + many of the recommended documents are in the set of true relevant documents, where the + order of the recommendations is taken into account (i.e. penalty for highly relevant documents is higher). +
    Normalized Discounted Cumulative Gain + $NDCG(k)=\frac{1}{M} \sum_{i=0}^{M-1} {\frac{1}{IDCG(D_i, k)}\sum_{j=0}^{n-1} + \frac{rel_{D_i}(R_i(j))}{\text{ln}(j+1)}} \\ + \text{Where} \\ + \hspace{5 mm} n = \text{min}\left(\text{max}\left(|R_i|,|D_i|\right),k\right) \\ + \hspace{5 mm} IDCG(D, k) = \sum_{j=0}^{\text{min}(\left|D\right|, k) - 1} \frac{1}{\text{ln}(j+1)}$ + + NDCG at k is a + measure of how many of the first k recommended documents are in the set of true relevant documents averaged + across all users. In contrast to precision at k, this metric takes into account the order of the recommendations + (documents are assumed to be in order of decreasing relevance). +
    + +**Examples** + +The following code snippets illustrate how to load a sample dataset, train an alternating least squares recommendation +model on the data, and evaluate the performance of the recommender by several ranking metrics. A brief summary of the +methodology is provided below. + +MovieLens ratings are on a scale of 1-5: + + * 5: Must see + * 4: Will enjoy + * 3: It's okay + * 2: Fairly bad + * 1: Awful + +So we should not recommend a movie if the predicted rating is less than 3. +To map ratings to confidence scores, we use: + + * 5 -> 2.5 + * 4 -> 1.5 + * 3 -> 0.5 + * 2 -> -0.5 + * 1 -> -1.5. + +This mappings means unobserved entries are generally between It's okay and Fairly bad. The semantics of 0 in this +expanded world of non-positive weights are "the same as never having interacted at all." + +
    + +
    + +{% highlight scala %} +import org.apache.spark.mllib.evaluation.{RegressionMetrics, RankingMetrics} +import org.apache.spark.mllib.recommendation.{ALS, Rating} + +// Read in the ratings data +val ratings = sc.textFile("data/mllib/sample_movielens_data.txt").map { line => + val fields = line.split("::") + Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble - 2.5) +}.cache() + +// Map ratings to 1 or 0, 1 indicating a movie that should be recommended +val binarizedRatings = ratings.map(r => Rating(r.user, r.product, if (r.rating > 0) 1.0 else 0.0)).cache() + +// Summarize ratings +val numRatings = ratings.count() +val numUsers = ratings.map(_.user).distinct().count() +val numMovies = ratings.map(_.product).distinct().count() +println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.") + +// Build the model +val numIterations = 10 +val rank = 10 +val lambda = 0.01 +val model = ALS.train(ratings, rank, numIterations, lambda) + +// Define a function to scale ratings from 0 to 1 +def scaledRating(r: Rating): Rating = { + val scaledRating = math.max(math.min(r.rating, 1.0), 0.0) + Rating(r.user, r.product, scaledRating) +} + +// Get sorted top ten predictions for each user and then scale from [0, 1] +val userRecommended = model.recommendProductsForUsers(10).map{ case (user, recs) => + (user, recs.map(scaledRating)) +} + +// Assume that any movie a user rated 3 or higher (which maps to a 1) is a relevant document +// Compare with top ten most relevant documents +val userMovies = binarizedRatings.groupBy(_.user) +val relevantDocuments = userMovies.join(userRecommended).map{ case (user, (actual, predictions)) => + (predictions.map(_.product), actual.filter(_.rating > 0.0).map(_.product).toArray) +} + +// Instantiate metrics object +val metrics = new RankingMetrics(relevantDocuments) + +// Precision at K +Array(1, 3, 5).foreach{ k => + println(s"Precision at $k = ${metrics.precisionAt(k)}") +} + +// Mean average precision +println(s"Mean average precision = ${metrics.meanAveragePrecision}") + +// Normalized discounted cumulative gain +Array(1, 3, 5).foreach{ k => + println(s"NDCG at $k = ${metrics.ndcgAt(k)}") +} + +// Get predictions for each data point +val allPredictions = model.predict(ratings.map(r => (r.user, r.product))).map(r => ((r.user, r.product), r.rating)) +val allRatings = ratings.map(r => ((r.user, r.product), r.rating)) +val predictionsAndLabels = allPredictions.join(allRatings).map{ case ((user, product), (predicted, actual)) => + (predicted, actual) +} + +// Get the RMSE using regression metrics +val regressionMetrics = new RegressionMetrics(predictionsAndLabels) +println(s"RMSE = ${regressionMetrics.rootMeanSquaredError}") + +// R-squared +println(s"R-squared = ${regressionMetrics.r2}") + +{% endhighlight %} + +
    + +
    + +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.rdd.RDD; +import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.function.Function; +import java.util.*; +import org.apache.spark.mllib.evaluation.RegressionMetrics; +import org.apache.spark.mllib.evaluation.RankingMetrics; +import org.apache.spark.mllib.recommendation.ALS; +import org.apache.spark.mllib.recommendation.Rating; + +// Read in the ratings data +public class Ranking { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Ranking Metrics"); + JavaSparkContext sc = new JavaSparkContext(conf); + String path = "data/mllib/sample_movielens_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD ratings = data.map( + new Function() { + public Rating call(String line) { + String[] parts = line.split("::"); + return new Rating(Integer.parseInt(parts[0]), Integer.parseInt(parts[1]), Double.parseDouble(parts[2]) - 2.5); + } + } + ); + ratings.cache(); + + // Train an ALS model + final MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), 10, 10, 0.01); + + // Get top 10 recommendations for every user and scale ratings from 0 to 1 + JavaRDD> userRecs = model.recommendProductsForUsers(10).toJavaRDD(); + JavaRDD> userRecsScaled = userRecs.map( + new Function, Tuple2>() { + public Tuple2 call(Tuple2 t) { + Rating[] scaledRatings = new Rating[t._2().length]; + for (int i = 0; i < scaledRatings.length; i++) { + double newRating = Math.max(Math.min(t._2()[i].rating(), 1.0), 0.0); + scaledRatings[i] = new Rating(t._2()[i].user(), t._2()[i].product(), newRating); + } + return new Tuple2(t._1(), scaledRatings); + } + } + ); + JavaPairRDD userRecommended = JavaPairRDD.fromJavaRDD(userRecsScaled); + + // Map ratings to 1 or 0, 1 indicating a movie that should be recommended + JavaRDD binarizedRatings = ratings.map( + new Function() { + public Rating call(Rating r) { + double binaryRating; + if (r.rating() > 0.0) { + binaryRating = 1.0; + } + else { + binaryRating = 0.0; + } + return new Rating(r.user(), r.product(), binaryRating); + } + } + ); + + // Group ratings by common user + JavaPairRDD> userMovies = binarizedRatings.groupBy( + new Function() { + public Object call(Rating r) { + return r.user(); + } + } + ); + + // Get true relevant documents from all user ratings + JavaPairRDD> userMoviesList = userMovies.mapValues( + new Function, List>() { + public List call(Iterable docs) { + List products = new ArrayList(); + for (Rating r : docs) { + if (r.rating() > 0.0) { + products.add(r.product()); + } + } + return products; + } + } + ); + + // Extract the product id from each recommendation + JavaPairRDD> userRecommendedList = userRecommended.mapValues( + new Function>() { + public List call(Rating[] docs) { + List products = new ArrayList(); + for (Rating r : docs) { + products.add(r.product()); + } + return products; + } + } + ); + JavaRDD, List>> relevantDocs = userMoviesList.join(userRecommendedList).values(); + + // Instantiate the metrics object + RankingMetrics metrics = RankingMetrics.of(relevantDocs); + + // Precision and NDCG at k + Integer[] kVector = {1, 3, 5}; + for (Integer k : kVector) { + System.out.format("Precision at %d = %f\n", k, metrics.precisionAt(k)); + System.out.format("NDCG at %d = %f\n", k, metrics.ndcgAt(k)); + } + + // Mean average precision + System.out.format("Mean average precision = %f\n", metrics.meanAveragePrecision()); + + // Evaluate the model using numerical ratings and regression metrics + JavaRDD> userProducts = ratings.map( + new Function>() { + public Tuple2 call(Rating r) { + return new Tuple2(r.user(), r.product()); + } + } + ); + JavaPairRDD, Object> predictions = JavaPairRDD.fromJavaRDD( + model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map( + new Function, Object>>() { + public Tuple2, Object> call(Rating r){ + return new Tuple2, Object>( + new Tuple2(r.user(), r.product()), r.rating()); + } + } + )); + JavaRDD> ratesAndPreds = + JavaPairRDD.fromJavaRDD(ratings.map( + new Function, Object>>() { + public Tuple2, Object> call(Rating r){ + return new Tuple2, Object>( + new Tuple2(r.user(), r.product()), r.rating()); + } + } + )).join(predictions).values(); + + // Create regression metrics object + RegressionMetrics regressionMetrics = new RegressionMetrics(ratesAndPreds.rdd()); + + // Root mean squared error + System.out.format("RMSE = %f\n", regressionMetrics.rootMeanSquaredError()); + + // R-squared + System.out.format("R-squared = %f\n", regressionMetrics.r2()); + } +} + +{% endhighlight %} + +
    + +
    + +{% highlight python %} +from pyspark.mllib.recommendation import ALS, Rating +from pyspark.mllib.evaluation import RegressionMetrics, RankingMetrics + +# Read in the ratings data +lines = sc.textFile("data/mllib/sample_movielens_data.txt") + +def parseLine(line): + fields = line.split("::") + return Rating(int(fields[0]), int(fields[1]), float(fields[2]) - 2.5) +ratings = lines.map(lambda r: parseLine(r)) + +# Train a model on to predict user-product ratings +model = ALS.train(ratings, 10, 10, 0.01) + +# Get predicted ratings on all existing user-product pairs +testData = ratings.map(lambda p: (p.user, p.product)) +predictions = model.predictAll(testData).map(lambda r: ((r.user, r.product), r.rating)) + +ratingsTuple = ratings.map(lambda r: ((r.user, r.product), r.rating)) +scoreAndLabels = predictions.join(ratingsTuple).map(lambda tup: tup[1]) + +# Instantiate regression metrics to compare predicted and actual ratings +metrics = RegressionMetrics(scoreAndLabels) + +# Root mean sqaured error +print "RMSE = %s" % metrics.rootMeanSquaredError + +# R-squared +print "R-squared = %s" % metrics.r2 + +{% endhighlight %} + +
    +
    + +## Regression model evaluation + +[Regression analysis](https://en.wikipedia.org/wiki/Regression_analysis) is used when predicting a continuous output +variable from a number of independent variables. + +**Available metrics** + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    MetricDefinition
    Mean Squared Error (MSE)$MSE = \frac{\sum_{i=0}^{N-1} (\mathbf{y}_i - \hat{\mathbf{y}}_i)^2}{N}$
    Root Mean Squared Error (RMSE)$RMSE = \sqrt{\frac{\sum_{i=0}^{N-1} (\mathbf{y}_i - \hat{\mathbf{y}}_i)^2}{N}}$
    Mean Absoloute Error (MAE)$MAE=\sum_{i=0}^{N-1} \left|\mathbf{y}_i - \hat{\mathbf{y}}_i\right|$
    Coefficient of Determination $(R^2)$$R^2=1 - \frac{MSE}{\text{VAR}(\mathbf{y}) \cdot (N-1)}=1-\frac{\sum_{i=0}^{N-1} + (\mathbf{y}_i - \hat{\mathbf{y}}_i)^2}{\sum_{i=0}^{N-1}(\mathbf{y}_i-\bar{\mathbf{y}})^2}$
    Explained Variance$1 - \frac{\text{VAR}(\mathbf{y} - \mathbf{\hat{y}})}{\text{VAR}(\mathbf{y})}$
    + +**Examples** + +
    +The following code snippets illustrate how to load a sample dataset, train a linear regression algorithm on the data, +and evaluate the performance of the algorithm by several regression metrics. + +
    + +{% highlight scala %} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.regression.LinearRegressionModel +import org.apache.spark.mllib.regression.LinearRegressionWithSGD +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.evaluation.RegressionMetrics +import org.apache.spark.mllib.util.MLUtils + +// Load the data +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_linear_regression_data.txt").cache() + +// Build the model +val numIterations = 100 +val model = LinearRegressionWithSGD.train(data, numIterations) + +// Get predictions +val valuesAndPreds = data.map{ point => + val prediction = model.predict(point.features) + (prediction, point.label) +} + +// Instantiate metrics object +val metrics = new RegressionMetrics(valuesAndPreds) + +// Squared error +println(s"MSE = ${metrics.meanSquaredError}") +println(s"RMSE = ${metrics.rootMeanSquaredError}") + +// R-squared +println(s"R-squared = ${metrics.r2}") + +// Mean absolute error +println(s"MAE = ${metrics.meanAbsoluteError}") + +// Explained variance +println(s"Explained variance = ${metrics.explainedVariance}") + +{% endhighlight %} + +
    + +
    + +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.regression.LinearRegressionModel; +import org.apache.spark.mllib.regression.LinearRegressionWithSGD; +import org.apache.spark.mllib.evaluation.RegressionMetrics; +import org.apache.spark.SparkConf; + +public class LinearRegression { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Linear Regression Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + + // Load and parse the data + String path = "data/mllib/sample_linear_regression_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD parsedData = data.map( + new Function() { + public LabeledPoint call(String line) { + String[] parts = line.split(" "); + double[] v = new double[parts.length - 1]; + for (int i = 1; i < parts.length - 1; i++) + v[i - 1] = Double.parseDouble(parts[i].split(":")[1]); + return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); + } + } + ); + parsedData.cache(); + + // Building the model + int numIterations = 100; + final LinearRegressionModel model = + LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations); + + // Evaluate model on training examples and compute training error + JavaRDD> valuesAndPreds = parsedData.map( + new Function>() { + public Tuple2 call(LabeledPoint point) { + double prediction = model.predict(point.features()); + return new Tuple2(prediction, point.label()); + } + } + ); + + // Instantiate metrics object + RegressionMetrics metrics = new RegressionMetrics(valuesAndPreds.rdd()); + + // Squared error + System.out.format("MSE = %f\n", metrics.meanSquaredError()); + System.out.format("RMSE = %f\n", metrics.rootMeanSquaredError()); + + // R-squared + System.out.format("R Squared = %f\n", metrics.r2()); + + // Mean absolute error + System.out.format("MAE = %f\n", metrics.meanAbsoluteError()); + + // Explained variance + System.out.format("Explained Variance = %f\n", metrics.explainedVariance()); + + // Save and load model + model.save(sc.sc(), "myModelPath"); + LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), "myModelPath"); + } +} + +{% endhighlight %} + +
    + +
    + +{% highlight python %} +from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD +from pyspark.mllib.evaluation import RegressionMetrics +from pyspark.mllib.linalg import DenseVector + +# Load and parse the data +def parsePoint(line): + values = line.split() + return LabeledPoint(float(values[0]), DenseVector([float(x.split(':')[1]) for x in values[1:]])) + +data = sc.textFile("data/mllib/sample_linear_regression_data.txt") +parsedData = data.map(parsePoint) + +# Build the model +model = LinearRegressionWithSGD.train(parsedData) + +# Get predictions +valuesAndPreds = parsedData.map(lambda p: (float(model.predict(p.features)), p.label)) + +# Instantiate metrics object +metrics = RegressionMetrics(valuesAndPreds) + +# Squared Error +print "MSE = %s" % metrics.meanSquaredError +print "RMSE = %s" % metrics.rootMeanSquaredError + +# R-squared +print "R-squared = %s" % metrics.r2 + +# Mean absolute error +print "MAE = %s" % metrics.meanAbsoluteError + +# Explained variance +print "Explained variance = %s" % metrics.explainedVariance + +{% endhighlight %} + +
    +
    \ No newline at end of file diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index d2d1cc93fe006..eea864eacf7c4 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -48,6 +48,7 @@ This lists functionality included in `spark.mllib`, the main MLlib API. * [Feature extraction and transformation](mllib-feature-extraction.html) * [Frequent pattern mining](mllib-frequent-pattern-mining.html) * FP-growth +* [Evaluation Metrics](mllib-evaluation-metrics.html) * [Optimization (developer)](mllib-optimization.html) * stochastic gradient descent * limited-memory BFGS (L-BFGS) diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 3927d65fbf8fb..07655baa414b5 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -10,7 +10,7 @@ displayTitle: MLlib - Linear Methods `\[ \newcommand{\R}{\mathbb{R}} -\newcommand{\E}{\mathbb{E}} +\newcommand{\E}{\mathbb{E}} \newcommand{\x}{\mathbf{x}} \newcommand{\y}{\mathbf{y}} \newcommand{\wv}{\mathbf{w}} @@ -18,10 +18,10 @@ displayTitle: MLlib - Linear Methods \newcommand{\bv}{\mathbf{b}} \newcommand{\N}{\mathbb{N}} \newcommand{\id}{\mathbf{I}} -\newcommand{\ind}{\mathbf{1}} -\newcommand{\0}{\mathbf{0}} -\newcommand{\unit}{\mathbf{e}} -\newcommand{\one}{\mathbf{1}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} \newcommand{\zero}{\mathbf{0}} \]` @@ -29,7 +29,7 @@ displayTitle: MLlib - Linear Methods Many standard *machine learning* methods can be formulated as a convex optimization problem, i.e. the task of finding a minimizer of a convex function `$f$` that depends on a variable vector -`$\wv$` (called `weights` in the code), which has `$d$` entries. +`$\wv$` (called `weights` in the code), which has `$d$` entries. Formally, we can write this as the optimization problem `$\min_{\wv \in\R^d} \; f(\wv)$`, where the objective function is of the form `\begin{equation} @@ -39,7 +39,7 @@ the objective function is of the form \ . \end{equation}` Here the vectors `$\x_i\in\R^d$` are the training data examples, for `$1\le i\le n$`, and -`$y_i\in\R$` are their corresponding labels, which we want to predict. +`$y_i\in\R$` are their corresponding labels, which we want to predict. We call the method *linear* if $L(\wv; \x, y)$ can be expressed as a function of $\wv^T x$ and $y$. Several of MLlib's classification and regression algorithms fall into this category, and are discussed here. @@ -99,6 +99,9 @@ regularizers in MLlib: L1$\|\wv\|_1$$\mathrm{sign}(\wv)$ + + elastic net$\alpha \|\wv\|_1 + (1-\alpha)\frac{1}{2}\|\wv\|_2^2$$\alpha \mathrm{sign}(\wv) + (1-\alpha) \wv$ + @@ -107,7 +110,7 @@ of `$\wv$`. L2-regularized problems are generally easier to solve than L1-regularized due to smoothness. However, L1 regularization can help promote sparsity in weights leading to smaller and more interpretable models, the latter of which can be useful for feature selection. -It is not recommended to train models without any regularization, +[Elastic net](http://en.wikipedia.org/wiki/Elastic_net_regularization) is a combination of L1 and L2 regularization. It is not recommended to train models without any regularization, especially when the number of training examples is small. ### Optimization @@ -531,7 +534,7 @@ sameModel = LogisticRegressionModel.load(sc, "myModelPath") ### Linear least squares, Lasso, and ridge regression -Linear least squares is the most common formulation for regression problems. +Linear least squares is the most common formulation for regression problems. It is a linear method as described above in equation `$\eqref{eq:regPrimal}$`, with the loss function in the formulation given by the squared loss: `\[ @@ -539,8 +542,8 @@ L(\wv;\x,y) := \frac{1}{2} (\wv^T \x - y)^2. \]` Various related regression methods are derived by using different types of regularization: -[*ordinary least squares*](http://en.wikipedia.org/wiki/Ordinary_least_squares) or -[*linear least squares*](http://en.wikipedia.org/wiki/Linear_least_squares_(mathematics)) uses +[*ordinary least squares*](http://en.wikipedia.org/wiki/Ordinary_least_squares) or +[*linear least squares*](http://en.wikipedia.org/wiki/Linear_least_squares_(mathematics)) uses no regularization; [*ridge regression*](http://en.wikipedia.org/wiki/Ridge_regression) uses L2 regularization; and [*Lasso*](http://en.wikipedia.org/wiki/Lasso_(statistics)) uses L1 regularization. For all of these models, the average loss or training error, $\frac{1}{n} \sum_{i=1}^n (\wv^T x_i - y_i)^2$, is @@ -552,7 +555,7 @@ known as the [mean squared error](http://en.wikipedia.org/wiki/Mean_squared_erro
    The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint. -The example then uses LinearRegressionWithSGD to build a simple linear model to predict label +The example then uses LinearRegressionWithSGD to build a simple linear model to predict label values. We compute the mean squared error at the end to evaluate [goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit). @@ -614,7 +617,7 @@ public class LinearRegression { public static void main(String[] args) { SparkConf conf = new SparkConf().setAppName("Linear Regression Example"); JavaSparkContext sc = new JavaSparkContext(conf); - + // Load and parse the data String path = "data/mllib/ridge-data/lpsa.data"; JavaRDD data = sc.textFile(path); @@ -634,7 +637,7 @@ public class LinearRegression { // Building the model int numIterations = 100; - final LinearRegressionModel model = + final LinearRegressionModel model = LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations); // Evaluate model on training examples and compute training error @@ -665,7 +668,7 @@ public class LinearRegression {
    The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint. -The example then uses LinearRegressionWithSGD to build a simple linear model to predict label +The example then uses LinearRegressionWithSGD to build a simple linear model to predict label values. We compute the mean squared error at the end to evaluate [goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit). @@ -706,8 +709,8 @@ a dependency. ###Streaming linear regression -When data arrive in a streaming fashion, it is useful to fit regression models online, -updating the parameters of the model as new data arrives. MLlib currently supports +When data arrive in a streaming fashion, it is useful to fit regression models online, +updating the parameters of the model as new data arrives. MLlib currently supports streaming linear regression using ordinary least squares. The fitting is similar to that performed offline, except fitting occurs on each batch of data, so that the model continually updates to reflect the data from the stream. @@ -722,7 +725,7 @@ online to the first stream, and make predictions on the second stream.
    -First, we import the necessary classes for parsing our input data and creating the model. +First, we import the necessary classes for parsing our input data and creating the model. {% highlight scala %} @@ -734,7 +737,7 @@ import org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD Then we make input streams for training and testing data. We assume a StreamingContext `ssc` has already been created, see [Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) -for more info. For this example, we use labeled points in training and testing streams, +for more info. For this example, we use labeled points in training and testing streams, but in practice you will likely want to use unlabeled vectors for test data. {% highlight scala %} @@ -754,7 +757,7 @@ val model = new StreamingLinearRegressionWithSGD() {% endhighlight %} -Now we register the streams for training and testing and start the job. +Now we register the streams for training and testing and start the job. Printing predictions alongside true labels lets us easily see the result. {% highlight scala %} @@ -764,14 +767,14 @@ model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print() ssc.start() ssc.awaitTermination() - + {% endhighlight %} We can now save text files with data to the training or testing folders. -Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label -and `x1,x2,x3` are the features. Anytime a text file is placed in `/training/data/dir` -the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. -As you feed more data to the training directory, the predictions +Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label +and `x1,x2,x3` are the features. Anytime a text file is placed in `/training/data/dir` +the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. +As you feed more data to the training directory, the predictions will get better!
    diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md index 887eae7f4f07b..de5d6485f9b5f 100644 --- a/docs/mllib-statistics.md +++ b/docs/mllib-statistics.md @@ -283,7 +283,7 @@ approxSample = data.sampleByKey(False, fractions); Hypothesis testing is a powerful tool in statistics to determine whether a result is statistically significant, whether this result occurred by chance or not. MLlib currently supports Pearson's -chi-squared ( $\chi^2$) tests for goodness of fit and independence. The input data types determine +chi-squared ( $\chi^2$) tests for goodness of fit and independence. The input data types determine whether the goodness of fit or the independence test is conducted. The goodness of fit test requires an input type of `Vector`, whereas the independence test requires a `Matrix` as input. @@ -422,6 +422,41 @@ for i, result in enumerate(featureTestResults):
    +Additionally, MLlib provides a 1-sample, 2-sided implementation of the Kolmogorov-Smirnov (KS) test +for equality of probability distributions. By providing the name of a theoretical distribution +(currently solely supported for the normal distribution) and its parameters, or a function to +calculate the cumulative distribution according to a given theoretical distribution, the user can +test the null hypothesis that their sample is drawn from that distribution. In the case that the +user tests against the normal distribution (`distName="norm"`), but does not provide distribution +parameters, the test initializes to the standard normal distribution and logs an appropriate +message. + +
    +
    +[`Statistics`](api/scala/index.html#org.apache.spark.mllib.stat.Statistics$) provides methods to +run a 1-sample, 2-sided Kolmogorov-Smirnov test. The following example demonstrates how to run +and interpret the hypothesis tests. + +{% highlight scala %} +import org.apache.spark.SparkContext +import org.apache.spark.mllib.stat.Statistics._ + +val data: RDD[Double] = ... // an RDD of sample data + +// run a KS test for the sample versus a standard normal distribution +val testResult = Statistics.kolmogorovSmirnovTest(data, "norm", 0, 1) +println(testResult) // summary of the test including the p-value, test statistic, + // and null hypothesis + // if our p-value indicates significance, we can reject the null hypothesis + +// perform a KS test using a cumulative distribution function of our making +val myCDF: Double => Double = ... +val testResult2 = Statistics.kolmogorovSmirnovTest(data, myCDF) +{% endhighlight %} +
    +
    + + ## Random data generation Random data generation is useful for randomized algorithms, prototyping, and performance testing. diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 5f1d6daeb27f0..debdd2adf22d6 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -184,6 +184,14 @@ acquire. By default, it will acquire *all* cores in the cluster (that get offere only makes sense if you run just one application at a time. You can cap the maximum number of cores using `conf.set("spark.cores.max", "10")` (for example). +You may also make use of `spark.mesos.constraints` to set attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. + +{% highlight scala %} +conf.set("spark.mesos.constraints", "tachyon=true;us-east-1=false") +{% endhighlight %} + +For example, Let's say `spark.mesos.constraints` is set to `tachyon=true;us-east-1=false`, then the resource offers will be checked to see if they meet both these constraints and only then will be accepted to start new executors. + # Mesos Docker Support Spark can make use of a Mesos Docker containerizer by setting the property `spark.mesos.executor.docker.image` @@ -298,6 +306,42 @@ See the [configuration page](configuration.html) for information on Spark config the final overhead will be this value. + + spark.mesos.principal + Framework principal to authenticate to Mesos + + Set the principal with which Spark framework will use to authenticate with Mesos. + + + + spark.mesos.secret + Framework secret to authenticate to Mesos + + Set the secret with which Spark framework will use to authenticate with Mesos. + + + + spark.mesos.role + Role for the Spark framework + + Set the role of this Spark framework for Mesos. Roles are used in Mesos for reservations + and resource weight sharing. + + + + spark.mesos.constraints + Attribute based constraints to be matched against when accepting resource offers. + + Attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. Refer to Mesos Attributes & Resources for more information on attributes. +
      +
    • Scalar constraints are matched with "less than equal" semantics i.e. value in the constraint must be less than or equal to the value in the resource offer.
    • +
    • Range constraints are matched with "contains" semantics i.e. value in the constraint must be within the resource offer's value.
    • +
    • Set constraints are matched with "subset of" semantics i.e. value in the constraint must be a subset of the resource offer's value.
    • +
    • Text constraints are metched with "equality" semantics i.e. value in the constraint must be exactly equal to the resource offer's value.
    • +
    • In case there is no value present as a part of the constraint any offer with the corresponding attribute will be accepted (without value check).
    • +
    + + # Troubleshooting and Debugging diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index de22ab557cacf..cac08a91b97d9 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -68,9 +68,9 @@ In YARN terminology, executors and application masters run inside "containers". yarn logs -applicationId -will print out the contents of all log files from all containers from the given application. You can also view the container log files directly in HDFS using the HDFS shell or API. The directory where they are located can be found by looking at your YARN configs (`yarn.nodemanager.remote-app-log-dir` and `yarn.nodemanager.remote-app-log-dir-suffix`). +will print out the contents of all log files from all containers from the given application. You can also view the container log files directly in HDFS using the HDFS shell or API. The directory where they are located can be found by looking at your YARN configs (`yarn.nodemanager.remote-app-log-dir` and `yarn.nodemanager.remote-app-log-dir-suffix`). The logs are also available on the Spark Web UI under the Executors Tab. You need to have both the Spark history server and the MapReduce history server running and configure `yarn.log.server.url` in `yarn-site.xml` properly. The log URL on the Spark history server UI will redirect you to the MapReduce history server to show the aggregated logs. -When log aggregation isn't turned on, logs are retained locally on each machine under `YARN_APP_LOGS_DIR`, which is usually configured to `/tmp/logs` or `$HADOOP_HOME/logs/userlogs` depending on the Hadoop version and installation. Viewing logs for a container requires going to the host that contains them and looking in this directory. Subdirectories organize log files by application ID and container ID. +When log aggregation isn't turned on, logs are retained locally on each machine under `YARN_APP_LOGS_DIR`, which is usually configured to `/tmp/logs` or `$HADOOP_HOME/logs/userlogs` depending on the Hadoop version and installation. Viewing logs for a container requires going to the host that contains them and looking in this directory. Subdirectories organize log files by application ID and container ID. The logs are also available on the Spark Web UI under the Executors Tab and doesn't require running the MapReduce history server. To review per-container launch environment, increase `yarn.nodemanager.delete.debug-delay-sec` to a large value (e.g. 36000), and then access the application cache through `yarn.nodemanager.local-dirs` diff --git a/docs/sparkr.md b/docs/sparkr.md index 095ea4308cfeb..4385a4eeacd5c 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -68,7 +68,7 @@ you can specify the packages with the `packages` argument.
    {% highlight r %} -sc <- sparkR.init(packages="com.databricks:spark-csv_2.11:1.0.3") +sc <- sparkR.init(sparkPackages="com.databricks:spark-csv_2.11:1.0.3") sqlContext <- sparkRSQL.init(sc) {% endhighlight %}
    @@ -116,7 +116,7 @@ sql(hiveContext, "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") sql(hiveContext, "LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") # Queries can be expressed in HiveQL. -results <- hiveContext.sql("FROM src SELECT key, value") +results <- sql(hiveContext, "FROM src SELECT key, value") # results is now a DataFrame head(results) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 88c96a9a095b3..95945eb7fc8a0 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -828,7 +828,7 @@ using this syntax. {% highlight scala %} val df = sqlContext.read.format("json").load("examples/src/main/resources/people.json") -df.select("name", "age").write.format("json").save("namesAndAges.json") +df.select("name", "age").write.format("parquet").save("namesAndAges.parquet") {% endhighlight %}
    @@ -1332,13 +1332,8 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext` spark.sql.parquet.filterPushdown - false - - Turn on Parquet filter pushdown optimization. This feature is turned off by default because of a known - bug in Parquet 1.6.0rc3 (PARQUET-136). - However, if your table doesn't contain any nullable string or binary columns, it's still safe to turn - this feature on. - + true + Enables Parquet filter push-down optimization when set to true. spark.sql.hive.convertMetastoreParquet @@ -1637,7 +1632,7 @@ sql(sqlContext, "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") sql(sqlContext, "LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") # Queries can be expressed in HiveQL. -results = sqlContext.sql("FROM src SELECT key, value").collect() +results <- collect(sql(sqlContext, "FROM src SELECT key, value")) {% endhighlight %} @@ -1798,7 +1793,7 @@ DataFrame jdbcDF = sqlContext.read().format("jdbc"). options(options).load(); {% highlight python %} -df = sqlContext.read.format('jdbc').options(url = 'jdbc:postgresql:dbserver', dbtable='schema.tablename').load() +df = sqlContext.read.format('jdbc').options(url='jdbc:postgresql:dbserver', dbtable='schema.tablename').load() {% endhighlight %} diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 33d835ba1c381..77d62047c3525 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -854,6 +854,8 @@ it with new information. To use this, you will have to do two steps. 1. Define the state update function - Specify with a function how to update the state using the previous state and the new values from an input stream. +In every batch, Spark will apply the state update function for all existing keys, regardless of whether they have new data in a batch or not. If the update function returns `None` then the key-value pair will be eliminated. + Let's illustrate this with an example. Say you want to maintain a running count of each word seen in a text data stream. Here, the running count is the state and it is an integer. We define the update function as: diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 18ccbc0a3edd0..ccf922d9371fb 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -125,7 +125,7 @@ def setup_external_libs(libs): ) with open(tgz_file_path, "wb") as tgz_file: tgz_file.write(download_stream.read()) - with open(tgz_file_path) as tar: + with open(tgz_file_path, "rb") as tar: if hashlib.md5(tar.read()).hexdigest() != lib["md5"]: print("ERROR: Got wrong md5sum for {lib}.".format(lib=lib["name"]), file=stderr) sys.exit(1) @@ -242,7 +242,7 @@ def parse_args(): help="Number of EBS volumes to attach to each node as /vol[x]. " + "The volumes will be deleted when the instances terminate. " + "Only possible on EBS-backed AMIs. " + - "EBS volumes are only attached if --ebs-vol-size > 0." + + "EBS volumes are only attached if --ebs-vol-size > 0. " + "Only support up to 8 EBS volumes.") parser.add_option( "--placement-group", type="string", default=None, @@ -325,14 +325,16 @@ def parse_args(): home_dir = os.getenv('HOME') if home_dir is None or not os.path.isfile(home_dir + '/.boto'): if not os.path.isfile('/etc/boto.cfg'): - if os.getenv('AWS_ACCESS_KEY_ID') is None: - print("ERROR: The environment variable AWS_ACCESS_KEY_ID must be set", - file=stderr) - sys.exit(1) - if os.getenv('AWS_SECRET_ACCESS_KEY') is None: - print("ERROR: The environment variable AWS_SECRET_ACCESS_KEY must be set", - file=stderr) - sys.exit(1) + # If there is no boto config, check aws credentials + if not os.path.isfile(home_dir + '/.aws/credentials'): + if os.getenv('AWS_ACCESS_KEY_ID') is None: + print("ERROR: The environment variable AWS_ACCESS_KEY_ID must be set", + file=stderr) + sys.exit(1) + if os.getenv('AWS_SECRET_ACCESS_KEY') is None: + print("ERROR: The environment variable AWS_SECRET_ACCESS_KEY must be set", + file=stderr) + sys.exit(1) return (opts, action, cluster_name) @@ -791,7 +793,7 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): ssh_write(slave_address, opts, ['tar', 'x'], dot_ssh_tar) modules = ['spark', 'ephemeral-hdfs', 'persistent-hdfs', - 'mapreduce', 'spark-standalone', 'tachyon'] + 'mapreduce', 'spark-standalone', 'tachyon', 'rstudio'] if opts.hadoop_major_version == "1": modules = list(filter(lambda x: x != "mapreduce", modules)) @@ -1153,8 +1155,8 @@ def ssh(host, opts, command): # If this was an ssh failure, provide the user with hints. if e.returncode == 255: raise UsageError( - "Failed to SSH to remote host {0}.\n" + - "Please check that you have provided the correct --identity-file and " + + "Failed to SSH to remote host {0}.\n" + "Please check that you have provided the correct --identity-file and " "--key-pair parameters and try again.".format(host)) else: raise e diff --git a/examples/src/main/r/data-manipulation.R b/examples/src/main/r/data-manipulation.R new file mode 100644 index 0000000000000..aa2336e300a91 --- /dev/null +++ b/examples/src/main/r/data-manipulation.R @@ -0,0 +1,107 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# For this example, we shall use the "flights" dataset +# The dataset consists of every flight departing Houston in 2011. +# The data set is made up of 227,496 rows x 14 columns. + +# To run this example use +# ./bin/sparkR --packages com.databricks:spark-csv_2.10:1.0.3 +# examples/src/main/r/data-manipulation.R + +# Load SparkR library into your R session +library(SparkR) + +args <- commandArgs(trailing = TRUE) + +if (length(args) != 1) { + print("Usage: data-manipulation.R % + summarize(avg(flightsDF$dep_delay), avg(flightsDF$arr_delay)) -> dailyDelayDF + + # Print the computed data frame + head(dailyDelayDF) +} + +# Stop the SparkContext now +sparkR.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala index 4c129dbe2d12d..d812262fd87dc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.spark.{SparkConf, SparkContext} @@ -52,3 +53,4 @@ object BroadcastTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala index 023bb3ee2d108..36832f51d2ad4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ + // scalastyle:off println package org.apache.spark.examples import java.nio.ByteBuffer @@ -140,3 +141,4 @@ object CassandraCQLTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala index ec689474aecb0..96ef3e198e380 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.nio.ByteBuffer @@ -130,6 +131,7 @@ object CassandraTest { sc.stop() } } +// scalastyle:on println /* create keyspace casDemo; diff --git a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala index 1f12034ce0f57..d651fe4d6ee75 100644 --- a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.io.File @@ -136,3 +137,4 @@ object DFSReadWriteTest { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala index e757283823fc3..c42df2b8845d2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import scala.collection.JavaConversions._ @@ -46,3 +47,4 @@ object DriverSubmissionTest { } } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala index 15f6678648b29..fa4a3afeecd19 100644 --- a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -53,3 +54,4 @@ object GroupByTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala index 95c96111c9b1f..244742327a907 100644 --- a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.hadoop.hbase.client.HBaseAdmin @@ -62,3 +63,4 @@ object HBaseTest { admin.close() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala index ed2b38e2ca6f8..124dc9af6390f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.spark._ @@ -41,3 +42,4 @@ object HdfsTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala index 3d5259463003d..af5f216f28ba4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.commons.math3.linear._ @@ -142,3 +143,4 @@ object LocalALS { new Array2DRowRealMatrix(Array.fill(rows, cols)(math.random)) } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala index ac2ea35bbd0e0..9c8aae53cf48d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -73,3 +74,4 @@ object LocalFileLR { println("Final w: " + w) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala index 04fc0a033014a..e7b28d38bdfc6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -119,3 +120,4 @@ object LocalKMeans { println("Final centers: " + kPoints) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala index c3fc74a116c0a..4f6b092a59ca5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -77,3 +78,4 @@ object LocalLR { println("Final w: " + w) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala b/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala index ee6b3ee34aeb2..3d923625f11b6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import scala.math.random @@ -33,3 +34,4 @@ object LocalPi { println("Pi is roughly " + 4 * count / 100000.0) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala index 75c82117cbad2..a80de10f4610a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.spark.{SparkConf, SparkContext} @@ -83,3 +84,4 @@ object LogQuery { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala index 2a5c0c0defe13..61ce9db914f9f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.spark.rdd.RDD @@ -53,3 +54,4 @@ object MultiBroadcastTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala index 5291ab81f459e..3b0b00fe4dd0a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -67,3 +68,4 @@ object SimpleSkewedGroupByTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala index 017d4e1e5ce13..719e2176fed3f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -57,3 +58,4 @@ object SkewedGroupByTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala index 30c4261551837..69799b7c2bb30 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.commons.math3.linear._ @@ -144,3 +145,4 @@ object SparkALS { new Array2DRowRealMatrix(Array.fill(rows, cols)(math.random)) } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala index 9099c2fcc90b3..505ea5a4c7a85 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -97,3 +98,4 @@ object SparkHdfsLR { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala index b514d9123f5e7..c56e1124ad415 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import breeze.linalg.{Vector, DenseVector, squaredDistance} @@ -100,3 +101,4 @@ object SparkKMeans { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala index 1e6b4fb0c7514..d265c227f4ed2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -86,3 +87,4 @@ object SparkLR { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala index bd7894f184c4c..0fd79660dd196 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.spark.SparkContext._ @@ -74,3 +75,4 @@ object SparkPageRank { ctx.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala index 35b8dd6c29b66..818d4f2b81f82 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import scala.math.random @@ -37,3 +38,4 @@ object SparkPi { spark.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala index 772cd897f5140..95072071ccddb 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import scala.util.Random @@ -70,3 +71,4 @@ object SparkTC { spark.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala index 4393b99e636b6..cfbdae02212a5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -94,3 +95,4 @@ object SparkTachyonHdfsLR { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala index 7743f7968b100..e46ac655beb58 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import scala.math.random @@ -46,3 +47,4 @@ object SparkTachyonPi { spark.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala index 409721b01c8fd..8dd6c9706e7df 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.graphx import scala.collection.mutable @@ -151,3 +152,4 @@ object Analytics extends Logging { } } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala index f6f8d9f90c275..da3ffca1a6f2a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.graphx /** @@ -42,3 +43,4 @@ object LiveJournalPageRank { Analytics.main(args.patch(0, List("pagerank"), 0)) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala index 3ec20d594b784..46e52aacd90bb 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.graphx import org.apache.spark.SparkContext._ @@ -128,3 +129,4 @@ object SynthBenchmark { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala index 6c0af20461d3b..14b358d46f6ab 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} @@ -110,3 +111,4 @@ object CrossValidatorExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala index 54e4073941056..f28671f7869fc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scala.collection.mutable @@ -355,3 +356,4 @@ object DecisionTreeExample { println(s" Root mean squared error (RMSE): $RMSE") } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index 7b8cc21ed8982..78f31b4ffe56a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} @@ -181,3 +182,4 @@ private class MyLogisticRegressionModel( copyValues(new MyLogisticRegressionModel(uid, weights), extra) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala index 33905277c7341..f4a15f806ea81 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scala.collection.mutable @@ -236,3 +237,4 @@ object GBTExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala index b54466fd48bc5..b73299fb12d3f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scala.collection.mutable @@ -140,3 +141,4 @@ object LinearRegressionExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala index 3cf193f353fbc..7682557127b51 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scala.collection.mutable @@ -157,3 +158,4 @@ object LogisticRegressionExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala index 25f21113bf622..cd411397a4b9d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scopt.OptionParser @@ -178,3 +179,4 @@ object MovieLensALS { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala index 6927eb8f275cf..bab31f585b0ef 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import java.util.concurrent.TimeUnit.{NANOSECONDS => NANO} @@ -183,3 +184,4 @@ object OneVsRestExample { (NANO.toSeconds(t1 - t0), result) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala index 9f7cad68a4594..109178f4137b2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scala.collection.mutable @@ -244,3 +245,4 @@ object RandomForestExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala index a0561e2573fc9..58d7b67674ff7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} @@ -100,3 +101,4 @@ object SimpleParamsExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala index 1324b066c30c3..960280137cbf9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scala.beans.BeanInfo @@ -89,3 +90,4 @@ object SimpleTextClassificationPipeline { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala index a113653810b93..1a4016f76c2ad 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.log4j.{Level, Logger} @@ -153,3 +154,4 @@ object BinaryClassification { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala index e49129c4e7844..026d4ecc6d10a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scopt.OptionParser @@ -91,3 +92,4 @@ object Correlations { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala index cb1abbd18fd4d..69988cc1b9334 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scopt.OptionParser @@ -106,3 +107,4 @@ object CosineSimilarity { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala index 520893b26d595..dc13f82488af7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import java.io.File @@ -119,3 +120,4 @@ object DatasetExample { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 3381941673db8..57ffe3dd2524f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scala.language.reflectiveCalls @@ -368,3 +369,4 @@ object DecisionTreeRunner { } // scalastyle:on structural.type } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala index f8c71ccabc43b..1fce4ba7efd60 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.{SparkConf, SparkContext} @@ -65,3 +66,4 @@ object DenseGaussianMixture { println() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala index 14cc5cbb679c5..380d85d60e7b4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.log4j.{Level, Logger} @@ -107,3 +108,4 @@ object DenseKMeans { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala index 13f24a1e59610..14b930550d554 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scopt.OptionParser @@ -80,3 +81,4 @@ object FPGrowthExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala index 7416fb5a40848..e16a6bf033574 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scopt.OptionParser @@ -145,3 +146,4 @@ object GradientBoostedTreesRunner { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index 31d629f853161..75b0f69cf91aa 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import java.text.BreakIterator @@ -302,3 +303,4 @@ private class SimpleTokenizer(sc: SparkContext, stopwordFile: String) extends Se } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala index 6a456ba7ec07b..8878061a0970b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.log4j.{Level, Logger} @@ -134,3 +135,4 @@ object LinearRegression { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index 99588b0984ab2..e43a6f2864c73 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scala.collection.mutable @@ -189,3 +190,4 @@ object MovieLensALS { math.sqrt(predictionsAndRatings.map(x => (x._1 - x._2) * (x._1 - x._2)).mean()) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala index 6e4e2d07f284b..5f839c75dd581 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scopt.OptionParser @@ -97,3 +98,4 @@ object MultivariateSummarizer { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala index 6d8b806569dfd..0723223954610 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.log4j.{Level, Logger} @@ -154,4 +155,4 @@ object PowerIterationClusteringExample { coeff * math.exp(expCoeff * ssquares) } } - +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala index 924b586e3af99..bee85ba0f9969 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.mllib.random.RandomRDDs @@ -58,3 +59,4 @@ object RandomRDDGeneration { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala index 663c12734af68..6963f43e082c4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.mllib.util.MLUtils @@ -125,3 +126,4 @@ object SampledRDDs { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala index f1ff4e6911f5e..f81fc292a3bd1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.log4j.{Level, Logger} @@ -100,3 +101,4 @@ object SparseNaiveBayes { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala index 8bb12d2ee9ed2..af03724a8ac62 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.SparkConf @@ -75,3 +76,4 @@ object StreamingKMeansExample { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala index 1a95048bbfe2d..b4a5dca031abd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.mllib.linalg.Vectors @@ -69,3 +70,4 @@ object StreamingLinearRegression { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala index e1998099c2d78..b42f4cb5f9338 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.mllib.linalg.Vectors @@ -71,3 +72,4 @@ object StreamingLogisticRegression { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala index 3cd9cb743e309..464fbd385ab5d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.{SparkConf, SparkContext} @@ -58,3 +59,4 @@ object TallSkinnyPCA { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala index 4d6690318615a..65b4bc46f0266 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.{SparkConf, SparkContext} @@ -58,3 +59,4 @@ object TallSkinnySVD { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala index b11e32047dc34..2cc56f04e5c1f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.sql import org.apache.spark.{SparkConf, SparkContext} @@ -73,3 +74,4 @@ object RDDRelation { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index b7ba60ec28155..bf40bd1ef13df 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.sql.hive import com.google.common.io.{ByteStreams, Files} @@ -77,3 +78,4 @@ object HiveFromSpark { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala index 016de4c63d1d2..e9c9907198769 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import scala.collection.mutable.LinkedList @@ -170,3 +171,4 @@ object ActorWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala index 30269a7ccae97..28e9bf520e568 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import java.io.{InputStreamReader, BufferedReader, InputStream} @@ -100,3 +101,4 @@ class CustomReceiver(host: String, port: Int) } } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala index fbe394de4a179..bd78526f8c299 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import kafka.serializer.StringDecoder @@ -70,3 +71,4 @@ object DirectKafkaWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala index 20e7df7c45b1b..91e52e4eff5a7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -66,3 +67,4 @@ object FlumeEventCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala index 1cc8c8d5c23b6..2bdbc37e2a289 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -65,3 +66,4 @@ object FlumePollingEventCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala index 4b4667fec44e6..1f282d437dc38 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -53,3 +54,4 @@ object HdfsWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala index 60416ee343544..b40d17e9c2fa3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import java.util.HashMap @@ -101,3 +102,4 @@ object KafkaWordCountProducer { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala index 813c8554f5193..d772ae309f40d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.eclipse.paho.client.mqttv3._ @@ -96,8 +97,10 @@ object MQTTWordCount { def main(args: Array[String]) { if (args.length < 2) { + // scalastyle:off println System.err.println( "Usage: MQTTWordCount ") + // scalastyle:on println System.exit(1) } @@ -113,3 +116,4 @@ object MQTTWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala index 2cd8073dada14..9a57fe286d1ae 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -57,3 +58,4 @@ object NetworkWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala index a9aaa445bccb6..5322929d177b4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -58,3 +59,4 @@ object RawNetworkGrep { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index 751b30ea15782..9916882e4f94a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import java.io.File @@ -108,3 +109,4 @@ object RecoverableNetworkWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala index 5a6b9216a3fbc..ed617754cbf1c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -99,3 +100,4 @@ object SQLContextSingleton { instance } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala index 345d0bc441351..02ba1c2eed0f7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -78,3 +79,4 @@ object StatefulNetworkWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala index c10de84a80ffe..825c671a929b1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import com.twitter.algebird._ @@ -113,3 +114,4 @@ object TwitterAlgebirdCMS { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala index 62db5e663b8af..49826ede70418 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import com.twitter.algebird.HyperLogLogMonoid @@ -90,3 +91,4 @@ object TwitterAlgebirdHLL { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala index f253d75b279f7..49cee1b43c2dc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.streaming.{Seconds, StreamingContext} @@ -82,3 +83,4 @@ object TwitterPopularTags { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala index e99d1baa72b9f..6ac9a72c37941 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import akka.actor.ActorSystem @@ -97,3 +98,4 @@ object ZeroMQWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala index 889f052c70263..bea7a47cb2855 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming.clickstream import java.net.ServerSocket @@ -108,3 +109,4 @@ object PageViewGenerator { } } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala index fbacaee98690f..ec7d39da8b2e9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming.clickstream import org.apache.spark.SparkContext._ @@ -107,3 +108,4 @@ object PageViewStream { ssc.start() } } +// scalastyle:on println diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml index 8565cd83edfa2..13189595d1d6c 100644 --- a/external/flume-assembly/pom.xml +++ b/external/flume-assembly/pom.xml @@ -32,6 +32,7 @@ http://spark.apache.org/ + provided streaming-flume-assembly @@ -40,6 +41,16 @@ org.apache.spark spark-streaming-flume_${scala.binary.version} ${project.version} + + + org.mortbay.jetty + jetty + + + org.mortbay.jetty + jetty-util + + org.apache.spark @@ -47,89 +58,101 @@ ${project.version} provided + + + commons-codec + commons-codec + provided + + + commons-net + commons-net + provided + + + com.google.protobuf + protobuf-java + provided + org.apache.avro avro - ${avro.version} + provided org.apache.avro avro-ipc - ${avro.version} - - - io.netty - netty - - - org.mortbay.jetty - jetty - - - org.mortbay.jetty - jetty-util - - - org.mortbay.jetty - servlet-api - - - org.apache.velocity - velocity - - + provided + + + org.scala-lang + scala-library + provided - target/scala-${scala.binary.version}/classes - target/scala-${scala.binary.version}/test-classes - - - org.apache.maven.plugins - maven-shade-plugin - - false - ${project.build.directory}/scala-${scala.binary.version}/spark-streaming-flume-assembly-${project.version}.jar - - - *:* - - - - - *:* - - META-INF/*.SF - META-INF/*.DSA - META-INF/*.RSA - - - - - - - package - - shade - - - - - - reference.conf - - - log4j.properties - - - - - - - - - - + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-shade-plugin + + false + ${project.build.directory}/scala-${scala.binary.version}/spark-streaming-flume-assembly-${project.version}.jar + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + package + + shade + + + + + + reference.conf + + + log4j.properties + + + + + + + + + + + + + + flume-provided + + provided + + + diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala index 17cbc6707b5ea..d87b86932dd41 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala @@ -113,7 +113,9 @@ private[sink] object Logging { try { // We use reflection here to handle the case where users remove the // slf4j-to-jul bridge order to route their logs to JUL. + // scalastyle:off classforname val bridgeClass = Class.forName("org.slf4j.bridge.SLF4JBridgeHandler") + // scalastyle:on classforname bridgeClass.getMethod("removeHandlersForRootLogger").invoke(null) val installed = bridgeClass.getMethod("isInstalled").invoke(null).asInstanceOf[Boolean] if (!installed) { diff --git a/external/kafka-assembly/pom.xml b/external/kafka-assembly/pom.xml index 8059c443827ef..977514fa5a1ec 100644 --- a/external/kafka-assembly/pom.xml +++ b/external/kafka-assembly/pom.xml @@ -58,6 +58,7 @@ maven-shade-plugin false + ${project.build.directory}/scala-${scala.binary.version}/spark-streaming-kafka-assembly-${project.version}.jar *:* diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index 876456c964770..48a1933d92f85 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -19,7 +19,7 @@ package org.apache.spark.streaming.kafka import scala.annotation.tailrec import scala.collection.mutable -import scala.reflect.{classTag, ClassTag} +import scala.reflect.ClassTag import kafka.common.TopicAndPartition import kafka.message.MessageAndMetadata @@ -29,7 +29,7 @@ import org.apache.spark.{Logging, SparkException} import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset -import org.apache.spark.streaming.scheduler.InputInfo +import org.apache.spark.streaming.scheduler.StreamInputInfo /** * A stream of {@link org.apache.spark.streaming.kafka.KafkaRDD} where @@ -119,8 +119,23 @@ class DirectKafkaInputDStream[ val rdd = KafkaRDD[K, V, U, T, R]( context.sparkContext, kafkaParams, currentOffsets, untilOffsets, messageHandler) - // Report the record number of this batch interval to InputInfoTracker. - val inputInfo = InputInfo(id, rdd.count) + // Report the record number and metadata of this batch interval to InputInfoTracker. + val offsetRanges = currentOffsets.map { case (tp, fo) => + val uo = untilOffsets(tp) + OffsetRange(tp.topic, tp.partition, fo, uo.offset) + } + val description = offsetRanges.filter { offsetRange => + // Don't display empty ranges. + offsetRange.fromOffset != offsetRange.untilOffset + }.map { offsetRange => + s"topic: ${offsetRange.topic}\tpartition: ${offsetRange.partition}\t" + + s"offsets: ${offsetRange.fromOffset} to ${offsetRange.untilOffset}" + }.mkString("\n") + // Copy offsetRanges to immutable.List to prevent from being modified by the user + val metadata = Map( + "offsets" -> offsetRanges.toList, + StreamInputInfo.METADATA_KEY_DESCRIPTION -> description) + val inputInfo = StreamInputInfo(id, rdd.count, metadata) ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) currentOffsets = untilOffsets.map(kv => kv._1 -> kv._2.offset) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala index 3e6b937af57b0..8465432c5850f 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala @@ -410,7 +410,7 @@ object KafkaCluster { } Seq("zookeeper.connect", "group.id").foreach { s => - if (!props.contains(s)) { + if (!props.containsKey(s)) { props.setProperty(s, "") } } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala index c5cd2154772ac..1a9d78c0d4f59 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -98,8 +98,7 @@ class KafkaRDD[ val res = context.runJob( this, (tc: TaskContext, it: Iterator[R]) => it.take(parts(tc.partitionId)).toArray, - parts.keys.toArray, - allowLocal = true) + parts.keys.toArray) res.foreach(buf ++= _) buf.toArray } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index 0e33362d34acd..f3b01bd60b178 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -670,4 +670,17 @@ private class KafkaUtilsPythonHelper { TopicAndPartition(topic, partition) def createBroker(host: String, port: JInt): Broker = Broker(host, port) + + def offsetRangesOfKafkaRDD(rdd: RDD[_]): JList[OffsetRange] = { + val parentRDDs = rdd.getNarrowAncestors + val kafkaRDDs = parentRDDs.filter(rdd => rdd.isInstanceOf[KafkaRDD[_, _, _, _, _]]) + + require( + kafkaRDDs.length == 1, + "Cannot get offset ranges, as there may be multiple Kafka RDDs or no Kafka RDD associated" + + "with this RDD, please call this method only on a Kafka RDD.") + + val kafkaRDD = kafkaRDDs.head.asInstanceOf[KafkaRDD[_, _, _, _, _]] + kafkaRDD.offsetRanges.toSeq + } } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala index 2675042666304..f326e7f1f6f8d 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala @@ -75,7 +75,7 @@ final class OffsetRange private( } override def toString(): String = { - s"OffsetRange(topic: '$topic', partition: $partition, range: [$fromOffset -> $untilOffset]" + s"OffsetRange(topic: '$topic', partition: $partition, range: [$fromOffset -> $untilOffset])" } /** this is to avoid ClassNotFoundException during checkpoint restore */ diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index 8e1715f6dbb95..5b3c79444aa68 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -111,7 +111,7 @@ class DirectKafkaStreamSuite rdd }.foreachRDD { rdd => for (o <- offsetRanges) { - println(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") + logInfo(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") } val collected = rdd.mapPartitionsWithIndex { (i, iter) => // For each partition, get size of the range in the partition, diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index 5289073eb457a..c242e7a57b9ab 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -40,6 +40,13 @@ spark-streaming_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-streaming_${scala.binary.version} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala index be8b62d3cc6ba..de749626ec09c 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import java.nio.ByteBuffer @@ -272,3 +273,4 @@ private[streaming] object StreamingExamples extends Logging { } } } +// scalastyle:on println diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala new file mode 100644 index 0000000000000..8f144a4d974a8 --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala @@ -0,0 +1,285 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import scala.collection.JavaConversions._ +import scala.util.control.NonFatal + +import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain} +import com.amazonaws.services.kinesis.AmazonKinesisClient +import com.amazonaws.services.kinesis.model._ + +import org.apache.spark._ +import org.apache.spark.rdd.{BlockRDD, BlockRDDPartition} +import org.apache.spark.storage.BlockId +import org.apache.spark.util.NextIterator + + +/** Class representing a range of Kinesis sequence numbers. Both sequence numbers are inclusive. */ +private[kinesis] +case class SequenceNumberRange( + streamName: String, shardId: String, fromSeqNumber: String, toSeqNumber: String) + +/** Class representing an array of Kinesis sequence number ranges */ +private[kinesis] +case class SequenceNumberRanges(ranges: Array[SequenceNumberRange]) { + def isEmpty(): Boolean = ranges.isEmpty + def nonEmpty(): Boolean = ranges.nonEmpty + override def toString(): String = ranges.mkString("SequenceNumberRanges(", ", ", ")") +} + +private[kinesis] +object SequenceNumberRanges { + def apply(range: SequenceNumberRange): SequenceNumberRanges = { + new SequenceNumberRanges(Array(range)) + } +} + + +/** Partition storing the information of the ranges of Kinesis sequence numbers to read */ +private[kinesis] +class KinesisBackedBlockRDDPartition( + idx: Int, + blockId: BlockId, + val isBlockIdValid: Boolean, + val seqNumberRanges: SequenceNumberRanges + ) extends BlockRDDPartition(blockId, idx) + +/** + * A BlockRDD where the block data is backed by Kinesis, which can accessed using the + * sequence numbers of the corresponding blocks. + */ +private[kinesis] +class KinesisBackedBlockRDD( + sc: SparkContext, + regionId: String, + endpointUrl: String, + @transient blockIds: Array[BlockId], + @transient arrayOfseqNumberRanges: Array[SequenceNumberRanges], + @transient isBlockIdValid: Array[Boolean] = Array.empty, + retryTimeoutMs: Int = 10000, + awsCredentialsOption: Option[SerializableAWSCredentials] = None + ) extends BlockRDD[Array[Byte]](sc, blockIds) { + + require(blockIds.length == arrayOfseqNumberRanges.length, + "Number of blockIds is not equal to the number of sequence number ranges") + + override def isValid(): Boolean = true + + override def getPartitions: Array[Partition] = { + Array.tabulate(blockIds.length) { i => + val isValid = if (isBlockIdValid.length == 0) true else isBlockIdValid(i) + new KinesisBackedBlockRDDPartition(i, blockIds(i), isValid, arrayOfseqNumberRanges(i)) + } + } + + override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { + val blockManager = SparkEnv.get.blockManager + val partition = split.asInstanceOf[KinesisBackedBlockRDDPartition] + val blockId = partition.blockId + + def getBlockFromBlockManager(): Option[Iterator[Array[Byte]]] = { + logDebug(s"Read partition data of $this from block manager, block $blockId") + blockManager.get(blockId).map(_.data.asInstanceOf[Iterator[Array[Byte]]]) + } + + def getBlockFromKinesis(): Iterator[Array[Byte]] = { + val credenentials = awsCredentialsOption.getOrElse { + new DefaultAWSCredentialsProviderChain().getCredentials() + } + partition.seqNumberRanges.ranges.iterator.flatMap { range => + new KinesisSequenceRangeIterator( + credenentials, endpointUrl, regionId, range, retryTimeoutMs) + } + } + if (partition.isBlockIdValid) { + getBlockFromBlockManager().getOrElse { getBlockFromKinesis() } + } else { + getBlockFromKinesis() + } + } +} + + +/** + * An iterator that return the Kinesis data based on the given range of sequence numbers. + * Internally, it repeatedly fetches sets of records starting from the fromSequenceNumber, + * until the endSequenceNumber is reached. + */ +private[kinesis] +class KinesisSequenceRangeIterator( + credentials: AWSCredentials, + endpointUrl: String, + regionId: String, + range: SequenceNumberRange, + retryTimeoutMs: Int + ) extends NextIterator[Array[Byte]] with Logging { + + private val client = new AmazonKinesisClient(credentials) + private val streamName = range.streamName + private val shardId = range.shardId + + private var toSeqNumberReceived = false + private var lastSeqNumber: String = null + private var internalIterator: Iterator[Record] = null + + client.setEndpoint(endpointUrl, "kinesis", regionId) + + override protected def getNext(): Array[Byte] = { + var nextBytes: Array[Byte] = null + if (toSeqNumberReceived) { + finished = true + } else { + + if (internalIterator == null) { + + // If the internal iterator has not been initialized, + // then fetch records from starting sequence number + internalIterator = getRecords(ShardIteratorType.AT_SEQUENCE_NUMBER, range.fromSeqNumber) + } else if (!internalIterator.hasNext) { + + // If the internal iterator does not have any more records, + // then fetch more records after the last consumed sequence number + internalIterator = getRecords(ShardIteratorType.AFTER_SEQUENCE_NUMBER, lastSeqNumber) + } + + if (!internalIterator.hasNext) { + + // If the internal iterator still does not have any data, then throw exception + // and terminate this iterator + finished = true + throw new SparkException( + s"Could not read until the end sequence number of the range: $range") + } else { + + // Get the record, copy the data into a byte array and remember its sequence number + val nextRecord: Record = internalIterator.next() + val byteBuffer = nextRecord.getData() + nextBytes = new Array[Byte](byteBuffer.remaining()) + byteBuffer.get(nextBytes) + lastSeqNumber = nextRecord.getSequenceNumber() + + // If the this record's sequence number matches the stopping sequence number, then make sure + // the iterator is marked finished next time getNext() is called + if (nextRecord.getSequenceNumber == range.toSeqNumber) { + toSeqNumberReceived = true + } + } + + } + nextBytes + } + + override protected def close(): Unit = { + client.shutdown() + } + + /** + * Get records starting from or after the given sequence number. + */ + private def getRecords(iteratorType: ShardIteratorType, seqNum: String): Iterator[Record] = { + val shardIterator = getKinesisIterator(iteratorType, seqNum) + val result = getRecordsAndNextKinesisIterator(shardIterator) + result._1 + } + + /** + * Get the records starting from using a Kinesis shard iterator (which is a progress handle + * to get records from Kinesis), and get the next shard iterator for next consumption. + */ + private def getRecordsAndNextKinesisIterator( + shardIterator: String): (Iterator[Record], String) = { + val getRecordsRequest = new GetRecordsRequest + getRecordsRequest.setRequestCredentials(credentials) + getRecordsRequest.setShardIterator(shardIterator) + val getRecordsResult = retryOrTimeout[GetRecordsResult]( + s"getting records using shard iterator") { + client.getRecords(getRecordsRequest) + } + (getRecordsResult.getRecords.iterator(), getRecordsResult.getNextShardIterator) + } + + /** + * Get the Kinesis shard iterator for getting records starting from or after the given + * sequence number. + */ + private def getKinesisIterator( + iteratorType: ShardIteratorType, + sequenceNumber: String): String = { + val getShardIteratorRequest = new GetShardIteratorRequest + getShardIteratorRequest.setRequestCredentials(credentials) + getShardIteratorRequest.setStreamName(streamName) + getShardIteratorRequest.setShardId(shardId) + getShardIteratorRequest.setShardIteratorType(iteratorType.toString) + getShardIteratorRequest.setStartingSequenceNumber(sequenceNumber) + val getShardIteratorResult = retryOrTimeout[GetShardIteratorResult]( + s"getting shard iterator from sequence number $sequenceNumber") { + client.getShardIterator(getShardIteratorRequest) + } + getShardIteratorResult.getShardIterator + } + + /** Helper method to retry Kinesis API request with exponential backoff and timeouts */ + private def retryOrTimeout[T](message: String)(body: => T): T = { + import KinesisSequenceRangeIterator._ + + var startTimeMs = System.currentTimeMillis() + var retryCount = 0 + var waitTimeMs = MIN_RETRY_WAIT_TIME_MS + var result: Option[T] = None + var lastError: Throwable = null + + def isTimedOut = (System.currentTimeMillis() - startTimeMs) >= retryTimeoutMs + def isMaxRetryDone = retryCount >= MAX_RETRIES + + while (result.isEmpty && !isTimedOut && !isMaxRetryDone) { + if (retryCount > 0) { // wait only if this is a retry + Thread.sleep(waitTimeMs) + waitTimeMs *= 2 // if you have waited, then double wait time for next round + } + try { + result = Some(body) + } catch { + case NonFatal(t) => + lastError = t + t match { + case ptee: ProvisionedThroughputExceededException => + logWarning(s"Error while $message [attempt = ${retryCount + 1}]", ptee) + case e: Throwable => + throw new SparkException(s"Error while $message", e) + } + } + retryCount += 1 + } + result.getOrElse { + if (isTimedOut) { + throw new SparkException( + s"Timed out after $retryTimeoutMs ms while $message, last exception: ", lastError) + } else { + throw new SparkException( + s"Gave up after $retryCount retries while $message, last exception: ", lastError) + } + } + } +} + +private[streaming] +object KinesisSequenceRangeIterator { + val MAX_RETRIES = 3 + val MIN_RETRY_WAIT_TIME_MS = 100 +} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala new file mode 100644 index 0000000000000..ca39358b75cb6 --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import java.nio.ByteBuffer +import java.util.concurrent.TimeUnit + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.util.{Failure, Random, Success, Try} + +import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain} +import com.amazonaws.regions.RegionUtils +import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClient +import com.amazonaws.services.dynamodbv2.document.DynamoDB +import com.amazonaws.services.kinesis.AmazonKinesisClient +import com.amazonaws.services.kinesis.model._ + +import org.apache.spark.Logging + +/** + * Shared utility methods for performing Kinesis tests that actually transfer data + */ +private class KinesisTestUtils( + val endpointUrl: String = "https://kinesis.us-west-2.amazonaws.com", + _regionName: String = "") extends Logging { + + val regionName = if (_regionName.length == 0) { + RegionUtils.getRegionByEndpoint(endpointUrl).getName() + } else { + RegionUtils.getRegion(_regionName).getName() + } + + val streamShardCount = 2 + + private val createStreamTimeoutSeconds = 300 + private val describeStreamPollTimeSeconds = 1 + + @volatile + private var streamCreated = false + + @volatile + private var _streamName: String = _ + + private lazy val kinesisClient = { + val client = new AmazonKinesisClient(KinesisTestUtils.getAWSCredentials()) + client.setEndpoint(endpointUrl) + client + } + + private lazy val dynamoDB = { + val dynamoDBClient = new AmazonDynamoDBClient(new DefaultAWSCredentialsProviderChain()) + dynamoDBClient.setRegion(RegionUtils.getRegion(regionName)) + new DynamoDB(dynamoDBClient) + } + + def streamName: String = { + require(streamCreated, "Stream not yet created, call createStream() to create one") + _streamName + } + + def createStream(): Unit = { + logInfo("Creating stream") + require(!streamCreated, "Stream already created") + _streamName = findNonExistentStreamName() + + // Create a stream. The number of shards determines the provisioned throughput. + val createStreamRequest = new CreateStreamRequest() + createStreamRequest.setStreamName(_streamName) + createStreamRequest.setShardCount(2) + kinesisClient.createStream(createStreamRequest) + + // The stream is now being created. Wait for it to become active. + waitForStreamToBeActive(_streamName) + streamCreated = true + logInfo("Created stream") + } + + /** + * Push data to Kinesis stream and return a map of + * shardId -> seq of (data, seq number) pushed to corresponding shard + */ + def pushData(testData: Seq[Int]): Map[String, Seq[(Int, String)]] = { + require(streamCreated, "Stream not yet created, call createStream() to create one") + val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]() + + testData.foreach { num => + val str = num.toString + val putRecordRequest = new PutRecordRequest().withStreamName(streamName) + .withData(ByteBuffer.wrap(str.getBytes())) + .withPartitionKey(str) + + val putRecordResult = kinesisClient.putRecord(putRecordRequest) + val shardId = putRecordResult.getShardId + val seqNumber = putRecordResult.getSequenceNumber() + val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, + new ArrayBuffer[(Int, String)]()) + sentSeqNumbers += ((num, seqNumber)) + } + + logInfo(s"Pushed $testData:\n\t ${shardIdToSeqNumbers.mkString("\n\t")}") + shardIdToSeqNumbers.toMap + } + + def deleteStream(): Unit = { + try { + if (streamCreated) { + kinesisClient.deleteStream(streamName) + } + } catch { + case e: Exception => + logWarning(s"Could not delete stream $streamName") + } + } + + def deleteDynamoDBTable(tableName: String): Unit = { + try { + val table = dynamoDB.getTable(tableName) + table.delete() + table.waitForDelete() + } catch { + case e: Exception => + logWarning(s"Could not delete DynamoDB table $tableName") + } + } + + private def describeStream(streamNameToDescribe: String): Option[StreamDescription] = { + try { + val describeStreamRequest = new DescribeStreamRequest().withStreamName(streamNameToDescribe) + val desc = kinesisClient.describeStream(describeStreamRequest).getStreamDescription() + Some(desc) + } catch { + case rnfe: ResourceNotFoundException => + None + } + } + + private def findNonExistentStreamName(): String = { + var testStreamName: String = null + do { + Thread.sleep(TimeUnit.SECONDS.toMillis(describeStreamPollTimeSeconds)) + testStreamName = s"KinesisTestUtils-${math.abs(Random.nextLong())}" + } while (describeStream(testStreamName).nonEmpty) + testStreamName + } + + private def waitForStreamToBeActive(streamNameToWaitFor: String): Unit = { + val startTime = System.currentTimeMillis() + val endTime = startTime + TimeUnit.SECONDS.toMillis(createStreamTimeoutSeconds) + while (System.currentTimeMillis() < endTime) { + Thread.sleep(TimeUnit.SECONDS.toMillis(describeStreamPollTimeSeconds)) + describeStream(streamNameToWaitFor).foreach { description => + val streamStatus = description.getStreamStatus() + logDebug(s"\t- current state: $streamStatus\n") + if ("ACTIVE".equals(streamStatus)) { + return + } + } + } + require(false, s"Stream $streamName never became active") + } +} + +private[kinesis] object KinesisTestUtils { + + val envVarName = "ENABLE_KINESIS_TESTS" + + val shouldRunTests = sys.env.get(envVarName) == Some("1") + + def isAWSCredentialsPresent: Boolean = { + Try { new DefaultAWSCredentialsProviderChain().getCredentials() }.isSuccess + } + + def getAWSCredentials(): AWSCredentials = { + assert(shouldRunTests, + "Kinesis test not enabled, should not attempt to get AWS credentials") + Try { new DefaultAWSCredentialsProviderChain().getCredentials() } match { + case Success(cred) => cred + case Failure(e) => + throw new Exception("Kinesis tests enabled, but could get not AWS credentials") + } + } +} diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala new file mode 100644 index 0000000000000..e81fb11e5959f --- /dev/null +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId} +import org.apache.spark.{SparkConf, SparkContext, SparkException} + +class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll { + + private val regionId = "us-east-1" + private val endpointUrl = "https://kinesis.us-east-1.amazonaws.com" + private val testData = 1 to 8 + + private var testUtils: KinesisTestUtils = null + private var shardIds: Seq[String] = null + private var shardIdToData: Map[String, Seq[Int]] = null + private var shardIdToSeqNumbers: Map[String, Seq[String]] = null + private var shardIdToDataAndSeqNumbers: Map[String, Seq[(Int, String)]] = null + private var shardIdToRange: Map[String, SequenceNumberRange] = null + private var allRanges: Seq[SequenceNumberRange] = null + + private var sc: SparkContext = null + private var blockManager: BlockManager = null + + + override def beforeAll(): Unit = { + runIfTestsEnabled("Prepare KinesisTestUtils") { + testUtils = new KinesisTestUtils(endpointUrl) + testUtils.createStream() + + shardIdToDataAndSeqNumbers = testUtils.pushData(testData) + require(shardIdToDataAndSeqNumbers.size > 1, "Need data to be sent to multiple shards") + + shardIds = shardIdToDataAndSeqNumbers.keySet.toSeq + shardIdToData = shardIdToDataAndSeqNumbers.mapValues { _.map { _._1 }} + shardIdToSeqNumbers = shardIdToDataAndSeqNumbers.mapValues { _.map { _._2 }} + shardIdToRange = shardIdToSeqNumbers.map { case (shardId, seqNumbers) => + val seqNumRange = SequenceNumberRange( + testUtils.streamName, shardId, seqNumbers.head, seqNumbers.last) + (shardId, seqNumRange) + } + allRanges = shardIdToRange.values.toSeq + + val conf = new SparkConf().setMaster("local[4]").setAppName("KinesisBackedBlockRDDSuite") + sc = new SparkContext(conf) + blockManager = sc.env.blockManager + } + } + + override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.deleteStream() + } + if (sc != null) { + sc.stop() + } + } + + testIfEnabled("Basic reading from Kinesis") { + // Verify all data using multiple ranges in a single RDD partition + val receivedData1 = new KinesisBackedBlockRDD(sc, regionId, endpointUrl, + fakeBlockIds(1), + Array(SequenceNumberRanges(allRanges.toArray)) + ).map { bytes => new String(bytes).toInt }.collect() + assert(receivedData1.toSet === testData.toSet) + + // Verify all data using one range in each of the multiple RDD partitions + val receivedData2 = new KinesisBackedBlockRDD(sc, regionId, endpointUrl, + fakeBlockIds(allRanges.size), + allRanges.map { range => SequenceNumberRanges(Array(range)) }.toArray + ).map { bytes => new String(bytes).toInt }.collect() + assert(receivedData2.toSet === testData.toSet) + + // Verify ordering within each partition + val receivedData3 = new KinesisBackedBlockRDD(sc, regionId, endpointUrl, + fakeBlockIds(allRanges.size), + allRanges.map { range => SequenceNumberRanges(Array(range)) }.toArray + ).map { bytes => new String(bytes).toInt }.collectPartitions() + assert(receivedData3.length === allRanges.size) + for (i <- 0 until allRanges.size) { + assert(receivedData3(i).toSeq === shardIdToData(allRanges(i).shardId)) + } + } + + testIfEnabled("Read data available in both block manager and Kinesis") { + testRDD(numPartitions = 2, numPartitionsInBM = 2, numPartitionsInKinesis = 2) + } + + testIfEnabled("Read data available only in block manager, not in Kinesis") { + testRDD(numPartitions = 2, numPartitionsInBM = 2, numPartitionsInKinesis = 0) + } + + testIfEnabled("Read data available only in Kinesis, not in block manager") { + testRDD(numPartitions = 2, numPartitionsInBM = 0, numPartitionsInKinesis = 2) + } + + testIfEnabled("Read data available partially in block manager, rest in Kinesis") { + testRDD(numPartitions = 2, numPartitionsInBM = 1, numPartitionsInKinesis = 1) + } + + testIfEnabled("Test isBlockValid skips block fetching from block manager") { + testRDD(numPartitions = 2, numPartitionsInBM = 2, numPartitionsInKinesis = 0, + testIsBlockValid = true) + } + + testIfEnabled("Test whether RDD is valid after removing blocks from block anager") { + testRDD(numPartitions = 2, numPartitionsInBM = 2, numPartitionsInKinesis = 2, + testBlockRemove = true) + } + + /** + * Test the WriteAheadLogBackedRDD, by writing some partitions of the data to block manager + * and the rest to a write ahead log, and then reading reading it all back using the RDD. + * It can also test if the partitions that were read from the log were again stored in + * block manager. + * + * + * + * @param numPartitions Number of partitions in RDD + * @param numPartitionsInBM Number of partitions to write to the BlockManager. + * Partitions 0 to (numPartitionsInBM-1) will be written to BlockManager + * @param numPartitionsInKinesis Number of partitions to write to the Kinesis. + * Partitions (numPartitions - 1 - numPartitionsInKinesis) to + * (numPartitions - 1) will be written to Kinesis + * @param testIsBlockValid Test whether setting isBlockValid to false skips block fetching + * @param testBlockRemove Test whether calling rdd.removeBlock() makes the RDD still usable with + * reads falling back to the WAL + * Example with numPartitions = 5, numPartitionsInBM = 3, and numPartitionsInWAL = 4 + * + * numPartitionsInBM = 3 + * |------------------| + * | | + * 0 1 2 3 4 + * | | + * |-------------------------| + * numPartitionsInKinesis = 4 + */ + private def testRDD( + numPartitions: Int, + numPartitionsInBM: Int, + numPartitionsInKinesis: Int, + testIsBlockValid: Boolean = false, + testBlockRemove: Boolean = false + ): Unit = { + require(shardIds.size > 1, "Need at least 2 shards to test") + require(numPartitionsInBM <= shardIds.size , + "Number of partitions in BlockManager cannot be more than the Kinesis test shards available") + require(numPartitionsInKinesis <= shardIds.size , + "Number of partitions in Kinesis cannot be more than the Kinesis test shards available") + require(numPartitionsInBM <= numPartitions, + "Number of partitions in BlockManager cannot be more than that in RDD") + require(numPartitionsInKinesis <= numPartitions, + "Number of partitions in Kinesis cannot be more than that in RDD") + + // Put necessary blocks in the block manager + val blockIds = fakeBlockIds(numPartitions) + blockIds.foreach(blockManager.removeBlock(_)) + (0 until numPartitionsInBM).foreach { i => + val blockData = shardIdToData(shardIds(i)).iterator.map { _.toString.getBytes() } + blockManager.putIterator(blockIds(i), blockData, StorageLevel.MEMORY_ONLY) + } + + // Create the necessary ranges to use in the RDD + val fakeRanges = Array.fill(numPartitions - numPartitionsInKinesis)( + SequenceNumberRanges(SequenceNumberRange("fakeStream", "fakeShardId", "xxx", "yyy"))) + val realRanges = Array.tabulate(numPartitionsInKinesis) { i => + val range = shardIdToRange(shardIds(i + (numPartitions - numPartitionsInKinesis))) + SequenceNumberRanges(Array(range)) + } + val ranges = (fakeRanges ++ realRanges) + + + // Make sure that the left `numPartitionsInBM` blocks are in block manager, and others are not + require( + blockIds.take(numPartitionsInBM).forall(blockManager.get(_).nonEmpty), + "Expected blocks not in BlockManager" + ) + + require( + blockIds.drop(numPartitionsInBM).forall(blockManager.get(_).isEmpty), + "Unexpected blocks in BlockManager" + ) + + // Make sure that the right sequence `numPartitionsInKinesis` are configured, and others are not + require( + ranges.takeRight(numPartitionsInKinesis).forall { + _.ranges.forall { _.streamName == testUtils.streamName } + }, "Incorrect configuration of RDD, expected ranges not set: " + ) + + require( + ranges.dropRight(numPartitionsInKinesis).forall { + _.ranges.forall { _.streamName != testUtils.streamName } + }, "Incorrect configuration of RDD, unexpected ranges set" + ) + + val rdd = new KinesisBackedBlockRDD(sc, regionId, endpointUrl, blockIds, ranges) + val collectedData = rdd.map { bytes => + new String(bytes).toInt + }.collect() + assert(collectedData.toSet === testData.toSet) + + // Verify that the block fetching is skipped when isBlockValid is set to false. + // This is done by using a RDD whose data is only in memory but is set to skip block fetching + // Using that RDD will throw exception, as it skips block fetching even if the blocks are in + // in BlockManager. + if (testIsBlockValid) { + require(numPartitionsInBM === numPartitions, "All partitions must be in BlockManager") + require(numPartitionsInKinesis === 0, "No partitions must be in Kinesis") + val rdd2 = new KinesisBackedBlockRDD(sc, regionId, endpointUrl, blockIds.toArray, + ranges, isBlockIdValid = Array.fill(blockIds.length)(false)) + intercept[SparkException] { + rdd2.collect() + } + } + + // Verify that the RDD is not invalid after the blocks are removed and can still read data + // from write ahead log + if (testBlockRemove) { + require(numPartitions === numPartitionsInKinesis, + "All partitions must be in WAL for this test") + require(numPartitionsInBM > 0, "Some partitions must be in BlockManager for this test") + rdd.removeBlocks() + assert(rdd.map { bytes => new String(bytes).toInt }.collect().toSet === testData.toSet) + } + } + + /** Generate fake block ids */ + private def fakeBlockIds(num: Int): Array[BlockId] = { + Array.tabulate(num) { i => new StreamBlockId(0, i) } + } +} diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala new file mode 100644 index 0000000000000..8373138785a89 --- /dev/null +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import org.apache.spark.SparkFunSuite + +/** + * Helper class that runs Kinesis real data transfer tests or + * ignores them based on env variable is set or not. + */ +trait KinesisFunSuite extends SparkFunSuite { + import KinesisTestUtils._ + + /** Run the test if environment variable is set or ignore the test */ + def testIfEnabled(testName: String)(testBody: => Unit) { + if (shouldRunTests) { + test(testName)(testBody) + } else { + ignore(s"$testName [enable by setting env var $envVarName=1]")(testBody) + } + } + + /** Run the give body of code only if Kinesis tests are enabled */ + def runIfTestsEnabled(message: String)(body: => Unit): Unit = { + if (shouldRunTests) { + body + } else { + ignore(s"$message [enable by setting env var $envVarName=1]")() + } + } +} diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala index 6c262624833cd..98f2c7c4f1bfb 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -26,23 +26,18 @@ import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionIn import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason import com.amazonaws.services.kinesis.model.Record import org.mockito.Mockito._ -// scalastyle:off -// To avoid introducing a dependency on Spark core tests, simply use scalatest's FunSuite -// here instead of our own SparkFunSuite. Introducing the dependency has caused problems -// in the past (SPARK-8781) that are complicated by bugs in the maven shade plugin (MSHADE-148). -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.mock.MockitoSugar import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{Milliseconds, Seconds, StreamingContext} +import org.apache.spark.streaming.{Milliseconds, Seconds, StreamingContext, TestSuiteBase} import org.apache.spark.util.{Clock, ManualClock, Utils} /** * Suite of Kinesis streaming receiver tests focusing mostly on the KinesisRecordProcessor */ -class KinesisReceiverSuite extends FunSuite with Matchers with BeforeAndAfter - with MockitoSugar { -// scalastyle:on +class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAfter + with MockitoSugar { val app = "TestKinesisReceiver" val stream = "mySparkStream" @@ -62,7 +57,7 @@ class KinesisReceiverSuite extends FunSuite with Matchers with BeforeAndAfter var checkpointStateMock: KinesisCheckpointState = _ var currentClockMock: Clock = _ - before { + override def beforeFunction(): Unit = { receiverMock = mock[KinesisReceiver] checkpointerMock = mock[IRecordProcessorCheckpointer] checkpointClockMock = mock[ManualClock] @@ -70,30 +65,14 @@ class KinesisReceiverSuite extends FunSuite with Matchers with BeforeAndAfter currentClockMock = mock[Clock] } - after { + override def afterFunction(): Unit = { + super.afterFunction() // Since this suite was originally written using EasyMock, add this to preserve the old // mocking semantics (see SPARK-5735 for more details) verifyNoMoreInteractions(receiverMock, checkpointerMock, checkpointClockMock, checkpointStateMock, currentClockMock) } - test("KinesisUtils API") { - val ssc = new StreamingContext("local[2]", getClass.getSimpleName, Seconds(1)) - // Tests the API, does not actually test data receiving - val kinesisStream1 = KinesisUtils.createStream(ssc, "mySparkStream", - "https://kinesis.us-west-2.amazonaws.com", Seconds(2), - InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2) - val kinesisStream2 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream", - "https://kinesis.us-west-2.amazonaws.com", "us-west-2", - InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2) - val kinesisStream3 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream", - "https://kinesis.us-west-2.amazonaws.com", "us-west-2", - InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2, - "awsAccessKey", "awsSecretKey") - - ssc.stop() - } - test("check serializability of SerializableAWSCredentials") { Utils.deserialize[SerializableAWSCredentials]( Utils.serialize(new SerializableAWSCredentials("x", "y"))) diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala new file mode 100644 index 0000000000000..b88c9c6478d56 --- /dev/null +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import scala.collection.mutable +import scala.concurrent.duration._ +import scala.language.postfixOps +import scala.util.Random + +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import org.scalatest.concurrent.Eventually +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} + +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming._ +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} + +class KinesisStreamSuite extends KinesisFunSuite + with Eventually with BeforeAndAfter with BeforeAndAfterAll { + + // This is the name that KCL uses to save metadata to DynamoDB + private val kinesisAppName = s"KinesisStreamSuite-${math.abs(Random.nextLong())}" + + private var ssc: StreamingContext = _ + private var sc: SparkContext = _ + + override def beforeAll(): Unit = { + val conf = new SparkConf() + .setMaster("local[4]") + .setAppName("KinesisStreamSuite") // Setting Spark app name to Kinesis app name + sc = new SparkContext(conf) + } + + override def afterAll(): Unit = { + sc.stop() + // Delete the Kinesis stream as well as the DynamoDB table generated by + // Kinesis Client Library when consuming the stream + } + + after { + if (ssc != null) { + ssc.stop(stopSparkContext = false) + ssc = null + } + } + + test("KinesisUtils API") { + ssc = new StreamingContext(sc, Seconds(1)) + // Tests the API, does not actually test data receiving + val kinesisStream1 = KinesisUtils.createStream(ssc, "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", Seconds(2), + InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2) + val kinesisStream2 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", "us-west-2", + InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2) + val kinesisStream3 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", "us-west-2", + InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2, + "awsAccessKey", "awsSecretKey") + } + + + /** + * Test the stream by sending data to a Kinesis stream and receiving from it. + * This test is not run by default as it requires AWS credentials that the test + * environment may not have. Even if there is AWS credentials available, the user + * may not want to run these tests to avoid the Kinesis costs. To enable this test, + * you must have AWS credentials available through the default AWS provider chain, + * and you have to set the system environment variable RUN_KINESIS_TESTS=1 . + */ + testIfEnabled("basic operation") { + val kinesisTestUtils = new KinesisTestUtils() + try { + kinesisTestUtils.createStream() + ssc = new StreamingContext(sc, Seconds(1)) + val awsCredentials = KinesisTestUtils.getAWSCredentials() + val stream = KinesisUtils.createStream(ssc, kinesisAppName, kinesisTestUtils.streamName, + kinesisTestUtils.endpointUrl, kinesisTestUtils.regionName, InitialPositionInStream.LATEST, + Seconds(10), StorageLevel.MEMORY_ONLY, + awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) + + val collected = new mutable.HashSet[Int] with mutable.SynchronizedSet[Int] + stream.map { bytes => new String(bytes).toInt }.foreachRDD { rdd => + collected ++= rdd.collect() + logInfo("Collected = " + rdd.collect().toSeq.mkString(", ")) + } + ssc.start() + + val testData = 1 to 10 + eventually(timeout(120 seconds), interval(10 second)) { + kinesisTestUtils.pushData(testData) + assert(collected === testData.toSet, "\nData received does not match data sent") + } + ssc.stop() + } finally { + kinesisTestUtils.deleteStream() + kinesisTestUtils.deleteDynamoDBTable(kinesisAppName) + } + } +} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala b/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala index 7372dfbd9fe98..70a7592da8ae3 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala @@ -32,7 +32,7 @@ trait PartitionStrategy extends Serializable { object PartitionStrategy { /** * Assigns edges to partitions using a 2D partitioning of the sparse edge adjacency matrix, - * guaranteeing a `2 * sqrt(numParts) - 1` bound on vertex replication. + * guaranteeing a `2 * sqrt(numParts)` bound on vertex replication. * * Suppose we have a graph with 12 vertices that we want to partition * over 9 machines. We can use the following sparse matrix representation: @@ -61,26 +61,36 @@ object PartitionStrategy { * that edges adjacent to `v11` can only be in the first column of blocks `(P0, P3, * P6)` or the last * row of blocks `(P6, P7, P8)`. As a consequence we can guarantee that `v11` will need to be - * replicated to at most `2 * sqrt(numParts) - 1` machines. + * replicated to at most `2 * sqrt(numParts)` machines. * * Notice that `P0` has many edges and as a consequence this partitioning would lead to poor work * balance. To improve balance we first multiply each vertex id by a large prime to shuffle the * vertex locations. * - * One of the limitations of this approach is that the number of machines must either be a - * perfect square. We partially address this limitation by computing the machine assignment to - * the next - * largest perfect square and then mapping back down to the actual number of machines. - * Unfortunately, this can also lead to work imbalance and so it is suggested that a perfect - * square is used. + * When the number of partitions requested is not a perfect square we use a slightly different + * method where the last column can have a different number of rows than the others while still + * maintaining the same size per block. */ case object EdgePartition2D extends PartitionStrategy { override def getPartition(src: VertexId, dst: VertexId, numParts: PartitionID): PartitionID = { val ceilSqrtNumParts: PartitionID = math.ceil(math.sqrt(numParts)).toInt val mixingPrime: VertexId = 1125899906842597L - val col: PartitionID = (math.abs(src * mixingPrime) % ceilSqrtNumParts).toInt - val row: PartitionID = (math.abs(dst * mixingPrime) % ceilSqrtNumParts).toInt - (col * ceilSqrtNumParts + row) % numParts + if (numParts == ceilSqrtNumParts * ceilSqrtNumParts) { + // Use old method for perfect squared to ensure we get same results + val col: PartitionID = (math.abs(src * mixingPrime) % ceilSqrtNumParts).toInt + val row: PartitionID = (math.abs(dst * mixingPrime) % ceilSqrtNumParts).toInt + (col * ceilSqrtNumParts + row) % numParts + + } else { + // Otherwise use new method + val cols = ceilSqrtNumParts + val rows = (numParts + cols - 1) / cols + val lastColRows = numParts - rows * (cols - 1) + val col = (math.abs(src * mixingPrime) % numParts / rows).toInt + val row = (math.abs(dst * mixingPrime) % (if (col < cols - 1) rows else lastColRows)).toInt + col * rows + row + + } } } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index cfcf7244eaed5..2ca60d51f8331 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -127,28 +127,25 @@ object Pregel extends Logging { var prevG: Graph[VD, ED] = null var i = 0 while (activeMessages > 0 && i < maxIterations) { - // Receive the messages. Vertices that didn't get any messages do not appear in newVerts. - val newVerts = g.vertices.innerJoin(messages)(vprog).cache() - // Update the graph with the new vertices. + // Receive the messages and update the vertices. prevG = g - g = g.outerJoinVertices(newVerts) { (vid, old, newOpt) => newOpt.getOrElse(old) } - g.cache() + g = g.joinVertices(messages)(vprog).cache() val oldMessages = messages - // Send new messages. Vertices that didn't get any messages don't appear in newVerts, so don't - // get to send messages. We must cache messages so it can be materialized on the next line, - // allowing us to uncache the previous iteration. - messages = g.mapReduceTriplets(sendMsg, mergeMsg, Some((newVerts, activeDirection))).cache() - // The call to count() materializes `messages`, `newVerts`, and the vertices of `g`. This - // hides oldMessages (depended on by newVerts), newVerts (depended on by messages), and the - // vertices of prevG (depended on by newVerts, oldMessages, and the vertices of g). + // Send new messages, skipping edges where neither side received a message. We must cache + // messages so it can be materialized on the next line, allowing us to uncache the previous + // iteration. + messages = g.mapReduceTriplets( + sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache() + // The call to count() materializes `messages` and the vertices of `g`. This hides oldMessages + // (depended on by the vertices of g) and the vertices of prevG (depended on by oldMessages + // and the vertices of g). activeMessages = messages.count() logInfo("Pregel finished iteration " + i) // Unpersist the RDDs hidden by newly-materialized RDDs oldMessages.unpersist(blocking = false) - newVerts.unpersist(blocking = false) prevG.unpersistVertices(blocking = false) prevG.edges.unpersist(blocking = false) // count the iteration diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index 90a74d23a26cc..da95314440d86 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -332,9 +332,9 @@ object GraphImpl { edgeStorageLevel: StorageLevel, vertexStorageLevel: StorageLevel): GraphImpl[VD, ED] = { val edgeRDD = EdgeRDD.fromEdges(edges)(classTag[ED], classTag[VD]) - .withTargetStorageLevel(edgeStorageLevel).cache() + .withTargetStorageLevel(edgeStorageLevel) val vertexRDD = VertexRDD(vertices, edgeRDD, defaultVertexAttr) - .withTargetStorageLevel(vertexStorageLevel).cache() + .withTargetStorageLevel(vertexStorageLevel) GraphImpl(vertexRDD, edgeRDD) } @@ -346,9 +346,14 @@ object GraphImpl { def apply[VD: ClassTag, ED: ClassTag]( vertices: VertexRDD[VD], edges: EdgeRDD[ED]): GraphImpl[VD, ED] = { + + vertices.cache() + // Convert the vertex partitions in edges to the correct type val newEdges = edges.asInstanceOf[EdgeRDDImpl[ED, _]] .mapEdgePartitions((pid, part) => part.withoutVertexAttributes[VD]) + .cache() + GraphImpl.fromExistingRDDs(vertices, newEdges) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala index be6b9047d932d..74a7de18d4161 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala @@ -66,7 +66,6 @@ private[graphx] object BytecodeUtils { val finder = new MethodInvocationFinder(c.getName, m) getClassReader(c).accept(finder, 0) for (classMethod <- finder.methodsInvoked) { - // println(classMethod) if (classMethod._1 == targetClass && classMethod._2 == targetMethod) { return true } else if (!seen.contains(classMethod)) { @@ -122,7 +121,7 @@ private[graphx] object BytecodeUtils { override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) { if (op == INVOKEVIRTUAL || op == INVOKESPECIAL || op == INVOKESTATIC) { if (!skipClass(owner)) { - methodsInvoked.add((Class.forName(owner.replace("/", ".")), name)) + methodsInvoked.add((Utils.classForName(owner.replace("/", ".")), name)) } } } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala index 9591c4e9b8f4e..989e226305265 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala @@ -33,7 +33,7 @@ import org.apache.spark.graphx.Edge import org.apache.spark.graphx.impl.GraphImpl /** A collection of graph generating functions. */ -object GraphGenerators { +object GraphGenerators extends Logging { val RMATa = 0.45 val RMATb = 0.15 @@ -142,7 +142,7 @@ object GraphGenerators { var edges: Set[Edge[Int]] = Set() while (edges.size < numEdges) { if (edges.size % 100 == 0) { - println(edges.size + " edges") + logDebug(edges.size + " edges") } edges += addEdge(numVertices) } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala index 186d0cc2a977b..61e44dcab578c 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.graphx.util import org.apache.spark.SparkFunSuite +// scalastyle:off println class BytecodeUtilsSuite extends SparkFunSuite { import BytecodeUtilsSuite.TestClass @@ -102,6 +103,7 @@ class BytecodeUtilsSuite extends SparkFunSuite { private val c = {e: TestClass => println(e.baz)} } +// scalastyle:on println object BytecodeUtilsSuite { class TestClass(val foo: Int, val bar: Long) { diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java index d4cfeacb6ef18..c0f89c9230692 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java @@ -25,11 +25,12 @@ import static org.apache.spark.launcher.CommandBuilderUtils.*; -/** +/** * Launcher for Spark applications. - *

    + *

    * Use this class to start Spark applications programmatically. The class uses a builder pattern * to allow clients to configure the Spark application and launch it as a child process. + *

    */ public class SparkLauncher { diff --git a/launcher/src/main/java/org/apache/spark/launcher/package-info.java b/launcher/src/main/java/org/apache/spark/launcher/package-info.java index 7ed756f4b8591..7c97dba511b28 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/package-info.java +++ b/launcher/src/main/java/org/apache/spark/launcher/package-info.java @@ -17,13 +17,17 @@ /** * Library for launching Spark applications. - *

    + * + *

    * This library allows applications to launch Spark programmatically. There's only one entry * point to the library - the {@link org.apache.spark.launcher.SparkLauncher} class. - *

    + *

    + * + *

    * To launch a Spark application, just instantiate a {@link org.apache.spark.launcher.SparkLauncher} * and configure the application to run. For example: - * + *

    + * *
      * {@code
      *   import org.apache.spark.launcher.SparkLauncher;
    diff --git a/make-distribution.sh b/make-distribution.sh
    index 9f063da3a16c0..4789b0e09cc8a 100755
    --- a/make-distribution.sh
    +++ b/make-distribution.sh
    @@ -33,7 +33,7 @@ SPARK_HOME="$(cd "`dirname "$0"`"; pwd)"
     DISTDIR="$SPARK_HOME/dist"
     
     SPARK_TACHYON=false
    -TACHYON_VERSION="0.6.4"
    +TACHYON_VERSION="0.7.0"
     TACHYON_TGZ="tachyon-${TACHYON_VERSION}-bin.tar.gz"
     TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/${TACHYON_TGZ}"
     
    @@ -219,6 +219,7 @@ cp -r "$SPARK_HOME/ec2" "$DISTDIR"
     if [ -d "$SPARK_HOME"/R/lib/SparkR ]; then
       mkdir -p "$DISTDIR"/R/lib
       cp -r "$SPARK_HOME/R/lib/SparkR" "$DISTDIR"/R/lib
    +  cp "$SPARK_HOME/R/lib/sparkr.zip" "$DISTDIR"/R/lib
     fi
     
     # Download and copy in tachyon, if requested
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
    index a1f3851d804ff..aef2c019d2871 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
    @@ -95,6 +95,8 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] {
       /** @group setParam */
       def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this }
     
    +  // Below, we clone stages so that modifications to the list of stages will not change
    +  // the Param value in the Pipeline.
       /** @group getParam */
       def getStages: Array[PipelineStage] = $(stages).clone()
     
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
    index 333b42711ec52..19fe039b8fd03 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
    @@ -169,10 +169,7 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
       override def transform(dataset: DataFrame): DataFrame = {
         transformSchema(dataset.schema, logging = true)
         if ($(predictionCol).nonEmpty) {
    -      val predictUDF = udf { (features: Any) =>
    -        predict(features.asInstanceOf[FeaturesType])
    -      }
    -      dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
    +      transformImpl(dataset)
         } else {
           this.logWarning(s"$uid: Predictor.transform() was called as NOOP" +
             " since no output columns were set.")
    @@ -180,6 +177,13 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
         }
       }
     
    +  protected def transformImpl(dataset: DataFrame): DataFrame = {
    +    val predictUDF = udf { (features: Any) =>
    +      predict(features.asInstanceOf[FeaturesType])
    +    }
    +    dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
    +  }
    +
       /**
        * Predict label for the given features.
        * This internal method is used to implement [[transform()]] and output [[predictionCol]].
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
    index 85c097bc64a4f..581d8fa7749be 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
    @@ -156,5 +156,5 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
        * This may be overridden to support thresholds which favor particular labels.
        * @return  predicted label
        */
    -  protected def raw2prediction(rawPrediction: Vector): Double = rawPrediction.toDense.argmax
    +  protected def raw2prediction(rawPrediction: Vector): Double = rawPrediction.argmax
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
    index 2dc1824964a42..36fe1bd40469c 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
    @@ -21,10 +21,10 @@ import org.apache.spark.annotation.Experimental
     import org.apache.spark.ml.{PredictionModel, Predictor}
     import org.apache.spark.ml.param.ParamMap
     import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams}
    +import org.apache.spark.ml.tree.impl.RandomForest
     import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
     import org.apache.spark.mllib.linalg.Vector
     import org.apache.spark.mllib.regression.LabeledPoint
    -import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree}
     import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
     import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
     import org.apache.spark.rdd.RDD
    @@ -75,8 +75,9 @@ final class DecisionTreeClassifier(override val uid: String)
         }
         val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
         val strategy = getOldStrategy(categoricalFeatures, numClasses)
    -    val oldModel = OldDecisionTree.train(oldDataset, strategy)
    -    DecisionTreeClassificationModel.fromOld(oldModel, this, categoricalFeatures)
    +    val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
    +      seed = 0L, parentUID = Some(uid))
    +    trees.head.asInstanceOf[DecisionTreeClassificationModel]
       }
     
       /** (private[ml]) Create a Strategy instance to use with the old API. */
    @@ -112,6 +113,12 @@ final class DecisionTreeClassificationModel private[ml] (
       require(rootNode != null,
         "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.")
     
    +  /**
    +   * Construct a decision tree classification model.
    +   * @param rootNode  Root node of tree, with other nodes attached.
    +   */
    +  def this(rootNode: Node) = this(Identifiable.randomUID("dtc"), rootNode)
    +
       override protected def predict(features: Vector): Double = {
         rootNode.predict(features)
       }
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
    index 554e3b8e052b2..eb0b1a0a405fc 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
    @@ -34,6 +34,8 @@ import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss}
     import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
     import org.apache.spark.rdd.RDD
     import org.apache.spark.sql.DataFrame
    +import org.apache.spark.sql.functions._
    +import org.apache.spark.sql.types.DoubleType
     
     /**
      * :: Experimental ::
    @@ -177,8 +179,15 @@ final class GBTClassificationModel(
     
       override def treeWeights: Array[Double] = _treeWeights
     
    +  override protected def transformImpl(dataset: DataFrame): DataFrame = {
    +    val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
    +    val predictUDF = udf { (features: Any) =>
    +      bcastModel.value.predict(features.asInstanceOf[Vector])
    +    }
    +    dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
    +  }
    +
       override protected def predict(features: Vector): Double = {
    -    // TODO: Override transform() to broadcast model: SPARK-7127
         // TODO: When we add a generic Boosting class, handle transform there?  SPARK-7129
         // Classifies by thresholding sum of weighted tree predictions
         val treePredictions = _trees.map(_.rootNode.predict(features))
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
    index 2e6eedd45ab07..8fc9199fb4602 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
    @@ -19,7 +19,7 @@ package org.apache.spark.ml.classification
     
     import scala.collection.mutable
     
    -import breeze.linalg.{DenseVector => BDV, norm => brzNorm}
    +import breeze.linalg.{DenseVector => BDV}
     import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
     
     import org.apache.spark.{Logging, SparkException}
    @@ -41,7 +41,7 @@ import org.apache.spark.storage.StorageLevel
      */
     private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams
       with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol
    -  with HasThreshold
    +  with HasThreshold with HasStandardization
     
     /**
      * :: Experimental ::
    @@ -98,6 +98,18 @@ class LogisticRegression(override val uid: String)
       def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
       setDefault(fitIntercept -> true)
     
    +  /**
    +   * Whether to standardize the training features before fitting the model.
    +   * The coefficients of models will be always returned on the original scale,
    +   * so it will be transparent for users. Note that when no regularization,
    +   * with or without standardization, the models should be always converged to
    +   * the same solution.
    +   * Default is true.
    +   * @group setParam
    +   * */
    +  def setStandardization(value: Boolean): this.type = set(standardization, value)
    +  setDefault(standardization -> true)
    +
       /** @group setParam */
       def setThreshold(value: Double): this.type = set(threshold, value)
       setDefault(threshold -> 0.5)
    @@ -116,7 +128,7 @@ class LogisticRegression(override val uid: String)
               case ((summarizer: MultivariateOnlineSummarizer, labelSummarizer: MultiClassSummarizer),
               (label: Double, features: Vector)) =>
                 (summarizer.add(features), labelSummarizer.add(label))
    -      },
    +        },
             combOp = (c1, c2) => (c1, c2) match {
               case ((summarizer1: MultivariateOnlineSummarizer,
               classSummarizer1: MultiClassSummarizer), (summarizer2: MultivariateOnlineSummarizer,
    @@ -149,15 +161,28 @@ class LogisticRegression(override val uid: String)
         val regParamL1 = $(elasticNetParam) * $(regParam)
         val regParamL2 = (1.0 - $(elasticNetParam)) * $(regParam)
     
    -    val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept),
    +    val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept), $(standardization),
           featuresStd, featuresMean, regParamL2)
     
         val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) {
           new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
         } else {
    -      // Remove the L1 penalization on the intercept
           def regParamL1Fun = (index: Int) => {
    -        if (index == numFeatures) 0.0 else regParamL1
    +        // Remove the L1 penalization on the intercept
    +        if (index == numFeatures) {
    +          0.0
    +        } else {
    +          if ($(standardization)) {
    +            regParamL1
    +          } else {
    +            // If `standardization` is false, we still standardize the data
    +            // to improve the rate of convergence; as a result, we have to
    +            // perform this reverse standardization by penalizing each component
    +            // differently to get effectively the same objective function when
    +            // the training dataset is not standardized.
    +            if (featuresStd(index) != 0.0) regParamL1 / featuresStd(index) else 0.0
    +          }
    +        }
           }
           new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol))
         }
    @@ -166,18 +191,18 @@ class LogisticRegression(override val uid: String)
           Vectors.zeros(if ($(fitIntercept)) numFeatures + 1 else numFeatures)
     
         if ($(fitIntercept)) {
    -      /**
    -       * For binary logistic regression, when we initialize the weights as zeros,
    -       * it will converge faster if we initialize the intercept such that
    -       * it follows the distribution of the labels.
    -       *
    -       * {{{
    -       * P(0) = 1 / (1 + \exp(b)), and
    -       * P(1) = \exp(b) / (1 + \exp(b))
    -       * }}}, hence
    -       * {{{
    -       * b = \log{P(1) / P(0)} = \log{count_1 / count_0}
    -       * }}}
    +      /*
    +         For binary logistic regression, when we initialize the weights as zeros,
    +         it will converge faster if we initialize the intercept such that
    +         it follows the distribution of the labels.
    +
    +         {{{
    +         P(0) = 1 / (1 + \exp(b)), and
    +         P(1) = \exp(b) / (1 + \exp(b))
    +         }}}, hence
    +         {{{
    +         b = \log{P(1) / P(0)} = \log{count_1 / count_0}
    +         }}}
            */
           initialWeightsWithIntercept.toArray(numFeatures)
             = math.log(histogram(1).toDouble / histogram(0).toDouble)
    @@ -186,39 +211,48 @@ class LogisticRegression(override val uid: String)
         val states = optimizer.iterations(new CachedDiffFunction(costFun),
           initialWeightsWithIntercept.toBreeze.toDenseVector)
     
    -    var state = states.next()
    -    val lossHistory = mutable.ArrayBuilder.make[Double]
    +    val (weights, intercept, objectiveHistory) = {
    +      /*
    +         Note that in Logistic Regression, the objective history (loss + regularization)
    +         is log-likelihood which is invariance under feature standardization. As a result,
    +         the objective history from optimizer is the same as the one in the original space.
    +       */
    +      val arrayBuilder = mutable.ArrayBuilder.make[Double]
    +      var state: optimizer.State = null
    +      while (states.hasNext) {
    +        state = states.next()
    +        arrayBuilder += state.adjustedValue
    +      }
     
    -    while (states.hasNext) {
    -      lossHistory += state.value
    -      state = states.next()
    -    }
    -    lossHistory += state.value
    +      if (state == null) {
    +        val msg = s"${optimizer.getClass.getName} failed."
    +        logError(msg)
    +        throw new SparkException(msg)
    +      }
     
    -    // The weights are trained in the scaled space; we're converting them back to
    -    // the original space.
    -    val weightsWithIntercept = {
    +      /*
    +         The weights are trained in the scaled space; we're converting them back to
    +         the original space.
    +         Note that the intercept in scaled space and original space is the same;
    +         as a result, no scaling is needed.
    +       */
           val rawWeights = state.x.toArray.clone()
           var i = 0
    -      // Note that the intercept in scaled space and original space is the same;
    -      // as a result, no scaling is needed.
           while (i < numFeatures) {
             rawWeights(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 }
             i += 1
           }
    -      Vectors.dense(rawWeights)
    +
    +      if ($(fitIntercept)) {
    +        (Vectors.dense(rawWeights.dropRight(1)).compressed, rawWeights.last, arrayBuilder.result())
    +      } else {
    +        (Vectors.dense(rawWeights).compressed, 0.0, arrayBuilder.result())
    +      }
         }
     
         if (handlePersistence) instances.unpersist()
     
    -    val (weights, intercept) = if ($(fitIntercept)) {
    -      (Vectors.dense(weightsWithIntercept.toArray.slice(0, weightsWithIntercept.size - 1)),
    -        weightsWithIntercept(weightsWithIntercept.size - 1))
    -    } else {
    -      (weightsWithIntercept, 0.0)
    -    }
    -
    -    new LogisticRegressionModel(uid, weights.compressed, intercept)
    +    copyValues(new LogisticRegressionModel(uid, weights, intercept))
       }
     
       override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra)
    @@ -423,16 +457,12 @@ private class LogisticAggregator(
         require(dim == data.size, s"Dimensions mismatch when adding new sample." +
           s" Expecting $dim but got ${data.size}.")
     
    -    val dataSize = data.size
    -
         val localWeightsArray = weightsArray
         val localGradientSumArray = gradientSumArray
     
         numClasses match {
           case 2 =>
    -        /**
    -         * For Binary Logistic Regression.
    -         */
    +        // For Binary Logistic Regression.
             val margin = - {
               var sum = 0.0
               data.foreachActive { (index, value) =>
    @@ -518,11 +548,13 @@ private class LogisticCostFun(
         data: RDD[(Double, Vector)],
         numClasses: Int,
         fitIntercept: Boolean,
    +    standardization: Boolean,
         featuresStd: Array[Double],
         featuresMean: Array[Double],
         regParamL2: Double) extends DiffFunction[BDV[Double]] {
     
       override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = {
    +    val numFeatures = featuresStd.length
         val w = Vectors.fromBreeze(weights)
     
         val logisticAggregator = data.treeAggregate(new LogisticAggregator(w, numClasses, fitIntercept,
    @@ -534,27 +566,43 @@ private class LogisticCostFun(
               case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
             })
     
    -    // regVal is the sum of weight squares for L2 regularization
    -    val norm = if (regParamL2 == 0.0) {
    -      0.0
    -    } else if (fitIntercept) {
    -      brzNorm(Vectors.dense(weights.toArray.slice(0, weights.size -1)).toBreeze, 2.0)
    -    } else {
    -      brzNorm(weights, 2.0)
    -    }
    -    val regVal = 0.5 * regParamL2 * norm * norm
    -
    -    val loss = logisticAggregator.loss + regVal
    -    val gradient = logisticAggregator.gradient
    +    val totalGradientArray = logisticAggregator.gradient.toArray
     
    -    if (fitIntercept) {
    -      val wArray = w.toArray.clone()
    -      wArray(wArray.length - 1) = 0.0
    -      axpy(regParamL2, Vectors.dense(wArray), gradient)
    +    // regVal is the sum of weight squares excluding intercept for L2 regularization.
    +    val regVal = if (regParamL2 == 0.0) {
    +      0.0
         } else {
    -      axpy(regParamL2, w, gradient)
    +      var sum = 0.0
    +      w.foreachActive { (index, value) =>
    +        // If `fitIntercept` is true, the last term which is intercept doesn't
    +        // contribute to the regularization.
    +        if (index != numFeatures) {
    +          // The following code will compute the loss of the regularization; also
    +          // the gradient of the regularization, and add back to totalGradientArray.
    +          sum += {
    +            if (standardization) {
    +              totalGradientArray(index) += regParamL2 * value
    +              value * value
    +            } else {
    +              if (featuresStd(index) != 0.0) {
    +                // If `standardization` is false, we still standardize the data
    +                // to improve the rate of convergence; as a result, we have to
    +                // perform this reverse standardization by penalizing each component
    +                // differently to get effectively the same objective function when
    +                // the training dataset is not standardized.
    +                val temp = value / (featuresStd(index) * featuresStd(index))
    +                totalGradientArray(index) += regParamL2 * temp
    +                value * temp
    +              } else {
    +                0.0
    +              }
    +            }
    +          }
    +        }
    +      }
    +      0.5 * regParamL2 * sum
         }
     
    -    (loss, gradient.toBreeze.asInstanceOf[BDV[Double]])
    +    (logisticAggregator.loss + regVal, new BDV(totalGradientArray))
       }
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
    new file mode 100644
    index 0000000000000..1f547e4a98af7
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
    @@ -0,0 +1,178 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.classification
    +
    +import org.apache.spark.SparkException
    +import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor}
    +import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param, DoubleParam}
    +import org.apache.spark.ml.util.Identifiable
    +import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes}
    +import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesModel}
    +import org.apache.spark.mllib.linalg._
    +import org.apache.spark.mllib.regression.LabeledPoint
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.sql.DataFrame
    +
    +/**
    + * Params for Naive Bayes Classifiers.
    + */
    +private[ml] trait NaiveBayesParams extends PredictorParams {
    +
    +  /**
    +   * The smoothing parameter.
    +   * (default = 1.0).
    +   * @group param
    +   */
    +  final val lambda: DoubleParam = new DoubleParam(this, "lambda", "The smoothing parameter.",
    +    ParamValidators.gtEq(0))
    +
    +  /** @group getParam */
    +  final def getLambda: Double = $(lambda)
    +
    +  /**
    +   * The model type which is a string (case-sensitive).
    +   * Supported options: "multinomial" and "bernoulli".
    +   * (default = multinomial)
    +   * @group param
    +   */
    +  final val modelType: Param[String] = new Param[String](this, "modelType", "The model type " +
    +    "which is a string (case-sensitive). Supported options: multinomial (default) and bernoulli.",
    +    ParamValidators.inArray[String](OldNaiveBayes.supportedModelTypes.toArray))
    +
    +  /** @group getParam */
    +  final def getModelType: String = $(modelType)
    +}
    +
    +/**
    + * Naive Bayes Classifiers.
    + * It supports both Multinomial NB
    + * ([[http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html]])
    + * which can handle finitely supported discrete data. For example, by converting documents into
    + * TF-IDF vectors, it can be used for document classification. By making every vector a
    + * binary (0/1) data, it can also be used as Bernoulli NB
    + * ([[http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html]]).
    + * The input feature values must be nonnegative.
    + */
    +class NaiveBayes(override val uid: String)
    +  extends Predictor[Vector, NaiveBayes, NaiveBayesModel]
    +  with NaiveBayesParams {
    +
    +  def this() = this(Identifiable.randomUID("nb"))
    +
    +  /**
    +   * Set the smoothing parameter.
    +   * Default is 1.0.
    +   * @group setParam
    +   */
    +  def setLambda(value: Double): this.type = set(lambda, value)
    +  setDefault(lambda -> 1.0)
    +
    +  /**
    +   * Set the model type using a string (case-sensitive).
    +   * Supported options: "multinomial" and "bernoulli".
    +   * Default is "multinomial"
    +   */
    +  def setModelType(value: String): this.type = set(modelType, value)
    +  setDefault(modelType -> OldNaiveBayes.Multinomial)
    +
    +  override protected def train(dataset: DataFrame): NaiveBayesModel = {
    +    val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
    +    val oldModel = OldNaiveBayes.train(oldDataset, $(lambda), $(modelType))
    +    NaiveBayesModel.fromOld(oldModel, this)
    +  }
    +
    +  override def copy(extra: ParamMap): NaiveBayes = defaultCopy(extra)
    +}
    +
    +/**
    + * Model produced by [[NaiveBayes]]
    + */
    +class NaiveBayesModel private[ml] (
    +    override val uid: String,
    +    val pi: Vector,
    +    val theta: Matrix)
    +  extends PredictionModel[Vector, NaiveBayesModel] with NaiveBayesParams {
    +
    +  import OldNaiveBayes.{Bernoulli, Multinomial}
    +
    +  /**
    +   * Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0.
    +   * This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
    +   * application of this condition (in predict function).
    +   */
    +  private lazy val (thetaMinusNegTheta, negThetaSum) = $(modelType) match {
    +    case Multinomial => (None, None)
    +    case Bernoulli =>
    +      val negTheta = theta.map(value => math.log(1.0 - math.exp(value)))
    +      val ones = new DenseVector(Array.fill(theta.numCols){1.0})
    +      val thetaMinusNegTheta = theta.map { value =>
    +        value - math.log(1.0 - math.exp(value))
    +      }
    +      (Option(thetaMinusNegTheta), Option(negTheta.multiply(ones)))
    +    case _ =>
    +      // This should never happen.
    +      throw new UnknownError(s"Invalid modelType: ${$(modelType)}.")
    +  }
    +
    +  override protected def predict(features: Vector): Double = {
    +    $(modelType) match {
    +      case Multinomial =>
    +        val prob = theta.multiply(features)
    +        BLAS.axpy(1.0, pi, prob)
    +        prob.argmax
    +      case Bernoulli =>
    +        features.foreachActive{ (index, value) =>
    +          if (value != 0.0 && value != 1.0) {
    +            throw new SparkException(
    +              s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features")
    +          }
    +        }
    +        val prob = thetaMinusNegTheta.get.multiply(features)
    +        BLAS.axpy(1.0, pi, prob)
    +        BLAS.axpy(1.0, negThetaSum.get, prob)
    +        prob.argmax
    +      case _ =>
    +        // This should never happen.
    +        throw new UnknownError(s"Invalid modelType: ${$(modelType)}.")
    +    }
    +  }
    +
    +  override def copy(extra: ParamMap): NaiveBayesModel = {
    +    copyValues(new NaiveBayesModel(uid, pi, theta).setParent(this.parent), extra)
    +  }
    +
    +  override def toString: String = {
    +    s"NaiveBayesModel with ${pi.size} classes"
    +  }
    +
    +}
    +
    +private[ml] object NaiveBayesModel {
    +
    +  /** Convert a model from the old API */
    +  def fromOld(
    +      oldModel: OldNaiveBayesModel,
    +      parent: NaiveBayes): NaiveBayesModel = {
    +    val uid = if (parent != null) parent.uid else Identifiable.randomUID("nb")
    +    val labels = Vectors.dense(oldModel.labels)
    +    val pi = Vectors.dense(oldModel.pi)
    +    val theta = new DenseMatrix(oldModel.labels.length, oldModel.theta(0).length,
    +      oldModel.theta.flatten, true)
    +    new NaiveBayesModel(uid, pi, theta)
    +  }
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
    index ea757c5e40c76..1741f19dc911c 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
    @@ -47,6 +47,8 @@ private[ml] trait OneVsRestParams extends PredictorParams {
     
       /**
        * param for the base binary classifier that we reduce multiclass classification into.
    +   * The base classifier input and output columns are ignored in favor of
    +   * the ones specified in [[OneVsRest]].
        * @group param
        */
       val classifier: Param[ClassifierType] = new Param(this, "classifier", "base binary classifier")
    @@ -160,6 +162,15 @@ final class OneVsRest(override val uid: String)
         set(classifier, value.asInstanceOf[ClassifierType])
       }
     
    +  /** @group setParam */
    +  def setLabelCol(value: String): this.type = set(labelCol, value)
    +
    +  /** @group setParam */
    +  def setFeaturesCol(value: String): this.type = set(featuresCol, value)
    +
    +  /** @group setParam */
    +  def setPredictionCol(value: String): this.type = set(predictionCol, value)
    +
       override def transformSchema(schema: StructType): StructType = {
         validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType)
       }
    @@ -195,7 +206,11 @@ final class OneVsRest(override val uid: String)
           val labelUDFWithNewMeta = labelUDF(col($(labelCol))).as(labelColName, newLabelMeta)
           val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta)
           val classifier = getClassifier
    -      classifier.fit(trainingDataset, classifier.labelCol -> labelColName)
    +      val paramMap = new ParamMap()
    +      paramMap.put(classifier.labelCol -> labelColName)
    +      paramMap.put(classifier.featuresCol -> getFeaturesCol)
    +      paramMap.put(classifier.predictionCol -> getPredictionCol)
    +      classifier.fit(trainingDataset, paramMap)
         }.toArray[ClassificationModel[_, _]]
     
         if (handlePersistence) {
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
    index 38e832372698c..dad451108626d 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
    @@ -173,5 +173,5 @@ private[spark] abstract class ProbabilisticClassificationModel[
        * This may be overridden to support thresholds which favor particular labels.
        * @return  predicted label
        */
    -  protected def probability2prediction(probability: Vector): Double = probability.toDense.argmax
    +  protected def probability2prediction(probability: Vector): Double = probability.argmax
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
    index d3c67494a31e4..bc19bd6df894f 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
    @@ -20,17 +20,19 @@ package org.apache.spark.ml.classification
     import scala.collection.mutable
     
     import org.apache.spark.annotation.Experimental
    +import org.apache.spark.ml.tree.impl.RandomForest
     import org.apache.spark.ml.{PredictionModel, Predictor}
     import org.apache.spark.ml.param.ParamMap
     import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel}
     import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
    -import org.apache.spark.mllib.linalg.Vector
    +import org.apache.spark.mllib.linalg.{Vector, Vectors}
     import org.apache.spark.mllib.regression.LabeledPoint
    -import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest}
     import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
     import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
     import org.apache.spark.rdd.RDD
     import org.apache.spark.sql.DataFrame
    +import org.apache.spark.sql.functions._
    +import org.apache.spark.sql.types.DoubleType
     
     /**
      * :: Experimental ::
    @@ -41,7 +43,7 @@ import org.apache.spark.sql.DataFrame
      */
     @Experimental
     final class RandomForestClassifier(override val uid: String)
    -  extends Predictor[Vector, RandomForestClassifier, RandomForestClassificationModel]
    +  extends Classifier[Vector, RandomForestClassifier, RandomForestClassificationModel]
       with RandomForestParams with TreeClassifierParams {
     
       def this() = this(Identifiable.randomUID("rfc"))
    @@ -93,9 +95,10 @@ final class RandomForestClassifier(override val uid: String)
         val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
         val strategy =
           super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
    -    val oldModel = OldRandomForest.trainClassifier(
    -      oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt)
    -    RandomForestClassificationModel.fromOld(oldModel, this, categoricalFeatures)
    +    val trees =
    +      RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed)
    +        .map(_.asInstanceOf[DecisionTreeClassificationModel])
    +    new RandomForestClassificationModel(trees, numClasses)
       }
     
       override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra)
    @@ -122,12 +125,20 @@ object RandomForestClassifier {
     @Experimental
     final class RandomForestClassificationModel private[ml] (
         override val uid: String,
    -    private val _trees: Array[DecisionTreeClassificationModel])
    -  extends PredictionModel[Vector, RandomForestClassificationModel]
    +    private val _trees: Array[DecisionTreeClassificationModel],
    +    override val numClasses: Int)
    +  extends ClassificationModel[Vector, RandomForestClassificationModel]
       with TreeEnsembleModel with Serializable {
     
       require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.")
     
    +  /**
    +   * Construct a random forest classification model, with all trees weighted equally.
    +   * @param trees  Component trees
    +   */
    +  def this(trees: Array[DecisionTreeClassificationModel], numClasses: Int) =
    +    this(Identifiable.randomUID("rfc"), trees, numClasses)
    +
       override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
     
       // Note: We may add support for weights (based on tree performance) later on.
    @@ -135,21 +146,28 @@ final class RandomForestClassificationModel private[ml] (
     
       override def treeWeights: Array[Double] = _treeWeights
     
    -  override protected def predict(features: Vector): Double = {
    -    // TODO: Override transform() to broadcast model.  SPARK-7127
    +  override protected def transformImpl(dataset: DataFrame): DataFrame = {
    +    val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
    +    val predictUDF = udf { (features: Any) =>
    +      bcastModel.value.predict(features.asInstanceOf[Vector])
    +    }
    +    dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
    +  }
    +
    +  override protected def predictRaw(features: Vector): Vector = {
         // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128
         // Classifies using majority votes.
         // Ignore the weights since all are 1.0 for now.
    -    val votes = mutable.Map.empty[Int, Double]
    +    val votes = new Array[Double](numClasses)
         _trees.view.foreach { tree =>
           val prediction = tree.rootNode.predict(features).toInt
    -      votes(prediction) = votes.getOrElse(prediction, 0.0) + 1.0 // 1.0 = weight
    +      votes(prediction) = votes(prediction) + 1.0 // 1.0 = weight
         }
    -    votes.maxBy(_._2)._1
    +    Vectors.dense(votes)
       }
     
       override def copy(extra: ParamMap): RandomForestClassificationModel = {
    -    copyValues(new RandomForestClassificationModel(uid, _trees), extra)
    +    copyValues(new RandomForestClassificationModel(uid, _trees, numClasses), extra)
       }
     
       override def toString: String = {
    @@ -168,7 +186,8 @@ private[ml] object RandomForestClassificationModel {
       def fromOld(
           oldModel: OldRandomForestModel,
           parent: RandomForestClassifier,
    -      categoricalFeatures: Map[Int, Int]): RandomForestClassificationModel = {
    +      categoricalFeatures: Map[Int, Int],
    +      numClasses: Int): RandomForestClassificationModel = {
         require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" +
           s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).")
         val newTrees = oldModel.trees.map { tree =>
    @@ -176,6 +195,6 @@ private[ml] object RandomForestClassificationModel {
           DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures)
         }
         val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc")
    -    new RandomForestClassificationModel(uid, newTrees)
    +    new RandomForestClassificationModel(uid, newTrees, numClasses)
       }
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
    new file mode 100644
    index 0000000000000..dc192add6ca13
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
    @@ -0,0 +1,205 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.clustering
    +
    +import org.apache.spark.annotation.Experimental
    +import org.apache.spark.ml.param.{Param, Params, IntParam, DoubleParam, ParamMap}
    +import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasMaxIter, HasPredictionCol, HasSeed}
    +import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
    +import org.apache.spark.ml.{Estimator, Model}
    +import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel}
    +import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
    +import org.apache.spark.sql.functions.{col, udf}
    +import org.apache.spark.sql.types.{IntegerType, StructType}
    +import org.apache.spark.sql.{DataFrame, Row}
    +import org.apache.spark.util.Utils
    +
    +
    +/**
    + * Common params for KMeans and KMeansModel
    + */
    +private[clustering] trait KMeansParams
    +    extends Params with HasMaxIter with HasFeaturesCol with HasSeed with HasPredictionCol {
    +
    +  /**
    +   * Set the number of clusters to create (k). Must be > 1. Default: 2.
    +   * @group param
    +   */
    +  final val k = new IntParam(this, "k", "number of clusters to create", (x: Int) => x > 1)
    +
    +  /** @group getParam */
    +  def getK: Int = $(k)
    +
    +  /**
    +   * Param the number of runs of the algorithm to execute in parallel. We initialize the algorithm
    +   * this many times with random starting conditions (configured by the initialization mode), then
    +   * return the best clustering found over any run. Must be >= 1. Default: 1.
    +   * @group param
    +   */
    +  final val runs = new IntParam(this, "runs",
    +    "number of runs of the algorithm to execute in parallel", (value: Int) => value >= 1)
    +
    +  /** @group getParam */
    +  def getRuns: Int = $(runs)
    +
    +  /**
    +   * Param the distance threshold within which we've consider centers to have converged.
    +   * If all centers move less than this Euclidean distance, we stop iterating one run.
    +   * Must be >= 0.0. Default: 1e-4
    +   * @group param
    +   */
    +  final val epsilon = new DoubleParam(this, "epsilon",
    +    "distance threshold within which we've consider centers to have converge",
    +    (value: Double) => value >= 0.0)
    +
    +  /** @group getParam */
    +  def getEpsilon: Double = $(epsilon)
    +
    +  /**
    +   * Param for the initialization algorithm. This can be either "random" to choose random points as
    +   * initial cluster centers, or "k-means||" to use a parallel variant of k-means++
    +   * (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||.
    +   * @group expertParam
    +   */
    +  final val initMode = new Param[String](this, "initMode", "initialization algorithm",
    +    (value: String) => MLlibKMeans.validateInitMode(value))
    +
    +  /** @group expertGetParam */
    +  def getInitMode: String = $(initMode)
    +
    +  /**
    +   * Param for the number of steps for the k-means|| initialization mode. This is an advanced
    +   * setting -- the default of 5 is almost always enough. Must be > 0. Default: 5.
    +   * @group expertParam
    +   */
    +  final val initSteps = new IntParam(this, "initSteps", "number of steps for k-means||",
    +    (value: Int) => value > 0)
    +
    +  /** @group expertGetParam */
    +  def getInitSteps: Int = $(initSteps)
    +
    +  /**
    +   * Validates and transforms the input schema.
    +   * @param schema input schema
    +   * @return output schema
    +   */
    +  protected def validateAndTransformSchema(schema: StructType): StructType = {
    +    SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
    +    SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
    +  }
    +}
    +
    +/**
    + * :: Experimental ::
    + * Model fitted by KMeans.
    + *
    + * @param parentModel a model trained by spark.mllib.clustering.KMeans.
    + */
    +@Experimental
    +class KMeansModel private[ml] (
    +    override val uid: String,
    +    private val parentModel: MLlibKMeansModel) extends Model[KMeansModel] with KMeansParams {
    +
    +  override def copy(extra: ParamMap): KMeansModel = {
    +    val copied = new KMeansModel(uid, parentModel)
    +    copyValues(copied, extra)
    +  }
    +
    +  override def transform(dataset: DataFrame): DataFrame = {
    +    val predictUDF = udf((vector: Vector) => predict(vector))
    +    dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
    +  }
    +
    +  override def transformSchema(schema: StructType): StructType = {
    +    validateAndTransformSchema(schema)
    +  }
    +
    +  private[clustering] def predict(features: Vector): Int = parentModel.predict(features)
    +
    +  def clusterCenters: Array[Vector] = parentModel.clusterCenters
    +}
    +
    +/**
    + * :: Experimental ::
    + * K-means clustering with support for multiple parallel runs and a k-means++ like initialization
    + * mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested,
    + * they are executed together with joint passes over the data for efficiency.
    + */
    +@Experimental
    +class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMeansParams {
    +
    +  setDefault(
    +    k -> 2,
    +    maxIter -> 20,
    +    runs -> 1,
    +    initMode -> MLlibKMeans.K_MEANS_PARALLEL,
    +    initSteps -> 5,
    +    epsilon -> 1e-4)
    +
    +  override def copy(extra: ParamMap): KMeans = defaultCopy(extra)
    +
    +  def this() = this(Identifiable.randomUID("kmeans"))
    +
    +  /** @group setParam */
    +  def setFeaturesCol(value: String): this.type = set(featuresCol, value)
    +
    +  /** @group setParam */
    +  def setPredictionCol(value: String): this.type = set(predictionCol, value)
    +
    +  /** @group setParam */
    +  def setK(value: Int): this.type = set(k, value)
    +
    +  /** @group expertSetParam */
    +  def setInitMode(value: String): this.type = set(initMode, value)
    +
    +  /** @group expertSetParam */
    +  def setInitSteps(value: Int): this.type = set(initSteps, value)
    +
    +  /** @group setParam */
    +  def setMaxIter(value: Int): this.type = set(maxIter, value)
    +
    +  /** @group setParam */
    +  def setRuns(value: Int): this.type = set(runs, value)
    +
    +  /** @group setParam */
    +  def setEpsilon(value: Double): this.type = set(epsilon, value)
    +
    +  /** @group setParam */
    +  def setSeed(value: Long): this.type = set(seed, value)
    +
    +  override def fit(dataset: DataFrame): KMeansModel = {
    +    val rdd = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point }
    +
    +    val algo = new MLlibKMeans()
    +      .setK($(k))
    +      .setInitializationMode($(initMode))
    +      .setInitializationSteps($(initSteps))
    +      .setMaxIterations($(maxIter))
    +      .setSeed($(seed))
    +      .setEpsilon($(epsilon))
    +      .setRuns($(runs))
    +    val parentModel = algo.run(rdd)
    +    val model = new KMeansModel(uid, parentModel)
    +    copyValues(model)
    +  }
    +
    +  override def transformSchema(schema: StructType): StructType = {
    +    validateAndTransformSchema(schema)
    +  }
    +}
    +
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala
    new file mode 100644
    index 0000000000000..6b77de89a0330
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala
    @@ -0,0 +1,82 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +package org.apache.spark.ml.feature
    +
    +import scala.collection.mutable
    +
    +import org.apache.spark.annotation.Experimental
    +import org.apache.spark.ml.UnaryTransformer
    +import org.apache.spark.ml.param.{ParamMap, ParamValidators, IntParam}
    +import org.apache.spark.ml.util.Identifiable
    +import org.apache.spark.mllib.linalg.{Vectors, VectorUDT, Vector}
    +import org.apache.spark.sql.types.{StringType, ArrayType, DataType}
    +
    +/**
    + * :: Experimental ::
    + * Converts a text document to a sparse vector of token counts.
    + * @param vocabulary An Array over terms. Only the terms in the vocabulary will be counted.
    + */
    +@Experimental
    +class CountVectorizerModel (override val uid: String, val vocabulary: Array[String])
    +  extends UnaryTransformer[Seq[String], Vector, CountVectorizerModel] {
    +
    +  def this(vocabulary: Array[String]) =
    +    this(Identifiable.randomUID("cntVec"), vocabulary)
    +
    +  /**
    +   * Corpus-specific filter to ignore scarce words in a document. For each document, terms with
    +   * frequency (count) less than the given threshold are ignored.
    +   * Default: 1
    +   * @group param
    +   */
    +  val minTermFreq: IntParam = new IntParam(this, "minTermFreq",
    +    "minimum frequency (count) filter used to neglect scarce words (>= 1). For each document, " +
    +      "terms with frequency less than the given threshold are ignored.", ParamValidators.gtEq(1))
    +
    +  /** @group setParam */
    +  def setMinTermFreq(value: Int): this.type = set(minTermFreq, value)
    +
    +  /** @group getParam */
    +  def getMinTermFreq: Int = $(minTermFreq)
    +
    +  setDefault(minTermFreq -> 1)
    +
    +  override protected def createTransformFunc: Seq[String] => Vector = {
    +    val dict = vocabulary.zipWithIndex.toMap
    +    document =>
    +      val termCounts = mutable.HashMap.empty[Int, Double]
    +      document.foreach { term =>
    +        dict.get(term) match {
    +          case Some(index) => termCounts.put(index, termCounts.getOrElse(index, 0.0) + 1.0)
    +          case None => // ignore terms not in the vocabulary
    +        }
    +      }
    +      Vectors.sparse(dict.size, termCounts.filter(_._2 >= $(minTermFreq)).toSeq)
    +  }
    +
    +  override protected def validateInputType(inputType: DataType): Unit = {
    +    require(inputType.sameType(ArrayType(StringType)),
    +      s"Input type must be ArrayType(StringType) but got $inputType.")
    +  }
    +
    +  override protected def outputDataType: DataType = new VectorUDT()
    +
    +  override def copy(extra: ParamMap): CountVectorizerModel = {
    +    val copied = new CountVectorizerModel(uid, vocabulary)
    +    copyValues(copied, extra)
    +  }
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
    index 3825942795645..9c60d4084ec46 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
    @@ -66,7 +66,6 @@ class OneHotEncoder(override val uid: String) extends Transformer
       def setOutputCol(value: String): this.type = set(outputCol, value)
     
       override def transformSchema(schema: StructType): StructType = {
    -    val is = "_is_"
         val inputColName = $(inputCol)
         val outputColName = $(outputCol)
     
    @@ -79,17 +78,17 @@ class OneHotEncoder(override val uid: String) extends Transformer
         val outputAttrNames: Option[Array[String]] = inputAttr match {
           case nominal: NominalAttribute =>
             if (nominal.values.isDefined) {
    -          nominal.values.map(_.map(v => inputColName + is + v))
    +          nominal.values
             } else if (nominal.numValues.isDefined) {
    -          nominal.numValues.map(n => Array.tabulate(n)(i => inputColName + is + i))
    +          nominal.numValues.map(n => Array.tabulate(n)(_.toString))
             } else {
               None
             }
           case binary: BinaryAttribute =>
             if (binary.values.isDefined) {
    -          binary.values.map(_.map(v => inputColName + is + v))
    +          binary.values
             } else {
    -          Some(Array.tabulate(2)(i => inputColName + is + i))
    +          Some(Array.tabulate(2)(_.toString))
             }
           case _: NumericAttribute =>
             throw new RuntimeException(
    @@ -123,7 +122,6 @@ class OneHotEncoder(override val uid: String) extends Transformer
     
       override def transform(dataset: DataFrame): DataFrame = {
         // schema transformation
    -    val is = "_is_"
         val inputColName = $(inputCol)
         val outputColName = $(outputCol)
         val shouldDropLast = $(dropLast)
    @@ -142,7 +140,7 @@ class OneHotEncoder(override val uid: String) extends Transformer
                 math.max(m0, m1)
               }
             ).toInt + 1
    -      val outputAttrNames = Array.tabulate(numAttrs)(i => inputColName + is + i)
    +      val outputAttrNames = Array.tabulate(numAttrs)(_.toString)
           val filtered = if (shouldDropLast) outputAttrNames.dropRight(1) else outputAttrNames
           val outputAttrs: Array[Attribute] =
             filtered.map(name => BinaryAttribute.defaultAttr.withName(name))
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
    new file mode 100644
    index 0000000000000..d1726917e4517
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
    @@ -0,0 +1,226 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.feature
    +
    +import scala.collection.mutable
    +import scala.collection.mutable.ArrayBuffer
    +import scala.util.parsing.combinator.RegexParsers
    +
    +import org.apache.spark.annotation.Experimental
    +import org.apache.spark.ml.{Estimator, Model, Transformer, Pipeline, PipelineModel, PipelineStage}
    +import org.apache.spark.ml.param.{Param, ParamMap}
    +import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
    +import org.apache.spark.ml.util.Identifiable
    +import org.apache.spark.mllib.linalg.VectorUDT
    +import org.apache.spark.sql.DataFrame
    +import org.apache.spark.sql.functions._
    +import org.apache.spark.sql.types._
    +
    +/**
    + * Base trait for [[RFormula]] and [[RFormulaModel]].
    + */
    +private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol {
    +  /** @group getParam */
    +  def setFeaturesCol(value: String): this.type = set(featuresCol, value)
    +
    +  /** @group getParam */
    +  def setLabelCol(value: String): this.type = set(labelCol, value)
    +
    +  protected def hasLabelCol(schema: StructType): Boolean = {
    +    schema.map(_.name).contains($(labelCol))
    +  }
    +}
    +
    +/**
    + * :: Experimental ::
    + * Implements the transforms required for fitting a dataset against an R model formula. Currently
    + * we support a limited subset of the R operators, including '~' and '+'. Also see the R formula
    + * docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
    + */
    +@Experimental
    +class RFormula(override val uid: String) extends Estimator[RFormulaModel] with RFormulaBase {
    +
    +  def this() = this(Identifiable.randomUID("rFormula"))
    +
    +  /**
    +   * R formula parameter. The formula is provided in string form.
    +   * @group param
    +   */
    +  val formula: Param[String] = new Param(this, "formula", "R model formula")
    +
    +  private var parsedFormula: Option[ParsedRFormula] = None
    +
    +  /**
    +   * Sets the formula to use for this transformer. Must be called before use.
    +   * @group setParam
    +   * @param value an R formula in string form (e.g. "y ~ x + z")
    +   */
    +  def setFormula(value: String): this.type = {
    +    parsedFormula = Some(RFormulaParser.parse(value))
    +    set(formula, value)
    +    this
    +  }
    +
    +  /** @group getParam */
    +  def getFormula: String = $(formula)
    +
    +  /** Whether the formula specifies fitting an intercept. */
    +  private[ml] def hasIntercept: Boolean = {
    +    require(parsedFormula.isDefined, "Must call setFormula() first.")
    +    parsedFormula.get.hasIntercept
    +  }
    +
    +  override def fit(dataset: DataFrame): RFormulaModel = {
    +    require(parsedFormula.isDefined, "Must call setFormula() first.")
    +    val resolvedFormula = parsedFormula.get.resolve(dataset.schema)
    +    // StringType terms and terms representing interactions need to be encoded before assembly.
    +    // TODO(ekl) add support for feature interactions
    +    val encoderStages = ArrayBuffer[PipelineStage]()
    +    val tempColumns = ArrayBuffer[String]()
    +    val takenNames = mutable.Set(dataset.columns: _*)
    +    val encodedTerms = resolvedFormula.terms.map { term =>
    +      dataset.schema(term) match {
    +        case column if column.dataType == StringType =>
    +          val indexCol = term + "_idx_" + uid
    +          val encodedCol = {
    +            var tmp = term
    +            while (takenNames.contains(tmp)) {
    +              tmp += "_"
    +            }
    +            tmp
    +          }
    +          takenNames.add(indexCol)
    +          takenNames.add(encodedCol)
    +          encoderStages += new StringIndexer().setInputCol(term).setOutputCol(indexCol)
    +          encoderStages += new OneHotEncoder().setInputCol(indexCol).setOutputCol(encodedCol)
    +          tempColumns += indexCol
    +          tempColumns += encodedCol
    +          encodedCol
    +        case _ =>
    +          term
    +      }
    +    }
    +    encoderStages += new VectorAssembler(uid)
    +      .setInputCols(encodedTerms.toArray)
    +      .setOutputCol($(featuresCol))
    +    encoderStages += new ColumnPruner(tempColumns.toSet)
    +    val pipelineModel = new Pipeline(uid).setStages(encoderStages.toArray).fit(dataset)
    +    copyValues(new RFormulaModel(uid, resolvedFormula, pipelineModel).setParent(this))
    +  }
    +
    +  // optimistic schema; does not contain any ML attributes
    +  override def transformSchema(schema: StructType): StructType = {
    +    if (hasLabelCol(schema)) {
    +      StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true))
    +    } else {
    +      StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true) :+
    +        StructField($(labelCol), DoubleType, true))
    +    }
    +  }
    +
    +  override def copy(extra: ParamMap): RFormula = defaultCopy(extra)
    +
    +  override def toString: String = s"RFormula(${get(formula)})"
    +}
    +
    +/**
    + * :: Experimental ::
    + * A fitted RFormula. Fitting is required to determine the factor levels of formula terms.
    + * @param resolvedFormula the fitted R formula.
    + * @param pipelineModel the fitted feature model, including factor to index mappings.
    + */
    +@Experimental
    +class RFormulaModel private[feature](
    +    override val uid: String,
    +    resolvedFormula: ResolvedRFormula,
    +    pipelineModel: PipelineModel)
    +  extends Model[RFormulaModel] with RFormulaBase {
    +
    +  override def transform(dataset: DataFrame): DataFrame = {
    +    checkCanTransform(dataset.schema)
    +    transformLabel(pipelineModel.transform(dataset))
    +  }
    +
    +  override def transformSchema(schema: StructType): StructType = {
    +    checkCanTransform(schema)
    +    val withFeatures = pipelineModel.transformSchema(schema)
    +    if (hasLabelCol(schema)) {
    +      withFeatures
    +    } else if (schema.exists(_.name == resolvedFormula.label)) {
    +      val nullable = schema(resolvedFormula.label).dataType match {
    +        case _: NumericType | BooleanType => false
    +        case _ => true
    +      }
    +      StructType(withFeatures.fields :+ StructField($(labelCol), DoubleType, nullable))
    +    } else {
    +      // Ignore the label field. This is a hack so that this transformer can also work on test
    +      // datasets in a Pipeline.
    +      withFeatures
    +    }
    +  }
    +
    +  override def copy(extra: ParamMap): RFormulaModel = copyValues(
    +    new RFormulaModel(uid, resolvedFormula, pipelineModel))
    +
    +  override def toString: String = s"RFormulaModel(${resolvedFormula})"
    +
    +  private def transformLabel(dataset: DataFrame): DataFrame = {
    +    val labelName = resolvedFormula.label
    +    if (hasLabelCol(dataset.schema)) {
    +      dataset
    +    } else if (dataset.schema.exists(_.name == labelName)) {
    +      dataset.schema(labelName).dataType match {
    +        case _: NumericType | BooleanType =>
    +          dataset.withColumn($(labelCol), dataset(labelName).cast(DoubleType))
    +        case other =>
    +          throw new IllegalArgumentException("Unsupported type for label: " + other)
    +      }
    +    } else {
    +      // Ignore the label field. This is a hack so that this transformer can also work on test
    +      // datasets in a Pipeline.
    +      dataset
    +    }
    +  }
    +
    +  private def checkCanTransform(schema: StructType) {
    +    val columnNames = schema.map(_.name)
    +    require(!columnNames.contains($(featuresCol)), "Features column already exists.")
    +    require(
    +      !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType,
    +      "Label column already exists and is not of type DoubleType.")
    +  }
    +}
    +
    +/**
    + * Utility transformer for removing temporary columns from a DataFrame.
    + * TODO(ekl) make this a public transformer
    + */
    +private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer {
    +  override val uid = Identifiable.randomUID("columnPruner")
    +
    +  override def transform(dataset: DataFrame): DataFrame = {
    +    val columnsToKeep = dataset.columns.filter(!columnsToPrune.contains(_))
    +    dataset.select(columnsToKeep.map(dataset.col) : _*)
    +  }
    +
    +  override def transformSchema(schema: StructType): StructType = {
    +    StructType(schema.fields.filter(col => !columnsToPrune.contains(col.name)))
    +  }
    +
    +  override def copy(extra: ParamMap): ColumnPruner = defaultCopy(extra)
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala
    new file mode 100644
    index 0000000000000..1ca3b92a7d92a
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala
    @@ -0,0 +1,129 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.feature
    +
    +import scala.util.parsing.combinator.RegexParsers
    +
    +import org.apache.spark.mllib.linalg.VectorUDT
    +import org.apache.spark.sql.types._
    +
    +/**
    + * Represents a parsed R formula.
    + */
    +private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) {
    +  /**
    +   * Resolves formula terms into column names. A schema is necessary for inferring the meaning
    +   * of the special '.' term. Duplicate terms will be removed during resolution.
    +   */
    +  def resolve(schema: StructType): ResolvedRFormula = {
    +    var includedTerms = Seq[String]()
    +    terms.foreach {
    +      case Dot =>
    +        includedTerms ++= simpleTypes(schema).filter(_ != label.value)
    +      case ColumnRef(value) =>
    +        includedTerms :+= value
    +      case Deletion(term: Term) =>
    +        term match {
    +          case ColumnRef(value) =>
    +            includedTerms = includedTerms.filter(_ != value)
    +          case Dot =>
    +            // e.g. "- .", which removes all first-order terms
    +            val fromSchema = simpleTypes(schema)
    +            includedTerms = includedTerms.filter(fromSchema.contains(_))
    +          case _: Deletion =>
    +            assert(false, "Deletion terms cannot be nested")
    +          case _: Intercept =>
    +        }
    +      case _: Intercept =>
    +    }
    +    ResolvedRFormula(label.value, includedTerms.distinct)
    +  }
    +
    +  /** Whether this formula specifies fitting with an intercept term. */
    +  def hasIntercept: Boolean = {
    +    var intercept = true
    +    terms.foreach {
    +      case Intercept(enabled) =>
    +        intercept = enabled
    +      case Deletion(Intercept(enabled)) =>
    +        intercept = !enabled
    +      case _ =>
    +    }
    +    intercept
    +  }
    +
    +  // the dot operator excludes complex column types
    +  private def simpleTypes(schema: StructType): Seq[String] = {
    +    schema.fields.filter(_.dataType match {
    +      case _: NumericType | StringType | BooleanType | _: VectorUDT => true
    +      case _ => false
    +    }).map(_.name)
    +  }
    +}
    +
    +/**
    + * Represents a fully evaluated and simplified R formula.
    + */
    +private[ml] case class ResolvedRFormula(label: String, terms: Seq[String])
    +
    +/**
    + * R formula terms. See the R formula docs here for more information:
    + * http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
    + */
    +private[ml] sealed trait Term
    +
    +/* R formula reference to all available columns, e.g. "." in a formula */
    +private[ml] case object Dot extends Term
    +
    +/* R formula reference to a column, e.g. "+ Species" in a formula */
    +private[ml] case class ColumnRef(value: String) extends Term
    +
    +/* R formula intercept toggle, e.g. "+ 0" in a formula */
    +private[ml] case class Intercept(enabled: Boolean) extends Term
    +
    +/* R formula deletion of a variable, e.g. "- Species" in a formula */
    +private[ml] case class Deletion(term: Term) extends Term
    +
    +/**
    + * Limited implementation of R formula parsing. Currently supports: '~', '+', '-', '.'.
    + */
    +private[ml] object RFormulaParser extends RegexParsers {
    +  def intercept: Parser[Intercept] =
    +    "([01])".r ^^ { case a => Intercept(a == "1") }
    +
    +  def columnRef: Parser[ColumnRef] =
    +    "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r ^^ { case a => ColumnRef(a) }
    +
    +  def term: Parser[Term] = intercept | columnRef | "\\.".r ^^ { case _ => Dot }
    +
    +  def terms: Parser[List[Term]] = (term ~ rep("+" ~ term | "-" ~ term)) ^^ {
    +    case op ~ list => list.foldLeft(List(op)) {
    +      case (left, "+" ~ right) => left ++ Seq(right)
    +      case (left, "-" ~ right) => left ++ Seq(Deletion(right))
    +    }
    +  }
    +
    +  def formula: Parser[ParsedRFormula] =
    +    (columnRef ~ "~" ~ terms) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) }
    +
    +  def parse(value: String): ParsedRFormula = parseAll(formula, value) match {
    +    case Success(result, _) => result
    +    case failure: NoSuccess => throw new IllegalArgumentException(
    +      "Could not parse formula: " + value)
    +  }
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
    index ca3c1cfb56b7f..72b545e5db3e4 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
    @@ -106,6 +106,12 @@ class StandardScalerModel private[ml] (
         scaler: feature.StandardScalerModel)
       extends Model[StandardScalerModel] with StandardScalerParams {
     
    +  /** Standard deviation of the StandardScalerModel */
    +  val std: Vector = scaler.std
    +
    +  /** Mean of the StandardScalerModel */
    +  val mean: Vector = scaler.mean
    +
       /** @group setParam */
       def setInputCol(value: String): this.type = set(inputCol, value)
     
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
    index 5f9f57a2ebcfa..248288ca73e99 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
    @@ -42,7 +42,7 @@ class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[S
         require(inputType == StringType, s"Input type must be string type but got $inputType.")
       }
     
    -  override protected def outputDataType: DataType = new ArrayType(StringType, false)
    +  override protected def outputDataType: DataType = new ArrayType(StringType, true)
     
       override def copy(extra: ParamMap): Tokenizer = defaultCopy(extra)
     }
    @@ -50,7 +50,7 @@ class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[S
     /**
      * :: Experimental ::
      * A regex based tokenizer that extracts tokens either by using the provided regex pattern to split
    - * the text (default) or repeatedly matching the regex (if `gaps` is true).
    + * the text (default) or repeatedly matching the regex (if `gaps` is false).
      * Optional parameters also allow filtering tokens using a minimal length.
      * It returns an array of strings that can be empty.
      */
    @@ -113,7 +113,7 @@ class RegexTokenizer(override val uid: String)
         require(inputType == StringType, s"Input type must be string type but got $inputType.")
       }
     
    -  override protected def outputDataType: DataType = new ArrayType(StringType, false)
    +  override protected def outputDataType: DataType = new ArrayType(StringType, true)
     
       override def copy(extra: ParamMap): RegexTokenizer = defaultCopy(extra)
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
    index 9f83c2ee16178..086917fa680f8 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
    @@ -116,7 +116,7 @@ class VectorAssembler(override val uid: String)
         if (schema.fieldNames.contains(outputColName)) {
           throw new IllegalArgumentException(s"Output column $outputColName already exists.")
         }
    -    StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, false))
    +    StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, true))
       }
     
       override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra)
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
    index 50c0d855066f8..954aa17e26a02 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
    @@ -295,6 +295,22 @@ class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array
         w(value.asScala.map(_.asInstanceOf[Double]).toArray)
     }
     
    +/**
    + * :: DeveloperApi ::
    + * Specialized version of [[Param[Array[Int]]]] for Java.
    + */
    +@DeveloperApi
    +class IntArrayParam(parent: Params, name: String, doc: String, isValid: Array[Int] => Boolean)
    +  extends Param[Array[Int]](parent, name, doc, isValid) {
    +
    +  def this(parent: Params, name: String, doc: String) =
    +    this(parent, name, doc, ParamValidators.alwaysTrue)
    +
    +  /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
    +  def w(value: java.util.List[java.lang.Integer]): ParamPair[Array[Int]] =
    +    w(value.asScala.map(_.asInstanceOf[Int]).toArray)
    +}
    +
     /**
      * :: Experimental ::
      * A param and its value.
    @@ -341,9 +357,7 @@ trait Params extends Identifiable with Serializable {
        * those are checked during schema validation.
        */
       def validateParams(): Unit = {
    -    params.filter(isDefined).foreach { param =>
    -      param.asInstanceOf[Param[Any]].validate($(param))
    -    }
    +    // Do nothing by default.  Override to handle Param interactions.
       }
     
       /**
    @@ -462,11 +476,14 @@ trait Params extends Identifiable with Serializable {
       /**
        * Sets default values for a list of params.
        *
    +   * Note: Java developers should use the single-parameter [[setDefault()]].
    +   *       Annotating this with varargs can cause compilation failures due to a Scala compiler bug.
    +   *       See SPARK-9268.
    +   *
        * @param paramPairs  a list of param pairs that specify params and their default values to set
        *                    respectively. Make sure that the params are initialized before this method
        *                    gets called.
        */
    -  @varargs
       protected final def setDefault(paramPairs: ParamPair[_]*): this.type = {
         paramPairs.foreach { p =>
           setDefault(p.param.asInstanceOf[Param[Any]], p.value)
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
    index b0a6af171c01f..f7ae1de522e01 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
    @@ -54,8 +54,7 @@ private[shared] object SharedParamsCodeGen {
             isValid = "ParamValidators.gtEq(1)"),
           ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")),
           ParamDesc[Boolean]("standardization", "whether to standardize the training features" +
    -        " prior to fitting the model sequence. Note that the coefficients of models are" +
    -        " always returned on the original scale.", Some("true")),
    +        " before fitting the model.", Some("true")),
           ParamDesc[Long]("seed", "random seed", Some("this.getClass.getName.hashCode.toLong")),
           ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]." +
             " For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.",
    @@ -135,7 +134,7 @@ private[shared] object SharedParamsCodeGen {
     
         s"""
           |/**
    -      | * (private[ml]) Trait for shared param $name$defaultValueDoc.
    +      | * Trait for shared param $name$defaultValueDoc.
           | */
           |private[ml] trait Has$Name extends Params {
           |
    @@ -174,7 +173,6 @@ private[shared] object SharedParamsCodeGen {
             |package org.apache.spark.ml.param.shared
             |
             |import org.apache.spark.ml.param._
    -        |import org.apache.spark.util.Utils
             |
             |// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen.
             |
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
    index bbe08939b6d75..65e48e4ee5083 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
    @@ -18,14 +18,13 @@
     package org.apache.spark.ml.param.shared
     
     import org.apache.spark.ml.param._
    -import org.apache.spark.util.Utils
     
     // DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen.
     
     // scalastyle:off
     
     /**
    - * (private[ml]) Trait for shared param regParam.
    + * Trait for shared param regParam.
      */
     private[ml] trait HasRegParam extends Params {
     
    @@ -40,7 +39,7 @@ private[ml] trait HasRegParam extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param maxIter.
    + * Trait for shared param maxIter.
      */
     private[ml] trait HasMaxIter extends Params {
     
    @@ -55,7 +54,7 @@ private[ml] trait HasMaxIter extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param featuresCol (default: "features").
    + * Trait for shared param featuresCol (default: "features").
      */
     private[ml] trait HasFeaturesCol extends Params {
     
    @@ -72,7 +71,7 @@ private[ml] trait HasFeaturesCol extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param labelCol (default: "label").
    + * Trait for shared param labelCol (default: "label").
      */
     private[ml] trait HasLabelCol extends Params {
     
    @@ -89,7 +88,7 @@ private[ml] trait HasLabelCol extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param predictionCol (default: "prediction").
    + * Trait for shared param predictionCol (default: "prediction").
      */
     private[ml] trait HasPredictionCol extends Params {
     
    @@ -106,7 +105,7 @@ private[ml] trait HasPredictionCol extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param rawPredictionCol (default: "rawPrediction").
    + * Trait for shared param rawPredictionCol (default: "rawPrediction").
      */
     private[ml] trait HasRawPredictionCol extends Params {
     
    @@ -123,7 +122,7 @@ private[ml] trait HasRawPredictionCol extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param probabilityCol (default: "probability").
    + * Trait for shared param probabilityCol (default: "probability").
      */
     private[ml] trait HasProbabilityCol extends Params {
     
    @@ -140,7 +139,7 @@ private[ml] trait HasProbabilityCol extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param threshold.
    + * Trait for shared param threshold.
      */
     private[ml] trait HasThreshold extends Params {
     
    @@ -155,7 +154,7 @@ private[ml] trait HasThreshold extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param inputCol.
    + * Trait for shared param inputCol.
      */
     private[ml] trait HasInputCol extends Params {
     
    @@ -170,7 +169,7 @@ private[ml] trait HasInputCol extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param inputCols.
    + * Trait for shared param inputCols.
      */
     private[ml] trait HasInputCols extends Params {
     
    @@ -185,7 +184,7 @@ private[ml] trait HasInputCols extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param outputCol (default: uid + "__output").
    + * Trait for shared param outputCol (default: uid + "__output").
      */
     private[ml] trait HasOutputCol extends Params {
     
    @@ -202,7 +201,7 @@ private[ml] trait HasOutputCol extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param checkpointInterval.
    + * Trait for shared param checkpointInterval.
      */
     private[ml] trait HasCheckpointInterval extends Params {
     
    @@ -217,7 +216,7 @@ private[ml] trait HasCheckpointInterval extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param fitIntercept (default: true).
    + * Trait for shared param fitIntercept (default: true).
      */
     private[ml] trait HasFitIntercept extends Params {
     
    @@ -234,15 +233,15 @@ private[ml] trait HasFitIntercept extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param standardization (default: true).
    + * Trait for shared param standardization (default: true).
      */
     private[ml] trait HasStandardization extends Params {
     
       /**
    -   * Param for whether to standardize the training features prior to fitting the model sequence. Note that the coefficients of models are always returned on the original scale..
    +   * Param for whether to standardize the training features before fitting the model..
        * @group param
        */
    -  final val standardization: BooleanParam = new BooleanParam(this, "standardization", "whether to standardize the training features prior to fitting the model sequence. Note that the coefficients of models are always returned on the original scale.")
    +  final val standardization: BooleanParam = new BooleanParam(this, "standardization", "whether to standardize the training features before fitting the model.")
     
       setDefault(standardization, true)
     
    @@ -251,7 +250,7 @@ private[ml] trait HasStandardization extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param seed (default: this.getClass.getName.hashCode.toLong).
    + * Trait for shared param seed (default: this.getClass.getName.hashCode.toLong).
      */
     private[ml] trait HasSeed extends Params {
     
    @@ -268,7 +267,7 @@ private[ml] trait HasSeed extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param elasticNetParam.
    + * Trait for shared param elasticNetParam.
      */
     private[ml] trait HasElasticNetParam extends Params {
     
    @@ -283,7 +282,7 @@ private[ml] trait HasElasticNetParam extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param tol.
    + * Trait for shared param tol.
      */
     private[ml] trait HasTol extends Params {
     
    @@ -298,7 +297,7 @@ private[ml] trait HasTol extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param stepSize.
    + * Trait for shared param stepSize.
      */
     private[ml] trait HasStepSize extends Params {
     
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
    new file mode 100644
    index 0000000000000..f5a022c31ed90
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
    @@ -0,0 +1,70 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.api.r
    +
    +import org.apache.spark.ml.attribute._
    +import org.apache.spark.ml.feature.RFormula
    +import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
    +import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
    +import org.apache.spark.ml.{Pipeline, PipelineModel}
    +import org.apache.spark.sql.DataFrame
    +
    +private[r] object SparkRWrappers {
    +  def fitRModelFormula(
    +      value: String,
    +      df: DataFrame,
    +      family: String,
    +      lambda: Double,
    +      alpha: Double): PipelineModel = {
    +    val formula = new RFormula().setFormula(value)
    +    val estimator = family match {
    +      case "gaussian" => new LinearRegression()
    +        .setRegParam(lambda)
    +        .setElasticNetParam(alpha)
    +        .setFitIntercept(formula.hasIntercept)
    +      case "binomial" => new LogisticRegression()
    +        .setRegParam(lambda)
    +        .setElasticNetParam(alpha)
    +        .setFitIntercept(formula.hasIntercept)
    +    }
    +    val pipeline = new Pipeline().setStages(Array(formula, estimator))
    +    pipeline.fit(df)
    +  }
    +
    +  def getModelWeights(model: PipelineModel): Array[Double] = {
    +    model.stages.last match {
    +      case m: LinearRegressionModel =>
    +        Array(m.intercept) ++ m.weights.toArray
    +      case _: LogisticRegressionModel =>
    +        throw new UnsupportedOperationException(
    +          "No weights available for LogisticRegressionModel")  // SPARK-9492
    +    }
    +  }
    +
    +  def getModelFeatures(model: PipelineModel): Array[String] = {
    +    model.stages.last match {
    +      case m: LinearRegressionModel =>
    +        val attrs = AttributeGroup.fromStructField(
    +          m.summary.predictions.schema(m.summary.featuresCol))
    +        Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get)
    +      case _: LogisticRegressionModel =>
    +        throw new UnsupportedOperationException(
    +          "No features names available for LogisticRegressionModel")  // SPARK-9492
    +    }
    +  }
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
    index be1f8063d41d8..6f3340c2f02be 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
    @@ -21,10 +21,10 @@ import org.apache.spark.annotation.Experimental
     import org.apache.spark.ml.{PredictionModel, Predictor}
     import org.apache.spark.ml.param.ParamMap
     import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeRegressorParams}
    +import org.apache.spark.ml.tree.impl.RandomForest
     import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
     import org.apache.spark.mllib.linalg.Vector
     import org.apache.spark.mllib.regression.LabeledPoint
    -import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree}
     import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
     import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
     import org.apache.spark.rdd.RDD
    @@ -67,8 +67,9 @@ final class DecisionTreeRegressor(override val uid: String)
           MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
         val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
         val strategy = getOldStrategy(categoricalFeatures)
    -    val oldModel = OldDecisionTree.train(oldDataset, strategy)
    -    DecisionTreeRegressionModel.fromOld(oldModel, this, categoricalFeatures)
    +    val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
    +      seed = 0L, parentUID = Some(uid))
    +    trees.head.asInstanceOf[DecisionTreeRegressionModel]
       }
     
       /** (private[ml]) Create a Strategy instance to use with the old API. */
    @@ -102,6 +103,12 @@ final class DecisionTreeRegressionModel private[ml] (
       require(rootNode != null,
         "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.")
     
    +  /**
    +   * Construct a decision tree regression model.
    +   * @param rootNode  Root node of tree, with other nodes attached.
    +   */
    +  def this(rootNode: Node) = this(Identifiable.randomUID("dtr"), rootNode)
    +
       override protected def predict(features: Vector): Double = {
         rootNode.predict(features)
       }
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
    index 47c110d027d67..e38dc73ee0ba7 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
    @@ -33,6 +33,8 @@ import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss
     import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
     import org.apache.spark.rdd.RDD
     import org.apache.spark.sql.DataFrame
    +import org.apache.spark.sql.functions._
    +import org.apache.spark.sql.types.DoubleType
     
     /**
      * :: Experimental ::
    @@ -167,8 +169,15 @@ final class GBTRegressionModel(
     
       override def treeWeights: Array[Double] = _treeWeights
     
    +  override protected def transformImpl(dataset: DataFrame): DataFrame = {
    +    val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
    +    val predictUDF = udf { (features: Any) =>
    +      bcastModel.value.predict(features.asInstanceOf[Vector])
    +    }
    +    dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
    +  }
    +
       override protected def predict(features: Vector): Double = {
    -    // TODO: Override transform() to broadcast model. SPARK-7127
         // TODO: When we add a generic Boosting class, handle transform there?  SPARK-7129
         // Classifies by thresholding sum of weighted tree predictions
         val treePredictions = _trees.map(_.rootNode.predict(features))
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
    new file mode 100644
    index 0000000000000..4ece8cf8cf0b6
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
    @@ -0,0 +1,144 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.regression
    +
    +import org.apache.spark.annotation.Experimental
    +import org.apache.spark.ml.PredictorParams
    +import org.apache.spark.ml.param.{Param, ParamMap, BooleanParam}
    +import org.apache.spark.ml.util.{SchemaUtils, Identifiable}
    +import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression}
    +import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsotonicRegressionModel}
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.sql.types.{DoubleType, DataType}
    +import org.apache.spark.sql.{Row, DataFrame}
    +import org.apache.spark.storage.StorageLevel
    +
    +/**
    + * Params for isotonic regression.
    + */
    +private[regression] trait IsotonicRegressionParams extends PredictorParams {
    +
    +  /**
    +   * Param for weight column name.
    +   * TODO: Move weightCol to sharedParams.
    +   *
    +   * @group param
    +   */
    +  final val weightCol: Param[String] =
    +    new Param[String](this, "weightCol", "weight column name")
    +
    +  /** @group getParam */
    +  final def getWeightCol: String = $(weightCol)
    +
    +  /**
    +   * Param for isotonic parameter.
    +   * Isotonic (increasing) or antitonic (decreasing) sequence.
    +   * @group param
    +   */
    +  final val isotonic: BooleanParam =
    +    new BooleanParam(this, "isotonic", "isotonic (increasing) or antitonic (decreasing) sequence")
    +
    +  /** @group getParam */
    +  final def getIsotonicParam: Boolean = $(isotonic)
    +}
    +
    +/**
    + * :: Experimental ::
    + * Isotonic regression.
    + *
    + * Currently implemented using parallelized pool adjacent violators algorithm.
    + * Only univariate (single feature) algorithm supported.
    + *
    + * Uses [[org.apache.spark.mllib.regression.IsotonicRegression]].
    + */
    +@Experimental
    +class IsotonicRegression(override val uid: String)
    +  extends Regressor[Double, IsotonicRegression, IsotonicRegressionModel]
    +  with IsotonicRegressionParams {
    +
    +  def this() = this(Identifiable.randomUID("isoReg"))
    +
    +  /**
    +   * Set the isotonic parameter.
    +   * Default is true.
    +   * @group setParam
    +   */
    +  def setIsotonicParam(value: Boolean): this.type = set(isotonic, value)
    +  setDefault(isotonic -> true)
    +
    +  /**
    +   * Set weight column param.
    +   * Default is weight.
    +   * @group setParam
    +   */
    +  def setWeightParam(value: String): this.type = set(weightCol, value)
    +  setDefault(weightCol -> "weight")
    +
    +  override private[ml] def featuresDataType: DataType = DoubleType
    +
    +  override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra)
    +
    +  private[this] def extractWeightedLabeledPoints(
    +      dataset: DataFrame): RDD[(Double, Double, Double)] = {
    +
    +    dataset.select($(labelCol), $(featuresCol), $(weightCol))
    +      .map { case Row(label: Double, features: Double, weights: Double) =>
    +        (label, features, weights)
    +      }
    +  }
    +
    +  override protected def train(dataset: DataFrame): IsotonicRegressionModel = {
    +    SchemaUtils.checkColumnType(dataset.schema, $(weightCol), DoubleType)
    +    // Extract columns from data.  If dataset is persisted, do not persist oldDataset.
    +    val instances = extractWeightedLabeledPoints(dataset)
    +    val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
    +    if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
    +
    +    val isotonicRegression = new MLlibIsotonicRegression().setIsotonic($(isotonic))
    +    val parentModel = isotonicRegression.run(instances)
    +
    +    new IsotonicRegressionModel(uid, parentModel)
    +  }
    +}
    +
    +/**
    + * :: Experimental ::
    + * Model fitted by IsotonicRegression.
    + * Predicts using a piecewise linear function.
    + *
    + * For detailed rules see [[org.apache.spark.mllib.regression.IsotonicRegressionModel.predict()]].
    + *
    + * @param parentModel A [[org.apache.spark.mllib.regression.IsotonicRegressionModel]]
    + *                    model trained by [[org.apache.spark.mllib.regression.IsotonicRegression]].
    + */
    +class IsotonicRegressionModel private[ml] (
    +    override val uid: String,
    +    private[ml] val parentModel: MLlibIsotonicRegressionModel)
    +  extends RegressionModel[Double, IsotonicRegressionModel]
    +  with IsotonicRegressionParams {
    +
    +  override def featuresDataType: DataType = DoubleType
    +
    +  override protected def predict(features: Double): Double = {
    +    parentModel.predict(features)
    +  }
    +
    +  override def copy(extra: ParamMap): IsotonicRegressionModel = {
    +    copyValues(new IsotonicRegressionModel(uid, parentModel), extra)
    +  }
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
    index 1b1d7299fb496..3b85ba001b128 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
    @@ -22,18 +22,21 @@ import scala.collection.mutable
     import breeze.linalg.{DenseVector => BDV, norm => brzNorm}
     import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
     
    -import org.apache.spark.Logging
    +import org.apache.spark.{Logging, SparkException}
     import org.apache.spark.annotation.Experimental
     import org.apache.spark.ml.PredictorParams
     import org.apache.spark.ml.param.ParamMap
     import org.apache.spark.ml.param.shared._
     import org.apache.spark.ml.util.Identifiable
    +import org.apache.spark.mllib.evaluation.RegressionMetrics
     import org.apache.spark.mllib.linalg.{Vector, Vectors}
     import org.apache.spark.mllib.linalg.BLAS._
     import org.apache.spark.mllib.regression.LabeledPoint
     import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
     import org.apache.spark.rdd.RDD
    -import org.apache.spark.sql.DataFrame
    +import org.apache.spark.sql.{DataFrame, Row}
    +import org.apache.spark.sql.functions.{col, udf}
    +import org.apache.spark.sql.types.StructField
     import org.apache.spark.storage.StorageLevel
     import org.apache.spark.util.StatCounter
     
    @@ -132,7 +135,6 @@ class LinearRegression(override val uid: String)
         val numFeatures = summarizer.mean.size
         val yMean = statCounter.mean
         val yStd = math.sqrt(statCounter.variance)
    -    // look at glmnet5.m L761 maaaybe that has info
     
         // If the yStd is zero, then the intercept is yMean with zero weights;
         // as a result, training is not needed.
    @@ -140,7 +142,17 @@ class LinearRegression(override val uid: String)
           logWarning(s"The standard deviation of the label is zero, so the weights will be zeros " +
             s"and the intercept will be the mean of the label; as a result, training is not needed.")
           if (handlePersistence) instances.unpersist()
    -      return new LinearRegressionModel(uid, Vectors.sparse(numFeatures, Seq()), yMean)
    +      val weights = Vectors.sparse(numFeatures, Seq())
    +      val intercept = yMean
    +
    +      val model = new LinearRegressionModel(uid, weights, intercept)
    +      val trainingSummary = new LinearRegressionTrainingSummary(
    +        model.transform(dataset),
    +        $(predictionCol),
    +        $(labelCol),
    +        $(featuresCol),
    +        Array(0D))
    +      return copyValues(model.setSummary(trainingSummary))
         }
     
         val featuresMean = summarizer.mean.toArray
    @@ -162,21 +174,33 @@ class LinearRegression(override val uid: String)
         }
     
         val initialWeights = Vectors.zeros(numFeatures)
    -    val states =
    -      optimizer.iterations(new CachedDiffFunction(costFun), initialWeights.toBreeze.toDenseVector)
    -
    -    var state = states.next()
    -    val lossHistory = mutable.ArrayBuilder.make[Double]
    -
    -    while (states.hasNext) {
    -      lossHistory += state.value
    -      state = states.next()
    -    }
    -    lossHistory += state.value
    +    val states = optimizer.iterations(new CachedDiffFunction(costFun),
    +      initialWeights.toBreeze.toDenseVector)
    +
    +    val (weights, objectiveHistory) = {
    +      /*
    +         Note that in Linear Regression, the objective history (loss + regularization) returned
    +         from optimizer is computed in the scaled space given by the following formula.
    +         {{{
    +         L = 1/2n||\sum_i w_i(x_i - \bar{x_i}) / \hat{x_i} - (y - \bar{y}) / \hat{y}||^2 + regTerms
    +         }}}
    +       */
    +      val arrayBuilder = mutable.ArrayBuilder.make[Double]
    +      var state: optimizer.State = null
    +      while (states.hasNext) {
    +        state = states.next()
    +        arrayBuilder += state.adjustedValue
    +      }
    +      if (state == null) {
    +        val msg = s"${optimizer.getClass.getName} failed."
    +        logError(msg)
    +        throw new SparkException(msg)
    +      }
     
    -    // The weights are trained in the scaled space; we're converting them back to
    -    // the original space.
    -    val weights = {
    +      /*
    +         The weights are trained in the scaled space; we're converting them back to
    +         the original space.
    +       */
           val rawWeights = state.x.toArray.clone()
           var i = 0
           val len = rawWeights.length
    @@ -184,17 +208,27 @@ class LinearRegression(override val uid: String)
             rawWeights(i) *= { if (featuresStd(i) != 0.0) yStd / featuresStd(i) else 0.0 }
             i += 1
           }
    -      Vectors.dense(rawWeights)
    +
    +      (Vectors.dense(rawWeights).compressed, arrayBuilder.result())
         }
     
    -    // The intercept in R's GLMNET is computed using closed form after the coefficients are
    -    // converged. See the following discussion for detail.
    -    // http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
    +    /*
    +       The intercept in R's GLMNET is computed using closed form after the coefficients are
    +       converged. See the following discussion for detail.
    +       http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
    +     */
         val intercept = if ($(fitIntercept)) yMean - dot(weights, Vectors.dense(featuresMean)) else 0.0
    +
         if (handlePersistence) instances.unpersist()
     
    -    // TODO: Converts to sparse format based on the storage, but may base on the scoring speed.
    -    copyValues(new LinearRegressionModel(uid, weights.compressed, intercept))
    +    val model = copyValues(new LinearRegressionModel(uid, weights, intercept))
    +    val trainingSummary = new LinearRegressionTrainingSummary(
    +      model.transform(dataset),
    +      $(predictionCol),
    +      $(labelCol),
    +      $(featuresCol),
    +      objectiveHistory)
    +    model.setSummary(trainingSummary)
       }
     
       override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra)
    @@ -212,13 +246,125 @@ class LinearRegressionModel private[ml] (
       extends RegressionModel[Vector, LinearRegressionModel]
       with LinearRegressionParams {
     
    +  private var trainingSummary: Option[LinearRegressionTrainingSummary] = None
    +
    +  /**
    +   * Gets summary (e.g. residuals, mse, r-squared ) of model on training set. An exception is
    +   * thrown if `trainingSummary == None`.
    +   */
    +  def summary: LinearRegressionTrainingSummary = trainingSummary match {
    +    case Some(summ) => summ
    +    case None =>
    +      throw new SparkException(
    +        "No training summary available for this LinearRegressionModel",
    +        new NullPointerException())
    +  }
    +
    +  private[regression] def setSummary(summary: LinearRegressionTrainingSummary): this.type = {
    +    this.trainingSummary = Some(summary)
    +    this
    +  }
    +
    +  /** Indicates whether a training summary exists for this model instance. */
    +  def hasSummary: Boolean = trainingSummary.isDefined
    +
    +  /**
    +   * Evaluates the model on a testset.
    +   * @param dataset Test dataset to evaluate model on.
    +   */
    +  // TODO: decide on a good name before exposing to public API
    +  private[regression] def evaluate(dataset: DataFrame): LinearRegressionSummary = {
    +    val t = udf { features: Vector => predict(features) }
    +    val predictionAndObservations = dataset
    +      .select(col($(labelCol)), t(col($(featuresCol))).as($(predictionCol)))
    +
    +    new LinearRegressionSummary(predictionAndObservations, $(predictionCol), $(labelCol))
    +  }
    +
       override protected def predict(features: Vector): Double = {
         dot(features, weights) + intercept
       }
     
       override def copy(extra: ParamMap): LinearRegressionModel = {
    -    copyValues(new LinearRegressionModel(uid, weights, intercept), extra)
    +    val newModel = copyValues(new LinearRegressionModel(uid, weights, intercept))
    +    if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get)
    +    newModel
    +  }
    +}
    +
    +/**
    + * :: Experimental ::
    + * Linear regression training results.
    + * @param predictions predictions outputted by the model's `transform` method.
    + * @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
    + */
    +@Experimental
    +class LinearRegressionTrainingSummary private[regression] (
    +    predictions: DataFrame,
    +    predictionCol: String,
    +    labelCol: String,
    +    val featuresCol: String,
    +    val objectiveHistory: Array[Double])
    +  extends LinearRegressionSummary(predictions, predictionCol, labelCol) {
    +
    +  /** Number of training iterations until termination */
    +  val totalIterations = objectiveHistory.length
    +
    +}
    +
    +/**
    + * :: Experimental ::
    + * Linear regression results evaluated on a dataset.
    + * @param predictions predictions outputted by the model's `transform` method.
    + */
    +@Experimental
    +class LinearRegressionSummary private[regression] (
    +    @transient val predictions: DataFrame,
    +    val predictionCol: String,
    +    val labelCol: String) extends Serializable {
    +
    +  @transient private val metrics = new RegressionMetrics(
    +    predictions
    +      .select(predictionCol, labelCol)
    +      .map { case Row(pred: Double, label: Double) => (pred, label) } )
    +
    +  /**
    +   * Returns the explained variance regression score.
    +   * explainedVariance = 1 - variance(y - \hat{y}) / variance(y)
    +   * Reference: [[http://en.wikipedia.org/wiki/Explained_variation]]
    +   */
    +  val explainedVariance: Double = metrics.explainedVariance
    +
    +  /**
    +   * Returns the mean absolute error, which is a risk function corresponding to the
    +   * expected value of the absolute error loss or l1-norm loss.
    +   */
    +  val meanAbsoluteError: Double = metrics.meanAbsoluteError
    +
    +  /**
    +   * Returns the mean squared error, which is a risk function corresponding to the
    +   * expected value of the squared error loss or quadratic loss.
    +   */
    +  val meanSquaredError: Double = metrics.meanSquaredError
    +
    +  /**
    +   * Returns the root mean squared error, which is defined as the square root of
    +   * the mean squared error.
    +   */
    +  val rootMeanSquaredError: Double = metrics.rootMeanSquaredError
    +
    +  /**
    +   * Returns R^2^, the coefficient of determination.
    +   * Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]]
    +   */
    +  val r2: Double = metrics.r2
    +
    +  /** Residuals (label - predicted value) */
    +  @transient lazy val residuals: DataFrame = {
    +    val t = udf { (pred: Double, label: Double) => label - pred }
    +    predictions.select(t(col(predictionCol), col(labelCol)).as("residuals"))
       }
    +
     }
     
     /**
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
    index 21c59061a02fa..506a878c2553b 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
    @@ -21,14 +21,16 @@ import org.apache.spark.annotation.Experimental
     import org.apache.spark.ml.{PredictionModel, Predictor}
     import org.apache.spark.ml.param.ParamMap
     import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeEnsembleModel, TreeRegressorParams}
    +import org.apache.spark.ml.tree.impl.RandomForest
     import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
     import org.apache.spark.mllib.linalg.Vector
     import org.apache.spark.mllib.regression.LabeledPoint
    -import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest}
     import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
     import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
     import org.apache.spark.rdd.RDD
     import org.apache.spark.sql.DataFrame
    +import org.apache.spark.sql.functions._
    +import org.apache.spark.sql.types.DoubleType
     
     /**
      * :: Experimental ::
    @@ -82,9 +84,10 @@ final class RandomForestRegressor(override val uid: String)
         val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
         val strategy =
           super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity)
    -    val oldModel = OldRandomForest.trainRegressor(
    -      oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt)
    -    RandomForestRegressionModel.fromOld(oldModel, this, categoricalFeatures)
    +    val trees =
    +      RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed)
    +        .map(_.asInstanceOf[DecisionTreeRegressionModel])
    +    new RandomForestRegressionModel(trees)
       }
     
       override def copy(extra: ParamMap): RandomForestRegressor = defaultCopy(extra)
    @@ -115,6 +118,12 @@ final class RandomForestRegressionModel private[ml] (
     
       require(numTrees > 0, "RandomForestRegressionModel requires at least 1 tree.")
     
    +  /**
    +   * Construct a random forest regression model, with all trees weighted equally.
    +   * @param trees  Component trees
    +   */
    +  def this(trees: Array[DecisionTreeRegressionModel]) = this(Identifiable.randomUID("rfr"), trees)
    +
       override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
     
       // Note: We may add support for weights (based on tree performance) later on.
    @@ -122,8 +131,15 @@ final class RandomForestRegressionModel private[ml] (
     
       override def treeWeights: Array[Double] = _treeWeights
     
    +  override protected def transformImpl(dataset: DataFrame): DataFrame = {
    +    val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
    +    val predictUDF = udf { (features: Any) =>
    +      bcastModel.value.predict(features.asInstanceOf[Vector])
    +    }
    +    dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
    +  }
    +
       override protected def predict(features: Vector): Double = {
    -    // TODO: Override transform() to broadcast model.  SPARK-7127
         // TODO: When we add a generic Bagging class, handle transform there.  SPARK-7128
         // Predict average of tree predictions.
         // Ignore the weights since all are 1.0 for now.
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
    index 4242154be14ce..bbc2427ca7d3d 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
    @@ -209,3 +209,132 @@ private object InternalNode {
         }
       }
     }
    +
    +/**
    + * Version of a node used in learning.  This uses vars so that we can modify nodes as we split the
    + * tree by adding children, etc.
    + *
    + * For now, we use node IDs.  These will be kept internal since we hope to remove node IDs
    + * in the future, or at least change the indexing (so that we can support much deeper trees).
    + *
    + * This node can either be:
    + *  - a leaf node, with leftChild, rightChild, split set to null, or
    + *  - an internal node, with all values set
    + *
    + * @param id  We currently use the same indexing as the old implementation in
    + *            [[org.apache.spark.mllib.tree.model.Node]], but this will change later.
    + * @param predictionStats  Predicted label + class probability (for classification).
    + *                         We will later modify this to store aggregate statistics for labels
    + *                         to provide all class probabilities (for classification) and maybe a
    + *                         distribution (for regression).
    + * @param isLeaf  Indicates whether this node will definitely be a leaf in the learned tree,
    + *                so that we do not need to consider splitting it further.
    + * @param stats  Old structure for storing stats about information gain, prediction, etc.
    + *               This is legacy and will be modified in the future.
    + */
    +private[tree] class LearningNode(
    +    var id: Int,
    +    var predictionStats: OldPredict,
    +    var impurity: Double,
    +    var leftChild: Option[LearningNode],
    +    var rightChild: Option[LearningNode],
    +    var split: Option[Split],
    +    var isLeaf: Boolean,
    +    var stats: Option[OldInformationGainStats]) extends Serializable {
    +
    +  /**
    +   * Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children.
    +   */
    +  def toNode: Node = {
    +    if (leftChild.nonEmpty) {
    +      assert(rightChild.nonEmpty && split.nonEmpty && stats.nonEmpty,
    +        "Unknown error during Decision Tree learning.  Could not convert LearningNode to Node.")
    +      new InternalNode(predictionStats.predict, impurity, stats.get.gain,
    +        leftChild.get.toNode, rightChild.get.toNode, split.get)
    +    } else {
    +      new LeafNode(predictionStats.predict, impurity)
    +    }
    +  }
    +
    +}
    +
    +private[tree] object LearningNode {
    +
    +  /** Create a node with some of its fields set. */
    +  def apply(
    +      id: Int,
    +      predictionStats: OldPredict,
    +      impurity: Double,
    +      isLeaf: Boolean): LearningNode = {
    +    new LearningNode(id, predictionStats, impurity, None, None, None, false, None)
    +  }
    +
    +  /** Create an empty node with the given node index.  Values must be set later on. */
    +  def emptyNode(nodeIndex: Int): LearningNode = {
    +    new LearningNode(nodeIndex, new OldPredict(Double.NaN, Double.NaN), Double.NaN,
    +      None, None, None, false, None)
    +  }
    +
    +  // The below indexing methods were copied from spark.mllib.tree.model.Node
    +
    +  /**
    +   * Return the index of the left child of this node.
    +   */
    +  def leftChildIndex(nodeIndex: Int): Int = nodeIndex << 1
    +
    +  /**
    +   * Return the index of the right child of this node.
    +   */
    +  def rightChildIndex(nodeIndex: Int): Int = (nodeIndex << 1) + 1
    +
    +  /**
    +   * Get the parent index of the given node, or 0 if it is the root.
    +   */
    +  def parentIndex(nodeIndex: Int): Int = nodeIndex >> 1
    +
    +  /**
    +   * Return the level of a tree which the given node is in.
    +   */
    +  def indexToLevel(nodeIndex: Int): Int = if (nodeIndex == 0) {
    +    throw new IllegalArgumentException(s"0 is not a valid node index.")
    +  } else {
    +    java.lang.Integer.numberOfTrailingZeros(java.lang.Integer.highestOneBit(nodeIndex))
    +  }
    +
    +  /**
    +   * Returns true if this is a left child.
    +   * Note: Returns false for the root.
    +   */
    +  def isLeftChild(nodeIndex: Int): Boolean = nodeIndex > 1 && nodeIndex % 2 == 0
    +
    +  /**
    +   * Return the maximum number of nodes which can be in the given level of the tree.
    +   * @param level  Level of tree (0 = root).
    +   */
    +  def maxNodesInLevel(level: Int): Int = 1 << level
    +
    +  /**
    +   * Return the index of the first node in the given level.
    +   * @param level  Level of tree (0 = root).
    +   */
    +  def startIndexInLevel(level: Int): Int = 1 << level
    +
    +  /**
    +   * Traces down from a root node to get the node with the given node index.
    +   * This assumes the node exists.
    +   */
    +  def getNode(nodeIndex: Int, rootNode: LearningNode): LearningNode = {
    +    var tmpNode: LearningNode = rootNode
    +    var levelsToGo = indexToLevel(nodeIndex)
    +    while (levelsToGo > 0) {
    +      if ((nodeIndex & (1 << levelsToGo - 1)) == 0) {
    +        tmpNode = tmpNode.leftChild.asInstanceOf[LearningNode]
    +      } else {
    +        tmpNode = tmpNode.rightChild.asInstanceOf[LearningNode]
    +      }
    +      levelsToGo -= 1
    +    }
    +    tmpNode
    +  }
    +
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
    index 7acdeeee72d23..78199cc2df582 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
    @@ -34,9 +34,19 @@ sealed trait Split extends Serializable {
       /** Index of feature which this split tests */
       def featureIndex: Int
     
    -  /** Return true (split to left) or false (split to right) */
    +  /**
    +   * Return true (split to left) or false (split to right).
    +   * @param features  Vector of features (original values, not binned).
    +   */
       private[ml] def shouldGoLeft(features: Vector): Boolean
     
    +  /**
    +   * Return true (split to left) or false (split to right).
    +   * @param binnedFeature Binned feature value.
    +   * @param splits All splits for the given feature.
    +   */
    +  private[tree] def shouldGoLeft(binnedFeature: Int, splits: Array[Split]): Boolean
    +
       /** Convert to old Split format */
       private[tree] def toOld: OldSplit
     }
    @@ -94,6 +104,14 @@ final class CategoricalSplit private[ml] (
         }
       }
     
    +  override private[tree] def shouldGoLeft(binnedFeature: Int, splits: Array[Split]): Boolean = {
    +    if (isLeft) {
    +      categories.contains(binnedFeature.toDouble)
    +    } else {
    +      !categories.contains(binnedFeature.toDouble)
    +    }
    +  }
    +
       override def equals(o: Any): Boolean = {
         o match {
           case other: CategoricalSplit => featureIndex == other.featureIndex &&
    @@ -144,6 +162,16 @@ final class ContinuousSplit private[ml] (override val featureIndex: Int, val thr
         features(featureIndex) <= threshold
       }
     
    +  override private[tree] def shouldGoLeft(binnedFeature: Int, splits: Array[Split]): Boolean = {
    +    if (binnedFeature == splits.length) {
    +      // > last split, so split right
    +      false
    +    } else {
    +      val featureValueUpperBound = splits(binnedFeature).asInstanceOf[ContinuousSplit].threshold
    +      featureValueUpperBound <= threshold
    +    }
    +  }
    +
       override def equals(o: Any): Boolean = {
         o match {
           case other: ContinuousSplit =>
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
    new file mode 100644
    index 0000000000000..488e8e4fb5dcd
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
    @@ -0,0 +1,194 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.tree.impl
    +
    +import java.io.IOException
    +
    +import scala.collection.mutable
    +
    +import org.apache.hadoop.fs.{Path, FileSystem}
    +
    +import org.apache.spark.Logging
    +import org.apache.spark.annotation.DeveloperApi
    +import org.apache.spark.ml.tree.{LearningNode, Split}
    +import org.apache.spark.mllib.tree.impl.BaggedPoint
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.storage.StorageLevel
    +
    +
    +/**
    + * This is used by the node id cache to find the child id that a data point would belong to.
    + * @param split Split information.
    + * @param nodeIndex The current node index of a data point that this will update.
    + */
    +private[tree] case class NodeIndexUpdater(split: Split, nodeIndex: Int) {
    +
    +  /**
    +   * Determine a child node index based on the feature value and the split.
    +   * @param binnedFeature Binned feature value.
    +   * @param splits Split information to convert the bin indices to approximate feature values.
    +   * @return Child node index to update to.
    +   */
    +  def updateNodeIndex(binnedFeature: Int, splits: Array[Split]): Int = {
    +    if (split.shouldGoLeft(binnedFeature, splits)) {
    +      LearningNode.leftChildIndex(nodeIndex)
    +    } else {
    +      LearningNode.rightChildIndex(nodeIndex)
    +    }
    +  }
    +}
    +
    +/**
    + * Each TreePoint belongs to a particular node per tree.
    + * Each row in the nodeIdsForInstances RDD is an array over trees of the node index
    + * in each tree. Initially, values should all be 1 for root node.
    + * The nodeIdsForInstances RDD needs to be updated at each iteration.
    + * @param nodeIdsForInstances The initial values in the cache
    + *                           (should be an Array of all 1's (meaning the root nodes)).
    + * @param checkpointInterval The checkpointing interval
    + *                           (how often should the cache be checkpointed.).
    + */
    +private[spark] class NodeIdCache(
    +  var nodeIdsForInstances: RDD[Array[Int]],
    +  val checkpointInterval: Int) extends Logging {
    +
    +  // Keep a reference to a previous node Ids for instances.
    +  // Because we will keep on re-persisting updated node Ids,
    +  // we want to unpersist the previous RDD.
    +  private var prevNodeIdsForInstances: RDD[Array[Int]] = null
    +
    +  // To keep track of the past checkpointed RDDs.
    +  private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]()
    +  private var rddUpdateCount = 0
    +
    +  // Indicates whether we can checkpoint
    +  private val canCheckpoint = nodeIdsForInstances.sparkContext.getCheckpointDir.nonEmpty
    +
    +  // FileSystem instance for deleting checkpoints as needed
    +  private val fs = FileSystem.get(nodeIdsForInstances.sparkContext.hadoopConfiguration)
    +
    +  /**
    +   * Update the node index values in the cache.
    +   * This updates the RDD and its lineage.
    +   * TODO: Passing bin information to executors seems unnecessary and costly.
    +   * @param data The RDD of training rows.
    +   * @param nodeIdUpdaters A map of node index updaters.
    +   *                       The key is the indices of nodes that we want to update.
    +   * @param splits  Split information needed to find child node indices.
    +   */
    +  def updateNodeIndices(
    +      data: RDD[BaggedPoint[TreePoint]],
    +      nodeIdUpdaters: Array[mutable.Map[Int, NodeIndexUpdater]],
    +      splits: Array[Array[Split]]): Unit = {
    +    if (prevNodeIdsForInstances != null) {
    +      // Unpersist the previous one if one exists.
    +      prevNodeIdsForInstances.unpersist()
    +    }
    +
    +    prevNodeIdsForInstances = nodeIdsForInstances
    +    nodeIdsForInstances = data.zip(nodeIdsForInstances).map { case (point, ids) =>
    +      var treeId = 0
    +      while (treeId < nodeIdUpdaters.length) {
    +        val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(ids(treeId), null)
    +        if (nodeIdUpdater != null) {
    +          val featureIndex = nodeIdUpdater.split.featureIndex
    +          val newNodeIndex = nodeIdUpdater.updateNodeIndex(
    +            binnedFeature = point.datum.binnedFeatures(featureIndex),
    +            splits = splits(featureIndex))
    +          ids(treeId) = newNodeIndex
    +        }
    +        treeId += 1
    +      }
    +      ids
    +    }
    +
    +    // Keep on persisting new ones.
    +    nodeIdsForInstances.persist(StorageLevel.MEMORY_AND_DISK)
    +    rddUpdateCount += 1
    +
    +    // Handle checkpointing if the directory is not None.
    +    if (canCheckpoint && (rddUpdateCount % checkpointInterval) == 0) {
    +      // Let's see if we can delete previous checkpoints.
    +      var canDelete = true
    +      while (checkpointQueue.size > 1 && canDelete) {
    +        // We can delete the oldest checkpoint iff
    +        // the next checkpoint actually exists in the file system.
    +        if (checkpointQueue(1).getCheckpointFile.isDefined) {
    +          val old = checkpointQueue.dequeue()
    +          // Since the old checkpoint is not deleted by Spark, we'll manually delete it here.
    +          try {
    +            fs.delete(new Path(old.getCheckpointFile.get), true)
    +          } catch {
    +            case e: IOException =>
    +              logError("Decision Tree learning using cacheNodeIds failed to remove checkpoint" +
    +                s" file: ${old.getCheckpointFile.get}")
    +          }
    +        } else {
    +          canDelete = false
    +        }
    +      }
    +
    +      nodeIdsForInstances.checkpoint()
    +      checkpointQueue.enqueue(nodeIdsForInstances)
    +    }
    +  }
    +
    +  /**
    +   * Call this after training is finished to delete any remaining checkpoints.
    +   */
    +  def deleteAllCheckpoints(): Unit = {
    +    while (checkpointQueue.nonEmpty) {
    +      val old = checkpointQueue.dequeue()
    +      if (old.getCheckpointFile.isDefined) {
    +        try {
    +          fs.delete(new Path(old.getCheckpointFile.get), true)
    +        } catch {
    +          case e: IOException =>
    +            logError("Decision Tree learning using cacheNodeIds failed to remove checkpoint" +
    +              s" file: ${old.getCheckpointFile.get}")
    +        }
    +      }
    +    }
    +  }
    +  if (prevNodeIdsForInstances != null) {
    +    // Unpersist the previous one if one exists.
    +    prevNodeIdsForInstances.unpersist()
    +  }
    +}
    +
    +@DeveloperApi
    +private[spark] object NodeIdCache {
    +  /**
    +   * Initialize the node Id cache with initial node Id values.
    +   * @param data The RDD of training rows.
    +   * @param numTrees The number of trees that we want to create cache for.
    +   * @param checkpointInterval The checkpointing interval
    +   *                           (how often should the cache be checkpointed.).
    +   * @param initVal The initial values in the cache.
    +   * @return A node Id cache containing an RDD of initial root node Indices.
    +   */
    +  def init(
    +      data: RDD[BaggedPoint[TreePoint]],
    +      numTrees: Int,
    +      checkpointInterval: Int,
    +      initVal: Int = 1): NodeIdCache = {
    +    new NodeIdCache(
    +      data.map(_ => Array.fill[Int](numTrees)(initVal)),
    +      checkpointInterval)
    +  }
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
    new file mode 100644
    index 0000000000000..15b56bd844bad
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
    @@ -0,0 +1,1132 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.tree.impl
    +
    +import java.io.IOException
    +
    +import scala.collection.mutable
    +import scala.util.Random
    +
    +import org.apache.spark.Logging
    +import org.apache.spark.ml.classification.DecisionTreeClassificationModel
    +import org.apache.spark.ml.regression.DecisionTreeRegressionModel
    +import org.apache.spark.ml.tree._
    +import org.apache.spark.mllib.regression.LabeledPoint
    +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
    +import org.apache.spark.mllib.tree.impl.{BaggedPoint, DTStatsAggregator, DecisionTreeMetadata,
    +  TimeTracker}
    +import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
    +import org.apache.spark.mllib.tree.model.{InformationGainStats, Predict}
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.storage.StorageLevel
    +import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}
    +
    +
    +private[ml] object RandomForest extends Logging {
    +
    +  /**
    +   * Train a random forest.
    +   * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
    +   * @return an unweighted set of trees
    +   */
    +  def run(
    +      input: RDD[LabeledPoint],
    +      strategy: OldStrategy,
    +      numTrees: Int,
    +      featureSubsetStrategy: String,
    +      seed: Long,
    +      parentUID: Option[String] = None): Array[DecisionTreeModel] = {
    +
    +    val timer = new TimeTracker()
    +
    +    timer.start("total")
    +
    +    timer.start("init")
    +
    +    val retaggedInput = input.retag(classOf[LabeledPoint])
    +    val metadata =
    +      DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy)
    +    logDebug("algo = " + strategy.algo)
    +    logDebug("numTrees = " + numTrees)
    +    logDebug("seed = " + seed)
    +    logDebug("maxBins = " + metadata.maxBins)
    +    logDebug("featureSubsetStrategy = " + featureSubsetStrategy)
    +    logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode)
    +    logDebug("subsamplingRate = " + strategy.subsamplingRate)
    +
    +    // Find the splits and the corresponding bins (interval between the splits) using a sample
    +    // of the input data.
    +    timer.start("findSplitsBins")
    +    val splits = findSplits(retaggedInput, metadata)
    +    timer.stop("findSplitsBins")
    +    logDebug("numBins: feature: number of bins")
    +    logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
    +      s"\t$featureIndex\t${metadata.numBins(featureIndex)}"
    +    }.mkString("\n"))
    +
    +    // Bin feature values (TreePoint representation).
    +    // Cache input RDD for speedup during multiple passes.
    +    val treeInput = TreePoint.convertToTreeRDD(retaggedInput, splits, metadata)
    +
    +    val withReplacement = numTrees > 1
    +
    +    val baggedInput = BaggedPoint
    +      .convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, withReplacement, seed)
    +      .persist(StorageLevel.MEMORY_AND_DISK)
    +
    +    // depth of the decision tree
    +    val maxDepth = strategy.maxDepth
    +    require(maxDepth <= 30,
    +      s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.")
    +
    +    // Max memory usage for aggregates
    +    // TODO: Calculate memory usage more precisely.
    +    val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L
    +    logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
    +    val maxMemoryPerNode = {
    +      val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
    +        // Find numFeaturesPerNode largest bins to get an upper bound on memory usage.
    +        Some(metadata.numBins.zipWithIndex.sortBy(- _._1)
    +          .take(metadata.numFeaturesPerNode).map(_._2))
    +      } else {
    +        None
    +      }
    +      RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
    +    }
    +    require(maxMemoryPerNode <= maxMemoryUsage,
    +      s"RandomForest/DecisionTree given maxMemoryInMB = ${strategy.maxMemoryInMB}," +
    +        " which is too small for the given features." +
    +        s"  Minimum value = ${maxMemoryPerNode / (1024L * 1024L)}")
    +
    +    timer.stop("init")
    +
    +    /*
    +     * The main idea here is to perform group-wise training of the decision tree nodes thus
    +     * reducing the passes over the data from (# nodes) to (# nodes / maxNumberOfNodesPerGroup).
    +     * Each data sample is handled by a particular node (or it reaches a leaf and is not used
    +     * in lower levels).
    +     */
    +
    +    // Create an RDD of node Id cache.
    +    // At first, all the rows belong to the root nodes (node Id == 1).
    +    val nodeIdCache = if (strategy.useNodeIdCache) {
    +      Some(NodeIdCache.init(
    +        data = baggedInput,
    +        numTrees = numTrees,
    +        checkpointInterval = strategy.checkpointInterval,
    +        initVal = 1))
    +    } else {
    +      None
    +    }
    +
    +    // FIFO queue of nodes to train: (treeIndex, node)
    +    val nodeQueue = new mutable.Queue[(Int, LearningNode)]()
    +
    +    val rng = new Random()
    +    rng.setSeed(seed)
    +
    +    // Allocate and queue root nodes.
    +    val topNodes = Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(nodeIndex = 1))
    +    Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex))))
    +
    +    while (nodeQueue.nonEmpty) {
    +      // Collect some nodes to split, and choose features for each node (if subsampling).
    +      // Each group of nodes may come from one or multiple trees, and at multiple levels.
    +      val (nodesForGroup, treeToNodeToIndexInfo) =
    +        RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)
    +      // Sanity check (should never occur):
    +      assert(nodesForGroup.nonEmpty,
    +        s"RandomForest selected empty nodesForGroup.  Error for unknown reason.")
    +
    +      // Choose node splits, and enqueue new nodes as needed.
    +      timer.start("findBestSplits")
    +      RandomForest.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
    +        treeToNodeToIndexInfo, splits, nodeQueue, timer, nodeIdCache)
    +      timer.stop("findBestSplits")
    +    }
    +
    +    baggedInput.unpersist()
    +
    +    timer.stop("total")
    +
    +    logInfo("Internal timing for DecisionTree:")
    +    logInfo(s"$timer")
    +
    +    // Delete any remaining checkpoints used for node Id cache.
    +    if (nodeIdCache.nonEmpty) {
    +      try {
    +        nodeIdCache.get.deleteAllCheckpoints()
    +      } catch {
    +        case e: IOException =>
    +          logWarning(s"delete all checkpoints failed. Error reason: ${e.getMessage}")
    +      }
    +    }
    +
    +    parentUID match {
    +      case Some(uid) =>
    +        if (strategy.algo == OldAlgo.Classification) {
    +          topNodes.map(rootNode => new DecisionTreeClassificationModel(uid, rootNode.toNode))
    +        } else {
    +          topNodes.map(rootNode => new DecisionTreeRegressionModel(uid, rootNode.toNode))
    +        }
    +      case None =>
    +        if (strategy.algo == OldAlgo.Classification) {
    +          topNodes.map(rootNode => new DecisionTreeClassificationModel(rootNode.toNode))
    +        } else {
    +          topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode))
    +        }
    +    }
    +  }
    +
    +  /**
    +   * Get the node index corresponding to this data point.
    +   * This function mimics prediction, passing an example from the root node down to a leaf
    +   * or unsplit node; that node's index is returned.
    +   *
    +   * @param node  Node in tree from which to classify the given data point.
    +   * @param binnedFeatures  Binned feature vector for data point.
    +   * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
    +   * @return  Leaf index if the data point reaches a leaf.
    +   *          Otherwise, last node reachable in tree matching this example.
    +   *          Note: This is the global node index, i.e., the index used in the tree.
    +   *                This index is different from the index used during training a particular
    +   *                group of nodes on one call to [[findBestSplits()]].
    +   */
    +  private def predictNodeIndex(
    +      node: LearningNode,
    +      binnedFeatures: Array[Int],
    +      splits: Array[Array[Split]]): Int = {
    +    if (node.isLeaf || node.split.isEmpty) {
    +      node.id
    +    } else {
    +      val split = node.split.get
    +      val featureIndex = split.featureIndex
    +      val splitLeft = split.shouldGoLeft(binnedFeatures(featureIndex), splits(featureIndex))
    +      if (node.leftChild.isEmpty) {
    +        // Not yet split. Return index from next layer of nodes to train
    +        if (splitLeft) {
    +          LearningNode.leftChildIndex(node.id)
    +        } else {
    +          LearningNode.rightChildIndex(node.id)
    +        }
    +      } else {
    +        if (splitLeft) {
    +          predictNodeIndex(node.leftChild.get, binnedFeatures, splits)
    +        } else {
    +          predictNodeIndex(node.rightChild.get, binnedFeatures, splits)
    +        }
    +      }
    +    }
    +  }
    +
    +  /**
    +   * Helper for binSeqOp, for data which can contain a mix of ordered and unordered features.
    +   *
    +   * For ordered features, a single bin is updated.
    +   * For unordered features, bins correspond to subsets of categories; either the left or right bin
    +   * for each subset is updated.
    +   *
    +   * @param agg  Array storing aggregate calculation, with a set of sufficient statistics for
    +   *             each (feature, bin).
    +   * @param treePoint  Data point being aggregated.
    +   * @param splits possible splits indexed (numFeatures)(numSplits)
    +   * @param unorderedFeatures  Set of indices of unordered features.
    +   * @param instanceWeight  Weight (importance) of instance in dataset.
    +   */
    +  private def mixedBinSeqOp(
    +      agg: DTStatsAggregator,
    +      treePoint: TreePoint,
    +      splits: Array[Array[Split]],
    +      unorderedFeatures: Set[Int],
    +      instanceWeight: Double,
    +      featuresForNode: Option[Array[Int]]): Unit = {
    +    val numFeaturesPerNode = if (featuresForNode.nonEmpty) {
    +      // Use subsampled features
    +      featuresForNode.get.length
    +    } else {
    +      // Use all features
    +      agg.metadata.numFeatures
    +    }
    +    // Iterate over features.
    +    var featureIndexIdx = 0
    +    while (featureIndexIdx < numFeaturesPerNode) {
    +      val featureIndex = if (featuresForNode.nonEmpty) {
    +        featuresForNode.get.apply(featureIndexIdx)
    +      } else {
    +        featureIndexIdx
    +      }
    +      if (unorderedFeatures.contains(featureIndex)) {
    +        // Unordered feature
    +        val featureValue = treePoint.binnedFeatures(featureIndex)
    +        val (leftNodeFeatureOffset, rightNodeFeatureOffset) =
    +          agg.getLeftRightFeatureOffsets(featureIndexIdx)
    +        // Update the left or right bin for each split.
    +        val numSplits = agg.metadata.numSplits(featureIndex)
    +        val featureSplits = splits(featureIndex)
    +        var splitIndex = 0
    +        while (splitIndex < numSplits) {
    +          if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) {
    +            agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight)
    +          } else {
    +            agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight)
    +          }
    +          splitIndex += 1
    +        }
    +      } else {
    +        // Ordered feature
    +        val binIndex = treePoint.binnedFeatures(featureIndex)
    +        agg.update(featureIndexIdx, binIndex, treePoint.label, instanceWeight)
    +      }
    +      featureIndexIdx += 1
    +    }
    +  }
    +
    +  /**
    +   * Helper for binSeqOp, for regression and for classification with only ordered features.
    +   *
    +   * For each feature, the sufficient statistics of one bin are updated.
    +   *
    +   * @param agg  Array storing aggregate calculation, with a set of sufficient statistics for
    +   *             each (feature, bin).
    +   * @param treePoint  Data point being aggregated.
    +   * @param instanceWeight  Weight (importance) of instance in dataset.
    +   */
    +  private def orderedBinSeqOp(
    +      agg: DTStatsAggregator,
    +      treePoint: TreePoint,
    +      instanceWeight: Double,
    +      featuresForNode: Option[Array[Int]]): Unit = {
    +    val label = treePoint.label
    +
    +    // Iterate over features.
    +    if (featuresForNode.nonEmpty) {
    +      // Use subsampled features
    +      var featureIndexIdx = 0
    +      while (featureIndexIdx < featuresForNode.get.length) {
    +        val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx))
    +        agg.update(featureIndexIdx, binIndex, label, instanceWeight)
    +        featureIndexIdx += 1
    +      }
    +    } else {
    +      // Use all features
    +      val numFeatures = agg.metadata.numFeatures
    +      var featureIndex = 0
    +      while (featureIndex < numFeatures) {
    +        val binIndex = treePoint.binnedFeatures(featureIndex)
    +        agg.update(featureIndex, binIndex, label, instanceWeight)
    +        featureIndex += 1
    +      }
    +    }
    +  }
    +
    +  /**
    +   * Given a group of nodes, this finds the best split for each node.
    +   *
    +   * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
    +   * @param metadata Learning and dataset metadata
    +   * @param topNodes Root node for each tree.  Used for matching instances with nodes.
    +   * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree
    +   * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo,
    +   *                              where nodeIndexInfo stores the index in the group and the
    +   *                              feature subsets (if using feature subsets).
    +   * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
    +   * @param nodeQueue  Queue of nodes to split, with values (treeIndex, node).
    +   *                   Updated with new non-leaf nodes which are created.
    +   * @param nodeIdCache Node Id cache containing an RDD of Array[Int] where
    +   *                    each value in the array is the data point's node Id
    +   *                    for a corresponding tree. This is used to prevent the need
    +   *                    to pass the entire tree to the executors during
    +   *                    the node stat aggregation phase.
    +   */
    +  private[tree] def findBestSplits(
    +      input: RDD[BaggedPoint[TreePoint]],
    +      metadata: DecisionTreeMetadata,
    +      topNodes: Array[LearningNode],
    +      nodesForGroup: Map[Int, Array[LearningNode]],
    +      treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]],
    +      splits: Array[Array[Split]],
    +      nodeQueue: mutable.Queue[(Int, LearningNode)],
    +      timer: TimeTracker = new TimeTracker,
    +      nodeIdCache: Option[NodeIdCache] = None): Unit = {
    +
    +    /*
    +     * The high-level descriptions of the best split optimizations are noted here.
    +     *
    +     * *Group-wise training*
    +     * We perform bin calculations for groups of nodes to reduce the number of
    +     * passes over the data.  Each iteration requires more computation and storage,
    +     * but saves several iterations over the data.
    +     *
    +     * *Bin-wise computation*
    +     * We use a bin-wise best split computation strategy instead of a straightforward best split
    +     * computation strategy. Instead of analyzing each sample for contribution to the left/right
    +     * child node impurity of every split, we first categorize each feature of a sample into a
    +     * bin. We exploit this structure to calculate aggregates for bins and then use these aggregates
    +     * to calculate information gain for each split.
    +     *
    +     * *Aggregation over partitions*
    +     * Instead of performing a flatMap/reduceByKey operation, we exploit the fact that we know
    +     * the number of splits in advance. Thus, we store the aggregates (at the appropriate
    +     * indices) in a single array for all bins and rely upon the RDD aggregate method to
    +     * drastically reduce the communication overhead.
    +     */
    +
    +    // numNodes:  Number of nodes in this group
    +    val numNodes = nodesForGroup.values.map(_.length).sum
    +    logDebug("numNodes = " + numNodes)
    +    logDebug("numFeatures = " + metadata.numFeatures)
    +    logDebug("numClasses = " + metadata.numClasses)
    +    logDebug("isMulticlass = " + metadata.isMulticlass)
    +    logDebug("isMulticlassWithCategoricalFeatures = " +
    +      metadata.isMulticlassWithCategoricalFeatures)
    +    logDebug("using nodeIdCache = " + nodeIdCache.nonEmpty.toString)
    +
    +    /**
    +     * Performs a sequential aggregation over a partition for a particular tree and node.
    +     *
    +     * For each feature, the aggregate sufficient statistics are updated for the relevant
    +     * bins.
    +     *
    +     * @param treeIndex Index of the tree that we want to perform aggregation for.
    +     * @param nodeInfo The node info for the tree node.
    +     * @param agg Array storing aggregate calculation, with a set of sufficient statistics
    +     *            for each (node, feature, bin).
    +     * @param baggedPoint Data point being aggregated.
    +     */
    +    def nodeBinSeqOp(
    +        treeIndex: Int,
    +        nodeInfo: NodeIndexInfo,
    +        agg: Array[DTStatsAggregator],
    +        baggedPoint: BaggedPoint[TreePoint]): Unit = {
    +      if (nodeInfo != null) {
    +        val aggNodeIndex = nodeInfo.nodeIndexInGroup
    +        val featuresForNode = nodeInfo.featureSubset
    +        val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
    +        if (metadata.unorderedFeatures.isEmpty) {
    +          orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
    +        } else {
    +          mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
    +            metadata.unorderedFeatures, instanceWeight, featuresForNode)
    +        }
    +      }
    +    }
    +
    +    /**
    +     * Performs a sequential aggregation over a partition.
    +     *
    +     * Each data point contributes to one node. For each feature,
    +     * the aggregate sufficient statistics are updated for the relevant bins.
    +     *
    +     * @param agg  Array storing aggregate calculation, with a set of sufficient statistics for
    +     *             each (node, feature, bin).
    +     * @param baggedPoint   Data point being aggregated.
    +     * @return  agg
    +     */
    +    def binSeqOp(
    +        agg: Array[DTStatsAggregator],
    +        baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = {
    +      treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
    +        val nodeIndex =
    +          predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures, splits)
    +        nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
    +      }
    +      agg
    +    }
    +
    +    /**
    +     * Do the same thing as binSeqOp, but with nodeIdCache.
    +     */
    +    def binSeqOpWithNodeIdCache(
    +        agg: Array[DTStatsAggregator],
    +        dataPoint: (BaggedPoint[TreePoint], Array[Int])): Array[DTStatsAggregator] = {
    +      treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
    +        val baggedPoint = dataPoint._1
    +        val nodeIdCache = dataPoint._2
    +        val nodeIndex = nodeIdCache(treeIndex)
    +        nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
    +      }
    +
    +      agg
    +    }
    +
    +    /**
    +     * Get node index in group --> features indices map,
    +     * which is a short cut to find feature indices for a node given node index in group.
    +     */
    +    def getNodeToFeatures(
    +        treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]): Option[Map[Int, Array[Int]]] = {
    +      if (!metadata.subsamplingFeatures) {
    +        None
    +      } else {
    +        val mutableNodeToFeatures = new mutable.HashMap[Int, Array[Int]]()
    +        treeToNodeToIndexInfo.values.foreach { nodeIdToNodeInfo =>
    +          nodeIdToNodeInfo.values.foreach { nodeIndexInfo =>
    +            assert(nodeIndexInfo.featureSubset.isDefined)
    +            mutableNodeToFeatures(nodeIndexInfo.nodeIndexInGroup) = nodeIndexInfo.featureSubset.get
    +          }
    +        }
    +        Some(mutableNodeToFeatures.toMap)
    +      }
    +    }
    +
    +    // array of nodes to train indexed by node index in group
    +    val nodes = new Array[LearningNode](numNodes)
    +    nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
    +      nodesForTree.foreach { node =>
    +        nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node
    +      }
    +    }
    +
    +    // Calculate best splits for all nodes in the group
    +    timer.start("chooseSplits")
    +
    +    // In each partition, iterate all instances and compute aggregate stats for each node,
    +    // yield an (nodeIndex, nodeAggregateStats) pair for each node.
    +    // After a `reduceByKey` operation,
    +    // stats of a node will be shuffled to a particular partition and be combined together,
    +    // then best splits for nodes are found there.
    +    // Finally, only best Splits for nodes are collected to driver to construct decision tree.
    +    val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
    +    val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)
    +
    +    val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) {
    +      input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points =>
    +        // Construct a nodeStatsAggregators array to hold node aggregate stats,
    +        // each node will have a nodeStatsAggregator
    +        val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
    +          val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
    +            Some(nodeToFeatures(nodeIndex))
    +          }
    +          new DTStatsAggregator(metadata, featuresForNode)
    +        }
    +
    +        // iterator all instances in current partition and update aggregate stats
    +        points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _))
    +
    +        // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
    +        // which can be combined with other partition using `reduceByKey`
    +        nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
    +      }
    +    } else {
    +      input.mapPartitions { points =>
    +        // Construct a nodeStatsAggregators array to hold node aggregate stats,
    +        // each node will have a nodeStatsAggregator
    +        val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
    +          val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
    +            Some(nodeToFeatures(nodeIndex))
    +          }
    +          new DTStatsAggregator(metadata, featuresForNode)
    +        }
    +
    +        // iterator all instances in current partition and update aggregate stats
    +        points.foreach(binSeqOp(nodeStatsAggregators, _))
    +
    +        // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
    +        // which can be combined with other partition using `reduceByKey`
    +        nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
    +      }
    +    }
    +
    +    val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b)).map {
    +      case (nodeIndex, aggStats) =>
    +        val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
    +          Some(nodeToFeatures(nodeIndex))
    +        }
    +
    +        // find best split for each node
    +        val (split: Split, stats: InformationGainStats, predict: Predict) =
    +          binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex))
    +        (nodeIndex, (split, stats, predict))
    +    }.collectAsMap()
    +
    +    timer.stop("chooseSplits")
    +
    +    val nodeIdUpdaters = if (nodeIdCache.nonEmpty) {
    +      Array.fill[mutable.Map[Int, NodeIndexUpdater]](
    +        metadata.numTrees)(mutable.Map[Int, NodeIndexUpdater]())
    +    } else {
    +      null
    +    }
    +    // Iterate over all nodes in this group.
    +    nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
    +      nodesForTree.foreach { node =>
    +        val nodeIndex = node.id
    +        val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex)
    +        val aggNodeIndex = nodeInfo.nodeIndexInGroup
    +        val (split: Split, stats: InformationGainStats, predict: Predict) =
    +          nodeToBestSplits(aggNodeIndex)
    +        logDebug("best split = " + split)
    +
    +        // Extract info for this node.  Create children if not leaf.
    +        val isLeaf =
    +          (stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) == metadata.maxDepth)
    +        node.predictionStats = predict
    +        node.isLeaf = isLeaf
    +        node.stats = Some(stats)
    +        node.impurity = stats.impurity
    +        logDebug("Node = " + node)
    +
    +        if (!isLeaf) {
    +          node.split = Some(split)
    +          val childIsLeaf = (LearningNode.indexToLevel(nodeIndex) + 1) == metadata.maxDepth
    +          val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0)
    +          val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0)
    +          node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex),
    +            stats.leftPredict, stats.leftImpurity, leftChildIsLeaf))
    +          node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex),
    +            stats.rightPredict, stats.rightImpurity, rightChildIsLeaf))
    +
    +          if (nodeIdCache.nonEmpty) {
    +            val nodeIndexUpdater = NodeIndexUpdater(
    +              split = split,
    +              nodeIndex = nodeIndex)
    +            nodeIdUpdaters(treeIndex).put(nodeIndex, nodeIndexUpdater)
    +          }
    +
    +          // enqueue left child and right child if they are not leaves
    +          if (!leftChildIsLeaf) {
    +            nodeQueue.enqueue((treeIndex, node.leftChild.get))
    +          }
    +          if (!rightChildIsLeaf) {
    +            nodeQueue.enqueue((treeIndex, node.rightChild.get))
    +          }
    +
    +          logDebug("leftChildIndex = " + node.leftChild.get.id +
    +            ", impurity = " + stats.leftImpurity)
    +          logDebug("rightChildIndex = " + node.rightChild.get.id +
    +            ", impurity = " + stats.rightImpurity)
    +        }
    +      }
    +    }
    +
    +    if (nodeIdCache.nonEmpty) {
    +      // Update the cache if needed.
    +      nodeIdCache.get.updateNodeIndices(input, nodeIdUpdaters, splits)
    +    }
    +  }
    +
    +  /**
    +   * Calculate the information gain for a given (feature, split) based upon left/right aggregates.
    +   * @param leftImpurityCalculator left node aggregates for this (feature, split)
    +   * @param rightImpurityCalculator right node aggregate for this (feature, split)
    +   * @return information gain and statistics for split
    +   */
    +  private def calculateGainForSplit(
    +      leftImpurityCalculator: ImpurityCalculator,
    +      rightImpurityCalculator: ImpurityCalculator,
    +      metadata: DecisionTreeMetadata,
    +      impurity: Double): InformationGainStats = {
    +    val leftCount = leftImpurityCalculator.count
    +    val rightCount = rightImpurityCalculator.count
    +
    +    // If left child or right child doesn't satisfy minimum instances per node,
    +    // then this split is invalid, return invalid information gain stats.
    +    if ((leftCount < metadata.minInstancesPerNode) ||
    +      (rightCount < metadata.minInstancesPerNode)) {
    +      return InformationGainStats.invalidInformationGainStats
    +    }
    +
    +    val totalCount = leftCount + rightCount
    +
    +    val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
    +    val rightImpurity = rightImpurityCalculator.calculate()
    +
    +    val leftWeight = leftCount / totalCount.toDouble
    +    val rightWeight = rightCount / totalCount.toDouble
    +
    +    val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
    +
    +    // if information gain doesn't satisfy minimum information gain,
    +    // then this split is invalid, return invalid information gain stats.
    +    if (gain < metadata.minInfoGain) {
    +      return InformationGainStats.invalidInformationGainStats
    +    }
    +
    +    // calculate left and right predict
    +    val leftPredict = calculatePredict(leftImpurityCalculator)
    +    val rightPredict = calculatePredict(rightImpurityCalculator)
    +
    +    new InformationGainStats(gain, impurity, leftImpurity, rightImpurity,
    +      leftPredict, rightPredict)
    +  }
    +
    +  private def calculatePredict(impurityCalculator: ImpurityCalculator): Predict = {
    +    val predict = impurityCalculator.predict
    +    val prob = impurityCalculator.prob(predict)
    +    new Predict(predict, prob)
    +  }
    +
    +  /**
    +   * Calculate predict value for current node, given stats of any split.
    +   * Note that this function is called only once for each node.
    +   * @param leftImpurityCalculator left node aggregates for a split
    +   * @param rightImpurityCalculator right node aggregates for a split
    +   * @return predict value and impurity for current node
    +   */
    +  private def calculatePredictImpurity(
    +      leftImpurityCalculator: ImpurityCalculator,
    +      rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = {
    +    val parentNodeAgg = leftImpurityCalculator.copy
    +    parentNodeAgg.add(rightImpurityCalculator)
    +    val predict = calculatePredict(parentNodeAgg)
    +    val impurity = parentNodeAgg.calculate()
    +
    +    (predict, impurity)
    +  }
    +
    +  /**
    +   * Find the best split for a node.
    +   * @param binAggregates Bin statistics.
    +   * @return tuple for best split: (Split, information gain, prediction at node)
    +   */
    +  private def binsToBestSplit(
    +      binAggregates: DTStatsAggregator,
    +      splits: Array[Array[Split]],
    +      featuresForNode: Option[Array[Int]],
    +      node: LearningNode): (Split, InformationGainStats, Predict) = {
    +
    +    // Calculate prediction and impurity if current node is top node
    +    val level = LearningNode.indexToLevel(node.id)
    +    var predictionAndImpurity: Option[(Predict, Double)] = if (level == 0) {
    +      None
    +    } else {
    +      Some((node.predictionStats, node.impurity))
    +    }
    +
    +    // For each (feature, split), calculate the gain, and select the best (feature, split).
    +    val (bestSplit, bestSplitStats) =
    +      Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
    +        val featureIndex = if (featuresForNode.nonEmpty) {
    +          featuresForNode.get.apply(featureIndexIdx)
    +        } else {
    +          featureIndexIdx
    +        }
    +        val numSplits = binAggregates.metadata.numSplits(featureIndex)
    +        if (binAggregates.metadata.isContinuous(featureIndex)) {
    +          // Cumulative sum (scanLeft) of bin statistics.
    +          // Afterwards, binAggregates for a bin is the sum of aggregates for
    +          // that bin + all preceding bins.
    +          val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
    +          var splitIndex = 0
    +          while (splitIndex < numSplits) {
    +            binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
    +            splitIndex += 1
    +          }
    +          // Find best split.
    +          val (bestFeatureSplitIndex, bestFeatureGainStats) =
    +            Range(0, numSplits).map { case splitIdx =>
    +              val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
    +              val rightChildStats =
    +                binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
    +              rightChildStats.subtract(leftChildStats)
    +              predictionAndImpurity = Some(predictionAndImpurity.getOrElse(
    +                calculatePredictImpurity(leftChildStats, rightChildStats)))
    +              val gainStats = calculateGainForSplit(leftChildStats,
    +                rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2)
    +              (splitIdx, gainStats)
    +            }.maxBy(_._2.gain)
    +          (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
    +        } else if (binAggregates.metadata.isUnordered(featureIndex)) {
    +          // Unordered categorical feature
    +          val (leftChildOffset, rightChildOffset) =
    +            binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
    +          val (bestFeatureSplitIndex, bestFeatureGainStats) =
    +            Range(0, numSplits).map { splitIndex =>
    +              val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
    +              val rightChildStats =
    +                binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
    +              predictionAndImpurity = Some(predictionAndImpurity.getOrElse(
    +                calculatePredictImpurity(leftChildStats, rightChildStats)))
    +              val gainStats = calculateGainForSplit(leftChildStats,
    +                rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2)
    +              (splitIndex, gainStats)
    +            }.maxBy(_._2.gain)
    +          (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
    +        } else {
    +          // Ordered categorical feature
    +          val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
    +          val numCategories = binAggregates.metadata.numBins(featureIndex)
    +
    +          /* Each bin is one category (feature value).
    +           * The bins are ordered based on centroidForCategories, and this ordering determines which
    +           * splits are considered.  (With K categories, we consider K - 1 possible splits.)
    +           *
    +           * centroidForCategories is a list: (category, centroid)
    +           */
    +          val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
    +            // For categorical variables in multiclass classification,
    +            // the bins are ordered by the impurity of their corresponding labels.
    +            Range(0, numCategories).map { case featureValue =>
    +              val categoryStats =
    +                binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
    +              val centroid = if (categoryStats.count != 0) {
    +                categoryStats.calculate()
    +              } else {
    +                Double.MaxValue
    +              }
    +              (featureValue, centroid)
    +            }
    +          } else { // regression or binary classification
    +            // For categorical variables in regression and binary classification,
    +            // the bins are ordered by the centroid of their corresponding labels.
    +            Range(0, numCategories).map { case featureValue =>
    +              val categoryStats =
    +                binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
    +              val centroid = if (categoryStats.count != 0) {
    +                categoryStats.predict
    +              } else {
    +                Double.MaxValue
    +              }
    +              (featureValue, centroid)
    +            }
    +          }
    +
    +          logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))
    +
    +          // bins sorted by centroids
    +          val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
    +
    +          logDebug("Sorted centroids for categorical variable = " +
    +            categoriesSortedByCentroid.mkString(","))
    +
    +          // Cumulative sum (scanLeft) of bin statistics.
    +          // Afterwards, binAggregates for a bin is the sum of aggregates for
    +          // that bin + all preceding bins.
    +          var splitIndex = 0
    +          while (splitIndex < numSplits) {
    +            val currentCategory = categoriesSortedByCentroid(splitIndex)._1
    +            val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
    +            binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
    +            splitIndex += 1
    +          }
    +          // lastCategory = index of bin with total aggregates for this (node, feature)
    +          val lastCategory = categoriesSortedByCentroid.last._1
    +          // Find best split.
    +          val (bestFeatureSplitIndex, bestFeatureGainStats) =
    +            Range(0, numSplits).map { splitIndex =>
    +              val featureValue = categoriesSortedByCentroid(splitIndex)._1
    +              val leftChildStats =
    +                binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
    +              val rightChildStats =
    +                binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
    +              rightChildStats.subtract(leftChildStats)
    +              predictionAndImpurity = Some(predictionAndImpurity.getOrElse(
    +                calculatePredictImpurity(leftChildStats, rightChildStats)))
    +              val gainStats = calculateGainForSplit(leftChildStats,
    +                rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2)
    +              (splitIndex, gainStats)
    +            }.maxBy(_._2.gain)
    +          val categoriesForSplit =
    +            categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
    +          val bestFeatureSplit =
    +            new CategoricalSplit(featureIndex, categoriesForSplit.toArray, numCategories)
    +          (bestFeatureSplit, bestFeatureGainStats)
    +        }
    +      }.maxBy(_._2.gain)
    +
    +    (bestSplit, bestSplitStats, predictionAndImpurity.get._1)
    +  }
    +
    +  /**
    +   * Returns splits and bins for decision tree calculation.
    +   * Continuous and categorical features are handled differently.
    +   *
    +   * Continuous features:
    +   *   For each feature, there are numBins - 1 possible splits representing the possible binary
    +   *   decisions at each node in the tree.
    +   *   This finds locations (feature values) for splits using a subsample of the data.
    +   *
    +   * Categorical features:
    +   *   For each feature, there is 1 bin per split.
    +   *   Splits and bins are handled in 2 ways:
    +   *   (a) "unordered features"
    +   *       For multiclass classification with a low-arity feature
    +   *       (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
    +   *       the feature is split based on subsets of categories.
    +   *   (b) "ordered features"
    +   *       For regression and binary classification,
    +   *       and for multiclass classification with a high-arity feature,
    +   *       there is one bin per category.
    +   *
    +   * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
    +   * @param metadata Learning and dataset metadata
    +   * @return A tuple of (splits, bins).
    +   *         Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]]
    +   *          of size (numFeatures, numSplits).
    +   *         Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]]
    +   *          of size (numFeatures, numBins).
    +   */
    +  protected[tree] def findSplits(
    +      input: RDD[LabeledPoint],
    +      metadata: DecisionTreeMetadata): Array[Array[Split]] = {
    +
    +    logDebug("isMulticlass = " + metadata.isMulticlass)
    +
    +    val numFeatures = metadata.numFeatures
    +
    +    // Sample the input only if there are continuous features.
    +    val hasContinuousFeatures = Range(0, numFeatures).exists(metadata.isContinuous)
    +    val sampledInput = if (hasContinuousFeatures) {
    +      // Calculate the number of samples for approximate quantile calculation.
    +      val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000)
    +      val fraction = if (requiredSamples < metadata.numExamples) {
    +        requiredSamples.toDouble / metadata.numExamples
    +      } else {
    +        1.0
    +      }
    +      logDebug("fraction of data used for calculating quantiles = " + fraction)
    +      input.sample(withReplacement = false, fraction, new XORShiftRandom(1).nextInt()).collect()
    +    } else {
    +      new Array[LabeledPoint](0)
    +    }
    +
    +    val splits = new Array[Array[Split]](numFeatures)
    +
    +    // Find all splits.
    +    // Iterate over all features.
    +    var featureIndex = 0
    +    while (featureIndex < numFeatures) {
    +      if (metadata.isContinuous(featureIndex)) {
    +        val featureSamples = sampledInput.map(_.features(featureIndex))
    +        val featureSplits = findSplitsForContinuousFeature(featureSamples, metadata, featureIndex)
    +
    +        val numSplits = featureSplits.length
    +        logDebug(s"featureIndex = $featureIndex, numSplits = $numSplits")
    +        splits(featureIndex) = new Array[Split](numSplits)
    +
    +        var splitIndex = 0
    +        while (splitIndex < numSplits) {
    +          val threshold = featureSplits(splitIndex)
    +          splits(featureIndex)(splitIndex) = new ContinuousSplit(featureIndex, threshold)
    +          splitIndex += 1
    +        }
    +      } else {
    +        // Categorical feature
    +        if (metadata.isUnordered(featureIndex)) {
    +          val numSplits = metadata.numSplits(featureIndex)
    +          val featureArity = metadata.featureArity(featureIndex)
    +          // TODO: Use an implicit representation mapping each category to a subset of indices.
    +          //       I.e., track indices such that we can calculate the set of bins for which
    +          //       feature value x splits to the left.
    +          // Unordered features
    +          // 2^(maxFeatureValue - 1) - 1 combinations
    +          splits(featureIndex) = new Array[Split](numSplits)
    +          var splitIndex = 0
    +          while (splitIndex < numSplits) {
    +            val categories: List[Double] =
    +              extractMultiClassCategories(splitIndex + 1, featureArity)
    +            splits(featureIndex)(splitIndex) =
    +              new CategoricalSplit(featureIndex, categories.toArray, featureArity)
    +            splitIndex += 1
    +          }
    +        } else {
    +          // Ordered features
    +          //   Bins correspond to feature values, so we do not need to compute splits or bins
    +          //   beforehand.  Splits are constructed as needed during training.
    +          splits(featureIndex) = new Array[Split](0)
    +        }
    +      }
    +      featureIndex += 1
    +    }
    +    splits
    +  }
    +
    +  /**
    +   * Nested method to extract list of eligible categories given an index. It extracts the
    +   * position of ones in a binary representation of the input. If binary
    +   * representation of an number is 01101 (13), the output list should (3.0, 2.0,
    +   * 0.0). The maxFeatureValue depict the number of rightmost digits that will be tested for ones.
    +   */
    +  private[tree] def extractMultiClassCategories(
    +      input: Int,
    +      maxFeatureValue: Int): List[Double] = {
    +    var categories = List[Double]()
    +    var j = 0
    +    var bitShiftedInput = input
    +    while (j < maxFeatureValue) {
    +      if (bitShiftedInput % 2 != 0) {
    +        // updating the list of categories.
    +        categories = j.toDouble :: categories
    +      }
    +      // Right shift by one
    +      bitShiftedInput = bitShiftedInput >> 1
    +      j += 1
    +    }
    +    categories
    +  }
    +
    +  /**
    +   * Find splits for a continuous feature
    +   * NOTE: Returned number of splits is set based on `featureSamples` and
    +   *       could be different from the specified `numSplits`.
    +   *       The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly.
    +   * @param featureSamples feature values of each sample
    +   * @param metadata decision tree metadata
    +   *                 NOTE: `metadata.numbins` will be changed accordingly
    +   *                       if there are not enough splits to be found
    +   * @param featureIndex feature index to find splits
    +   * @return array of splits
    +   */
    +  private[tree] def findSplitsForContinuousFeature(
    +      featureSamples: Array[Double],
    +      metadata: DecisionTreeMetadata,
    +      featureIndex: Int): Array[Double] = {
    +    require(metadata.isContinuous(featureIndex),
    +      "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")
    +
    +    val splits = {
    +      val numSplits = metadata.numSplits(featureIndex)
    +
    +      // get count for each distinct value
    +      val valueCountMap = featureSamples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
    +        m + ((x, m.getOrElse(x, 0) + 1))
    +      }
    +      // sort distinct values
    +      val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
    +
    +      // if possible splits is not enough or just enough, just return all possible splits
    +      val possibleSplits = valueCounts.length
    +      if (possibleSplits <= numSplits) {
    +        valueCounts.map(_._1)
    +      } else {
    +        // stride between splits
    +        val stride: Double = featureSamples.length.toDouble / (numSplits + 1)
    +        logDebug("stride = " + stride)
    +
    +        // iterate `valueCount` to find splits
    +        val splitsBuilder = mutable.ArrayBuilder.make[Double]
    +        var index = 1
    +        // currentCount: sum of counts of values that have been visited
    +        var currentCount = valueCounts(0)._2
    +        // targetCount: target value for `currentCount`.
    +        // If `currentCount` is closest value to `targetCount`,
    +        // then current value is a split threshold.
    +        // After finding a split threshold, `targetCount` is added by stride.
    +        var targetCount = stride
    +        while (index < valueCounts.length) {
    +          val previousCount = currentCount
    +          currentCount += valueCounts(index)._2
    +          val previousGap = math.abs(previousCount - targetCount)
    +          val currentGap = math.abs(currentCount - targetCount)
    +          // If adding count of current value to currentCount
    +          // makes the gap between currentCount and targetCount smaller,
    +          // previous value is a split threshold.
    +          if (previousGap < currentGap) {
    +            splitsBuilder += valueCounts(index - 1)._1
    +            targetCount += stride
    +          }
    +          index += 1
    +        }
    +
    +        splitsBuilder.result()
    +      }
    +    }
    +
    +    // TODO: Do not fail; just ignore the useless feature.
    +    assert(splits.length > 0,
    +      s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." +
    +        "  Please remove this feature and then try again.")
    +    // set number of splits accordingly
    +    metadata.setNumSplits(featureIndex, splits.length)
    +
    +    splits
    +  }
    +
    +  private[tree] class NodeIndexInfo(
    +      val nodeIndexInGroup: Int,
    +      val featureSubset: Option[Array[Int]]) extends Serializable
    +
    +  /**
    +   * Pull nodes off of the queue, and collect a group of nodes to be split on this iteration.
    +   * This tracks the memory usage for aggregates and stops adding nodes when too much memory
    +   * will be needed; this allows an adaptive number of nodes since different nodes may require
    +   * different amounts of memory (if featureSubsetStrategy is not "all").
    +   *
    +   * @param nodeQueue  Queue of nodes to split.
    +   * @param maxMemoryUsage  Bound on size of aggregate statistics.
    +   * @return  (nodesForGroup, treeToNodeToIndexInfo).
    +   *          nodesForGroup holds the nodes to split: treeIndex --> nodes in tree.
    +   *
    +   *          treeToNodeToIndexInfo holds indices selected features for each node:
    +   *            treeIndex --> (global) node index --> (node index in group, feature indices).
    +   *          The (global) node index is the index in the tree; the node index in group is the
    +   *           index in [0, numNodesInGroup) of the node in this group.
    +   *          The feature indices are None if not subsampling features.
    +   */
    +  private[tree] def selectNodesToSplit(
    +      nodeQueue: mutable.Queue[(Int, LearningNode)],
    +      maxMemoryUsage: Long,
    +      metadata: DecisionTreeMetadata,
    +      rng: Random): (Map[Int, Array[LearningNode]], Map[Int, Map[Int, NodeIndexInfo]]) = {
    +    // Collect some nodes to split:
    +    //  nodesForGroup(treeIndex) = nodes to split
    +    val mutableNodesForGroup = new mutable.HashMap[Int, mutable.ArrayBuffer[LearningNode]]()
    +    val mutableTreeToNodeToIndexInfo =
    +      new mutable.HashMap[Int, mutable.HashMap[Int, NodeIndexInfo]]()
    +    var memUsage: Long = 0L
    +    var numNodesInGroup = 0
    +    while (nodeQueue.nonEmpty && memUsage < maxMemoryUsage) {
    +      val (treeIndex, node) = nodeQueue.head
    +      // Choose subset of features for node (if subsampling).
    +      val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
    +        Some(SamplingUtils.reservoirSampleAndCount(Range(0,
    +          metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong())._1)
    +      } else {
    +        None
    +      }
    +      // Check if enough memory remains to add this node to the group.
    +      val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
    +      if (memUsage + nodeMemUsage <= maxMemoryUsage) {
    +        nodeQueue.dequeue()
    +        mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[LearningNode]()) +=
    +          node
    +        mutableTreeToNodeToIndexInfo
    +          .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id)
    +          = new NodeIndexInfo(numNodesInGroup, featureSubset)
    +      }
    +      numNodesInGroup += 1
    +      memUsage += nodeMemUsage
    +    }
    +    // Convert mutable maps to immutable ones.
    +    val nodesForGroup: Map[Int, Array[LearningNode]] =
    +      mutableNodesForGroup.mapValues(_.toArray).toMap
    +    val treeToNodeToIndexInfo = mutableTreeToNodeToIndexInfo.mapValues(_.toMap).toMap
    +    (nodesForGroup, treeToNodeToIndexInfo)
    +  }
    +
    +  /**
    +   * Get the number of values to be stored for this node in the bin aggregates.
    +   * @param featureSubset  Indices of features which may be split at this node.
    +   *                       If None, then use all features.
    +   */
    +  private def aggregateSizeForNode(
    +      metadata: DecisionTreeMetadata,
    +      featureSubset: Option[Array[Int]]): Long = {
    +    val totalBins = if (featureSubset.nonEmpty) {
    +      featureSubset.get.map(featureIndex => metadata.numBins(featureIndex).toLong).sum
    +    } else {
    +      metadata.numBins.map(_.toLong).sum
    +    }
    +    if (metadata.isClassification) {
    +      metadata.numClasses * totalBins
    +    } else {
    +      3 * totalBins
    +    }
    +  }
    +
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala
    new file mode 100644
    index 0000000000000..9fa27e5e1f721
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala
    @@ -0,0 +1,134 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.tree.impl
    +
    +import org.apache.spark.ml.tree.{ContinuousSplit, Split}
    +import org.apache.spark.mllib.regression.LabeledPoint
    +import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata
    +import org.apache.spark.rdd.RDD
    +
    +
    +/**
    + * Internal representation of LabeledPoint for DecisionTree.
    + * This bins feature values based on a subsampled of data as follows:
    + *  (a) Continuous features are binned into ranges.
    + *  (b) Unordered categorical features are binned based on subsets of feature values.
    + *      "Unordered categorical features" are categorical features with low arity used in
    + *      multiclass classification.
    + *  (c) Ordered categorical features are binned based on feature values.
    + *      "Ordered categorical features" are categorical features with high arity,
    + *      or any categorical feature used in regression or binary classification.
    + *
    + * @param label  Label from LabeledPoint
    + * @param binnedFeatures  Binned feature values.
    + *                        Same length as LabeledPoint.features, but values are bin indices.
    + */
    +private[spark] class TreePoint(val label: Double, val binnedFeatures: Array[Int])
    +  extends Serializable {
    +}
    +
    +private[spark] object TreePoint {
    +
    +  /**
    +   * Convert an input dataset into its TreePoint representation,
    +   * binning feature values in preparation for DecisionTree training.
    +   * @param input     Input dataset.
    +   * @param splits    Splits for features, of size (numFeatures, numSplits).
    +   * @param metadata  Learning and dataset metadata
    +   * @return  TreePoint dataset representation
    +   */
    +  def convertToTreeRDD(
    +      input: RDD[LabeledPoint],
    +      splits: Array[Array[Split]],
    +      metadata: DecisionTreeMetadata): RDD[TreePoint] = {
    +    // Construct arrays for featureArity for efficiency in the inner loop.
    +    val featureArity: Array[Int] = new Array[Int](metadata.numFeatures)
    +    var featureIndex = 0
    +    while (featureIndex < metadata.numFeatures) {
    +      featureArity(featureIndex) = metadata.featureArity.getOrElse(featureIndex, 0)
    +      featureIndex += 1
    +    }
    +    val thresholds: Array[Array[Double]] = featureArity.zipWithIndex.map { case (arity, idx) =>
    +      if (arity == 0) {
    +        splits(idx).map(_.asInstanceOf[ContinuousSplit].threshold)
    +      } else {
    +        Array.empty[Double]
    +      }
    +    }
    +    input.map { x =>
    +      TreePoint.labeledPointToTreePoint(x, thresholds, featureArity)
    +    }
    +  }
    +
    +  /**
    +   * Convert one LabeledPoint into its TreePoint representation.
    +   * @param thresholds  For each feature, split thresholds for continuous features,
    +   *                    empty for categorical features.
    +   * @param featureArity  Array indexed by feature, with value 0 for continuous and numCategories
    +   *                      for categorical features.
    +   */
    +  private def labeledPointToTreePoint(
    +      labeledPoint: LabeledPoint,
    +      thresholds: Array[Array[Double]],
    +      featureArity: Array[Int]): TreePoint = {
    +    val numFeatures = labeledPoint.features.size
    +    val arr = new Array[Int](numFeatures)
    +    var featureIndex = 0
    +    while (featureIndex < numFeatures) {
    +      arr(featureIndex) =
    +        findBin(featureIndex, labeledPoint, featureArity(featureIndex), thresholds(featureIndex))
    +      featureIndex += 1
    +    }
    +    new TreePoint(labeledPoint.label, arr)
    +  }
    +
    +  /**
    +   * Find discretized value for one (labeledPoint, feature).
    +   *
    +   * NOTE: We cannot use Bucketizer since it handles split thresholds differently than the old
    +   *       (mllib) tree API.  We want to maintain the same behavior as the old tree API.
    +   *
    +   * @param featureArity  0 for continuous features; number of categories for categorical features.
    +   */
    +  private def findBin(
    +      featureIndex: Int,
    +      labeledPoint: LabeledPoint,
    +      featureArity: Int,
    +      thresholds: Array[Double]): Int = {
    +    val featureValue = labeledPoint.features(featureIndex)
    +
    +    if (featureArity == 0) {
    +      val idx = java.util.Arrays.binarySearch(thresholds, featureValue)
    +      if (idx >= 0) {
    +        idx
    +      } else {
    +        -idx - 1
    +      }
    +    } else {
    +      // Categorical feature bins are indexed by feature values.
    +      if (featureValue < 0 || featureValue >= featureArity) {
    +        throw new IllegalArgumentException(
    +          s"DecisionTree given invalid data:" +
    +            s" Feature $featureIndex is categorical with values in {0,...,${featureArity - 1}," +
    +            s" but a data point gives it value $featureValue.\n" +
    +            "  Bad data point: " + labeledPoint.toString)
    +      }
    +      featureValue.toInt
    +    }
    +  }
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
    index 1929f9d02156e..22873909c33fa 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
    @@ -17,6 +17,7 @@
     
     package org.apache.spark.ml.tree
     
    +import org.apache.spark.mllib.linalg.{Vectors, Vector}
     
     /**
      * Abstraction for Decision Tree models.
    @@ -70,6 +71,10 @@ private[ml] trait TreeEnsembleModel {
       /** Weights for each tree, zippable with [[trees]] */
       def treeWeights: Array[Double]
     
    +  /** Weights used by the python wrappers. */
    +  // Note: An array cannot be returned directly due to serialization problems.
    +  private[spark] def javaTreeWeights: Vector = Vectors.dense(treeWeights)
    +
       /** Summary of the model */
       override def toString: String = {
         // Implementing classes should generally override this method to be more descriptive.
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
    index e2444ab65b43b..f979319cc4b58 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
    @@ -32,38 +32,7 @@ import org.apache.spark.sql.types.StructType
     /**
      * Params for [[CrossValidator]] and [[CrossValidatorModel]].
      */
    -private[ml] trait CrossValidatorParams extends Params {
    -
    -  /**
    -   * param for the estimator to be cross-validated
    -   * @group param
    -   */
    -  val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection")
    -
    -  /** @group getParam */
    -  def getEstimator: Estimator[_] = $(estimator)
    -
    -  /**
    -   * param for estimator param maps
    -   * @group param
    -   */
    -  val estimatorParamMaps: Param[Array[ParamMap]] =
    -    new Param(this, "estimatorParamMaps", "param maps for the estimator")
    -
    -  /** @group getParam */
    -  def getEstimatorParamMaps: Array[ParamMap] = $(estimatorParamMaps)
    -
    -  /**
    -   * param for the evaluator used to select hyper-parameters that maximize the cross-validated
    -   * metric
    -   * @group param
    -   */
    -  val evaluator: Param[Evaluator] = new Param(this, "evaluator",
    -    "evaluator used to select hyper-parameters that maximize the cross-validated metric")
    -
    -  /** @group getParam */
    -  def getEvaluator: Evaluator = $(evaluator)
    -
    +private[ml] trait CrossValidatorParams extends ValidatorParams {
       /**
        * Param for number of folds for cross validation.  Must be >= 2.
        * Default: 3
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
    new file mode 100644
    index 0000000000000..c0edc730b6fd6
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
    @@ -0,0 +1,168 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.tuning
    +
    +import org.apache.spark.Logging
    +import org.apache.spark.annotation.Experimental
    +import org.apache.spark.ml.evaluation.Evaluator
    +import org.apache.spark.ml.{Estimator, Model}
    +import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators}
    +import org.apache.spark.ml.util.Identifiable
    +import org.apache.spark.sql.DataFrame
    +import org.apache.spark.sql.types.StructType
    +
    +/**
    + * Params for [[TrainValidationSplit]] and [[TrainValidationSplitModel]].
    + */
    +private[ml] trait TrainValidationSplitParams extends ValidatorParams {
    +  /**
    +   * Param for ratio between train and validation data. Must be between 0 and 1.
    +   * Default: 0.75
    +   * @group param
    +   */
    +  val trainRatio: DoubleParam = new DoubleParam(this, "trainRatio",
    +    "ratio between training set and validation set (>= 0 && <= 1)", ParamValidators.inRange(0, 1))
    +
    +  /** @group getParam */
    +  def getTrainRatio: Double = $(trainRatio)
    +
    +  setDefault(trainRatio -> 0.75)
    +}
    +
    +/**
    + * :: Experimental ::
    + * Validation for hyper-parameter tuning.
    + * Randomly splits the input dataset into train and validation sets,
    + * and uses evaluation metric on the validation set to select the best model.
    + * Similar to [[CrossValidator]], but only splits the set once.
    + */
    +@Experimental
    +class TrainValidationSplit(override val uid: String) extends Estimator[TrainValidationSplitModel]
    +  with TrainValidationSplitParams with Logging {
    +
    +  def this() = this(Identifiable.randomUID("tvs"))
    +
    +  /** @group setParam */
    +  def setEstimator(value: Estimator[_]): this.type = set(estimator, value)
    +
    +  /** @group setParam */
    +  def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value)
    +
    +  /** @group setParam */
    +  def setEvaluator(value: Evaluator): this.type = set(evaluator, value)
    +
    +  /** @group setParam */
    +  def setTrainRatio(value: Double): this.type = set(trainRatio, value)
    +
    +  override def fit(dataset: DataFrame): TrainValidationSplitModel = {
    +    val schema = dataset.schema
    +    transformSchema(schema, logging = true)
    +    val sqlCtx = dataset.sqlContext
    +    val est = $(estimator)
    +    val eval = $(evaluator)
    +    val epm = $(estimatorParamMaps)
    +    val numModels = epm.length
    +    val metrics = new Array[Double](epm.length)
    +
    +    val Array(training, validation) =
    +      dataset.rdd.randomSplit(Array($(trainRatio), 1 - $(trainRatio)))
    +    val trainingDataset = sqlCtx.createDataFrame(training, schema).cache()
    +    val validationDataset = sqlCtx.createDataFrame(validation, schema).cache()
    +
    +    // multi-model training
    +    logDebug(s"Train split with multiple sets of parameters.")
    +    val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]]
    +    trainingDataset.unpersist()
    +    var i = 0
    +    while (i < numModels) {
    +      // TODO: duplicate evaluator to take extra params from input
    +      val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)))
    +      logDebug(s"Got metric $metric for model trained with ${epm(i)}.")
    +      metrics(i) += metric
    +      i += 1
    +    }
    +    validationDataset.unpersist()
    +
    +    logInfo(s"Train validation split metrics: ${metrics.toSeq}")
    +    val (bestMetric, bestIndex) = metrics.zipWithIndex.maxBy(_._1)
    +    logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
    +    logInfo(s"Best train validation split metric: $bestMetric.")
    +    val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
    +    copyValues(new TrainValidationSplitModel(uid, bestModel, metrics).setParent(this))
    +  }
    +
    +  override def transformSchema(schema: StructType): StructType = {
    +    $(estimator).transformSchema(schema)
    +  }
    +
    +  override def validateParams(): Unit = {
    +    super.validateParams()
    +    val est = $(estimator)
    +    for (paramMap <- $(estimatorParamMaps)) {
    +      est.copy(paramMap).validateParams()
    +    }
    +  }
    +
    +  override def copy(extra: ParamMap): TrainValidationSplit = {
    +    val copied = defaultCopy(extra).asInstanceOf[TrainValidationSplit]
    +    if (copied.isDefined(estimator)) {
    +      copied.setEstimator(copied.getEstimator.copy(extra))
    +    }
    +    if (copied.isDefined(evaluator)) {
    +      copied.setEvaluator(copied.getEvaluator.copy(extra))
    +    }
    +    copied
    +  }
    +}
    +
    +/**
    + * :: Experimental ::
    + * Model from train validation split.
    + *
    + * @param uid Id.
    + * @param bestModel Estimator determined best model.
    + * @param validationMetrics Evaluated validation metrics.
    + */
    +@Experimental
    +class TrainValidationSplitModel private[ml] (
    +    override val uid: String,
    +    val bestModel: Model[_],
    +    val validationMetrics: Array[Double])
    +  extends Model[TrainValidationSplitModel] with TrainValidationSplitParams {
    +
    +  override def validateParams(): Unit = {
    +    bestModel.validateParams()
    +  }
    +
    +  override def transform(dataset: DataFrame): DataFrame = {
    +    transformSchema(dataset.schema, logging = true)
    +    bestModel.transform(dataset)
    +  }
    +
    +  override def transformSchema(schema: StructType): StructType = {
    +    bestModel.transformSchema(schema)
    +  }
    +
    +  override def copy(extra: ParamMap): TrainValidationSplitModel = {
    +    val copied = new TrainValidationSplitModel (
    +      uid,
    +      bestModel.copy(extra).asInstanceOf[Model[_]],
    +      validationMetrics.clone())
    +    copyValues(copied, extra)
    +  }
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
    new file mode 100644
    index 0000000000000..8897ab0825acd
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
    @@ -0,0 +1,60 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.tuning
    +
    +import org.apache.spark.annotation.DeveloperApi
    +import org.apache.spark.ml.Estimator
    +import org.apache.spark.ml.evaluation.Evaluator
    +import org.apache.spark.ml.param.{ParamMap, Param, Params}
    +
    +/**
    + * :: DeveloperApi ::
    + * Common params for [[TrainValidationSplitParams]] and [[CrossValidatorParams]].
    + */
    +@DeveloperApi
    +private[ml] trait ValidatorParams extends Params {
    +
    +  /**
    +   * param for the estimator to be validated
    +   * @group param
    +   */
    +  val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection")
    +
    +  /** @group getParam */
    +  def getEstimator: Estimator[_] = $(estimator)
    +
    +  /**
    +   * param for estimator param maps
    +   * @group param
    +   */
    +  val estimatorParamMaps: Param[Array[ParamMap]] =
    +    new Param(this, "estimatorParamMaps", "param maps for the estimator")
    +
    +  /** @group getParam */
    +  def getEstimatorParamMaps: Array[ParamMap] = $(estimatorParamMaps)
    +
    +  /**
    +   * param for the evaluator used to select hyper-parameters that maximize the validated metric
    +   * @group param
    +   */
    +  val evaluator: Param[Evaluator] = new Param(this, "evaluator",
    +    "evaluator used to select hyper-parameters that maximize the validated metric")
    +
    +  /** @group getParam */
    +  def getEvaluator: Evaluator = $(evaluator)
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
    index 7cd53c6d7ef79..76f651488aef9 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
    @@ -32,10 +32,15 @@ private[spark] object SchemaUtils {
        * @param colName  column name
        * @param dataType  required column data type
        */
    -  def checkColumnType(schema: StructType, colName: String, dataType: DataType): Unit = {
    +  def checkColumnType(
    +      schema: StructType,
    +      colName: String,
    +      dataType: DataType,
    +      msg: String = ""): Unit = {
         val actualDataType = schema(colName).dataType
    +    val message = if (msg != null && msg.trim.length > 0) " " + msg else ""
         require(actualDataType.equals(dataType),
    -      s"Column $colName must be of type $dataType but was actually $actualDataType.")
    +      s"Column $colName must be of type $dataType but was actually $actualDataType.$message")
       }
     
       /**
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala
    new file mode 100644
    index 0000000000000..8d4174124b5c4
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala
    @@ -0,0 +1,153 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.util
    +
    +import scala.collection.mutable
    +
    +import org.apache.spark.{Accumulator, SparkContext}
    +
    +/**
    + * Abstract class for stopwatches.
    + */
    +private[spark] abstract class Stopwatch extends Serializable {
    +
    +  @transient private var running: Boolean = false
    +  private var startTime: Long = _
    +
    +  /**
    +   * Name of the stopwatch.
    +   */
    +  val name: String
    +
    +  /**
    +   * Starts the stopwatch.
    +   * Throws an exception if the stopwatch is already running.
    +   */
    +  def start(): Unit = {
    +    assume(!running, "start() called but the stopwatch is already running.")
    +    running = true
    +    startTime = now
    +  }
    +
    +  /**
    +   * Stops the stopwatch and returns the duration of the last session in milliseconds.
    +   * Throws an exception if the stopwatch is not running.
    +   */
    +  def stop(): Long = {
    +    assume(running, "stop() called but the stopwatch is not running.")
    +    val duration = now - startTime
    +    add(duration)
    +    running = false
    +    duration
    +  }
    +
    +  /**
    +   * Checks whether the stopwatch is running.
    +   */
    +  def isRunning: Boolean = running
    +
    +  /**
    +   * Returns total elapsed time in milliseconds, not counting the current session if the stopwatch
    +   * is running.
    +   */
    +  def elapsed(): Long
    +
    +  override def toString: String = s"$name: ${elapsed()}ms"
    +
    +  /**
    +   * Gets the current time in milliseconds.
    +   */
    +  protected def now: Long = System.currentTimeMillis()
    +
    +  /**
    +   * Adds input duration to total elapsed time.
    +   */
    +  protected def add(duration: Long): Unit
    +}
    +
    +/**
    + * A local [[Stopwatch]].
    + */
    +private[spark] class LocalStopwatch(override val name: String) extends Stopwatch {
    +
    +  private var elapsedTime: Long = 0L
    +
    +  override def elapsed(): Long = elapsedTime
    +
    +  override protected def add(duration: Long): Unit = {
    +    elapsedTime += duration
    +  }
    +}
    +
    +/**
    + * A distributed [[Stopwatch]] using Spark accumulator.
    + * @param sc SparkContext
    + */
    +private[spark] class DistributedStopwatch(
    +    sc: SparkContext,
    +    override val name: String) extends Stopwatch {
    +
    +  private val elapsedTime: Accumulator[Long] = sc.accumulator(0L, s"DistributedStopwatch($name)")
    +
    +  override def elapsed(): Long = elapsedTime.value
    +
    +  override protected def add(duration: Long): Unit = {
    +    elapsedTime += duration
    +  }
    +}
    +
    +/**
    + * A multiple stopwatch that contains local and distributed stopwatches.
    + * @param sc SparkContext
    + */
    +private[spark] class MultiStopwatch(@transient private val sc: SparkContext) extends Serializable {
    +
    +  private val stopwatches: mutable.Map[String, Stopwatch] = mutable.Map.empty
    +
    +  /**
    +   * Adds a local stopwatch.
    +   * @param name stopwatch name
    +   */
    +  def addLocal(name: String): this.type = {
    +    require(!stopwatches.contains(name), s"Stopwatch with name $name already exists.")
    +    stopwatches(name) = new LocalStopwatch(name)
    +    this
    +  }
    +
    +  /**
    +   * Adds a distributed stopwatch.
    +   * @param name stopwatch name
    +   */
    +  def addDistributed(name: String): this.type = {
    +    require(!stopwatches.contains(name), s"Stopwatch with name $name already exists.")
    +    stopwatches(name) = new DistributedStopwatch(sc, name)
    +    this
    +  }
    +
    +  /**
    +   * Gets a stopwatch.
    +   * @param name stopwatch name
    +   */
    +  def apply(name: String): Stopwatch = stopwatches(name)
    +
    +  override def toString: String = {
    +    stopwatches.values.toArray.sortBy(_.name)
    +      .map(c => s"  $c")
    +      .mkString("{\n", ",\n", "\n}")
    +  }
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
    new file mode 100644
    index 0000000000000..0ec88ef77d695
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
    @@ -0,0 +1,53 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.mllib.api.python
    +
    +import java.util.{List => JList}
    +
    +import scala.collection.JavaConverters._
    +import scala.collection.mutable.ArrayBuffer
    +
    +import org.apache.spark.SparkContext
    +import org.apache.spark.mllib.linalg.{Vector, Vectors, Matrix}
    +import org.apache.spark.mllib.clustering.GaussianMixtureModel
    +
    +/**
    +  * Wrapper around GaussianMixtureModel to provide helper methods in Python
    +  */
    +private[python] class GaussianMixtureModelWrapper(model: GaussianMixtureModel) {
    +  val weights: Vector = Vectors.dense(model.weights)
    +  val k: Int = weights.size
    +
    +  /**
    +    * Returns gaussians as a List of Vectors and Matrices corresponding each MultivariateGaussian
    +    */
    +  val gaussians: JList[Object] = {
    +    val modelGaussians = model.gaussians
    +    var i = 0
    +    var mu = ArrayBuffer.empty[Vector]
    +    var sigma = ArrayBuffer.empty[Matrix]
    +    while (i < k) {
    +      mu += modelGaussians(i).mu
    +      sigma += modelGaussians(i).sigma
    +      i += 1
    +    }
    +    List(mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava
    +  }
    +
    +  def save(sc: SparkContext, path: String): Unit = model.save(sc, path)
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
    index e628059c4af8e..6f080d32bbf4d 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
    @@ -43,7 +43,7 @@ import org.apache.spark.mllib.recommendation._
     import org.apache.spark.mllib.regression._
     import org.apache.spark.mllib.stat.correlation.CorrelationNames
     import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
    -import org.apache.spark.mllib.stat.test.ChiSqTestResult
    +import org.apache.spark.mllib.stat.test.{ChiSqTestResult, KolmogorovSmirnovTestResult}
     import org.apache.spark.mllib.stat.{
       KernelDensity, MultivariateStatisticalSummary, Statistics}
     import org.apache.spark.mllib.tree.configuration.{Algo, BoostingStrategy, Strategy}
    @@ -364,7 +364,7 @@ private[python] class PythonMLLibAPI extends Serializable {
           seed: java.lang.Long,
           initialModelWeights: java.util.ArrayList[Double],
           initialModelMu: java.util.ArrayList[Vector],
    -      initialModelSigma: java.util.ArrayList[Matrix]): JList[Object] = {
    +      initialModelSigma: java.util.ArrayList[Matrix]): GaussianMixtureModelWrapper = {
         val gmmAlg = new GaussianMixture()
           .setK(k)
           .setConvergenceTol(convergenceTol)
    @@ -382,16 +382,7 @@ private[python] class PythonMLLibAPI extends Serializable {
         if (seed != null) gmmAlg.setSeed(seed)
     
         try {
    -      val model = gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK))
    -      var wt = ArrayBuffer.empty[Double]
    -      var mu = ArrayBuffer.empty[Vector]
    -      var sigma = ArrayBuffer.empty[Matrix]
    -      for (i <- 0 until model.k) {
    -          wt += model.weights(i)
    -          mu += model.gaussians(i).mu
    -          sigma += model.gaussians(i).sigma
    -      }
    -      List(Vectors.dense(wt.toArray), mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava
    +      new GaussianMixtureModelWrapper(gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK)))
         } finally {
           data.rdd.unpersist(blocking = false)
         }
    @@ -502,6 +493,39 @@ private[python] class PythonMLLibAPI extends Serializable {
         new MatrixFactorizationModelWrapper(model)
       }
     
    +  /**
    +   * Java stub for Python mllib LDA.run()
    +   */
    +  def trainLDAModel(
    +      data: JavaRDD[java.util.List[Any]],
    +      k: Int,
    +      maxIterations: Int,
    +      docConcentration: Double,
    +      topicConcentration: Double,
    +      seed: java.lang.Long,
    +      checkpointInterval: Int,
    +      optimizer: String): LDAModel = {
    +    val algo = new LDA()
    +      .setK(k)
    +      .setMaxIterations(maxIterations)
    +      .setDocConcentration(docConcentration)
    +      .setTopicConcentration(topicConcentration)
    +      .setCheckpointInterval(checkpointInterval)
    +      .setOptimizer(optimizer)
    +
    +    if (seed != null) algo.setSeed(seed)
    +
    +    val documents = data.rdd.map(_.asScala.toArray).map { r =>
    +      r(0) match {
    +        case i: java.lang.Integer => (i.toLong, r(1).asInstanceOf[Vector])
    +        case i: java.lang.Long => (i.toLong, r(1).asInstanceOf[Vector])
    +        case _ => throw new IllegalArgumentException("input values contains invalid type value.")
    +      }
    +    }
    +    algo.run(documents)
    +  }
    +
    +
       /**
        * Java stub for Python mllib FPGrowth.train().  This stub returns a handle
        * to the Java object instead of the content of the Java object.  Extra care
    @@ -1060,6 +1084,18 @@ private[python] class PythonMLLibAPI extends Serializable {
         LinearDataGenerator.generateLinearRDD(
           sc, nexamples, nfeatures, eps, nparts, intercept)
       }
    +
    +  /**
    +   * Java stub for Statistics.kolmogorovSmirnovTest()
    +   */
    +  def kolmogorovSmirnovTest(
    +      data: JavaRDD[Double],
    +      distName: String,
    +      params: JList[Double]): KolmogorovSmirnovTestResult = {
    +    val paramsSeq = params.asScala.toSeq
    +    Statistics.kolmogorovSmirnovTest(data, distName, paramsSeq: _*)
    +  }
    +
     }
     
     /**
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
    index 35a0db76f3a8c..ba73024e3c04d 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
    @@ -36,6 +36,7 @@ trait ClassificationModel extends Serializable {
        *
        * @param testData RDD representing data points to be predicted
        * @return an RDD[Double] where each entry contains the corresponding prediction
    +   * @since 0.8.0
        */
       def predict(testData: RDD[Vector]): RDD[Double]
     
    @@ -44,6 +45,7 @@ trait ClassificationModel extends Serializable {
        *
        * @param testData array representing a single data point
        * @return predicted category from the trained model
    +   * @since 0.8.0
        */
       def predict(testData: Vector): Double
     
    @@ -51,6 +53,7 @@ trait ClassificationModel extends Serializable {
        * Predict values for examples stored in a JavaRDD.
        * @param testData JavaRDD representing data points to be predicted
        * @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction
    +   * @since 0.8.0
        */
       def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] =
         predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]]
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
    index 2df4d21e8cd55..268642ac6a2f6 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
    @@ -85,6 +85,7 @@ class LogisticRegressionModel (
        * in Binary Logistic Regression. An example with prediction score greater than or equal to
        * this threshold is identified as an positive, and negative otherwise. The default value is 0.5.
        * It is only used for binary classification.
    +   * @since 1.0.0
        */
       @Experimental
       def setThreshold(threshold: Double): this.type = {
    @@ -96,6 +97,7 @@ class LogisticRegressionModel (
        * :: Experimental ::
        * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions.
        * It is only used for binary classification.
    +   * @since 1.3.0
        */
       @Experimental
       def getThreshold: Option[Double] = threshold
    @@ -104,6 +106,7 @@ class LogisticRegressionModel (
        * :: Experimental ::
        * Clears the threshold so that `predict` will output raw prediction scores.
        * It is only used for binary classification.
    +   * @since 1.0.0
        */
       @Experimental
       def clearThreshold(): this.type = {
    @@ -155,6 +158,9 @@ class LogisticRegressionModel (
         }
       }
     
    +  /**
    +   * @since 1.3.0
    +   */
       override def save(sc: SparkContext, path: String): Unit = {
         GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName,
           numFeatures, numClasses, weights, intercept, threshold)
    @@ -162,6 +168,9 @@ class LogisticRegressionModel (
     
       override protected def formatVersion: String = "1.0"
     
    +  /**
    +   * @since 1.4.0
    +   */
       override def toString: String = {
         s"${super.toString}, numClasses = ${numClasses}, threshold = ${threshold.getOrElse("None")}"
       }
    @@ -169,6 +178,9 @@ class LogisticRegressionModel (
     
     object LogisticRegressionModel extends Loader[LogisticRegressionModel] {
     
    +  /**
    +   * @since 1.3.0
    +   */
       override def load(sc: SparkContext, path: String): LogisticRegressionModel = {
         val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
         // Hard-code class name string in case it changes in the future
    @@ -249,6 +261,7 @@ object LogisticRegressionWithSGD {
        * @param miniBatchFraction Fraction of data to be used per iteration.
        * @param initialWeights Initial set of weights to be used. Array should be equal in size to
        *        the number of features in the data.
    +   * @since 1.0.0
        */
       def train(
           input: RDD[LabeledPoint],
    @@ -271,6 +284,7 @@ object LogisticRegressionWithSGD {
        * @param stepSize Step size to be used for each iteration of gradient descent.
     
        * @param miniBatchFraction Fraction of data to be used per iteration.
    +   * @since 1.0.0
        */
       def train(
           input: RDD[LabeledPoint],
    @@ -292,6 +306,7 @@ object LogisticRegressionWithSGD {
     
        * @param numIterations Number of iterations of gradient descent to run.
        * @return a LogisticRegressionModel which has the weights and offset from training.
    +   * @since 1.0.0
        */
       def train(
           input: RDD[LabeledPoint],
    @@ -309,6 +324,7 @@ object LogisticRegressionWithSGD {
        * @param input RDD of (label, array of features) pairs.
        * @param numIterations Number of iterations of gradient descent to run.
        * @return a LogisticRegressionModel which has the weights and offset from training.
    +   * @since 1.0.0
        */
       def train(
           input: RDD[LabeledPoint],
    @@ -345,6 +361,7 @@ class LogisticRegressionWithLBFGS
        * Set the number of possible outcomes for k classes classification problem in
        * Multinomial Logistic Regression.
        * By default, it is binary logistic regression so k will be set to 2.
    +   * @since 1.3.0
        */
       @Experimental
       def setNumClasses(numClasses: Int): this.type = {
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
    index f51ee36d0dfcb..2df91c09421e9 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
    @@ -40,7 +40,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext}
      *              where D is number of features
      * @param modelType The type of NB model to fit  can be "multinomial" or "bernoulli"
      */
    -class NaiveBayesModel private[mllib] (
    +class NaiveBayesModel private[spark] (
         val labels: Array[Double],
         val pi: Array[Double],
         val theta: Array[Array[Double]],
    @@ -93,26 +93,70 @@ class NaiveBayesModel private[mllib] (
       override def predict(testData: Vector): Double = {
         modelType match {
           case Multinomial =>
    -        val prob = thetaMatrix.multiply(testData)
    -        BLAS.axpy(1.0, piVector, prob)
    -        labels(prob.argmax)
    +        labels(multinomialCalculation(testData).argmax)
           case Bernoulli =>
    -        testData.foreachActive { (index, value) =>
    -          if (value != 0.0 && value != 1.0) {
    -            throw new SparkException(
    -              s"Bernoulli naive Bayes requires 0 or 1 feature values but found $testData.")
    -          }
    -        }
    -        val prob = thetaMinusNegTheta.get.multiply(testData)
    -        BLAS.axpy(1.0, piVector, prob)
    -        BLAS.axpy(1.0, negThetaSum.get, prob)
    -        labels(prob.argmax)
    -      case _ =>
    -        // This should never happen.
    -        throw new UnknownError(s"Invalid modelType: $modelType.")
    +        labels(bernoulliCalculation(testData).argmax)
    +    }
    +  }
    +
    +  /**
    +   * Predict values for the given data set using the model trained.
    +   *
    +   * @param testData RDD representing data points to be predicted
    +   * @return an RDD[Vector] where each entry contains the predicted posterior class probabilities,
    +   *         in the same order as class labels
    +   */
    +  def predictProbabilities(testData: RDD[Vector]): RDD[Vector] = {
    +    val bcModel = testData.context.broadcast(this)
    +    testData.mapPartitions { iter =>
    +      val model = bcModel.value
    +      iter.map(model.predictProbabilities)
         }
       }
     
    +  /**
    +   * Predict posterior class probabilities for a single data point using the model trained.
    +   *
    +   * @param testData array representing a single data point
    +   * @return predicted posterior class probabilities from the trained model,
    +   *         in the same order as class labels
    +   */
    +  def predictProbabilities(testData: Vector): Vector = {
    +    modelType match {
    +      case Multinomial =>
    +        posteriorProbabilities(multinomialCalculation(testData))
    +      case Bernoulli =>
    +        posteriorProbabilities(bernoulliCalculation(testData))
    +    }
    +  }
    +
    +  private def multinomialCalculation(testData: Vector) = {
    +    val prob = thetaMatrix.multiply(testData)
    +    BLAS.axpy(1.0, piVector, prob)
    +    prob
    +  }
    +
    +  private def bernoulliCalculation(testData: Vector) = {
    +    testData.foreachActive((_, value) =>
    +      if (value != 0.0 && value != 1.0) {
    +        throw new SparkException(
    +          s"Bernoulli naive Bayes requires 0 or 1 feature values but found $testData.")
    +      }
    +    )
    +    val prob = thetaMinusNegTheta.get.multiply(testData)
    +    BLAS.axpy(1.0, piVector, prob)
    +    BLAS.axpy(1.0, negThetaSum.get, prob)
    +    prob
    +  }
    +
    +  private def posteriorProbabilities(logProb: DenseVector) = {
    +    val logProbArray = logProb.toArray
    +    val maxLog = logProbArray.max
    +    val scaledProbs = logProbArray.map(lp => math.exp(lp - maxLog))
    +    val probSum = scaledProbs.sum
    +    new DenseVector(scaledProbs.map(_ / probSum))
    +  }
    +
       override def save(sc: SparkContext, path: String): Unit = {
         val data = NaiveBayesModel.SaveLoadV2_0.Data(labels, pi, theta, modelType)
         NaiveBayesModel.SaveLoadV2_0.save(sc, path, data)
    @@ -338,7 +382,7 @@ class NaiveBayes private (
             BLAS.axpy(1.0, c2._2, c1._2)
             (c1._1 + c2._1, c1._2)
           }
    -    ).collect()
    +    ).collect().sortBy(_._1)
     
         val numLabels = aggregated.length
         var numDocuments = 0L
    @@ -381,13 +425,13 @@ class NaiveBayes private (
     object NaiveBayes {
     
       /** String name for multinomial model type. */
    -  private[classification] val Multinomial: String = "multinomial"
    +  private[spark] val Multinomial: String = "multinomial"
     
       /** String name for Bernoulli model type. */
    -  private[classification] val Bernoulli: String = "bernoulli"
    +  private[spark] val Bernoulli: String = "bernoulli"
     
       /* Set of modelTypes that NaiveBayes supports */
    -  private[classification] val supportedModelTypes = Set(Multinomial, Bernoulli)
    +  private[spark] val supportedModelTypes = Set(Multinomial, Bernoulli)
     
       /**
        * Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
    @@ -400,6 +444,7 @@ object NaiveBayes {
        *
        * @param input RDD of `(label, array of features)` pairs.  Every vector should be a frequency
        *              vector or a count vector.
    +   * @since 0.9.0
        */
       def train(input: RDD[LabeledPoint]): NaiveBayesModel = {
         new NaiveBayes().run(input)
    @@ -415,6 +460,7 @@ object NaiveBayes {
        * @param input RDD of `(label, array of features)` pairs.  Every vector should be a frequency
        *              vector or a count vector.
        * @param lambda The smoothing parameter
    +   * @since 0.9.0
        */
       def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = {
         new NaiveBayes(lambda, Multinomial).run(input)
    @@ -437,6 +483,7 @@ object NaiveBayes {
        *
        * @param modelType The type of NB model to fit from the enumeration NaiveBayesModels, can be
        *              multinomial or bernoulli
    +   * @since 0.9.0
        */
       def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = {
         require(supportedModelTypes.contains(modelType),
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
    index 348485560713e..5b54feeb10467 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
    @@ -46,6 +46,7 @@ class SVMModel (
        * Sets the threshold that separates positive predictions from negative predictions. An example
        * with prediction score greater than or equal to this threshold is identified as an positive,
        * and negative otherwise. The default value is 0.0.
    +   * @since 1.3.0
        */
       @Experimental
       def setThreshold(threshold: Double): this.type = {
    @@ -56,6 +57,7 @@ class SVMModel (
       /**
        * :: Experimental ::
        * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions.
    +   * @since 1.3.0
        */
       @Experimental
       def getThreshold: Option[Double] = threshold
    @@ -63,6 +65,7 @@ class SVMModel (
       /**
        * :: Experimental ::
        * Clears the threshold so that `predict` will output raw prediction scores.
    +   * @since 1.0.0
        */
       @Experimental
       def clearThreshold(): this.type = {
    @@ -81,6 +84,9 @@ class SVMModel (
         }
       }
     
    +  /**
    +   * @since 1.3.0
    +   */
       override def save(sc: SparkContext, path: String): Unit = {
         GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName,
           numFeatures = weights.size, numClasses = 2, weights, intercept, threshold)
    @@ -88,6 +94,9 @@ class SVMModel (
     
       override protected def formatVersion: String = "1.0"
     
    +  /**
    +   * @since 1.4.0
    +   */
       override def toString: String = {
         s"${super.toString}, numClasses = 2, threshold = ${threshold.getOrElse("None")}"
       }
    @@ -95,6 +104,9 @@ class SVMModel (
     
     object SVMModel extends Loader[SVMModel] {
     
    +   /**
    +   * @since 1.3.0
    +   */
       override def load(sc: SparkContext, path: String): SVMModel = {
         val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
         // Hard-code class name string in case it changes in the future
    @@ -173,6 +185,7 @@ object SVMWithSGD {
        * @param miniBatchFraction Fraction of data to be used per iteration.
        * @param initialWeights Initial set of weights to be used. Array should be equal in size to
        *        the number of features in the data.
    +   * @since 0.8.0
        */
       def train(
           input: RDD[LabeledPoint],
    @@ -196,6 +209,7 @@ object SVMWithSGD {
        * @param stepSize Step size to be used for each iteration of gradient descent.
        * @param regParam Regularization parameter.
        * @param miniBatchFraction Fraction of data to be used per iteration.
    +   * @since 0.8.0
        */
       def train(
           input: RDD[LabeledPoint],
    @@ -217,6 +231,7 @@ object SVMWithSGD {
        * @param regParam Regularization parameter.
        * @param numIterations Number of iterations of gradient descent to run.
        * @return a SVMModel which has the weights and offset from training.
    +   * @since 0.8.0
        */
       def train(
           input: RDD[LabeledPoint],
    @@ -235,6 +250,7 @@ object SVMWithSGD {
        * @param input RDD of (label, array of features) pairs.
        * @param numIterations Number of iterations of gradient descent to run.
        * @return a SVMModel which has the weights and offset from training.
    +   * @since 0.8.0
        */
       def train(input: RDD[LabeledPoint], numIterations: Int): SVMModel = {
         train(input, numIterations, 1.0, 0.01, 1.0)
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
    index fc509d2ba1470..e459367333d26 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
    @@ -140,6 +140,10 @@ class GaussianMixture private (
         // Get length of the input vectors
         val d = breezeData.first().length
     
    +    // Heuristic to distribute the computation of the [[MultivariateGaussian]]s, approximately when
    +    // d > 25 except for when k is very small
    +    val distributeGaussians = ((k - 1.0) / k) * d > 25
    +
         // Determine initial weights and corresponding Gaussians.
         // If the user supplied an initial GMM, we use those values, otherwise
         // we start with uniform weights, a random mean from the data, and
    @@ -171,14 +175,25 @@ class GaussianMixture private (
           // Create new distributions based on the partial assignments
           // (often referred to as the "M" step in literature)
           val sumWeights = sums.weights.sum
    -      var i = 0
    -      while (i < k) {
    -        val mu = sums.means(i) / sums.weights(i)
    -        BLAS.syr(-sums.weights(i), Vectors.fromBreeze(mu),
    -          Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix])
    -        weights(i) = sums.weights(i) / sumWeights
    -        gaussians(i) = new MultivariateGaussian(mu, sums.sigmas(i) / sums.weights(i))
    -        i = i + 1
    +
    +      if (distributeGaussians) {
    +        val numPartitions = math.min(k, 1024)
    +        val tuples =
    +          Seq.tabulate(k)(i => (sums.means(i), sums.sigmas(i), sums.weights(i)))
    +        val (ws, gs) = sc.parallelize(tuples, numPartitions).map { case (mean, sigma, weight) =>
    +          updateWeightsAndGaussians(mean, sigma, weight, sumWeights)
    +        }.collect.unzip
    +        Array.copy(ws, 0, weights, 0, ws.length)
    +        Array.copy(gs, 0, gaussians, 0, gs.length)
    +      } else {
    +        var i = 0
    +        while (i < k) {
    +          val (weight, gaussian) =
    +            updateWeightsAndGaussians(sums.means(i), sums.sigmas(i), sums.weights(i), sumWeights)
    +          weights(i) = weight
    +          gaussians(i) = gaussian
    +          i = i + 1
    +        }
           }
     
           llhp = llh // current becomes previous
    @@ -192,6 +207,19 @@ class GaussianMixture private (
       /** Java-friendly version of [[run()]] */
       def run(data: JavaRDD[Vector]): GaussianMixtureModel = run(data.rdd)
     
    +  private def updateWeightsAndGaussians(
    +      mean: BDV[Double],
    +      sigma: BreezeMatrix[Double],
    +      weight: Double,
    +      sumWeights: Double): (Double, MultivariateGaussian) = {
    +    val mu = (mean /= weight)
    +    BLAS.syr(-weight, Vectors.fromBreeze(mu),
    +      Matrices.fromBreeze(sigma).asInstanceOf[DenseMatrix])
    +    val newWeight = weight / sumWeights
    +    val newGaussian = new MultivariateGaussian(mu, sigma / weight)
    +    (newWeight, newGaussian)
    +  }
    +
       /** Average of dense breeze vectors */
       private def vectorMean(x: IndexedSeq[BV[Double]]): BDV[Double] = {
         val v = BDV.zeros[Double](x(0).length)
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
    index 0f8d6a399682d..0a65403f4ec95 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
    @@ -85,9 +85,7 @@ class KMeans private (
        * (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||.
        */
       def setInitializationMode(initializationMode: String): this.type = {
    -    if (initializationMode != KMeans.RANDOM && initializationMode != KMeans.K_MEANS_PARALLEL) {
    -      throw new IllegalArgumentException("Invalid initialization mode: " + initializationMode)
    -    }
    +    KMeans.validateInitMode(initializationMode)
         this.initializationMode = initializationMode
         this
       }
    @@ -156,6 +154,21 @@ class KMeans private (
         this
       }
     
    +  // Initial cluster centers can be provided as a KMeansModel object rather than using the
    +  // random or k-means|| initializationMode
    +  private var initialModel: Option[KMeansModel] = None
    +
    +  /**
    +   * Set the initial starting point, bypassing the random initialization or k-means||
    +   * The condition model.k == this.k must be met, failure results
    +   * in an IllegalArgumentException.
    +   */
    +  def setInitialModel(model: KMeansModel): this.type = {
    +    require(model.k == k, "mismatched cluster count")
    +    initialModel = Some(model)
    +    this
    +  }
    +
       /**
        * Train a K-means model on the given set of points; `data` should be cached for high
        * performance, because this is an iterative algorithm.
    @@ -193,20 +206,34 @@ class KMeans private (
     
         val initStartTime = System.nanoTime()
     
    -    val centers = if (initializationMode == KMeans.RANDOM) {
    -      initRandom(data)
    +    // Only one run is allowed when initialModel is given
    +    val numRuns = if (initialModel.nonEmpty) {
    +      if (runs > 1) logWarning("Ignoring runs; one run is allowed when initialModel is given.")
    +      1
         } else {
    -      initKMeansParallel(data)
    +      runs
         }
     
    +    val centers = initialModel match {
    +      case Some(kMeansCenters) => {
    +        Array(kMeansCenters.clusterCenters.map(s => new VectorWithNorm(s)))
    +      }
    +      case None => {
    +        if (initializationMode == KMeans.RANDOM) {
    +          initRandom(data)
    +        } else {
    +          initKMeansParallel(data)
    +        }
    +      }
    +    }
         val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
         logInfo(s"Initialization with $initializationMode took " + "%.3f".format(initTimeInSeconds) +
           " seconds.")
     
    -    val active = Array.fill(runs)(true)
    -    val costs = Array.fill(runs)(0.0)
    +    val active = Array.fill(numRuns)(true)
    +    val costs = Array.fill(numRuns)(0.0)
     
    -    var activeRuns = new ArrayBuffer[Int] ++ (0 until runs)
    +    var activeRuns = new ArrayBuffer[Int] ++ (0 until numRuns)
         var iteration = 0
     
         val iterationStartTime = System.nanoTime()
    @@ -521,6 +548,14 @@ object KMeans {
           v2: VectorWithNorm): Double = {
         MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm)
       }
    +
    +  private[spark] def validateInitMode(initMode: String): Boolean = {
    +    initMode match {
    +      case KMeans.RANDOM => true
    +      case KMeans.K_MEANS_PARALLEL => true
    +      case _ => false
    +    }
    +  }
     }
     
     /**
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
    index a410547a72fda..ab124e6d77c5e 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
    @@ -23,11 +23,10 @@ import org.apache.spark.Logging
     import org.apache.spark.annotation.{DeveloperApi, Experimental}
     import org.apache.spark.api.java.JavaPairRDD
     import org.apache.spark.graphx._
    -import org.apache.spark.mllib.linalg.Vector
    +import org.apache.spark.mllib.linalg.{Vector, Vectors}
     import org.apache.spark.rdd.RDD
     import org.apache.spark.util.Utils
     
    -
     /**
      * :: Experimental ::
      *
    @@ -49,14 +48,15 @@ import org.apache.spark.util.Utils
     class LDA private (
         private var k: Int,
         private var maxIterations: Int,
    -    private var docConcentration: Double,
    +    private var docConcentration: Vector,
         private var topicConcentration: Double,
         private var seed: Long,
         private var checkpointInterval: Int,
         private var ldaOptimizer: LDAOptimizer) extends Logging {
     
    -  def this() = this(k = 10, maxIterations = 20, docConcentration = -1, topicConcentration = -1,
    -    seed = Utils.random.nextLong(), checkpointInterval = 10, ldaOptimizer = new EMLDAOptimizer)
    +  def this() = this(k = 10, maxIterations = 20, docConcentration = Vectors.dense(-1),
    +    topicConcentration = -1, seed = Utils.random.nextLong(), checkpointInterval = 10,
    +    ldaOptimizer = new EMLDAOptimizer)
     
       /**
        * Number of topics to infer.  I.e., the number of soft cluster centers.
    @@ -77,37 +77,50 @@ class LDA private (
        * Concentration parameter (commonly named "alpha") for the prior placed on documents'
        * distributions over topics ("theta").
        *
    -   * This is the parameter to a symmetric Dirichlet distribution.
    +   * This is the parameter to a Dirichlet distribution.
        */
    -  def getDocConcentration: Double = this.docConcentration
    +  def getDocConcentration: Vector = this.docConcentration
     
       /**
        * Concentration parameter (commonly named "alpha") for the prior placed on documents'
        * distributions over topics ("theta").
        *
    -   * This is the parameter to a symmetric Dirichlet distribution, where larger values
    -   * mean more smoothing (more regularization).
    +   * This is the parameter to a Dirichlet distribution, where larger values mean more smoothing
    +   * (more regularization).
        *
    -   * If set to -1, then docConcentration is set automatically.
    -   *  (default = -1 = automatic)
    +   * If set to a singleton vector Vector(-1), then docConcentration is set automatically. If set to
    +   * singleton vector Vector(t) where t != -1, then t is replicated to a vector of length k during
    +   * [[LDAOptimizer.initialize()]]. Otherwise, the [[docConcentration]] vector must be length k.
    +   * (default = Vector(-1) = automatic)
        *
        * Optimizer-specific parameter settings:
        *  - EM
    -   *     - Value should be > 1.0
    -   *     - default = (50 / k) + 1, where 50/k is common in LDA libraries and +1 follows
    -   *       Asuncion et al. (2009), who recommend a +1 adjustment for EM.
    +   *     - Currently only supports symmetric distributions, so all values in the vector should be
    +   *       the same.
    +   *     - Values should be > 1.0
    +   *     - default = uniformly (50 / k) + 1, where 50/k is common in LDA libraries and +1 follows
    +   *       from Asuncion et al. (2009), who recommend a +1 adjustment for EM.
        *  - Online
    -   *     - Value should be >= 0
    -   *     - default = (1.0 / k), following the implementation from
    +   *     - Values should be >= 0
    +   *     - default = uniformly (1.0 / k), following the implementation from
        *       [[https://github.com/Blei-Lab/onlineldavb]].
        */
    -  def setDocConcentration(docConcentration: Double): this.type = {
    +  def setDocConcentration(docConcentration: Vector): this.type = {
         this.docConcentration = docConcentration
         this
       }
     
    +  /** Replicates Double to create a symmetric prior */
    +  def setDocConcentration(docConcentration: Double): this.type = {
    +    this.docConcentration = Vectors.dense(docConcentration)
    +    this
    +  }
    +
       /** Alias for [[getDocConcentration]] */
    -  def getAlpha: Double = getDocConcentration
    +  def getAlpha: Vector = getDocConcentration
    +
    +  /** Alias for [[setDocConcentration()]] */
    +  def setAlpha(alpha: Vector): this.type = setDocConcentration(alpha)
     
       /** Alias for [[setDocConcentration()]] */
       def setAlpha(alpha: Double): this.type = setDocConcentration(alpha)
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
    index 974b26924dfb8..6cfad3fbbdb87 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
    @@ -17,13 +17,21 @@
     
     package org.apache.spark.mllib.clustering
     
    -import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum}
    -
    +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, normalize, sum}
    +import breeze.numerics.{exp, lgamma}
    +import org.apache.hadoop.fs.Path
    +import org.json4s.DefaultFormats
    +import org.json4s.JsonDSL._
    +import org.json4s.jackson.JsonMethods._
    +
    +import org.apache.spark.SparkContext
     import org.apache.spark.annotation.Experimental
     import org.apache.spark.api.java.JavaPairRDD
    -import org.apache.spark.graphx.{VertexId, EdgeContext, Graph}
    -import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix}
    +import org.apache.spark.graphx.{Edge, EdgeContext, Graph, VertexId}
    +import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors}
    +import org.apache.spark.mllib.util.{Loader, Saveable}
     import org.apache.spark.rdd.RDD
    +import org.apache.spark.sql.{Row, SQLContext}
     import org.apache.spark.util.BoundedPriorityQueue
     
     /**
    @@ -35,7 +43,7 @@ import org.apache.spark.util.BoundedPriorityQueue
      * including local and distributed data structures.
      */
     @Experimental
    -abstract class LDAModel private[clustering] {
    +abstract class LDAModel private[clustering] extends Saveable {
     
       /** Number of topics */
       def k: Int
    @@ -43,6 +51,31 @@ abstract class LDAModel private[clustering] {
       /** Vocabulary size (number of terms or terms in the vocabulary) */
       def vocabSize: Int
     
    +  /**
    +   * Concentration parameter (commonly named "alpha") for the prior placed on documents'
    +   * distributions over topics ("theta").
    +   *
    +   * This is the parameter to a Dirichlet distribution.
    +   */
    +  def docConcentration: Vector
    +
    +  /**
    +   * Concentration parameter (commonly named "beta" or "eta") for the prior placed on topics'
    +   * distributions over terms.
    +   *
    +   * This is the parameter to a symmetric Dirichlet distribution.
    +   *
    +   * Note: The topics' distributions over terms are called "beta" in the original LDA paper
    +   * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009.
    +   */
    +  def topicConcentration: Double
    +
    +  /**
    +  * Shape parameter for random initialization of variational parameter gamma.
    +  * Used for variational inference for perplexity and other test-time computations.
    +  */
    +  protected def gammaShape: Double
    +
       /**
        * Inferred topics, where each topic is represented by a distribution over terms.
        * This is a matrix of size vocabSize x k, where each column is a topic.
    @@ -153,12 +186,14 @@ abstract class LDAModel private[clustering] {
      * This model stores only the inferred topics.
      * It may be used for computing topics for new documents, but it may give less accurate answers
      * than the [[DistributedLDAModel]].
    - *
      * @param topics Inferred topics (vocabSize x k matrix).
      */
     @Experimental
     class LocalLDAModel private[clustering] (
    -    private val topics: Matrix) extends LDAModel with Serializable {
    +    val topics: Matrix,
    +    override val docConcentration: Vector,
    +    override val topicConcentration: Double,
    +    override protected[clustering] val gammaShape: Double) extends LDAModel with Serializable {
     
       override def k: Int = topics.numCols
     
    @@ -176,12 +211,218 @@ class LocalLDAModel private[clustering] (
         }.toArray
       }
     
    +  override protected def formatVersion = "1.0"
    +
    +  override def save(sc: SparkContext, path: String): Unit = {
    +    LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix, docConcentration, topicConcentration,
    +      gammaShape)
    +  }
       // TODO
       // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ???
     
    -  // TODO:
    -  // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ???
    +  /**
    +   * Calculate the log variational bound on perplexity. See Equation (16) in original Online
    +   * LDA paper.
    +   * @param documents test corpus to use for calculating perplexity
    +   * @return the log perplexity per word
    +   */
    +  def logPerplexity(documents: RDD[(Long, Vector)]): Double = {
    +    val corpusWords = documents
    +      .map { case (_, termCounts) => termCounts.toArray.sum }
    +      .sum()
    +    val batchVariationalBound = bound(documents, docConcentration,
    +      topicConcentration, topicsMatrix.toBreeze.toDenseMatrix, gammaShape, k, vocabSize)
    +    val perWordBound = batchVariationalBound / corpusWords
    +
    +    perWordBound
    +  }
    +
    +  /**
    +   * Estimate the variational likelihood bound of from `documents`:
    +   *    log p(documents) >= E_q[log p(documents)] - E_q[log q(documents)]
    +   * This bound is derived by decomposing the LDA model to:
    +   *    log p(documents) = E_q[log p(documents)] - E_q[log q(documents)] + D(q|p)
    +   * and noting that the KL-divergence D(q|p) >= 0. See Equation (16) in original Online LDA paper.
    +   * @param documents a subset of the test corpus
    +   * @param alpha document-topic Dirichlet prior parameters
    +   * @param eta topic-word Dirichlet prior parameters
    +   * @param lambda parameters for variational q(beta | lambda) topic-word distributions
    +   * @param gammaShape shape parameter for random initialization of variational q(theta | gamma)
    +   *                   topic mixture distributions
    +   * @param k number of topics
    +   * @param vocabSize number of unique terms in the entire test corpus
    +   */
    +  private def bound(
    +      documents: RDD[(Long, Vector)],
    +      alpha: Vector,
    +      eta: Double,
    +      lambda: BDM[Double],
    +      gammaShape: Double,
    +      k: Int,
    +      vocabSize: Long): Double = {
    +    val brzAlpha = alpha.toBreeze.toDenseVector
    +    // transpose because dirichletExpectation normalizes by row and we need to normalize
    +    // by topic (columns of lambda)
    +    val Elogbeta = LDAUtils.dirichletExpectation(lambda.t).t
    +
    +    var score = documents.filter(_._2.numNonzeros > 0).map { case (id: Long, termCounts: Vector) =>
    +      var docScore = 0.0D
    +      val (gammad: BDV[Double], _) = OnlineLDAOptimizer.variationalTopicInference(
    +        termCounts, exp(Elogbeta), brzAlpha, gammaShape, k)
    +      val Elogthetad: BDV[Double] = LDAUtils.dirichletExpectation(gammad)
    +
    +      // E[log p(doc | theta, beta)]
    +      termCounts.foreachActive { case (idx, count) =>
    +        docScore += count * LDAUtils.logSumExp(Elogthetad + Elogbeta(idx, ::).t)
    +      }
    +      // E[log p(theta | alpha) - log q(theta | gamma)]; assumes alpha is a vector
    +      docScore += sum((brzAlpha - gammad) :* Elogthetad)
    +      docScore += sum(lgamma(gammad) - lgamma(brzAlpha))
    +      docScore += lgamma(sum(brzAlpha)) - lgamma(sum(gammad))
    +
    +      docScore
    +    }.sum()
    +
    +    // E[log p(beta | eta) - log q (beta | lambda)]; assumes eta is a scalar
    +    score += sum((eta - lambda) :* Elogbeta)
    +    score += sum(lgamma(lambda) - lgamma(eta))
    +
    +    val sumEta = eta * vocabSize
    +    score += sum(lgamma(sumEta) - lgamma(sum(lambda(::, breeze.linalg.*))))
    +
    +    score
    +  }
    +
    +  /**
    +   * Predicts the topic mixture distribution for each document (often called "theta" in the
    +   * literature).  Returns a vector of zeros for an empty document.
    +   *
    +   * This uses a variational approximation following Hoffman et al. (2010), where the approximate
    +   * distribution is called "gamma."  Technically, this method returns this approximation "gamma"
    +   * for each document.
    +   * @param documents documents to predict topic mixture distributions for
    +   * @return An RDD of (document ID, topic mixture distribution for document)
    +   */
    +  // TODO: declare in LDAModel and override once implemented in DistributedLDAModel
    +  def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = {
    +    // Double transpose because dirichletExpectation normalizes by row and we need to normalize
    +    // by topic (columns of lambda)
    +    val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.toBreeze.toDenseMatrix.t).t)
    +    val docConcentrationBrz = this.docConcentration.toBreeze
    +    val gammaShape = this.gammaShape
    +    val k = this.k
    +
    +    documents.map { case (id: Long, termCounts: Vector) =>
    +      if (termCounts.numNonzeros == 0) {
    +         (id, Vectors.zeros(k))
    +      } else {
    +        val (gamma, _) = OnlineLDAOptimizer.variationalTopicInference(
    +          termCounts,
    +          expElogbeta,
    +          docConcentrationBrz,
    +          gammaShape,
    +          k)
    +        (id, Vectors.dense(normalize(gamma, 1.0).toArray))
    +      }
    +    }
    +  }
    +
    +}
    +
    +
    +@Experimental
    +object LocalLDAModel extends Loader[LocalLDAModel] {
    +
    +  private object SaveLoadV1_0 {
    +
    +    val thisFormatVersion = "1.0"
    +
    +    val thisClassName = "org.apache.spark.mllib.clustering.LocalLDAModel"
    +
    +    // Store the distribution of terms of each topic and the column index in topicsMatrix
    +    // as a Row in data.
    +    case class Data(topic: Vector, index: Int)
    +
    +    def save(
    +        sc: SparkContext,
    +        path: String,
    +        topicsMatrix: Matrix,
    +        docConcentration: Vector,
    +        topicConcentration: Double,
    +        gammaShape: Double): Unit = {
    +      val sqlContext = SQLContext.getOrCreate(sc)
    +      import sqlContext.implicits._
    +
    +      val k = topicsMatrix.numCols
    +      val metadata = compact(render
    +        (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
    +          ("k" -> k) ~ ("vocabSize" -> topicsMatrix.numRows) ~
    +          ("docConcentration" -> docConcentration.toArray.toSeq) ~
    +          ("topicConcentration" -> topicConcentration) ~
    +          ("gammaShape" -> gammaShape)))
    +      sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
    +
    +      val topicsDenseMatrix = topicsMatrix.toBreeze.toDenseMatrix
    +      val topics = Range(0, k).map { topicInd =>
    +        Data(Vectors.dense((topicsDenseMatrix(::, topicInd).toArray)), topicInd)
    +      }.toSeq
    +      sc.parallelize(topics, 1).toDF().write.parquet(Loader.dataPath(path))
    +    }
    +
    +    def load(
    +        sc: SparkContext,
    +        path: String,
    +        docConcentration: Vector,
    +        topicConcentration: Double,
    +        gammaShape: Double): LocalLDAModel = {
    +      val dataPath = Loader.dataPath(path)
    +      val sqlContext = SQLContext.getOrCreate(sc)
    +      val dataFrame = sqlContext.read.parquet(dataPath)
    +
    +      Loader.checkSchema[Data](dataFrame.schema)
    +      val topics = dataFrame.collect()
    +      val vocabSize = topics(0).getAs[Vector](0).size
    +      val k = topics.size
    +
    +      val brzTopics = BDM.zeros[Double](vocabSize, k)
    +      topics.foreach { case Row(vec: Vector, ind: Int) =>
    +        brzTopics(::, ind) := vec.toBreeze
    +      }
    +      val topicsMat = Matrices.fromBreeze(brzTopics)
    +
    +      // TODO: initialize with docConcentration, topicConcentration, and gammaShape after SPARK-9940
    +      new LocalLDAModel(topicsMat, docConcentration, topicConcentration, gammaShape)
    +    }
    +  }
     
    +  override def load(sc: SparkContext, path: String): LocalLDAModel = {
    +    val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path)
    +    implicit val formats = DefaultFormats
    +    val expectedK = (metadata \ "k").extract[Int]
    +    val expectedVocabSize = (metadata \ "vocabSize").extract[Int]
    +    val docConcentration =
    +      Vectors.dense((metadata \ "docConcentration").extract[Seq[Double]].toArray)
    +    val topicConcentration = (metadata \ "topicConcentration").extract[Double]
    +    val gammaShape = (metadata \ "gammaShape").extract[Double]
    +    val classNameV1_0 = SaveLoadV1_0.thisClassName
    +
    +    val model = (loadedClassName, loadedVersion) match {
    +      case (className, "1.0") if className == classNameV1_0 =>
    +        SaveLoadV1_0.load(sc, path, docConcentration, topicConcentration, gammaShape)
    +      case _ => throw new Exception(
    +        s"LocalLDAModel.load did not recognize model with (className, format version):" +
    +          s"($loadedClassName, $loadedVersion).  Supported:\n" +
    +          s"  ($classNameV1_0, 1.0)")
    +    }
    +
    +    val topicsMatrix = model.topicsMatrix
    +    require(expectedK == topicsMatrix.numCols,
    +      s"LocalLDAModel requires $expectedK topics, got ${topicsMatrix.numCols} topics")
    +    require(expectedVocabSize == topicsMatrix.numRows,
    +      s"LocalLDAModel requires $expectedVocabSize terms for each topic, " +
    +        s"but got ${topicsMatrix.numRows}")
    +    model
    +  }
     }
     
     /**
    @@ -193,28 +434,25 @@ class LocalLDAModel private[clustering] (
      * than the [[LocalLDAModel]].
      */
     @Experimental
    -class DistributedLDAModel private (
    -    private val graph: Graph[LDA.TopicCounts, LDA.TokenCount],
    -    private val globalTopicTotals: LDA.TopicCounts,
    +class DistributedLDAModel private[clustering] (
    +    private[clustering] val graph: Graph[LDA.TopicCounts, LDA.TokenCount],
    +    private[clustering] val globalTopicTotals: LDA.TopicCounts,
         val k: Int,
         val vocabSize: Int,
    -    private val docConcentration: Double,
    -    private val topicConcentration: Double,
    +    override val docConcentration: Vector,
    +    override val topicConcentration: Double,
    +    override protected[clustering] val gammaShape: Double,
         private[spark] val iterationTimes: Array[Double]) extends LDAModel {
     
       import LDA._
     
    -  private[clustering] def this(state: EMLDAOptimizer, iterationTimes: Array[Double]) = {
    -    this(state.graph, state.globalTopicTotals, state.k, state.vocabSize, state.docConcentration,
    -      state.topicConcentration, iterationTimes)
    -  }
    -
       /**
        * Convert model to a local model.
        * The local model stores the inferred topics but not the topic distributions for training
        * documents.
        */
    -  def toLocal: LocalLDAModel = new LocalLDAModel(topicsMatrix)
    +  def toLocal: LocalLDAModel = new LocalLDAModel(topicsMatrix, docConcentration, topicConcentration,
    +    gammaShape)
     
       /**
        * Inferred topics, where each topic is represented by a distribution over terms.
    @@ -286,8 +524,9 @@ class DistributedLDAModel private (
        *    hyperparameters.
        */
       lazy val logLikelihood: Double = {
    -    val eta = topicConcentration
    -    val alpha = docConcentration
    +    // TODO: generalize this for asymmetric (non-scalar) alpha
    +    val alpha = this.docConcentration(0) // To avoid closure capture of enclosing object
    +    val eta = this.topicConcentration
         assert(eta > 1.0)
         assert(alpha > 1.0)
         val N_k = globalTopicTotals
    @@ -311,8 +550,9 @@ class DistributedLDAModel private (
        *  log P(topics, topic distributions for docs | alpha, eta)
        */
       lazy val logPrior: Double = {
    -    val eta = topicConcentration
    -    val alpha = docConcentration
    +    // TODO: generalize this for asymmetric (non-scalar) alpha
    +    val alpha = this.docConcentration(0) // To avoid closure capture of enclosing object
    +    val eta = this.topicConcentration
         // Term vertices: Compute phi_{wk}.  Use to compute prior log probability.
         // Doc vertex: Compute theta_{kj}.  Use to compute prior log probability.
         val N_k = globalTopicTotals
    @@ -323,12 +563,12 @@ class DistributedLDAModel private (
               val N_wk = vertex._2
               val smoothed_N_wk: TopicCounts = N_wk + (eta - 1.0)
               val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k
    -          (eta - 1.0) * brzSum(phi_wk.map(math.log))
    +          (eta - 1.0) * sum(phi_wk.map(math.log))
             } else {
               val N_kj = vertex._2
               val smoothed_N_kj: TopicCounts = N_kj + (alpha - 1.0)
               val theta_kj: TopicCounts = normalize(smoothed_N_kj, 1.0)
    -          (alpha - 1.0) * brzSum(theta_kj.map(math.log))
    +          (alpha - 1.0) * sum(theta_kj.map(math.log))
             }
         }
         graph.vertices.aggregate(0.0)(seqOp, _ + _)
    @@ -354,4 +594,142 @@ class DistributedLDAModel private (
       // TODO:
       // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ???
     
    +  override protected def formatVersion = "1.0"
    +
    +  override def save(sc: SparkContext, path: String): Unit = {
    +    DistributedLDAModel.SaveLoadV1_0.save(
    +      sc, path, graph, globalTopicTotals, k, vocabSize, docConcentration, topicConcentration,
    +      iterationTimes, gammaShape)
    +  }
     }
    +
    +
    +@Experimental
    +object DistributedLDAModel extends Loader[DistributedLDAModel] {
    +
    +  private object SaveLoadV1_0 {
    +
    +    val thisFormatVersion = "1.0"
    +
    +    val thisClassName = "org.apache.spark.mllib.clustering.DistributedLDAModel"
    +
    +    // Store globalTopicTotals as a Vector.
    +    case class Data(globalTopicTotals: Vector)
    +
    +    // Store each term and document vertex with an id and the topicWeights.
    +    case class VertexData(id: Long, topicWeights: Vector)
    +
    +    // Store each edge with the source id, destination id and tokenCounts.
    +    case class EdgeData(srcId: Long, dstId: Long, tokenCounts: Double)
    +
    +    def save(
    +        sc: SparkContext,
    +        path: String,
    +        graph: Graph[LDA.TopicCounts, LDA.TokenCount],
    +        globalTopicTotals: LDA.TopicCounts,
    +        k: Int,
    +        vocabSize: Int,
    +        docConcentration: Vector,
    +        topicConcentration: Double,
    +        iterationTimes: Array[Double],
    +        gammaShape: Double): Unit = {
    +      val sqlContext = SQLContext.getOrCreate(sc)
    +      import sqlContext.implicits._
    +
    +      val metadata = compact(render
    +        (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
    +          ("k" -> k) ~ ("vocabSize" -> vocabSize) ~
    +          ("docConcentration" -> docConcentration.toArray.toSeq) ~
    +          ("topicConcentration" -> topicConcentration) ~
    +          ("iterationTimes" -> iterationTimes.toSeq) ~
    +          ("gammaShape" -> gammaShape)))
    +      sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
    +
    +      val newPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString
    +      sc.parallelize(Seq(Data(Vectors.fromBreeze(globalTopicTotals)))).toDF()
    +        .write.parquet(newPath)
    +
    +      val verticesPath = new Path(Loader.dataPath(path), "topicCounts").toUri.toString
    +      graph.vertices.map { case (ind, vertex) =>
    +        VertexData(ind, Vectors.fromBreeze(vertex))
    +      }.toDF().write.parquet(verticesPath)
    +
    +      val edgesPath = new Path(Loader.dataPath(path), "tokenCounts").toUri.toString
    +      graph.edges.map { case Edge(srcId, dstId, prop) =>
    +        EdgeData(srcId, dstId, prop)
    +      }.toDF().write.parquet(edgesPath)
    +    }
    +
    +    def load(
    +        sc: SparkContext,
    +        path: String,
    +        vocabSize: Int,
    +        docConcentration: Vector,
    +        topicConcentration: Double,
    +        iterationTimes: Array[Double],
    +        gammaShape: Double): DistributedLDAModel = {
    +      val dataPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString
    +      val vertexDataPath = new Path(Loader.dataPath(path), "topicCounts").toUri.toString
    +      val edgeDataPath = new Path(Loader.dataPath(path), "tokenCounts").toUri.toString
    +      val sqlContext = SQLContext.getOrCreate(sc)
    +      val dataFrame = sqlContext.read.parquet(dataPath)
    +      val vertexDataFrame = sqlContext.read.parquet(vertexDataPath)
    +      val edgeDataFrame = sqlContext.read.parquet(edgeDataPath)
    +
    +      Loader.checkSchema[Data](dataFrame.schema)
    +      Loader.checkSchema[VertexData](vertexDataFrame.schema)
    +      Loader.checkSchema[EdgeData](edgeDataFrame.schema)
    +      val globalTopicTotals: LDA.TopicCounts =
    +        dataFrame.first().getAs[Vector](0).toBreeze.toDenseVector
    +      val vertices: RDD[(VertexId, LDA.TopicCounts)] = vertexDataFrame.map {
    +        case Row(ind: Long, vec: Vector) => (ind, vec.toBreeze.toDenseVector)
    +      }
    +
    +      val edges: RDD[Edge[LDA.TokenCount]] = edgeDataFrame.map {
    +        case Row(srcId: Long, dstId: Long, prop: Double) => Edge(srcId, dstId, prop)
    +      }
    +      val graph: Graph[LDA.TopicCounts, LDA.TokenCount] = Graph(vertices, edges)
    +
    +      new DistributedLDAModel(graph, globalTopicTotals, globalTopicTotals.length, vocabSize,
    +        docConcentration, topicConcentration, gammaShape, iterationTimes)
    +    }
    +
    +  }
    +
    +  override def load(sc: SparkContext, path: String): DistributedLDAModel = {
    +    val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path)
    +    implicit val formats = DefaultFormats
    +    val expectedK = (metadata \ "k").extract[Int]
    +    val vocabSize = (metadata \ "vocabSize").extract[Int]
    +    val docConcentration =
    +      Vectors.dense((metadata \ "docConcentration").extract[Seq[Double]].toArray)
    +    val topicConcentration = (metadata \ "topicConcentration").extract[Double]
    +    val iterationTimes = (metadata \ "iterationTimes").extract[Seq[Double]]
    +    val gammaShape = (metadata \ "gammaShape").extract[Double]
    +    val classNameV1_0 = SaveLoadV1_0.thisClassName
    +
    +    val model = (loadedClassName, loadedVersion) match {
    +      case (className, "1.0") if className == classNameV1_0 => {
    +        DistributedLDAModel.SaveLoadV1_0.load(sc, path, vocabSize, docConcentration,
    +          topicConcentration, iterationTimes.toArray, gammaShape)
    +      }
    +      case _ => throw new Exception(
    +        s"DistributedLDAModel.load did not recognize model with (className, format version):" +
    +          s"($loadedClassName, $loadedVersion).  Supported: ($classNameV1_0, 1.0)")
    +    }
    +
    +    require(model.vocabSize == vocabSize,
    +      s"DistributedLDAModel requires $vocabSize vocabSize, got ${model.vocabSize} vocabSize")
    +    require(model.docConcentration == docConcentration,
    +      s"DistributedLDAModel requires $docConcentration docConcentration, " +
    +        s"got ${model.docConcentration} docConcentration")
    +    require(model.topicConcentration == topicConcentration,
    +      s"DistributedLDAModel requires $topicConcentration docConcentration, " +
    +        s"got ${model.topicConcentration} docConcentration")
    +    require(expectedK == model.k,
    +      s"DistributedLDAModel requires $expectedK topics, got ${model.k} topics")
    +    model
    +  }
    +
    +}
    +
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
    index 8e5154b902d1d..d6f8b29a43dfd 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
    @@ -19,15 +19,15 @@ package org.apache.spark.mllib.clustering
     
     import java.util.Random
     
    -import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, sum, normalize, kron}
    -import breeze.numerics.{digamma, exp, abs}
    +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, normalize, sum}
    +import breeze.numerics.{abs, exp}
     import breeze.stats.distributions.{Gamma, RandBasis}
     
     import org.apache.spark.annotation.DeveloperApi
     import org.apache.spark.graphx._
     import org.apache.spark.graphx.impl.GraphImpl
     import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer
    -import org.apache.spark.mllib.linalg.{Matrices, SparseVector, DenseVector, Vector}
    +import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors}
     import org.apache.spark.rdd.RDD
     
     /**
    @@ -95,8 +95,11 @@ final class EMLDAOptimizer extends LDAOptimizer {
        * Compute bipartite term/doc graph.
        */
       override private[clustering] def initialize(docs: RDD[(Long, Vector)], lda: LDA): LDAOptimizer = {
    +    val docConcentration = lda.getDocConcentration(0)
    +    require({
    +      lda.getDocConcentration.toArray.forall(_ == docConcentration)
    +    }, "EMLDAOptimizer currently only supports symmetric document-topic priors")
     
    -    val docConcentration = lda.getDocConcentration
         val topicConcentration = lda.getTopicConcentration
         val k = lda.getK
     
    @@ -139,8 +142,9 @@ final class EMLDAOptimizer extends LDAOptimizer {
         this.k = k
         this.vocabSize = docs.take(1).head._2.size
         this.checkpointInterval = lda.getCheckpointInterval
    -    this.graphCheckpointer = new
    -      PeriodicGraphCheckpointer[TopicCounts, TokenCount](graph, checkpointInterval)
    +    this.graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount](
    +      checkpointInterval, graph.vertices.sparkContext)
    +    this.graphCheckpointer.update(this.graph)
         this.globalTopicTotals = computeGlobalTopicTotals()
         this
       }
    @@ -185,7 +189,7 @@ final class EMLDAOptimizer extends LDAOptimizer {
         // Update the vertex descriptors with the new counts.
         val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges)
         graph = newGraph
    -    graphCheckpointer.updateGraph(newGraph)
    +    graphCheckpointer.update(newGraph)
         globalTopicTotals = computeGlobalTopicTotals()
         this
       }
    @@ -205,7 +209,11 @@ final class EMLDAOptimizer extends LDAOptimizer {
       override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = {
         require(graph != null, "graph is null, EMLDAOptimizer not initialized.")
         this.graphCheckpointer.deleteAllCheckpoints()
    -    new DistributedLDAModel(this, iterationTimes)
    +    // This assumes gammaShape = 100 in OnlineLDAOptimizer to ensure equivalence in LDAModel.toLocal
    +    // conversion
    +    new DistributedLDAModel(this.graph, this.globalTopicTotals, this.k, this.vocabSize,
    +      Vectors.dense(Array.fill(this.k)(this.docConcentration)), this.topicConcentration,
    +      100, iterationTimes)
       }
     }
     
    @@ -229,10 +237,10 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
       private var vocabSize: Int = 0
     
       /** alias for docConcentration */
    -  private var alpha: Double = 0
    +  private var alpha: Vector = Vectors.dense(0)
     
       /** (private[clustering] for debugging)  Get docConcentration */
    -  private[clustering] def getAlpha: Double = alpha
    +  private[clustering] def getAlpha: Vector = alpha
     
       /** alias for topicConcentration */
       private var eta: Double = 0
    @@ -343,7 +351,19 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
         this.k = lda.getK
         this.corpusSize = docs.count()
         this.vocabSize = docs.first()._2.size
    -    this.alpha = if (lda.getDocConcentration == -1) 1.0 / k else lda.getDocConcentration
    +    this.alpha = if (lda.getDocConcentration.size == 1) {
    +      if (lda.getDocConcentration(0) == -1) Vectors.dense(Array.fill(k)(1.0 / k))
    +      else {
    +        require(lda.getDocConcentration(0) >= 0, s"all entries in alpha must be >=0, got: $alpha")
    +        Vectors.dense(Array.fill(k)(lda.getDocConcentration(0)))
    +      }
    +    } else {
    +      require(lda.getDocConcentration.size == k, s"alpha must have length k, got: $alpha")
    +      lda.getDocConcentration.foreachActive { case (_, x) =>
    +        require(x >= 0, s"all entries in alpha must be >= 0, got: $alpha")
    +      }
    +      lda.getDocConcentration
    +    }
         this.eta = if (lda.getTopicConcentration == -1) 1.0 / k else lda.getTopicConcentration
         this.randomGenerator = new Random(lda.getSeed)
     
    @@ -370,76 +390,52 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
         iteration += 1
         val k = this.k
         val vocabSize = this.vocabSize
    -    val Elogbeta = dirichletExpectation(lambda)
    -    val expElogbeta = exp(Elogbeta)
    -    val alpha = this.alpha
    +    val expElogbeta = exp(LDAUtils.dirichletExpectation(lambda)).t
    +    val alpha = this.alpha.toBreeze
         val gammaShape = this.gammaShape
     
    -    val stats: RDD[BDM[Double]] = batch.mapPartitions { docs =>
    -      val stat = BDM.zeros[Double](k, vocabSize)
    -      docs.foreach { doc =>
    -        val termCounts = doc._2
    -        val (ids: List[Int], cts: Array[Double]) = termCounts match {
    -          case v: DenseVector => ((0 until v.size).toList, v.values)
    -          case v: SparseVector => (v.indices.toList, v.values)
    -          case v => throw new IllegalArgumentException("Online LDA does not support vector type "
    -            + v.getClass)
    -        }
    -
    -        // Initialize the variational distribution q(theta|gamma) for the mini-batch
    -        var gammad = new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k).t // 1 * K
    -        var Elogthetad = digamma(gammad) - digamma(sum(gammad))     // 1 * K
    -        var expElogthetad = exp(Elogthetad)                         // 1 * K
    -        val expElogbetad = expElogbeta(::, ids).toDenseMatrix       // K * ids
    -
    -        var phinorm = expElogthetad * expElogbetad + 1e-100         // 1 * ids
    -        var meanchange = 1D
    -        val ctsVector = new BDV[Double](cts).t                      // 1 * ids
    -
    -        // Iterate between gamma and phi until convergence
    -        while (meanchange > 1e-3) {
    -          val lastgamma = gammad
    -          //        1*K                  1 * ids               ids * k
    -          gammad = (expElogthetad :* ((ctsVector / phinorm) * expElogbetad.t)) + alpha
    -          Elogthetad = digamma(gammad) - digamma(sum(gammad))
    -          expElogthetad = exp(Elogthetad)
    -          phinorm = expElogthetad * expElogbetad + 1e-100
    -          meanchange = sum(abs(gammad - lastgamma)) / k
    -        }
    +    val stats: RDD[(BDM[Double], List[BDV[Double]])] = batch.mapPartitions { docs =>
    +      val nonEmptyDocs = docs.filter(_._2.numNonzeros > 0)
     
    -        val m1 = expElogthetad.t
    -        val m2 = (ctsVector / phinorm).t.toDenseVector
    -        var i = 0
    -        while (i < ids.size) {
    -          stat(::, ids(i)) := stat(::, ids(i)) + m1 * m2(i)
    -          i += 1
    +      val stat = BDM.zeros[Double](k, vocabSize)
    +      var gammaPart = List[BDV[Double]]()
    +      nonEmptyDocs.zipWithIndex.foreach { case ((_, termCounts: Vector), idx: Int) =>
    +        val ids: List[Int] = termCounts match {
    +          case v: DenseVector => (0 until v.size).toList
    +          case v: SparseVector => v.indices.toList
             }
    +        val (gammad, sstats) = OnlineLDAOptimizer.variationalTopicInference(
    +          termCounts, expElogbeta, alpha, gammaShape, k)
    +        stat(::, ids) := stat(::, ids).toDenseMatrix + sstats
    +        gammaPart = gammad :: gammaPart
           }
    -      Iterator(stat)
    +      Iterator((stat, gammaPart))
         }
    -
    -    val statsSum: BDM[Double] = stats.reduce(_ += _)
    -    val batchResult = statsSum :* expElogbeta
    +    val statsSum: BDM[Double] = stats.map(_._1).reduce(_ += _)
    +    val gammat: BDM[Double] = breeze.linalg.DenseMatrix.vertcat(
    +      stats.map(_._2).reduce(_ ++ _).map(_.toDenseMatrix): _*)
    +    val batchResult = statsSum :* expElogbeta.t
     
         // Note that this is an optimization to avoid batch.count
    -    update(batchResult, iteration, (miniBatchFraction * corpusSize).ceil.toInt)
    +    updateLambda(batchResult, (miniBatchFraction * corpusSize).ceil.toInt)
         this
       }
     
    -  override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = {
    -    new LocalLDAModel(Matrices.fromBreeze(lambda).transpose)
    -  }
    -
       /**
        * Update lambda based on the batch submitted. batchSize can be different for each iteration.
        */
    -  private[clustering] def update(stat: BDM[Double], iter: Int, batchSize: Int): Unit = {
    +  private def updateLambda(stat: BDM[Double], batchSize: Int): Unit = {
         // weight of the mini-batch.
    -    val weight = math.pow(getTau0 + iter, -getKappa)
    +    val weight = rho()
     
         // Update lambda based on documents.
    -    lambda = lambda * (1 - weight) +
    -      (stat * (corpusSize.toDouble / batchSize.toDouble) + eta) * weight
    +    lambda := (1 - weight) * lambda +
    +      weight * (stat * (corpusSize.toDouble / batchSize.toDouble) + eta)
    +  }
    +
    +  /** Calculates learning rate rho, which decays as a function of [[iteration]] */
    +  private def rho(): Double = {
    +    math.pow(getTau0 + this.iteration, -getKappa)
       }
     
       /**
    @@ -453,15 +449,57 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
         new BDM[Double](col, row, temp).t
       }
     
    +  override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = {
    +    new LocalLDAModel(Matrices.fromBreeze(lambda).transpose, alpha, eta, gammaShape)
    +  }
    +
    +}
    +
    +/**
    + * Serializable companion object containing helper methods and shared code for
    + * [[OnlineLDAOptimizer]] and [[LocalLDAModel]].
    + */
    +private[clustering] object OnlineLDAOptimizer {
       /**
    -   * For theta ~ Dir(alpha), computes E[log(theta)] given alpha. Currently the implementation
    -   * uses digamma which is accurate but expensive.
    +   * Uses variational inference to infer the topic distribution `gammad` given the term counts
    +   * for a document. `termCounts` must contain at least one non-zero entry, otherwise Breeze will
    +   * throw a BLAS error.
    +   *
    +   * An optimization (Lee, Seung: Algorithms for non-negative matrix factorization, NIPS 2001)
    +   * avoids explicit computation of variational parameter `phi`.
    +   * @see [[http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.31.7566]]
        */
    -  private def dirichletExpectation(alpha: BDM[Double]): BDM[Double] = {
    -    val rowSum = sum(alpha(breeze.linalg.*, ::))
    -    val digAlpha = digamma(alpha)
    -    val digRowSum = digamma(rowSum)
    -    val result = digAlpha(::, breeze.linalg.*) - digRowSum
    -    result
    +  private[clustering] def variationalTopicInference(
    +      termCounts: Vector,
    +      expElogbeta: BDM[Double],
    +      alpha: breeze.linalg.Vector[Double],
    +      gammaShape: Double,
    +      k: Int): (BDV[Double], BDM[Double]) = {
    +    val (ids: List[Int], cts: Array[Double]) = termCounts match {
    +      case v: DenseVector => ((0 until v.size).toList, v.values)
    +      case v: SparseVector => (v.indices.toList, v.values)
    +    }
    +    // Initialize the variational distribution q(theta|gamma) for the mini-batch
    +    val gammad: BDV[Double] =
    +      new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k)                   // K
    +    val expElogthetad: BDV[Double] = exp(LDAUtils.dirichletExpectation(gammad))  // K
    +    val expElogbetad = expElogbeta(ids, ::).toDenseMatrix                        // ids * K
    +
    +    val phinorm: BDV[Double] = expElogbetad * expElogthetad :+ 1e-100            // ids
    +    var meanchange = 1D
    +    val ctsVector = new BDV[Double](cts)                                         // ids
    +
    +    // Iterate between gamma and phi until convergence
    +    while (meanchange > 1e-3) {
    +      val lastgamma = gammad.copy
    +      //        K                  K * ids               ids
    +      gammad := (expElogthetad :* (expElogbetad.t * (ctsVector :/ phinorm))) :+ alpha
    +      expElogthetad := exp(LDAUtils.dirichletExpectation(gammad))
    +      phinorm := expElogbetad * expElogthetad :+ 1e-100
    +      meanchange = sum(abs(gammad - lastgamma)) / k
    +    }
    +
    +    val sstatsd = expElogthetad.asDenseMatrix.t * (ctsVector :/ phinorm).asDenseMatrix
    +    (gammad, sstatsd)
       }
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala
    new file mode 100644
    index 0000000000000..f7e5ce1665fe6
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala
    @@ -0,0 +1,55 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +package org.apache.spark.mllib.clustering
    +
    +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, max, sum}
    +import breeze.numerics._
    +
    +/**
    + * Utility methods for LDA.
    + */
    +object LDAUtils {
    +  /**
    +   * Log Sum Exp with overflow protection using the identity:
    +   * For any a: \log \sum_{n=1}^N \exp\{x_n\} = a + \log \sum_{n=1}^N \exp\{x_n - a\}
    +   */
    +  private[clustering] def logSumExp(x: BDV[Double]): Double = {
    +    val a = max(x)
    +    a + log(sum(exp(x :- a)))
    +  }
    +
    +  /**
    +   * For theta ~ Dir(alpha), computes E[log(theta)] given alpha. Currently the implementation
    +   * uses [[breeze.numerics.digamma]] which is accurate but expensive.
    +   */
    +  private[clustering] def dirichletExpectation(alpha: BDV[Double]): BDV[Double] = {
    +    digamma(alpha) - digamma(sum(alpha))
    +  }
    +
    +  /**
    +   * Computes [[dirichletExpectation()]] row-wise, assuming each row of alpha are
    +   * Dirichlet parameters.
    +   */
    +  private[clustering] def dirichletExpectation(alpha: BDM[Double]): BDM[Double] = {
    +    val rowSum = sum(alpha(breeze.linalg.*, ::))
    +    val digAlpha = digamma(alpha)
    +    val digRowSum = digamma(rowSum)
    +    val result = digAlpha(::, breeze.linalg.*) - digRowSum
    +    result
    +  }
    +
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
    index e7a243f854e33..407e43a024a2e 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
    @@ -153,6 +153,27 @@ class PowerIterationClustering private[clustering] (
         this
       }
     
    +  /**
    +   * Run the PIC algorithm on Graph.
    +   *
    +   * @param graph an affinity matrix represented as graph, which is the matrix A in the PIC paper.
    +   *              The similarity s,,ij,, represented as the edge between vertices (i, j) must
    +   *              be nonnegative. This is a symmetric matrix and hence s,,ij,, = s,,ji,,. For
    +   *              any (i, j) with nonzero similarity, there should be either (i, j, s,,ij,,)
    +   *              or (j, i, s,,ji,,) in the input. Tuples with i = j are ignored, because we
    +   *              assume s,,ij,, = 0.0.
    +   *
    +   * @return a [[PowerIterationClusteringModel]] that contains the clustering result
    +   */
    +  def run(graph: Graph[Double, Double]): PowerIterationClusteringModel = {
    +    val w = normalize(graph)
    +    val w0 = initMode match {
    +      case "random" => randomInit(w)
    +      case "degree" => initDegreeVector(w)
    +    }
    +    pic(w0)
    +  }
    +
       /**
        * Run the PIC algorithm.
        *
    @@ -212,6 +233,31 @@ object PowerIterationClustering extends Logging {
       @Experimental
       case class Assignment(id: Long, cluster: Int)
     
    +  /**
    +   * Normalizes the affinity graph (A) and returns the normalized affinity matrix (W).
    +   */
    +  private[clustering]
    +  def normalize(graph: Graph[Double, Double]): Graph[Double, Double] = {
    +    val vD = graph.aggregateMessages[Double](
    +      sendMsg = ctx => {
    +        val i = ctx.srcId
    +        val j = ctx.dstId
    +        val s = ctx.attr
    +        if (s < 0.0) {
    +          throw new SparkException("Similarity must be nonnegative but found s($i, $j) = $s.")
    +        }
    +        if (s > 0.0) {
    +          ctx.sendToSrc(s)
    +        }
    +      },
    +      mergeMsg = _ + _,
    +      TripletFields.EdgeOnly)
    +    GraphImpl.fromExistingRDDs(vD, graph.edges)
    +      .mapTriplets(
    +        e => e.attr / math.max(e.srcAttr, MLUtils.EPSILON),
    +        TripletFields.Src)
    +  }
    +
       /**
        * Normalizes the affinity matrix (A) by row sums and returns the normalized affinity matrix (W).
        */
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
    index e577bf87f885e..408847afa800d 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
    @@ -53,14 +53,22 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend
           )
         summary
       }
    +  private lazy val SSerr = math.pow(summary.normL2(1), 2)
    +  private lazy val SStot = summary.variance(0) * (summary.count - 1)
    +  private lazy val SSreg = {
    +    val yMean = summary.mean(0)
    +    predictionAndObservations.map {
    +      case (prediction, _) => math.pow(prediction - yMean, 2)
    +    }.sum()
    +  }
     
       /**
    -   * Returns the explained variance regression score.
    -   * explainedVariance = 1 - variance(y - \hat{y}) / variance(y)
    -   * Reference: [[http://en.wikipedia.org/wiki/Explained_variation]]
    +   * Returns the variance explained by regression.
    +   * explainedVariance = \sum_i (\hat{y_i} - \bar{y})^2 / n
    +   * @see [[https://en.wikipedia.org/wiki/Fraction_of_variance_unexplained]]
        */
       def explainedVariance: Double = {
    -    1 - summary.variance(1) / summary.variance(0)
    +    SSreg / summary.count
       }
     
       /**
    @@ -76,8 +84,7 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend
        * expected value of the squared error loss or quadratic loss.
        */
       def meanSquaredError: Double = {
    -    val rmse = summary.normL2(1) / math.sqrt(summary.count)
    -    rmse * rmse
    +    SSerr / summary.count
       }
     
       /**
    @@ -85,14 +92,14 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend
        * the mean squared error.
        */
       def rootMeanSquaredError: Double = {
    -    summary.normL2(1) / math.sqrt(summary.count)
    +    math.sqrt(this.meanSquaredError)
       }
     
       /**
    -   * Returns R^2^, the coefficient of determination.
    -   * Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]]
    +   * Returns R^2^, the unadjusted coefficient of determination.
    +   * @see [[http://en.wikipedia.org/wiki/Coefficient_of_determination]]
        */
       def r2: Double = {
    -    1 - math.pow(summary.normL2(1), 2) / (summary.variance(0) * (summary.count - 1))
    +    1 - SSerr / SStot
       }
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
    index f087d06d2a46a..cbbd2b0c8d060 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
    @@ -403,17 +403,8 @@ class Word2Vec extends Serializable with Logging {
         }
         newSentences.unpersist()
     
    -    val word2VecMap = mutable.HashMap.empty[String, Array[Float]]
    -    var i = 0
    -    while (i < vocabSize) {
    -      val word = bcVocab.value(i).word
    -      val vector = new Array[Float](vectorSize)
    -      Array.copy(syn0Global, i * vectorSize, vector, 0, vectorSize)
    -      word2VecMap += word -> vector
    -      i += 1
    -    }
    -
    -    new Word2VecModel(word2VecMap.toMap)
    +    val wordArray = vocab.map(_.word)
    +    new Word2VecModel(wordArray.zipWithIndex.toMap, syn0Global)
       }
     
       /**
    @@ -429,38 +420,42 @@ class Word2Vec extends Serializable with Logging {
     /**
      * :: Experimental ::
      * Word2Vec model
    + * @param wordIndex maps each word to an index, which can retrieve the corresponding
    + *                  vector from wordVectors
    + * @param wordVectors array of length numWords * vectorSize, vector corresponding
    + *                    to the word mapped with index i can be retrieved by the slice
    + *                    (i * vectorSize, i * vectorSize + vectorSize)
      */
     @Experimental
    -class Word2VecModel private[spark] (
    -    model: Map[String, Array[Float]]) extends Serializable with Saveable {
    -
    -  // wordList: Ordered list of words obtained from model.
    -  private val wordList: Array[String] = model.keys.toArray
    -
    -  // wordIndex: Maps each word to an index, which can retrieve the corresponding
    -  //            vector from wordVectors (see below).
    -  private val wordIndex: Map[String, Int] = wordList.zip(0 until model.size).toMap
    +class Word2VecModel private[mllib] (
    +    private val wordIndex: Map[String, Int],
    +    private val wordVectors: Array[Float]) extends Serializable with Saveable {
     
    -  // vectorSize: Dimension of each word's vector.
    -  private val vectorSize = model.head._2.size
       private val numWords = wordIndex.size
    +  // vectorSize: Dimension of each word's vector.
    +  private val vectorSize = wordVectors.length / numWords
    +
    +  // wordList: Ordered list of words obtained from wordIndex.
    +  private val wordList: Array[String] = {
    +    val (wl, _) = wordIndex.toSeq.sortBy(_._2).unzip
    +    wl.toArray
    +  }
     
    -  // wordVectors: Array of length numWords * vectorSize, vector corresponding to the word
    -  //              mapped with index i can be retrieved by the slice
    -  //              (ind * vectorSize, ind * vectorSize + vectorSize)
       // wordVecNorms: Array of length numWords, each value being the Euclidean norm
       //               of the wordVector.
    -  private val (wordVectors: Array[Float], wordVecNorms: Array[Double]) = {
    -    val wordVectors = new Array[Float](vectorSize * numWords)
    +  private val wordVecNorms: Array[Double] = {
         val wordVecNorms = new Array[Double](numWords)
         var i = 0
         while (i < numWords) {
    -      val vec = model.get(wordList(i)).get
    -      Array.copy(vec, 0, wordVectors, i * vectorSize, vectorSize)
    +      val vec = wordVectors.slice(i * vectorSize, i * vectorSize + vectorSize)
           wordVecNorms(i) = blas.snrm2(vectorSize, vec, 1)
           i += 1
         }
    -    (wordVectors, wordVecNorms)
    +    wordVecNorms
    +  }
    +
    +  def this(model: Map[String, Array[Float]]) = {
    +    this(Word2VecModel.buildWordIndex(model), Word2VecModel.buildWordVectors(model))
       }
     
       private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = {
    @@ -484,8 +479,9 @@ class Word2VecModel private[spark] (
        * @return vector representation of word
        */
       def transform(word: String): Vector = {
    -    model.get(word) match {
    -      case Some(vec) =>
    +    wordIndex.get(word) match {
    +      case Some(ind) =>
    +        val vec = wordVectors.slice(ind * vectorSize, ind * vectorSize + vectorSize)
             Vectors.dense(vec.map(_.toDouble))
           case None =>
             throw new IllegalStateException(s"$word not in vocabulary")
    @@ -511,7 +507,7 @@ class Word2VecModel private[spark] (
        */
       def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
         require(num > 0, "Number of similar words should > 0")
    -
    +    // TODO: optimize top-k
         val fVector = vector.toArray.map(_.toFloat)
         val cosineVec = Array.fill[Float](numWords)(0)
         val alpha: Float = 1
    @@ -521,13 +517,13 @@ class Word2VecModel private[spark] (
           "T", vectorSize, numWords, alpha, wordVectors, vectorSize, fVector, 1, beta, cosineVec, 1)
     
         // Need not divide with the norm of the given vector since it is constant.
    -    val updatedCosines = new Array[Double](numWords)
    +    val cosVec = cosineVec.map(_.toDouble)
         var ind = 0
         while (ind < numWords) {
    -      updatedCosines(ind) = cosineVec(ind) / wordVecNorms(ind)
    +      cosVec(ind) /= wordVecNorms(ind)
           ind += 1
         }
    -    wordList.zip(updatedCosines)
    +    wordList.zip(cosVec)
           .toSeq
           .sortBy(- _._2)
           .take(num + 1)
    @@ -548,6 +544,23 @@ class Word2VecModel private[spark] (
     @Experimental
     object Word2VecModel extends Loader[Word2VecModel] {
     
    +  private def buildWordIndex(model: Map[String, Array[Float]]): Map[String, Int] = {
    +    model.keys.zipWithIndex.toMap
    +  }
    +
    +  private def buildWordVectors(model: Map[String, Array[Float]]): Array[Float] = {
    +    require(model.nonEmpty, "Word2VecMap should be non-empty")
    +    val (vectorSize, numWords) = (model.head._2.size, model.size)
    +    val wordList = model.keys.toArray
    +    val wordVectors = new Array[Float](vectorSize * numWords)
    +    var i = 0
    +    while (i < numWords) {
    +      Array.copy(model(wordList(i)), 0, wordVectors, i * vectorSize, vectorSize)
    +      i += 1
    +    }
    +    wordVectors
    +  }
    +
       private object SaveLoadV1_0 {
     
         val formatVersionV1_0 = "1.0"
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala
    new file mode 100644
    index 0000000000000..72d0ea0c12e1e
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala
    @@ -0,0 +1,119 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +package org.apache.spark.mllib.fpm
    +
    +import scala.reflect.ClassTag
    +
    +import org.apache.spark.Logging
    +import org.apache.spark.annotation.Experimental
    +import org.apache.spark.api.java.JavaRDD
    +import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
    +import org.apache.spark.mllib.fpm.AssociationRules.Rule
    +import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset
    +import org.apache.spark.rdd.RDD
    +
    +/**
    + * :: Experimental ::
    + *
    + * Generates association rules from a [[RDD[FreqItemset[Item]]]. This method only generates
    + * association rules which have a single item as the consequent.
    + *
    + * @since 1.5.0
    + */
    +@Experimental
    +class AssociationRules private[fpm] (
    +    private var minConfidence: Double) extends Logging with Serializable {
    +
    +  /**
    +   * Constructs a default instance with default parameters {minConfidence = 0.8}.
    +   *
    +   * @since 1.5.0
    +   */
    +  def this() = this(0.8)
    +
    +  /**
    +   * Sets the minimal confidence (default: `0.8`).
    +   *
    +   * @since 1.5.0
    +   */
    +  def setMinConfidence(minConfidence: Double): this.type = {
    +    require(minConfidence >= 0.0 && minConfidence <= 1.0)
    +    this.minConfidence = minConfidence
    +    this
    +  }
    +
    +  /**
    +   * Computes the association rules with confidence above [[minConfidence]].
    +   * @param freqItemsets frequent itemset model obtained from [[FPGrowth]]
    +   * @return a [[Set[Rule[Item]]] containing the assocation rules.
    +   *
    +   * @since 1.5.0
    +   */
    +  def run[Item: ClassTag](freqItemsets: RDD[FreqItemset[Item]]): RDD[Rule[Item]] = {
    +    // For candidate rule X => Y, generate (X, (Y, freq(X union Y)))
    +    val candidates = freqItemsets.flatMap { itemset =>
    +      val items = itemset.items
    +      items.flatMap { item =>
    +        items.partition(_ == item) match {
    +          case (consequent, antecedent) if !antecedent.isEmpty =>
    +            Some((antecedent.toSeq, (consequent.toSeq, itemset.freq)))
    +          case _ => None
    +        }
    +      }
    +    }
    +
    +    // Join to get (X, ((Y, freq(X union Y)), freq(X))), generate rules, and filter by confidence
    +    candidates.join(freqItemsets.map(x => (x.items.toSeq, x.freq)))
    +      .map { case (antecendent, ((consequent, freqUnion), freqAntecedent)) =>
    +      new Rule(antecendent.toArray, consequent.toArray, freqUnion, freqAntecedent)
    +    }.filter(_.confidence >= minConfidence)
    +  }
    +
    +  def run[Item](freqItemsets: JavaRDD[FreqItemset[Item]]): JavaRDD[Rule[Item]] = {
    +    val tag = fakeClassTag[Item]
    +    run(freqItemsets.rdd)(tag)
    +  }
    +}
    +
    +object AssociationRules {
    +
    +  /**
    +   * :: Experimental ::
    +   *
    +   * An association rule between sets of items.
    +   * @param antecedent hypotheses of the rule
    +   * @param consequent conclusion of the rule
    +   * @tparam Item item type
    +   *
    +   * @since 1.5.0
    +   */
    +  @Experimental
    +  class Rule[Item] private[fpm] (
    +      val antecedent: Array[Item],
    +      val consequent: Array[Item],
    +      freqUnion: Double,
    +      freqAntecedent: Double) extends Serializable {
    +
    +    def confidence: Double = freqUnion.toDouble / freqAntecedent
    +
    +    require(antecedent.toSet.intersect(consequent.toSet).isEmpty, {
    +      val sharedItems = antecedent.toSet.intersect(consequent.toSet)
    +      s"A valid association rule must have disjoint antecedent and " +
    +        s"consequent but ${sharedItems} is present in both."
    +    })
    +  }
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
    index abac08022ea47..e2370a52f4930 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
    @@ -28,7 +28,7 @@ import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException}
     import org.apache.spark.annotation.Experimental
     import org.apache.spark.api.java.JavaRDD
     import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
    -import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset
    +import org.apache.spark.mllib.fpm.FPGrowth._
     import org.apache.spark.rdd.RDD
     import org.apache.spark.storage.StorageLevel
     
    @@ -36,11 +36,23 @@ import org.apache.spark.storage.StorageLevel
      * :: Experimental ::
      *
      * Model trained by [[FPGrowth]], which holds frequent itemsets.
    - * @param freqItemsets frequent itemsets, which is an RDD of [[FreqItemset]]
    + * @param freqItemsets frequent itemset, which is an RDD of [[FreqItemset]]
      * @tparam Item item type
    + *
    + * @since 1.3.0
      */
     @Experimental
    -class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable
    +class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable {
    +  /**
    +   * Generates association rules for the [[Item]]s in [[freqItemsets]].
    +   * @param confidence minimal confidence of the rules produced
    +   * @since 1.5.0
    +   */
    +  def generateAssociationRules(confidence: Double): RDD[AssociationRules.Rule[Item]] = {
    +    val associationRules = new AssociationRules(confidence)
    +    associationRules.run(freqItemsets)
    +  }
    +}
     
     /**
      * :: Experimental ::
    @@ -58,21 +70,26 @@ class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) ex
      *
      * @see [[http://en.wikipedia.org/wiki/Association_rule_learning Association rule learning
      *       (Wikipedia)]]
    + *
    + * @since 1.3.0
      */
     @Experimental
     class FPGrowth private (
         private var minSupport: Double,
    -    private var numPartitions: Int,
    -    private var ordered: Boolean) extends Logging with Serializable {
    +    private var numPartitions: Int) extends Logging with Serializable {
     
       /**
        * Constructs a default instance with default parameters {minSupport: `0.3`, numPartitions: same
    -   * as the input data, ordered: `false`}.
    +   * as the input data}.
    +   *
    +   * @since 1.3.0
        */
    -  def this() = this(0.3, -1, false)
    +  def this() = this(0.3, -1)
     
       /**
        * Sets the minimal support level (default: `0.3`).
    +   *
    +   * @since 1.3.0
        */
       def setMinSupport(minSupport: Double): this.type = {
         this.minSupport = minSupport
    @@ -81,25 +98,20 @@ class FPGrowth private (
     
       /**
        * Sets the number of partitions used by parallel FP-growth (default: same as input data).
    +   *
    +   * @since 1.3.0
        */
       def setNumPartitions(numPartitions: Int): this.type = {
         this.numPartitions = numPartitions
         this
       }
     
    -  /**
    -   * Indicates whether to mine itemsets (unordered) or sequences (ordered) (default: false, mine
    -   * itemsets).
    -   */
    -  def setOrdered(ordered: Boolean): this.type = {
    -    this.ordered = ordered
    -    this
    -  }
    -
       /**
        * Computes an FP-Growth model that contains frequent itemsets.
        * @param data input data set, each element contains a transaction
        * @return an [[FPGrowthModel]]
    +   *
    +   * @since 1.3.0
        */
       def run[Item: ClassTag](data: RDD[Array[Item]]): FPGrowthModel[Item] = {
         if (data.getStorageLevel == StorageLevel.NONE) {
    @@ -165,7 +177,7 @@ class FPGrowth private (
         .flatMap { case (part, tree) =>
           tree.extract(minCount, x => partitioner.getPartition(x) == part)
         }.map { case (ranks, count) =>
    -      new FreqItemset(ranks.map(i => freqItems(i)).reverse.toArray, count, ordered)
    +      new FreqItemset(ranks.map(i => freqItems(i)).toArray, count)
         }
       }
     
    @@ -181,12 +193,9 @@ class FPGrowth private (
           itemToRank: Map[Item, Int],
           partitioner: Partitioner): mutable.Map[Int, Array[Int]] = {
         val output = mutable.Map.empty[Int, Array[Int]]
    -    // Filter the basket by frequent items pattern
    +    // Filter the basket by frequent items pattern and sort their ranks.
         val filtered = transaction.flatMap(itemToRank.get)
    -    if (!this.ordered) {
    -      ju.Arrays.sort(filtered)
    -    }
    -    // Generate conditional transactions
    +    ju.Arrays.sort(filtered)
         val n = filtered.length
         var i = n - 1
         while (i >= 0) {
    @@ -203,6 +212,8 @@ class FPGrowth private (
     
     /**
      * :: Experimental ::
    + *
    + * @since 1.3.0
      */
     @Experimental
     object FPGrowth {
    @@ -211,21 +222,16 @@ object FPGrowth {
        * Frequent itemset.
        * @param items items in this itemset. Java users should call [[FreqItemset#javaItems]] instead.
        * @param freq frequency
    -   * @param ordered indicates if items represents an itemset (false) or sequence (true)
        * @tparam Item item type
    +   *
    +   * @since 1.3.0
        */
    -  class FreqItemset[Item](val items: Array[Item], val freq: Long, val ordered: Boolean)
    -    extends Serializable {
    -
    -    /**
    -     * Auxillary constructor, assumes unordered by default.
    -     */
    -    def this(items: Array[Item], freq: Long) {
    -      this(items, freq, false)
    -    }
    +  class FreqItemset[Item](val items: Array[Item], val freq: Long) extends Serializable {
     
         /**
          * Returns items in a Java List.
    +     *
    +     * @since 1.3.0
          */
         def javaItems: java.util.List[Item] = {
           items.toList.asJava
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
    new file mode 100644
    index 0000000000000..0ea792081086d
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
    @@ -0,0 +1,94 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.mllib.fpm
    +
    +import scala.collection.mutable
    +
    +import org.apache.spark.Logging
    +
    +/**
    + * Calculate all patterns of a projected database in local.
    + */
    +private[fpm] object LocalPrefixSpan extends Logging with Serializable {
    +
    +  /**
    +   * Calculate all patterns of a projected database.
    +   * @param minCount minimum count
    +   * @param maxPatternLength maximum pattern length
    +   * @param prefixes prefixes in reversed order
    +   * @param database the projected database
    +   * @return a set of sequential pattern pairs,
    +   *         the key of pair is sequential pattern (a list of items in reversed order),
    +   *         the value of pair is the pattern's count.
    +   */
    +  def run(
    +      minCount: Long,
    +      maxPatternLength: Int,
    +      prefixes: List[Int],
    +      database: Iterable[Array[Int]]): Iterator[(List[Int], Long)] = {
    +    if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty
    +    val frequentItemAndCounts = getFreqItemAndCounts(minCount, database)
    +    val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains))
    +    frequentItemAndCounts.iterator.flatMap { case (item, count) =>
    +      val newPrefixes = item :: prefixes
    +      val newProjected = project(filteredDatabase, item)
    +      Iterator.single((newPrefixes, count)) ++
    +        run(minCount, maxPatternLength, newPrefixes, newProjected)
    +    }
    +  }
    +
    +  /**
    +   * Calculate suffix sequence immediately after the first occurrence of an item.
    +   * @param item item to get suffix after
    +   * @param sequence sequence to extract suffix from
    +   * @return suffix sequence
    +   */
    +  def getSuffix(item: Int, sequence: Array[Int]): Array[Int] = {
    +    val index = sequence.indexOf(item)
    +    if (index == -1) {
    +      Array()
    +    } else {
    +      sequence.drop(index + 1)
    +    }
    +  }
    +
    +  def project(database: Iterable[Array[Int]], prefix: Int): Iterable[Array[Int]] = {
    +    database
    +      .map(getSuffix(prefix, _))
    +      .filter(_.nonEmpty)
    +  }
    +
    +  /**
    +   * Generates frequent items by filtering the input data using minimal count level.
    +   * @param minCount the minimum count for an item to be frequent
    +   * @param database database of sequences
    +   * @return freq item to count map
    +   */
    +  private def getFreqItemAndCounts(
    +      minCount: Long,
    +      database: Iterable[Array[Int]]): mutable.Map[Int, Long] = {
    +    // TODO: use PrimitiveKeyOpenHashMap
    +    val counts = mutable.Map[Int, Long]().withDefaultValue(0L)
    +    database.foreach { sequence =>
    +      sequence.distinct.foreach { item =>
    +        counts(item) += 1L
    +      }
    +    }
    +    counts.filter(_._2 >= minCount)
    +  }
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
    new file mode 100644
    index 0000000000000..e6752332cdeeb
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
    @@ -0,0 +1,249 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.mllib.fpm
    +
    +import scala.collection.mutable.ArrayBuffer
    +
    +import org.apache.spark.Logging
    +import org.apache.spark.annotation.Experimental
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.storage.StorageLevel
    +
    +/**
    + *
    + * :: Experimental ::
    + *
    + * A parallel PrefixSpan algorithm to mine sequential pattern.
    + * The PrefixSpan algorithm is described in
    + * [[http://doi.org/10.1109/ICDE.2001.914830]].
    + *
    + * @param minSupport the minimal support level of the sequential pattern, any pattern appears
    + *                   more than  (minSupport * size-of-the-dataset) times will be output
    + * @param maxPatternLength the maximal length of the sequential pattern, any pattern appears
    + *                   less than maxPatternLength will be output
    + *
    + * @see [[https://en.wikipedia.org/wiki/Sequential_Pattern_Mining Sequential Pattern Mining
    + *       (Wikipedia)]]
    + */
    +@Experimental
    +class PrefixSpan private (
    +    private var minSupport: Double,
    +    private var maxPatternLength: Int) extends Logging with Serializable {
    +
    +  /**
    +   * The maximum number of items allowed in a projected database before local processing. If a
    +   * projected database exceeds this size, another iteration of distributed PrefixSpan is run.
    +   */
    +  // TODO: make configurable with a better default value, 10000 may be too small
    +  private val maxLocalProjDBSize: Long = 10000
    +
    +  /**
    +   * Constructs a default instance with default parameters
    +   * {minSupport: `0.1`, maxPatternLength: `10`}.
    +   */
    +  def this() = this(0.1, 10)
    +
    +  /**
    +   * Get the minimal support (i.e. the frequency of occurrence before a pattern is considered
    +   * frequent).
    +   */
    +  def getMinSupport: Double = this.minSupport
    +
    +  /**
    +   * Sets the minimal support level (default: `0.1`).
    +   */
    +  def setMinSupport(minSupport: Double): this.type = {
    +    require(minSupport >= 0 && minSupport <= 1, "The minimum support value must be in [0, 1].")
    +    this.minSupport = minSupport
    +    this
    +  }
    +
    +  /**
    +   * Gets the maximal pattern length (i.e. the length of the longest sequential pattern to consider.
    +   */
    +  def getMaxPatternLength: Double = this.maxPatternLength
    +
    +  /**
    +   * Sets maximal pattern length (default: `10`).
    +   */
    +  def setMaxPatternLength(maxPatternLength: Int): this.type = {
    +    // TODO: support unbounded pattern length when maxPatternLength = 0
    +    require(maxPatternLength >= 1, "The maximum pattern length value must be greater than 0.")
    +    this.maxPatternLength = maxPatternLength
    +    this
    +  }
    +
    +  /**
    +   * Find the complete set of sequential patterns in the input sequences.
    +   * @param sequences input data set, contains a set of sequences,
    +   *                  a sequence is an ordered list of elements.
    +   * @return a set of sequential pattern pairs,
    +   *         the key of pair is pattern (a list of elements),
    +   *         the value of pair is the pattern's count.
    +   */
    +  def run(sequences: RDD[Array[Int]]): RDD[(Array[Int], Long)] = {
    +    val sc = sequences.sparkContext
    +
    +    if (sequences.getStorageLevel == StorageLevel.NONE) {
    +      logWarning("Input data is not cached.")
    +    }
    +
    +    // Convert min support to a min number of transactions for this dataset
    +    val minCount = if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong
    +
    +    // (Frequent items -> number of occurrences, all items here satisfy the `minSupport` threshold
    +    val freqItemCounts = sequences
    +      .flatMap(seq => seq.distinct.map(item => (item, 1L)))
    +      .reduceByKey(_ + _)
    +      .filter(_._2 >= minCount)
    +      .collect()
    +
    +    // Pairs of (length 1 prefix, suffix consisting of frequent items)
    +    val itemSuffixPairs = {
    +      val freqItems = freqItemCounts.map(_._1).toSet
    +      sequences.flatMap { seq =>
    +        val filteredSeq = seq.filter(freqItems.contains(_))
    +        freqItems.flatMap { item =>
    +          val candidateSuffix = LocalPrefixSpan.getSuffix(item, filteredSeq)
    +          candidateSuffix match {
    +            case suffix if !suffix.isEmpty => Some((List(item), suffix))
    +            case _ => None
    +          }
    +        }
    +      }
    +    }
    +
    +    // Accumulator for the computed results to be returned, initialized to the frequent items (i.e.
    +    // frequent length-one prefixes)
    +    var resultsAccumulator = freqItemCounts.map(x => (List(x._1), x._2))
    +
    +    // Remaining work to be locally and distributively processed respectfully
    +    var (pairsForLocal, pairsForDistributed) = partitionByProjDBSize(itemSuffixPairs)
    +
    +    // Continue processing until no pairs for distributed processing remain (i.e. all prefixes have
    +    // projected database sizes <= `maxLocalProjDBSize`)
    +    while (pairsForDistributed.count() != 0) {
    +      val (nextPatternAndCounts, nextPrefixSuffixPairs) =
    +        extendPrefixes(minCount, pairsForDistributed)
    +      pairsForDistributed.unpersist()
    +      val (smallerPairsPart, largerPairsPart) = partitionByProjDBSize(nextPrefixSuffixPairs)
    +      pairsForDistributed = largerPairsPart
    +      pairsForDistributed.persist(StorageLevel.MEMORY_AND_DISK)
    +      pairsForLocal ++= smallerPairsPart
    +      resultsAccumulator ++= nextPatternAndCounts.collect()
    +    }
    +
    +    // Process the small projected databases locally
    +    val remainingResults = getPatternsInLocal(
    +      minCount, sc.parallelize(pairsForLocal, 1).groupByKey())
    +
    +    (sc.parallelize(resultsAccumulator, 1) ++ remainingResults)
    +      .map { case (pattern, count) => (pattern.toArray, count) }
    +  }
    +
    +
    +  /**
    +   * Partitions the prefix-suffix pairs by projected database size.
    +   * @param prefixSuffixPairs prefix (length n) and suffix pairs,
    +   * @return prefix-suffix pairs partitioned by whether their projected database size is <= or
    +   *         greater than [[maxLocalProjDBSize]]
    +   */
    +  private def partitionByProjDBSize(prefixSuffixPairs: RDD[(List[Int], Array[Int])])
    +    : (Array[(List[Int], Array[Int])], RDD[(List[Int], Array[Int])]) = {
    +    val prefixToSuffixSize = prefixSuffixPairs
    +      .aggregateByKey(0)(
    +        seqOp = { case (count, suffix) => count + suffix.length },
    +        combOp = { _ + _ })
    +    val smallPrefixes = prefixToSuffixSize
    +      .filter(_._2 <= maxLocalProjDBSize)
    +      .keys
    +      .collect()
    +      .toSet
    +    val small = prefixSuffixPairs.filter { case (prefix, _) => smallPrefixes.contains(prefix) }
    +    val large = prefixSuffixPairs.filter { case (prefix, _) => !smallPrefixes.contains(prefix) }
    +    (small.collect(), large)
    +  }
    +
    +  /**
    +   * Extends all prefixes by one item from their suffix and computes the resulting frequent prefixes
    +   * and remaining work.
    +   * @param minCount minimum count
    +   * @param prefixSuffixPairs prefix (length N) and suffix pairs,
    +   * @return (frequent length N+1 extended prefix, count) pairs and (frequent length N+1 extended
    +   *         prefix, corresponding suffix) pairs.
    +   */
    +  private def extendPrefixes(
    +      minCount: Long,
    +      prefixSuffixPairs: RDD[(List[Int], Array[Int])])
    +    : (RDD[(List[Int], Long)], RDD[(List[Int], Array[Int])]) = {
    +
    +    // (length N prefix, item from suffix) pairs and their corresponding number of occurrences
    +    // Every (prefix :+ suffix) is guaranteed to have support exceeding `minSupport`
    +    val prefixItemPairAndCounts = prefixSuffixPairs
    +      .flatMap { case (prefix, suffix) => suffix.distinct.map(y => ((prefix, y), 1L)) }
    +      .reduceByKey(_ + _)
    +      .filter(_._2 >= minCount)
    +
    +    // Map from prefix to set of possible next items from suffix
    +    val prefixToNextItems = prefixItemPairAndCounts
    +      .keys
    +      .groupByKey()
    +      .mapValues(_.toSet)
    +      .collect()
    +      .toMap
    +
    +
    +    // Frequent patterns with length N+1 and their corresponding counts
    +    val extendedPrefixAndCounts = prefixItemPairAndCounts
    +      .map { case ((prefix, item), count) => (item :: prefix, count) }
    +
    +    // Remaining work, all prefixes will have length N+1
    +    val extendedPrefixAndSuffix = prefixSuffixPairs
    +      .filter(x => prefixToNextItems.contains(x._1))
    +      .flatMap { case (prefix, suffix) =>
    +        val frequentNextItems = prefixToNextItems(prefix)
    +        val filteredSuffix = suffix.filter(frequentNextItems.contains(_))
    +        frequentNextItems.flatMap { item =>
    +          LocalPrefixSpan.getSuffix(item, filteredSuffix) match {
    +            case suffix if !suffix.isEmpty => Some(item :: prefix, suffix)
    +            case _ => None
    +          }
    +        }
    +      }
    +
    +    (extendedPrefixAndCounts, extendedPrefixAndSuffix)
    +  }
    +
    +  /**
    +   * Calculate the patterns in local.
    +   * @param minCount the absolute minimum count
    +   * @param data prefixes and projected sequences data data
    +   * @return patterns
    +   */
    +  private def getPatternsInLocal(
    +      minCount: Long,
    +      data: RDD[(List[Int], Iterable[Array[Int]])]): RDD[(List[Int], Long)] = {
    +    data.flatMap {
    +      case (prefix, projDB) =>
    +        LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList.reverse, projDB)
    +          .map { case (pattern: List[Int], count: Long) =>
    +          (pattern.reverse, count)
    +        }
    +    }
    +  }
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala
    new file mode 100644
    index 0000000000000..72d3aabc9b1f4
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala
    @@ -0,0 +1,154 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.mllib.impl
    +
    +import scala.collection.mutable
    +
    +import org.apache.hadoop.fs.{Path, FileSystem}
    +
    +import org.apache.spark.{SparkContext, Logging}
    +import org.apache.spark.storage.StorageLevel
    +
    +
    +/**
    + * This abstraction helps with persisting and checkpointing RDDs and types derived from RDDs
    + * (such as Graphs and DataFrames).  In documentation, we use the phrase "Dataset" to refer to
    + * the distributed data type (RDD, Graph, etc.).
    + *
    + * Specifically, this abstraction automatically handles persisting and (optionally) checkpointing,
    + * as well as unpersisting and removing checkpoint files.
    + *
    + * Users should call update() when a new Dataset has been created,
    + * before the Dataset has been materialized.  After updating [[PeriodicCheckpointer]], users are
    + * responsible for materializing the Dataset to ensure that persisting and checkpointing actually
    + * occur.
    + *
    + * When update() is called, this does the following:
    + *  - Persist new Dataset (if not yet persisted), and put in queue of persisted Datasets.
    + *  - Unpersist Datasets from queue until there are at most 3 persisted Datasets.
    + *  - If using checkpointing and the checkpoint interval has been reached,
    + *     - Checkpoint the new Dataset, and put in a queue of checkpointed Datasets.
    + *     - Remove older checkpoints.
    + *
    + * WARNINGS:
    + *  - This class should NOT be copied (since copies may conflict on which Datasets should be
    + *    checkpointed).
    + *  - This class removes checkpoint files once later Datasets have been checkpointed.
    + *    However, references to the older Datasets will still return isCheckpointed = true.
    + *
    + * @param checkpointInterval  Datasets will be checkpointed at this interval
    + * @param sc  SparkContext for the Datasets given to this checkpointer
    + * @tparam T  Dataset type, such as RDD[Double]
    + */
    +private[mllib] abstract class PeriodicCheckpointer[T](
    +    val checkpointInterval: Int,
    +    val sc: SparkContext) extends Logging {
    +
    +  /** FIFO queue of past checkpointed Datasets */
    +  private val checkpointQueue = mutable.Queue[T]()
    +
    +  /** FIFO queue of past persisted Datasets */
    +  private val persistedQueue = mutable.Queue[T]()
    +
    +  /** Number of times [[update()]] has been called */
    +  private var updateCount = 0
    +
    +  /**
    +   * Update with a new Dataset. Handle persistence and checkpointing as needed.
    +   * Since this handles persistence and checkpointing, this should be called before the Dataset
    +   * has been materialized.
    +   *
    +   * @param newData  New Dataset created from previous Datasets in the lineage.
    +   */
    +  def update(newData: T): Unit = {
    +    persist(newData)
    +    persistedQueue.enqueue(newData)
    +    // We try to maintain 2 Datasets in persistedQueue to support the semantics of this class:
    +    // Users should call [[update()]] when a new Dataset has been created,
    +    // before the Dataset has been materialized.
    +    while (persistedQueue.size > 3) {
    +      val dataToUnpersist = persistedQueue.dequeue()
    +      unpersist(dataToUnpersist)
    +    }
    +    updateCount += 1
    +
    +    // Handle checkpointing (after persisting)
    +    if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) {
    +      // Add new checkpoint before removing old checkpoints.
    +      checkpoint(newData)
    +      checkpointQueue.enqueue(newData)
    +      // Remove checkpoints before the latest one.
    +      var canDelete = true
    +      while (checkpointQueue.size > 1 && canDelete) {
    +        // Delete the oldest checkpoint only if the next checkpoint exists.
    +        if (isCheckpointed(checkpointQueue.head)) {
    +          removeCheckpointFile()
    +        } else {
    +          canDelete = false
    +        }
    +      }
    +    }
    +  }
    +
    +  /** Checkpoint the Dataset */
    +  protected def checkpoint(data: T): Unit
    +
    +  /** Return true iff the Dataset is checkpointed */
    +  protected def isCheckpointed(data: T): Boolean
    +
    +  /**
    +   * Persist the Dataset.
    +   * Note: This should handle checking the current [[StorageLevel]] of the Dataset.
    +   */
    +  protected def persist(data: T): Unit
    +
    +  /** Unpersist the Dataset */
    +  protected def unpersist(data: T): Unit
    +
    +  /** Get list of checkpoint files for this given Dataset */
    +  protected def getCheckpointFiles(data: T): Iterable[String]
    +
    +  /**
    +   * Call this at the end to delete any remaining checkpoint files.
    +   */
    +  def deleteAllCheckpoints(): Unit = {
    +    while (checkpointQueue.nonEmpty) {
    +      removeCheckpointFile()
    +    }
    +  }
    +
    +  /**
    +   * Dequeue the oldest checkpointed Dataset, and remove its checkpoint files.
    +   * This prints a warning but does not fail if the files cannot be removed.
    +   */
    +  private def removeCheckpointFile(): Unit = {
    +    val old = checkpointQueue.dequeue()
    +    // Since the old checkpoint is not deleted by Spark, we manually delete it.
    +    val fs = FileSystem.get(sc.hadoopConfiguration)
    +    getCheckpointFiles(old).foreach { checkpointFile =>
    +      try {
    +        fs.delete(new Path(checkpointFile), true)
    +      } catch {
    +        case e: Exception =>
    +          logWarning("PeriodicCheckpointer could not remove old checkpoint file: " +
    +            checkpointFile)
    +      }
    +    }
    +  }
    +
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
    index 6e5dd119dd653..11a059536c50c 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
    @@ -17,11 +17,7 @@
     
     package org.apache.spark.mllib.impl
     
    -import scala.collection.mutable
    -
    -import org.apache.hadoop.fs.{Path, FileSystem}
    -
    -import org.apache.spark.Logging
    +import org.apache.spark.SparkContext
     import org.apache.spark.graphx.Graph
     import org.apache.spark.storage.StorageLevel
     
    @@ -31,12 +27,12 @@ import org.apache.spark.storage.StorageLevel
      * Specifically, it automatically handles persisting and (optionally) checkpointing, as well as
      * unpersisting and removing checkpoint files.
      *
    - * Users should call [[PeriodicGraphCheckpointer.updateGraph()]] when a new graph has been created,
    + * Users should call update() when a new graph has been created,
      * before the graph has been materialized.  After updating [[PeriodicGraphCheckpointer]], users are
      * responsible for materializing the graph to ensure that persisting and checkpointing actually
      * occur.
      *
    - * When [[PeriodicGraphCheckpointer.updateGraph()]] is called, this does the following:
    + * When update() is called, this does the following:
      *  - Persist new graph (if not yet persisted), and put in queue of persisted graphs.
      *  - Unpersist graphs from queue until there are at most 3 persisted graphs.
      *  - If using checkpointing and the checkpoint interval has been reached,
    @@ -52,7 +48,7 @@ import org.apache.spark.storage.StorageLevel
      * Example usage:
      * {{{
      *  val (graph1, graph2, graph3, ...) = ...
    - *  val cp = new PeriodicGraphCheckpointer(graph1, dir, 2)
    + *  val cp = new PeriodicGraphCheckpointer(2, sc)
      *  graph1.vertices.count(); graph1.edges.count()
      *  // persisted: graph1
      *  cp.updateGraph(graph2)
    @@ -73,99 +69,30 @@ import org.apache.spark.storage.StorageLevel
      *  // checkpointed: graph4
      * }}}
      *
    - * @param currentGraph  Initial graph
      * @param checkpointInterval Graphs will be checkpointed at this interval
      * @tparam VD  Vertex descriptor type
      * @tparam ED  Edge descriptor type
      *
    - * TODO: Generalize this for Graphs and RDDs, and move it out of MLlib.
    + * TODO: Move this out of MLlib?
      */
     private[mllib] class PeriodicGraphCheckpointer[VD, ED](
    -    var currentGraph: Graph[VD, ED],
    -    val checkpointInterval: Int) extends Logging {
    -
    -  /** FIFO queue of past checkpointed RDDs */
    -  private val checkpointQueue = mutable.Queue[Graph[VD, ED]]()
    -
    -  /** FIFO queue of past persisted RDDs */
    -  private val persistedQueue = mutable.Queue[Graph[VD, ED]]()
    -
    -  /** Number of times [[updateGraph()]] has been called */
    -  private var updateCount = 0
    -
    -  /**
    -   * Spark Context for the Graphs given to this checkpointer.
    -   * NOTE: This code assumes that only one SparkContext is used for the given graphs.
    -   */
    -  private val sc = currentGraph.vertices.sparkContext
    +    checkpointInterval: Int,
    +    sc: SparkContext)
    +  extends PeriodicCheckpointer[Graph[VD, ED]](checkpointInterval, sc) {
     
    -  updateGraph(currentGraph)
    +  override protected def checkpoint(data: Graph[VD, ED]): Unit = data.checkpoint()
     
    -  /**
    -   * Update [[currentGraph]] with a new graph. Handle persistence and checkpointing as needed.
    -   * Since this handles persistence and checkpointing, this should be called before the graph
    -   * has been materialized.
    -   *
    -   * @param newGraph  New graph created from previous graphs in the lineage.
    -   */
    -  def updateGraph(newGraph: Graph[VD, ED]): Unit = {
    -    if (newGraph.vertices.getStorageLevel == StorageLevel.NONE) {
    -      newGraph.persist()
    -    }
    -    persistedQueue.enqueue(newGraph)
    -    // We try to maintain 2 Graphs in persistedQueue to support the semantics of this class:
    -    // Users should call [[updateGraph()]] when a new graph has been created,
    -    // before the graph has been materialized.
    -    while (persistedQueue.size > 3) {
    -      val graphToUnpersist = persistedQueue.dequeue()
    -      graphToUnpersist.unpersist(blocking = false)
    -    }
    -    updateCount += 1
    +  override protected def isCheckpointed(data: Graph[VD, ED]): Boolean = data.isCheckpointed
     
    -    // Handle checkpointing (after persisting)
    -    if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) {
    -      // Add new checkpoint before removing old checkpoints.
    -      newGraph.checkpoint()
    -      checkpointQueue.enqueue(newGraph)
    -      // Remove checkpoints before the latest one.
    -      var canDelete = true
    -      while (checkpointQueue.size > 1 && canDelete) {
    -        // Delete the oldest checkpoint only if the next checkpoint exists.
    -        if (checkpointQueue.get(1).get.isCheckpointed) {
    -          removeCheckpointFile()
    -        } else {
    -          canDelete = false
    -        }
    -      }
    +  override protected def persist(data: Graph[VD, ED]): Unit = {
    +    if (data.vertices.getStorageLevel == StorageLevel.NONE) {
    +      data.persist()
         }
       }
     
    -  /**
    -   * Call this at the end to delete any remaining checkpoint files.
    -   */
    -  def deleteAllCheckpoints(): Unit = {
    -    while (checkpointQueue.size > 0) {
    -      removeCheckpointFile()
    -    }
    -  }
    +  override protected def unpersist(data: Graph[VD, ED]): Unit = data.unpersist(blocking = false)
     
    -  /**
    -   * Dequeue the oldest checkpointed Graph, and remove its checkpoint files.
    -   * This prints a warning but does not fail if the files cannot be removed.
    -   */
    -  private def removeCheckpointFile(): Unit = {
    -    val old = checkpointQueue.dequeue()
    -    // Since the old checkpoint is not deleted by Spark, we manually delete it.
    -    val fs = FileSystem.get(sc.hadoopConfiguration)
    -    old.getCheckpointFiles.foreach { checkpointFile =>
    -      try {
    -        fs.delete(new Path(checkpointFile), true)
    -      } catch {
    -        case e: Exception =>
    -          logWarning("PeriodicGraphCheckpointer could not remove old checkpoint file: " +
    -            checkpointFile)
    -      }
    -    }
    +  override protected def getCheckpointFiles(data: Graph[VD, ED]): Iterable[String] = {
    +    data.getCheckpointFiles
       }
    -
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala
    new file mode 100644
    index 0000000000000..f31ed2aa90a64
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala
    @@ -0,0 +1,97 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.mllib.impl
    +
    +import org.apache.spark.SparkContext
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.storage.StorageLevel
    +
    +
    +/**
    + * This class helps with persisting and checkpointing RDDs.
    + * Specifically, it automatically handles persisting and (optionally) checkpointing, as well as
    + * unpersisting and removing checkpoint files.
    + *
    + * Users should call update() when a new RDD has been created,
    + * before the RDD has been materialized.  After updating [[PeriodicRDDCheckpointer]], users are
    + * responsible for materializing the RDD to ensure that persisting and checkpointing actually
    + * occur.
    + *
    + * When update() is called, this does the following:
    + *  - Persist new RDD (if not yet persisted), and put in queue of persisted RDDs.
    + *  - Unpersist RDDs from queue until there are at most 3 persisted RDDs.
    + *  - If using checkpointing and the checkpoint interval has been reached,
    + *     - Checkpoint the new RDD, and put in a queue of checkpointed RDDs.
    + *     - Remove older checkpoints.
    + *
    + * WARNINGS:
    + *  - This class should NOT be copied (since copies may conflict on which RDDs should be
    + *    checkpointed).
    + *  - This class removes checkpoint files once later RDDs have been checkpointed.
    + *    However, references to the older RDDs will still return isCheckpointed = true.
    + *
    + * Example usage:
    + * {{{
    + *  val (rdd1, rdd2, rdd3, ...) = ...
    + *  val cp = new PeriodicRDDCheckpointer(2, sc)
    + *  rdd1.count();
    + *  // persisted: rdd1
    + *  cp.update(rdd2)
    + *  rdd2.count();
    + *  // persisted: rdd1, rdd2
    + *  // checkpointed: rdd2
    + *  cp.update(rdd3)
    + *  rdd3.count();
    + *  // persisted: rdd1, rdd2, rdd3
    + *  // checkpointed: rdd2
    + *  cp.update(rdd4)
    + *  rdd4.count();
    + *  // persisted: rdd2, rdd3, rdd4
    + *  // checkpointed: rdd4
    + *  cp.update(rdd5)
    + *  rdd5.count();
    + *  // persisted: rdd3, rdd4, rdd5
    + *  // checkpointed: rdd4
    + * }}}
    + *
    + * @param checkpointInterval  RDDs will be checkpointed at this interval
    + * @tparam T  RDD element type
    + *
    + * TODO: Move this out of MLlib?
    + */
    +private[mllib] class PeriodicRDDCheckpointer[T](
    +    checkpointInterval: Int,
    +    sc: SparkContext)
    +  extends PeriodicCheckpointer[RDD[T]](checkpointInterval, sc) {
    +
    +  override protected def checkpoint(data: RDD[T]): Unit = data.checkpoint()
    +
    +  override protected def isCheckpointed(data: RDD[T]): Boolean = data.isCheckpointed
    +
    +  override protected def persist(data: RDD[T]): Unit = {
    +    if (data.getStorageLevel == StorageLevel.NONE) {
    +      data.persist()
    +    }
    +  }
    +
    +  override protected def unpersist(data: RDD[T]): Unit = data.unpersist(blocking = false)
    +
    +  override protected def getCheckpointFiles(data: RDD[T]): Iterable[String] = {
    +    data.getCheckpointFile.map(x => x)
    +  }
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
    index 3523f1804325d..9029093e0fa08 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
    @@ -303,8 +303,8 @@ private[spark] object BLAS extends Serializable with Logging {
           C: DenseMatrix): Unit = {
         require(!C.isTransposed,
           "The matrix C cannot be the product of a transpose() call. C.isTransposed must be false.")
    -    if (alpha == 0.0) {
    -      logDebug("gemm: alpha is equal to 0. Returning C.")
    +    if (alpha == 0.0 && beta == 1.0) {
    +      logDebug("gemm: alpha is equal to 0 and beta is equal to 1. Returning C.")
         } else {
           A match {
             case sparse: SparseMatrix => gemm(alpha, sparse, B, beta, C)
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
    index 75e7004464af9..88914fa875990 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
    @@ -24,9 +24,9 @@ import scala.collection.mutable.{ArrayBuilder => MArrayBuilder, HashSet => MHash
     import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM}
     
     import org.apache.spark.annotation.DeveloperApi
    -import org.apache.spark.sql.Row
    -import org.apache.spark.sql.types._
     import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
    +import org.apache.spark.sql.catalyst.InternalRow
    +import org.apache.spark.sql.types._
     
     /**
      * Trait for a local matrix.
    @@ -98,7 +98,7 @@ sealed trait Matrix extends Serializable {
       /** Map the values of this matrix using a function. Generates a new matrix. Performs the
         * function on only the backing array. For example, an operation such as addition or
         * subtraction will only be performed on the non-zero values in a `SparseMatrix`. */
    -  private[mllib] def map(f: Double => Double): Matrix
    +  private[spark] def map(f: Double => Double): Matrix
     
       /** Update all the values of this matrix using the function f. Performed in-place on the
         * backing array. For example, an operation such as addition or subtraction will only be
    @@ -147,16 +147,16 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
           ))
       }
     
    -  override def serialize(obj: Any): Row = {
    +  override def serialize(obj: Any): InternalRow = {
         val row = new GenericMutableRow(7)
         obj match {
           case sm: SparseMatrix =>
             row.setByte(0, 0)
             row.setInt(1, sm.numRows)
             row.setInt(2, sm.numCols)
    -        row.update(3, sm.colPtrs.toSeq)
    -        row.update(4, sm.rowIndices.toSeq)
    -        row.update(5, sm.values.toSeq)
    +        row.update(3, new GenericArrayData(sm.colPtrs.map(_.asInstanceOf[Any])))
    +        row.update(4, new GenericArrayData(sm.rowIndices.map(_.asInstanceOf[Any])))
    +        row.update(5, new GenericArrayData(sm.values.map(_.asInstanceOf[Any])))
             row.setBoolean(6, sm.isTransposed)
     
           case dm: DenseMatrix =>
    @@ -165,7 +165,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
             row.setInt(2, dm.numCols)
             row.setNullAt(3)
             row.setNullAt(4)
    -        row.update(5, dm.values.toSeq)
    +        row.update(5, new GenericArrayData(dm.values.map(_.asInstanceOf[Any])))
             row.setBoolean(6, dm.isTransposed)
         }
         row
    @@ -173,20 +173,18 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
     
       override def deserialize(datum: Any): Matrix = {
         datum match {
    -      // TODO: something wrong with UDT serialization, should never happen.
    -      case m: Matrix => m
    -      case row: Row =>
    -        require(row.length == 7,
    -          s"MatrixUDT.deserialize given row with length ${row.length} but requires length == 7")
    +      case row: InternalRow =>
    +        require(row.numFields == 7,
    +          s"MatrixUDT.deserialize given row with length ${row.numFields} but requires length == 7")
             val tpe = row.getByte(0)
             val numRows = row.getInt(1)
             val numCols = row.getInt(2)
    -        val values = row.getAs[Iterable[Double]](5).toArray
    +        val values = row.getArray(5).toArray.map(_.asInstanceOf[Double])
             val isTransposed = row.getBoolean(6)
             tpe match {
               case 0 =>
    -            val colPtrs = row.getAs[Iterable[Int]](3).toArray
    -            val rowIndices = row.getAs[Iterable[Int]](4).toArray
    +            val colPtrs = row.getArray(3).toArray.map(_.asInstanceOf[Int])
    +            val rowIndices = row.getArray(4).toArray.map(_.asInstanceOf[Int])
                 new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed)
               case 1 =>
                 new DenseMatrix(numRows, numCols, values, isTransposed)
    @@ -291,7 +289,7 @@ class DenseMatrix(
     
       override def copy: DenseMatrix = new DenseMatrix(numRows, numCols, values.clone())
     
    -  private[mllib] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f),
    +  private[spark] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f),
         isTransposed)
     
       private[mllib] def update(f: Double => Double): DenseMatrix = {
    @@ -557,7 +555,7 @@ class SparseMatrix(
         new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.clone())
       }
     
    -  private[mllib] def map(f: Double => Double) =
    +  private[spark] def map(f: Double => Double) =
         new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.map(f), isTransposed)
     
       private[mllib] def update(f: Double => Double): SparseMatrix = {
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala
    index 9669c364bad8f..b416d50a5631e 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala
    @@ -25,3 +25,11 @@ import org.apache.spark.annotation.Experimental
      */
     @Experimental
     case class SingularValueDecomposition[UType, VType](U: UType, s: Vector, V: VType)
    +
    +/**
    + * :: Experimental ::
    + * Represents QR factors.
    + */
    +@Experimental
    +case class QRDecomposition[UType, VType](Q: UType, R: VType)
    +
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
    index c9c27425d2877..89a1818db0d1d 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
    @@ -28,7 +28,7 @@ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}
     import org.apache.spark.SparkException
     import org.apache.spark.annotation.DeveloperApi
     import org.apache.spark.mllib.util.NumericParser
    -import org.apache.spark.sql.Row
    +import org.apache.spark.sql.catalyst.InternalRow
     import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
     import org.apache.spark.sql.types._
     
    @@ -150,6 +150,12 @@ sealed trait Vector extends Serializable {
           toDense
         }
       }
    +
    +  /**
    +   * Find the index of a maximal element.  Returns the first maximal element in case of a tie.
    +   * Returns -1 if vector has length 0.
    +   */
    +  def argmax: Int
     }
     
     /**
    @@ -175,51 +181,41 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
           StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true)))
       }
     
    -  override def serialize(obj: Any): Row = {
    +  override def serialize(obj: Any): InternalRow = {
         obj match {
           case SparseVector(size, indices, values) =>
             val row = new GenericMutableRow(4)
             row.setByte(0, 0)
             row.setInt(1, size)
    -        row.update(2, indices.toSeq)
    -        row.update(3, values.toSeq)
    +        row.update(2, new GenericArrayData(indices.map(_.asInstanceOf[Any])))
    +        row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any])))
             row
           case DenseVector(values) =>
             val row = new GenericMutableRow(4)
             row.setByte(0, 1)
             row.setNullAt(1)
             row.setNullAt(2)
    -        row.update(3, values.toSeq)
    -        row
    -      // TODO: There are bugs in UDT serialization because we don't have a clear separation between
    -      // TODO: internal SQL types and language specific types (including UDT). UDT serialize and
    -      // TODO: deserialize may get called twice. See SPARK-7186.
    -      case row: Row =>
    +        row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any])))
             row
         }
       }
     
       override def deserialize(datum: Any): Vector = {
         datum match {
    -      case row: Row =>
    -        require(row.length == 4,
    -          s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4")
    +      case row: InternalRow =>
    +        require(row.numFields == 4,
    +          s"VectorUDT.deserialize given row with length ${row.numFields} but requires length == 4")
             val tpe = row.getByte(0)
             tpe match {
               case 0 =>
                 val size = row.getInt(1)
    -            val indices = row.getAs[Iterable[Int]](2).toArray
    -            val values = row.getAs[Iterable[Double]](3).toArray
    +            val indices = row.getArray(2).toArray().map(_.asInstanceOf[Int])
    +            val values = row.getArray(3).toArray().map(_.asInstanceOf[Double])
                 new SparseVector(size, indices, values)
               case 1 =>
    -            val values = row.getAs[Iterable[Double]](3).toArray
    +            val values = row.getArray(3).toArray().map(_.asInstanceOf[Double])
                 new DenseVector(values)
             }
    -      // TODO: There are bugs in UDT serialization because we don't have a clear separation between
    -      // TODO: internal SQL types and language specific types (including UDT). UDT serialize and
    -      // TODO: deserialize may get called twice. See SPARK-7186.
    -      case v: Vector =>
    -        v
         }
       }
     
    @@ -598,11 +594,7 @@ class DenseVector(val values: Array[Double]) extends Vector {
         new SparseVector(size, ii, vv)
       }
     
    -  /**
    -   * Find the index of a maximal element.  Returns the first maximal element in case of a tie.
    -   * Returns -1 if vector has length 0.
    -   */
    -  private[spark] def argmax: Int = {
    +  override def argmax: Int = {
         if (size == 0) {
           -1
         } else {
    @@ -642,6 +634,8 @@ class SparseVector(
       require(indices.length == values.length, "Sparse vectors require that the dimension of the" +
         s" indices match the dimension of the values. You provided ${indices.length} indices and " +
         s" ${values.length} values.")
    +  require(indices.length <= size, s"You provided ${indices.length} indices and values, " +
    +    s"which exceeds the specified vector size ${size}.")
     
       override def toString: String =
         s"($size,${indices.mkString("[", ",", "]")},${values.mkString("[", ",", "]")})"
    @@ -727,6 +721,51 @@ class SparseVector(
           new SparseVector(size, ii, vv)
         }
       }
    +
    +  override def argmax: Int = {
    +    if (size == 0) {
    +      -1
    +    } else {
    +      // Find the max active entry.
    +      var maxIdx = indices(0)
    +      var maxValue = values(0)
    +      var maxJ = 0
    +      var j = 1
    +      val na = numActives
    +      while (j < na) {
    +        val v = values(j)
    +        if (v > maxValue) {
    +          maxValue = v
    +          maxIdx = indices(j)
    +          maxJ = j
    +        }
    +        j += 1
    +      }
    +
    +      // If the max active entry is nonpositive and there exists inactive ones, find the first zero.
    +      if (maxValue <= 0.0 && na < size) {
    +        if (maxValue == 0.0) {
    +          // If there exists an inactive entry before maxIdx, find it and return its index.
    +          if (maxJ < maxIdx) {
    +            var k = 0
    +            while (k < maxJ && indices(k) == k) {
    +              k += 1
    +            }
    +            maxIdx = k
    +          }
    +        } else {
    +          // If the max active value is negative, find and return the first inactive index.
    +          var k = 0
    +          while (k < na && indices(k) == k) {
    +            k += 1
    +          }
    +          maxIdx = k
    +        }
    +      }
    +
    +      maxIdx
    +    }
    +  }
     }
     
     object SparseVector {
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
    index 1626da9c3d2ee..bfc90c9ef8527 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
    @@ -22,7 +22,7 @@ import java.util.Arrays
     import scala.collection.mutable.ListBuffer
     
     import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV, axpy => brzAxpy,
    -  svd => brzSvd}
    +  svd => brzSvd, MatrixSingularException, inv}
     import breeze.numerics.{sqrt => brzSqrt}
     import com.github.fommil.netlib.BLAS.{getInstance => blas}
     
    @@ -497,6 +497,50 @@ class RowMatrix(
         columnSimilaritiesDIMSUM(computeColumnSummaryStatistics().normL2.toArray, gamma)
       }
     
    +  /**
    +   * Compute QR decomposition for [[RowMatrix]]. The implementation is designed to optimize the QR
    +   * decomposition (factorization) for the [[RowMatrix]] of a tall and skinny shape.
    +   * Reference:
    +   *  Paul G. Constantine, David F. Gleich. "Tall and skinny QR factorizations in MapReduce
    +   *  architectures"  ([[http://dx.doi.org/10.1145/1996092.1996103]])
    +   *
    +   * @param computeQ whether to computeQ
    +   * @return QRDecomposition(Q, R), Q = null if computeQ = false.
    +   */
    +  def tallSkinnyQR(computeQ: Boolean = false): QRDecomposition[RowMatrix, Matrix] = {
    +    val col = numCols().toInt
    +    // split rows horizontally into smaller matrices, and compute QR for each of them
    +    val blockQRs = rows.glom().map { partRows =>
    +      val bdm = BDM.zeros[Double](partRows.length, col)
    +      var i = 0
    +      partRows.foreach { row =>
    +        bdm(i, ::) := row.toBreeze.t
    +        i += 1
    +      }
    +      breeze.linalg.qr.reduced(bdm).r
    +    }
    +
    +    // combine the R part from previous results vertically into a tall matrix
    +    val combinedR = blockQRs.treeReduce{ (r1, r2) =>
    +      val stackedR = BDM.vertcat(r1, r2)
    +      breeze.linalg.qr.reduced(stackedR).r
    +    }
    +    val finalR = Matrices.fromBreeze(combinedR.toDenseMatrix)
    +    val finalQ = if (computeQ) {
    +      try {
    +        val invR = inv(combinedR)
    +        this.multiply(Matrices.fromBreeze(invR))
    +      } catch {
    +        case err: MatrixSingularException =>
    +          logWarning("R is not invertible and return Q as null")
    +          null
    +      }
    +    } else {
    +      null
    +    }
    +    QRDecomposition(finalQ, finalR)
    +  }
    +
       /**
        * Find all similar columns using the DIMSUM sampling algorithm, described in two papers
        *
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala
    index 35e81fcb3de0d..1facf83d806d0 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala
    @@ -72,7 +72,7 @@ class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int
           val w1 = windowSize - 1
           // Get the first w1 items of each partition, starting from the second partition.
           val nextHeads =
    -        parent.context.runJob(parent, (iter: Iterator[T]) => iter.take(w1).toArray, 1 until n, true)
    +        parent.context.runJob(parent, (iter: Iterator[T]) => iter.take(w1).toArray, 1 until n)
           val partitions = mutable.ArrayBuffer[SlidingRDDPartition[T]]()
           var i = 0
           var partitionIndex = 0
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
    index 93290e6508529..56c549ef99cb7 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
    @@ -26,6 +26,7 @@ import org.apache.spark.storage.StorageLevel
     
     /**
      * A more compact class to represent a rating than Tuple3[Int, Int, Double].
    + * @since 0.8.0
      */
     case class Rating(user: Int, product: Int, rating: Double)
     
    @@ -254,6 +255,7 @@ class ALS private (
     
     /**
      * Top-level methods for calling Alternating Least Squares (ALS) matrix factorization.
    + * @since 0.8.0
      */
     object ALS {
       /**
    @@ -269,6 +271,7 @@ object ALS {
        * @param lambda     regularization factor (recommended: 0.01)
        * @param blocks     level of parallelism to split computation into
        * @param seed       random seed
    +   * @since 0.9.1
        */
       def train(
           ratings: RDD[Rating],
    @@ -293,6 +296,7 @@ object ALS {
        * @param iterations number of iterations of ALS (recommended: 10-20)
        * @param lambda     regularization factor (recommended: 0.01)
        * @param blocks     level of parallelism to split computation into
    +   * @since 0.8.0
        */
       def train(
           ratings: RDD[Rating],
    @@ -315,6 +319,7 @@ object ALS {
        * @param rank       number of features to use
        * @param iterations number of iterations of ALS (recommended: 10-20)
        * @param lambda     regularization factor (recommended: 0.01)
    +   * @since 0.8.0
        */
       def train(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double)
         : MatrixFactorizationModel = {
    @@ -331,6 +336,7 @@ object ALS {
        * @param ratings    RDD of (userID, productID, rating) pairs
        * @param rank       number of features to use
        * @param iterations number of iterations of ALS (recommended: 10-20)
    +   * @since 0.8.0
        */
       def train(ratings: RDD[Rating], rank: Int, iterations: Int)
         : MatrixFactorizationModel = {
    @@ -351,6 +357,7 @@ object ALS {
        * @param blocks     level of parallelism to split computation into
        * @param alpha      confidence parameter
        * @param seed       random seed
    +   * @since 0.8.1
        */
       def trainImplicit(
           ratings: RDD[Rating],
    @@ -377,6 +384,7 @@ object ALS {
        * @param lambda     regularization factor (recommended: 0.01)
        * @param blocks     level of parallelism to split computation into
        * @param alpha      confidence parameter
    +   * @since 0.8.1
        */
       def trainImplicit(
           ratings: RDD[Rating],
    @@ -401,6 +409,7 @@ object ALS {
        * @param iterations number of iterations of ALS (recommended: 10-20)
        * @param lambda     regularization factor (recommended: 0.01)
        * @param alpha      confidence parameter
    +   * @since 0.8.1
        */
       def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double, alpha: Double)
         : MatrixFactorizationModel = {
    @@ -418,6 +427,7 @@ object ALS {
        * @param ratings    RDD of (userID, productID, rating) pairs
        * @param rank       number of features to use
        * @param iterations number of iterations of ALS (recommended: 10-20)
    +   * @since 0.8.1
        */
       def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int)
         : MatrixFactorizationModel = {
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
    index 43d219a49cf4e..261ca9cef0c5b 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
    @@ -49,6 +49,7 @@ import org.apache.spark.storage.StorageLevel
      *                     the features computed for this user.
      * @param productFeatures RDD of tuples where each tuple represents the productId
      *                        and the features computed for this product.
    + * @since 0.8.0
      */
     class MatrixFactorizationModel(
         val rank: Int,
    @@ -73,7 +74,9 @@ class MatrixFactorizationModel(
         }
       }
     
    -  /** Predict the rating of one user for one product. */
    +  /** Predict the rating of one user for one product.
    +   * @since 0.8.0
    +   */
       def predict(user: Int, product: Int): Double = {
         val userVector = userFeatures.lookup(user).head
         val productVector = productFeatures.lookup(product).head
    @@ -111,6 +114,7 @@ class MatrixFactorizationModel(
        *
        * @param usersProducts  RDD of (user, product) pairs.
        * @return RDD of Ratings.
    +   * @since 0.9.0
        */
       def predict(usersProducts: RDD[(Int, Int)]): RDD[Rating] = {
         // Previously the partitions of ratings are only based on the given products.
    @@ -142,6 +146,7 @@ class MatrixFactorizationModel(
     
       /**
        * Java-friendly version of [[MatrixFactorizationModel.predict]].
    +   * @since 1.2.0
        */
       def predict(usersProducts: JavaPairRDD[JavaInteger, JavaInteger]): JavaRDD[Rating] = {
         predict(usersProducts.rdd.asInstanceOf[RDD[(Int, Int)]]).toJavaRDD()
    @@ -157,6 +162,7 @@ class MatrixFactorizationModel(
        *  by score, decreasing. The first returned is the one predicted to be most strongly
        *  recommended to the user. The score is an opaque value that indicates how strongly
        *  recommended the product is.
    +   *  @since 1.1.0
        */
       def recommendProducts(user: Int, num: Int): Array[Rating] =
         MatrixFactorizationModel.recommend(userFeatures.lookup(user).head, productFeatures, num)
    @@ -173,6 +179,7 @@ class MatrixFactorizationModel(
        *  by score, decreasing. The first returned is the one predicted to be most strongly
        *  recommended to the product. The score is an opaque value that indicates how strongly
        *  recommended the user is.
    +   *  @since 1.1.0
        */
       def recommendUsers(product: Int, num: Int): Array[Rating] =
         MatrixFactorizationModel.recommend(productFeatures.lookup(product).head, userFeatures, num)
    @@ -180,6 +187,20 @@ class MatrixFactorizationModel(
     
       protected override val formatVersion: String = "1.0"
     
    +  /**
    +   * Save this model to the given path.
    +   *
    +   * This saves:
    +   *  - human-readable (JSON) model metadata to path/metadata/
    +   *  - Parquet formatted data to path/data/
    +   *
    +   * The model may be loaded using [[Loader.load]].
    +   *
    +   * @param sc  Spark context used to save model data.
    +   * @param path  Path specifying the directory in which to save this model.
    +   *              If the directory already exists, this method throws an exception.
    +   * @since 1.3.0
    +   */
       override def save(sc: SparkContext, path: String): Unit = {
         MatrixFactorizationModel.SaveLoadV1_0.save(this, path)
       }
    @@ -191,6 +212,7 @@ class MatrixFactorizationModel(
        * @return [(Int, Array[Rating])] objects, where every tuple contains a userID and an array of
        * rating objects which contains the same userId, recommended productID and a "score" in the
        * rating field. Semantics of score is same as recommendProducts API
    +   * @since 1.4.0
        */
       def recommendProductsForUsers(num: Int): RDD[(Int, Array[Rating])] = {
         MatrixFactorizationModel.recommendForAll(rank, userFeatures, productFeatures, num).map {
    @@ -208,6 +230,7 @@ class MatrixFactorizationModel(
        * @return [(Int, Array[Rating])] objects, where every tuple contains a productID and an array
        * of rating objects which contains the recommended userId, same productID and a "score" in the
        * rating field. Semantics of score is same as recommendUsers API
    +   * @since 1.4.0
        */
       def recommendUsersForProducts(num: Int): RDD[(Int, Array[Rating])] = {
         MatrixFactorizationModel.recommendForAll(rank, productFeatures, userFeatures, num).map {
    @@ -218,6 +241,9 @@ class MatrixFactorizationModel(
       }
     }
     
    +/**
    + * @since 1.3.0
    + */
     object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
     
       import org.apache.spark.mllib.util.Loader._
    @@ -292,6 +318,16 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
         }
       }
     
    +  /**
    +   * Load a model from the given path.
    +   *
    +   * The model should have been saved by [[Saveable.save]].
    +   *
    +   * @param sc  Spark context used for loading model files.
    +   * @param path  Path specifying the directory to which the model was saved.
    +   * @return  Model instance
    +   * @since 1.3.0
    +   */
       override def load(sc: SparkContext, path: String): MatrixFactorizationModel = {
         val (loadedClassName, formatVersion, _) = loadMetadata(sc, path)
         val classNameV1_0 = SaveLoadV1_0.thisClassName
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala
    index 58a50f9c19f14..93a6753efd4d9 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala
    @@ -37,6 +37,7 @@ import org.apache.spark.rdd.RDD
      *   .setBandwidth(3.0)
      * val densities = kd.estimate(Array(-1.0, 2.0, 5.0))
      * }}}
    + * @since 1.4.0
      */
     @Experimental
     class KernelDensity extends Serializable {
    @@ -51,6 +52,7 @@ class KernelDensity extends Serializable {
     
       /**
        * Sets the bandwidth (standard deviation) of the Gaussian kernel (default: `1.0`).
    +   * @since 1.4.0
        */
       def setBandwidth(bandwidth: Double): this.type = {
         require(bandwidth > 0, s"Bandwidth must be positive, but got $bandwidth.")
    @@ -60,6 +62,7 @@ class KernelDensity extends Serializable {
     
       /**
        * Sets the sample to use for density estimation.
    +   * @since 1.4.0
        */
       def setSample(sample: RDD[Double]): this.type = {
         this.sample = sample
    @@ -68,6 +71,7 @@ class KernelDensity extends Serializable {
     
       /**
        * Sets the sample to use for density estimation (for Java users).
    +   * @since 1.4.0
        */
       def setSample(sample: JavaRDD[java.lang.Double]): this.type = {
         this.sample = sample.rdd.asInstanceOf[RDD[Double]]
    @@ -76,6 +80,7 @@ class KernelDensity extends Serializable {
     
       /**
        * Estimates probability density function at the given array of points.
    +   * @since 1.4.0
        */
       def estimate(points: Array[Double]): Array[Double] = {
         val sample = this.sample
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
    index d321cc554c1cc..62da9f2ef22a3 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
    @@ -33,6 +33,7 @@ import org.apache.spark.mllib.linalg.{Vectors, Vector}
      * Reference: [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]]
      * Zero elements (including explicit zero values) are skipped when calling add(),
      * to have time complexity O(nnz) instead of O(n) for each column.
    + * @since 1.1.0
      */
     @DeveloperApi
     class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with Serializable {
    @@ -52,6 +53,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
        *
        * @param sample The sample in dense/sparse vector format to be added into this summarizer.
        * @return This MultivariateOnlineSummarizer object.
    +   * @since 1.1.0
        */
       def add(sample: Vector): this.type = {
         if (n == 0) {
    @@ -107,6 +109,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
        *
        * @param other The other MultivariateOnlineSummarizer to be merged.
        * @return This MultivariateOnlineSummarizer object.
    +   * @since 1.1.0
        */
       def merge(other: MultivariateOnlineSummarizer): this.type = {
         if (this.totalCnt != 0 && other.totalCnt != 0) {
    @@ -149,6 +152,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
         this
       }
     
    +  /**
    +   * @since 1.1.0
    +   */
       override def mean: Vector = {
         require(totalCnt > 0, s"Nothing has been added to this summarizer.")
     
    @@ -161,6 +167,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
         Vectors.dense(realMean)
       }
     
    +  /**
    +   * @since 1.1.0
    +   */
       override def variance: Vector = {
         require(totalCnt > 0, s"Nothing has been added to this summarizer.")
     
    @@ -183,14 +192,23 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
         Vectors.dense(realVariance)
       }
     
    +  /**
    +   * @since 1.1.0
    +   */
       override def count: Long = totalCnt
     
    +  /**
    +   * @since 1.1.0
    +   */
       override def numNonzeros: Vector = {
         require(totalCnt > 0, s"Nothing has been added to this summarizer.")
     
         Vectors.dense(nnz)
       }
     
    +  /**
    +   * @since 1.1.0
    +   */
       override def max: Vector = {
         require(totalCnt > 0, s"Nothing has been added to this summarizer.")
     
    @@ -202,6 +220,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
         Vectors.dense(currMax)
       }
     
    +  /**
    +   * @since 1.1.0
    +   */
       override def min: Vector = {
         require(totalCnt > 0, s"Nothing has been added to this summarizer.")
     
    @@ -213,6 +234,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
         Vectors.dense(currMin)
       }
     
    +  /**
    +   * @since 1.2.0
    +   */
       override def normL2: Vector = {
         require(totalCnt > 0, s"Nothing has been added to this summarizer.")
     
    @@ -227,6 +251,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
         Vectors.dense(realMagnitude)
       }
     
    +  /**
    +   * @since 1.2.0
    +   */
       override def normL1: Vector = {
         require(totalCnt > 0, s"Nothing has been added to this summarizer.")
     
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala
    index 6a364c93284af..3bb49f12289e1 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala
    @@ -21,46 +21,55 @@ import org.apache.spark.mllib.linalg.Vector
     
     /**
      * Trait for multivariate statistical summary of a data matrix.
    + * @since 1.0.0
      */
     trait MultivariateStatisticalSummary {
     
       /**
        * Sample mean vector.
    +   * @since 1.0.0
        */
       def mean: Vector
     
       /**
        * Sample variance vector. Should return a zero vector if the sample size is 1.
    +   * @since 1.0.0
        */
       def variance: Vector
     
       /**
        * Sample size.
    +   * @since 1.0.0
        */
       def count: Long
     
       /**
        * Number of nonzero elements (including explicitly presented zero values) in each column.
    +   * @since 1.0.0
        */
       def numNonzeros: Vector
     
       /**
        * Maximum value of each column.
    +   * @since 1.0.0
        */
       def max: Vector
     
       /**
        * Minimum value of each column.
    +   * @since 1.0.0
        */
       def min: Vector
     
       /**
        * Euclidean magnitude of each column
    +   * @since 1.2.0
        */
       def normL2: Vector
     
       /**
        * L1 norm of each column
    +   * @since 1.2.0
        */
       def normL1: Vector
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala
    index 900007ec6bc74..f84502919e381 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala
    @@ -17,18 +17,22 @@
     
     package org.apache.spark.mllib.stat
     
    +import scala.annotation.varargs
    +
     import org.apache.spark.annotation.Experimental
     import org.apache.spark.api.java.JavaRDD
     import org.apache.spark.mllib.linalg.distributed.RowMatrix
     import org.apache.spark.mllib.linalg.{Matrix, Vector}
     import org.apache.spark.mllib.regression.LabeledPoint
     import org.apache.spark.mllib.stat.correlation.Correlations
    -import org.apache.spark.mllib.stat.test.{ChiSqTest, ChiSqTestResult}
    +import org.apache.spark.mllib.stat.test.{ChiSqTest, ChiSqTestResult, KolmogorovSmirnovTest,
    +  KolmogorovSmirnovTestResult}
     import org.apache.spark.rdd.RDD
     
     /**
      * :: Experimental ::
      * API for statistical functions in MLlib.
    + * @since 1.1.0
      */
     @Experimental
     object Statistics {
    @@ -38,6 +42,7 @@ object Statistics {
        *
        * @param X an RDD[Vector] for which column-wise summary statistics are to be computed.
        * @return [[MultivariateStatisticalSummary]] object containing column-wise summary statistics.
    +   * @since 1.1.0
        */
       def colStats(X: RDD[Vector]): MultivariateStatisticalSummary = {
         new RowMatrix(X).computeColumnSummaryStatistics()
    @@ -49,6 +54,7 @@ object Statistics {
        *
        * @param X an RDD[Vector] for which the correlation matrix is to be computed.
        * @return Pearson correlation matrix comparing columns in X.
    +   * @since 1.1.0
        */
       def corr(X: RDD[Vector]): Matrix = Correlations.corrMatrix(X)
     
    @@ -65,6 +71,7 @@ object Statistics {
        * @param method String specifying the method to use for computing correlation.
        *               Supported: `pearson` (default), `spearman`
        * @return Correlation matrix comparing columns in X.
    +   * @since 1.1.0
        */
       def corr(X: RDD[Vector], method: String): Matrix = Correlations.corrMatrix(X, method)
     
    @@ -78,10 +85,14 @@ object Statistics {
        * @param x RDD[Double] of the same cardinality as y.
        * @param y RDD[Double] of the same cardinality as x.
        * @return A Double containing the Pearson correlation between the two input RDD[Double]s
    +   * @since 1.1.0
        */
       def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y)
     
    -  /** Java-friendly version of [[corr()]] */
    +  /**
    +   * Java-friendly version of [[corr()]]
    +   * @since 1.4.1
    +   */
       def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double]): Double =
         corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]])
     
    @@ -98,10 +109,14 @@ object Statistics {
        *               Supported: `pearson` (default), `spearman`
        * @return A Double containing the correlation between the two input RDD[Double]s using the
        *         specified method.
    +   * @since 1.1.0
        */
       def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method)
     
    -  /** Java-friendly version of [[corr()]] */
    +  /**
    +   * Java-friendly version of [[corr()]]
    +   * @since 1.4.1
    +   */
       def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double], method: String): Double =
         corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]], method)
     
    @@ -118,6 +133,7 @@ object Statistics {
        *                 `expected` is rescaled if the `expected` sum differs from the `observed` sum.
        * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value,
        *         the method used, and the null hypothesis.
    +   * @since 1.1.0
        */
       def chiSqTest(observed: Vector, expected: Vector): ChiSqTestResult = {
         ChiSqTest.chiSquared(observed, expected)
    @@ -132,6 +148,7 @@ object Statistics {
        * @param observed Vector containing the observed categorical counts/relative frequencies.
        * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value,
        *         the method used, and the null hypothesis.
    +   * @since 1.1.0
        */
       def chiSqTest(observed: Vector): ChiSqTestResult = ChiSqTest.chiSquared(observed)
     
    @@ -142,6 +159,7 @@ object Statistics {
        * @param observed The contingency matrix (containing either counts or relative frequencies).
        * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value,
        *         the method used, and the null hypothesis.
    +   * @since 1.1.0
        */
       def chiSqTest(observed: Matrix): ChiSqTestResult = ChiSqTest.chiSquaredMatrix(observed)
     
    @@ -154,8 +172,44 @@ object Statistics {
        *             Real-valued features will be treated as categorical for each distinct value.
        * @return an array containing the ChiSquaredTestResult for every feature against the label.
        *         The order of the elements in the returned array reflects the order of input features.
    +   * @since 1.1.0
        */
       def chiSqTest(data: RDD[LabeledPoint]): Array[ChiSqTestResult] = {
         ChiSqTest.chiSquaredFeatures(data)
       }
    +
    +  /**
    +   * Conduct the two-sided Kolmogorov-Smirnov (KS) test for data sampled from a
    +   * continuous distribution. By comparing the largest difference between the empirical cumulative
    +   * distribution of the sample data and the theoretical distribution we can provide a test for the
    +   * the null hypothesis that the sample data comes from that theoretical distribution.
    +   * For more information on KS Test:
    +   * @see [[https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test]]
    +   *
    +   * @param data an `RDD[Double]` containing the sample of data to test
    +   * @param cdf a `Double => Double` function to calculate the theoretical CDF at a given value
    +   * @return [[org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult]] object containing test
    +   *        statistic, p-value, and null hypothesis.
    +   */
    +  def kolmogorovSmirnovTest(data: RDD[Double], cdf: Double => Double)
    +    : KolmogorovSmirnovTestResult = {
    +    KolmogorovSmirnovTest.testOneSample(data, cdf)
    +  }
    +
    +  /**
    +   * Convenience function to conduct a one-sample, two-sided Kolmogorov-Smirnov test for probability
    +   * distribution equality. Currently supports the normal distribution, taking as parameters
    +   * the mean and standard deviation.
    +   * (distName = "norm")
    +   * @param data an `RDD[Double]` containing the sample of data to test
    +   * @param distName a `String` name for a theoretical distribution
    +   * @param params `Double*` specifying the parameters to be used for the theoretical distribution
    +   * @return [[org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult]] object containing test
    +   *        statistic, p-value, and null hypothesis.
    +   */
    +  @varargs
    +  def kolmogorovSmirnovTest(data: RDD[Double], distName: String, params: Double*)
    +    : KolmogorovSmirnovTestResult = {
    +    KolmogorovSmirnovTest.testOneSample(data, distName, params: _*)
    +  }
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
    index cf51b24ff777f..9aa7763d7890d 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
    @@ -32,6 +32,7 @@ import org.apache.spark.mllib.util.MLUtils
      *
      * @param mu The mean vector of the distribution
      * @param sigma The covariance matrix of the distribution
    + * @since 1.3.0
      */
     @DeveloperApi
     class MultivariateGaussian (
    @@ -60,12 +61,16 @@ class MultivariateGaussian (
        */
       private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants
     
    -  /** Returns density of this multivariate Gaussian at given point, x */
    +  /** Returns density of this multivariate Gaussian at given point, x
    +    * @since 1.3.0
    +    */
       def pdf(x: Vector): Double = {
         pdf(x.toBreeze)
       }
     
    -  /** Returns the log-density of this multivariate Gaussian at given point, x */
    +  /** Returns the log-density of this multivariate Gaussian at given point, x
    +    * @since 1.3.0
    +    */
       def logpdf(x: Vector): Double = {
         logpdf(x.toBreeze)
       }
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala
    new file mode 100644
    index 0000000000000..2b3ed6df486c9
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala
    @@ -0,0 +1,194 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.mllib.stat.test
    +
    +import scala.annotation.varargs
    +
    +import org.apache.commons.math3.distribution.{NormalDistribution, RealDistribution}
    +import org.apache.commons.math3.stat.inference.{KolmogorovSmirnovTest => CommonMathKolmogorovSmirnovTest}
    +
    +import org.apache.spark.Logging
    +import org.apache.spark.rdd.RDD
    +
    +/**
    + * Conduct the two-sided Kolmogorov Smirnov (KS) test for data sampled from a
    + * continuous distribution. By comparing the largest difference between the empirical cumulative
    + * distribution of the sample data and the theoretical distribution we can provide a test for the
    + * the null hypothesis that the sample data comes from that theoretical distribution.
    + * For more information on KS Test:
    + * @see [[https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test]]
    + *
    + * Implementation note: We seek to implement the KS test with a minimal number of distributed
    + * passes. We sort the RDD, and then perform the following operations on a per-partition basis:
    + * calculate an empirical cumulative distribution value for each observation, and a theoretical
    + * cumulative distribution value. We know the latter to be correct, while the former will be off by
    + * a constant (how large the constant is depends on how many values precede it in other partitions).
    + * However, given that this constant simply shifts the empirical CDF upwards, but doesn't
    + * change its shape, and furthermore, that constant is the same within a given partition, we can
    + * pick 2 values in each partition that can potentially resolve to the largest global distance.
    + * Namely, we pick the minimum distance and the maximum distance. Additionally, we keep track of how
    + * many elements are in each partition. Once these three values have been returned for every
    + * partition, we can collect and operate locally. Locally, we can now adjust each distance by the
    + * appropriate constant (the cumulative sum of number of elements in the prior partitions divided by
    + * thedata set size). Finally, we take the maximum absolute value, and this is the statistic.
    + */
    +private[stat] object KolmogorovSmirnovTest extends Logging {
    +
    +  // Null hypothesis for the type of KS test to be included in the result.
    +  object NullHypothesis extends Enumeration {
    +    type NullHypothesis = Value
    +    val OneSampleTwoSided = Value("Sample follows theoretical distribution")
    +  }
    +
    +  /**
    +   * Runs a KS test for 1 set of sample data, comparing it to a theoretical distribution
    +   * @param data `RDD[Double]` data on which to run test
    +   * @param cdf `Double => Double` function to calculate the theoretical CDF
    +   * @return [[org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult]] summarizing the test
    +   *        results (p-value, statistic, and null hypothesis)
    +   */
    +  def testOneSample(data: RDD[Double], cdf: Double => Double): KolmogorovSmirnovTestResult = {
    +    val n = data.count().toDouble
    +    val localData = data.sortBy(x => x).mapPartitions { part =>
    +      val partDiffs = oneSampleDifferences(part, n, cdf) // local distances
    +      searchOneSampleCandidates(partDiffs) // candidates: local extrema
    +    }.collect()
    +    val ksStat = searchOneSampleStatistic(localData, n) // result: global extreme
    +    evalOneSampleP(ksStat, n.toLong)
    +  }
    +
    +  /**
    +   * Runs a KS test for 1 set of sample data, comparing it to a theoretical distribution
    +   * @param data `RDD[Double]` data on which to run test
    +   * @param distObj `RealDistribution` a theoretical distribution
    +   * @return [[org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult]] summarizing the test
    +   *        results (p-value, statistic, and null hypothesis)
    +   */
    +  def testOneSample(data: RDD[Double], distObj: RealDistribution): KolmogorovSmirnovTestResult = {
    +    val cdf = (x: Double) => distObj.cumulativeProbability(x)
    +    testOneSample(data, cdf)
    +  }
    +
    +  /**
    +   * Calculate unadjusted distances between the empirical CDF and the theoretical CDF in a
    +   * partition
    +   * @param partData `Iterator[Double]` 1 partition of a sorted RDD
    +   * @param n `Double` the total size of the RDD
    +   * @param cdf `Double => Double` a function the calculates the theoretical CDF of a value
    +   * @return `Iterator[(Double, Double)] `Unadjusted (ie. off by a constant) potential extrema
    +   *        in a partition. The first element corresponds to the (empirical CDF - 1/N) - CDF,
    +   *        the second element corresponds to empirical CDF - CDF.  We can then search the resulting
    +   *        iterator for the minimum of the first and the maximum of the second element, and provide
    +   *        this as a partition's candidate extrema
    +   */
    +  private def oneSampleDifferences(partData: Iterator[Double], n: Double, cdf: Double => Double)
    +    : Iterator[(Double, Double)] = {
    +    // zip data with index (within that partition)
    +    // calculate local (unadjusted) empirical CDF and subtract CDF
    +    partData.zipWithIndex.map { case (v, ix) =>
    +      // dp and dl are later adjusted by constant, when global info is available
    +      val dp = (ix + 1) / n
    +      val dl = ix / n
    +      val cdfVal = cdf(v)
    +      (dl - cdfVal, dp - cdfVal)
    +    }
    +  }
    +
    +  /**
    +   * Search the unadjusted differences in a partition and return the
    +   * two extrema (furthest below and furthest above CDF), along with a count of elements in that
    +   * partition
    +   * @param partDiffs `Iterator[(Double, Double)]` the unadjusted differences between empirical CDF
    +   *                 and CDFin a partition, which come as a tuple of
    +   *                 (empirical CDF - 1/N - CDF, empirical CDF - CDF)
    +   * @return `Iterator[(Double, Double, Double)]` the local extrema and a count of elements
    +   */
    +  private def searchOneSampleCandidates(partDiffs: Iterator[(Double, Double)])
    +    : Iterator[(Double, Double, Double)] = {
    +    val initAcc = (Double.MaxValue, Double.MinValue, 0.0)
    +    val pResults = partDiffs.foldLeft(initAcc) { case ((pMin, pMax, pCt), (dl, dp)) =>
    +      (math.min(pMin, dl), math.max(pMax, dp), pCt + 1)
    +    }
    +    val results = if (pResults == initAcc) Array[(Double, Double, Double)]() else Array(pResults)
    +    results.iterator
    +  }
    +
    +  /**
    +   * Find the global maximum distance between empirical CDF and CDF (i.e. the KS statistic) after
    +   * adjusting local extrema estimates from individual partitions with the amount of elements in
    +   * preceding partitions
    +   * @param localData `Array[(Double, Double, Double)]` A local array containing the collected
    +   *                 results of `searchOneSampleCandidates` across all partitions
    +   * @param n `Double`The size of the RDD
    +   * @return The one-sample Kolmogorov Smirnov Statistic
    +   */
    +  private def searchOneSampleStatistic(localData: Array[(Double, Double, Double)], n: Double)
    +    : Double = {
    +    val initAcc = (Double.MinValue, 0.0)
    +    // adjust differences based on the number of elements preceding it, which should provide
    +    // the correct distance between empirical CDF and CDF
    +    val results = localData.foldLeft(initAcc) { case ((prevMax, prevCt), (minCand, maxCand, ct)) =>
    +      val adjConst = prevCt / n
    +      val dist1 = math.abs(minCand + adjConst)
    +      val dist2 = math.abs(maxCand + adjConst)
    +      val maxVal = Array(prevMax, dist1, dist2).max
    +      (maxVal, prevCt + ct)
    +    }
    +    results._1
    +  }
    +
    +  /**
    +   * A convenience function that allows running the KS test for 1 set of sample data against
    +   * a named distribution
    +   * @param data the sample data that we wish to evaluate
    +   * @param distName the name of the theoretical distribution
    +   * @param params Variable length parameter for distribution's parameters
    +   * @return [[org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult]] summarizing the
    +   *        test results (p-value, statistic, and null hypothesis)
    +   */
    +  @varargs
    +  def testOneSample(data: RDD[Double], distName: String, params: Double*)
    +    : KolmogorovSmirnovTestResult = {
    +    val distObj =
    +      distName match {
    +        case "norm" => {
    +          if (params.nonEmpty) {
    +            // parameters are passed, then can only be 2
    +            require(params.length == 2, "Normal distribution requires mean and standard " +
    +              "deviation as parameters")
    +            new NormalDistribution(params(0), params(1))
    +          } else {
    +            // if no parameters passed in initializes to standard normal
    +            logInfo("No parameters specified for normal distribution," +
    +              "initialized to standard normal (i.e. N(0, 1))")
    +            new NormalDistribution(0, 1)
    +          }
    +        }
    +        case  _ => throw new UnsupportedOperationException(s"$distName not yet supported through" +
    +          s" convenience method. Current options are:['norm'].")
    +      }
    +
    +    testOneSample(data, distObj)
    +  }
    +
    +  private def evalOneSampleP(ksStat: Double, n: Long): KolmogorovSmirnovTestResult = {
    +    val pval = 1 - new CommonMathKolmogorovSmirnovTest().cdf(ksStat, n.toInt)
    +    new KolmogorovSmirnovTestResult(pval, ksStat, NullHypothesis.OneSampleTwoSided.toString)
    +  }
    +}
    +
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala
    index 4784f9e947908..f44be13706695 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala
    @@ -90,3 +90,20 @@ class ChiSqTestResult private[stat] (override val pValue: Double,
           super.toString
       }
     }
    +
    +/**
    + * :: Experimental ::
    + * Object containing the test results for the Kolmogorov-Smirnov test.
    + */
    +@Experimental
    +class KolmogorovSmirnovTestResult private[stat] (
    +    override val pValue: Double,
    +    override val statistic: Double,
    +    override val nullHypothesis: String) extends TestResult[Int] {
    +
    +  override val degreesOfFreedom = 0
    +
    +  override def toString: String = {
    +    "Kolmogorov-Smirnov test summary:\n" + super.toString
    +  }
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
    index a835f96d5d0e3..9ce6faa137c41 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
    @@ -20,6 +20,7 @@ package org.apache.spark.mllib.tree
     import org.apache.spark.Logging
     import org.apache.spark.annotation.Experimental
     import org.apache.spark.api.java.JavaRDD
    +import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer
     import org.apache.spark.mllib.regression.LabeledPoint
     import org.apache.spark.mllib.tree.configuration.BoostingStrategy
     import org.apache.spark.mllib.tree.configuration.Algo._
    @@ -184,22 +185,28 @@ object GradientBoostedTrees extends Logging {
           false
         }
     
    +    // Prepare periodic checkpointers
    +    val predErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
    +      treeStrategy.getCheckpointInterval, input.sparkContext)
    +    val validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
    +      treeStrategy.getCheckpointInterval, input.sparkContext)
    +
         timer.stop("init")
     
         logDebug("##########")
         logDebug("Building tree 0")
         logDebug("##########")
    -    var data = input
     
         // Initialize tree
         timer.start("building tree 0")
    -    val firstTreeModel = new DecisionTree(treeStrategy).run(data)
    +    val firstTreeModel = new DecisionTree(treeStrategy).run(input)
         val firstTreeWeight = 1.0
         baseLearners(0) = firstTreeModel
         baseLearnerWeights(0) = firstTreeWeight
     
         var predError: RDD[(Double, Double)] = GradientBoostedTreesModel.
           computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
    +    predErrorCheckpointer.update(predError)
         logDebug("error of gbt = " + predError.values.mean())
     
         // Note: A model of type regression is used since we require raw prediction
    @@ -207,35 +214,34 @@ object GradientBoostedTrees extends Logging {
     
         var validatePredError: RDD[(Double, Double)] = GradientBoostedTreesModel.
           computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss)
    +    if (validate) validatePredErrorCheckpointer.update(validatePredError)
         var bestValidateError = if (validate) validatePredError.values.mean() else 0.0
         var bestM = 1
     
    -    // pseudo-residual for second iteration
    -    data = predError.zip(input).map { case ((pred, _), point) =>
    -      LabeledPoint(-loss.gradient(pred, point.label), point.features)
    -    }
    -
         var m = 1
    -    while (m < numIterations) {
    +    var doneLearning = false
    +    while (m < numIterations && !doneLearning) {
    +      // Update data with pseudo-residuals
    +      val data = predError.zip(input).map { case ((pred, _), point) =>
    +        LabeledPoint(-loss.gradient(pred, point.label), point.features)
    +      }
    +
           timer.start(s"building tree $m")
           logDebug("###################################################")
           logDebug("Gradient boosting tree iteration " + m)
           logDebug("###################################################")
           val model = new DecisionTree(treeStrategy).run(data)
           timer.stop(s"building tree $m")
    -      // Create partial model
    +      // Update partial model
           baseLearners(m) = model
           // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
           //       Technically, the weight should be optimized for the particular loss.
           //       However, the behavior should be reasonable, though not optimal.
           baseLearnerWeights(m) = learningRate
    -      // Note: A model of type regression is used since we require raw prediction
    -      val partialModel = new GradientBoostedTreesModel(
    -        Regression, baseLearners.slice(0, m + 1),
    -        baseLearnerWeights.slice(0, m + 1))
     
           predError = GradientBoostedTreesModel.updatePredictionError(
             input, predError, baseLearnerWeights(m), baseLearners(m), loss)
    +      predErrorCheckpointer.update(predError)
           logDebug("error of gbt = " + predError.values.mean())
     
           if (validate) {
    @@ -246,21 +252,15 @@ object GradientBoostedTrees extends Logging {
     
             validatePredError = GradientBoostedTreesModel.updatePredictionError(
               validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss)
    +        validatePredErrorCheckpointer.update(validatePredError)
             val currentValidateError = validatePredError.values.mean()
             if (bestValidateError - currentValidateError < validationTol) {
    -          return new GradientBoostedTreesModel(
    -            boostingStrategy.treeStrategy.algo,
    -            baseLearners.slice(0, bestM),
    -            baseLearnerWeights.slice(0, bestM))
    +          doneLearning = true
             } else if (currentValidateError < bestValidateError) {
    -            bestValidateError = currentValidateError
    -            bestM = m + 1
    +          bestValidateError = currentValidateError
    +          bestM = m + 1
             }
           }
    -      // Update data with pseudo-residuals
    -      data = predError.zip(input).map { case ((pred, _), point) =>
    -        LabeledPoint(-loss.gradient(pred, point.label), point.features)
    -      }
           m += 1
         }
     
    @@ -269,6 +269,8 @@ object GradientBoostedTrees extends Logging {
         logInfo("Internal timing for DecisionTree:")
         logInfo(s"$timer")
     
    +    predErrorCheckpointer.deleteAllCheckpoints()
    +    validatePredErrorCheckpointer.deleteAllCheckpoints()
         if (persistedInput) input.unpersist()
     
         if (validate) {
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
    index 2d6b01524ff3d..9fd30c9b56319 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
    @@ -36,7 +36,8 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
      *                     learning rate should be between in the interval (0, 1]
      * @param validationTol Useful when runWithValidation is used. If the error rate on the
      *                      validation input between two iterations is less than the validationTol
    - *                      then stop. Ignored when [[run]] is used.
    + *                      then stop.  Ignored when
    + *                      [[org.apache.spark.mllib.tree.GradientBoostedTrees.run()]] is used.
      */
     @Experimental
     case class BoostingStrategy(
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
    index 089010c81ffb6..572815df0bc4a 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
    @@ -38,10 +38,10 @@ import org.apache.spark.util.random.XORShiftRandom
      * TODO: This does not currently support (Double) weighted instances.  Once MLlib has weighted
      *       dataset support, update.  (We store subsampleWeights as Double for this future extension.)
      */
    -private[tree] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double])
    +private[spark] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double])
       extends Serializable
     
    -private[tree] object BaggedPoint {
    +private[spark] object BaggedPoint {
     
       /**
        * Convert an input dataset into its BaggedPoint representation,
    @@ -60,7 +60,7 @@ private[tree] object BaggedPoint {
           subsamplingRate: Double,
           numSubsamples: Int,
           withReplacement: Boolean,
    -      seed: Int = Utils.random.nextInt()): RDD[BaggedPoint[Datum]] = {
    +      seed: Long = Utils.random.nextLong()): RDD[BaggedPoint[Datum]] = {
         if (withReplacement) {
           convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed)
         } else {
    @@ -76,7 +76,7 @@ private[tree] object BaggedPoint {
           input: RDD[Datum],
           subsamplingRate: Double,
           numSubsamples: Int,
    -      seed: Int): RDD[BaggedPoint[Datum]] = {
    +      seed: Long): RDD[BaggedPoint[Datum]] = {
         input.mapPartitionsWithIndex { (partitionIndex, instances) =>
           // Use random seed = seed + partitionIndex + 1 to make generation reproducible.
           val rng = new XORShiftRandom
    @@ -100,7 +100,7 @@ private[tree] object BaggedPoint {
           input: RDD[Datum],
           subsample: Double,
           numSubsamples: Int,
    -      seed: Int): RDD[BaggedPoint[Datum]] = {
    +      seed: Long): RDD[BaggedPoint[Datum]] = {
         input.mapPartitionsWithIndex { (partitionIndex, instances) =>
           // Use random seed = seed + partitionIndex + 1 to make generation reproducible.
           val poisson = new PoissonDistribution(subsample)
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
    index ce8825cc03229..7985ed4b4c0fa 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
    @@ -27,7 +27,7 @@ import org.apache.spark.mllib.tree.impurity._
      * and helps with indexing.
      * This class is abstract to support learning with and without feature subsampling.
      */
    -private[tree] class DTStatsAggregator(
    +private[spark] class DTStatsAggregator(
         val metadata: DecisionTreeMetadata,
         featureSubset: Option[Array[Int]]) extends Serializable {
     
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
    index f73896e37c05e..9fe264656ede7 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
    @@ -37,7 +37,7 @@ import org.apache.spark.rdd.RDD
      *                      I.e., the feature takes values in {0, ..., arity - 1}.
      * @param numBins  Number of bins for each feature.
      */
    -private[tree] class DecisionTreeMetadata(
    +private[spark] class DecisionTreeMetadata(
         val numFeatures: Int,
         val numExamples: Long,
         val numClasses: Int,
    @@ -94,7 +94,7 @@ private[tree] class DecisionTreeMetadata(
     
     }
     
    -private[tree] object DecisionTreeMetadata extends Logging {
    +private[spark] object DecisionTreeMetadata extends Logging {
     
       /**
        * Construct a [[DecisionTreeMetadata]] instance for this dataset and parameters.
    @@ -128,9 +128,13 @@ private[tree] object DecisionTreeMetadata extends Logging {
         // based on the number of training examples.
         if (strategy.categoricalFeaturesInfo.nonEmpty) {
           val maxCategoriesPerFeature = strategy.categoricalFeaturesInfo.values.max
    +      val maxCategory =
    +        strategy.categoricalFeaturesInfo.find(_._2 == maxCategoriesPerFeature).get._1
           require(maxCategoriesPerFeature <= maxPossibleBins,
    -        s"DecisionTree requires maxBins (= $maxPossibleBins) >= max categories " +
    -          s"in categorical features (= $maxCategoriesPerFeature)")
    +        s"DecisionTree requires maxBins (= $maxPossibleBins) to be at least as large as the " +
    +        s"number of values in each categorical feature, but categorical feature $maxCategory " +
    +        s"has $maxCategoriesPerFeature values. Considering remove this and other categorical " +
    +        "features with a large number of values, or add more training examples.")
         }
     
         val unorderedFeatures = new mutable.HashSet[Int]()
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
    index bdd0f576b048d..8f9eb24b57b55 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
    @@ -75,7 +75,7 @@ private[tree] case class NodeIndexUpdater(
      *                           (how often should the cache be checkpointed.).
      */
     @DeveloperApi
    -private[tree] class NodeIdCache(
    +private[spark] class NodeIdCache(
       var nodeIdsForInstances: RDD[Array[Int]],
       val checkpointInterval: Int) {
     
    @@ -170,7 +170,7 @@ private[tree] class NodeIdCache(
     }
     
     @DeveloperApi
    -private[tree] object NodeIdCache {
    +private[spark] object NodeIdCache {
       /**
        * Initialize the node Id cache with initial node Id values.
        * @param data The RDD of training rows.
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala
    index d215d68c4279e..aac84243d5ce1 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala
    @@ -25,7 +25,7 @@ import org.apache.spark.annotation.Experimental
      * Time tracker implementation which holds labeled timers.
      */
     @Experimental
    -private[tree] class TimeTracker extends Serializable {
    +private[spark] class TimeTracker extends Serializable {
     
       private val starts: MutableHashMap[String, Long] = new MutableHashMap[String, Long]()
     
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
    index 50b292e71b067..21919d69a38a3 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
    @@ -37,11 +37,11 @@ import org.apache.spark.rdd.RDD
      * @param binnedFeatures  Binned feature values.
      *                        Same length as LabeledPoint.features, but values are bin indices.
      */
    -private[tree] class TreePoint(val label: Double, val binnedFeatures: Array[Int])
    +private[spark] class TreePoint(val label: Double, val binnedFeatures: Array[Int])
       extends Serializable {
     }
     
    -private[tree] object TreePoint {
    +private[spark] object TreePoint {
     
       /**
        * Convert an input dataset into its TreePoint representation,
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
    index 72eb24c49264a..578749d85a4e6 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
    @@ -57,7 +57,7 @@ trait Impurity extends Serializable {
      * Note: Instances of this class do not hold the data; they operate on views of the data.
      * @param statsSize  Length of the vector of sufficient statistics for one bin.
      */
    -private[tree] abstract class ImpurityAggregator(val statsSize: Int) extends Serializable {
    +private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Serializable {
     
       /**
        * Merge the stats from one bin into another.
    @@ -95,7 +95,7 @@ private[tree] abstract class ImpurityAggregator(val statsSize: Int) extends Seri
      * (node, feature, bin).
      * @param stats  Array of sufficient statistics for a (node, feature, bin).
      */
    -private[tree] abstract class ImpurityCalculator(val stats: Array[Double]) {
    +private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) {
     
       /**
        * Make a deep copy of this [[ImpurityCalculator]].
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
    index a5582d3ef3324..011a5d57422f7 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
    @@ -42,11 +42,11 @@ object SquaredError extends Loss {
        * @return Loss gradient
        */
       override def gradient(prediction: Double, label: Double): Double = {
    -    2.0 * (prediction - label)
    +    - 2.0 * (label - prediction)
       }
     
       override private[mllib] def computeError(prediction: Double, label: Double): Double = {
    -    val err = prediction - label
    +    val err = label - prediction
         err * err
       }
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
    index 2d087c967f679..dc9e0f9f51ffb 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
    @@ -67,7 +67,7 @@ class InformationGainStats(
     }
     
     
    -private[tree] object InformationGainStats {
    +private[spark] object InformationGainStats {
       /**
        * An [[org.apache.spark.mllib.tree.model.InformationGainStats]] object to
        * denote that current split doesn't satisfies minimum info gain or
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala
    index 6eaebaf7dba9f..e6bcff48b022c 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala
    @@ -64,8 +64,10 @@ object KMeansDataGenerator {
     
       def main(args: Array[String]) {
         if (args.length < 6) {
    +      // scalastyle:off println
           println("Usage: KMeansGenerator " +
             "      []")
    +      // scalastyle:on println
           System.exit(1)
         }
     
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
    index b4e33c98ba7e5..87eeb5db05d26 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
    @@ -153,8 +153,10 @@ object LinearDataGenerator {
     
       def main(args: Array[String]) {
         if (args.length < 2) {
    +      // scalastyle:off println
           println("Usage: LinearDataGenerator " +
             "  [num_examples] [num_features] [num_partitions]")
    +      // scalastyle:on println
           System.exit(1)
         }
     
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala
    index 9d802678c4a77..c09cbe69bb971 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala
    @@ -64,8 +64,10 @@ object LogisticRegressionDataGenerator {
     
       def main(args: Array[String]) {
         if (args.length != 5) {
    +      // scalastyle:off println
           println("Usage: LogisticRegressionGenerator " +
             "    ")
    +      // scalastyle:on println
           System.exit(1)
         }
     
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala
    index bd73a866c8a82..16f430599a515 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala
    @@ -55,8 +55,10 @@ import org.apache.spark.rdd.RDD
     object MFDataGenerator {
       def main(args: Array[String]) {
         if (args.length < 2) {
    +      // scalastyle:off println
           println("Usage: MFDataGenerator " +
             "  [m] [n] [rank] [trainSampFact] [noise] [sigma] [test] [testSampFact]")
    +      // scalastyle:on println
           System.exit(1)
         }
     
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala
    index a8e30cc9d730c..ad20b7694a779 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala
    @@ -37,8 +37,10 @@ object SVMDataGenerator {
     
       def main(args: Array[String]) {
         if (args.length < 2) {
    +      // scalastyle:off println
           println("Usage: SVMGenerator " +
             "  [num_examples] [num_features] [num_partitions]")
    +      // scalastyle:on println
           System.exit(1)
         }
     
    diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
    new file mode 100644
    index 0000000000000..09a9fba0c19cf
    --- /dev/null
    +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
    @@ -0,0 +1,98 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.classification;
    +
    +import java.io.Serializable;
    +
    +import com.google.common.collect.Lists;
    +import org.junit.After;
    +import org.junit.Before;
    +import org.junit.Test;
    +
    +import org.apache.spark.api.java.JavaRDD;
    +import org.apache.spark.api.java.JavaSparkContext;
    +import org.apache.spark.mllib.linalg.VectorUDT;
    +import org.apache.spark.mllib.linalg.Vectors;
    +import org.apache.spark.sql.DataFrame;
    +import org.apache.spark.sql.Row;
    +import org.apache.spark.sql.RowFactory;
    +import org.apache.spark.sql.SQLContext;
    +import org.apache.spark.sql.types.DataTypes;
    +import org.apache.spark.sql.types.Metadata;
    +import org.apache.spark.sql.types.StructField;
    +import org.apache.spark.sql.types.StructType;
    +
    +public class JavaNaiveBayesSuite implements Serializable {
    +
    +  private transient JavaSparkContext jsc;
    +  private transient SQLContext jsql;
    +
    +  @Before
    +  public void setUp() {
    +    jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
    +    jsql = new SQLContext(jsc);
    +  }
    +
    +  @After
    +  public void tearDown() {
    +    jsc.stop();
    +    jsc = null;
    +  }
    +
    +  public void validatePrediction(DataFrame predictionAndLabels) {
    +    for (Row r : predictionAndLabels.collect()) {
    +      double prediction = r.getAs(0);
    +      double label = r.getAs(1);
    +      assert(prediction == label);
    +    }
    +  }
    +
    +  @Test
    +  public void naiveBayesDefaultParams() {
    +    NaiveBayes nb = new NaiveBayes();
    +    assert(nb.getLabelCol() == "label");
    +    assert(nb.getFeaturesCol() == "features");
    +    assert(nb.getPredictionCol() == "prediction");
    +    assert(nb.getLambda() == 1.0);
    +    assert(nb.getModelType() == "multinomial");
    +  }
    +
    +  @Test
    +  public void testNaiveBayes() {
    +    JavaRDD jrdd = jsc.parallelize(Lists.newArrayList(
    +      RowFactory.create(0.0, Vectors.dense(1.0, 0.0, 0.0)),
    +      RowFactory.create(0.0, Vectors.dense(2.0, 0.0, 0.0)),
    +      RowFactory.create(1.0, Vectors.dense(0.0, 1.0, 0.0)),
    +      RowFactory.create(1.0, Vectors.dense(0.0, 2.0, 0.0)),
    +      RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 1.0)),
    +      RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 2.0))
    +    ));
    +
    +    StructType schema = new StructType(new StructField[]{
    +      new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
    +      new StructField("features", new VectorUDT(), false, Metadata.empty())
    +    });
    +
    +    DataFrame dataset = jsql.createDataFrame(jrdd, schema);
    +    NaiveBayes nb = new NaiveBayes().setLambda(0.5).setModelType("multinomial");
    +    NaiveBayesModel model = nb.fit(dataset);
    +
    +    DataFrame predictionAndLabels = model.transform(dataset).select("prediction", "label");
    +    validatePrediction(predictionAndLabels);
    +  }
    +}
    diff --git a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
    new file mode 100644
    index 0000000000000..d09fa7fd5637c
    --- /dev/null
    +++ b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
    @@ -0,0 +1,72 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.clustering;
    +
    +import java.io.Serializable;
    +import java.util.Arrays;
    +import java.util.List;
    +
    +import org.junit.After;
    +import org.junit.Before;
    +import org.junit.Test;
    +import static org.junit.Assert.assertArrayEquals;
    +import static org.junit.Assert.assertEquals;
    +import static org.junit.Assert.assertTrue;
    +
    +import org.apache.spark.api.java.JavaSparkContext;
    +import org.apache.spark.mllib.linalg.Vector;
    +import org.apache.spark.sql.DataFrame;
    +import org.apache.spark.sql.SQLContext;
    +
    +public class JavaKMeansSuite implements Serializable {
    +
    +  private transient int k = 5;
    +  private transient JavaSparkContext sc;
    +  private transient DataFrame dataset;
    +  private transient SQLContext sql;
    +
    +  @Before
    +  public void setUp() {
    +    sc = new JavaSparkContext("local", "JavaKMeansSuite");
    +    sql = new SQLContext(sc);
    +
    +    dataset = KMeansSuite.generateKMeansData(sql, 50, 3, k);
    +  }
    +
    +  @After
    +  public void tearDown() {
    +    sc.stop();
    +    sc = null;
    +  }
    +
    +  @Test
    +  public void fitAndTransform() {
    +    KMeans kmeans = new KMeans().setK(k).setSeed(1);
    +    KMeansModel model = kmeans.fit(dataset);
    +
    +    Vector[] centers = model.clusterCenters();
    +    assertEquals(k, centers.length);
    +
    +    DataFrame transformed = model.transform(dataset);
    +    List columns = Arrays.asList(transformed.columns());
    +    List expectedColumns = Arrays.asList("features", "prediction");
    +    for (String column: expectedColumns) {
    +      assertTrue(columns.contains(column));
    +    }
    +  }
    +}
    diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java
    new file mode 100644
    index 0000000000000..5cf43fec6f29e
    --- /dev/null
    +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java
    @@ -0,0 +1,114 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.feature;
    +
    +import java.io.Serializable;
    +import java.util.List;
    +
    +import scala.Tuple2;
    +
    +import com.google.common.collect.Lists;
    +import org.junit.After;
    +import org.junit.Assert;
    +import org.junit.Before;
    +import org.junit.Test;
    +
    +import org.apache.spark.api.java.function.Function;
    +import org.apache.spark.api.java.JavaRDD;
    +import org.apache.spark.api.java.JavaSparkContext;
    +import org.apache.spark.mllib.linalg.distributed.RowMatrix;
    +import org.apache.spark.mllib.linalg.Matrix;
    +import org.apache.spark.mllib.linalg.Vector;
    +import org.apache.spark.mllib.linalg.Vectors;
    +import org.apache.spark.sql.DataFrame;
    +import org.apache.spark.sql.Row;
    +import org.apache.spark.sql.SQLContext;
    +
    +public class JavaPCASuite implements Serializable {
    +  private transient JavaSparkContext jsc;
    +  private transient SQLContext sqlContext;
    +
    +  @Before
    +  public void setUp() {
    +    jsc = new JavaSparkContext("local", "JavaPCASuite");
    +    sqlContext = new SQLContext(jsc);
    +  }
    +
    +  @After
    +  public void tearDown() {
    +    jsc.stop();
    +    jsc = null;
    +  }
    +
    +  public static class VectorPair implements Serializable {
    +    private Vector features = Vectors.dense(0.0);
    +    private Vector expected = Vectors.dense(0.0);
    +
    +    public void setFeatures(Vector features) {
    +      this.features = features;
    +    }
    +
    +    public Vector getFeatures() {
    +      return this.features;
    +    }
    +
    +    public void setExpected(Vector expected) {
    +      this.expected = expected;
    +    }
    +
    +    public Vector getExpected() {
    +      return this.expected;
    +    }
    +  }
    +
    +  @Test
    +  public void testPCA() {
    +    List points = Lists.newArrayList(
    +      Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0}),
    +      Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0),
    +      Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)
    +    );
    +    JavaRDD dataRDD = jsc.parallelize(points, 2);
    +
    +    RowMatrix mat = new RowMatrix(dataRDD.rdd());
    +    Matrix pc = mat.computePrincipalComponents(3);
    +    JavaRDD expected = mat.multiply(pc).rows().toJavaRDD();
    +
    +    JavaRDD featuresExpected = dataRDD.zip(expected).map(
    +      new Function, VectorPair>() {
    +        public VectorPair call(Tuple2 pair) {
    +          VectorPair featuresExpected = new VectorPair();
    +          featuresExpected.setFeatures(pair._1());
    +          featuresExpected.setExpected(pair._2());
    +          return featuresExpected;
    +        }
    +      }
    +    );
    +
    +    DataFrame df = sqlContext.createDataFrame(featuresExpected, VectorPair.class);
    +    PCAModel pca = new PCA()
    +      .setInputCol("features")
    +      .setOutputCol("pca_features")
    +      .setK(3)
    +      .fit(df);
    +    List result = pca.transform(df).select("pca_features", "expected").toJavaRDD().collect();
    +    for (Row r : result) {
    +      Assert.assertEquals(r.get(1), r.get(0));
    +    }
    +  }
    +}
    diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
    index 3ae09d39ef500..dc6ce8061f62b 100644
    --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
    +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
    @@ -96,11 +96,8 @@ private void init() {
           new DoubleArrayParam(this, "myDoubleArrayParam", "this is a double param");
     
         setDefault(myIntParam(), 1);
    -    setDefault(myIntParam().w(1));
         setDefault(myDoubleParam(), 0.5);
    -    setDefault(myIntParam().w(1), myDoubleParam().w(0.5));
         setDefault(myDoubleArrayParam(), new double[] {1.0, 2.0});
    -    setDefault(myDoubleArrayParam().w(new double[] {1.0, 2.0}));
       }
     
       @Override
    diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
    index 71b041818d7ee..ebe800e749e05 100644
    --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
    +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
    @@ -57,7 +57,7 @@ public void runDT() {
         JavaRDD data = sc.parallelize(
           LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
         Map categoricalFeatures = new HashMap();
    -    DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
    +    DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
     
         // This tests setters. Training with various options is tested in Scala.
         DecisionTreeRegressor dt = new DecisionTreeRegressor()
    diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
    index 581c033f08ebe..d272a42c8576f 100644
    --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
    +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
    @@ -19,6 +19,7 @@
     
     import java.io.Serializable;
     import java.util.ArrayList;
    +import java.util.Arrays;
     
     import scala.Tuple2;
     
    @@ -28,12 +29,13 @@
     import org.junit.Before;
     import org.junit.Test;
     
    +import org.apache.spark.api.java.function.Function;
     import org.apache.spark.api.java.JavaPairRDD;
     import org.apache.spark.api.java.JavaRDD;
     import org.apache.spark.api.java.JavaSparkContext;
     import org.apache.spark.mllib.linalg.Matrix;
     import org.apache.spark.mllib.linalg.Vector;
    -
    +import org.apache.spark.mllib.linalg.Vectors;
     
     public class JavaLDASuite implements Serializable {
       private transient JavaSparkContext sc;
    @@ -58,7 +60,10 @@ public void tearDown() {
     
       @Test
       public void localLDAModel() {
    -    LocalLDAModel model = new LocalLDAModel(LDASuite$.MODULE$.tinyTopics());
    +    Matrix topics = LDASuite$.MODULE$.tinyTopics();
    +    double[] topicConcentration = new double[topics.numRows()];
    +    Arrays.fill(topicConcentration, 1.0D / topics.numRows());
    +    LocalLDAModel model = new LocalLDAModel(topics, Vectors.dense(topicConcentration), 1D, 100D);
     
         // Check: basic parameters
         assertEquals(model.k(), tinyK);
    @@ -110,7 +115,15 @@ public void distributedLDAModel() {
     
         // Check: topic distributions
         JavaPairRDD topicDistributions = model.javaTopicDistributions();
    -    assertEquals(topicDistributions.count(), corpus.count());
    +    // SPARK-5562. since the topicDistribution returns the distribution of the non empty docs
    +    // over topics. Compare it against nonEmptyCorpus instead of corpus
    +    JavaPairRDD nonEmptyCorpus = corpus.filter(
    +      new Function, Boolean>() {
    +        public Boolean call(Tuple2 tuple2) {
    +          return Vectors.norm(tuple2._2(), 1.0) != 0.0;
    +        }
    +    });
    +    assertEquals(topicDistributions.count(), nonEmptyCorpus.count());
       }
     
       @Test
    diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java
    new file mode 100644
    index 0000000000000..b3815ae6039c0
    --- /dev/null
    +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java
    @@ -0,0 +1,58 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +package org.apache.spark.mllib.fpm;
    +
    +import java.io.Serializable;
    +
    +import org.junit.After;
    +import org.junit.Before;
    +import org.junit.Test;
    +import com.google.common.collect.Lists;
    +
    +import org.apache.spark.api.java.JavaRDD;
    +import org.apache.spark.api.java.JavaSparkContext;
    +import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset;
    +
    +
    +public class JavaAssociationRulesSuite implements Serializable {
    +  private transient JavaSparkContext sc;
    +
    +  @Before
    +  public void setUp() {
    +    sc = new JavaSparkContext("local", "JavaFPGrowth");
    +  }
    +
    +  @After
    +  public void tearDown() {
    +    sc.stop();
    +    sc = null;
    +  }
    +
    +  @Test
    +  public void runAssociationRules() {
    +
    +    @SuppressWarnings("unchecked")
    +    JavaRDD> freqItemsets = sc.parallelize(Lists.newArrayList(
    +      new FreqItemset(new String[] {"a"}, 15L),
    +      new FreqItemset(new String[] {"b"}, 35L),
    +      new FreqItemset(new String[] {"a", "b"}, 18L)
    +    ));
    +
    +    JavaRDD> results = (new AssociationRules()).run(freqItemsets);
    +  }
    +}
    +
    diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
    index bd0edf2b9ea62..9ce2c52dca8b6 100644
    --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
    +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
    @@ -29,7 +29,6 @@
     
     import org.apache.spark.api.java.JavaRDD;
     import org.apache.spark.api.java.JavaSparkContext;
    -import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset;
     
     public class JavaFPGrowthSuite implements Serializable {
       private transient JavaSparkContext sc;
    @@ -62,10 +61,10 @@ public void runFPGrowth() {
           .setNumPartitions(2)
           .run(rdd);
     
    -    List> freqItemsets = model.freqItemsets().toJavaRDD().collect();
    +    List> freqItemsets = model.freqItemsets().toJavaRDD().collect();
         assertEquals(18, freqItemsets.size());
     
    -    for (FreqItemset itemset: freqItemsets) {
    +    for (FPGrowth.FreqItemset itemset: freqItemsets) {
           // Test return types.
           List items = itemset.javaItems();
           long freq = itemset.freq();
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
    index c5fd2f9d5a22a..6355e0f179496 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
    @@ -218,7 +218,7 @@ class AttributeSuite extends SparkFunSuite {
         // Attribute.fromStructField should accept any NumericType, not just DoubleType
         val longFldWithMeta = new StructField("x", LongType, false, metadata)
         assert(Attribute.fromStructField(longFldWithMeta).isNumeric)
    -    val decimalFldWithMeta = new StructField("x", DecimalType(None), false, metadata)
    +    val decimalFldWithMeta = new StructField("x", DecimalType(38, 18), false, metadata)
         assert(Attribute.fromStructField(decimalFldWithMeta).isNumeric)
       }
     }
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
    index 82c345491bb3c..a7bc77965fefd 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
    @@ -28,6 +28,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.rdd.RDD
     import org.apache.spark.sql.DataFrame
    +import org.apache.spark.util.Utils
     
     
     /**
    @@ -76,6 +77,25 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
         }
       }
     
    +  test("Checkpointing") {
    +    val tempDir = Utils.createTempDir()
    +    val path = tempDir.toURI.toString
    +    sc.setCheckpointDir(path)
    +
    +    val categoricalFeatures = Map.empty[Int, Int]
    +    val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2)
    +    val gbt = new GBTClassifier()
    +      .setMaxDepth(2)
    +      .setLossType("logistic")
    +      .setMaxIter(5)
    +      .setStepSize(0.1)
    +      .setCheckpointInterval(2)
    +    val model = gbt.fit(df)
    +
    +    sc.checkpointDir = None
    +    Utils.deleteRecursively(tempDir)
    +  }
    +
       // TODO: Reinstate test once runWithValidation is implemented   SPARK-7132
       /*
       test("runWithValidation stops early and performs better on a validation dataset") {
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
    index ba8fbee84197c..b7dd44753896a 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
    @@ -77,6 +77,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
         assert(lr.getRawPredictionCol === "rawPrediction")
         assert(lr.getProbabilityCol === "probability")
         assert(lr.getFitIntercept)
    +    assert(lr.getStandardization)
         val model = lr.fit(dataset)
         model.transform(dataset)
           .select("label", "probability", "prediction", "rawPrediction")
    @@ -208,8 +209,11 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
       }
     
       test("binary logistic regression with intercept without regularization") {
    -    val trainer = (new LogisticRegression).setFitIntercept(true)
    -    val model = trainer.fit(binaryDataset)
    +    val trainer1 = (new LogisticRegression).setFitIntercept(true).setStandardization(true)
    +    val trainer2 = (new LogisticRegression).setFitIntercept(true).setStandardization(false)
    +
    +    val model1 = trainer1.fit(binaryDataset)
    +    val model2 = trainer2.fit(binaryDataset)
     
         /*
            Using the following R code to load the data and train the model using glmnet package.
    @@ -230,18 +234,22 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
            data.V5     -0.7996864
          */
         val interceptR = 2.8366423
    -    val weightsR = Array(-0.5895848, 0.8931147, -0.3925051, -0.7996864)
    +    val weightsR = Vectors.dense(-0.5895848, 0.8931147, -0.3925051, -0.7996864)
    +
    +    assert(model1.intercept ~== interceptR relTol 1E-3)
    +    assert(model1.weights ~= weightsR relTol 1E-3)
     
    -    assert(model.intercept ~== interceptR relTol 1E-3)
    -    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
    -    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
    -    assert(model.weights(2) ~== weightsR(2) relTol 1E-3)
    -    assert(model.weights(3) ~== weightsR(3) relTol 1E-3)
    +    // Without regularization, with or without standardization will converge to the same solution.
    +    assert(model2.intercept ~== interceptR relTol 1E-3)
    +    assert(model2.weights ~= weightsR relTol 1E-3)
       }
     
       test("binary logistic regression without intercept without regularization") {
    -    val trainer = (new LogisticRegression).setFitIntercept(false)
    -    val model = trainer.fit(binaryDataset)
    +    val trainer1 = (new LogisticRegression).setFitIntercept(false).setStandardization(true)
    +    val trainer2 = (new LogisticRegression).setFitIntercept(false).setStandardization(false)
    +
    +    val model1 = trainer1.fit(binaryDataset)
    +    val model2 = trainer2.fit(binaryDataset)
     
         /*
            Using the following R code to load the data and train the model using glmnet package.
    @@ -263,19 +271,24 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
            data.V5     -0.7407946
          */
         val interceptR = 0.0
    -    val weightsR = Array(-0.3534996, 1.2964482, -0.3571741, -0.7407946)
    +    val weightsR = Vectors.dense(-0.3534996, 1.2964482, -0.3571741, -0.7407946)
    +
    +    assert(model1.intercept ~== interceptR relTol 1E-3)
    +    assert(model1.weights ~= weightsR relTol 1E-2)
     
    -    assert(model.intercept ~== interceptR relTol 1E-3)
    -    assert(model.weights(0) ~== weightsR(0) relTol 1E-2)
    -    assert(model.weights(1) ~== weightsR(1) relTol 1E-2)
    -    assert(model.weights(2) ~== weightsR(2) relTol 1E-3)
    -    assert(model.weights(3) ~== weightsR(3) relTol 1E-3)
    +    // Without regularization, with or without standardization should converge to the same solution.
    +    assert(model2.intercept ~== interceptR relTol 1E-3)
    +    assert(model2.weights ~= weightsR relTol 1E-2)
       }
     
       test("binary logistic regression with intercept with L1 regularization") {
    -    val trainer = (new LogisticRegression).setFitIntercept(true)
    -      .setElasticNetParam(1.0).setRegParam(0.12)
    -    val model = trainer.fit(binaryDataset)
    +    val trainer1 = (new LogisticRegression).setFitIntercept(true)
    +      .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(true)
    +    val trainer2 = (new LogisticRegression).setFitIntercept(true)
    +      .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(false)
    +
    +    val model1 = trainer1.fit(binaryDataset)
    +    val model2 = trainer2.fit(binaryDataset)
     
         /*
            Using the following R code to load the data and train the model using glmnet package.
    @@ -295,20 +308,46 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
            data.V4     -0.04325749
            data.V5     -0.02481551
          */
    -    val interceptR = -0.05627428
    -    val weightsR = Array(0.0, 0.0, -0.04325749, -0.02481551)
    -
    -    assert(model.intercept ~== interceptR relTol 1E-2)
    -    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
    -    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
    -    assert(model.weights(2) ~== weightsR(2) relTol 1E-2)
    -    assert(model.weights(3) ~== weightsR(3) relTol 2E-2)
    +    val interceptR1 = -0.05627428
    +    val weightsR1 = Vectors.dense(0.0, 0.0, -0.04325749, -0.02481551)
    +
    +    assert(model1.intercept ~== interceptR1 relTol 1E-2)
    +    assert(model1.weights ~= weightsR1 absTol 2E-2)
    +
    +    /*
    +       Using the following R code to load the data and train the model using glmnet package.
    +
    +       library("glmnet")
    +       data <- read.csv("path", header=FALSE)
    +       label = factor(data$V1)
    +       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
    +       weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12,
    +           standardize=FALSE))
    +       weights
    +
    +       5 x 1 sparse Matrix of class "dgCMatrix"
    +                           s0
    +       (Intercept)  0.3722152
    +       data.V2       .
    +       data.V3       .
    +       data.V4     -0.1665453
    +       data.V5       .
    +     */
    +    val interceptR2 = 0.3722152
    +    val weightsR2 = Vectors.dense(0.0, 0.0, -0.1665453, 0.0)
    +
    +    assert(model2.intercept ~== interceptR2 relTol 1E-2)
    +    assert(model2.weights ~= weightsR2 absTol 1E-3)
       }
     
       test("binary logistic regression without intercept with L1 regularization") {
    -    val trainer = (new LogisticRegression).setFitIntercept(false)
    -      .setElasticNetParam(1.0).setRegParam(0.12)
    -    val model = trainer.fit(binaryDataset)
    +    val trainer1 = (new LogisticRegression).setFitIntercept(false)
    +      .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(true)
    +    val trainer2 = (new LogisticRegression).setFitIntercept(false)
    +      .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(false)
    +
    +    val model1 = trainer1.fit(binaryDataset)
    +    val model2 = trainer2.fit(binaryDataset)
     
         /*
            Using the following R code to load the data and train the model using glmnet package.
    @@ -329,20 +368,46 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
            data.V4     -0.05189203
            data.V5     -0.03891782
          */
    -    val interceptR = 0.0
    -    val weightsR = Array(0.0, 0.0, -0.05189203, -0.03891782)
    +    val interceptR1 = 0.0
    +    val weightsR1 = Vectors.dense(0.0, 0.0, -0.05189203, -0.03891782)
    +
    +    assert(model1.intercept ~== interceptR1 relTol 1E-3)
    +    assert(model1.weights ~= weightsR1 absTol 1E-3)
    +
    +    /*
    +       Using the following R code to load the data and train the model using glmnet package.
     
    -    assert(model.intercept ~== interceptR relTol 1E-3)
    -    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
    -    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
    -    assert(model.weights(2) ~== weightsR(2) relTol 1E-2)
    -    assert(model.weights(3) ~== weightsR(3) relTol 1E-2)
    +       library("glmnet")
    +       data <- read.csv("path", header=FALSE)
    +       label = factor(data$V1)
    +       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
    +       weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12,
    +           intercept=FALSE, standardize=FALSE))
    +       weights
    +
    +       5 x 1 sparse Matrix of class "dgCMatrix"
    +                            s0
    +       (Intercept)   .
    +       data.V2       .
    +       data.V3       .
    +       data.V4     -0.08420782
    +       data.V5       .
    +     */
    +    val interceptR2 = 0.0
    +    val weightsR2 = Vectors.dense(0.0, 0.0, -0.08420782, 0.0)
    +
    +    assert(model2.intercept ~== interceptR2 absTol 1E-3)
    +    assert(model2.weights ~= weightsR2 absTol 1E-3)
       }
     
       test("binary logistic regression with intercept with L2 regularization") {
    -    val trainer = (new LogisticRegression).setFitIntercept(true)
    -      .setElasticNetParam(0.0).setRegParam(1.37)
    -    val model = trainer.fit(binaryDataset)
    +    val trainer1 = (new LogisticRegression).setFitIntercept(true)
    +      .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(true)
    +    val trainer2 = (new LogisticRegression).setFitIntercept(true)
    +      .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(false)
    +
    +    val model1 = trainer1.fit(binaryDataset)
    +    val model2 = trainer2.fit(binaryDataset)
     
         /*
            Using the following R code to load the data and train the model using glmnet package.
    @@ -362,20 +427,46 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
            data.V4     -0.04865309
            data.V5     -0.10062872
          */
    -    val interceptR = 0.15021751
    -    val weightsR = Array(-0.07251837, 0.10724191, -0.04865309, -0.10062872)
    -
    -    assert(model.intercept ~== interceptR relTol 1E-3)
    -    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
    -    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
    -    assert(model.weights(2) ~== weightsR(2) relTol 1E-3)
    -    assert(model.weights(3) ~== weightsR(3) relTol 1E-3)
    +    val interceptR1 = 0.15021751
    +    val weightsR1 = Vectors.dense(-0.07251837, 0.10724191, -0.04865309, -0.10062872)
    +
    +    assert(model1.intercept ~== interceptR1 relTol 1E-3)
    +    assert(model1.weights ~= weightsR1 relTol 1E-3)
    +
    +    /*
    +       Using the following R code to load the data and train the model using glmnet package.
    +
    +       library("glmnet")
    +       data <- read.csv("path", header=FALSE)
    +       label = factor(data$V1)
    +       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
    +       weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37,
    +           standardize=FALSE))
    +       weights
    +
    +       5 x 1 sparse Matrix of class "dgCMatrix"
    +                            s0
    +       (Intercept)  0.48657516
    +       data.V2     -0.05155371
    +       data.V3      0.02301057
    +       data.V4     -0.11482896
    +       data.V5     -0.06266838
    +     */
    +    val interceptR2 = 0.48657516
    +    val weightsR2 = Vectors.dense(-0.05155371, 0.02301057, -0.11482896, -0.06266838)
    +
    +    assert(model2.intercept ~== interceptR2 relTol 1E-3)
    +    assert(model2.weights ~= weightsR2 relTol 1E-3)
       }
     
       test("binary logistic regression without intercept with L2 regularization") {
    -    val trainer = (new LogisticRegression).setFitIntercept(false)
    -      .setElasticNetParam(0.0).setRegParam(1.37)
    -    val model = trainer.fit(binaryDataset)
    +    val trainer1 = (new LogisticRegression).setFitIntercept(false)
    +      .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(true)
    +    val trainer2 = (new LogisticRegression).setFitIntercept(false)
    +      .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(false)
    +
    +    val model1 = trainer1.fit(binaryDataset)
    +    val model2 = trainer2.fit(binaryDataset)
     
         /*
            Using the following R code to load the data and train the model using glmnet package.
    @@ -396,20 +487,46 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
            data.V4     -0.04708770
            data.V5     -0.09799775
          */
    -    val interceptR = 0.0
    -    val weightsR = Array(-0.06099165, 0.12857058, -0.04708770, -0.09799775)
    +    val interceptR1 = 0.0
    +    val weightsR1 = Vectors.dense(-0.06099165, 0.12857058, -0.04708770, -0.09799775)
    +
    +    assert(model1.intercept ~== interceptR1 absTol 1E-3)
    +    assert(model1.weights ~= weightsR1 relTol 1E-2)
    +
    +    /*
    +       Using the following R code to load the data and train the model using glmnet package.
    +
    +       library("glmnet")
    +       data <- read.csv("path", header=FALSE)
    +       label = factor(data$V1)
    +       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
    +       weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37,
    +           intercept=FALSE, standardize=FALSE))
    +       weights
     
    -    assert(model.intercept ~== interceptR relTol 1E-3)
    -    assert(model.weights(0) ~== weightsR(0) relTol 1E-2)
    -    assert(model.weights(1) ~== weightsR(1) relTol 1E-2)
    -    assert(model.weights(2) ~== weightsR(2) relTol 1E-3)
    -    assert(model.weights(3) ~== weightsR(3) relTol 1E-3)
    +       5 x 1 sparse Matrix of class "dgCMatrix"
    +                             s0
    +       (Intercept)   .
    +       data.V2     -0.005679651
    +       data.V3      0.048967094
    +       data.V4     -0.093714016
    +       data.V5     -0.053314311
    +     */
    +    val interceptR2 = 0.0
    +    val weightsR2 = Vectors.dense(-0.005679651, 0.048967094, -0.093714016, -0.053314311)
    +
    +    assert(model2.intercept ~== interceptR2 absTol 1E-3)
    +    assert(model2.weights ~= weightsR2 relTol 1E-2)
       }
     
       test("binary logistic regression with intercept with ElasticNet regularization") {
    -    val trainer = (new LogisticRegression).setFitIntercept(true)
    -      .setElasticNetParam(0.38).setRegParam(0.21)
    -    val model = trainer.fit(binaryDataset)
    +    val trainer1 = (new LogisticRegression).setFitIntercept(true)
    +      .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true)
    +    val trainer2 = (new LogisticRegression).setFitIntercept(true)
    +      .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false)
    +
    +    val model1 = trainer1.fit(binaryDataset)
    +    val model2 = trainer2.fit(binaryDataset)
     
         /*
            Using the following R code to load the data and train the model using glmnet package.
    @@ -429,20 +546,46 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
            data.V4     -0.08849250
            data.V5     -0.15458796
          */
    -    val interceptR = 0.57734851
    -    val weightsR = Array(-0.05310287, 0.0, -0.08849250, -0.15458796)
    -
    -    assert(model.intercept ~== interceptR relTol 6E-3)
    -    assert(model.weights(0) ~== weightsR(0) relTol 5E-3)
    -    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
    -    assert(model.weights(2) ~== weightsR(2) relTol 5E-3)
    -    assert(model.weights(3) ~== weightsR(3) relTol 1E-3)
    +    val interceptR1 = 0.57734851
    +    val weightsR1 = Vectors.dense(-0.05310287, 0.0, -0.08849250, -0.15458796)
    +
    +    assert(model1.intercept ~== interceptR1 relTol 6E-3)
    +    assert(model1.weights ~== weightsR1 absTol 5E-3)
    +
    +    /*
    +       Using the following R code to load the data and train the model using glmnet package.
    +
    +       library("glmnet")
    +       data <- read.csv("path", header=FALSE)
    +       label = factor(data$V1)
    +       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
    +       weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21,
    +           standardize=FALSE))
    +       weights
    +
    +       5 x 1 sparse Matrix of class "dgCMatrix"
    +                            s0
    +       (Intercept)  0.51555993
    +       data.V2       .
    +       data.V3       .
    +       data.V4     -0.18807395
    +       data.V5     -0.05350074
    +     */
    +    val interceptR2 = 0.51555993
    +    val weightsR2 = Vectors.dense(0.0, 0.0, -0.18807395, -0.05350074)
    +
    +    assert(model2.intercept ~== interceptR2 relTol 6E-3)
    +    assert(model2.weights ~= weightsR2 absTol 1E-3)
       }
     
       test("binary logistic regression without intercept with ElasticNet regularization") {
    -    val trainer = (new LogisticRegression).setFitIntercept(false)
    -      .setElasticNetParam(0.38).setRegParam(0.21)
    -    val model = trainer.fit(binaryDataset)
    +    val trainer1 = (new LogisticRegression).setFitIntercept(false)
    +      .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true)
    +    val trainer2 = (new LogisticRegression).setFitIntercept(false)
    +      .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false)
    +
    +    val model1 = trainer1.fit(binaryDataset)
    +    val model2 = trainer2.fit(binaryDataset)
     
         /*
            Using the following R code to load the data and train the model using glmnet package.
    @@ -463,20 +606,46 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
            data.V4     -0.081203769
            data.V5     -0.142534158
          */
    -    val interceptR = 0.0
    -    val weightsR = Array(-0.001005743, 0.072577857, -0.081203769, -0.142534158)
    +    val interceptR1 = 0.0
    +    val weightsR1 = Vectors.dense(-0.001005743, 0.072577857, -0.081203769, -0.142534158)
    +
    +    assert(model1.intercept ~== interceptR1 relTol 1E-3)
    +    assert(model1.weights ~= weightsR1 absTol 1E-2)
    +
    +    /*
    +       Using the following R code to load the data and train the model using glmnet package.
    +
    +       library("glmnet")
    +       data <- read.csv("path", header=FALSE)
    +       label = factor(data$V1)
    +       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
    +       weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21,
    +           intercept=FALSE, standardize=FALSE))
    +       weights
     
    -    assert(model.intercept ~== interceptR relTol 1E-3)
    -    assert(model.weights(0) ~== weightsR(0) absTol 1E-3)
    -    assert(model.weights(1) ~== weightsR(1) absTol 1E-2)
    -    assert(model.weights(2) ~== weightsR(2) relTol 1E-3)
    -    assert(model.weights(3) ~== weightsR(3) relTol 1E-2)
    +       5 x 1 sparse Matrix of class "dgCMatrix"
    +                            s0
    +       (Intercept)   .
    +       data.V2       .
    +       data.V3      0.03345223
    +       data.V4     -0.11304532
    +       data.V5       .
    +     */
    +    val interceptR2 = 0.0
    +    val weightsR2 = Vectors.dense(0.0, 0.03345223, -0.11304532, 0.0)
    +
    +    assert(model2.intercept ~== interceptR2 absTol 1E-3)
    +    assert(model2.weights ~= weightsR2 absTol 1E-3)
       }
     
       test("binary logistic regression with intercept with strong L1 regularization") {
    -    val trainer = (new LogisticRegression).setFitIntercept(true)
    -      .setElasticNetParam(1.0).setRegParam(6.0)
    -    val model = trainer.fit(binaryDataset)
    +    val trainer1 = (new LogisticRegression).setFitIntercept(true)
    +      .setElasticNetParam(1.0).setRegParam(6.0).setStandardization(true)
    +    val trainer2 = (new LogisticRegression).setFitIntercept(true)
    +      .setElasticNetParam(1.0).setRegParam(6.0).setStandardization(false)
    +
    +    val model1 = trainer1.fit(binaryDataset)
    +    val model2 = trainer2.fit(binaryDataset)
     
         val histogram = binaryDataset.map { case Row(label: Double, features: Vector) => label }
           .treeAggregate(new MultiClassSummarizer)(
    @@ -500,13 +669,13 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
            }}}
          */
         val interceptTheory = math.log(histogram(1).toDouble / histogram(0).toDouble)
    -    val weightsTheory = Array(0.0, 0.0, 0.0, 0.0)
    +    val weightsTheory = Vectors.dense(0.0, 0.0, 0.0, 0.0)
    +
    +    assert(model1.intercept ~== interceptTheory relTol 1E-5)
    +    assert(model1.weights ~= weightsTheory absTol 1E-6)
     
    -    assert(model.intercept ~== interceptTheory relTol 1E-5)
    -    assert(model.weights(0) ~== weightsTheory(0) absTol 1E-6)
    -    assert(model.weights(1) ~== weightsTheory(1) absTol 1E-6)
    -    assert(model.weights(2) ~== weightsTheory(2) absTol 1E-6)
    -    assert(model.weights(3) ~== weightsTheory(3) absTol 1E-6)
    +    assert(model2.intercept ~== interceptTheory relTol 1E-5)
    +    assert(model2.weights ~= weightsTheory absTol 1E-6)
     
         /*
            Using the following R code to load the data and train the model using glmnet package.
    @@ -527,12 +696,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
            data.V5       .
          */
         val interceptR = -0.248065
    -    val weightsR = Array(0.0, 0.0, 0.0, 0.0)
    +    val weightsR = Vectors.dense(0.0, 0.0, 0.0, 0.0)
     
    -    assert(model.intercept ~== interceptR relTol 1E-5)
    -    assert(model.weights(0) ~== weightsR(0) absTol 1E-6)
    -    assert(model.weights(1) ~== weightsR(1) absTol 1E-6)
    -    assert(model.weights(2) ~== weightsR(2) absTol 1E-6)
    -    assert(model.weights(3) ~== weightsR(3) absTol 1E-6)
    +    assert(model1.intercept ~== interceptR relTol 1E-5)
    +    assert(model1.weights ~= weightsR absTol 1E-6)
       }
     }
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
    new file mode 100644
    index 0000000000000..76381a2741296
    --- /dev/null
    +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
    @@ -0,0 +1,116 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.classification
    +
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.ml.param.ParamsSuite
    +import org.apache.spark.mllib.linalg._
    +import org.apache.spark.mllib.util.MLlibTestSparkContext
    +import org.apache.spark.mllib.util.TestingUtils._
    +import org.apache.spark.mllib.classification.NaiveBayesSuite._
    +import org.apache.spark.sql.DataFrame
    +import org.apache.spark.sql.Row
    +
    +class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
    +
    +  def validatePrediction(predictionAndLabels: DataFrame): Unit = {
    +    val numOfErrorPredictions = predictionAndLabels.collect().count {
    +      case Row(prediction: Double, label: Double) =>
    +        prediction != label
    +    }
    +    // At least 80% of the predictions should be on.
    +    assert(numOfErrorPredictions < predictionAndLabels.count() / 5)
    +  }
    +
    +  def validateModelFit(
    +      piData: Vector,
    +      thetaData: Matrix,
    +      model: NaiveBayesModel): Unit = {
    +    assert(Vectors.dense(model.pi.toArray.map(math.exp)) ~==
    +      Vectors.dense(piData.toArray.map(math.exp)) absTol 0.05, "pi mismatch")
    +    assert(model.theta.map(math.exp) ~== thetaData.map(math.exp) absTol 0.05, "theta mismatch")
    +  }
    +
    +  test("params") {
    +    ParamsSuite.checkParams(new NaiveBayes)
    +    val model = new NaiveBayesModel("nb", pi = Vectors.dense(Array(0.2, 0.8)),
    +      theta = new DenseMatrix(2, 3, Array(0.1, 0.2, 0.3, 0.4, 0.6, 0.4)))
    +    ParamsSuite.checkParams(model)
    +  }
    +
    +  test("naive bayes: default params") {
    +    val nb = new NaiveBayes
    +    assert(nb.getLabelCol === "label")
    +    assert(nb.getFeaturesCol === "features")
    +    assert(nb.getPredictionCol === "prediction")
    +    assert(nb.getLambda === 1.0)
    +    assert(nb.getModelType === "multinomial")
    +  }
    +
    +  test("Naive Bayes Multinomial") {
    +    val nPoints = 1000
    +    val piArray = Array(0.5, 0.1, 0.4).map(math.log)
    +    val thetaArray = Array(
    +      Array(0.70, 0.10, 0.10, 0.10), // label 0
    +      Array(0.10, 0.70, 0.10, 0.10), // label 1
    +      Array(0.10, 0.10, 0.70, 0.10)  // label 2
    +    ).map(_.map(math.log))
    +    val pi = Vectors.dense(piArray)
    +    val theta = new DenseMatrix(3, 4, thetaArray.flatten, true)
    +
    +    val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
    +      piArray, thetaArray, nPoints, 42, "multinomial"))
    +    val nb = new NaiveBayes().setLambda(1.0).setModelType("multinomial")
    +    val model = nb.fit(testDataset)
    +
    +    validateModelFit(pi, theta, model)
    +    assert(model.hasParent)
    +
    +    val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
    +      piArray, thetaArray, nPoints, 17, "multinomial"))
    +    val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
    +
    +    validatePrediction(predictionAndLabels)
    +  }
    +
    +  test("Naive Bayes Bernoulli") {
    +    val nPoints = 10000
    +    val piArray = Array(0.5, 0.3, 0.2).map(math.log)
    +    val thetaArray = Array(
    +      Array(0.50, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.40), // label 0
    +      Array(0.02, 0.70, 0.10, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02), // label 1
    +      Array(0.02, 0.02, 0.60, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.30)  // label 2
    +    ).map(_.map(math.log))
    +    val pi = Vectors.dense(piArray)
    +    val theta = new DenseMatrix(3, 12, thetaArray.flatten, true)
    +
    +    val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
    +      piArray, thetaArray, nPoints, 45, "bernoulli"))
    +    val nb = new NaiveBayes().setLambda(1.0).setModelType("bernoulli")
    +    val model = nb.fit(testDataset)
    +
    +    validateModelFit(pi, theta, model)
    +    assert(model.hasParent)
    +
    +    val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
    +      piArray, thetaArray, nPoints, 20, "bernoulli"))
    +    val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
    +
    +    validatePrediction(predictionAndLabels)
    +  }
    +}
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
    index 75cf5bd4ead4f..3775292f6dca7 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
    @@ -19,6 +19,7 @@ package org.apache.spark.ml.classification
     
     import org.apache.spark.SparkFunSuite
     import org.apache.spark.ml.attribute.NominalAttribute
    +import org.apache.spark.ml.feature.StringIndexer
     import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
     import org.apache.spark.ml.util.MetadataUtils
     import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
    @@ -104,6 +105,29 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
         ova.fit(datasetWithLabelMetadata)
       }
     
    +  test("SPARK-8092: ensure label features and prediction cols are configurable") {
    +    val labelIndexer = new StringIndexer()
    +      .setInputCol("label")
    +      .setOutputCol("indexed")
    +
    +    val indexedDataset = labelIndexer
    +      .fit(dataset)
    +      .transform(dataset)
    +      .drop("label")
    +      .withColumnRenamed("features", "f")
    +
    +    val ova = new OneVsRest()
    +    ova.setClassifier(new LogisticRegression())
    +      .setLabelCol(labelIndexer.getOutputCol)
    +      .setFeaturesCol("f")
    +      .setPredictionCol("p")
    +
    +    val ovaModel = ova.fit(indexedDataset)
    +    val transformedDataset = ovaModel.transform(indexedDataset)
    +    val outputFields = transformedDataset.schema.fieldNames.toSet
    +    assert(outputFields.contains("p"))
    +  }
    +
       test("SPARK-8049: OneVsRest shouldn't output temp columns") {
         val logReg = new LogisticRegression()
           .setMaxIter(1)
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
    index 1b6b69c7dc71e..ab711c8e4b215 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
    @@ -21,13 +21,13 @@ import org.apache.spark.SparkFunSuite
     import org.apache.spark.ml.impl.TreeTests
     import org.apache.spark.ml.param.ParamsSuite
     import org.apache.spark.ml.tree.LeafNode
    -import org.apache.spark.mllib.linalg.Vectors
    +import org.apache.spark.mllib.linalg.{Vector, Vectors}
     import org.apache.spark.mllib.regression.LabeledPoint
     import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
     import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.rdd.RDD
    -import org.apache.spark.sql.DataFrame
    +import org.apache.spark.sql.{DataFrame, Row}
     
     /**
      * Test suite for [[RandomForestClassifier]].
    @@ -66,7 +66,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
       test("params") {
         ParamsSuite.checkParams(new RandomForestClassifier)
         val model = new RandomForestClassificationModel("rfc",
    -      Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))))
    +      Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))), 2)
         ParamsSuite.checkParams(model)
       }
     
    @@ -167,9 +167,19 @@ private object RandomForestClassifierSuite {
         val newModel = rf.fit(newData)
         // Use parent from newTree since this is not checked anyways.
         val oldModelAsNew = RandomForestClassificationModel.fromOld(
    -      oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures)
    +      oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures,
    +      numClasses)
         TreeTests.checkEqual(oldModelAsNew, newModel)
         assert(newModel.hasParent)
         assert(!newModel.trees.head.asInstanceOf[DecisionTreeClassificationModel].hasParent)
    +    assert(newModel.numClasses == numClasses)
    +    val results = newModel.transform(newData)
    +    results.select("rawPrediction", "prediction").collect().foreach {
    +      case Row(raw: Vector, prediction: Double) => {
    +        assert(raw.size == numClasses)
    +        val predFromRaw = raw.toArray.zipWithIndex.maxBy(_._1)._2
    +        assert(predFromRaw == prediction)
    +      }
    +    }
       }
     }
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
    new file mode 100644
    index 0000000000000..1f15ac02f4008
    --- /dev/null
    +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
    @@ -0,0 +1,114 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.clustering
    +
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans}
    +import org.apache.spark.mllib.linalg.{Vector, Vectors}
    +import org.apache.spark.mllib.util.MLlibTestSparkContext
    +import org.apache.spark.sql.{DataFrame, SQLContext}
    +
    +private[clustering] case class TestRow(features: Vector)
    +
    +object KMeansSuite {
    +  def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = {
    +    val sc = sql.sparkContext
    +    val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble)))
    +      .map(v => new TestRow(v))
    +    sql.createDataFrame(rdd)
    +  }
    +}
    +
    +class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
    +
    +  final val k = 5
    +  @transient var dataset: DataFrame = _
    +
    +  override def beforeAll(): Unit = {
    +    super.beforeAll()
    +
    +    dataset = KMeansSuite.generateKMeansData(sqlContext, 50, 3, k)
    +  }
    +
    +  test("default parameters") {
    +    val kmeans = new KMeans()
    +
    +    assert(kmeans.getK === 2)
    +    assert(kmeans.getFeaturesCol === "features")
    +    assert(kmeans.getPredictionCol === "prediction")
    +    assert(kmeans.getMaxIter === 20)
    +    assert(kmeans.getRuns === 1)
    +    assert(kmeans.getInitMode === MLlibKMeans.K_MEANS_PARALLEL)
    +    assert(kmeans.getInitSteps === 5)
    +    assert(kmeans.getEpsilon === 1e-4)
    +  }
    +
    +  test("set parameters") {
    +    val kmeans = new KMeans()
    +      .setK(9)
    +      .setFeaturesCol("test_feature")
    +      .setPredictionCol("test_prediction")
    +      .setMaxIter(33)
    +      .setRuns(7)
    +      .setInitMode(MLlibKMeans.RANDOM)
    +      .setInitSteps(3)
    +      .setSeed(123)
    +      .setEpsilon(1e-3)
    +
    +    assert(kmeans.getK === 9)
    +    assert(kmeans.getFeaturesCol === "test_feature")
    +    assert(kmeans.getPredictionCol === "test_prediction")
    +    assert(kmeans.getMaxIter === 33)
    +    assert(kmeans.getRuns === 7)
    +    assert(kmeans.getInitMode === MLlibKMeans.RANDOM)
    +    assert(kmeans.getInitSteps === 3)
    +    assert(kmeans.getSeed === 123)
    +    assert(kmeans.getEpsilon === 1e-3)
    +  }
    +
    +  test("parameters validation") {
    +    intercept[IllegalArgumentException] {
    +      new KMeans().setK(1)
    +    }
    +    intercept[IllegalArgumentException] {
    +      new KMeans().setInitMode("no_such_a_mode")
    +    }
    +    intercept[IllegalArgumentException] {
    +      new KMeans().setInitSteps(0)
    +    }
    +    intercept[IllegalArgumentException] {
    +      new KMeans().setRuns(0)
    +    }
    +  }
    +
    +  test("fit & transform") {
    +    val predictionColName = "kmeans_prediction"
    +    val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName).setSeed(1)
    +    val model = kmeans.fit(dataset)
    +    assert(model.clusterCenters.length === k)
    +
    +    val transformed = model.transform(dataset)
    +    val expectedColumns = Array("features", predictionColName)
    +    expectedColumns.foreach { column =>
    +      assert(transformed.columns.contains(column))
    +    }
    +    val clusters = transformed.select(predictionColName).map(_.getInt(0)).distinct().collect().toSet
    +    assert(clusters.size === k)
    +    assert(clusters === Set(0, 1, 2, 3, 4))
    +  }
    +}
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala
    new file mode 100644
    index 0000000000000..e90d9d4ef21ff
    --- /dev/null
    +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala
    @@ -0,0 +1,73 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +package org.apache.spark.ml.feature
    +
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.ml.param.ParamsSuite
    +import org.apache.spark.mllib.linalg.{Vector, Vectors}
    +import org.apache.spark.mllib.util.MLlibTestSparkContext
    +import org.apache.spark.mllib.util.TestingUtils._
    +
    +class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext {
    +
    +  test("params") {
    +    ParamsSuite.checkParams(new CountVectorizerModel(Array("empty")))
    +  }
    +
    +  test("CountVectorizerModel common cases") {
    +    val df = sqlContext.createDataFrame(Seq(
    +      (0, "a b c d".split(" ").toSeq,
    +        Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))),
    +      (1, "a b b c d  a".split(" ").toSeq,
    +        Vectors.sparse(4, Seq((0, 2.0), (1, 2.0), (2, 1.0), (3, 1.0)))),
    +      (2, "a".split(" ").toSeq, Vectors.sparse(4, Seq((0, 1.0)))),
    +      (3, "".split(" ").toSeq, Vectors.sparse(4, Seq())), // empty string
    +      (4, "a notInDict d".split(" ").toSeq,
    +        Vectors.sparse(4, Seq((0, 1.0), (3, 1.0))))  // with words not in vocabulary
    +    )).toDF("id", "words", "expected")
    +    val cv = new CountVectorizerModel(Array("a", "b", "c", "d"))
    +      .setInputCol("words")
    +      .setOutputCol("features")
    +    val output = cv.transform(df).collect()
    +    output.foreach { p =>
    +      val features = p.getAs[Vector]("features")
    +      val expected = p.getAs[Vector]("expected")
    +      assert(features ~== expected absTol 1e-14)
    +    }
    +  }
    +
    +  test("CountVectorizerModel with minTermFreq") {
    +    val df = sqlContext.createDataFrame(Seq(
    +      (0, "a a a b b c c c d ".split(" ").toSeq, Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))),
    +      (1, "c c c c c c".split(" ").toSeq, Vectors.sparse(4, Seq((2, 6.0)))),
    +      (2, "a".split(" ").toSeq, Vectors.sparse(4, Seq())),
    +      (3, "e e e e e".split(" ").toSeq, Vectors.sparse(4, Seq())))
    +    ).toDF("id", "words", "expected")
    +    val cv = new CountVectorizerModel(Array("a", "b", "c", "d"))
    +      .setInputCol("words")
    +      .setOutputCol("features")
    +      .setMinTermFreq(3)
    +    val output = cv.transform(df).collect()
    +    output.foreach { p =>
    +      val features = p.getAs[Vector]("features")
    +      val expected = p.getAs[Vector]("expected")
    +      assert(features ~== expected absTol 1e-14)
    +    }
    +  }
    +}
    +
    +
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
    index 65846a846b7b4..321eeb843941c 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
    @@ -86,8 +86,8 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext {
         val output = encoder.transform(df)
         val group = AttributeGroup.fromStructField(output.schema("encoded"))
         assert(group.size === 2)
    -    assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("size_is_small").withIndex(0))
    -    assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("size_is_medium").withIndex(1))
    +    assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0))
    +    assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1))
       }
     
       test("input column without ML attribute") {
    @@ -98,7 +98,7 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext {
         val output = encoder.transform(df)
         val group = AttributeGroup.fromStructField(output.schema("encoded"))
         assert(group.size === 2)
    -    assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("index_is_0").withIndex(0))
    -    assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("index_is_1").withIndex(1))
    +    assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0))
    +    assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1))
       }
     }
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
    new file mode 100644
    index 0000000000000..436e66bab09b0
    --- /dev/null
    +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
    @@ -0,0 +1,82 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.feature
    +
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.sql.types._
    +
    +class RFormulaParserSuite extends SparkFunSuite {
    +  private def checkParse(
    +      formula: String,
    +      label: String,
    +      terms: Seq[String],
    +      schema: StructType = null) {
    +    val resolved = RFormulaParser.parse(formula).resolve(schema)
    +    assert(resolved.label == label)
    +    assert(resolved.terms == terms)
    +  }
    +
    +  test("parse simple formulas") {
    +    checkParse("y ~ x", "y", Seq("x"))
    +    checkParse("y ~ x + x", "y", Seq("x"))
    +    checkParse("y ~   ._foo  ", "y", Seq("._foo"))
    +    checkParse("resp ~ A_VAR + B + c123", "resp", Seq("A_VAR", "B", "c123"))
    +  }
    +
    +  test("parse dot") {
    +    val schema = (new StructType)
    +      .add("a", "int", true)
    +      .add("b", "long", false)
    +      .add("c", "string", true)
    +    checkParse("a ~ .", "a", Seq("b", "c"), schema)
    +  }
    +
    +  test("parse deletion") {
    +    val schema = (new StructType)
    +      .add("a", "int", true)
    +      .add("b", "long", false)
    +      .add("c", "string", true)
    +    checkParse("a ~ c - b", "a", Seq("c"), schema)
    +  }
    +
    +  test("parse additions and deletions in order") {
    +    val schema = (new StructType)
    +      .add("a", "int", true)
    +      .add("b", "long", false)
    +      .add("c", "string", true)
    +    checkParse("a ~ . - b + . - c", "a", Seq("b"), schema)
    +  }
    +
    +  test("dot ignores complex column types") {
    +    val schema = (new StructType)
    +      .add("a", "int", true)
    +      .add("b", "tinyint", false)
    +      .add("c", "map", true)
    +    checkParse("a ~ .", "a", Seq("b"), schema)
    +  }
    +
    +  test("parse intercept") {
    +    assert(RFormulaParser.parse("a ~ b").hasIntercept)
    +    assert(RFormulaParser.parse("a ~ b + 1").hasIntercept)
    +    assert(RFormulaParser.parse("a ~ b - 0").hasIntercept)
    +    assert(RFormulaParser.parse("a ~ b - 1 + 1").hasIntercept)
    +    assert(!RFormulaParser.parse("a ~ b + 0").hasIntercept)
    +    assert(!RFormulaParser.parse("a ~ b - 1").hasIntercept)
    +    assert(!RFormulaParser.parse("a ~ b + 1 - 1").hasIntercept)
    +  }
    +}
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
    new file mode 100644
    index 0000000000000..6aed3243afce8
    --- /dev/null
    +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
    @@ -0,0 +1,126 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.feature
    +
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.ml.attribute._
    +import org.apache.spark.ml.param.ParamsSuite
    +import org.apache.spark.mllib.linalg.Vectors
    +import org.apache.spark.mllib.util.MLlibTestSparkContext
    +
    +class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
    +  test("params") {
    +    ParamsSuite.checkParams(new RFormula())
    +  }
    +
    +  test("transform numeric data") {
    +    val formula = new RFormula().setFormula("id ~ v1 + v2")
    +    val original = sqlContext.createDataFrame(
    +      Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2")
    +    val model = formula.fit(original)
    +    val result = model.transform(original)
    +    val resultSchema = model.transformSchema(original.schema)
    +    val expected = sqlContext.createDataFrame(
    +      Seq(
    +        (0, 1.0, 3.0, Vectors.dense(1.0, 3.0), 0.0),
    +        (2, 2.0, 5.0, Vectors.dense(2.0, 5.0), 2.0))
    +      ).toDF("id", "v1", "v2", "features", "label")
    +    // TODO(ekl) make schema comparisons ignore metadata, to avoid .toString
    +    assert(result.schema.toString == resultSchema.toString)
    +    assert(resultSchema == expected.schema)
    +    assert(result.collect() === expected.collect())
    +  }
    +
    +  test("features column already exists") {
    +    val formula = new RFormula().setFormula("y ~ x").setFeaturesCol("x")
    +    val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
    +    intercept[IllegalArgumentException] {
    +      formula.fit(original)
    +    }
    +    intercept[IllegalArgumentException] {
    +      formula.fit(original)
    +    }
    +  }
    +
    +  test("label column already exists") {
    +    val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
    +    val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
    +    val model = formula.fit(original)
    +    val resultSchema = model.transformSchema(original.schema)
    +    assert(resultSchema.length == 3)
    +    assert(resultSchema.toString == model.transform(original).schema.toString)
    +  }
    +
    +  test("label column already exists but is not double type") {
    +    val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
    +    val original = sqlContext.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y")
    +    val model = formula.fit(original)
    +    intercept[IllegalArgumentException] {
    +      model.transformSchema(original.schema)
    +    }
    +    intercept[IllegalArgumentException] {
    +      model.transform(original)
    +    }
    +  }
    +
    +  test("allow missing label column for test datasets") {
    +    val formula = new RFormula().setFormula("y ~ x").setLabelCol("label")
    +    val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "_not_y")
    +    val model = formula.fit(original)
    +    val resultSchema = model.transformSchema(original.schema)
    +    assert(resultSchema.length == 3)
    +    assert(!resultSchema.exists(_.name == "label"))
    +    assert(resultSchema.toString == model.transform(original).schema.toString)
    +  }
    +
    +  test("encodes string terms") {
    +    val formula = new RFormula().setFormula("id ~ a + b")
    +    val original = sqlContext.createDataFrame(
    +      Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5))
    +    ).toDF("id", "a", "b")
    +    val model = formula.fit(original)
    +    val result = model.transform(original)
    +    val resultSchema = model.transformSchema(original.schema)
    +    val expected = sqlContext.createDataFrame(
    +      Seq(
    +        (1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
    +        (2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0),
    +        (3, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 3.0),
    +        (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0))
    +      ).toDF("id", "a", "b", "features", "label")
    +    assert(result.schema.toString == resultSchema.toString)
    +    assert(result.collect() === expected.collect())
    +  }
    +
    +  test("attribute generation") {
    +    val formula = new RFormula().setFormula("id ~ a + b")
    +    val original = sqlContext.createDataFrame(
    +      Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5))
    +    ).toDF("id", "a", "b")
    +    val model = formula.fit(original)
    +    val result = model.transform(original)
    +    val attrs = AttributeGroup.fromStructField(result.schema("features"))
    +    val expectedAttrs = new AttributeGroup(
    +      "features",
    +      Array(
    +        new BinaryAttribute(Some("a__bar"), Some(1)),
    +        new BinaryAttribute(Some("a__foo"), Some(2)),
    +        new NumericAttribute(Some("b"), Some(3))))
    +    assert(attrs === expectedAttrs)
    +  }
    +}
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
    index 8c85c96d5c6d8..03120c828ca96 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
    @@ -19,7 +19,7 @@ package org.apache.spark.ml.feature
     
     import scala.beans.{BeanInfo, BeanProperty}
     
    -import org.apache.spark.{SparkException, SparkFunSuite}
    +import org.apache.spark.{Logging, SparkException, SparkFunSuite}
     import org.apache.spark.ml.attribute._
     import org.apache.spark.ml.param.ParamsSuite
     import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
    @@ -27,7 +27,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.rdd.RDD
     import org.apache.spark.sql.DataFrame
     
    -class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
    +class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
     
       import VectorIndexerSuite.FeatureData
     
    @@ -113,11 +113,11 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
         model.transform(sparsePoints1) // should work
         intercept[SparkException] {
           model.transform(densePoints2).collect()
    -      println("Did not throw error when fit, transform were called on vectors of different lengths")
    +      logInfo("Did not throw error when fit, transform were called on vectors of different lengths")
         }
         intercept[SparkException] {
           vectorIndexer.fit(badPoints)
    -      println("Did not throw error when fitting vectors of different lengths in same RDD.")
    +      logInfo("Did not throw error when fitting vectors of different lengths in same RDD.")
         }
       }
     
    @@ -196,7 +196,7 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
             }
           } catch {
             case e: org.scalatest.exceptions.TestFailedException =>
    -          println(errMsg)
    +          logError(errMsg)
               throw e
           }
         }
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
    index 9682edcd9ba84..dbdce0c9dea54 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
    @@ -25,7 +25,8 @@ import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees =>
     import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.rdd.RDD
    -import org.apache.spark.sql.{DataFrame, Row}
    +import org.apache.spark.sql.DataFrame
    +import org.apache.spark.util.Utils
     
     
     /**
    @@ -88,6 +89,23 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
         assert(predictions.min() < -1)
       }
     
    +  test("Checkpointing") {
    +    val tempDir = Utils.createTempDir()
    +    val path = tempDir.toURI.toString
    +    sc.setCheckpointDir(path)
    +
    +    val df = sqlContext.createDataFrame(data)
    +    val gbt = new GBTRegressor()
    +      .setMaxDepth(2)
    +      .setMaxIter(5)
    +      .setStepSize(0.1)
    +      .setCheckpointInterval(2)
    +    val model = gbt.fit(df)
    +
    +    sc.checkpointDir = None
    +    Utils.deleteRecursively(tempDir)
    +  }
    +
       // TODO: Reinstate test once runWithValidation is implemented  SPARK-7132
       /*
       test("runWithValidation stops early and performs better on a validation dataset") {
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
    new file mode 100644
    index 0000000000000..66e4b170bae80
    --- /dev/null
    +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
    @@ -0,0 +1,148 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.regression
    +
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.ml.param.ParamsSuite
    +import org.apache.spark.mllib.util.MLlibTestSparkContext
    +import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
    +import org.apache.spark.sql.{DataFrame, Row}
    +
    +class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
    +  private val schema = StructType(
    +    Array(
    +      StructField("label", DoubleType),
    +      StructField("features", DoubleType),
    +      StructField("weight", DoubleType)))
    +
    +  private val predictionSchema = StructType(Array(StructField("features", DoubleType)))
    +
    +  private def generateIsotonicInput(labels: Seq[Double]): DataFrame = {
    +    val data = Seq.tabulate(labels.size)(i => Row(labels(i), i.toDouble, 1d))
    +    val parallelData = sc.parallelize(data)
    +
    +    sqlContext.createDataFrame(parallelData, schema)
    +  }
    +
    +  private def generatePredictionInput(features: Seq[Double]): DataFrame = {
    +    val data = Seq.tabulate(features.size)(i => Row(features(i)))
    +
    +    val parallelData = sc.parallelize(data)
    +    sqlContext.createDataFrame(parallelData, predictionSchema)
    +  }
    +
    +  test("isotonic regression predictions") {
    +    val dataset = generateIsotonicInput(Seq(1, 2, 3, 1, 6, 17, 16, 17, 18))
    +    val trainer = new IsotonicRegression().setIsotonicParam(true)
    +
    +    val model = trainer.fit(dataset)
    +
    +    val predictions = model
    +      .transform(dataset)
    +      .select("prediction").map {
    +        case Row(pred) => pred
    +      }.collect()
    +
    +    assert(predictions === Array(1, 2, 2, 2, 6, 16.5, 16.5, 17, 18))
    +
    +    assert(model.parentModel.boundaries === Array(0, 1, 3, 4, 5, 6, 7, 8))
    +    assert(model.parentModel.predictions === Array(1, 2, 2, 6, 16.5, 16.5, 17.0, 18.0))
    +    assert(model.parentModel.isotonic)
    +  }
    +
    +  test("antitonic regression predictions") {
    +    val dataset = generateIsotonicInput(Seq(7, 5, 3, 5, 1))
    +    val trainer = new IsotonicRegression().setIsotonicParam(false)
    +
    +    val model = trainer.fit(dataset)
    +    val features = generatePredictionInput(Seq(-2.0, -1.0, 0.5, 0.75, 1.0, 2.0, 9.0))
    +
    +    val predictions = model
    +      .transform(features)
    +      .select("prediction").map {
    +        case Row(pred) => pred
    +      }.collect()
    +
    +    assert(predictions === Array(7, 7, 6, 5.5, 5, 4, 1))
    +  }
    +
    +  test("params validation") {
    +    val dataset = generateIsotonicInput(Seq(1, 2, 3))
    +    val ir = new IsotonicRegression
    +    ParamsSuite.checkParams(ir)
    +    val model = ir.fit(dataset)
    +    ParamsSuite.checkParams(model)
    +  }
    +
    +  test("default params") {
    +    val dataset = generateIsotonicInput(Seq(1, 2, 3))
    +    val ir = new IsotonicRegression()
    +    assert(ir.getLabelCol === "label")
    +    assert(ir.getFeaturesCol === "features")
    +    assert(ir.getWeightCol === "weight")
    +    assert(ir.getPredictionCol === "prediction")
    +    assert(ir.getIsotonicParam === true)
    +
    +    val model = ir.fit(dataset)
    +    model.transform(dataset)
    +      .select("label", "features", "prediction", "weight")
    +      .collect()
    +
    +    assert(model.getLabelCol === "label")
    +    assert(model.getFeaturesCol === "features")
    +    assert(model.getWeightCol === "weight")
    +    assert(model.getPredictionCol === "prediction")
    +    assert(model.getIsotonicParam === true)
    +    assert(model.hasParent)
    +  }
    +
    +  test("set parameters") {
    +    val isotonicRegression = new IsotonicRegression()
    +      .setIsotonicParam(false)
    +      .setWeightParam("w")
    +      .setFeaturesCol("f")
    +      .setLabelCol("l")
    +      .setPredictionCol("p")
    +
    +    assert(isotonicRegression.getIsotonicParam === false)
    +    assert(isotonicRegression.getWeightCol === "w")
    +    assert(isotonicRegression.getFeaturesCol === "f")
    +    assert(isotonicRegression.getLabelCol === "l")
    +    assert(isotonicRegression.getPredictionCol === "p")
    +  }
    +
    +  test("missing column") {
    +    val dataset = generateIsotonicInput(Seq(1, 2, 3))
    +
    +    intercept[IllegalArgumentException] {
    +      new IsotonicRegression().setWeightParam("w").fit(dataset)
    +    }
    +
    +    intercept[IllegalArgumentException] {
    +      new IsotonicRegression().setFeaturesCol("f").fit(dataset)
    +    }
    +
    +    intercept[IllegalArgumentException] {
    +      new IsotonicRegression().setLabelCol("l").fit(dataset)
    +    }
    +
    +    intercept[IllegalArgumentException] {
    +      new IsotonicRegression().fit(dataset).setFeaturesCol("f").transform(dataset)
    +    }
    +  }
    +}
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
    index 5f39d44f37352..7cdda3db88ad1 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
    @@ -18,7 +18,8 @@
     package org.apache.spark.ml.regression
     
     import org.apache.spark.SparkFunSuite
    -import org.apache.spark.mllib.linalg.DenseVector
    +import org.apache.spark.ml.param.ParamsSuite
    +import org.apache.spark.mllib.linalg.{DenseVector, Vectors}
     import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
     import org.apache.spark.mllib.util.TestingUtils._
     import org.apache.spark.sql.{DataFrame, Row}
    @@ -55,6 +56,30 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       }
     
    +  test("params") {
    +    ParamsSuite.checkParams(new LinearRegression)
    +    val model = new LinearRegressionModel("linearReg", Vectors.dense(0.0), 0.0)
    +    ParamsSuite.checkParams(model)
    +  }
    +
    +  test("linear regression: default params") {
    +    val lir = new LinearRegression
    +    assert(lir.getLabelCol === "label")
    +    assert(lir.getFeaturesCol === "features")
    +    assert(lir.getPredictionCol === "prediction")
    +    assert(lir.getRegParam === 0.0)
    +    assert(lir.getElasticNetParam === 0.0)
    +    assert(lir.getFitIntercept)
    +    val model = lir.fit(dataset)
    +    model.transform(dataset)
    +      .select("label", "prediction")
    +      .collect()
    +    assert(model.getFeaturesCol === "features")
    +    assert(model.getPredictionCol === "prediction")
    +    assert(model.intercept !== 0.0)
    +    assert(model.hasParent)
    +  }
    +
       test("linear regression with intercept without regularization") {
         val trainer = new LinearRegression
         val model = trainer.fit(dataset)
    @@ -75,11 +100,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
            as.numeric.data.V3. 7.198257
          */
         val interceptR = 6.298698
    -    val weightsR = Array(4.700706, 7.199082)
    +    val weightsR = Vectors.dense(4.700706, 7.199082)
     
         assert(model.intercept ~== interceptR relTol 1E-3)
    -    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
    -    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
    +    assert(model.weights ~= weightsR relTol 1E-3)
     
         model.transform(dataset).select("features", "prediction").collect().foreach {
           case Row(features: DenseVector, prediction1: Double) =>
    @@ -104,11 +128,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
            as.numeric.data.V2. 6.995908
            as.numeric.data.V3. 5.275131
          */
    -    val weightsR = Array(6.995908, 5.275131)
    +    val weightsR = Vectors.dense(6.995908, 5.275131)
     
    -    assert(model.intercept ~== 0 relTol 1E-3)
    -    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
    -    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
    +    assert(model.intercept ~== 0 absTol 1E-3)
    +    assert(model.weights ~= weightsR relTol 1E-3)
         /*
            Then again with the data with no intercept:
            > weightsWithoutIntercept
    @@ -118,11 +141,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
            as.numeric.data3.V2. 4.70011
            as.numeric.data3.V3. 7.19943
          */
    -    val weightsWithoutInterceptR = Array(4.70011, 7.19943)
    +    val weightsWithoutInterceptR = Vectors.dense(4.70011, 7.19943)
     
    -    assert(modelWithoutIntercept.intercept ~== 0 relTol 1E-3)
    -    assert(modelWithoutIntercept.weights(0) ~== weightsWithoutInterceptR(0) relTol 1E-3)
    -    assert(modelWithoutIntercept.weights(1) ~== weightsWithoutInterceptR(1) relTol 1E-3)
    +    assert(modelWithoutIntercept.intercept ~== 0 absTol 1E-3)
    +    assert(modelWithoutIntercept.weights ~= weightsWithoutInterceptR relTol 1E-3)
       }
     
       test("linear regression with intercept with L1 regularization") {
    @@ -139,11 +161,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
            as.numeric.data.V3. 6.679841
          */
         val interceptR = 6.24300
    -    val weightsR = Array(4.024821, 6.679841)
    +    val weightsR = Vectors.dense(4.024821, 6.679841)
     
         assert(model.intercept ~== interceptR relTol 1E-3)
    -    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
    -    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
    +    assert(model.weights ~= weightsR relTol 1E-3)
     
         model.transform(dataset).select("features", "prediction").collect().foreach {
           case Row(features: DenseVector, prediction1: Double) =>
    @@ -169,11 +190,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
            as.numeric.data.V3. 4.772913
          */
         val interceptR = 0.0
    -    val weightsR = Array(6.299752, 4.772913)
    +    val weightsR = Vectors.dense(6.299752, 4.772913)
     
    -    assert(model.intercept ~== interceptR relTol 1E-3)
    -    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
    -    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
    +    assert(model.intercept ~== interceptR absTol 1E-5)
    +    assert(model.weights ~= weightsR relTol 1E-3)
     
         model.transform(dataset).select("features", "prediction").collect().foreach {
           case Row(features: DenseVector, prediction1: Double) =>
    @@ -197,11 +217,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
            as.numeric.data.V3. 4.926260
          */
         val interceptR = 5.269376
    -    val weightsR = Array(3.736216, 5.712356)
    +    val weightsR = Vectors.dense(3.736216, 5.712356)
     
         assert(model.intercept ~== interceptR relTol 1E-3)
    -    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
    -    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
    +    assert(model.weights ~= weightsR relTol 1E-3)
     
         model.transform(dataset).select("features", "prediction").collect().foreach {
           case Row(features: DenseVector, prediction1: Double) =>
    @@ -227,11 +246,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
            as.numeric.data.V3. 4.214502
          */
         val interceptR = 0.0
    -    val weightsR = Array(5.522875, 4.214502)
    +    val weightsR = Vectors.dense(5.522875, 4.214502)
     
    -    assert(model.intercept ~== interceptR relTol 1E-3)
    -    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
    -    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
    +    assert(model.intercept ~== interceptR absTol 1E-3)
    +    assert(model.weights ~== weightsR relTol 1E-3)
     
         model.transform(dataset).select("features", "prediction").collect().foreach {
           case Row(features: DenseVector, prediction1: Double) =>
    @@ -255,11 +273,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
            as.numeric.data.V3. 5.200403
          */
         val interceptR = 5.696056
    -    val weightsR = Array(3.670489, 6.001122)
    +    val weightsR = Vectors.dense(3.670489, 6.001122)
     
         assert(model.intercept ~== interceptR relTol 1E-3)
    -    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
    -    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
    +    assert(model.weights ~== weightsR relTol 1E-3)
     
         model.transform(dataset).select("features", "prediction").collect().foreach {
           case Row(features: DenseVector, prediction1: Double) =>
    @@ -285,11 +302,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
            as.numeric.dataM.V3. 4.322251
          */
         val interceptR = 0.0
    -    val weightsR = Array(5.673348, 4.322251)
    +    val weightsR = Vectors.dense(5.673348, 4.322251)
     
    -    assert(model.intercept ~== interceptR relTol 1E-3)
    -    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
    -    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
    +    assert(model.intercept ~== interceptR absTol 1E-3)
    +    assert(model.weights ~= weightsR relTol 1E-3)
     
         model.transform(dataset).select("features", "prediction").collect().foreach {
           case Row(features: DenseVector, prediction1: Double) =>
    @@ -298,4 +314,63 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
             assert(prediction1 ~== prediction2 relTol 1E-5)
         }
       }
    +
    +  test("linear regression model training summary") {
    +    val trainer = new LinearRegression
    +    val model = trainer.fit(dataset)
    +
    +    // Training results for the model should be available
    +    assert(model.hasSummary)
    +
    +    // Residuals in [[LinearRegressionResults]] should equal those manually computed
    +    val expectedResiduals = dataset.select("features", "label")
    +      .map { case Row(features: DenseVector, label: Double) =>
    +      val prediction =
    +        features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
    +      label - prediction
    +    }
    +      .zip(model.summary.residuals.map(_.getDouble(0)))
    +      .collect()
    +      .foreach { case (manualResidual: Double, resultResidual: Double) =>
    +      assert(manualResidual ~== resultResidual relTol 1E-5)
    +    }
    +
    +    /*
    +       Use the following R code to generate model training results.
    +
    +       predictions <- predict(fit, newx=features)
    +       residuals <- label - predictions
    +       > mean(residuals^2) # MSE
    +       [1] 0.009720325
    +       > mean(abs(residuals)) # MAD
    +       [1] 0.07863206
    +       > cor(predictions, label)^2# r^2
    +               [,1]
    +       s0 0.9998749
    +     */
    +    assert(model.summary.meanSquaredError ~== 0.00972035 relTol 1E-5)
    +    assert(model.summary.meanAbsoluteError ~== 0.07863206  relTol 1E-5)
    +    assert(model.summary.r2 ~== 0.9998749 relTol 1E-5)
    +
    +    // Objective function should be monotonically decreasing for linear regression
    +    assert(
    +      model.summary
    +        .objectiveHistory
    +        .sliding(2)
    +        .forall(x => x(0) >= x(1)))
    +  }
    +
    +  test("linear regression model testset evaluation summary") {
    +    val trainer = new LinearRegression
    +    val model = trainer.fit(dataset)
    +
    +    // Evaluating on training dataset should yield results summary equal to training summary
    +    val testSummary = model.evaluate(dataset)
    +    assert(model.summary.meanSquaredError ~== testSummary.meanSquaredError relTol 1E-5)
    +    assert(model.summary.r2 ~== testSummary.r2 relTol 1E-5)
    +    model.summary.residuals.select("residuals").collect()
    +      .zip(testSummary.residuals.select("residuals").collect())
    +      .forall { case (Row(r1: Double), Row(r2: Double)) => r1 ~== r2 relTol 1E-5 }
    +  }
    +
     }
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
    new file mode 100644
    index 0000000000000..c8e58f216cceb
    --- /dev/null
    +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
    @@ -0,0 +1,139 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.tuning
    +
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.ml.{Estimator, Model}
    +import org.apache.spark.ml.classification.LogisticRegression
    +import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
    +import org.apache.spark.ml.param.ParamMap
    +import org.apache.spark.ml.param.shared.HasInputCol
    +import org.apache.spark.ml.regression.LinearRegression
    +import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
    +import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
    +import org.apache.spark.sql.DataFrame
    +import org.apache.spark.sql.types.StructType
    +
    +class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext {
    +  test("train validation with logistic regression") {
    +    val dataset = sqlContext.createDataFrame(
    +      sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2))
    +
    +    val lr = new LogisticRegression
    +    val lrParamMaps = new ParamGridBuilder()
    +      .addGrid(lr.regParam, Array(0.001, 1000.0))
    +      .addGrid(lr.maxIter, Array(0, 10))
    +      .build()
    +    val eval = new BinaryClassificationEvaluator
    +    val cv = new TrainValidationSplit()
    +      .setEstimator(lr)
    +      .setEstimatorParamMaps(lrParamMaps)
    +      .setEvaluator(eval)
    +      .setTrainRatio(0.5)
    +    val cvModel = cv.fit(dataset)
    +    val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression]
    +    assert(cv.getTrainRatio === 0.5)
    +    assert(parent.getRegParam === 0.001)
    +    assert(parent.getMaxIter === 10)
    +    assert(cvModel.validationMetrics.length === lrParamMaps.length)
    +  }
    +
    +  test("train validation with linear regression") {
    +    val dataset = sqlContext.createDataFrame(
    +        sc.parallelize(LinearDataGenerator.generateLinearInput(
    +            6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2))
    +
    +    val trainer = new LinearRegression
    +    val lrParamMaps = new ParamGridBuilder()
    +      .addGrid(trainer.regParam, Array(1000.0, 0.001))
    +      .addGrid(trainer.maxIter, Array(0, 10))
    +      .build()
    +    val eval = new RegressionEvaluator()
    +    val cv = new TrainValidationSplit()
    +      .setEstimator(trainer)
    +      .setEstimatorParamMaps(lrParamMaps)
    +      .setEvaluator(eval)
    +      .setTrainRatio(0.5)
    +    val cvModel = cv.fit(dataset)
    +    val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression]
    +    assert(parent.getRegParam === 0.001)
    +    assert(parent.getMaxIter === 10)
    +    assert(cvModel.validationMetrics.length === lrParamMaps.length)
    +
    +      eval.setMetricName("r2")
    +    val cvModel2 = cv.fit(dataset)
    +    val parent2 = cvModel2.bestModel.parent.asInstanceOf[LinearRegression]
    +    assert(parent2.getRegParam === 0.001)
    +    assert(parent2.getMaxIter === 10)
    +    assert(cvModel2.validationMetrics.length === lrParamMaps.length)
    +  }
    +
    +  test("validateParams should check estimatorParamMaps") {
    +    import TrainValidationSplitSuite._
    +
    +    val est = new MyEstimator("est")
    +    val eval = new MyEvaluator
    +    val paramMaps = new ParamGridBuilder()
    +      .addGrid(est.inputCol, Array("input1", "input2"))
    +      .build()
    +
    +    val cv = new TrainValidationSplit()
    +      .setEstimator(est)
    +      .setEstimatorParamMaps(paramMaps)
    +      .setEvaluator(eval)
    +      .setTrainRatio(0.5)
    +    cv.validateParams() // This should pass.
    +
    +    val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "")
    +    cv.setEstimatorParamMaps(invalidParamMaps)
    +    intercept[IllegalArgumentException] {
    +      cv.validateParams()
    +    }
    +  }
    +}
    +
    +object TrainValidationSplitSuite {
    +
    +  abstract class MyModel extends Model[MyModel]
    +
    +  class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol {
    +
    +    override def validateParams(): Unit = require($(inputCol).nonEmpty)
    +
    +    override def fit(dataset: DataFrame): MyModel = {
    +      throw new UnsupportedOperationException
    +    }
    +
    +    override def transformSchema(schema: StructType): StructType = {
    +      throw new UnsupportedOperationException
    +    }
    +
    +    override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra)
    +  }
    +
    +  class MyEvaluator extends Evaluator {
    +
    +    override def evaluate(dataset: DataFrame): Double = {
    +      throw new UnsupportedOperationException
    +    }
    +
    +    override val uid: String = "eval"
    +
    +    override def copy(extra: ParamMap): MyEvaluator = defaultCopy(extra)
    +  }
    +}
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala
    new file mode 100644
    index 0000000000000..9e6bc7193c13b
    --- /dev/null
    +++ b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala
    @@ -0,0 +1,125 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.util
    +
    +import java.util.Random
    +
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.mllib.util.MLlibTestSparkContext
    +
    +class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext {
    +
    +  import StopwatchSuite._
    +
    +  private def testStopwatchOnDriver(sw: Stopwatch): Unit = {
    +    assert(sw.name === "sw")
    +    assert(sw.elapsed() === 0L)
    +    assert(!sw.isRunning)
    +    intercept[AssertionError] {
    +      sw.stop()
    +    }
    +    val duration = checkStopwatch(sw)
    +    val elapsed = sw.elapsed()
    +    assert(elapsed === duration)
    +    val duration2 = checkStopwatch(sw)
    +    val elapsed2 = sw.elapsed()
    +    assert(elapsed2 === duration + duration2)
    +    assert(sw.toString === s"sw: ${elapsed2}ms")
    +    sw.start()
    +    assert(sw.isRunning)
    +    intercept[AssertionError] {
    +      sw.start()
    +    }
    +  }
    +
    +  test("LocalStopwatch") {
    +    val sw = new LocalStopwatch("sw")
    +    testStopwatchOnDriver(sw)
    +  }
    +
    +  test("DistributedStopwatch on driver") {
    +    val sw = new DistributedStopwatch(sc, "sw")
    +    testStopwatchOnDriver(sw)
    +  }
    +
    +  test("DistributedStopwatch on executors") {
    +    val sw = new DistributedStopwatch(sc, "sw")
    +    val rdd = sc.parallelize(0 until 4, 4)
    +    val acc = sc.accumulator(0L)
    +    rdd.foreach { i =>
    +      acc += checkStopwatch(sw)
    +    }
    +    assert(!sw.isRunning)
    +    val elapsed = sw.elapsed()
    +    assert(elapsed === acc.value)
    +  }
    +
    +  test("MultiStopwatch") {
    +    val sw = new MultiStopwatch(sc)
    +      .addLocal("local")
    +      .addDistributed("spark")
    +    assert(sw("local").name === "local")
    +    assert(sw("spark").name === "spark")
    +    intercept[NoSuchElementException] {
    +      sw("some")
    +    }
    +    assert(sw.toString === "{\n  local: 0ms,\n  spark: 0ms\n}")
    +    val localDuration = checkStopwatch(sw("local"))
    +    val sparkDuration = checkStopwatch(sw("spark"))
    +    val localElapsed = sw("local").elapsed()
    +    val sparkElapsed = sw("spark").elapsed()
    +    assert(localElapsed === localDuration)
    +    assert(sparkElapsed === sparkDuration)
    +    assert(sw.toString ===
    +      s"{\n  local: ${localElapsed}ms,\n  spark: ${sparkElapsed}ms\n}")
    +    val rdd = sc.parallelize(0 until 4, 4)
    +    val acc = sc.accumulator(0L)
    +    rdd.foreach { i =>
    +      sw("local").start()
    +      val duration = checkStopwatch(sw("spark"))
    +      sw("local").stop()
    +      acc += duration
    +    }
    +    val localElapsed2 = sw("local").elapsed()
    +    assert(localElapsed2 === localElapsed)
    +    val sparkElapsed2 = sw("spark").elapsed()
    +    assert(sparkElapsed2 === sparkElapsed + acc.value)
    +  }
    +}
    +
    +private object StopwatchSuite extends SparkFunSuite {
    +
    +  /**
    +   * Checks the input stopwatch on a task that takes a random time (<10ms) to finish. Validates and
    +   * returns the duration reported by the stopwatch.
    +   */
    +  def checkStopwatch(sw: Stopwatch): Long = {
    +    val ubStart = now
    +    sw.start()
    +    val lbStart = now
    +    Thread.sleep(new Random().nextInt(10))
    +    val lb = now - lbStart
    +    val duration = sw.stop()
    +    val ub = now - ubStart
    +    assert(duration >= lb && duration <= ub)
    +    duration
    +  }
    +
    +  /** The current time in milliseconds. */
    +  private def now: Long = System.currentTimeMillis()
    +}
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
    index f7fc8730606af..cffa1ab700f80 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
    @@ -19,13 +19,14 @@ package org.apache.spark.mllib.classification
     
     import scala.util.Random
     
    -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
    +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, Vector => BV}
     import breeze.stats.distributions.{Multinomial => BrzMultinomial}
     
     import org.apache.spark.{SparkException, SparkFunSuite}
    -import org.apache.spark.mllib.linalg.Vectors
    +import org.apache.spark.mllib.linalg.{Vector, Vectors}
     import org.apache.spark.mllib.regression.LabeledPoint
     import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
    +import org.apache.spark.mllib.util.TestingUtils._
     import org.apache.spark.util.Utils
     
     object NaiveBayesSuite {
    @@ -154,6 +155,29 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
     
         // Test prediction on Array.
         validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
    +
    +    // Test posteriors
    +    validationData.map(_.features).foreach { features =>
    +      val predicted = model.predictProbabilities(features).toArray
    +      assert(predicted.sum ~== 1.0 relTol 1.0e-10)
    +      val expected = expectedMultinomialProbabilities(model, features)
    +      expected.zip(predicted).foreach { case (e, p) => assert(e ~== p relTol 1.0e-10) }
    +    }
    +  }
    +
    +  /**
    +   * @param model Multinomial Naive Bayes model
    +   * @param testData input to compute posterior probabilities for
    +   * @return posterior class probabilities (in order of labels) for input
    +   */
    +  private def expectedMultinomialProbabilities(model: NaiveBayesModel, testData: Vector) = {
    +    val piVector = new BDV(model.pi)
    +    // model.theta is row-major; treat it as col-major representation of transpose, and transpose:
    +    val thetaMatrix = new BDM(model.theta(0).length, model.theta.length, model.theta.flatten).t
    +    val logClassProbs: BV[Double] = piVector + (thetaMatrix * testData.toBreeze)
    +    val classProbs = logClassProbs.toArray.map(math.exp)
    +    val classProbsSum = classProbs.sum
    +    classProbs.map(_ / classProbsSum)
       }
     
       test("Naive Bayes Bernoulli") {
    @@ -182,6 +206,33 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
     
         // Test prediction on Array.
         validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
    +
    +    // Test posteriors
    +    validationData.map(_.features).foreach { features =>
    +      val predicted = model.predictProbabilities(features).toArray
    +      assert(predicted.sum ~== 1.0 relTol 1.0e-10)
    +      val expected = expectedBernoulliProbabilities(model, features)
    +      expected.zip(predicted).foreach { case (e, p) => assert(e ~== p relTol 1.0e-10) }
    +    }
    +  }
    +
    +  /**
    +   * @param model Bernoulli Naive Bayes model
    +   * @param testData input to compute posterior probabilities for
    +   * @return posterior class probabilities (in order of labels) for input
    +   */
    +  private def expectedBernoulliProbabilities(model: NaiveBayesModel, testData: Vector) = {
    +    val piVector = new BDV(model.pi)
    +    val thetaMatrix = new BDM(model.theta(0).length, model.theta.length, model.theta.flatten).t
    +    val negThetaMatrix = new BDM(model.theta(0).length, model.theta.length,
    +      model.theta.flatten.map(v => math.log(1.0 - math.exp(v)))).t
    +    val testBreeze = testData.toBreeze
    +    val negTestBreeze = new BDV(Array.fill(testBreeze.size)(1.0)) - testBreeze
    +    val piTheta: BV[Double] = piVector + (thetaMatrix * testBreeze)
    +    val logClassProbs: BV[Double] = piTheta + (negThetaMatrix * negTestBreeze)
    +    val classProbs = logClassProbs.toArray.map(math.exp)
    +    val classProbsSum = classProbs.sum
    +    classProbs.map(_ / classProbsSum)
       }
     
       test("detect negative values") {
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
    index fd653296c9d97..d7b291d5a6330 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
    @@ -24,13 +24,22 @@ import org.apache.spark.mllib.linalg.Vectors
     import org.apache.spark.mllib.regression.LabeledPoint
     import org.apache.spark.mllib.util.TestingUtils._
     import org.apache.spark.streaming.dstream.DStream
    -import org.apache.spark.streaming.TestSuiteBase
    +import org.apache.spark.streaming.{StreamingContext, TestSuiteBase}
     
     class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase {
     
       // use longer wait time to ensure job completion
       override def maxWaitTimeMillis: Int = 30000
     
    +  var ssc: StreamingContext = _
    +
    +  override def afterFunction() {
    +    super.afterFunction()
    +    if (ssc != null) {
    +      ssc.stop()
    +    }
    +  }
    +
       // Test if we can accurately learn B for Y = logistic(BX) on streaming data
       test("parameter accuracy") {
     
    @@ -50,7 +59,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase
         }
     
         // apply model training to input stream
    -    val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
    +    ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
           model.trainOn(inputDStream)
           inputDStream.count()
         })
    @@ -84,7 +93,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase
     
         // apply model training to input stream, storing the intermediate results
         // (we add a count to ensure the result is a DStream)
    -    val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
    +    ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
           model.trainOn(inputDStream)
           inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0) - B)))
           inputDStream.count()
    @@ -118,7 +127,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase
         }
     
         // apply model predictions to test stream
    -    val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
    +    ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
           model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
         })
     
    @@ -147,7 +156,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase
         }
     
         // train and predict
    -    val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
    +    ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
           model.trainOn(inputDStream)
           model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
         })
    @@ -167,7 +176,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase
           .setNumIterations(10)
         val numBatches = 10
         val emptyInput = Seq.empty[Seq[LabeledPoint]]
    -    val ssc = setupStreams(emptyInput,
    +    ssc = setupStreams(emptyInput,
           (inputDStream: DStream[LabeledPoint]) => {
             model.trainOn(inputDStream)
             model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
    index 0dbbd7127444f..3003c62d9876c 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
    @@ -278,6 +278,28 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
           }
         }
       }
    +
    +  test("Initialize using given cluster centers") {
    +    val points = Seq(
    +      Vectors.dense(0.0, 0.0),
    +      Vectors.dense(1.0, 0.0),
    +      Vectors.dense(0.0, 1.0),
    +      Vectors.dense(1.0, 1.0)
    +    )
    +    val rdd = sc.parallelize(points, 3)
    +    // creating an initial model
    +    val initialModel = new KMeansModel(Array(points(0), points(2)))
    +
    +    val returnModel = new KMeans()
    +      .setK(2)
    +      .setMaxIterations(0)
    +      .setInitialModel(initialModel)
    +      .run(rdd)
    +   // comparing the returned model and the initial model
    +    assert(returnModel.clusterCenters(0) === initialModel.clusterCenters(0))
    +    assert(returnModel.clusterCenters(1) === initialModel.clusterCenters(1))
    +  }
    +
     }
     
     object KMeansSuite extends SparkFunSuite {
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
    index 406affa25539d..c43e1e575c09c 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
    @@ -17,19 +17,22 @@
     
     package org.apache.spark.mllib.clustering
     
    -import breeze.linalg.{DenseMatrix => BDM}
    +import breeze.linalg.{DenseMatrix => BDM, max, argmax}
     
     import org.apache.spark.SparkFunSuite
    -import org.apache.spark.mllib.linalg.{Vector, DenseMatrix, Matrix, Vectors}
    +import org.apache.spark.graphx.Edge
    +import org.apache.spark.mllib.linalg.{DenseMatrix, Matrix, Vector, Vectors}
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.mllib.util.TestingUtils._
    +import org.apache.spark.util.Utils
     
     class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
     
       import LDASuite._
     
       test("LocalLDAModel") {
    -    val model = new LocalLDAModel(tinyTopics)
    +    val model = new LocalLDAModel(tinyTopics,
    +      Vectors.dense(Array.fill(tinyTopics.numRows)(1.0 / tinyTopics.numRows)), 1D, 100D)
     
         // Check: basic parameters
         assert(model.k === tinyK)
    @@ -80,28 +83,25 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
         assert(model.topicsMatrix === localModel.topicsMatrix)
     
         // Check: topic summaries
    -    //  The odd decimal formatting and sorting is a hack to do a robust comparison.
    -    val roundedTopicSummary = model.describeTopics().map { case (terms, termWeights) =>
    -      // cut values to 3 digits after the decimal place
    -      terms.zip(termWeights).map { case (term, weight) =>
    -        ("%.3f".format(weight).toDouble, term.toInt)
    -      }
    -    }.sortBy(_.mkString(""))
    -    val roundedLocalTopicSummary = localModel.describeTopics().map { case (terms, termWeights) =>
    -      // cut values to 3 digits after the decimal place
    -      terms.zip(termWeights).map { case (term, weight) =>
    -        ("%.3f".format(weight).toDouble, term.toInt)
    -      }
    -    }.sortBy(_.mkString(""))
    -    roundedTopicSummary.zip(roundedLocalTopicSummary).foreach { case (t1, t2) =>
    -      assert(t1 === t2)
    +    val topicSummary = model.describeTopics().map { case (terms, termWeights) =>
    +      Vectors.sparse(tinyVocabSize, terms, termWeights)
    +    }.sortBy(_.toString)
    +    val localTopicSummary = localModel.describeTopics().map { case (terms, termWeights) =>
    +      Vectors.sparse(tinyVocabSize, terms, termWeights)
    +    }.sortBy(_.toString)
    +    topicSummary.zip(localTopicSummary).foreach { case (topics, topicsLocal) =>
    +      assert(topics ~== topicsLocal absTol 0.01)
         }
     
         // Check: per-doc topic distributions
         val topicDistributions = model.topicDistributions.collect()
    +
         //  Ensure all documents are covered.
    -    assert(topicDistributions.length === tinyCorpus.length)
    -    assert(tinyCorpus.map(_._1).toSet === topicDistributions.map(_._1).toSet)
    +    // SPARK-5562. since the topicDistribution returns the distribution of the non empty docs
    +    // over topics. Compare it against nonEmptyTinyCorpus instead of tinyCorpus
    +    val nonEmptyTinyCorpus = getNonEmptyDoc(tinyCorpus)
    +    assert(topicDistributions.length === nonEmptyTinyCorpus.length)
    +    assert(nonEmptyTinyCorpus.map(_._1).toSet === topicDistributions.map(_._1).toSet)
         //  Ensure we have proper distributions
         topicDistributions.foreach { case (docId, topicDistribution) =>
           assert(topicDistribution.size === tinyK)
    @@ -127,22 +127,38 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
     
       test("setter alias") {
         val lda = new LDA().setAlpha(2.0).setBeta(3.0)
    -    assert(lda.getAlpha === 2.0)
    -    assert(lda.getDocConcentration === 2.0)
    +    assert(lda.getAlpha.toArray.forall(_ === 2.0))
    +    assert(lda.getDocConcentration.toArray.forall(_ === 2.0))
         assert(lda.getBeta === 3.0)
         assert(lda.getTopicConcentration === 3.0)
       }
     
    +  test("initializing with alpha length != k or 1 fails") {
    +    intercept[IllegalArgumentException] {
    +      val lda = new LDA().setK(2).setAlpha(Vectors.dense(1, 2, 3, 4))
    +      val corpus = sc.parallelize(tinyCorpus, 2)
    +      lda.run(corpus)
    +    }
    +  }
    +
    +  test("initializing with elements in alpha < 0 fails") {
    +    intercept[IllegalArgumentException] {
    +      val lda = new LDA().setK(4).setAlpha(Vectors.dense(-1, 2, 3, 4))
    +      val corpus = sc.parallelize(tinyCorpus, 2)
    +      lda.run(corpus)
    +    }
    +  }
    +
       test("OnlineLDAOptimizer initialization") {
         val lda = new LDA().setK(2)
         val corpus = sc.parallelize(tinyCorpus, 2)
         val op = new OnlineLDAOptimizer().initialize(corpus, lda)
         op.setKappa(0.9876).setMiniBatchFraction(0.123).setTau0(567)
    -    assert(op.getAlpha == 0.5) // default 1.0 / k
    -    assert(op.getEta == 0.5)   // default 1.0 / k
    -    assert(op.getKappa == 0.9876)
    -    assert(op.getMiniBatchFraction == 0.123)
    -    assert(op.getTau0 == 567)
    +    assert(op.getAlpha.toArray.forall(_ === 0.5)) // default 1.0 / k
    +    assert(op.getEta === 0.5)   // default 1.0 / k
    +    assert(op.getKappa === 0.9876)
    +    assert(op.getMiniBatchFraction === 0.123)
    +    assert(op.getTau0 === 567)
       }
     
       test("OnlineLDAOptimizer one iteration") {
    @@ -174,10 +190,12 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
     
         // verify the result, Note this generate the identical result as
         // [[https://github.com/Blei-Lab/onlineldavb]]
    -    val topic1 = op.getLambda(0, ::).inner.toArray.map("%.4f".format(_)).mkString(", ")
    -    val topic2 = op.getLambda(1, ::).inner.toArray.map("%.4f".format(_)).mkString(", ")
    -    assert("1.1101, 1.2076, 1.3050, 0.8899, 0.7924, 0.6950" == topic1)
    -    assert("0.8899, 0.7924, 0.6950, 1.1101, 1.2076, 1.3050" == topic2)
    +    val topic1: Vector = Vectors.fromBreeze(op.getLambda(0, ::).t)
    +    val topic2: Vector = Vectors.fromBreeze(op.getLambda(1, ::).t)
    +    val expectedTopic1 = Vectors.dense(1.1101, 1.2076, 1.3050, 0.8899, 0.7924, 0.6950)
    +    val expectedTopic2 = Vectors.dense(0.8899, 0.7924, 0.6950, 1.1101, 1.2076, 1.3050)
    +    assert(topic1 ~== expectedTopic1 absTol 0.01)
    +    assert(topic2 ~== expectedTopic2 absTol 0.01)
       }
     
       test("OnlineLDAOptimizer with toy data") {
    @@ -213,6 +231,263 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
         }
       }
     
    +  test("LocalLDAModel logPerplexity") {
    +    val k = 2
    +    val vocabSize = 6
    +    val alpha = 0.01
    +    val eta = 0.01
    +    val gammaShape = 100
    +    // obtained from LDA model trained in gensim, see below
    +    val topics = new DenseMatrix(numRows = vocabSize, numCols = k, values = Array(
    +      1.86738052, 1.94056535, 1.89981687, 0.0833265, 0.07405918, 0.07940597,
    +      0.15081551, 0.08637973, 0.12428538, 1.9474897, 1.94615165, 1.95204124))
    +
    +    def toydata: Array[(Long, Vector)] = Array(
    +      Vectors.sparse(6, Array(0, 1), Array(1, 1)),
    +      Vectors.sparse(6, Array(1, 2), Array(1, 1)),
    +      Vectors.sparse(6, Array(0, 2), Array(1, 1)),
    +      Vectors.sparse(6, Array(3, 4), Array(1, 1)),
    +      Vectors.sparse(6, Array(3, 5), Array(1, 1)),
    +      Vectors.sparse(6, Array(4, 5), Array(1, 1))
    +    ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) }
    +    val docs = sc.parallelize(toydata)
    +
    +
    +    val ldaModel: LocalLDAModel = new LocalLDAModel(
    +      topics, Vectors.dense(Array.fill(k)(alpha)), eta, gammaShape)
    +
    +    /* Verify results using gensim:
    +       import numpy as np
    +       from gensim import models
    +       corpus = [
    +          [(0, 1.0), (1, 1.0)],
    +          [(1, 1.0), (2, 1.0)],
    +          [(0, 1.0), (2, 1.0)],
    +          [(3, 1.0), (4, 1.0)],
    +          [(3, 1.0), (5, 1.0)],
    +          [(4, 1.0), (5, 1.0)]]
    +       np.random.seed(2345)
    +       lda = models.ldamodel.LdaModel(
    +          corpus=corpus, alpha=0.01, eta=0.01, num_topics=2, update_every=0, passes=100,
    +          decay=0.51, offset=1024)
    +       print(lda.log_perplexity(corpus))
    +       > -3.69051285096
    +     */
    +
    +    assert(ldaModel.logPerplexity(docs) ~== -3.690D relTol 1E-3D)
    +  }
    +
    +  test("LocalLDAModel predict") {
    +    val k = 2
    +    val vocabSize = 6
    +    val alpha = 0.01
    +    val eta = 0.01
    +    val gammaShape = 100
    +    // obtained from LDA model trained in gensim, see below
    +    val topics = new DenseMatrix(numRows = vocabSize, numCols = k, values = Array(
    +      1.86738052, 1.94056535, 1.89981687, 0.0833265, 0.07405918, 0.07940597,
    +      0.15081551, 0.08637973, 0.12428538, 1.9474897, 1.94615165, 1.95204124))
    +
    +    def toydata: Array[(Long, Vector)] = Array(
    +      Vectors.sparse(6, Array(0, 1), Array(1, 1)),
    +      Vectors.sparse(6, Array(1, 2), Array(1, 1)),
    +      Vectors.sparse(6, Array(0, 2), Array(1, 1)),
    +      Vectors.sparse(6, Array(3, 4), Array(1, 1)),
    +      Vectors.sparse(6, Array(3, 5), Array(1, 1)),
    +      Vectors.sparse(6, Array(4, 5), Array(1, 1))
    +    ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) }
    +    val docs = sc.parallelize(toydata)
    +
    +    val ldaModel: LocalLDAModel = new LocalLDAModel(
    +      topics, Vectors.dense(Array.fill(k)(alpha)), eta, gammaShape)
    +
    +    /* Verify results using gensim:
    +       import numpy as np
    +       from gensim import models
    +       corpus = [
    +          [(0, 1.0), (1, 1.0)],
    +          [(1, 1.0), (2, 1.0)],
    +          [(0, 1.0), (2, 1.0)],
    +          [(3, 1.0), (4, 1.0)],
    +          [(3, 1.0), (5, 1.0)],
    +          [(4, 1.0), (5, 1.0)]]
    +       np.random.seed(2345)
    +       lda = models.ldamodel.LdaModel(
    +          corpus=corpus, alpha=0.01, eta=0.01, num_topics=2, update_every=0, passes=100,
    +          decay=0.51, offset=1024)
    +       print(list(lda.get_document_topics(corpus)))
    +       > [[(0, 0.99504950495049516)], [(0, 0.99504950495049516)],
    +       > [(0, 0.99504950495049516)], [(1, 0.99504950495049516)],
    +       > [(1, 0.99504950495049516)], [(1, 0.99504950495049516)]]
    +     */
    +
    +    val expectedPredictions = List(
    +      (0, 0.99504), (0, 0.99504),
    +      (0, 0.99504), (1, 0.99504),
    +      (1, 0.99504), (1, 0.99504))
    +
    +    val actualPredictions = ldaModel.topicDistributions(docs).map { case (id, topics) =>
    +        // convert results to expectedPredictions format, which only has highest probability topic
    +        val topicsBz = topics.toBreeze.toDenseVector
    +        (id, (argmax(topicsBz), max(topicsBz)))
    +      }.sortByKey()
    +      .values
    +      .collect()
    +
    +    expectedPredictions.zip(actualPredictions).forall { case (expected, actual) =>
    +      expected._1 === actual._1 && (expected._2 ~== actual._2 relTol 1E-3D)
    +    }
    +  }
    +
    +  test("OnlineLDAOptimizer with asymmetric prior") {
    +    def toydata: Array[(Long, Vector)] = Array(
    +      Vectors.sparse(6, Array(0, 1), Array(1, 1)),
    +      Vectors.sparse(6, Array(1, 2), Array(1, 1)),
    +      Vectors.sparse(6, Array(0, 2), Array(1, 1)),
    +      Vectors.sparse(6, Array(3, 4), Array(1, 1)),
    +      Vectors.sparse(6, Array(3, 5), Array(1, 1)),
    +      Vectors.sparse(6, Array(4, 5), Array(1, 1))
    +    ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) }
    +
    +    val docs = sc.parallelize(toydata)
    +    val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau0(1024).setKappa(0.51)
    +      .setGammaShape(1e10)
    +    val lda = new LDA().setK(2)
    +      .setDocConcentration(Vectors.dense(0.00001, 0.1))
    +      .setTopicConcentration(0.01)
    +      .setMaxIterations(100)
    +      .setOptimizer(op)
    +      .setSeed(12345)
    +
    +    val ldaModel = lda.run(docs)
    +    val topicIndices = ldaModel.describeTopics(maxTermsPerTopic = 10)
    +    val topics = topicIndices.map { case (terms, termWeights) =>
    +      terms.zip(termWeights)
    +    }
    +
    +    /* Verify results with Python:
    +
    +       import numpy as np
    +       from gensim import models
    +       corpus = [
    +           [(0, 1.0), (1, 1.0)],
    +           [(1, 1.0), (2, 1.0)],
    +           [(0, 1.0), (2, 1.0)],
    +           [(3, 1.0), (4, 1.0)],
    +           [(3, 1.0), (5, 1.0)],
    +           [(4, 1.0), (5, 1.0)]]
    +       np.random.seed(10)
    +       lda = models.ldamodel.LdaModel(
    +           corpus=corpus, alpha=np.array([0.00001, 0.1]), num_topics=2, update_every=0, passes=100)
    +       lda.print_topics()
    +
    +       > ['0.167*0 + 0.167*1 + 0.167*2 + 0.167*3 + 0.167*4 + 0.167*5',
    +          '0.167*0 + 0.167*1 + 0.167*2 + 0.167*4 + 0.167*3 + 0.167*5']
    +     */
    +    topics.foreach { topic =>
    +      assert(topic.forall { case (_, p) => p ~= 0.167 absTol 0.05 })
    +    }
    +  }
    +
    +  test("model save/load") {
    +    // Test for LocalLDAModel.
    +    val localModel = new LocalLDAModel(tinyTopics,
    +      Vectors.dense(Array.fill(tinyTopics.numRows)(0.01)), 0.5D, 10D)
    +    val tempDir1 = Utils.createTempDir()
    +    val path1 = tempDir1.toURI.toString
    +
    +    // Test for DistributedLDAModel.
    +    val k = 3
    +    val docConcentration = 1.2
    +    val topicConcentration = 1.5
    +    val lda = new LDA()
    +    lda.setK(k)
    +      .setDocConcentration(docConcentration)
    +      .setTopicConcentration(topicConcentration)
    +      .setMaxIterations(5)
    +      .setSeed(12345)
    +    val corpus = sc.parallelize(tinyCorpus, 2)
    +    val distributedModel: DistributedLDAModel = lda.run(corpus).asInstanceOf[DistributedLDAModel]
    +    val tempDir2 = Utils.createTempDir()
    +    val path2 = tempDir2.toURI.toString
    +
    +    try {
    +      localModel.save(sc, path1)
    +      distributedModel.save(sc, path2)
    +      val samelocalModel = LocalLDAModel.load(sc, path1)
    +      assert(samelocalModel.topicsMatrix === localModel.topicsMatrix)
    +      assert(samelocalModel.k === localModel.k)
    +      assert(samelocalModel.vocabSize === localModel.vocabSize)
    +      assert(samelocalModel.docConcentration === localModel.docConcentration)
    +      assert(samelocalModel.topicConcentration === localModel.topicConcentration)
    +      assert(samelocalModel.gammaShape === localModel.gammaShape)
    +
    +      val sameDistributedModel = DistributedLDAModel.load(sc, path2)
    +      assert(distributedModel.topicsMatrix === sameDistributedModel.topicsMatrix)
    +      assert(distributedModel.k === sameDistributedModel.k)
    +      assert(distributedModel.vocabSize === sameDistributedModel.vocabSize)
    +      assert(distributedModel.iterationTimes === sameDistributedModel.iterationTimes)
    +      assert(distributedModel.docConcentration === sameDistributedModel.docConcentration)
    +      assert(distributedModel.topicConcentration === sameDistributedModel.topicConcentration)
    +      assert(distributedModel.gammaShape === sameDistributedModel.gammaShape)
    +      assert(distributedModel.globalTopicTotals === sameDistributedModel.globalTopicTotals)
    +
    +      val graph = distributedModel.graph
    +      val sameGraph = sameDistributedModel.graph
    +      assert(graph.vertices.sortByKey().collect() === sameGraph.vertices.sortByKey().collect())
    +      val edge = graph.edges.map {
    +        case Edge(sid: Long, did: Long, nos: Double) => (sid, did, nos)
    +      }.sortBy(x => (x._1, x._2)).collect()
    +      val sameEdge = sameGraph.edges.map {
    +        case Edge(sid: Long, did: Long, nos: Double) => (sid, did, nos)
    +      }.sortBy(x => (x._1, x._2)).collect()
    +      assert(edge === sameEdge)
    +    } finally {
    +      Utils.deleteRecursively(tempDir1)
    +      Utils.deleteRecursively(tempDir2)
    +    }
    +  }
    +
    +  test("EMLDAOptimizer with empty docs") {
    +    val vocabSize = 6
    +    val emptyDocsArray = Array.fill(6)(Vectors.sparse(vocabSize, Array.empty, Array.empty))
    +    val emptyDocs = emptyDocsArray
    +      .zipWithIndex.map { case (wordCounts, docId) =>
    +        (docId.toLong, wordCounts)
    +    }
    +    val distributedEmptyDocs = sc.parallelize(emptyDocs, 2)
    +
    +    val op = new EMLDAOptimizer()
    +    val lda = new LDA()
    +      .setK(3)
    +      .setMaxIterations(5)
    +      .setSeed(12345)
    +      .setOptimizer(op)
    +
    +    val model = lda.run(distributedEmptyDocs)
    +    assert(model.vocabSize === vocabSize)
    +  }
    +
    +  test("OnlineLDAOptimizer with empty docs") {
    +    val vocabSize = 6
    +    val emptyDocsArray = Array.fill(6)(Vectors.sparse(vocabSize, Array.empty, Array.empty))
    +    val emptyDocs = emptyDocsArray
    +      .zipWithIndex.map { case (wordCounts, docId) =>
    +        (docId.toLong, wordCounts)
    +    }
    +    val distributedEmptyDocs = sc.parallelize(emptyDocs, 2)
    +
    +    val op = new OnlineLDAOptimizer()
    +    val lda = new LDA()
    +      .setK(3)
    +      .setMaxIterations(5)
    +      .setSeed(12345)
    +      .setOptimizer(op)
    +
    +    val model = lda.run(distributedEmptyDocs)
    +    assert(model.vocabSize === vocabSize)
    +  }
    +
     }
     
     private[clustering] object LDASuite {
    @@ -232,12 +507,17 @@ private[clustering] object LDASuite {
       }
     
       def tinyCorpus: Array[(Long, Vector)] = Array(
    +    Vectors.dense(0, 0, 0, 0, 0), // empty doc
         Vectors.dense(1, 3, 0, 2, 8),
         Vectors.dense(0, 2, 1, 0, 4),
         Vectors.dense(2, 3, 12, 3, 1),
    +    Vectors.dense(0, 0, 0, 0, 0), // empty doc
         Vectors.dense(0, 3, 1, 9, 8),
         Vectors.dense(1, 1, 4, 2, 6)
       ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) }
       assert(tinyCorpus.forall(_._2.size == tinyVocabSize)) // sanity check for test data
     
    +  def getNonEmptyDoc(corpus: Array[(Long, Vector)]): Array[(Long, Vector)] = corpus.filter {
    +    case (_, wc: Vector) => Vectors.norm(wc, p = 1.0) != 0.0
    +  }
     }
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
    index 19e65f1b53ab5..189000512155f 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
    @@ -68,6 +68,54 @@ class PowerIterationClusteringSuite extends SparkFunSuite with MLlibTestSparkCon
         assert(predictions2.toSet == Set((0 to 3).toSet, (4 to 15).toSet))
       }
     
    +  test("power iteration clustering on graph") {
    +    /*
    +     We use the following graph to test PIC. All edges are assigned similarity 1.0 except 0.1 for
    +     edge (3, 4).
    +
    +     15-14 -13 -12
    +     |           |
    +     4 . 3 - 2  11
    +     |   | x |   |
    +     5   0 - 1  10
    +     |           |
    +     6 - 7 - 8 - 9
    +     */
    +
    +    val similarities = Seq[(Long, Long, Double)]((0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0),
    +      (1, 3, 1.0), (2, 3, 1.0), (3, 4, 0.1), // (3, 4) is a weak edge
    +      (4, 5, 1.0), (4, 15, 1.0), (5, 6, 1.0), (6, 7, 1.0), (7, 8, 1.0), (8, 9, 1.0), (9, 10, 1.0),
    +      (10, 11, 1.0), (11, 12, 1.0), (12, 13, 1.0), (13, 14, 1.0), (14, 15, 1.0))
    +
    +    val edges = similarities.flatMap { case (i, j, s) =>
    +      if (i != j) {
    +        Seq(Edge(i, j, s), Edge(j, i, s))
    +      } else {
    +        None
    +      }
    +    }
    +    val graph = Graph.fromEdges(sc.parallelize(edges, 2), 0.0)
    +
    +    val model = new PowerIterationClustering()
    +      .setK(2)
    +      .run(graph)
    +    val predictions = Array.fill(2)(mutable.Set.empty[Long])
    +    model.assignments.collect().foreach { a =>
    +      predictions(a.cluster) += a.id
    +    }
    +    assert(predictions.toSet == Set((0 to 3).toSet, (4 to 15).toSet))
    +
    +    val model2 = new PowerIterationClustering()
    +      .setK(2)
    +      .setInitializationMode("degree")
    +      .run(sc.parallelize(similarities, 2))
    +    val predictions2 = Array.fill(2)(mutable.Set.empty[Long])
    +    model2.assignments.collect().foreach { a =>
    +      predictions2(a.cluster) += a.id
    +    }
    +    assert(predictions2.toSet == Set((0 to 3).toSet, (4 to 15).toSet))
    +  }
    +
       test("normalize and powerIter") {
         /*
          Test normalize() with the following graph:
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
    index ac01622b8a089..3645d29dccdb2 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
    @@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering
     import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.linalg.{Vector, Vectors}
     import org.apache.spark.mllib.util.TestingUtils._
    -import org.apache.spark.streaming.TestSuiteBase
    +import org.apache.spark.streaming.{StreamingContext, TestSuiteBase}
     import org.apache.spark.streaming.dstream.DStream
     import org.apache.spark.util.random.XORShiftRandom
     
    @@ -28,6 +28,15 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase {
     
       override def maxWaitTimeMillis: Int = 30000
     
    +  var ssc: StreamingContext = _
    +
    +  override def afterFunction() {
    +    super.afterFunction()
    +    if (ssc != null) {
    +      ssc.stop()
    +    }
    +  }
    +
       test("accuracy for single center and equivalence to grand average") {
         // set parameters
         val numBatches = 10
    @@ -46,7 +55,7 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase {
         val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42)
     
         // setup and run the model training
    -    val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
    +    ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
           model.trainOn(inputDStream)
           inputDStream.count()
         })
    @@ -82,7 +91,7 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase {
         val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42)
     
         // setup and run the model training
    -    val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
    +    ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
           kMeans.trainOn(inputDStream)
           inputDStream.count()
         })
    @@ -114,7 +123,7 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase {
           StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42, Array(Vectors.dense(0.0)))
     
         // setup and run the model training
    -    val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
    +    ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
           kMeans.trainOn(inputDStream)
           inputDStream.count()
         })
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
    index 9de2bdb6d7246..4b7f1be58f99b 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
    @@ -23,24 +23,85 @@ import org.apache.spark.mllib.util.TestingUtils._
     
     class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
     
    -  test("regression metrics") {
    +  test("regression metrics for unbiased (includes intercept term) predictor") {
    +    /* Verify results in R:
    +       preds = c(2.25, -0.25, 1.75, 7.75)
    +       obs = c(3.0, -0.5, 2.0, 7.0)
    +
    +       SStot = sum((obs - mean(obs))^2)
    +       SSreg = sum((preds - mean(obs))^2)
    +       SSerr = sum((obs - preds)^2)
    +
    +       explainedVariance = SSreg / length(obs)
    +       explainedVariance
    +       > [1] 8.796875
    +       meanAbsoluteError = mean(abs(preds - obs))
    +       meanAbsoluteError
    +       > [1] 0.5
    +       meanSquaredError = mean((preds - obs)^2)
    +       meanSquaredError
    +       > [1] 0.3125
    +       rmse = sqrt(meanSquaredError)
    +       rmse
    +       > [1] 0.559017
    +       r2 = 1 - SSerr / SStot
    +       r2
    +       > [1] 0.9571734
    +     */
    +    val predictionAndObservations = sc.parallelize(
    +      Seq((2.25, 3.0), (-0.25, -0.5), (1.75, 2.0), (7.75, 7.0)), 2)
    +    val metrics = new RegressionMetrics(predictionAndObservations)
    +    assert(metrics.explainedVariance ~== 8.79687 absTol 1E-5,
    +      "explained variance regression score mismatch")
    +    assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch")
    +    assert(metrics.meanSquaredError ~== 0.3125 absTol 1E-5, "mean squared error mismatch")
    +    assert(metrics.rootMeanSquaredError ~== 0.55901 absTol 1E-5,
    +      "root mean squared error mismatch")
    +    assert(metrics.r2 ~== 0.95717 absTol 1E-5, "r2 score mismatch")
    +  }
    +
    +  test("regression metrics for biased (no intercept term) predictor") {
    +    /* Verify results in R:
    +       preds = c(2.5, 0.0, 2.0, 8.0)
    +       obs = c(3.0, -0.5, 2.0, 7.0)
    +
    +       SStot = sum((obs - mean(obs))^2)
    +       SSreg = sum((preds - mean(obs))^2)
    +       SSerr = sum((obs - preds)^2)
    +
    +       explainedVariance = SSreg / length(obs)
    +       explainedVariance
    +       > [1] 8.859375
    +       meanAbsoluteError = mean(abs(preds - obs))
    +       meanAbsoluteError
    +       > [1] 0.5
    +       meanSquaredError = mean((preds - obs)^2)
    +       meanSquaredError
    +       > [1] 0.375
    +       rmse = sqrt(meanSquaredError)
    +       rmse
    +       > [1] 0.6123724
    +       r2 = 1 - SSerr / SStot
    +       r2
    +       > [1] 0.9486081
    +     */
         val predictionAndObservations = sc.parallelize(
           Seq((2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)), 2)
         val metrics = new RegressionMetrics(predictionAndObservations)
    -    assert(metrics.explainedVariance ~== 0.95717 absTol 1E-5,
    +    assert(metrics.explainedVariance ~== 8.85937 absTol 1E-5,
           "explained variance regression score mismatch")
         assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch")
         assert(metrics.meanSquaredError ~== 0.375 absTol 1E-5, "mean squared error mismatch")
         assert(metrics.rootMeanSquaredError ~== 0.61237 absTol 1E-5,
           "root mean squared error mismatch")
    -    assert(metrics.r2 ~== 0.94861 absTol 1E-5, "r2 score mismatch")
    +    assert(metrics.r2 ~== 0.94860 absTol 1E-5, "r2 score mismatch")
       }
     
       test("regression metrics with complete fitting") {
         val predictionAndObservations = sc.parallelize(
           Seq((3.0, 3.0), (0.0, 0.0), (2.0, 2.0), (8.0, 8.0)), 2)
         val metrics = new RegressionMetrics(predictionAndObservations)
    -    assert(metrics.explainedVariance ~== 1.0 absTol 1E-5,
    +    assert(metrics.explainedVariance ~== 8.6875 absTol 1E-5,
           "explained variance regression score mismatch")
         assert(metrics.meanAbsoluteError ~== 0.0 absTol 1E-5, "mean absolute error mismatch")
         assert(metrics.meanSquaredError ~== 0.0 absTol 1E-5, "mean squared error mismatch")
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
    index b6818369208d7..a864eec460f2b 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
    @@ -37,6 +37,22 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext {
         assert(syms.length == 2)
         assert(syms(0)._1 == "b")
         assert(syms(1)._1 == "c")
    +
    +    // Test that model built using Word2Vec, i.e wordVectors and wordIndec
    +    // and a Word2VecMap give the same values.
    +    val word2VecMap = model.getVectors
    +    val newModel = new Word2VecModel(word2VecMap)
    +    assert(newModel.getVectors.mapValues(_.toSeq) === word2VecMap.mapValues(_.toSeq))
    +  }
    +
    +  test("Word2Vec throws exception when vocabulary is empty") {
    +    intercept[IllegalArgumentException] {
    +      val sentence = "a b c"
    +      val localDoc = Seq(sentence, sentence)
    +      val doc = sc.parallelize(localDoc)
    +        .map(line => line.split(" ").toSeq)
    +      new Word2Vec().setMinCount(10).fit(doc)
    +    }
       }
     
       test("Word2VecModel") {
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala
    new file mode 100644
    index 0000000000000..77a2773c36f56
    --- /dev/null
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala
    @@ -0,0 +1,89 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +package org.apache.spark.mllib.fpm
    +
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.mllib.util.MLlibTestSparkContext
    +
    +class AssociationRulesSuite extends SparkFunSuite with MLlibTestSparkContext {
    +
    +  test("association rules using String type") {
    +    val freqItemsets = sc.parallelize(Seq(
    +      (Set("s"), 3L), (Set("z"), 5L), (Set("x"), 4L), (Set("t"), 3L), (Set("y"), 3L),
    +      (Set("r"), 3L),
    +      (Set("x", "z"), 3L), (Set("t", "y"), 3L), (Set("t", "x"), 3L), (Set("s", "x"), 3L),
    +      (Set("y", "x"), 3L), (Set("y", "z"), 3L), (Set("t", "z"), 3L),
    +      (Set("y", "x", "z"), 3L), (Set("t", "x", "z"), 3L), (Set("t", "y", "z"), 3L),
    +      (Set("t", "y", "x"), 3L),
    +      (Set("t", "y", "x", "z"), 3L)
    +    ).map {
    +      case (items, freq) => new FPGrowth.FreqItemset(items.toArray, freq)
    +    })
    +
    +    val ar = new AssociationRules()
    +
    +    val results1 = ar
    +      .setMinConfidence(0.9)
    +      .run(freqItemsets)
    +      .collect()
    +
    +    /* Verify results using the `R` code:
    +       transactions = as(sapply(
    +         list("r z h k p",
    +              "z y x w v u t s",
    +              "s x o n r",
    +              "x z y m t s q e",
    +              "z",
    +              "x z y r q t p"),
    +         FUN=function(x) strsplit(x," ",fixed=TRUE)),
    +         "transactions")
    +       ars = apriori(transactions,
    +                     parameter = list(support = 0.0, confidence = 0.5, target="rules", minlen=2))
    +       arsDF = as(ars, "data.frame")
    +       arsDF$support = arsDF$support * length(transactions)
    +       names(arsDF)[names(arsDF) == "support"] = "freq"
    +       > nrow(arsDF)
    +       [1] 23
    +       > sum(arsDF$confidence == 1)
    +       [1] 23
    +     */
    +    assert(results1.size === 23)
    +    assert(results1.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23)
    +
    +    val results2 = ar
    +      .setMinConfidence(0)
    +      .run(freqItemsets)
    +      .collect()
    +
    +    /* Verify results using the `R` code:
    +       ars = apriori(transactions,
    +                  parameter = list(support = 0.5, confidence = 0.5, target="rules", minlen=2))
    +       arsDF = as(ars, "data.frame")
    +       arsDF$support = arsDF$support * length(transactions)
    +       names(arsDF)[names(arsDF) == "support"] = "freq"
    +       nrow(arsDF)
    +       sum(arsDF$confidence == 1)
    +       > nrow(arsDF)
    +       [1] 30
    +       > sum(arsDF$confidence == 1)
    +       [1] 23
    +     */
    +    assert(results2.size === 30)
    +    assert(results2.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23)
    +  }
    +}
    +
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
    index 1a8a1e79f2810..4a9bfdb348d9f 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
    @@ -22,7 +22,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
     class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
     
     
    -  test("FP-Growth frequent itemsets using String type") {
    +  test("FP-Growth using String type") {
         val transactions = Seq(
           "r z h k p",
           "z y x w v u t s",
    @@ -38,18 +38,59 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
         val model6 = fpg
           .setMinSupport(0.9)
           .setNumPartitions(1)
    -      .setOrdered(false)
           .run(rdd)
    +
    +    /* Verify results using the `R` code:
    +       transactions = as(sapply(
    +         list("r z h k p",
    +              "z y x w v u t s",
    +              "s x o n r",
    +              "x z y m t s q e",
    +              "z",
    +              "x z y r q t p"),
    +         FUN=function(x) strsplit(x," ",fixed=TRUE)),
    +         "transactions")
    +       > eclat(transactions, parameter = list(support = 0.9))
    +       ...
    +       eclat - zero frequent items
    +       set of 0 itemsets
    +     */
         assert(model6.freqItemsets.count() === 0)
     
         val model3 = fpg
           .setMinSupport(0.5)
           .setNumPartitions(2)
    -      .setOrdered(false)
           .run(rdd)
         val freqItemsets3 = model3.freqItemsets.collect().map { itemset =>
           (itemset.items.toSet, itemset.freq)
         }
    +
    +    /* Verify results using the `R` code:
    +       fp = eclat(transactions, parameter = list(support = 0.5))
    +       fpDF = as(sort(fp), "data.frame")
    +       fpDF$support = fpDF$support * length(transactions)
    +       names(fpDF)[names(fpDF) == "support"] = "freq"
    +       > fpDF
    +              items freq
    +       13       {z}    5
    +       14       {x}    4
    +       1      {s,x}    3
    +       2  {t,x,y,z}    3
    +       3    {t,y,z}    3
    +       4    {t,x,y}    3
    +       5    {x,y,z}    3
    +       6      {y,z}    3
    +       7      {x,y}    3
    +       8      {t,y}    3
    +       9    {t,x,z}    3
    +       10     {t,z}    3
    +       11     {t,x}    3
    +       12     {x,z}    3
    +       15       {t}    3
    +       16       {y}    3
    +       17       {s}    3
    +       18       {r}    3
    +     */
         val expected = Set(
           (Set("s"), 3L), (Set("z"), 5L), (Set("x"), 4L), (Set("t"), 3L), (Set("y"), 3L),
           (Set("r"), 3L),
    @@ -63,19 +104,35 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
         val model2 = fpg
           .setMinSupport(0.3)
           .setNumPartitions(4)
    -      .setOrdered(false)
           .run(rdd)
    +
    +    /* Verify results using the `R` code:
    +       fp = eclat(transactions, parameter = list(support = 0.3))
    +       fpDF = as(fp, "data.frame")
    +       fpDF$support = fpDF$support * length(transactions)
    +       names(fpDF)[names(fpDF) == "support"] = "freq"
    +       > nrow(fpDF)
    +       [1] 54
    +     */
         assert(model2.freqItemsets.count() === 54)
     
         val model1 = fpg
           .setMinSupport(0.1)
           .setNumPartitions(8)
    -      .setOrdered(false)
           .run(rdd)
    +
    +    /* Verify results using the `R` code:
    +       fp = eclat(transactions, parameter = list(support = 0.1))
    +       fpDF = as(fp, "data.frame")
    +       fpDF$support = fpDF$support * length(transactions)
    +       names(fpDF)[names(fpDF) == "support"] = "freq"
    +       > nrow(fpDF)
    +       [1] 625
    +     */
         assert(model1.freqItemsets.count() === 625)
       }
     
    -  test("FP-Growth frequent sequences using String type"){
    +  test("FP-Growth String type association rule generation") {
         val transactions = Seq(
           "r z h k p",
           "z y x w v u t s",
    @@ -86,36 +143,38 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
           .map(_.split(" "))
         val rdd = sc.parallelize(transactions, 2).cache()
     
    -    val fpg = new FPGrowth()
    -
    -    val model1 = fpg
    +    /* Verify results using the `R` code:
    +       transactions = as(sapply(
    +         list("r z h k p",
    +              "z y x w v u t s",
    +              "s x o n r",
    +              "x z y m t s q e",
    +              "z",
    +              "x z y r q t p"),
    +         FUN=function(x) strsplit(x," ",fixed=TRUE)),
    +         "transactions")
    +       ars = apriori(transactions,
    +                     parameter = list(support = 0.0, confidence = 0.5, target="rules", minlen=2))
    +       arsDF = as(ars, "data.frame")
    +       arsDF$support = arsDF$support * length(transactions)
    +       names(arsDF)[names(arsDF) == "support"] = "freq"
    +       > nrow(arsDF)
    +       [1] 23
    +       > sum(arsDF$confidence == 1)
    +       [1] 23
    +     */
    +    val rules = (new FPGrowth())
           .setMinSupport(0.5)
           .setNumPartitions(2)
    -      .setOrdered(true)
           .run(rdd)
    +      .generateAssociationRules(0.9)
    +      .collect()
     
    -    /*
    -      Use the following R code to verify association rules using arulesSequences package.
    -
    -      data = read_baskets("path", info = c("sequenceID","eventID","SIZE"))
    -      freqItemSeq = cspade(data, parameter = list(support = 0.5))
    -      resSeq = as(freqItemSeq, "data.frame")
    -      resSeq$support = resSeq$support * length(transactions)
    -      names(resSeq)[names(resSeq) == "support"] = "freq"
    -      resSeq
    -     */
    -    val expected = Set(
    -      (Seq("r"), 3L), (Seq("s"), 3L), (Seq("t"), 3L), (Seq("x"), 4L), (Seq("y"), 3L),
    -      (Seq("z"), 5L), (Seq("z", "y"), 3L), (Seq("x", "t"), 3L), (Seq("y", "t"), 3L),
    -      (Seq("z", "t"), 3L), (Seq("z", "y", "t"), 3L)
    -    )
    -    val freqItemseqs1 = model1.freqItemsets.collect().map { itemset =>
    -      (itemset.items.toSeq, itemset.freq)
    -    }.toSet
    -    assert(freqItemseqs1 == expected)
    +    assert(rules.size === 23)
    +    assert(rules.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23)
       }
     
    -  test("FP-Growth frequent itemsets using Int type") {
    +  test("FP-Growth using Int type") {
         val transactions = Seq(
           "1 2 3",
           "1 2 3 4",
    @@ -132,20 +191,53 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
         val model6 = fpg
           .setMinSupport(0.9)
           .setNumPartitions(1)
    -      .setOrdered(false)
           .run(rdd)
    +
    +    /* Verify results using the `R` code:
    +       transactions = as(sapply(
    +         list("1 2 3",
    +              "1 2 3 4",
    +              "5 4 3 2 1",
    +              "6 5 4 3 2 1",
    +              "2 4",
    +              "1 3",
    +              "1 7"),
    +         FUN=function(x) strsplit(x," ",fixed=TRUE)),
    +         "transactions")
    +       > eclat(transactions, parameter = list(support = 0.9))
    +       ...
    +       eclat - zero frequent items
    +       set of 0 itemsets
    +     */
         assert(model6.freqItemsets.count() === 0)
     
         val model3 = fpg
           .setMinSupport(0.5)
           .setNumPartitions(2)
    -      .setOrdered(false)
           .run(rdd)
         assert(model3.freqItemsets.first().items.getClass === Array(1).getClass,
           "frequent itemsets should use primitive arrays")
         val freqItemsets3 = model3.freqItemsets.collect().map { itemset =>
           (itemset.items.toSet, itemset.freq)
         }
    +
    +    /* Verify results using the `R` code:
    +       fp = eclat(transactions, parameter = list(support = 0.5))
    +       fpDF = as(sort(fp), "data.frame")
    +       fpDF$support = fpDF$support * length(transactions)
    +       names(fpDF)[names(fpDF) == "support"] = "freq"
    +       > fpDF
    +          items freq
    +      6     {1}    6
    +      3   {1,3}    5
    +      7     {2}    5
    +      8     {3}    5
    +      1   {2,4}    4
    +      2 {1,2,3}    4
    +      4   {2,3}    4
    +      5   {1,2}    4
    +      9     {4}    4
    +     */
         val expected = Set(
           (Set(1), 6L), (Set(2), 5L), (Set(3), 5L), (Set(4), 4L),
           (Set(1, 2), 4L), (Set(1, 3), 5L), (Set(2, 3), 4L),
    @@ -155,15 +247,31 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
         val model2 = fpg
           .setMinSupport(0.3)
           .setNumPartitions(4)
    -      .setOrdered(false)
           .run(rdd)
    +
    +    /* Verify results using the `R` code:
    +       fp = eclat(transactions, parameter = list(support = 0.3))
    +       fpDF = as(fp, "data.frame")
    +       fpDF$support = fpDF$support * length(transactions)
    +       names(fpDF)[names(fpDF) == "support"] = "freq"
    +       > nrow(fpDF)
    +       [1] 15
    +     */
         assert(model2.freqItemsets.count() === 15)
     
         val model1 = fpg
           .setMinSupport(0.1)
           .setNumPartitions(8)
    -      .setOrdered(false)
           .run(rdd)
    +
    +    /* Verify results using the `R` code:
    +       fp = eclat(transactions, parameter = list(support = 0.1))
    +       fpDF = as(fp, "data.frame")
    +       fpDF$support = fpDF$support * length(transactions)
    +       names(fpDF)[names(fpDF) == "support"] = "freq"
    +       > nrow(fpDF)
    +       [1] 65
    +     */
         assert(model1.freqItemsets.count() === 65)
       }
     }
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
    new file mode 100644
    index 0000000000000..6dd2dc926acc5
    --- /dev/null
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
    @@ -0,0 +1,113 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +package org.apache.spark.mllib.fpm
    +
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.mllib.util.MLlibTestSparkContext
    +
    +class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
    +
    +  test("PrefixSpan using Integer type") {
    +
    +    /*
    +      library("arulesSequences")
    +      prefixSpanSeqs = read_baskets("prefixSpanSeqs", info = c("sequenceID","eventID","SIZE"))
    +      freqItemSeq = cspade(
    +        prefixSpanSeqs,
    +        parameter = list(support =
    +          2 / length(unique(transactionInfo(prefixSpanSeqs)$sequenceID)), maxlen = 2 ))
    +      resSeq = as(freqItemSeq, "data.frame")
    +      resSeq
    +    */
    +
    +    val sequences = Array(
    +      Array(1, 3, 4, 5),
    +      Array(2, 3, 1),
    +      Array(2, 4, 1),
    +      Array(3, 1, 3, 4, 5),
    +      Array(3, 4, 4, 3),
    +      Array(6, 5, 3))
    +
    +    val rdd = sc.parallelize(sequences, 2).cache()
    +
    +    val prefixspan = new PrefixSpan()
    +      .setMinSupport(0.33)
    +      .setMaxPatternLength(50)
    +    val result1 = prefixspan.run(rdd)
    +    val expectedValue1 = Array(
    +      (Array(1), 4L),
    +      (Array(1, 3), 2L),
    +      (Array(1, 3, 4), 2L),
    +      (Array(1, 3, 4, 5), 2L),
    +      (Array(1, 3, 5), 2L),
    +      (Array(1, 4), 2L),
    +      (Array(1, 4, 5), 2L),
    +      (Array(1, 5), 2L),
    +      (Array(2), 2L),
    +      (Array(2, 1), 2L),
    +      (Array(3), 5L),
    +      (Array(3, 1), 2L),
    +      (Array(3, 3), 2L),
    +      (Array(3, 4), 3L),
    +      (Array(3, 4, 5), 2L),
    +      (Array(3, 5), 2L),
    +      (Array(4), 4L),
    +      (Array(4, 5), 2L),
    +      (Array(5), 3L)
    +    )
    +    assert(compareResults(expectedValue1, result1.collect()))
    +
    +    prefixspan.setMinSupport(0.5).setMaxPatternLength(50)
    +    val result2 = prefixspan.run(rdd)
    +    val expectedValue2 = Array(
    +      (Array(1), 4L),
    +      (Array(3), 5L),
    +      (Array(3, 4), 3L),
    +      (Array(4), 4L),
    +      (Array(5), 3L)
    +    )
    +    assert(compareResults(expectedValue2, result2.collect()))
    +
    +    prefixspan.setMinSupport(0.33).setMaxPatternLength(2)
    +    val result3 = prefixspan.run(rdd)
    +    val expectedValue3 = Array(
    +      (Array(1), 4L),
    +      (Array(1, 3), 2L),
    +      (Array(1, 4), 2L),
    +      (Array(1, 5), 2L),
    +      (Array(2, 1), 2L),
    +      (Array(2), 2L),
    +      (Array(3), 5L),
    +      (Array(3, 1), 2L),
    +      (Array(3, 3), 2L),
    +      (Array(3, 4), 3L),
    +      (Array(3, 5), 2L),
    +      (Array(4), 4L),
    +      (Array(4, 5), 2L),
    +      (Array(5), 3L)
    +    )
    +    assert(compareResults(expectedValue3, result3.collect()))
    +  }
    +
    +  private def compareResults(
    +    expectedValue: Array[(Array[Int], Long)],
    +    actualValue: Array[(Array[Int], Long)]): Boolean = {
    +    expectedValue.map(x => (x._1.toSeq, x._2)).toSet ==
    +      actualValue.map(x => (x._1.toSeq, x._2)).toSet
    +  }
    +
    +}
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
    index d34888af2d73b..e331c75989187 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
    @@ -30,20 +30,20 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo
     
       import PeriodicGraphCheckpointerSuite._
     
    -  // TODO: Do I need to call count() on the graphs' RDDs?
    -
       test("Persisting") {
         var graphsToCheck = Seq.empty[GraphToCheck]
     
         val graph1 = createGraph(sc)
    -    val checkpointer = new PeriodicGraphCheckpointer(graph1, 10)
    +    val checkpointer =
    +      new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext)
    +    checkpointer.update(graph1)
         graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
         checkPersistence(graphsToCheck, 1)
     
         var iteration = 2
         while (iteration < 9) {
           val graph = createGraph(sc)
    -      checkpointer.updateGraph(graph)
    +      checkpointer.update(graph)
           graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration)
           checkPersistence(graphsToCheck, iteration)
           iteration += 1
    @@ -57,7 +57,9 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo
         var graphsToCheck = Seq.empty[GraphToCheck]
         sc.setCheckpointDir(path)
         val graph1 = createGraph(sc)
    -    val checkpointer = new PeriodicGraphCheckpointer(graph1, checkpointInterval)
    +    val checkpointer = new PeriodicGraphCheckpointer[Double, Double](
    +      checkpointInterval, graph1.vertices.sparkContext)
    +    checkpointer.update(graph1)
         graph1.edges.count()
         graph1.vertices.count()
         graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
    @@ -66,7 +68,7 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo
         var iteration = 2
         while (iteration < 9) {
           val graph = createGraph(sc)
    -      checkpointer.updateGraph(graph)
    +      checkpointer.update(graph)
           graph.vertices.count()
           graph.edges.count()
           graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration)
    @@ -168,7 +170,7 @@ private object PeriodicGraphCheckpointerSuite {
           } else {
             // Graph should never be checkpointed
             assert(!graph.isCheckpointed, "Graph should never have been checkpointed")
    -        assert(graph.getCheckpointFiles.length == 0, "Graph should not have any checkpoint files")
    +        assert(graph.getCheckpointFiles.isEmpty, "Graph should not have any checkpoint files")
           }
         } catch {
           case e: AssertionError =>
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala
    new file mode 100644
    index 0000000000000..b2a459a68b5fa
    --- /dev/null
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala
    @@ -0,0 +1,173 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.mllib.impl
    +
    +import org.apache.hadoop.fs.{FileSystem, Path}
    +
    +import org.apache.spark.{SparkContext, SparkFunSuite}
    +import org.apache.spark.mllib.util.MLlibTestSparkContext
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.storage.StorageLevel
    +import org.apache.spark.util.Utils
    +
    +
    +class PeriodicRDDCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext {
    +
    +  import PeriodicRDDCheckpointerSuite._
    +
    +  test("Persisting") {
    +    var rddsToCheck = Seq.empty[RDDToCheck]
    +
    +    val rdd1 = createRDD(sc)
    +    val checkpointer = new PeriodicRDDCheckpointer[Double](10, rdd1.sparkContext)
    +    checkpointer.update(rdd1)
    +    rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1)
    +    checkPersistence(rddsToCheck, 1)
    +
    +    var iteration = 2
    +    while (iteration < 9) {
    +      val rdd = createRDD(sc)
    +      checkpointer.update(rdd)
    +      rddsToCheck = rddsToCheck :+ RDDToCheck(rdd, iteration)
    +      checkPersistence(rddsToCheck, iteration)
    +      iteration += 1
    +    }
    +  }
    +
    +  test("Checkpointing") {
    +    val tempDir = Utils.createTempDir()
    +    val path = tempDir.toURI.toString
    +    val checkpointInterval = 2
    +    var rddsToCheck = Seq.empty[RDDToCheck]
    +    sc.setCheckpointDir(path)
    +    val rdd1 = createRDD(sc)
    +    val checkpointer = new PeriodicRDDCheckpointer[Double](checkpointInterval, rdd1.sparkContext)
    +    checkpointer.update(rdd1)
    +    rdd1.count()
    +    rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1)
    +    checkCheckpoint(rddsToCheck, 1, checkpointInterval)
    +
    +    var iteration = 2
    +    while (iteration < 9) {
    +      val rdd = createRDD(sc)
    +      checkpointer.update(rdd)
    +      rdd.count()
    +      rddsToCheck = rddsToCheck :+ RDDToCheck(rdd, iteration)
    +      checkCheckpoint(rddsToCheck, iteration, checkpointInterval)
    +      iteration += 1
    +    }
    +
    +    checkpointer.deleteAllCheckpoints()
    +    rddsToCheck.foreach { rdd =>
    +      confirmCheckpointRemoved(rdd.rdd)
    +    }
    +
    +    Utils.deleteRecursively(tempDir)
    +  }
    +}
    +
    +private object PeriodicRDDCheckpointerSuite {
    +
    +  case class RDDToCheck(rdd: RDD[Double], gIndex: Int)
    +
    +  def createRDD(sc: SparkContext): RDD[Double] = {
    +    sc.parallelize(Seq(0.0, 1.0, 2.0, 3.0))
    +  }
    +
    +  def checkPersistence(rdds: Seq[RDDToCheck], iteration: Int): Unit = {
    +    rdds.foreach { g =>
    +      checkPersistence(g.rdd, g.gIndex, iteration)
    +    }
    +  }
    +
    +  /**
    +   * Check storage level of rdd.
    +   * @param gIndex  Index of rdd in order inserted into checkpointer (from 1).
    +   * @param iteration  Total number of rdds inserted into checkpointer.
    +   */
    +  def checkPersistence(rdd: RDD[_], gIndex: Int, iteration: Int): Unit = {
    +    try {
    +      if (gIndex + 2 < iteration) {
    +        assert(rdd.getStorageLevel == StorageLevel.NONE)
    +      } else {
    +        assert(rdd.getStorageLevel != StorageLevel.NONE)
    +      }
    +    } catch {
    +      case _: AssertionError =>
    +        throw new Exception(s"PeriodicRDDCheckpointerSuite.checkPersistence failed with:\n" +
    +          s"\t gIndex = $gIndex\n" +
    +          s"\t iteration = $iteration\n" +
    +          s"\t rdd.getStorageLevel = ${rdd.getStorageLevel}\n")
    +    }
    +  }
    +
    +  def checkCheckpoint(rdds: Seq[RDDToCheck], iteration: Int, checkpointInterval: Int): Unit = {
    +    rdds.reverse.foreach { g =>
    +      checkCheckpoint(g.rdd, g.gIndex, iteration, checkpointInterval)
    +    }
    +  }
    +
    +  def confirmCheckpointRemoved(rdd: RDD[_]): Unit = {
    +    // Note: We cannot check rdd.isCheckpointed since that value is never updated.
    +    //       Instead, we check for the presence of the checkpoint files.
    +    //       This test should continue to work even after this rdd.isCheckpointed issue
    +    //       is fixed (though it can then be simplified and not look for the files).
    +    val fs = FileSystem.get(rdd.sparkContext.hadoopConfiguration)
    +    rdd.getCheckpointFile.foreach { checkpointFile =>
    +      assert(!fs.exists(new Path(checkpointFile)), "RDD checkpoint file should have been removed")
    +    }
    +  }
    +
    +  /**
    +   * Check checkpointed status of rdd.
    +   * @param gIndex  Index of rdd in order inserted into checkpointer (from 1).
    +   * @param iteration  Total number of rdds inserted into checkpointer.
    +   */
    +  def checkCheckpoint(
    +      rdd: RDD[_],
    +      gIndex: Int,
    +      iteration: Int,
    +      checkpointInterval: Int): Unit = {
    +    try {
    +      if (gIndex % checkpointInterval == 0) {
    +        // We allow 2 checkpoint intervals since we perform an action (checkpointing a second rdd)
    +        // only AFTER PeriodicRDDCheckpointer decides whether to remove the previous checkpoint.
    +        if (iteration - 2 * checkpointInterval < gIndex && gIndex <= iteration) {
    +          assert(rdd.isCheckpointed, "RDD should be checkpointed")
    +          assert(rdd.getCheckpointFile.nonEmpty, "RDD should have 2 checkpoint files")
    +        } else {
    +          confirmCheckpointRemoved(rdd)
    +        }
    +      } else {
    +        // RDD should never be checkpointed
    +        assert(!rdd.isCheckpointed, "RDD should never have been checkpointed")
    +        assert(rdd.getCheckpointFile.isEmpty, "RDD should not have any checkpoint files")
    +      }
    +    } catch {
    +      case e: AssertionError =>
    +        throw new Exception(s"PeriodicRDDCheckpointerSuite.checkCheckpoint failed with:\n" +
    +          s"\t gIndex = $gIndex\n" +
    +          s"\t iteration = $iteration\n" +
    +          s"\t checkpointInterval = $checkpointInterval\n" +
    +          s"\t rdd.isCheckpointed = ${rdd.isCheckpointed}\n" +
    +          s"\t rdd.getCheckpointFile = ${rdd.getCheckpointFile.mkString(", ")}\n" +
    +          s"  AssertionError message: ${e.getMessage}")
    +    }
    +  }
    +
    +}
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
    index b0f3f71113c57..d119e0b50a393 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
    @@ -200,8 +200,14 @@ class BLASSuite extends SparkFunSuite {
         val C10 = C1.copy
         val C11 = C1.copy
         val C12 = C1.copy
    +    val C13 = C1.copy
    +    val C14 = C1.copy
    +    val C15 = C1.copy
    +    val C16 = C1.copy
         val expected2 = new DenseMatrix(4, 2, Array(2.0, 1.0, 4.0, 2.0, 4.0, 0.0, 4.0, 3.0))
         val expected3 = new DenseMatrix(4, 2, Array(2.0, 2.0, 4.0, 2.0, 8.0, 0.0, 6.0, 6.0))
    +    val expected4 = new DenseMatrix(4, 2, Array(5.0, 0.0, 10.0, 5.0, 0.0, 0.0, 5.0, 0.0))
    +    val expected5 = C1.copy
     
         gemm(1.0, dA, B, 2.0, C1)
         gemm(1.0, sA, B, 2.0, C2)
    @@ -248,6 +254,16 @@ class BLASSuite extends SparkFunSuite {
         assert(C10 ~== expected2 absTol 1e-15)
         assert(C11 ~== expected3 absTol 1e-15)
         assert(C12 ~== expected3 absTol 1e-15)
    +
    +    gemm(0, dA, B, 5, C13)
    +    gemm(0, sA, B, 5, C14)
    +    gemm(0, dA, B, 1, C15)
    +    gemm(0, sA, B, 1, C16)
    +    assert(C13 ~== expected4 absTol 1e-15)
    +    assert(C14 ~== expected4 absTol 1e-15)
    +    assert(C15 ~== expected5 absTol 1e-15)
    +    assert(C16 ~== expected5 absTol 1e-15)
    +
       }
     
       test("gemv") {
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
    index c4ae0a16f7c04..1c37ea5123e82 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
    @@ -21,10 +21,10 @@ import scala.util.Random
     
     import breeze.linalg.{DenseMatrix => BDM, squaredDistance => breezeSquaredDistance}
     
    -import org.apache.spark.{SparkException, SparkFunSuite}
    +import org.apache.spark.{Logging, SparkException, SparkFunSuite}
     import org.apache.spark.mllib.util.TestingUtils._
     
    -class VectorsSuite extends SparkFunSuite {
    +class VectorsSuite extends SparkFunSuite with Logging {
     
       val arr = Array(0.1, 0.0, 0.3, 0.4)
       val n = 4
    @@ -57,16 +57,70 @@ class VectorsSuite extends SparkFunSuite {
         assert(vec.values === values)
       }
     
    +  test("sparse vector construction with mismatched indices/values array") {
    +    intercept[IllegalArgumentException] {
    +      Vectors.sparse(4, Array(1, 2, 3), Array(3.0, 5.0, 7.0, 9.0))
    +    }
    +    intercept[IllegalArgumentException] {
    +      Vectors.sparse(4, Array(1, 2, 3), Array(3.0, 5.0))
    +    }
    +  }
    +
    +  test("sparse vector construction with too many indices vs size") {
    +    intercept[IllegalArgumentException] {
    +      Vectors.sparse(3, Array(1, 2, 3, 4), Array(3.0, 5.0, 7.0, 9.0))
    +    }
    +  }
    +
       test("dense to array") {
         val vec = Vectors.dense(arr).asInstanceOf[DenseVector]
         assert(vec.toArray.eq(arr))
       }
     
    +  test("dense argmax") {
    +    val vec = Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector]
    +    assert(vec.argmax === -1)
    +
    +    val vec2 = Vectors.dense(arr).asInstanceOf[DenseVector]
    +    assert(vec2.argmax === 3)
    +
    +    val vec3 = Vectors.dense(Array(-1.0, 0.0, -2.0, 1.0)).asInstanceOf[DenseVector]
    +    assert(vec3.argmax === 3)
    +  }
    +
       test("sparse to array") {
         val vec = Vectors.sparse(n, indices, values).asInstanceOf[SparseVector]
         assert(vec.toArray === arr)
       }
     
    +  test("sparse argmax") {
    +    val vec = Vectors.sparse(0, Array.empty[Int], Array.empty[Double]).asInstanceOf[SparseVector]
    +    assert(vec.argmax === -1)
    +
    +    val vec2 = Vectors.sparse(n, indices, values).asInstanceOf[SparseVector]
    +    assert(vec2.argmax === 3)
    +
    +    val vec3 = Vectors.sparse(5, Array(2, 3, 4), Array(1.0, 0.0, -.7))
    +    assert(vec3.argmax === 2)
    +
    +    // check for case that sparse vector is created with
    +    // only negative values {0.0, 0.0,-1.0, -0.7, 0.0}
    +    val vec4 = Vectors.sparse(5, Array(2, 3), Array(-1.0, -.7))
    +    assert(vec4.argmax === 0)
    +
    +    val vec5 = Vectors.sparse(11, Array(0, 3, 10), Array(-1.0, -.7, 0.0))
    +    assert(vec5.argmax === 1)
    +
    +    val vec6 = Vectors.sparse(11, Array(0, 1, 2), Array(-1.0, -.7, 0.0))
    +    assert(vec6.argmax === 2)
    +
    +    val vec7 = Vectors.sparse(5, Array(0, 1, 3), Array(-1.0, 0.0, -.7))
    +    assert(vec7.argmax === 1)
    +
    +    val vec8 = Vectors.sparse(5, Array(1, 2), Array(0.0, -1.0))
    +    assert(vec8.argmax === 0)
    +  }
    +
       test("vector equals") {
         val dv1 = Vectors.dense(arr.clone())
         val dv2 = Vectors.dense(arr.clone())
    @@ -142,7 +196,7 @@ class VectorsSuite extends SparkFunSuite {
         malformatted.foreach { s =>
           intercept[SparkException] {
             Vectors.parse(s)
    -        println(s"Didn't detect malformatted string $s.")
    +        logInfo(s"Didn't detect malformatted string $s.")
           }
         }
       }
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
    index b6cb53d0c743e..283ffec1d49d7 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
    @@ -19,6 +19,7 @@ package org.apache.spark.mllib.linalg.distributed
     
     import scala.util.Random
     
    +import breeze.numerics.abs
     import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, norm => brzNorm, svd => brzSvd}
     
     import org.apache.spark.SparkFunSuite
    @@ -238,6 +239,22 @@ class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
           }
         }
       }
    +
    +  test("QR Decomposition") {
    +    for (mat <- Seq(denseMat, sparseMat)) {
    +      val result = mat.tallSkinnyQR(true)
    +      val expected = breeze.linalg.qr.reduced(mat.toBreeze())
    +      val calcQ = result.Q
    +      val calcR = result.R
    +      assert(closeToZero(abs(expected.q) - abs(calcQ.toBreeze())))
    +      assert(closeToZero(abs(expected.r) - abs(calcR.toBreeze.asInstanceOf[BDM[Double]])))
    +      assert(closeToZero(calcQ.multiply(calcR).toBreeze - mat.toBreeze()))
    +      // Decomposition without computing Q
    +      val rOnly = mat.tallSkinnyQR(computeQ = false)
    +      assert(rOnly.Q == null)
    +      assert(closeToZero(abs(expected.r) - abs(rOnly.R.toBreeze.asInstanceOf[BDM[Double]])))
    +    }
    +  }
     }
     
     class RowMatrixClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
    index a2a4c5f6b8b70..34c07ed170816 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
    @@ -22,14 +22,23 @@ import scala.collection.mutable.ArrayBuffer
     import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.linalg.Vectors
     import org.apache.spark.mllib.util.LinearDataGenerator
    +import org.apache.spark.streaming.{StreamingContext, TestSuiteBase}
     import org.apache.spark.streaming.dstream.DStream
    -import org.apache.spark.streaming.TestSuiteBase
     
     class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
     
       // use longer wait time to ensure job completion
       override def maxWaitTimeMillis: Int = 20000
     
    +  var ssc: StreamingContext = _
    +
    +  override def afterFunction() {
    +    super.afterFunction()
    +    if (ssc != null) {
    +      ssc.stop()
    +    }
    +  }
    +
       // Assert that two values are equal within tolerance epsilon
       def assertEqual(v1: Double, v2: Double, epsilon: Double) {
         def errorMessage = v1.toString + " did not equal " + v2.toString
    @@ -62,7 +71,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
         }
     
         // apply model training to input stream
    -    val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
    +    ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
           model.trainOn(inputDStream)
           inputDStream.count()
         })
    @@ -98,7 +107,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
     
         // apply model training to input stream, storing the intermediate results
         // (we add a count to ensure the result is a DStream)
    -    val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
    +    ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
           model.trainOn(inputDStream)
           inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0) - 10.0)))
           inputDStream.count()
    @@ -129,7 +138,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
         }
     
         // apply model predictions to test stream
    -    val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
    +    ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
           model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
         })
         // collect the output as (true, estimated) tuples
    @@ -156,7 +165,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
         }
     
         // train and predict
    -    val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
    +    ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
           model.trainOn(inputDStream)
           model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
         })
    @@ -177,7 +186,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
         val numBatches = 10
         val nPoints = 100
         val emptyInput = Seq.empty[Seq[LabeledPoint]]
    -    val ssc = setupStreams(emptyInput,
    +    ssc = setupStreams(emptyInput,
           (inputDStream: DStream[LabeledPoint]) => {
             model.trainOn(inputDStream)
             model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala
    index c292ced75e870..c3eeda012571c 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala
    @@ -19,13 +19,13 @@ package org.apache.spark.mllib.stat
     
     import breeze.linalg.{DenseMatrix => BDM, Matrix => BM}
     
    -import org.apache.spark.SparkFunSuite
    +import org.apache.spark.{Logging, SparkFunSuite}
     import org.apache.spark.mllib.linalg.Vectors
     import org.apache.spark.mllib.stat.correlation.{Correlations, PearsonCorrelation,
       SpearmanCorrelation}
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     
    -class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext {
    +class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
     
       // test input data
       val xData = Array(1.0, 0.0, -2.0)
    @@ -146,7 +146,7 @@ class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext {
       def matrixApproxEqual(A: BM[Double], B: BM[Double], threshold: Double = 1e-6): Boolean = {
         for (i <- 0 until A.rows; j <- 0 until A.cols) {
           if (!approxEqual(A(i, j), B(i, j), threshold)) {
    -        println("i, j = " + i + ", " + j + " actual: " + A(i, j) + " expected:" + B(i, j))
    +        logInfo("i, j = " + i + ", " + j + " actual: " + A(i, j) + " expected:" + B(i, j))
             return false
           }
         }
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala
    index b084a5fb4313f..142b90e764a7c 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala
    @@ -19,6 +19,10 @@ package org.apache.spark.mllib.stat
     
     import java.util.Random
     
    +import org.apache.commons.math3.distribution.{ExponentialDistribution,
    +  NormalDistribution, UniformRealDistribution}
    +import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest
    +
     import org.apache.spark.{SparkException, SparkFunSuite}
     import org.apache.spark.mllib.linalg.{DenseVector, Matrices, Vectors}
     import org.apache.spark.mllib.regression.LabeledPoint
    @@ -153,4 +157,101 @@ class HypothesisTestSuite extends SparkFunSuite with MLlibTestSparkContext {
           Statistics.chiSqTest(sc.parallelize(continuousFeature, 2))
         }
       }
    +
    +  test("1 sample Kolmogorov-Smirnov test: apache commons math3 implementation equivalence") {
    +    // Create theoretical distributions
    +    val stdNormalDist = new NormalDistribution(0, 1)
    +    val expDist = new ExponentialDistribution(0.6)
    +    val unifDist = new UniformRealDistribution()
    +
    +    // set seeds
    +    val seed = 10L
    +    stdNormalDist.reseedRandomGenerator(seed)
    +    expDist.reseedRandomGenerator(seed)
    +    unifDist.reseedRandomGenerator(seed)
    +
    +    // Sample data from the distributions and parallelize it
    +    val n = 100000
    +    val sampledNorm = sc.parallelize(stdNormalDist.sample(n), 10)
    +    val sampledExp = sc.parallelize(expDist.sample(n), 10)
    +    val sampledUnif = sc.parallelize(unifDist.sample(n), 10)
    +
    +    // Use a apache math commons local KS test to verify calculations
    +    val ksTest = new KolmogorovSmirnovTest()
    +    val pThreshold = 0.05
    +
    +    // Comparing a standard normal sample to a standard normal distribution
    +    val result1 = Statistics.kolmogorovSmirnovTest(sampledNorm, "norm", 0, 1)
    +    val referenceStat1 = ksTest.kolmogorovSmirnovStatistic(stdNormalDist, sampledNorm.collect())
    +    val referencePVal1 = 1 - ksTest.cdf(referenceStat1, n)
    +    // Verify vs apache math commons ks test
    +    assert(result1.statistic ~== referenceStat1 relTol 1e-4)
    +    assert(result1.pValue ~== referencePVal1 relTol 1e-4)
    +    // Cannot reject null hypothesis
    +    assert(result1.pValue > pThreshold)
    +
    +    // Comparing an exponential sample to a standard normal distribution
    +    val result2 = Statistics.kolmogorovSmirnovTest(sampledExp, "norm", 0, 1)
    +    val referenceStat2 = ksTest.kolmogorovSmirnovStatistic(stdNormalDist, sampledExp.collect())
    +    val referencePVal2 = 1 - ksTest.cdf(referenceStat2, n)
    +    // verify vs apache math commons ks test
    +    assert(result2.statistic ~== referenceStat2 relTol 1e-4)
    +    assert(result2.pValue ~== referencePVal2 relTol 1e-4)
    +    // reject null hypothesis
    +    assert(result2.pValue < pThreshold)
    +
    +    // Testing the use of a user provided CDF function
    +    // Distribution is not serializable, so will have to create in the lambda
    +    val expCDF = (x: Double) => new ExponentialDistribution(0.2).cumulativeProbability(x)
    +
    +    // Comparing an exponential sample with mean X to an exponential distribution with mean Y
    +    // Where X != Y
    +    val result3 = Statistics.kolmogorovSmirnovTest(sampledExp, expCDF)
    +    val referenceStat3 = ksTest.kolmogorovSmirnovStatistic(new ExponentialDistribution(0.2),
    +      sampledExp.collect())
    +    val referencePVal3 = 1 - ksTest.cdf(referenceStat3, sampledNorm.count().toInt)
    +    // verify vs apache math commons ks test
    +    assert(result3.statistic ~== referenceStat3 relTol 1e-4)
    +    assert(result3.pValue ~== referencePVal3 relTol 1e-4)
    +    // reject null hypothesis
    +    assert(result3.pValue < pThreshold)
    +  }
    +
    +  test("1 sample Kolmogorov-Smirnov test: R implementation equivalence") {
    +    /*
    +      Comparing results with R's implementation of Kolmogorov-Smirnov for 1 sample
    +      > sessionInfo()
    +      R version 3.2.0 (2015-04-16)
    +      Platform: x86_64-apple-darwin13.4.0 (64-bit)
    +      > set.seed(20)
    +      > v <- rnorm(20)
    +      > v
    +       [1]  1.16268529 -0.58592447  1.78546500 -1.33259371 -0.44656677  0.56960612
    +       [7] -2.88971761 -0.86901834 -0.46170268 -0.55554091 -0.02013537 -0.15038222
    +      [13] -0.62812676  1.32322085 -1.52135057 -0.43742787  0.97057758  0.02822264
    +      [19] -0.08578219  0.38921440
    +      > ks.test(v, pnorm, alternative = "two.sided")
    +
    +               One-sample Kolmogorov-Smirnov test
    +
    +      data:  v
    +      D = 0.18874, p-value = 0.4223
    +      alternative hypothesis: two-sided
    +    */
    +
    +    val rKSStat = 0.18874
    +    val rKSPVal = 0.4223
    +    val rData = sc.parallelize(
    +      Array(
    +        1.1626852897838, -0.585924465893051, 1.78546500331661, -1.33259371048501,
    +        -0.446566766553219, 0.569606122374976, -2.88971761441412, -0.869018343326555,
    +        -0.461702683149641, -0.555540910137444, -0.0201353678515895, -0.150382224136063,
    +        -0.628126755843964, 1.32322085193283, -1.52135057001199, -0.437427868856691,
    +        0.970577579543399, 0.0282226444247749, -0.0857821886527593, 0.389214404984942
    +      )
    +    )
    +    val rCompResult = Statistics.kolmogorovSmirnovTest(rData, "norm", 0, 1)
    +    assert(rCompResult.statistic ~== rKSStat relTol 1e-4)
    +    assert(rCompResult.pValue ~== rKSPVal relTol 1e-4)
    +  }
     }
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
    index 8972c229b7ecb..334bf3790fc7a 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
    @@ -70,7 +70,7 @@ object EnsembleTestHelper {
           metricName: String = "mse") {
         val predictions = input.map(x => model.predict(x.features))
         val errors = predictions.zip(input.map(_.label)).map { case (prediction, label) =>
    -      prediction - label
    +      label - prediction
         }
         val metric = metricName match {
           case "mse" =>
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
    index 84dd3b342d4c0..6fc9e8df621df 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
    @@ -17,7 +17,7 @@
     
     package org.apache.spark.mllib.tree
     
    -import org.apache.spark.SparkFunSuite
    +import org.apache.spark.{Logging, SparkFunSuite}
     import org.apache.spark.mllib.regression.LabeledPoint
     import org.apache.spark.mllib.tree.configuration.Algo._
     import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy}
    @@ -31,7 +31,7 @@ import org.apache.spark.util.Utils
     /**
      * Test suite for [[GradientBoostedTrees]].
      */
    -class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext {
    +class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
     
       test("Regression with continuous features: SquaredError") {
         GradientBoostedTreesSuite.testCombinations.foreach {
    @@ -50,7 +50,7 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
               EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.06)
             } catch {
               case e: java.lang.AssertionError =>
    -            println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
    +            logError(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
                   s" subsamplingRate=$subsamplingRate")
                 throw e
             }
    @@ -80,7 +80,7 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
               EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.85, "mae")
             } catch {
               case e: java.lang.AssertionError =>
    -            println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
    +            logError(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
                   s" subsamplingRate=$subsamplingRate")
                 throw e
             }
    @@ -111,7 +111,7 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
               EnsembleTestHelper.validateClassifier(gbt, GradientBoostedTreesSuite.data, 0.9)
             } catch {
               case e: java.lang.AssertionError =>
    -            println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
    +            logError(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
                   s" subsamplingRate=$subsamplingRate")
                 throw e
             }
    @@ -166,43 +166,58 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
     
         val algos = Array(Regression, Regression, Classification)
         val losses = Array(SquaredError, AbsoluteError, LogLoss)
    -    (algos zip losses) map {
    -      case (algo, loss) => {
    -        val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2,
    -          categoricalFeaturesInfo = Map.empty)
    -        val boostingStrategy =
    -          new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
    -        val gbtValidate = new GradientBoostedTrees(boostingStrategy)
    -          .runWithValidation(trainRdd, validateRdd)
    -        val numTrees = gbtValidate.numTrees
    -        assert(numTrees !== numIterations)
    -
    -        // Test that it performs better on the validation dataset.
    -        val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd)
    -        val (errorWithoutValidation, errorWithValidation) = {
    -          if (algo == Classification) {
    -            val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
    -            (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd))
    -          } else {
    -            (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd))
    -          }
    -        }
    -        assert(errorWithValidation <= errorWithoutValidation)
    -
    -        // Test that results from evaluateEachIteration comply with runWithValidation.
    -        // Note that convergenceTol is set to 0.0
    -        val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss)
    -        assert(evaluationArray.length === numIterations)
    -        assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1))
    -        var i = 1
    -        while (i < numTrees) {
    -          assert(evaluationArray(i) <= evaluationArray(i - 1))
    -          i += 1
    +    algos.zip(losses).foreach { case (algo, loss) =>
    +      val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2,
    +        categoricalFeaturesInfo = Map.empty)
    +      val boostingStrategy =
    +        new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
    +      val gbtValidate = new GradientBoostedTrees(boostingStrategy)
    +        .runWithValidation(trainRdd, validateRdd)
    +      val numTrees = gbtValidate.numTrees
    +      assert(numTrees !== numIterations)
    +
    +      // Test that it performs better on the validation dataset.
    +      val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd)
    +      val (errorWithoutValidation, errorWithValidation) = {
    +        if (algo == Classification) {
    +          val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
    +          (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd))
    +        } else {
    +          (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd))
             }
           }
    +      assert(errorWithValidation <= errorWithoutValidation)
    +
    +      // Test that results from evaluateEachIteration comply with runWithValidation.
    +      // Note that convergenceTol is set to 0.0
    +      val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss)
    +      assert(evaluationArray.length === numIterations)
    +      assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1))
    +      var i = 1
    +      while (i < numTrees) {
    +        assert(evaluationArray(i) <= evaluationArray(i - 1))
    +        i += 1
    +      }
         }
       }
     
    +  test("Checkpointing") {
    +    val tempDir = Utils.createTempDir()
    +    val path = tempDir.toURI.toString
    +    sc.setCheckpointDir(path)
    +
    +    val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2)
    +
    +    val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
    +      categoricalFeaturesInfo = Map.empty, checkpointInterval = 2)
    +    val boostingStrategy = new BoostingStrategy(treeStrategy, SquaredError, 5, 0.1)
    +
    +    val gbt = GradientBoostedTrees.train(rdd, boostingStrategy)
    +
    +    sc.checkpointDir = None
    +    Utils.deleteRecursively(tempDir)
    +  }
    +
     }
     
     private object GradientBoostedTreesSuite {
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala
    index 5e9101cdd3804..525ab68c7921a 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala
    @@ -26,7 +26,7 @@ trait LocalClusterSparkContext extends BeforeAndAfterAll { self: Suite =>
     
       override def beforeAll() {
         val conf = new SparkConf()
    -      .setMaster("local-cluster[2, 1, 512]")
    +      .setMaster("local-cluster[2, 1, 1024]")
           .setAppName("test-cluster")
           .set("spark.akka.frameSize", "1") // set to 1MB to detect direct serialization of data
         sc = new SparkContext(conf)
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala
    index fa4f74d71b7e7..16d7c3ab39b03 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala
    @@ -33,7 +33,7 @@ class NumericParserSuite extends SparkFunSuite {
         malformatted.foreach { s =>
           intercept[SparkException] {
             NumericParser.parse(s)
    -        println(s"Didn't detect malformatted string $s.")
    +        throw new RuntimeException(s"Didn't detect malformatted string $s.")
           }
         }
       }
    diff --git a/pom.xml b/pom.xml
    index df06632997029..b4410c6c56de8 100644
    --- a/pom.xml
    +++ b/pom.xml
    @@ -129,7 +129,7 @@
         ${hadoop.version}
         0.98.7-hadoop2
         hbase
    -    1.4.0
    +    1.6.0
         3.4.5
         2.4.0
         org.spark-project.hive
    @@ -145,7 +145,7 @@
         0.5.0
         2.4.0
         2.0.8
    -    3.1.0
    +    3.1.2
         1.7.7
         hadoop2
         0.7.1
    @@ -153,7 +153,6 @@
         1.2.1
         4.3.2
         3.4.1
    -    ${project.build.directory}/spark-test-classpath.txt
         2.10.4
         2.10
         ${scala.version}
    @@ -162,6 +161,8 @@
         2.4.4
         1.1.1.7
         1.1.2
    +    
    +    false
     
         ${java.home}
     
    @@ -178,6 +179,7 @@
         compile
         compile
         compile
    +    test
     
         
    +    
         
    -      spark-1.4-staging
    -      Spark 1.4 RC4 Staging Repository
    -      https://repository.apache.org/content/repositories/orgapachespark-1112
    +      twttr-repo
    +      Twttr Repository
    +      http://maven.twttr.com
           
             true
           
    @@ -304,17 +306,6 @@
           unused
           1.0.0
         
    -    
    -    
    -      org.codehaus.groovy
    -      groovy-all
    -      2.3.7
    -      provided
    -    
         
           
             com.fasterxml.jackson.module
    -        jackson-module-scala_2.10
    +        jackson-module-scala_${scala.binary.version}
             ${fasterxml.jackson.version}
             
               
    @@ -739,6 +725,12 @@
             curator-framework
             ${curator.version}
           
    +      
    +        org.apache.curator
    +        curator-test
    +        ${curator.version}
    +        test
    +      
           
             org.apache.hadoop
             hadoop-client
    @@ -1100,6 +1092,12 @@
             ${parquet.version}
             ${parquet.deps.scope}
           
    +      
    +        org.apache.parquet
    +        parquet-avro
    +        ${parquet.version}
    +        ${parquet.test.deps.scope}
    +      
           
             org.apache.flume
             flume-ng-core
    @@ -1110,6 +1108,10 @@
                 io.netty
                 netty
               
    +          
    +            org.apache.flume
    +            flume-ng-auth
    +          
               
                 org.apache.thrift
                 libthrift
    @@ -1305,6 +1307,7 @@
                   false
                   false
                   true
    +              true
                 
               
               
    @@ -1386,6 +1389,58 @@
               maven-deploy-plugin
               2.8.2
             
    +        
    +        
    +        
    +          org.eclipse.m2e
    +          lifecycle-mapping
    +          1.0.0
    +          
    +            
    +              
    +                
    +                  
    +                    org.apache.maven.plugins
    +                    maven-dependency-plugin
    +                    [2.8,)
    +                    
    +                      build-classpath
    +                    
    +                  
    +                  
    +                    
    +                  
    +                
    +                
    +                  
    +                    org.apache.maven.plugins
    +                    maven-jar-plugin
    +                    [2.6,)
    +                    
    +                      test-jar
    +                    
    +                  
    +                  
    +                    
    +                  
    +                
    +                
    +                  
    +                    org.apache.maven.plugins
    +                    maven-antrun-plugin
    +                    [1.8,)
    +                    
    +                      run
    +                    
    +                  
    +                  
    +                    
    +                  
    +                
    +              
    +            
    +          
    +        
           
         
     
    @@ -1403,34 +1458,12 @@
                 
                 
                   test
    -              ${test_classpath_file}
    +              test_classpath
                 
               
             
           
     
    -      
    -      
    -        org.codehaus.gmavenplus
    -        gmavenplus-plugin
    -        1.5
    -        
    -          
    -            process-test-classes
    -            
    -              execute
    -            
    -            
    -              
    -                
    -              
    -            
    -          
    -        
    -      
           
    +          ${create.dependency.reduced.pom}
               
                 
                   
    @@ -1495,36 +1530,6 @@
             org.apache.maven.plugins
             maven-enforcer-plugin
           
    -      
    -        org.codehaus.mojo
    -        build-helper-maven-plugin
    -        
    -          
    -            add-scala-sources
    -            generate-sources
    -            
    -              add-source
    -            
    -            
    -              
    -                src/main/scala
    -              
    -            
    -          
    -          
    -            add-scala-test-sources
    -            generate-test-sources
    -            
    -              add-test-source
    -            
    -            
    -              
    -                src/test/scala
    -              
    -            
    -          
    -        
    -      
           
             net.alchim31.maven
             scala-maven-plugin
    @@ -1714,7 +1719,6 @@
           
             2.3.0
             0.9.3
    -        3.1.1
           
         
     
    @@ -1723,7 +1727,6 @@
           
             2.4.0
             0.9.3
    -        3.1.1
           
         
     
    @@ -1732,7 +1735,6 @@
           
             2.6.0
             0.9.3
    -        3.1.1
             3.4.6
             2.6.0
           
    @@ -1802,6 +1804,15 @@
             ${scala.version}
             org.scala-lang
           
    +      
    +        
    +          
    +            ${jline.groupid}
    +            jline
    +            ${jline.version}
    +          
    +        
    +      
         
     
         
    @@ -1820,10 +1831,28 @@
             scala-2.11
           
           
    -        2.11.6
    +        2.11.7
             2.11
    -        2.12.1
    -        jline
    +      
    +    
    +
    +    
    +      
    +      release
    +      
    +        
    +        true
           
         
     
    diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
    index 680b699e9e4a1..fa36629c37a35 100644
    --- a/project/MimaExcludes.scala
    +++ b/project/MimaExcludes.scala
    @@ -58,30 +58,101 @@ object MimaExcludes {
                   "org.apache.spark.ml.regression.LeastSquaresAggregator.this"),
                 ProblemFilters.exclude[MissingMethodProblem](
                   "org.apache.spark.ml.regression.LeastSquaresCostFun.this"),
    +            ProblemFilters.exclude[MissingMethodProblem](
    +              "org.apache.spark.ml.classification.LogisticCostFun.this"),
                 // SQL execution is considered private.
                 excludePackage("org.apache.spark.sql.execution"),
    -            // NanoTime and CatalystTimestampConverter is only used inside catalyst,
    -            // not needed anymore
    -            ProblemFilters.exclude[MissingClassProblem](
    -              "org.apache.spark.sql.parquet.timestamp.NanoTime"),
    -              ProblemFilters.exclude[MissingClassProblem](
    -              "org.apache.spark.sql.parquet.timestamp.NanoTime$"),
    -            ProblemFilters.exclude[MissingClassProblem](
    -              "org.apache.spark.sql.parquet.CatalystTimestampConverter"),
    -            ProblemFilters.exclude[MissingClassProblem](
    -              "org.apache.spark.sql.parquet.CatalystTimestampConverter$"),
    -            // SPARK-6777 Implements backwards compatibility rules in CatalystSchemaConverter
    -            ProblemFilters.exclude[MissingClassProblem](
    -              "org.apache.spark.sql.parquet.ParquetTypeInfo"),
    -            ProblemFilters.exclude[MissingClassProblem](
    -              "org.apache.spark.sql.parquet.ParquetTypeInfo$")
    +            // Parquet support is considered private.
    +            excludePackage("org.apache.spark.sql.parquet"),
    +            // The old JSON RDD is removed in favor of streaming Jackson
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.json.JsonRDD$"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.json.JsonRDD"),
    +            // local function inside a method
    +            ProblemFilters.exclude[MissingMethodProblem](
    +              "org.apache.spark.sql.SQLContext.org$apache$spark$sql$SQLContext$$needsConversion$1"),
    +            ProblemFilters.exclude[MissingMethodProblem](
    +              "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$24")
               ) ++ Seq(
                 // SPARK-8479 Add numNonzeros and numActives to Matrix.
                 ProblemFilters.exclude[MissingMethodProblem](
                   "org.apache.spark.mllib.linalg.Matrix.numNonzeros"),
                 ProblemFilters.exclude[MissingMethodProblem](
                   "org.apache.spark.mllib.linalg.Matrix.numActives")
    +          ) ++ Seq(
    +            // SPARK-8914 Remove RDDApi
    +            ProblemFilters.exclude[MissingClassProblem](
    +            "org.apache.spark.sql.RDDApi")
    +          ) ++ Seq(
    +            // SPARK-8701 Add input metadata in the batch page.
    +            ProblemFilters.exclude[MissingClassProblem](
    +              "org.apache.spark.streaming.scheduler.InputInfo$"),
    +            ProblemFilters.exclude[MissingClassProblem](
    +              "org.apache.spark.streaming.scheduler.InputInfo")
    +          ) ++ Seq(
    +            // SPARK-6797 Support YARN modes for SparkR
    +            ProblemFilters.exclude[MissingMethodProblem](
    +              "org.apache.spark.api.r.PairwiseRRDD.this"),
    +            ProblemFilters.exclude[MissingMethodProblem](
    +              "org.apache.spark.api.r.RRDD.createRWorker"),
    +            ProblemFilters.exclude[MissingMethodProblem](
    +              "org.apache.spark.api.r.RRDD.this"),
    +            ProblemFilters.exclude[MissingMethodProblem](
    +              "org.apache.spark.api.r.StringRRDD.this"),
    +            ProblemFilters.exclude[MissingMethodProblem](
    +              "org.apache.spark.api.r.BaseRRDD.this")
    +          ) ++ Seq(
    +            // SPARK-7422 add argmax for sparse vectors
    +            ProblemFilters.exclude[MissingMethodProblem](
    +              "org.apache.spark.mllib.linalg.Vector.argmax")
    +          ) ++ Seq(
    +            // SPARK-8906 Move all internal data source classes into execution.datasources
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.ResolvedDataSource"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreInsertCastAndRename$"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsingAsSelect$"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoDataSource$"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopPartition"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$PartitionValues$"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DefaultWriterContainer"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$PartitionValues"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.RefreshTable$"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsing$"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DynamicPartitionWriterContainer"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsingAsSelect"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand$"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreInsertCastAndRename"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Partition$"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LogicalRelation$"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LogicalRelation"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Partition"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.BaseWriterContainer"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreWriteCheck"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsing"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.RefreshTable"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$NewHadoopMapPartitionsWithSplitRDD"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DataSourceStrategy$"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsing"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsingAsSelect$"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsingAsSelect"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsing$"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.ResolvedDataSource$"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreWriteCheck$"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoDataSource"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoHadoopFsRelation"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DDLParser"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CaseInsensitiveMap"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoHadoopFsRelation$"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DataSourceStrategy"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$NewHadoopMapPartitionsWithSplitRDD$"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec$"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DDLException")
               )
    +
             case v if v.startsWith("1.4") =>
               Seq(
                 MimaBuild.excludeSparkPackage("deploy"),
    diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
    index 454678e5fbc28..818b9a6e3c085 100644
    --- a/project/SparkBuild.scala
    +++ b/project/SparkBuild.scala
    @@ -69,6 +69,7 @@ object SparkBuild extends PomBuild {
         import scala.collection.mutable
         var isAlphaYarn = false
         var profiles: mutable.Seq[String] = mutable.Seq("sbt")
    +    // scalastyle:off println
         if (Properties.envOrNone("SPARK_GANGLIA_LGPL").isDefined) {
           println("NOTE: SPARK_GANGLIA_LGPL is deprecated, please use -Pspark-ganglia-lgpl flag.")
           profiles ++= Seq("spark-ganglia-lgpl")
    @@ -88,6 +89,7 @@ object SparkBuild extends PomBuild {
           println("NOTE: SPARK_YARN is deprecated, please use -Pyarn flag.")
           profiles ++= Seq("yarn")
         }
    +    // scalastyle:on println
         profiles
       }
     
    @@ -96,8 +98,10 @@ object SparkBuild extends PomBuild {
         case None => backwardCompatibility
         case Some(v) =>
           if (backwardCompatibility.nonEmpty)
    +        // scalastyle:off println
             println("Note: We ignore environment variables, when use of profile is detected in " +
               "conjunction with environment variable.")
    +        // scalastyle:on println
           v.split("(\\s+|,)").filterNot(_.isEmpty).map(_.trim.replaceAll("-P", "")).toSeq
         }
     
    @@ -150,7 +154,38 @@ object SparkBuild extends PomBuild {
           if (major.toInt >= 1 && minor.toInt >= 8) Seq("-Xdoclint:all", "-Xdoclint:-missing") else Seq.empty
         },
     
    -    javacOptions in Compile ++= Seq("-encoding", "UTF-8")
    +    javacOptions in Compile ++= Seq("-encoding", "UTF-8"),
    +
    +    // Implements -Xfatal-warnings, ignoring deprecation warnings.
    +    // Code snippet taken from https://issues.scala-lang.org/browse/SI-8410.
    +    compile in Compile := {
    +      val analysis = (compile in Compile).value
    +      val s = streams.value
    +
    +      def logProblem(l: (=> String) => Unit, f: File, p: xsbti.Problem) = {
    +        l(f.toString + ":" + p.position.line.fold("")(_ + ":") + " " + p.message)
    +        l(p.position.lineContent)
    +        l("")
    +      }
    +
    +      var failed = 0
    +      analysis.infos.allInfos.foreach { case (k, i) =>
    +        i.reportedProblems foreach { p =>
    +          val deprecation = p.message.contains("is deprecated")
    +
    +          if (!deprecation) {
    +            failed = failed + 1
    +          }
    +
    +          logProblem(if (deprecation) s.log.warn else s.log.error, k, p)
    +        }
    +      }
    +
    +      if (failed > 0) {
    +        sys.error(s"$failed fatal warnings")
    +      }
    +      analysis
    +    }
       )
     
       def enable(settings: Seq[Setting[_]])(projectRef: ProjectRef) = {
    @@ -483,8 +518,8 @@ object Unidoc {
             "mllib.tree.impurity", "mllib.tree.model", "mllib.util",
             "mllib.evaluation", "mllib.feature", "mllib.random", "mllib.stat.correlation",
             "mllib.stat.test", "mllib.tree.impl", "mllib.tree.loss",
    -        "ml", "ml.attribute", "ml.classification", "ml.evaluation", "ml.feature", "ml.param",
    -        "ml.recommendation", "ml.regression", "ml.tuning"
    +        "ml", "ml.attribute", "ml.classification", "ml.clustering", "ml.evaluation", "ml.feature",
    +        "ml.param", "ml.recommendation", "ml.regression", "ml.tuning"
           ),
           "-group", "Spark SQL", packageList("sql.api.java", "sql.api.java.types", "sql.hive.api.java"),
           "-noqualifier", "java.lang"
    diff --git a/pylintrc b/pylintrc
    new file mode 100644
    index 0000000000000..6a675770da69a
    --- /dev/null
    +++ b/pylintrc
    @@ -0,0 +1,404 @@
    +#
    +# Licensed to the Apache Software Foundation (ASF) under one or more
    +# contributor license agreements.  See the NOTICE file distributed with
    +# this work for additional information regarding copyright ownership.
    +# The ASF licenses this file to You under the Apache License, Version 2.0
    +# (the "License"); you may not use this file except in compliance with
    +# the License.  You may obtain a copy of the License at
    +#
    +#    http://www.apache.org/licenses/LICENSE-2.0
    +#
    +# Unless required by applicable law or agreed to in writing, software
    +# distributed under the License is distributed on an "AS IS" BASIS,
    +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    +# See the License for the specific language governing permissions and
    +# limitations under the License.
    +#
    +
    +[MASTER]
    +
    +# Specify a configuration file.
    +#rcfile=
    +
    +# Python code to execute, usually for sys.path manipulation such as
    +# pygtk.require().
    +#init-hook=
    +
    +# Profiled execution.
    +profile=no
    +
    +# Add files or directories to the blacklist. They should be base names, not
    +# paths.
    +ignore=pyspark.heapq3
    +
    +# Pickle collected data for later comparisons.
    +persistent=yes
    +
    +# List of plugins (as comma separated values of python modules names) to load,
    +# usually to register additional checkers.
    +load-plugins=
    +
    +# Use multiple processes to speed up Pylint.
    +jobs=1
    +
    +# Allow loading of arbitrary C extensions. Extensions are imported into the
    +# active Python interpreter and may run arbitrary code.
    +unsafe-load-any-extension=no
    +
    +# A comma-separated list of package or module names from where C extensions may
    +# be loaded. Extensions are loading into the active Python interpreter and may
    +# run arbitrary code
    +extension-pkg-whitelist=
    +
    +# Allow optimization of some AST trees. This will activate a peephole AST
    +# optimizer, which will apply various small optimizations. For instance, it can
    +# be used to obtain the result of joining multiple strings with the addition
    +# operator. Joining a lot of strings can lead to a maximum recursion error in
    +# Pylint and this flag can prevent that. It has one side effect, the resulting
    +# AST will be different than the one from reality.
    +optimize-ast=no
    +
    +
    +[MESSAGES CONTROL]
    +
    +# Only show warnings with the listed confidence levels. Leave empty to show
    +# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
    +confidence=
    +
    +# Enable the message, report, category or checker with the given id(s). You can
    +# either give multiple identifier separated by comma (,) or put this option
    +# multiple time. See also the "--disable" option for examples.
    +enable=
    +
    +# Disable the message, report, category or checker with the given id(s). You
    +# can either give multiple identifiers separated by comma (,) or put this
    +# option multiple times (only on the command line, not in the configuration
    +# file where it should appear only once).You can also use "--disable=all" to
    +# disable everything first and then reenable specific checks. For example, if
    +# you want to run only the similarities checker, you can use "--disable=all
    +# --enable=similarities". If you want to run only the classes checker, but have
    +# no Warning level messages displayed, use"--disable=all --enable=classes
    +# --disable=W"
    +
    +# These errors are arranged in order of number of warning given in pylint.
    +# If you would like to improve the code quality of pyspark, remove any of these disabled errors
    +# run ./dev/lint-python and see if the errors raised by pylint can be fixed.
    +
    +disable=invalid-name,missing-docstring,protected-access,unused-argument,no-member,unused-wildcard-import,redefined-builtin,too-many-arguments,unused-variable,too-few-public-methods,bad-continuation,duplicate-code,redefined-outer-name,too-many-ancestors,import-error,superfluous-parens,unused-import,line-too-long,no-name-in-module,unnecessary-lambda,import-self,no-self-use,unidiomatic-typecheck,fixme,too-many-locals,cyclic-import,too-many-branches,bare-except,wildcard-import,dangerous-default-value,broad-except,too-many-public-methods,deprecated-lambda,anomalous-backslash-in-string,too-many-lines,reimported,too-many-statements,bad-whitespace,unpacking-non-sequence,too-many-instance-attributes,abstract-method,old-style-class,global-statement,attribute-defined-outside-init,arguments-differ,undefined-all-variable,no-init,useless-else-on-loop,super-init-not-called,notimplemented-raised,too-many-return-statements,pointless-string-statement,global-variable-undefined,bad-classmethod-argument,too-many-format-args,parse-error,no-self-argument,pointless-statement,undefined-variable,undefined-loop-variable
    +
    +
    +[REPORTS]
    +
    +# Set the output format. Available formats are text, parseable, colorized, msvs
    +# (visual studio) and html. You can also give a reporter class, eg
    +# mypackage.mymodule.MyReporterClass.
    +output-format=text
    +
    +# Put messages in a separate file for each module / package specified on the
    +# command line instead of printing them on stdout. Reports (if any) will be
    +# written in a file name "pylint_global.[txt|html]".
    +files-output=no
    +
    +# Tells whether to display a full report or only the messages
    +reports=no
    +
    +# Python expression which should return a note less than 10 (10 is the highest
    +# note). You have access to the variables errors warning, statement which
    +# respectively contain the number of errors / warnings messages and the total
    +# number of statements analyzed. This is used by the global evaluation report
    +# (RP0004).
    +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
    +
    +# Add a comment according to your evaluation note. This is used by the global
    +# evaluation report (RP0004).
    +comment=no
    +
    +# Template used to display messages. This is a python new-style format string
    +# used to format the message information. See doc for all details
    +#msg-template=
    +
    +
    +[MISCELLANEOUS]
    +
    +# List of note tags to take in consideration, separated by a comma.
    +notes=FIXME,XXX,TODO
    +
    +
    +[BASIC]
    +
    +# Required attributes for module, separated by a comma
    +required-attributes=
    +
    +# List of builtins function names that should not be used, separated by a comma
    +bad-functions=
    +
    +# Good variable names which should always be accepted, separated by a comma
    +good-names=i,j,k,ex,Run,_
    +
    +# Bad variable names which should always be refused, separated by a comma
    +bad-names=baz,toto,tutu,tata
    +
    +# Colon-delimited sets of names that determine each other's naming style when
    +# the name regexes allow several styles.
    +name-group=
    +
    +# Include a hint for the correct naming format with invalid-name
    +include-naming-hint=no
    +
    +# Regular expression matching correct function names
    +function-rgx=[a-z_][a-z0-9_]{2,30}$
    +
    +# Naming hint for function names
    +function-name-hint=[a-z_][a-z0-9_]{2,30}$
    +
    +# Regular expression matching correct variable names
    +variable-rgx=[a-z_][a-z0-9_]{2,30}$
    +
    +# Naming hint for variable names
    +variable-name-hint=[a-z_][a-z0-9_]{2,30}$
    +
    +# Regular expression matching correct constant names
    +const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$
    +
    +# Naming hint for constant names
    +const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$
    +
    +# Regular expression matching correct attribute names
    +attr-rgx=[a-z_][a-z0-9_]{2,30}$
    +
    +# Naming hint for attribute names
    +attr-name-hint=[a-z_][a-z0-9_]{2,30}$
    +
    +# Regular expression matching correct argument names
    +argument-rgx=[a-z_][a-z0-9_]{2,30}$
    +
    +# Naming hint for argument names
    +argument-name-hint=[a-z_][a-z0-9_]{2,30}$
    +
    +# Regular expression matching correct class attribute names
    +class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$
    +
    +# Naming hint for class attribute names
    +class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$
    +
    +# Regular expression matching correct inline iteration names
    +inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$
    +
    +# Naming hint for inline iteration names
    +inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$
    +
    +# Regular expression matching correct class names
    +class-rgx=[A-Z_][a-zA-Z0-9]+$
    +
    +# Naming hint for class names
    +class-name-hint=[A-Z_][a-zA-Z0-9]+$
    +
    +# Regular expression matching correct module names
    +module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
    +
    +# Naming hint for module names
    +module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
    +
    +# Regular expression matching correct method names
    +method-rgx=[a-z_][a-z0-9_]{2,30}$
    +
    +# Naming hint for method names
    +method-name-hint=[a-z_][a-z0-9_]{2,30}$
    +
    +# Regular expression which should only match function or class names that do
    +# not require a docstring.
    +no-docstring-rgx=__.*__
    +
    +# Minimum line length for functions/classes that require docstrings, shorter
    +# ones are exempt.
    +docstring-min-length=-1
    +
    +
    +[FORMAT]
    +
    +# Maximum number of characters on a single line.
    +max-line-length=100
    +
    +# Regexp for a line that is allowed to be longer than the limit.
    +ignore-long-lines=^\s*(# )??$
    +
    +# Allow the body of an if to be on the same line as the test if there is no
    +# else.
    +single-line-if-stmt=no
    +
    +# List of optional constructs for which whitespace checking is disabled
    +no-space-check=trailing-comma,dict-separator
    +
    +# Maximum number of lines in a module
    +max-module-lines=1000
    +
    +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
    +# tab).
    +indent-string='    '
    +
    +# Number of spaces of indent required inside a hanging or continued line.
    +indent-after-paren=4
    +
    +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
    +expected-line-ending-format=
    +
    +
    +[SIMILARITIES]
    +
    +# Minimum lines number of a similarity.
    +min-similarity-lines=4
    +
    +# Ignore comments when computing similarities.
    +ignore-comments=yes
    +
    +# Ignore docstrings when computing similarities.
    +ignore-docstrings=yes
    +
    +# Ignore imports when computing similarities.
    +ignore-imports=no
    +
    +
    +[VARIABLES]
    +
    +# Tells whether we should check for unused import in __init__ files.
    +init-import=no
    +
    +# A regular expression matching the name of dummy variables (i.e. expectedly
    +# not used).
    +dummy-variables-rgx=_$|dummy
    +
    +# List of additional names supposed to be defined in builtins. Remember that
    +# you should avoid to define new builtins when possible.
    +additional-builtins=
    +
    +# List of strings which can identify a callback function by name. A callback
    +# name must start or end with one of those strings.
    +callbacks=cb_,_cb
    +
    +
    +[SPELLING]
    +
    +# Spelling dictionary name. Available dictionaries: none. To make it working
    +# install python-enchant package.
    +spelling-dict=
    +
    +# List of comma separated words that should not be checked.
    +spelling-ignore-words=
    +
    +# A path to a file that contains private dictionary; one word per line.
    +spelling-private-dict-file=
    +
    +# Tells whether to store unknown words to indicated private dictionary in
    +# --spelling-private-dict-file option instead of raising a message.
    +spelling-store-unknown-words=no
    +
    +
    +[LOGGING]
    +
    +# Logging modules to check that the string format arguments are in logging
    +# function parameter format
    +logging-modules=logging
    +
    +
    +[TYPECHECK]
    +
    +# Tells whether missing members accessed in mixin class should be ignored. A
    +# mixin class is detected if its name ends with "mixin" (case insensitive).
    +ignore-mixin-members=yes
    +
    +# List of module names for which member attributes should not be checked
    +# (useful for modules/projects where namespaces are manipulated during runtime
    +# and thus existing member attributes cannot be deduced by static analysis
    +ignored-modules=
    +
    +# List of classes names for which member attributes should not be checked
    +# (useful for classes with attributes dynamically set).
    +ignored-classes=SQLObject
    +
    +# When zope mode is activated, add a predefined set of Zope acquired attributes
    +# to generated-members.
    +zope=no
    +
    +# List of members which are set dynamically and missed by pylint inference
    +# system, and so shouldn't trigger E0201 when accessed. Python regular
    +# expressions are accepted.
    +generated-members=REQUEST,acl_users,aq_parent
    +
    +
    +[CLASSES]
    +
    +# List of interface methods to ignore, separated by a comma. This is used for
    +# instance to not check methods defines in Zope's Interface base class.
    +ignore-iface-methods=isImplementedBy,deferred,extends,names,namesAndDescriptions,queryDescriptionFor,getBases,getDescriptionFor,getDoc,getName,getTaggedValue,getTaggedValueTags,isEqualOrExtendedBy,setTaggedValue,isImplementedByInstancesOf,adaptWith,is_implemented_by
    +
    +# List of method names used to declare (i.e. assign) instance attributes.
    +defining-attr-methods=__init__,__new__,setUp
    +
    +# List of valid names for the first argument in a class method.
    +valid-classmethod-first-arg=cls
    +
    +# List of valid names for the first argument in a metaclass class method.
    +valid-metaclass-classmethod-first-arg=mcs
    +
    +# List of member names, which should be excluded from the protected access
    +# warning.
    +exclude-protected=_asdict,_fields,_replace,_source,_make
    +
    +
    +[IMPORTS]
    +
    +# Deprecated modules which should not be used, separated by a comma
    +deprecated-modules=regsub,TERMIOS,Bastion,rexec
    +
    +# Create a graph of every (i.e. internal and external) dependencies in the
    +# given file (report RP0402 must not be disabled)
    +import-graph=
    +
    +# Create a graph of external dependencies in the given file (report RP0402 must
    +# not be disabled)
    +ext-import-graph=
    +
    +# Create a graph of internal dependencies in the given file (report RP0402 must
    +# not be disabled)
    +int-import-graph=
    +
    +
    +[DESIGN]
    +
    +# Maximum number of arguments for function / method
    +max-args=5
    +
    +# Argument names that match this expression will be ignored. Default to name
    +# with leading underscore
    +ignored-argument-names=_.*
    +
    +# Maximum number of locals for function / method body
    +max-locals=15
    +
    +# Maximum number of return / yield for function / method body
    +max-returns=6
    +
    +# Maximum number of branch for function / method body
    +max-branches=12
    +
    +# Maximum number of statements in function / method body
    +max-statements=50
    +
    +# Maximum number of parents for a class (see R0901).
    +max-parents=7
    +
    +# Maximum number of attributes for a class (see R0902).
    +max-attributes=7
    +
    +# Minimum number of public methods for a class (see R0903).
    +min-public-methods=2
    +
    +# Maximum number of public methods for a class (see R0904).
    +max-public-methods=20
    +
    +
    +[EXCEPTIONS]
    +
    +# Exceptions that will emit a warning when being caught. Defaults to
    +# "Exception"
    +overgeneral-exceptions=Exception
    diff --git a/python/docs/pyspark.ml.rst b/python/docs/pyspark.ml.rst
    index 518b8e774dd5f..86d4186a2c798 100644
    --- a/python/docs/pyspark.ml.rst
    +++ b/python/docs/pyspark.ml.rst
    @@ -33,6 +33,14 @@ pyspark.ml.classification module
         :undoc-members:
         :inherited-members:
     
    +pyspark.ml.clustering module
    +----------------------------
    +
    +.. automodule:: pyspark.ml.clustering
    +    :members:
    +    :undoc-members:
    +    :inherited-members:
    +
     pyspark.ml.recommendation module
     --------------------------------
     
    diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py
    index 9ef93071d2e77..3b647985801b7 100644
    --- a/python/pyspark/cloudpickle.py
    +++ b/python/pyspark/cloudpickle.py
    @@ -350,7 +350,26 @@ def save_global(self, obj, name=None, pack=struct.pack):
                 if new_override:
                     d['__new__'] = obj.__new__
     
    -            self.save_reduce(typ, (obj.__name__, obj.__bases__, d), obj=obj)
    +            self.save(_load_class)
    +            self.save_reduce(typ, (obj.__name__, obj.__bases__, {"__doc__": obj.__doc__}), obj=obj)
    +            d.pop('__doc__', None)
    +            # handle property and staticmethod
    +            dd = {}
    +            for k, v in d.items():
    +                if isinstance(v, property):
    +                    k = ('property', k)
    +                    v = (v.fget, v.fset, v.fdel, v.__doc__)
    +                elif isinstance(v, staticmethod) and hasattr(v, '__func__'):
    +                    k = ('staticmethod', k)
    +                    v = v.__func__
    +                elif isinstance(v, classmethod) and hasattr(v, '__func__'):
    +                    k = ('classmethod', k)
    +                    v = v.__func__
    +                dd[k] = v
    +            self.save(dd)
    +            self.write(pickle.TUPLE2)
    +            self.write(pickle.REDUCE)
    +
             else:
                 raise pickle.PicklingError("Can't pickle %r" % obj)
     
    @@ -708,6 +727,23 @@ def _make_skel_func(code, closures, base_globals = None):
                                   None, None, closure)
     
     
    +def _load_class(cls, d):
    +    """
    +    Loads additional properties into class `cls`.
    +    """
    +    for k, v in d.items():
    +        if isinstance(k, tuple):
    +            typ, k = k
    +            if typ == 'property':
    +                v = property(*v)
    +            elif typ == 'staticmethod':
    +                v = staticmethod(v)
    +            elif typ == 'classmethod':
    +                v = classmethod(v)
    +        setattr(cls, k, v)
    +    return cls
    +
    +
     """Constructors for 3rd party libraries
     Note: These can never be renamed due to client compatibility issues"""
     
    diff --git a/python/pyspark/context.py b/python/pyspark/context.py
    index d7466729b8f36..eb5b0bbbdac4b 100644
    --- a/python/pyspark/context.py
    +++ b/python/pyspark/context.py
    @@ -152,6 +152,11 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
             self.master = self._conf.get("spark.master")
             self.appName = self._conf.get("spark.app.name")
             self.sparkHome = self._conf.get("spark.home", None)
    +
    +        # Let YARN know it's a pyspark app, so it distributes needed libraries.
    +        if self.master == "yarn-client":
    +            self._conf.set("spark.yarn.isPython", "true")
    +
             for (k, v) in self._conf.getAll():
                 if k.startswith("spark.executorEnv."):
                     varName = k[len("spark.executorEnv."):]
    @@ -908,8 +913,7 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False):
             # by runJob() in order to avoid having to pass a Python lambda into
             # SparkContext#runJob.
             mappedRDD = rdd.mapPartitions(partitionFunc)
    -        port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions,
    -                                          allowLocal)
    +        port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions)
             return list(_load_from_socket(port, mappedRDD._jrdd_deserializer))
     
         def show_profiles(self):
    diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
    index 90cd342a6cf7f..60be85e53e2aa 100644
    --- a/python/pyspark/java_gateway.py
    +++ b/python/pyspark/java_gateway.py
    @@ -52,7 +52,11 @@ def launch_gateway():
             script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit"
             submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell")
             if os.environ.get("SPARK_TESTING"):
    -            submit_args = "--conf spark.ui.enabled=false " + submit_args
    +            submit_args = ' '.join([
    +                "--conf spark.ui.enabled=false",
    +                "--conf spark.buffer.pageSize=4mb",
    +                submit_args
    +            ])
             command = [os.path.join(SPARK_HOME, script)] + shlex.split(submit_args)
     
             # Start a socket that will be used by PythonGatewayServer to communicate its port to us
    diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
    index 7abbde8b260eb..5a82bc286d1e8 100644
    --- a/python/pyspark/ml/classification.py
    +++ b/python/pyspark/ml/classification.py
    @@ -18,7 +18,8 @@
     from pyspark.ml.util import keyword_only
     from pyspark.ml.wrapper import JavaEstimator, JavaModel
     from pyspark.ml.param.shared import *
    -from pyspark.ml.regression import RandomForestParams
    +from pyspark.ml.regression import (
    +    RandomForestParams, DecisionTreeModel, TreeEnsembleModels)
     from pyspark.mllib.common import inherit_doc
     
     
    @@ -202,6 +203,10 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
         >>> td = si_model.transform(df)
         >>> dt = DecisionTreeClassifier(maxDepth=2, labelCol="indexed")
         >>> model = dt.fit(td)
    +    >>> model.numNodes
    +    3
    +    >>> model.depth
    +    1
         >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
         >>> model.transform(test0).head().prediction
         0.0
    @@ -269,7 +274,8 @@ def getImpurity(self):
             return self.getOrDefault(self.impurity)
     
     
    -class DecisionTreeClassificationModel(JavaModel):
    +@inherit_doc
    +class DecisionTreeClassificationModel(DecisionTreeModel):
         """
         Model fitted by DecisionTreeClassifier.
         """
    @@ -284,6 +290,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
         It supports both binary and multiclass labels, as well as both continuous and categorical
         features.
     
    +    >>> from numpy import allclose
         >>> from pyspark.mllib.linalg import Vectors
         >>> from pyspark.ml.feature import StringIndexer
         >>> df = sqlContext.createDataFrame([
    @@ -292,8 +299,10 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
         >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
         >>> si_model = stringIndexer.fit(df)
         >>> td = si_model.transform(df)
    -    >>> rf = RandomForestClassifier(numTrees=2, maxDepth=2, labelCol="indexed", seed=42)
    +    >>> rf = RandomForestClassifier(numTrees=3, maxDepth=2, labelCol="indexed", seed=42)
         >>> model = rf.fit(td)
    +    >>> allclose(model.treeWeights, [1.0, 1.0, 1.0])
    +    True
         >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
         >>> model.transform(test0).head().prediction
         0.0
    @@ -423,7 +432,7 @@ def getFeatureSubsetStrategy(self):
             return self.getOrDefault(self.featureSubsetStrategy)
     
     
    -class RandomForestClassificationModel(JavaModel):
    +class RandomForestClassificationModel(TreeEnsembleModels):
         """
         Model fitted by RandomForestClassifier.
         """
    @@ -438,6 +447,7 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
         It supports binary labels, as well as both continuous and categorical features.
         Note: Multiclass labels are not currently supported.
     
    +    >>> from numpy import allclose
         >>> from pyspark.mllib.linalg import Vectors
         >>> from pyspark.ml.feature import StringIndexer
         >>> df = sqlContext.createDataFrame([
    @@ -448,6 +458,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
         >>> td = si_model.transform(df)
         >>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed")
         >>> model = gbt.fit(td)
    +    >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
    +    True
         >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
         >>> model.transform(test0).head().prediction
         0.0
    @@ -558,7 +570,7 @@ def getStepSize(self):
             return self.getOrDefault(self.stepSize)
     
     
    -class GBTClassificationModel(JavaModel):
    +class GBTClassificationModel(TreeEnsembleModels):
         """
         Model fitted by GBTClassifier.
         """
    diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
    new file mode 100644
    index 0000000000000..b5e9b6549d9f1
    --- /dev/null
    +++ b/python/pyspark/ml/clustering.py
    @@ -0,0 +1,206 @@
    +#
    +# Licensed to the Apache Software Foundation (ASF) under one or more
    +# contributor license agreements.  See the NOTICE file distributed with
    +# this work for additional information regarding copyright ownership.
    +# The ASF licenses this file to You under the Apache License, Version 2.0
    +# (the "License"); you may not use this file except in compliance with
    +# the License.  You may obtain a copy of the License at
    +#
    +#    http://www.apache.org/licenses/LICENSE-2.0
    +#
    +# Unless required by applicable law or agreed to in writing, software
    +# distributed under the License is distributed on an "AS IS" BASIS,
    +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    +# See the License for the specific language governing permissions and
    +# limitations under the License.
    +#
    +
    +from pyspark.ml.util import keyword_only
    +from pyspark.ml.wrapper import JavaEstimator, JavaModel
    +from pyspark.ml.param.shared import *
    +from pyspark.mllib.common import inherit_doc
    +from pyspark.mllib.linalg import _convert_to_vector
    +
    +__all__ = ['KMeans', 'KMeansModel']
    +
    +
    +class KMeansModel(JavaModel):
    +    """
    +    Model fitted by KMeans.
    +    """
    +
    +    def clusterCenters(self):
    +        """Get the cluster centers, represented as a list of NumPy arrays."""
    +        return [c.toArray() for c in self._call_java("clusterCenters")]
    +
    +
    +@inherit_doc
    +class KMeans(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed):
    +    """
    +    K-means Clustering
    +
    +    >>> from pyspark.mllib.linalg import Vectors
    +    >>> data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),),
    +    ...         (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)]
    +    >>> df = sqlContext.createDataFrame(data, ["features"])
    +    >>> kmeans = KMeans().setK(2).setSeed(1).setFeaturesCol("features")
    +    >>> model = kmeans.fit(df)
    +    >>> centers = model.clusterCenters()
    +    >>> len(centers)
    +    2
    +    >>> transformed = model.transform(df).select("features", "prediction")
    +    >>> rows = transformed.collect()
    +    >>> rows[0].prediction == rows[1].prediction
    +    True
    +    >>> rows[2].prediction == rows[3].prediction
    +    True
    +    """
    +
    +    # a placeholder to make it appear in the generated doc
    +    k = Param(Params._dummy(), "k", "number of clusters to create")
    +    epsilon = Param(Params._dummy(), "epsilon",
    +                    "distance threshold within which " +
    +                    "we've consider centers to have converged")
    +    runs = Param(Params._dummy(), "runs", "number of runs of the algorithm to execute in parallel")
    +    initMode = Param(Params._dummy(), "initMode",
    +                     "the initialization algorithm. This can be either \"random\" to " +
    +                     "choose random points as initial cluster centers, or \"k-means||\" " +
    +                     "to use a parallel variant of k-means++")
    +    initSteps = Param(Params._dummy(), "initSteps", "steps for k-means initialization mode")
    +
    +    @keyword_only
    +    def __init__(self, k=2, maxIter=20, runs=1, epsilon=1e-4, initMode="k-means||", initStep=5):
    +        super(KMeans, self).__init__()
    +        self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.KMeans", self.uid)
    +        self.k = Param(self, "k", "number of clusters to create")
    +        self.epsilon = Param(self, "epsilon",
    +                             "distance threshold within which " +
    +                             "we've consider centers to have converged")
    +        self.runs = Param(self, "runs", "number of runs of the algorithm to execute in parallel")
    +        self.seed = Param(self, "seed", "random seed")
    +        self.initMode = Param(self, "initMode",
    +                              "the initialization algorithm. This can be either \"random\" to " +
    +                              "choose random points as initial cluster centers, or \"k-means||\" " +
    +                              "to use a parallel variant of k-means++")
    +        self.initSteps = Param(self, "initSteps", "steps for k-means initialization mode")
    +        self._setDefault(k=2, maxIter=20, runs=1, epsilon=1e-4, initMode="k-means||", initSteps=5)
    +        kwargs = self.__init__._input_kwargs
    +        self.setParams(**kwargs)
    +
    +    def _create_model(self, java_model):
    +        return KMeansModel(java_model)
    +
    +    @keyword_only
    +    def setParams(self, k=2, maxIter=20, runs=1, epsilon=1e-4, initMode="k-means||", initSteps=5):
    +        """
    +        setParams(self, k=2, maxIter=20, runs=1, epsilon=1e-4, initMode="k-means||", initSteps=5):
    +
    +        Sets params for KMeans.
    +        """
    +        kwargs = self.setParams._input_kwargs
    +        return self._set(**kwargs)
    +
    +    def setK(self, value):
    +        """
    +        Sets the value of :py:attr:`k`.
    +
    +        >>> algo = KMeans().setK(10)
    +        >>> algo.getK()
    +        10
    +        """
    +        self._paramMap[self.k] = value
    +        return self
    +
    +    def getK(self):
    +        """
    +        Gets the value of `k`
    +        """
    +        return self.getOrDefault(self.k)
    +
    +    def setEpsilon(self, value):
    +        """
    +        Sets the value of :py:attr:`epsilon`.
    +
    +        >>> algo = KMeans().setEpsilon(1e-5)
    +        >>> abs(algo.getEpsilon() - 1e-5) < 1e-5
    +        True
    +        """
    +        self._paramMap[self.epsilon] = value
    +        return self
    +
    +    def getEpsilon(self):
    +        """
    +        Gets the value of `epsilon`
    +        """
    +        return self.getOrDefault(self.epsilon)
    +
    +    def setRuns(self, value):
    +        """
    +        Sets the value of :py:attr:`runs`.
    +
    +        >>> algo = KMeans().setRuns(10)
    +        >>> algo.getRuns()
    +        10
    +        """
    +        self._paramMap[self.runs] = value
    +        return self
    +
    +    def getRuns(self):
    +        """
    +        Gets the value of `runs`
    +        """
    +        return self.getOrDefault(self.runs)
    +
    +    def setInitMode(self, value):
    +        """
    +        Sets the value of :py:attr:`initMode`.
    +
    +        >>> algo = KMeans()
    +        >>> algo.getInitMode()
    +        'k-means||'
    +        >>> algo = algo.setInitMode("random")
    +        >>> algo.getInitMode()
    +        'random'
    +        """
    +        self._paramMap[self.initMode] = value
    +        return self
    +
    +    def getInitMode(self):
    +        """
    +        Gets the value of `initMode`
    +        """
    +        return self.getOrDefault(self.initMode)
    +
    +    def setInitSteps(self, value):
    +        """
    +        Sets the value of :py:attr:`initSteps`.
    +
    +        >>> algo = KMeans().setInitSteps(10)
    +        >>> algo.getInitSteps()
    +        10
    +        """
    +        self._paramMap[self.initSteps] = value
    +        return self
    +
    +    def getInitSteps(self):
    +        """
    +        Gets the value of `initSteps`
    +        """
    +        return self.getOrDefault(self.initSteps)
    +
    +
    +if __name__ == "__main__":
    +    import doctest
    +    from pyspark.context import SparkContext
    +    from pyspark.sql import SQLContext
    +    globs = globals().copy()
    +    # The small batch size here ensures that we see multiple batches,
    +    # even in these small test examples:
    +    sc = SparkContext("local[2]", "ml.clustering tests")
    +    sqlContext = SQLContext(sc)
    +    globs['sc'] = sc
    +    globs['sqlContext'] = sqlContext
    +    (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
    +    sc.stop()
    +    if failure_count:
    +        exit(-1)
    diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
    index 8804dace849b3..015e7a9d4900a 100644
    --- a/python/pyspark/ml/feature.py
    +++ b/python/pyspark/ml/feature.py
    @@ -24,7 +24,7 @@
     __all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'NGram', 'Normalizer', 'OneHotEncoder',
                'PolynomialExpansion', 'RegexTokenizer', 'StandardScaler', 'StandardScalerModel',
                'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer',
    -           'Word2Vec', 'Word2VecModel']
    +           'Word2Vec', 'Word2VecModel', 'PCA', 'PCAModel']
     
     
     @inherit_doc
    @@ -525,7 +525,7 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol):
         """
         A regex based tokenizer that extracts tokens either by using the
         provided regex pattern (in Java dialect) to split the text
    -    (default) or repeatedly matching the regex (if gaps is true).
    +    (default) or repeatedly matching the regex (if gaps is false).
         Optional parameters also allow filtering tokens using a minimal
         length.
         It returns an array of strings that can be empty.
    @@ -627,6 +627,10 @@ class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol):
         >>> df = sqlContext.createDataFrame([(Vectors.dense([0.0]),), (Vectors.dense([2.0]),)], ["a"])
         >>> standardScaler = StandardScaler(inputCol="a", outputCol="scaled")
         >>> model = standardScaler.fit(df)
    +    >>> model.mean
    +    DenseVector([1.0])
    +    >>> model.std
    +    DenseVector([1.4142])
         >>> model.transform(df).collect()[1].scaled
         DenseVector([1.4142])
         """
    @@ -692,6 +696,20 @@ class StandardScalerModel(JavaModel):
         Model fitted by StandardScaler.
         """
     
    +    @property
    +    def std(self):
    +        """
    +        Standard deviation of the StandardScalerModel.
    +        """
    +        return self._call_java("std")
    +
    +    @property
    +    def mean(self):
    +        """
    +        Mean of the StandardScalerModel.
    +        """
    +        return self._call_java("mean")
    +
     
     @inherit_doc
     class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol):
    @@ -1030,6 +1048,68 @@ class Word2VecModel(JavaModel):
         """
     
     
    +@inherit_doc
    +class PCA(JavaEstimator, HasInputCol, HasOutputCol):
    +    """
    +    PCA trains a model to project vectors to a low-dimensional space using PCA.
    +
    +    >>> from pyspark.mllib.linalg import Vectors
    +    >>> data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),),
    +    ...     (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),),
    +    ...     (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)]
    +    >>> df = sqlContext.createDataFrame(data,["features"])
    +    >>> pca = PCA(k=2, inputCol="features", outputCol="pca_features")
    +    >>> model = pca.fit(df)
    +    >>> model.transform(df).collect()[0].pca_features
    +    DenseVector([1.648..., -4.013...])
    +    """
    +
    +    # a placeholder to make it appear in the generated doc
    +    k = Param(Params._dummy(), "k", "the number of principal components")
    +
    +    @keyword_only
    +    def __init__(self, k=None, inputCol=None, outputCol=None):
    +        """
    +        __init__(self, k=None, inputCol=None, outputCol=None)
    +        """
    +        super(PCA, self).__init__()
    +        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.PCA", self.uid)
    +        self.k = Param(self, "k", "the number of principal components")
    +        kwargs = self.__init__._input_kwargs
    +        self.setParams(**kwargs)
    +
    +    @keyword_only
    +    def setParams(self, k=None, inputCol=None, outputCol=None):
    +        """
    +        setParams(self, k=None, inputCol=None, outputCol=None)
    +        Set params for this PCA.
    +        """
    +        kwargs = self.setParams._input_kwargs
    +        return self._set(**kwargs)
    +
    +    def setK(self, value):
    +        """
    +        Sets the value of :py:attr:`k`.
    +        """
    +        self._paramMap[self.k] = value
    +        return self
    +
    +    def getK(self):
    +        """
    +        Gets the value of k or its default value.
    +        """
    +        return self.getOrDefault(self.k)
    +
    +    def _create_model(self, java_model):
    +        return PCAModel(java_model)
    +
    +
    +class PCAModel(JavaModel):
    +    """
    +    Model fitted by PCA.
    +    """
    +
    +
     if __name__ == "__main__":
         import doctest
         from pyspark.context import SparkContext
    diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py
    index bc088e4c29e26..595124726366d 100644
    --- a/python/pyspark/ml/param/shared.py
    +++ b/python/pyspark/ml/param/shared.py
    @@ -444,7 +444,7 @@ class DecisionTreeParams(Params):
         minInfoGain = Param(Params._dummy(), "minInfoGain", "Minimum information gain for a split to be considered at a tree node.")
         maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.")
         cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.")
    -    
    +
     
         def __init__(self):
             super(DecisionTreeParams, self).__init__()
    @@ -460,7 +460,7 @@ def __init__(self):
             self.maxMemoryInMB = Param(self, "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.")
             #: param for If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.
             self.cacheNodeIds = Param(self, "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.")
    -        
    +
         def setMaxDepth(self, value):
             """
             Sets the value of :py:attr:`maxDepth`.
    diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
    index b139e27372d80..44f60a769566d 100644
    --- a/python/pyspark/ml/regression.py
    +++ b/python/pyspark/ml/regression.py
    @@ -172,6 +172,10 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
         ...     (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
         >>> dt = DecisionTreeRegressor(maxDepth=2)
         >>> model = dt.fit(df)
    +    >>> model.depth
    +    1
    +    >>> model.numNodes
    +    3
         >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
         >>> model.transform(test0).head().prediction
         0.0
    @@ -239,7 +243,37 @@ def getImpurity(self):
             return self.getOrDefault(self.impurity)
     
     
    -class DecisionTreeRegressionModel(JavaModel):
    +@inherit_doc
    +class DecisionTreeModel(JavaModel):
    +
    +    @property
    +    def numNodes(self):
    +        """Return number of nodes of the decision tree."""
    +        return self._call_java("numNodes")
    +
    +    @property
    +    def depth(self):
    +        """Return depth of the decision tree."""
    +        return self._call_java("depth")
    +
    +    def __repr__(self):
    +        return self._call_java("toString")
    +
    +
    +@inherit_doc
    +class TreeEnsembleModels(JavaModel):
    +
    +    @property
    +    def treeWeights(self):
    +        """Return the weights for each tree"""
    +        return list(self._call_java("javaTreeWeights"))
    +
    +    def __repr__(self):
    +        return self._call_java("toString")
    +
    +
    +@inherit_doc
    +class DecisionTreeRegressionModel(DecisionTreeModel):
         """
         Model fitted by DecisionTreeRegressor.
         """
    @@ -253,12 +287,15 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
         learning algorithm for regression.
         It supports both continuous and categorical features.
     
    +    >>> from numpy import allclose
         >>> from pyspark.mllib.linalg import Vectors
         >>> df = sqlContext.createDataFrame([
         ...     (1.0, Vectors.dense(1.0)),
         ...     (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
         >>> rf = RandomForestRegressor(numTrees=2, maxDepth=2, seed=42)
         >>> model = rf.fit(df)
    +    >>> allclose(model.treeWeights, [1.0, 1.0])
    +    True
         >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
         >>> model.transform(test0).head().prediction
         0.0
    @@ -389,7 +426,7 @@ def getFeatureSubsetStrategy(self):
             return self.getOrDefault(self.featureSubsetStrategy)
     
     
    -class RandomForestRegressionModel(JavaModel):
    +class RandomForestRegressionModel(TreeEnsembleModels):
         """
         Model fitted by RandomForestRegressor.
         """
    @@ -403,12 +440,15 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
         learning algorithm for regression.
         It supports both continuous and categorical features.
     
    +    >>> from numpy import allclose
         >>> from pyspark.mllib.linalg import Vectors
         >>> df = sqlContext.createDataFrame([
         ...     (1.0, Vectors.dense(1.0)),
         ...     (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
         >>> gbt = GBTRegressor(maxIter=5, maxDepth=2)
         >>> model = gbt.fit(df)
    +    >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
    +    True
         >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
         >>> model.transform(test0).head().prediction
         0.0
    @@ -518,7 +558,7 @@ def getStepSize(self):
             return self.getOrDefault(self.stepSize)
     
     
    -class GBTRegressionModel(JavaModel):
    +class GBTRegressionModel(TreeEnsembleModels):
         """
         Model fitted by GBTRegressor.
         """
    diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
    index a3eab635282f6..900ade248c386 100644
    --- a/python/pyspark/mllib/clustering.py
    +++ b/python/pyspark/mllib/clustering.py
    @@ -20,6 +20,7 @@
     
     if sys.version > '3':
         xrange = range
    +    basestring = str
     
     from math import exp, log
     
    @@ -31,13 +32,15 @@
     from pyspark.rdd import RDD, ignore_unicode_prefix
     from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, callJavaFunc, _py2java, _java2py
     from pyspark.mllib.linalg import SparseVector, _convert_to_vector, DenseVector
    +from pyspark.mllib.regression import LabeledPoint
     from pyspark.mllib.stat.distribution import MultivariateGaussian
     from pyspark.mllib.util import Saveable, Loader, inherit_doc, JavaLoader, JavaSaveable
     from pyspark.streaming import DStream
     
     __all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture',
                'PowerIterationClusteringModel', 'PowerIterationClustering',
    -           'StreamingKMeans', 'StreamingKMeansModel']
    +           'StreamingKMeans', 'StreamingKMeansModel',
    +           'LDA', 'LDAModel']
     
     
     @inherit_doc
    @@ -149,11 +152,19 @@ def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||"
             return KMeansModel([c.toArray() for c in centers])
     
     
    -class GaussianMixtureModel(object):
    +@inherit_doc
    +class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader):
     
    -    """A clustering model derived from the Gaussian Mixture Model method.
    +    """
    +    .. note:: Experimental
    +
    +    A clustering model derived from the Gaussian Mixture Model method.
     
         >>> from pyspark.mllib.linalg import Vectors, DenseMatrix
    +    >>> from numpy.testing import assert_equal
    +    >>> from shutil import rmtree
    +    >>> import os, tempfile
    +
         >>> clusterdata_1 =  sc.parallelize(array([-0.1,-0.05,-0.01,-0.1,
         ...                                         0.9,0.8,0.75,0.935,
         ...                                        -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2))
    @@ -166,6 +177,25 @@ class GaussianMixtureModel(object):
         True
         >>> labels[4]==labels[5]
         True
    +
    +    >>> path = tempfile.mkdtemp()
    +    >>> model.save(sc, path)
    +    >>> sameModel = GaussianMixtureModel.load(sc, path)
    +    >>> assert_equal(model.weights, sameModel.weights)
    +    >>> mus, sigmas = list(
    +    ...     zip(*[(g.mu, g.sigma) for g in model.gaussians]))
    +    >>> sameMus, sameSigmas = list(
    +    ...     zip(*[(g.mu, g.sigma) for g in sameModel.gaussians]))
    +    >>> mus == sameMus
    +    True
    +    >>> sigmas == sameSigmas
    +    True
    +    >>> from shutil import rmtree
    +    >>> try:
    +    ...     rmtree(path)
    +    ... except OSError:
    +    ...     pass
    +
         >>> data =  array([-5.1971, -2.5359, -3.8220,
         ...                -5.2211, -5.0602,  4.7118,
         ...                 6.8989, 3.4592,  4.6322,
    @@ -179,25 +209,15 @@ class GaussianMixtureModel(object):
         True
         >>> labels[3]==labels[4]
         True
    -    >>> clusterdata_3 = sc.parallelize(data.reshape(15, 1))
    -    >>> im = GaussianMixtureModel([0.5, 0.5],
    -    ...      [MultivariateGaussian(Vectors.dense([-1.0]), DenseMatrix(1, 1, [1.0])),
    -    ...      MultivariateGaussian(Vectors.dense([1.0]), DenseMatrix(1, 1, [1.0]))])
    -    >>> model = GaussianMixture.train(clusterdata_3, 2, initialModel=im)
         """
     
    -    def __init__(self, weights, gaussians):
    -        self._weights = weights
    -        self._gaussians = gaussians
    -        self._k = len(self._weights)
    -
         @property
         def weights(self):
             """
             Weights for each Gaussian distribution in the mixture, where weights[i] is
             the weight for Gaussian i, and weights.sum == 1.
             """
    -        return self._weights
    +        return array(self.call("weights"))
     
         @property
         def gaussians(self):
    @@ -205,12 +225,14 @@ def gaussians(self):
             Array of MultivariateGaussian where gaussians[i] represents
             the Multivariate Gaussian (Normal) Distribution for Gaussian i.
             """
    -        return self._gaussians
    +        return [
    +            MultivariateGaussian(gaussian[0], gaussian[1])
    +            for gaussian in zip(*self.call("gaussians"))]
     
         @property
         def k(self):
             """Number of gaussians in mixture."""
    -        return self._k
    +        return len(self.weights)
     
         def predict(self, x):
             """
    @@ -235,17 +257,30 @@ def predictSoft(self, x):
             :return:     membership_matrix. RDD of array of double values.
             """
             if isinstance(x, RDD):
    -            means, sigmas = zip(*[(g.mu, g.sigma) for g in self._gaussians])
    +            means, sigmas = zip(*[(g.mu, g.sigma) for g in self.gaussians])
                 membership_matrix = callMLlibFunc("predictSoftGMM", x.map(_convert_to_vector),
    -                                              _convert_to_vector(self._weights), means, sigmas)
    +                                              _convert_to_vector(self.weights), means, sigmas)
                 return membership_matrix.map(lambda x: pyarray.array('d', x))
             else:
                 raise TypeError("x should be represented by an RDD, "
                                 "but got %s." % type(x))
     
    +    @classmethod
    +    def load(cls, sc, path):
    +        """Load the GaussianMixtureModel from disk.
    +
    +        :param sc: SparkContext
    +        :param path: str, path to where the model is stored.
    +        """
    +        model = cls._load_java(sc, path)
    +        wrapper = sc._jvm.GaussianMixtureModelWrapper(model)
    +        return cls(wrapper)
    +
     
     class GaussianMixture(object):
         """
    +    .. note:: Experimental
    +
         Learning algorithm for Gaussian Mixtures using the expectation-maximization algorithm.
     
         :param data:            RDD of data points
    @@ -268,11 +303,10 @@ def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initia
                 initialModelWeights = initialModel.weights
                 initialModelMu = [initialModel.gaussians[i].mu for i in range(initialModel.k)]
                 initialModelSigma = [initialModel.gaussians[i].sigma for i in range(initialModel.k)]
    -        weight, mu, sigma = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector),
    -                                          k, convergenceTol, maxIterations, seed,
    -                                          initialModelWeights, initialModelMu, initialModelSigma)
    -        mvg_obj = [MultivariateGaussian(mu[i], sigma[i]) for i in range(k)]
    -        return GaussianMixtureModel(weight, mvg_obj)
    +        java_model = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector),
    +                                   k, convergenceTol, maxIterations, seed,
    +                                   initialModelWeights, initialModelMu, initialModelSigma)
    +        return GaussianMixtureModel(java_model)
     
     
     class PowerIterationClusteringModel(JavaModelWrapper, JavaSaveable, JavaLoader):
    @@ -282,18 +316,30 @@ class PowerIterationClusteringModel(JavaModelWrapper, JavaSaveable, JavaLoader):
     
         Model produced by [[PowerIterationClustering]].
     
    -    >>> data = [(0, 1, 1.0), (0, 2, 1.0), (1, 3, 1.0), (2, 3, 1.0),
    -    ...     (0, 3, 1.0), (1, 2, 1.0), (0, 4, 0.1)]
    +    >>> data = [(0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0), (1, 3, 1.0),
    +    ... (2, 3, 1.0), (3, 4, 0.1), (4, 5, 1.0), (4, 15, 1.0), (5, 6, 1.0),
    +    ... (6, 7, 1.0), (7, 8, 1.0), (8, 9, 1.0), (9, 10, 1.0), (10, 11, 1.0),
    +    ... (11, 12, 1.0), (12, 13, 1.0), (13, 14, 1.0), (14, 15, 1.0)]
         >>> rdd = sc.parallelize(data, 2)
         >>> model = PowerIterationClustering.train(rdd, 2, 100)
         >>> model.k
         2
    +    >>> result = sorted(model.assignments().collect(), key=lambda x: x.id)
    +    >>> result[0].cluster == result[1].cluster == result[2].cluster == result[3].cluster
    +    True
    +    >>> result[4].cluster == result[5].cluster == result[6].cluster == result[7].cluster
    +    True
         >>> import os, tempfile
         >>> path = tempfile.mkdtemp()
         >>> model.save(sc, path)
         >>> sameModel = PowerIterationClusteringModel.load(sc, path)
         >>> sameModel.k
         2
    +    >>> result = sorted(model.assignments().collect(), key=lambda x: x.id)
    +    >>> result[0].cluster == result[1].cluster == result[2].cluster == result[3].cluster
    +    True
    +    >>> result[4].cluster == result[5].cluster == result[6].cluster == result[7].cluster
    +    True
         >>> from shutil import rmtree
         >>> try:
         ...     rmtree(path)
    @@ -551,6 +597,108 @@ def predictOnValues(self, dstream):
             return dstream.mapValues(lambda x: self._model.predict(x))
     
     
    +class LDAModel(JavaModelWrapper):
    +
    +    """ A clustering model derived from the LDA method.
    +
    +    Latent Dirichlet Allocation (LDA), a topic model designed for text documents.
    +    Terminology
    +    - "word" = "term": an element of the vocabulary
    +    - "token": instance of a term appearing in a document
    +    - "topic": multinomial distribution over words representing some concept
    +    References:
    +    - Original LDA paper (journal version):
    +    Blei, Ng, and Jordan.  "Latent Dirichlet Allocation."  JMLR, 2003.
    +
    +    >>> from pyspark.mllib.linalg import Vectors
    +    >>> from numpy.testing import assert_almost_equal, assert_equal
    +    >>> data = [
    +    ...     [1, Vectors.dense([0.0, 1.0])],
    +    ...     [2, SparseVector(2, {0: 1.0})],
    +    ... ]
    +    >>> rdd =  sc.parallelize(data)
    +    >>> model = LDA.train(rdd, k=2)
    +    >>> model.vocabSize()
    +    2
    +    >>> topics = model.topicsMatrix()
    +    >>> topics_expect = array([[0.5,  0.5], [0.5, 0.5]])
    +    >>> assert_almost_equal(topics, topics_expect, 1)
    +
    +    >>> import os, tempfile
    +    >>> from shutil import rmtree
    +    >>> path = tempfile.mkdtemp()
    +    >>> model.save(sc, path)
    +    >>> sameModel = LDAModel.load(sc, path)
    +    >>> assert_equal(sameModel.topicsMatrix(), model.topicsMatrix())
    +    >>> sameModel.vocabSize() == model.vocabSize()
    +    True
    +    >>> try:
    +    ...     rmtree(path)
    +    ... except OSError:
    +    ...     pass
    +    """
    +
    +    def topicsMatrix(self):
    +        """Inferred topics, where each topic is represented by a distribution over terms."""
    +        return self.call("topicsMatrix").toArray()
    +
    +    def vocabSize(self):
    +        """Vocabulary size (number of terms or terms in the vocabulary)"""
    +        return self.call("vocabSize")
    +
    +    def save(self, sc, path):
    +        """Save the LDAModel on to disk.
    +
    +        :param sc: SparkContext
    +        :param path: str, path to where the model needs to be stored.
    +        """
    +        if not isinstance(sc, SparkContext):
    +            raise TypeError("sc should be a SparkContext, got type %s" % type(sc))
    +        if not isinstance(path, basestring):
    +            raise TypeError("path should be a basestring, got type %s" % type(path))
    +        self._java_model.save(sc._jsc.sc(), path)
    +
    +    @classmethod
    +    def load(cls, sc, path):
    +        """Load the LDAModel from disk.
    +
    +        :param sc: SparkContext
    +        :param path: str, path to where the model is stored.
    +        """
    +        if not isinstance(sc, SparkContext):
    +            raise TypeError("sc should be a SparkContext, got type %s" % type(sc))
    +        if not isinstance(path, basestring):
    +            raise TypeError("path should be a basestring, got type %s" % type(path))
    +        java_model = sc._jvm.org.apache.spark.mllib.clustering.DistributedLDAModel.load(
    +            sc._jsc.sc(), path)
    +        return cls(java_model)
    +
    +
    +class LDA(object):
    +
    +    @classmethod
    +    def train(cls, rdd, k=10, maxIterations=20, docConcentration=-1.0,
    +              topicConcentration=-1.0, seed=None, checkpointInterval=10, optimizer="em"):
    +        """Train a LDA model.
    +
    +        :param rdd:                 RDD of data points
    +        :param k:                   Number of clusters you want
    +        :param maxIterations:       Number of iterations. Default to 20
    +        :param docConcentration:    Concentration parameter (commonly named "alpha")
    +            for the prior placed on documents' distributions over topics ("theta").
    +        :param topicConcentration:  Concentration parameter (commonly named "beta" or "eta")
    +            for the prior placed on topics' distributions over terms.
    +        :param seed:                Random Seed
    +        :param checkpointInterval:  Period (in iterations) between checkpoints.
    +        :param optimizer:           LDAOptimizer used to perform the actual calculation.
    +            Currently "em", "online" are supported. Default to "em".
    +        """
    +        model = callMLlibFunc("trainLDAModel", rdd, k, maxIterations,
    +                              docConcentration, topicConcentration, seed,
    +                              checkpointInterval, optimizer)
    +        return LDAModel(model)
    +
    +
     def _test():
         import doctest
         import pyspark.mllib.clustering
    diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py
    index c5cf3a4e7ff22..4398ca86f2ec2 100644
    --- a/python/pyspark/mllib/evaluation.py
    +++ b/python/pyspark/mllib/evaluation.py
    @@ -82,7 +82,7 @@ class RegressionMetrics(JavaModelWrapper):
         ...     (2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)])
         >>> metrics = RegressionMetrics(predictionAndObservations)
         >>> metrics.explainedVariance
    -    0.95...
    +    8.859...
         >>> metrics.meanAbsoluteError
         0.5...
         >>> metrics.meanSquaredError
    @@ -152,6 +152,10 @@ class MulticlassMetrics(JavaModelWrapper):
         >>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0),
         ...     (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)])
         >>> metrics = MulticlassMetrics(predictionAndLabels)
    +    >>> metrics.confusionMatrix().toArray()
    +    array([[ 2.,  1.,  1.],
    +           [ 1.,  3.,  0.],
    +           [ 0.,  0.,  1.]])
         >>> metrics.falsePositiveRate(0.0)
         0.2...
         >>> metrics.precision(1.0)
    @@ -186,6 +190,13 @@ def __init__(self, predictionAndLabels):
             java_model = java_class(df._jdf)
             super(MulticlassMetrics, self).__init__(java_model)
     
    +    def confusionMatrix(self):
    +        """
    +        Returns confusion matrix: predicted classes are in columns,
    +        they are ordered by class label ascending, as in "labels".
    +        """
    +        return self.call("confusionMatrix")
    +
         def truePositiveRate(self, label):
             """
             Returns true positive rate for a given label (category).
    diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py
    index b7f00d60069e6..bdc4a132b1b18 100644
    --- a/python/pyspark/mllib/fpm.py
    +++ b/python/pyspark/mllib/fpm.py
    @@ -39,8 +39,8 @@ class FPGrowthModel(JavaModelWrapper):
         >>> data = [["a", "b", "c"], ["a", "b", "d", "e"], ["a", "c", "e"], ["a", "c", "f"]]
         >>> rdd = sc.parallelize(data, 2)
         >>> model = FPGrowth.train(rdd, 0.6, 2)
    -    >>> sorted(model.freqItemsets().collect(), key=lambda x: x.items)
    -    [FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'a', u'c'], freq=3), ...
    +    >>> sorted(model.freqItemsets().collect())
    +    [FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'c'], freq=3), ...
         """
     
         def freqItemsets(self):
    diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg/__init__.py
    similarity index 86%
    rename from python/pyspark/mllib/linalg.py
    rename to python/pyspark/mllib/linalg/__init__.py
    index 9959a01cce7e0..334dc8e38bb8f 100644
    --- a/python/pyspark/mllib/linalg.py
    +++ b/python/pyspark/mllib/linalg/__init__.py
    @@ -30,7 +30,9 @@
         basestring = str
         xrange = range
         import copyreg as copy_reg
    +    long = int
     else:
    +    from itertools import izip as zip
         import copy_reg
     
     import numpy as np
    @@ -116,6 +118,10 @@ def _format_float(f, digits=4):
         return s
     
     
    +def _format_float_list(l):
    +    return [_format_float(x) for x in l]
    +
    +
     class VectorUDT(UserDefinedType):
         """
         SQL user-defined type (UDT) for Vector.
    @@ -440,8 +446,10 @@ def __init__(self, size, *args):
             values (sorted by index).
     
             :param size: Size of the vector.
    -        :param args: Non-zero entries, as a dictionary, list of tupes,
    -               or two sorted lists containing indices and values.
    +        :param args: Active entries, as a dictionary {index: value, ...},
    +          a list of tuples [(index, value), ...], or a list of strictly i
    +          ncreasing indices and a list of corresponding values [index, ...],
    +          [value, ...]. Inactive entries are treated as zeros.
     
             >>> SparseVector(4, {1: 1.0, 3: 5.5})
             SparseVector(4, {1: 1.0, 3: 5.5})
    @@ -451,6 +459,7 @@ def __init__(self, size, *args):
             SparseVector(4, {1: 1.0, 3: 5.5})
             """
             self.size = int(size)
    +        """ Size of the vector. """
             assert 1 <= len(args) <= 2, "must pass either 2 or 3 arguments"
             if len(args) == 1:
                 pairs = args[0]
    @@ -458,7 +467,9 @@ def __init__(self, size, *args):
                     pairs = pairs.items()
                 pairs = sorted(pairs)
                 self.indices = np.array([p[0] for p in pairs], dtype=np.int32)
    +            """ A list of indices corresponding to active entries. """
                 self.values = np.array([p[1] for p in pairs], dtype=np.float64)
    +            """ A list of values corresponding to active entries. """
             else:
                 if isinstance(args[0], bytes):
                     assert isinstance(args[1], bytes), "values should be string too"
    @@ -555,7 +566,7 @@ def dot(self, other):
             25.0
             >>> a.dot(array.array('d', [1., 2., 3., 4.]))
             22.0
    -        >>> b = SparseVector(4, [2, 4], [1.0, 2.0])
    +        >>> b = SparseVector(4, [2], [1.0])
             >>> a.dot(b)
             0.0
             >>> a.dot(np.array([[1, 1], [2, 2], [3, 3], [4, 4]]))
    @@ -590,18 +601,14 @@ def dot(self, other):
                 return np.dot(other.array[self.indices], self.values)
     
             elif isinstance(other, SparseVector):
    -            result = 0.0
    -            i, j = 0, 0
    -            while i < len(self.indices) and j < len(other.indices):
    -                if self.indices[i] == other.indices[j]:
    -                    result += self.values[i] * other.values[j]
    -                    i += 1
    -                    j += 1
    -                elif self.indices[i] < other.indices[j]:
    -                    i += 1
    -                else:
    -                    j += 1
    -            return result
    +            # Find out common indices.
    +            self_cmind = np.in1d(self.indices, other.indices, assume_unique=True)
    +            self_values = self.values[self_cmind]
    +            if self_values.size == 0:
    +                return 0.0
    +            else:
    +                other_cmind = np.in1d(other.indices, self.indices, assume_unique=True)
    +                return np.dot(self_values, other.values[other_cmind])
     
             else:
                 return self.dot(_convert_to_vector(other))
    @@ -617,11 +624,11 @@ def squared_distance(self, other):
             11.0
             >>> a.squared_distance(np.array([1., 2., 3., 4.]))
             11.0
    -        >>> b = SparseVector(4, [2, 4], [1.0, 2.0])
    +        >>> b = SparseVector(4, [2], [1.0])
             >>> a.squared_distance(b)
    -        30.0
    +        26.0
             >>> b.squared_distance(a)
    -        30.0
    +        26.0
             >>> b.squared_distance([1., 2.])
             Traceback (most recent call last):
                 ...
    @@ -764,14 +771,18 @@ def sparse(size, *args):
             return SparseVector(size, *args)
     
         @staticmethod
    -    def dense(elements):
    +    def dense(*elements):
             """
    -        Create a dense vector of 64-bit floats from a Python list. Always
    -        returns a NumPy array.
    +        Create a dense vector of 64-bit floats from a Python list or numbers.
     
             >>> Vectors.dense([1, 2, 3])
             DenseVector([1.0, 2.0, 3.0])
    +        >>> Vectors.dense(1.0, 2.0)
    +        DenseVector([1.0, 2.0])
             """
    +        if len(elements) == 1 and not isinstance(elements[0], (float, int, long)):
    +            # it's list, numpy.array or other iterable object.
    +            elements = elements[0]
             return DenseVector(elements)
     
         @staticmethod
    @@ -874,6 +885,50 @@ def __reduce__(self):
                 self.numRows, self.numCols, self.values.tostring(),
                 int(self.isTransposed))
     
    +    def __str__(self):
    +        """
    +        Pretty printing of a DenseMatrix
    +
    +        >>> dm = DenseMatrix(2, 2, range(4))
    +        >>> print(dm)
    +        DenseMatrix([[ 0.,  2.],
    +                     [ 1.,  3.]])
    +        >>> dm = DenseMatrix(2, 2, range(4), isTransposed=True)
    +        >>> print(dm)
    +        DenseMatrix([[ 0.,  1.],
    +                     [ 2.,  3.]])
    +        """
    +        # Inspired by __repr__ in scipy matrices.
    +        array_lines = repr(self.toArray()).splitlines()
    +
    +        # We need to adjust six spaces which is the difference in number
    +        # of letters between "DenseMatrix" and "array"
    +        x = '\n'.join([(" " * 6 + line) for line in array_lines[1:]])
    +        return array_lines[0].replace("array", "DenseMatrix") + "\n" + x
    +
    +    def __repr__(self):
    +        """
    +        Representation of a DenseMatrix
    +
    +        >>> dm = DenseMatrix(2, 2, range(4))
    +        >>> dm
    +        DenseMatrix(2, 2, [0.0, 1.0, 2.0, 3.0], False)
    +        """
    +        # If the number of values are less than seventeen then return as it is.
    +        # Else return first eight values and last eight values.
    +        if len(self.values) < 17:
    +            entries = _format_float_list(self.values)
    +        else:
    +            entries = (
    +                _format_float_list(self.values[:8]) +
    +                ["..."] +
    +                _format_float_list(self.values[-8:])
    +            )
    +
    +        entries = ", ".join(entries)
    +        return "DenseMatrix({0}, {1}, [{2}], {3})".format(
    +            self.numRows, self.numCols, entries, self.isTransposed)
    +
         def toArray(self):
             """
             Return an numpy.ndarray
    @@ -950,6 +1005,84 @@ def __init__(self, numRows, numCols, colPtrs, rowIndices, values,
                 raise ValueError("Expected rowIndices of length %d, got %d."
                                  % (self.rowIndices.size, self.values.size))
     
    +    def __str__(self):
    +        """
    +        Pretty printing of a SparseMatrix
    +
    +        >>> sm1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4])
    +        >>> print(sm1)
    +        2 X 2 CSCMatrix
    +        (0,0) 2.0
    +        (1,0) 3.0
    +        (1,1) 4.0
    +        >>> sm1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True)
    +        >>> print(sm1)
    +        2 X 2 CSRMatrix
    +        (0,0) 2.0
    +        (0,1) 3.0
    +        (1,1) 4.0
    +        """
    +        spstr = "{0} X {1} ".format(self.numRows, self.numCols)
    +        if self.isTransposed:
    +            spstr += "CSRMatrix\n"
    +        else:
    +            spstr += "CSCMatrix\n"
    +
    +        cur_col = 0
    +        smlist = []
    +
    +        # Display first 16 values.
    +        if len(self.values) <= 16:
    +            zipindval = zip(self.rowIndices, self.values)
    +        else:
    +            zipindval = zip(self.rowIndices[:16], self.values[:16])
    +        for i, (rowInd, value) in enumerate(zipindval):
    +            if self.colPtrs[cur_col + 1] <= i:
    +                cur_col += 1
    +            if self.isTransposed:
    +                smlist.append('({0},{1}) {2}'.format(
    +                    cur_col, rowInd, _format_float(value)))
    +            else:
    +                smlist.append('({0},{1}) {2}'.format(
    +                    rowInd, cur_col, _format_float(value)))
    +        spstr += "\n".join(smlist)
    +
    +        if len(self.values) > 16:
    +            spstr += "\n.." * 2
    +        return spstr
    +
    +    def __repr__(self):
    +        """
    +        Representation of a SparseMatrix
    +
    +        >>> sm1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4])
    +        >>> sm1
    +        SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2.0, 3.0, 4.0], False)
    +        """
    +        rowIndices = list(self.rowIndices)
    +        colPtrs = list(self.colPtrs)
    +
    +        if len(self.values) <= 16:
    +            values = _format_float_list(self.values)
    +
    +        else:
    +            values = (
    +                _format_float_list(self.values[:8]) +
    +                ["..."] +
    +                _format_float_list(self.values[-8:])
    +            )
    +            rowIndices = rowIndices[:8] + ["..."] + rowIndices[-8:]
    +
    +        if len(self.colPtrs) > 16:
    +            colPtrs = colPtrs[:8] + ["..."] + colPtrs[-8:]
    +
    +        values = ", ".join(values)
    +        rowIndices = ", ".join([str(ind) for ind in rowIndices])
    +        colPtrs = ", ".join([str(ptr) for ptr in colPtrs])
    +        return "SparseMatrix({0}, {1}, [{2}], [{3}], [{4}], {5})".format(
    +            self.numRows, self.numCols, colPtrs, rowIndices,
    +            values, self.isTransposed)
    +
         def __reduce__(self):
             return SparseMatrix, (
                 self.numRows, self.numCols, self.colPtrs.tostring(),
    diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
    index 8e90adee5f4c2..5b7afc15ddfba 100644
    --- a/python/pyspark/mllib/regression.py
    +++ b/python/pyspark/mllib/regression.py
    @@ -97,9 +97,11 @@ class LinearRegressionModelBase(LinearModel):
     
         def predict(self, x):
             """
    -        Predict the value of the dependent variable given a vector x
    -        containing values for the independent variables.
    +        Predict the value of the dependent variable given a vector or
    +        an RDD of vectors containing values for the independent variables.
             """
    +        if isinstance(x, RDD):
    +            return x.map(self.predict)
             x = _convert_to_vector(x)
             return self.weights.dot(x) + self.intercept
     
    @@ -124,6 +126,8 @@ class LinearRegressionModel(LinearRegressionModelBase):
         True
         >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
         True
    +    >>> abs(lrm.predict(sc.parallelize([[1.0]])).collect()[0] - 1) < 0.5
    +    True
         >>> import os, tempfile
         >>> path = tempfile.mkdtemp()
         >>> lrm.save(sc, path)
    @@ -267,6 +271,8 @@ class LassoModel(LinearRegressionModelBase):
         True
         >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
         True
    +    >>> abs(lrm.predict(sc.parallelize([[1.0]])).collect()[0] - 1) < 0.5
    +    True
         >>> import os, tempfile
         >>> path = tempfile.mkdtemp()
         >>> lrm.save(sc, path)
    @@ -382,6 +388,8 @@ class RidgeRegressionModel(LinearRegressionModelBase):
         True
         >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
         True
    +    >>> abs(lrm.predict(sc.parallelize([[1.0]])).collect()[0] - 1) < 0.5
    +    True
         >>> import os, tempfile
         >>> path = tempfile.mkdtemp()
         >>> lrm.save(sc, path)
    diff --git a/python/pyspark/mllib/stat/_statistics.py b/python/pyspark/mllib/stat/_statistics.py
    index b475be4b4d953..36c8f48a4a882 100644
    --- a/python/pyspark/mllib/stat/_statistics.py
    +++ b/python/pyspark/mllib/stat/_statistics.py
    @@ -15,11 +15,15 @@
     # limitations under the License.
     #
     
    +import sys
    +if sys.version >= '3':
    +    basestring = str
    +
     from pyspark.rdd import RDD, ignore_unicode_prefix
     from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
     from pyspark.mllib.linalg import Matrix, _convert_to_vector
     from pyspark.mllib.regression import LabeledPoint
    -from pyspark.mllib.stat.test import ChiSqTestResult
    +from pyspark.mllib.stat.test import ChiSqTestResult, KolmogorovSmirnovTestResult
     
     
     __all__ = ['MultivariateStatisticalSummary', 'Statistics']
    @@ -238,6 +242,67 @@ def chiSqTest(observed, expected=None):
                 jmodel = callMLlibFunc("chiSqTest", _convert_to_vector(observed), expected)
             return ChiSqTestResult(jmodel)
     
    +    @staticmethod
    +    @ignore_unicode_prefix
    +    def kolmogorovSmirnovTest(data, distName="norm", *params):
    +        """
    +        .. note:: Experimental
    +
    +        Performs the Kolmogorov-Smirnov (KS) test for data sampled from
    +        a continuous distribution. It tests the null hypothesis that
    +        the data is generated from a particular distribution.
    +
    +        The given data is sorted and the Empirical Cumulative
    +        Distribution Function (ECDF) is calculated
    +        which for a given point is the number of points having a CDF
    +        value lesser than it divided by the total number of points.
    +
    +        Since the data is sorted, this is a step function
    +        that rises by (1 / length of data) for every ordered point.
    +
    +        The KS statistic gives us the maximum distance between the
    +        ECDF and the CDF. Intuitively if this statistic is large, the
    +        probabilty that the null hypothesis is true becomes small.
    +        For specific details of the implementation, please have a look
    +        at the Scala documentation.
    +
    +        :param data: RDD, samples from the data
    +        :param distName: string, currently only "norm" is supported.
    +                         (Normal distribution) to calculate the
    +                         theoretical distribution of the data.
    +        :param params: additional values which need to be provided for
    +                       a certain distribution.
    +                       If not provided, the default values are used.
    +        :return: KolmogorovSmirnovTestResult object containing the test
    +                 statistic, degrees of freedom, p-value,
    +                 the method used, and the null hypothesis.
    +
    +        >>> kstest = Statistics.kolmogorovSmirnovTest
    +        >>> data = sc.parallelize([-1.0, 0.0, 1.0])
    +        >>> ksmodel = kstest(data, "norm")
    +        >>> print(round(ksmodel.pValue, 3))
    +        1.0
    +        >>> print(round(ksmodel.statistic, 3))
    +        0.175
    +        >>> ksmodel.nullHypothesis
    +        u'Sample follows theoretical distribution'
    +
    +        >>> data = sc.parallelize([2.0, 3.0, 4.0])
    +        >>> ksmodel = kstest(data, "norm", 3.0, 1.0)
    +        >>> print(round(ksmodel.pValue, 3))
    +        1.0
    +        >>> print(round(ksmodel.statistic, 3))
    +        0.175
    +        """
    +        if not isinstance(data, RDD):
    +            raise TypeError("data should be an RDD, got %s." % type(data))
    +        if not isinstance(distName, basestring):
    +            raise TypeError("distName should be a string, got %s." % type(distName))
    +
    +        params = [float(param) for param in params]
    +        return KolmogorovSmirnovTestResult(
    +            callMLlibFunc("kolmogorovSmirnovTest", data, distName, params))
    +
     
     def _test():
         import doctest
    diff --git a/python/pyspark/mllib/stat/test.py b/python/pyspark/mllib/stat/test.py
    index 762506e952b43..0abe104049ff9 100644
    --- a/python/pyspark/mllib/stat/test.py
    +++ b/python/pyspark/mllib/stat/test.py
    @@ -15,24 +15,16 @@
     # limitations under the License.
     #
     
    -from pyspark.mllib.common import JavaModelWrapper
    +from pyspark.mllib.common import inherit_doc, JavaModelWrapper
     
     
    -__all__ = ["ChiSqTestResult"]
    +__all__ = ["ChiSqTestResult", "KolmogorovSmirnovTestResult"]
     
     
    -class ChiSqTestResult(JavaModelWrapper):
    +class TestResult(JavaModelWrapper):
         """
    -    .. note:: Experimental
    -
    -    Object containing the test results for the chi-squared hypothesis test.
    +    Base class for all test results.
         """
    -    @property
    -    def method(self):
    -        """
    -        Name of the test method
    -        """
    -        return self._java_model.method()
     
         @property
         def pValue(self):
    @@ -67,3 +59,24 @@ def nullHypothesis(self):
     
         def __str__(self):
             return self._java_model.toString()
    +
    +
    +@inherit_doc
    +class ChiSqTestResult(TestResult):
    +    """
    +    Contains test results for the chi-squared hypothesis test.
    +    """
    +
    +    @property
    +    def method(self):
    +        """
    +        Name of the test method
    +        """
    +        return self._java_model.method()
    +
    +
    +@inherit_doc
    +class KolmogorovSmirnovTestResult(TestResult):
    +    """
    +    Contains test results for the Kolmogorov-Smirnov test.
    +    """
    diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
    index d9f9874d50c1a..3f5a02af12e39 100644
    --- a/python/pyspark/mllib/tests.py
    +++ b/python/pyspark/mllib/tests.py
    @@ -27,7 +27,7 @@
     from shutil import rmtree
     
     from numpy import (
    -    array, array_equal, zeros, inf, random, exp, dot, all, mean, abs)
    +    array, array_equal, zeros, inf, random, exp, dot, all, mean, abs, arange, tile, ones)
     from numpy import sum as array_sum
     
     from py4j.protocol import Py4JJavaError
    @@ -189,6 +189,53 @@ def test_matrix_indexing(self):
                 for j in range(2):
                     self.assertEquals(mat[i, j], expected[i][j])
     
    +    def test_repr_dense_matrix(self):
    +        mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10])
    +        self.assertTrue(
    +            repr(mat),
    +            'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)')
    +
    +        mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10], True)
    +        self.assertTrue(
    +            repr(mat),
    +            'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)')
    +
    +        mat = DenseMatrix(6, 3, zeros(18))
    +        self.assertTrue(
    +            repr(mat),
    +            'DenseMatrix(6, 3, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..., \
    +                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], False)')
    +
    +    def test_repr_sparse_matrix(self):
    +        sm1t = SparseMatrix(
    +            3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0],
    +            isTransposed=True)
    +        self.assertTrue(
    +            repr(sm1t),
    +            'SparseMatrix(3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], True)')
    +
    +        indices = tile(arange(6), 3)
    +        values = ones(18)
    +        sm = SparseMatrix(6, 3, [0, 6, 12, 18], indices, values)
    +        self.assertTrue(
    +            repr(sm), "SparseMatrix(6, 3, [0, 6, 12, 18], \
    +                [0, 1, 2, 3, 4, 5, 0, 1, ..., 4, 5, 0, 1, 2, 3, 4, 5], \
    +                [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..., \
    +                1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], False)")
    +
    +        self.assertTrue(
    +            str(sm),
    +            "6 X 3 CSCMatrix\n\
    +            (0,0) 1.0\n(1,0) 1.0\n(2,0) 1.0\n(3,0) 1.0\n(4,0) 1.0\n(5,0) 1.0\n\
    +            (0,1) 1.0\n(1,1) 1.0\n(2,1) 1.0\n(3,1) 1.0\n(4,1) 1.0\n(5,1) 1.0\n\
    +            (0,2) 1.0\n(1,2) 1.0\n(2,2) 1.0\n(3,2) 1.0\n..\n..")
    +
    +        sm = SparseMatrix(1, 18, zeros(19), [], [])
    +        self.assertTrue(
    +            repr(sm),
    +            'SparseMatrix(1, 18, \
    +                [0, 0, 0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0, 0, 0, 0], [], [], False)')
    +
         def test_sparse_matrix(self):
             # Test sparse matrix creation.
             sm1 = SparseMatrix(
    @@ -198,6 +245,9 @@ def test_sparse_matrix(self):
             self.assertEquals(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4])
             self.assertEquals(sm1.rowIndices.tolist(), [1, 2, 1, 2])
             self.assertEquals(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0])
    +        self.assertTrue(
    +            repr(sm1),
    +            'SparseMatrix(3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0], False)')
     
             # Test indexing
             expected = [
    @@ -819,6 +869,25 @@ def test_right_number_of_results(self):
             self.assertIsNotNone(chi[1000])
     
     
    +class KolmogorovSmirnovTest(MLlibTestCase):
    +
    +    def test_R_implementation_equivalence(self):
    +        data = self.sc.parallelize([
    +            1.1626852897838, -0.585924465893051, 1.78546500331661, -1.33259371048501,
    +            -0.446566766553219, 0.569606122374976, -2.88971761441412, -0.869018343326555,
    +            -0.461702683149641, -0.555540910137444, -0.0201353678515895, -0.150382224136063,
    +            -0.628126755843964, 1.32322085193283, -1.52135057001199, -0.437427868856691,
    +            0.970577579543399, 0.0282226444247749, -0.0857821886527593, 0.389214404984942
    +        ])
    +        model = Statistics.kolmogorovSmirnovTest(data, "norm")
    +        self.assertAlmostEqual(model.statistic, 0.189, 3)
    +        self.assertAlmostEqual(model.pValue, 0.422, 3)
    +
    +        model = Statistics.kolmogorovSmirnovTest(data, "norm", 0, 1)
    +        self.assertAlmostEqual(model.statistic, 0.189, 3)
    +        self.assertAlmostEqual(model.pValue, 0.422, 3)
    +
    +
     class SerDeTest(MLlibTestCase):
         def test_to_java_object_rdd(self):  # SPARK-6660
             data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0)
    diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py
    index 875d3b2d642c6..916de2d6fcdbd 100644
    --- a/python/pyspark/mllib/util.py
    +++ b/python/pyspark/mllib/util.py
    @@ -21,7 +21,9 @@
     
     if sys.version > '3':
         xrange = range
    +    basestring = str
     
    +from pyspark import SparkContext
     from pyspark.mllib.common import callMLlibFunc, inherit_doc
     from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector
     
    @@ -223,6 +225,10 @@ class JavaSaveable(Saveable):
         """
     
         def save(self, sc, path):
    +        if not isinstance(sc, SparkContext):
    +            raise TypeError("sc should be a SparkContext, got type %s" % type(sc))
    +        if not isinstance(path, basestring):
    +            raise TypeError("path should be a basestring, got type %s" % type(path))
             self._java_model.save(sc._jsc.sc(), path)
     
     
    diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
    index 79dafb0a4ef27..fa8e0a0574a62 100644
    --- a/python/pyspark/rdd.py
    +++ b/python/pyspark/rdd.py
    @@ -700,12 +700,14 @@ def groupBy(self, f, numPartitions=None):
             return self.map(lambda x: (f(x), x)).groupByKey(numPartitions)
     
         @ignore_unicode_prefix
    -    def pipe(self, command, env={}):
    +    def pipe(self, command, env={}, checkCode=False):
             """
             Return an RDD created by piping elements to a forked external process.
     
             >>> sc.parallelize(['1', '2', '', '3']).pipe('cat').collect()
             [u'1', u'2', u'', u'3']
    +
    +        :param checkCode: whether or not to check the return value of the shell command.
             """
             def func(iterator):
                 pipe = Popen(
    @@ -717,7 +719,17 @@ def pipe_objs(out):
                         out.write(s.encode('utf-8'))
                     out.close()
                 Thread(target=pipe_objs, args=[pipe.stdin]).start()
    -            return (x.rstrip(b'\n').decode('utf-8') for x in iter(pipe.stdout.readline, b''))
    +
    +            def check_return_code():
    +                pipe.wait()
    +                if checkCode and pipe.returncode:
    +                    raise Exception("Pipe function `%s' exited "
    +                                    "with error code %d" % (command, pipe.returncode))
    +                else:
    +                    for i in range(0):
    +                        yield i
    +            return (x.rstrip(b'\n').decode('utf-8') for x in
    +                    chain(iter(pipe.stdout.readline, b''), check_return_code()))
             return self.mapPartitions(func)
     
         def foreach(self, f):
    @@ -850,6 +862,9 @@ def func(iterator):
                 for obj in iterator:
                     acc = op(obj, acc)
                 yield acc
    +        # collecting result of mapPartitions here ensures that the copy of
    +        # zeroValue provided to each partition is unique from the one provided
    +        # to the final reduce call
             vals = self.mapPartitions(func).collect()
             return reduce(op, vals, zeroValue)
     
    @@ -879,8 +894,11 @@ def func(iterator):
                 for obj in iterator:
                     acc = seqOp(acc, obj)
                 yield acc
    -
    -        return self.mapPartitions(func).fold(zeroValue, combOp)
    +        # collecting result of mapPartitions here ensures that the copy of
    +        # zeroValue provided to each partition is unique from the one provided
    +        # to the final reduce call
    +        vals = self.mapPartitions(func).collect()
    +        return reduce(combOp, vals, zeroValue)
     
         def treeAggregate(self, zeroValue, seqOp, combOp, depth=2):
             """
    @@ -1275,7 +1293,7 @@ def takeUpToNumLeft(iterator):
                         taken += 1
     
                 p = range(partsScanned, min(partsScanned + numPartsToTry, totalParts))
    -            res = self.context.runJob(self, takeUpToNumLeft, p, True)
    +            res = self.context.runJob(self, takeUpToNumLeft, p)
     
                 items += res
                 partsScanned += numPartsToTry
    @@ -2175,7 +2193,7 @@ def lookup(self, key):
             values = self.filter(lambda kv: kv[0] == key).values()
     
             if self.partitioner is not None:
    -            return self.ctx.runJob(values, lambda x: x, [self.partitioner(key)], False)
    +            return self.ctx.runJob(values, lambda x: x, [self.partitioner(key)])
     
             return values.collect()
     
    diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py
    index 144cdf0b0cdd5..99331297c19f0 100644
    --- a/python/pyspark/shell.py
    +++ b/python/pyspark/shell.py
    @@ -40,7 +40,7 @@
     if os.environ.get("SPARK_EXECUTOR_URI"):
         SparkContext.setSystemProperty("spark.executor.uri", os.environ["SPARK_EXECUTOR_URI"])
     
    -sc = SparkContext(appName="PySparkShell", pyFiles=add_files)
    +sc = SparkContext(pyFiles=add_files)
     atexit.register(lambda: sc.stop())
     
     try:
    diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py
    index 8fb71bac64a5e..b8118bdb7ca76 100644
    --- a/python/pyspark/shuffle.py
    +++ b/python/pyspark/shuffle.py
    @@ -606,7 +606,7 @@ def _open_file(self):
             if not os.path.exists(d):
                 os.makedirs(d)
             p = os.path.join(d, str(id(self)))
    -        self._file = open(p, "wb+", 65536)
    +        self._file = open(p, "w+b", 65536)
             self._ser = BatchedSerializer(CompressedSerializer(PickleSerializer()), 1024)
             os.unlink(p)
     
    diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
    index 309c11faf9319..917de24f3536b 100644
    --- a/python/pyspark/sql/context.py
    +++ b/python/pyspark/sql/context.py
    @@ -30,10 +30,11 @@
     from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
     from pyspark.sql import since
     from pyspark.sql.types import Row, StringType, StructType, _verify_type, \
    -    _infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter
    +    _infer_schema, _has_nulltype, _merge_type, _create_converter
     from pyspark.sql.dataframe import DataFrame
     from pyspark.sql.readwriter import DataFrameReader
     from pyspark.sql.utils import install_exception_handler
    +from pyspark.sql.functions import UserDefinedFunction
     
     try:
         import pandas
    @@ -191,19 +192,8 @@ def registerFunction(self, name, f, returnType=StringType()):
             >>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
             [Row(_c0=4)]
             """
    -        func = lambda _, it: map(lambda x: f(*x), it)
    -        ser = AutoBatchedSerializer(PickleSerializer())
    -        command = (func, None, ser, ser)
    -        pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self._sc, command, self)
    -        self._ssql_ctx.udf().registerPython(name,
    -                                            bytearray(pickled_cmd),
    -                                            env,
    -                                            includes,
    -                                            self._sc.pythonExec,
    -                                            self._sc.pythonVer,
    -                                            bvars,
    -                                            self._sc._javaAccumulator,
    -                                            returnType.json())
    +        udf = UserDefinedFunction(f, returnType, name)
    +        self._ssql_ctx.udf().registerPython(name, udf._judf)
     
         def _inferSchemaFromList(self, data):
             """
    @@ -287,6 +277,66 @@ def applySchema(self, rdd, schema):
     
             return self.createDataFrame(rdd, schema)
     
    +    def _createFromRDD(self, rdd, schema, samplingRatio):
    +        """
    +        Create an RDD for DataFrame from an existing RDD, returns the RDD and schema.
    +        """
    +        if schema is None or isinstance(schema, (list, tuple)):
    +            struct = self._inferSchema(rdd, samplingRatio)
    +            converter = _create_converter(struct)
    +            rdd = rdd.map(converter)
    +            if isinstance(schema, (list, tuple)):
    +                for i, name in enumerate(schema):
    +                    struct.fields[i].name = name
    +                    struct.names[i] = name
    +            schema = struct
    +
    +        elif isinstance(schema, StructType):
    +            # take the first few rows to verify schema
    +            rows = rdd.take(10)
    +            for row in rows:
    +                _verify_type(row, schema)
    +
    +        else:
    +            raise TypeError("schema should be StructType or list or None, but got: %s" % schema)
    +
    +        # convert python objects to sql data
    +        rdd = rdd.map(schema.toInternal)
    +        return rdd, schema
    +
    +    def _createFromLocal(self, data, schema):
    +        """
    +        Create an RDD for DataFrame from an list or pandas.DataFrame, returns
    +        the RDD and schema.
    +        """
    +        if has_pandas and isinstance(data, pandas.DataFrame):
    +            if schema is None:
    +                schema = [str(x) for x in data.columns]
    +            data = [r.tolist() for r in data.to_records(index=False)]
    +
    +        # make sure data could consumed multiple times
    +        if not isinstance(data, list):
    +            data = list(data)
    +
    +        if schema is None or isinstance(schema, (list, tuple)):
    +            struct = self._inferSchemaFromList(data)
    +            if isinstance(schema, (list, tuple)):
    +                for i, name in enumerate(schema):
    +                    struct.fields[i].name = name
    +                    struct.names[i] = name
    +            schema = struct
    +
    +        elif isinstance(schema, StructType):
    +            for row in data:
    +                _verify_type(row, schema)
    +
    +        else:
    +            raise TypeError("schema should be StructType or list or None, but got: %s" % schema)
    +
    +        # convert python objects to sql data
    +        data = [schema.toInternal(row) for row in data]
    +        return self._sc.parallelize(data), schema
    +
         @since(1.3)
         @ignore_unicode_prefix
         def createDataFrame(self, data, schema=None, samplingRatio=None):
    @@ -350,50 +400,15 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
             if isinstance(data, DataFrame):
                 raise TypeError("data is already a DataFrame")
     
    -        if has_pandas and isinstance(data, pandas.DataFrame):
    -            if schema is None:
    -                schema = [str(x) for x in data.columns]
    -            data = [r.tolist() for r in data.to_records(index=False)]
    -
    -        if not isinstance(data, RDD):
    -            if not isinstance(data, list):
    -                data = list(data)
    -            try:
    -                # data could be list, tuple, generator ...
    -                rdd = self._sc.parallelize(data)
    -            except Exception:
    -                raise TypeError("cannot create an RDD from type: %s" % type(data))
    -        else:
    -            rdd = data
    -
    -        if schema is None or isinstance(schema, (list, tuple)):
    -            if isinstance(data, RDD):
    -                struct = self._inferSchema(rdd, samplingRatio)
    -            else:
    -                struct = self._inferSchemaFromList(data)
    -            if isinstance(schema, (list, tuple)):
    -                for i, name in enumerate(schema):
    -                    struct.fields[i].name = name
    -            schema = struct
    -            converter = _create_converter(schema)
    -            rdd = rdd.map(converter)
    -
    -        elif isinstance(schema, StructType):
    -            # take the first few rows to verify schema
    -            rows = rdd.take(10)
    -            for row in rows:
    -                _verify_type(row, schema)
    -
    +        if isinstance(data, RDD):
    +            rdd, schema = self._createFromRDD(data, schema, samplingRatio)
             else:
    -            raise TypeError("schema should be StructType or list or None")
    -
    -        # convert python objects to sql data
    -        converter = _python_to_sql_converter(schema)
    -        rdd = rdd.map(converter)
    -
    +            rdd, schema = self._createFromLocal(data, schema)
             jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd())
    -        df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
    -        return DataFrame(df, self)
    +        jdf = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
    +        df = DataFrame(jdf, self)
    +        df._schema = schema
    +        return df
     
         @since(1.3)
         def registerDataFrameAsTable(self, df, tableName):
    diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
    index 1e9c657cf81b3..0f3480c239187 100644
    --- a/python/pyspark/sql/dataframe.py
    +++ b/python/pyspark/sql/dataframe.py
    @@ -31,7 +31,7 @@
     from pyspark.storagelevel import StorageLevel
     from pyspark.traceback_utils import SCCallSiteSync
     from pyspark.sql import since
    -from pyspark.sql.types import _create_cls, _parse_datatype_json_string
    +from pyspark.sql.types import _parse_datatype_json_string
     from pyspark.sql.column import Column, _to_seq, _to_java_column
     from pyspark.sql.readwriter import DataFrameWriter
     from pyspark.sql.types import *
    @@ -83,15 +83,7 @@ def rdd(self):
             """
             if self._lazy_rdd is None:
                 jrdd = self._jdf.javaToPython()
    -            rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer()))
    -            schema = self.schema
    -
    -            def applySchema(it):
    -                cls = _create_cls(schema)
    -                return map(cls, it)
    -
    -            self._lazy_rdd = rdd.mapPartitions(applySchema)
    -
    +            self._lazy_rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer()))
             return self._lazy_rdd
     
         @property
    @@ -287,9 +279,7 @@ def collect(self):
             """
             with SCCallSiteSync(self._sc) as css:
                 port = self._sc._jvm.PythonRDD.collectAndServe(self._jdf.javaToPython().rdd())
    -        rs = list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
    -        cls = _create_cls(self.schema)
    -        return [cls(r) for r in rs]
    +        return list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
     
         @ignore_unicode_prefix
         @since(1.3)
    @@ -451,6 +441,42 @@ def sample(self, withReplacement, fraction, seed=None):
             rdd = self._jdf.sample(withReplacement, fraction, long(seed))
             return DataFrame(rdd, self.sql_ctx)
     
    +    @since(1.5)
    +    def sampleBy(self, col, fractions, seed=None):
    +        """
    +        Returns a stratified sample without replacement based on the
    +        fraction given on each stratum.
    +
    +        :param col: column that defines strata
    +        :param fractions:
    +            sampling fraction for each stratum. If a stratum is not
    +            specified, we treat its fraction as zero.
    +        :param seed: random seed
    +        :return: a new DataFrame that represents the stratified sample
    +
    +        >>> from pyspark.sql.functions import col
    +        >>> dataset = sqlContext.range(0, 100).select((col("id") % 3).alias("key"))
    +        >>> sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, seed=0)
    +        >>> sampled.groupBy("key").count().orderBy("key").show()
    +        +---+-----+
    +        |key|count|
    +        +---+-----+
    +        |  0|    3|
    +        |  1|    8|
    +        +---+-----+
    +
    +        """
    +        if not isinstance(col, str):
    +            raise ValueError("col must be a string, but got %r" % type(col))
    +        if not isinstance(fractions, dict):
    +            raise ValueError("fractions must be a dict but got %r" % type(fractions))
    +        for k, v in fractions.items():
    +            if not isinstance(k, (float, int, long, basestring)):
    +                raise ValueError("key must be float, int, long, or string, but got %r" % type(k))
    +            fractions[k] = float(v)
    +        seed = seed if seed is not None else random.randint(0, sys.maxsize)
    +        return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions), seed), self.sql_ctx)
    +
         @since(1.4)
         def randomSplit(self, weights, seed=None):
             """Randomly splits this :class:`DataFrame` with the provided weights.
    @@ -1140,7 +1166,7 @@ def crosstab(self, col1, col2):
             non-zero pair frequencies will be returned.
             The first column of each row will be the distinct values of `col1` and the column names
             will be the distinct values of `col2`. The name of the first column will be `$col1_$col2`.
    -        Pairs that have no occurrences will have `null` as their counts.
    +        Pairs that have no occurrences will have zero as their counts.
             :func:`DataFrame.crosstab` and :func:`DataFrameStatFunctions.crosstab` are aliases.
     
             :param col1: The name of the first column. Distinct items will make the first item of
    @@ -1324,6 +1350,11 @@ def freqItems(self, cols, support=None):
     
         freqItems.__doc__ = DataFrame.freqItems.__doc__
     
    +    def sampleBy(self, col, fractions, seed=None):
    +        return self.df.sampleBy(col, fractions, seed)
    +
    +    sampleBy.__doc__ = DataFrame.sampleBy.__doc__
    +
     
     def _test():
         import doctest
    diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
    index 69e563ef36e87..a7295e25f0aa5 100644
    --- a/python/pyspark/sql/functions.py
    +++ b/python/pyspark/sql/functions.py
    @@ -39,21 +39,30 @@
         'coalesce',
         'countDistinct',
         'explode',
    +    'format_number',
    +    'length',
         'log2',
         'md5',
         'monotonicallyIncreasingId',
         'rand',
         'randn',
    +    'regexp_extract',
    +    'regexp_replace',
         'sha1',
         'sha2',
    +    'size',
         'sparkPartitionId',
    -    'strlen',
         'struct',
         'udf',
         'when']
     
     __all__ += ['lag', 'lead', 'ntile']
     
    +__all__ += [
    +    'date_format', 'date_add', 'date_sub', 'add_months', 'months_between',
    +    'year', 'quarter', 'month', 'hour', 'minute', 'second',
    +    'dayofmonth', 'dayofyear', 'weekofyear']
    +
     
     def _create_function(name, doc=""):
         """ Create a function for aggregator by name"""
    @@ -323,6 +332,48 @@ def explode(col):
         return Column(jc)
     
     
    +@ignore_unicode_prefix
    +@since(1.5)
    +def levenshtein(left, right):
    +    """Computes the Levenshtein distance of the two given strings.
    +
    +    >>> df0 = sqlContext.createDataFrame([('kitten', 'sitting',)], ['l', 'r'])
    +    >>> df0.select(levenshtein('l', 'r').alias('d')).collect()
    +    [Row(d=3)]
    +    """
    +    sc = SparkContext._active_spark_context
    +    jc = sc._jvm.functions.levenshtein(_to_java_column(left), _to_java_column(right))
    +    return Column(jc)
    +
    +
    +@ignore_unicode_prefix
    +@since(1.5)
    +def regexp_extract(str, pattern, idx):
    +    """Extract a specific(idx) group identified by a java regex, from the specified string column.
    +
    +    >>> df = sqlContext.createDataFrame([('100-200',)], ['str'])
    +    >>> df.select(regexp_extract('str', '(\d+)-(\d+)', 1).alias('d')).collect()
    +    [Row(d=u'100')]
    +    """
    +    sc = SparkContext._active_spark_context
    +    jc = sc._jvm.functions.regexp_extract(_to_java_column(str), pattern, idx)
    +    return Column(jc)
    +
    +
    +@ignore_unicode_prefix
    +@since(1.5)
    +def regexp_replace(str, pattern, replacement):
    +    """Replace all substrings of the specified string value that match regexp with rep.
    +
    +    >>> df = sqlContext.createDataFrame([('100-200',)], ['str'])
    +    >>> df.select(regexp_replace('str', '(\\d+)', '##').alias('d')).collect()
    +    [Row(d=u'##-##')]
    +    """
    +    sc = SparkContext._active_spark_context
    +    jc = sc._jvm.functions.regexp_replace(_to_java_column(str), pattern, replacement)
    +    return Column(jc)
    +
    +
     @ignore_unicode_prefix
     @since(1.5)
     def md5(col):
    @@ -381,6 +432,34 @@ def randn(seed=None):
         return Column(jc)
     
     
    +@ignore_unicode_prefix
    +@since(1.5)
    +def hex(col):
    +    """Computes hex value of the given column, which could be StringType,
    +    BinaryType, IntegerType or LongType.
    +
    +    >>> sqlContext.createDataFrame([('ABC', 3)], ['a', 'b']).select(hex('a'), hex('b')).collect()
    +    [Row(hex(a)=u'414243', hex(b)=u'3')]
    +    """
    +    sc = SparkContext._active_spark_context
    +    jc = sc._jvm.functions.hex(_to_java_column(col))
    +    return Column(jc)
    +
    +
    +@ignore_unicode_prefix
    +@since(1.5)
    +def unhex(col):
    +    """Inverse of hex. Interprets each pair of characters as a hexadecimal number
    +    and converts to the byte representation of number.
    +
    +    >>> sqlContext.createDataFrame([('414243',)], ['a']).select(unhex('a')).collect()
    +    [Row(unhex(a)=bytearray(b'ABC'))]
    +    """
    +    sc = SparkContext._active_spark_context
    +    jc = sc._jvm.functions.unhex(_to_java_column(col))
    +    return Column(jc)
    +
    +
     @ignore_unicode_prefix
     @since(1.5)
     def sha1(col):
    @@ -462,16 +541,40 @@ def sparkPartitionId():
         return Column(sc._jvm.functions.sparkPartitionId())
     
     
    +def expr(str):
    +    """Parses the expression string into the column that it represents
    +
    +    >>> df.select(expr("length(name)")).collect()
    +    [Row('length(name)=5), Row('length(name)=3)]
    +    """
    +    sc = SparkContext._active_spark_context
    +    return Column(sc._jvm.functions.expr(str))
    +
    +
     @ignore_unicode_prefix
     @since(1.5)
    -def strlen(col):
    -    """Calculates the length of a string expression.
    +def length(col):
    +    """Calculates the length of a string or binary expression.
     
    -    >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(strlen('a').alias('length')).collect()
    +    >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(length('a').alias('length')).collect()
         [Row(length=3)]
         """
         sc = SparkContext._active_spark_context
    -    return Column(sc._jvm.functions.strlen(_to_java_column(col)))
    +    return Column(sc._jvm.functions.length(_to_java_column(col)))
    +
    +
    +@ignore_unicode_prefix
    +@since(1.5)
    +def format_number(col, d):
    +    """Formats the number X to a format like '#,###,###.##', rounded to d decimal places,
    +       and returns the result as a string.
    +    :param col: the column name of the numeric value to be formatted
    +    :param d: the N decimal places
    +    >>> sqlContext.createDataFrame([(5,)], ['a']).select(format_number('a', 4).alias('v')).collect()
    +    [Row(v=u'5.0000')]
    +    """
    +    sc = SparkContext._active_spark_context
    +    return Column(sc._jvm.functions.format_number(_to_java_column(col), d))
     
     
     @ignore_unicode_prefix
    @@ -595,29 +698,234 @@ def ntile(n):
         return Column(sc._jvm.functions.ntile(int(n)))
     
     
    +@ignore_unicode_prefix
    +@since(1.5)
    +def date_format(dateCol, format):
    +    """
    +    Converts a date/timestamp/string to a value of string in the format specified by the date
    +    format given by the second argument.
    +
    +    A pattern could be for instance `dd.MM.yyyy` and could return a string like '18.03.1993'. All
    +    pattern letters of the Java class `java.text.SimpleDateFormat` can be used.
    +
    +    NOTE: Use when ever possible specialized functions like `year`. These benefit from a
    +    specialized implementation.
    +
    +    >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a'])
    +    >>> df.select(date_format('a', 'MM/dd/yyy').alias('date')).collect()
    +    [Row(date=u'04/08/2015')]
    +    """
    +    sc = SparkContext._active_spark_context
    +    return Column(sc._jvm.functions.date_format(_to_java_column(dateCol), format))
    +
    +
    +@since(1.5)
    +def year(col):
    +    """
    +    Extract the year of a given date as integer.
    +
    +    >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a'])
    +    >>> df.select(year('a').alias('year')).collect()
    +    [Row(year=2015)]
    +    """
    +    sc = SparkContext._active_spark_context
    +    return Column(sc._jvm.functions.year(_to_java_column(col)))
    +
    +
    +@since(1.5)
    +def quarter(col):
    +    """
    +    Extract the quarter of a given date as integer.
    +
    +    >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a'])
    +    >>> df.select(quarter('a').alias('quarter')).collect()
    +    [Row(quarter=2)]
    +    """
    +    sc = SparkContext._active_spark_context
    +    return Column(sc._jvm.functions.quarter(_to_java_column(col)))
    +
    +
    +@since(1.5)
    +def month(col):
    +    """
    +    Extract the month of a given date as integer.
    +
    +    >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a'])
    +    >>> df.select(month('a').alias('month')).collect()
    +    [Row(month=4)]
    +   """
    +    sc = SparkContext._active_spark_context
    +    return Column(sc._jvm.functions.month(_to_java_column(col)))
    +
    +
    +@since(1.5)
    +def dayofmonth(col):
    +    """
    +    Extract the day of the month of a given date as integer.
    +
    +    >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a'])
    +    >>> df.select(dayofmonth('a').alias('day')).collect()
    +    [Row(day=8)]
    +    """
    +    sc = SparkContext._active_spark_context
    +    return Column(sc._jvm.functions.dayofmonth(_to_java_column(col)))
    +
    +
    +@since(1.5)
    +def dayofyear(col):
    +    """
    +    Extract the day of the year of a given date as integer.
    +
    +    >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a'])
    +    >>> df.select(dayofyear('a').alias('day')).collect()
    +    [Row(day=98)]
    +    """
    +    sc = SparkContext._active_spark_context
    +    return Column(sc._jvm.functions.dayofyear(_to_java_column(col)))
    +
    +
    +@since(1.5)
    +def hour(col):
    +    """
    +    Extract the hours of a given date as integer.
    +
    +    >>> df = sqlContext.createDataFrame([('2015-04-08 13:08:15',)], ['a'])
    +    >>> df.select(hour('a').alias('hour')).collect()
    +    [Row(hour=13)]
    +    """
    +    sc = SparkContext._active_spark_context
    +    return Column(sc._jvm.functions.hour(_to_java_column(col)))
    +
    +
    +@since(1.5)
    +def minute(col):
    +    """
    +    Extract the minutes of a given date as integer.
    +
    +    >>> df = sqlContext.createDataFrame([('2015-04-08 13:08:15',)], ['a'])
    +    >>> df.select(minute('a').alias('minute')).collect()
    +    [Row(minute=8)]
    +    """
    +    sc = SparkContext._active_spark_context
    +    return Column(sc._jvm.functions.minute(_to_java_column(col)))
    +
    +
    +@since(1.5)
    +def second(col):
    +    """
    +    Extract the seconds of a given date as integer.
    +
    +    >>> df = sqlContext.createDataFrame([('2015-04-08 13:08:15',)], ['a'])
    +    >>> df.select(second('a').alias('second')).collect()
    +    [Row(second=15)]
    +    """
    +    sc = SparkContext._active_spark_context
    +    return Column(sc._jvm.functions.second(_to_java_column(col)))
    +
    +
    +@since(1.5)
    +def weekofyear(col):
    +    """
    +    Extract the week number of a given date as integer.
    +
    +    >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a'])
    +    >>> df.select(weekofyear(df.a).alias('week')).collect()
    +    [Row(week=15)]
    +    """
    +    sc = SparkContext._active_spark_context
    +    return Column(sc._jvm.functions.weekofyear(_to_java_column(col)))
    +
    +
    +@since(1.5)
    +def date_add(start, days):
    +    """
    +    Returns the date that is `days` days after `start`
    +
    +    >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d'])
    +    >>> df.select(date_add(df.d, 1).alias('d')).collect()
    +    [Row(d=datetime.date(2015, 4, 9))]
    +    """
    +    sc = SparkContext._active_spark_context
    +    return Column(sc._jvm.functions.date_add(_to_java_column(start), days))
    +
    +
    +@since(1.5)
    +def date_sub(start, days):
    +    """
    +    Returns the date that is `days` days before `start`
    +
    +    >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d'])
    +    >>> df.select(date_sub(df.d, 1).alias('d')).collect()
    +    [Row(d=datetime.date(2015, 4, 7))]
    +    """
    +    sc = SparkContext._active_spark_context
    +    return Column(sc._jvm.functions.date_sub(_to_java_column(start), days))
    +
    +
    +@since(1.5)
    +def add_months(start, months):
    +    """
    +    Returns the date that is `months` months after `start`
    +
    +    >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d'])
    +    >>> df.select(add_months(df.d, 1).alias('d')).collect()
    +    [Row(d=datetime.date(2015, 5, 8))]
    +    """
    +    sc = SparkContext._active_spark_context
    +    return Column(sc._jvm.functions.add_months(_to_java_column(start), months))
    +
    +
    +@since(1.5)
    +def months_between(date1, date2):
    +    """
    +    Returns the number of months between date1 and date2.
    +
    +    >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['t', 'd'])
    +    >>> df.select(months_between(df.t, df.d).alias('months')).collect()
    +    [Row(months=3.9495967...)]
    +    """
    +    sc = SparkContext._active_spark_context
    +    return Column(sc._jvm.functions.months_between(_to_java_column(date1), _to_java_column(date2)))
    +
    +
    +@since(1.5)
    +def size(col):
    +    """
    +    Collection function: returns the length of the array or map stored in the column.
    +    :param col: name of column or expression
    +
    +    >>> df = sqlContext.createDataFrame([([1, 2, 3],),([1],),([],)], ['data'])
    +    >>> df.select(size(df.data)).collect()
    +    [Row(size(data)=3), Row(size(data)=1), Row(size(data)=0)]
    +    """
    +    sc = SparkContext._active_spark_context
    +    return Column(sc._jvm.functions.size(_to_java_column(col)))
    +
    +
     class UserDefinedFunction(object):
         """
         User defined function in Python
     
         .. versionadded:: 1.3
         """
    -    def __init__(self, func, returnType):
    +    def __init__(self, func, returnType, name=None):
             self.func = func
             self.returnType = returnType
             self._broadcast = None
    -        self._judf = self._create_judf()
    +        self._judf = self._create_judf(name)
     
    -    def _create_judf(self):
    -        f = self.func  # put it in closure `func`
    -        func = lambda _, it: map(lambda x: f(*x), it)
    +    def _create_judf(self, name):
    +        f, returnType = self.func, self.returnType  # put them in closure `func`
    +        func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it)
             ser = AutoBatchedSerializer(PickleSerializer())
             command = (func, None, ser, ser)
             sc = SparkContext._active_spark_context
             pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
             ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
             jdt = ssql_ctx.parseDataType(self.returnType.json())
    -        fname = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
    -        judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env, includes,
    +        if name is None:
    +            name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
    +        judf = sc._jvm.UserDefinedPythonFunction(name, bytearray(pickled_command), env, includes,
                                                      sc.pythonExec, sc.pythonVer, broadcast_vars,
                                                      sc._javaAccumulator, jdt)
             return judf
    diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
    index 882a03090ec13..dea8bad79e187 100644
    --- a/python/pyspark/sql/readwriter.py
    +++ b/python/pyspark/sql/readwriter.py
    @@ -146,14 +146,28 @@ def table(self, tableName):
             return self._df(self._jreader.table(tableName))
     
         @since(1.4)
    -    def parquet(self, *path):
    +    def parquet(self, *paths):
             """Loads a Parquet file, returning the result as a :class:`DataFrame`.
     
             >>> df = sqlContext.read.parquet('python/test_support/sql/parquet_partitioned')
             >>> df.dtypes
             [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]
             """
    -        return self._df(self._jreader.parquet(_to_seq(self._sqlContext._sc, path)))
    +        return self._df(self._jreader.parquet(_to_seq(self._sqlContext._sc, paths)))
    +
    +    @since(1.5)
    +    def orc(self, path):
    +        """
    +        Loads an ORC file, returning the result as a :class:`DataFrame`.
    +
    +        ::Note: Currently ORC support is only available together with
    +        :class:`HiveContext`.
    +
    +        >>> df = hiveContext.read.orc('python/test_support/sql/orc_partitioned')
    +        >>> df.dtypes
    +        [('a', 'bigint'), ('b', 'int'), ('c', 'int')]
    +        """
    +        return self._df(self._jreader.orc(path))
     
         @since(1.4)
         def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPartitions=None,
    @@ -378,6 +392,29 @@ def parquet(self, path, mode=None, partitionBy=None):
                 self.partitionBy(partitionBy)
             self._jwrite.parquet(path)
     
    +    def orc(self, path, mode=None, partitionBy=None):
    +        """Saves the content of the :class:`DataFrame` in ORC format at the specified path.
    +
    +        ::Note: Currently ORC support is only available together with
    +        :class:`HiveContext`.
    +
    +        :param path: the path in any Hadoop supported file system
    +        :param mode: specifies the behavior of the save operation when data already exists.
    +
    +            * ``append``: Append contents of this :class:`DataFrame` to existing data.
    +            * ``overwrite``: Overwrite existing data.
    +            * ``ignore``: Silently ignore this operation if data already exists.
    +            * ``error`` (default case): Throw an exception if data already exists.
    +        :param partitionBy: names of partitioning columns
    +
    +        >>> orc_df = hiveContext.read.orc('python/test_support/sql/orc_partitioned')
    +        >>> orc_df.write.orc(os.path.join(tempfile.mkdtemp(), 'data'))
    +        """
    +        self.mode(mode)
    +        if partitionBy is not None:
    +            self.partitionBy(partitionBy)
    +        self._jwrite.orc(path)
    +
         @since(1.4)
         def jdbc(self, url, table, mode=None, properties={}):
             """Saves the content of the :class:`DataFrame` to a external database table via JDBC.
    @@ -408,7 +445,7 @@ def _test():
         import os
         import tempfile
         from pyspark.context import SparkContext
    -    from pyspark.sql import Row, SQLContext
    +    from pyspark.sql import Row, SQLContext, HiveContext
         import pyspark.sql.readwriter
     
         os.chdir(os.environ["SPARK_HOME"])
    @@ -420,6 +457,7 @@ def _test():
         globs['os'] = os
         globs['sc'] = sc
         globs['sqlContext'] = SQLContext(sc)
    +    globs['hiveContext'] = HiveContext(sc)
         globs['df'] = globs['sqlContext'].read.parquet('python/test_support/sql/parquet_partitioned')
     
         (failure_count, test_count) = doctest.testmod(
    diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
    index 333378c7f1854..ebd3ea8db6a43 100644
    --- a/python/pyspark/sql/tests.py
    +++ b/python/pyspark/sql/tests.py
    @@ -45,9 +45,9 @@
     from pyspark.sql.types import *
     from pyspark.sql.types import UserDefinedType, _infer_type
     from pyspark.tests import ReusedPySparkTestCase
    -from pyspark.sql.functions import UserDefinedFunction
    +from pyspark.sql.functions import UserDefinedFunction, sha2
     from pyspark.sql.window import Window
    -from pyspark.sql.utils import AnalysisException
    +from pyspark.sql.utils import AnalysisException, IllegalArgumentException
     
     
     class UTC(datetime.tzinfo):
    @@ -75,7 +75,7 @@ def sqlType(self):
     
         @classmethod
         def module(cls):
    -        return 'pyspark.tests'
    +        return 'pyspark.sql.tests'
     
         @classmethod
         def scalaUDT(cls):
    @@ -106,10 +106,45 @@ def __str__(self):
             return "(%s,%s)" % (self.x, self.y)
     
         def __eq__(self, other):
    -        return isinstance(other, ExamplePoint) and \
    +        return isinstance(other, self.__class__) and \
                 other.x == self.x and other.y == self.y
     
     
    +class PythonOnlyUDT(UserDefinedType):
    +    """
    +    User-defined type (UDT) for ExamplePoint.
    +    """
    +
    +    @classmethod
    +    def sqlType(self):
    +        return ArrayType(DoubleType(), False)
    +
    +    @classmethod
    +    def module(cls):
    +        return '__main__'
    +
    +    def serialize(self, obj):
    +        return [obj.x, obj.y]
    +
    +    def deserialize(self, datum):
    +        return PythonOnlyPoint(datum[0], datum[1])
    +
    +    @staticmethod
    +    def foo():
    +        pass
    +
    +    @property
    +    def props(self):
    +        return {}
    +
    +
    +class PythonOnlyPoint(ExamplePoint):
    +    """
    +    An example class to demonstrate UDT in only Python
    +    """
    +    __UDT__ = PythonOnlyUDT()
    +
    +
     class DataTypeTests(unittest.TestCase):
         # regression test for SPARK-6055
         def test_data_type_eq(self):
    @@ -151,6 +186,17 @@ def test_range(self):
             self.assertEqual(self.sqlCtx.range(-2).count(), 0)
             self.assertEqual(self.sqlCtx.range(3).count(), 3)
     
    +    def test_duplicated_column_names(self):
    +        df = self.sqlCtx.createDataFrame([(1, 2)], ["c", "c"])
    +        row = df.select('*').first()
    +        self.assertEqual(1, row[0])
    +        self.assertEqual(2, row[1])
    +        self.assertEqual("Row(c=1, c=2)", str(row))
    +        # Cannot access columns
    +        self.assertRaises(AnalysisException, lambda: df.select(df[0]).first())
    +        self.assertRaises(AnalysisException, lambda: df.select(df.c).first())
    +        self.assertRaises(AnalysisException, lambda: df.select(df["c"]).first())
    +
         def test_explode(self):
             from pyspark.sql.functions import explode
             d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})]
    @@ -322,6 +368,10 @@ def test_infer_nested_schema(self):
             df = self.sqlCtx.inferSchema(rdd)
             self.assertEquals(Row(field1=1, field2=u'row1'), df.first())
     
    +    def test_select_null_literal(self):
    +        df = self.sqlCtx.sql("select null as col")
    +        self.assertEquals(Row(col=None), df.first())
    +
         def test_apply_schema(self):
             from datetime import date, datetime
             rdd = self.sc.parallelize([(127, -128, -32768, 32767, 2147483647, 1.0,
    @@ -380,10 +430,39 @@ def test_convert_row_to_dict(self):
             self.assertEqual(1, row.asDict()["l"][0].a)
             self.assertEqual(1.0, row.asDict()['d']['key'].c)
     
    +    def test_udt(self):
    +        from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _verify_type
    +        from pyspark.sql.tests import ExamplePointUDT, ExamplePoint
    +
    +        def check_datatype(datatype):
    +            pickled = pickle.loads(pickle.dumps(datatype))
    +            assert datatype == pickled
    +            scala_datatype = self.sqlCtx._ssql_ctx.parseDataType(datatype.json())
    +            python_datatype = _parse_datatype_json_string(scala_datatype.json())
    +            assert datatype == python_datatype
    +
    +        check_datatype(ExamplePointUDT())
    +        structtype_with_udt = StructType([StructField("label", DoubleType(), False),
    +                                          StructField("point", ExamplePointUDT(), False)])
    +        check_datatype(structtype_with_udt)
    +        p = ExamplePoint(1.0, 2.0)
    +        self.assertEqual(_infer_type(p), ExamplePointUDT())
    +        _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT())
    +        self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], ExamplePointUDT()))
    +
    +        check_datatype(PythonOnlyUDT())
    +        structtype_with_udt = StructType([StructField("label", DoubleType(), False),
    +                                          StructField("point", PythonOnlyUDT(), False)])
    +        check_datatype(structtype_with_udt)
    +        p = PythonOnlyPoint(1.0, 2.0)
    +        self.assertEqual(_infer_type(p), PythonOnlyUDT())
    +        _verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT())
    +        self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT()))
    +
         def test_infer_schema_with_udt(self):
             from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
             row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
    -        df = self.sc.parallelize([row]).toDF()
    +        df = self.sqlCtx.createDataFrame([row])
             schema = df.schema
             field = [f for f in schema.fields if f.name == "point"][0]
             self.assertEqual(type(field.dataType), ExamplePointUDT)
    @@ -391,26 +470,66 @@ def test_infer_schema_with_udt(self):
             point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
             self.assertEqual(point, ExamplePoint(1.0, 2.0))
     
    +        row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
    +        df = self.sqlCtx.createDataFrame([row])
    +        schema = df.schema
    +        field = [f for f in schema.fields if f.name == "point"][0]
    +        self.assertEqual(type(field.dataType), PythonOnlyUDT)
    +        df.registerTempTable("labeled_point")
    +        point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
    +        self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
    +
         def test_apply_schema_with_udt(self):
             from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
             row = (1.0, ExamplePoint(1.0, 2.0))
    -        rdd = self.sc.parallelize([row])
             schema = StructType([StructField("label", DoubleType(), False),
                                  StructField("point", ExamplePointUDT(), False)])
    -        df = rdd.toDF(schema)
    +        df = self.sqlCtx.createDataFrame([row], schema)
             point = df.head().point
             self.assertEquals(point, ExamplePoint(1.0, 2.0))
     
    +        row = (1.0, PythonOnlyPoint(1.0, 2.0))
    +        schema = StructType([StructField("label", DoubleType(), False),
    +                             StructField("point", PythonOnlyUDT(), False)])
    +        df = self.sqlCtx.createDataFrame([row], schema)
    +        point = df.head().point
    +        self.assertEquals(point, PythonOnlyPoint(1.0, 2.0))
    +
    +    def test_udf_with_udt(self):
    +        from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
    +        row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
    +        df = self.sqlCtx.createDataFrame([row])
    +        self.assertEqual(1.0, df.map(lambda r: r.point.x).first())
    +        udf = UserDefinedFunction(lambda p: p.y, DoubleType())
    +        self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
    +        udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT())
    +        self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
    +
    +        row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
    +        df = self.sqlCtx.createDataFrame([row])
    +        self.assertEqual(1.0, df.map(lambda r: r.point.x).first())
    +        udf = UserDefinedFunction(lambda p: p.y, DoubleType())
    +        self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
    +        udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT())
    +        self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
    +
         def test_parquet_with_udt(self):
    -        from pyspark.sql.tests import ExamplePoint
    +        from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
             row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
    -        df0 = self.sc.parallelize([row]).toDF()
    +        df0 = self.sqlCtx.createDataFrame([row])
             output_dir = os.path.join(self.tempdir.name, "labeled_point")
    -        df0.saveAsParquetFile(output_dir)
    +        df0.write.parquet(output_dir)
             df1 = self.sqlCtx.parquetFile(output_dir)
             point = df1.head().point
             self.assertEquals(point, ExamplePoint(1.0, 2.0))
     
    +        row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
    +        df0 = self.sqlCtx.createDataFrame([row])
    +        df0.write.parquet(output_dir, mode='overwrite')
    +        df1 = self.sqlCtx.parquetFile(output_dir)
    +        point = df1.head().point
    +        self.assertEquals(point, PythonOnlyPoint(1.0, 2.0))
    +
         def test_column_operators(self):
             ci = self.df.key
             cs = self.df.value
    @@ -686,19 +805,31 @@ def test_filter_with_datetime(self):
         def test_time_with_timezone(self):
             day = datetime.date.today()
             now = datetime.datetime.now()
    -        ts = time.mktime(now.timetuple()) + now.microsecond / 1e6
    +        ts = time.mktime(now.timetuple())
             # class in __main__ is not serializable
             from pyspark.sql.tests import UTC
             utc = UTC()
    -        utcnow = datetime.datetime.fromtimestamp(ts, utc)
    +        utcnow = datetime.datetime.utcfromtimestamp(ts)  # without microseconds
    +        # add microseconds to utcnow (keeping year,month,day,hour,minute,second)
    +        utcnow = datetime.datetime(*(utcnow.timetuple()[:6] + (now.microsecond, utc)))
             df = self.sqlCtx.createDataFrame([(day, now, utcnow)])
             day1, now1, utcnow1 = df.first()
    -        # Pyrolite serialize java.sql.Date as datetime, will be fixed in new version
    -        self.assertEqual(day1.date(), day)
    -        # Pyrolite does not support microsecond, the error should be
    -        # less than 1 millisecond
    -        self.assertTrue(now - now1 < datetime.timedelta(0.001))
    -        self.assertTrue(now - utcnow1 < datetime.timedelta(0.001))
    +        self.assertEqual(day1, day)
    +        self.assertEqual(now, now1)
    +        self.assertEqual(now, utcnow1)
    +
    +    def test_decimal(self):
    +        from decimal import Decimal
    +        schema = StructType([StructField("decimal", DecimalType(10, 5))])
    +        df = self.sqlCtx.createDataFrame([(Decimal("3.14159"),)], schema)
    +        row = df.select(df.decimal + 1).first()
    +        self.assertEqual(row[0], Decimal("4.14159"))
    +        tmpPath = tempfile.mkdtemp()
    +        shutil.rmtree(tmpPath)
    +        df.write.parquet(tmpPath)
    +        df2 = self.sqlCtx.read.parquet(tmpPath)
    +        row = df2.first()
    +        self.assertEqual(row[0], Decimal("3.14159"))
     
         def test_dropna(self):
             schema = StructType([
    @@ -809,6 +940,13 @@ def test_bitwise_operations(self):
             result = df.select(functions.bitwiseNOT(df.b)).collect()[0].asDict()
             self.assertEqual(~75, result['~b'])
     
    +    def test_expr(self):
    +        from pyspark.sql import functions
    +        row = Row(a="length string", b=75)
    +        df = self.sqlCtx.createDataFrame([row])
    +        result = df.select(functions.expr("length(a)")).collect()[0].asDict()
    +        self.assertEqual(13, result["'length(a)"])
    +
         def test_replace(self):
             schema = StructType([
                 StructField("name", StringType(), True),
    @@ -863,6 +1001,13 @@ def test_capture_analysis_exception(self):
             # RuntimeException should not be captured
             self.assertRaises(py4j.protocol.Py4JJavaError, lambda: self.sqlCtx.sql("abc"))
     
    +    def test_capture_illegalargument_exception(self):
    +        self.assertRaisesRegexp(IllegalArgumentException, "Setting negative mapred.reduce.tasks",
    +                                lambda: self.sqlCtx.sql("SET mapred.reduce.tasks=-1"))
    +        df = self.sqlCtx.createDataFrame([(1, 2)], ["a", "b"])
    +        self.assertRaisesRegexp(IllegalArgumentException, "1024 is not in the permitted values",
    +                                lambda: df.select(sha2(df.a, 1024)).collect())
    +
     
     class HiveContextSQLTests(ReusedPySparkTestCase):
     
    diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
    index 160df40d65cc1..6f74b7162f7cc 100644
    --- a/python/pyspark/sql/types.py
    +++ b/python/pyspark/sql/types.py
    @@ -20,13 +20,10 @@
     import time
     import datetime
     import calendar
    -import keyword
    -import warnings
     import json
     import re
    -import weakref
    +import base64
     from array import array
    -from operator import itemgetter
     
     if sys.version >= "3":
         long = int
    @@ -35,6 +32,8 @@
     from py4j.protocol import register_input_converter
     from py4j.java_gateway import JavaClass
     
    +from pyspark.serializers import CloudPickleSerializer
    +
     __all__ = [
         "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType",
         "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType",
    @@ -71,6 +70,26 @@ def json(self):
                               separators=(',', ':'),
                               sort_keys=True)
     
    +    def needConversion(self):
    +        """
    +        Does this type need to conversion between Python object and internal SQL object.
    +
    +        This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.
    +        """
    +        return False
    +
    +    def toInternal(self, obj):
    +        """
    +        Converts a Python object into an internal SQL object.
    +        """
    +        return obj
    +
    +    def fromInternal(self, obj):
    +        """
    +        Converts an internal SQL object into a native Python object.
    +        """
    +        return obj
    +
     
     # This singleton pattern does not work with pickle, you will get
     # another object after pickle and unpickle
    @@ -143,6 +162,17 @@ class DateType(AtomicType):
     
         __metaclass__ = DataTypeSingleton
     
    +    EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal()
    +
    +    def needConversion(self):
    +        return True
    +
    +    def toInternal(self, d):
    +        return d and d.toordinal() - self.EPOCH_ORDINAL
    +
    +    def fromInternal(self, v):
    +        return v and datetime.date.fromordinal(v + self.EPOCH_ORDINAL)
    +
     
     class TimestampType(AtomicType):
         """Timestamp (datetime.datetime) data type.
    @@ -150,33 +180,50 @@ class TimestampType(AtomicType):
     
         __metaclass__ = DataTypeSingleton
     
    +    def needConversion(self):
    +        return True
    +
    +    def toInternal(self, dt):
    +        if dt is not None:
    +            seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo
    +                       else time.mktime(dt.timetuple()))
    +            return int(seconds * 1e6 + dt.microsecond)
    +
    +    def fromInternal(self, ts):
    +        if ts is not None:
    +            # using int to avoid precision loss in float
    +            return datetime.datetime.fromtimestamp(ts // 1000000).replace(microsecond=ts % 1000000)
    +
     
     class DecimalType(FractionalType):
         """Decimal (decimal.Decimal) data type.
    +
    +    The DecimalType must have fixed precision (the maximum total number of digits)
    +    and scale (the number of digits on the right of dot). For example, (5, 2) can
    +    support the value from [-999.99 to 999.99].
    +
    +    The precision can be up to 38, the scale must less or equal to precision.
    +
    +    When create a DecimalType, the default precision and scale is (10, 0). When infer
    +    schema from decimal.Decimal objects, it will be DecimalType(38, 18).
    +
    +    :param precision: the maximum total number of digits (default: 10)
    +    :param scale: the number of digits on right side of dot. (default: 0)
         """
     
    -    def __init__(self, precision=None, scale=None):
    +    def __init__(self, precision=10, scale=0):
             self.precision = precision
             self.scale = scale
    -        self.hasPrecisionInfo = precision is not None
    +        self.hasPrecisionInfo = True  # this is public API
     
         def simpleString(self):
    -        if self.hasPrecisionInfo:
    -            return "decimal(%d,%d)" % (self.precision, self.scale)
    -        else:
    -            return "decimal(10,0)"
    +        return "decimal(%d,%d)" % (self.precision, self.scale)
     
         def jsonValue(self):
    -        if self.hasPrecisionInfo:
    -            return "decimal(%d,%d)" % (self.precision, self.scale)
    -        else:
    -            return "decimal"
    +        return "decimal(%d,%d)" % (self.precision, self.scale)
     
         def __repr__(self):
    -        if self.hasPrecisionInfo:
    -            return "DecimalType(%d,%d)" % (self.precision, self.scale)
    -        else:
    -            return "DecimalType()"
    +        return "DecimalType(%d,%d)" % (self.precision, self.scale)
     
     
     class DoubleType(FractionalType):
    @@ -259,6 +306,19 @@ def fromJson(cls, json):
             return ArrayType(_parse_datatype_json_value(json["elementType"]),
                              json["containsNull"])
     
    +    def needConversion(self):
    +        return self.elementType.needConversion()
    +
    +    def toInternal(self, obj):
    +        if not self.needConversion():
    +            return obj
    +        return obj and [self.elementType.toInternal(v) for v in obj]
    +
    +    def fromInternal(self, obj):
    +        if not self.needConversion():
    +            return obj
    +        return obj and [self.elementType.fromInternal(v) for v in obj]
    +
     
     class MapType(DataType):
         """Map data type.
    @@ -304,6 +364,21 @@ def fromJson(cls, json):
                            _parse_datatype_json_value(json["valueType"]),
                            json["valueContainsNull"])
     
    +    def needConversion(self):
    +        return self.keyType.needConversion() or self.valueType.needConversion()
    +
    +    def toInternal(self, obj):
    +        if not self.needConversion():
    +            return obj
    +        return obj and dict((self.keyType.toInternal(k), self.valueType.toInternal(v))
    +                            for k, v in obj.items())
    +
    +    def fromInternal(self, obj):
    +        if not self.needConversion():
    +            return obj
    +        return obj and dict((self.keyType.fromInternal(k), self.valueType.fromInternal(v))
    +                            for k, v in obj.items())
    +
     
     class StructField(DataType):
         """A field in :class:`StructType`.
    @@ -311,7 +386,7 @@ class StructField(DataType):
         :param name: string, name of the field.
         :param dataType: :class:`DataType` of the field.
         :param nullable: boolean, whether the field can be null (None) or not.
    -    :param metadata: a dict from string to simple type that can be serialized to JSON automatically
    +    :param metadata: a dict from string to simple type that can be toInternald to JSON automatically
         """
     
         def __init__(self, name, dataType, nullable=True, metadata=None):
    @@ -351,6 +426,15 @@ def fromJson(cls, json):
                                json["nullable"],
                                json["metadata"])
     
    +    def needConversion(self):
    +        return self.dataType.needConversion()
    +
    +    def toInternal(self, obj):
    +        return self.dataType.toInternal(obj)
    +
    +    def fromInternal(self, obj):
    +        return self.dataType.fromInternal(obj)
    +
     
     class StructType(DataType):
         """Struct type, consisting of a list of :class:`StructField`.
    @@ -371,10 +455,13 @@ def __init__(self, fields=None):
             """
             if not fields:
                 self.fields = []
    +            self.names = []
             else:
                 self.fields = fields
    +            self.names = [f.name for f in fields]
                 assert all(isinstance(f, StructField) for f in fields),\
                     "fields should be a list of StructField"
    +        self._needSerializeAnyField = any(f.needConversion() for f in self.fields)
     
         def add(self, field, data_type=None, nullable=True, metadata=None):
             """
    @@ -406,6 +493,7 @@ def add(self, field, data_type=None, nullable=True, metadata=None):
             """
             if isinstance(field, StructField):
                 self.fields.append(field)
    +            self.names.append(field.name)
             else:
                 if isinstance(field, str) and data_type is None:
                     raise ValueError("Must specify DataType if passing name of struct_field to create.")
    @@ -415,6 +503,8 @@ def add(self, field, data_type=None, nullable=True, metadata=None):
                 else:
                     data_type_f = data_type
                 self.fields.append(StructField(field, data_type_f, nullable, metadata))
    +            self.names.append(field)
    +        self._needSerializeAnyField = any(f.needConversion() for f in self.fields)
             return self
     
         def simpleString(self):
    @@ -432,6 +522,41 @@ def jsonValue(self):
         def fromJson(cls, json):
             return StructType([StructField.fromJson(f) for f in json["fields"]])
     
    +    def needConversion(self):
    +        # We need convert Row()/namedtuple into tuple()
    +        return True
    +
    +    def toInternal(self, obj):
    +        if obj is None:
    +            return
    +
    +        if self._needSerializeAnyField:
    +            if isinstance(obj, dict):
    +                return tuple(f.toInternal(obj.get(n)) for n, f in zip(self.names, self.fields))
    +            elif isinstance(obj, (tuple, list)):
    +                return tuple(f.toInternal(v) for f, v in zip(self.fields, obj))
    +            else:
    +                raise ValueError("Unexpected tuple %r with StructType" % obj)
    +        else:
    +            if isinstance(obj, dict):
    +                return tuple(obj.get(n) for n in self.names)
    +            elif isinstance(obj, (list, tuple)):
    +                return tuple(obj)
    +            else:
    +                raise ValueError("Unexpected tuple %r with StructType" % obj)
    +
    +    def fromInternal(self, obj):
    +        if obj is None:
    +            return
    +        if isinstance(obj, Row):
    +            # it's already converted by pickler
    +            return obj
    +        if self._needSerializeAnyField:
    +            values = [f.fromInternal(v) for f, v in zip(self.fields, obj)]
    +        else:
    +            values = obj
    +        return _create_row(self.names, values)
    +
     
     class UserDefinedType(DataType):
         """User-defined type (UDT).
    @@ -460,21 +585,40 @@ def module(cls):
         @classmethod
         def scalaUDT(cls):
             """
    -        The class name of the paired Scala UDT.
    +        The class name of the paired Scala UDT (could be '', if there
    +        is no corresponding one).
    +        """
    +        return ''
    +
    +    def needConversion(self):
    +        return True
    +
    +    @classmethod
    +    def _cachedSqlType(cls):
    +        """
    +        Cache the sqlType() into class, because it's heavy used in `toInternal`.
             """
    -        raise NotImplementedError("UDT must have a paired Scala UDT.")
    +        if not hasattr(cls, "_cached_sql_type"):
    +            cls._cached_sql_type = cls.sqlType()
    +        return cls._cached_sql_type
    +
    +    def toInternal(self, obj):
    +        return self._cachedSqlType().toInternal(self.serialize(obj))
    +
    +    def fromInternal(self, obj):
    +        return self.deserialize(self._cachedSqlType().fromInternal(obj))
     
         def serialize(self, obj):
             """
             Converts the a user-type object into a SQL datum.
             """
    -        raise NotImplementedError("UDT must implement serialize().")
    +        raise NotImplementedError("UDT must implement toInternal().")
     
         def deserialize(self, datum):
             """
             Converts a SQL datum into a user-type object.
             """
    -        raise NotImplementedError("UDT must implement deserialize().")
    +        raise NotImplementedError("UDT must implement fromInternal().")
     
         def simpleString(self):
             return 'udt'
    @@ -483,22 +627,37 @@ def json(self):
             return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True)
     
         def jsonValue(self):
    -        schema = {
    -            "type": "udt",
    -            "class": self.scalaUDT(),
    -            "pyClass": "%s.%s" % (self.module(), type(self).__name__),
    -            "sqlType": self.sqlType().jsonValue()
    -        }
    +        if self.scalaUDT():
    +            assert self.module() != '__main__', 'UDT in __main__ cannot work with ScalaUDT'
    +            schema = {
    +                "type": "udt",
    +                "class": self.scalaUDT(),
    +                "pyClass": "%s.%s" % (self.module(), type(self).__name__),
    +                "sqlType": self.sqlType().jsonValue()
    +            }
    +        else:
    +            ser = CloudPickleSerializer()
    +            b = ser.dumps(type(self))
    +            schema = {
    +                "type": "udt",
    +                "pyClass": "%s.%s" % (self.module(), type(self).__name__),
    +                "serializedClass": base64.b64encode(b).decode('utf8'),
    +                "sqlType": self.sqlType().jsonValue()
    +            }
             return schema
     
         @classmethod
         def fromJson(cls, json):
    -        pyUDT = json["pyClass"]
    +        pyUDT = str(json["pyClass"])  # convert unicode to str
             split = pyUDT.rfind(".")
             pyModule = pyUDT[:split]
             pyClass = pyUDT[split+1:]
             m = __import__(pyModule, globals(), locals(), [pyClass])
    -        UDT = getattr(m, pyClass)
    +        if not hasattr(m, pyClass):
    +            s = base64.b64decode(json['serializedClass'].encode('utf-8'))
    +            UDT = CloudPickleSerializer().loads(s)
    +        else:
    +            UDT = getattr(m, pyClass)
             return UDT()
     
         def __eq__(self, other):
    @@ -506,7 +665,7 @@ def __eq__(self, other):
     
     
     _atomic_types = [StringType, BinaryType, BooleanType, DecimalType, FloatType, DoubleType,
    -                 ByteType, ShortType, IntegerType, LongType, DateType, TimestampType]
    +                 ByteType, ShortType, IntegerType, LongType, DateType, TimestampType, NullType]
     _all_atomic_types = dict((t.typeName(), t) for t in _atomic_types)
     _all_complex_types = dict((v.typeName(), v)
                               for v in [ArrayType, MapType, StructType])
    @@ -557,11 +716,6 @@ def _parse_datatype_json_string(json_string):
         >>> complex_maptype = MapType(complex_structtype,
         ...                           complex_arraytype, False)
         >>> check_datatype(complex_maptype)
    -
    -    >>> check_datatype(ExamplePointUDT())
    -    >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False),
    -    ...                                   StructField("point", ExamplePointUDT(), False)])
    -    >>> check_datatype(structtype_with_udt)
         """
         return _parse_datatype_json_value(json.loads(json_string))
     
    @@ -613,10 +767,6 @@ def _parse_datatype_json_value(json_value):
     
     def _infer_type(obj):
         """Infer the DataType from obj
    -
    -    >>> p = ExamplePoint(1.0, 2.0)
    -    >>> _infer_type(p)
    -    ExamplePointUDT
         """
         if obj is None:
             return NullType()
    @@ -625,7 +775,10 @@ def _infer_type(obj):
             return obj.__UDT__
     
         dataType = _type_mappings.get(type(obj))
    -    if dataType is not None:
    +    if dataType is DecimalType:
    +        # the precision and scale of `obj` may be different from row to row.
    +        return DecimalType(38, 18)
    +    elif dataType is not None:
             return dataType()
     
         if isinstance(obj, dict):
    @@ -671,117 +824,6 @@ def _infer_schema(row):
         return StructType(fields)
     
     
    -def _need_python_to_sql_conversion(dataType):
    -    """
    -    Checks whether we need python to sql conversion for the given type.
    -    For now, only UDTs need this conversion.
    -
    -    >>> _need_python_to_sql_conversion(DoubleType())
    -    False
    -    >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False),
    -    ...                       StructField("values", ArrayType(DoubleType(), False), False)])
    -    >>> _need_python_to_sql_conversion(schema0)
    -    True
    -    >>> _need_python_to_sql_conversion(ExamplePointUDT())
    -    True
    -    >>> schema1 = ArrayType(ExamplePointUDT(), False)
    -    >>> _need_python_to_sql_conversion(schema1)
    -    True
    -    >>> schema2 = StructType([StructField("label", DoubleType(), False),
    -    ...                       StructField("point", ExamplePointUDT(), False)])
    -    >>> _need_python_to_sql_conversion(schema2)
    -    True
    -    """
    -    if isinstance(dataType, StructType):
    -        # convert namedtuple or Row into tuple
    -        return True
    -    elif isinstance(dataType, ArrayType):
    -        return _need_python_to_sql_conversion(dataType.elementType)
    -    elif isinstance(dataType, MapType):
    -        return _need_python_to_sql_conversion(dataType.keyType) or \
    -            _need_python_to_sql_conversion(dataType.valueType)
    -    elif isinstance(dataType, UserDefinedType):
    -        return True
    -    elif isinstance(dataType, (DateType, TimestampType)):
    -        return True
    -    else:
    -        return False
    -
    -
    -EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal()
    -
    -
    -def _python_to_sql_converter(dataType):
    -    """
    -    Returns a converter that converts a Python object into a SQL datum for the given type.
    -
    -    >>> conv = _python_to_sql_converter(DoubleType())
    -    >>> conv(1.0)
    -    1.0
    -    >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False))
    -    >>> conv([1.0, 2.0])
    -    [1.0, 2.0]
    -    >>> conv = _python_to_sql_converter(ExamplePointUDT())
    -    >>> conv(ExamplePoint(1.0, 2.0))
    -    [1.0, 2.0]
    -    >>> schema = StructType([StructField("label", DoubleType(), False),
    -    ...                      StructField("point", ExamplePointUDT(), False)])
    -    >>> conv = _python_to_sql_converter(schema)
    -    >>> conv((1.0, ExamplePoint(1.0, 2.0)))
    -    (1.0, [1.0, 2.0])
    -    """
    -    if not _need_python_to_sql_conversion(dataType):
    -        return lambda x: x
    -
    -    if isinstance(dataType, StructType):
    -        names, types = zip(*[(f.name, f.dataType) for f in dataType.fields])
    -        if any(_need_python_to_sql_conversion(t) for t in types):
    -            converters = [_python_to_sql_converter(t) for t in types]
    -
    -            def converter(obj):
    -                if isinstance(obj, dict):
    -                    return tuple(c(obj.get(n)) for n, c in zip(names, converters))
    -                elif isinstance(obj, tuple):
    -                    if hasattr(obj, "__fields__") or hasattr(obj, "_fields"):
    -                        return tuple(c(v) for c, v in zip(converters, obj))
    -                    else:
    -                        return tuple(c(v) for c, v in zip(converters, obj))
    -                elif obj is not None:
    -                    raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType))
    -        else:
    -            def converter(obj):
    -                if isinstance(obj, dict):
    -                    return tuple(obj.get(n) for n in names)
    -                else:
    -                    return tuple(obj)
    -        return converter
    -    elif isinstance(dataType, ArrayType):
    -        element_converter = _python_to_sql_converter(dataType.elementType)
    -        return lambda a: a and [element_converter(v) for v in a]
    -    elif isinstance(dataType, MapType):
    -        key_converter = _python_to_sql_converter(dataType.keyType)
    -        value_converter = _python_to_sql_converter(dataType.valueType)
    -        return lambda m: m and dict([(key_converter(k), value_converter(v)) for k, v in m.items()])
    -
    -    elif isinstance(dataType, UserDefinedType):
    -        return lambda obj: obj and dataType.serialize(obj)
    -
    -    elif isinstance(dataType, DateType):
    -        return lambda d: d and d.toordinal() - EPOCH_ORDINAL
    -
    -    elif isinstance(dataType, TimestampType):
    -
    -        def to_posix_timstamp(dt):
    -            if dt:
    -                seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo
    -                           else time.mktime(dt.timetuple()))
    -                return int(seconds * 1e7 + dt.microsecond * 10)
    -        return to_posix_timstamp
    -
    -    else:
    -        raise ValueError("Unexpected type %r" % dataType)
    -
    -
     def _has_nulltype(dt):
         """ Return whether there is NullType in `dt` or not """
         if isinstance(dt, StructType):
    @@ -1059,20 +1101,19 @@ def _verify_type(obj, dataType):
         Traceback (most recent call last):
             ...
         ValueError:...
    -    >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT())
    -    >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL
    -    Traceback (most recent call last):
    -        ...
    -    ValueError:...
         """
         # all objects are nullable
         if obj is None:
             return
     
    +    # StringType can work with any types
    +    if isinstance(dataType, StringType):
    +        return
    +
         if isinstance(dataType, UserDefinedType):
             if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType):
                 raise ValueError("%r is not an instance of type %r" % (obj, dataType))
    -        _verify_type(dataType.serialize(obj), dataType.sqlType())
    +        _verify_type(dataType.toInternal(obj), dataType.sqlType())
             return
     
         _type = type(dataType)
    @@ -1082,7 +1123,7 @@ def _verify_type(obj, dataType):
             if not isinstance(obj, (tuple, list)):
                 raise TypeError("StructType can not accept object in type %s" % type(obj))
         else:
    -        # subclass of them can not be deserialized in JVM
    +        # subclass of them can not be fromInternald in JVM
             if type(obj) not in _acceptable_types[_type]:
                 raise TypeError("%s can not accept object in type %s" % (dataType, type(obj)))
     
    @@ -1102,159 +1143,10 @@ def _verify_type(obj, dataType):
             for v, f in zip(obj, dataType.fields):
                 _verify_type(v, f.dataType)
     
    -_cached_cls = weakref.WeakValueDictionary()
    -
    -
    -def _restore_object(dataType, obj):
    -    """ Restore object during unpickling. """
    -    # use id(dataType) as key to speed up lookup in dict
    -    # Because of batched pickling, dataType will be the
    -    # same object in most cases.
    -    k = id(dataType)
    -    cls = _cached_cls.get(k)
    -    if cls is None or cls.__datatype is not dataType:
    -        # use dataType as key to avoid create multiple class
    -        cls = _cached_cls.get(dataType)
    -        if cls is None:
    -            cls = _create_cls(dataType)
    -            _cached_cls[dataType] = cls
    -        cls.__datatype = dataType
    -        _cached_cls[k] = cls
    -    return cls(obj)
    -
    -
    -def _create_object(cls, v):
    -    """ Create an customized object with class `cls`. """
    -    # datetime.date would be deserialized as datetime.datetime
    -    # from java type, so we need to set it back.
    -    if cls is datetime.date and isinstance(v, datetime.datetime):
    -        return v.date()
    -    return cls(v) if v is not None else v
    -
    -
    -def _create_getter(dt, i):
    -    """ Create a getter for item `i` with schema """
    -    cls = _create_cls(dt)
    -
    -    def getter(self):
    -        return _create_object(cls, self[i])
    -
    -    return getter
    -
    -
    -def _has_struct_or_date(dt):
    -    """Return whether `dt` is or has StructType/DateType in it"""
    -    if isinstance(dt, StructType):
    -        return True
    -    elif isinstance(dt, ArrayType):
    -        return _has_struct_or_date(dt.elementType)
    -    elif isinstance(dt, MapType):
    -        return _has_struct_or_date(dt.keyType) or _has_struct_or_date(dt.valueType)
    -    elif isinstance(dt, DateType):
    -        return True
    -    elif isinstance(dt, UserDefinedType):
    -        return True
    -    return False
    -
    -
    -def _create_properties(fields):
    -    """Create properties according to fields"""
    -    ps = {}
    -    for i, f in enumerate(fields):
    -        name = f.name
    -        if (name.startswith("__") and name.endswith("__")
    -                or keyword.iskeyword(name)):
    -            warnings.warn("field name %s can not be accessed in Python,"
    -                          "use position to access it instead" % name)
    -        if _has_struct_or_date(f.dataType):
    -            # delay creating object until accessing it
    -            getter = _create_getter(f.dataType, i)
    -        else:
    -            getter = itemgetter(i)
    -        ps[name] = property(getter)
    -    return ps
    -
    -
    -def _create_cls(dataType):
    -    """
    -    Create an class by dataType
    -
    -    The created class is similar to namedtuple, but can have nested schema.
    -
    -    >>> schema = _parse_schema_abstract("a b c")
    -    >>> row = (1, 1.0, "str")
    -    >>> schema = _infer_schema_type(row, schema)
    -    >>> obj = _create_cls(schema)(row)
    -    >>> import pickle
    -    >>> pickle.loads(pickle.dumps(obj))
    -    Row(a=1, b=1.0, c='str')
    -
    -    >>> row = [[1], {"key": (1, 2.0)}]
    -    >>> schema = _parse_schema_abstract("a[] b{c d}")
    -    >>> schema = _infer_schema_type(row, schema)
    -    >>> obj = _create_cls(schema)(row)
    -    >>> pickle.loads(pickle.dumps(obj))
    -    Row(a=[1], b={'key': Row(c=1, d=2.0)})
    -    >>> pickle.loads(pickle.dumps(obj.a))
    -    [1]
    -    >>> pickle.loads(pickle.dumps(obj.b))
    -    {'key': Row(c=1, d=2.0)}
    -    """
    -
    -    if isinstance(dataType, ArrayType):
    -        cls = _create_cls(dataType.elementType)
    -
    -        def List(l):
    -            if l is None:
    -                return
    -            return [_create_object(cls, v) for v in l]
    -
    -        return List
    -
    -    elif isinstance(dataType, MapType):
    -        kcls = _create_cls(dataType.keyType)
    -        vcls = _create_cls(dataType.valueType)
    -
    -        def Dict(d):
    -            if d is None:
    -                return
    -            return dict((_create_object(kcls, k), _create_object(vcls, v)) for k, v in d.items())
    -
    -        return Dict
    -
    -    elif isinstance(dataType, DateType):
    -        return datetime.date
    -
    -    elif isinstance(dataType, UserDefinedType):
    -        return lambda datum: dataType.deserialize(datum)
    -
    -    elif not isinstance(dataType, StructType):
    -        # no wrapper for atomic types
    -        return lambda x: x
    -
    -    class Row(tuple):
    -
    -        """ Row in DataFrame """
    -        __datatype = dataType
    -        __fields__ = tuple(f.name for f in dataType.fields)
    -        __slots__ = ()
    -
    -        # create property for fast access
    -        locals().update(_create_properties(dataType.fields))
    -
    -        def asDict(self):
    -            """ Return as a dict """
    -            return dict((n, getattr(self, n)) for n in self.__fields__)
    -
    -        def __repr__(self):
    -            # call collect __repr__ for nested objects
    -            return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n))
    -                                          for n in self.__fields__))
    -
    -        def __reduce__(self):
    -            return (_restore_object, (self.__datatype, tuple(self)))
     
    -    return Row
    +# This is used to unpickle a Row from JVM
    +def _create_row_inbound_converter(dataType):
    +    return lambda *a: dataType.fromInternal(a)
     
     
     def _create_row(fields, values):
    @@ -1373,18 +1265,12 @@ def convert(self, obj, gateway_client):
     def _test():
         import doctest
         from pyspark.context import SparkContext
    -    # let doctest run in pyspark.sql.types, so DataTypes can be picklable
    -    import pyspark.sql.types
    -    from pyspark.sql import Row, SQLContext
    -    from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
    -    globs = pyspark.sql.types.__dict__.copy()
    +    from pyspark.sql import SQLContext
    +    globs = globals()
         sc = SparkContext('local[4]', 'PythonTest')
         globs['sc'] = sc
         globs['sqlContext'] = SQLContext(sc)
    -    globs['ExamplePoint'] = ExamplePoint
    -    globs['ExamplePointUDT'] = ExamplePointUDT
    -    (failure_count, test_count) = doctest.testmod(
    -        pyspark.sql.types, globs=globs, optionflags=doctest.ELLIPSIS)
    +    (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
         globs['sc'].stop()
         if failure_count:
             exit(-1)
    diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
    index cc5b2c088b7cc..0f795ca35b38a 100644
    --- a/python/pyspark/sql/utils.py
    +++ b/python/pyspark/sql/utils.py
    @@ -24,6 +24,12 @@ class AnalysisException(Exception):
         """
     
     
    +class IllegalArgumentException(Exception):
    +    """
    +    Passed an illegal or inappropriate argument.
    +    """
    +
    +
     def capture_sql_exception(f):
         def deco(*a, **kw):
             try:
    @@ -32,6 +38,8 @@ def deco(*a, **kw):
                 s = e.java_exception.toString()
                 if s.startswith('org.apache.spark.sql.AnalysisException: '):
                     raise AnalysisException(s.split(': ', 1)[1])
    +            if s.startswith('java.lang.IllegalArgumentException: '):
    +                raise IllegalArgumentException(s.split(': ', 1)[1])
                 raise
         return deco
     
    diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py
    index 10a859a532e28..33dd596335b47 100644
    --- a/python/pyspark/streaming/kafka.py
    +++ b/python/pyspark/streaming/kafka.py
    @@ -21,6 +21,8 @@
     from pyspark.storagelevel import StorageLevel
     from pyspark.serializers import PairDeserializer, NoOpSerializer
     from pyspark.streaming import DStream
    +from pyspark.streaming.dstream import TransformedDStream
    +from pyspark.streaming.util import TransformFunction
     
     __all__ = ['Broker', 'KafkaUtils', 'OffsetRange', 'TopicAndPartition', 'utf8_decoder']
     
    @@ -122,8 +124,9 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets={},
                 raise e
     
             ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
    -        stream = DStream(jstream, ssc, ser)
    -        return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
    +        stream = DStream(jstream, ssc, ser) \
    +            .map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
    +        return KafkaDStream(stream._jdstream, ssc, stream._jrdd_deserializer)
     
         @staticmethod
         def createRDD(sc, kafkaParams, offsetRanges, leaders={},
    @@ -161,8 +164,8 @@ def createRDD(sc, kafkaParams, offsetRanges, leaders={},
                 raise e
     
             ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
    -        rdd = RDD(jrdd, sc, ser)
    -        return rdd.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
    +        rdd = RDD(jrdd, sc, ser).map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
    +        return KafkaRDD(rdd._jrdd, rdd.ctx, rdd._jrdd_deserializer)
     
         @staticmethod
         def _printErrorMsg(sc):
    @@ -200,14 +203,30 @@ def __init__(self, topic, partition, fromOffset, untilOffset):
             :param fromOffset: Inclusive starting offset.
             :param untilOffset: Exclusive ending offset.
             """
    -        self._topic = topic
    -        self._partition = partition
    -        self._fromOffset = fromOffset
    -        self._untilOffset = untilOffset
    +        self.topic = topic
    +        self.partition = partition
    +        self.fromOffset = fromOffset
    +        self.untilOffset = untilOffset
    +
    +    def __eq__(self, other):
    +        if isinstance(other, self.__class__):
    +            return (self.topic == other.topic
    +                    and self.partition == other.partition
    +                    and self.fromOffset == other.fromOffset
    +                    and self.untilOffset == other.untilOffset)
    +        else:
    +            return False
    +
    +    def __ne__(self, other):
    +        return not self.__eq__(other)
    +
    +    def __str__(self):
    +        return "OffsetRange(topic: %s, partition: %d, range: [%d -> %d]" \
    +               % (self.topic, self.partition, self.fromOffset, self.untilOffset)
     
         def _jOffsetRange(self, helper):
    -        return helper.createOffsetRange(self._topic, self._partition, self._fromOffset,
    -                                        self._untilOffset)
    +        return helper.createOffsetRange(self.topic, self.partition, self.fromOffset,
    +                                        self.untilOffset)
     
     
     class TopicAndPartition(object):
    @@ -244,3 +263,87 @@ def __init__(self, host, port):
     
         def _jBroker(self, helper):
             return helper.createBroker(self._host, self._port)
    +
    +
    +class KafkaRDD(RDD):
    +    """
    +    A Python wrapper of KafkaRDD, to provide additional information on normal RDD.
    +    """
    +
    +    def __init__(self, jrdd, ctx, jrdd_deserializer):
    +        RDD.__init__(self, jrdd, ctx, jrdd_deserializer)
    +
    +    def offsetRanges(self):
    +        """
    +        Get the OffsetRange of specific KafkaRDD.
    +        :return: A list of OffsetRange
    +        """
    +        try:
    +            helperClass = self.ctx._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
    +                .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
    +            helper = helperClass.newInstance()
    +            joffsetRanges = helper.offsetRangesOfKafkaRDD(self._jrdd.rdd())
    +        except Py4JJavaError as e:
    +            if 'ClassNotFoundException' in str(e.java_exception):
    +                KafkaUtils._printErrorMsg(self.ctx)
    +            raise e
    +
    +        ranges = [OffsetRange(o.topic(), o.partition(), o.fromOffset(), o.untilOffset())
    +                  for o in joffsetRanges]
    +        return ranges
    +
    +
    +class KafkaDStream(DStream):
    +    """
    +    A Python wrapper of KafkaDStream
    +    """
    +
    +    def __init__(self, jdstream, ssc, jrdd_deserializer):
    +        DStream.__init__(self, jdstream, ssc, jrdd_deserializer)
    +
    +    def foreachRDD(self, func):
    +        """
    +        Apply a function to each RDD in this DStream.
    +        """
    +        if func.__code__.co_argcount == 1:
    +            old_func = func
    +            func = lambda r, rdd: old_func(rdd)
    +        jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer) \
    +            .rdd_wrapper(lambda jrdd, ctx, ser: KafkaRDD(jrdd, ctx, ser))
    +        api = self._ssc._jvm.PythonDStream
    +        api.callForeachRDD(self._jdstream, jfunc)
    +
    +    def transform(self, func):
    +        """
    +        Return a new DStream in which each RDD is generated by applying a function
    +        on each RDD of this DStream.
    +
    +        `func` can have one argument of `rdd`, or have two arguments of
    +        (`time`, `rdd`)
    +        """
    +        if func.__code__.co_argcount == 1:
    +            oldfunc = func
    +            func = lambda t, rdd: oldfunc(rdd)
    +        assert func.__code__.co_argcount == 2, "func should take one or two arguments"
    +
    +        return KafkaTransformedDStream(self, func)
    +
    +
    +class KafkaTransformedDStream(TransformedDStream):
    +    """
    +    Kafka specific wrapper of TransformedDStream to transform on Kafka RDD.
    +    """
    +
    +    def __init__(self, prev, func):
    +        TransformedDStream.__init__(self, prev, func)
    +
    +    @property
    +    def _jdstream(self):
    +        if self._jdstream_val is not None:
    +            return self._jdstream_val
    +
    +        jfunc = TransformFunction(self._sc, self.func, self.prev._jrdd_deserializer) \
    +            .rdd_wrapper(lambda jrdd, ctx, ser: KafkaRDD(jrdd, ctx, ser))
    +        dstream = self._sc._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc)
    +        self._jdstream_val = dstream.asJavaDStream()
    +        return self._jdstream_val
    diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
    index 77f9ccf0b114a..0da312b89b72f 100644
    --- a/python/pyspark/streaming/tests.py
    +++ b/python/pyspark/streaming/tests.py
    @@ -679,6 +679,70 @@ def test_kafka_rdd_with_leaders(self):
             rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges, leaders)
             self._validateRddResult(sendData, rdd)
     
    +    @unittest.skipIf(sys.version >= "3", "long type not support")
    +    def test_kafka_rdd_get_offsetRanges(self):
    +        """Test Python direct Kafka RDD get OffsetRanges."""
    +        topic = self._randomTopic()
    +        sendData = {"a": 3, "b": 4, "c": 5}
    +        offsetRanges = [OffsetRange(topic, 0, long(0), long(sum(sendData.values())))]
    +        kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()}
    +
    +        self._kafkaTestUtils.createTopic(topic)
    +        self._kafkaTestUtils.sendMessages(topic, sendData)
    +        rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges)
    +        self.assertEqual(offsetRanges, rdd.offsetRanges())
    +
    +    @unittest.skipIf(sys.version >= "3", "long type not support")
    +    def test_kafka_direct_stream_foreach_get_offsetRanges(self):
    +        """Test the Python direct Kafka stream foreachRDD get offsetRanges."""
    +        topic = self._randomTopic()
    +        sendData = {"a": 1, "b": 2, "c": 3}
    +        kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(),
    +                       "auto.offset.reset": "smallest"}
    +
    +        self._kafkaTestUtils.createTopic(topic)
    +        self._kafkaTestUtils.sendMessages(topic, sendData)
    +
    +        stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams)
    +
    +        offsetRanges = []
    +
    +        def getOffsetRanges(_, rdd):
    +            for o in rdd.offsetRanges():
    +                offsetRanges.append(o)
    +
    +        stream.foreachRDD(getOffsetRanges)
    +        self.ssc.start()
    +        self.wait_for(offsetRanges, 1)
    +
    +        self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), long(6))])
    +
    +    @unittest.skipIf(sys.version >= "3", "long type not support")
    +    def test_kafka_direct_stream_transform_get_offsetRanges(self):
    +        """Test the Python direct Kafka stream transform get offsetRanges."""
    +        topic = self._randomTopic()
    +        sendData = {"a": 1, "b": 2, "c": 3}
    +        kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(),
    +                       "auto.offset.reset": "smallest"}
    +
    +        self._kafkaTestUtils.createTopic(topic)
    +        self._kafkaTestUtils.sendMessages(topic, sendData)
    +
    +        stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams)
    +
    +        offsetRanges = []
    +
    +        def transformWithOffsetRanges(rdd):
    +            for o in rdd.offsetRanges():
    +                offsetRanges.append(o)
    +            return rdd
    +
    +        stream.transform(transformWithOffsetRanges).foreachRDD(lambda rdd: rdd.count())
    +        self.ssc.start()
    +        self.wait_for(offsetRanges, 1)
    +
    +        self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), long(6))])
    +
     
     class FlumeStreamTests(PySparkStreamingTestCase):
         timeout = 20  # seconds
    diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py
    index a9bfec2aab8fc..b20613b1283bd 100644
    --- a/python/pyspark/streaming/util.py
    +++ b/python/pyspark/streaming/util.py
    @@ -37,6 +37,11 @@ def __init__(self, ctx, func, *deserializers):
             self.ctx = ctx
             self.func = func
             self.deserializers = deserializers
    +        self._rdd_wrapper = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser)
    +
    +    def rdd_wrapper(self, func):
    +        self._rdd_wrapper = func
    +        return self
     
         def call(self, milliseconds, jrdds):
             try:
    @@ -51,7 +56,7 @@ def call(self, milliseconds, jrdds):
                 if len(sers) < len(jrdds):
                     sers += (sers[0],) * (len(jrdds) - len(sers))
     
    -            rdds = [RDD(jrdd, self.ctx, ser) if jrdd else None
    +            rdds = [self._rdd_wrapper(jrdd, self.ctx, ser) if jrdd else None
                         for jrdd, ser in zip(jrdds, sers)]
                 t = datetime.fromtimestamp(milliseconds / 1000.0)
                 r = self.func(t, *rdds)
    diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
    index 17256dfc95744..8bfed074c9052 100644
    --- a/python/pyspark/tests.py
    +++ b/python/pyspark/tests.py
    @@ -529,10 +529,127 @@ def test_deleting_input_files(self):
     
         def test_sampling_default_seed(self):
             # Test for SPARK-3995 (default seed setting)
    -        data = self.sc.parallelize(range(1000), 1)
    +        data = self.sc.parallelize(xrange(1000), 1)
             subset = data.takeSample(False, 10)
             self.assertEqual(len(subset), 10)
     
    +    def test_aggregate_mutable_zero_value(self):
    +        # Test for SPARK-9021; uses aggregate and treeAggregate to build dict
    +        # representing a counter of ints
    +        # NOTE: dict is used instead of collections.Counter for Python 2.6
    +        # compatibility
    +        from collections import defaultdict
    +
    +        # Show that single or multiple partitions work
    +        data1 = self.sc.range(10, numSlices=1)
    +        data2 = self.sc.range(10, numSlices=2)
    +
    +        def seqOp(x, y):
    +            x[y] += 1
    +            return x
    +
    +        def comboOp(x, y):
    +            for key, val in y.items():
    +                x[key] += val
    +            return x
    +
    +        counts1 = data1.aggregate(defaultdict(int), seqOp, comboOp)
    +        counts2 = data2.aggregate(defaultdict(int), seqOp, comboOp)
    +        counts3 = data1.treeAggregate(defaultdict(int), seqOp, comboOp, 2)
    +        counts4 = data2.treeAggregate(defaultdict(int), seqOp, comboOp, 2)
    +
    +        ground_truth = defaultdict(int, dict((i, 1) for i in range(10)))
    +        self.assertEqual(counts1, ground_truth)
    +        self.assertEqual(counts2, ground_truth)
    +        self.assertEqual(counts3, ground_truth)
    +        self.assertEqual(counts4, ground_truth)
    +
    +    def test_aggregate_by_key_mutable_zero_value(self):
    +        # Test for SPARK-9021; uses aggregateByKey to make a pair RDD that
    +        # contains lists of all values for each key in the original RDD
    +
    +        # list(range(...)) for Python 3.x compatibility (can't use * operator
    +        # on a range object)
    +        # list(zip(...)) for Python 3.x compatibility (want to parallelize a
    +        # collection, not a zip object)
    +        tuples = list(zip(list(range(10))*2, [1]*20))
    +        # Show that single or multiple partitions work
    +        data1 = self.sc.parallelize(tuples, 1)
    +        data2 = self.sc.parallelize(tuples, 2)
    +
    +        def seqOp(x, y):
    +            x.append(y)
    +            return x
    +
    +        def comboOp(x, y):
    +            x.extend(y)
    +            return x
    +
    +        values1 = data1.aggregateByKey([], seqOp, comboOp).collect()
    +        values2 = data2.aggregateByKey([], seqOp, comboOp).collect()
    +        # Sort lists to ensure clean comparison with ground_truth
    +        values1.sort()
    +        values2.sort()
    +
    +        ground_truth = [(i, [1]*2) for i in range(10)]
    +        self.assertEqual(values1, ground_truth)
    +        self.assertEqual(values2, ground_truth)
    +
    +    def test_fold_mutable_zero_value(self):
    +        # Test for SPARK-9021; uses fold to merge an RDD of dict counters into
    +        # a single dict
    +        # NOTE: dict is used instead of collections.Counter for Python 2.6
    +        # compatibility
    +        from collections import defaultdict
    +
    +        counts1 = defaultdict(int, dict((i, 1) for i in range(10)))
    +        counts2 = defaultdict(int, dict((i, 1) for i in range(3, 8)))
    +        counts3 = defaultdict(int, dict((i, 1) for i in range(4, 7)))
    +        counts4 = defaultdict(int, dict((i, 1) for i in range(5, 6)))
    +        all_counts = [counts1, counts2, counts3, counts4]
    +        # Show that single or multiple partitions work
    +        data1 = self.sc.parallelize(all_counts, 1)
    +        data2 = self.sc.parallelize(all_counts, 2)
    +
    +        def comboOp(x, y):
    +            for key, val in y.items():
    +                x[key] += val
    +            return x
    +
    +        fold1 = data1.fold(defaultdict(int), comboOp)
    +        fold2 = data2.fold(defaultdict(int), comboOp)
    +
    +        ground_truth = defaultdict(int)
    +        for counts in all_counts:
    +            for key, val in counts.items():
    +                ground_truth[key] += val
    +        self.assertEqual(fold1, ground_truth)
    +        self.assertEqual(fold2, ground_truth)
    +
    +    def test_fold_by_key_mutable_zero_value(self):
    +        # Test for SPARK-9021; uses foldByKey to make a pair RDD that contains
    +        # lists of all values for each key in the original RDD
    +
    +        tuples = [(i, range(i)) for i in range(10)]*2
    +        # Show that single or multiple partitions work
    +        data1 = self.sc.parallelize(tuples, 1)
    +        data2 = self.sc.parallelize(tuples, 2)
    +
    +        def comboOp(x, y):
    +            x.extend(y)
    +            return x
    +
    +        values1 = data1.foldByKey([], comboOp).collect()
    +        values2 = data2.foldByKey([], comboOp).collect()
    +        # Sort lists to ensure clean comparison with ground_truth
    +        values1.sort()
    +        values2.sort()
    +
    +        # list(range(...)) for Python 3.x compatibility
    +        ground_truth = [(i, list(range(i))*2) for i in range(10)]
    +        self.assertEqual(values1, ground_truth)
    +        self.assertEqual(values2, ground_truth)
    +
         def test_aggregate_by_key(self):
             data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2)
     
    @@ -624,8 +741,8 @@ def test_zip_with_different_serializers(self):
     
         def test_zip_with_different_object_sizes(self):
             # regress test for SPARK-5973
    -        a = self.sc.parallelize(range(10000)).map(lambda i: '*' * i)
    -        b = self.sc.parallelize(range(10000, 20000)).map(lambda i: '*' * i)
    +        a = self.sc.parallelize(xrange(10000)).map(lambda i: '*' * i)
    +        b = self.sc.parallelize(xrange(10000, 20000)).map(lambda i: '*' * i)
             self.assertEqual(10000, a.zip(b).count())
     
         def test_zip_with_different_number_of_items(self):
    @@ -647,7 +764,7 @@ def test_zip_with_different_number_of_items(self):
                 self.assertRaises(Exception, lambda: a.zip(b).count())
     
         def test_count_approx_distinct(self):
    -        rdd = self.sc.parallelize(range(1000))
    +        rdd = self.sc.parallelize(xrange(1000))
             self.assertTrue(950 < rdd.countApproxDistinct(0.03) < 1050)
             self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.03) < 1050)
             self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.03) < 1050)
    @@ -777,7 +894,7 @@ def test_distinct(self):
         def test_external_group_by_key(self):
             self.sc._conf.set("spark.python.worker.memory", "1m")
             N = 200001
    -        kv = self.sc.parallelize(range(N)).map(lambda x: (x % 3, x))
    +        kv = self.sc.parallelize(xrange(N)).map(lambda x: (x % 3, x))
             gkv = kv.groupByKey().cache()
             self.assertEqual(3, gkv.count())
             filtered = gkv.filter(lambda kv: kv[0] == 1)
    @@ -871,7 +988,7 @@ def test_narrow_dependency_in_join(self):
     
         # Regression test for SPARK-6294
         def test_take_on_jrdd(self):
    -        rdd = self.sc.parallelize(range(1 << 20)).map(lambda x: str(x))
    +        rdd = self.sc.parallelize(xrange(1 << 20)).map(lambda x: str(x))
             rdd._jrdd.first()
     
         def test_sortByKey_uses_all_partitions_not_only_first_and_last(self):
    @@ -885,6 +1002,19 @@ def test_sortByKey_uses_all_partitions_not_only_first_and_last(self):
                 for size in sizes:
                     self.assertGreater(size, 0)
     
    +    def test_pipe_functions(self):
    +        data = ['1', '2', '3']
    +        rdd = self.sc.parallelize(data)
    +        with QuietTest(self.sc):
    +            self.assertEqual([], rdd.pipe('cc').collect())
    +            self.assertRaises(Py4JJavaError, rdd.pipe('cc', checkCode=True).collect)
    +        result = rdd.pipe('cat').collect()
    +        result.sort()
    +        for x, y in zip(data, result):
    +            self.assertEqual(x, y)
    +        self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', checkCode=True).collect)
    +        self.assertEqual([], rdd.pipe('grep 4').collect())
    +
     
     class ProfilerTests(PySparkTestCase):
     
    @@ -1504,13 +1634,13 @@ def run():
                 self.fail("daemon had been killed")
     
             # run a normal job
    -        rdd = self.sc.parallelize(range(100), 1)
    +        rdd = self.sc.parallelize(xrange(100), 1)
             self.assertEqual(100, rdd.map(str).count())
     
         def test_after_exception(self):
             def raise_exception(_):
                 raise Exception()
    -        rdd = self.sc.parallelize(range(100), 1)
    +        rdd = self.sc.parallelize(xrange(100), 1)
             with QuietTest(self.sc):
                 self.assertRaises(Exception, lambda: rdd.foreach(raise_exception))
             self.assertEqual(100, rdd.map(str).count())
    @@ -1526,22 +1656,22 @@ def test_after_jvm_exception(self):
             with QuietTest(self.sc):
                 self.assertRaises(Exception, lambda: filtered_data.count())
     
    -        rdd = self.sc.parallelize(range(100), 1)
    +        rdd = self.sc.parallelize(xrange(100), 1)
             self.assertEqual(100, rdd.map(str).count())
     
         def test_accumulator_when_reuse_worker(self):
             from pyspark.accumulators import INT_ACCUMULATOR_PARAM
             acc1 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM)
    -        self.sc.parallelize(range(100), 20).foreach(lambda x: acc1.add(x))
    +        self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc1.add(x))
             self.assertEqual(sum(range(100)), acc1.value)
     
             acc2 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM)
    -        self.sc.parallelize(range(100), 20).foreach(lambda x: acc2.add(x))
    +        self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc2.add(x))
             self.assertEqual(sum(range(100)), acc2.value)
             self.assertEqual(sum(range(100)), acc1.value)
     
         def test_reuse_worker_after_take(self):
    -        rdd = self.sc.parallelize(range(100000), 1)
    +        rdd = self.sc.parallelize(xrange(100000), 1)
             self.assertEqual(0, rdd.first())
     
             def count():
    @@ -1693,7 +1823,7 @@ def test_module_dependency_on_cluster(self):
                 |    return x + 1
                 """)
             proc = subprocess.Popen([self.sparkSubmit, "--py-files", zip, "--master",
    -                                "local-cluster[1,1,512]", script],
    +                                "local-cluster[1,1,1024]", script],
                                     stdout=subprocess.PIPE)
             out, err = proc.communicate()
             self.assertEqual(0, proc.returncode)
    @@ -1727,7 +1857,7 @@ def test_package_dependency_on_cluster(self):
             self.create_spark_package("a:mylib:0.1")
             proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories",
                                      "file:" + self.programDir, "--master",
    -                                 "local-cluster[1,1,512]", script], stdout=subprocess.PIPE)
    +                                 "local-cluster[1,1,1024]", script], stdout=subprocess.PIPE)
             out, err = proc.communicate()
             self.assertEqual(0, proc.returncode)
             self.assertIn("[2, 3, 4]", out.decode('utf-8'))
    @@ -1746,7 +1876,7 @@ def test_single_script_on_cluster(self):
             # this will fail if you have different spark.executor.memory
             # in conf/spark-defaults.conf
             proc = subprocess.Popen(
    -            [self.sparkSubmit, "--master", "local-cluster[1,1,512]", script],
    +            [self.sparkSubmit, "--master", "local-cluster[1,1,1024]", script],
                 stdout=subprocess.PIPE)
             out, err = proc.communicate()
             self.assertEqual(0, proc.returncode)
    diff --git a/python/run-tests.py b/python/run-tests.py
    index 7638854def2e8..cc560779373b3 100755
    --- a/python/run-tests.py
    +++ b/python/run-tests.py
    @@ -72,7 +72,8 @@ def print_red(text):
     
     
     def run_individual_python_test(test_name, pyspark_python):
    -    env = {'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python)}
    +    env = dict(os.environ)
    +    env.update({'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python)})
         LOGGER.debug("Starting test(%s): %s", pyspark_python, test_name)
         start_time = time.time()
         try:
    diff --git a/python/test_support/sql/orc_partitioned/._SUCCESS.crc b/python/test_support/sql/orc_partitioned/._SUCCESS.crc
    new file mode 100644
    index 0000000000000..3b7b044936a89
    Binary files /dev/null and b/python/test_support/sql/orc_partitioned/._SUCCESS.crc differ
    diff --git a/sql/hive/src/test/resources/golden/udaf_number_format-0-eff4ef3c207d14d5121368f294697964 b/python/test_support/sql/orc_partitioned/_SUCCESS
    old mode 100644
    new mode 100755
    similarity index 100%
    rename from sql/hive/src/test/resources/golden/udaf_number_format-0-eff4ef3c207d14d5121368f294697964
    rename to python/test_support/sql/orc_partitioned/_SUCCESS
    diff --git a/python/test_support/sql/orc_partitioned/b=0/c=0/.part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc.crc b/python/test_support/sql/orc_partitioned/b=0/c=0/.part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc.crc
    new file mode 100644
    index 0000000000000..834cf0b7f2272
    Binary files /dev/null and b/python/test_support/sql/orc_partitioned/b=0/c=0/.part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc.crc differ
    diff --git a/python/test_support/sql/orc_partitioned/b=0/c=0/part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc b/python/test_support/sql/orc_partitioned/b=0/c=0/part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc
    new file mode 100755
    index 0000000000000..4943801873356
    Binary files /dev/null and b/python/test_support/sql/orc_partitioned/b=0/c=0/part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc differ
    diff --git a/python/test_support/sql/orc_partitioned/b=1/c=1/.part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc.crc b/python/test_support/sql/orc_partitioned/b=1/c=1/.part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc.crc
    new file mode 100644
    index 0000000000000..693dceeee3ef2
    Binary files /dev/null and b/python/test_support/sql/orc_partitioned/b=1/c=1/.part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc.crc differ
    diff --git a/python/test_support/sql/orc_partitioned/b=1/c=1/part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc b/python/test_support/sql/orc_partitioned/b=1/c=1/part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc
    new file mode 100755
    index 0000000000000..4cbb95ae0242c
    Binary files /dev/null and b/python/test_support/sql/orc_partitioned/b=1/c=1/part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc differ
    diff --git a/repl/pom.xml b/repl/pom.xml
    index 370b2bc2fa8ed..a5a0f1fc2c857 100644
    --- a/repl/pom.xml
    +++ b/repl/pom.xml
    @@ -38,11 +38,6 @@
       
     
       
    -    
    -      ${jline.groupid}
    -      jline
    -      ${jline.version}
    -    
         
           org.apache.spark
           spark-core_${scala.binary.version}
    @@ -138,7 +133,6 @@
                 
                 
                   
    -                src/main/scala
                     ${extra.source.dir}
                   
                 
    @@ -151,7 +145,6 @@
                 
                 
                   
    -                src/test/scala
                     ${extra.testsource.dir}
                   
                 
    @@ -161,6 +154,20 @@
         
       
       
    +    
    +      scala-2.10
    +      
    +        !scala-2.11
    +      
    +      
    +        
    +          ${jline.groupid}
    +          jline
    +          ${jline.version}
    +        
    +      
    +    
    +
         
           scala-2.11
           
    diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala
    index 6480e2d24e044..24fbbc12c08da 100644
    --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala
    +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala
    @@ -39,6 +39,8 @@ class SparkCommandLine(args: List[String], override val settings: Settings)
       }
     
       def this(args: List[String]) {
    +    // scalastyle:off println
         this(args, str => Console.println("Error: " + str))
    +    // scalastyle:on println
       }
     }
    diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala
    index 2b235525250c2..8130868fe1487 100644
    --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala
    +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala
    @@ -1008,9 +1008,9 @@ class SparkILoop(
         val jars = SparkILoop.getAddedJars
         val conf = new SparkConf()
           .setMaster(getMaster())
    -      .setAppName("Spark shell")
           .setJars(jars)
           .set("spark.repl.class.uri", intp.classServerUri)
    +      .setIfMissing("spark.app.name", "Spark shell")
         if (execUri != null) {
           conf.set("spark.executor.uri", execUri)
         }
    @@ -1101,7 +1101,9 @@ object SparkILoop extends Logging {
                 val s = super.readLine()
                 // helping out by printing the line being interpreted.
                 if (s != null)
    +              // scalastyle:off println
                   output.println(s)
    +              // scalastyle:on println
                 s
               }
             }
    diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
    index 05faef8786d2c..bd3314d94eed6 100644
    --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
    +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
    @@ -80,11 +80,13 @@ private[repl] trait SparkILoopInit {
         if (!initIsComplete)
           withLock { while (!initIsComplete) initLoopCondition.await() }
         if (initError != null) {
    +      // scalastyle:off println
           println("""
             |Failed to initialize the REPL due to an unexpected error.
             |This is a bug, please, report it along with the error diagnostics printed below.
             |%s.""".stripMargin.format(initError)
           )
    +      // scalastyle:on println
           false
         } else true
       }
    diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala
    index 35fb625645022..4ee605fd7f11e 100644
    --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala
    +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala
    @@ -1079,8 +1079,10 @@ import org.apache.spark.annotation.DeveloperApi
           throw new EvalException("Failed to load '" + path + "': " + ex.getMessage, ex)
     
         private def load(path: String): Class[_] = {
    +      // scalastyle:off classforname
           try Class.forName(path, true, classLoader)
           catch { case ex: Throwable => evalError(path, unwrap(ex)) }
    +      // scalastyle:on classforname
         }
     
         lazy val evalClass = load(evalPath)
    @@ -1761,7 +1763,9 @@ object SparkIMain {
             if (intp.totalSilence) ()
             else super.printMessage(msg)
           }
    +      // scalastyle:off println
           else Console.println(msg)
    +      // scalastyle:on println
         }
       }
     }
    diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
    index f150fec7db945..5674dcd669bee 100644
    --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
    +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
    @@ -211,7 +211,7 @@ class ReplSuite extends SparkFunSuite {
       }
     
       test("local-cluster mode") {
    -    val output = runInterpreter("local-cluster[1,1,512]",
    +    val output = runInterpreter("local-cluster[1,1,1024]",
           """
             |var v = 7
             |def getV() = v
    @@ -233,7 +233,7 @@ class ReplSuite extends SparkFunSuite {
       }
     
       test("SPARK-1199 two instances of same class don't type check.") {
    -    val output = runInterpreter("local-cluster[1,1,512]",
    +    val output = runInterpreter("local-cluster[1,1,1024]",
           """
             |case class Sum(exp: String, exp2: String)
             |val a = Sum("A", "B")
    @@ -256,7 +256,7 @@ class ReplSuite extends SparkFunSuite {
     
       test("SPARK-2576 importing SQLContext.implicits._") {
         // We need to use local-cluster to test this case.
    -    val output = runInterpreter("local-cluster[1,1,512]",
    +    val output = runInterpreter("local-cluster[1,1,1024]",
           """
             |val sqlContext = new org.apache.spark.sql.SQLContext(sc)
             |import sqlContext.implicits._
    @@ -325,9 +325,9 @@ class ReplSuite extends SparkFunSuite {
         assertDoesNotContain("Exception", output)
         assertContains("ret: Array[Foo] = Array(Foo(1),", output)
       }
    -  
    +
       test("collecting objects of class defined in repl - shuffling") {
    -    val output = runInterpreter("local-cluster[1,1,512]",
    +    val output = runInterpreter("local-cluster[1,1,1024]",
           """
             |case class Foo(i: Int)
             |val list = List((1, Foo(1)), (1, Foo(2)))
    diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala
    index f4f4b626988e9..be31eb2eda546 100644
    --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala
    +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala
    @@ -17,13 +17,14 @@
     
     package org.apache.spark.repl
     
    +import java.io.File
    +
    +import scala.tools.nsc.Settings
    +
     import org.apache.spark.util.Utils
     import org.apache.spark._
     import org.apache.spark.sql.SQLContext
     
    -import scala.tools.nsc.Settings
    -import scala.tools.nsc.interpreter.SparkILoop
    -
     object Main extends Logging {
     
       val conf = new SparkConf()
    @@ -32,7 +33,8 @@ object Main extends Logging {
       val outputDir = Utils.createTempDir(rootDir)
       val s = new Settings()
       s.processArguments(List("-Yrepl-class-based",
    -    "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", "-Yrepl-sync"), true)
    +    "-Yrepl-outdir", s"${outputDir.getAbsolutePath}",
    +    "-classpath", getAddedJars.mkString(File.pathSeparator)), true)
       val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf))
       var sparkContext: SparkContext = _
       var sqlContext: SQLContext = _
    @@ -48,7 +50,6 @@ object Main extends Logging {
         Option(sparkContext).map(_.stop)
       }
     
    -
       def getAddedJars: Array[String] = {
         val envJars = sys.env.get("ADD_JARS")
         if (envJars.isDefined) {
    @@ -64,9 +65,9 @@ object Main extends Logging {
         val jars = getAddedJars
         val conf = new SparkConf()
           .setMaster(getMaster)
    -      .setAppName("Spark shell")
           .setJars(jars)
           .set("spark.repl.class.uri", classServer.uri)
    +      .setIfMissing("spark.app.name", "Spark shell")
         logInfo("Spark class server started at " + classServer.uri)
         if (execUri != null) {
           conf.set("spark.executor.uri", execUri)
    @@ -84,10 +85,9 @@ object Main extends Logging {
         val loader = Utils.getContextOrSparkClassLoader
         try {
           sqlContext = loader.loadClass(name).getConstructor(classOf[SparkContext])
    -        .newInstance(sparkContext).asInstanceOf[SQLContext] 
    +        .newInstance(sparkContext).asInstanceOf[SQLContext]
           logInfo("Created sql context (with Hive support)..")
    -    }
    -    catch {
    +    } catch {
           case _: java.lang.ClassNotFoundException | _: java.lang.NoClassDefFoundError =>
             sqlContext = new SQLContext(sparkContext)
             logInfo("Created sql context..")
    diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala
    deleted file mode 100644
    index 8e519fa67f649..0000000000000
    --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala
    +++ /dev/null
    @@ -1,86 +0,0 @@
    -/* NSC -- new Scala compiler
    - * Copyright 2005-2013 LAMP/EPFL
    - * @author  Paul Phillips
    - */
    -
    -package scala.tools.nsc
    -package interpreter
    -
    -import scala.tools.nsc.ast.parser.Tokens.EOF
    -
    -trait SparkExprTyper {
    -  val repl: SparkIMain
    -
    -  import repl._
    -  import global.{ reporter => _, Import => _, _ }
    -  import naming.freshInternalVarName
    -
    -  def symbolOfLine(code: String): Symbol = {
    -    def asExpr(): Symbol = {
    -      val name  = freshInternalVarName()
    -      // Typing it with a lazy val would give us the right type, but runs
    -      // into compiler bugs with things like existentials, so we compile it
    -      // behind a def and strip the NullaryMethodType which wraps the expr.
    -      val line = "def " + name + " = " + code
    -
    -      interpretSynthetic(line) match {
    -        case IR.Success =>
    -          val sym0 = symbolOfTerm(name)
    -          // drop NullaryMethodType
    -          sym0.cloneSymbol setInfo exitingTyper(sym0.tpe_*.finalResultType)
    -        case _          => NoSymbol
    -      }
    -    }
    -    def asDefn(): Symbol = {
    -      val old = repl.definedSymbolList.toSet
    -
    -      interpretSynthetic(code) match {
    -        case IR.Success =>
    -          repl.definedSymbolList filterNot old match {
    -            case Nil        => NoSymbol
    -            case sym :: Nil => sym
    -            case syms       => NoSymbol.newOverloaded(NoPrefix, syms)
    -          }
    -        case _ => NoSymbol
    -      }
    -    }
    -    def asError(): Symbol = {
    -      interpretSynthetic(code)
    -      NoSymbol
    -    }
    -    beSilentDuring(asExpr()) orElse beSilentDuring(asDefn()) orElse asError()
    -  }
    -
    -  private var typeOfExpressionDepth = 0
    -  def typeOfExpression(expr: String, silent: Boolean = true): Type = {
    -    if (typeOfExpressionDepth > 2) {
    -      repldbg("Terminating typeOfExpression recursion for expression: " + expr)
    -      return NoType
    -    }
    -    typeOfExpressionDepth += 1
    -    // Don't presently have a good way to suppress undesirable success output
    -    // while letting errors through, so it is first trying it silently: if there
    -    // is an error, and errors are desired, then it re-evaluates non-silently
    -    // to induce the error message.
    -    try beSilentDuring(symbolOfLine(expr).tpe) match {
    -      case NoType if !silent => symbolOfLine(expr).tpe // generate error
    -      case tpe               => tpe
    -    }
    -    finally typeOfExpressionDepth -= 1
    -  }
    -
    -  // This only works for proper types.
    -  def typeOfTypeString(typeString: String): Type = {
    -    def asProperType(): Option[Type] = {
    -      val name = freshInternalVarName()
    -      val line = "def %s: %s = ???" format (name, typeString)
    -      interpretSynthetic(line) match {
    -        case IR.Success =>
    -          val sym0 = symbolOfTerm(name)
    -          Some(sym0.asMethod.returnType)
    -        case _          => None
    -      }
    -    }
    -    beSilentDuring(asProperType()) getOrElse NoType
    -  }
    -}
    diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala
    index 7a5e94da5cbf3..bf609ff0f65fc 100644
    --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala
    +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala
    @@ -1,88 +1,64 @@
    -/* NSC -- new Scala compiler
    - * Copyright 2005-2013 LAMP/EPFL
    - * @author Alexander Spoon
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
      */
     
    -package scala
    -package tools.nsc
    -package interpreter
    +package org.apache.spark.repl
     
    -import scala.language.{ implicitConversions, existentials }
    -import scala.annotation.tailrec
    -import Predef.{ println => _, _ }
    -import interpreter.session._
    -import StdReplTags._
    -import scala.reflect.api.{Mirror, Universe, TypeCreator}
    -import scala.util.Properties.{ jdkHome, javaVersion, versionString, javaVmName }
    -import scala.tools.nsc.util.{ ClassPath, Exceptional, stringFromWriter, stringFromStream }
    -import scala.reflect.{ClassTag, classTag}
    -import scala.reflect.internal.util.{ BatchSourceFile, ScalaClassLoader }
    -import ScalaClassLoader._
    -import scala.reflect.io.{ File, Directory }
    -import scala.tools.util._
    -import scala.collection.generic.Clearable
    -import scala.concurrent.{ ExecutionContext, Await, Future, future }
    -import ExecutionContext.Implicits._
    -import java.io.{ BufferedReader, FileReader }
    +import java.io.{BufferedReader, FileReader}
     
    -/** The Scala interactive shell.  It provides a read-eval-print loop
    -  *  around the Interpreter class.
    -  *  After instantiation, clients should call the main() method.
    -  *
    -  *  If no in0 is specified, then input will come from the console, and
    -  *  the class will attempt to provide input editing feature such as
    -  *  input history.
    -  *
    -  *  @author Moez A. Abdel-Gawad
    -  *  @author  Lex Spoon
    -  *  @version 1.2
    -  */
    -class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter)
    -  extends AnyRef
    -  with LoopCommands
    -{
    -  def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out)
    -  def this() = this(None, new JPrintWriter(Console.out, true))
    -//
    -//  @deprecated("Use `intp` instead.", "2.9.0") def interpreter = intp
    -//  @deprecated("Use `intp` instead.", "2.9.0") def interpreter_= (i: Interpreter): Unit = intp = i
    -
    -  var in: InteractiveReader = _   // the input stream from which commands come
    -  var settings: Settings = _
    -  var intp: SparkIMain = _
    +import Predef.{println => _, _}
    +import scala.util.Properties.{jdkHome, javaVersion, versionString, javaVmName}
     
    -  var globalFuture: Future[Boolean] = _
    +import scala.tools.nsc.interpreter.{JPrintWriter, ILoop}
    +import scala.tools.nsc.Settings
    +import scala.tools.nsc.util.stringFromStream
     
    -  protected def asyncMessage(msg: String) {
    -    if (isReplInfo || isReplPower)
    -      echoAndRefresh(msg)
    -  }
    +/**
    + *  A Spark-specific interactive shell.
    + */
    +class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter)
    +    extends ILoop(in0, out) {
    +  def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out)
    +  def this() = this(None, new JPrintWriter(Console.out, true))
     
       def initializeSpark() {
         intp.beQuietDuring {
    -      command( """
    +      processLine("""
              @transient val sc = {
                val _sc = org.apache.spark.repl.Main.createSparkContext()
                println("Spark context available as sc.")
                _sc
              }
             """)
    -      command( """
    +      processLine("""
              @transient val sqlContext = {
                val _sqlContext = org.apache.spark.repl.Main.createSQLContext()
                println("SQL context available as sqlContext.")
                _sqlContext
              }
             """)
    -      command("import org.apache.spark.SparkContext._")
    -      command("import sqlContext.implicits._")
    -      command("import sqlContext.sql")
    -      command("import org.apache.spark.sql.functions._")
    +      processLine("import org.apache.spark.SparkContext._")
    +      processLine("import sqlContext.implicits._")
    +      processLine("import sqlContext.sql")
    +      processLine("import org.apache.spark.sql.functions._")
         }
       }
     
       /** Print a welcome message */
    -  def printWelcome() {
    +  override def printWelcome() {
         import org.apache.spark.SPARK_VERSION
         echo("""Welcome to
           ____              __
    @@ -98,875 +74,42 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter)
         echo("Type :help for more information.")
       }
     
    -  override def echoCommandMessage(msg: String) {
    -    intp.reporter printUntruncatedMessage msg
    -  }
    -
    -  // lazy val power = new Power(intp, new StdReplVals(this))(tagOfStdReplVals, classTag[StdReplVals])
    -  def history = in.history
    -
    -  // classpath entries added via :cp
    -  var addedClasspath: String = ""
    -
    -  /** A reverse list of commands to replay if the user requests a :replay */
    -  var replayCommandStack: List[String] = Nil
    -
    -  /** A list of commands to replay if the user requests a :replay */
    -  def replayCommands = replayCommandStack.reverse
    -
    -  /** Record a command for replay should the user request a :replay */
    -  def addReplay(cmd: String) = replayCommandStack ::= cmd
    -
    -  def savingReplayStack[T](body: => T): T = {
    -    val saved = replayCommandStack
    -    try body
    -    finally replayCommandStack = saved
    -  }
    -  def savingReader[T](body: => T): T = {
    -    val saved = in
    -    try body
    -    finally in = saved
    -  }
    -
    -  /** Close the interpreter and set the var to null. */
    -  def closeInterpreter() {
    -    if (intp ne null) {
    -      intp.close()
    -      intp = null
    -    }
    -  }
    -
    -  class SparkILoopInterpreter extends SparkIMain(settings, out) {
    -    outer =>
    -
    -    override lazy val formatting = new Formatting {
    -      def prompt = SparkILoop.this.prompt
    -    }
    -    override protected def parentClassLoader =
    -      settings.explicitParentLoader.getOrElse( classOf[SparkILoop].getClassLoader )
    -  }
    -
    -  /** Create a new interpreter. */
    -  def createInterpreter() {
    -    if (addedClasspath != "")
    -      settings.classpath append addedClasspath
    -
    -    intp = new SparkILoopInterpreter
    -  }
    -
    -  /** print a friendly help message */
    -  def helpCommand(line: String): Result = {
    -    if (line == "") helpSummary()
    -    else uniqueCommand(line) match {
    -      case Some(lc) => echo("\n" + lc.help)
    -      case _        => ambiguousError(line)
    -    }
    -  }
    -  private def helpSummary() = {
    -    val usageWidth  = commands map (_.usageMsg.length) max
    -    val formatStr   = "%-" + usageWidth + "s %s"
    -
    -    echo("All commands can be abbreviated, e.g. :he instead of :help.")
    -
    -    commands foreach { cmd =>
    -      echo(formatStr.format(cmd.usageMsg, cmd.help))
    -    }
    -  }
    -  private def ambiguousError(cmd: String): Result = {
    -    matchingCommands(cmd) match {
    -      case Nil  => echo(cmd + ": no such command.  Type :help for help.")
    -      case xs   => echo(cmd + " is ambiguous: did you mean " + xs.map(":" + _.name).mkString(" or ") + "?")
    -    }
    -    Result(keepRunning = true, None)
    -  }
    -  private def matchingCommands(cmd: String) = commands filter (_.name startsWith cmd)
    -  private def uniqueCommand(cmd: String): Option[LoopCommand] = {
    -    // this lets us add commands willy-nilly and only requires enough command to disambiguate
    -    matchingCommands(cmd) match {
    -      case List(x)  => Some(x)
    -      // exact match OK even if otherwise appears ambiguous
    -      case xs       => xs find (_.name == cmd)
    -    }
    -  }
    -
    -  /** Show the history */
    -  lazy val historyCommand = new LoopCommand("history", "show the history (optional num is commands to show)") {
    -    override def usage = "[num]"
    -    def defaultLines = 20
    -
    -    def apply(line: String): Result = {
    -      if (history eq NoHistory)
    -        return "No history available."
    -
    -      val xs      = words(line)
    -      val current = history.index
    -      val count   = try xs.head.toInt catch { case _: Exception => defaultLines }
    -      val lines   = history.asStrings takeRight count
    -      val offset  = current - lines.size + 1
    -
    -      for ((line, index) <- lines.zipWithIndex)
    -        echo("%3d  %s".format(index + offset, line))
    -    }
    -  }
    -
    -  // When you know you are most likely breaking into the middle
    -  // of a line being typed.  This softens the blow.
    -  protected def echoAndRefresh(msg: String) = {
    -    echo("\n" + msg)
    -    in.redrawLine()
    -  }
    -  protected def echo(msg: String) = {
    -    out println msg
    -    out.flush()
    -  }
    -
    -  /** Search the history */
    -  def searchHistory(_cmdline: String) {
    -    val cmdline = _cmdline.toLowerCase
    -    val offset  = history.index - history.size + 1
    -
    -    for ((line, index) <- history.asStrings.zipWithIndex ; if line.toLowerCase contains cmdline)
    -      echo("%d %s".format(index + offset, line))
    -  }
    -
    -  private val currentPrompt = Properties.shellPromptString
    -
    -  /** Prompt to print when awaiting input */
    -  def prompt = currentPrompt
    -
       import LoopCommand.{ cmd, nullary }
     
    -  /** Standard commands **/
    -  lazy val standardCommands = List(
    -    cmd("cp", "", "add a jar or directory to the classpath", addClasspath),
    -    cmd("edit", "|", "edit history", editCommand),
    -    cmd("help", "[command]", "print this summary or command-specific help", helpCommand),
    -    historyCommand,
    -    cmd("h?", "", "search the history", searchHistory),
    -    cmd("imports", "[name name ...]", "show import history, identifying sources of names", importsCommand),
    -    //cmd("implicits", "[-v]", "show the implicits in scope", intp.implicitsCommand),
    -    cmd("javap", "", "disassemble a file or class name", javapCommand),
    -    cmd("line", "|", "place line(s) at the end of history", lineCommand),
    -    cmd("load", "", "interpret lines in a file", loadCommand),
    -    cmd("paste", "[-raw] [path]", "enter paste mode or paste a file", pasteCommand),
    -    // nullary("power", "enable power user mode", powerCmd),
    -    nullary("quit", "exit the interpreter", () => Result(keepRunning = false, None)),
    -    nullary("replay", "reset execution and replay all previous commands", replay),
    -    nullary("reset", "reset the repl to its initial state, forgetting all session entries", resetCommand),
    -    cmd("save", "", "save replayable session to a file", saveCommand),
    -    shCommand,
    -    cmd("settings", "[+|-]", "+enable/-disable flags, set compiler options", changeSettings),
    -    nullary("silent", "disable/enable automatic printing of results", verbosity),
    -//    cmd("type", "[-v] ", "display the type of an expression without evaluating it", typeCommand),
    -//    cmd("kind", "[-v] ", "display the kind of expression's type", kindCommand),
    -    nullary("warnings", "show the suppressed warnings from the most recent line which had any", warningsCommand)
    -  )
    -
    -  /** Power user commands */
    -//  lazy val powerCommands: List[LoopCommand] = List(
    -//    cmd("phase", "", "set the implicit phase for power commands", phaseCommand)
    -//  )
    -
    -  private def importsCommand(line: String): Result = {
    -    val tokens    = words(line)
    -    val handlers  = intp.languageWildcardHandlers ++ intp.importHandlers
    -
    -    handlers.filterNot(_.importedSymbols.isEmpty).zipWithIndex foreach {
    -      case (handler, idx) =>
    -        val (types, terms) = handler.importedSymbols partition (_.name.isTypeName)
    -        val imps           = handler.implicitSymbols
    -        val found          = tokens filter (handler importsSymbolNamed _)
    -        val typeMsg        = if (types.isEmpty) "" else types.size + " types"
    -        val termMsg        = if (terms.isEmpty) "" else terms.size + " terms"
    -        val implicitMsg    = if (imps.isEmpty) "" else imps.size + " are implicit"
    -        val foundMsg       = if (found.isEmpty) "" else found.mkString(" // imports: ", ", ", "")
    -        val statsMsg       = List(typeMsg, termMsg, implicitMsg) filterNot (_ == "") mkString ("(", ", ", ")")
    -
    -        intp.reporter.printMessage("%2d) %-30s %s%s".format(
    -          idx + 1,
    -          handler.importString,
    -          statsMsg,
    -          foundMsg
    -        ))
    -    }
    -  }
    -
    -  private def findToolsJar() = PathResolver.SupplementalLocations.platformTools
    +  private val blockedCommands = Set("implicits", "javap", "power", "type", "kind")
     
    -  private def addToolsJarToLoader() = {
    -    val cl = findToolsJar() match {
    -      case Some(tools) => ScalaClassLoader.fromURLs(Seq(tools.toURL), intp.classLoader)
    -      case _           => intp.classLoader
    -    }
    -    if (Javap.isAvailable(cl)) {
    -      repldbg(":javap available.")
    -      cl
    -    }
    -    else {
    -      repldbg(":javap unavailable: no tools.jar at " + jdkHome)
    -      intp.classLoader
    -    }
    -  }
    -//
    -//  protected def newJavap() =
    -//    JavapClass(addToolsJarToLoader(), new IMain.ReplStrippingWriter(intp), Some(intp))
    -//
    -//  private lazy val javap = substituteAndLog[Javap]("javap", NoJavap)(newJavap())
    -
    -  // Still todo: modules.
    -//  private def typeCommand(line0: String): Result = {
    -//    line0.trim match {
    -//      case "" => ":type [-v] "
    -//      case s  => intp.typeCommandInternal(s stripPrefix "-v " trim, verbose = s startsWith "-v ")
    -//    }
    -//  }
    -
    -//  private def kindCommand(expr: String): Result = {
    -//    expr.trim match {
    -//      case "" => ":kind [-v] "
    -//      case s  => intp.kindCommandInternal(s stripPrefix "-v " trim, verbose = s startsWith "-v ")
    -//    }
    -//  }
    -
    -  private def warningsCommand(): Result = {
    -    if (intp.lastWarnings.isEmpty)
    -      "Can't find any cached warnings."
    -    else
    -      intp.lastWarnings foreach { case (pos, msg) => intp.reporter.warning(pos, msg) }
    -  }
    -
    -  private def changeSettings(args: String): Result = {
    -    def showSettings() = {
    -      for (s <- settings.userSetSettings.toSeq.sorted) echo(s.toString)
    -    }
    -    def updateSettings() = {
    -      // put aside +flag options
    -      val (pluses, rest) = (args split "\\s+").toList partition (_.startsWith("+"))
    -      val tmps = new Settings
    -      val (ok, leftover) = tmps.processArguments(rest, processAll = true)
    -      if (!ok) echo("Bad settings request.")
    -      else if (leftover.nonEmpty) echo("Unprocessed settings.")
    -      else {
    -        // boolean flags set-by-user on tmp copy should be off, not on
    -        val offs = tmps.userSetSettings filter (_.isInstanceOf[Settings#BooleanSetting])
    -        val (minuses, nonbools) = rest partition (arg => offs exists (_ respondsTo arg))
    -        // update non-flags
    -        settings.processArguments(nonbools, processAll = true)
    -        // also snag multi-value options for clearing, e.g. -Ylog: and -language:
    -        for {
    -          s <- settings.userSetSettings
    -          if s.isInstanceOf[Settings#MultiStringSetting] || s.isInstanceOf[Settings#PhasesSetting]
    -          if nonbools exists (arg => arg.head == '-' && arg.last == ':' && (s respondsTo arg.init))
    -        } s match {
    -          case c: Clearable => c.clear()
    -          case _ =>
    -        }
    -        def update(bs: Seq[String], name: String=>String, setter: Settings#Setting=>Unit) = {
    -          for (b <- bs)
    -            settings.lookupSetting(name(b)) match {
    -              case Some(s) =>
    -                if (s.isInstanceOf[Settings#BooleanSetting]) setter(s)
    -                else echo(s"Not a boolean flag: $b")
    -              case _ =>
    -                echo(s"Not an option: $b")
    -            }
    -        }
    -        update(minuses, identity, _.tryToSetFromPropertyValue("false"))  // turn off
    -        update(pluses, "-" + _.drop(1), _.tryToSet(Nil))                 // turn on
    -      }
    -    }
    -    if (args.isEmpty) showSettings() else updateSettings()
    -  }
    -
    -  private def javapCommand(line: String): Result = {
    -//    if (javap == null)
    -//      ":javap unavailable, no tools.jar at %s.  Set JDK_HOME.".format(jdkHome)
    -//    else if (line == "")
    -//      ":javap [-lcsvp] [path1 path2 ...]"
    -//    else
    -//      javap(words(line)) foreach { res =>
    -//        if (res.isError) return "Failed: " + res.value
    -//        else res.show()
    -//      }
    -  }
    -
    -  private def pathToPhaseWrapper = intp.originalPath("$r") + ".phased.atCurrent"
    -
    -  private def phaseCommand(name: String): Result = {
    -//    val phased: Phased = power.phased
    -//    import phased.NoPhaseName
    -//
    -//    if (name == "clear") {
    -//      phased.set(NoPhaseName)
    -//      intp.clearExecutionWrapper()
    -//      "Cleared active phase."
    -//    }
    -//    else if (name == "") phased.get match {
    -//      case NoPhaseName => "Usage: :phase  (e.g. typer, erasure.next, erasure+3)"
    -//      case ph          => "Active phase is '%s'.  (To clear, :phase clear)".format(phased.get)
    -//    }
    -//    else {
    -//      val what = phased.parse(name)
    -//      if (what.isEmpty || !phased.set(what))
    -//        "'" + name + "' does not appear to represent a valid phase."
    -//      else {
    -//        intp.setExecutionWrapper(pathToPhaseWrapper)
    -//        val activeMessage =
    -//          if (what.toString.length == name.length) "" + what
    -//          else "%s (%s)".format(what, name)
    -//
    -//        "Active phase is now: " + activeMessage
    -//      }
    -//    }
    -  }
    +  /** Standard commands **/
    +  lazy val sparkStandardCommands: List[SparkILoop.this.LoopCommand] =
    +    standardCommands.filter(cmd => !blockedCommands(cmd.name))
     
       /** Available commands */
    -  def commands: List[LoopCommand] = standardCommands ++ (
    -    // if (isReplPower)
    -    //  powerCommands
    -    // else
    -      Nil
    -    )
    -
    -  val replayQuestionMessage =
    -    """|That entry seems to have slain the compiler.  Shall I replay
    -      |your session? I can re-run each line except the last one.
    -      |[y/n]
    -    """.trim.stripMargin
    -
    -  private val crashRecovery: PartialFunction[Throwable, Boolean] = {
    -    case ex: Throwable =>
    -      val (err, explain) = (
    -        if (intp.isInitializeComplete)
    -          (intp.global.throwableAsString(ex), "")
    -        else
    -          (ex.getMessage, "The compiler did not initialize.\n")
    -        )
    -      echo(err)
    -
    -      ex match {
    -        case _: NoSuchMethodError | _: NoClassDefFoundError =>
    -          echo("\nUnrecoverable error.")
    -          throw ex
    -        case _  =>
    -          def fn(): Boolean =
    -            try in.readYesOrNo(explain + replayQuestionMessage, { echo("\nYou must enter y or n.") ; fn() })
    -            catch { case _: RuntimeException => false }
    -
    -          if (fn()) replay()
    -          else echo("\nAbandoning crashed session.")
    -      }
    -      true
    -  }
    -
    -  // return false if repl should exit
    -  def processLine(line: String): Boolean = {
    -    import scala.concurrent.duration._
    -    Await.ready(globalFuture, 60.seconds)
    -
    -    (line ne null) && (command(line) match {
    -      case Result(false, _)      => false
    -      case Result(_, Some(line)) => addReplay(line) ; true
    -      case _                     => true
    -    })
    -  }
    -
    -  private def readOneLine() = {
    -    out.flush()
    -    in readLine prompt
    -  }
    -
    -  /** The main read-eval-print loop for the repl.  It calls
    -    *  command() for each line of input, and stops when
    -    *  command() returns false.
    -    */
    -  @tailrec final def loop() {
    -    if ( try processLine(readOneLine()) catch crashRecovery )
    -      loop()
    -  }
    -
    -  /** interpret all lines from a specified file */
    -  def interpretAllFrom(file: File) {
    -    savingReader {
    -      savingReplayStack {
    -        file applyReader { reader =>
    -          in = SimpleReader(reader, out, interactive = false)
    -          echo("Loading " + file + "...")
    -          loop()
    -        }
    -      }
    -    }
    -  }
    -
    -  /** create a new interpreter and replay the given commands */
    -  def replay() {
    -    reset()
    -    if (replayCommandStack.isEmpty)
    -      echo("Nothing to replay.")
    -    else for (cmd <- replayCommands) {
    -      echo("Replaying: " + cmd)  // flush because maybe cmd will have its own output
    -      command(cmd)
    -      echo("")
    -    }
    -  }
    -  def resetCommand() {
    -    echo("Resetting interpreter state.")
    -    if (replayCommandStack.nonEmpty) {
    -      echo("Forgetting this session history:\n")
    -      replayCommands foreach echo
    -      echo("")
    -      replayCommandStack = Nil
    -    }
    -    if (intp.namedDefinedTerms.nonEmpty)
    -      echo("Forgetting all expression results and named terms: " + intp.namedDefinedTerms.mkString(", "))
    -    if (intp.definedTypes.nonEmpty)
    -      echo("Forgetting defined types: " + intp.definedTypes.mkString(", "))
    -
    -    reset()
    -  }
    -  def reset() {
    -    intp.reset()
    -    unleashAndSetPhase()
    -  }
    -
    -  def lineCommand(what: String): Result = editCommand(what, None)
    -
    -  // :edit id or :edit line
    -  def editCommand(what: String): Result = editCommand(what, Properties.envOrNone("EDITOR"))
    -
    -  def editCommand(what: String, editor: Option[String]): Result = {
    -    def diagnose(code: String) = {
    -      echo("The edited code is incomplete!\n")
    -      val errless = intp compileSources new BatchSourceFile("", s"object pastel {\n$code\n}")
    -      if (errless) echo("The compiler reports no errors.")
    -    }
    -    def historicize(text: String) = history match {
    -      case jlh: JLineHistory => text.lines foreach jlh.add ; jlh.moveToEnd() ; true
    -      case _ => false
    -    }
    -    def edit(text: String): Result = editor match {
    -      case Some(ed) =>
    -        val tmp = File.makeTemp()
    -        tmp.writeAll(text)
    -        try {
    -          val pr = new ProcessResult(s"$ed ${tmp.path}")
    -          pr.exitCode match {
    -            case 0 =>
    -              tmp.safeSlurp() match {
    -                case Some(edited) if edited.trim.isEmpty => echo("Edited text is empty.")
    -                case Some(edited) =>
    -                  echo(edited.lines map ("+" + _) mkString "\n")
    -                  val res = intp interpret edited
    -                  if (res == IR.Incomplete) diagnose(edited)
    -                  else {
    -                    historicize(edited)
    -                    Result(lineToRecord = Some(edited), keepRunning = true)
    -                  }
    -                case None => echo("Can't read edited text. Did you delete it?")
    -              }
    -            case x => echo(s"Error exit from $ed ($x), ignoring")
    -          }
    -        } finally {
    -          tmp.delete()
    -        }
    -      case None =>
    -        if (historicize(text)) echo("Placing text in recent history.")
    -        else echo(f"No EDITOR defined and you can't change history, echoing your text:%n$text")
    -    }
    -
    -    // if what is a number, use it as a line number or range in history
    -    def isNum = what forall (c => c.isDigit || c == '-' || c == '+')
    -    // except that "-" means last value
    -    def isLast = (what == "-")
    -    if (isLast || !isNum) {
    -      val name = if (isLast) intp.mostRecentVar else what
    -      val sym = intp.symbolOfIdent(name)
    -      intp.prevRequestList collectFirst { case r if r.defines contains sym => r } match {
    -        case Some(req) => edit(req.line)
    -        case None      => echo(s"No symbol in scope: $what")
    -      }
    -    } else try {
    -      val s = what
    -      // line 123, 120+3, -3, 120-123, 120-, note -3 is not 0-3 but (cur-3,cur)
    -      val (start, len) =
    -        if ((s indexOf '+') > 0) {
    -          val (a,b) = s splitAt (s indexOf '+')
    -          (a.toInt, b.drop(1).toInt)
    -        } else {
    -          (s indexOf '-') match {
    -            case -1 => (s.toInt, 1)
    -            case 0  => val n = s.drop(1).toInt ; (history.index - n, n)
    -            case _ if s.last == '-' => val n = s.init.toInt ; (n, history.index - n)
    -            case i  => val n = s.take(i).toInt ; (n, s.drop(i+1).toInt - n)
    -          }
    -        }
    -      import scala.collection.JavaConverters._
    -      val index = (start - 1) max 0
    -      val text = history match {
    -        case jlh: JLineHistory => jlh.entries(index).asScala.take(len) map (_.value) mkString "\n"
    -        case _ => history.asStrings.slice(index, index + len) mkString "\n"
    -      }
    -      edit(text)
    -    } catch {
    -      case _: NumberFormatException => echo(s"Bad range '$what'")
    -        echo("Use line 123, 120+3, -3, 120-123, 120-, note -3 is not 0-3 but (cur-3,cur)")
    -    }
    -  }
    -
    -  /** fork a shell and run a command */
    -  lazy val shCommand = new LoopCommand("sh", "run a shell command (result is implicitly => List[String])") {
    -    override def usage = ""
    -    def apply(line: String): Result = line match {
    -      case ""   => showUsage()
    -      case _    =>
    -        val toRun = s"new ${classOf[ProcessResult].getName}(${string2codeQuoted(line)})"
    -        intp interpret toRun
    -        ()
    -    }
    -  }
    -
    -  def withFile[A](filename: String)(action: File => A): Option[A] = {
    -    val res = Some(File(filename)) filter (_.exists) map action
    -    if (res.isEmpty) echo("That file does not exist")  // courtesy side-effect
    -    res
    -  }
    -
    -  def loadCommand(arg: String) = {
    -    var shouldReplay: Option[String] = None
    -    withFile(arg)(f => {
    -      interpretAllFrom(f)
    -      shouldReplay = Some(":load " + arg)
    -    })
    -    Result(keepRunning = true, shouldReplay)
    -  }
    -
    -  def saveCommand(filename: String): Result = (
    -    if (filename.isEmpty) echo("File name is required.")
    -    else if (replayCommandStack.isEmpty) echo("No replay commands in session")
    -    else File(filename).printlnAll(replayCommands: _*)
    -    )
    -
    -  def addClasspath(arg: String): Unit = {
    -    val f = File(arg).normalize
    -    if (f.exists) {
    -      addedClasspath = ClassPath.join(addedClasspath, f.path)
    -      val totalClasspath = ClassPath.join(settings.classpath.value, addedClasspath)
    -      echo("Added '%s'.  Your new classpath is:\n\"%s\"".format(f.path, totalClasspath))
    -      replay()
    -    }
    -    else echo("The path '" + f + "' doesn't seem to exist.")
    -  }
    -
    -  def powerCmd(): Result = {
    -    if (isReplPower) "Already in power mode."
    -    else enablePowerMode(isDuringInit = false)
    -  }
    -  def enablePowerMode(isDuringInit: Boolean) = {
    -    replProps.power setValue true
    -    unleashAndSetPhase()
    -    // asyncEcho(isDuringInit, power.banner)
    -  }
    -  private def unleashAndSetPhase() {
    -    if (isReplPower) {
    -    //  power.unleash()
    -      // Set the phase to "typer"
    -      // intp beSilentDuring phaseCommand("typer")
    -    }
    -  }
    -
    -  def asyncEcho(async: Boolean, msg: => String) {
    -    if (async) asyncMessage(msg)
    -    else echo(msg)
    -  }
    -
    -  def verbosity() = {
    -    val old = intp.printResults
    -    intp.printResults = !old
    -    echo("Switched " + (if (old) "off" else "on") + " result printing.")
    -  }
    -
    -  /** Run one command submitted by the user.  Two values are returned:
    -    * (1) whether to keep running, (2) the line to record for replay,
    -    * if any. */
    -  def command(line: String): Result = {
    -    if (line startsWith ":") {
    -      val cmd = line.tail takeWhile (x => !x.isWhitespace)
    -      uniqueCommand(cmd) match {
    -        case Some(lc) => lc(line.tail stripPrefix cmd dropWhile (_.isWhitespace))
    -        case _        => ambiguousError(cmd)
    -      }
    -    }
    -    else if (intp.global == null) Result(keepRunning = false, None)  // Notice failure to create compiler
    -    else Result(keepRunning = true, interpretStartingWith(line))
    -  }
    -
    -  private def readWhile(cond: String => Boolean) = {
    -    Iterator continually in.readLine("") takeWhile (x => x != null && cond(x))
    -  }
    -
    -  def pasteCommand(arg: String): Result = {
    -    var shouldReplay: Option[String] = None
    -    def result = Result(keepRunning = true, shouldReplay)
    -    val (raw, file) =
    -      if (arg.isEmpty) (false, None)
    -      else {
    -        val r = """(-raw)?(\s+)?([^\-]\S*)?""".r
    -        arg match {
    -          case r(flag, sep, name) =>
    -            if (flag != null && name != null && sep == null)
    -              echo(s"""I assume you mean "$flag $name"?""")
    -            (flag != null, Option(name))
    -          case _ =>
    -            echo("usage: :paste -raw file")
    -            return result
    -        }
    -      }
    -    val code = file match {
    -      case Some(name) =>
    -        withFile(name)(f => {
    -          shouldReplay = Some(s":paste $arg")
    -          val s = f.slurp.trim
    -          if (s.isEmpty) echo(s"File contains no code: $f")
    -          else echo(s"Pasting file $f...")
    -          s
    -        }) getOrElse ""
    -      case None =>
    -        echo("// Entering paste mode (ctrl-D to finish)\n")
    -        val text = (readWhile(_ => true) mkString "\n").trim
    -        if (text.isEmpty) echo("\n// Nothing pasted, nothing gained.\n")
    -        else echo("\n// Exiting paste mode, now interpreting.\n")
    -        text
    -    }
    -    def interpretCode() = {
    -      val res = intp interpret code
    -      // if input is incomplete, let the compiler try to say why
    -      if (res == IR.Incomplete) {
    -        echo("The pasted code is incomplete!\n")
    -        // Remembrance of Things Pasted in an object
    -        val errless = intp compileSources new BatchSourceFile("", s"object pastel {\n$code\n}")
    -        if (errless) echo("...but compilation found no error? Good luck with that.")
    -      }
    -    }
    -    def compileCode() = {
    -      val errless = intp compileSources new BatchSourceFile("", code)
    -      if (!errless) echo("There were compilation errors!")
    -    }
    -    if (code.nonEmpty) {
    -      if (raw) compileCode() else interpretCode()
    -    }
    -    result
    -  }
    -
    -  private object paste extends Pasted {
    -    val ContinueString = "     | "
    -    val PromptString   = "scala> "
    -
    -    def interpret(line: String): Unit = {
    -      echo(line.trim)
    -      intp interpret line
    -      echo("")
    -    }
    -
    -    def transcript(start: String) = {
    -      echo("\n// Detected repl transcript paste: ctrl-D to finish.\n")
    -      apply(Iterator(start) ++ readWhile(_.trim != PromptString.trim))
    -    }
    -  }
    -  import paste.{ ContinueString, PromptString }
    -
    -  /** Interpret expressions starting with the first line.
    -    * Read lines until a complete compilation unit is available
    -    * or until a syntax error has been seen.  If a full unit is
    -    * read, go ahead and interpret it.  Return the full string
    -    * to be recorded for replay, if any.
    -    */
    -  def interpretStartingWith(code: String): Option[String] = {
    -    // signal completion non-completion input has been received
    -    in.completion.resetVerbosity()
    -
    -    def reallyInterpret = {
    -      val reallyResult = intp.interpret(code)
    -      (reallyResult, reallyResult match {
    -        case IR.Error       => None
    -        case IR.Success     => Some(code)
    -        case IR.Incomplete  =>
    -          if (in.interactive && code.endsWith("\n\n")) {
    -            echo("You typed two blank lines.  Starting a new command.")
    -            None
    -          }
    -          else in.readLine(ContinueString) match {
    -            case null =>
    -              // we know compilation is going to fail since we're at EOF and the
    -              // parser thinks the input is still incomplete, but since this is
    -              // a file being read non-interactively we want to fail.  So we send
    -              // it straight to the compiler for the nice error message.
    -              intp.compileString(code)
    -              None
    -
    -            case line => interpretStartingWith(code + "\n" + line)
    -          }
    -      })
    -    }
    -
    -    /** Here we place ourselves between the user and the interpreter and examine
    -      *  the input they are ostensibly submitting.  We intervene in several cases:
    -      *
    -      *  1) If the line starts with "scala> " it is assumed to be an interpreter paste.
    -      *  2) If the line starts with "." (but not ".." or "./") it is treated as an invocation
    -      *     on the previous result.
    -      *  3) If the Completion object's execute returns Some(_), we inject that value
    -      *     and avoid the interpreter, as it's likely not valid scala code.
    -      */
    -    if (code == "") None
    -    else if (!paste.running && code.trim.startsWith(PromptString)) {
    -      paste.transcript(code)
    -      None
    -    }
    -    else if (Completion.looksLikeInvocation(code) && intp.mostRecentVar != "") {
    -      interpretStartingWith(intp.mostRecentVar + code)
    -    }
    -    else if (code.trim startsWith "//") {
    -      // line comment, do nothing
    -      None
    -    }
    -    else
    -      reallyInterpret._2
    -  }
    -
    -  // runs :load `file` on any files passed via -i
    -  def loadFiles(settings: Settings) = settings match {
    -    case settings: GenericRunnerSettings =>
    -      for (filename <- settings.loadfiles.value) {
    -        val cmd = ":load " + filename
    -        command(cmd)
    -        addReplay(cmd)
    -        echo("")
    -      }
    -    case _ =>
    -  }
    -
    -  /** Tries to create a JLineReader, falling back to SimpleReader:
    -    *  unless settings or properties are such that it should start
    -    *  with SimpleReader.
    -    */
    -  def chooseReader(settings: Settings): InteractiveReader = {
    -    if (settings.Xnojline || Properties.isEmacsShell)
    -      SimpleReader()
    -    else try new JLineReader(
    -      if (settings.noCompletion) NoCompletion
    -      else new SparkJLineCompletion(intp)
    -    )
    -    catch {
    -      case ex @ (_: Exception | _: NoClassDefFoundError) =>
    -        echo("Failed to created JLineReader: " + ex + "\nFalling back to SimpleReader.")
    -        SimpleReader()
    -    }
    -  }
    -  protected def tagOfStaticClass[T: ClassTag]: u.TypeTag[T] =
    -    u.TypeTag[T](
    -      m,
    -      new TypeCreator {
    -        def apply[U <: Universe with Singleton](m: Mirror[U]): U # Type =
    -          m.staticClass(classTag[T].runtimeClass.getName).toTypeConstructor.asInstanceOf[U # Type]
    -      })
    -
    -  private def loopPostInit() {
    -    // Bind intp somewhere out of the regular namespace where
    -    // we can get at it in generated code.
    -    intp.quietBind(NamedParam[SparkIMain]("$intp", intp)(tagOfStaticClass[SparkIMain], classTag[SparkIMain]))
    -    // Auto-run code via some setting.
    -    ( replProps.replAutorunCode.option
    -      flatMap (f => io.File(f).safeSlurp())
    -      foreach (intp quietRun _)
    -      )
    -    // classloader and power mode setup
    -    intp.setContextClassLoader()
    -    if (isReplPower) {
    -     // replProps.power setValue true
    -     // unleashAndSetPhase()
    -     // asyncMessage(power.banner)
    -    }
    -    // SI-7418 Now, and only now, can we enable TAB completion.
    -    in match {
    -      case x: JLineReader => x.consoleReader.postInit
    -      case _              =>
    -    }
    -  }
    -  def process(settings: Settings): Boolean = savingContextLoader {
    -    this.settings = settings
    -    createInterpreter()
    -
    -    // sets in to some kind of reader depending on environmental cues
    -    in = in0.fold(chooseReader(settings))(r => SimpleReader(r, out, interactive = true))
    -    globalFuture = future {
    -      intp.initializeSynchronous()
    -      loopPostInit()
    -      !intp.reporter.hasErrors
    -    }
    -    import scala.concurrent.duration._
    -    Await.ready(globalFuture, 10 seconds)
    -    printWelcome()
    +  override def commands: List[LoopCommand] = sparkStandardCommands
    +
    +  /** 
    +   * We override `loadFiles` because we need to initialize Spark *before* the REPL
    +   * sees any files, so that the Spark context is visible in those files. This is a bit of a
    +   * hack, but there isn't another hook available to us at this point.
    +   */
    +  override def loadFiles(settings: Settings): Unit = {
         initializeSpark()
    -    loadFiles(settings)
    -
    -    try loop()
    -    catch AbstractOrMissingHandler()
    -    finally closeInterpreter()
    -
    -    true
    +    super.loadFiles(settings)
       }
    -
    -  @deprecated("Use `process` instead", "2.9.0")
    -  def main(settings: Settings): Unit = process(settings) //used by sbt
     }
     
     object SparkILoop {
    -  implicit def loopToInterpreter(repl: SparkILoop): SparkIMain = repl.intp
     
    -  // Designed primarily for use by test code: take a String with a
    -  // bunch of code, and prints out a transcript of what it would look
    -  // like if you'd just typed it into the repl.
    -  def runForTranscript(code: String, settings: Settings): String = {
    -    import java.io.{ BufferedReader, StringReader, OutputStreamWriter }
    -
    -    stringFromStream { ostream =>
    -      Console.withOut(ostream) {
    -        val output = new JPrintWriter(new OutputStreamWriter(ostream), true) {
    -          override def write(str: String) = {
    -            // completely skip continuation lines
    -            if (str forall (ch => ch.isWhitespace || ch == '|')) ()
    -            else super.write(str)
    -          }
    -        }
    -        val input = new BufferedReader(new StringReader(code.trim + "\n")) {
    -          override def readLine(): String = {
    -            val s = super.readLine()
    -            // helping out by printing the line being interpreted.
    -            if (s != null)
    -              output.println(s)
    -            s
    -          }
    -        }
    -        val repl = new SparkILoop(input, output)
    -        if (settings.classpath.isDefault)
    -          settings.classpath.value = sys.props("java.class.path")
    -
    -        repl process settings
    -      }
    -    }
    -  }
    -
    -  /** Creates an interpreter loop with default settings and feeds
    -    *  the given code to it as input.
    -    */
    +  /** 
    +   * Creates an interpreter loop with default settings and feeds
    +   * the given code to it as input.
    +   */
       def run(code: String, sets: Settings = new Settings): String = {
         import java.io.{ BufferedReader, StringReader, OutputStreamWriter }
     
         stringFromStream { ostream =>
           Console.withOut(ostream) {
    -        val input    = new BufferedReader(new StringReader(code))
    -        val output   = new JPrintWriter(new OutputStreamWriter(ostream), true)
    -        val repl     = new SparkILoop(input, output)
    +        val input = new BufferedReader(new StringReader(code))
    +        val output = new JPrintWriter(new OutputStreamWriter(ostream), true)
    +        val repl = new SparkILoop(input, output)
     
             if (sets.classpath.isDefault)
               sets.classpath.value = sys.props("java.class.path")
    diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala
    deleted file mode 100644
    index 1cb910f376060..0000000000000
    --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala
    +++ /dev/null
    @@ -1,1319 +0,0 @@
    -/* NSC -- new Scala compiler
    - * Copyright 2005-2013 LAMP/EPFL
    - * @author  Martin Odersky
    - */
    -
    -package scala
    -package tools.nsc
    -package interpreter
    -
    -import PartialFunction.cond
    -import scala.language.implicitConversions
    -import scala.beans.BeanProperty
    -import scala.collection.mutable
    -import scala.concurrent.{ Future, ExecutionContext }
    -import scala.reflect.runtime.{ universe => ru }
    -import scala.reflect.{ ClassTag, classTag }
    -import scala.reflect.internal.util.{ BatchSourceFile, SourceFile }
    -import scala.tools.util.PathResolver
    -import scala.tools.nsc.io.AbstractFile
    -import scala.tools.nsc.typechecker.{ TypeStrings, StructuredTypeStrings }
    -import scala.tools.nsc.util.{ ScalaClassLoader, stringFromReader, stringFromWriter, StackTraceOps }
    -import scala.tools.nsc.util.Exceptional.unwrap
    -import javax.script.{AbstractScriptEngine, Bindings, ScriptContext, ScriptEngine, ScriptEngineFactory, ScriptException, CompiledScript, Compilable}
    -
    -/** An interpreter for Scala code.
    -  *
    -  *  The main public entry points are compile(), interpret(), and bind().
    -  *  The compile() method loads a complete Scala file.  The interpret() method
    -  *  executes one line of Scala code at the request of the user.  The bind()
    -  *  method binds an object to a variable that can then be used by later
    -  *  interpreted code.
    -  *
    -  *  The overall approach is based on compiling the requested code and then
    -  *  using a Java classloader and Java reflection to run the code
    -  *  and access its results.
    -  *
    -  *  In more detail, a single compiler instance is used
    -  *  to accumulate all successfully compiled or interpreted Scala code.  To
    -  *  "interpret" a line of code, the compiler generates a fresh object that
    -  *  includes the line of code and which has public member(s) to export
    -  *  all variables defined by that code.  To extract the result of an
    -  *  interpreted line to show the user, a second "result object" is created
    -  *  which imports the variables exported by the above object and then
    -  *  exports members called "$eval" and "$print". To accomodate user expressions
    -  *  that read from variables or methods defined in previous statements, "import"
    -  *  statements are used.
    -  *
    -  *  This interpreter shares the strengths and weaknesses of using the
    -  *  full compiler-to-Java.  The main strength is that interpreted code
    -  *  behaves exactly as does compiled code, including running at full speed.
    -  *  The main weakness is that redefining classes and methods is not handled
    -  *  properly, because rebinding at the Java level is technically difficult.
    -  *
    -  *  @author Moez A. Abdel-Gawad
    -  *  @author Lex Spoon
    -  */
    -class SparkIMain(@BeanProperty val factory: ScriptEngineFactory, initialSettings: Settings,
    -  protected val out: JPrintWriter) extends AbstractScriptEngine with Compilable with SparkImports {
    -  imain =>
    -
    -  setBindings(createBindings, ScriptContext.ENGINE_SCOPE)
    -  object replOutput extends ReplOutput(settings.Yreploutdir) { }
    -
    -  @deprecated("Use replOutput.dir instead", "2.11.0")
    -  def virtualDirectory = replOutput.dir
    -  // Used in a test case.
    -  def showDirectory() = replOutput.show(out)
    -
    -  private[nsc] var printResults               = true      // whether to print result lines
    -  private[nsc] var totalSilence               = false     // whether to print anything
    -  private var _initializeComplete             = false     // compiler is initialized
    -  private var _isInitialized: Future[Boolean] = null      // set up initialization future
    -  private var bindExceptions                  = true      // whether to bind the lastException variable
    -  private var _executionWrapper               = ""        // code to be wrapped around all lines
    -
    -  /** We're going to go to some trouble to initialize the compiler asynchronously.
    -    *  It's critical that nothing call into it until it's been initialized or we will
    -    *  run into unrecoverable issues, but the perceived repl startup time goes
    -    *  through the roof if we wait for it.  So we initialize it with a future and
    -    *  use a lazy val to ensure that any attempt to use the compiler object waits
    -    *  on the future.
    -    */
    -  private var _classLoader: util.AbstractFileClassLoader = null                              // active classloader
    -  private val _compiler: ReplGlobal                 = newCompiler(settings, reporter)   // our private compiler
    -
    -  def compilerClasspath: Seq[java.net.URL] = (
    -    if (isInitializeComplete) global.classPath.asURLs
    -    else new PathResolver(settings).result.asURLs  // the compiler's classpath
    -    )
    -  def settings = initialSettings
    -  // Run the code body with the given boolean settings flipped to true.
    -  def withoutWarnings[T](body: => T): T = beQuietDuring {
    -    val saved = settings.nowarn.value
    -    if (!saved)
    -      settings.nowarn.value = true
    -
    -    try body
    -    finally if (!saved) settings.nowarn.value = false
    -  }
    -
    -  /** construct an interpreter that reports to Console */
    -  def this(settings: Settings, out: JPrintWriter) = this(null, settings, out)
    -  def this(factory: ScriptEngineFactory, settings: Settings) = this(factory, settings, new NewLinePrintWriter(new ConsoleWriter, true))
    -  def this(settings: Settings) = this(settings, new NewLinePrintWriter(new ConsoleWriter, true))
    -  def this(factory: ScriptEngineFactory) = this(factory, new Settings())
    -  def this() = this(new Settings())
    -
    -  lazy val formatting: Formatting = new Formatting {
    -    val prompt = Properties.shellPromptString
    -  }
    -  lazy val reporter: SparkReplReporter = new SparkReplReporter(this)
    -
    -  import formatting._
    -  import reporter.{ printMessage, printUntruncatedMessage }
    -
    -  // This exists mostly because using the reporter too early leads to deadlock.
    -  private def echo(msg: String) { Console println msg }
    -  private def _initSources = List(new BatchSourceFile("", "class $repl_$init { }"))
    -  private def _initialize() = {
    -    try {
    -      // if this crashes, REPL will hang its head in shame
    -      val run = new _compiler.Run()
    -      assert(run.typerPhase != NoPhase, "REPL requires a typer phase.")
    -      run compileSources _initSources
    -      _initializeComplete = true
    -      true
    -    }
    -    catch AbstractOrMissingHandler()
    -  }
    -  private def tquoted(s: String) = "\"\"\"" + s + "\"\"\""
    -  private val logScope = scala.sys.props contains "scala.repl.scope"
    -  private def scopelog(msg: String) = if (logScope) Console.err.println(msg)
    -
    -  // argument is a thunk to execute after init is done
    -  def initialize(postInitSignal: => Unit) {
    -    synchronized {
    -      if (_isInitialized == null) {
    -        _isInitialized =
    -          Future(try _initialize() finally postInitSignal)(ExecutionContext.global)
    -      }
    -    }
    -  }
    -  def initializeSynchronous(): Unit = {
    -    if (!isInitializeComplete) {
    -      _initialize()
    -      assert(global != null, global)
    -    }
    -  }
    -  def isInitializeComplete = _initializeComplete
    -
    -  lazy val global: Global = {
    -    if (!isInitializeComplete) _initialize()
    -    _compiler
    -  }
    -
    -  import global._
    -  import definitions.{ ObjectClass, termMember, dropNullaryMethod}
    -
    -  lazy val runtimeMirror = ru.runtimeMirror(classLoader)
    -
    -  private def noFatal(body: => Symbol): Symbol = try body catch { case _: FatalError => NoSymbol }
    -
    -  def getClassIfDefined(path: String)  = (
    -    noFatal(runtimeMirror staticClass path)
    -      orElse noFatal(rootMirror staticClass path)
    -    )
    -  def getModuleIfDefined(path: String) = (
    -    noFatal(runtimeMirror staticModule path)
    -      orElse noFatal(rootMirror staticModule path)
    -    )
    -
    -  implicit class ReplTypeOps(tp: Type) {
    -    def andAlso(fn: Type => Type): Type = if (tp eq NoType) tp else fn(tp)
    -  }
    -
    -  // TODO: If we try to make naming a lazy val, we run into big time
    -  // scalac unhappiness with what look like cycles.  It has not been easy to
    -  // reduce, but name resolution clearly takes different paths.
    -  object naming extends {
    -    val global: imain.global.type = imain.global
    -  } with Naming {
    -    // make sure we don't overwrite their unwisely named res3 etc.
    -    def freshUserTermName(): TermName = {
    -      val name = newTermName(freshUserVarName())
    -      if (replScope containsName name) freshUserTermName()
    -      else name
    -    }
    -    def isInternalTermName(name: Name) = isInternalVarName("" + name)
    -  }
    -  import naming._
    -
    -  object deconstruct extends {
    -    val global: imain.global.type = imain.global
    -  } with StructuredTypeStrings
    -
    -  lazy val memberHandlers = new {
    -    val intp: imain.type = imain
    -  } with SparkMemberHandlers
    -  import memberHandlers._
    -
    -  /** Temporarily be quiet */
    -  def beQuietDuring[T](body: => T): T = {
    -    val saved = printResults
    -    printResults = false
    -    try body
    -    finally printResults = saved
    -  }
    -  def beSilentDuring[T](operation: => T): T = {
    -    val saved = totalSilence
    -    totalSilence = true
    -    try operation
    -    finally totalSilence = saved
    -  }
    -
    -  def quietRun[T](code: String) = beQuietDuring(interpret(code))
    -
    -  /** takes AnyRef because it may be binding a Throwable or an Exceptional */
    -  private def withLastExceptionLock[T](body: => T, alt: => T): T = {
    -    assert(bindExceptions, "withLastExceptionLock called incorrectly.")
    -    bindExceptions = false
    -
    -    try     beQuietDuring(body)
    -    catch   logAndDiscard("withLastExceptionLock", alt)
    -    finally bindExceptions = true
    -  }
    -
    -  def executionWrapper = _executionWrapper
    -  def setExecutionWrapper(code: String) = _executionWrapper = code
    -  def clearExecutionWrapper() = _executionWrapper = ""
    -
    -  /** interpreter settings */
    -  lazy val isettings = new SparkISettings(this)
    -
    -  /** Instantiate a compiler.  Overridable. */
    -  protected def newCompiler(settings: Settings, reporter: reporters.Reporter): ReplGlobal = {
    -    settings.outputDirs setSingleOutput replOutput.dir
    -    settings.exposeEmptyPackage.value = true
    -    new Global(settings, reporter) with ReplGlobal { override def toString: String = "" }
    -  }
    -
    -  /** Parent classloader.  Overridable. */
    -  protected def parentClassLoader: ClassLoader =
    -    settings.explicitParentLoader.getOrElse( this.getClass.getClassLoader() )
    -
    -  /* A single class loader is used for all commands interpreted by this Interpreter.
    -     It would also be possible to create a new class loader for each command
    -     to interpret.  The advantages of the current approach are:
    -
    -       - Expressions are only evaluated one time.  This is especially
    -         significant for I/O, e.g. "val x = Console.readLine"
    -
    -     The main disadvantage is:
    -
    -       - Objects, classes, and methods cannot be rebound.  Instead, definitions
    -         shadow the old ones, and old code objects refer to the old
    -         definitions.
    -  */
    -  def resetClassLoader() = {
    -    repldbg("Setting new classloader: was " + _classLoader)
    -    _classLoader = null
    -    ensureClassLoader()
    -  }
    -  final def ensureClassLoader() {
    -    if (_classLoader == null)
    -      _classLoader = makeClassLoader()
    -  }
    -  def classLoader: util.AbstractFileClassLoader = {
    -    ensureClassLoader()
    -    _classLoader
    -  }
    -
    -  def backticked(s: String): String = (
    -    (s split '.').toList map {
    -      case "_"                               => "_"
    -      case s if nme.keywords(newTermName(s)) => s"`$s`"
    -      case s                                 => s
    -    } mkString "."
    -    )
    -  def readRootPath(readPath: String) = getModuleIfDefined(readPath)
    -
    -  abstract class PhaseDependentOps {
    -    def shift[T](op: => T): T
    -
    -    def path(name: => Name): String = shift(path(symbolOfName(name)))
    -    def path(sym: Symbol): String = backticked(shift(sym.fullName))
    -    def sig(sym: Symbol): String  = shift(sym.defString)
    -  }
    -  object typerOp extends PhaseDependentOps {
    -    def shift[T](op: => T): T = exitingTyper(op)
    -  }
    -  object flatOp extends PhaseDependentOps {
    -    def shift[T](op: => T): T = exitingFlatten(op)
    -  }
    -
    -  def originalPath(name: String): String = originalPath(name: TermName)
    -  def originalPath(name: Name): String   = typerOp path name
    -  def originalPath(sym: Symbol): String  = typerOp path sym
    -  def flatPath(sym: Symbol): String      = flatOp shift sym.javaClassName
    -  def translatePath(path: String) = {
    -    val sym = if (path endsWith "$") symbolOfTerm(path.init) else symbolOfIdent(path)
    -    sym.toOption map flatPath
    -  }
    -  def translateEnclosingClass(n: String) = symbolOfTerm(n).enclClass.toOption map flatPath
    -
    -  private class TranslatingClassLoader(parent: ClassLoader) extends util.AbstractFileClassLoader(replOutput.dir, parent) {
    -    /** Overridden here to try translating a simple name to the generated
    -      *  class name if the original attempt fails.  This method is used by
    -      *  getResourceAsStream as well as findClass.
    -      */
    -    override protected def findAbstractFile(name: String): AbstractFile =
    -      super.findAbstractFile(name) match {
    -        case null if _initializeComplete => translatePath(name) map (super.findAbstractFile(_)) orNull
    -        case file => file
    -      }
    -  }
    -  private def makeClassLoader(): util.AbstractFileClassLoader =
    -    new TranslatingClassLoader(parentClassLoader match {
    -      case null   => ScalaClassLoader fromURLs compilerClasspath
    -      case p      => new ScalaClassLoader.URLClassLoader(compilerClasspath, p)
    -    })
    -
    -  // Set the current Java "context" class loader to this interpreter's class loader
    -  def setContextClassLoader() = classLoader.setAsContext()
    -
    -  def allDefinedNames: List[Name]  = exitingTyper(replScope.toList.map(_.name).sorted)
    -  def unqualifiedIds: List[String] = allDefinedNames map (_.decode) sorted
    -
    -  /** Most recent tree handled which wasn't wholly synthetic. */
    -  private def mostRecentlyHandledTree: Option[Tree] = {
    -    prevRequests.reverse foreach { req =>
    -      req.handlers.reverse foreach {
    -        case x: MemberDefHandler if x.definesValue && !isInternalTermName(x.name) => return Some(x.member)
    -        case _ => ()
    -      }
    -    }
    -    None
    -  }
    -
    -  private def updateReplScope(sym: Symbol, isDefined: Boolean) {
    -    def log(what: String) {
    -      val mark = if (sym.isType) "t " else "v "
    -      val name = exitingTyper(sym.nameString)
    -      val info = cleanTypeAfterTyper(sym)
    -      val defn = sym defStringSeenAs info
    -
    -      scopelog(f"[$mark$what%6s] $name%-25s $defn%s")
    -    }
    -    if (ObjectClass isSubClass sym.owner) return
    -    // unlink previous
    -    replScope lookupAll sym.name foreach { sym =>
    -      log("unlink")
    -      replScope unlink sym
    -    }
    -    val what = if (isDefined) "define" else "import"
    -    log(what)
    -    replScope enter sym
    -  }
    -
    -  def recordRequest(req: Request) {
    -    if (req == null)
    -      return
    -
    -    prevRequests += req
    -
    -    // warning about serially defining companions.  It'd be easy
    -    // enough to just redefine them together but that may not always
    -    // be what people want so I'm waiting until I can do it better.
    -    exitingTyper {
    -      req.defines filterNot (s => req.defines contains s.companionSymbol) foreach { newSym =>
    -        val oldSym = replScope lookup newSym.name.companionName
    -        if (Seq(oldSym, newSym).permutations exists { case Seq(s1, s2) => s1.isClass && s2.isModule }) {
    -          replwarn(s"warning: previously defined $oldSym is not a companion to $newSym.")
    -          replwarn("Companions must be defined together; you may wish to use :paste mode for this.")
    -        }
    -      }
    -    }
    -    exitingTyper {
    -      req.imports foreach (sym => updateReplScope(sym, isDefined = false))
    -      req.defines foreach (sym => updateReplScope(sym, isDefined = true))
    -    }
    -  }
    -
    -  private[nsc] def replwarn(msg: => String) {
    -    if (!settings.nowarnings)
    -      printMessage(msg)
    -  }
    -
    -  def compileSourcesKeepingRun(sources: SourceFile*) = {
    -    val run = new Run()
    -    assert(run.typerPhase != NoPhase, "REPL requires a typer phase.")
    -    reporter.reset()
    -    run compileSources sources.toList
    -    (!reporter.hasErrors, run)
    -  }
    -
    -  /** Compile an nsc SourceFile.  Returns true if there are
    -    *  no compilation errors, or false otherwise.
    -    */
    -  def compileSources(sources: SourceFile*): Boolean =
    -    compileSourcesKeepingRun(sources: _*)._1
    -
    -  /** Compile a string.  Returns true if there are no
    -    *  compilation errors, or false otherwise.
    -    */
    -  def compileString(code: String): Boolean =
    -    compileSources(new BatchSourceFile("