diff --git a/.github/PULL_REQUEST_TEMPLATE b/.github/PULL_REQUEST_TEMPLATE index 39bcc84225b4a..989e95ccd0135 100644 --- a/.github/PULL_REQUEST_TEMPLATE +++ b/.github/PULL_REQUEST_TEMPLATE @@ -3,7 +3,7 @@ (Please fill in changes proposed in this fix) -## How was the this patch tested? +## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index f194a46303e0d..636d39e1e9cae 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -13,7 +13,9 @@ export("print.jobj") # MLlib integration exportMethods("glm", "predict", - "summary") + "summary", + "kmeans", + "fitted") # Job group lifecycle management methods export("setJobGroup", @@ -109,6 +111,7 @@ exportMethods("%in%", "add_months", "alias", "approxCountDistinct", + "approxQuantile", "array_contains", "asc", "ascii", diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 2dba71abec689..3db72b57954d7 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -67,6 +67,13 @@ setGeneric("crosstab", function(x, col1, col2) { standardGeneric("crosstab") }) # @export setGeneric("freqItems", function(x, cols, support = 0.01) { standardGeneric("freqItems") }) +# @rdname statfunctions +# @export +setGeneric("approxQuantile", + function(x, col, probabilities, relativeError) { + standardGeneric("approxQuantile") + }) + # @rdname distinct # @export setGeneric("distinct", function(x, numPartitions = 1) { standardGeneric("distinct") }) @@ -1160,3 +1167,11 @@ setGeneric("predict", function(object, ...) { standardGeneric("predict") }) #' @rdname rbind #' @export setGeneric("rbind", signature = "...") + +#' @rdname kmeans +#' @export +setGeneric("kmeans") + +#' @rdname fitted +#' @export +setGeneric("fitted") diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 8d3b4388ae575..346f33d7dab2c 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -104,11 +104,11 @@ setMethod("predict", signature(object = "PipelineModel"), setMethod("summary", signature(object = "PipelineModel"), function(object, ...) { modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelName", object@model) + "getModelName", object@model) features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelFeatures", object@model) + "getModelFeatures", object@model) coefficients <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelCoefficients", object@model) + "getModelCoefficients", object@model) if (modelName == "LinearRegressionModel") { devianceResiduals <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "getModelDevianceResiduals", object@model) @@ -119,10 +119,76 @@ setMethod("summary", signature(object = "PipelineModel"), colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)") rownames(coefficients) <- unlist(features) return(list(devianceResiduals = devianceResiduals, coefficients = coefficients)) - } else { + } else if (modelName == "LogisticRegressionModel") { coefficients <- as.matrix(unlist(coefficients)) colnames(coefficients) <- c("Estimate") rownames(coefficients) <- unlist(features) return(list(coefficients = coefficients)) + } else if (modelName == "KMeansModel") { + modelSize <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getKMeansModelSize", object@model) + cluster <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getKMeansCluster", object@model, "classes") + k <- unlist(modelSize)[1] + size <- unlist(modelSize)[-1] + coefficients <- t(matrix(coefficients, ncol = k)) + colnames(coefficients) <- unlist(features) + rownames(coefficients) <- 1:k + return(list(coefficients = coefficients, size = size, cluster = dataFrame(cluster))) + } else { + stop(paste("Unsupported model", modelName, sep = " ")) + } + }) + +#' Fit a k-means model +#' +#' Fit a k-means model, similarly to R's kmeans(). +#' +#' @param x DataFrame for training +#' @param centers Number of centers +#' @param iter.max Maximum iteration number +#' @param algorithm Algorithm choosen to fit the model +#' @return A fitted k-means model +#' @rdname kmeans +#' @export +#' @examples +#'\dontrun{ +#' model <- kmeans(x, centers = 2, algorithm="random") +#'} +setMethod("kmeans", signature(x = "DataFrame"), + function(x, centers, iter.max = 10, algorithm = c("random", "k-means||")) { + columnNames <- as.array(colnames(x)) + algorithm <- match.arg(algorithm) + model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "fitKMeans", x@sdf, + algorithm, iter.max, centers, columnNames) + return(new("PipelineModel", model = model)) + }) + +#' Get fitted result from a model +#' +#' Get fitted result from a model, similarly to R's fitted(). +#' +#' @param object A fitted MLlib model +#' @return DataFrame containing fitted values +#' @rdname fitted +#' @export +#' @examples +#'\dontrun{ +#' model <- kmeans(trainingData, 2) +#' fitted.model <- fitted(model) +#' showDF(fitted.model) +#'} +setMethod("fitted", signature(object = "PipelineModel"), + function(object, method = c("centers", "classes"), ...) { + modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getModelName", object@model) + + if (modelName == "KMeansModel") { + method <- match.arg(method) + fittedResult <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getKMeansCluster", object@model, method) + return(dataFrame(fittedResult)) + } else { + stop(paste("Unsupported model", modelName, sep = " ")) } }) diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R index 2e8076843f08a..edf72937c633a 100644 --- a/R/pkg/R/stats.R +++ b/R/pkg/R/stats.R @@ -130,6 +130,45 @@ setMethod("freqItems", signature(x = "DataFrame", cols = "character"), collect(dataFrame(sct)) }) +#' approxQuantile +#' +#' Calculates the approximate quantiles of a numerical column of a DataFrame. +#' +#' The result of this algorithm has the following deterministic bound: +#' If the DataFrame has N elements and if we request the quantile at probability `p` up to error +#' `err`, then the algorithm will return a sample `x` from the DataFrame so that the *exact* rank +#' of `x` is close to (p * N). More precisely, +#' floor((p - err) * N) <= rank(x) <= ceil((p + err) * N). +#' This method implements a variation of the Greenwald-Khanna algorithm (with some speed +#' optimizations). The algorithm was first present in [[http://dx.doi.org/10.1145/375663.375670 +#' Space-efficient Online Computation of Quantile Summaries]] by Greenwald and Khanna. +#' +#' @param x A SparkSQL DataFrame. +#' @param col The name of the numerical column. +#' @param probabilities A list of quantile probabilities. Each number must belong to [0, 1]. +#' For example 0 is the minimum, 0.5 is the median, 1 is the maximum. +#' @param relativeError The relative target precision to achieve (>= 0). If set to zero, +#' the exact quantiles are computed, which could be very expensive. +#' Note that values greater than 1 are accepted but give the same result as 1. +#' @return The approximate quantiles at the given probabilities. +#' +#' @rdname statfunctions +#' @name approxQuantile +#' @export +#' @examples +#' \dontrun{ +#' df <- jsonFile(sqlContext, "/path/to/file.json") +#' quantiles <- approxQuantile(df, "key", c(0.5, 0.8), 0.0) +#' } +setMethod("approxQuantile", + signature(x = "DataFrame", col = "character", + probabilities = "numeric", relativeError = "numeric"), + function(x, col, probabilities, relativeError) { + statFunctions <- callJMethod(x@sdf, "stat") + callJMethod(statFunctions, "approxQuantile", col, + as.list(probabilities), relativeError) + }) + #' sampleBy #' #' Returns a stratified sample without replacement based on the fraction given on each stratum. diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 08099dd96a87b..af84a0abcf94d 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -113,3 +113,31 @@ test_that("summary works on base GLM models", { baseSummary <- summary(baseModel) expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4) }) + +test_that("kmeans", { + newIris <- iris + newIris$Species <- NULL + training <- suppressWarnings(createDataFrame(sqlContext, newIris)) + + # Cache the DataFrame here to work around the bug SPARK-13178. + cache(training) + take(training, 1) + + model <- kmeans(x = training, centers = 2) + sample <- take(select(predict(model, training), "prediction"), 1) + expect_equal(typeof(sample$prediction), "integer") + expect_equal(sample$prediction, 1) + + # Test stats::kmeans is working + statsModel <- kmeans(x = newIris, centers = 2) + expect_equal(sort(unique(statsModel$cluster)), c(1, 2)) + + # Test fitted works on KMeans + fitted.model <- fitted(model) + expect_equal(sort(collect(distinct(select(fitted.model, "prediction")))$prediction), c(0, 1)) + + # Test summary works on KMeans + summary.model <- summary(model) + cluster <- summary.model$cluster + expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1)) +}) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index cc118108f61cc..236bae6bded25 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1785,6 +1785,14 @@ test_that("sampleBy() on a DataFrame", { expect_identical(as.list(result[2, ]), list(key = "1", count = 7)) }) +test_that("approxQuantile() on a DataFrame", { + l <- lapply(c(0:99), function(i) { i }) + df <- createDataFrame(sqlContext, l, "key") + quantiles <- approxQuantile(df, "key", c(0.5, 0.8), 0.0) + expect_equal(quantiles[[1]], 50) + expect_equal(quantiles[[2]], 80) +}) + test_that("SQL error message is returned from JVM", { retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e) expect_equal(grepl("Table not found: blah", retError), TRUE) diff --git a/bin/spark-submit.cmd b/bin/spark-submit.cmd index f121b62a53d24..f301606933a95 100644 --- a/bin/spark-submit.cmd +++ b/bin/spark-submit.cmd @@ -20,4 +20,4 @@ rem rem This is the entry point for running Spark submit. To avoid polluting the rem environment, it just launches a new cmd to do the real work. -cmd /V /E /C spark-submit2.cmd %* +cmd /V /E /C "%~dp0spark-submit2.cmd" %* diff --git a/network/common/pom.xml b/common/network-common/pom.xml similarity index 100% rename from network/common/pom.xml rename to common/network-common/pom.xml diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/TransportContext.java rename to common/network-common/src/main/java/org/apache/spark/network/TransportContext.java diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java rename to common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java rename to common/network-common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java rename to common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java rename to common/network-common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java similarity index 96% rename from network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java rename to common/network-common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java index f55b884bc45ce..631d767715256 100644 --- a/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java @@ -28,7 +28,7 @@ /** * A {@link ManagedBuffer} backed by {@link ByteBuffer}. */ -public final class NioManagedBuffer extends ManagedBuffer { +public class NioManagedBuffer extends ManagedBuffer { private final ByteBuffer buf; public NioManagedBuffer(ByteBuffer buf) { diff --git a/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java b/common/network-common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java rename to common/network-common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java diff --git a/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java b/common/network-common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java rename to common/network-common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java diff --git a/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java b/common/network-common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java rename to common/network-common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java diff --git a/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java rename to common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java diff --git a/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java rename to common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/client/TransportClient.java rename to common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java rename to common/network-common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java rename to common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java rename to common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Encodable.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/Encodable.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/Message.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java rename to common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java rename to common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java rename to common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java rename to common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java rename to common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java rename to common/network-common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java rename to common/network-common/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java rename to common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java rename to common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java diff --git a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/MessageHandler.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java rename to common/network-common/src/main/java/org/apache/spark/network/server/MessageHandler.java diff --git a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java rename to common/network-common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java diff --git a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java rename to common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java rename to common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java diff --git a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/server/StreamManager.java rename to common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java rename to common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java rename to common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/server/TransportServer.java rename to common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java rename to common/network-common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java diff --git a/network/common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java b/common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java rename to common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java diff --git a/network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java b/common/network-common/src/main/java/org/apache/spark/network/util/ByteUnit.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java rename to common/network-common/src/main/java/org/apache/spark/network/util/ByteUnit.java diff --git a/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java b/common/network-common/src/main/java/org/apache/spark/network/util/ConfigProvider.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java rename to common/network-common/src/main/java/org/apache/spark/network/util/ConfigProvider.java diff --git a/network/common/src/main/java/org/apache/spark/network/util/IOMode.java b/common/network-common/src/main/java/org/apache/spark/network/util/IOMode.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/util/IOMode.java rename to common/network-common/src/main/java/org/apache/spark/network/util/IOMode.java diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java rename to common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java diff --git a/network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java b/common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java rename to common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java diff --git a/network/common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java b/common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java rename to common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java rename to common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java diff --git a/network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java b/common/network-common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java rename to common/network-common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java similarity index 98% rename from network/common/src/main/java/org/apache/spark/network/util/TransportConf.java rename to common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index 115135d44adbd..9f030da2b3cec 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -132,7 +132,8 @@ public int ioRetryWaitTimeMs() { * memory mapping has high overhead for blocks close to or below the page size of the OS. */ public int memoryMapBytes() { - return conf.getInt("spark.storage.memoryMapThreshold", 2 * 1024 * 1024); + return Ints.checkedCast(JavaUtils.byteStringAsBytes( + conf.get("spark.storage.memoryMapThreshold", "2m"))); } /** diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java rename to common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java similarity index 100% rename from network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java rename to common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java similarity index 100% rename from network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java rename to common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java diff --git a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java similarity index 100% rename from network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java rename to common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java similarity index 100% rename from network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java rename to common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java diff --git a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java similarity index 100% rename from network/common/src/test/java/org/apache/spark/network/StreamSuite.java rename to common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java diff --git a/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java b/common/network-common/src/test/java/org/apache/spark/network/TestManagedBuffer.java similarity index 100% rename from network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java rename to common/network-common/src/test/java/org/apache/spark/network/TestManagedBuffer.java diff --git a/network/common/src/test/java/org/apache/spark/network/TestUtils.java b/common/network-common/src/test/java/org/apache/spark/network/TestUtils.java similarity index 100% rename from network/common/src/test/java/org/apache/spark/network/TestUtils.java rename to common/network-common/src/test/java/org/apache/spark/network/TestUtils.java diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java similarity index 100% rename from network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java rename to common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java diff --git a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java similarity index 100% rename from network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java rename to common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java diff --git a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java b/common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java similarity index 100% rename from network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java rename to common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java similarity index 100% rename from network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java rename to common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java diff --git a/network/common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java similarity index 100% rename from network/common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java rename to common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java diff --git a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java b/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java similarity index 100% rename from network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java rename to common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java diff --git a/network/common/src/test/resources/log4j.properties b/common/network-common/src/test/resources/log4j.properties similarity index 100% rename from network/common/src/test/resources/log4j.properties rename to common/network-common/src/test/resources/log4j.properties diff --git a/network/shuffle/pom.xml b/common/network-shuffle/pom.xml similarity index 100% rename from network/shuffle/pom.xml rename to common/network-shuffle/pom.xml diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java b/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java similarity index 100% rename from network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java rename to common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java similarity index 100% rename from network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java rename to common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java similarity index 100% rename from network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java rename to common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java similarity index 100% rename from network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java rename to common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java similarity index 100% rename from network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java rename to common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java similarity index 100% rename from network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java rename to common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java similarity index 100% rename from network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java rename to common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java similarity index 100% rename from network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java rename to common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java similarity index 100% rename from network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java rename to common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java similarity index 100% rename from network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java rename to common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java diff --git a/network/yarn/pom.xml b/common/network-yarn/pom.xml similarity index 100% rename from network/yarn/pom.xml rename to common/network-yarn/pom.xml diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java similarity index 100% rename from network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java rename to common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java similarity index 100% rename from network/yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java rename to common/network-yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java index 2c9aa93582e6f..40fa20c4a3e37 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java @@ -50,7 +50,7 @@ * * This implementation is largely based on the {@code CountMinSketch} class from stream-lib. */ -abstract public class CountMinSketch { +public abstract class CountMinSketch { public enum Version { /** diff --git a/tags/README.md b/common/tags/README.md similarity index 100% rename from tags/README.md rename to common/tags/README.md diff --git a/tags/pom.xml b/common/tags/pom.xml similarity index 97% rename from tags/pom.xml rename to common/tags/pom.xml index 3e8e6f6182875..8e702b4fefe8c 100644 --- a/tags/pom.xml +++ b/common/tags/pom.xml @@ -23,7 +23,7 @@ org.apache.spark spark-parent_2.11 2.0.0-SNAPSHOT - ../pom.xml + ../../pom.xml org.apache.spark diff --git a/tags/src/main/java/org/apache/spark/tags/DockerTest.java b/common/tags/src/main/java/org/apache/spark/tags/DockerTest.java similarity index 100% rename from tags/src/main/java/org/apache/spark/tags/DockerTest.java rename to common/tags/src/main/java/org/apache/spark/tags/DockerTest.java diff --git a/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java b/common/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java similarity index 100% rename from tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java rename to common/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java diff --git a/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java b/common/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java similarity index 100% rename from tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java rename to common/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java diff --git a/unsafe/pom.xml b/common/unsafe/pom.xml similarity index 98% rename from unsafe/pom.xml rename to common/unsafe/pom.xml index 75fea556eeae1..5250014739da2 100644 --- a/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -23,7 +23,7 @@ org.apache.spark spark-parent_2.11 2.0.0-SNAPSHOT - ../pom.xml + ../../pom.xml org.apache.spark diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/Platform.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java similarity index 100% rename from unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java rename to common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java similarity index 100% rename from unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java rename to common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java similarity index 100% rename from unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java rename to common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java similarity index 100% rename from unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java rename to common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java similarity index 100% rename from unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java rename to common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java diff --git a/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala similarity index 100% rename from unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala rename to common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html index 66d111e439096..a2b3826dd324b 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html @@ -64,7 +64,7 @@ {{#applications}} - {{id}} + {{id}} {{name}} {{#attempts}} {{attemptId}} diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index 689c92e86129e..167c8020850d5 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -37,6 +37,22 @@ function formatDuration(milliseconds) { return hours.toFixed(1) + " h"; } +function makeIdNumeric(id) { + var strs = id.split("_"); + if (strs.length < 3) { + return id; + } + var appSeqNum = strs[2]; + var resl = strs[0] + "_" + strs[1] + "_"; + var diff = 10 - appSeqNum.length; + while (diff > 0) { + resl += "0"; // padding 0 before the app sequence number to make sure it has 10 characters + diff--; + } + resl += appSeqNum; + return resl; +} + function formatDate(date) { return date.split(".")[0].replace("T", " "); } @@ -62,6 +78,21 @@ jQuery.extend( jQuery.fn.dataTableExt.oSort, { } } ); +jQuery.extend( jQuery.fn.dataTableExt.oSort, { + "appid-numeric-pre": function ( a ) { + var x = a.match(/title="*(-?[0-9a-zA-Z\-\_]+)/)[1]; + return makeIdNumeric(x); + }, + + "appid-numeric-asc": function ( a, b ) { + return ((a < b) ? -1 : ((a > b) ? 1 : 0)); + }, + + "appid-numeric-desc": function ( a, b ) { + return ((a < b) ? 1 : ((a > b) ? -1 : 0)); + } +} ); + $(document).ajaxStop($.unblockUI); $(document).ajaxStart(function(){ $.blockUI({ message: '

Loading history summary...

'}); @@ -109,7 +140,7 @@ $(document).ready(function() { var selector = "#history-summary-table"; var conf = { "columns": [ - {name: 'first'}, + {name: 'first', type: "appid-numeric"}, {name: 'second'}, {name: 'third'}, {name: 'fourth'}, @@ -118,7 +149,8 @@ $(document).ready(function() { {name: 'seventh'}, {name: 'eighth'}, ], - "autoWidth": false + "autoWidth": false, + "order": [[ 0, "desc" ]] }; var rowGroupConf = { diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js index 4337c42087e79..1b0d4692d9cd0 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js +++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js @@ -222,10 +222,11 @@ function renderDagVizForJob(svgContainer) { var attemptId = 0 var stageLink = d3.select("#stage-" + stageId + "-" + attemptId) .select("a.name-link") - .attr("href") + "&expandDagViz=true"; + .attr("href"); container = svgContainer .append("a") .attr("xlink:href", stageLink) + .attr("onclick", "window.localStorage.setItem(expandDagVizArrowKey(false), true)") .append("g") .attr("id", containerId); } diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index 1ec9ba7755725..2b456facd9439 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -21,6 +21,7 @@ import scala.collection.mutable import org.apache.spark.rdd.RDD import org.apache.spark.storage._ +import org.apache.spark.util.CompletionIterator /** * Spark class responsible for passing RDDs partition contents to the BlockManager and making @@ -47,6 +48,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { existingMetrics.incBytesReadInternal(blockResult.bytes) val iter = blockResult.data.asInstanceOf[Iterator[T]] + new InterruptibleIterator[T](context, iter) { override def next(): T = { existingMetrics.incRecordsReadInternal(1) @@ -156,7 +158,9 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { case Left(arr) => // We have successfully unrolled the entire partition, so cache it in memory blockManager.putArray(key, arr, level, tellMaster = true, effectiveStorageLevel) - arr.iterator.asInstanceOf[Iterator[T]] + CompletionIterator[T, Iterator[T]]( + arr.iterator.asInstanceOf[Iterator[T]], + blockManager.releaseLock(key)) case Right(it) => // There is not enough space to cache this partition in memory val returnValues = it.asInstanceOf[Iterator[T]] diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 36e240e618490..b81bfb3182212 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -503,6 +503,31 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { set("spark.executor.instances", value) } } + + if (contains("spark.master") && get("spark.master").startsWith("yarn-")) { + val warning = s"spark.master ${get("spark.master")} is deprecated in Spark 2.0+, please " + + "instead use \"yarn\" with specified deploy mode." + + get("spark.master") match { + case "yarn-cluster" => + logWarning(warning) + set("spark.master", "yarn") + set("spark.submit.deployMode", "cluster") + case "yarn-client" => + logWarning(warning) + set("spark.master", "yarn") + set("spark.submit.deployMode", "client") + case _ => // Any other unexpected master will be checked when creating scheduler backend. + } + } + + if (contains("spark.submit.deployMode")) { + get("spark.submit.deployMode") match { + case "cluster" | "client" => + case e => throw new SparkException("spark.submit.deployMode can only be \"cluster\" or " + + "\"client\".") + } + } } /** diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index cd7eed382e346..0e8b735b923bb 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -237,13 +237,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli def jars: Seq[String] = _jars def files: Seq[String] = _files def master: String = _conf.get("spark.master") + def deployMode: String = _conf.getOption("spark.submit.deployMode").getOrElse("client") def appName: String = _conf.get("spark.app.name") private[spark] def isEventLogEnabled: Boolean = _conf.getBoolean("spark.eventLog.enabled", false) private[spark] def eventLogDir: Option[URI] = _eventLogDir private[spark] def eventLogCodec: Option[String] = _eventLogCodec - def isLocal: Boolean = (master == "local" || master.startsWith("local[")) + def isLocal: Boolean = Utils.isLocalMaster(_conf) /** * @return true if context is stopped or in the midst of stopping. @@ -375,10 +376,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } // System property spark.yarn.app.id must be set if user code ran by AM on a YARN cluster - // yarn-standalone is deprecated, but still supported - if ((master == "yarn-cluster" || master == "yarn-standalone") && - !_conf.contains("spark.yarn.app.id")) { - throw new SparkException("Detected yarn-cluster mode, but isn't running on a cluster. " + + if (master == "yarn" && deployMode == "cluster" && !_conf.contains("spark.yarn.app.id")) { + throw new SparkException("Detected yarn cluster mode, but isn't running on a cluster. " + "Deployment to YARN is not supported directly by SparkContext. Please use spark-submit.") } @@ -414,7 +413,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } } - if (master == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") + if (master == "yarn" && deployMode == "client") System.setProperty("SPARK_YARN_MODE", "true") // "_jobProgressListener" should be set up before creating SparkEnv because when creating // "SparkEnv", some messages will be posted to "listenerBus" and we should not miss them. @@ -491,7 +490,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli HeartbeatReceiver.ENDPOINT_NAME, new HeartbeatReceiver(this)) // Create and start the scheduler - val (sched, ts) = SparkContext.createTaskScheduler(this, master) + val (sched, ts) = SparkContext.createTaskScheduler(this, master, deployMode) _schedulerBackend = sched _taskScheduler = ts _dagScheduler = new DAGScheduler(this) @@ -527,10 +526,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Optionally scale number of executors dynamically based on workload. Exposed for testing. val dynamicAllocationEnabled = Utils.isDynamicAllocationEnabled(_conf) - if (!dynamicAllocationEnabled && _conf.getBoolean("spark.dynamicAllocation.enabled", false)) { - logWarning("Dynamic Allocation and num executors both set, thus dynamic allocation disabled.") - } - _executorAllocationManager = if (dynamicAllocationEnabled) { Some(new ExecutorAllocationManager(this, listenerBus, _conf)) @@ -1590,10 +1585,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli key = uri.getScheme match { // A JAR file which exists only on the driver node case null | "file" => - // yarn-standalone is deprecated, but still supported - if (SparkHadoopUtil.get.isYarnMode() && - (master == "yarn-standalone" || master == "yarn-cluster")) { - // In order for this to work in yarn-cluster mode the user must specify the + if (master == "yarn" && deployMode == "cluster") { + // In order for this to work in yarn cluster mode the user must specify the // --addJars option to the client to upload the file into the distributed cache // of the AM to make it show up in the current working directory. val fileName = new Path(uri.getPath).getName() @@ -2319,7 +2312,8 @@ object SparkContext extends Logging { */ private def createTaskScheduler( sc: SparkContext, - master: String): (SchedulerBackend, TaskScheduler) = { + master: String, + deployMode: String): (SchedulerBackend, TaskScheduler) = { import SparkMasterRegex._ // When running locally, don't try to re-execute tasks on failure. @@ -2381,11 +2375,7 @@ object SparkContext extends Logging { } (backend, scheduler) - case "yarn-standalone" | "yarn-cluster" => - if (master == "yarn-standalone") { - logWarning( - "\"yarn-standalone\" is deprecated as of Spark 1.0. Use \"yarn-cluster\" instead.") - } + case "yarn" if deployMode == "cluster" => val scheduler = try { val clazz = Utils.classForName("org.apache.spark.scheduler.cluster.YarnClusterScheduler") val cons = clazz.getConstructor(classOf[SparkContext]) @@ -2410,7 +2400,7 @@ object SparkContext extends Logging { scheduler.initialize(backend) (backend, scheduler) - case "yarn-client" => + case "yarn" if deployMode == "client" => val scheduler = try { val clazz = Utils.classForName("org.apache.spark.scheduler.cluster.YarnScheduler") val cons = clazz.getConstructor(classOf[SparkContext]) @@ -2451,7 +2441,7 @@ object SparkContext extends Logging { case zkUrl if zkUrl.startsWith("zk://") => logWarning("Master URL for a multi-master Mesos cluster managed by ZooKeeper should be " + "in the form mesos://zk://host:port. Current Master URL will stop working in Spark 2.0.") - createTaskScheduler(sc, "mesos://" + zkUrl) + createTaskScheduler(sc, "mesos://" + zkUrl, deployMode) case _ => throw new SparkException("Could not parse Master URL: '" + master + "'") diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 9f49cf1c4c9bd..bfcacbf229b00 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -23,7 +23,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.source.Source -import org.apache.spark.util.TaskCompletionListener +import org.apache.spark.util.{TaskCompletionListener, TaskFailureListener} object TaskContext { @@ -106,6 +106,8 @@ abstract class TaskContext extends Serializable { * Adds a (Java friendly) listener to be executed on task completion. * This will be called in all situation - success, failure, or cancellation. * An example use is for HadoopRDD to register a callback to close the input stream. + * + * Exceptions thrown by the listener will result in failure of the task. */ def addTaskCompletionListener(listener: TaskCompletionListener): TaskContext @@ -113,8 +115,30 @@ abstract class TaskContext extends Serializable { * Adds a listener in the form of a Scala closure to be executed on task completion. * This will be called in all situations - success, failure, or cancellation. * An example use is for HadoopRDD to register a callback to close the input stream. + * + * Exceptions thrown by the listener will result in failure of the task. */ - def addTaskCompletionListener(f: (TaskContext) => Unit): TaskContext + def addTaskCompletionListener(f: (TaskContext) => Unit): TaskContext = { + addTaskCompletionListener(new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = f(context) + }) + } + + /** + * Adds a listener to be executed on task failure. + * Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times. + */ + def addTaskFailureListener(listener: TaskFailureListener): TaskContext + + /** + * Adds a listener to be executed on task failure. + * Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times. + */ + def addTaskFailureListener(f: (TaskContext, Throwable) => Unit): TaskContext = { + addTaskFailureListener(new TaskFailureListener { + override def onTaskFailure(context: TaskContext, error: Throwable): Unit = f(context, error) + }) + } /** * The ID of the stage that this task belong to. diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 1d228b6b86c55..65f6f741f7f4e 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -23,7 +23,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.MetricsSystem import org.apache.spark.metrics.source.Source -import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} +import org.apache.spark.util._ private[spark] class TaskContextImpl( val stageId: Int, @@ -41,9 +41,12 @@ private[spark] class TaskContextImpl( */ override val taskMetrics: TaskMetrics = new TaskMetrics(initialAccumulators) - // List of callback functions to execute when the task completes. + /** List of callback functions to execute when the task completes. */ @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener] + /** List of callback functions to execute when the task fails. */ + @transient private val onFailureCallbacks = new ArrayBuffer[TaskFailureListener] + // Whether the corresponding task has been killed. @volatile private var interrupted: Boolean = false @@ -55,14 +58,30 @@ private[spark] class TaskContextImpl( this } - override def addTaskCompletionListener(f: TaskContext => Unit): this.type = { - onCompleteCallbacks += new TaskCompletionListener { - override def onTaskCompletion(context: TaskContext): Unit = f(context) - } + override def addTaskFailureListener(listener: TaskFailureListener): this.type = { + onFailureCallbacks += listener this } - /** Marks the task as completed and triggers the listeners. */ + /** Marks the task as completed and triggers the failure listeners. */ + private[spark] def markTaskFailed(error: Throwable): Unit = { + val errorMsgs = new ArrayBuffer[String](2) + // Process complete callbacks in the reverse order of registration + onFailureCallbacks.reverse.foreach { listener => + try { + listener.onTaskFailure(this, error) + } catch { + case e: Throwable => + errorMsgs += e.getMessage + logError("Error in TaskFailureListener", e) + } + } + if (errorMsgs.nonEmpty) { + throw new TaskCompletionListenerException(errorMsgs, Option(error)) + } + } + + /** Marks the task as completed and triggers the completion listeners. */ private[spark] def markTaskCompleted(): Unit = { completed = true val errorMsgs = new ArrayBuffer[String](2) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index f12e2dfafa19d..05d1c31a08f22 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -42,14 +42,8 @@ import org.apache.spark.util.{SerializableConfiguration, Utils} private[spark] class PythonRDD( parent: RDD[_], - command: Array[Byte], - envVars: JMap[String, String], - pythonIncludes: JList[String], - preservePartitoning: Boolean, - pythonExec: String, - pythonVer: String, - broadcastVars: JList[Broadcast[PythonBroadcast]], - accumulator: Accumulator[JList[Array[Byte]]]) + func: PythonFunction, + preservePartitoning: Boolean) extends RDD[Array[Byte]](parent) { val bufferSize = conf.getInt("spark.buffer.size", 65536) @@ -64,29 +58,37 @@ private[spark] class PythonRDD( val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { - val runner = new PythonRunner( - command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, accumulator, - bufferSize, reuse_worker) + val runner = new PythonRunner(func, bufferSize, reuse_worker) runner.compute(firstParent.iterator(split, context), split.index, context) } } - /** - * A helper class to run Python UDFs in Spark. + * A wrapper for a Python function, contains all necessary context to run the function in Python + * runner. */ -private[spark] class PythonRunner( +private[spark] case class PythonFunction( command: Array[Byte], envVars: JMap[String, String], pythonIncludes: JList[String], pythonExec: String, pythonVer: String, broadcastVars: JList[Broadcast[PythonBroadcast]], - accumulator: Accumulator[JList[Array[Byte]]], + accumulator: Accumulator[JList[Array[Byte]]]) + +/** + * A helper class to run Python UDFs in Spark. + */ +private[spark] class PythonRunner( + func: PythonFunction, bufferSize: Int, reuse_worker: Boolean) extends Logging { + private val envVars = func.envVars + private val pythonExec = func.pythonExec + private val accumulator = func.accumulator + def compute( inputIterator: Iterator[_], partitionIndex: Int, @@ -225,6 +227,11 @@ private[spark] class PythonRunner( @volatile private var _exception: Exception = null + private val pythonVer = func.pythonVer + private val pythonIncludes = func.pythonIncludes + private val broadcastVars = func.broadcastVars + private val command = func.command + setDaemon(true) /** Contains the exception thrown while writing the parent iterator to the Python process. */ diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 9bd69727f6086..c08f87a8b45c1 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -24,10 +24,10 @@ import scala.collection.JavaConverters._ import scala.reflect.ClassTag import scala.util.Random -import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException} +import org.apache.spark._ import org.apache.spark.io.CompressionCodec import org.apache.spark.serializer.Serializer -import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} +import org.apache.spark.storage.{BlockId, BroadcastBlockId, StorageLevel} import org.apache.spark.util.{ByteBufferInputStream, Utils} import org.apache.spark.util.io.ByteArrayChunkOutputStream @@ -90,22 +90,29 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) /** * Divide the object into multiple blocks and put those blocks in the block manager. + * * @param value the object to divide * @return number of blocks this broadcast variable is divided into */ private def writeBlocks(value: T): Int = { + import StorageLevel._ // Store a copy of the broadcast variable in the driver so that tasks run on the driver // do not create a duplicate copy of the broadcast variable's value. - SparkEnv.get.blockManager.putSingle(broadcastId, value, StorageLevel.MEMORY_AND_DISK, - tellMaster = false) + val blockManager = SparkEnv.get.blockManager + if (blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) { + blockManager.releaseLock(broadcastId) + } else { + throw new SparkException(s"Failed to store $broadcastId in BlockManager") + } val blocks = TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec) blocks.zipWithIndex.foreach { case (block, i) => - SparkEnv.get.blockManager.putBytes( - BroadcastBlockId(id, "piece" + i), - block, - StorageLevel.MEMORY_AND_DISK_SER, - tellMaster = true) + val pieceId = BroadcastBlockId(id, "piece" + i) + if (blockManager.putBytes(pieceId, block, MEMORY_AND_DISK_SER, tellMaster = true)) { + blockManager.releaseLock(pieceId) + } else { + throw new SparkException(s"Failed to store $pieceId of $broadcastId in local BlockManager") + } } blocks.length } @@ -127,16 +134,18 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) def getRemote: Option[ByteBuffer] = bm.getRemoteBytes(pieceId).map { block => // If we found the block from remote executors/driver's BlockManager, put the block // in this executor's BlockManager. - SparkEnv.get.blockManager.putBytes( - pieceId, - block, - StorageLevel.MEMORY_AND_DISK_SER, - tellMaster = true) + if (!bm.putBytes(pieceId, block, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) { + throw new SparkException( + s"Failed to store $pieceId of $broadcastId in local BlockManager") + } block } val block: ByteBuffer = getLocal.orElse(getRemote).getOrElse( throw new SparkException(s"Failed to get $pieceId of $broadcastId")) + // At this point we are guaranteed to hold a read lock, since we either got the block locally + // or stored the remotely-fetched block and automatically downgraded the write lock. blocks(pid) = block + releaseLock(pieceId) } blocks } @@ -165,8 +174,10 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) private def readBroadcastBlock(): T = Utils.tryOrIOException { TorrentBroadcast.synchronized { setConf(SparkEnv.get.conf) - SparkEnv.get.blockManager.getLocal(broadcastId).map(_.data.next()) match { + val blockManager = SparkEnv.get.blockManager + blockManager.getLocal(broadcastId).map(_.data.next()) match { case Some(x) => + releaseLock(broadcastId) x.asInstanceOf[T] case None => @@ -179,13 +190,36 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) blocks, SparkEnv.get.serializer, compressionCodec) // Store the merged copy in BlockManager so other tasks on this executor don't // need to re-fetch it. - SparkEnv.get.blockManager.putSingle( - broadcastId, obj, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + val storageLevel = StorageLevel.MEMORY_AND_DISK + if (blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) { + releaseLock(broadcastId) + } else { + throw new SparkException(s"Failed to store $broadcastId in BlockManager") + } obj } } } + /** + * If running in a task, register the given block's locks for release upon task completion. + * Otherwise, if not running in a task then immediately release the lock. + */ + private def releaseLock(blockId: BlockId): Unit = { + val blockManager = SparkEnv.get.blockManager + Option(TaskContext.get()) match { + case Some(taskContext) => + taskContext.addTaskCompletionListener(_ => blockManager.releaseLock(blockId)) + case None => + // This should only happen on the driver, where broadcast variables may be accessed + // outside of running tasks (e.g. when computing rdd.partitions()). In order to allow + // broadcast variables to be garbage collected we need to free the reference here + // which is slightly unsafe but is technically okay because broadcast variables aren't + // stored off-heap. + blockManager.releaseLock(blockId) + } + } + } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index a6749f7e38802..d5a3383932a72 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -226,11 +226,17 @@ object SparkSubmit { // Set the cluster manager val clusterManager: Int = args.master match { - case m if m.startsWith("yarn") => YARN + case "yarn" => YARN + case "yarn-client" | "yarn-cluster" => + printWarning(s"Master ${args.master} is deprecated since 2.0." + + " Please use master \"yarn\" with specified deploy mode instead.") + YARN case m if m.startsWith("spark") => STANDALONE case m if m.startsWith("mesos") => MESOS case m if m.startsWith("local") => LOCAL - case _ => printErrorAndExit("Master must start with yarn, spark, mesos, or local"); -1 + case _ => + printErrorAndExit("Master must either be yarn or start with spark, mesos, local") + -1 } // Set the deploy mode; default is client mode @@ -240,23 +246,20 @@ object SparkSubmit { case _ => printErrorAndExit("Deploy mode must be either client or cluster"); -1 } - // Because "yarn-cluster" and "yarn-client" encapsulate both the master - // and deploy mode, we have some logic to infer the master and deploy mode + // Because the deprecated way of specifying "yarn-cluster" and "yarn-client" encapsulate both + // the master and deploy mode, we have some logic to infer the master and deploy mode // from each other if only one is specified, or exit early if they are at odds. if (clusterManager == YARN) { - if (args.master == "yarn-standalone") { - printWarning("\"yarn-standalone\" is deprecated. Use \"yarn-cluster\" instead.") - args.master = "yarn-cluster" - } (args.master, args.deployMode) match { case ("yarn-cluster", null) => deployMode = CLUSTER + args.master = "yarn" case ("yarn-cluster", "client") => printErrorAndExit("Client deploy mode is not compatible with master \"yarn-cluster\"") case ("yarn-client", "cluster") => printErrorAndExit("Cluster deploy mode is not compatible with master \"yarn-client\"") case (_, mode) => - args.master = "yarn-" + Option(mode).getOrElse("client") + args.master = "yarn" } // Make sure YARN is included in our build if we're trying to use it diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 915ef81b4eae3..175756b80b6bb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -255,6 +255,10 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S "either HADOOP_CONF_DIR or YARN_CONF_DIR must be set in the environment.") } } + + if (proxyUser != null && principal != null) { + SparkSubmit.printErrorAndExit("Only one of --proxy-user or --principal can be provided.") + } } private def validateKillArguments(): Unit = { @@ -517,6 +521,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G). | | --proxy-user NAME User to impersonate when submitting the application. + | This argument does not work with --principal / --keytab. | | --help, -h Show this help message and exit | --verbose, -v Print additional debug output diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 00be3a240dbac..a959f200d4cc2 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -114,6 +114,19 @@ private[spark] class Executor( private val heartbeatReceiverRef = RpcUtils.makeDriverRef(HeartbeatReceiver.ENDPOINT_NAME, conf, env.rpcEnv) + /** + * When an executor is unable to send heartbeats to the driver more than `HEARTBEAT_MAX_FAILURES` + * times, it should kill itself. The default value is 60. It means we will retry to send + * heartbeats about 10 minutes because the heartbeat interval is 10s. + */ + private val HEARTBEAT_MAX_FAILURES = conf.getInt("spark.executor.heartbeat.maxFailures", 60) + + /** + * Count the failure times of heartbeat. It should only be acessed in the heartbeat thread. Each + * successful heartbeat will reset it to 0. + */ + private var heartbeatFailures = 0 + startDriverHeartbeater() def launchTask( @@ -218,7 +231,9 @@ private[spark] class Executor( threwException = false res } finally { + val releasedLocks = env.blockManager.releaseAllLocksForTask(taskId) val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory() + if (freedMemory > 0) { val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId" if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false) && !threwException) { @@ -227,6 +242,17 @@ private[spark] class Executor( logError(errMsg) } } + + if (releasedLocks.nonEmpty) { + val errMsg = + s"${releasedLocks.size} block locks were not released by TID = $taskId:\n" + + releasedLocks.mkString("[", ", ", "]") + if (conf.getBoolean("spark.storage.exceptionOnPinLeak", false) && !threwException) { + throw new SparkException(errMsg) + } else { + logError(errMsg) + } + } } val taskFinish = System.currentTimeMillis() @@ -266,8 +292,11 @@ private[spark] class Executor( ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize)) } else if (resultSize >= maxRpcMessageSize) { val blockId = TaskResultBlockId(taskId) - env.blockManager.putBytes( + val putSucceeded = env.blockManager.putBytes( blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER) + if (putSucceeded) { + env.blockManager.releaseLock(blockId) + } logInfo( s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)") ser.serialize(new IndirectTaskResult[Any](blockId, resultSize)) @@ -445,8 +474,16 @@ private[spark] class Executor( logInfo("Told to re-register on heartbeat") env.blockManager.reregister() } + heartbeatFailures = 0 } catch { - case NonFatal(e) => logWarning("Issue communicating with driver in heartbeater", e) + case NonFatal(e) => + logWarning("Issue communicating with driver in heartbeater", e) + heartbeatFailures += 1 + if (heartbeatFailures >= HEARTBEAT_MAX_FAILURES) { + logError(s"Exit as unable to send heartbeats to driver " + + s"more than $HEARTBEAT_MAX_FAILURES times") + System.exit(ExecutorExitCode.HEARTBEAT_FAILURE) + } } } diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala index ea36fb60bd540..99858f785600d 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala @@ -39,6 +39,12 @@ object ExecutorExitCode { /** ExternalBlockStore failed to create a local temporary directory after many attempts. */ val EXTERNAL_BLOCK_STORE_FAILED_TO_CREATE_DIR = 55 + /** + * Executor is unable to send heartbeats to the driver more than + * "spark.executor.heartbeat.maxFailures" times. + */ + val HEARTBEAT_FAILURE = 56 + def explainExitCode(exitCode: Int): String = { exitCode match { case UNCAUGHT_EXCEPTION => "Uncaught exception" @@ -51,6 +57,8 @@ object ExecutorExitCode { // TODO: replace external block store with concrete implementation name case EXTERNAL_BLOCK_STORE_FAILED_TO_CREATE_DIR => "ExternalBlockStore failed to create a local temporary directory." + case HEARTBEAT_FAILURE => + "Unable to send heartbeats to driver." case _ => "Unknown executor exit code (" + exitCode + ")" + ( if (exitCode > 128) { diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala index a3321e3f179f6..6c57c98ea5c59 100644 --- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala @@ -183,7 +183,17 @@ object UnifiedMemoryManager { val minSystemMemory = reservedMemory * 1.5 if (systemMemory < minSystemMemory) { throw new IllegalArgumentException(s"System memory $systemMemory must " + - s"be at least $minSystemMemory. Please use a larger heap size.") + s"be at least $minSystemMemory. Please increase heap size using the --driver-memory " + + s"option or spark.driver.memory in Spark configuration.") + } + // SPARK-12759 Check executor memory to fail fast if memory is insufficient + if (conf.contains("spark.executor.memory")) { + val executorMemory = conf.getSizeAsBytes("spark.executor.memory") + if (executorMemory < minSystemMemory) { + throw new IllegalArgumentException(s"Executor memory $executorMemory must be at least " + + s"$minSystemMemory. Please increase executor memory using the " + + s"--executor-memory option or spark.executor.memory in Spark configuration.") + } } val usableMemory = systemMemory - reservedMemory val memoryFraction = conf.getDouble("spark.memory.fraction", 0.75) diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index 1745d52c81923..cc5e851c29b32 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -31,6 +31,14 @@ trait BlockDataManager { /** * Put the block locally, using the given storage level. + * + * Returns true if the block was stored and false if the put operation failed or the block + * already existed. */ - def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Unit + def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Boolean + + /** + * Release locks acquired by [[putBlockData()]] and [[getBlockData()]]. + */ + def releaseLock(blockId: BlockId): Unit } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index df8c21fb837ed..e4246df83a6ec 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -65,7 +65,11 @@ class NettyBlockRpcServer( val level: StorageLevel = serializer.newInstance().deserialize(ByteBuffer.wrap(uploadBlock.metadata)) val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData)) - blockManager.putBlockData(BlockId(uploadBlock.blockId), data, level) + val blockId = BlockId(uploadBlock.blockId) + val putSucceeded = blockManager.putBlockData(blockId, data, level) + if (putSucceeded) { + blockManager.releaseLock(blockId) + } responseContext.onSuccess(ByteBuffer.allocate(0)) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index a49f3716e2702..d2b8ca90a9899 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -64,6 +64,7 @@ private[spark] abstract class Task[T]( taskAttemptId: Long, attemptNumber: Int, metricsSystem: MetricsSystem): T = { + SparkEnv.get.blockManager.registerTask(taskAttemptId) context = new TaskContextImpl( stageId, partitionId, @@ -79,7 +80,12 @@ private[spark] abstract class Task[T]( } try { runTask(context) + } catch { case e: Throwable => + // Catch all errors; run task failure callbacks, and rethrow the exception. + context.markTaskFailed(e) + throw e } finally { + // Call the task completion callbacks. context.markTaskCompleted() try { Utils.tryLogNonFatalError { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 0a5b09dc0d1fa..d151de5f6a830 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -179,6 +179,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp context.reply(true) case RemoveExecutor(executorId, reason) => + // We will remove the executor's state and cannot restore it. However, the connection + // between the driver and the executor may be still alive so that the executor won't exit + // automatically, so try to tell the executor to stop itself. See SPARK-13519. + executorDataMap.get(executorId).foreach(_.executorEndpoint.send(StopExecutor)) removeExecutor(executorId, reason) context.reply(true) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index f803cc7a36a9a..622f361ec2a3c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -19,14 +19,12 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File import java.util.{Collections, List => JList} -import java.util.concurrent.TimeUnit import java.util.concurrent.locks.ReentrantLock import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.{Buffer, HashMap, HashSet} -import com.google.common.base.Stopwatch import org.apache.mesos.{Scheduler => MScheduler, SchedulerDriver} import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} @@ -493,12 +491,11 @@ private[spark] class CoarseMesosSchedulerBackend( // Wait for executors to report done, or else mesosDriver.stop() will forcefully kill them. // See SPARK-12330 - val stopwatch = new Stopwatch() - stopwatch.start() + val startTime = System.nanoTime() // slaveIdsWithExecutors has no memory barrier, so this is eventually consistent while (numExecutors() > 0 && - stopwatch.elapsed(TimeUnit.MILLISECONDS) < shutdownTimeoutMS) { + System.nanoTime() - startTime < shutdownTimeoutMS * 1000L * 1000L) { Thread.sleep(100) } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala index 645ede26a0879..7750a096230cb 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala @@ -28,7 +28,7 @@ private[v1] class AllRDDResource(ui: SparkUI) { @GET def rddList(): Seq[RDDStorageInfo] = { - val storageStatusList = ui.storageListener.storageStatusList + val storageStatusList = ui.storageListener.activeStorageStatusList val rddInfos = ui.storageListener.rddInfoList rddInfos.map{rddInfo => AllRDDResource.getRDDStorageInfo(rddInfo.id, rddInfo, storageStatusList, @@ -44,7 +44,7 @@ private[spark] object AllRDDResource { rddId: Int, listener: StorageListener, includeDetails: Boolean): Option[RDDStorageInfo] = { - val storageStatusList = listener.storageStatusList + val storageStatusList = listener.activeStorageStatusList listener.rddInfoList.find { _.id == rddId }.map { rddInfo => getRDDStorageInfo(rddId, rddInfo, storageStatusList, includeDetails) } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala index 3bdba922328c2..6ca59c2f3caeb 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala @@ -31,9 +31,9 @@ private[v1] class ExecutorListResource(ui: SparkUI) { listener.synchronized { // The follow codes should be protected by `listener` to make sure no executors will be // removed before we query their status. See SPARK-12784. - val storageStatusList = listener.storageStatusList + val storageStatusList = listener.activeStorageStatusList (0 until storageStatusList.size).map { statusId => - ExecutorsPage.getExecInfo(listener, statusId) + ExecutorsPage.getExecInfo(listener, statusId, isActive = true) } } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index d116e68c17f18..909dd0c07ea63 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -54,6 +54,7 @@ class ExecutorStageSummary private[spark]( class ExecutorSummary private[spark]( val id: String, val hostPort: String, + val isActive: Boolean, val rddBlocks: Int, val memoryUsed: Long, val diskUsed: Long, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala deleted file mode 100644 index 22fdf73e9d1f4..0000000000000 --- a/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import java.util.concurrent.ConcurrentHashMap - -private[storage] class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) { - // To save space, 'pending' and 'failed' are encoded as special sizes: - @volatile var size: Long = BlockInfo.BLOCK_PENDING - private def pending: Boolean = size == BlockInfo.BLOCK_PENDING - private def failed: Boolean = size == BlockInfo.BLOCK_FAILED - private def initThread: Thread = BlockInfo.blockInfoInitThreads.get(this) - - setInitThread() - - private def setInitThread() { - /* Set current thread as init thread - waitForReady will not block this thread - * (in case there is non trivial initialization which ends up calling waitForReady - * as part of initialization itself) */ - BlockInfo.blockInfoInitThreads.put(this, Thread.currentThread()) - } - - /** - * Wait for this BlockInfo to be marked as ready (i.e. block is finished writing). - * Return true if the block is available, false otherwise. - */ - def waitForReady(): Boolean = { - if (pending && initThread != Thread.currentThread()) { - synchronized { - while (pending) { - this.wait() - } - } - } - !failed - } - - /** Mark this BlockInfo as ready (i.e. block is finished writing) */ - def markReady(sizeInBytes: Long) { - require(sizeInBytes >= 0, s"sizeInBytes was negative: $sizeInBytes") - assert(pending) - size = sizeInBytes - BlockInfo.blockInfoInitThreads.remove(this) - synchronized { - this.notifyAll() - } - } - - /** Mark this BlockInfo as ready but failed */ - def markFailure() { - assert(pending) - size = BlockInfo.BLOCK_FAILED - BlockInfo.blockInfoInitThreads.remove(this) - synchronized { - this.notifyAll() - } - } -} - -private object BlockInfo { - /* initThread is logically a BlockInfo field, but we store it here because - * it's only needed while this block is in the 'pending' state and we want - * to minimize BlockInfo's memory footprint. */ - private val blockInfoInitThreads = new ConcurrentHashMap[BlockInfo, Thread] - - private val BLOCK_PENDING: Long = -1L - private val BLOCK_FAILED: Long = -2L -} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala new file mode 100644 index 0000000000000..0eda97e58d451 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala @@ -0,0 +1,445 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import javax.annotation.concurrent.GuardedBy + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import com.google.common.collect.ConcurrentHashMultiset + +import org.apache.spark.{Logging, SparkException, TaskContext} + + +/** + * Tracks metadata for an individual block. + * + * Instances of this class are _not_ thread-safe and are protected by locks in the + * [[BlockInfoManager]]. + * + * @param level the block's storage level. This is the requested persistence level, not the + * effective storage level of the block (i.e. if this is MEMORY_AND_DISK, then this + * does not imply that the block is actually resident in memory). + * @param tellMaster whether state changes for this block should be reported to the master. This + * is true for most blocks, but is false for broadcast blocks. + */ +private[storage] class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) { + + /** + * The size of the block (in bytes) + */ + def size: Long = _size + def size_=(s: Long): Unit = { + _size = s + checkInvariants() + } + private[this] var _size: Long = 0 + + /** + * The number of times that this block has been locked for reading. + */ + def readerCount: Int = _readerCount + def readerCount_=(c: Int): Unit = { + _readerCount = c + checkInvariants() + } + private[this] var _readerCount: Int = 0 + + /** + * The task attempt id of the task which currently holds the write lock for this block, or + * [[BlockInfo.NON_TASK_WRITER]] if the write lock is held by non-task code, or + * [[BlockInfo.NO_WRITER]] if this block is not locked for writing. + */ + def writerTask: Long = _writerTask + def writerTask_=(t: Long): Unit = { + _writerTask = t + checkInvariants() + } + private[this] var _writerTask: Long = 0 + + /** + * True if this block has been removed from the BlockManager and false otherwise. + * This field is used to communicate block deletion to blocked readers / writers (see its usage + * in [[BlockInfoManager]]). + */ + def removed: Boolean = _removed + def removed_=(r: Boolean): Unit = { + _removed = r + checkInvariants() + } + private[this] var _removed: Boolean = false + + private def checkInvariants(): Unit = { + // A block's reader count must be non-negative: + assert(_readerCount >= 0) + // A block is either locked for reading or for writing, but not for both at the same time: + assert(_readerCount == 0 || _writerTask == BlockInfo.NO_WRITER) + // If a block is removed then it is not locked: + assert(!_removed || (_readerCount == 0 && _writerTask == BlockInfo.NO_WRITER)) + } + + checkInvariants() +} + +private[storage] object BlockInfo { + + /** + * Special task attempt id constant used to mark a block's write lock as being unlocked. + */ + val NO_WRITER: Long = -1 + + /** + * Special task attempt id constant used to mark a block's write lock as being held by + * a non-task thread (e.g. by a driver thread or by unit test code). + */ + val NON_TASK_WRITER: Long = -1024 +} + +/** + * Component of the [[BlockManager]] which tracks metadata for blocks and manages block locking. + * + * The locking interface exposed by this class is readers-writer lock. Every lock acquisition is + * automatically associated with a running task and locks are automatically released upon task + * completion or failure. + * + * This class is thread-safe. + */ +private[storage] class BlockInfoManager extends Logging { + + private type TaskAttemptId = Long + + /** + * Used to look up metadata for individual blocks. Entries are added to this map via an atomic + * set-if-not-exists operation ([[lockNewBlockForWriting()]]) and are removed + * by [[removeBlock()]]. + */ + @GuardedBy("this") + private[this] val infos = new mutable.HashMap[BlockId, BlockInfo] + + /** + * Tracks the set of blocks that each task has locked for writing. + */ + @GuardedBy("this") + private[this] val writeLocksByTask = + new mutable.HashMap[TaskAttemptId, mutable.Set[BlockId]] + with mutable.MultiMap[TaskAttemptId, BlockId] + + /** + * Tracks the set of blocks that each task has locked for reading, along with the number of times + * that a block has been locked (since our read locks are re-entrant). + */ + @GuardedBy("this") + private[this] val readLocksByTask = + new mutable.HashMap[TaskAttemptId, ConcurrentHashMultiset[BlockId]] + + // ---------------------------------------------------------------------------------------------- + + // Initialization for special task attempt ids: + registerTask(BlockInfo.NON_TASK_WRITER) + + // ---------------------------------------------------------------------------------------------- + + /** + * Called at the start of a task in order to register that task with this [[BlockInfoManager]]. + * This must be called prior to calling any other BlockInfoManager methods from that task. + */ + def registerTask(taskAttemptId: TaskAttemptId): Unit = synchronized { + require(!readLocksByTask.contains(taskAttemptId), + s"Task attempt $taskAttemptId is already registered") + readLocksByTask(taskAttemptId) = ConcurrentHashMultiset.create() + } + + /** + * Returns the current task's task attempt id (which uniquely identifies the task), or + * [[BlockInfo.NON_TASK_WRITER]] if called by a non-task thread. + */ + private def currentTaskAttemptId: TaskAttemptId = { + Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(BlockInfo.NON_TASK_WRITER) + } + + /** + * Lock a block for reading and return its metadata. + * + * If another task has already locked this block for reading, then the read lock will be + * immediately granted to the calling task and its lock count will be incremented. + * + * If another task has locked this block for writing, then this call will block until the write + * lock is released or will return immediately if `blocking = false`. + * + * A single task can lock a block multiple times for reading, in which case each lock will need + * to be released separately. + * + * @param blockId the block to lock. + * @param blocking if true (default), this call will block until the lock is acquired. If false, + * this call will return immediately if the lock acquisition fails. + * @return None if the block did not exist or was removed (in which case no lock is held), or + * Some(BlockInfo) (in which case the block is locked for reading). + */ + def lockForReading( + blockId: BlockId, + blocking: Boolean = true): Option[BlockInfo] = synchronized { + logTrace(s"Task $currentTaskAttemptId trying to acquire read lock for $blockId") + infos.get(blockId).map { info => + while (info.writerTask != BlockInfo.NO_WRITER) { + if (blocking) wait() else return None + } + if (info.removed) return None + info.readerCount += 1 + readLocksByTask(currentTaskAttemptId).add(blockId) + logTrace(s"Task $currentTaskAttemptId acquired read lock for $blockId") + info + } + } + + /** + * Lock a block for writing and return its metadata. + * + * If another task has already locked this block for either reading or writing, then this call + * will block until the other locks are released or will return immediately if `blocking = false`. + * + * If this is called by a task which already holds the block's exclusive write lock, then this + * method will throw an exception. + * + * @param blockId the block to lock. + * @param blocking if true (default), this call will block until the lock is acquired. If false, + * this call will return immediately if the lock acquisition fails. + * @return None if the block did not exist or was removed (in which case no lock is held), or + * Some(BlockInfo) (in which case the block is locked for writing). + */ + def lockForWriting( + blockId: BlockId, + blocking: Boolean = true): Option[BlockInfo] = synchronized { + logTrace(s"Task $currentTaskAttemptId trying to acquire write lock for $blockId") + infos.get(blockId).map { info => + if (info.writerTask == currentTaskAttemptId) { + throw new IllegalStateException( + s"Task $currentTaskAttemptId has already locked $blockId for writing") + } else { + while (info.writerTask != BlockInfo.NO_WRITER || info.readerCount != 0) { + if (blocking) wait() else return None + } + if (info.removed) return None + } + info.writerTask = currentTaskAttemptId + writeLocksByTask.addBinding(currentTaskAttemptId, blockId) + logTrace(s"Task $currentTaskAttemptId acquired write lock for $blockId") + info + } + } + + /** + * Throws an exception if the current task does not hold a write lock on the given block. + * Otherwise, returns the block's BlockInfo. + */ + def assertBlockIsLockedForWriting(blockId: BlockId): BlockInfo = synchronized { + infos.get(blockId) match { + case Some(info) => + if (info.writerTask != currentTaskAttemptId) { + throw new SparkException( + s"Task $currentTaskAttemptId has not locked block $blockId for writing") + } else { + info + } + case None => + throw new SparkException(s"Block $blockId does not exist") + } + } + + /** + * Get a block's metadata without acquiring any locks. This method is only exposed for use by + * [[BlockManager.getStatus()]] and should not be called by other code outside of this class. + */ + private[storage] def get(blockId: BlockId): Option[BlockInfo] = synchronized { + infos.get(blockId) + } + + /** + * Downgrades an exclusive write lock to a shared read lock. + */ + def downgradeLock(blockId: BlockId): Unit = synchronized { + logTrace(s"Task $currentTaskAttemptId downgrading write lock for $blockId") + val info = get(blockId).get + require(info.writerTask == currentTaskAttemptId, + s"Task $currentTaskAttemptId tried to downgrade a write lock that it does not hold on" + + s" block $blockId") + unlock(blockId) + val lockOutcome = lockForReading(blockId, blocking = false) + assert(lockOutcome.isDefined) + } + + /** + * Release a lock on the given block. + */ + def unlock(blockId: BlockId): Unit = synchronized { + logTrace(s"Task $currentTaskAttemptId releasing lock for $blockId") + val info = get(blockId).getOrElse { + throw new IllegalStateException(s"Block $blockId not found") + } + if (info.writerTask != BlockInfo.NO_WRITER) { + info.writerTask = BlockInfo.NO_WRITER + writeLocksByTask.removeBinding(currentTaskAttemptId, blockId) + } else { + assert(info.readerCount > 0, s"Block $blockId is not locked for reading") + info.readerCount -= 1 + val countsForTask = readLocksByTask(currentTaskAttemptId) + val newPinCountForTask: Int = countsForTask.remove(blockId, 1) - 1 + assert(newPinCountForTask >= 0, + s"Task $currentTaskAttemptId release lock on block $blockId more times than it acquired it") + } + notifyAll() + } + + /** + * Atomically create metadata for a block and acquire a write lock for it, if it doesn't already + * exist. + * + * @param blockId the block id. + * @param newBlockInfo the block info for the new block. + * @return true if the block did not already exist, false otherwise. If this returns false, then + * no new locks are acquired. If this returns true, a write lock on the new block will + * be held. + */ + def lockNewBlockForWriting( + blockId: BlockId, + newBlockInfo: BlockInfo): Boolean = synchronized { + logTrace(s"Task $currentTaskAttemptId trying to put $blockId") + if (!infos.contains(blockId)) { + infos(blockId) = newBlockInfo + newBlockInfo.writerTask = currentTaskAttemptId + writeLocksByTask.addBinding(currentTaskAttemptId, blockId) + logTrace(s"Task $currentTaskAttemptId successfully locked new block $blockId") + true + } else { + logTrace(s"Task $currentTaskAttemptId did not create and lock block $blockId " + + s"because that block already exists") + false + } + } + + /** + * Release all lock held by the given task, clearing that task's pin bookkeeping + * structures and updating the global pin counts. This method should be called at the + * end of a task (either by a task completion handler or in `TaskRunner.run()`). + * + * @return the ids of blocks whose pins were released + */ + def releaseAllLocksForTask(taskAttemptId: TaskAttemptId): Seq[BlockId] = { + val blocksWithReleasedLocks = mutable.ArrayBuffer[BlockId]() + + val readLocks = synchronized { + readLocksByTask.remove(taskAttemptId).get + } + val writeLocks = synchronized { + writeLocksByTask.remove(taskAttemptId).getOrElse(Seq.empty) + } + + for (blockId <- writeLocks) { + infos.get(blockId).foreach { info => + assert(info.writerTask == taskAttemptId) + info.writerTask = BlockInfo.NO_WRITER + } + blocksWithReleasedLocks += blockId + } + readLocks.entrySet().iterator().asScala.foreach { entry => + val blockId = entry.getElement + val lockCount = entry.getCount + blocksWithReleasedLocks += blockId + synchronized { + get(blockId).foreach { info => + info.readerCount -= lockCount + assert(info.readerCount >= 0) + } + } + } + + synchronized { + notifyAll() + } + blocksWithReleasedLocks + } + + /** + * Returns the number of blocks tracked. + */ + def size: Int = synchronized { + infos.size + } + + /** + * Return the number of map entries in this pin counter's internal data structures. + * This is used in unit tests in order to detect memory leaks. + */ + private[storage] def getNumberOfMapEntries: Long = synchronized { + size + + readLocksByTask.size + + readLocksByTask.map(_._2.size()).sum + + writeLocksByTask.size + + writeLocksByTask.map(_._2.size).sum + } + + /** + * Returns an iterator over a snapshot of all blocks' metadata. Note that the individual entries + * in this iterator are mutable and thus may reflect blocks that are deleted while the iterator + * is being traversed. + */ + def entries: Iterator[(BlockId, BlockInfo)] = synchronized { + infos.toArray.toIterator + } + + /** + * Removes the given block and releases the write lock on it. + * + * This can only be called while holding a write lock on the given block. + */ + def removeBlock(blockId: BlockId): Unit = synchronized { + logTrace(s"Task $currentTaskAttemptId trying to remove block $blockId") + infos.get(blockId) match { + case Some(blockInfo) => + if (blockInfo.writerTask != currentTaskAttemptId) { + throw new IllegalStateException( + s"Task $currentTaskAttemptId called remove() on block $blockId without a write lock") + } else { + infos.remove(blockId) + blockInfo.readerCount = 0 + blockInfo.writerTask = BlockInfo.NO_WRITER + blockInfo.removed = true + } + case None => + throw new IllegalArgumentException( + s"Task $currentTaskAttemptId called remove() on non-existent block $blockId") + } + notifyAll() + } + + /** + * Delete all state. Called during shutdown. + */ + def clear(): Unit = synchronized { + infos.valuesIterator.foreach { blockInfo => + blockInfo.readerCount = 0 + blockInfo.writerTask = BlockInfo.NO_WRITER + blockInfo.removed = true + } + infos.clear() + readLocksByTask.clear() + writeLocksByTask.clear() + notifyAll() + } + +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 77fd03a6bcfc5..29124b368e405 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -19,9 +19,7 @@ package org.apache.spark.storage import java.io._ import java.nio.{ByteBuffer, MappedByteBuffer} -import java.util.concurrent.ConcurrentHashMap -import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.concurrent.{Await, ExecutionContext, Future} import scala.concurrent.duration._ @@ -77,7 +75,7 @@ private[spark] class BlockManager( val diskBlockManager = new DiskBlockManager(this, conf) - private val blockInfo = new ConcurrentHashMap[BlockId, BlockInfo] + private[storage] val blockInfoManager = new BlockInfoManager private val futureExecutionContext = ExecutionContext.fromExecutorService( ThreadUtils.newDaemonCachedThreadPool("block-manager-future", 128)) @@ -223,8 +221,8 @@ private[spark] class BlockManager( * will be made then. */ private def reportAllBlocks(): Unit = { - logInfo(s"Reporting ${blockInfo.size} blocks to the master.") - for ((blockId, info) <- blockInfo.asScala) { + logInfo(s"Reporting ${blockInfoManager.size} blocks to the master.") + for ((blockId, info) <- blockInfoManager.entries) { val status = getCurrentBlockStatus(blockId, info) if (!tryToReportBlockStatus(blockId, info, status)) { logError(s"Failed to report $blockId to master; giving up.") @@ -286,7 +284,7 @@ private[spark] class BlockManager( .asInstanceOf[Option[ByteBuffer]] if (blockBytesOpt.isDefined) { val buffer = blockBytesOpt.get - new NioManagedBuffer(buffer) + new BlockManagerManagedBuffer(this, blockId, buffer) } else { throw new BlockNotFoundException(blockId.toString) } @@ -296,7 +294,7 @@ private[spark] class BlockManager( /** * Put the block locally, using the given storage level. */ - override def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Unit = { + override def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Boolean = { putBytes(blockId, data.nioByteBuffer(), level) } @@ -305,7 +303,7 @@ private[spark] class BlockManager( * NOTE: This is mainly for testing, and it doesn't fetch information from external block store. */ def getStatus(blockId: BlockId): Option[BlockStatus] = { - blockInfo.asScala.get(blockId).map { info => + blockInfoManager.get(blockId).map { info => val memSize = if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L val diskSize = if (diskStore.contains(blockId)) diskStore.getSize(blockId) else 0L BlockStatus(info.level, memSize = memSize, diskSize = diskSize) @@ -318,7 +316,12 @@ private[spark] class BlockManager( * may not know of). */ def getMatchingBlockIds(filter: BlockId => Boolean): Seq[BlockId] = { - (blockInfo.asScala.keys ++ diskBlockManager.getAllBlocks()).filter(filter).toSeq + // The `toArray` is necessary here in order to force the list to be materialized so that we + // don't try to serialize a lazy iterator when responding to client requests. + (blockInfoManager.entries.map(_._1) ++ diskBlockManager.getAllBlocks()) + .filter(filter) + .toArray + .toSeq } /** @@ -425,26 +428,11 @@ private[spark] class BlockManager( } private def doGetLocal(blockId: BlockId, asBlockResult: Boolean): Option[Any] = { - val info = blockInfo.get(blockId) - if (info != null) { - info.synchronized { - // Double check to make sure the block is still there. There is a small chance that the - // block has been removed by removeBlock (which also synchronizes on the blockInfo object). - // Note that this only checks metadata tracking. If user intentionally deleted the block - // on disk or from off heap storage without using removeBlock, this conditional check will - // still pass but eventually we will get an exception because we can't find the block. - if (blockInfo.asScala.get(blockId).isEmpty) { - logWarning(s"Block $blockId had been removed") - return None - } - - // If another thread is writing the block, wait for it to become ready. - if (!info.waitForReady()) { - // If we get here, the block write failed. - logWarning(s"Block $blockId was marked as failure.") - return None - } - + blockInfoManager.lockForReading(blockId) match { + case None => + logDebug(s"Block $blockId was not found") + None + case Some(info) => val level = info.level logDebug(s"Level for block $blockId is $level") @@ -452,7 +440,10 @@ private[spark] class BlockManager( if (level.useMemory) { logDebug(s"Getting block $blockId from memory") val result = if (asBlockResult) { - memoryStore.getValues(blockId).map(new BlockResult(_, DataReadMethod.Memory, info.size)) + memoryStore.getValues(blockId).map { iter => + val ci = CompletionIterator[Any, Iterator[Any]](iter, releaseLock(blockId)) + new BlockResult(ci, DataReadMethod.Memory, info.size) + } } else { memoryStore.getBytes(blockId) } @@ -470,6 +461,7 @@ private[spark] class BlockManager( val bytes: ByteBuffer = diskStore.getBytes(blockId) match { case Some(b) => b case None => + releaseLock(blockId) throw new BlockException( blockId, s"Block $blockId not found on disk, though it should be") } @@ -478,8 +470,9 @@ private[spark] class BlockManager( if (!level.useMemory) { // If the block shouldn't be stored in memory, we can just return it if (asBlockResult) { - return Some(new BlockResult(dataDeserialize(blockId, bytes), DataReadMethod.Disk, - info.size)) + val iter = CompletionIterator[Any, Iterator[Any]]( + dataDeserialize(blockId, bytes), releaseLock(blockId)) + return Some(new BlockResult(iter, DataReadMethod.Disk, info.size)) } else { return Some(bytes) } @@ -511,26 +504,34 @@ private[spark] class BlockManager( // space to unroll the block. Either way, the put here should return an iterator. putResult.data match { case Left(it) => - return Some(new BlockResult(it, DataReadMethod.Disk, info.size)) + val ci = CompletionIterator[Any, Iterator[Any]](it, releaseLock(blockId)) + return Some(new BlockResult(ci, DataReadMethod.Disk, info.size)) case _ => // This only happens if we dropped the values back to disk (which is never) throw new SparkException("Memory store did not return an iterator!") } } else { - return Some(new BlockResult(values, DataReadMethod.Disk, info.size)) + val ci = CompletionIterator[Any, Iterator[Any]](values, releaseLock(blockId)) + return Some(new BlockResult(ci, DataReadMethod.Disk, info.size)) } } } + } else { + // This branch represents a case where the BlockInfoManager contained an entry for + // the block but the block could not be found in any of the block stores. This case + // should never occur, but for completeness's sake we address it here. + logError( + s"Block $blockId is supposedly stored locally but was not found in any block store") + releaseLock(blockId) + None } - } - } else { - logDebug(s"Block $blockId not registered locally") } - None } /** * Get block from remote block managers. + * + * This does not acquire a lock on this block in this JVM. */ def getRemote(blockId: BlockId): Option[BlockResult] = { logDebug(s"Getting remote block $blockId") @@ -597,6 +598,10 @@ private[spark] class BlockManager( /** * Get a block from the block manager (either local or remote). + * + * This acquires a read lock on the block if the block was stored locally and does not acquire + * any locks if the block was fetched from a remote block manager. The read lock will + * automatically be freed once the result's `data` iterator is fully consumed. */ def get(blockId: BlockId): Option[BlockResult] = { val local = getLocal(blockId) @@ -612,6 +617,36 @@ private[spark] class BlockManager( None } + /** + * Downgrades an exclusive write lock to a shared read lock. + */ + def downgradeLock(blockId: BlockId): Unit = { + blockInfoManager.downgradeLock(blockId) + } + + /** + * Release a lock on the given block. + */ + def releaseLock(blockId: BlockId): Unit = { + blockInfoManager.unlock(blockId) + } + + /** + * Registers a task with the BlockManager in order to initialize per-task bookkeeping structures. + */ + def registerTask(taskAttemptId: Long): Unit = { + blockInfoManager.registerTask(taskAttemptId) + } + + /** + * Release all locks for the given task. + * + * @return the blocks whose locks were released. + */ + def releaseAllLocksForTask(taskAttemptId: Long): Seq[BlockId] = { + blockInfoManager.releaseAllLocksForTask(taskAttemptId) + } + /** * @return true if the block was stored or false if the block was already stored or an * error occurred. @@ -703,19 +738,12 @@ private[spark] class BlockManager( * to be dropped right after it got put into memory. Note, however, that other threads will * not be able to get() this block until we call markReady on its BlockInfo. */ val putBlockInfo = { - val tinfo = new BlockInfo(level, tellMaster) - // Do atomically ! - val oldBlockOpt = Option(blockInfo.putIfAbsent(blockId, tinfo)) - if (oldBlockOpt.isDefined) { - if (oldBlockOpt.get.waitForReady()) { - logWarning(s"Block $blockId already exists on this machine; not re-adding it") - return false - } - // TODO: So the block info exists - but previous attempt to load it (?) failed. - // What do we do now ? Retry on it ? - oldBlockOpt.get + val newInfo = new BlockInfo(level, tellMaster) + if (blockInfoManager.lockNewBlockForWriting(blockId, newInfo)) { + newInfo } else { - tinfo + logWarning(s"Block $blockId already exists on this machine; not re-adding it") + return false } } @@ -750,7 +778,7 @@ private[spark] class BlockManager( case _ => null } - var marked = false + var blockWasSuccessfullyStored = false putBlockInfo.synchronized { logTrace("Put for block %s took %s to get into synchronized block" @@ -792,11 +820,11 @@ private[spark] class BlockManager( } val putBlockStatus = getCurrentBlockStatus(blockId, putBlockInfo) - if (putBlockStatus.storageLevel != StorageLevel.NONE) { + blockWasSuccessfullyStored = putBlockStatus.storageLevel.isValid + if (blockWasSuccessfullyStored) { // Now that the block is in either the memory, externalBlockStore, or disk store, // let other threads read it, and tell the master about it. - marked = true - putBlockInfo.markReady(size) + putBlockInfo.size = size if (tellMaster) { reportBlockStatus(blockId, putBlockInfo, putBlockStatus) } @@ -805,13 +833,10 @@ private[spark] class BlockManager( } } } finally { - // If we failed in putting the block to memory/disk, notify other possible readers - // that it has failed, and then remove it from the block info map. - if (!marked) { - // Note that the remove must happen before markFailure otherwise another thread - // could've inserted a new BlockInfo before we remove it. - blockInfo.remove(blockId) - putBlockInfo.markFailure() + if (blockWasSuccessfullyStored) { + blockInfoManager.downgradeLock(blockId) + } else { + blockInfoManager.removeBlock(blockId) logWarning(s"Putting block $blockId failed") } } @@ -852,7 +877,7 @@ private[spark] class BlockManager( .format(blockId, Utils.getUsedTimeMs(startTimeMs))) } - marked + blockWasSuccessfullyStored } /** @@ -989,77 +1014,63 @@ private[spark] class BlockManager( * store reaches its limit and needs to free up space. * * If `data` is not put on disk, it won't be created. + * + * The caller of this method must hold a write lock on the block before calling this method. + * This method does not release the write lock. + * + * @return the block's new effective StorageLevel. */ def dropFromMemory( blockId: BlockId, - data: () => Either[Array[Any], ByteBuffer]): Unit = { - + data: () => Either[Array[Any], ByteBuffer]): StorageLevel = { logInfo(s"Dropping block $blockId from memory") - val info = blockInfo.get(blockId) - - // If the block has not already been dropped - if (info != null) { - info.synchronized { - // required ? As of now, this will be invoked only for blocks which are ready - // But in case this changes in future, adding for consistency sake. - if (!info.waitForReady()) { - // If we get here, the block write failed. - logWarning(s"Block $blockId was marked as failure. Nothing to drop") - return - } else if (blockInfo.asScala.get(blockId).isEmpty) { - logWarning(s"Block $blockId was already dropped.") - return - } - var blockIsUpdated = false - val level = info.level - - // Drop to disk, if storage level requires - if (level.useDisk && !diskStore.contains(blockId)) { - logInfo(s"Writing block $blockId to disk") - data() match { - case Left(elements) => - diskStore.putArray(blockId, elements, level, returnValues = false) - case Right(bytes) => - diskStore.putBytes(blockId, bytes, level) - } - blockIsUpdated = true - } + val info = blockInfoManager.assertBlockIsLockedForWriting(blockId) + var blockIsUpdated = false + val level = info.level + + // Drop to disk, if storage level requires + if (level.useDisk && !diskStore.contains(blockId)) { + logInfo(s"Writing block $blockId to disk") + data() match { + case Left(elements) => + diskStore.putArray(blockId, elements, level, returnValues = false) + case Right(bytes) => + diskStore.putBytes(blockId, bytes, level) + } + blockIsUpdated = true + } - // Actually drop from memory store - val droppedMemorySize = - if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L - val blockIsRemoved = memoryStore.remove(blockId) - if (blockIsRemoved) { - blockIsUpdated = true - } else { - logWarning(s"Block $blockId could not be dropped from memory as it does not exist") - } + // Actually drop from memory store + val droppedMemorySize = + if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L + val blockIsRemoved = memoryStore.remove(blockId) + if (blockIsRemoved) { + blockIsUpdated = true + } else { + logWarning(s"Block $blockId could not be dropped from memory as it does not exist") + } - val status = getCurrentBlockStatus(blockId, info) - if (info.tellMaster) { - reportBlockStatus(blockId, info, status, droppedMemorySize) - } - if (!level.useDisk) { - // The block is completely gone from this node; forget it so we can put() it again later. - blockInfo.remove(blockId) - } - if (blockIsUpdated) { - Option(TaskContext.get()).foreach { c => - c.taskMetrics().incUpdatedBlockStatuses(Seq((blockId, status))) - } - } + val status = getCurrentBlockStatus(blockId, info) + if (info.tellMaster) { + reportBlockStatus(blockId, info, status, droppedMemorySize) + } + if (blockIsUpdated) { + Option(TaskContext.get()).foreach { c => + c.taskMetrics().incUpdatedBlockStatuses(Seq((blockId, status))) } } + status.storageLevel } /** * Remove all blocks belonging to the given RDD. + * * @return The number of blocks removed. */ def removeRdd(rddId: Int): Int = { // TODO: Avoid a linear scan by creating another mapping of RDD.id to blocks. logInfo(s"Removing RDD $rddId") - val blocksToRemove = blockInfo.asScala.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) + val blocksToRemove = blockInfoManager.entries.flatMap(_._1.asRDDId).filter(_.rddId == rddId) blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster = false) } blocksToRemove.size } @@ -1069,7 +1080,7 @@ private[spark] class BlockManager( */ def removeBroadcast(broadcastId: Long, tellMaster: Boolean): Int = { logDebug(s"Removing broadcast $broadcastId") - val blocksToRemove = blockInfo.asScala.keys.collect { + val blocksToRemove = blockInfoManager.entries.map(_._1).collect { case bid @ BroadcastBlockId(`broadcastId`, _) => bid } blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster) } @@ -1081,9 +1092,11 @@ private[spark] class BlockManager( */ def removeBlock(blockId: BlockId, tellMaster: Boolean = true): Unit = { logDebug(s"Removing block $blockId") - val info = blockInfo.get(blockId) - if (info != null) { - info.synchronized { + blockInfoManager.lockForWriting(blockId) match { + case None => + // The block has already been removed; do nothing. + logWarning(s"Asked to remove block $blockId, which does not exist") + case Some(info) => // Removals are idempotent in disk store and memory store. At worst, we get a warning. val removedFromMemory = memoryStore.remove(blockId) val removedFromDisk = diskStore.remove(blockId) @@ -1091,15 +1104,11 @@ private[spark] class BlockManager( logWarning(s"Block $blockId could not be removed as it was not found in either " + "the disk, memory, or external block store") } - blockInfo.remove(blockId) + blockInfoManager.removeBlock(blockId) if (tellMaster && info.tellMaster) { val status = getCurrentBlockStatus(blockId, info) reportBlockStatus(blockId, info, status) } - } - } else { - // The block has already been removed; do nothing. - logWarning(s"Asked to remove block $blockId, which does not exist") } } @@ -1174,7 +1183,7 @@ private[spark] class BlockManager( } diskBlockManager.stop() rpcEnv.stop(slaveEndpoint) - blockInfo.clear() + blockInfoManager.clear() memoryStore.clear() diskStore.clear() futureExecutionContext.shutdownNow() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala new file mode 100644 index 0000000000000..5886b9c00b557 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.nio.ByteBuffer + +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} + +/** + * This [[ManagedBuffer]] wraps a [[ByteBuffer]] which was retrieved from the [[BlockManager]] + * so that the corresponding block's read lock can be released once this buffer's references + * are released. + * + * This is effectively a wrapper / bridge to connect the BlockManager's notion of read locks + * to the network layer's notion of retain / release counts. + */ +private[storage] class BlockManagerManagedBuffer( + blockManager: BlockManager, + blockId: BlockId, + buf: ByteBuffer) extends NioManagedBuffer(buf) { + + override def retain(): ManagedBuffer = { + super.retain() + val locked = blockManager.blockInfoManager.lockForReading(blockId, blocking = false) + assert(locked.isDefined) + this + } + + override def release(): ManagedBuffer = { + blockManager.releaseLock(blockId) + super.release() + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala index 69985c9759e2d..6f6a6773ba4fd 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala @@ -60,8 +60,10 @@ private[spark] abstract class BlockStore(val blockManager: BlockManager) extends /** * Remove a block, if it exists. + * * @param blockId the block to remove. * @return True if the block was found and removed, False otherwise. + * @throws IllegalStateException if the block is pinned by a task. */ def remove(blockId: BlockId): Boolean diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index c34d49c0d9061..9cc4084497731 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -203,6 +203,7 @@ private[spark] class DiskBlockObjectWriter( numRecordsWritten += 1 writeMetrics.incRecordsWritten(1) + // TODO: call updateBytesWritten() less frequently. if (numRecordsWritten % 32 == 0) { updateBytesWritten() } diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 024b660ce6a7b..2f16c8f3d8bad 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -95,7 +95,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo val values = blockManager.dataDeserialize(blockId, bytes) putIterator(blockId, values, level, returnValues = true) } else { - tryToPut(blockId, bytes, bytes.limit, deserialized = false) + tryToPut(blockId, () => bytes, bytes.limit, deserialized = false) PutResult(bytes.limit(), Right(bytes.duplicate())) } } @@ -127,11 +127,11 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo returnValues: Boolean): PutResult = { if (level.deserialized) { val sizeEstimate = SizeEstimator.estimate(values.asInstanceOf[AnyRef]) - tryToPut(blockId, values, sizeEstimate, deserialized = true) + tryToPut(blockId, () => values, sizeEstimate, deserialized = true) PutResult(sizeEstimate, Left(values.iterator)) } else { val bytes = blockManager.dataSerialize(blockId, values.iterator) - tryToPut(blockId, bytes, bytes.limit, deserialized = false) + tryToPut(blockId, () => bytes, bytes.limit, deserialized = false) PutResult(bytes.limit(), Right(bytes.duplicate())) } } @@ -208,7 +208,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo } override def remove(blockId: BlockId): Boolean = memoryManager.synchronized { - val entry = entries.synchronized { entries.remove(blockId) } + val entry = entries.synchronized { + entries.remove(blockId) + } if (entry != null) { memoryManager.releaseStorageMemory(entry.size) logDebug(s"Block $blockId of size ${entry.size} dropped " + @@ -327,14 +329,6 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo blockId.asRDDId.map(_.rddId) } - private def tryToPut( - blockId: BlockId, - value: Any, - size: Long, - deserialized: Boolean): Boolean = { - tryToPut(blockId, () => value, size, deserialized) - } - /** * Try to put in a set of values, if we can free up enough space. The value should either be * an Array if deserialized is true or a ByteBuffer otherwise. Its (possibly estimated) size @@ -410,6 +404,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo var freedMemory = 0L val rddToAdd = blockId.flatMap(getRddId) val selectedBlocks = new ArrayBuffer[BlockId] + def blockIsEvictable(blockId: BlockId): Boolean = { + rddToAdd.isEmpty || rddToAdd != getRddId(blockId) + } // This is synchronized to ensure that the set of entries is not changed // (because of getValue or getBytes) while traversing the iterator, as that // can lead to exceptions. @@ -418,9 +415,14 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo while (freedMemory < space && iterator.hasNext) { val pair = iterator.next() val blockId = pair.getKey - if (rddToAdd.isEmpty || rddToAdd != getRddId(blockId)) { - selectedBlocks += blockId - freedMemory += pair.getValue.size + if (blockIsEvictable(blockId)) { + // We don't want to evict blocks which are currently being read, so we need to obtain + // an exclusive write lock on blocks which are candidates for eviction. We perform a + // non-blocking "tryLock" here in order to ignore blocks which are locked for reading: + if (blockManager.blockInfoManager.lockForWriting(blockId, blocking = false).isDefined) { + selectedBlocks += blockId + freedMemory += pair.getValue.size + } } } } @@ -438,7 +440,16 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo } else { Right(entry.value.asInstanceOf[ByteBuffer].duplicate()) } - blockManager.dropFromMemory(blockId, () => data) + val newEffectiveStorageLevel = blockManager.dropFromMemory(blockId, () => data) + if (newEffectiveStorageLevel.isValid) { + // The block is still present in at least one store, so release the lock + // but don't delete the block info + blockManager.releaseLock(blockId) + } else { + // The block isn't present in any store, so delete the block info so that the + // block can be stored again + blockManager.blockInfoManager.removeBlock(blockId) + } } } freedMemory @@ -447,6 +458,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo logInfo(s"Will not store $id as it would require dropping another block " + "from the same RDD") } + selectedBlocks.foreach { id => + blockManager.releaseLock(id) + } 0L } } @@ -463,6 +477,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo /** * Reserve memory for unrolling the given block for this task. + * * @return whether the request is granted. */ def reserveUnrollMemoryForThisTask(blockId: BlockId, memory: Long): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala index d98aae8ff0c68..f552b498a76de 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala @@ -19,6 +19,7 @@ package org.apache.spark.storage import scala.collection.mutable +import org.apache.spark.SparkConf import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ @@ -29,14 +30,20 @@ import org.apache.spark.scheduler._ * This class is thread-safe (unlike JobProgressListener) */ @DeveloperApi -class StorageStatusListener extends SparkListener { +class StorageStatusListener(conf: SparkConf) extends SparkListener { // This maintains only blocks that are cached (i.e. storage level is not StorageLevel.NONE) private[storage] val executorIdToStorageStatus = mutable.Map[String, StorageStatus]() + private[storage] val deadExecutorStorageStatus = new mutable.ListBuffer[StorageStatus]() + private[this] val retainedDeadExecutors = conf.getInt("spark.ui.retainedDeadExecutors", 100) def storageStatusList: Seq[StorageStatus] = synchronized { executorIdToStorageStatus.values.toSeq } + def deadStorageStatusList: Seq[StorageStatus] = synchronized { + deadExecutorStorageStatus.toSeq + } + /** Update storage status list to reflect updated block statuses */ private def updateStorageStatus(execId: String, updatedBlocks: Seq[(BlockId, BlockStatus)]) { executorIdToStorageStatus.get(execId).foreach { storageStatus => @@ -87,8 +94,12 @@ class StorageStatusListener extends SparkListener { override def onBlockManagerRemoved(blockManagerRemoved: SparkListenerBlockManagerRemoved) { synchronized { val executorId = blockManagerRemoved.blockManagerId.executorId - executorIdToStorageStatus.remove(executorId) + executorIdToStorageStatus.remove(executorId).foreach { status => + deadExecutorStorageStatus += status + } + if (deadExecutorStorageStatus.size > retainedDeadExecutors) { + deadExecutorStorageStatus.trimStart(1) + } } } - } diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 6cc30eeaf5d82..5324a7682960b 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -196,7 +196,7 @@ private[spark] object SparkUI { } val environmentListener = new EnvironmentListener - val storageStatusListener = new StorageStatusListener + val storageStatusListener = new StorageStatusListener(conf) val executorsListener = new ExecutorsListener(storageStatusListener, conf) val storageListener = new StorageListener(storageStatusListener) val operationGraphListener = new RDDOperationGraphListener(conf) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index ddd7f713fe417..0493513d667c2 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -408,13 +408,6 @@ private[spark] object UIUtils extends Logging { } - /** Return a script element that automatically expands the DAG visualization on page load. */ - def expandDagVizOnLoad(forJob: Boolean): Seq[Node] = { - - } - /** * Returns HTML rendering of a job or stage description. It will try to parse the string as HTML * and make sure that it only contains anchors with root-relative links. Otherwise, diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index e1f754999912b..eba7a312ba81f 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -54,24 +54,30 @@ private[ui] class ExecutorsPage( private val GCTimePercent = 0.1 def render(request: HttpServletRequest): Seq[Node] = { - val (storageStatusList, execInfo) = listener.synchronized { + val (activeExecutorInfo, deadExecutorInfo) = listener.synchronized { // The follow codes should be protected by `listener` to make sure no executors will be // removed before we query their status. See SPARK-12784. - val _storageStatusList = listener.storageStatusList - val _execInfo = { - for (statusId <- 0 until _storageStatusList.size) - yield ExecutorsPage.getExecInfo(listener, statusId) + val _activeExecutorInfo = { + for (statusId <- 0 until listener.activeStorageStatusList.size) + yield ExecutorsPage.getExecInfo(listener, statusId, isActive = true) } - (_storageStatusList, _execInfo) + val _deadExecutorInfo = { + for (statusId <- 0 until listener.deadStorageStatusList.size) + yield ExecutorsPage.getExecInfo(listener, statusId, isActive = false) + } + (_activeExecutorInfo, _deadExecutorInfo) } + + val execInfo = activeExecutorInfo ++ deadExecutorInfo val execInfoSorted = execInfo.sortBy(_.id) val logsExist = execInfo.filter(_.executorLogs.nonEmpty).nonEmpty - val execTable = + val execTable = { + @@ -98,22 +104,28 @@ private[ui] class ExecutorsPage( {execInfoSorted.map(execRow(_, logsExist))}
Executor ID AddressStatus RDD Blocks Storage Memory Disk Used
+ } val content =
-

Totals for {execInfo.size} Executors

- {execSummary(execInfo)} +

Dead Executors({deadExecutorInfo.size})

+
+
+
+
+

Active Executors({activeExecutorInfo.size})

+ {execSummary(activeExecutorInfo)}
-

Active Executors

+

Executors

{execTable}
; - UIUtils.headerSparkPage("Executors (" + execInfo.size + ")", content, parent) + UIUtils.headerSparkPage("Executors", content, parent) } /** Render an HTML row representing an executor */ @@ -121,9 +133,19 @@ private[ui] class ExecutorsPage( val maximumMemory = info.maxMemory val memoryUsed = info.memoryUsed val diskUsed = info.diskUsed + val executorStatus = + if (info.isActive) { + "Active" + } else { + "Dead" + } + {info.id} {info.hostPort} + + {executorStatus} + {info.rddBlocks} {Utils.bytesToString(memoryUsed)} / @@ -161,10 +183,14 @@ private[ui] class ExecutorsPage( } { if (threadDumpEnabled) { - val encodedId = URLEncoder.encode(info.id, "UTF-8") - - Thread Dump - + if (info.isActive) { + val encodedId = URLEncoder.encode(info.id, "UTF-8") + + Thread Dump + + } else { + + } } else { Seq.empty } @@ -236,14 +262,13 @@ private[ui] class ExecutorsPage( } private def taskData( - maxTasks: Int, - activeTasks: Int, - failedTasks: Int, - completedTasks: Int, - totalTasks: Int, - totalDuration: Long, - totalGCTime: Long): - Seq[Node] = { + maxTasks: Int, + activeTasks: Int, + failedTasks: Int, + completedTasks: Int, + totalTasks: Int, + totalDuration: Long, + totalGCTime: Long): Seq[Node] = { // Determine Color Opacity from 0.5-1 // activeTasks range from 0 to maxTasks val activeTasksAlpha = @@ -302,8 +327,15 @@ private[ui] class ExecutorsPage( private[spark] object ExecutorsPage { /** Represent an executor's info as a map given a storage status index */ - def getExecInfo(listener: ExecutorsListener, statusId: Int): ExecutorSummary = { - val status = listener.storageStatusList(statusId) + def getExecInfo( + listener: ExecutorsListener, + statusId: Int, + isActive: Boolean): ExecutorSummary = { + val status = if (isActive) { + listener.activeStorageStatusList(statusId) + } else { + listener.deadStorageStatusList(statusId) + } val execId = status.blockManagerId.executorId val hostPort = status.blockManagerId.hostPort val rddBlocks = status.numBlocks @@ -326,6 +358,7 @@ private[spark] object ExecutorsPage { new ExecutorSummary( execId, hostPort, + isActive, rddBlocks, memUsed, diskUsed, diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index dcfebe92ed805..788f35ec77d9f 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -61,7 +61,9 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: Spar val executorToLogUrls = HashMap[String, Map[String, String]]() val executorIdToData = HashMap[String, ExecutorUIData]() - def storageStatusList: Seq[StorageStatus] = storageStatusListener.storageStatusList + def activeStorageStatusList: Seq[StorageStatus] = storageStatusListener.storageStatusList + + def deadStorageStatusList: Seq[StorageStatus] = storageStatusListener.deadStorageStatusList override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = synchronized { val eid = executorAdded.executorId @@ -81,7 +83,7 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: Spar override def onApplicationStart(applicationStart: SparkListenerApplicationStart): Unit = { applicationStart.driverLogs.foreach { logs => - val storageStatus = storageStatusList.find { s => + val storageStatus = activeStorageStatusList.find { s => s.blockManagerId.executorId == SparkContext.LEGACY_DRIVER_IDENTIFIER || s.blockManagerId.executorId == SparkContext.DRIVER_IDENTIFIER } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 0b68b88566b70..689ab7dd5ed62 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -107,10 +107,6 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val taskPageSize = Option(parameterTaskPageSize).map(_.toInt).getOrElse(100) val taskPrevPageSize = Option(parameterTaskPrevPageSize).map(_.toInt).getOrElse(taskPageSize) - // If this is set, expand the dag visualization by default - val expandDagVizParam = request.getParameter("expandDagViz") - val expandDagViz = expandDagVizParam != null && expandDagVizParam.toBoolean - val stageId = parameterId.toInt val stageAttemptId = parameterAttempt.toInt val stageDataOption = progressListener.stageIdToData.get((stageId, stageAttemptId)) @@ -263,13 +259,6 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val dagViz = UIUtils.showDagVizForStage( stageId, operationGraphListener.getOperationGraphForStage(stageId)) - val maybeExpandDagViz: Seq[Node] = - if (expandDagViz) { - UIUtils.expandDagVizOnLoad(forJob = false) - } else { - Seq.empty - } - val accumulableHeaders: Seq[String] = Seq("Accumulable", "Value") def accumulableRow(acc: AccumulableInfo): Seq[Node] = { (acc.name, acc.value) match { @@ -578,7 +567,6 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val content = summary ++ dagViz ++ - maybeExpandDagViz ++ showAdditionalMetrics ++ makeTimeline( // Only show the tasks in the table diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala index 003c218aada9c..cb2827199853a 100644 --- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala +++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala @@ -20,10 +20,11 @@ package org.apache.spark.ui.scope import scala.collection.mutable import scala.collection.mutable.{ListBuffer, StringBuilder} +import org.apache.commons.lang3.StringEscapeUtils + import org.apache.spark.Logging import org.apache.spark.scheduler.StageInfo import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.CallSite /** * A representation of a generic cluster graph used for storing information on RDD operations. @@ -183,7 +184,7 @@ private[ui] object RDDOperationGraph extends Logging { /** Return the dot representation of a node in an RDDOperationGraph. */ private def makeDotNode(node: RDDOperationNode): String = { val label = s"${node.name} [${node.id}]\n${node.callsite}" - s"""${node.id} [label="$label"]""" + s"""${node.id} [label="${StringEscapeUtils.escapeJava(label)}"]""" } /** Update the dot representation of the RDDOperationGraph in cluster to subgraph. */ @@ -192,7 +193,7 @@ private[ui] object RDDOperationGraph extends Logging { cluster: RDDOperationCluster, indent: String): Unit = { subgraph.append(indent).append(s"subgraph cluster${cluster.id} {\n") - subgraph.append(indent).append(s""" label="${cluster.name}";\n""") + .append(indent).append(s""" label="${StringEscapeUtils.escapeJava(cluster.name)}";\n""") cluster.childNodes.foreach { node => subgraph.append(indent).append(s" ${makeDotNode(node)};\n") } diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala index f1e28b4e1e9c2..8f75b586e1399 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala @@ -43,7 +43,7 @@ class StorageListener(storageStatusListener: StorageStatusListener) extends Bloc private[ui] val _rddInfoMap = mutable.Map[Int, RDDInfo]() // exposed for testing - def storageStatusList: Seq[StorageStatus] = storageStatusListener.storageStatusList + def activeStorageStatusList: Seq[StorageStatus] = storageStatusListener.storageStatusList /** Filter RDD info to include only those with cached partitions */ def rddInfoList: Seq[RDDInfo] = synchronized { @@ -54,7 +54,7 @@ class StorageListener(storageStatusListener: StorageStatusListener) extends Bloc private def updateRDDInfo(updatedBlocks: Seq[(BlockId, BlockStatus)]): Unit = { val rddIdsToUpdate = updatedBlocks.flatMap { case (bid, _) => bid.asRDDId.map(_.rddId) }.toSet val rddInfosToUpdate = _rddInfoMap.values.toSeq.filter { s => rddIdsToUpdate.contains(s.id) } - StorageUtils.updateRddInfo(rddInfosToUpdate, storageStatusList) + StorageUtils.updateRddInfo(rddInfosToUpdate, activeStorageStatusList) } /** diff --git a/core/src/main/scala/org/apache/spark/util/Benchmark.scala b/core/src/main/scala/org/apache/spark/util/Benchmark.scala index 39d1829310762..b562b58f1b6bf 100644 --- a/core/src/main/scala/org/apache/spark/util/Benchmark.scala +++ b/core/src/main/scala/org/apache/spark/util/Benchmark.scala @@ -19,6 +19,7 @@ package org.apache.spark.util import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.util.Try import org.apache.commons.lang3.SystemUtils @@ -93,7 +94,10 @@ private[spark] object Benchmark { if (SystemUtils.IS_OS_MAC_OSX) { Utils.executeAndGetOutput(Seq("/usr/sbin/sysctl", "-n", "machdep.cpu.brand_string")) } else if (SystemUtils.IS_OS_LINUX) { - Utils.executeAndGetOutput(Seq("/usr/bin/grep", "-m", "1", "\"model name\"", "/proc/cpuinfo")) + Try { + val grepPath = Utils.executeAndGetOutput(Seq("which", "grep")) + Utils.executeAndGetOutput(Seq(grepPath, "-m", "1", "model name", "/proc/cpuinfo")) + }.getOrElse("Unknown processor") } else { System.getenv("PROCESSOR_IDENTIFIER") } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index e0c9bf02a1a20..6103a10ccc50e 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2195,6 +2195,16 @@ private[spark] object Utils extends Logging { isInDirectory(parent, child.getParentFile) } + + /** + * + * @return whether it is local mode + */ + def isLocalMaster(conf: SparkConf): Boolean = { + val master = conf.get("spark.master", "") + master == "local" || master.startsWith("local[") + } + /** * Return whether dynamic allocation is enabled in the given conf * Dynamic allocation and explicitly setting the number of executors are inherently @@ -2202,8 +2212,13 @@ private[spark] object Utils extends Logging { * the latter should override the former (SPARK-9092). */ def isDynamicAllocationEnabled(conf: SparkConf): Boolean = { - conf.getBoolean("spark.dynamicAllocation.enabled", false) && - conf.getInt("spark.executor.instances", 0) == 0 + val numExecutor = conf.getInt("spark.executor.instances", 0) + val dynamicAllocationEnabled = conf.getBoolean("spark.dynamicAllocation.enabled", false) + if (numExecutor != 0 && dynamicAllocationEnabled) { + logWarning("Dynamic Allocation and num executors both set, thus dynamic allocation disabled.") + } + numExecutor == 0 && dynamicAllocationEnabled && + (!isLocalMaster(conf) || conf.getBoolean("spark.dynamicAllocation.testing", false)) } def tryWithResource[R <: Closeable, T](createResource: => R)(f: R => T): T = { diff --git a/core/src/main/scala/org/apache/spark/util/taskListeners.scala b/core/src/main/scala/org/apache/spark/util/taskListeners.scala new file mode 100644 index 0000000000000..1be31e88ab68e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/taskListeners.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.EventListener + +import org.apache.spark.TaskContext +import org.apache.spark.annotation.DeveloperApi + +/** + * :: DeveloperApi :: + * + * Listener providing a callback function to invoke when a task's execution completes. + */ +@DeveloperApi +trait TaskCompletionListener extends EventListener { + def onTaskCompletion(context: TaskContext): Unit +} + + +/** + * :: DeveloperApi :: + * + * Listener providing a callback function to invoke when a task's execution encounters an error. + * Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times. + */ +@DeveloperApi +trait TaskFailureListener extends EventListener { + def onTaskFailure(context: TaskContext, error: Throwable): Unit +} + + +/** + * Exception thrown when there is an exception in executing the callback in TaskCompletionListener. + */ +private[spark] +class TaskCompletionListenerException( + errorMessages: Seq[String], + val previousError: Option[Throwable] = None) + extends RuntimeException { + + override def getMessage: String = { + if (errorMessages.size == 1) { + errorMessages.head + } else { + errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n") + } + + previousError.map { e => + "\n\nPrevious exception in task: " + e.getMessage + "\n" + + e.getStackTrace.mkString("\t", "\n\t", "") + }.getOrElse("") + } +} diff --git a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java index 4a918f725dc91..f914081d7d5b2 100644 --- a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java +++ b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java @@ -18,6 +18,8 @@ package test.org.apache.spark; import org.apache.spark.TaskContext; +import org.apache.spark.util.TaskCompletionListener; +import org.apache.spark.util.TaskFailureListener; /** * Something to make sure that TaskContext can be used in Java. @@ -32,10 +34,38 @@ public static void test() { tc.isRunningLocally(); tc.addTaskCompletionListener(new JavaTaskCompletionListenerImpl()); + tc.addTaskFailureListener(new JavaTaskFailureListenerImpl()); tc.attemptNumber(); tc.partitionId(); tc.stageId(); tc.taskAttemptId(); } + + /** + * A simple implementation of TaskCompletionListener that makes sure TaskCompletionListener and + * TaskContext is Java friendly. + */ + static class JavaTaskCompletionListenerImpl implements TaskCompletionListener { + @Override + public void onTaskCompletion(TaskContext context) { + context.isCompleted(); + context.isInterrupted(); + context.stageId(); + context.partitionId(); + context.isRunningLocally(); + context.addTaskCompletionListener(this); + } + } + + /** + * A simple implementation of TaskCompletionListener that makes sure TaskCompletionListener and + * TaskContext is Java friendly. + */ + static class JavaTaskFailureListenerImpl implements TaskFailureListener { + @Override + public void onTaskFailure(TaskContext context, Throwable error) { + } + } + } diff --git a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json index 9d5d224e55176..4a88eeee747dc 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json @@ -1,6 +1,7 @@ [ { "id" : "", "hostPort" : "localhost:57971", + "isActive" : true, "rddBlocks" : 8, "memoryUsed" : 28000128, "diskUsed" : 0, diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index 48a0282b30cf0..ffc02bcb011f3 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -87,6 +87,7 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before val context = TaskContext.empty() try { TaskContext.setTaskContext(context) + sc.env.blockManager.registerTask(0) cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY) assert(context.taskMetrics.updatedBlockStatuses.size === 2) } finally { diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 7b0238091730d..d1e806b2eb80a 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -485,7 +485,8 @@ class CleanerTester( def assertCleanup()(implicit waitTimeout: PatienceConfiguration.Timeout) { try { eventually(waitTimeout, interval(100 millis)) { - assert(isAllCleanedUp) + assert(isAllCleanedUp, + "The following resources were not cleaned up:\n" + uncleanedResourcesToString) } postCleanupValidate() } finally { diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index b96c937f02ecf..9b6ab7b6bcfee 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -29,15 +29,21 @@ class SparkContextSchedulerCreationSuite extends SparkFunSuite with LocalSparkContext with PrivateMethodTester with Logging { def createTaskScheduler(master: String): TaskSchedulerImpl = - createTaskScheduler(master, new SparkConf()) + createTaskScheduler(master, "client") - def createTaskScheduler(master: String, conf: SparkConf): TaskSchedulerImpl = { + def createTaskScheduler(master: String, deployMode: String): TaskSchedulerImpl = + createTaskScheduler(master, deployMode, new SparkConf()) + + def createTaskScheduler( + master: String, + deployMode: String, + conf: SparkConf): TaskSchedulerImpl = { // Create local SparkContext to setup a SparkEnv. We don't actually want to start() the // real schedulers, so we don't want to create a full SparkContext with the desired scheduler. sc = new SparkContext("local", "test", conf) val createTaskSchedulerMethod = PrivateMethod[Tuple2[SchedulerBackend, TaskScheduler]]('createTaskScheduler) - val (_, sched) = SparkContext invokePrivate createTaskSchedulerMethod(sc, master) + val (_, sched) = SparkContext invokePrivate createTaskSchedulerMethod(sc, master, deployMode) sched.asInstanceOf[TaskSchedulerImpl] } @@ -107,7 +113,7 @@ class SparkContextSchedulerCreationSuite test("local-default-parallelism") { val conf = new SparkConf().set("spark.default.parallelism", "16") - val sched = createTaskScheduler("local", conf) + val sched = createTaskScheduler("local", "client", conf) sched.backend match { case s: LocalBackend => assert(s.defaultParallelism() === 16) @@ -122,9 +128,9 @@ class SparkContextSchedulerCreationSuite } } - def testYarn(master: String, expectedClassName: String) { + def testYarn(master: String, deployMode: String, expectedClassName: String) { try { - val sched = createTaskScheduler(master) + val sched = createTaskScheduler(master, deployMode) assert(sched.getClass === Utils.classForName(expectedClassName)) } catch { case e: SparkException => @@ -135,21 +141,17 @@ class SparkContextSchedulerCreationSuite } test("yarn-cluster") { - testYarn("yarn-cluster", "org.apache.spark.scheduler.cluster.YarnClusterScheduler") - } - - test("yarn-standalone") { - testYarn("yarn-standalone", "org.apache.spark.scheduler.cluster.YarnClusterScheduler") + testYarn("yarn", "cluster", "org.apache.spark.scheduler.cluster.YarnClusterScheduler") } test("yarn-client") { - testYarn("yarn-client", "org.apache.spark.scheduler.cluster.YarnScheduler") + testYarn("yarn", "client", "org.apache.spark.scheduler.cluster.YarnScheduler") } def testMesos(master: String, expectedClass: Class[_], coarse: Boolean) { val conf = new SparkConf().set("spark.mesos.coarse", coarse.toString) try { - val sched = createTaskScheduler(master, conf) + val sched = createTaskScheduler(master, "client", conf) assert(sched.backend.getClass === expectedClass) } catch { case e: UnsatisfiedLinkError => diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index fe2c8299a0d91..41ac60ece0eda 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -358,7 +358,8 @@ class SparkSubmitSuite val appArgs = new SparkSubmitArguments(clArgs) val (_, _, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) sysProps("spark.executor.memory") should be ("5g") - sysProps("spark.master") should be ("yarn-cluster") + sysProps("spark.master") should be ("yarn") + sysProps("spark.submit.deployMode") should be ("cluster") mainClass should be ("org.apache.spark.deploy.yarn.Client") } @@ -454,7 +455,7 @@ class SparkSubmitSuite // Test files and archives (Yarn) val clArgs2 = Seq( - "--master", "yarn-client", + "--master", "yarn", "--class", "org.SomeClass", "--files", files, "--archives", archives, @@ -512,7 +513,7 @@ class SparkSubmitSuite writer2.println("spark.yarn.dist.archives " + archives) writer2.close() val clArgs2 = Seq( - "--master", "yarn-client", + "--master", "yarn", "--class", "org.SomeClass", "--properties-file", f2.getPath, "thejar.jar" diff --git a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala index 0c4359c3c2cd5..9686c6621b465 100644 --- a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala @@ -227,7 +227,25 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes val exception = intercept[IllegalArgumentException] { UnifiedMemoryManager(conf2, numCores = 1) } - assert(exception.getMessage.contains("larger heap size")) + assert(exception.getMessage.contains("increase heap size")) + } + + test("insufficient executor memory") { + val systemMemory = 1024 * 1024 + val reservedMemory = 300 * 1024 + val memoryFraction = 0.8 + val conf = new SparkConf() + .set("spark.memory.fraction", memoryFraction.toString) + .set("spark.testing.memory", systemMemory.toString) + .set("spark.testing.reservedMemory", reservedMemory.toString) + val mm = UnifiedMemoryManager(conf, numCores = 1) + + // Try using an executor memory that's too small + val conf2 = conf.clone().set("spark.executor.memory", (reservedMemory / 2).toString) + val exception = intercept[IllegalArgumentException] { + UnifiedMemoryManager(conf2, numCores = 1) + } + assert(exception.getMessage.contains("increase executor memory")) } test("execution can evict cached blocks when there are multiple active tasks (SPARK-12155)") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 850e470ca14d6..c4cf2f9f70755 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.source.JvmSource import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD -import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} +import org.apache.spark.util._ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { @@ -66,6 +66,26 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark assert(TaskContextSuite.completed === true) } + test("calls TaskFailureListeners after failure") { + TaskContextSuite.lastError = null + sc = new SparkContext("local", "test") + val rdd = new RDD[String](sc, List()) { + override def getPartitions = Array[Partition](StubPartition(0)) + override def compute(split: Partition, context: TaskContext) = { + context.addTaskFailureListener((context, error) => TaskContextSuite.lastError = error) + sys.error("damn error") + } + } + val closureSerializer = SparkEnv.get.closureSerializer.newInstance() + val func = (c: TaskContext, i: Iterator[String]) => i.next() + val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func)))) + val task = new ResultTask[String, String](0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0) + intercept[RuntimeException] { + task.run(0, 0, null) + } + assert(TaskContextSuite.lastError.getMessage == "damn error") + } + test("all TaskCompletionListeners should be called even if some fail") { val context = TaskContext.empty() val listener = mock(classOf[TaskCompletionListener]) @@ -80,6 +100,26 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark verify(listener, times(1)).onTaskCompletion(any()) } + test("all TaskFailureListeners should be called even if some fail") { + val context = TaskContext.empty() + val listener = mock(classOf[TaskFailureListener]) + context.addTaskFailureListener((_, _) => throw new Exception("exception in listener1")) + context.addTaskFailureListener(listener) + context.addTaskFailureListener((_, _) => throw new Exception("exception in listener3")) + + val e = intercept[TaskCompletionListenerException] { + context.markTaskFailed(new Exception("exception in task")) + } + + // Make sure listener 2 was called. + verify(listener, times(1)).onTaskFailure(any(), any()) + + // also need to check failure in TaskFailureListener does not mask earlier exception + assert(e.getMessage.contains("exception in listener1")) + assert(e.getMessage.contains("exception in listener3")) + assert(e.getMessage.contains("exception in task")) + } + test("TaskContext.attemptNumber should return attempt number, not task id (SPARK-4014)") { sc = new SparkContext("local[1,2]", "test") // use maxRetries = 2 because we test failed tasks // Check that attemptIds are 0 for all tasks' initial attempts @@ -153,6 +193,8 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark private object TaskContextSuite { @volatile var completed = false + + @volatile var lastError: Throwable = _ } private case class StubPartition(index: Int) extends Partition diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala index a0483f6483889..c1484b0afa85f 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark._ import org.apache.spark.serializer.KryoDistributedTest._ import org.apache.spark.util.Utils -class KryoSerializerDistributedSuite extends SparkFunSuite { +class KryoSerializerDistributedSuite extends SparkFunSuite with LocalSparkContext { test("kryo objects are serialised consistently in different processes") { val conf = new SparkConf(false) @@ -34,7 +34,7 @@ class KryoSerializerDistributedSuite extends SparkFunSuite { val jar = TestUtils.createJarWithClasses(List(AppJarRegistrator.customClassName)) conf.setJars(List(jar.getPath)) - val sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) val original = Thread.currentThread.getContextClassLoader val loader = new java.net.URLClassLoader(Array(jar), Utils.getContextOrSparkClassLoader) SparkEnv.get.serializer.setDefaultClassLoader(loader) @@ -47,8 +47,6 @@ class KryoSerializerDistributedSuite extends SparkFunSuite { // Join the two RDDs, and force evaluation assert(shuffledRDD.join(cachedRDD).collect().size == 1) - - LocalSparkContext.stop(sc) } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala new file mode 100644 index 0000000000000..662b18f667b0d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import scala.concurrent.{Await, ExecutionContext, Future} +import scala.language.implicitConversions + +import org.scalatest.BeforeAndAfterEach +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.{SparkException, SparkFunSuite, TaskContext, TaskContextImpl} + + +class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach { + + private implicit val ec = ExecutionContext.global + private var blockInfoManager: BlockInfoManager = _ + + override protected def beforeEach(): Unit = { + super.beforeEach() + blockInfoManager = new BlockInfoManager() + for (t <- 0 to 4) { + blockInfoManager.registerTask(t) + } + } + + override protected def afterEach(): Unit = { + try { + blockInfoManager = null + } finally { + super.afterEach() + } + } + + private implicit def stringToBlockId(str: String): BlockId = { + TestBlockId(str) + } + + private def newBlockInfo(): BlockInfo = { + new BlockInfo(StorageLevel.MEMORY_ONLY, tellMaster = false) + } + + private def withTaskId[T](taskAttemptId: Long)(block: => T): T = { + try { + TaskContext.setTaskContext(new TaskContextImpl(0, 0, taskAttemptId, 0, null, null)) + block + } finally { + TaskContext.unset() + } + } + + test("initial memory usage") { + assert(blockInfoManager.size === 0) + } + + test("get non-existent block") { + assert(blockInfoManager.get("non-existent-block").isEmpty) + assert(blockInfoManager.lockForReading("non-existent-block").isEmpty) + assert(blockInfoManager.lockForWriting("non-existent-block").isEmpty) + } + + test("basic lockNewBlockForWriting") { + val initialNumMapEntries = blockInfoManager.getNumberOfMapEntries + val blockInfo = newBlockInfo() + withTaskId(1) { + assert(blockInfoManager.lockNewBlockForWriting("block", blockInfo)) + assert(blockInfoManager.get("block").get eq blockInfo) + assert(!blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + assert(blockInfoManager.get("block").get eq blockInfo) + assert(blockInfo.readerCount === 0) + assert(blockInfo.writerTask === 1) + blockInfoManager.unlock("block") + assert(blockInfo.readerCount === 0) + assert(blockInfo.writerTask === BlockInfo.NO_WRITER) + } + assert(blockInfoManager.size === 1) + assert(blockInfoManager.getNumberOfMapEntries === initialNumMapEntries + 1) + } + + test("read locks are reentrant") { + withTaskId(1) { + assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + blockInfoManager.unlock("block") + assert(blockInfoManager.lockForReading("block").isDefined) + assert(blockInfoManager.lockForReading("block").isDefined) + assert(blockInfoManager.get("block").get.readerCount === 2) + assert(blockInfoManager.get("block").get.writerTask === BlockInfo.NO_WRITER) + blockInfoManager.unlock("block") + assert(blockInfoManager.get("block").get.readerCount === 1) + blockInfoManager.unlock("block") + assert(blockInfoManager.get("block").get.readerCount === 0) + } + } + + test("multiple tasks can hold read locks") { + withTaskId(0) { + assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + blockInfoManager.unlock("block") + } + withTaskId(1) { assert(blockInfoManager.lockForReading("block").isDefined) } + withTaskId(2) { assert(blockInfoManager.lockForReading("block").isDefined) } + withTaskId(3) { assert(blockInfoManager.lockForReading("block").isDefined) } + withTaskId(4) { assert(blockInfoManager.lockForReading("block").isDefined) } + assert(blockInfoManager.get("block").get.readerCount === 4) + } + + test("single task can hold write lock") { + withTaskId(0) { + assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + blockInfoManager.unlock("block") + } + withTaskId(1) { + assert(blockInfoManager.lockForWriting("block").isDefined) + assert(blockInfoManager.get("block").get.writerTask === 1) + } + withTaskId(2) { + assert(blockInfoManager.lockForWriting("block", blocking = false).isEmpty) + assert(blockInfoManager.get("block").get.writerTask === 1) + } + } + + test("cannot call lockForWriting while already holding a write lock") { + withTaskId(0) { + assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + blockInfoManager.unlock("block") + } + withTaskId(1) { + assert(blockInfoManager.lockForWriting("block").isDefined) + intercept[IllegalStateException] { + blockInfoManager.lockForWriting("block") + } + blockInfoManager.assertBlockIsLockedForWriting("block") + } + } + + test("assertBlockIsLockedForWriting throws exception if block is not locked") { + intercept[SparkException] { + blockInfoManager.assertBlockIsLockedForWriting("block") + } + withTaskId(BlockInfo.NON_TASK_WRITER) { + intercept[SparkException] { + blockInfoManager.assertBlockIsLockedForWriting("block") + } + } + } + + test("downgrade lock") { + withTaskId(0) { + assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + blockInfoManager.downgradeLock("block") + } + withTaskId(1) { + assert(blockInfoManager.lockForReading("block").isDefined) + } + assert(blockInfoManager.get("block").get.readerCount === 2) + assert(blockInfoManager.get("block").get.writerTask === BlockInfo.NO_WRITER) + } + + test("write lock will block readers") { + withTaskId(0) { + assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + } + val get1Future = Future { + withTaskId(1) { + blockInfoManager.lockForReading("block") + } + } + val get2Future = Future { + withTaskId(2) { + blockInfoManager.lockForReading("block") + } + } + Thread.sleep(300) // Hack to try to ensure that both future tasks are waiting + withTaskId(0) { + blockInfoManager.unlock("block") + } + assert(Await.result(get1Future, 1.seconds).isDefined) + assert(Await.result(get2Future, 1.seconds).isDefined) + assert(blockInfoManager.get("block").get.readerCount === 2) + } + + test("read locks will block writer") { + withTaskId(0) { + assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + blockInfoManager.unlock("block") + blockInfoManager.lockForReading("block") + } + val write1Future = Future { + withTaskId(1) { + blockInfoManager.lockForWriting("block") + } + } + val write2Future = Future { + withTaskId(2) { + blockInfoManager.lockForWriting("block") + } + } + Thread.sleep(300) // Hack to try to ensure that both future tasks are waiting + withTaskId(0) { + blockInfoManager.unlock("block") + } + assert( + Await.result(Future.firstCompletedOf(Seq(write1Future, write2Future)), 1.seconds).isDefined) + val firstWriteWinner = if (write1Future.isCompleted) 1 else 2 + withTaskId(firstWriteWinner) { + blockInfoManager.unlock("block") + } + assert(Await.result(write1Future, 1.seconds).isDefined) + assert(Await.result(write2Future, 1.seconds).isDefined) + } + + test("removing a non-existent block throws IllegalArgumentException") { + withTaskId(0) { + intercept[IllegalArgumentException] { + blockInfoManager.removeBlock("non-existent-block") + } + } + } + + test("removing a block without holding any locks throws IllegalStateException") { + withTaskId(0) { + assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + blockInfoManager.unlock("block") + intercept[IllegalStateException] { + blockInfoManager.removeBlock("block") + } + } + } + + test("removing a block while holding only a read lock throws IllegalStateException") { + withTaskId(0) { + assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + blockInfoManager.unlock("block") + assert(blockInfoManager.lockForReading("block").isDefined) + intercept[IllegalStateException] { + blockInfoManager.removeBlock("block") + } + } + } + + test("removing a block causes blocked callers to receive None") { + withTaskId(0) { + assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + } + val getFuture = Future { + withTaskId(1) { + blockInfoManager.lockForReading("block") + } + } + val writeFuture = Future { + withTaskId(2) { + blockInfoManager.lockForWriting("block") + } + } + Thread.sleep(300) // Hack to try to ensure that both future tasks are waiting + withTaskId(0) { + blockInfoManager.removeBlock("block") + } + assert(Await.result(getFuture, 1.seconds).isEmpty) + assert(Await.result(writeFuture, 1.seconds).isEmpty) + } + + test("releaseAllLocksForTask releases write locks") { + val initialNumMapEntries = blockInfoManager.getNumberOfMapEntries + withTaskId(0) { + assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + } + assert(blockInfoManager.getNumberOfMapEntries === initialNumMapEntries + 3) + blockInfoManager.releaseAllLocksForTask(0) + assert(blockInfoManager.getNumberOfMapEntries === initialNumMapEntries) + } +} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index 3fd6fb4560465..a94d8b424d956 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -190,6 +190,7 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo def putBlockAndGetLocations(blockId: String, level: StorageLevel): Set[BlockManagerId] = { stores.head.putSingle(blockId, new Array[Byte](blockSize), level) + stores.head.releaseLock(blockId) val locations = master.getLocations(blockId).sortBy { _.executorId }.toSet stores.foreach { _.removeBlock(blockId) } master.removeBlock(blockId) @@ -251,6 +252,7 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo // Insert a block with 2x replication and return the number of copies of the block def replicateAndGetNumCopies(blockId: String): Int = { store.putSingle(blockId, new Array[Byte](1000), StorageLevel.MEMORY_AND_DISK_2) + store.releaseLock(blockId) val numLocations = master.getLocations(blockId).size allStores.foreach { _.removeBlock(blockId) } numLocations @@ -288,6 +290,7 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo def replicateAndGetNumCopies(blockId: String, replicationFactor: Int): Int = { val storageLevel = StorageLevel(true, true, false, true, replicationFactor) initialStores.head.putSingle(blockId, new Array[Byte](blockSize), storageLevel) + initialStores.head.releaseLock(blockId) val numLocations = master.getLocations(blockId).size allStores.foreach { _.removeBlock(blockId) } numLocations @@ -355,6 +358,7 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo val blockId = new TestBlockId( "block-with-" + storageLevel.description.replace(" ", "-").toLowerCase) stores(0).putSingle(blockId, new Array[Byte](blockSize), storageLevel) + stores(0).releaseLock(blockId) // Assert that master know two locations for the block val blockLocations = master.getLocations(blockId).map(_.executorId).toSet @@ -367,6 +371,7 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo }.foreach { testStore => val testStoreName = testStore.blockManagerId.executorId assert(testStore.getLocal(blockId).isDefined, s"$blockId was not found in $testStoreName") + testStore.releaseLock(blockId) assert(master.getLocations(blockId).map(_.executorId).toSet.contains(testStoreName), s"master does not have status for ${blockId.name} in $testStoreName") @@ -392,6 +397,7 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo (1 to 10).foreach { i => testStore.putSingle(s"dummy-block-$i", new Array[Byte](1000), MEMORY_ONLY_SER) + testStore.releaseLock(s"dummy-block-$i") } (1 to 10).foreach { i => testStore.removeBlock(s"dummy-block-$i") diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index e1b2c9633edca..e4ab9ee0ebb38 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -45,6 +45,8 @@ import org.apache.spark.util._ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach with PrivateMethodTester with ResetSystemProperties { + import BlockManagerSuite._ + var conf: SparkConf = null var store: BlockManager = null var store2: BlockManager = null @@ -66,6 +68,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER, master: BlockManagerMaster = this.master): BlockManager = { + val serializer = new KryoSerializer(conf) val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1) val blockManager = new BlockManager(name, rpcEnv, master, serializer, conf, @@ -169,14 +172,14 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val a3 = new Array[Byte](4000) // Putting a1, a2 and a3 in memory and telling master only about a1 and a2 - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) - store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY) - store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY, tellMaster = false) + store.putSingleAndReleaseLock("a1", a1, StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock("a2", a2, StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock("a3", a3, StorageLevel.MEMORY_ONLY, tellMaster = false) // Checking whether blocks are in memory - assert(store.getSingle("a1").isDefined, "a1 was not in store") - assert(store.getSingle("a2").isDefined, "a2 was not in store") - assert(store.getSingle("a3").isDefined, "a3 was not in store") + assert(store.getSingleAndReleaseLock("a1").isDefined, "a1 was not in store") + assert(store.getSingleAndReleaseLock("a2").isDefined, "a2 was not in store") + assert(store.getSingleAndReleaseLock("a3").isDefined, "a3 was not in store") // Checking whether master knows about the blocks or not assert(master.getLocations("a1").size > 0, "master was not told about a1") @@ -184,10 +187,10 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(master.getLocations("a3").size === 0, "master was told about a3") // Drop a1 and a2 from memory; this should be reported back to the master - store.dropFromMemory("a1", () => null: Either[Array[Any], ByteBuffer]) - store.dropFromMemory("a2", () => null: Either[Array[Any], ByteBuffer]) - assert(store.getSingle("a1") === None, "a1 not removed from store") - assert(store.getSingle("a2") === None, "a2 not removed from store") + store.dropFromMemoryIfExists("a1", () => null: Either[Array[Any], ByteBuffer]) + store.dropFromMemoryIfExists("a2", () => null: Either[Array[Any], ByteBuffer]) + assert(store.getSingleAndReleaseLock("a1") === None, "a1 not removed from store") + assert(store.getSingleAndReleaseLock("a2") === None, "a2 not removed from store") assert(master.getLocations("a1").size === 0, "master did not remove a1") assert(master.getLocations("a2").size === 0, "master did not remove a2") } @@ -202,7 +205,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_2) + store.putSingleAndReleaseLock("a1", a1, StorageLevel.MEMORY_ONLY_2) store2.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_2) assert(master.getLocations("a1").size === 2, "master did not report 2 locations for a1") assert(master.getLocations("a2").size === 2, "master did not report 2 locations for a2") @@ -215,17 +218,17 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val a3 = new Array[Byte](4000) // Putting a1, a2 and a3 in memory and telling master only about a1 and a2 - store.putSingle("a1-to-remove", a1, StorageLevel.MEMORY_ONLY) - store.putSingle("a2-to-remove", a2, StorageLevel.MEMORY_ONLY) - store.putSingle("a3-to-remove", a3, StorageLevel.MEMORY_ONLY, tellMaster = false) + store.putSingleAndReleaseLock("a1-to-remove", a1, StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock("a2-to-remove", a2, StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock("a3-to-remove", a3, StorageLevel.MEMORY_ONLY, tellMaster = false) // Checking whether blocks are in memory and memory size val memStatus = master.getMemoryStatus.head._2 assert(memStatus._1 == 20000L, "total memory " + memStatus._1 + " should equal 20000") assert(memStatus._2 <= 12000L, "remaining memory " + memStatus._2 + " should <= 12000") - assert(store.getSingle("a1-to-remove").isDefined, "a1 was not in store") - assert(store.getSingle("a2-to-remove").isDefined, "a2 was not in store") - assert(store.getSingle("a3-to-remove").isDefined, "a3 was not in store") + assert(store.getSingleAndReleaseLock("a1-to-remove").isDefined, "a1 was not in store") + assert(store.getSingleAndReleaseLock("a2-to-remove").isDefined, "a2 was not in store") + assert(store.getSingleAndReleaseLock("a3-to-remove").isDefined, "a3 was not in store") // Checking whether master knows about the blocks or not assert(master.getLocations("a1-to-remove").size > 0, "master was not told about a1") @@ -238,15 +241,15 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE master.removeBlock("a3-to-remove") eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - store.getSingle("a1-to-remove") should be (None) + assert(!store.hasLocalBlock("a1-to-remove")) master.getLocations("a1-to-remove") should have size 0 } eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - store.getSingle("a2-to-remove") should be (None) + assert(!store.hasLocalBlock("a2-to-remove")) master.getLocations("a2-to-remove") should have size 0 } eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - store.getSingle("a3-to-remove") should not be (None) + assert(store.hasLocalBlock("a3-to-remove")) master.getLocations("a3-to-remove") should have size 0 } eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { @@ -262,30 +265,30 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val a2 = new Array[Byte](4000) val a3 = new Array[Byte](4000) // Putting a1, a2 and a3 in memory. - store.putSingle(rdd(0, 0), a1, StorageLevel.MEMORY_ONLY) - store.putSingle(rdd(0, 1), a2, StorageLevel.MEMORY_ONLY) - store.putSingle("nonrddblock", a3, StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock(rdd(0, 0), a1, StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock(rdd(0, 1), a2, StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock("nonrddblock", a3, StorageLevel.MEMORY_ONLY) master.removeRdd(0, blocking = false) eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - store.getSingle(rdd(0, 0)) should be (None) + store.getSingleAndReleaseLock(rdd(0, 0)) should be (None) master.getLocations(rdd(0, 0)) should have size 0 } eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - store.getSingle(rdd(0, 1)) should be (None) + store.getSingleAndReleaseLock(rdd(0, 1)) should be (None) master.getLocations(rdd(0, 1)) should have size 0 } eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - store.getSingle("nonrddblock") should not be (None) + store.getSingleAndReleaseLock("nonrddblock") should not be (None) master.getLocations("nonrddblock") should have size (1) } - store.putSingle(rdd(0, 0), a1, StorageLevel.MEMORY_ONLY) - store.putSingle(rdd(0, 1), a2, StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock(rdd(0, 0), a1, StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock(rdd(0, 1), a2, StorageLevel.MEMORY_ONLY) master.removeRdd(0, blocking = true) - store.getSingle(rdd(0, 0)) should be (None) + store.getSingleAndReleaseLock(rdd(0, 0)) should be (None) master.getLocations(rdd(0, 0)) should have size 0 - store.getSingle(rdd(0, 1)) should be (None) + store.getSingleAndReleaseLock(rdd(0, 1)) should be (None) master.getLocations(rdd(0, 1)) should have size 0 } @@ -305,54 +308,54 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE // insert broadcast blocks in both the stores Seq(driverStore, executorStore).foreach { case s => - s.putSingle(broadcast0BlockId, a1, StorageLevel.DISK_ONLY) - s.putSingle(broadcast1BlockId, a2, StorageLevel.DISK_ONLY) - s.putSingle(broadcast2BlockId, a3, StorageLevel.DISK_ONLY) - s.putSingle(broadcast2BlockId2, a4, StorageLevel.DISK_ONLY) + s.putSingleAndReleaseLock(broadcast0BlockId, a1, StorageLevel.DISK_ONLY) + s.putSingleAndReleaseLock(broadcast1BlockId, a2, StorageLevel.DISK_ONLY) + s.putSingleAndReleaseLock(broadcast2BlockId, a3, StorageLevel.DISK_ONLY) + s.putSingleAndReleaseLock(broadcast2BlockId2, a4, StorageLevel.DISK_ONLY) } // verify whether the blocks exist in both the stores Seq(driverStore, executorStore).foreach { case s => - s.getLocal(broadcast0BlockId) should not be (None) - s.getLocal(broadcast1BlockId) should not be (None) - s.getLocal(broadcast2BlockId) should not be (None) - s.getLocal(broadcast2BlockId2) should not be (None) + assert(s.hasLocalBlock(broadcast0BlockId)) + assert(s.hasLocalBlock(broadcast1BlockId)) + assert(s.hasLocalBlock(broadcast2BlockId)) + assert(s.hasLocalBlock(broadcast2BlockId2)) } // remove broadcast 0 block only from executors master.removeBroadcast(0, removeFromMaster = false, blocking = true) // only broadcast 0 block should be removed from the executor store - executorStore.getLocal(broadcast0BlockId) should be (None) - executorStore.getLocal(broadcast1BlockId) should not be (None) - executorStore.getLocal(broadcast2BlockId) should not be (None) + assert(!executorStore.hasLocalBlock(broadcast0BlockId)) + assert(executorStore.hasLocalBlock(broadcast1BlockId)) + assert(executorStore.hasLocalBlock(broadcast2BlockId)) // nothing should be removed from the driver store - driverStore.getLocal(broadcast0BlockId) should not be (None) - driverStore.getLocal(broadcast1BlockId) should not be (None) - driverStore.getLocal(broadcast2BlockId) should not be (None) + assert(driverStore.hasLocalBlock(broadcast0BlockId)) + assert(driverStore.hasLocalBlock(broadcast1BlockId)) + assert(driverStore.hasLocalBlock(broadcast2BlockId)) // remove broadcast 0 block from the driver as well master.removeBroadcast(0, removeFromMaster = true, blocking = true) - driverStore.getLocal(broadcast0BlockId) should be (None) - driverStore.getLocal(broadcast1BlockId) should not be (None) + assert(!driverStore.hasLocalBlock(broadcast0BlockId)) + assert(driverStore.hasLocalBlock(broadcast1BlockId)) // remove broadcast 1 block from both the stores asynchronously // and verify all broadcast 1 blocks have been removed master.removeBroadcast(1, removeFromMaster = true, blocking = false) eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - driverStore.getLocal(broadcast1BlockId) should be (None) - executorStore.getLocal(broadcast1BlockId) should be (None) + assert(!driverStore.hasLocalBlock(broadcast1BlockId)) + assert(!executorStore.hasLocalBlock(broadcast1BlockId)) } // remove broadcast 2 from both the stores asynchronously // and verify all broadcast 2 blocks have been removed master.removeBroadcast(2, removeFromMaster = true, blocking = false) eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - driverStore.getLocal(broadcast2BlockId) should be (None) - driverStore.getLocal(broadcast2BlockId2) should be (None) - executorStore.getLocal(broadcast2BlockId) should be (None) - executorStore.getLocal(broadcast2BlockId2) should be (None) + assert(!driverStore.hasLocalBlock(broadcast2BlockId)) + assert(!driverStore.hasLocalBlock(broadcast2BlockId2)) + assert(!executorStore.hasLocalBlock(broadcast2BlockId)) + assert(!executorStore.hasLocalBlock(broadcast2BlockId2)) } executorStore.stop() driverStore.stop() @@ -363,9 +366,9 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE store = makeBlockManager(2000) val a1 = new Array[Byte](400) - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock("a1", a1, StorageLevel.MEMORY_ONLY) - assert(store.getSingle("a1").isDefined, "a1 was not in store") + assert(store.getSingleAndReleaseLock("a1").isDefined, "a1 was not in store") assert(master.getLocations("a1").size > 0, "master was not told about a1") master.removeExecutor(store.blockManagerId.executorId) @@ -381,13 +384,13 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock("a1", a1, StorageLevel.MEMORY_ONLY) assert(master.getLocations("a1").size > 0, "master was not told about a1") master.removeExecutor(store.blockManagerId.executorId) assert(master.getLocations("a1").size == 0, "a1 was not removed from master") - store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock("a2", a2, StorageLevel.MEMORY_ONLY) store.waitForAsyncReregister() assert(master.getLocations("a1").size > 0, "a1 was not reregistered with master") @@ -404,12 +407,13 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE master.removeExecutor(store.blockManagerId.executorId) val t1 = new Thread { override def run() { - store.putIterator("a2", a2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIteratorAndReleaseLock( + "a2", a2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) } } val t2 = new Thread { override def run() { - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock("a1", a1, StorageLevel.MEMORY_ONLY) } } val t3 = new Thread { @@ -425,8 +429,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE t2.join() t3.join() - store.dropFromMemory("a1", () => null: Either[Array[Any], ByteBuffer]) - store.dropFromMemory("a2", () => null: Either[Array[Any], ByteBuffer]) + store.dropFromMemoryIfExists("a1", () => null: Either[Array[Any], ByteBuffer]) + store.dropFromMemoryIfExists("a2", () => null: Either[Array[Any], ByteBuffer]) store.waitForAsyncReregister() } } @@ -437,9 +441,12 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val list2 = List(new Array[Byte](500), new Array[Byte](1000), new Array[Byte](1500)) val list1SizeEstimate = SizeEstimator.estimate(list1.iterator.toArray) val list2SizeEstimate = SizeEstimator.estimate(list2.iterator.toArray) - store.putIterator("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - store.putIterator("list2memory", list2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - store.putIterator("list2disk", list2.iterator, StorageLevel.DISK_ONLY, tellMaster = true) + store.putIteratorAndReleaseLock( + "list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIteratorAndReleaseLock( + "list2memory", list2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIteratorAndReleaseLock( + "list2disk", list2.iterator, StorageLevel.DISK_ONLY, tellMaster = true) val list1Get = store.get("list1") assert(list1Get.isDefined, "list1 expected to be in store") assert(list1Get.get.data.size === 2) @@ -479,8 +486,10 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE store2 = makeBlockManager(8000, "executor2") store3 = makeBlockManager(8000, "executor3") val list1 = List(new Array[Byte](4000)) - store2.putIterator("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - store3.putIterator("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store2.putIteratorAndReleaseLock( + "list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store3.putIteratorAndReleaseLock( + "list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) assert(store.getRemoteBytes("list1").isDefined, "list1Get expected to be fetched") store2.stop() store2 = null @@ -506,18 +515,18 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val a1 = new Array[Byte](4000) val a2 = new Array[Byte](4000) val a3 = new Array[Byte](4000) - store.putSingle("a1", a1, storageLevel) - store.putSingle("a2", a2, storageLevel) - store.putSingle("a3", a3, storageLevel) - assert(store.getSingle("a2").isDefined, "a2 was not in store") - assert(store.getSingle("a3").isDefined, "a3 was not in store") - assert(store.getSingle("a1") === None, "a1 was in store") - assert(store.getSingle("a2").isDefined, "a2 was not in store") + store.putSingleAndReleaseLock("a1", a1, storageLevel) + store.putSingleAndReleaseLock("a2", a2, storageLevel) + store.putSingleAndReleaseLock("a3", a3, storageLevel) + assert(store.getSingleAndReleaseLock("a2").isDefined, "a2 was not in store") + assert(store.getSingleAndReleaseLock("a3").isDefined, "a3 was not in store") + assert(store.getSingleAndReleaseLock("a1") === None, "a1 was in store") + assert(store.getSingleAndReleaseLock("a2").isDefined, "a2 was not in store") // At this point a2 was gotten last, so LRU will getSingle rid of a3 - store.putSingle("a1", a1, storageLevel) - assert(store.getSingle("a1").isDefined, "a1 was not in store") - assert(store.getSingle("a2").isDefined, "a2 was not in store") - assert(store.getSingle("a3") === None, "a3 was in store") + store.putSingleAndReleaseLock("a1", a1, storageLevel) + assert(store.getSingleAndReleaseLock("a1").isDefined, "a1 was not in store") + assert(store.getSingleAndReleaseLock("a2").isDefined, "a2 was not in store") + assert(store.getSingleAndReleaseLock("a3") === None, "a3 was in store") } test("in-memory LRU for partitions of same RDD") { @@ -525,34 +534,34 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val a1 = new Array[Byte](4000) val a2 = new Array[Byte](4000) val a3 = new Array[Byte](4000) - store.putSingle(rdd(0, 1), a1, StorageLevel.MEMORY_ONLY) - store.putSingle(rdd(0, 2), a2, StorageLevel.MEMORY_ONLY) - store.putSingle(rdd(0, 3), a3, StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock(rdd(0, 1), a1, StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock(rdd(0, 2), a2, StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock(rdd(0, 3), a3, StorageLevel.MEMORY_ONLY) // Even though we accessed rdd_0_3 last, it should not have replaced partitions 1 and 2 // from the same RDD - assert(store.getSingle(rdd(0, 3)) === None, "rdd_0_3 was in store") - assert(store.getSingle(rdd(0, 2)).isDefined, "rdd_0_2 was not in store") - assert(store.getSingle(rdd(0, 1)).isDefined, "rdd_0_1 was not in store") + assert(store.getSingleAndReleaseLock(rdd(0, 3)) === None, "rdd_0_3 was in store") + assert(store.getSingleAndReleaseLock(rdd(0, 2)).isDefined, "rdd_0_2 was not in store") + assert(store.getSingleAndReleaseLock(rdd(0, 1)).isDefined, "rdd_0_1 was not in store") // Check that rdd_0_3 doesn't replace them even after further accesses - assert(store.getSingle(rdd(0, 3)) === None, "rdd_0_3 was in store") - assert(store.getSingle(rdd(0, 3)) === None, "rdd_0_3 was in store") - assert(store.getSingle(rdd(0, 3)) === None, "rdd_0_3 was in store") + assert(store.getSingleAndReleaseLock(rdd(0, 3)) === None, "rdd_0_3 was in store") + assert(store.getSingleAndReleaseLock(rdd(0, 3)) === None, "rdd_0_3 was in store") + assert(store.getSingleAndReleaseLock(rdd(0, 3)) === None, "rdd_0_3 was in store") } test("in-memory LRU for partitions of multiple RDDs") { store = makeBlockManager(12000) - store.putSingle(rdd(0, 1), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) - store.putSingle(rdd(0, 2), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) - store.putSingle(rdd(1, 1), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock(rdd(0, 1), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock(rdd(0, 2), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock(rdd(1, 1), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) // At this point rdd_1_1 should've replaced rdd_0_1 assert(store.memoryStore.contains(rdd(1, 1)), "rdd_1_1 was not in store") assert(!store.memoryStore.contains(rdd(0, 1)), "rdd_0_1 was in store") assert(store.memoryStore.contains(rdd(0, 2)), "rdd_0_2 was not in store") // Do a get() on rdd_0_2 so that it is the most recently used item - assert(store.getSingle(rdd(0, 2)).isDefined, "rdd_0_2 was not in store") + assert(store.getSingleAndReleaseLock(rdd(0, 2)).isDefined, "rdd_0_2 was not in store") // Put in more partitions from RDD 0; they should replace rdd_1_1 - store.putSingle(rdd(0, 3), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) - store.putSingle(rdd(0, 4), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock(rdd(0, 3), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock(rdd(0, 4), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) // Now rdd_1_1 should be dropped to add rdd_0_3, but then rdd_0_2 should *not* be dropped // when we try to add rdd_0_4. assert(!store.memoryStore.contains(rdd(1, 1)), "rdd_1_1 was in store") @@ -567,28 +576,28 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) - store.putSingle("a1", a1, StorageLevel.DISK_ONLY) - store.putSingle("a2", a2, StorageLevel.DISK_ONLY) - store.putSingle("a3", a3, StorageLevel.DISK_ONLY) - assert(store.getSingle("a2").isDefined, "a2 was in store") - assert(store.getSingle("a3").isDefined, "a3 was in store") - assert(store.getSingle("a1").isDefined, "a1 was in store") + store.putSingleAndReleaseLock("a1", a1, StorageLevel.DISK_ONLY) + store.putSingleAndReleaseLock("a2", a2, StorageLevel.DISK_ONLY) + store.putSingleAndReleaseLock("a3", a3, StorageLevel.DISK_ONLY) + assert(store.getSingleAndReleaseLock("a2").isDefined, "a2 was in store") + assert(store.getSingleAndReleaseLock("a3").isDefined, "a3 was in store") + assert(store.getSingleAndReleaseLock("a1").isDefined, "a1 was in store") } test("disk and memory storage") { - testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, _.getSingle) + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, _.getSingleAndReleaseLock) } test("disk and memory storage with getLocalBytes") { - testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, _.getLocalBytes) + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, _.getLocalBytesAndReleaseLock) } test("disk and memory storage with serialization") { - testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, _.getSingle) + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, _.getSingleAndReleaseLock) } test("disk and memory storage with serialization and getLocalBytes") { - testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, _.getLocalBytes) + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, _.getLocalBytesAndReleaseLock) } def testDiskAndMemoryStorage( @@ -598,9 +607,9 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val a1 = new Array[Byte](4000) val a2 = new Array[Byte](4000) val a3 = new Array[Byte](4000) - store.putSingle("a1", a1, storageLevel) - store.putSingle("a2", a2, storageLevel) - store.putSingle("a3", a3, storageLevel) + store.putSingleAndReleaseLock("a1", a1, storageLevel) + store.putSingleAndReleaseLock("a2", a2, storageLevel) + store.putSingleAndReleaseLock("a3", a3, storageLevel) assert(accessMethod(store)("a2").isDefined, "a2 was not in store") assert(accessMethod(store)("a3").isDefined, "a3 was not in store") assert(store.memoryStore.getValues("a1").isEmpty, "a1 was in memory store") @@ -615,19 +624,19 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val a3 = new Array[Byte](4000) val a4 = new Array[Byte](4000) // First store a1 and a2, both in memory, and a3, on disk only - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_SER) - store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_SER) - store.putSingle("a3", a3, StorageLevel.DISK_ONLY) + store.putSingleAndReleaseLock("a1", a1, StorageLevel.MEMORY_ONLY_SER) + store.putSingleAndReleaseLock("a2", a2, StorageLevel.MEMORY_ONLY_SER) + store.putSingleAndReleaseLock("a3", a3, StorageLevel.DISK_ONLY) // At this point LRU should not kick in because a3 is only on disk - assert(store.getSingle("a1").isDefined, "a1 was not in store") - assert(store.getSingle("a2").isDefined, "a2 was not in store") - assert(store.getSingle("a3").isDefined, "a3 was not in store") + assert(store.getSingleAndReleaseLock("a1").isDefined, "a1 was not in store") + assert(store.getSingleAndReleaseLock("a2").isDefined, "a2 was not in store") + assert(store.getSingleAndReleaseLock("a3").isDefined, "a3 was not in store") // Now let's add in a4, which uses both disk and memory; a1 should drop out - store.putSingle("a4", a4, StorageLevel.MEMORY_AND_DISK_SER) - assert(store.getSingle("a1") == None, "a1 was in store") - assert(store.getSingle("a2").isDefined, "a2 was not in store") - assert(store.getSingle("a3").isDefined, "a3 was not in store") - assert(store.getSingle("a4").isDefined, "a4 was not in store") + store.putSingleAndReleaseLock("a4", a4, StorageLevel.MEMORY_AND_DISK_SER) + assert(store.getSingleAndReleaseLock("a1") == None, "a1 was in store") + assert(store.getSingleAndReleaseLock("a2").isDefined, "a2 was not in store") + assert(store.getSingleAndReleaseLock("a3").isDefined, "a3 was not in store") + assert(store.getSingleAndReleaseLock("a4").isDefined, "a4 was not in store") } test("in-memory LRU with streams") { @@ -635,23 +644,27 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val list1 = List(new Array[Byte](2000), new Array[Byte](2000)) val list2 = List(new Array[Byte](2000), new Array[Byte](2000)) val list3 = List(new Array[Byte](2000), new Array[Byte](2000)) - store.putIterator("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - store.putIterator("list2", list2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - store.putIterator("list3", list3.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - assert(store.get("list2").isDefined, "list2 was not in store") + store.putIteratorAndReleaseLock( + "list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIteratorAndReleaseLock( + "list2", list2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIteratorAndReleaseLock( + "list3", list3.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + assert(store.getAndReleaseLock("list2").isDefined, "list2 was not in store") assert(store.get("list2").get.data.size === 2) - assert(store.get("list3").isDefined, "list3 was not in store") + assert(store.getAndReleaseLock("list3").isDefined, "list3 was not in store") assert(store.get("list3").get.data.size === 2) - assert(store.get("list1") === None, "list1 was in store") - assert(store.get("list2").isDefined, "list2 was not in store") + assert(store.getAndReleaseLock("list1") === None, "list1 was in store") + assert(store.getAndReleaseLock("list2").isDefined, "list2 was not in store") assert(store.get("list2").get.data.size === 2) // At this point list2 was gotten last, so LRU will getSingle rid of list3 - store.putIterator("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - assert(store.get("list1").isDefined, "list1 was not in store") + store.putIteratorAndReleaseLock( + "list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + assert(store.getAndReleaseLock("list1").isDefined, "list1 was not in store") assert(store.get("list1").get.data.size === 2) - assert(store.get("list2").isDefined, "list2 was not in store") + assert(store.getAndReleaseLock("list2").isDefined, "list2 was not in store") assert(store.get("list2").get.data.size === 2) - assert(store.get("list3") === None, "list1 was in store") + assert(store.getAndReleaseLock("list3") === None, "list1 was in store") } test("LRU with mixed storage levels and streams") { @@ -661,33 +674,37 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val list3 = List(new Array[Byte](2000), new Array[Byte](2000)) val list4 = List(new Array[Byte](2000), new Array[Byte](2000)) // First store list1 and list2, both in memory, and list3, on disk only - store.putIterator("list1", list1.iterator, StorageLevel.MEMORY_ONLY_SER, tellMaster = true) - store.putIterator("list2", list2.iterator, StorageLevel.MEMORY_ONLY_SER, tellMaster = true) - store.putIterator("list3", list3.iterator, StorageLevel.DISK_ONLY, tellMaster = true) + store.putIteratorAndReleaseLock( + "list1", list1.iterator, StorageLevel.MEMORY_ONLY_SER, tellMaster = true) + store.putIteratorAndReleaseLock( + "list2", list2.iterator, StorageLevel.MEMORY_ONLY_SER, tellMaster = true) + store.putIteratorAndReleaseLock( + "list3", list3.iterator, StorageLevel.DISK_ONLY, tellMaster = true) val listForSizeEstimate = new ArrayBuffer[Any] listForSizeEstimate ++= list1.iterator val listSize = SizeEstimator.estimate(listForSizeEstimate) // At this point LRU should not kick in because list3 is only on disk - assert(store.get("list1").isDefined, "list1 was not in store") + assert(store.getAndReleaseLock("list1").isDefined, "list1 was not in store") assert(store.get("list1").get.data.size === 2) - assert(store.get("list2").isDefined, "list2 was not in store") + assert(store.getAndReleaseLock("list2").isDefined, "list2 was not in store") assert(store.get("list2").get.data.size === 2) - assert(store.get("list3").isDefined, "list3 was not in store") + assert(store.getAndReleaseLock("list3").isDefined, "list3 was not in store") assert(store.get("list3").get.data.size === 2) - assert(store.get("list1").isDefined, "list1 was not in store") + assert(store.getAndReleaseLock("list1").isDefined, "list1 was not in store") assert(store.get("list1").get.data.size === 2) - assert(store.get("list2").isDefined, "list2 was not in store") + assert(store.getAndReleaseLock("list2").isDefined, "list2 was not in store") assert(store.get("list2").get.data.size === 2) - assert(store.get("list3").isDefined, "list3 was not in store") + assert(store.getAndReleaseLock("list3").isDefined, "list3 was not in store") assert(store.get("list3").get.data.size === 2) // Now let's add in list4, which uses both disk and memory; list1 should drop out - store.putIterator("list4", list4.iterator, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true) - assert(store.get("list1") === None, "list1 was in store") - assert(store.get("list2").isDefined, "list2 was not in store") + store.putIteratorAndReleaseLock( + "list4", list4.iterator, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true) + assert(store.getAndReleaseLock("list1") === None, "list1 was in store") + assert(store.getAndReleaseLock("list2").isDefined, "list2 was not in store") assert(store.get("list2").get.data.size === 2) - assert(store.get("list3").isDefined, "list3 was not in store") + assert(store.getAndReleaseLock("list3").isDefined, "list3 was not in store") assert(store.get("list3").get.data.size === 2) - assert(store.get("list4").isDefined, "list4 was not in store") + assert(store.getAndReleaseLock("list4").isDefined, "list4 was not in store") assert(store.get("list4").get.data.size === 2) } @@ -705,18 +722,19 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE test("overly large block") { store = makeBlockManager(5000) - store.putSingle("a1", new Array[Byte](10000), StorageLevel.MEMORY_ONLY) - assert(store.getSingle("a1") === None, "a1 was in store") - store.putSingle("a2", new Array[Byte](10000), StorageLevel.MEMORY_AND_DISK) + store.putSingleAndReleaseLock("a1", new Array[Byte](10000), StorageLevel.MEMORY_ONLY) + assert(store.getSingleAndReleaseLock("a1") === None, "a1 was in store") + store.putSingleAndReleaseLock("a2", new Array[Byte](10000), StorageLevel.MEMORY_AND_DISK) assert(store.memoryStore.getValues("a2") === None, "a2 was in memory store") - assert(store.getSingle("a2").isDefined, "a2 was not in store") + assert(store.getSingleAndReleaseLock("a2").isDefined, "a2 was not in store") } test("block compression") { try { conf.set("spark.shuffle.compress", "true") store = makeBlockManager(20000, "exec1") - store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) + store.putSingleAndReleaseLock( + ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) <= 100, "shuffle_0_0_0 was not compressed") store.stop() @@ -724,7 +742,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE conf.set("spark.shuffle.compress", "false") store = makeBlockManager(20000, "exec2") - store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](10000), StorageLevel.MEMORY_ONLY_SER) + store.putSingleAndReleaseLock( + ShuffleBlockId(0, 0, 0), new Array[Byte](10000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) >= 10000, "shuffle_0_0_0 was compressed") store.stop() @@ -732,7 +751,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE conf.set("spark.broadcast.compress", "true") store = makeBlockManager(20000, "exec3") - store.putSingle(BroadcastBlockId(0), new Array[Byte](10000), StorageLevel.MEMORY_ONLY_SER) + store.putSingleAndReleaseLock( + BroadcastBlockId(0), new Array[Byte](10000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(BroadcastBlockId(0)) <= 1000, "broadcast_0 was not compressed") store.stop() @@ -740,28 +760,29 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE conf.set("spark.broadcast.compress", "false") store = makeBlockManager(20000, "exec4") - store.putSingle(BroadcastBlockId(0), new Array[Byte](10000), StorageLevel.MEMORY_ONLY_SER) + store.putSingleAndReleaseLock( + BroadcastBlockId(0), new Array[Byte](10000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(BroadcastBlockId(0)) >= 10000, "broadcast_0 was compressed") store.stop() store = null conf.set("spark.rdd.compress", "true") store = makeBlockManager(20000, "exec5") - store.putSingle(rdd(0, 0), new Array[Byte](10000), StorageLevel.MEMORY_ONLY_SER) + store.putSingleAndReleaseLock(rdd(0, 0), new Array[Byte](10000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(rdd(0, 0)) <= 1000, "rdd_0_0 was not compressed") store.stop() store = null conf.set("spark.rdd.compress", "false") store = makeBlockManager(20000, "exec6") - store.putSingle(rdd(0, 0), new Array[Byte](10000), StorageLevel.MEMORY_ONLY_SER) + store.putSingleAndReleaseLock(rdd(0, 0), new Array[Byte](10000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(rdd(0, 0)) >= 10000, "rdd_0_0 was compressed") store.stop() store = null // Check that any other block types are also kept uncompressed store = makeBlockManager(20000, "exec7") - store.putSingle("other_block", new Array[Byte](10000), StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock("other_block", new Array[Byte](10000), StorageLevel.MEMORY_ONLY) assert(store.memoryStore.getSize("other_block") >= 10000, "other_block was compressed") store.stop() store = null @@ -789,12 +810,12 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE class UnserializableClass val a1 = new UnserializableClass intercept[java.io.NotSerializableException] { - store.putSingle("a1", a1, StorageLevel.DISK_ONLY) + store.putSingleAndReleaseLock("a1", a1, StorageLevel.DISK_ONLY) } // Make sure get a1 doesn't hang and returns None. failAfter(1 second) { - assert(store.getSingle("a1").isEmpty, "a1 should not be in store") + assert(store.getSingleAndReleaseLock("a1").isEmpty, "a1 should not be in store") } } @@ -844,6 +865,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE test("updated block statuses") { store = makeBlockManager(12000) + store.registerTask(0) val list = List.fill(2)(new Array[Byte](2000)) val bigList = List.fill(8)(new Array[Byte](2000)) @@ -860,7 +882,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE // 1 updated block (i.e. list1) val updatedBlocks1 = getUpdatedBlocks { - store.putIterator("list1", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIteratorAndReleaseLock( + "list1", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) } assert(updatedBlocks1.size === 1) assert(updatedBlocks1.head._1 === TestBlockId("list1")) @@ -868,7 +891,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE // 1 updated block (i.e. list2) val updatedBlocks2 = getUpdatedBlocks { - store.putIterator("list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + store.putIteratorAndReleaseLock( + "list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) } assert(updatedBlocks2.size === 1) assert(updatedBlocks2.head._1 === TestBlockId("list2")) @@ -876,7 +900,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE // 2 updated blocks - list1 is kicked out of memory while list3 is added val updatedBlocks3 = getUpdatedBlocks { - store.putIterator("list3", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIteratorAndReleaseLock( + "list3", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) } assert(updatedBlocks3.size === 2) updatedBlocks3.foreach { case (id, status) => @@ -890,7 +915,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE // 2 updated blocks - list2 is kicked out of memory (but put on disk) while list4 is added val updatedBlocks4 = getUpdatedBlocks { - store.putIterator("list4", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIteratorAndReleaseLock( + "list4", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) } assert(updatedBlocks4.size === 2) updatedBlocks4.foreach { case (id, status) => @@ -905,7 +931,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE // No updated blocks - list5 is too big to fit in store and nothing is kicked out val updatedBlocks5 = getUpdatedBlocks { - store.putIterator("list5", bigList.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIteratorAndReleaseLock( + "list5", bigList.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) } assert(updatedBlocks5.size === 0) @@ -929,9 +956,12 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val list = List.fill(2)(new Array[Byte](2000)) // Tell master. By LRU, only list2 and list3 remains. - store.putIterator("list1", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - store.putIterator("list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) - store.putIterator("list3", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIteratorAndReleaseLock( + "list1", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIteratorAndReleaseLock( + "list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + store.putIteratorAndReleaseLock( + "list3", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) // getLocations and getBlockStatus should yield the same locations assert(store.master.getLocations("list1").size === 0) @@ -945,9 +975,12 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.master.getBlockStatus("list3", askSlaves = true).size === 1) // This time don't tell master and see what happens. By LRU, only list5 and list6 remains. - store.putIterator("list4", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = false) - store.putIterator("list5", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) - store.putIterator("list6", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = false) + store.putIteratorAndReleaseLock( + "list4", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = false) + store.putIteratorAndReleaseLock( + "list5", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + store.putIteratorAndReleaseLock( + "list6", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = false) // getLocations should return nothing because the master is not informed // getBlockStatus without asking slaves should have the same result @@ -968,9 +1001,12 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val list = List.fill(2)(new Array[Byte](100)) // insert some blocks - store.putIterator("list1", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) - store.putIterator("list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) - store.putIterator("list3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + store.putIteratorAndReleaseLock( + "list1", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + store.putIteratorAndReleaseLock( + "list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + store.putIteratorAndReleaseLock( + "list3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) // getLocations and getBlockStatus should yield the same locations assert(store.master.getMatchingBlockIds(_.toString.contains("list"), askSlaves = false).size @@ -979,9 +1015,12 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE === 1) // insert some more blocks - store.putIterator("newlist1", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) - store.putIterator("newlist2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) - store.putIterator("newlist3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + store.putIteratorAndReleaseLock( + "newlist1", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + store.putIteratorAndReleaseLock( + "newlist2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + store.putIteratorAndReleaseLock( + "newlist3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) // getLocations and getBlockStatus should yield the same locations assert(store.master.getMatchingBlockIds(_.toString.contains("newlist"), askSlaves = false).size @@ -991,7 +1030,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val blockIds = Seq(RDDBlockId(1, 0), RDDBlockId(1, 1), RDDBlockId(2, 0)) blockIds.foreach { blockId => - store.putIterator(blockId, list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIteratorAndReleaseLock( + blockId, list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) } val matchedBlockIds = store.master.getMatchingBlockIds(_ match { case RDDBlockId(1, _) => true @@ -1002,12 +1042,12 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE test("SPARK-1194 regression: fix the same-RDD rule for cache replacement") { store = makeBlockManager(12000) - store.putSingle(rdd(0, 0), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) - store.putSingle(rdd(1, 0), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock(rdd(0, 0), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock(rdd(1, 0), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) // Access rdd_1_0 to ensure it's not least recently used. - assert(store.getSingle(rdd(1, 0)).isDefined, "rdd_1_0 was not in store") + assert(store.getSingleAndReleaseLock(rdd(1, 0)).isDefined, "rdd_1_0 was not in store") // According to the same-RDD rule, rdd_1_0 should be replaced here. - store.putSingle(rdd(0, 1), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock(rdd(0, 1), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) // rdd_1_0 should have been replaced, even it's not least recently used. assert(store.memoryStore.contains(rdd(0, 0)), "rdd_0_0 was not in store") assert(store.memoryStore.contains(rdd(0, 1)), "rdd_0_1 was not in store") @@ -1086,8 +1126,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE memoryStore.releasePendingUnrollMemoryForThisTask() // Unroll with not enough space. This should succeed after kicking out someBlock1. - store.putIterator("someBlock1", smallList.iterator, StorageLevel.MEMORY_ONLY) - store.putIterator("someBlock2", smallList.iterator, StorageLevel.MEMORY_ONLY) + store.putIteratorAndReleaseLock("someBlock1", smallList.iterator, StorageLevel.MEMORY_ONLY) + store.putIteratorAndReleaseLock("someBlock2", smallList.iterator, StorageLevel.MEMORY_ONLY) unrollResult = memoryStore.unrollSafely("unroll", smallList.iterator) verifyUnroll(smallList.iterator, unrollResult, shouldBeArray = true) assert(memoryStore.currentUnrollMemoryForThisTask === 0) @@ -1098,7 +1138,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE // Unroll huge block with not enough space. Even after ensuring free space of 12000 * 0.4 = // 4800 bytes, there is still not enough room to unroll this block. This returns an iterator. // In the mean time, however, we kicked out someBlock2 before giving up. - store.putIterator("someBlock3", smallList.iterator, StorageLevel.MEMORY_ONLY) + store.putIteratorAndReleaseLock("someBlock3", smallList.iterator, StorageLevel.MEMORY_ONLY) unrollResult = memoryStore.unrollSafely("unroll", bigList.iterator) verifyUnroll(bigList.iterator, unrollResult, shouldBeArray = false) assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator @@ -1130,8 +1170,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE // would not know how to drop them from memory later. memoryStore.remove("b1") memoryStore.remove("b2") - store.putIterator("b1", smallIterator, memOnly) - store.putIterator("b2", smallIterator, memOnly) + store.putIteratorAndReleaseLock("b1", smallIterator, memOnly) + store.putIteratorAndReleaseLock("b2", smallIterator, memOnly) // Unroll with not enough space. This should succeed but kick out b1 in the process. val result3 = memoryStore.putIterator("b3", smallIterator, memOnly, returnValues = true) @@ -1142,7 +1182,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(memoryStore.contains("b3")) assert(memoryStore.currentUnrollMemoryForThisTask === 0) memoryStore.remove("b3") - store.putIterator("b3", smallIterator, memOnly) + store.putIteratorAndReleaseLock("b3", smallIterator, memOnly) // Unroll huge block with not enough space. This should fail and kick out b2 in the process. val result4 = memoryStore.putIterator("b4", bigIterator, memOnly, returnValues = true) @@ -1169,8 +1209,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]] assert(memoryStore.currentUnrollMemoryForThisTask === 0) - store.putIterator("b1", smallIterator, memAndDisk) - store.putIterator("b2", smallIterator, memAndDisk) + store.putIteratorAndReleaseLock("b1", smallIterator, memAndDisk) + store.putIteratorAndReleaseLock("b2", smallIterator, memAndDisk) // Unroll with not enough space. This should succeed but kick out b1 in the process. // Memory store should contain b2 and b3, while disk store should contain only b1 @@ -1183,7 +1223,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(!diskStore.contains("b2")) assert(!diskStore.contains("b3")) memoryStore.remove("b3") - store.putIterator("b3", smallIterator, StorageLevel.MEMORY_ONLY) + store.putIteratorAndReleaseLock("b3", smallIterator, StorageLevel.MEMORY_ONLY) assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Unroll huge block with not enough space. This should fail and drop the new block to disk @@ -1244,6 +1284,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE store = makeBlockManager(12000) val memoryStore = store.memoryStore val blockId = BlockId("rdd_3_10") + store.blockInfoManager.lockNewBlockForWriting( + blockId, new BlockInfo(StorageLevel.MEMORY_ONLY, tellMaster = false)) val result = memoryStore.putBytes(blockId, 13000, () => { fail("A big ByteBuffer that cannot be put into MemoryStore should not be created") }) @@ -1263,4 +1305,104 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(result.size === 10000) assert(result.data === Right(bytes)) } + + test("read-locked blocks cannot be evicted from the MemoryStore") { + store = makeBlockManager(12000) + val arr = new Array[Byte](4000) + // First store a1 and a2, both in memory, and a3, on disk only + store.putSingleAndReleaseLock("a1", arr, StorageLevel.MEMORY_ONLY_SER) + store.putSingleAndReleaseLock("a2", arr, StorageLevel.MEMORY_ONLY_SER) + assert(store.getSingle("a1").isDefined, "a1 was not in store") + assert(store.getSingle("a2").isDefined, "a2 was not in store") + // This put should fail because both a1 and a2 should be read-locked: + store.putSingleAndReleaseLock("a3", arr, StorageLevel.MEMORY_ONLY_SER) + assert(store.getSingle("a3").isEmpty, "a3 was in store") + assert(store.getSingle("a1").isDefined, "a1 was not in store") + assert(store.getSingle("a2").isDefined, "a2 was not in store") + // Release both pins of block a2: + store.releaseLock("a2") + store.releaseLock("a2") + // Block a1 is the least-recently accessed, so an LRU eviction policy would evict it before + // block a2. However, a1 is still pinned so this put of a3 should evict a2 instead: + store.putSingleAndReleaseLock("a3", arr, StorageLevel.MEMORY_ONLY_SER) + assert(store.getSingle("a2").isEmpty, "a2 was in store") + assert(store.getSingle("a1").isDefined, "a1 was not in store") + assert(store.getSingle("a3").isDefined, "a3 was not in store") + } +} + +private object BlockManagerSuite { + + private implicit class BlockManagerTestUtils(store: BlockManager) { + + def putSingleAndReleaseLock( + block: BlockId, + value: Any, + storageLevel: StorageLevel, + tellMaster: Boolean): Unit = { + if (store.putSingle(block, value, storageLevel, tellMaster)) { + store.releaseLock(block) + } + } + + def putSingleAndReleaseLock(block: BlockId, value: Any, storageLevel: StorageLevel): Unit = { + if (store.putSingle(block, value, storageLevel)) { + store.releaseLock(block) + } + } + + def putIteratorAndReleaseLock( + blockId: BlockId, + values: Iterator[Any], + level: StorageLevel): Unit = { + if (store.putIterator(blockId, values, level)) { + store.releaseLock(blockId) + } + } + + def putIteratorAndReleaseLock( + blockId: BlockId, + values: Iterator[Any], + level: StorageLevel, + tellMaster: Boolean): Unit = { + if (store.putIterator(blockId, values, level, tellMaster)) { + store.releaseLock(blockId) + } + } + + def dropFromMemoryIfExists( + blockId: BlockId, + data: () => Either[Array[Any], ByteBuffer]): Unit = { + store.blockInfoManager.lockForWriting(blockId).foreach { info => + val newEffectiveStorageLevel = store.dropFromMemory(blockId, data) + if (newEffectiveStorageLevel.isValid) { + // The block is still present in at least one store, so release the lock + // but don't delete the block info + store.releaseLock(blockId) + } else { + // The block isn't present in any store, so delete the block info so that the + // block can be stored again + store.blockInfoManager.removeBlock(blockId) + } + } + } + + private def wrapGet[T](f: BlockId => Option[T]): BlockId => Option[T] = (blockId: BlockId) => { + val result = f(blockId) + if (result.isDefined) { + store.releaseLock(blockId) + } + result + } + + def hasLocalBlock(blockId: BlockId): Boolean = { + getLocalAndReleaseLock(blockId).isDefined + } + + val getLocalAndReleaseLock: (BlockId) => Option[BlockResult] = wrapGet(store.getLocal) + val getAndReleaseLock: (BlockId) => Option[BlockResult] = wrapGet(store.get) + val getSingleAndReleaseLock: (BlockId) => Option[Any] = wrapGet(store.getSingle) + val getLocalBytesAndReleaseLock: (BlockId) => Option[ByteBuffer] = wrapGet(store.getLocalBytes) + } + } diff --git a/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala index 9de434166bba3..14daa003bc5a6 100644 --- a/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import org.apache.spark.{SparkFunSuite, Success} +import org.apache.spark.{SparkConf, SparkFunSuite, Success} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ @@ -29,9 +29,11 @@ class StorageStatusListenerSuite extends SparkFunSuite { private val bm2 = BlockManagerId("fat", "duck", 2) private val taskInfo1 = new TaskInfo(0, 0, 0, 0, "big", "dog", TaskLocality.ANY, false) private val taskInfo2 = new TaskInfo(0, 0, 0, 0, "fat", "duck", TaskLocality.ANY, false) + private val conf = new SparkConf() test("block manager added/removed") { - val listener = new StorageStatusListener + conf.set("spark.ui.retainedDeadExecutors", "1") + val listener = new StorageStatusListener(conf) // Block manager add assert(listener.executorIdToStorageStatus.size === 0) @@ -53,14 +55,18 @@ class StorageStatusListenerSuite extends SparkFunSuite { assert(listener.executorIdToStorageStatus.size === 1) assert(!listener.executorIdToStorageStatus.get("big").isDefined) assert(listener.executorIdToStorageStatus.get("fat").isDefined) + assert(listener.deadExecutorStorageStatus.size === 1) + assert(listener.deadExecutorStorageStatus(0).blockManagerId.executorId.equals("big")) listener.onBlockManagerRemoved(SparkListenerBlockManagerRemoved(1L, bm2)) assert(listener.executorIdToStorageStatus.size === 0) assert(!listener.executorIdToStorageStatus.get("big").isDefined) assert(!listener.executorIdToStorageStatus.get("fat").isDefined) + assert(listener.deadExecutorStorageStatus.size === 1) + assert(listener.deadExecutorStorageStatus(0).blockManagerId.executorId.equals("fat")) } test("task end without updated blocks") { - val listener = new StorageStatusListener + val listener = new StorageStatusListener(conf) listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm2, 2000L)) val taskMetrics = new TaskMetrics @@ -77,7 +83,7 @@ class StorageStatusListenerSuite extends SparkFunSuite { } test("task end with updated blocks") { - val listener = new StorageStatusListener + val listener = new StorageStatusListener(conf) listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm2, 2000L)) val taskMetrics1 = new TaskMetrics @@ -126,7 +132,7 @@ class StorageStatusListenerSuite extends SparkFunSuite { } test("unpersist RDD") { - val listener = new StorageStatusListener + val listener = new StorageStatusListener(conf) listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) val taskMetrics1 = new TaskMetrics val taskMetrics2 = new TaskMetrics diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala index d1dbf7c1558b2..6b7c538ac8549 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.ui.storage import org.scalatest.BeforeAndAfter -import org.apache.spark.{SparkFunSuite, Success} +import org.apache.spark.{SparkConf, SparkFunSuite, Success} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.storage._ @@ -44,7 +44,7 @@ class StorageTabSuite extends SparkFunSuite with BeforeAndAfter { before { bus = new LiveListenerBus - storageStatusListener = new StorageStatusListener + storageStatusListener = new StorageStatusListener(new SparkConf()) storageListener = new StorageListener(storageStatusListener) bus.addListener(storageStatusListener) bus.addListener(storageListener) diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 7c6778b065467..412c0ac9d9be3 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -722,6 +722,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { test("isDynamicAllocationEnabled") { val conf = new SparkConf() + conf.set("spark.master", "yarn-client") assert(Utils.isDynamicAllocationEnabled(conf) === false) assert(Utils.isDynamicAllocationEnabled( conf.set("spark.dynamicAllocation.enabled", "false")) === false) @@ -731,6 +732,8 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { conf.set("spark.executor.instances", "1")) === false) assert(Utils.isDynamicAllocationEnabled( conf.set("spark.executor.instances", "0")) === true) + assert(Utils.isDynamicAllocationEnabled(conf.set("spark.master", "local")) === false) + assert(Utils.isDynamicAllocationEnabled(conf.set("spark.dynamicAllocation.testing", "true"))) } test("encodeFileNameToURIRawPath") { diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 2fd7fcc39ea28..c08b6d7de6fe0 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -23,8 +23,8 @@ usage: release-build.sh Creates build deliverables from a Spark commit. Top level targets are - package: Create binary packages and copy them to people.apache - docs: Build docs and copy them to people.apache + package: Create binary packages and copy them to home.apache + docs: Build docs and copy them to home.apache publish-snapshot: Publish snapshot release to Apache snapshots publish-release: Publish a release to Apache release repo @@ -64,13 +64,16 @@ for env in ASF_USERNAME ASF_RSA_KEY GPG_PASSPHRASE GPG_KEY; do fi done +# Explicitly set locale in order to make `sort` output consistent across machines. +# See https://stackoverflow.com/questions/28881 for more details. +export LC_ALL=C + # Commit ref to checkout when building GIT_REF=${GIT_REF:-master} # Destination directory parent on remote server REMOTE_PARENT_DIR=${REMOTE_PARENT_DIR:-/home/$ASF_USERNAME/public_html} -SSH="ssh -o ConnectTimeout=300 -o StrictHostKeyChecking=no -i $ASF_RSA_KEY" GPG="gpg --no-tty --batch" NEXUS_ROOT=https://repository.apache.org/service/local/staging NEXUS_PROFILE=d63f592e7eac0 # Profile for Spark staging uploads @@ -97,7 +100,20 @@ if [ -z "$SPARK_PACKAGE_VERSION" ]; then fi DEST_DIR_NAME="spark-$SPARK_PACKAGE_VERSION" -USER_HOST="$ASF_USERNAME@people.apache.org" + +function LFTP { + SSH="ssh -o ConnectTimeout=300 -o StrictHostKeyChecking=no -i $ASF_RSA_KEY" + COMMANDS=$(cat <postgresql test + + prob=$prob, prediction=$prediction") - } - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/PipelineExample.scala %}
-{% highlight java %} -import java.util.Arrays; -import java.util.List; - -import org.apache.spark.ml.Pipeline; -import org.apache.spark.ml.PipelineModel; -import org.apache.spark.ml.PipelineStage; -import org.apache.spark.ml.classification.LogisticRegression; -import org.apache.spark.ml.feature.HashingTF; -import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; - -// Labeled and unlabeled instance types. -// Spark SQL can infer schema from Java Beans. -public class Document implements Serializable { - private long id; - private String text; - - public Document(long id, String text) { - this.id = id; - this.text = text; - } - - public long getId() { return this.id; } - public void setId(long id) { this.id = id; } - - public String getText() { return this.text; } - public void setText(String text) { this.text = text; } -} - -public class LabeledDocument extends Document implements Serializable { - private double label; - - public LabeledDocument(long id, String text, double label) { - super(id, text); - this.label = label; - } - - public double getLabel() { return this.label; } - public void setLabel(double label) { this.label = label; } -} - -// Prepare training documents, which are labeled. -DataFrame training = sqlContext.createDataFrame(Arrays.asList( - new LabeledDocument(0L, "a b c d e spark", 1.0), - new LabeledDocument(1L, "b d", 0.0), - new LabeledDocument(2L, "spark f g h", 1.0), - new LabeledDocument(3L, "hadoop mapreduce", 0.0) -), LabeledDocument.class); - -// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. -Tokenizer tokenizer = new Tokenizer() - .setInputCol("text") - .setOutputCol("words"); -HashingTF hashingTF = new HashingTF() - .setNumFeatures(1000) - .setInputCol(tokenizer.getOutputCol()) - .setOutputCol("features"); -LogisticRegression lr = new LogisticRegression() - .setMaxIter(10) - .setRegParam(0.01); -Pipeline pipeline = new Pipeline() - .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); - -// Fit the pipeline to training documents. -PipelineModel model = pipeline.fit(training); - -// Prepare test documents, which are unlabeled. -DataFrame test = sqlContext.createDataFrame(Arrays.asList( - new Document(4L, "spark i j k"), - new Document(5L, "l m n"), - new Document(6L, "mapreduce spark"), - new Document(7L, "apache hadoop") -), Document.class); - -// Make predictions on test documents. -DataFrame predictions = model.transform(test); -for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) { - System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) - + ", prediction=" + r.get(3)); -} - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaPipelineExample.java %}
-{% highlight python %} -from pyspark.ml import Pipeline -from pyspark.ml.classification import LogisticRegression -from pyspark.ml.feature import HashingTF, Tokenizer -from pyspark.sql import Row - -# Prepare training documents from a list of (id, text, label) tuples. -LabeledDocument = Row("id", "text", "label") -training = sqlContext.createDataFrame([ - (0L, "a b c d e spark", 1.0), - (1L, "b d", 0.0), - (2L, "spark f g h", 1.0), - (3L, "hadoop mapreduce", 0.0)], ["id", "text", "label"]) - -# Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. -tokenizer = Tokenizer(inputCol="text", outputCol="words") -hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") -lr = LogisticRegression(maxIter=10, regParam=0.01) -pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) - -# Fit the pipeline to training documents. -model = pipeline.fit(training) - -# Prepare test documents, which are unlabeled (id, text) tuples. -test = sqlContext.createDataFrame([ - (4L, "spark i j k"), - (5L, "l m n"), - (6L, "mapreduce spark"), - (7L, "apache hadoop")], ["id", "text"]) - -# Make predictions on test documents and print columns of interest. -prediction = model.transform(test) -selected = prediction.select("id", "text", "prediction") -for row in selected.collect(): - print(row) - -{% endhighlight %} +{% include_example python/ml/pipeline_example.py %}
@@ -646,201 +276,16 @@ However, it is also a well-established method for choosing parameters which is m
-{% highlight scala %} -import org.apache.spark.ml.Pipeline -import org.apache.spark.ml.classification.LogisticRegression -import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator -import org.apache.spark.ml.feature.{HashingTF, Tokenizer} -import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator} -import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.sql.Row - -// Prepare training data from a list of (id, text, label) tuples. -val training = sqlContext.createDataFrame(Seq( - (0L, "a b c d e spark", 1.0), - (1L, "b d", 0.0), - (2L, "spark f g h", 1.0), - (3L, "hadoop mapreduce", 0.0), - (4L, "b spark who", 1.0), - (5L, "g d a y", 0.0), - (6L, "spark fly", 1.0), - (7L, "was mapreduce", 0.0), - (8L, "e spark program", 1.0), - (9L, "a e c l", 0.0), - (10L, "spark compile", 1.0), - (11L, "hadoop software", 0.0) -)).toDF("id", "text", "label") - -// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. -val tokenizer = new Tokenizer() - .setInputCol("text") - .setOutputCol("words") -val hashingTF = new HashingTF() - .setInputCol(tokenizer.getOutputCol) - .setOutputCol("features") -val lr = new LogisticRegression() - .setMaxIter(10) -val pipeline = new Pipeline() - .setStages(Array(tokenizer, hashingTF, lr)) - -// We use a ParamGridBuilder to construct a grid of parameters to search over. -// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, -// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. -val paramGrid = new ParamGridBuilder() - .addGrid(hashingTF.numFeatures, Array(10, 100, 1000)) - .addGrid(lr.regParam, Array(0.1, 0.01)) - .build() - -// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. -// This will allow us to jointly choose parameters for all Pipeline stages. -// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. -// Note that the evaluator here is a BinaryClassificationEvaluator and its default metric -// is areaUnderROC. -val cv = new CrossValidator() - .setEstimator(pipeline) - .setEvaluator(new BinaryClassificationEvaluator) - .setEstimatorParamMaps(paramGrid) - .setNumFolds(2) // Use 3+ in practice - -// Run cross-validation, and choose the best set of parameters. -val cvModel = cv.fit(training) - -// Prepare test documents, which are unlabeled (id, text) tuples. -val test = sqlContext.createDataFrame(Seq( - (4L, "spark i j k"), - (5L, "l m n"), - (6L, "mapreduce spark"), - (7L, "apache hadoop") -)).toDF("id", "text") - -// Make predictions on test documents. cvModel uses the best model found (lrModel). -cvModel.transform(test) - .select("id", "text", "probability", "prediction") - .collect() - .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => - println(s"($id, $text) --> prob=$prob, prediction=$prediction") - } - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala %}
-{% highlight java %} -import java.util.Arrays; -import java.util.List; - -import org.apache.spark.ml.Pipeline; -import org.apache.spark.ml.PipelineStage; -import org.apache.spark.ml.classification.LogisticRegression; -import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; -import org.apache.spark.ml.feature.HashingTF; -import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.ml.param.ParamMap; -import org.apache.spark.ml.tuning.CrossValidator; -import org.apache.spark.ml.tuning.CrossValidatorModel; -import org.apache.spark.ml.tuning.ParamGridBuilder; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; - -// Labeled and unlabeled instance types. -// Spark SQL can infer schema from Java Beans. -public class Document implements Serializable { - private long id; - private String text; - - public Document(long id, String text) { - this.id = id; - this.text = text; - } - - public long getId() { return this.id; } - public void setId(long id) { this.id = id; } - - public String getText() { return this.text; } - public void setText(String text) { this.text = text; } -} - -public class LabeledDocument extends Document implements Serializable { - private double label; - - public LabeledDocument(long id, String text, double label) { - super(id, text); - this.label = label; - } - - public double getLabel() { return this.label; } - public void setLabel(double label) { this.label = label; } -} - - -// Prepare training documents, which are labeled. -DataFrame training = sqlContext.createDataFrame(Arrays.asList( - new LabeledDocument(0L, "a b c d e spark", 1.0), - new LabeledDocument(1L, "b d", 0.0), - new LabeledDocument(2L, "spark f g h", 1.0), - new LabeledDocument(3L, "hadoop mapreduce", 0.0), - new LabeledDocument(4L, "b spark who", 1.0), - new LabeledDocument(5L, "g d a y", 0.0), - new LabeledDocument(6L, "spark fly", 1.0), - new LabeledDocument(7L, "was mapreduce", 0.0), - new LabeledDocument(8L, "e spark program", 1.0), - new LabeledDocument(9L, "a e c l", 0.0), - new LabeledDocument(10L, "spark compile", 1.0), - new LabeledDocument(11L, "hadoop software", 0.0) -), LabeledDocument.class); - -// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. -Tokenizer tokenizer = new Tokenizer() - .setInputCol("text") - .setOutputCol("words"); -HashingTF hashingTF = new HashingTF() - .setNumFeatures(1000) - .setInputCol(tokenizer.getOutputCol()) - .setOutputCol("features"); -LogisticRegression lr = new LogisticRegression() - .setMaxIter(10) - .setRegParam(0.01); -Pipeline pipeline = new Pipeline() - .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); - -// We use a ParamGridBuilder to construct a grid of parameters to search over. -// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, -// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. -ParamMap[] paramGrid = new ParamGridBuilder() - .addGrid(hashingTF.numFeatures(), new int[]{10, 100, 1000}) - .addGrid(lr.regParam(), new double[]{0.1, 0.01}) - .build(); - -// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. -// This will allow us to jointly choose parameters for all Pipeline stages. -// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. -// Note that the evaluator here is a BinaryClassificationEvaluator and its default metric -// is areaUnderROC. -CrossValidator cv = new CrossValidator() - .setEstimator(pipeline) - .setEvaluator(new BinaryClassificationEvaluator()) - .setEstimatorParamMaps(paramGrid) - .setNumFolds(2); // Use 3+ in practice - -// Run cross-validation, and choose the best set of parameters. -CrossValidatorModel cvModel = cv.fit(training); - -// Prepare test documents, which are unlabeled. -DataFrame test = sqlContext.createDataFrame(Arrays.asList( - new Document(4L, "spark i j k"), - new Document(5L, "l m n"), - new Document(6L, "mapreduce spark"), - new Document(7L, "apache hadoop") -), Document.class); - -// Make predictions on test documents. cvModel uses the best model found (lrModel). -DataFrame predictions = cvModel.transform(test); -for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) { - System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) - + ", prediction=" + r.get(3)); -} - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java %} +
+ +
+ +{% include_example python/ml/cross_validator.py %}
@@ -864,91 +309,11 @@ The `ParamMap` which produces the best evaluation metric is selected as the best
-{% highlight scala %} -import org.apache.spark.ml.evaluation.RegressionEvaluator -import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit} - -// Prepare training and test data. -val data = sqlContext.read.format("libsvm").load("data/mllib/sample_linear_regression_data.txt") -val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345) - -val lr = new LinearRegression() - -// We use a ParamGridBuilder to construct a grid of parameters to search over. -// TrainValidationSplit will try all combinations of values and determine best model using -// the evaluator. -val paramGrid = new ParamGridBuilder() - .addGrid(lr.regParam, Array(0.1, 0.01)) - .addGrid(lr.fitIntercept) - .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0)) - .build() - -// In this case the estimator is simply the linear regression. -// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. -val trainValidationSplit = new TrainValidationSplit() - .setEstimator(lr) - .setEvaluator(new RegressionEvaluator) - .setEstimatorParamMaps(paramGrid) - // 80% of the data will be used for training and the remaining 20% for validation. - .setTrainRatio(0.8) - -// Run train validation split, and choose the best set of parameters. -val model = trainValidationSplit.fit(training) - -// Make predictions on test data. model is the model with combination of parameters -// that performed best. -model.transform(test) - .select("features", "label", "prediction") - .show() - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala %}
-{% highlight java %} -import org.apache.spark.ml.evaluation.RegressionEvaluator; -import org.apache.spark.ml.param.ParamMap; -import org.apache.spark.ml.regression.LinearRegression; -import org.apache.spark.ml.tuning.*; -import org.apache.spark.sql.DataFrame; - -DataFrame data = jsql.read().format("libsvm").load("data/mllib/sample_linear_regression_data.txt"); - -// Prepare training and test data. -DataFrame[] splits = data.randomSplit(new double[] {0.9, 0.1}, 12345); -DataFrame training = splits[0]; -DataFrame test = splits[1]; - -LinearRegression lr = new LinearRegression(); - -// We use a ParamGridBuilder to construct a grid of parameters to search over. -// TrainValidationSplit will try all combinations of values and determine best model using -// the evaluator. -ParamMap[] paramGrid = new ParamGridBuilder() - .addGrid(lr.regParam(), new double[] {0.1, 0.01}) - .addGrid(lr.fitIntercept()) - .addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0}) - .build(); - -// In this case the estimator is simply the linear regression. -// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. -TrainValidationSplit trainValidationSplit = new TrainValidationSplit() - .setEstimator(lr) - .setEvaluator(new RegressionEvaluator()) - .setEstimatorParamMaps(paramGrid) - .setTrainRatio(0.8); // 80% for training and the remaining 20% for validation - -// Run train validation split, and choose the best set of parameters. -TrainValidationSplitModel model = trainValidationSplit.fit(training); - -// Make predictions on test data. model is the model with combination of parameters -// that performed best. -model.transform(test) - .select("features", "label", "prediction") - .show(); - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java %}
diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index a8612b6c84fe9..9af48357b3dfc 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -121,12 +121,12 @@ The parameters are listed below roughly in order of descending importance. New These parameters describe the problem you want to solve and your dataset. They should be specified and do not require tuning. -* **`algo`**: `Classification` or `Regression` +* **`algo`**: Type of decision tree, either `Classification` or `Regression`. -* **`numClasses`**: Number of classes (for `Classification` only) +* **`numClasses`**: Number of classes (for `Classification` only). * **`categoricalFeaturesInfo`**: Specifies which features are categorical and how many categorical values each of those features can take. This is given as a map from feature indices to feature arity (number of categories). Any features not in this map are treated as continuous. - * E.g., `Map(0 -> 2, 4 -> 10)` specifies that feature `0` is binary (taking values `0` or `1`) and that feature `4` has 10 categories (values `{0, 1, ..., 9}`). Note that feature indices are 0-based: features `0` and `4` are the 1st and 5th elements of an instance's feature vector. + * For example, `Map(0 -> 2, 4 -> 10)` specifies that feature `0` is binary (taking values `0` or `1`) and that feature `4` has 10 categories (values `{0, 1, ..., 9}`). Note that feature indices are 0-based: features `0` and `4` are the 1st and 5th elements of an instance's feature vector. * Note that you do not have to specify `categoricalFeaturesInfo`. The algorithm will still run and may get reasonable results. However, performance should be better if categorical features are properly designated. ### Stopping criteria diff --git a/docs/mllib-dimensionality-reduction.md b/docs/mllib-dimensionality-reduction.md index 11d8e0bd1d23d..cceddce9f79a6 100644 --- a/docs/mllib-dimensionality-reduction.md +++ b/docs/mllib-dimensionality-reduction.md @@ -64,19 +64,7 @@ passes, $O(n)$ storage on each executor, and $O(n k)$ storage on the driver.
Refer to the [`SingularValueDecomposition` Scala docs](api/scala/index.html#org.apache.spark.mllib.linalg.SingularValueDecomposition) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.linalg.Matrix -import org.apache.spark.mllib.linalg.distributed.RowMatrix -import org.apache.spark.mllib.linalg.SingularValueDecomposition - -val mat: RowMatrix = ... - -// Compute the top 20 singular values and corresponding singular vectors. -val svd: SingularValueDecomposition[RowMatrix, Matrix] = mat.computeSVD(20, computeU = true) -val U: RowMatrix = svd.U // The U factor is a RowMatrix. -val s: Vector = svd.s // The singular values are stored in a local dense vector. -val V: Matrix = svd.V // The V factor is a local dense matrix. -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/SVDExample.scala %} The same code applies to `IndexedRowMatrix` if `U` is defined as an `IndexedRowMatrix`. @@ -84,43 +72,7 @@ The same code applies to `IndexedRowMatrix` if `U` is defined as an
Refer to the [`SingularValueDecomposition` Java docs](api/java/org/apache/spark/mllib/linalg/SingularValueDecomposition.html) for details on the API. -{% highlight java %} -import java.util.LinkedList; - -import org.apache.spark.api.java.*; -import org.apache.spark.mllib.linalg.distributed.RowMatrix; -import org.apache.spark.mllib.linalg.Matrix; -import org.apache.spark.mllib.linalg.SingularValueDecomposition; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.rdd.RDD; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; - -public class SVD { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("SVD Example"); - SparkContext sc = new SparkContext(conf); - - double[][] array = ... - LinkedList rowsList = new LinkedList(); - for (int i = 0; i < array.length; i++) { - Vector currentRow = Vectors.dense(array[i]); - rowsList.add(currentRow); - } - JavaRDD rows = JavaSparkContext.fromSparkContext(sc).parallelize(rowsList); - - // Create a RowMatrix from JavaRDD. - RowMatrix mat = new RowMatrix(rows.rdd()); - - // Compute the top 4 singular values and corresponding singular vectors. - SingularValueDecomposition svd = mat.computeSVD(4, true, 1.0E-9d); - RowMatrix U = svd.U(); - Vector s = svd.s(); - Matrix V = svd.V(); - } -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaSVDExample.java %} The same code applies to `IndexedRowMatrix` if `U` is defined as an `IndexedRowMatrix`. @@ -151,36 +103,14 @@ and use them to project the vectors into a low-dimensional space. Refer to the [`RowMatrix` Scala docs](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.RowMatrix) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.linalg.Matrix -import org.apache.spark.mllib.linalg.distributed.RowMatrix - -val mat: RowMatrix = ... - -// Compute the top 10 principal components. -val pc: Matrix = mat.computePrincipalComponents(10) // Principal components are stored in a local dense matrix. - -// Project the rows to the linear space spanned by the top 10 principal components. -val projected: RowMatrix = mat.multiply(pc) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala %} The following code demonstrates how to compute principal components on source vectors and use them to project the vectors into a low-dimensional space while keeping associated labels: Refer to the [`PCA` Scala docs](api/scala/index.html#org.apache.spark.mllib.feature.PCA) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.feature.PCA - -val data: RDD[LabeledPoint] = ... - -// Compute the top 10 principal components. -val pca = new PCA(10).fit(data.map(_.features)) - -// Project vectors to the linear space spanned by the top 10 principal components, keeping the label -val projected = data.map(p => p.copy(features = pca.transform(p.features))) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/PCAOnSourceVectorExample.scala %}
@@ -192,40 +122,7 @@ The number of columns should be small, e.g, less than 1000. Refer to the [`RowMatrix` Java docs](api/java/org/apache/spark/mllib/linalg/distributed/RowMatrix.html) for details on the API. -{% highlight java %} -import java.util.LinkedList; - -import org.apache.spark.api.java.*; -import org.apache.spark.mllib.linalg.distributed.RowMatrix; -import org.apache.spark.mllib.linalg.Matrix; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.rdd.RDD; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; - -public class PCA { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("PCA Example"); - SparkContext sc = new SparkContext(conf); - - double[][] array = ... - LinkedList rowsList = new LinkedList(); - for (int i = 0; i < array.length; i++) { - Vector currentRow = Vectors.dense(array[i]); - rowsList.add(currentRow); - } - JavaRDD rows = JavaSparkContext.fromSparkContext(sc).parallelize(rowsList); - - // Create a RowMatrix from JavaRDD. - RowMatrix mat = new RowMatrix(rows.rdd()); - - // Compute the top 3 principal components. - Matrix pc = mat.computePrincipalComponents(3); - RowMatrix projected = mat.multiply(pc); - } -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaPCAExample.java %}
diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index aac8f7560a4f8..63665c49bc972 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -170,42 +170,7 @@ error. Refer to the [`SVMWithSGD` Scala docs](api/scala/index.html#org.apache.spark.mllib.classification.SVMWithSGD) and [`SVMModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.classification.SVMModel) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD} -import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics -import org.apache.spark.mllib.util.MLUtils - -// Load training data in LIBSVM format. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") - -// Split data into training (60%) and test (40%). -val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L) -val training = splits(0).cache() -val test = splits(1) - -// Run training algorithm to build the model -val numIterations = 100 -val model = SVMWithSGD.train(training, numIterations) - -// Clear the default threshold. -model.clearThreshold() - -// Compute raw scores on the test set. -val scoreAndLabels = test.map { point => - val score = model.predict(point.features) - (score, point.label) -} - -// Get evaluation metrics. -val metrics = new BinaryClassificationMetrics(scoreAndLabels) -val auROC = metrics.areaUnderROC() - -println("Area under ROC = " + auROC) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = SVMModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala %} The `SVMWithSGD.train()` method by default performs L2 regularization with the regularization parameter set to 1.0. If we want to configure this algorithm, we @@ -216,6 +181,7 @@ variant of SVMs with regularization parameter set to 0.1, and runs the training algorithm for 200 iterations. {% highlight scala %} + import org.apache.spark.mllib.optimization.L1Updater val svmAlg = new SVMWithSGD() @@ -237,61 +203,7 @@ that is equivalent to the provided example in Scala is given below: Refer to the [`SVMWithSGD` Java docs](api/java/org/apache/spark/mllib/classification/SVMWithSGD.html) and [`SVMModel` Java docs](api/java/org/apache/spark/mllib/classification/SVMModel.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.classification.*; -import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; - -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; - -public class SVMClassifier { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("SVM Classifier Example"); - SparkContext sc = new SparkContext(conf); - String path = "data/mllib/sample_libsvm_data.txt"; - JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); - - // Split initial RDD into two... [60% training data, 40% testing data]. - JavaRDD training = data.sample(false, 0.6, 11L); - training.cache(); - JavaRDD test = data.subtract(training); - - // Run training algorithm to build the model. - int numIterations = 100; - final SVMModel model = SVMWithSGD.train(training.rdd(), numIterations); - - // Clear the default threshold. - model.clearThreshold(); - - // Compute raw scores on the test set. - JavaRDD> scoreAndLabels = test.map( - new Function>() { - public Tuple2 call(LabeledPoint p) { - Double score = model.predict(p.features()); - return new Tuple2(score, p.label()); - } - } - ); - - // Get evaluation metrics. - BinaryClassificationMetrics metrics = - new BinaryClassificationMetrics(JavaRDD.toRDD(scoreAndLabels)); - double auROC = metrics.areaUnderROC(); - - System.out.println("Area under ROC = " + auROC); - - // Save and load model - model.save(sc, "myModelPath"); - SVMModel sameModel = SVMModel.load(sc, "myModelPath"); - } -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaSVMWithSGDExample.java %} The `SVMWithSGD.train()` method by default performs L2 regularization with the regularization parameter set to 1.0. If we want to configure this algorithm, we @@ -325,30 +237,7 @@ and make predictions with the resulting model to compute the training error. Refer to the [`SVMWithSGD` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.classification.SVMWithSGD) and [`SVMModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.classification.SVMModel) for more details on the API. -{% highlight python %} -from pyspark.mllib.classification import SVMWithSGD, SVMModel -from pyspark.mllib.regression import LabeledPoint - -# Load and parse the data -def parsePoint(line): - values = [float(x) for x in line.split(' ')] - return LabeledPoint(values[0], values[1:]) - -data = sc.textFile("data/mllib/sample_svm_data.txt") -parsedData = data.map(parsePoint) - -# Build the model -model = SVMWithSGD.train(parsedData, iterations=100) - -# Evaluating the model on training data -labelsAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) -trainErr = labelsAndPreds.filter(lambda (v, p): v != p).count() / float(parsedData.count()) -print("Training Error = " + str(trainErr)) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = SVMModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example python/mllib/svm_with_sgd_example.py %} @@ -406,42 +295,7 @@ Then the model is evaluated against the test dataset and saved to disk. Refer to the [`LogisticRegressionWithLBFGS` Scala docs](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS) and [`LogisticRegressionModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionModel) for details on the API. -{% highlight scala %} -import org.apache.spark.SparkContext -import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, LogisticRegressionModel} -import org.apache.spark.mllib.evaluation.MulticlassMetrics -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.MLUtils - -// Load training data in LIBSVM format. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") - -// Split data into training (60%) and test (40%). -val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L) -val training = splits(0).cache() -val test = splits(1) - -// Run training algorithm to build the model -val model = new LogisticRegressionWithLBFGS() - .setNumClasses(10) - .run(training) - -// Compute raw scores on the test set. -val predictionAndLabels = test.map { case LabeledPoint(label, features) => - val prediction = model.predict(features) - (prediction, label) -} - -// Get evaluation metrics. -val metrics = new MulticlassMetrics(predictionAndLabels) -val precision = metrics.precision -println("Precision = " + precision) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = LogisticRegressionModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/LogisticRegressionWithLBFGSExample.scala %} @@ -454,57 +308,7 @@ Then the model is evaluated against the test dataset and saved to disk. Refer to the [`LogisticRegressionWithLBFGS` Java docs](api/java/org/apache/spark/mllib/classification/LogisticRegressionWithLBFGS.html) and [`LogisticRegressionModel` Java docs](api/java/org/apache/spark/mllib/classification/LogisticRegressionModel.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.classification.LogisticRegressionModel; -import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; -import org.apache.spark.mllib.evaluation.MulticlassMetrics; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; - -public class MultinomialLogisticRegressionExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("SVM Classifier Example"); - SparkContext sc = new SparkContext(conf); - String path = "data/mllib/sample_libsvm_data.txt"; - JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); - - // Split initial RDD into two... [60% training data, 40% testing data]. - JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L); - JavaRDD training = splits[0].cache(); - JavaRDD test = splits[1]; - - // Run training algorithm to build the model. - final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() - .setNumClasses(10) - .run(training.rdd()); - - // Compute raw scores on the test set. - JavaRDD> predictionAndLabels = test.map( - new Function>() { - public Tuple2 call(LabeledPoint p) { - Double prediction = model.predict(p.features()); - return new Tuple2(prediction, p.label()); - } - } - ); - - // Get evaluation metrics. - MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); - double precision = metrics.precision(); - System.out.println("Precision = " + precision); - - // Save and load model - model.save(sc, "myModelPath"); - LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath"); - } -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaLogisticRegressionWithLBFGSExample.java %}
@@ -516,30 +320,7 @@ will in the future. Refer to the [`LogisticRegressionWithLBFGS` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.classification.LogisticRegressionWithLBFGS) and [`LogisticRegressionModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.classification.LogisticRegressionModel) for more details on the API. -{% highlight python %} -from pyspark.mllib.classification import LogisticRegressionWithLBFGS, LogisticRegressionModel -from pyspark.mllib.regression import LabeledPoint - -# Load and parse the data -def parsePoint(line): - values = [float(x) for x in line.split(' ')] - return LabeledPoint(values[0], values[1:]) - -data = sc.textFile("data/mllib/sample_svm_data.txt") -parsedData = data.map(parsePoint) - -# Build the model -model = LogisticRegressionWithLBFGS.train(parsedData) - -# Evaluating the model on training data -labelsAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) -trainErr = labelsAndPreds.filter(lambda (v, p): v != p).count() / float(parsedData.count()) -print("Training Error = " + str(trainErr)) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = LogisticRegressionModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example python/mllib/logistic_regression_with_lbfgs_example.py %}
@@ -575,36 +356,7 @@ values. We compute the mean squared error at the end to evaluate Refer to the [`LinearRegressionWithSGD` Scala docs](api/scala/index.html#org.apache.spark.mllib.regression.LinearRegressionWithSGD) and [`LinearRegressionModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.regression.LinearRegressionModel) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.regression.LinearRegressionModel -import org.apache.spark.mllib.regression.LinearRegressionWithSGD -import org.apache.spark.mllib.linalg.Vectors - -// Load and parse the data -val data = sc.textFile("data/mllib/ridge-data/lpsa.data") -val parsedData = data.map { line => - val parts = line.split(',') - LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble))) -}.cache() - -// Building the model -val numIterations = 100 -val stepSize = 0.00000001 -val model = LinearRegressionWithSGD.train(parsedData, numIterations, stepSize) - -// Evaluate model on training examples and compute training error -val valuesAndPreds = parsedData.map { point => - val prediction = model.predict(point.features) - (point.label, prediction) -} -val MSE = valuesAndPreds.map{case(v, p) => math.pow((v - p), 2)}.mean() -println("training Mean Squared Error = " + MSE) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = LinearRegressionModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala %} [`RidgeRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.RidgeRegressionWithSGD) and [`LassoWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.LassoWithSGD) can be used in a similar fashion as `LinearRegressionWithSGD`. @@ -620,70 +372,7 @@ the Scala snippet provided, is presented below: Refer to the [`LinearRegressionWithSGD` Java docs](api/java/org/apache/spark/mllib/regression/LinearRegressionWithSGD.html) and [`LinearRegressionModel` Java docs](api/java/org/apache/spark/mllib/regression/LinearRegressionModel.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.regression.LinearRegressionModel; -import org.apache.spark.mllib.regression.LinearRegressionWithSGD; -import org.apache.spark.SparkConf; - -public class LinearRegression { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Linear Regression Example"); - JavaSparkContext sc = new JavaSparkContext(conf); - - // Load and parse the data - String path = "data/mllib/ridge-data/lpsa.data"; - JavaRDD data = sc.textFile(path); - JavaRDD parsedData = data.map( - new Function() { - public LabeledPoint call(String line) { - String[] parts = line.split(","); - String[] features = parts[1].split(" "); - double[] v = new double[features.length]; - for (int i = 0; i < features.length - 1; i++) - v[i] = Double.parseDouble(features[i]); - return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); - } - } - ); - parsedData.cache(); - - // Building the model - int numIterations = 100; - double stepSize = 0.00000001; - final LinearRegressionModel model = - LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations, stepSize); - - // Evaluate model on training examples and compute training error - JavaRDD> valuesAndPreds = parsedData.map( - new Function>() { - public Tuple2 call(LabeledPoint point) { - double prediction = model.predict(point.features()); - return new Tuple2(prediction, point.label()); - } - } - ); - double MSE = new JavaDoubleRDD(valuesAndPreds.map( - new Function, Object>() { - public Object call(Tuple2 pair) { - return Math.pow(pair._1() - pair._2(), 2.0); - } - } - ).rdd()).mean(); - System.out.println("training Mean Squared Error = " + MSE); - - // Save and load model - model.save(sc.sc(), "myModelPath"); - LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), "myModelPath"); - } -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaLinearRegressionWithSGDExample.java %}
@@ -696,29 +385,7 @@ Note that the Python API does not yet support model save/load but will in the fu Refer to the [`LinearRegressionWithSGD` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.regression.LinearRegressionWithSGD) and [`LinearRegressionModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.regression.LinearRegressionModel) for more details on the API. -{% highlight python %} -from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD, LinearRegressionModel - -# Load and parse the data -def parsePoint(line): - values = [float(x) for x in line.replace(',', ' ').split(' ')] - return LabeledPoint(values[0], values[1:]) - -data = sc.textFile("data/mllib/ridge-data/lpsa.data") -parsedData = data.map(parsePoint) - -# Build the model -model = LinearRegressionWithSGD.train(parsedData, iterations=100, step=0.00000001) - -# Evaluate the model on training data -valuesAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) -MSE = valuesAndPreds.map(lambda (v, p): (v - p)**2).reduce(lambda x, y: x + y) / valuesAndPreds.count() -print("Mean Squared Error = " + str(MSE)) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = LinearRegressionModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example python/mllib/linear_regression_with_sgd_example.py %}
@@ -748,108 +415,50 @@ online to the first stream, and make predictions on the second stream. First, we import the necessary classes for parsing our input data and creating the model. -{% highlight scala %} - -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD - -{% endhighlight %} - Then we make input streams for training and testing data. We assume a StreamingContext `ssc` has already been created, see [Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) for more info. For this example, we use labeled points in training and testing streams, but in practice you will likely want to use unlabeled vectors for test data. -{% highlight scala %} - -val trainingData = ssc.textFileStream("/training/data/dir").map(LabeledPoint.parse).cache() -val testData = ssc.textFileStream("/testing/data/dir").map(LabeledPoint.parse) - -{% endhighlight %} +We create our model by initializing the weights to zero and register the streams for training and +testing then start the job. Printing predictions alongside true labels lets us easily see the +result. -We create our model by initializing the weights to 0 - -{% highlight scala %} - -val numFeatures = 3 -val model = new StreamingLinearRegressionWithSGD() - .setInitialWeights(Vectors.zeros(numFeatures)) - -{% endhighlight %} - -Now we register the streams for training and testing and start the job. -Printing predictions alongside true labels lets us easily see the result. - -{% highlight scala %} - -model.trainOn(trainingData) -model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print() - -ssc.start() -ssc.awaitTermination() - -{% endhighlight %} - -We can now save text files with data to the training or testing folders. +Finally we can save text files with data to the training or testing folders. Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label -and `x1,x2,x3` are the features. Anytime a text file is placed in `/training/data/dir` -the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. +and `x1,x2,x3` are the features. Anytime a text file is placed in `args(0)` +the model will update. Anytime a text file is placed in `args(1)` you will see predictions. As you feed more data to the training directory, the predictions will get better! +Here is a complete example: +{% include_example scala/org/apache/spark/examples/mllib/StreamingLinearRegressionExample.scala %} +
First, we import the necessary classes for parsing our input data and creating the model. -{% highlight python %} -from pyspark.mllib.linalg import Vectors -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.regression import StreamingLinearRegressionWithSGD -{% endhighlight %} - Then we make input streams for training and testing data. We assume a StreamingContext `ssc` has already been created, see [Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) for more info. For this example, we use labeled points in training and testing streams, but in practice you will likely want to use unlabeled vectors for test data. -{% highlight python %} -def parse(lp): - label = float(lp[lp.find('(') + 1: lp.find(',')]) - vec = Vectors.dense(lp[lp.find('[') + 1: lp.find(']')].split(',')) - return LabeledPoint(label, vec) - -trainingData = ssc.textFileStream("/training/data/dir").map(parse).cache() -testData = ssc.textFileStream("/testing/data/dir").map(parse) -{% endhighlight %} - -We create our model by initializing the weights to 0 - -{% highlight python %} -numFeatures = 3 -model = StreamingLinearRegressionWithSGD() -model.setInitialWeights([0.0, 0.0, 0.0]) -{% endhighlight %} +We create our model by initializing the weights to 0. Now we register the streams for training and testing and start the job. -{% highlight python %} -model.trainOn(trainingData) -print(model.predictOnValues(testData.map(lambda lp: (lp.label, lp.features)))) - -ssc.start() -ssc.awaitTermination() -{% endhighlight %} - We can now save text files with data to the training or testing folders. Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label -and `x1,x2,x3` are the features. Anytime a text file is placed in `/training/data/dir` -the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. +and `x1,x2,x3` are the features. Anytime a text file is placed in `sys.argv[1]` +the model will update. Anytime a text file is placed in `sys.argv[2]` you will see predictions. As you feed more data to the training directory, the predictions will get better! +Here a complete example: +{% include_example python/mllib/streaming_linear_regression_example.py %} +
diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 5ebafa40b0c7a..2f0ed5eca2b2b 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -1177,7 +1177,7 @@ that originally created it. In addition, each persisted RDD can be stored using a different *storage level*, allowing you, for example, to persist the dataset on disk, persist it in memory but as serialized Java objects (to save space), -replicate it across nodes, or store it off-heap in [Tachyon](http://tachyon-project.org/). +replicate it across nodes. These levels are set by passing a `StorageLevel` object ([Scala](api/scala/index.html#org.apache.spark.storage.StorageLevel), [Java](api/java/index.html?org/apache/spark/storage/StorageLevel.html), @@ -1218,24 +1218,11 @@ storage levels is: MEMORY_ONLY_2, MEMORY_AND_DISK_2, etc. Same as the levels above, but replicate each partition on two cluster nodes. - - OFF_HEAP (experimental) - Store RDD in serialized format in Tachyon. - Compared to MEMORY_ONLY_SER, OFF_HEAP reduces garbage collection overhead and allows executors - to be smaller and to share a pool of memory, making it attractive in environments with - large heaps or multiple concurrent applications. Furthermore, as the RDDs reside in Tachyon, - the crash of an executor does not lead to losing the in-memory cache. In this mode, the memory - in Tachyon is discardable. Thus, Tachyon does not attempt to reconstruct a block that it evicts - from memory. If you plan to use Tachyon as the off heap store, Spark is compatible with Tachyon - out-of-the-box. Please refer to this page - for the suggested version pairings. - - **Note:** *In Python, stored objects will always be serialized with the [Pickle](https://docs.python.org/2/library/pickle.html) library, so it does not matter whether you choose a serialized level. The available storage levels in Python include `MEMORY_ONLY`, `MEMORY_ONLY_2`, -`MEMORY_AND_DISK`, `MEMORY_AND_DISK_2`, `DISK_ONLY`, `DISK_ONLY_2` and `OFF_HEAP`.* +`MEMORY_AND_DISK`, `MEMORY_AND_DISK_2`, `DISK_ONLY`, and `DISK_ONLY_2`.* Spark also automatically persists some intermediate data in shuffle operations (e.g. `reduceByKey`), even without users calling `persist`. This is done to avoid recomputing the entire input if a node fails during the shuffle. We still recommend users call `persist` on the resulting RDD if they plan to reuse it. @@ -1259,11 +1246,6 @@ requests from a web application). *All* the storage levels provide full fault to recomputing lost data, but the replicated ones let you continue running tasks on the RDD without waiting to recompute a lost partition. -* In environments with high amounts of memory or multiple applications, the experimental `OFF_HEAP` -mode has several advantages: - * It allows multiple executors to share the same pool of memory in Tachyon. - * It significantly reduces garbage collection costs. - * Cached data is not lost if individual executors crash. ### Removing Data diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index b9f64c7ed146e..9816d030e90ac 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -349,8 +349,9 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.uris (none) - A list of URIs to be downloaded to the sandbox when driver or executor is launched by Mesos. - This applies to both coarse-grain and fine-grain mode. + A comma-separated list of URIs to be downloaded to the sandbox + when driver or executor is launched by Mesos. This applies to + both coarse-grained and fine-grained mode. diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index 413532f2f6cfa..cebdb6d910228 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -150,12 +150,6 @@ The master URL passed to Spark can be in one of the following formats: client or cluster mode depending on the value of --deploy-mode. The cluster location will be found based on the HADOOP_CONF_DIR or YARN_CONF_DIR variable. - yarn-client Equivalent to yarn with --deploy-mode client, - which is preferred to `yarn-client` - - yarn-cluster Equivalent to yarn with --deploy-mode cluster, - which is preferred to `yarn-cluster` - diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java new file mode 100644 index 0000000000000..e124c1cf18550 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import java.util.Arrays; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +// $example on$ +import org.apache.spark.ml.clustering.BisectingKMeans; +import org.apache.spark.ml.clustering.BisectingKMeansModel; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + + +/** + * An example demonstrating a bisecting k-means clustering. + */ +public class JavaBisectingKMeansExample { + + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaBisectingKMeansExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.dense(0.1, 0.1, 0.1)), + RowFactory.create(Vectors.dense(0.3, 0.3, 0.25)), + RowFactory.create(Vectors.dense(0.1, 0.1, -0.1)), + RowFactory.create(Vectors.dense(20.3, 20.1, 19.9)), + RowFactory.create(Vectors.dense(20.2, 20.1, 19.7)), + RowFactory.create(Vectors.dense(18.9, 20.0, 19.7)) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("features", new VectorUDT(), false, Metadata.empty()), + }); + + DataFrame dataset = jsql.createDataFrame(data, schema); + + BisectingKMeans bkm = new BisectingKMeans().setK(2); + BisectingKMeansModel model = bkm.fit(dataset); + + System.out.println("Compute Cost: " + model.computeCost(dataset)); + + Vector[] clusterCenters = model.clusterCenters(); + for (int i = 0; i < clusterCenters.length; i++) { + Vector clusterCenter = clusterCenters[i]; + System.out.println("Cluster Center " + i + ": " + clusterCenter); + } + // $example off$ + + jsc.stop(); + } +} diff --git a/core/src/test/java/test/org/apache/spark/JavaTaskCompletionListenerImpl.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDocument.java similarity index 58% rename from core/src/test/java/test/org/apache/spark/JavaTaskCompletionListenerImpl.java rename to examples/src/main/java/org/apache/spark/examples/ml/JavaDocument.java index e38bc38949d7c..6459dabc0698b 100644 --- a/core/src/test/java/test/org/apache/spark/JavaTaskCompletionListenerImpl.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDocument.java @@ -15,25 +15,29 @@ * limitations under the License. */ -package test.org.apache.spark; - -import org.apache.spark.TaskContext; -import org.apache.spark.util.TaskCompletionListener; +package org.apache.spark.examples.ml; +import java.io.Serializable; /** - * A simple implementation of TaskCompletionListener that makes sure TaskCompletionListener and - * TaskContext is Java friendly. + * Unlabeled instance type, Spark SQL can infer schema from Java Beans. */ -public class JavaTaskCompletionListenerImpl implements TaskCompletionListener { +@SuppressWarnings("serial") +public class JavaDocument implements Serializable { + + private long id; + private String text; + + public JavaDocument(long id, String text) { + this.id = id; + this.text = text; + } + + public long getId() { + return this.id; + } - @Override - public void onTaskCompletion(TaskContext context) { - context.isCompleted(); - context.isInterrupted(); - context.stageId(); - context.partitionId(); - context.isRunningLocally(); - context.addTaskCompletionListener(this); + public String getText() { + return this.text; } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java new file mode 100644 index 0000000000000..44cf3507f3743 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; +// $example off$ + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +// $example on$ +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.classification.LogisticRegressionModel; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +// $example off$ +import org.apache.spark.sql.SQLContext; + +/** + * Java example for Estimator, Transformer, and Param. + */ +public class JavaEstimatorTransformerParamExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf() + .setAppName("JavaEstimatorTransformerParamExample"); + SparkContext sc = new SparkContext(conf); + SQLContext sqlContext = new SQLContext(sc); + + // $example on$ + // Prepare training data. + // We use LabeledPoint, which is a JavaBean. Spark SQL can convert RDDs of JavaBeans into + // DataFrames, where it uses the bean metadata to infer the schema. + DataFrame training = sqlContext.createDataFrame( + Arrays.asList( + new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), + new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), + new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), + new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)) + ), LabeledPoint.class); + + // Create a LogisticRegression instance. This instance is an Estimator. + LogisticRegression lr = new LogisticRegression(); + // Print out the parameters, documentation, and any default values. + System.out.println("LogisticRegression parameters:\n" + lr.explainParams() + "\n"); + + // We may set parameters using setter methods. + lr.setMaxIter(10).setRegParam(0.01); + + // Learn a LogisticRegression model. This uses the parameters stored in lr. + LogisticRegressionModel model1 = lr.fit(training); + // Since model1 is a Model (i.e., a Transformer produced by an Estimator), + // we can view the parameters it used during fit(). + // This prints the parameter (name: value) pairs, where names are unique IDs for this + // LogisticRegression instance. + System.out.println("Model 1 was fit using parameters: " + model1.parent().extractParamMap()); + + // We may alternatively specify parameters using a ParamMap. + ParamMap paramMap = new ParamMap() + .put(lr.maxIter().w(20)) // Specify 1 Param. + .put(lr.maxIter(), 30) // This overwrites the original maxIter. + .put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify multiple Params. + + // One can also combine ParamMaps. + ParamMap paramMap2 = new ParamMap() + .put(lr.probabilityCol().w("myProbability")); // Change output column name + ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2); + + // Now learn a new model using the paramMapCombined parameters. + // paramMapCombined overrides all parameters set earlier via lr.set* methods. + LogisticRegressionModel model2 = lr.fit(training, paramMapCombined); + System.out.println("Model 2 was fit using parameters: " + model2.parent().extractParamMap()); + + // Prepare test documents. + DataFrame test = sqlContext.createDataFrame(Arrays.asList( + new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), + new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), + new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)) + ), LabeledPoint.class); + + // Make predictions on test documents using the Transformer.transform() method. + // LogisticRegression.transform will only use the 'features' column. + // Note that model2.transform() outputs a 'myProbability' column instead of the usual + // 'probability' column since we renamed the lr.probabilityCol parameter previously. + DataFrame results = model2.transform(test); + for (Row r : results.select("features", "label", "myProbability", "prediction").collect()) { + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2) + + ", prediction=" + r.get(3)); + } + // $example off$ + + sc.stop(); + } +} diff --git a/core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala b/examples/src/main/java/org/apache/spark/examples/ml/JavaLabeledDocument.java similarity index 64% rename from core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala rename to examples/src/main/java/org/apache/spark/examples/ml/JavaLabeledDocument.java index f64e069cd1724..68d1caf6ad3f3 100644 --- a/core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLabeledDocument.java @@ -15,20 +15,24 @@ * limitations under the License. */ -package org.apache.spark.util +package org.apache.spark.examples.ml; + +import java.io.Serializable; /** - * Exception thrown when there is an exception in - * executing the callback in TaskCompletionListener. + * Labeled instance type, Spark SQL can infer schema from Java Beans. */ -private[spark] -class TaskCompletionListenerException(errorMessages: Seq[String]) extends Exception { +@SuppressWarnings("serial") +public class JavaLabeledDocument extends JavaDocument implements Serializable { + + private double label; + + public JavaLabeledDocument(long id, String text, double label) { + super(id, text); + this.label = label; + } - override def getMessage: String = { - if (errorMessages.size == 1) { - errorMessages.head - } else { - errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n") - } + public double getLabel() { + return this.label; } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java new file mode 100644 index 0000000000000..87ad119491e9a --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; +// $example off$ + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +// $example on$ +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; +import org.apache.spark.ml.feature.HashingTF; +import org.apache.spark.ml.feature.Tokenizer; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.ml.tuning.CrossValidator; +import org.apache.spark.ml.tuning.CrossValidatorModel; +import org.apache.spark.ml.tuning.ParamGridBuilder; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +// $example off$ +import org.apache.spark.sql.SQLContext; + +/** + * Java example for Model Selection via Cross Validation. + */ +public class JavaModelSelectionViaCrossValidationExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf() + .setAppName("JavaModelSelectionViaCrossValidationExample"); + SparkContext sc = new SparkContext(conf); + SQLContext sqlContext = new SQLContext(sc); + + // $example on$ + // Prepare training documents, which are labeled. + DataFrame training = sqlContext.createDataFrame(Arrays.asList( + new JavaLabeledDocument(0L, "a b c d e spark", 1.0), + new JavaLabeledDocument(1L, "b d", 0.0), + new JavaLabeledDocument(2L,"spark f g h", 1.0), + new JavaLabeledDocument(3L, "hadoop mapreduce", 0.0), + new JavaLabeledDocument(4L, "b spark who", 1.0), + new JavaLabeledDocument(5L, "g d a y", 0.0), + new JavaLabeledDocument(6L, "spark fly", 1.0), + new JavaLabeledDocument(7L, "was mapreduce", 0.0), + new JavaLabeledDocument(8L, "e spark program", 1.0), + new JavaLabeledDocument(9L, "a e c l", 0.0), + new JavaLabeledDocument(10L, "spark compile", 1.0), + new JavaLabeledDocument(11L, "hadoop software", 0.0) + ), JavaLabeledDocument.class); + + // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. + Tokenizer tokenizer = new Tokenizer() + .setInputCol("text") + .setOutputCol("words"); + HashingTF hashingTF = new HashingTF() + .setNumFeatures(1000) + .setInputCol(tokenizer.getOutputCol()) + .setOutputCol("features"); + LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.01); + Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); + + // We use a ParamGridBuilder to construct a grid of parameters to search over. + // With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, + // this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. + ParamMap[] paramGrid = new ParamGridBuilder() + .addGrid(hashingTF.numFeatures(), new int[] {10, 100, 1000}) + .addGrid(lr.regParam(), new double[] {0.1, 0.01}) + .build(); + + // We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. + // This will allow us to jointly choose parameters for all Pipeline stages. + // A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. + // Note that the evaluator here is a BinaryClassificationEvaluator and its default metric + // is areaUnderROC. + CrossValidator cv = new CrossValidator() + .setEstimator(pipeline) + .setEvaluator(new BinaryClassificationEvaluator()) + .setEstimatorParamMaps(paramGrid).setNumFolds(2); // Use 3+ in practice + + // Run cross-validation, and choose the best set of parameters. + CrossValidatorModel cvModel = cv.fit(training); + + // Prepare test documents, which are unlabeled. + DataFrame test = sqlContext.createDataFrame(Arrays.asList( + new JavaDocument(4L, "spark i j k"), + new JavaDocument(5L, "l m n"), + new JavaDocument(6L, "mapreduce spark"), + new JavaDocument(7L, "apache hadoop") + ), JavaDocument.class); + + // Make predictions on test documents. cvModel uses the best model found (lrModel). + DataFrame predictions = cvModel.transform(test); + for (Row r : predictions.select("id", "text", "probability", "prediction").collect()) { + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + + ", prediction=" + r.get(3)); + } + // $example off$ + + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java new file mode 100644 index 0000000000000..77adb02dfd9a8 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +// $example on$ +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.ml.regression.LinearRegression; +import org.apache.spark.ml.tuning.ParamGridBuilder; +import org.apache.spark.ml.tuning.TrainValidationSplit; +import org.apache.spark.ml.tuning.TrainValidationSplitModel; +import org.apache.spark.sql.DataFrame; +// $example off$ +import org.apache.spark.sql.SQLContext; + +/** + * Java example for Model Selection via Train Validation Split. + */ +public class JavaModelSelectionViaTrainValidationSplitExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf() + .setAppName("JavaModelSelectionViaTrainValidationSplitExample"); + SparkContext sc = new SparkContext(conf); + SQLContext jsql = new SQLContext(sc); + + // $example on$ + DataFrame data = jsql.read().format("libsvm") + .load("data/mllib/sample_linear_regression_data.txt"); + + // Prepare training and test data. + DataFrame[] splits = data.randomSplit(new double[] {0.9, 0.1}, 12345); + DataFrame training = splits[0]; + DataFrame test = splits[1]; + + LinearRegression lr = new LinearRegression(); + + // We use a ParamGridBuilder to construct a grid of parameters to search over. + // TrainValidationSplit will try all combinations of values and determine best model using + // the evaluator. + ParamMap[] paramGrid = new ParamGridBuilder() + .addGrid(lr.regParam(), new double[] {0.1, 0.01}) + .addGrid(lr.fitIntercept()) + .addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0}) + .build(); + + // In this case the estimator is simply the linear regression. + // A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. + TrainValidationSplit trainValidationSplit = new TrainValidationSplit() + .setEstimator(lr) + .setEvaluator(new RegressionEvaluator()) + .setEstimatorParamMaps(paramGrid) + .setTrainRatio(0.8); // 80% for training and the remaining 20% for validation + + // Run train validation split, and choose the best set of parameters. + TrainValidationSplitModel model = trainValidationSplit.fit(training); + + // Make predictions on test data. model is the model with combination of parameters + // that performed best. + model.transform(test) + .select("features", "label", "prediction") + .show(); + // $example off$ + + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPipelineExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPipelineExample.java new file mode 100644 index 0000000000000..3407c25c83c37 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaPipelineExample.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; +// $example off$ + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +// $example on$ +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.feature.HashingTF; +import org.apache.spark.ml.feature.Tokenizer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +// $example off$ +import org.apache.spark.sql.SQLContext; + +/** + * Java example for simple text document 'Pipeline'. + */ +public class JavaPipelineExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaPipelineExample"); + SparkContext sc = new SparkContext(conf); + SQLContext sqlContext = new SQLContext(sc); + + // $example on$ + // Prepare training documents, which are labeled. + DataFrame training = sqlContext.createDataFrame(Arrays.asList( + new JavaLabeledDocument(0L, "a b c d e spark", 1.0), + new JavaLabeledDocument(1L, "b d", 0.0), + new JavaLabeledDocument(2L, "spark f g h", 1.0), + new JavaLabeledDocument(3L, "hadoop mapreduce", 0.0) + ), JavaLabeledDocument.class); + + // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. + Tokenizer tokenizer = new Tokenizer() + .setInputCol("text") + .setOutputCol("words"); + HashingTF hashingTF = new HashingTF() + .setNumFeatures(1000) + .setInputCol(tokenizer.getOutputCol()) + .setOutputCol("features"); + LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.01); + Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); + + // Fit the pipeline to training documents. + PipelineModel model = pipeline.fit(training); + + // Prepare test documents, which are unlabeled. + DataFrame test = sqlContext.createDataFrame(Arrays.asList( + new JavaDocument(4L, "spark i j k"), + new JavaDocument(5L, "l m n"), + new JavaDocument(6L, "mapreduce spark"), + new JavaDocument(7L, "apache hadoop") + ), JavaDocument.class); + + // Make predictions on test documents. + DataFrame predictions = model.transform(test); + for (Row r : predictions.select("id", "text", "probability", "prediction").collect()) { + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + + ", prediction=" + r.get(3)); + } + // $example off$ + + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java index 0001500f4fa5a..c600094947d5a 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java @@ -33,7 +33,7 @@ // $example off$ /** - * Java example for graph clustering using power iteration clustering (PIC). + * Java example for bisecting k-means clustering. */ public class JavaBisectingKMeansExample { public static void main(String[] args) { @@ -54,9 +54,7 @@ public static void main(String[] args) { BisectingKMeansModel model = bkm.run(data); System.out.println("Compute Cost: " + model.computeCost(data)); - for (Vector center: model.clusterCenters()) { - System.out.println(""); - } + Vector[] clusterCenters = model.clusterCenters(); for (int i = 0; i < clusterCenters.length; i++) { Vector clusterCenter = clusterCenters[i]; diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLinearRegressionWithSGDExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLinearRegressionWithSGDExample.java new file mode 100644 index 0000000000000..3e50118c0d9ec --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLinearRegressionWithSGDExample.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; + +// $example on$ +import scala.Tuple2; + +import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.regression.LinearRegressionModel; +import org.apache.spark.mllib.regression.LinearRegressionWithSGD; +// $example off$ + +/** + * Example for LinearRegressionWithSGD. + */ +public class JavaLinearRegressionWithSGDExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaLinearRegressionWithSGDExample"); + JavaSparkContext sc = new JavaSparkContext(conf); + + // $example on$ + // Load and parse the data + String path = "data/mllib/ridge-data/lpsa.data"; + JavaRDD data = sc.textFile(path); + JavaRDD parsedData = data.map( + new Function() { + public LabeledPoint call(String line) { + String[] parts = line.split(","); + String[] features = parts[1].split(" "); + double[] v = new double[features.length]; + for (int i = 0; i < features.length - 1; i++) { + v[i] = Double.parseDouble(features[i]); + } + return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); + } + } + ); + parsedData.cache(); + + // Building the model + int numIterations = 100; + double stepSize = 0.00000001; + final LinearRegressionModel model = + LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations, stepSize); + + // Evaluate model on training examples and compute training error + JavaRDD> valuesAndPreds = parsedData.map( + new Function>() { + public Tuple2 call(LabeledPoint point) { + double prediction = model.predict(point.features()); + return new Tuple2(prediction, point.label()); + } + } + ); + double MSE = new JavaDoubleRDD(valuesAndPreds.map( + new Function, Object>() { + public Object call(Tuple2 pair) { + return Math.pow(pair._1() - pair._2(), 2.0); + } + } + ).rdd()).mean(); + System.out.println("training Mean Squared Error = " + MSE); + + // Save and load model + model.save(sc.sc(), "target/tmp/javaLinearRegressionWithSGDModel"); + LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), + "target/tmp/javaLinearRegressionWithSGDModel"); + // $example off$ + + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLogisticRegressionWithLBFGSExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLogisticRegressionWithLBFGSExample.java new file mode 100644 index 0000000000000..9d8e4a90dbc99 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLogisticRegressionWithLBFGSExample.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +// $example on$ +import scala.Tuple2; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; +import org.apache.spark.mllib.evaluation.MulticlassMetrics; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +// $example off$ + +/** + * Example for LogisticRegressionWithLBFGS. + */ +public class JavaLogisticRegressionWithLBFGSExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaLogisticRegressionWithLBFGSExample"); + SparkContext sc = new SparkContext(conf); + // $example on$ + String path = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L); + JavaRDD training = splits[0].cache(); + JavaRDD test = splits[1]; + + // Run training algorithm to build the model. + final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() + .setNumClasses(10) + .run(training.rdd()); + + // Compute raw scores on the test set. + JavaRDD> predictionAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double prediction = model.predict(p.features()); + return new Tuple2(prediction, p.label()); + } + } + ); + + // Get evaluation metrics. + MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); + double precision = metrics.precision(); + System.out.println("Precision = " + precision); + + // Save and load model + model.save(sc, "target/tmp/javaLogisticRegressionWithLBFGSModel"); + LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, + "target/tmp/javaLogisticRegressionWithLBFGSModel"); + // $example off$ + + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaPCAExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPCAExample.java new file mode 100644 index 0000000000000..faf76a9540e77 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPCAExample.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.LinkedList; +// $example off$ + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +// $example on$ +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.linalg.distributed.RowMatrix; +// $example off$ + +/** + * Example for compute principal components on a 'RowMatrix'. + */ +public class JavaPCAExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("PCA Example"); + SparkContext sc = new SparkContext(conf); + + // $example on$ + double[][] array = {{1.12, 2.05, 3.12}, {5.56, 6.28, 8.94}, {10.2, 8.0, 20.5}}; + LinkedList rowsList = new LinkedList(); + for (int i = 0; i < array.length; i++) { + Vector currentRow = Vectors.dense(array[i]); + rowsList.add(currentRow); + } + JavaRDD rows = JavaSparkContext.fromSparkContext(sc).parallelize(rowsList); + + // Create a RowMatrix from JavaRDD. + RowMatrix mat = new RowMatrix(rows.rdd()); + + // Compute the top 3 principal components. + Matrix pc = mat.computePrincipalComponents(3); + RowMatrix projected = mat.multiply(pc); + // $example off$ + Vector[] collectPartitions = (Vector[])projected.rows().collect(); + System.out.println("Projected vector of principal component:"); + for (Vector vector : collectPartitions) { + System.out.println("\t" + vector); + } + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVDExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVDExample.java new file mode 100644 index 0000000000000..f3685db9f2fb2 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVDExample.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.LinkedList; +// $example off$ + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +// $example on$ +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.linalg.SingularValueDecomposition; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.linalg.distributed.RowMatrix; +// $example off$ + +/** + * Example for SingularValueDecomposition. + */ +public class JavaSVDExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("SVD Example"); + SparkContext sc = new SparkContext(conf); + + // $example on$ + double[][] array = {{1.12, 2.05, 3.12}, {5.56, 6.28, 8.94}, {10.2, 8.0, 20.5}}; + LinkedList rowsList = new LinkedList(); + for (int i = 0; i < array.length; i++) { + Vector currentRow = Vectors.dense(array[i]); + rowsList.add(currentRow); + } + JavaRDD rows = JavaSparkContext.fromSparkContext(sc).parallelize(rowsList); + + // Create a RowMatrix from JavaRDD. + RowMatrix mat = new RowMatrix(rows.rdd()); + + // Compute the top 3 singular values and corresponding singular vectors. + SingularValueDecomposition svd = mat.computeSVD(3, true, 1.0E-9d); + RowMatrix U = svd.U(); + Vector s = svd.s(); + Matrix V = svd.V(); + // $example off$ + Vector[] collectPartitions = (Vector[]) U.rows().collect(); + System.out.println("U factor is:"); + for (Vector vector : collectPartitions) { + System.out.println("\t" + vector); + } + System.out.println("Singular values are: " + s); + System.out.println("V factor is:\n" + V); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVMWithSGDExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVMWithSGDExample.java new file mode 100644 index 0000000000000..720b167b2cadf --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVMWithSGDExample.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +// $example on$ +import scala.Tuple2; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.SVMModel; +import org.apache.spark.mllib.classification.SVMWithSGD; +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +// $example off$ + +/** + * Example for SVMWithSGD. + */ +public class JavaSVMWithSGDExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaSVMWithSGDExample"); + SparkContext sc = new SparkContext(conf); + // $example on$ + String path = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD training = data.sample(false, 0.6, 11L); + training.cache(); + JavaRDD test = data.subtract(training); + + // Run training algorithm to build the model. + int numIterations = 100; + final SVMModel model = SVMWithSGD.train(training.rdd(), numIterations); + + // Clear the default threshold. + model.clearThreshold(); + + // Compute raw scores on the test set. + JavaRDD> scoreAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double score = model.predict(p.features()); + return new Tuple2(score, p.label()); + } + } + ); + + // Get evaluation metrics. + BinaryClassificationMetrics metrics = + new BinaryClassificationMetrics(JavaRDD.toRDD(scoreAndLabels)); + double auROC = metrics.areaUnderROC(); + + System.out.println("Area under ROC = " + auROC); + + // Save and load model + model.save(sc, "target/tmp/javaSVMWithSGDModel"); + SVMModel sameModel = SVMModel.load(sc, "target/tmp/javaSVMWithSGDModel"); + // $example off$ + + sc.stop(); + } +} diff --git a/examples/src/main/python/ml/cross_validator.py b/examples/src/main/python/ml/cross_validator.py index f0ca97c724940..5f0ef20218c4a 100644 --- a/examples/src/main/python/ml/cross_validator.py +++ b/examples/src/main/python/ml/cross_validator.py @@ -18,12 +18,14 @@ from __future__ import print_function from pyspark import SparkContext +# $example on$ from pyspark.ml import Pipeline from pyspark.ml.classification import LogisticRegression from pyspark.ml.evaluation import BinaryClassificationEvaluator from pyspark.ml.feature import HashingTF, Tokenizer from pyspark.ml.tuning import CrossValidator, ParamGridBuilder from pyspark.sql import Row, SQLContext +# $example off$ """ A simple example demonstrating model selection using CrossValidator. @@ -36,7 +38,7 @@ if __name__ == "__main__": sc = SparkContext(appName="CrossValidatorExample") sqlContext = SQLContext(sc) - + # $example on$ # Prepare training documents, which are labeled. LabeledDocument = Row("id", "text", "label") training = sc.parallelize([(0, "a b c d e spark", 1.0), @@ -92,5 +94,6 @@ selected = prediction.select("id", "text", "probability", "prediction") for row in selected.collect(): print(row) + # $example off$ sc.stop() diff --git a/examples/src/main/python/ml/estimator_transformer_param_example.py b/examples/src/main/python/ml/estimator_transformer_param_example.py new file mode 100644 index 0000000000000..9a8993dac4f65 --- /dev/null +++ b/examples/src/main/python/ml/estimator_transformer_param_example.py @@ -0,0 +1,87 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Estimator Transformer Param Example. +""" +from pyspark import SparkContext, SQLContext +# $example on$ +from pyspark.mllib.linalg import Vectors +from pyspark.ml.classification import LogisticRegression +# $example off$ + +if __name__ == "__main__": + + sc = SparkContext(appName="EstimatorTransformerParamExample") + sqlContext = SQLContext(sc) + + # $example on$ + # Prepare training data from a list of (label, features) tuples. + training = sqlContext.createDataFrame([ + (1.0, Vectors.dense([0.0, 1.1, 0.1])), + (0.0, Vectors.dense([2.0, 1.0, -1.0])), + (0.0, Vectors.dense([2.0, 1.3, 1.0])), + (1.0, Vectors.dense([0.0, 1.2, -0.5]))], ["label", "features"]) + + # Create a LogisticRegression instance. This instance is an Estimator. + lr = LogisticRegression(maxIter=10, regParam=0.01) + # Print out the parameters, documentation, and any default values. + print "LogisticRegression parameters:\n" + lr.explainParams() + "\n" + + # Learn a LogisticRegression model. This uses the parameters stored in lr. + model1 = lr.fit(training) + + # Since model1 is a Model (i.e., a transformer produced by an Estimator), + # we can view the parameters it used during fit(). + # This prints the parameter (name: value) pairs, where names are unique IDs for this + # LogisticRegression instance. + print "Model 1 was fit using parameters: " + print model1.extractParamMap() + + # We may alternatively specify parameters using a Python dictionary as a paramMap + paramMap = {lr.maxIter: 20} + paramMap[lr.maxIter] = 30 # Specify 1 Param, overwriting the original maxIter. + paramMap.update({lr.regParam: 0.1, lr.threshold: 0.55}) # Specify multiple Params. + + # You can combine paramMaps, which are python dictionaries. + paramMap2 = {lr.probabilityCol: "myProbability"} # Change output column name + paramMapCombined = paramMap.copy() + paramMapCombined.update(paramMap2) + + # Now learn a new model using the paramMapCombined parameters. + # paramMapCombined overrides all parameters set earlier via lr.set* methods. + model2 = lr.fit(training, paramMapCombined) + print "Model 2 was fit using parameters: " + print model2.extractParamMap() + + # Prepare test data + test = sqlContext.createDataFrame([ + (1.0, Vectors.dense([-1.0, 1.5, 1.3])), + (0.0, Vectors.dense([3.0, 2.0, -0.1])), + (1.0, Vectors.dense([0.0, 2.2, -1.5]))], ["label", "features"]) + + # Make predictions on test data using the Transformer.transform() method. + # LogisticRegression.transform will only use the 'features' column. + # Note that model2.transform() outputs a "myProbability" column instead of the usual + # 'probability' column since we renamed the lr.probabilityCol parameter previously. + prediction = model2.transform(test) + selected = prediction.select("features", "label", "myProbability", "prediction") + for row in selected.collect(): + print row + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/pipeline_example.py b/examples/src/main/python/ml/pipeline_example.py new file mode 100644 index 0000000000000..3288568f0c287 --- /dev/null +++ b/examples/src/main/python/ml/pipeline_example.py @@ -0,0 +1,64 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Pipeline Example. +""" +from pyspark import SparkContext, SQLContext +# $example on$ +from pyspark.ml import Pipeline +from pyspark.ml.classification import LogisticRegression +from pyspark.ml.feature import HashingTF, Tokenizer +# $example off$ + +if __name__ == "__main__": + + sc = SparkContext(appName="PipelineExample") + sqlContext = SQLContext(sc) + + # $example on$ + # Prepare training documents from a list of (id, text, label) tuples. + training = sqlContext.createDataFrame([ + (0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0)], ["id", "text", "label"]) + + # Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. + tokenizer = Tokenizer(inputCol="text", outputCol="words") + hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") + lr = LogisticRegression(maxIter=10, regParam=0.01) + pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) + + # Fit the pipeline to training documents. + model = pipeline.fit(training) + + # Prepare test documents, which are unlabeled (id, text) tuples. + test = sqlContext.createDataFrame([ + (4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop")], ["id", "text"]) + + # Make predictions on test documents and print columns of interest. + prediction = model.transform(test) + selected = prediction.select("id", "text", "prediction") + for row in selected.collect(): + print(row) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/mllib/linear_regression_with_sgd_example.py b/examples/src/main/python/mllib/linear_regression_with_sgd_example.py new file mode 100644 index 0000000000000..6fbaeff0cd5a0 --- /dev/null +++ b/examples/src/main/python/mllib/linear_regression_with_sgd_example.py @@ -0,0 +1,54 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Linear Regression With SGD Example. +""" +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD, LinearRegressionModel +# $example off$ + +if __name__ == "__main__": + + sc = SparkContext(appName="PythonLinearRegressionWithSGDExample") + + # $example on$ + # Load and parse the data + def parsePoint(line): + values = [float(x) for x in line.replace(',', ' ').split(' ')] + return LabeledPoint(values[0], values[1:]) + + data = sc.textFile("data/mllib/ridge-data/lpsa.data") + parsedData = data.map(parsePoint) + + # Build the model + model = LinearRegressionWithSGD.train(parsedData, iterations=100, step=0.00000001) + + # Evaluate the model on training data + valuesAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) + MSE = valuesAndPreds \ + .map(lambda (v, p): (v - p)**2) \ + .reduce(lambda x, y: x + y) / valuesAndPreds.count() + print("Mean Squared Error = " + str(MSE)) + + # Save and load model + model.save(sc, "target/tmp/pythonLinearRegressionWithSGDModel") + sameModel = LinearRegressionModel.load(sc, "target/tmp/pythonLinearRegressionWithSGDModel") + # $example off$ diff --git a/examples/src/main/python/mllib/logistic_regression_with_lbfgs_example.py b/examples/src/main/python/mllib/logistic_regression_with_lbfgs_example.py new file mode 100644 index 0000000000000..e030b74ba6b15 --- /dev/null +++ b/examples/src/main/python/mllib/logistic_regression_with_lbfgs_example.py @@ -0,0 +1,54 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Logistic Regression With LBFGS Example. +""" +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.classification import LogisticRegressionWithLBFGS, LogisticRegressionModel +from pyspark.mllib.regression import LabeledPoint +# $example off$ + +if __name__ == "__main__": + + sc = SparkContext(appName="PythonLogisticRegressionWithLBFGSExample") + + # $example on$ + # Load and parse the data + def parsePoint(line): + values = [float(x) for x in line.split(' ')] + return LabeledPoint(values[0], values[1:]) + + data = sc.textFile("data/mllib/sample_svm_data.txt") + parsedData = data.map(parsePoint) + + # Build the model + model = LogisticRegressionWithLBFGS.train(parsedData) + + # Evaluating the model on training data + labelsAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) + trainErr = labelsAndPreds.filter(lambda (v, p): v != p).count() / float(parsedData.count()) + print("Training Error = " + str(trainErr)) + + # Save and load model + model.save(sc, "target/tmp/pythonLogisticRegressionWithLBFGSModel") + sameModel = LogisticRegressionModel.load(sc, + "target/tmp/pythonLogisticRegressionWithLBFGSModel") + # $example off$ diff --git a/examples/src/main/python/mllib/naive_bayes_example.py b/examples/src/main/python/mllib/naive_bayes_example.py index f5e120c678fcf..e7d5893d67413 100644 --- a/examples/src/main/python/mllib/naive_bayes_example.py +++ b/examples/src/main/python/mllib/naive_bayes_example.py @@ -17,9 +17,15 @@ """ NaiveBayes Example. + +Usage: + `spark-submit --master local[4] examples/src/main/python/mllib/naive_bayes_example.py` """ + from __future__ import print_function +import shutil + from pyspark import SparkContext # $example on$ from pyspark.mllib.classification import NaiveBayes, NaiveBayesModel @@ -50,8 +56,15 @@ def parseLine(line): # Make prediction and test accuracy. predictionAndLabel = test.map(lambda p: (model.predict(p.features), p.label)) accuracy = 1.0 * predictionAndLabel.filter(lambda (x, v): x == v).count() / test.count() + print('model accuracy {}'.format(accuracy)) # Save and load model - model.save(sc, "target/tmp/myNaiveBayesModel") - sameModel = NaiveBayesModel.load(sc, "target/tmp/myNaiveBayesModel") + output_dir = 'target/tmp/myNaiveBayesModel' + shutil.rmtree(output_dir, ignore_errors=True) + model.save(sc, output_dir) + sameModel = NaiveBayesModel.load(sc, output_dir) + predictionAndLabel = test.map(lambda p: (sameModel.predict(p.features), p.label)) + accuracy = 1.0 * predictionAndLabel.filter(lambda (x, v): x == v).count() / test.count() + print('sameModel accuracy {}'.format(accuracy)) + # $example off$ diff --git a/examples/src/main/python/mllib/streaming_linear_regression_example.py b/examples/src/main/python/mllib/streaming_linear_regression_example.py new file mode 100644 index 0000000000000..f600496867c11 --- /dev/null +++ b/examples/src/main/python/mllib/streaming_linear_regression_example.py @@ -0,0 +1,62 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Streaming Linear Regression Example. +""" +from __future__ import print_function + +# $example on$ +import sys +# $example off$ + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +# $example on$ +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.regression import StreamingLinearRegressionWithSGD +# $example off$ + +if __name__ == "__main__": + if len(sys.argv) != 3: + print("Usage: streaming_linear_regression_example.py ", + file=sys.stderr) + exit(-1) + + sc = SparkContext(appName="PythonLogisticRegressionWithLBFGSExample") + ssc = StreamingContext(sc, 1) + + # $example on$ + def parse(lp): + label = float(lp[lp.find('(') + 1: lp.find(',')]) + vec = Vectors.dense(lp[lp.find('[') + 1: lp.find(']')].split(',')) + return LabeledPoint(label, vec) + + trainingData = ssc.textFileStream(sys.argv[1]).map(parse).cache() + testData = ssc.textFileStream(sys.argv[2]).map(parse) + + numFeatures = 3 + model = StreamingLinearRegressionWithSGD() + model.setInitialWeights([0.0, 0.0, 0.0]) + + model.trainOn(trainingData) + print(model.predictOnValues(testData.map(lambda lp: (lp.label, lp.features)))) + + ssc.start() + ssc.awaitTermination() + # $example off$ diff --git a/examples/src/main/python/mllib/svm_with_sgd_example.py b/examples/src/main/python/mllib/svm_with_sgd_example.py new file mode 100644 index 0000000000000..309ab09cc375a --- /dev/null +++ b/examples/src/main/python/mllib/svm_with_sgd_example.py @@ -0,0 +1,47 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.classification import SVMWithSGD, SVMModel +from pyspark.mllib.regression import LabeledPoint +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PythonSVMWithSGDExample") + + # $example on$ + # Load and parse the data + def parsePoint(line): + values = [float(x) for x in line.split(' ')] + return LabeledPoint(values[0], values[1:]) + + data = sc.textFile("data/mllib/sample_svm_data.txt") + parsedData = data.map(parsePoint) + + # Build the model + model = SVMWithSGD.train(parsedData, iterations=100) + + # Evaluating the model on training data + labelsAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) + trainErr = labelsAndPreds.filter(lambda (v, p): v != p).count() / float(parsedData.count()) + print("Training Error = " + str(trainErr)) + + # Save and load model + model.save(sc, "target/tmp/pythonSVMWithSGDModel") + sameModel = SVMModel.load(sc, "target/tmp/pythonSVMWithSGDModel") + # $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala index 0a477abae5679..7e608a281203e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala @@ -79,7 +79,7 @@ object DataFrameExample { labelSummary.show() // Convert features column to an RDD of vectors. - val features = df.select("features").map { case Row(v: Vector) => v } + val features = df.select("features").rdd.map { case Row(v: Vector) => v } val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())( (summary, feat) => summary.add(feat), (sum1, sum2) => sum1.merge(sum2)) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala index a37d12aa636cd..d2560cc00ba07 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala @@ -310,8 +310,8 @@ object DecisionTreeExample { data: DataFrame, labelColName: String): Unit = { val fullPredictions = model.transform(data).cache() - val predictions = fullPredictions.select("prediction").map(_.getDouble(0)) - val labels = fullPredictions.select(labelColName).map(_.getDouble(0)) + val predictions = fullPredictions.select("prediction").rdd.map(_.getDouble(0)) + val labels = fullPredictions.select(labelColName).rdd.map(_.getDouble(0)) // Print number of classes for reference val numClasses = MetadataUtils.getNumClasses(fullPredictions.schema(labelColName)) match { case Some(n) => n @@ -335,8 +335,8 @@ object DecisionTreeExample { data: DataFrame, labelColName: String): Unit = { val fullPredictions = model.transform(data).cache() - val predictions = fullPredictions.select("prediction").map(_.getDouble(0)) - val labels = fullPredictions.select(labelColName).map(_.getDouble(0)) + val predictions = fullPredictions.select("prediction").rdd.map(_.getDouble(0)) + val labels = fullPredictions.select(labelColName).rdd.map(_.getDouble(0)) val RMSE = new RegressionMetrics(predictions.zip(labels)).rootMeanSquaredError println(s" Root mean squared error (RMSE): $RMSE") } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala new file mode 100644 index 0000000000000..65e3c365abb3f --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.sql.Row +// $example off$ +import org.apache.spark.sql.SQLContext + +object EstimatorTransformerParamExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("EstimatorTransformerParamExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Prepare training data from a list of (label, features) tuples. + val training = sqlContext.createDataFrame(Seq( + (1.0, Vectors.dense(0.0, 1.1, 0.1)), + (0.0, Vectors.dense(2.0, 1.0, -1.0)), + (0.0, Vectors.dense(2.0, 1.3, 1.0)), + (1.0, Vectors.dense(0.0, 1.2, -0.5)) + )).toDF("label", "features") + + // Create a LogisticRegression instance. This instance is an Estimator. + val lr = new LogisticRegression() + // Print out the parameters, documentation, and any default values. + println("LogisticRegression parameters:\n" + lr.explainParams() + "\n") + + // We may set parameters using setter methods. + lr.setMaxIter(10) + .setRegParam(0.01) + + // Learn a LogisticRegression model. This uses the parameters stored in lr. + val model1 = lr.fit(training) + // Since model1 is a Model (i.e., a Transformer produced by an Estimator), + // we can view the parameters it used during fit(). + // This prints the parameter (name: value) pairs, where names are unique IDs for this + // LogisticRegression instance. + println("Model 1 was fit using parameters: " + model1.parent.extractParamMap) + + // We may alternatively specify parameters using a ParamMap, + // which supports several methods for specifying parameters. + val paramMap = ParamMap(lr.maxIter -> 20) + .put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter. + .put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params. + + // One can also combine ParamMaps. + val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name + val paramMapCombined = paramMap ++ paramMap2 + + // Now learn a new model using the paramMapCombined parameters. + // paramMapCombined overrides all parameters set earlier via lr.set* methods. + val model2 = lr.fit(training, paramMapCombined) + println("Model 2 was fit using parameters: " + model2.parent.extractParamMap) + + // Prepare test data. + val test = sqlContext.createDataFrame(Seq( + (1.0, Vectors.dense(-1.0, 1.5, 1.3)), + (0.0, Vectors.dense(3.0, 2.0, -0.1)), + (1.0, Vectors.dense(0.0, 2.2, -1.5)) + )).toDF("label", "features") + + // Make predictions on test data using the Transformer.transform() method. + // LogisticRegression.transform will only use the 'features' column. + // Note that model2.transform() outputs a 'myProbability' column instead of the usual + // 'probability' column since we renamed the lr.probabilityCol parameter previously. + model2.transform(test) + .select("features", "label", "myProbability", "prediction") + .collect() + .foreach { case Row(features: Vector, label: Double, prob: Vector, prediction: Double) => + println(s"($features, $label) -> prob=$prob, prediction=$prediction") + } + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala new file mode 100644 index 0000000000000..0331d6e7b35df --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator +import org.apache.spark.ml.feature.{HashingTF, Tokenizer} +import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder} +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.sql.Row +// $example off$ +import org.apache.spark.sql.SQLContext + +object ModelSelectionViaCrossValidationExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("ModelSelectionViaCrossValidationExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Prepare training data from a list of (id, text, label) tuples. + val training = sqlContext.createDataFrame(Seq( + (0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0), + (4L, "b spark who", 1.0), + (5L, "g d a y", 0.0), + (6L, "spark fly", 1.0), + (7L, "was mapreduce", 0.0), + (8L, "e spark program", 1.0), + (9L, "a e c l", 0.0), + (10L, "spark compile", 1.0), + (11L, "hadoop software", 0.0) + )).toDF("id", "text", "label") + + // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. + val tokenizer = new Tokenizer() + .setInputCol("text") + .setOutputCol("words") + val hashingTF = new HashingTF() + .setInputCol(tokenizer.getOutputCol) + .setOutputCol("features") + val lr = new LogisticRegression() + .setMaxIter(10) + val pipeline = new Pipeline() + .setStages(Array(tokenizer, hashingTF, lr)) + + // We use a ParamGridBuilder to construct a grid of parameters to search over. + // With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, + // this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. + val paramGrid = new ParamGridBuilder() + .addGrid(hashingTF.numFeatures, Array(10, 100, 1000)) + .addGrid(lr.regParam, Array(0.1, 0.01)) + .build() + + // We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. + // This will allow us to jointly choose parameters for all Pipeline stages. + // A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. + // Note that the evaluator here is a BinaryClassificationEvaluator and its default metric + // is areaUnderROC. + val cv = new CrossValidator() + .setEstimator(pipeline) + .setEvaluator(new BinaryClassificationEvaluator) + .setEstimatorParamMaps(paramGrid) + .setNumFolds(2) // Use 3+ in practice + + // Run cross-validation, and choose the best set of parameters. + val cvModel = cv.fit(training) + + // Prepare test documents, which are unlabeled (id, text) tuples. + val test = sqlContext.createDataFrame(Seq( + (4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop") + )).toDF("id", "text") + + // Make predictions on test documents. cvModel uses the best model found (lrModel). + cvModel.transform(test) + .select("id", "text", "probability", "prediction") + .collect() + .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => + println(s"($id, $text) --> prob=$prob, prediction=$prediction") + } + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala new file mode 100644 index 0000000000000..5a95344f223df --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.evaluation.RegressionEvaluator +import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit} +// $example off$ +import org.apache.spark.sql.SQLContext + +object ModelSelectionViaTrainValidationSplitExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("ModelSelectionViaTrainValidationSplitExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Prepare training and test data. + val data = sqlContext.read.format("libsvm").load("data/mllib/sample_linear_regression_data.txt") + val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345) + + val lr = new LinearRegression() + + // We use a ParamGridBuilder to construct a grid of parameters to search over. + // TrainValidationSplit will try all combinations of values and determine best model using + // the evaluator. + val paramGrid = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.01)) + .addGrid(lr.fitIntercept) + .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0)) + .build() + + // In this case the estimator is simply the linear regression. + // A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. + val trainValidationSplit = new TrainValidationSplit() + .setEstimator(lr) + .setEvaluator(new RegressionEvaluator) + .setEstimatorParamMaps(paramGrid) + // 80% of the data will be used for training and the remaining 20% for validation. + .setTrainRatio(0.8) + + // Run train validation split, and choose the best set of parameters. + val model = trainValidationSplit.fit(training) + + // Make predictions on test data. model is the model with combination of parameters + // that performed best. + model.transform(test) + .select("features", "label", "prediction") + .show() + // $example off$ + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala index ccee3b2aef980..a0bb5dabf4574 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala @@ -155,7 +155,7 @@ object OneVsRestExample { // evaluate the model val predictionsAndLabels = predictions.select("prediction", "label") - .map(row => (row.getDouble(0), row.getDouble(1))) + .rdd.map(row => (row.getDouble(0), row.getDouble(1))) val metrics = new MulticlassMetrics(predictionsAndLabels) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PipelineExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PipelineExample.scala new file mode 100644 index 0000000000000..6c29063626bac --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/PipelineExample.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.feature.{HashingTF, Tokenizer} +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.sql.Row +// $example off$ +import org.apache.spark.sql.SQLContext + +object PipelineExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("PipelineExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Prepare training documents from a list of (id, text, label) tuples. + val training = sqlContext.createDataFrame(Seq( + (0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0) + )).toDF("id", "text", "label") + + // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. + val tokenizer = new Tokenizer() + .setInputCol("text") + .setOutputCol("words") + val hashingTF = new HashingTF() + .setNumFeatures(1000) + .setInputCol(tokenizer.getOutputCol) + .setOutputCol("features") + val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.01) + val pipeline = new Pipeline() + .setStages(Array(tokenizer, hashingTF, lr)) + + // Fit the pipeline to training documents. + val model = pipeline.fit(training) + + // Now we can optionally save the fitted pipeline to disk + model.write.overwrite().save("/tmp/spark-logistic-regression-model") + + // We can also save this unfit pipeline to disk + pipeline.write.overwrite().save("/tmp/unfit-lr-model") + + // And load it back in during production + val sameModel = PipelineModel.load("/tmp/spark-logistic-regression-model") + + // Prepare test documents, which are unlabeled (id, text) tuples. + val test = sqlContext.createDataFrame(Seq( + (4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop") + )).toDF("id", "text") + + // Make predictions on test documents. + model.transform(test) + .select("id", "text", "probability", "prediction") + .collect() + .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => + println(s"($id, $text) --> prob=$prob, prediction=$prediction") + } + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index d28323555b990..038b2fe6112e8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -221,6 +221,7 @@ object LDAExample { val model = pipeline.fit(df) val documents = model.transform(df) .select("features") + .rdd .map { case Row(features: Vector) => features } .zipWithIndex() .map(_.swap) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala new file mode 100644 index 0000000000000..669868787e8f0 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.regression.LinearRegressionModel +import org.apache.spark.mllib.regression.LinearRegressionWithSGD +// $example off$ + +object LinearRegressionWithSGDExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("LinearRegressionWithSGDExample") + val sc = new SparkContext(conf) + + // $example on$ + // Load and parse the data + val data = sc.textFile("data/mllib/ridge-data/lpsa.data") + val parsedData = data.map { line => + val parts = line.split(',') + LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble))) + }.cache() + + // Building the model + val numIterations = 100 + val stepSize = 0.00000001 + val model = LinearRegressionWithSGD.train(parsedData, numIterations, stepSize) + + // Evaluate model on training examples and compute training error + val valuesAndPreds = parsedData.map { point => + val prediction = model.predict(point.features) + (point.label, prediction) + } + val MSE = valuesAndPreds.map{ case(v, p) => math.pow((v - p), 2) }.mean() + println("training Mean Squared Error = " + MSE) + + // Save and load model + model.save(sc, "target/tmp/scalaLinearRegressionWithSGDModel") + val sameModel = LinearRegressionModel.load(sc, "target/tmp/scalaLinearRegressionWithSGDModel") + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LogisticRegressionWithLBFGSExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LogisticRegressionWithLBFGSExample.scala new file mode 100644 index 0000000000000..632a2d537e5bc --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LogisticRegressionWithLBFGSExample.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.mllib.classification.{LogisticRegressionModel, LogisticRegressionWithLBFGS} +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils +// $example off$ + +object LogisticRegressionWithLBFGSExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("LogisticRegressionWithLBFGSExample") + val sc = new SparkContext(conf) + + // $example on$ + // Load training data in LIBSVM format. + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + + // Split data into training (60%) and test (40%). + val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L) + val training = splits(0).cache() + val test = splits(1) + + // Run training algorithm to build the model + val model = new LogisticRegressionWithLBFGS() + .setNumClasses(10) + .run(training) + + // Compute raw scores on the test set. + val predictionAndLabels = test.map { case LabeledPoint(label, features) => + val prediction = model.predict(features) + (prediction, label) + } + + // Get evaluation metrics. + val metrics = new MulticlassMetrics(predictionAndLabels) + val precision = metrics.precision + println("Precision = " + precision) + + // Save and load model + model.save(sc, "target/tmp/scalaLogisticRegressionWithLBFGSModel") + val sameModel = LogisticRegressionModel.load(sc, + "target/tmp/scalaLogisticRegressionWithLBFGSModel") + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala new file mode 100644 index 0000000000000..234de230eb201 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +// $example on$ +import org.apache.spark.mllib.linalg.Matrix +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.distributed.RowMatrix +// $example off$ + +object PCAOnRowMatrixExample { + + def main(args: Array[String]): Unit = { + + val conf = new SparkConf().setAppName("PCAOnRowMatrixExample") + val sc = new SparkContext(conf) + + // $example on$ + val data = Array( + Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) + + val dataRDD = sc.parallelize(data, 2) + + val mat: RowMatrix = new RowMatrix(dataRDD) + + // Compute the top 4 principal components. + // Principal components are stored in a local dense matrix. + val pc: Matrix = mat.computePrincipalComponents(4) + + // Project the rows to the linear space spanned by the top 4 principal components. + val projected: RowMatrix = mat.multiply(pc) + // $example off$ + val collect = projected.rows.collect() + println("Projected Row Matrix of principal component:") + collect.foreach { vector => println(vector) } + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnSourceVectorExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnSourceVectorExample.scala new file mode 100644 index 0000000000000..f7694879dfbdb --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnSourceVectorExample.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +// $example on$ +import org.apache.spark.mllib.feature.PCA +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD +// $example off$ + +object PCAOnSourceVectorExample { + + def main(args: Array[String]): Unit = { + + val conf = new SparkConf().setAppName("PCAOnSourceVectorExample") + val sc = new SparkContext(conf) + + // $example on$ + val data: RDD[LabeledPoint] = sc.parallelize(Seq( + new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)), + new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)))) + + // Compute the top 5 principal components. + val pca = new PCA(5).fit(data.map(_.features)) + + // Project vectors to the linear space spanned by the top 5 principal + // components, keeping the label + val projected = data.map(p => p.copy(features = pca.transform(p.features))) + // $example off$ + val collect = projected.collect() + println("Projected vector of principal component:") + collect.foreach { vector => println(vector) } + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala new file mode 100644 index 0000000000000..c26580d4c1960 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +// $example on$ +import org.apache.spark.mllib.linalg.Matrix +import org.apache.spark.mllib.linalg.SingularValueDecomposition +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.distributed.RowMatrix +// $example off$ + +object SVDExample { + + def main(args: Array[String]): Unit = { + + val conf = new SparkConf().setAppName("SVDExample") + val sc = new SparkContext(conf) + + // $example on$ + val data = Array( + Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) + + val dataRDD = sc.parallelize(data, 2) + + val mat: RowMatrix = new RowMatrix(dataRDD) + + // Compute the top 5 singular values and corresponding singular vectors. + val svd: SingularValueDecomposition[RowMatrix, Matrix] = mat.computeSVD(5, computeU = true) + val U: RowMatrix = svd.U // The U factor is a RowMatrix. + val s: Vector = svd.s // The singular values are stored in a local dense vector. + val V: Matrix = svd.V // The V factor is a local dense matrix. + // $example off$ + val collect = U.rows.collect() + println("U factor is:") + collect.foreach { vector => println(vector) } + println(s"Singular values are: $s") + println(s"V factor is:\n$V") + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala new file mode 100644 index 0000000000000..b73fe9b2b3faa --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD} +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics +import org.apache.spark.mllib.util.MLUtils +// $example off$ + +object SVMWithSGDExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("SVMWithSGDExample") + val sc = new SparkContext(conf) + + // $example on$ + // Load training data in LIBSVM format. + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + + // Split data into training (60%) and test (40%). + val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L) + val training = splits(0).cache() + val test = splits(1) + + // Run training algorithm to build the model + val numIterations = 100 + val model = SVMWithSGD.train(training, numIterations) + + // Clear the default threshold. + model.clearThreshold() + + // Compute raw scores on the test set. + val scoreAndLabels = test.map { point => + val score = model.predict(point.features) + (score, point.label) + } + + // Get evaluation metrics. + val metrics = new BinaryClassificationMetrics(scoreAndLabels) + val auROC = metrics.areaUnderROC() + + println("Area under ROC = " + auROC) + + // Save and load model + model.save(sc, "target/tmp/scalaSVMWithSGDModel") + val sameModel = SVMModel.load(sc, "target/tmp/scalaSVMWithSGDModel") + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegressionExample.scala new file mode 100644 index 0000000000000..0a1cd2d62d5b5 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegressionExample.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.SparkConf +// $example on$ +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD +// $example off$ +import org.apache.spark.streaming._ + +object StreamingLinearRegressionExample { + + def main(args: Array[String]): Unit = { + if (args.length != 2) { + System.err.println("Usage: StreamingLinearRegressionExample ") + System.exit(1) + } + + val conf = new SparkConf().setAppName("StreamingLinearRegressionExample") + val ssc = new StreamingContext(conf, Seconds(1)) + + // $example on$ + val trainingData = ssc.textFileStream(args(0)).map(LabeledPoint.parse).cache() + val testData = ssc.textFileStream(args(1)).map(LabeledPoint.parse) + + val numFeatures = 3 + val model = new StreamingLinearRegressionWithSGD() + .setInitialWeights(Vectors.zeros(numFeatures)) + + model.trainOn(trainingData) + model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print() + + ssc.start() + ssc.awaitTermination() + // $example off$ + + ssc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala index a2f0fcd0e486e..620ff07631c36 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala @@ -52,7 +52,7 @@ object RDDRelation { val rddFromSql = sqlContext.sql("SELECT key, value FROM records WHERE key < 10") println("Result of RDD.map:") - rddFromSql.map(row => s"Key: ${row(0)}, Value: ${row(1)}").collect().foreach(println) + rddFromSql.rdd.map(row => s"Key: ${row(0)}, Value: ${row(1)}").collect().foreach(println) // Queries can also be written using a LINQ-like Scala DSL. df.where($"key" === 1).orderBy($"value".asc).select($"key").collect().foreach(println) diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index 4e427f54daa52..b654a2c8d4a40 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -63,7 +63,7 @@ object HiveFromSpark { val rddFromSql = sql("SELECT key, value FROM src WHERE key < 10 ORDER BY key") println("Result of RDD.map:") - val rddAsStrings = rddFromSql.map { + val rddAsStrings = rddFromSql.rdd.map { case Row(key: Int, value: String) => s"Key: $key, Value: $value" } diff --git a/external/akka/src/main/scala/org/apache/spark/streaming/akka/ActorReceiver.scala b/external/akka/src/main/scala/org/apache/spark/streaming/akka/ActorReceiver.scala index c75dc92445b64..33415c15be2ef 100644 --- a/external/akka/src/main/scala/org/apache/spark/streaming/akka/ActorReceiver.scala +++ b/external/akka/src/main/scala/org/apache/spark/streaming/akka/ActorReceiver.scala @@ -20,12 +20,15 @@ package org.apache.spark.streaming.akka import java.nio.ByteBuffer import java.util.concurrent.atomic.AtomicInteger +import scala.concurrent.Future import scala.concurrent.duration._ import scala.language.postfixOps import scala.reflect.ClassTag import akka.actor._ import akka.actor.SupervisorStrategy.{Escalate, Restart} +import akka.pattern.ask +import akka.util.Timeout import com.typesafe.config.ConfigFactory import org.apache.spark.{Logging, TaskContext} @@ -105,13 +108,26 @@ abstract class ActorReceiver extends Actor { } /** - * Store a single item of received data to Spark's memory. + * Store a single item of received data to Spark's memory asynchronously. * These single items will be aggregated together into data blocks before * being pushed into Spark's memory. */ def store[T](item: T) { context.parent ! SingleItemData(item) } + + /** + * Store a single item of received data to Spark's memory and returns a `Future`. + * The `Future` will be completed when the operator finishes, or with an + * `akka.pattern.AskTimeoutException` after the given timeout has expired. + * These single items will be aggregated together into data blocks before + * being pushed into Spark's memory. + * + * This method allows the user to control the flow speed using `Future` + */ + def store[T](item: T, timeout: Timeout): Future[Unit] = { + context.parent.ask(AskStoreSingleItemData(item))(timeout).map(_ => ())(context.dispatcher) + } } /** @@ -162,6 +178,19 @@ abstract class JavaActorReceiver extends UntypedActor { def store[T](item: T) { context.parent ! SingleItemData(item) } + + /** + * Store a single item of received data to Spark's memory and returns a `Future`. + * The `Future` will be completed when the operator finishes, or with an + * `akka.pattern.AskTimeoutException` after the given timeout has expired. + * These single items will be aggregated together into data blocks before + * being pushed into Spark's memory. + * + * This method allows the user to control the flow speed using `Future` + */ + def store[T](item: T, timeout: Timeout): Future[Unit] = { + context.parent.ask(AskStoreSingleItemData(item))(timeout).map(_ => ())(context.dispatcher) + } } /** @@ -179,8 +208,10 @@ case class Statistics(numberOfMsgs: Int, /** Case class to receive data sent by child actors */ private[akka] sealed trait ActorReceiverData private[akka] case class SingleItemData[T](item: T) extends ActorReceiverData +private[akka] case class AskStoreSingleItemData[T](item: T) extends ActorReceiverData private[akka] case class IteratorData[T](iterator: Iterator[T]) extends ActorReceiverData private[akka] case class ByteBufferData(bytes: ByteBuffer) extends ActorReceiverData +private[akka] object Ack extends ActorReceiverData /** * Provides Actors as receivers for receiving stream. @@ -233,6 +264,12 @@ private[akka] class ActorReceiverSupervisor[T: ClassTag]( store(msg.asInstanceOf[T]) n.incrementAndGet + case AskStoreSingleItemData(msg) => + logDebug("received single sync") + store(msg.asInstanceOf[T]) + n.incrementAndGet + sender() ! Ack + case ByteBufferData(bytes) => logDebug("received bytes") store(bytes) diff --git a/external/akka/src/test/java/org/apache/spark/streaming/akka/JavaAkkaUtilsSuite.java b/external/akka/src/test/java/org/apache/spark/streaming/akka/JavaAkkaUtilsSuite.java index b732506767154..ac5ef31c8b355 100644 --- a/external/akka/src/test/java/org/apache/spark/streaming/akka/JavaAkkaUtilsSuite.java +++ b/external/akka/src/test/java/org/apache/spark/streaming/akka/JavaAkkaUtilsSuite.java @@ -20,6 +20,7 @@ import akka.actor.ActorSystem; import akka.actor.Props; import akka.actor.SupervisorStrategy; +import akka.util.Timeout; import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.api.java.JavaStreamingContext; import org.junit.Test; @@ -62,5 +63,6 @@ class JavaTestActor extends JavaActorReceiver { @Override public void onReceive(Object message) throws Exception { store((String) message); + store((String) message, new Timeout(1000)); } } diff --git a/external/akka/src/test/scala/org/apache/spark/streaming/akka/AkkaUtilsSuite.scala b/external/akka/src/test/scala/org/apache/spark/streaming/akka/AkkaUtilsSuite.scala index f437585a98e4f..ce95d9dd72f90 100644 --- a/external/akka/src/test/scala/org/apache/spark/streaming/akka/AkkaUtilsSuite.scala +++ b/external/akka/src/test/scala/org/apache/spark/streaming/akka/AkkaUtilsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.streaming.akka +import scala.concurrent.duration._ + import akka.actor.{Props, SupervisorStrategy} import org.apache.spark.SparkFunSuite @@ -60,5 +62,6 @@ class AkkaUtilsSuite extends SparkFunSuite { class TestActor extends ActorReceiver { override def receive: Receive = { case m: String => store(m) + case m => store(m, 10.seconds) } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index 68af14397ba81..45c2c008f6dc4 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -148,7 +148,8 @@ List buildClassPath(String appClassPath) throws IOException { String scala = getScalaVersion(); List projects = Arrays.asList("core", "repl", "mllib", "graphx", "streaming", "tools", "sql/catalyst", "sql/core", "sql/hive", "sql/hive-thriftserver", - "yarn", "launcher", "network/common", "network/shuffle", "network/yarn"); + "yarn", "launcher", + "common/network-common", "common/network-shuffle", "common/network-yarn"); if (prependClasses) { if (!isTesting) { System.err.println( diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java index 931a24cfd4b1d..e575fd33080a2 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java @@ -48,8 +48,8 @@ public List buildCommand(Map env) throws IOException { String memKey = null; String extraClassPath = null; - // Master, Worker, and HistoryServer use SPARK_DAEMON_JAVA_OPTS (and specific opts) + - // SPARK_DAEMON_MEMORY. + // Master, Worker, HistoryServer, ExternalShuffleService, MesosClusterDispatcher use + // SPARK_DAEMON_JAVA_OPTS (and specific opts) + SPARK_DAEMON_MEMORY. if (className.equals("org.apache.spark.deploy.master.Master")) { javaOptsKeys.add("SPARK_DAEMON_JAVA_OPTS"); javaOptsKeys.add("SPARK_MASTER_OPTS"); @@ -69,6 +69,8 @@ public List buildCommand(Map env) throws IOException { } else if (className.equals("org.apache.spark.executor.MesosExecutorBackend")) { javaOptsKeys.add("SPARK_EXECUTOR_OPTS"); memKey = "SPARK_EXECUTOR_MEMORY"; + } else if (className.equals("org.apache.spark.deploy.mesos.MesosClusterDispatcher")) { + javaOptsKeys.add("SPARK_DAEMON_JAVA_OPTS"); } else if (className.equals("org.apache.spark.deploy.ExternalShuffleService") || className.equals("org.apache.spark.deploy.mesos.MesosExternalShuffleService")) { javaOptsKeys.add("SPARK_DAEMON_JAVA_OPTS"); diff --git a/make-distribution.sh b/make-distribution.sh index 327659298e4d8..ac90ea317a6fc 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -32,11 +32,6 @@ set -x SPARK_HOME="$(cd "`dirname "$0"`"; pwd)" DISTDIR="$SPARK_HOME/dist" -SPARK_TACHYON=false -TACHYON_VERSION="0.8.2" -TACHYON_TGZ="tachyon-${TACHYON_VERSION}-bin.tar.gz" -TACHYON_URL="http://tachyon-project.org/downloads/files/${TACHYON_VERSION}/${TACHYON_TGZ}" - MAKE_TGZ=false NAME=none MVN="$SPARK_HOME/build/mvn" @@ -45,7 +40,7 @@ function exit_with_usage { echo "make-distribution.sh - tool for making binary distributions of Spark" echo "" echo "usage:" - cl_options="[--name] [--tgz] [--mvn ] [--with-tachyon]" + cl_options="[--name] [--tgz] [--mvn ]" echo "./make-distribution.sh $cl_options " echo "See Spark's \"Building Spark\" doc for correct Maven options." echo "" @@ -69,9 +64,6 @@ while (( "$#" )); do echo "Error: '--with-hive' is no longer supported, use Maven options -Phive and -Phive-thriftserver" exit_with_usage ;; - --with-tachyon) - SPARK_TACHYON=true - ;; --tgz) MAKE_TGZ=true ;; @@ -150,12 +142,6 @@ else echo "Making distribution for Spark $VERSION in $DISTDIR..." fi -if [ "$SPARK_TACHYON" == "true" ]; then - echo "Tachyon Enabled" -else - echo "Tachyon Disabled" -fi - # Build uber fat JAR cd "$SPARK_HOME" @@ -183,7 +169,7 @@ cp "$SPARK_HOME"/assembly/target/scala*/*assembly*hadoop*.jar "$DISTDIR/lib/" cp "$SPARK_HOME"/examples/target/scala*/spark-examples*.jar "$DISTDIR/lib/" # This will fail if the -Pyarn profile is not provided # In this case, silence the error and ignore the return code of this command -cp "$SPARK_HOME"/network/yarn/target/scala*/spark-*-yarn-shuffle.jar "$DISTDIR/lib/" &> /dev/null || : +cp "$SPARK_HOME"/common/network-yarn/target/scala*/spark-*-yarn-shuffle.jar "$DISTDIR/lib/" &> /dev/null || : # Copy example sources (needed for python and SQL) mkdir -p "$DISTDIR/examples/src/main" @@ -219,40 +205,6 @@ if [ -d "$SPARK_HOME"/R/lib/SparkR ]; then cp "$SPARK_HOME/R/lib/sparkr.zip" "$DISTDIR"/R/lib fi -# Download and copy in tachyon, if requested -if [ "$SPARK_TACHYON" == "true" ]; then - TMPD=`mktemp -d 2>/dev/null || mktemp -d -t 'disttmp'` - - pushd "$TMPD" > /dev/null - echo "Fetching tachyon tgz" - - TACHYON_DL="${TACHYON_TGZ}.part" - if [ $(command -v curl) ]; then - curl --silent -k -L "${TACHYON_URL}" > "${TACHYON_DL}" && mv "${TACHYON_DL}" "${TACHYON_TGZ}" - elif [ $(command -v wget) ]; then - wget --quiet "${TACHYON_URL}" -O "${TACHYON_DL}" && mv "${TACHYON_DL}" "${TACHYON_TGZ}" - else - printf "You do not have curl or wget installed. please install Tachyon manually.\n" - exit -1 - fi - - tar xzf "${TACHYON_TGZ}" - cp "tachyon-${TACHYON_VERSION}/assembly/target/tachyon-assemblies-${TACHYON_VERSION}-jar-with-dependencies.jar" "$DISTDIR/lib" - mkdir -p "$DISTDIR/tachyon/src/main/java/tachyon/web" - cp -r "tachyon-${TACHYON_VERSION}"/{bin,conf,libexec} "$DISTDIR/tachyon" - cp -r "tachyon-${TACHYON_VERSION}"/servers/src/main/java/tachyon/web "$DISTDIR/tachyon/src/main/java/tachyon/web" - - if [[ `uname -a` == Darwin* ]]; then - # need to run sed differently on osx - nl=$'\n'; sed -i "" -e "s|export TACHYON_JAR=\$TACHYON_HOME/target/\(.*\)|# This is set for spark's make-distribution\\$nl export TACHYON_JAR=\$TACHYON_HOME/../lib/\1|" "$DISTDIR/tachyon/libexec/tachyon-config.sh" - else - sed -i "s|export TACHYON_JAR=\$TACHYON_HOME/target/\(.*\)|# This is set for spark's make-distribution\n export TACHYON_JAR=\$TACHYON_HOME/../lib/\1|" "$DISTDIR/tachyon/libexec/tachyon-config.sh" - fi - - popd > /dev/null - rm -rf "$TMPD" -fi - if [ "$MAKE_TGZ" == "true" ]; then TARDIR_NAME=spark-$VERSION-bin-$NAME TARDIR="$SPARK_HOME/$TARDIR_NAME" diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index d1388b5e2eb5a..4b27ee6c5a414 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -122,8 +122,10 @@ abstract class Predictor[ * and put it in an RDD with strong types. */ protected def extractLabeledPoints(dataset: DataFrame): RDD[LabeledPoint] = { - dataset.select($(labelCol), $(featuresCol)) - .map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) } + dataset.select($(labelCol), $(featuresCol)).rdd.map { + case Row(label: Double, features: Vector) => + LabeledPoint(label, features) + } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index ac0124513f283..0d329d2c08d50 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -263,10 +263,11 @@ class LogisticRegression @Since("1.2.0") ( protected[spark] def train(dataset: DataFrame, handlePersistence: Boolean): LogisticRegressionModel = { val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) - val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).map { - case Row(label: Double, weight: Double, features: Vector) => - Instance(label, weight, features) - } + val instances: RDD[Instance] = + dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { + case Row(label: Double, weight: Double, features: Vector) => + Instance(label, weight, features) + } if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) @@ -790,6 +791,7 @@ sealed trait LogisticRegressionSummary extends Serializable { /** * :: Experimental :: * Logistic regression training results. + * * @param predictions dataframe outputted by the model's `transform` method. * @param probabilityCol field in "predictions" which gives the calibrated probability of * each instance as a vector. @@ -813,6 +815,7 @@ class BinaryLogisticRegressionTrainingSummary private[classification] ( /** * :: Experimental :: * Binary Logistic regression results for a given model. + * * @param predictions dataframe outputted by the model's `transform` method. * @param probabilityCol field in "predictions" which gives the calibrated probability of * each instance. @@ -837,7 +840,7 @@ class BinaryLogisticRegressionSummary private[classification] ( // TODO: Allow the user to vary the number of bins using a setBins method in // BinaryClassificationMetrics. For now the default is set to 100. @transient private val binaryMetrics = new BinaryClassificationMetrics( - predictions.select(probabilityCol, labelCol).map { + predictions.select(probabilityCol, labelCol).rdd.map { case Row(score: Vector, label: Double) => (score(1), label) }, 100 ) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 45d293bc69618..f014a1d572387 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -112,7 +112,7 @@ class BisectingKMeansModel private[ml] ( @Since("2.0.0") def computeCost(dataset: DataFrame): Double = { SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) - val data = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point } + val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } parentModel.computeCost(data) } } @@ -176,7 +176,7 @@ class BisectingKMeans @Since("2.0.0") ( @Since("2.0.0") override def fit(dataset: DataFrame): BisectingKMeansModel = { - val rdd = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point } + val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } val bkm = new MLlibBisectingKMeans() .setK($(k)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index b2292e20e2124..79332b0d02157 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.clustering import org.apache.hadoop.fs.Path +import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params} @@ -129,12 +130,32 @@ class KMeansModel private[ml] ( @Since("1.6.0") def computeCost(dataset: DataFrame): Double = { SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) - val data = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point } + val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } parentModel.computeCost(data) } @Since("1.6.0") override def write: MLWriter = new KMeansModel.KMeansModelWriter(this) + + private var trainingSummary: Option[KMeansSummary] = None + + private[clustering] def setSummary(summary: KMeansSummary): this.type = { + this.trainingSummary = Some(summary) + this + } + + /** + * Gets summary of model on training set. An exception is + * thrown if `trainingSummary == None`. + */ + @Since("2.0.0") + def summary: KMeansSummary = trainingSummary match { + case Some(summ) => summ + case None => + throw new SparkException( + s"No training summary available for the ${this.getClass.getSimpleName}", + new NullPointerException()) + } } @Since("1.6.0") @@ -239,7 +260,7 @@ class KMeans @Since("1.5.0") ( @Since("1.5.0") override def fit(dataset: DataFrame): KMeansModel = { - val rdd = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point } + val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } val algo = new MLlibKMeans() .setK($(k)) @@ -249,8 +270,9 @@ class KMeans @Since("1.5.0") ( .setSeed($(seed)) .setEpsilon($(tol)) val parentModel = algo.run(rdd) - val model = new KMeansModel(uid, parentModel) - copyValues(model.setParent(this)) + val model = copyValues(new KMeansModel(uid, parentModel).setParent(this)) + val summary = new KMeansSummary(model.transform(dataset), $(predictionCol), $(featuresCol)) + model.setSummary(summary) } @Since("1.5.0") @@ -266,3 +288,22 @@ object KMeans extends DefaultParamsReadable[KMeans] { override def load(path: String): KMeans = super.load(path) } +class KMeansSummary private[clustering] ( + @Since("2.0.0") @transient val predictions: DataFrame, + @Since("2.0.0") val predictionCol: String, + @Since("2.0.0") val featuresCol: String) extends Serializable { + + /** + * Cluster centers of the transformed data. + */ + @Since("2.0.0") + @transient lazy val cluster: DataFrame = predictions.select(predictionCol) + + /** + * Size of each cluster. + */ + @Since("2.0.0") + lazy val size: Array[Int] = cluster.rdd.map { + case Row(clusterIdx: Int) => (clusterIdx, 1) + }.reduceByKey(_ + _).collect().sortBy(_._1).map(_._2) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 99383e77f7eb0..6304b20d544ad 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -803,6 +803,7 @@ private[clustering] object LDA extends DefaultParamsReadable[LDA] { dataset .withColumn("docId", monotonicallyIncreasingId()) .select("docId", featuresCol) + .rdd .map { case Row(docId: Long, features: Vector) => (docId, features) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index a1d36c4becfa2..00f31255845ab 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -84,11 +84,10 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) // TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2. - val scoreAndLabels = dataset.select($(rawPredictionCol), $(labelCol)) - .map { - case Row(rawPrediction: Vector, label: Double) => (rawPrediction(1), label) - case Row(rawPrediction: Double, label: Double) => (rawPrediction, label) - } + val scoreAndLabels = dataset.select($(rawPredictionCol), $(labelCol)).rdd.map { + case Row(rawPrediction: Vector, label: Double) => (rawPrediction(1), label) + case Row(rawPrediction: Double, label: Double) => (rawPrediction, label) + } val metrics = new BinaryClassificationMetrics(scoreAndLabels) val metric = $(metricName) match { case "areaUnderROC" => metrics.areaUnderROC() diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index a921153b9474f..55ff44323a790 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -74,9 +74,9 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType) SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) - val predictionAndLabels = dataset.select($(predictionCol), $(labelCol)) - .map { case Row(prediction: Double, label: Double) => - (prediction, label) + val predictionAndLabels = dataset.select($(predictionCol), $(labelCol)).rdd.map { + case Row(prediction: Double, label: Double) => + (prediction, label) } val metrics = new MulticlassMetrics(predictionAndLabels) val metric = $(metricName) match { diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index b6b25ecd01b3d..adee61e297081 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -85,7 +85,8 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui val predictionAndLabels = dataset .select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType)) - .map { case Row(prediction: Double, label: Double) => + .rdd. + map { case Row(prediction: Double, label: Double) => (prediction, label) } val metrics = new RegressionMetrics(predictionAndLabels) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index 7b565ef3ed922..4abc459f5369a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -79,7 +79,7 @@ final class ChiSqSelector(override val uid: String) override def fit(dataset: DataFrame): ChiSqSelectorModel = { transformSchema(dataset.schema, logging = true) - val input = dataset.select($(labelCol), $(featuresCol)).map { + val input = dataset.select($(labelCol), $(featuresCol)).rdd.map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index a6dfe58e56f0e..cf151458f0917 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -133,7 +133,7 @@ class CountVectorizer(override val uid: String) override def fit(dataset: DataFrame): CountVectorizerModel = { transformSchema(dataset.schema, logging = true) val vocSize = $(vocabSize) - val input = dataset.select($(inputCol)).map(_.getAs[Seq[String]](0)) + val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0)) val minDf = if ($(minDF) >= 1.0) { $(minDF) } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 9e7eee4f29988..cebbe5c162f79 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -79,7 +79,7 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa override def fit(dataset: DataFrame): IDFModel = { transformSchema(dataset.schema, logging = true) - val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } + val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } val idf = new feature.IDF($(minDocFreq)).fit(input) copyValues(new IDFModel(uid, idf).setParent(this)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala new file mode 100644 index 0000000000000..09fad236422b7 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param.{ParamMap, Params} +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.mllib.stat.Statistics +import org.apache.spark.sql._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{StructField, StructType} + +/** + * Params for [[MaxAbsScaler]] and [[MaxAbsScalerModel]]. + */ +private[feature] trait MaxAbsScalerParams extends Params with HasInputCol with HasOutputCol { + + /** Validates and transforms the input schema. */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + validateParams() + val inputType = schema($(inputCol)).dataType + require(inputType.isInstanceOf[VectorUDT], + s"Input column ${$(inputCol)} must be a vector column") + require(!schema.fieldNames.contains($(outputCol)), + s"Output column ${$(outputCol)} already exists.") + val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) + StructType(outputFields) + } +} + +/** + * :: Experimental :: + * Rescale each feature individually to range [-1, 1] by dividing through the largest maximum + * absolute value in each feature. It does not shift/center the data, and thus does not destroy + * any sparsity. + */ +@Experimental +class MaxAbsScaler @Since("2.0.0") (override val uid: String) + extends Estimator[MaxAbsScalerModel] with MaxAbsScalerParams with DefaultParamsWritable { + + @Since("2.0.0") + def this() = this(Identifiable.randomUID("maxAbsScal")) + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def fit(dataset: DataFrame): MaxAbsScalerModel = { + transformSchema(dataset.schema, logging = true) + val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } + val summary = Statistics.colStats(input) + val minVals = summary.min.toArray + val maxVals = summary.max.toArray + val n = minVals.length + val maxAbs = Array.tabulate(n) { i => math.max(math.abs(minVals(i)), math.abs(maxVals(i))) } + + copyValues(new MaxAbsScalerModel(uid, Vectors.dense(maxAbs)).setParent(this)) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + override def copy(extra: ParamMap): MaxAbsScaler = defaultCopy(extra) +} + +@Since("1.6.0") +object MaxAbsScaler extends DefaultParamsReadable[MaxAbsScaler] { + + @Since("1.6.0") + override def load(path: String): MaxAbsScaler = super.load(path) +} + +/** + * :: Experimental :: + * Model fitted by [[MaxAbsScaler]]. + * + */ +@Experimental +class MaxAbsScalerModel private[ml] ( + override val uid: String, + val maxAbs: Vector) + extends Model[MaxAbsScalerModel] with MaxAbsScalerParams with MLWritable { + + import MaxAbsScalerModel._ + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema, logging = true) + // TODO: this looks hack, we may have to handle sparse and dense vectors separately. + val maxAbsUnzero = Vectors.dense(maxAbs.toArray.map(x => if (x == 0) 1 else x)) + val reScale = udf { (vector: Vector) => + val brz = vector.toBreeze / maxAbsUnzero.toBreeze + Vectors.fromBreeze(brz) + } + dataset.withColumn($(outputCol), reScale(col($(inputCol)))) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + override def copy(extra: ParamMap): MaxAbsScalerModel = { + val copied = new MaxAbsScalerModel(uid, maxAbs) + copyValues(copied, extra).setParent(parent) + } + + @Since("1.6.0") + override def write: MLWriter = new MaxAbsScalerModelWriter(this) +} + +@Since("1.6.0") +object MaxAbsScalerModel extends MLReadable[MaxAbsScalerModel] { + + private[MaxAbsScalerModel] + class MaxAbsScalerModelWriter(instance: MaxAbsScalerModel) extends MLWriter { + + private case class Data(maxAbs: Vector) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = new Data(instance.maxAbs) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class MaxAbsScalerModelReader extends MLReader[MaxAbsScalerModel] { + + private val className = classOf[MaxAbsScalerModel].getName + + override def load(path: String): MaxAbsScalerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val Row(maxAbs: Vector) = sqlContext.read.parquet(dataPath) + .select("maxAbs") + .head() + val model = new MaxAbsScalerModel(metadata.uid, maxAbs) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[MaxAbsScalerModel] = new MaxAbsScalerModelReader + + @Since("1.6.0") + override def load(path: String): MaxAbsScalerModel = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index ad0458d0d0e1a..18be5c0701fb1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -108,7 +108,7 @@ class MinMaxScaler(override val uid: String) override def fit(dataset: DataFrame): MinMaxScalerModel = { transformSchema(dataset.schema, logging = true) - val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } + val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } val summary = Statistics.colStats(input) copyValues(new MinMaxScalerModel(uid, summary.min, summary.max).setParent(this)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 342540418fa80..e9df161c00b83 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -130,7 +130,7 @@ class OneHotEncoder(override val uid: String) extends Transformer transformSchema(dataset.schema)(outputColName)) if (outputAttrGroup.size < 0) { // If the number of attributes is unknown, we check the values from the input column. - val numAttrs = dataset.select(col(inputColName).cast(DoubleType)).map(_.getDouble(0)) + val numAttrs = dataset.select(col(inputColName).cast(DoubleType)).rdd.map(_.getDouble(0)) .aggregate(0.0)( (m, x) => { assert(x >=0.0 && x == x.toInt, diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 0e07dfabfeaab..80b124f74716d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -70,7 +70,7 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams */ override def fit(dataset: DataFrame): PCAModel = { transformSchema(dataset.schema, logging = true) - val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v} + val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v} val pca = new feature.PCA(k = $(k)) val pcaModel = pca.fit(input) copyValues(new PCAModel(uid, pcaModel.pc, pcaModel.explainedVariance).setParent(this)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala index 42b26c8ee836c..0a9b9719c15d3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -46,7 +46,7 @@ class PolynomialExpansion(override val uid: String) * @group param */ val degree = new IntParam(this, "degree", "the polynomial degree to expand (>= 1)", - ParamValidators.gt(1)) + ParamValidators.gtEq(1)) setDefault(degree -> 2) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 1f4cca123310a..769f4406e2dff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -103,6 +103,13 @@ final class QuantileDiscretizer(override val uid: String) @Since("1.6.0") object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging { + + /** + * Minimum number of samples required for finding splits, regardless of number of bins. If + * the dataset has fewer rows than this value, the entire dataset will be used. + */ + private[spark] val minSamplesRequired: Int = 10000 + /** * Sampling from the given dataset to collect quantile statistics. */ @@ -110,8 +117,8 @@ object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] wi val totalSamples = dataset.count() require(totalSamples > 0, "QuantileDiscretizer requires non-empty input dataset but was given an empty input.") - val requiredSamples = math.max(numBins * numBins, 10000) - val fraction = math.min(requiredSamples / dataset.count(), 1.0) + val requiredSamples = math.max(numBins * numBins, minSamplesRequired) + val fraction = math.min(requiredSamples.toDouble / dataset.count(), 1.0) dataset.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect() } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 6a0b6c240ec60..9952d3bc9f1a5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -87,7 +87,7 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM override def fit(dataset: DataFrame): StandardScalerModel = { transformSchema(dataset.schema, logging = true) - val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } + val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } val scaler = new feature.StandardScaler(withMean = $(withMean), withStd = $(withStd)) val scalerModel = scaler.fit(input) copyValues(new StandardScalerModel(uid, scalerModel.std, scalerModel.mean).setParent(this)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 912bd95a2ec70..7dd794b9d7d1d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -83,6 +83,7 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod override def fit(dataset: DataFrame): StringIndexerModel = { val counts = dataset.select(col($(inputCol)).cast(StringType)) + .rdd .map(_.getString(0)) .countByValue() val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray @@ -150,6 +151,7 @@ class StringIndexerModel ( "Skip StringIndexerModel.") return dataset } + validateAndTransformSchema(dataset.schema) val indexer = udf { label: String => if (labelToIndex.contains(label)) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 2a5268406ddf2..5c11760fab9b2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -113,7 +113,7 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod val firstRow = dataset.select($(inputCol)).take(1) require(firstRow.length == 1, s"VectorIndexer cannot be fit on an empty dataset.") val numFeatures = firstRow(0).getAs[Vector](0).size - val vectorDataset = dataset.select($(inputCol)).map { case Row(v: Vector) => v } + val vectorDataset = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } val maxCats = $(maxCategories) val categoryStats: VectorIndexer.CategoryStats = vectorDataset.mapPartitions { iter => val localCatStats = new VectorIndexer.CategoryStats(numFeatures, maxCats) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 2b6b3c3a0fc58..a4c3d2751f50d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -138,7 +138,7 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] override def fit(dataset: DataFrame): Word2VecModel = { transformSchema(dataset.schema, logging = true) - val input = dataset.select($(inputCol)).map(_.getAs[Seq[String]](0)) + val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0)) val wordVectors = new feature.Word2Vec() .setLearningRate($(stepSize)) .setMinCount($(minCount)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index 61b3642131810..55b7510656643 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -156,6 +156,12 @@ private[ml] class WeightedLeastSquares( private[ml] object WeightedLeastSquares { + /** + * In order to take the normal equation approach efficiently, [[WeightedLeastSquares]] + * only supports the number of features is no more than 4096. + */ + val MAX_NUM_FEATURES: Int = 4096 + /** * Aggregator to provide necessary summary statistics for solving [[WeightedLeastSquares]]. */ @@ -174,8 +180,8 @@ private[ml] object WeightedLeastSquares { private var aaSum: DenseVector = _ private def init(k: Int): Unit = { - require(k <= 4096, "In order to take the normal equation approach efficiently, " + - s"we set the max number of features to 4096 but got $k.") + require(k <= MAX_NUM_FEATURES, "In order to take the normal equation approach efficiently, " + + s"we set the max number of features to $MAX_NUM_FEATURES but got $k.") this.k = k triK = k * (k + 1) / 2 count = 0L diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index 551e75dc0a02d..d23e4fc9d1f57 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -20,7 +20,8 @@ package org.apache.spark.ml.api.r import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} -import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.clustering.{KMeans, KMeansModel} +import org.apache.spark.ml.feature.{RFormula, VectorAssembler} import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel} import org.apache.spark.sql.DataFrame @@ -51,6 +52,22 @@ private[r] object SparkRWrappers { pipeline.fit(df) } + def fitKMeans( + df: DataFrame, + initMode: String, + maxIter: Double, + k: Double, + columns: Array[String]): PipelineModel = { + val assembler = new VectorAssembler().setInputCols(columns) + val kMeans = new KMeans() + .setInitMode(initMode) + .setMaxIter(maxIter.toInt) + .setK(k.toInt) + .setFeaturesCol(assembler.getOutputCol) + val pipeline = new Pipeline().setStages(Array(assembler, kMeans)) + pipeline.fit(df) + } + def getModelCoefficients(model: PipelineModel): Array[Double] = { model.stages.last match { case m: LinearRegressionModel => { @@ -72,6 +89,8 @@ private[r] object SparkRWrappers { m.coefficients.toArray } } + case m: KMeansModel => + m.clusterCenters.flatMap(_.toArray) } } @@ -85,6 +104,31 @@ private[r] object SparkRWrappers { } } + def getKMeansModelSize(model: PipelineModel): Array[Int] = { + model.stages.last match { + case m: KMeansModel => Array(m.getK) ++ m.summary.size + case other => throw new UnsupportedOperationException( + s"KMeansModel required but ${other.getClass.getSimpleName} found.") + } + } + + def getKMeansCluster(model: PipelineModel, method: String): DataFrame = { + model.stages.last match { + case m: KMeansModel => + if (method == "centers") { + // Drop the assembled vector for easy-print to R side. + m.summary.predictions.drop(m.summary.featuresCol) + } else if (method == "classes") { + m.summary.cluster + } else { + throw new UnsupportedOperationException( + s"Method (centers or classes) required but $method found.") + } + case other => throw new UnsupportedOperationException( + s"KMeansModel required but ${other.getClass.getSimpleName} found.") + } + } + def getModelFeatures(model: PipelineModel): Array[String] = { model.stages.last match { case m: LinearRegressionModel => @@ -103,6 +147,10 @@ private[r] object SparkRWrappers { } else { attrs.attributes.get.map(_.name.get) } + case m: KMeansModel => + val attrs = AttributeGroup.fromStructField( + m.summary.predictions.schema(m.summary.featuresCol)) + attrs.attributes.get.map(_.name.get) } } @@ -112,6 +160,8 @@ private[r] object SparkRWrappers { "LinearRegressionModel" case m: LogisticRegressionModel => "LogisticRegressionModel" + case m: KMeansModel => + "KMeansModel" } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 4be4d6abed6b3..dacdac9a1df16 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -392,6 +392,7 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f) val ratings = dataset .select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType), r) + .rdd .map { row => Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index e8a1ff2278a92..e4339d67b928d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -184,7 +184,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S * and put it in an RDD with strong types. */ protected[ml] def extractAFTPoints(dataset: DataFrame): RDD[AFTPoint] = { - dataset.select($(featuresCol), $(labelCol), $(censorCol)).map { + dataset.select($(featuresCol), $(labelCol), $(censorCol)).rdd.map { case Row(features: Vector, label: Double, censor: Double) => AFTPoint(features, label, censor) } @@ -437,23 +437,25 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) extends Serializable { - // beta is the intercept and regression coefficients to the covariates - private val beta = parameters.slice(1, parameters.length) + // the regression coefficients to the covariates + private val coefficients = parameters.slice(2, parameters.length) + private val intercept = parameters.valueAt(1) // sigma is the scale parameter of the AFT model private val sigma = math.exp(parameters(0)) private var totalCnt: Long = 0L private var lossSum = 0.0 - private var gradientBetaSum = BDV.zeros[Double](beta.length) + private var gradientCoefficientSum = BDV.zeros[Double](coefficients.length) + private var gradientInterceptSum = 0.0 private var gradientLogSigmaSum = 0.0 def count: Long = totalCnt def loss: Double = if (totalCnt == 0) 1.0 else lossSum / totalCnt - // Here we optimize loss function over beta and log(sigma) + // Here we optimize loss function over coefficients, intercept and log(sigma) def gradient: BDV[Double] = BDV.vertcat(BDV(Array(gradientLogSigmaSum / totalCnt.toDouble)), - gradientBetaSum/totalCnt.toDouble) + BDV(Array(gradientInterceptSum/totalCnt.toDouble)), gradientCoefficientSum/totalCnt.toDouble) /** * Add a new training data to this AFTAggregator, and update the loss and gradient @@ -464,15 +466,12 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) */ def add(data: AFTPoint): this.type = { - // TODO: Don't create a new xi vector each time. - val xi = if (fitIntercept) { - Vectors.dense(Array(1.0) ++ data.features.toArray).toBreeze - } else { - Vectors.dense(Array(0.0) ++ data.features.toArray).toBreeze - } + val interceptFlag = if (fitIntercept) 1.0 else 0.0 + + val xi = data.features.toBreeze val ti = data.label val delta = data.censor - val epsilon = (math.log(ti) - beta.dot(xi)) / sigma + val epsilon = (math.log(ti) - coefficients.dot(xi) - intercept * interceptFlag ) / sigma lossSum += math.log(sigma) * delta lossSum += (math.exp(epsilon) - delta * epsilon) @@ -481,8 +480,10 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) assert(!lossSum.isInfinity, s"AFTAggregator loss sum is infinity. Error for unknown reason.") - gradientBetaSum += xi * (delta - math.exp(epsilon)) / sigma - gradientLogSigmaSum += delta + (delta - math.exp(epsilon)) * epsilon + val deltaMinusExpEps = delta - math.exp(epsilon) + gradientCoefficientSum += xi * deltaMinusExpEps / sigma + gradientInterceptSum += interceptFlag * deltaMinusExpEps / sigma + gradientLogSigmaSum += delta + deltaMinusExpEps * epsilon totalCnt += 1 this @@ -501,7 +502,8 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) totalCnt += other.totalCnt lossSum += other.lossSum - gradientBetaSum += other.gradientBetaSum + gradientCoefficientSum += other.gradientCoefficientSum + gradientInterceptSum += other.gradientInterceptSum gradientLogSigmaSum += other.gradientLogSigmaSum } this diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala new file mode 100644 index 0000000000000..a850dfee0a452 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -0,0 +1,577 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import breeze.stats.distributions.{Gaussian => GD} + +import org.apache.spark.{Logging, SparkException} +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.PredictorParams +import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.optim._ +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.linalg.{BLAS, Vector} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions._ + +/** + * Params for Generalized Linear Regression. + */ +private[regression] trait GeneralizedLinearRegressionBase extends PredictorParams + with HasFitIntercept with HasMaxIter with HasTol with HasRegParam with HasWeightCol + with HasSolver with Logging { + + /** + * Param for the name of family which is a description of the error distribution + * to be used in the model. + * Supported options: "gaussian", "binomial", "poisson" and "gamma". + * Default is "gaussian". + * @group param + */ + @Since("2.0.0") + final val family: Param[String] = new Param(this, "family", + "The name of family which is a description of the error distribution to be used in the " + + "model. Supported options: gaussian(default), binomial, poisson and gamma.", + ParamValidators.inArray[String](GeneralizedLinearRegression.supportedFamilyNames.toArray)) + + /** @group getParam */ + @Since("2.0.0") + def getFamily: String = $(family) + + /** + * Param for the name of link function which provides the relationship + * between the linear predictor and the mean of the distribution function. + * Supported options: "identity", "log", "inverse", "logit", "probit", "cloglog" and "sqrt". + * @group param + */ + @Since("2.0.0") + final val link: Param[String] = new Param(this, "link", "The name of link function " + + "which provides the relationship between the linear predictor and the mean of the " + + "distribution function. Supported options: identity, log, inverse, logit, probit, " + + "cloglog and sqrt.", + ParamValidators.inArray[String](GeneralizedLinearRegression.supportedLinkNames.toArray)) + + /** @group getParam */ + @Since("2.0.0") + def getLink: String = $(link) + + import GeneralizedLinearRegression._ + + @Since("2.0.0") + override def validateParams(): Unit = { + if ($(solver) == "irls") { + setDefault(maxIter -> 25) + } + if (isDefined(link)) { + require(supportedFamilyAndLinkPairs.contains( + Family.fromName($(family)) -> Link.fromName($(link))), "Generalized Linear Regression " + + s"with ${$(family)} family does not support ${$(link)} link function.") + } + } +} + +/** + * :: Experimental :: + * + * Fit a Generalized Linear Model ([[https://en.wikipedia.org/wiki/Generalized_linear_model]]) + * specified by giving a symbolic description of the linear predictor (link function) and + * a description of the error distribution (family). + * It supports "gaussian", "binomial", "poisson" and "gamma" as family. + * Valid link functions for each family is listed below. The first link function of each family + * is the default one. + * - "gaussian" -> "identity", "log", "inverse" + * - "binomial" -> "logit", "probit", "cloglog" + * - "poisson" -> "log", "identity", "sqrt" + * - "gamma" -> "inverse", "identity", "log" + */ +@Experimental +@Since("2.0.0") +class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val uid: String) + extends Regressor[Vector, GeneralizedLinearRegression, GeneralizedLinearRegressionModel] + with GeneralizedLinearRegressionBase with Logging { + + import GeneralizedLinearRegression._ + + @Since("2.0.0") + def this() = this(Identifiable.randomUID("glm")) + + /** + * Sets the value of param [[family]]. + * Default is "gaussian". + * @group setParam + */ + @Since("2.0.0") + def setFamily(value: String): this.type = set(family, value) + setDefault(family -> Gaussian.name) + + /** + * Sets the value of param [[link]]. + * @group setParam + */ + @Since("2.0.0") + def setLink(value: String): this.type = set(link, value) + + /** + * Sets if we should fit the intercept. + * Default is true. + * @group setParam + */ + @Since("2.0.0") + def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) + + /** + * Sets the maximum number of iterations. + * Default is 25 if the solver algorithm is "irls". + * @group setParam + */ + @Since("2.0.0") + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** + * Sets the convergence tolerance of iterations. + * Smaller value will lead to higher accuracy with the cost of more iterations. + * Default is 1E-6. + * @group setParam + */ + @Since("2.0.0") + def setTol(value: Double): this.type = set(tol, value) + setDefault(tol -> 1E-6) + + /** + * Sets the regularization parameter. + * Default is 0.0. + * @group setParam + */ + @Since("2.0.0") + def setRegParam(value: Double): this.type = set(regParam, value) + setDefault(regParam -> 0.0) + + /** + * Sets the value of param [[weightCol]]. + * If this is not set or empty, we treat all instance weights as 1.0. + * Default is empty, so all instances have weight one. + * @group setParam + */ + @Since("2.0.0") + def setWeightCol(value: String): this.type = set(weightCol, value) + setDefault(weightCol -> "") + + /** + * Sets the solver algorithm used for optimization. + * Currently only support "irls" which is also the default solver. + * @group setParam + */ + @Since("2.0.0") + def setSolver(value: String): this.type = set(solver, value) + setDefault(solver -> "irls") + + override protected def train(dataset: DataFrame): GeneralizedLinearRegressionModel = { + val familyObj = Family.fromName($(family)) + val linkObj = if (isDefined(link)) { + Link.fromName($(link)) + } else { + familyObj.defaultLink + } + val familyAndLink = new FamilyAndLink(familyObj, linkObj) + + val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd + .map { case Row(features: Vector) => + features.size + }.first() + if (numFeatures > WeightedLeastSquares.MAX_NUM_FEATURES) { + val msg = "Currently, GeneralizedLinearRegression only supports number of features" + + s" <= ${WeightedLeastSquares.MAX_NUM_FEATURES}. Found $numFeatures in the input dataset." + throw new SparkException(msg) + } + + val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) + val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd + .map { case Row(label: Double, weight: Double, features: Vector) => + Instance(label, weight, features) + } + + if (familyObj == Gaussian && linkObj == Identity) { + // TODO: Make standardizeFeatures and standardizeLabel configurable. + val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), + standardizeFeatures = true, standardizeLabel = true) + val wlsModel = optimizer.fit(instances) + val model = copyValues( + new GeneralizedLinearRegressionModel(uid, wlsModel.coefficients, wlsModel.intercept) + .setParent(this)) + return model + } + + // Fit Generalized Linear Model by iteratively reweighted least squares (IRLS). + val initialModel = familyAndLink.initialize(instances, $(fitIntercept), $(regParam)) + val optimizer = new IterativelyReweightedLeastSquares(initialModel, familyAndLink.reweightFunc, + $(fitIntercept), $(regParam), $(maxIter), $(tol)) + val irlsModel = optimizer.fit(instances) + + val model = copyValues( + new GeneralizedLinearRegressionModel(uid, irlsModel.coefficients, irlsModel.intercept) + .setParent(this)) + model + } + + @Since("2.0.0") + override def copy(extra: ParamMap): GeneralizedLinearRegression = defaultCopy(extra) +} + +@Since("2.0.0") +private[ml] object GeneralizedLinearRegression { + + /** Set of family and link pairs that GeneralizedLinearRegression supports. */ + lazy val supportedFamilyAndLinkPairs = Set( + Gaussian -> Identity, Gaussian -> Log, Gaussian -> Inverse, + Binomial -> Logit, Binomial -> Probit, Binomial -> CLogLog, + Poisson -> Log, Poisson -> Identity, Poisson -> Sqrt, + Gamma -> Inverse, Gamma -> Identity, Gamma -> Log + ) + + /** Set of family names that GeneralizedLinearRegression supports. */ + lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name) + + /** Set of link names that GeneralizedLinearRegression supports. */ + lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name) + + val epsilon: Double = 1E-16 + + /** + * Wrapper of family and link combination used in the model. + */ + private[ml] class FamilyAndLink(val family: Family, val link: Link) extends Serializable { + + /** Linear predictor based on given mu. */ + def predict(mu: Double): Double = link.link(family.project(mu)) + + /** Fitted value based on linear predictor eta. */ + def fitted(eta: Double): Double = family.project(link.unlink(eta)) + + /** + * Get the initial guess model for [[IterativelyReweightedLeastSquares]]. + */ + def initialize( + instances: RDD[Instance], + fitIntercept: Boolean, + regParam: Double): WeightedLeastSquaresModel = { + val newInstances = instances.map { instance => + val mu = family.initialize(instance.label, instance.weight) + val eta = predict(mu) + Instance(eta, instance.weight, instance.features) + } + // TODO: Make standardizeFeatures and standardizeLabel configurable. + val initialModel = new WeightedLeastSquares(fitIntercept, regParam, + standardizeFeatures = true, standardizeLabel = true) + .fit(newInstances) + initialModel + } + + /** + * The reweight function used to update offsets and weights + * at each iteration of [[IterativelyReweightedLeastSquares]]. + */ + val reweightFunc: (Instance, WeightedLeastSquaresModel) => (Double, Double) = { + (instance: Instance, model: WeightedLeastSquaresModel) => { + val eta = model.predict(instance.features) + val mu = fitted(eta) + val offset = eta + (instance.label - mu) * link.deriv(mu) + val weight = instance.weight / (math.pow(this.link.deriv(mu), 2.0) * family.variance(mu)) + (offset, weight) + } + } + } + + /** + * A description of the error distribution to be used in the model. + * @param name the name of the family. + */ + private[ml] abstract class Family(val name: String) extends Serializable { + + /** The default link instance of this family. */ + val defaultLink: Link + + /** Initialize the starting value for mu. */ + def initialize(y: Double, weight: Double): Double + + /** The variance of the endogenous variable's mean, given the value mu. */ + def variance(mu: Double): Double + + /** Trim the fitted value so that it will be in valid range. */ + def project(mu: Double): Double = mu + } + + private[ml] object Family { + + /** + * Gets the [[Family]] object from its name. + * @param name family name: "gaussian", "binomial", "poisson" or "gamma". + */ + def fromName(name: String): Family = { + name match { + case Gaussian.name => Gaussian + case Binomial.name => Binomial + case Poisson.name => Poisson + case Gamma.name => Gamma + } + } + } + + /** + * Gaussian exponential family distribution. + * The default link for the Gaussian family is the identity link. + */ + private[ml] object Gaussian extends Family("gaussian") { + + val defaultLink: Link = Identity + + override def initialize(y: Double, weight: Double): Double = y + + def variance(mu: Double): Double = 1.0 + + override def project(mu: Double): Double = { + if (mu.isNegInfinity) { + Double.MinValue + } else if (mu.isPosInfinity) { + Double.MaxValue + } else { + mu + } + } + } + + /** + * Binomial exponential family distribution. + * The default link for the Binomial family is the logit link. + */ + private[ml] object Binomial extends Family("binomial") { + + val defaultLink: Link = Logit + + override def initialize(y: Double, weight: Double): Double = { + val mu = (weight * y + 0.5) / (weight + 1.0) + require(mu > 0.0 && mu < 1.0, "The response variable of Binomial family" + + s"should be in range (0, 1), but got $mu") + mu + } + + override def variance(mu: Double): Double = mu * (1.0 - mu) + + override def project(mu: Double): Double = { + if (mu < epsilon) { + epsilon + } else if (mu > 1.0 - epsilon) { + 1.0 - epsilon + } else { + mu + } + } + } + + /** + * Poisson exponential family distribution. + * The default link for the Poisson family is the log link. + */ + private[ml] object Poisson extends Family("poisson") { + + val defaultLink: Link = Log + + override def initialize(y: Double, weight: Double): Double = { + require(y > 0.0, "The response variable of Poisson family " + + s"should be positive, but got $y") + y + } + + override def variance(mu: Double): Double = mu + + override def project(mu: Double): Double = { + if (mu < epsilon) { + epsilon + } else if (mu.isInfinity) { + Double.MaxValue + } else { + mu + } + } + } + + /** + * Gamma exponential family distribution. + * The default link for the Gamma family is the inverse link. + */ + private[ml] object Gamma extends Family("gamma") { + + val defaultLink: Link = Inverse + + override def initialize(y: Double, weight: Double): Double = { + require(y > 0.0, "The response variable of Gamma family " + + s"should be positive, but got $y") + y + } + + override def variance(mu: Double): Double = math.pow(mu, 2.0) + + override def project(mu: Double): Double = { + if (mu < epsilon) { + epsilon + } else if (mu.isInfinity) { + Double.MaxValue + } else { + mu + } + } + } + + /** + * A description of the link function to be used in the model. + * The link function provides the relationship between the linear predictor + * and the mean of the distribution function. + * @param name the name of link function. + */ + private[ml] abstract class Link(val name: String) extends Serializable { + + /** The link function. */ + def link(mu: Double): Double + + /** Derivative of the link function. */ + def deriv(mu: Double): Double + + /** The inverse link function. */ + def unlink(eta: Double): Double + } + + private[ml] object Link { + + /** + * Gets the [[Link]] object from its name. + * @param name link name: "identity", "logit", "log", + * "inverse", "probit", "cloglog" or "sqrt". + */ + def fromName(name: String): Link = { + name match { + case Identity.name => Identity + case Logit.name => Logit + case Log.name => Log + case Inverse.name => Inverse + case Probit.name => Probit + case CLogLog.name => CLogLog + case Sqrt.name => Sqrt + } + } + } + + private[ml] object Identity extends Link("identity") { + + override def link(mu: Double): Double = mu + + override def deriv(mu: Double): Double = 1.0 + + override def unlink(eta: Double): Double = eta + } + + private[ml] object Logit extends Link("logit") { + + override def link(mu: Double): Double = math.log(mu / (1.0 - mu)) + + override def deriv(mu: Double): Double = 1.0 / (mu * (1.0 - mu)) + + override def unlink(eta: Double): Double = 1.0 / (1.0 + math.exp(-1.0 * eta)) + } + + private[ml] object Log extends Link("log") { + + override def link(mu: Double): Double = math.log(mu) + + override def deriv(mu: Double): Double = 1.0 / mu + + override def unlink(eta: Double): Double = math.exp(eta) + } + + private[ml] object Inverse extends Link("inverse") { + + override def link(mu: Double): Double = 1.0 / mu + + override def deriv(mu: Double): Double = -1.0 * math.pow(mu, -2.0) + + override def unlink(eta: Double): Double = 1.0 / eta + } + + private[ml] object Probit extends Link("probit") { + + override def link(mu: Double): Double = GD(0.0, 1.0).icdf(mu) + + override def deriv(mu: Double): Double = 1.0 / GD(0.0, 1.0).pdf(GD(0.0, 1.0).icdf(mu)) + + override def unlink(eta: Double): Double = GD(0.0, 1.0).cdf(eta) + } + + private[ml] object CLogLog extends Link("cloglog") { + + override def link(mu: Double): Double = math.log(-1.0 * math.log(1 - mu)) + + override def deriv(mu: Double): Double = 1.0 / ((mu - 1.0) * math.log(1.0 - mu)) + + override def unlink(eta: Double): Double = 1.0 - math.exp(-1.0 * math.exp(eta)) + } + + private[ml] object Sqrt extends Link("sqrt") { + + override def link(mu: Double): Double = math.sqrt(mu) + + override def deriv(mu: Double): Double = 1.0 / (2.0 * math.sqrt(mu)) + + override def unlink(eta: Double): Double = math.pow(eta, 2.0) + } +} + +/** + * :: Experimental :: + * Model produced by [[GeneralizedLinearRegression]]. + */ +@Experimental +@Since("2.0.0") +class GeneralizedLinearRegressionModel private[ml] ( + @Since("2.0.0") override val uid: String, + @Since("2.0.0") val coefficients: Vector, + @Since("2.0.0") val intercept: Double) + extends RegressionModel[Vector, GeneralizedLinearRegressionModel] + with GeneralizedLinearRegressionBase { + + import GeneralizedLinearRegression._ + + lazy val familyObj = Family.fromName($(family)) + lazy val linkObj = if (isDefined(link)) { + Link.fromName($(link)) + } else { + familyObj.defaultLink + } + lazy val familyAndLink = new FamilyAndLink(familyObj, linkObj) + + override protected def predict(features: Vector): Double = { + val eta = BLAS.dot(features, coefficients) + intercept + familyAndLink.fitted(eta) + } + + @Since("2.0.0") + override def copy(extra: ParamMap): GeneralizedLinearRegressionModel = { + copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept), extra) + .setParent(parent) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index 1573bb4c1b745..36b006c10e1fb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -90,9 +90,9 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures } else { lit(1.0) } - dataset.select(col($(labelCol)), f, w) - .map { case Row(label: Double, feature: Double, weight: Double) => - (label, feature, weight) + dataset.select(col($(labelCol)), f, w).rdd.map { + case Row(label: Double, feature: Double, weight: Double) => + (label, feature, weight) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index e253f25c0ea65..b4f17b8e28982 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -158,19 +158,19 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String override protected def train(dataset: DataFrame): LinearRegressionModel = { // Extract the number of features before deciding optimization solver. - val numFeatures = dataset.select(col($(featuresCol))).limit(1).map { + val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd.map { case Row(features: Vector) => features.size }.first() val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) - if (($(solver) == "auto" && $(elasticNetParam) == 0.0 && numFeatures <= 4096) || - $(solver) == "normal") { + if (($(solver) == "auto" && $(elasticNetParam) == 0.0 && + numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == "normal") { require($(elasticNetParam) == 0.0, "Only L2 regularization can be used when normal " + "solver is used.'") // For low dimensional data, WeightedLeastSquares is more efficiently since the // training algorithm only requires one pass through the data. (SPARK-10668) val instances: RDD[Instance] = dataset.select( - col($(labelCol)), w, col($(featuresCol))).map { + col($(labelCol)), w, col($(featuresCol))).rdd.map { case Row(label: Double, weight: Double, features: Vector) => Instance(label, weight, features) } @@ -196,10 +196,11 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String return lrModel.setSummary(trainingSummary) } - val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).map { - case Row(label: Double, weight: Double, features: Vector) => - Instance(label, weight, features) - } + val instances: RDD[Instance] = + dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { + case Row(label: Double, weight: Double, features: Vector) => + Instance(label, weight, features) + } val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) @@ -277,8 +278,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) { new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) } else { + val standardizationParam = $(standardization) def effectiveL1RegFun = (index: Int) => { - if ($(standardization)) { + if (standardizationParam) { effectiveL1RegParam } else { // If `standardization` is false, we still standardize the data @@ -512,6 +514,7 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { * :: Experimental :: * Linear regression training results. Currently, the training summary ignores the * training coefficients except for the objective trace. + * * @param predictions predictions outputted by the model's `transform` method. * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. */ @@ -536,6 +539,7 @@ class LinearRegressionTrainingSummary private[regression] ( /** * :: Experimental :: * Linear regression results evaluated on a dataset. + * * @param predictions predictions outputted by the model's `transform` method. */ @Since("1.5.0") @@ -550,6 +554,7 @@ class LinearRegressionSummary private[regression] ( @transient private val metrics = new RegressionMetrics( predictions .select(predictionCol, labelCol) + .rdd .map { case Row(pred: Double, label: Double) => (pred, label) }, !model.getFitIntercept) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 93cf16e6f0c2a..ca0ed95a483f4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -1052,7 +1052,7 @@ private[python] class PythonMLLibAPI extends Serializable { * Java stub for the constructor of Python mllib RankingMetrics */ def newRankingMetrics(predictionAndLabels: DataFrame): RankingMetrics[Any] = { - new RankingMetrics(predictionAndLabels.map( + new RankingMetrics(predictionAndLabels.rdd.map( r => (r.getSeq(0).toArray[Any], r.getSeq(1).toArray[Any]))) } @@ -1135,7 +1135,7 @@ private[python] class PythonMLLibAPI extends Serializable { def createIndexedRowMatrix(rows: DataFrame, numRows: Long, numCols: Int): IndexedRowMatrix = { // We use DataFrames for serialization of IndexedRows from Python, // so map each Row in the DataFrame back to an IndexedRow. - val indexedRows = rows.map { + val indexedRows = rows.rdd.map { case Row(index: Long, vector: Vector) => IndexedRow(index, vector) } new IndexedRowMatrix(indexedRows, numRows, numCols) @@ -1147,7 +1147,7 @@ private[python] class PythonMLLibAPI extends Serializable { def createCoordinateMatrix(rows: DataFrame, numRows: Long, numCols: Long): CoordinateMatrix = { // We use DataFrames for serialization of MatrixEntry entries from // Python, so map each Row in the DataFrame back to a MatrixEntry. - val entries = rows.map { + val entries = rows.rdd.map { case Row(i: Long, j: Long, value: Double) => MatrixEntry(i, j, value) } new CoordinateMatrix(entries, numRows, numCols) @@ -1161,7 +1161,7 @@ private[python] class PythonMLLibAPI extends Serializable { // We use DataFrames for serialization of sub-matrix blocks from // Python, so map each Row in the DataFrame back to a // ((blockRowIndex, blockColIndex), sub-matrix) tuple. - val blockTuples = blocks.map { + val blockTuples = blocks.rdd.map { case Row(Row(blockRowIndex: Long, blockColIndex: Long), subMatrix: Matrix) => ((blockRowIndex.toInt, blockColIndex.toInt), subMatrix) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index c3882606d7dbd..f807b5683c390 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -408,6 +408,10 @@ class LogisticRegressionWithLBFGS * defaults to the mllib implementation. If more than two classes * or feature scaling is disabled, always uses mllib implementation. * Uses user provided weights. + * + * In the ml LogisticRegression implementation, the number of corrections + * used in the LBFGS update can not be configured. So `optimizer.setNumCorrections()` + * will have no effect if we fall into that route. */ override def run(input: RDD[LabeledPoint], initialWeights: Vector): LogisticRegressionModel = { run(input, initialWeights, userSuppliedWeights = true) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index 26c6235fe5907..3b91fe8643dd8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -143,7 +143,7 @@ object KMeansModel extends Loader[KMeansModel] { val k = (metadata \ "k").extract[Int] val centroids = sqlContext.read.parquet(Loader.dataPath(path)) Loader.checkSchema[Cluster](centroids.schema) - val localCentroids = centroids.map(Cluster.apply).collect() + val localCentroids = centroids.rdd.map(Cluster.apply).collect() assert(k == localCentroids.size) new KMeansModel(localCentroids.sortBy(_.id).map(_.point)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index b30ecb80209d9..25d67a3756f6c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -896,11 +896,11 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { Loader.checkSchema[EdgeData](edgeDataFrame.schema) val globalTopicTotals: LDA.TopicCounts = dataFrame.first().getAs[Vector](0).toBreeze.toDenseVector - val vertices: RDD[(VertexId, LDA.TopicCounts)] = vertexDataFrame.map { + val vertices: RDD[(VertexId, LDA.TopicCounts)] = vertexDataFrame.rdd.map { case Row(ind: Long, vec: Vector) => (ind, vec.toBreeze.toDenseVector) } - val edges: RDD[Edge[LDA.TokenCount]] = edgeDataFrame.map { + val edges: RDD[Edge[LDA.TokenCount]] = edgeDataFrame.rdd.map { case Row(srcId: Long, dstId: Long, prop: Double) => Edge(srcId, dstId, prop) } val graph: Graph[LDA.TopicCounts, LDA.TokenCount] = Graph(vertices, edges) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 7a41f74191536..7491ab0d51cac 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -25,7 +25,6 @@ import breeze.stats.distributions.{Gamma, RandBasis} import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.graphx._ -import org.apache.spark.graphx.impl.GraphImpl import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors} import org.apache.spark.rdd.RDD @@ -188,7 +187,7 @@ final class EMLDAOptimizer extends LDAOptimizer { graph.aggregateMessages[(Boolean, TopicCounts)](sendMsg, mergeMsg) .mapValues(_._2) // Update the vertex descriptors with the new counts. - val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges) + val newGraph = Graph(docTopicDistributions, graph.edges) graph = newGraph graphCheckpointer.update(newGraph) globalTopicTotals = computeGlobalTopicTotals() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index feacafec7930f..9732dfa1744f5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -93,7 +93,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode val assignments = sqlContext.read.parquet(Loader.dataPath(path)) Loader.checkSchema[PowerIterationClustering.Assignment](assignments.schema) - val assignmentsRDD = assignments.map { + val assignmentsRDD = assignments.rdd.map { case Row(id: Long, cluster: Int) => PowerIterationClustering.Assignment(id, cluster) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala index 12cf22095720a..319c54724dac1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala @@ -58,7 +58,7 @@ class BinaryClassificationMetrics @Since("1.3.0") ( * @param scoreAndLabels a DataFrame with two double columns: score and label */ private[mllib] def this(scoreAndLabels: DataFrame) = - this(scoreAndLabels.map(r => (r.getDouble(0), r.getDouble(1)))) + this(scoreAndLabels.rdd.map(r => (r.getDouble(0), r.getDouble(1)))) /** * Unpersist intermediate RDDs used in the computation. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index c5104960cfcb6..3029b15f588a4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -38,7 +38,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl * @param predictionAndLabels a DataFrame with two double columns: prediction and label */ private[mllib] def this(predictionAndLabels: DataFrame) = - this(predictionAndLabels.map(r => (r.getDouble(0), r.getDouble(1)))) + this(predictionAndLabels.rdd.map(r => (r.getDouble(0), r.getDouble(1)))) private lazy val labelCountByClass: Map[Double, Long] = predictionAndLabels.values.countByValue() private lazy val labelCount: Long = labelCountByClass.values.sum diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala index c100b3c9ec14a..daf6ff4db4ed0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala @@ -35,7 +35,9 @@ class MultilabelMetrics @Since("1.2.0") (predictionAndLabels: RDD[(Array[Double] * @param predictionAndLabels a DataFrame with two double array columns: prediction and label */ private[mllib] def this(predictionAndLabels: DataFrame) = - this(predictionAndLabels.map(r => (r.getSeq[Double](0).toArray, r.getSeq[Double](1).toArray))) + this(predictionAndLabels.rdd.map { r => + (r.getSeq[Double](0).toArray, r.getSeq[Double](1).toArray) + }) private lazy val numDocs: Long = predictionAndLabels.count() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala index 18c90b204a26a..0f4c97ec20c00 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala @@ -46,7 +46,7 @@ class RegressionMetrics @Since("2.0.0") ( * prediction and observation */ private[mllib] def this(predictionAndObservations: DataFrame) = - this(predictionAndObservations.map(r => (r.getDouble(0), r.getDouble(1)))) + this(predictionAndObservations.rdd.map(r => (r.getDouble(0), r.getDouble(1)))) /** * Use MultivariateOnlineSummarizer to calculate summary statistics of observations and errors. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index 33728bf5d77e5..4f0e13feae086 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -161,7 +161,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { // Check schema explicitly since erasure makes it hard to use match-case for checking. Loader.checkSchema[Data](dataFrame.schema) - val features = dataArray.map { + val features = dataArray.rdd.map { case Row(feature: Int) => (feature) }.collect() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 85d609386ff94..b35d7217d693e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -134,7 +134,7 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] { } def loadImpl[Item: ClassTag](freqItemsets: DataFrame, sample: Item): FPGrowthModel[Item] = { - val freqItemsetsRDD = freqItemsets.select("items", "freq").map { x => + val freqItemsetsRDD = freqItemsets.select("items", "freq").rdd.map { x => val items = x.getAs[Seq[Item]](0).toArray val freq = x.getLong(1) new FreqItemset(items, freq) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index a5bd77e6bee91..11179a21c815c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -41,7 +41,7 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) extends Optimizer with Logging { private var numCorrections = 10 - private var convergenceTol = 1E-4 + private var convergenceTol = 1E-6 private var maxNumIterations = 100 private var regParam = 0.0 @@ -59,7 +59,7 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) } /** - * Set the convergence tolerance of iterations for L-BFGS. Default 1E-4. + * Set the convergence tolerance of iterations for L-BFGS. Default 1E-6. * Smaller value will lead to higher accuracy with the cost of more iterations. * This value must be nonnegative. Lower convergence values are less tolerant * and therefore generally cause more iterations to be run. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 628cf1dd57280..c91729a9fd495 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -369,13 +369,13 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { assert(className == thisClassName) assert(formatVersion == thisFormatVersion) val rank = (metadata \ "rank").extract[Int] - val userFeatures = sqlContext.read.parquet(userPath(path)) - .map { case Row(id: Int, features: Seq[_]) => + val userFeatures = sqlContext.read.parquet(userPath(path)).rdd.map { + case Row(id: Int, features: Seq[_]) => + (id, features.asInstanceOf[Seq[Double]].toArray) + } + val productFeatures = sqlContext.read.parquet(productPath(path)).rdd.map { + case Row(id: Int, features: Seq[_]) => (id, features.asInstanceOf[Seq[Double]].toArray) - } - val productFeatures = sqlContext.read.parquet(productPath(path)) - .map { case Row(id: Int, features: Seq[_]) => - (id, features.asInstanceOf[Seq[Double]].toArray) } new MatrixFactorizationModel(rank, userFeatures, productFeatures) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 51235a23711a1..40440d50fc748 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -38,8 +38,9 @@ import org.apache.spark.util.random.XORShiftRandom /** * A class which implements a decision tree learning algorithm for classification and regression. * It supports both continuous and categorical features. + * * @param strategy The configuration parameters for the tree algorithm which specify the type - * of algorithm (classification, regression, etc.), feature type (continuous, + * of decision tree (classification or regression), feature type (continuous, * categorical), depth of the tree, quantile calculation strategy, etc. */ @Since("1.0.0") @@ -50,8 +51,8 @@ class DecisionTree @Since("1.0.0") (private val strategy: Strategy) /** * Method to train a decision tree model over an RDD - * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] - * @return DecisionTreeModel that can be used for prediction + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @return DecisionTreeModel that can be used for prediction. */ @Since("1.2.0") def run(input: RDD[LabeledPoint]): DecisionTreeModel = { @@ -77,9 +78,9 @@ object DecisionTree extends Serializable with Logging { * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. * @param strategy The configuration parameters for the tree algorithm which specify the type - * of algorithm (classification, regression, etc.), feature type (continuous, + * of decision tree (classification or regression), feature type (continuous, * categorical), depth of the tree, quantile calculation strategy, etc. - * @return DecisionTreeModel that can be used for prediction + * @return DecisionTreeModel that can be used for prediction. */ @Since("1.0.0") def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = { @@ -97,11 +98,11 @@ object DecisionTree extends Serializable with Logging { * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. - * @param algo algorithm, classification or regression - * @param impurity impurity criterion used for information gain calculation - * @param maxDepth Maximum depth of the tree. - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * @return DecisionTreeModel that can be used for prediction + * @param algo Type of decision tree, either classification or regression. + * @param impurity Criterion used for information gain calculation. + * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means + * 1 internal node + 2 leaf nodes). + * @return DecisionTreeModel that can be used for prediction. */ @Since("1.0.0") def train( @@ -124,12 +125,12 @@ object DecisionTree extends Serializable with Logging { * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. - * @param algo algorithm, classification or regression - * @param impurity impurity criterion used for information gain calculation - * @param maxDepth Maximum depth of the tree. - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * @param numClasses number of classes for classification. Default value of 2. - * @return DecisionTreeModel that can be used for prediction + * @param algo Type of decision tree, either classification or regression. + * @param impurity Criterion used for information gain calculation. + * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means + * 1 internal node + 2 leaf nodes). + * @param numClasses Number of classes for classification. Default value of 2. + * @return DecisionTreeModel that can be used for prediction. */ @Since("1.2.0") def train( @@ -153,17 +154,17 @@ object DecisionTree extends Serializable with Logging { * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. - * @param algo classification or regression - * @param impurity criterion used for information gain calculation - * @param maxDepth Maximum depth of the tree. - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * @param numClasses number of classes for classification. Default value of 2. - * @param maxBins maximum number of bins used for splitting features - * @param quantileCalculationStrategy algorithm for calculating quantiles - * @param categoricalFeaturesInfo Map storing arity of categorical features. - * E.g., an entry (n -> k) indicates that feature n is categorical - * with k categories indexed from 0: {0, 1, ..., k-1}. - * @return DecisionTreeModel that can be used for prediction + * @param algo Type of decision tree, either classification or regression. + * @param impurity Criterion used for information gain calculation. + * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means + * 1 internal node + 2 leaf nodes). + * @param numClasses Number of classes for classification. Default value of 2. + * @param maxBins Maximum number of bins used for splitting features. + * @param quantileCalculationStrategy Algorithm for calculating quantiles. + * @param categoricalFeaturesInfo Map storing arity of categorical features. An entry (n -> k) + * indicates that feature n is categorical with k categories + * indexed from 0: {0, 1, ..., k-1}. + * @return DecisionTreeModel that can be used for prediction. */ @Since("1.0.0") def train( @@ -185,18 +186,18 @@ object DecisionTree extends Serializable with Logging { * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * Labels should take values {0, 1, ..., numClasses-1}. - * @param numClasses number of classes for classification. - * @param categoricalFeaturesInfo Map storing arity of categorical features. - * E.g., an entry (n -> k) indicates that feature n is categorical - * with k categories indexed from 0: {0, 1, ..., k-1}. + * @param numClasses Number of classes for classification. + * @param categoricalFeaturesInfo Map storing arity of categorical features. An entry (n -> k) + * indicates that feature n is categorical with k categories + * indexed from 0: {0, 1, ..., k-1}. * @param impurity Criterion used for information gain calculation. * Supported values: "gini" (recommended) or "entropy". - * @param maxDepth Maximum depth of the tree. - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * (suggested value: 5) - * @param maxBins maximum number of bins used for splitting features - * (suggested value: 32) - * @return DecisionTreeModel that can be used for prediction + * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means + * 1 internal node + 2 leaf nodes). + * (suggested value: 5) + * @param maxBins Maximum number of bins used for splitting features. + * (suggested value: 32) + * @return DecisionTreeModel that can be used for prediction. */ @Since("1.1.0") def trainClassifier( @@ -232,17 +233,17 @@ object DecisionTree extends Serializable with Logging { * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * Labels are real numbers. - * @param categoricalFeaturesInfo Map storing arity of categorical features. - * E.g., an entry (n -> k) indicates that feature n is categorical - * with k categories indexed from 0: {0, 1, ..., k-1}. + * @param categoricalFeaturesInfo Map storing arity of categorical features. An entry (n -> k) + * indicates that feature n is categorical with k categories + * indexed from 0: {0, 1, ..., k-1}. * @param impurity Criterion used for information gain calculation. - * Supported values: "variance". - * @param maxDepth Maximum depth of the tree. - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * (suggested value: 5) - * @param maxBins maximum number of bins used for splitting features - * (suggested value: 32) - * @return DecisionTreeModel that can be used for prediction + * The only supported value for regression is "variance". + * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means + * 1 internal node + 2 leaf nodes). + * (suggested value: 5) + * @param maxBins Maximum number of bins used for splitting features. + * (suggested value: 32) + * @return DecisionTreeModel that can be used for prediction. */ @Since("1.1.0") def trainRegressor( @@ -277,7 +278,7 @@ object DecisionTree extends Serializable with Logging { * * @param node Node in tree from which to classify the given data point. * @param binnedFeatures Binned feature vector for data point. - * @param bins possible bins for all features, indexed (numFeatures)(numBins) + * @param bins Possible bins for all features, indexed (numFeatures)(numBins). * @param unorderedFeatures Set of indices of unordered features. * @return Leaf index if the data point reaches a leaf. * Otherwise, last node reachable in tree matching this example. @@ -333,12 +334,12 @@ object DecisionTree extends Serializable with Logging { * For unordered features, bins correspond to subsets of categories; either the left or right bin * for each subset is updated. * - * @param agg Array storing aggregate calculation, with a set of sufficient statistics for - * each (feature, bin). - * @param treePoint Data point being aggregated. - * @param splits possible splits indexed (numFeatures)(numSplits) - * @param unorderedFeatures Set of indices of unordered features. - * @param instanceWeight Weight (importance) of instance in dataset. + * @param agg Array storing aggregate calculation, with a set of sufficient statistics for + * each (feature, bin). + * @param treePoint Data point being aggregated. + * @param splits Possible splits indexed (numFeatures)(numSplits). + * @param unorderedFeatures Set of indices of unordered features. + * @param instanceWeight Weight (importance) of instance in dataset. */ private def mixedBinSeqOp( agg: DTStatsAggregator, @@ -394,10 +395,10 @@ object DecisionTree extends Serializable with Logging { * * For each feature, the sufficient statistics of one bin are updated. * - * @param agg Array storing aggregate calculation, with a set of sufficient statistics for - * each (feature, bin). - * @param treePoint Data point being aggregated. - * @param instanceWeight Weight (importance) of instance in dataset. + * @param agg Array storing aggregate calculation, with a set of sufficient statistics for + * each (feature, bin). + * @param treePoint Data point being aggregated. + * @param instanceWeight Weight (importance) of instance in dataset. */ private def orderedBinSeqOp( agg: DTStatsAggregator, @@ -430,17 +431,17 @@ object DecisionTree extends Serializable with Logging { /** * Given a group of nodes, this finds the best split for each node. * - * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] - * @param metadata Learning and dataset metadata + * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]. + * @param metadata Learning and dataset metadata. * @param topNodes Root node for each tree. Used for matching instances with nodes. - * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree + * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree. * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo, * where nodeIndexInfo stores the index in the group and the * feature subsets (if using feature subsets). - * @param splits possible splits for all features, indexed (numFeatures)(numSplits) - * @param bins possible bins for all features, indexed (numFeatures)(numBins) - * @param nodeQueue Queue of nodes to split, with values (treeIndex, node). - * Updated with new non-leaf nodes which are created. + * @param splits Possible splits for all features, indexed (numFeatures)(numSplits). + * @param bins Possible bins for all features, indexed (numFeatures)(numBins). + * @param nodeQueue Queue of nodes to split, with values (treeIndex, node). + * Updated with new non-leaf nodes which are created. * @param nodeIdCache Node Id cache containing an RDD of Array[Int] where * each value in the array is the data point's node Id * for a corresponding tree. This is used to prevent the need @@ -527,10 +528,10 @@ object DecisionTree extends Serializable with Logging { * Each data point contributes to one node. For each feature, * the aggregate sufficient statistics are updated for the relevant bins. * - * @param agg Array storing aggregate calculation, with a set of sufficient statistics for - * each (node, feature, bin). - * @param baggedPoint Data point being aggregated. - * @return agg + * @param agg Array storing aggregate calculation, with a set of sufficient statistics for + * each (node, feature, bin). + * @param baggedPoint Data point being aggregated. + * @return Array of decision tree statistics. */ def binSeqOp( agg: Array[DTStatsAggregator], @@ -563,6 +564,7 @@ object DecisionTree extends Serializable with Logging { /** * Get node index in group --> features indices map, * which is a short cut to find feature indices for a node given node index in group + * * @param treeToNodeToIndexInfo * @return */ @@ -719,9 +721,10 @@ object DecisionTree extends Serializable with Logging { /** * Calculate the information gain for a given (feature, split) based upon left/right aggregates. - * @param leftImpurityCalculator left node aggregates for this (feature, split) - * @param rightImpurityCalculator right node aggregate for this (feature, split) - * @return information gain and statistics for split + * + * @param leftImpurityCalculator Left node aggregates for this (feature, split). + * @param rightImpurityCalculator Right node aggregate for this (feature, split). + * @return Information gain and statistics for split. */ private def calculateGainForSplit( leftImpurityCalculator: ImpurityCalculator, @@ -771,9 +774,10 @@ object DecisionTree extends Serializable with Logging { /** * Calculate predict value for current node, given stats of any split. * Note that this function is called only once for each node. - * @param leftImpurityCalculator left node aggregates for a split - * @param rightImpurityCalculator right node aggregates for a split - * @return predict value and impurity for current node + * + * @param leftImpurityCalculator Left node aggregates for a split. + * @param rightImpurityCalculator Right node aggregates for a split. + * @return Predict value and impurity for current node. */ private def calculatePredictImpurity( leftImpurityCalculator: ImpurityCalculator, @@ -788,8 +792,9 @@ object DecisionTree extends Serializable with Logging { /** * Find the best split for a node. + * * @param binAggregates Bin statistics. - * @return tuple for best split: (Split, information gain, prediction at node) + * @return Tuple for best split: (Split, information gain, prediction at node). */ private[tree] def binsToBestSplit( binAggregates: DTStatsAggregator, @@ -955,8 +960,8 @@ object DecisionTree extends Serializable with Logging { * and for multiclass classification with a high-arity feature, * there is one bin per category. * - * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] - * @param metadata Learning and dataset metadata + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @param metadata Learning and dataset metadata. * @return A tuple of (splits, bins). * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]] * of size (numFeatures, numSplits). @@ -1102,12 +1107,13 @@ object DecisionTree extends Serializable with Logging { * NOTE: Returned number of splits is set based on `featureSamples` and * could be different from the specified `numSplits`. * The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly. - * @param featureSamples feature values of each sample - * @param metadata decision tree metadata + * + * @param featureSamples Feature values of each sample. + * @param metadata Decision tree metadata. * NOTE: `metadata.numbins` will be changed accordingly - * if there are not enough splits to be found - * @param featureIndex feature index to find splits - * @return array of splits + * if there are not enough splits to be found. + * @param featureIndex Feature index to find splits. + * @return Array of splits. */ private[tree] def findSplitsForContinuousFeature( featureSamples: Array[Double], diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index 1b71256c585bd..d131f5da6c7eb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -54,8 +54,9 @@ class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: Boosti /** * Method to train a gradient boosting model + * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * @return a gradient boosted trees model that can be used for prediction + * @return GradientBoostedTreesModel that can be used for prediction. */ @Since("1.2.0") def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = { @@ -82,13 +83,14 @@ class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: Boosti /** * Method to validate a gradient boosting model + * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * @param validationInput Validation dataset. * This dataset should be different from the training dataset, * but it should follow the same distribution. * E.g., these two datasets could be created from an original dataset * by using [[org.apache.spark.rdd.RDD.randomSplit()]] - * @return a gradient boosted trees model that can be used for prediction + * @return GradientBoostedTreesModel that can be used for prediction. */ @Since("1.4.0") def runWithValidation( @@ -132,7 +134,7 @@ object GradientBoostedTrees extends Logging { * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. * @param boostingStrategy Configuration options for the boosting algorithm. - * @return a gradient boosted trees model that can be used for prediction + * @return GradientBoostedTreesModel that can be used for prediction. */ @Since("1.2.0") def train( @@ -153,11 +155,11 @@ object GradientBoostedTrees extends Logging { /** * Internal method for performing regression using trees as base learners. - * @param input training dataset - * @param validationInput validation dataset, ignored if validate is set to false. - * @param boostingStrategy boosting parameters - * @param validate whether or not to use the validation dataset. - * @return a gradient boosted trees model that can be used for prediction + * @param input Training dataset. + * @param validationInput Validation dataset, ignored if validate is set to false. + * @param boostingStrategy Boosting parameters. + * @param validate Whether or not to use the validation dataset. + * @return GradientBoostedTreesModel that can be used for prediction. */ private def boost( input: RDD[LabeledPoint], diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index 570a76f960796..b7714b382a594 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -53,12 +53,12 @@ import org.apache.spark.util.random.SamplingUtils * random forests]] * * @param strategy The configuration parameters for the random forest algorithm which specify - * the type of algorithm (classification, regression, etc.), feature type + * the type of random forest (classification or regression), feature type * (continuous, categorical), depth of the tree, quantile calculation strategy, * etc. * @param numTrees If 1, then no bootstrapping is used. If > 1, then bootstrapping is done. * @param featureSubsetStrategy Number of features to consider for splits at each node. - * Supported: "auto", "all", "sqrt", "log2", "onethird". + * Supported values: "auto", "all", "sqrt", "log2", "onethird". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; * if numTrees > 1 (forest) set to "sqrt" for classification and @@ -121,8 +121,9 @@ private class RandomForest ( /** * Method to train a decision tree model over an RDD - * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] - * @return a random forest model that can be used for prediction + * + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @return RandomForestModel that can be used for prediction. */ def run(input: RDD[LabeledPoint]): RandomForestModel = { @@ -269,12 +270,12 @@ object RandomForest extends Serializable with Logging { * @param strategy Parameters for training each tree in the forest. * @param numTrees Number of trees in the random forest. * @param featureSubsetStrategy Number of features to consider for splits at each node. - * Supported: "auto", "all", "sqrt", "log2", "onethird". + * Supported values: "auto", "all", "sqrt", "log2", "onethird". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; * if numTrees > 1 (forest) set to "sqrt". - * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return a random forest model that can be used for prediction + * @param seed Random seed for bootstrapping and choosing feature subsets. + * @return RandomForestModel that can be used for prediction. */ @Since("1.2.0") def trainClassifier( @@ -294,25 +295,25 @@ object RandomForest extends Serializable with Logging { * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * Labels should take values {0, 1, ..., numClasses-1}. - * @param numClasses number of classes for classification. - * @param categoricalFeaturesInfo Map storing arity of categorical features. - * E.g., an entry (n -> k) indicates that feature n is categorical - * with k categories indexed from 0: {0, 1, ..., k-1}. + * @param numClasses Number of classes for classification. + * @param categoricalFeaturesInfo Map storing arity of categorical features. An entry (n -> k) + * indicates that feature n is categorical with k categories + * indexed from 0: {0, 1, ..., k-1}. * @param numTrees Number of trees in the random forest. * @param featureSubsetStrategy Number of features to consider for splits at each node. - * Supported: "auto", "all", "sqrt", "log2", "onethird". + * Supported values: "auto", "all", "sqrt", "log2", "onethird". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; * if numTrees > 1 (forest) set to "sqrt". * @param impurity Criterion used for information gain calculation. * Supported values: "gini" (recommended) or "entropy". - * @param maxDepth Maximum depth of the tree. - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * (suggested value: 4) - * @param maxBins maximum number of bins used for splitting features - * (suggested value: 100) - * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return a random forest model that can be used for prediction + * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means + * 1 internal node + 2 leaf nodes). + * (suggested value: 4) + * @param maxBins Maximum number of bins used for splitting features + * (suggested value: 100) + * @param seed Random seed for bootstrapping and choosing feature subsets. + * @return RandomForestModel that can be used for prediction. */ @Since("1.2.0") def trainClassifier( @@ -358,12 +359,12 @@ object RandomForest extends Serializable with Logging { * @param strategy Parameters for training each tree in the forest. * @param numTrees Number of trees in the random forest. * @param featureSubsetStrategy Number of features to consider for splits at each node. - * Supported: "auto", "all", "sqrt", "log2", "onethird". + * Supported values: "auto", "all", "sqrt", "log2", "onethird". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; * if numTrees > 1 (forest) set to "onethird". - * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return a random forest model that can be used for prediction + * @param seed Random seed for bootstrapping and choosing feature subsets. + * @return RandomForestModel that can be used for prediction. */ @Since("1.2.0") def trainRegressor( @@ -383,24 +384,24 @@ object RandomForest extends Serializable with Logging { * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * Labels are real numbers. - * @param categoricalFeaturesInfo Map storing arity of categorical features. - * E.g., an entry (n -> k) indicates that feature n is categorical - * with k categories indexed from 0: {0, 1, ..., k-1}. + * @param categoricalFeaturesInfo Map storing arity of categorical features. An entry (n -> k) + * indicates that feature n is categorical with k categories + * indexed from 0: {0, 1, ..., k-1}. * @param numTrees Number of trees in the random forest. * @param featureSubsetStrategy Number of features to consider for splits at each node. - * Supported: "auto", "all", "sqrt", "log2", "onethird". + * Supported values: "auto", "all", "sqrt", "log2", "onethird". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; * if numTrees > 1 (forest) set to "onethird". * @param impurity Criterion used for information gain calculation. - * Supported values: "variance". - * @param maxDepth Maximum depth of the tree. - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * (suggested value: 4) - * @param maxBins maximum number of bins used for splitting features - * (suggested value: 100) - * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return a random forest model that can be used for prediction + * The only supported value for regression is "variance". + * @param maxDepth Maximum depth of the tree. (e.g., depth 0 means 1 leaf node, depth 1 means + * 1 internal node + 2 leaf nodes). + * (suggested value: 4) + * @param maxBins Maximum number of bins used for splitting features. + * (suggested value: 100) + * @param seed Random seed for bootstrapping and choosing feature subsets. + * @return RandomForestModel that can be used for prediction. */ @Since("1.2.0") def trainRegressor( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 6c04403f1ad75..9e3e50192d507 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -34,8 +34,8 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} * Supported for Classification: [[org.apache.spark.mllib.tree.impurity.Gini]], * [[org.apache.spark.mllib.tree.impurity.Entropy]]. * Supported for Regression: [[org.apache.spark.mllib.tree.impurity.Variance]]. - * @param maxDepth Maximum depth of the tree. - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means + * 1 internal node + 2 leaf nodes). * @param numClasses Number of classes for classification. * (Ignored for regression.) * Default value is 2 (binary classification). @@ -45,10 +45,9 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} * @param quantileCalculationStrategy Algorithm for calculating quantiles. Supported: * [[org.apache.spark.mllib.tree.configuration.QuantileStrategy.Sort]] * @param categoricalFeaturesInfo A map storing information about the categorical variables and the - * number of discrete values they take. For example, an entry (n -> - * k) implies the feature n is categorical with k categories 0, - * 1, 2, ... , k-1. It's important to note that features are - * zero-indexed. + * number of discrete values they take. An entry (n -> k) + * indicates that feature n is categorical with k categories + * indexed from 0: {0, 1, ..., k-1}. * @param minInstancesPerNode Minimum number of instances each child must have after split. * Default value is 1. If a split cause left or right child * to have less than minInstancesPerNode, @@ -60,7 +59,7 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} * 256 MB. * @param subsamplingRate Fraction of the training data used for learning decision tree. * @param useNodeIdCache If this is true, instead of passing trees to executors, the algorithm will - * maintain a separate RDD of node Id cache for each row. + * maintain a separate RDD of node Id cache for each row. * @param checkpointInterval How often to checkpoint when the node Id cache gets updated. * E.g. 10 means that the cache will get checkpointed every 10 updates. If * the checkpoint directory is not set in diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 89c470d573431..ec5d7b91892c9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -247,7 +247,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { val dataRDD = sqlContext.read.parquet(datapath) // Check schema explicitly since erasure makes it hard to use match-case for checking. Loader.checkSchema[NodeData](dataRDD.schema) - val nodes = dataRDD.map(NodeData.apply) + val nodes = dataRDD.rdd.map(NodeData.apply) // Build node data into a tree. val trees = constructTrees(nodes) assert(trees.size == 1, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index feabcee24fa2c..59713c382e58b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -473,7 +473,7 @@ private[tree] object TreeEnsembleModel extends Logging { treeAlgo: String): Array[DecisionTreeModel] = { val datapath = Loader.dataPath(path) val sqlContext = SQLContext.getOrCreate(sc) - val nodes = sqlContext.read.parquet(datapath).map(NodeData.apply) + val nodes = sqlContext.read.parquet(datapath).rdd.map(NodeData.apply) val trees = constructTrees(nodes) trees.map(new DecisionTreeModel(_, Algo.fromString(treeAlgo))) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 972c0868a454a..cfb9bbfd41ee7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -735,7 +735,7 @@ class LogisticRegressionSuite val model1 = trainer1.fit(binaryDataset) val model2 = trainer2.fit(binaryDataset) - val histogram = binaryDataset.map { case Row(label: Double, features: Vector) => label } + val histogram = binaryDataset.rdd.map { case Row(label: Double, features: Vector) => label } .treeAggregate(new MultiClassSummarizer)( seqOp = (c, v) => (c, v) match { case (classSummarizer: MultiClassSummarizer, label: Double) => classSummarizer.add(label) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index a326432d017fc..602b5a8116998 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -76,8 +76,9 @@ class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSp val model = trainer.fit(dataFrame) val numFeatures = dataFrame.select("features").first().getAs[Vector](0).size assert(model.numFeatures === numFeatures) - val mlpPredictionAndLabels = model.transform(dataFrame).select("prediction", "label") - .map { case Row(p: Double, l: Double) => (p, l) } + val mlpPredictionAndLabels = model.transform(dataFrame).select("prediction", "label").rdd.map { + case Row(p: Double, l: Double) => (p, l) + } // train multinomial logistic regression val lr = new LogisticRegressionWithLBFGS() .setIntercept(true) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 445e50d867e15..2ae74a2090ecf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -81,9 +81,9 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { val predictionColSchema = transformedDataset.schema(ovaModel.getPredictionCol) assert(MetadataUtils.getNumClasses(predictionColSchema) === Some(3)) - val ovaResults = transformedDataset - .select("prediction", "label") - .map(row => (row.getDouble(0), row.getDouble(1))) + val ovaResults = transformedDataset.select("prediction", "label").rdd.map { + row => (row.getDouble(0), row.getDouble(1)) + } val lr = new LogisticRegressionWithLBFGS().setIntercept(true).setNumClasses(numClasses) lr.optimizer.setRegParam(0.1).setNumIterations(100) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index fc4a4add5d755..b719a8c7e7340 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -77,7 +77,8 @@ class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext { expectedColumns.foreach { column => assert(transformed.columns.contains(column)) } - val clusters = transformed.select(predictionColName).map(_.getInt(0)).distinct().collect().toSet + val clusters = + transformed.select(predictionColName).rdd.map(_.getInt(0)).distinct().collect().toSet assert(clusters.size === k) assert(clusters === Set(0, 1, 2, 3, 4)) assert(model.computeCost(dataset) < 0.1) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index e5357ba8e220c..c684bc11cccff 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -93,7 +93,8 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR expectedColumns.foreach { column => assert(transformed.columns.contains(column)) } - val clusters = transformed.select(predictionColName).map(_.getInt(0)).distinct().collect().toSet + val clusters = + transformed.select(predictionColName).rdd.map(_.getInt(0)).distinct().collect().toSet assert(clusters.size === k) assert(clusters === Set(0, 1, 2, 3, 4)) assert(model.computeCost(dataset) < 0.1) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index 97dbfd9a4314a..03270401ad2bb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -199,7 +199,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead // describeTopics val topics = model.describeTopics(3) assert(topics.count() === k) - assert(topics.select("topic").map(_.getInt(0)).collect().toSet === Range(0, k).toSet) + assert(topics.select("topic").rdd.map(_.getInt(0)).collect().toSet === Range(0, k).toSet) topics.select("termIndices").collect().foreach { case r: Row => val termIndices = r.getAs[Seq[Int]](0) assert(termIndices.length === 3 && termIndices.toSet.size === 3) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala new file mode 100644 index 0000000000000..e083d4713680e --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Row + +class MaxAbsScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + test("MaxAbsScaler fit basic case") { + val data = Array( + Vectors.dense(1, 0, 100), + Vectors.dense(2, 0, 0), + Vectors.sparse(3, Array(0, 2), Array(-2, -100)), + Vectors.sparse(3, Array(0), Array(-1.5))) + + val expected: Array[Vector] = Array( + Vectors.dense(0.5, 0, 1), + Vectors.dense(1, 0, 0), + Vectors.sparse(3, Array(0, 2), Array(-1, -1)), + Vectors.sparse(3, Array(0), Array(-0.75))) + + val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected") + val scaler = new MaxAbsScaler() + .setInputCol("features") + .setOutputCol("scaled") + + val model = scaler.fit(df) + model.transform(df).select("expected", "scaled").collect() + .foreach { case Row(vector1: Vector, vector2: Vector) => + assert(vector1.equals(vector2), s"MaxAbsScaler ut error: $vector2 should be $vector1") + } + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + } + + test("MaxAbsScaler read/write") { + val t = new MaxAbsScaler() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + testDefaultReadWrite(t) + } + + test("MaxAbsScalerModel read/write") { + val instance = new MaxAbsScalerModel( + "myMaxAbsScalerModel", Vectors.dense(1.0, 10.0)) + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.maxAbs === instance.maxAbs) + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index 76d12050f9677..e238b33ed8c64 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -51,7 +51,7 @@ class OneHotEncoderSuite .setDropLast(false) val encoded = encoder.transform(transformed) - val output = encoded.select("id", "labelVec").map { r => + val output = encoded.select("id", "labelVec").rdd.map { r => val vec = r.getAs[Vector](1) (r.getInt(0), vec(0), vec(1), vec(2)) }.collect().toSet @@ -68,7 +68,7 @@ class OneHotEncoderSuite .setOutputCol("labelVec") val encoded = encoder.transform(transformed) - val output = encoded.select("id", "labelVec").map { r => + val output = encoded.select("id", "labelVec").rdd.map { r => val vec = r.getAs[Vector](1) (r.getInt(0), vec(0), vec(1)) }.collect().toSet diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala index dfdc5792c6dbc..86dbee1cf4a5a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala @@ -34,22 +34,31 @@ class PolynomialExpansionSuite ParamsSuite.checkParams(new PolynomialExpansion) } - test("Polynomial expansion with default parameter") { - val data = Array( - Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), - Vectors.dense(-2.0, 2.3), - Vectors.dense(0.0, 0.0, 0.0), - Vectors.dense(0.6, -1.1, -3.0), - Vectors.sparse(3, Seq()) - ) - - val twoDegreeExpansion: Array[Vector] = Array( - Vectors.sparse(9, Array(0, 1, 2, 3, 4), Array(-2.0, 4.0, 2.3, -4.6, 5.29)), - Vectors.dense(-2.0, 4.0, 2.3, -4.6, 5.29), - Vectors.dense(new Array[Double](9)), - Vectors.dense(0.6, 0.36, -1.1, -0.66, 1.21, -3.0, -1.8, 3.3, 9.0), - Vectors.sparse(9, Array.empty, Array.empty)) + private val data = Array( + Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), + Vectors.dense(-2.0, 2.3), + Vectors.dense(0.0, 0.0, 0.0), + Vectors.dense(0.6, -1.1, -3.0), + Vectors.sparse(3, Seq()) + ) + + private val twoDegreeExpansion: Array[Vector] = Array( + Vectors.sparse(9, Array(0, 1, 2, 3, 4), Array(-2.0, 4.0, 2.3, -4.6, 5.29)), + Vectors.dense(-2.0, 4.0, 2.3, -4.6, 5.29), + Vectors.dense(new Array[Double](9)), + Vectors.dense(0.6, 0.36, -1.1, -0.66, 1.21, -3.0, -1.8, 3.3, 9.0), + Vectors.sparse(9, Array.empty, Array.empty)) + + private val threeDegreeExpansion: Array[Vector] = Array( + Vectors.sparse(19, Array(0, 1, 2, 3, 4, 5, 6, 7, 8), + Array(-2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17)), + Vectors.dense(-2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17), + Vectors.dense(new Array[Double](19)), + Vectors.dense(0.6, 0.36, 0.216, -1.1, -0.66, -0.396, 1.21, 0.726, -1.331, -3.0, -1.8, + -1.08, 3.3, 1.98, -3.63, 9.0, 5.4, -9.9, -27.0), + Vectors.sparse(19, Array.empty, Array.empty)) + test("Polynomial expansion with default parameter") { val df = sqlContext.createDataFrame(data.zip(twoDegreeExpansion)).toDF("features", "expected") val polynomialExpansion = new PolynomialExpansion() @@ -67,23 +76,6 @@ class PolynomialExpansionSuite } test("Polynomial expansion with setter") { - val data = Array( - Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), - Vectors.dense(-2.0, 2.3), - Vectors.dense(0.0, 0.0, 0.0), - Vectors.dense(0.6, -1.1, -3.0), - Vectors.sparse(3, Seq()) - ) - - val threeDegreeExpansion: Array[Vector] = Array( - Vectors.sparse(19, Array(0, 1, 2, 3, 4, 5, 6, 7, 8), - Array(-2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17)), - Vectors.dense(-2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17), - Vectors.dense(new Array[Double](19)), - Vectors.dense(0.6, 0.36, 0.216, -1.1, -0.66, -0.396, 1.21, 0.726, -1.331, -3.0, -1.8, - -1.08, 3.3, 1.98, -3.63, 9.0, 5.4, -9.9, -27.0), - Vectors.sparse(19, Array.empty, Array.empty)) - val df = sqlContext.createDataFrame(data.zip(threeDegreeExpansion)).toDF("features", "expected") val polynomialExpansion = new PolynomialExpansion() @@ -101,6 +93,22 @@ class PolynomialExpansionSuite } } + test("Polynomial expansion with degree 1 is identity on vectors") { + val df = sqlContext.createDataFrame(data.zip(data)).toDF("features", "expected") + + val polynomialExpansion = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + .setDegree(1) + + polynomialExpansion.transform(df).select("polyFeatures", "expected").collect().foreach { + case Row(expanded: Vector, expected: Vector) => + assert(expanded ~== expected absTol 1e-1) + case _ => + throw new TestFailedException("Unmatched data types after polynomial expansion", 0) + } + } + test("read/write") { val t = new PolynomialExpansion() .setInputCol("myInputCol") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index 6a2c601bbed18..25fabf64d5594 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -71,6 +71,26 @@ class QuantileDiscretizerSuite } } + test("Test splits on dataset larger than minSamplesRequired") { + val sqlCtx = SQLContext.getOrCreate(sc) + import sqlCtx.implicits._ + + val datasetSize = QuantileDiscretizer.minSamplesRequired + 1 + val numBuckets = 5 + val df = sc.parallelize((1.0 to datasetSize by 1.0).map(Tuple1.apply)).toDF("input") + val discretizer = new QuantileDiscretizer() + .setInputCol("input") + .setOutputCol("result") + .setNumBuckets(numBuckets) + .setSeed(1) + + val result = discretizer.fit(df).transform(df) + val observedNumBuckets = result.select("result").distinct.count + + assert(observedNumBuckets === numBuckets, + "Observed number of buckets does not equal expected number of buckets.") + } + test("read/write") { val t = new QuantileDiscretizer() .setInputCol("myInputCol") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 5d199ca9b51b1..b9533f881dd31 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -52,7 +52,7 @@ class StringIndexerSuite val attr = Attribute.fromStructField(transformed.schema("labelIndex")) .asInstanceOf[NominalAttribute] assert(attr.values.get === Array("a", "c", "b")) - val output = transformed.select("id", "labelIndex").map { r => + val output = transformed.select("id", "labelIndex").rdd.map { r => (r.getInt(0), r.getDouble(1)) }.collect().toSet // a -> 0, b -> 2, c -> 1 @@ -83,7 +83,7 @@ class StringIndexerSuite val attr = Attribute.fromStructField(transformed.schema("labelIndex")) .asInstanceOf[NominalAttribute] assert(attr.values.get === Array("b", "a")) - val output = transformed.select("id", "labelIndex").map { r => + val output = transformed.select("id", "labelIndex").rdd.map { r => (r.getInt(0), r.getDouble(1)) }.collect().toSet // a -> 1, b -> 0 @@ -102,7 +102,7 @@ class StringIndexerSuite val attr = Attribute.fromStructField(transformed.schema("labelIndex")) .asInstanceOf[NominalAttribute] assert(attr.values.get === Array("100", "300", "200")) - val output = transformed.select("id", "labelIndex").map { r => + val output = transformed.select("id", "labelIndex").rdd.map { r => (r.getInt(0), r.getDouble(1)) }.collect().toSet // 100 -> 0, 200 -> 2, 300 -> 1 @@ -118,6 +118,17 @@ class StringIndexerSuite assert(indexerModel.transform(df).eq(df)) } + test("StringIndexerModel can't overwrite output column") { + val df = sqlContext.createDataFrame(Seq((1, 2), (3, 4))).toDF("input", "output") + val indexer = new StringIndexer() + .setInputCol("input") + .setOutputCol("output") + .fit(df) + intercept[IllegalArgumentException] { + indexer.transform(df) + } + } + test("StringIndexer read/write") { val t = new StringIndexer() .setInputCol("myInputCol") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index 67817fa4baf56..d4f836ef33ad9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -159,7 +159,7 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext // Chose correct categorical features assert(categoryMaps.keys.toSet === categoricalFeatures) val transformed = model.transform(data).select("indexed") - val indexedRDD: RDD[Vector] = transformed.map(_.getAs[Vector](0)) + val indexedRDD: RDD[Vector] = transformed.rdd.map(_.getAs[Vector](0)) val featureAttrs = AttributeGroup.fromStructField(transformed.schema("indexed")) assert(featureAttrs.name === "indexed") assert(featureAttrs.attributes.get.length === model.numFeatures) @@ -216,7 +216,8 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext val points = data.collect().map(_.getAs[Vector](0)) val vectorIndexer = getIndexer.setMaxCategories(maxCategories) val model = vectorIndexer.fit(data) - val indexedPoints = model.transform(data).select("indexed").map(_.getAs[Vector](0)).collect() + val indexedPoints = + model.transform(data).select("indexed").rdd.map(_.getAs[Vector](0)).collect() points.zip(indexedPoints).foreach { case (orig: SparseVector, indexed: SparseVector) => assert(orig.indices.length == indexed.indices.length) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index f094c550e545a..1671fb6f3a85e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -100,7 +100,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul .setSeed(42L) .fit(docDF) - val realVectors = model.getVectors.sort("word").select("vector").map { + val realVectors = model.getVectors.sort("word").select("vector").rdd.map { case Row(v: Vector) => v }.collect() // These expectations are just magic values, characterizing the current @@ -134,7 +134,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul .fit(docDF) val expectedSimilarity = Array(0.2608488929093532, -0.8271274846926078) - val (synonyms, similarity) = model.findSynonyms("a", 2).map { + val (synonyms, similarity) = model.findSynonyms("a", 2).rdd.map { case Row(w: String, sim: Double) => (w, sim) }.collect().unzip @@ -161,7 +161,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul .setSeed(42L) .fit(docDF) - val (synonyms, similarity) = model.findSynonyms("a", 6).map { + val (synonyms, similarity) = model.findSynonyms("a", 6).rdd.map { case Row(w: String, sim: Double) => (w, sim) }.collect().unzip @@ -174,7 +174,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul .setWindowSize(10) .fit(docDF) - val (synonymsLarger, similarityLarger) = model.findSynonyms("a", 6).map { + val (synonymsLarger, similarityLarger) = model.findSynonyms("a", 6).rdd.map { case Row(w: String, sim: Double) => (w, sim) }.collect().unzip // The similarity score should be very different with the larger window diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index ff0d8f5568279..2bedd70ce93e7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -342,11 +342,10 @@ class ALSSuite .setSeed(0) val alpha = als.getAlpha val model = als.fit(training.toDF()) - val predictions = model.transform(test.toDF()) - .select("rating", "prediction") - .map { case Row(rating: Float, prediction: Float) => + val predictions = model.transform(test.toDF()).select("rating", "prediction").rdd.map { + case Row(rating: Float, prediction: Float) => (rating.toDouble, prediction.toDouble) - } + } val rmse = if (implicitPrefs) { // TODO: Use a better (rank-based?) evaluation metric for implicit feedback. diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 09326600e620f..244db8637bea0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -87,7 +87,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { // copied model must have the same parent. MLTestingUtils.checkCopy(model) val preds = model.transform(df) - val predictions = preds.select("prediction").map(_.getDouble(0)) + val predictions = preds.select("prediction").rdd.map(_.getDouble(0)) // Checks based on SPARK-8736 (to ensure it is not doing classification) assert(predictions.max() > 2) assert(predictions.min() < -1) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala new file mode 100644 index 0000000000000..8bfa9855ce4ea --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -0,0 +1,507 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.mllib.classification.LogisticRegressionSuite._ +import org.apache.spark.mllib.linalg.{BLAS, DenseVector, Vectors} +import org.apache.spark.mllib.random._ +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{DataFrame, Row} + +class GeneralizedLinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { + + private val seed: Int = 42 + @transient var datasetGaussianIdentity: DataFrame = _ + @transient var datasetGaussianLog: DataFrame = _ + @transient var datasetGaussianInverse: DataFrame = _ + @transient var datasetBinomial: DataFrame = _ + @transient var datasetPoissonLog: DataFrame = _ + @transient var datasetPoissonIdentity: DataFrame = _ + @transient var datasetPoissonSqrt: DataFrame = _ + @transient var datasetGammaInverse: DataFrame = _ + @transient var datasetGammaIdentity: DataFrame = _ + @transient var datasetGammaLog: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + import GeneralizedLinearRegressionSuite._ + + datasetGaussianIdentity = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gaussian", link = "identity"), 2)) + + datasetGaussianLog = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gaussian", link = "log"), 2)) + + datasetGaussianInverse = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gaussian", link = "inverse"), 2)) + + datasetBinomial = { + val nPoints = 10000 + val coefficients = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) + val xMean = Array(5.843, 3.057, 3.758, 1.199) + val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) + + val testData = + generateMultinomialLogisticInput(coefficients, xMean, xVariance, + addIntercept = true, nPoints, seed) + + sqlContext.createDataFrame(sc.parallelize(testData, 2)) + } + + datasetPoissonLog = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "poisson", link = "log"), 2)) + + datasetPoissonIdentity = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "poisson", link = "identity"), 2)) + + datasetPoissonSqrt = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "poisson", link = "sqrt"), 2)) + + datasetGammaInverse = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gamma", link = "inverse"), 2)) + + datasetGammaIdentity = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gamma", link = "identity"), 2)) + + datasetGammaLog = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gamma", link = "log"), 2)) + } + + test("params") { + ParamsSuite.checkParams(new GeneralizedLinearRegression) + val model = new GeneralizedLinearRegressionModel("genLinReg", Vectors.dense(0.0), 0.0) + ParamsSuite.checkParams(model) + } + + test("generalized linear regression: default params") { + val glr = new GeneralizedLinearRegression + assert(glr.getLabelCol === "label") + assert(glr.getFeaturesCol === "features") + assert(glr.getPredictionCol === "prediction") + assert(glr.getFitIntercept) + assert(glr.getTol === 1E-6) + assert(glr.getWeightCol === "") + assert(glr.getRegParam === 0.0) + assert(glr.getSolver == "irls") + // TODO: Construct model directly instead of via fitting. + val model = glr.setFamily("gaussian").setLink("identity") + .fit(datasetGaussianIdentity) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + + assert(model.getFeaturesCol === "features") + assert(model.getPredictionCol === "prediction") + assert(model.intercept !== 0.0) + assert(model.hasParent) + assert(model.getFamily === "gaussian") + assert(model.getLink === "identity") + } + + test("generalized linear regression: gaussian family against glm") { + /* + R code: + f1 <- data$V1 ~ data$V2 + data$V3 - 1 + f2 <- data$V1 ~ data$V2 + data$V3 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family="gaussian", data=data) + print(as.vector(coef(model))) + } + + [1] 2.2960999 0.8087933 + [1] 2.5002642 2.2000403 0.5999485 + + data <- read.csv("path", header=FALSE) + model1 <- glm(f1, family=gaussian(link=log), data=data, start=c(0,0)) + model2 <- glm(f2, family=gaussian(link=log), data=data, start=c(0,0,0)) + print(as.vector(coef(model1))) + print(as.vector(coef(model2))) + + [1] 0.23069326 0.07993778 + [1] 0.25001858 0.22002452 0.05998789 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family=gaussian(link=inverse), data=data) + print(as.vector(coef(model))) + } + + [1] 2.3010179 0.8198976 + [1] 2.4108902 2.2130248 0.6086152 + */ + + val expected = Seq( + Vectors.dense(0.0, 2.2960999, 0.8087933), + Vectors.dense(2.5002642, 2.2000403, 0.5999485), + Vectors.dense(0.0, 0.23069326, 0.07993778), + Vectors.dense(0.25001858, 0.22002452, 0.05998789), + Vectors.dense(0.0, 2.3010179, 0.8198976), + Vectors.dense(2.4108902, 2.2130248, 0.6086152)) + + import GeneralizedLinearRegression._ + + var idx = 0 + for ((link, dataset) <- Seq(("identity", datasetGaussianIdentity), ("log", datasetGaussianLog), + ("inverse", datasetGaussianInverse))) { + for (fitIntercept <- Seq(false, true)) { + val trainer = new GeneralizedLinearRegression().setFamily("gaussian").setLink(link) + .setFitIntercept(fitIntercept) + val model = trainer.fit(dataset) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) + assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gaussian family, " + + s"$link link and fitIntercept = $fitIntercept.") + + val familyLink = new FamilyAndLink(Gaussian, Link.fromName(link)) + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val eta = BLAS.dot(features, model.coefficients) + model.intercept + val prediction2 = familyLink.fitted(eta) + assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + + s"gaussian family, $link link and fitIntercept = $fitIntercept.") + } + + idx += 1 + } + } + } + + test("generalized linear regression: gaussian family against glmnet") { + /* + R code: + library(glmnet) + data <- read.csv("path", header=FALSE) + label = data$V1 + features = as.matrix(data.frame(data$V2, data$V3)) + for (intercept in c(FALSE, TRUE)) { + for (lambda in c(0.0, 0.1, 1.0)) { + model <- glmnet(features, label, family="gaussian", intercept=intercept, + lambda=lambda, alpha=0, thresh=1E-14) + print(as.vector(coef(model))) + } + } + + [1] 0.0000000 2.2961005 0.8087932 + [1] 0.0000000 2.2130368 0.8309556 + [1] 0.0000000 1.7176137 0.9610657 + [1] 2.5002642 2.2000403 0.5999485 + [1] 3.1106389 2.0935142 0.5712711 + [1] 6.7597127 1.4581054 0.3994266 + */ + + val expected = Seq( + Vectors.dense(0.0, 2.2961005, 0.8087932), + Vectors.dense(0.0, 2.2130368, 0.8309556), + Vectors.dense(0.0, 1.7176137, 0.9610657), + Vectors.dense(2.5002642, 2.2000403, 0.5999485), + Vectors.dense(3.1106389, 2.0935142, 0.5712711), + Vectors.dense(6.7597127, 1.4581054, 0.3994266)) + + var idx = 0 + for (fitIntercept <- Seq(false, true); + regParam <- Seq(0.0, 0.1, 1.0)) { + val trainer = new GeneralizedLinearRegression().setFamily("gaussian") + .setFitIntercept(fitIntercept).setRegParam(regParam) + val model = trainer.fit(datasetGaussianIdentity) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) + assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gaussian family, " + + s"fitIntercept = $fitIntercept and regParam = $regParam.") + + idx += 1 + } + } + + test("generalized linear regression: binomial family against glm") { + /* + R code: + f1 <- data$V1 ~ data$V2 + data$V3 + data$V4 + data$V5 - 1 + f2 <- data$V1 ~ data$V2 + data$V3 + data$V4 + data$V5 + data <- read.csv("path", header=FALSE) + + for (formula in c(f1, f2)) { + model <- glm(formula, family="binomial", data=data) + print(as.vector(coef(model))) + } + + [1] -0.3560284 1.3010002 -0.3570805 -0.7406762 + [1] 2.8367406 -0.5896187 0.8931655 -0.3925169 -0.7996989 + + for (formula in c(f1, f2)) { + model <- glm(formula, family=binomial(link=probit), data=data) + print(as.vector(coef(model))) + } + + [1] -0.2134390 0.7800646 -0.2144267 -0.4438358 + [1] 1.6995366 -0.3524694 0.5332651 -0.2352985 -0.4780850 + + for (formula in c(f1, f2)) { + model <- glm(formula, family=binomial(link=cloglog), data=data) + print(as.vector(coef(model))) + } + + [1] -0.2832198 0.8434144 -0.2524727 -0.5293452 + [1] 1.5063590 -0.4038015 0.6133664 -0.2687882 -0.5541758 + */ + val expected = Seq( + Vectors.dense(0.0, -0.3560284, 1.3010002, -0.3570805, -0.7406762), + Vectors.dense(2.8367406, -0.5896187, 0.8931655, -0.3925169, -0.7996989), + Vectors.dense(0.0, -0.2134390, 0.7800646, -0.2144267, -0.4438358), + Vectors.dense(1.6995366, -0.3524694, 0.5332651, -0.2352985, -0.4780850), + Vectors.dense(0.0, -0.2832198, 0.8434144, -0.2524727, -0.5293452), + Vectors.dense(1.5063590, -0.4038015, 0.6133664, -0.2687882, -0.5541758)) + + import GeneralizedLinearRegression._ + + var idx = 0 + for ((link, dataset) <- Seq(("logit", datasetBinomial), ("probit", datasetBinomial), + ("cloglog", datasetBinomial))) { + for (fitIntercept <- Seq(false, true)) { + val trainer = new GeneralizedLinearRegression().setFamily("binomial").setLink(link) + .setFitIntercept(fitIntercept) + val model = trainer.fit(dataset) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1), + model.coefficients(2), model.coefficients(3)) + assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with binomial family, " + + s"$link link and fitIntercept = $fitIntercept.") + + val familyLink = new FamilyAndLink(Binomial, Link.fromName(link)) + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val eta = BLAS.dot(features, model.coefficients) + model.intercept + val prediction2 = familyLink.fitted(eta) + assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + + s"binomial family, $link link and fitIntercept = $fitIntercept.") + } + + idx += 1 + } + } + } + + test("generalized linear regression: poisson family against glm") { + /* + R code: + f1 <- data$V1 ~ data$V2 + data$V3 - 1 + f2 <- data$V1 ~ data$V2 + data$V3 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family="poisson", data=data) + print(as.vector(coef(model))) + } + + [1] 0.22999393 0.08047088 + [1] 0.25022353 0.21998599 0.05998621 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family=poisson(link=identity), data=data) + print(as.vector(coef(model))) + } + + [1] 2.2929501 0.8119415 + [1] 2.5012730 2.1999407 0.5999107 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family=poisson(link=sqrt), data=data) + print(as.vector(coef(model))) + } + + [1] 2.2958947 0.8090515 + [1] 2.5000480 2.1999972 0.5999968 + */ + val expected = Seq( + Vectors.dense(0.0, 0.22999393, 0.08047088), + Vectors.dense(0.25022353, 0.21998599, 0.05998621), + Vectors.dense(0.0, 2.2929501, 0.8119415), + Vectors.dense(2.5012730, 2.1999407, 0.5999107), + Vectors.dense(0.0, 2.2958947, 0.8090515), + Vectors.dense(2.5000480, 2.1999972, 0.5999968)) + + import GeneralizedLinearRegression._ + + var idx = 0 + for ((link, dataset) <- Seq(("log", datasetPoissonLog), ("identity", datasetPoissonIdentity), + ("sqrt", datasetPoissonSqrt))) { + for (fitIntercept <- Seq(false, true)) { + val trainer = new GeneralizedLinearRegression().setFamily("poisson").setLink(link) + .setFitIntercept(fitIntercept) + val model = trainer.fit(dataset) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) + assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with poisson family, " + + s"$link link and fitIntercept = $fitIntercept.") + + val familyLink = new FamilyAndLink(Poisson, Link.fromName(link)) + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val eta = BLAS.dot(features, model.coefficients) + model.intercept + val prediction2 = familyLink.fitted(eta) + assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + + s"poisson family, $link link and fitIntercept = $fitIntercept.") + } + + idx += 1 + } + } + } + + test("generalized linear regression: gamma family against glm") { + /* + R code: + f1 <- data$V1 ~ data$V2 + data$V3 - 1 + f2 <- data$V1 ~ data$V2 + data$V3 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family="Gamma", data=data) + print(as.vector(coef(model))) + } + + [1] 2.3392419 0.8058058 + [1] 2.3507700 2.2533574 0.6042991 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family=Gamma(link=identity), data=data) + print(as.vector(coef(model))) + } + + [1] 2.2908883 0.8147796 + [1] 2.5002406 2.1998346 0.6000059 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family=Gamma(link=log), data=data) + print(as.vector(coef(model))) + } + + [1] 0.22958970 0.08091066 + [1] 0.25003210 0.21996957 0.06000215 + */ + val expected = Seq( + Vectors.dense(0.0, 2.3392419, 0.8058058), + Vectors.dense(2.3507700, 2.2533574, 0.6042991), + Vectors.dense(0.0, 2.2908883, 0.8147796), + Vectors.dense(2.5002406, 2.1998346, 0.6000059), + Vectors.dense(0.0, 0.22958970, 0.08091066), + Vectors.dense(0.25003210, 0.21996957, 0.06000215)) + + import GeneralizedLinearRegression._ + + var idx = 0 + for ((link, dataset) <- Seq(("inverse", datasetGammaInverse), + ("identity", datasetGammaIdentity), ("log", datasetGammaLog))) { + for (fitIntercept <- Seq(false, true)) { + val trainer = new GeneralizedLinearRegression().setFamily("gamma").setLink(link) + .setFitIntercept(fitIntercept) + val model = trainer.fit(dataset) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) + assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gamma family, " + + s"$link link and fitIntercept = $fitIntercept.") + + val familyLink = new FamilyAndLink(Gamma, Link.fromName(link)) + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val eta = BLAS.dot(features, model.coefficients) + model.intercept + val prediction2 = familyLink.fitted(eta) + assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + + s"gamma family, $link link and fitIntercept = $fitIntercept.") + } + + idx += 1 + } + } + } +} + +object GeneralizedLinearRegressionSuite { + + def generateGeneralizedLinearRegressionInput( + intercept: Double, + coefficients: Array[Double], + xMean: Array[Double], + xVariance: Array[Double], + nPoints: Int, + seed: Int, + noiseLevel: Double, + family: String, + link: String): Seq[LabeledPoint] = { + + val rnd = new Random(seed) + def rndElement(i: Int) = { + (rnd.nextDouble() - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i) + } + val (generator, mean) = family match { + case "gaussian" => (new StandardNormalGenerator, 0.0) + case "poisson" => (new PoissonGenerator(1.0), 1.0) + case "gamma" => (new GammaGenerator(1.0, 1.0), 1.0) + } + generator.setSeed(seed) + + (0 until nPoints).map { _ => + val features = Vectors.dense(coefficients.indices.map { rndElement(_) }.toArray) + val eta = BLAS.dot(Vectors.dense(coefficients), features) + intercept + val mu = link match { + case "identity" => eta + case "log" => math.exp(eta) + case "sqrt" => math.pow(eta, 2.0) + case "inverse" => 1.0 / eta + } + val label = mu + noiseLevel * (generator.nextValue() - mean) + // Return LabeledPoints with DenseVector + LabeledPoint(label, features) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala index f067c29d27a7d..b8874b4cd32a4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -46,7 +46,7 @@ class IsotonicRegressionSuite val predictions = model .transform(dataset) - .select("prediction").map { case Row(pred) => + .select("prediction").rdd.map { case Row(pred) => pred }.collect() @@ -66,7 +66,7 @@ class IsotonicRegressionSuite val predictions = model .transform(features) - .select("prediction").map { + .select("prediction").rdd.map { case Row(pred) => pred }.collect() @@ -160,7 +160,7 @@ class IsotonicRegressionSuite val predictions = model .transform(features) - .select("prediction").map { + .select("prediction").rdd.map { case Row(pred) => pred }.collect() diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 3ae108d822de7..9dee04c8776db 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -686,17 +686,18 @@ class LinearRegressionSuite // Residuals in [[LinearRegressionResults]] should equal those manually computed val expectedResiduals = datasetWithDenseFeature.select("features", "label") + .rdd .map { case Row(features: DenseVector, label: Double) => - val prediction = - features(0) * model.coefficients(0) + features(1) * model.coefficients(1) + - model.intercept - label - prediction - } - .zip(model.summary.residuals.map(_.getDouble(0))) + val prediction = + features(0) * model.coefficients(0) + features(1) * model.coefficients(1) + + model.intercept + label - prediction + } + .zip(model.summary.residuals.rdd.map(_.getDouble(0))) .collect() .foreach { case (manualResidual: Double, resultResidual: Double) => - assert(manualResidual ~== resultResidual relTol 1E-5) - } + assert(manualResidual ~== resultResidual relTol 1E-5) + } /* # Use the following R code to generate model training results. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index d140545e377f2..cea0adc55c076 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -667,9 +667,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w test("binary logistic regression with intercept with L1 regularization") { val trainer1 = new LogisticRegressionWithLBFGS().setIntercept(true).setFeatureScaling(true) - trainer1.optimizer.setUpdater(new L1Updater).setRegParam(0.12).setConvergenceTol(1E-6) + trainer1.optimizer.setUpdater(new L1Updater).setRegParam(0.12) val trainer2 = new LogisticRegressionWithLBFGS().setIntercept(true).setFeatureScaling(false) - trainer2.optimizer.setUpdater(new L1Updater).setRegParam(0.12).setConvergenceTol(1E-6) + trainer2.optimizer.setUpdater(new L1Updater).setRegParam(0.12) val model1 = trainer1.run(binaryDataset) val model2 = trainer2.run(binaryDataset) @@ -726,9 +726,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w test("binary logistic regression without intercept with L1 regularization") { val trainer1 = new LogisticRegressionWithLBFGS().setIntercept(false).setFeatureScaling(true) - trainer1.optimizer.setUpdater(new L1Updater).setRegParam(0.12).setConvergenceTol(1E-6) + trainer1.optimizer.setUpdater(new L1Updater).setRegParam(0.12) val trainer2 = new LogisticRegressionWithLBFGS().setIntercept(false).setFeatureScaling(false) - trainer2.optimizer.setUpdater(new L1Updater).setRegParam(0.12).setConvergenceTol(1E-6) + trainer2.optimizer.setUpdater(new L1Updater).setRegParam(0.12) val model1 = trainer1.run(binaryDataset) val model2 = trainer2.run(binaryDataset) @@ -786,9 +786,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w test("binary logistic regression with intercept with L2 regularization") { val trainer1 = new LogisticRegressionWithLBFGS().setIntercept(true).setFeatureScaling(true) - trainer1.optimizer.setUpdater(new SquaredL2Updater).setRegParam(1.37).setConvergenceTol(1E-6) + trainer1.optimizer.setUpdater(new SquaredL2Updater).setRegParam(1.37) val trainer2 = new LogisticRegressionWithLBFGS().setIntercept(true).setFeatureScaling(false) - trainer2.optimizer.setUpdater(new SquaredL2Updater).setRegParam(1.37).setConvergenceTol(1E-6) + trainer2.optimizer.setUpdater(new SquaredL2Updater).setRegParam(1.37) val model1 = trainer1.run(binaryDataset) val model2 = trainer2.run(binaryDataset) @@ -845,9 +845,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w test("binary logistic regression without intercept with L2 regularization") { val trainer1 = new LogisticRegressionWithLBFGS().setIntercept(false).setFeatureScaling(true) - trainer1.optimizer.setUpdater(new SquaredL2Updater).setRegParam(1.37).setConvergenceTol(1E-6) + trainer1.optimizer.setUpdater(new SquaredL2Updater).setRegParam(1.37) val trainer2 = new LogisticRegressionWithLBFGS().setIntercept(false).setFeatureScaling(false) - trainer2.optimizer.setUpdater(new SquaredL2Updater).setRegParam(1.37).setConvergenceTol(1E-6) + trainer2.optimizer.setUpdater(new SquaredL2Updater).setRegParam(1.37) val model1 = trainer1.run(binaryDataset) val model2 = trainer2.run(binaryDataset) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala index 77a2773c36f56..dcb1f398b04b8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala @@ -42,6 +42,7 @@ class AssociationRulesSuite extends SparkFunSuite with MLlibTestSparkContext { .collect() /* Verify results using the `R` code: + library(arules) transactions = as(sapply( list("r z h k p", "z y x w v u t s", @@ -52,7 +53,7 @@ class AssociationRulesSuite extends SparkFunSuite with MLlibTestSparkContext { FUN=function(x) strsplit(x," ",fixed=TRUE)), "transactions") ars = apriori(transactions, - parameter = list(support = 0.0, confidence = 0.5, target="rules", minlen=2)) + parameter = list(support = 0.5, confidence = 0.9, target="rules", minlen=2)) arsDF = as(ars, "data.frame") arsDF$support = arsDF$support * length(transactions) names(arsDF)[names(arsDF) == "support"] = "freq" diff --git a/pom.xml b/pom.xml index 244355b080221..2148379896d35 100644 --- a/pom.xml +++ b/pom.xml @@ -87,19 +87,19 @@ common/sketch - tags + common/network-common + common/network-shuffle + common/unsafe + common/tags core graphx mllib tools - network/common - network/shuffle streaming sql/catalyst sql/core sql/hive docker-integration-tests - unsafe assembly external/twitter external/flume @@ -147,7 +147,7 @@ 1.2.4 8.1.14.v20131031 3.0.0.v201112011016 - 0.5.0 + 0.7.4 2.4.0 2.0.8 3.1.2 @@ -2442,7 +2442,7 @@ yarn yarn - network/yarn + common/network-yarn diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 746223f39e4f6..9ce37fc753c46 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -262,12 +262,32 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.Graph.mapReduceTriplets$default$3"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.impl.GraphImpl.mapReduceTriplets") ) ++ Seq( - // SPARK-13426 Remove the support of SIMR - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkMasterRegex.SIMR_REGEX") + // SPARK-13426 Remove the support of SIMR + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkMasterRegex.SIMR_REGEX") ) ++ Seq( // SPARK-13413 Remove SparkContext.metricsSystem/schedulerBackend_ setter ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.metricsSystem"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.schedulerBackend_=") + ) ++ Seq( + // SPARK-13220 Deprecate yarn-client and yarn-cluster mode + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.SparkContext.org$apache$spark$SparkContext$$createTaskScheduler") + ) ++ Seq( + // SPARK-13465 TaskContext. + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.addTaskFailureListener") + ) ++ Seq ( + // SPARK-7729 Executor which has been killed should also be displayed on Executor Tab + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.ExecutorSummary.this") + ) ++ Seq( + // SPARK-13526 Move SQLContext per-session states to new class + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.sql.UDFRegistration.this") + ) ++ Seq( + // [SPARK-13486][SQL] Move SQLConf into an internal package + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$SQLConfEntry"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$SQLConfEntry$") ) case v if v.startsWith("1.6") => Seq( diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 3179fb30ab4d7..253af15cb5cd9 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -26,11 +26,12 @@ from pyspark.mllib.common import inherit_doc -__all__ = ['LogisticRegression', 'LogisticRegressionModel', 'DecisionTreeClassifier', - 'DecisionTreeClassificationModel', 'GBTClassifier', 'GBTClassificationModel', - 'RandomForestClassifier', 'RandomForestClassificationModel', 'NaiveBayes', - 'NaiveBayesModel', 'MultilayerPerceptronClassifier', - 'MultilayerPerceptronClassificationModel'] +__all__ = ['LogisticRegression', 'LogisticRegressionModel', + 'DecisionTreeClassifier', 'DecisionTreeClassificationModel', + 'GBTClassifier', 'GBTClassificationModel', + 'RandomForestClassifier', 'RandomForestClassificationModel', + 'NaiveBayes', 'NaiveBayesModel', + 'MultilayerPerceptronClassifier', 'MultilayerPerceptronClassificationModel'] @inherit_doc diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 611b9190491c1..1cea477acb47d 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -21,7 +21,8 @@ from pyspark.ml.param.shared import * from pyspark.mllib.common import inherit_doc -__all__ = ['KMeans', 'KMeansModel', 'BisectingKMeans', 'BisectingKMeansModel'] +__all__ = ['BisectingKMeans', 'BisectingKMeansModel', + 'KMeans', 'KMeansModel'] class KMeansModel(JavaModel, MLWritable, MLReadable): diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 464c9446f2f39..fb31c7310c0a8 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -27,14 +27,34 @@ from pyspark.mllib.common import inherit_doc from pyspark.mllib.linalg import _convert_to_vector -__all__ = ['Binarizer', 'Bucketizer', 'CountVectorizer', 'CountVectorizerModel', 'DCT', - 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', 'IndexToString', 'MinMaxScaler', - 'MinMaxScalerModel', 'NGram', 'Normalizer', 'OneHotEncoder', 'PCA', 'PCAModel', - 'PolynomialExpansion', 'QuantileDiscretizer', 'RegexTokenizer', 'RFormula', - 'RFormulaModel', 'SQLTransformer', 'StandardScaler', 'StandardScalerModel', - 'StopWordsRemover', 'StringIndexer', 'StringIndexerModel', 'Tokenizer', - 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', 'Word2Vec', 'Word2VecModel', - 'ChiSqSelector', 'ChiSqSelectorModel'] +__all__ = ['Binarizer', + 'Bucketizer', + 'ChiSqSelector', 'ChiSqSelectorModel', + 'CountVectorizer', 'CountVectorizerModel', + 'DCT', + 'ElementwiseProduct', + 'HashingTF', + 'IDF', 'IDFModel', + 'IndexToString', + 'MaxAbsScaler', 'MaxAbsScalerModel', + 'MinMaxScaler', 'MinMaxScalerModel', + 'NGram', + 'Normalizer', + 'OneHotEncoder', + 'PCA', 'PCAModel', + 'PolynomialExpansion', + 'QuantileDiscretizer', + 'RegexTokenizer', + 'RFormula', 'RFormulaModel', + 'SQLTransformer', + 'StandardScaler', 'StandardScalerModel', + 'StopWordsRemover', + 'StringIndexer', 'StringIndexerModel', + 'Tokenizer', + 'VectorAssembler', + 'VectorIndexer', 'VectorIndexerModel', + 'VectorSlicer', + 'Word2Vec', 'Word2VecModel'] @inherit_doc @@ -544,6 +564,66 @@ class IDFModel(JavaModel): """ +@inherit_doc +class MaxAbsScaler(JavaEstimator, HasInputCol, HasOutputCol): + """ + .. note:: Experimental + + Rescale each feature individually to range [-1, 1] by dividing through the largest maximum + absolute value in each feature. It does not shift/center the data, and thus does not destroy + any sparsity. + + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([(Vectors.dense([1.0]),), (Vectors.dense([2.0]),)], ["a"]) + >>> maScaler = MaxAbsScaler(inputCol="a", outputCol="scaled") + >>> model = maScaler.fit(df) + >>> model.transform(df).show() + +-----+------+ + | a|scaled| + +-----+------+ + |[1.0]| [0.5]| + |[2.0]| [1.0]| + +-----+------+ + ... + + .. versionadded:: 2.0.0 + """ + + @keyword_only + def __init__(self, inputCol=None, outputCol=None): + """ + __init__(self, inputCol=None, outputCol=None) + """ + super(MaxAbsScaler, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.MaxAbsScaler", self.uid) + self._setDefault() + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.0.0") + def setParams(self, inputCol=None, outputCol=None): + """ + setParams(self, inputCol=None, outputCol=None) + Sets params for this MaxAbsScaler. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def _create_model(self, java_model): + return MaxAbsScalerModel(java_model) + + +class MaxAbsScalerModel(JavaModel): + """ + .. note:: Experimental + + Model fitted by :py:class:`MaxAbsScaler`. + + .. versionadded:: 2.0.0 + """ + + @inherit_doc class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol): """ @@ -939,7 +1019,7 @@ def getDegree(self): @inherit_doc -class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol): +class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasSeed): """ .. note:: Experimental @@ -951,7 +1031,9 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol): >>> df = sqlContext.createDataFrame([(0.1,), (0.4,), (1.2,), (1.5,)], ["values"]) >>> qds = QuantileDiscretizer(numBuckets=2, - ... inputCol="values", outputCol="buckets") + ... inputCol="values", outputCol="buckets", seed=123) + >>> qds.getSeed() + 123 >>> bucketizer = qds.fit(df) >>> splits = bucketizer.getSplits() >>> splits[0] @@ -971,9 +1053,9 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol): "categories) into which data points are grouped. Must be >= 2. Default 2.") @keyword_only - def __init__(self, numBuckets=2, inputCol=None, outputCol=None): + def __init__(self, numBuckets=2, inputCol=None, outputCol=None, seed=None): """ - __init__(self, numBuckets=2, inputCol=None, outputCol=None) + __init__(self, numBuckets=2, inputCol=None, outputCol=None, seed=None) """ super(QuantileDiscretizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.QuantileDiscretizer", @@ -987,9 +1069,9 @@ def __init__(self, numBuckets=2, inputCol=None, outputCol=None): @keyword_only @since("2.0.0") - def setParams(self, numBuckets=2, inputCol=None, outputCol=None): + def setParams(self, numBuckets=2, inputCol=None, outputCol=None, seed=None): """ - setParams(self, numBuckets=2, inputCol=None, outputCol=None) + setParams(self, numBuckets=2, inputCol=None, outputCol=None, seed=None) Set the params for the QuantileDiscretizer """ kwargs = self.setParams._input_kwargs diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index de4a751a54796..6b994fe9f93b4 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -154,7 +154,7 @@ def intercept(self): @inherit_doc class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, - HasWeightCol): + HasWeightCol, MLWritable, MLReadable): """ .. note:: Experimental @@ -172,6 +172,18 @@ class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti 0.0 >>> model.boundaries DenseVector([0.0, 1.0]) + >>> ir_path = temp_path + "/ir" + >>> ir.save(ir_path) + >>> ir2 = IsotonicRegression.load(ir_path) + >>> ir2.getIsotonic() + True + >>> model_path = temp_path + "/ir_model" + >>> model.save(model_path) + >>> model2 = IsotonicRegressionModel.load(model_path) + >>> model.boundaries == model2.boundaries + True + >>> model.predictions == model2.predictions + True """ isotonic = \ @@ -237,7 +249,7 @@ def getFeatureIndex(self): return self.getOrDefault(self.featureIndex) -class IsotonicRegressionModel(JavaModel): +class IsotonicRegressionModel(JavaModel, MLWritable, MLReadable): """ .. note:: Experimental @@ -663,7 +675,7 @@ class GBTRegressionModel(TreeEnsembleModels): @inherit_doc class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, - HasFitIntercept, HasMaxIter, HasTol): + HasFitIntercept, HasMaxIter, HasTol, MLWritable, MLReadable): """ Accelerated Failure Time (AFT) Model Survival Regression @@ -690,6 +702,20 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi | 0.0|(1,[],[])| 0.0| 1.0| +-----+---------+------+----------+ ... + >>> aftsr_path = temp_path + "/aftsr" + >>> aftsr.save(aftsr_path) + >>> aftsr2 = AFTSurvivalRegression.load(aftsr_path) + >>> aftsr2.getMaxIter() + 100 + >>> model_path = temp_path + "/aftsr_model" + >>> model.save(model_path) + >>> model2 = AFTSurvivalRegressionModel.load(model_path) + >>> model.coefficients == model2.coefficients + True + >>> model.intercept == model2.intercept + True + >>> model.scale == model2.scale + True .. versionadded:: 1.6.0 """ @@ -787,7 +813,7 @@ def getQuantilesCol(self): return self.getOrDefault(self.quantilesCol) -class AFTSurvivalRegressionModel(JavaModel): +class AFTSurvivalRegressionModel(JavaModel, MLWritable, MLReadable): """ Model fitted by AFTSurvivalRegression. diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index b24592c3798e6..57106f8690a7d 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -294,7 +294,7 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, (default: 0.01) :param regType: The type of regularizer used for training our model. - Allowed values: + Supported values: - "l1" for using L1 regularization - "l2" for using L2 regularization (default) @@ -326,8 +326,8 @@ class LogisticRegressionWithLBFGS(object): """ @classmethod @since('1.2.0') - def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType="l2", - intercept=False, corrections=10, tolerance=1e-4, validateData=True, numClasses=2): + def train(cls, data, iterations=100, initialWeights=None, regParam=0.0, regType="l2", + intercept=False, corrections=10, tolerance=1e-6, validateData=True, numClasses=2): """ Train a logistic regression model on the given data. @@ -341,10 +341,10 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType (default: None) :param regParam: The regularizer parameter. - (default: 0.01) + (default: 0.0) :param regType: The type of regularizer used for training our model. - Allowed values: + Supported values: - "l1" for using L1 regularization - "l2" for using L2 regularization (default) @@ -356,10 +356,12 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType (default: False) :param corrections: The number of corrections used in the LBFGS update. - (default: 10) + If a known updater is used for binary classification, + it calls the ml implementation and this parameter will + have no effect. (default: 10) :param tolerance: The convergence tolerance of iterations for L-BFGS. - (default: 1e-4) + (default: 1e-6) :param validateData: Boolean parameter which indicates if the algorithm should validate data before training. diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py index 7a2d77a4dad13..5c9706cb8cb29 100644 --- a/python/pyspark/mllib/fpm.py +++ b/python/pyspark/mllib/fpm.py @@ -21,14 +21,15 @@ from pyspark import SparkContext, since from pyspark.rdd import ignore_unicode_prefix -from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc +from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc +from pyspark.mllib.util import JavaSaveable, JavaLoader, inherit_doc __all__ = ['FPGrowth', 'FPGrowthModel', 'PrefixSpan', 'PrefixSpanModel'] @inherit_doc @ignore_unicode_prefix -class FPGrowthModel(JavaModelWrapper): +class FPGrowthModel(JavaModelWrapper, JavaSaveable, JavaLoader): """ .. note:: Experimental @@ -40,6 +41,11 @@ class FPGrowthModel(JavaModelWrapper): >>> model = FPGrowth.train(rdd, 0.6, 2) >>> sorted(model.freqItemsets().collect()) [FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'c'], freq=3), ... + >>> model_path = temp_path + "/fpm" + >>> model.save(sc, model_path) + >>> sameModel = FPGrowthModel.load(sc, model_path) + >>> sorted(model.freqItemsets().collect()) == sorted(sameModel.freqItemsets().collect()) + True .. versionadded:: 1.4.0 """ @@ -51,6 +57,16 @@ def freqItemsets(self): """ return self.call("getFreqItemsets").map(lambda x: (FPGrowth.FreqItemset(x[0], x[1]))) + @classmethod + @since("2.0.0") + def load(cls, sc, path): + """ + Load a model from the given path. + """ + model = cls._load_java(sc, path) + wrapper = sc._jvm.FPGrowthModelWrapper(model) + return FPGrowthModel(wrapper) + class FPGrowth(object): """ @@ -170,8 +186,19 @@ def _test(): import pyspark.mllib.fpm globs = pyspark.mllib.fpm.__dict__.copy() globs['sc'] = SparkContext('local[4]', 'PythonTest') - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) - globs['sc'].stop() + import tempfile + + temp_path = tempfile.mkdtemp() + globs['temp_path'] = temp_path + try: + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + finally: + from shutil import rmtree + try: + rmtree(temp_path) + except OSError: + pass if failure_count: exit(-1) diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 4dd7083d79c8c..3b77a6200054f 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -37,10 +37,11 @@ class LabeledPoint(object): """ Class that represents the features and labels of a data point. - :param label: Label for this data point. - :param features: Vector of features for this point (NumPy array, - list, pyspark.mllib.linalg.SparseVector, or scipy.sparse - column matrix) + :param label: + Label for this data point. + :param features: + Vector of features for this point (NumPy array, list, + pyspark.mllib.linalg.SparseVector, or scipy.sparse column matrix). Note: 'label' and 'features' are accessible as class attributes. @@ -66,8 +67,10 @@ class LinearModel(object): """ A linear model that has a vector of coefficients and an intercept. - :param weights: Weights computed for every feature. - :param intercept: Intercept computed for this model. + :param weights: + Weights computed for every feature. + :param intercept: + Intercept computed for this model. .. versionadded:: 0.9.0 """ @@ -217,19 +220,8 @@ def _regression_train_wrapper(train_func, modelClass, data, initial_weights): class LinearRegressionWithSGD(object): """ - Train a linear regression model with no regularization using Stochastic Gradient Descent. - This solves the least squares regression formulation - - f(weights) = 1/n ||A weights-y||^2 - - which is the mean squared error. - Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with - its corresponding right hand side label y. - See also the documentation for the precise formulation. - .. versionadded:: 0.9.0 """ - @classmethod @since("0.9.0") def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, @@ -237,47 +229,52 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, validateData=True, convergenceTol=0.001): """ Train a linear regression model using Stochastic Gradient - Descent (SGD). - This solves the least squares regression formulation - - f(weights) = 1/(2n) ||A weights - y||^2, - - which is the mean squared error. - Here the data matrix has n rows, and the input RDD holds the - set of rows of A, each with its corresponding right hand side - label y. See also the documentation for the precise formulation. - - :param data: The training data, an RDD of - LabeledPoint. - :param iterations: The number of iterations - (default: 100). - :param step: The step parameter used in SGD - (default: 1.0). - :param miniBatchFraction: Fraction of data to be used for each - SGD iteration (default: 1.0). - :param initialWeights: The initial weights (default: None). - :param regParam: The regularizer parameter - (default: 0.0). - :param regType: The type of regularizer used for - training our model. - - :Allowed values: - - "l1" for using L1 regularization (lasso), - - "l2" for using L2 regularization (ridge), - - None for no regularization - - (default: None) - - :param intercept: Boolean parameter which indicates the - use or not of the augmented representation - for training data (i.e. whether bias - features are activated or not, - default: False). - :param validateData: Boolean parameter which indicates if - the algorithm should validate data - before training. (default: True) - :param convergenceTol: A condition which decides iteration termination. - (default: 0.001) + Descent (SGD). This solves the least squares regression + formulation + + f(weights) = 1/(2n) ||A weights - y||^2 + + which is the mean squared error. Here the data matrix has n rows, + and the input RDD holds the set of rows of A, each with its + corresponding right hand side label y. + See also the documentation for the precise formulation. + + :param data: + The training data, an RDD of LabeledPoint. + :param iterations: + The number of iterations. + (default: 100) + :param step: + The step parameter used in SGD. + (default: 1.0) + :param miniBatchFraction: + Fraction of data to be used for each SGD iteration. + (default: 1.0) + :param initialWeights: + The initial weights. + (default: None) + :param regParam: + The regularizer parameter. + (default: 0.0) + :param regType: + The type of regularizer used for training our model. + Supported values: + + - "l1" for using L1 regularization + - "l2" for using L2 regularization + - None for no regularization (default) + :param intercept: + Boolean parameter which indicates the use or not of the + augmented representation for training data (i.e., whether bias + features are activated or not). + (default: False) + :param validateData: + Boolean parameter which indicates if the algorithm should + validate data before training. + (default: True) + :param convergenceTol: + A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, int(iterations), @@ -368,56 +365,53 @@ def load(cls, sc, path): class LassoWithSGD(object): """ - Train a regression model with L1-regularization using Stochastic Gradient Descent. - This solves the L1-regularized least squares regression formulation - - f(weights) = 1/2n ||A weights-y||^2 + regParam ||weights||_1 - - Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with - its corresponding right hand side label y. - See also the documentation for the precise formulation. - .. versionadded:: 0.9.0 """ - @classmethod @since("0.9.0") def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None, intercept=False, validateData=True, convergenceTol=0.001): """ - Train a regression model with L1-regularization using - Stochastic Gradient Descent. - This solves the l1-regularized least squares regression - formulation - - f(weights) = 1/(2n) ||A weights - y||^2 + regParam ||weights||_1. - - Here the data matrix has n rows, and the input RDD holds the - set of rows of A, each with its corresponding right hand side - label y. See also the documentation for the precise formulation. - - :param data: The training data, an RDD of - LabeledPoint. - :param iterations: The number of iterations - (default: 100). - :param step: The step parameter used in SGD - (default: 1.0). - :param regParam: The regularizer parameter - (default: 0.01). - :param miniBatchFraction: Fraction of data to be used for each - SGD iteration (default: 1.0). - :param initialWeights: The initial weights (default: None). - :param intercept: Boolean parameter which indicates the - use or not of the augmented representation - for training data (i.e. whether bias - features are activated or not, - default: False). - :param validateData: Boolean parameter which indicates if - the algorithm should validate data - before training. (default: True) - :param convergenceTol: A condition which decides iteration termination. - (default: 0.001) + Train a regression model with L1-regularization using Stochastic + Gradient Descent. This solves the l1-regularized least squares + regression formulation + + f(weights) = 1/(2n) ||A weights - y||^2 + regParam ||weights||_1 + + Here the data matrix has n rows, and the input RDD holds the set + of rows of A, each with its corresponding right hand side label y. + See also the documentation for the precise formulation. + + :param data: + The training data, an RDD of LabeledPoint. + :param iterations: + The number of iterations. + (default: 100) + :param step: + The step parameter used in SGD. + (default: 1.0) + :param regParam: + The regularizer parameter. + (default: 0.01) + :param miniBatchFraction: + Fraction of data to be used for each SGD iteration. + (default: 1.0) + :param initialWeights: + The initial weights. + (default: None) + :param intercept: + Boolean parameter which indicates the use or not of the + augmented representation for training data (i.e. whether bias + features are activated or not). + (default: False) + :param validateData: + Boolean parameter which indicates if the algorithm should + validate data before training. + (default: True) + :param convergenceTol: + A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainLassoModelWithSGD", rdd, int(iterations), float(step), @@ -508,56 +502,53 @@ def load(cls, sc, path): class RidgeRegressionWithSGD(object): """ - Train a regression model with L2-regularization using Stochastic Gradient Descent. - This solves the L2-regularized least squares regression formulation - - f(weights) = 1/2n ||A weights-y||^2 + regParam/2 ||weights||^2 - - Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with - its corresponding right hand side label y. - See also the documentation for the precise formulation. - .. versionadded:: 0.9.0 """ - @classmethod @since("0.9.0") def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None, intercept=False, validateData=True, convergenceTol=0.001): """ - Train a regression model with L2-regularization using - Stochastic Gradient Descent. - This solves the l2-regularized least squares regression - formulation - - f(weights) = 1/(2n) ||A weights - y||^2 + regParam/2 ||weights||^2. - - Here the data matrix has n rows, and the input RDD holds the - set of rows of A, each with its corresponding right hand side - label y. See also the documentation for the precise formulation. - - :param data: The training data, an RDD of - LabeledPoint. - :param iterations: The number of iterations - (default: 100). - :param step: The step parameter used in SGD - (default: 1.0). - :param regParam: The regularizer parameter - (default: 0.01). - :param miniBatchFraction: Fraction of data to be used for each - SGD iteration (default: 1.0). - :param initialWeights: The initial weights (default: None). - :param intercept: Boolean parameter which indicates the - use or not of the augmented representation - for training data (i.e. whether bias - features are activated or not, - default: False). - :param validateData: Boolean parameter which indicates if - the algorithm should validate data - before training. (default: True) - :param convergenceTol: A condition which decides iteration termination. - (default: 0.001) + Train a regression model with L2-regularization using Stochastic + Gradient Descent. This solves the l2-regularized least squares + regression formulation + + f(weights) = 1/(2n) ||A weights - y||^2 + regParam/2 ||weights||^2 + + Here the data matrix has n rows, and the input RDD holds the set + of rows of A, each with its corresponding right hand side label y. + See also the documentation for the precise formulation. + + :param data: + The training data, an RDD of LabeledPoint. + :param iterations: + The number of iterations. + (default: 100) + :param step: + The step parameter used in SGD. + (default: 1.0) + :param regParam: + The regularizer parameter. + (default: 0.01) + :param miniBatchFraction: + Fraction of data to be used for each SGD iteration. + (default: 1.0) + :param initialWeights: + The initial weights. + (default: None) + :param intercept: + Boolean parameter which indicates the use or not of the + augmented representation for training data (i.e. whether bias + features are activated or not). + (default: False) + :param validateData: + Boolean parameter which indicates if the algorithm should + validate data before training. + (default: True) + :param convergenceTol: + A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainRidgeModelWithSGD", rdd, int(iterations), float(step), @@ -572,12 +563,14 @@ class IsotonicRegressionModel(Saveable, Loader): """ Regression model for isotonic regression. - :param boundaries: Array of boundaries for which predictions are - known. Boundaries must be sorted in increasing order. - :param predictions: Array of predictions associated to the - boundaries at the same index. Results of isotonic - regression and therefore monotone. - :param isotonic: indicates whether this is isotonic or antitonic. + :param boundaries: + Array of boundaries for which predictions are known. Boundaries + must be sorted in increasing order. + :param predictions: + Array of predictions associated to the boundaries at the same + index. Results of isotonic regression and therefore monotone. + :param isotonic: + Indicates whether this is isotonic or antitonic. >>> data = [(1, 0, 1), (2, 1, 1), (3, 2, 1), (1, 3, 1), (6, 4, 1), (17, 5, 1), (16, 6, 1)] >>> irm = IsotonicRegression.train(sc.parallelize(data)) @@ -628,7 +621,8 @@ def predict(self, x): values with the same boundary then the same rules as in 2) are used. - :param x: Feature or RDD of Features to be labeled. + :param x: + Feature or RDD of Features to be labeled. """ if isinstance(x, RDD): return x.map(lambda v: self.predict(v)) @@ -657,8 +651,8 @@ def load(cls, sc, path): class IsotonicRegression(object): """ Isotonic regression. - Currently implemented using parallelized pool adjacent violators algorithm. - Only univariate (single feature) algorithm supported. + Currently implemented using parallelized pool adjacent violators + algorithm. Only univariate (single feature) algorithm supported. Sequential PAV implementation based on: @@ -684,8 +678,11 @@ def train(cls, data, isotonic=True): """ Train a isotonic regression model on the given data. - :param data: RDD of (label, feature, weight) tuples. - :param isotonic: Whether this is isotonic or antitonic. + :param data: + RDD of (label, feature, weight) tuples. + :param isotonic: + Whether this is isotonic (which is default) or antitonic. + (default: True) """ boundaries, predictions = callMLlibFunc("trainIsotonicRegressionModel", data.map(_convert_to_vector), bool(isotonic)) @@ -721,9 +718,11 @@ def _validate(self, dstream): @since("1.5.0") def predictOn(self, dstream): """ - Make predictions on a dstream. + Use the model to make predictions on batches of data from a + DStream. - :return: Transformed dstream object. + :return: + DStream containing predictions. """ self._validate(dstream) return dstream.map(lambda x: self._model.predict(x)) @@ -731,9 +730,11 @@ def predictOn(self, dstream): @since("1.5.0") def predictOnValues(self, dstream): """ - Make predictions on a keyed dstream. + Use the model to make predictions on the values of a DStream and + carry over its keys. - :return: Transformed dstream object. + :return: + DStream containing the input keys and the predictions as values. """ self._validate(dstream) return dstream.mapValues(lambda x: self._model.predict(x)) @@ -742,14 +743,15 @@ def predictOnValues(self, dstream): @inherit_doc class StreamingLinearRegressionWithSGD(StreamingLinearAlgorithm): """ - Train or predict a linear regression model on streaming data. Training uses - Stochastic Gradient Descent to update the model based on each new batch of - incoming data from a DStream (see `LinearRegressionWithSGD` for model equation). + Train or predict a linear regression model on streaming data. + Training uses Stochastic Gradient Descent to update the model + based on each new batch of incoming data from a DStream + (see `LinearRegressionWithSGD` for model equation). Each batch of data is assumed to be an RDD of LabeledPoints. The number of data points per batch can vary, but the number - of features must be constant. An initial weight - vector must be provided. + of features must be constant. An initial weight vector must + be provided. :param stepSize: Step size for each iteration of gradient descent. diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 0001b60093a69..f7ea466b43291 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -60,8 +60,7 @@ def numTrees(self): @since("1.3.0") def totalNumNodes(self): """ - Get total number of nodes, summed over all trees in the - ensemble. + Get total number of nodes, summed over all trees in the ensemble. """ return self.call("totalNumNodes") @@ -92,8 +91,9 @@ def predict(self, x): transformation or action. Call predict directly on the RDD instead. - :param x: Data point (feature vector), - or an RDD of data points (feature vectors). + :param x: + Data point (feature vector), or an RDD of data points (feature + vectors). """ if isinstance(x, RDD): return self.call("predict", x.map(_convert_to_vector)) @@ -108,8 +108,9 @@ def numNodes(self): @since("1.1.0") def depth(self): - """Get depth of tree. - E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes. + """ + Get depth of tree (e.g. depth 0 means 1 leaf node, depth 1 + means 1 internal node + 2 leaf nodes). """ return self._java_model.depth() @@ -152,24 +153,37 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0): """ - Train a DecisionTreeModel for classification. - - :param data: Training data: RDD of LabeledPoint. - Labels are integers {0,1,...,numClasses}. - :param numClasses: Number of classes for classification. - :param categoricalFeaturesInfo: Map from categorical feature index - to number of categories. - Any feature not in this map - is treated as continuous. - :param impurity: Supported values: "entropy" or "gini" - :param maxDepth: Max depth of tree. - E.g., depth 0 means 1 leaf node. - Depth 1 means 1 internal node + 2 leaf nodes. - :param maxBins: Number of bins used for finding splits at each node. - :param minInstancesPerNode: Min number of instances required at child - nodes to create the parent split - :param minInfoGain: Min info gain required to create a split - :return: DecisionTreeModel + Train a decision tree model for classification. + + :param data: + Training data: RDD of LabeledPoint. Labels should take values + {0, 1, ..., numClasses-1}. + :param numClasses: + Number of classes for classification. + :param categoricalFeaturesInfo: + Map storing arity of categorical features. An entry (n -> k) + indicates that feature n is categorical with k categories + indexed from 0: {0, 1, ..., k-1}. + :param impurity: + Criterion used for information gain calculation. + Supported values: "gini" or "entropy". + (default: "gini") + :param maxDepth: + Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1 + means 1 internal node + 2 leaf nodes). + (default: 5) + :param maxBins: + Number of bins used for finding splits at each node. + (default: 32) + :param minInstancesPerNode: + Minimum number of instances required at child nodes to create + the parent split. + (default: 1) + :param minInfoGain: + Minimum info gain required to create a split. + (default: 0.0) + :return: + DecisionTreeModel. Example usage: @@ -211,23 +225,34 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, impurity="variance", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0): """ - Train a DecisionTreeModel for regression. - - :param data: Training data: RDD of LabeledPoint. - Labels are real numbers. - :param categoricalFeaturesInfo: Map from categorical feature - index to number of categories. - Any feature not in this map is treated as continuous. - :param impurity: Supported values: "variance" - :param maxDepth: Max depth of tree. - E.g., depth 0 means 1 leaf node. - Depth 1 means 1 internal node + 2 leaf nodes. - :param maxBins: Number of bins used for finding splits at each - node. - :param minInstancesPerNode: Min number of instances required at - child nodes to create the parent split - :param minInfoGain: Min info gain required to create a split - :return: DecisionTreeModel + Train a decision tree model for regression. + + :param data: + Training data: RDD of LabeledPoint. Labels are real numbers. + :param categoricalFeaturesInfo: + Map storing arity of categorical features. An entry (n -> k) + indicates that feature n is categorical with k categories + indexed from 0: {0, 1, ..., k-1}. + :param impurity: + Criterion used for information gain calculation. + The only supported value for regression is "variance". + (default: "variance") + :param maxDepth: + Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1 + means 1 internal node + 2 leaf nodes). + (default: 5) + :param maxBins: + Number of bins used for finding splits at each node. + (default: 32) + :param minInstancesPerNode: + Minimum number of instances required at child nodes to create + the parent split. + (default: 1) + :param minInfoGain: + Minimum info gain required to create a split. + (default: 0.0) + :return: + DecisionTreeModel. Example usage: @@ -302,34 +327,44 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy="auto", impurity="gini", maxDepth=4, maxBins=32, seed=None): """ - Method to train a decision tree model for binary or multiclass + Train a random forest model for binary or multiclass classification. - :param data: Training dataset: RDD of LabeledPoint. Labels - should take values {0, 1, ..., numClasses-1}. - :param numClasses: number of classes for classification. - :param categoricalFeaturesInfo: Map storing arity of categorical - features. E.g., an entry (n -> k) indicates that - feature n is categorical with k categories indexed - from 0: {0, 1, ..., k-1}. - :param numTrees: Number of trees in the random forest. - :param featureSubsetStrategy: Number of features to consider for - splits at each node. - Supported: "auto" (default), "all", "sqrt", "log2", "onethird". - If "auto" is set, this parameter is set based on numTrees: - if numTrees == 1, set to "all"; - if numTrees > 1 (forest) set to "sqrt". - :param impurity: Criterion used for information gain calculation. - Supported values: "gini" (recommended) or "entropy". - :param maxDepth: Maximum depth of the tree. - E.g., depth 0 means 1 leaf node; depth 1 means - 1 internal node + 2 leaf nodes. (default: 4) - :param maxBins: maximum number of bins used for splitting - features - (default: 32) - :param seed: Random seed for bootstrapping and choosing feature - subsets. - :return: RandomForestModel that can be used for prediction + :param data: + Training dataset: RDD of LabeledPoint. Labels should take values + {0, 1, ..., numClasses-1}. + :param numClasses: + Number of classes for classification. + :param categoricalFeaturesInfo: + Map storing arity of categorical features. An entry (n -> k) + indicates that feature n is categorical with k categories + indexed from 0: {0, 1, ..., k-1}. + :param numTrees: + Number of trees in the random forest. + :param featureSubsetStrategy: + Number of features to consider for splits at each node. + Supported values: "auto", "all", "sqrt", "log2", "onethird". + If "auto" is set, this parameter is set based on numTrees: + if numTrees == 1, set to "all"; + if numTrees > 1 (forest) set to "sqrt". + (default: "auto") + :param impurity: + Criterion used for information gain calculation. + Supported values: "gini" or "entropy". + (default: "gini") + :param maxDepth: + Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1 + means 1 internal node + 2 leaf nodes). + (default: 4) + :param maxBins: + Maximum number of bins used for splitting features. + (default: 32) + :param seed: + Random seed for bootstrapping and choosing feature subsets. + Set as None to generate seed based on system time. + (default: None) + :return: + RandomForestModel that can be used for prediction. Example usage: @@ -383,32 +418,40 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetStrategy="auto", impurity="variance", maxDepth=4, maxBins=32, seed=None): """ - Method to train a decision tree model for regression. - - :param data: Training dataset: RDD of LabeledPoint. Labels are - real numbers. - :param categoricalFeaturesInfo: Map storing arity of categorical - features. E.g., an entry (n -> k) indicates that feature - n is categorical with k categories indexed from 0: - {0, 1, ..., k-1}. - :param numTrees: Number of trees in the random forest. - :param featureSubsetStrategy: Number of features to consider for - splits at each node. - Supported: "auto" (default), "all", "sqrt", "log2", "onethird". - If "auto" is set, this parameter is set based on numTrees: - if numTrees == 1, set to "all"; - if numTrees > 1 (forest) set to "onethird" for regression. - :param impurity: Criterion used for information gain - calculation. - Supported values: "variance". - :param maxDepth: Maximum depth of the tree. E.g., depth 0 means - 1 leaf node; depth 1 means 1 internal node + 2 leaf - nodes. (default: 4) - :param maxBins: maximum number of bins used for splitting - features (default: 32) - :param seed: Random seed for bootstrapping and choosing feature - subsets. - :return: RandomForestModel that can be used for prediction + Train a random forest model for regression. + + :param data: + Training dataset: RDD of LabeledPoint. Labels are real numbers. + :param categoricalFeaturesInfo: + Map storing arity of categorical features. An entry (n -> k) + indicates that feature n is categorical with k categories + indexed from 0: {0, 1, ..., k-1}. + :param numTrees: + Number of trees in the random forest. + :param featureSubsetStrategy: + Number of features to consider for splits at each node. + Supported values: "auto", "all", "sqrt", "log2", "onethird". + If "auto" is set, this parameter is set based on numTrees: + if numTrees == 1, set to "all"; + if numTrees > 1 (forest) set to "onethird" for regression. + (default: "auto") + :param impurity: + Criterion used for information gain calculation. + The only supported value for regression is "variance". + (default: "variance") + :param maxDepth: + Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1 + means 1 internal node + 2 leaf nodes). + (default: 4) + :param maxBins: + Maximum number of bins used for splitting features. + (default: 32) + :param seed: + Random seed for bootstrapping and choosing feature subsets. + Set as None to generate seed based on system time. + (default: None) + :return: + RandomForestModel that can be used for prediction. Example usage: @@ -480,31 +523,37 @@ def trainClassifier(cls, data, categoricalFeaturesInfo, loss="logLoss", numIterations=100, learningRate=0.1, maxDepth=3, maxBins=32): """ - Method to train a gradient-boosted trees model for - classification. - - :param data: Training dataset: RDD of LabeledPoint. - Labels should take values {0, 1}. - :param categoricalFeaturesInfo: Map storing arity of categorical - features. E.g., an entry (n -> k) indicates that feature - n is categorical with k categories indexed from 0: - {0, 1, ..., k-1}. - :param loss: Loss function used for minimization during gradient - boosting. Supported: {"logLoss" (default), - "leastSquaresError", "leastAbsoluteError"}. - :param numIterations: Number of iterations of boosting. - (default: 100) - :param learningRate: Learning rate for shrinking the - contribution of each estimator. The learning rate - should be between in the interval (0, 1]. - (default: 0.1) - :param maxDepth: Maximum depth of the tree. E.g., depth 0 means - 1 leaf node; depth 1 means 1 internal node + 2 leaf - nodes. (default: 3) - :param maxBins: maximum number of bins used for splitting - features (default: 32) DecisionTree requires maxBins >= max categories - :return: GradientBoostedTreesModel that can be used for - prediction + Train a gradient-boosted trees model for classification. + + :param data: + Training dataset: RDD of LabeledPoint. Labels should take values + {0, 1}. + :param categoricalFeaturesInfo: + Map storing arity of categorical features. An entry (n -> k) + indicates that feature n is categorical with k categories + indexed from 0: {0, 1, ..., k-1}. + :param loss: + Loss function used for minimization during gradient boosting. + Supported values: "logLoss", "leastSquaresError", + "leastAbsoluteError". + (default: "logLoss") + :param numIterations: + Number of iterations of boosting. + (default: 100) + :param learningRate: + Learning rate for shrinking the contribution of each estimator. + The learning rate should be between in the interval (0, 1]. + (default: 0.1) + :param maxDepth: + Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1 + means 1 internal node + 2 leaf nodes). + (default: 3) + :param maxBins: + Maximum number of bins used for splitting features. DecisionTree + requires maxBins >= max categories. + (default: 32) + :return: + GradientBoostedTreesModel that can be used for prediction. Example usage: @@ -543,30 +592,36 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, loss="leastSquaresError", numIterations=100, learningRate=0.1, maxDepth=3, maxBins=32): """ - Method to train a gradient-boosted trees model for regression. - - :param data: Training dataset: RDD of LabeledPoint. Labels are - real numbers. - :param categoricalFeaturesInfo: Map storing arity of categorical - features. E.g., an entry (n -> k) indicates that feature - n is categorical with k categories indexed from 0: - {0, 1, ..., k-1}. - :param loss: Loss function used for minimization during gradient - boosting. Supported: {"logLoss" (default), - "leastSquaresError", "leastAbsoluteError"}. - :param numIterations: Number of iterations of boosting. - (default: 100) - :param learningRate: Learning rate for shrinking the - contribution of each estimator. The learning rate - should be between in the interval (0, 1]. - (default: 0.1) - :param maxBins: maximum number of bins used for splitting - features (default: 32) DecisionTree requires maxBins >= max categories - :param maxDepth: Maximum depth of the tree. E.g., depth 0 means - 1 leaf node; depth 1 means 1 internal node + 2 leaf - nodes. (default: 3) - :return: GradientBoostedTreesModel that can be used for - prediction + Train a gradient-boosted trees model for regression. + + :param data: + Training dataset: RDD of LabeledPoint. Labels are real numbers. + :param categoricalFeaturesInfo: + Map storing arity of categorical features. An entry (n -> k) + indicates that feature n is categorical with k categories + indexed from 0: {0, 1, ..., k-1}. + :param loss: + Loss function used for minimization during gradient boosting. + Supported values: "logLoss", "leastSquaresError", + "leastAbsoluteError". + (default: "leastSquaresError") + :param numIterations: + Number of iterations of boosting. + (default: 100) + :param learningRate: + Learning rate for shrinking the contribution of each estimator. + The learning rate should be between in the interval (0, 1]. + (default: 0.1) + :param maxDepth: + Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1 + means 1 internal node + 2 leaf nodes). + (default: 3) + :param maxBins: + Maximum number of bins used for splitting features. DecisionTree + requires maxBins >= max categories. + (default: 32) + :return: + GradientBoostedTreesModel that can be used for prediction. Example usage: diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 4eaf589ad5d46..37574cea0b687 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2309,7 +2309,7 @@ def toLocalIterator(self): yield row -def _prepare_for_python_RDD(sc, command, obj=None): +def _prepare_for_python_RDD(sc, command): # the serialized command will be compressed by broadcast ser = CloudPickleSerializer() pickled_command = ser.dumps(command) @@ -2329,6 +2329,15 @@ def _prepare_for_python_RDD(sc, command, obj=None): return pickled_command, broadcast_vars, env, includes +def _wrap_function(sc, func, deserializer, serializer, profiler=None): + assert deserializer, "deserializer should not be empty" + assert serializer, "serializer should not be empty" + command = (func, profiler, deserializer, serializer) + pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command) + return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec, + sc.pythonVer, broadcast_vars, sc._javaAccumulator) + + class PipelinedRDD(RDD): """ @@ -2390,14 +2399,10 @@ def _jrdd(self): else: profiler = None - command = (self.func, profiler, self._prev_jrdd_deserializer, - self._jrdd_deserializer) - pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self.ctx, command, self) - python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), - bytearray(pickled_cmd), - env, includes, self.preservesPartitioning, - self.ctx.pythonExec, self.ctx.pythonVer, - bvars, self.ctx._javaAccumulator) + wrapped_func = _wrap_function(self.ctx, self.func, self._prev_jrdd_deserializer, + self._jrdd_deserializer, profiler) + python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), wrapped_func, + self.preservesPartitioning) self._jrdd_val = python_rdd.asJavaRDD() if profiler: diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 89bf1443a68a1..87e32c04eac1c 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -29,7 +29,7 @@ from py4j.protocol import Py4JError from pyspark import since -from pyspark.rdd import RDD, _prepare_for_python_RDD, ignore_unicode_prefix +from pyspark.rdd import RDD, ignore_unicode_prefix from pyspark.serializers import AutoBatchedSerializer, PickleSerializer from pyspark.sql.types import Row, StringType, StructType, _verify_type, \ _infer_schema, _has_nulltype, _merge_type, _create_converter diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 83b034fe77435..76fbb0c9aa4c9 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -173,7 +173,8 @@ def explain(self, extended=False): >>> df.explain() == Physical Plan == - Scan ExistingRDD[age#0,name#1] + WholeStageCodegen + : +- Scan ExistingRDD[age#0,name#1] >>> df.explain(True) == Parsed Logical Plan == @@ -551,8 +552,8 @@ def alias(self, alias): >>> df_as1 = df.alias("df_as1") >>> df_as2 = df.alias("df_as2") >>> joined_df = df_as1.join(df_as2, col("df_as1.name") == col("df_as2.name"), 'inner') - >>> joined_df.select(col("df_as1.name"), col("df_as2.name"), col("df_as2.age")).collect() - [Row(name=u'Bob', name=u'Bob', age=5), Row(name=u'Alice', name=u'Alice', age=2)] + >>> joined_df.select("df_as1.name", "df_as2.name", "df_as2.age").collect() + [Row(name=u'Alice', name=u'Alice', age=2), Row(name=u'Bob', name=u'Bob', age=5)] """ assert isinstance(alias, basestring), "alias should be a string" return DataFrame(getattr(self._jdf, "as")(alias), self.sql_ctx) @@ -1177,6 +1178,55 @@ def replace(self, to_replace, value, subset=None): return DataFrame( self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sql_ctx) + @since(2.0) + def approxQuantile(self, col, probabilities, relativeError): + """ + Calculates the approximate quantiles of a numerical column of a + DataFrame. + + The result of this algorithm has the following deterministic bound: + If the DataFrame has N elements and if we request the quantile at + probability `p` up to error `err`, then the algorithm will return + a sample `x` from the DataFrame so that the *exact* rank of `x` is + close to (p * N). More precisely, + + floor((p - err) * N) <= rank(x) <= ceil((p + err) * N). + + This method implements a variation of the Greenwald-Khanna + algorithm (with some speed optimizations). The algorithm was first + present in [[http://dx.doi.org/10.1145/375663.375670 + Space-efficient Online Computation of Quantile Summaries]] + by Greenwald and Khanna. + + :param col: the name of the numerical column + :param probabilities: a list of quantile probabilities + Each number must belong to [0, 1]. + For example 0 is the minimum, 0.5 is the median, 1 is the maximum. + :param relativeError: The relative target precision to achieve + (>= 0). If set to zero, the exact quantiles are computed, which + could be very expensive. Note that values greater than 1 are + accepted but give the same result as 1. + :return: the approximate quantiles at the given probabilities + """ + if not isinstance(col, str): + raise ValueError("col should be a string.") + + if not isinstance(probabilities, (list, tuple)): + raise ValueError("probabilities should be a list or tuple") + if isinstance(probabilities, tuple): + probabilities = list(probabilities) + for p in probabilities: + if not isinstance(p, (float, int, long)) or p < 0 or p > 1: + raise ValueError("probabilities should be numerical (float, int, long) in [0,1].") + probabilities = _to_list(self._sc, probabilities) + + if not isinstance(relativeError, (float, int, long)) or relativeError < 0: + raise ValueError("relativeError should be numerical (float, int, long) >= 0.") + relativeError = float(relativeError) + + jaq = self._jdf.stat().approxQuantile(col, probabilities, relativeError) + return list(jaq) + @since(1.4) def corr(self, col1, col2, method=None): """ @@ -1395,6 +1445,11 @@ class DataFrameStatFunctions(object): def __init__(self, df): self.df = df + def approxQuantile(self, col, probabilities, relativeError): + return self.df.approxQuantile(col, probabilities, relativeError) + + approxQuantile.__doc__ = DataFrame.approxQuantile.__doc__ + def corr(self, col1, col2, method=None): return self.df.corr(col1, col2, method) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 6894c2733828c..b30cc6799eb97 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -25,7 +25,7 @@ from itertools import imap as map from pyspark import since, SparkContext -from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix +from pyspark.rdd import _wrap_function, ignore_unicode_prefix from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.sql.types import StringType from pyspark.sql.column import Column, _to_java_column, _to_seq @@ -1645,16 +1645,14 @@ def _create_judf(self, name): f, returnType = self.func, self.returnType # put them in closure `func` func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it) ser = AutoBatchedSerializer(PickleSerializer()) - command = (func, None, ser, ser) sc = SparkContext.getOrCreate() - pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) + wrapped_func = _wrap_function(sc, func, ser, ser) ctx = SQLContext.getOrCreate(sc) jdt = ctx._ssql_ctx.parseDataType(self.returnType.json()) if name is None: name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__ judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( - name, bytearray(pickled_command), env, includes, sc.pythonExec, sc.pythonVer, - broadcast_vars, sc._javaAccumulator, jdt) + name, wrapped_func, jdt) return judf def __del__(self): diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index b1453c637f79e..7f5368d8bdbb2 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -233,6 +233,23 @@ def text(self, paths): paths = [paths] return self._df(self._jreader.text(self._sqlContext._sc._jvm.PythonUtils.toSeq(paths))) + @since(2.0) + def csv(self, paths): + """Loads a CSV file and returns the result as a [[DataFrame]]. + + This function goes through the input once to determine the input schema. To avoid going + through the entire data once, specify the schema explicitly using [[schema]]. + + :param paths: string, or list of strings, for input path(s). + + >>> df = sqlContext.read.csv('python/test_support/sql/ages.csv') + >>> df.dtypes + [('C0', 'string'), ('C1', 'string')] + """ + if isinstance(paths, basestring): + paths = [paths] + return self._df(self._jreader.csv(self._sqlContext._sc._jvm.PythonUtils.toSeq(paths))) + @since(1.5) def orc(self, path): """Loads an ORC file, returning the result as a :class:`DataFrame`. @@ -448,6 +465,11 @@ def json(self, path, mode=None): * ``ignore``: Silently ignore this operation if data already exists. * ``error`` (default case): Throw an exception if data already exists. + You can set the following JSON-specific option(s) for writing JSON files: + * ``compression`` (default ``None``): compression codec to use when saving to file. + This can be one of the known case-insensitive shorten names + (``bzip2``, ``gzip``, ``lz4``, and ``snappy``). + >>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode)._jwrite.json(path) @@ -476,11 +498,39 @@ def parquet(self, path, mode=None, partitionBy=None): def text(self, path): """Saves the content of the DataFrame in a text file at the specified path. + :param path: the path in any Hadoop supported file system + The DataFrame must have only one column that is of string type. Each row becomes a new line in the output file. + + You can set the following option(s) for writing text files: + * ``compression`` (default ``None``): compression codec to use when saving to file. + This can be one of the known case-insensitive shorten names + (``bzip2``, ``gzip``, ``lz4``, and ``snappy``). """ self._jwrite.text(path) + @since(2.0) + def csv(self, path, mode=None): + """Saves the content of the [[DataFrame]] in CSV format at the specified path. + + :param path: the path in any Hadoop supported file system + :param mode: specifies the behavior of the save operation when data already exists. + + * ``append``: Append contents of this :class:`DataFrame` to existing data. + * ``overwrite``: Overwrite existing data. + * ``ignore``: Silently ignore this operation if data already exists. + * ``error`` (default case): Throw an exception if data already exists. + + You can set the following CSV-specific option(s) for writing CSV files: + * ``compression`` (default ``None``): compression codec to use when saving to file. + This can be one of the known case-insensitive shorten names + (``bzip2``, ``gzip``, ``lz4``, and ``snappy``). + + >>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data')) + """ + self.mode(mode)._jwrite.csv(path) + @since(1.5) def orc(self, path, mode=None, partitionBy=None): """Saves the content of the :class:`DataFrame` in ORC format at the specified path. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index cc11c0f35cdc9..90fd7696910ed 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -669,6 +669,13 @@ def test_first_last_ignorenulls(self): functions.last(df2.id, True).alias('d')) self.assertEqual([Row(a=None, b=1, c=None, d=98)], df3.collect()) + def test_approxQuantile(self): + df = self.sc.parallelize([Row(a=i) for i in range(10)]).toDF() + aq = df.stat.approxQuantile("a", [0.1, 0.5, 0.9], 0.1) + self.assertTrue(isinstance(aq, list)) + self.assertEqual(len(aq), 3) + self.assertTrue(all(isinstance(q, float) for q in aq)) + def test_corr(self): import math df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF() diff --git a/python/test_support/sql/ages.csv b/python/test_support/sql/ages.csv new file mode 100644 index 0000000000000..18991feda788a --- /dev/null +++ b/python/test_support/sql/ages.csv @@ -0,0 +1,4 @@ +Joe,20 +Tom,30 +Hyukjin,25 + diff --git a/sbin/start-all.sh b/sbin/start-all.sh index 6217f9bf28e3d..a5d30d274ea6e 100755 --- a/sbin/start-all.sh +++ b/sbin/start-all.sh @@ -25,22 +25,11 @@ if [ -z "${SPARK_HOME}" ]; then export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" fi -TACHYON_STR="" - -while (( "$#" )); do -case $1 in - --with-tachyon) - TACHYON_STR="--with-tachyon" - ;; - esac -shift -done - # Load the Spark configuration . "${SPARK_HOME}/sbin/spark-config.sh" # Start Master -"${SPARK_HOME}/sbin"/start-master.sh $TACHYON_STR +"${SPARK_HOME}/sbin"/start-master.sh # Start Workers -"${SPARK_HOME}/sbin"/start-slaves.sh $TACHYON_STR +"${SPARK_HOME}/sbin"/start-slaves.sh diff --git a/sbin/start-master.sh b/sbin/start-master.sh index 9f2e14dff609f..ce7f17795997e 100755 --- a/sbin/start-master.sh +++ b/sbin/start-master.sh @@ -39,21 +39,6 @@ fi ORIGINAL_ARGS="$@" -START_TACHYON=false - -while (( "$#" )); do -case $1 in - --with-tachyon) - if [ ! -e "${SPARK_HOME}"/tachyon/bin/tachyon ]; then - echo "Error: --with-tachyon specified, but tachyon not found." - exit -1 - fi - START_TACHYON=true - ;; - esac -shift -done - . "${SPARK_HOME}/sbin/spark-config.sh" . "${SPARK_HOME}/bin/load-spark-env.sh" @@ -73,9 +58,3 @@ fi "${SPARK_HOME}/sbin"/spark-daemon.sh start $CLASS 1 \ --ip $SPARK_MASTER_IP --port $SPARK_MASTER_PORT --webui-port $SPARK_MASTER_WEBUI_PORT \ $ORIGINAL_ARGS - -if [ "$START_TACHYON" == "true" ]; then - "${SPARK_HOME}"/tachyon/bin/tachyon bootstrap-conf $SPARK_MASTER_IP - "${SPARK_HOME}"/tachyon/bin/tachyon format -s - "${SPARK_HOME}"/tachyon/bin/tachyon-start.sh master -fi diff --git a/sbin/start-slaves.sh b/sbin/start-slaves.sh index 51ca81e053b70..5bf2b83b42ce4 100755 --- a/sbin/start-slaves.sh +++ b/sbin/start-slaves.sh @@ -23,21 +23,6 @@ if [ -z "${SPARK_HOME}" ]; then export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" fi -START_TACHYON=false - -while (( "$#" )); do -case $1 in - --with-tachyon) - if [ ! -e "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon ]; then - echo "Error: --with-tachyon specified, but tachyon not found." - exit -1 - fi - START_TACHYON=true - ;; - esac -shift -done - . "${SPARK_HOME}/sbin/spark-config.sh" . "${SPARK_HOME}/bin/load-spark-env.sh" @@ -50,12 +35,5 @@ if [ "$SPARK_MASTER_IP" = "" ]; then SPARK_MASTER_IP="`hostname`" fi -if [ "$START_TACHYON" == "true" ]; then - "${SPARK_HOME}/sbin/slaves.sh" cd "${SPARK_HOME}" \; "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon bootstrap-conf "$SPARK_MASTER_IP" - - # set -t so we can call sudo - SPARK_SSH_OPTS="-o StrictHostKeyChecking=no -t" "${SPARK_HOME}/sbin/slaves.sh" cd "${SPARK_HOME}" \; "${SPARK_HOME}/tachyon/bin/tachyon-start.sh" worker SudoMount \; sleep 1 -fi - # Launch the slaves "${SPARK_HOME}/sbin/slaves.sh" cd "${SPARK_HOME}" \; "${SPARK_HOME}/sbin/start-slave.sh" "spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT" diff --git a/sbin/stop-master.sh b/sbin/stop-master.sh index e57962bb354d9..14644ea72d43b 100755 --- a/sbin/stop-master.sh +++ b/sbin/stop-master.sh @@ -26,7 +26,3 @@ fi . "${SPARK_HOME}/sbin/spark-config.sh" "${SPARK_HOME}/sbin"/spark-daemon.sh stop org.apache.spark.deploy.master.Master 1 - -if [ -e "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon ]; then - "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon killAll tachyon.master.Master -fi diff --git a/sbin/stop-slaves.sh b/sbin/stop-slaves.sh index 63956377629d6..a57441b52a04a 100755 --- a/sbin/stop-slaves.sh +++ b/sbin/stop-slaves.sh @@ -25,9 +25,4 @@ fi . "${SPARK_HOME}/bin/load-spark-env.sh" -# do before the below calls as they exec -if [ -e "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon ]; then - "${SPARK_HOME}/sbin/slaves.sh" cd "${SPARK_HOME}" \; "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon killAll tachyon.worker.Worker -fi - "${SPARK_HOME}/sbin/slaves.sh" cd "${SPARK_HOME}" \; "${SPARK_HOME}/sbin"/stop-slave.sh diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g index 10f2e2416bb64..13a6a2d276a57 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g @@ -398,198 +398,3 @@ precedenceOrExpression : precedenceAndExpression (precedenceOrOperator^ precedenceAndExpression)* ; - - -booleanValue - : - KW_TRUE^ | KW_FALSE^ - ; - -booleanValueTok - : - KW_TRUE -> TOK_TRUE - | KW_FALSE -> TOK_FALSE - ; - -tableOrPartition - : - tableName partitionSpec? -> ^(TOK_TAB tableName partitionSpec?) - ; - -partitionSpec - : - KW_PARTITION - LPAREN partitionVal (COMMA partitionVal )* RPAREN -> ^(TOK_PARTSPEC partitionVal +) - ; - -partitionVal - : - identifier (EQUAL constant)? -> ^(TOK_PARTVAL identifier constant?) - ; - -dropPartitionSpec - : - KW_PARTITION - LPAREN dropPartitionVal (COMMA dropPartitionVal )* RPAREN -> ^(TOK_PARTSPEC dropPartitionVal +) - ; - -dropPartitionVal - : - identifier dropPartitionOperator constant -> ^(TOK_PARTVAL identifier dropPartitionOperator constant) - ; - -dropPartitionOperator - : - EQUAL | NOTEQUAL | LESSTHANOREQUALTO | LESSTHAN | GREATERTHANOREQUALTO | GREATERTHAN - ; - -sysFuncNames - : - KW_AND - | KW_OR - | KW_NOT - | KW_LIKE - | KW_IF - | KW_CASE - | KW_WHEN - | KW_TINYINT - | KW_SMALLINT - | KW_INT - | KW_BIGINT - | KW_FLOAT - | KW_DOUBLE - | KW_BOOLEAN - | KW_STRING - | KW_BINARY - | KW_ARRAY - | KW_MAP - | KW_STRUCT - | KW_UNIONTYPE - | EQUAL - | EQUAL_NS - | NOTEQUAL - | LESSTHANOREQUALTO - | LESSTHAN - | GREATERTHANOREQUALTO - | GREATERTHAN - | DIVIDE - | PLUS - | MINUS - | STAR - | MOD - | DIV - | AMPERSAND - | TILDE - | BITWISEOR - | BITWISEXOR - | KW_RLIKE - | KW_REGEXP - | KW_IN - | KW_BETWEEN - ; - -descFuncNames - : - (sysFuncNames) => sysFuncNames - | StringLiteral - | functionIdentifier - ; - -//We are allowed to use From and To in CreateTableUsing command's options (actually seems we can use any string as the option key). But we can't simply add them into nonReserved because by doing that we mess other existing rules. So we create a looseIdentifier and looseNonReserved here. -looseIdentifier - : - Identifier - | looseNonReserved -> Identifier[$looseNonReserved.text] - // If it decides to support SQL11 reserved keywords, i.e., useSQL11ReservedKeywordsForIdentifier()=false, - // the sql11keywords in existing q tests will NOT be added back. - | {useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsIdentifier -> Identifier[$sql11ReservedKeywordsUsedAsIdentifier.text] - ; - -identifier - : - Identifier - | nonReserved -> Identifier[$nonReserved.text] - // If it decides to support SQL11 reserved keywords, i.e., useSQL11ReservedKeywordsForIdentifier()=false, - // the sql11keywords in existing q tests will NOT be added back. - | {useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsIdentifier -> Identifier[$sql11ReservedKeywordsUsedAsIdentifier.text] - ; - -functionIdentifier -@init { gParent.pushMsg("function identifier", state); } -@after { gParent.popMsg(state); } - : - identifier (DOT identifier)? -> identifier+ - ; - -principalIdentifier -@init { gParent.pushMsg("identifier for principal spec", state); } -@after { gParent.popMsg(state); } - : identifier - | QuotedIdentifier - ; - -looseNonReserved - : nonReserved | KW_FROM | KW_TO - ; - -//The new version of nonReserved + sql11ReservedKeywordsUsedAsIdentifier = old version of nonReserved -//Non reserved keywords are basically the keywords that can be used as identifiers. -//All the KW_* are automatically not only keywords, but also reserved keywords. -//That means, they can NOT be used as identifiers. -//If you would like to use them as identifiers, put them in the nonReserved list below. -//If you are not sure, please refer to the SQL2011 column in -//http://www.postgresql.org/docs/9.5/static/sql-keywords-appendix.html -nonReserved - : - KW_ADD | KW_ADMIN | KW_AFTER | KW_ANALYZE | KW_ARCHIVE | KW_ASC | KW_BEFORE | KW_BUCKET | KW_BUCKETS - | KW_CASCADE | KW_CHANGE | KW_CLUSTER | KW_CLUSTERED | KW_CLUSTERSTATUS | KW_COLLECTION | KW_COLUMNS - | KW_COMMENT | KW_COMPACT | KW_COMPACTIONS | KW_COMPUTE | KW_CONCATENATE | KW_CONTINUE | KW_DATA | KW_DAY - | KW_DATABASES | KW_DATETIME | KW_DBPROPERTIES | KW_DEFERRED | KW_DEFINED | KW_DELIMITED | KW_DEPENDENCY - | KW_DESC | KW_DIRECTORIES | KW_DIRECTORY | KW_DISABLE | KW_DISTRIBUTE | KW_ELEM_TYPE - | KW_ENABLE | KW_ESCAPED | KW_EXCLUSIVE | KW_EXPLAIN | KW_EXPORT | KW_FIELDS | KW_FILE | KW_FILEFORMAT - | KW_FIRST | KW_FORMAT | KW_FORMATTED | KW_FUNCTIONS | KW_HOLD_DDLTIME | KW_HOUR | KW_IDXPROPERTIES | KW_IGNORE - | KW_INDEX | KW_INDEXES | KW_INPATH | KW_INPUTDRIVER | KW_INPUTFORMAT | KW_ITEMS | KW_JAR - | KW_KEYS | KW_KEY_TYPE | KW_LIMIT | KW_LINES | KW_LOAD | KW_LOCATION | KW_LOCK | KW_LOCKS | KW_LOGICAL | KW_LONG - | KW_MAPJOIN | KW_MATERIALIZED | KW_METADATA | KW_MINUS | KW_MINUTE | KW_MONTH | KW_MSCK | KW_NOSCAN | KW_NO_DROP | KW_OFFLINE - | KW_OPTION | KW_OUTPUTDRIVER | KW_OUTPUTFORMAT | KW_OVERWRITE | KW_OWNER | KW_PARTITIONED | KW_PARTITIONS | KW_PLUS | KW_PRETTY - | KW_PRINCIPALS | KW_PROTECTION | KW_PURGE | KW_READ | KW_READONLY | KW_REBUILD | KW_RECORDREADER | KW_RECORDWRITER - | KW_RELOAD | KW_RENAME | KW_REPAIR | KW_REPLACE | KW_REPLICATION | KW_RESTRICT | KW_REWRITE - | KW_ROLE | KW_ROLES | KW_SCHEMA | KW_SCHEMAS | KW_SECOND | KW_SEMI | KW_SERDE | KW_SERDEPROPERTIES | KW_SERVER | KW_SETS | KW_SHARED - | KW_SHOW | KW_SHOW_DATABASE | KW_SKEWED | KW_SORT | KW_SORTED | KW_SSL | KW_STATISTICS | KW_STORED - | KW_STREAMTABLE | KW_STRING | KW_STRUCT | KW_TABLES | KW_TBLPROPERTIES | KW_TEMPORARY | KW_TERMINATED - | KW_TINYINT | KW_TOUCH | KW_TRANSACTIONS | KW_UNARCHIVE | KW_UNDO | KW_UNIONTYPE | KW_UNLOCK | KW_UNSET - | KW_UNSIGNED | KW_URI | KW_USE | KW_UTC | KW_UTCTIMESTAMP | KW_VALUE_TYPE | KW_VIEW | KW_WHILE | KW_YEAR - | KW_WORK - | KW_TRANSACTION - | KW_WRITE - | KW_ISOLATION - | KW_LEVEL - | KW_SNAPSHOT - | KW_AUTOCOMMIT - | KW_ANTI - | KW_WEEK | KW_MILLISECOND | KW_MICROSECOND - | KW_CLEAR | KW_LAZY | KW_CACHE | KW_UNCACHE | KW_DFS -; - -//The following SQL2011 reserved keywords are used as cast function name only, but not as identifiers. -sql11ReservedKeywordsUsedAsCastFunctionName - : - KW_BIGINT | KW_BINARY | KW_BOOLEAN | KW_CURRENT_DATE | KW_CURRENT_TIMESTAMP | KW_DATE | KW_DOUBLE | KW_FLOAT | KW_INT | KW_SMALLINT | KW_TIMESTAMP - ; - -//The following SQL2011 reserved keywords are used as identifiers in many q tests, they may be added back due to backward compatibility. -//We are planning to remove the following whole list after several releases. -//Thus, please do not change the following list unless you know what to do. -sql11ReservedKeywordsUsedAsIdentifier - : - KW_ALL | KW_ALTER | KW_ARRAY | KW_AS | KW_AUTHORIZATION | KW_BETWEEN | KW_BIGINT | KW_BINARY | KW_BOOLEAN - | KW_BOTH | KW_BY | KW_CREATE | KW_CUBE | KW_CURRENT_DATE | KW_CURRENT_TIMESTAMP | KW_CURSOR | KW_DATE | KW_DECIMAL | KW_DELETE | KW_DESCRIBE - | KW_DOUBLE | KW_DROP | KW_EXISTS | KW_EXTERNAL | KW_FALSE | KW_FETCH | KW_FLOAT | KW_FOR | KW_FULL | KW_GRANT - | KW_GROUP | KW_GROUPING | KW_IMPORT | KW_IN | KW_INNER | KW_INSERT | KW_INT | KW_INTERSECT | KW_INTO | KW_IS | KW_LATERAL - | KW_LEFT | KW_LIKE | KW_LOCAL | KW_NONE | KW_NULL | KW_OF | KW_ORDER | KW_OUT | KW_OUTER | KW_PARTITION - | KW_PERCENT | KW_PROCEDURE | KW_RANGE | KW_READS | KW_REVOKE | KW_RIGHT - | KW_ROLLUP | KW_ROW | KW_ROWS | KW_SET | KW_SMALLINT | KW_TABLE | KW_TIMESTAMP | KW_TO | KW_TRIGGER | KW_TRUE - | KW_TRUNCATE | KW_UNION | KW_UPDATE | KW_USER | KW_USING | KW_VALUES | KW_WITH -//The following two keywords come from MySQL. Although they are not keywords in SQL2011, they are reserved keywords in MySQL. - | KW_REGEXP | KW_RLIKE - ; diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/KeywordParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/KeywordParser.g new file mode 100644 index 0000000000000..12cd5f54a0297 --- /dev/null +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/KeywordParser.g @@ -0,0 +1,244 @@ +/** + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + This file is an adaptation of Hive's org/apache/hadoop/hive/ql/IdentifiersParser.g grammar. +*/ + +parser grammar KeywordParser; + +options +{ +output=AST; +ASTLabelType=CommonTree; +backtrack=false; +k=3; +} + +@members { + @Override + public Object recoverFromMismatchedSet(IntStream input, + RecognitionException re, BitSet follow) throws RecognitionException { + throw re; + } + @Override + public void displayRecognitionError(String[] tokenNames, + RecognitionException e) { + gParent.displayRecognitionError(tokenNames, e); + } + protected boolean useSQL11ReservedKeywordsForIdentifier() { + return gParent.useSQL11ReservedKeywordsForIdentifier(); + } +} + +@rulecatch { +catch (RecognitionException e) { + throw e; +} +} + +booleanValue + : + KW_TRUE^ | KW_FALSE^ + ; + +booleanValueTok + : + KW_TRUE -> TOK_TRUE + | KW_FALSE -> TOK_FALSE + ; + +tableOrPartition + : + tableName partitionSpec? -> ^(TOK_TAB tableName partitionSpec?) + ; + +partitionSpec + : + KW_PARTITION + LPAREN partitionVal (COMMA partitionVal )* RPAREN -> ^(TOK_PARTSPEC partitionVal +) + ; + +partitionVal + : + identifier (EQUAL constant)? -> ^(TOK_PARTVAL identifier constant?) + ; + +dropPartitionSpec + : + KW_PARTITION + LPAREN dropPartitionVal (COMMA dropPartitionVal )* RPAREN -> ^(TOK_PARTSPEC dropPartitionVal +) + ; + +dropPartitionVal + : + identifier dropPartitionOperator constant -> ^(TOK_PARTVAL identifier dropPartitionOperator constant) + ; + +dropPartitionOperator + : + EQUAL | NOTEQUAL | LESSTHANOREQUALTO | LESSTHAN | GREATERTHANOREQUALTO | GREATERTHAN + ; + +sysFuncNames + : + KW_AND + | KW_OR + | KW_NOT + | KW_LIKE + | KW_IF + | KW_CASE + | KW_WHEN + | KW_TINYINT + | KW_SMALLINT + | KW_INT + | KW_BIGINT + | KW_FLOAT + | KW_DOUBLE + | KW_BOOLEAN + | KW_STRING + | KW_BINARY + | KW_ARRAY + | KW_MAP + | KW_STRUCT + | KW_UNIONTYPE + | EQUAL + | EQUAL_NS + | NOTEQUAL + | LESSTHANOREQUALTO + | LESSTHAN + | GREATERTHANOREQUALTO + | GREATERTHAN + | DIVIDE + | PLUS + | MINUS + | STAR + | MOD + | DIV + | AMPERSAND + | TILDE + | BITWISEOR + | BITWISEXOR + | KW_RLIKE + | KW_REGEXP + | KW_IN + | KW_BETWEEN + ; + +descFuncNames + : + (sysFuncNames) => sysFuncNames + | StringLiteral + | functionIdentifier + ; + +//We are allowed to use From and To in CreateTableUsing command's options (actually seems we can use any string as the option key). But we can't simply add them into nonReserved because by doing that we mess other existing rules. So we create a looseIdentifier and looseNonReserved here. +looseIdentifier + : + Identifier + | looseNonReserved -> Identifier[$looseNonReserved.text] + // If it decides to support SQL11 reserved keywords, i.e., useSQL11ReservedKeywordsForIdentifier()=false, + // the sql11keywords in existing q tests will NOT be added back. + | {useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsIdentifier -> Identifier[$sql11ReservedKeywordsUsedAsIdentifier.text] + ; + +identifier + : + Identifier + | nonReserved -> Identifier[$nonReserved.text] + // If it decides to support SQL11 reserved keywords, i.e., useSQL11ReservedKeywordsForIdentifier()=false, + // the sql11keywords in existing q tests will NOT be added back. + | {useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsIdentifier -> Identifier[$sql11ReservedKeywordsUsedAsIdentifier.text] + ; + +functionIdentifier +@init { gParent.pushMsg("function identifier", state); } +@after { gParent.popMsg(state); } + : + identifier (DOT identifier)? -> identifier+ + ; + +principalIdentifier +@init { gParent.pushMsg("identifier for principal spec", state); } +@after { gParent.popMsg(state); } + : identifier + | QuotedIdentifier + ; + +looseNonReserved + : nonReserved | KW_FROM | KW_TO + ; + +//The new version of nonReserved + sql11ReservedKeywordsUsedAsIdentifier = old version of nonReserved +//Non reserved keywords are basically the keywords that can be used as identifiers. +//All the KW_* are automatically not only keywords, but also reserved keywords. +//That means, they can NOT be used as identifiers. +//If you would like to use them as identifiers, put them in the nonReserved list below. +//If you are not sure, please refer to the SQL2011 column in +//http://www.postgresql.org/docs/9.5/static/sql-keywords-appendix.html +nonReserved + : + KW_ADD | KW_ADMIN | KW_AFTER | KW_ANALYZE | KW_ARCHIVE | KW_ASC | KW_BEFORE | KW_BUCKET | KW_BUCKETS + | KW_CASCADE | KW_CHANGE | KW_CLUSTER | KW_CLUSTERED | KW_CLUSTERSTATUS | KW_COLLECTION | KW_COLUMNS + | KW_COMMENT | KW_COMPACT | KW_COMPACTIONS | KW_COMPUTE | KW_CONCATENATE | KW_CONTINUE | KW_DATA | KW_DAY + | KW_DATABASES | KW_DATETIME | KW_DBPROPERTIES | KW_DEFERRED | KW_DEFINED | KW_DELIMITED | KW_DEPENDENCY + | KW_DESC | KW_DIRECTORIES | KW_DIRECTORY | KW_DISABLE | KW_DISTRIBUTE | KW_ELEM_TYPE + | KW_ENABLE | KW_ESCAPED | KW_EXCLUSIVE | KW_EXPLAIN | KW_EXPORT | KW_FIELDS | KW_FILE | KW_FILEFORMAT + | KW_FIRST | KW_FORMAT | KW_FORMATTED | KW_FUNCTIONS | KW_HOLD_DDLTIME | KW_HOUR | KW_IDXPROPERTIES | KW_IGNORE + | KW_INDEX | KW_INDEXES | KW_INPATH | KW_INPUTDRIVER | KW_INPUTFORMAT | KW_ITEMS | KW_JAR + | KW_KEYS | KW_KEY_TYPE | KW_LIMIT | KW_LINES | KW_LOAD | KW_LOCATION | KW_LOCK | KW_LOCKS | KW_LOGICAL | KW_LONG + | KW_MAPJOIN | KW_MATERIALIZED | KW_METADATA | KW_MINUS | KW_MINUTE | KW_MONTH | KW_MSCK | KW_NOSCAN | KW_NO_DROP | KW_OFFLINE + | KW_OPTION | KW_OUTPUTDRIVER | KW_OUTPUTFORMAT | KW_OVERWRITE | KW_OWNER | KW_PARTITIONED | KW_PARTITIONS | KW_PLUS | KW_PRETTY + | KW_PRINCIPALS | KW_PROTECTION | KW_PURGE | KW_READ | KW_READONLY | KW_REBUILD | KW_RECORDREADER | KW_RECORDWRITER + | KW_RELOAD | KW_RENAME | KW_REPAIR | KW_REPLACE | KW_REPLICATION | KW_RESTRICT | KW_REWRITE + | KW_ROLE | KW_ROLES | KW_SCHEMA | KW_SCHEMAS | KW_SECOND | KW_SEMI | KW_SERDE | KW_SERDEPROPERTIES | KW_SERVER | KW_SETS | KW_SHARED + | KW_SHOW | KW_SHOW_DATABASE | KW_SKEWED | KW_SORT | KW_SORTED | KW_SSL | KW_STATISTICS | KW_STORED + | KW_STREAMTABLE | KW_STRING | KW_STRUCT | KW_TABLES | KW_TBLPROPERTIES | KW_TEMPORARY | KW_TERMINATED + | KW_TINYINT | KW_TOUCH | KW_TRANSACTIONS | KW_UNARCHIVE | KW_UNDO | KW_UNIONTYPE | KW_UNLOCK | KW_UNSET + | KW_UNSIGNED | KW_URI | KW_USE | KW_UTC | KW_UTCTIMESTAMP | KW_VALUE_TYPE | KW_VIEW | KW_WHILE | KW_YEAR + | KW_WORK + | KW_TRANSACTION + | KW_WRITE + | KW_ISOLATION + | KW_LEVEL + | KW_SNAPSHOT + | KW_AUTOCOMMIT + | KW_ANTI + | KW_WEEK | KW_MILLISECOND | KW_MICROSECOND + | KW_CLEAR | KW_LAZY | KW_CACHE | KW_UNCACHE | KW_DFS +; + +//The following SQL2011 reserved keywords are used as cast function name only, but not as identifiers. +sql11ReservedKeywordsUsedAsCastFunctionName + : + KW_BIGINT | KW_BINARY | KW_BOOLEAN | KW_CURRENT_DATE | KW_CURRENT_TIMESTAMP | KW_DATE | KW_DOUBLE | KW_FLOAT | KW_INT | KW_SMALLINT | KW_TIMESTAMP + ; + +//The following SQL2011 reserved keywords are used as identifiers in many q tests, they may be added back due to backward compatibility. +//We are planning to remove the following whole list after several releases. +//Thus, please do not change the following list unless you know what to do. +sql11ReservedKeywordsUsedAsIdentifier + : + KW_ALL | KW_ALTER | KW_ARRAY | KW_AS | KW_AUTHORIZATION | KW_BETWEEN | KW_BIGINT | KW_BINARY | KW_BOOLEAN + | KW_BOTH | KW_BY | KW_CREATE | KW_CUBE | KW_CURRENT_DATE | KW_CURRENT_TIMESTAMP | KW_CURSOR | KW_DATE | KW_DECIMAL | KW_DELETE | KW_DESCRIBE + | KW_DOUBLE | KW_DROP | KW_EXISTS | KW_EXTERNAL | KW_FALSE | KW_FETCH | KW_FLOAT | KW_FOR | KW_FULL | KW_GRANT + | KW_GROUP | KW_GROUPING | KW_IMPORT | KW_IN | KW_INNER | KW_INSERT | KW_INT | KW_INTERSECT | KW_INTO | KW_IS | KW_LATERAL + | KW_LEFT | KW_LIKE | KW_LOCAL | KW_NONE | KW_NULL | KW_OF | KW_ORDER | KW_OUT | KW_OUTER | KW_PARTITION + | KW_PERCENT | KW_PROCEDURE | KW_RANGE | KW_READS | KW_REVOKE | KW_RIGHT + | KW_ROLLUP | KW_ROW | KW_ROWS | KW_SET | KW_SMALLINT | KW_TABLE | KW_TIMESTAMP | KW_TO | KW_TRIGGER | KW_TRUE + | KW_TRUNCATE | KW_UNION | KW_UPDATE | KW_USER | KW_USING | KW_VALUES | KW_WITH +//The following two keywords come from MySQL. Although they are not keywords in SQL2011, they are reserved keywords in MySQL. + | KW_REGEXP | KW_RLIKE + ; diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g index 9aeea69cd2d9b..1db3aed65815d 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g @@ -26,7 +26,7 @@ ASTLabelType=CommonTree; backtrack=false; k=3; } -import SelectClauseParser, FromClauseParser, IdentifiersParser, ExpressionParser; +import SelectClauseParser, FromClauseParser, IdentifiersParser, KeywordParser, ExpressionParser; tokens { TOK_INSERT; @@ -2320,26 +2320,6 @@ regularBody[boolean topLevel] ) | selectStatement[topLevel] - | - (LPAREN selectStatement0[true]) => nestedSetOpSelectStatement[topLevel] - ; - -nestedSetOpSelectStatement[boolean topLevel] - : - ( - LPAREN s=selectStatement0[topLevel] RPAREN -> {$s.tree} - ) - (set=setOpSelectStatement[$nestedSetOpSelectStatement.tree, topLevel]) - -> {set == null}? - {$nestedSetOpSelectStatement.tree} - -> {$set.tree} - ; - -selectStatement0[boolean topLevel] - : - (selectStatement[true]) => selectStatement[topLevel] - | - (nestedSetOpSelectStatement[true]) => nestedSetOpSelectStatement[topLevel] ; selectStatement[boolean topLevel] diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 27ae62f1212f6..0ad0f4976c77a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -36,7 +36,7 @@ import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter; import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator; -final class UnsafeExternalRowSorter { +public final class UnsafeExternalRowSorter { /** * If positive, forces records to be spilled to disk at the given frequency (measured in numbers @@ -84,8 +84,7 @@ void setTestSpillFrequency(int frequency) { testSpillFrequency = frequency; } - @VisibleForTesting - void insertRow(UnsafeRow row) throws IOException { + public void insertRow(UnsafeRow row) throws IOException { final long prefix = prefixComputer.computePrefix(row); sorter.insertRecord( row.getBaseObject(), @@ -110,8 +109,7 @@ private void cleanupResources() { sorter.cleanupResources(); } - @VisibleForTesting - Iterator sort() throws IOException { + public Iterator sort() throws IOException { try { final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator(); if (!sortedIterator.hasNext()) { @@ -160,7 +158,6 @@ public UnsafeRow next() { } } - public Iterator sort(Iterator inputIterator) throws IOException { while (inputIterator.hasNext()) { insertRow(inputIterator.next()); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala index 069c665a39680..a0a56d728cde9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala @@ -499,12 +499,14 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C s"Sampling fraction ($fraction) must be on interval [0, 100]") Sample(0.0, fraction.toDouble / 100, withReplacement = false, (math.random * 1000).toInt, - relation) + relation)( + isTableSample = true) case Token("TOK_TABLEBUCKETSAMPLE", Token(numerator, Nil) :: Token(denominator, Nil) :: Nil) => val fraction = numerator.toDouble / denominator.toDouble - Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, relation) + Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, relation)( + isTableSample = true) case a => noParseRule("Sampling", a) }.getOrElse(relation) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 23e4709bbd882..876aa0eae0e90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.analysis +import java.lang.reflect.Modifier + import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException @@ -559,7 +561,13 @@ class Analyzer( } resolveExpression(unbound, LocalRelation(attributes), throws = true) transform { - case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass => + case n: NewInstance + // If this is an inner class of another class, register the outer object in `OuterScopes`. + // Note that static inner classes (e.g., inner classes within Scala objects) don't need + // outer pointer registration. + if n.outerPointer.isEmpty && + n.cls.isMemberClass && + !Modifier.isStatic(n.cls.getModifiers) => val outer = OuterScopes.outerScopes.get(n.cls.getDeclaringClass.getName) if (outer == null) { throw new AnalysisException( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 1be97c7b81197..26bb96eb085ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -113,6 +113,7 @@ object FunctionRegistry { type FunctionBuilder = Seq[Expression] => Expression + // Note: Whenever we add a new entry here, make sure we also update ExpressionToSQLSuite val expressions: Map[String, (ExpressionInfo, FunctionBuilder)] = Map( // misc non-aggregate functions expression[Abs]("abs"), @@ -125,13 +126,12 @@ object FunctionRegistry { expression[IsNull]("isnull"), expression[IsNotNull]("isnotnull"), expression[Least]("least"), + expression[CreateNamedStruct]("named_struct"), + expression[NaNvl]("nanvl"), expression[Coalesce]("nvl"), expression[Rand]("rand"), expression[Randn]("randn"), expression[CreateStruct]("struct"), - expression[CreateNamedStruct]("named_struct"), - expression[Sqrt]("sqrt"), - expression[NaNvl]("nanvl"), // math functions expression[Acos]("acos"), @@ -145,24 +145,26 @@ object FunctionRegistry { expression[Cos]("cos"), expression[Cosh]("cosh"), expression[Conv]("conv"), + expression[ToDegrees]("degrees"), expression[EulerNumber]("e"), expression[Exp]("exp"), expression[Expm1]("expm1"), expression[Floor]("floor"), expression[Factorial]("factorial"), - expression[Hypot]("hypot"), expression[Hex]("hex"), + expression[Hypot]("hypot"), expression[Logarithm]("log"), - expression[Log]("ln"), expression[Log10]("log10"), expression[Log1p]("log1p"), expression[Log2]("log2"), + expression[Log]("ln"), expression[UnaryMinus]("negative"), expression[Pi]("pi"), - expression[Pow]("pow"), - expression[Pow]("power"), expression[Pmod]("pmod"), expression[UnaryPositive]("positive"), + expression[Pow]("pow"), + expression[Pow]("power"), + expression[ToRadians]("radians"), expression[Rint]("rint"), expression[Round]("round"), expression[ShiftLeft]("shiftleft"), @@ -172,10 +174,9 @@ object FunctionRegistry { expression[Signum]("signum"), expression[Sin]("sin"), expression[Sinh]("sinh"), + expression[Sqrt]("sqrt"), expression[Tan]("tan"), expression[Tanh]("tanh"), - expression[ToDegrees]("degrees"), - expression[ToRadians]("radians"), // aggregate functions expression[HyperLogLogPlusPlus]("approx_count_distinct"), @@ -186,11 +187,13 @@ object FunctionRegistry { expression[CovSample]("covar_samp"), expression[First]("first"), expression[First]("first_value"), + expression[Kurtosis]("kurtosis"), expression[Last]("last"), expression[Last]("last_value"), expression[Max]("max"), expression[Average]("mean"), expression[Min]("min"), + expression[Skewness]("skewness"), expression[StddevSamp]("stddev"), expression[StddevPop]("stddev_pop"), expression[StddevSamp]("stddev_samp"), @@ -198,36 +201,34 @@ object FunctionRegistry { expression[VarianceSamp]("variance"), expression[VariancePop]("var_pop"), expression[VarianceSamp]("var_samp"), - expression[Skewness]("skewness"), - expression[Kurtosis]("kurtosis"), // string functions expression[Ascii]("ascii"), expression[Base64]("base64"), expression[Concat]("concat"), expression[ConcatWs]("concat_ws"), - expression[Encode]("encode"), expression[Decode]("decode"), + expression[Encode]("encode"), expression[FindInSet]("find_in_set"), expression[FormatNumber]("format_number"), + expression[FormatString]("format_string"), expression[GetJsonObject]("get_json_object"), expression[InitCap]("initcap"), - expression[JsonTuple]("json_tuple"), + expression[StringInstr]("instr"), expression[Lower]("lcase"), - expression[Lower]("lower"), expression[Length]("length"), expression[Levenshtein]("levenshtein"), - expression[RegExpExtract]("regexp_extract"), - expression[RegExpReplace]("regexp_replace"), - expression[StringInstr]("instr"), + expression[Lower]("lower"), expression[StringLocate]("locate"), expression[StringLPad]("lpad"), expression[StringTrimLeft]("ltrim"), - expression[FormatString]("format_string"), + expression[JsonTuple]("json_tuple"), expression[FormatString]("printf"), - expression[StringRPad]("rpad"), + expression[RegExpExtract]("regexp_extract"), + expression[RegExpReplace]("regexp_replace"), expression[StringRepeat]("repeat"), expression[StringReverse]("reverse"), + expression[StringRPad]("rpad"), expression[StringTrimRight]("rtrim"), expression[SoundEx]("soundex"), expression[StringSpace]("space"), @@ -237,8 +238,8 @@ object FunctionRegistry { expression[SubstringIndex]("substring_index"), expression[StringTranslate]("translate"), expression[StringTrim]("trim"), - expression[UnBase64]("unbase64"), expression[Upper]("ucase"), + expression[UnBase64]("unbase64"), expression[Unhex]("unhex"), expression[Upper]("upper"), @@ -246,7 +247,6 @@ object FunctionRegistry { expression[AddMonths]("add_months"), expression[CurrentDate]("current_date"), expression[CurrentTimestamp]("current_timestamp"), - expression[CurrentTimestamp]("now"), expression[DateDiff]("datediff"), expression[DateAdd]("date_add"), expression[DateFormatClass]("date_format"), @@ -262,6 +262,7 @@ object FunctionRegistry { expression[Month]("month"), expression[MonthsBetween]("months_between"), expression[NextDay]("next_day"), + expression[CurrentTimestamp]("now"), expression[Quarter]("quarter"), expression[Second]("second"), expression[ToDate]("to_date"), @@ -273,9 +274,9 @@ object FunctionRegistry { expression[Year]("year"), // collection functions + expression[ArrayContains]("array_contains"), expression[Size]("size"), expression[SortArray]("sort_array"), - expression[ArrayContains]("array_contains"), // misc functions expression[Crc32]("crc32"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 1a2ec7ed931ea..a12f7396fe819 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -235,7 +235,7 @@ package object dsl { implicit class DslAttribute(a: AttributeReference) { def notNull: AttributeReference = a.withNullability(false) - def nullable: AttributeReference = a.withNullability(true) + def canBeNull: AttributeReference = a.withNullability(true) def at(ordinal: Int): BoundReference = BoundReference(ordinal, a.dataType, a.nullable) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala new file mode 100644 index 0000000000000..b58a5273041e4 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.rules._ + +/** + * Rewrites an expression using rules that are guaranteed preserve the result while attempting + * to remove cosmetic variations. Deterministic expressions that are `equal` after canonicalization + * will always return the same answer given the same input (i.e. false positives should not be + * possible). However, it is possible that two canonical expressions that are not equal will in fact + * return the same answer given any input (i.e. false negatives are possible). + * + * The following rules are applied: + * - Names and nullability hints for [[org.apache.spark.sql.types.DataType]]s are stripped. + * - Commutative and associative operations ([[Add]] and [[Multiply]]) have their children ordered + * by `hashCode`. +* - [[EqualTo]] and [[EqualNullSafe]] are reordered by `hashCode`. + * - Other comparisons ([[GreaterThan]], [[LessThan]]) are reversed by `hashCode`. + */ +object Canonicalize extends RuleExecutor[Expression] { + override protected def batches: Seq[Batch] = + Batch( + "Expression Canonicalization", FixedPoint(100), + IgnoreNamesTypes, + Reorder) :: Nil + + /** Remove names and nullability from types. */ + protected object IgnoreNamesTypes extends Rule[Expression] { + override def apply(e: Expression): Expression = e transformUp { + case a: AttributeReference => + AttributeReference("none", a.dataType.asNullable)(exprId = a.exprId) + } + } + + /** Collects adjacent commutative operations. */ + protected def gatherCommutative( + e: Expression, + f: PartialFunction[Expression, Seq[Expression]]): Seq[Expression] = e match { + case c if f.isDefinedAt(c) => f(c).flatMap(gatherCommutative(_, f)) + case other => other :: Nil + } + + /** Orders a set of commutative operations by their hash code. */ + protected def orderCommutative( + e: Expression, + f: PartialFunction[Expression, Seq[Expression]]): Seq[Expression] = + gatherCommutative(e, f).sortBy(_.hashCode()) + + /** Rearrange expressions that are commutative or associative. */ + protected object Reorder extends Rule[Expression] { + override def apply(e: Expression): Expression = e transformUp { + case a: Add => orderCommutative(a, { case Add(l, r) => Seq(l, r) }).reduce(Add) + case m: Multiply => orderCommutative(m, { case Multiply(l, r) => Seq(l, r) }).reduce(Multiply) + + case EqualTo(l, r) if l.hashCode() > r.hashCode() => EqualTo(r, l) + case EqualNullSafe(l, r) if l.hashCode() > r.hashCode() => EqualNullSafe(r, l) + + case GreaterThan(l, r) if l.hashCode() > r.hashCode() => LessThan(r, l) + case LessThan(l, r) if l.hashCode() > r.hashCode() => GreaterThan(r, l) + + case GreaterThanOrEqual(l, r) if l.hashCode() > r.hashCode() => LessThanOrEqual(r, l) + case LessThanOrEqual(l, r) if l.hashCode() > r.hashCode() => GreaterThanOrEqual(r, l) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 119496c7eee5e..692c16092fe3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -144,49 +144,32 @@ abstract class Expression extends TreeNode[Expression] { */ def childrenResolved: Boolean = children.forall(_.resolved) + /** + * Returns an expression where a best effort attempt has been made to transform `this` in a way + * that preserves the result but removes cosmetic variations (case sensitivity, ordering for + * commutative operations, etc.) See [[Canonicalize]] for more details. + * + * `deterministic` expressions where `this.canonicalized == other.canonicalized` will always + * evaluate to the same result. + */ + lazy val canonicalized: Expression = Canonicalize.execute(this) + /** * Returns true when two expressions will always compute the same result, even if they differ * cosmetically (i.e. capitalization of names in attributes may be different). + * + * See [[Canonicalize]] for more details. */ - def semanticEquals(other: Expression): Boolean = this.getClass == other.getClass && { - def checkSemantic(elements1: Seq[Any], elements2: Seq[Any]): Boolean = { - elements1.length == elements2.length && elements1.zip(elements2).forall { - case (e1: Expression, e2: Expression) => e1 semanticEquals e2 - case (Some(e1: Expression), Some(e2: Expression)) => e1 semanticEquals e2 - case (t1: Traversable[_], t2: Traversable[_]) => checkSemantic(t1.toSeq, t2.toSeq) - case (i1, i2) => i1 == i2 - } - } - // Non-deterministic expressions cannot be semantic equal - if (!deterministic || !other.deterministic) return false - val elements1 = this.productIterator.toSeq - val elements2 = other.asInstanceOf[Product].productIterator.toSeq - checkSemantic(elements1, elements2) - } + def semanticEquals(other: Expression): Boolean = + deterministic && other.deterministic && canonicalized == other.canonicalized /** - * Returns the hash for this expression. Expressions that compute the same result, even if - * they differ cosmetically should return the same hash. + * Returns a `hashCode` for the calculation performed by this expression. Unlike the standard + * `hashCode`, an attempt has been made to eliminate cosmetic differences. + * + * See [[Canonicalize]] for more details. */ - def semanticHash() : Int = { - def computeHash(e: Seq[Any]): Int = { - // See http://stackoverflow.com/questions/113511/hash-code-implementation - var hash: Int = 17 - e.foreach(i => { - val h: Int = i match { - case e: Expression => e.semanticHash() - case Some(e: Expression) => e.semanticHash() - case t: Traversable[_] => computeHash(t.toSeq) - case null => 0 - case other => other.hashCode() - } - hash = hash * 37 + h - }) - hash - } - - computeHash(this.productIterator.toSeq) - } + def semanticHash(): Int = canonicalized.hashCode() /** * Checks the input data types, returns `TypeCheckResult.success` if it's valid, @@ -369,7 +352,6 @@ abstract class UnaryExpression extends Expression { } } - /** * An expression with two inputs and one output. The output is by default evaluated to null * if any input is evaluated to null. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala new file mode 100644 index 0000000000000..acea049adca3d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +object ExpressionSet { + /** Constructs a new [[ExpressionSet]] by applying [[Canonicalize]] to `expressions`. */ + def apply(expressions: TraversableOnce[Expression]): ExpressionSet = { + val set = new ExpressionSet() + expressions.foreach(set.add) + set + } +} + +/** + * A [[Set]] where membership is determined based on a canonical representation of an [[Expression]] + * (i.e. one that attempts to ignore cosmetic differences). See [[Canonicalize]] for more details. + * + * Internally this set uses the canonical representation, but keeps also track of the original + * expressions to ease debugging. Since different expressions can share the same canonical + * representation, this means that operations that extract expressions from this set are only + * guranteed to see at least one such expression. For example: + * + * {{{ + * val set = AttributeSet(a + 1, 1 + a) + * + * set.iterator => Iterator(a + 1) + * set.contains(a + 1) => true + * set.contains(1 + a) => true + * set.contains(a + 2) => false + * }}} + */ +class ExpressionSet protected( + protected val baseSet: mutable.Set[Expression] = new mutable.HashSet, + protected val originals: mutable.Buffer[Expression] = new ArrayBuffer) + extends Set[Expression] { + + protected def add(e: Expression): Unit = { + if (!baseSet.contains(e.canonicalized)) { + baseSet.add(e.canonicalized) + originals.append(e) + } + } + + override def contains(elem: Expression): Boolean = baseSet.contains(elem.canonicalized) + + override def +(elem: Expression): ExpressionSet = { + val newSet = new ExpressionSet(baseSet.clone(), originals.clone()) + newSet.add(elem) + newSet + } + + override def -(elem: Expression): ExpressionSet = { + val newBaseSet = baseSet.clone().filterNot(_ == elem.canonicalized) + val newOriginals = originals.clone().filterNot(_.canonicalized == elem.canonicalized) + new ExpressionSet(newBaseSet, newOriginals) + } + + override def iterator: Iterator[Expression] = originals.iterator + + /** + * Returns a string containing both the post [[Canonicalize]] expressions and the original + * expressions in this set. + */ + def toDebugString: String = + s""" + |baseSet: ${baseSet.mkString(", ")} + |originals: ${originals.mkString(", ")} + """.stripMargin +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala index c49c601c3034b..dbd0acf06caa7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala @@ -35,7 +35,7 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override def dataType: DataType = StringType - override val prettyName = "INPUT_FILE_NAME" + override def prettyName: String = "input_file_name" override protected def initInternal(): Unit = {} @@ -48,6 +48,4 @@ case class InputFileName() extends LeafExpression with Nondeterministic { s"final ${ctx.javaType(dataType)} ${ev.value} = " + "org.apache.spark.rdd.SqlNewHadoopRDDState.getInputFileName();" } - - override def sql: String = prettyName } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala index a474017221721..32bae133608c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala @@ -68,6 +68,8 @@ case class HyperLogLogPlusPlus( inputAggBufferOffset = 0) } + override def prettyName: String = "approx_count_distinct" + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 5af234609ddd4..ed812e06799a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -95,8 +95,6 @@ case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes } protected override def nullSafeEval(input: Any): Any = numeric.abs(input) - - override def sql: String = s"$prettyName(${child.sql})" } abstract class BinaryArithmetic extends BinaryOperator { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 200c6a05df7e3..c3e9fa33e63a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -34,7 +34,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi if (predicate.dataType != BooleanType) { TypeCheckResult.TypeCheckFailure( s"type of predicate expression in If should be boolean, not ${predicate.dataType}") - } else if (trueValue.dataType != falseValue.dataType) { + } else if (trueValue.dataType.asNullable != falseValue.dataType.asNullable) { TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " + s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).") } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index dcbb594afd86e..33bd3f20959b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -234,8 +234,6 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression override def prettyName: String = "hash" - override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")}, $seed)" - override def eval(input: InternalRow): Any = { var hash = seed var i = 0 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 1554382840a48..2aeb9575f1dd2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -123,7 +123,7 @@ object SamplePushDown extends Rule[LogicalPlan] { // Push down projection into sample case Project(projectList, s @ Sample(lb, up, replace, seed, child)) => Sample(lb, up, replace, seed, - Project(projectList, child)) + Project(projectList, child))() } } @@ -313,97 +313,85 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { */ object ColumnPruning extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case a @ Aggregate(_, _, e @ Expand(projects, output, child)) - if (e.outputSet -- a.references).nonEmpty => - val newOutput = output.filter(a.references.contains(_)) - val newProjects = projects.map { proj => - proj.zip(output).filter { case (e, a) => + // Prunes the unused columns from project list of Project/Aggregate/Window/Expand + case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty => + p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains))) + case p @ Project(_, a: Aggregate) if (a.outputSet -- p.references).nonEmpty => + p.copy( + child = a.copy(aggregateExpressions = a.aggregateExpressions.filter(p.references.contains))) + case p @ Project(_, w: Window) if (w.outputSet -- p.references).nonEmpty => + p.copy(child = w.copy( + projectList = w.projectList.filter(p.references.contains), + windowExpressions = w.windowExpressions.filter(p.references.contains))) + case a @ Project(_, e @ Expand(_, _, grandChild)) if (e.outputSet -- a.references).nonEmpty => + val newOutput = e.output.filter(a.references.contains(_)) + val newProjects = e.projections.map { proj => + proj.zip(e.output).filter { case (e, a) => newOutput.contains(a) }.unzip._1 } - a.copy(child = Expand(newProjects, newOutput, child)) + a.copy(child = Expand(newProjects, newOutput, grandChild)) + // TODO: support some logical plan for Dataset - case a @ Aggregate(_, _, e @ Expand(_, _, child)) - if (child.outputSet -- e.references -- a.references).nonEmpty => - a.copy(child = e.copy(child = prunedChild(child, e.references ++ a.references))) - - // Eliminate attributes that are not needed to calculate the specified aggregates. + // Prunes the unused columns from child of Aggregate/Window/Expand/Generate case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => - a.copy(child = Project(a.references.toSeq, child)) - - // Eliminate attributes that are not needed to calculate the Generate. + a.copy(child = prunedChild(child, a.references)) + case w @ Window(_, _, _, _, child) if (child.outputSet -- w.references).nonEmpty => + w.copy(child = prunedChild(child, w.references)) + case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty => + e.copy(child = prunedChild(child, e.references)) case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty => - g.copy(child = Project(g.references.toSeq, g.child)) + g.copy(child = prunedChild(g.child, g.references)) + // Turn off `join` for Generate if no column from it's child is used case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) => p.copy(child = g.copy(join = false)) - case p @ Project(projectList, g: Generate) if g.join => - val neededChildOutput = p.references -- g.generatorOutput ++ g.references - if (neededChildOutput == g.child.outputSet) { - p + // Eliminate unneeded attributes from right side of a LeftSemiJoin. + case j @ Join(left, right, LeftSemi, condition) => + j.copy(right = prunedChild(right, j.references)) + + // all the columns will be used to compare, so we can't prune them + case p @ Project(_, _: SetOperation) => p + case p @ Project(_, _: Distinct) => p + // Eliminate unneeded attributes from children of Union. + case p @ Project(_, u: Union) => + if ((u.outputSet -- p.references).nonEmpty) { + val firstChild = u.children.head + val newOutput = prunedChild(firstChild, p.references).output + // pruning the columns of all children based on the pruned first child. + val newChildren = u.children.map { p => + val selected = p.output.zipWithIndex.filter { case (a, i) => + newOutput.contains(firstChild.output(i)) + }.map(_._1) + Project(selected, p) + } + p.copy(child = u.withNewChildren(newChildren)) } else { - Project(projectList, g.copy(child = Project(neededChildOutput.toSeq, g.child))) + p } - case p @ Project(projectList, a @ Aggregate(groupingExpressions, aggregateExpressions, child)) - if (a.outputSet -- p.references).nonEmpty => - Project( - projectList, - Aggregate( - groupingExpressions, - aggregateExpressions.filter(e => p.references.contains(e)), - child)) - - // Eliminate unneeded attributes from either side of a Join. - case Project(projectList, Join(left, right, joinType, condition)) => - // Collect the list of all references required either above or to evaluate the condition. - val allReferences: AttributeSet = - AttributeSet( - projectList.flatMap(_.references.iterator)) ++ - condition.map(_.references).getOrElse(AttributeSet(Seq.empty)) - - /** Applies a projection only when the child is producing unnecessary attributes */ - def pruneJoinChild(c: LogicalPlan): LogicalPlan = prunedChild(c, allReferences) - - Project(projectList, Join(pruneJoinChild(left), pruneJoinChild(right), joinType, condition)) - - // Eliminate unneeded attributes from right side of a LeftSemiJoin. - case Join(left, right, LeftSemi, condition) => - // Collect the list of all references required to evaluate the condition. - val allReferences: AttributeSet = - condition.map(_.references).getOrElse(AttributeSet(Seq.empty)) - - Join(left, prunedChild(right, allReferences), LeftSemi, condition) - - // Push down project through limit, so that we may have chance to push it further. - case Project(projectList, Limit(exp, child)) => - Limit(exp, Project(projectList, child)) + // Can't prune the columns on LeafNode + case p @ Project(_, l: LeafNode) => p - // Push down project if possible when the child is sort. - case p @ Project(projectList, s @ Sort(_, _, grandChild)) => - if (s.references.subsetOf(p.outputSet)) { - s.copy(child = Project(projectList, grandChild)) + // Eliminate no-op Projects + case p @ Project(projectList, child) if child.output == p.output => child + + // for all other logical plans that inherits the output from it's children + case p @ Project(_, child) => + val required = child.references ++ p.references + if ((child.inputSet -- required).nonEmpty) { + val newChildren = child.children.map(c => prunedChild(c, required)) + p.copy(child = child.withNewChildren(newChildren)) } else { - val neededReferences = s.references ++ p.references - if (neededReferences == grandChild.outputSet) { - // No column we can prune, return the original plan. - p - } else { - // Do not use neededReferences.toSeq directly, should respect grandChild's output order. - val newProjectList = grandChild.output.filter(neededReferences.contains) - p.copy(child = s.copy(child = Project(newProjectList, grandChild))) - } + p } - - // Eliminate no-op Projects - case Project(projectList, child) if child.output == projectList => child } /** Applies a projection only when the child is producing unnecessary attributes */ private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) { - Project(allReferences.filter(c.outputSet.contains).toSeq, c) + Project(c.output.filter(allReferences.contains), c) } else { c } @@ -804,7 +792,14 @@ object SimplifyFilters extends Rule[LogicalPlan] { */ object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case filter @ Filter(condition, project @ Project(fields, grandChild)) => + // SPARK-13473: We can't push the predicate down when the underlying projection output non- + // deterministic field(s). Non-deterministic expressions are essentially stateful. This + // implies that, for a given input row, the output are determined by the expression's initial + // state and all the input rows processed before. In another word, the order of input rows + // matters for non-deterministic expressions, while pushing down predicates changes the order. + case filter @ Filter(condition, project @ Project(fields, grandChild)) + if fields.forall(_.deterministic) => + // Create a map of Aliases to their values from the child projection. // e.g., 'SELECT a + b AS c, d ...' produces Map(c -> a + b). val aliasMap = AttributeMap(fields.collect { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 5e7d144ae4107..a74b288cb22ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -63,17 +63,19 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy } /** - * A sequence of expressions that describes the data property of the output rows of this - * operator. For example, if the output of this operator is column `a`, an example `constraints` - * can be `Set(a > 10, a < 20)`. + * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For + * example, if this set contains the expression `a = 2` then that expression is guaranteed to + * evaluate to `true` for all rows produced. */ - lazy val constraints: Set[Expression] = getRelevantConstraints(validConstraints) + lazy val constraints: ExpressionSet = ExpressionSet(getRelevantConstraints(validConstraints)) /** * This method can be overridden by any child class of QueryPlan to specify a set of constraints * based on the given operator's constraint propagation logic. These constraints are then * canonicalized and filtered automatically to contain only those attributes that appear in the - * [[outputSet]] + * [[outputSet]]. + * + * See [[Canonicalize]] for more details. */ protected def validConstraints: Set[Expression] = Set.empty diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index d3b5879777a76..f9f1f88cec846 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -45,6 +45,9 @@ object LocalRelation { case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil) extends LeafNode with analysis.MultiInstanceRelation { + // A local relation must have resolved output. + require(output.forall(_.resolved), "Unresolved attributes found when constructing LocalRelation.") + /** * Returns an identical copy of this relation with new exprIds for all attributes. Different * attributes are required when a relation is going to be included multiple times in the same diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 35df2429dbe46..31e775d60f950 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -315,7 +315,38 @@ abstract class UnaryNode extends LogicalPlan { override def children: Seq[LogicalPlan] = child :: Nil + /** + * Generates an additional set of aliased constraints by replacing the original constraint + * expressions with the corresponding alias + */ + protected def getAliasedConstraints(projectList: Seq[NamedExpression]): Set[Expression] = { + projectList.flatMap { + case a @ Alias(e, _) => + child.constraints.map(_ transform { + case expr: Expression if expr.semanticEquals(e) => + a.toAttribute + }).union(Set(EqualNullSafe(e, a.toAttribute))) + case _ => + Set.empty[Expression] + }.toSet + } + override protected def validConstraints: Set[Expression] = child.constraints + + override def statistics: Statistics = { + // There should be some overhead in Row object, the size should not be zero when there is + // no columns, this help to prevent divide-by-zero error. + val childRowSize = child.output.map(_.dataType.defaultSize).sum + 8 + val outputRowSize = output.map(_.dataType.defaultSize).sum + 8 + // Assume there will be the same number of rows as child has. + var sizeInBytes = (child.statistics.sizeInBytes * outputRowSize) / childRowSize + if (sizeInBytes == 0) { + // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero + // (product of children). + sizeInBytes = 1 + } + Statistics(sizeInBytes = sizeInBytes) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 70ecbce82927f..e81a0f9487469 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -51,25 +51,8 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend !expressions.exists(!_.resolved) && childrenResolved && !hasSpecialExpressions } - /** - * Generates an additional set of aliased constraints by replacing the original constraint - * expressions with the corresponding alias - */ - private def getAliasedConstraints: Set[Expression] = { - projectList.flatMap { - case a @ Alias(e, _) => - child.constraints.map(_ transform { - case expr: Expression if expr.semanticEquals(e) => - a.toAttribute - }).union(Set(EqualNullSafe(e, a.toAttribute))) - case _ => - Set.empty[Expression] - }.toSet - } - - override def validConstraints: Set[Expression] = { - child.constraints.union(getAliasedConstraints) - } + override def validConstraints: Set[Expression] = + child.constraints.union(getAliasedConstraints(projectList)) } /** @@ -126,9 +109,8 @@ case class Filter(condition: Expression, child: LogicalPlan) override def maxRows: Option[Long] = child.maxRows - override protected def validConstraints: Set[Expression] = { + override protected def validConstraints: Set[Expression] = child.constraints.union(splitConjunctivePredicates(condition).toSet) - } } abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { @@ -157,9 +139,8 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable) } - override protected def validConstraints: Set[Expression] = { + override protected def validConstraints: Set[Expression] = leftConstraints.union(rightConstraints) - } // Intersect are only resolved if they don't introduce ambiguous expression ids, // since the Optimizer will convert Intersect to Join. @@ -176,6 +157,13 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation Some(children.flatMap(_.maxRows).min) } } + + override def statistics: Statistics = { + val leftSize = left.statistics.sizeInBytes + val rightSize = right.statistics.sizeInBytes + val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize + Statistics(sizeInBytes = sizeInBytes) + } } case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { @@ -188,6 +176,10 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le childrenResolved && left.output.length == right.output.length && left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } + + override def statistics: Statistics = { + Statistics(sizeInBytes = left.statistics.sizeInBytes) + } } /** Factory for constructing new `Union` nodes. */ @@ -321,6 +313,10 @@ case class Join( */ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output + + // We manually set statistics of BroadcastHint to smallest value to make sure + // the plan wrapped by BroadcastHint will be considered to broadcast later. + override def statistics: Statistics = Statistics(sizeInBytes = 1) } case class InsertIntoTable( @@ -426,6 +422,17 @@ case class Aggregate( override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) override def maxRows: Option[Long] = child.maxRows + + override def validConstraints: Set[Expression] = + child.constraints.union(getAliasedConstraints(aggregateExpressions)) + + override def statistics: Statistics = { + if (groupingExpressions.isEmpty) { + Statistics(sizeInBytes = 1) + } else { + super.statistics + } + } } case class Window( @@ -521,9 +528,7 @@ case class Expand( AttributeSet(projections.flatten.flatMap(_.references)) override def statistics: Statistics = { - // TODO shouldn't we factor in the size of the projection versus the size of the backing child - // row? - val sizeInBytes = child.statistics.sizeInBytes * projections.length + val sizeInBytes = super.statistics.sizeInBytes * projections.length Statistics(sizeInBytes = sizeInBytes) } } @@ -637,15 +642,29 @@ case class SubqueryAlias(alias: String, child: LogicalPlan) extends UnaryNode { * @param withReplacement Whether to sample with replacement. * @param seed the random seed * @param child the LogicalPlan + * @param isTableSample Is created from TABLESAMPLE in the parser. */ case class Sample( lowerBound: Double, upperBound: Double, withReplacement: Boolean, seed: Long, - child: LogicalPlan) extends UnaryNode { + child: LogicalPlan)( + val isTableSample: java.lang.Boolean = false) extends UnaryNode { override def output: Seq[Attribute] = child.output + + override def statistics: Statistics = { + val ratio = upperBound - lowerBound + // BigInt can't multiply with Double + var sizeInBytes = child.statistics.sizeInBytes * (ratio * 100).toInt / 100 + if (sizeInBytes == 0) { + sizeInBytes = 1 + } + Statistics(sizeInBytes = sizeInBytes) + } + + override protected def otherCopyArgs: Seq[AnyRef] = isTableSample :: Nil } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala index c646dcfa11811..e01f69f81359e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala @@ -31,5 +31,6 @@ trait BroadcastMode { * IdentityBroadcastMode requires that rows are broadcasted in their original form. */ case object IdentityBroadcastMode extends BroadcastMode { + // TODO: pack the UnsafeRows into single bytes array. override def transform(rows: Array[InternalRow]): Array[InternalRow] = rows } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala index f2c6f34ea51c7..c40e140e8c5c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala @@ -47,9 +47,9 @@ class BinaryType private() extends AtomicType { } /** - * The default size of a value of the BinaryType is 4096 bytes. + * The default size of a value of the BinaryType is 100 bytes. */ - override def defaultSize: Int = 4096 + override def defaultSize: Int = 100 private[spark] override def asNullable: BinaryType = this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 38ce1604b1ede..6a59e9728a9f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -340,6 +340,9 @@ object Decimal { val ROUND_CEILING = BigDecimal.RoundingMode.CEILING val ROUND_FLOOR = BigDecimal.RoundingMode.FLOOR + /** Maximum number of decimal digits a Int can represent */ + val MAX_INT_DIGITS = 9 + /** Maximum number of decimal digits a Long can represent */ val MAX_LONG_DIGITS = 18 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 71ea5b8841e1d..9c1319c1c5e6f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -91,9 +91,9 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { } /** - * The default size of a value of the DecimalType is 4096 bytes. + * The default size of a value of the DecimalType is 8 bytes (precision <= 18) or 16 bytes. */ - override def defaultSize: Int = 4096 + override def defaultSize: Int = if (precision <= Decimal.MAX_LONG_DIGITS) 8 else 16 override def simpleString: String = s"decimal($precision,$scale)" @@ -150,6 +150,17 @@ object DecimalType extends AbstractDataType { } } + /** + * Returns if dt is a DecimalType that fits inside a int + */ + def is32BitDecimalType(dt: DataType): Boolean = { + dt match { + case t: DecimalType => + t.precision <= Decimal.MAX_INT_DIGITS + case _ => false + } + } + /** * Returns if dt is a DecimalType that fits inside a long */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala index fca0b799eb809..06ee0fbfe9642 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala @@ -23,8 +23,10 @@ private[sql] object ObjectType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = throw new UnsupportedOperationException("null literals can't be casted to ObjectType") - // No casting or comparison is supported. - override private[sql] def acceptsType(other: DataType): Boolean = false + override private[sql] def acceptsType(other: DataType): Boolean = other match { + case ObjectType(_) => true + case _ => false + } override private[sql] def simpleString: String = "Object" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala index a7627a2de1611..44a25361f31c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -38,9 +38,9 @@ class StringType private() extends AtomicType { private[sql] val ordering = implicitly[Ordering[InternalType]] /** - * The default size of a value of the StringType is 4096 bytes. + * The default size of a value of the StringType is 20 bytes. */ - override def defaultSize: Int = 4096 + override def defaultSize: Int = 20 private[spark] override def asNullable: StringType = this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index 7664c30ee7650..9d2449f3b729e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -71,10 +71,7 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { */ def userClass: java.lang.Class[UserType] - /** - * The default size of a value of the UserDefinedType is 4096 bytes. - */ - override def defaultSize: Int = 4096 + override def defaultSize: Int = sqlType.defaultSize /** * For UDT, asNullable will not change the nullability of its internal sqlType and just returns diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala index 812aa5acd8e56..53a8d6e53e38a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala @@ -202,70 +202,6 @@ class CatalystQlSuite extends PlanTest { "from windowData") } - test("nesting UNION") { - val parsed = parser.parsePlan( - """ - |SELECT `u_1`.`id` FROM (((SELECT `t0`.`id` FROM `default`.`t0`) - |UNION ALL (SELECT `t0`.`id` FROM `default`.`t0`)) UNION ALL - |(SELECT `t0`.`id` FROM `default`.`t0`)) AS u_1 - """.stripMargin) - - val expected = Project( - UnresolvedAlias(UnresolvedAttribute("u_1.id"), None) :: Nil, - SubqueryAlias("u_1", - Union( - Union( - Project( - UnresolvedAlias(UnresolvedAttribute("t0.id"), None) :: Nil, - UnresolvedRelation(TableIdentifier("t0", Some("default")), None)), - Project( - UnresolvedAlias(UnresolvedAttribute("t0.id"), None) :: Nil, - UnresolvedRelation(TableIdentifier("t0", Some("default")), None))), - Project( - UnresolvedAlias(UnresolvedAttribute("t0.id"), None) :: Nil, - UnresolvedRelation(TableIdentifier("t0", Some("default")), None))))) - - comparePlans(parsed, expected) - - val parsedSame = parser.parsePlan( - """ - |SELECT `u_1`.`id` FROM ((SELECT `t0`.`id` FROM `default`.`t0`) - |UNION ALL (SELECT `t0`.`id` FROM `default`.`t0`) UNION ALL - |(SELECT `t0`.`id` FROM `default`.`t0`)) AS u_1 - """.stripMargin) - - comparePlans(parsedSame, expected) - - val parsed2 = parser.parsePlan( - """ - |SELECT `u_1`.`id` FROM ((((SELECT `t0`.`id` FROM `default`.`t0`) - |UNION ALL (SELECT `t0`.`id` FROM `default`.`t0`)) UNION ALL - |(SELECT `t0`.`id` FROM `default`.`t0`)) - |UNION ALL (SELECT `t0`.`id` FROM `default`.`t0`)) AS u_1 - """.stripMargin) - - val expected2 = Project( - UnresolvedAlias(UnresolvedAttribute("u_1.id"), None) :: Nil, - SubqueryAlias("u_1", - Union( - Union( - Union( - Project( - UnresolvedAlias(UnresolvedAttribute("t0.id"), None) :: Nil, - UnresolvedRelation(TableIdentifier("t0", Some("default")), None)), - Project( - UnresolvedAlias(UnresolvedAttribute("t0.id"), None) :: Nil, - UnresolvedRelation(TableIdentifier("t0", Some("default")), None))), - Project( - UnresolvedAlias(UnresolvedAttribute("t0.id"), None) :: Nil, - UnresolvedRelation(TableIdentifier("t0", Some("default")), None))), - Project( - UnresolvedAlias(UnresolvedAttribute("t0.id"), None) :: Nil, - UnresolvedRelation(TableIdentifier("t0", Some("default")), None))))) - - comparePlans(parsed2, expected2) - } - test("subquery") { parser.parsePlan("select (select max(b) from s) ss from t") parser.parsePlan("select * from t where a = (select b from s)") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index e0a95ba8bbd5c..ef825e606202f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -60,7 +60,18 @@ trait AnalysisTest extends PlanTest { inputPlan: LogicalPlan, caseSensitive: Boolean = true): Unit = { val analyzer = getAnalyzer(caseSensitive) - analyzer.checkAnalysis(analyzer.execute(inputPlan)) + val analysisAttempt = analyzer.execute(inputPlan) + try analyzer.checkAnalysis(analysisAttempt) catch { + case a: AnalysisException => + fail( + s""" + |Failed to Analyze Plan + |$inputPlan + | + |Partial Analysis + |$analysisAttempt + """.stripMargin, a) + } } protected def assertAnalysisError( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index e00060f9b6aff..cca320fae9505 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -23,12 +23,14 @@ import java.util.Arrays import scala.collection.mutable.ArrayBuffer import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} -import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.analysis.AnalysisTest +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.catalyst.util.ArrayData -import org.apache.spark.sql.types.{ArrayType, StructType} +import org.apache.spark.sql.types.{ArrayType, ObjectType, StructType} case class RepeatedStruct(s: Seq[PrimitiveData]) @@ -74,7 +76,7 @@ class JavaSerializable(val value: Int) extends Serializable { } } -class ExpressionEncoderSuite extends SparkFunSuite { +class ExpressionEncoderSuite extends PlanTest with AnalysisTest { OuterScopes.addOuterScope(this) implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = ExpressionEncoder() @@ -305,6 +307,15 @@ class ExpressionEncoderSuite extends SparkFunSuite { """.stripMargin, e) } + // Test the correct resolution of serialization / deserialization. + val attr = AttributeReference("obj", ObjectType(encoder.clsTag.runtimeClass))() + val inputPlan = LocalRelation(attr) + val plan = + Project(Alias(encoder.fromRowExpression, "obj")() :: Nil, + Project(encoder.namedExpressions, + inputPlan)) + assertAnalysisSuccess(plan) + val isCorrect = (input, convertedBack) match { case (b1: Array[Byte], b2: Array[Byte]) => Arrays.equals(b1, b2) case (b1: Array[Int], b2: Array[Int]) => Arrays.equals(b1, b2) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala new file mode 100644 index 0000000000000..ce42e5784ccd2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.types.IntegerType + +class ExpressionSetSuite extends SparkFunSuite { + + val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1)) + val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1)) + val fakeA = AttributeReference("a", IntegerType)(exprId = ExprId(3)) + + val bUpper = AttributeReference("B", IntegerType)(exprId = ExprId(2)) + val bLower = AttributeReference("b", IntegerType)(exprId = ExprId(2)) + + val aAndBSet = AttributeSet(aUpper :: bUpper :: Nil) + + def setTest(size: Int, exprs: Expression*): Unit = { + test(s"expect $size: ${exprs.mkString(", ")}") { + val set = ExpressionSet(exprs) + if (set.size != size) { + fail(set.toDebugString) + } + } + } + + def setTestIgnore(size: Int, exprs: Expression*): Unit = + ignore(s"expect $size: ${exprs.mkString(", ")}") {} + + // Commutative + setTest(1, aUpper + 1, aLower + 1) + setTest(2, aUpper + 1, aLower + 2) + setTest(2, aUpper + 1, fakeA + 1) + setTest(2, aUpper + 1, bUpper + 1) + + setTest(1, aUpper + aLower, aLower + aUpper) + setTest(1, aUpper + bUpper, bUpper + aUpper) + setTest(1, + aUpper + bUpper + 3, + bUpper + 3 + aUpper, + bUpper + aUpper + 3, + Literal(3) + aUpper + bUpper) + setTest(1, + aUpper * bUpper * 3, + bUpper * 3 * aUpper, + bUpper * aUpper * 3, + Literal(3) * aUpper * bUpper) + setTest(1, aUpper === bUpper, bUpper === aUpper) + + setTest(1, aUpper + 1 === bUpper, bUpper === Literal(1) + aUpper) + + + // Not commutative + setTest(2, aUpper - bUpper, bUpper - aUpper) + + // Reversable + setTest(1, aUpper > bUpper, bUpper < aUpper) + setTest(1, aUpper >= bUpper, bUpper <= aUpper) + + test("add to / remove from set") { + val initialSet = ExpressionSet(aUpper + 1 :: Nil) + + assert((initialSet + (aUpper + 1)).size == 1) + assert((initialSet + (aUpper + 2)).size == 2) + assert((initialSet - (aUpper + 1)).size == 0) + assert((initialSet - (aUpper + 2)).size == 1) + + assert((initialSet + (aLower + 1)).size == 1) + assert((initialSet - (aLower + 1)).size == 0) + + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index c890fffc40167..715d01a3cd876 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Explode, Literal} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Explode, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -119,11 +120,134 @@ class ColumnPruningSuite extends PlanTest { Seq('c, Literal.create(null, StringType), 1), Seq('c, 'a, 2)), Seq('c, 'aa.int, 'gid.int), - Project(Seq('c, 'a), + Project(Seq('a, 'c), input))).analyze comparePlans(optimized, expected) } + test("Column pruning on Filter") { + val input = LocalRelation('a.int, 'b.string, 'c.double) + val query = Project('a :: Nil, Filter('c > Literal(0.0), input)).analyze + val expected = + Project('a :: Nil, + Filter('c > Literal(0.0), + Project(Seq('a, 'c), input))).analyze + comparePlans(Optimize.execute(query), expected) + } + + test("Column pruning on except/intersect/distinct") { + val input = LocalRelation('a.int, 'b.string, 'c.double) + val query = Project('a :: Nil, Except(input, input)).analyze + comparePlans(Optimize.execute(query), query) + + val query2 = Project('a :: Nil, Intersect(input, input)).analyze + comparePlans(Optimize.execute(query2), query2) + val query3 = Project('a :: Nil, Distinct(input)).analyze + comparePlans(Optimize.execute(query3), query3) + } + + test("Column pruning on Project") { + val input = LocalRelation('a.int, 'b.string, 'c.double) + val query = Project('a :: Nil, Project(Seq('a, 'b), input)).analyze + val expected = Project(Seq('a), input).analyze + comparePlans(Optimize.execute(query), expected) + } + + test("column pruning for group") { + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val originalQuery = + testRelation + .groupBy('a)('a, count('b)) + .select('a) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .groupBy('a)('a).analyze + + comparePlans(optimized, correctAnswer) + } + + test("column pruning for group with alias") { + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + val originalQuery = + testRelation + .groupBy('a)('a as 'c, count('b)) + .select('c) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .groupBy('a)('a as 'c).analyze + + comparePlans(optimized, correctAnswer) + } + + test("column pruning for Project(ne, Limit)") { + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + val originalQuery = + testRelation + .select('a, 'b) + .limit(2) + .select('a) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .limit(2).analyze + + comparePlans(optimized, correctAnswer) + } + + test("push down project past sort") { + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val x = testRelation.subquery('x) + + // push down valid + val originalQuery = { + x.select('a, 'b) + .sortBy(SortOrder('a, Ascending)) + .select('a) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + x.select('a) + .sortBy(SortOrder('a, Ascending)).analyze + + comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) + + // push down invalid + val originalQuery1 = { + x.select('a, 'b) + .sortBy(SortOrder('a, Ascending)) + .select('b) + } + + val optimized1 = Optimize.execute(originalQuery1.analyze) + val correctAnswer1 = + x.select('a, 'b) + .sortBy(SortOrder('a, Ascending)) + .select('b).analyze + + comparePlans(optimized1, analysis.EliminateSubqueryAliases(correctAnswer1)) + } + + test("Column pruning on Union") { + val input1 = LocalRelation('a.int, 'b.string, 'c.double) + val input2 = LocalRelation('c.int, 'd.string, 'e.double) + val query = Project('b :: Nil, + Union(input1 :: input2 :: Nil)).analyze + val expected = Project('b :: Nil, + Union(Project('b :: Nil, input1) :: Project('d :: Nil, input2) :: Nil)).analyze + comparePlans(Optimize.execute(query), expected) + } + // todo: add more tests for column pruning } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 7805723ec86e2..1292aa0003dd7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -41,7 +41,6 @@ class FilterPushdownSuite extends PlanTest { PushPredicateThroughJoin, PushPredicateThroughGenerate, PushPredicateThroughAggregate, - ColumnPruning, CollapseProject) :: Nil } @@ -65,52 +64,6 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("column pruning for group") { - val originalQuery = - testRelation - .groupBy('a)('a, count('b)) - .select('a) - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - testRelation - .select('a) - .groupBy('a)('a).analyze - - comparePlans(optimized, correctAnswer) - } - - test("column pruning for group with alias") { - val originalQuery = - testRelation - .groupBy('a)('a as 'c, count('b)) - .select('c) - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - testRelation - .select('a) - .groupBy('a)('a as 'c).analyze - - comparePlans(optimized, correctAnswer) - } - - test("column pruning for Project(ne, Limit)") { - val originalQuery = - testRelation - .select('a, 'b) - .limit(2) - .select('a) - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - testRelation - .select('a) - .limit(2).analyze - - comparePlans(optimized, correctAnswer) - } - // After this line is unimplemented. test("simple push down") { val originalQuery = @@ -145,7 +98,7 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("nondeterministic: can't push down filter through project") { + test("nondeterministic: can't push down filter with nondeterministic condition through project") { val originalQuery = testRelation .select(Rand(10).as('rand), 'a) .where('rand > 5 || 'a > 5) @@ -156,36 +109,15 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, originalQuery) } - test("nondeterministic: push down part of filter through project") { + test("nondeterministic: can't push down filter through project with nondeterministic field") { val originalQuery = testRelation .select(Rand(10).as('rand), 'a) - .where('rand > 5 && 'a > 5) - .analyze - - val optimized = Optimize.execute(originalQuery) - - val correctAnswer = testRelation .where('a > 5) - .select(Rand(10).as('rand), 'a) - .where('rand > 5) - .analyze - - comparePlans(optimized, correctAnswer) - } - - test("nondeterministic: push down filter through project") { - val originalQuery = testRelation - .select(Rand(10).as('rand), 'a) - .where('a > 5 && 'a < 10) .analyze val optimized = Optimize.execute(originalQuery) - val correctAnswer = testRelation - .where('a > 5 && 'a < 10) - .select(Rand(10).as('rand), 'a) - .analyze - comparePlans(optimized, correctAnswer) + comparePlans(optimized, originalQuery) } test("filters: combines filters") { @@ -604,43 +536,10 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, originalQuery) } - test("push down project past sort") { - val x = testRelation.subquery('x) - - // push down valid - val originalQuery = { - x.select('a, 'b) - .sortBy(SortOrder('a, Ascending)) - .select('a) - } - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - x.select('a) - .sortBy(SortOrder('a, Ascending)).analyze - - comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) - - // push down invalid - val originalQuery1 = { - x.select('a, 'b) - .sortBy(SortOrder('a, Ascending)) - .select('b) - } - - val optimized1 = Optimize.execute(originalQuery1.analyze) - val correctAnswer1 = - x.select('a, 'b) - .sortBy(SortOrder('a, Ascending)) - .select('b).analyze - - comparePlans(optimized1, analysis.EliminateSubqueryAliases(correctAnswer1)) - } - test("push project and filter down into sample") { val x = testRelation.subquery('x) val originalQuery = - Sample(0.0, 0.6, false, 11L, x).select('a) + Sample(0.0, 0.6, false, 11L, x)().select('a) val originalQueryAnalyzed = EliminateSubqueryAliases(analysis.SimpleAnalyzer.execute(originalQuery)) @@ -648,7 +547,7 @@ class FilterPushdownSuite extends PlanTest { val optimized = Optimize.execute(originalQueryAnalyzed) val correctAnswer = - Sample(0.0, 0.6, false, 11L, x.select('a)) + Sample(0.0, 0.6, false, 11L, x.select('a))() comparePlans(optimized, correctAnswer.analyze) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala similarity index 77% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index a5b487bcc822f..2f382bbda0c58 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -23,18 +23,18 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor -class JoinOrderSuite extends PlanTest { +class JoinOptimizationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Subqueries", Once, EliminateSubqueryAliases) :: - Batch("Filter Pushdown", Once, + Batch("Filter Pushdown", FixedPoint(100), CombineFilters, PushPredicateThroughProject, BooleanSimplification, @@ -92,4 +92,30 @@ class JoinOrderSuite extends PlanTest { comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) } + + test("broadcasthint sets relation statistics to smallest value") { + val input = LocalRelation('key.int, 'value.string) + + val query = + Project(Seq($"x.key", $"y.key"), + Join( + SubqueryAlias("x", input), + BroadcastHint(SubqueryAlias("y", input)), Inner, None)).analyze + + val optimized = Optimize.execute(query) + + val expected = + Project(Seq($"x.key", $"y.key"), + Join( + Project(Seq($"x.key"), SubqueryAlias("x", input)), + BroadcastHint(Project(Seq($"y.key"), SubqueryAlias("y", input))), + Inner, None)).analyze + + comparePlans(optimized, expected) + + val broadcastChildren = optimized.collect { + case Join(_, r, _, _) if r.statistics.sizeInBytes == 1 => r + } + assert(broadcastChildren.size == 1) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 373b1ffa83d23..b68432b1a128f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -72,6 +72,21 @@ class ConstraintPropagationSuite extends SparkFunSuite { IsNotNull(resolveColumn(tr, "c")))) } + test("propagating constraints in aggregate") { + val tr = LocalRelation('a.int, 'b.string, 'c.int) + + assert(tr.analyze.constraints.isEmpty) + + val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5) + .groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a).analyze + + verifyConstraints(aliasedRelation.analyze.constraints, + Set(resolveColumn(aliasedRelation.analyze, "c1") > 10, + IsNotNull(resolveColumn(aliasedRelation.analyze, "c1")), + resolveColumn(aliasedRelation.analyze, "a") < 5, + IsNotNull(resolveColumn(aliasedRelation.analyze, "a")))) + } + test("propagating constraints in aliases") { val tr = LocalRelation('a.int, 'b.string, 'c.int) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index c2bbca7c33f28..6b85f12521c2a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -248,15 +248,15 @@ class DataTypeSuite extends SparkFunSuite { checkDefaultSize(LongType, 8) checkDefaultSize(FloatType, 4) checkDefaultSize(DoubleType, 8) - checkDefaultSize(DecimalType(10, 5), 4096) - checkDefaultSize(DecimalType.SYSTEM_DEFAULT, 4096) + checkDefaultSize(DecimalType(10, 5), 8) + checkDefaultSize(DecimalType.SYSTEM_DEFAULT, 16) checkDefaultSize(DateType, 4) checkDefaultSize(TimestampType, 8) - checkDefaultSize(StringType, 4096) - checkDefaultSize(BinaryType, 4096) + checkDefaultSize(StringType, 20) + checkDefaultSize(BinaryType, 100) checkDefaultSize(ArrayType(DoubleType, true), 800) - checkDefaultSize(ArrayType(StringType, false), 409600) - checkDefaultSize(MapType(IntegerType, StringType, true), 410000) + checkDefaultSize(ArrayType(StringType, false), 2000) + checkDefaultSize(MapType(IntegerType, StringType, true), 2400) checkDefaultSize(MapType(IntegerType, ArrayType(DoubleType), false), 80400) checkDefaultSize(structType, 812) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java index 4576ac2a3222f..57dbd7c2ff56f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -257,8 +257,7 @@ private void initializeInternal() throws IOException { throw new IOException("Unsupported type: " + t); } if (originalTypes[i] == OriginalType.DECIMAL && - primitiveType.getDecimalMetadata().getPrecision() > - CatalystSchemaConverter.MAX_PRECISION_FOR_INT64()) { + primitiveType.getDecimalMetadata().getPrecision() > Decimal.MAX_LONG_DIGITS()) { throw new IOException("Decimal with high precision is not supported."); } if (primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.INT96) { @@ -439,7 +438,7 @@ private void decodeFixedLenArrayAsDecimalBatch(int col, int num) throws IOExcept PrimitiveType type = requestedSchema.getFields().get(col).asPrimitiveType(); int precision = type.getDecimalMetadata().getPrecision(); int scale = type.getDecimalMetadata().getScale(); - Preconditions.checkState(precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64(), + Preconditions.checkState(precision <= Decimal.MAX_LONG_DIGITS(), "Unsupported precision."); for (int n = 0; n < num; ++n) { @@ -480,11 +479,6 @@ private final class ColumnReader { */ private boolean useDictionary; - /** - * If useDictionary is true, the staging vector used to decode the ids. - */ - private ColumnVector dictionaryIds; - /** * Maximum definition level for this column. */ @@ -620,17 +614,13 @@ private void readBatch(int total, ColumnVector column) throws IOException { } int num = Math.min(total, leftInPage); if (useDictionary) { - // Data is dictionary encoded. We will vector decode the ids and then resolve the values. - if (dictionaryIds == null) { - dictionaryIds = ColumnVector.allocate(total, DataTypes.IntegerType, MemoryMode.ON_HEAP); - } else { - dictionaryIds.reset(); - dictionaryIds.reserve(total); - } // Read and decode dictionary ids. - readIntBatch(rowId, num, dictionaryIds); - decodeDictionaryIds(rowId, num, column); + ColumnVector dictionaryIds = column.reserveDictionaryIds(total);; + defColumn.readIntegers( + num, dictionaryIds, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + decodeDictionaryIds(rowId, num, column, dictionaryIds); } else { + column.setDictionary(null); switch (descriptor.getType()) { case BOOLEAN: readBooleanBatch(rowId, num, column); @@ -667,55 +657,25 @@ private void readBatch(int total, ColumnVector column) throws IOException { /** * Reads `num` values into column, decoding the values from `dictionaryIds` and `dictionary`. */ - private void decodeDictionaryIds(int rowId, int num, ColumnVector column) { + private void decodeDictionaryIds(int rowId, int num, ColumnVector column, + ColumnVector dictionaryIds) { switch (descriptor.getType()) { case INT32: - if (column.dataType() == DataTypes.IntegerType) { - for (int i = rowId; i < rowId + num; ++i) { - column.putInt(i, dictionary.decodeToInt(dictionaryIds.getInt(i))); - } - } else if (column.dataType() == DataTypes.ByteType) { - for (int i = rowId; i < rowId + num; ++i) { - column.putByte(i, (byte) dictionary.decodeToInt(dictionaryIds.getInt(i))); - } - } else if (column.dataType() == DataTypes.ShortType) { - for (int i = rowId; i < rowId + num; ++i) { - column.putShort(i, (short) dictionary.decodeToInt(dictionaryIds.getInt(i))); - } - } else if (DecimalType.is64BitDecimalType(column.dataType())) { - for (int i = rowId; i < rowId + num; ++i) { - column.putLong(i, dictionary.decodeToInt(dictionaryIds.getInt(i))); - } - } else { - throw new NotImplementedException("Unimplemented type: " + column.dataType()); - } - break; - case INT64: - if (column.dataType() == DataTypes.LongType || - DecimalType.is64BitDecimalType(column.dataType())) { - for (int i = rowId; i < rowId + num; ++i) { - column.putLong(i, dictionary.decodeToLong(dictionaryIds.getInt(i))); - } - } else { - throw new NotImplementedException("Unimplemented type: " + column.dataType()); - } - break; - case FLOAT: - for (int i = rowId; i < rowId + num; ++i) { - column.putFloat(i, dictionary.decodeToFloat(dictionaryIds.getInt(i))); - } - break; - case DOUBLE: - for (int i = rowId; i < rowId + num; ++i) { - column.putDouble(i, dictionary.decodeToDouble(dictionaryIds.getInt(i))); - } + case BINARY: + column.setDictionary(dictionary); break; case FIXED_LEN_BYTE_ARRAY: - if (DecimalType.is64BitDecimalType(column.dataType())) { + // DecimalType written in the legacy mode + if (DecimalType.is32BitDecimalType(column.dataType())) { + for (int i = rowId; i < rowId + num; ++i) { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); + column.putInt(i, (int) CatalystRowConverter.binaryToUnscaledLong(v)); + } + } else if (DecimalType.is64BitDecimalType(column.dataType())) { for (int i = rowId; i < rowId + num; ++i) { Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); column.putLong(i, CatalystRowConverter.binaryToUnscaledLong(v)); @@ -725,32 +685,9 @@ private void decodeDictionaryIds(int rowId, int num, ColumnVector column) { } break; - case BINARY: - // TODO: this is incredibly inefficient as it blows up the dictionary right here. We - // need to do this better. We should probably add the dictionary data to the ColumnVector - // and reuse it across batches. This should mean adding a ByteArray would just update - // the length and offset. - for (int i = rowId; i < rowId + num; ++i) { - Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); - column.putByteArray(i, v.getBytes()); - } - break; - default: throw new NotImplementedException("Unsupported type: " + descriptor.getType()); } - - if (dictionaryIds.numNulls() > 0) { - // Copy the NULLs over. - // TODO: we can improve this by decoding the NULLs directly into column. This would - // mean we decode the int ids into `dictionaryIds` and the NULLs into `column` and then - // just do the ID remapping as above. - for (int i = 0; i < num; ++i) { - if (dictionaryIds.getIsNull(rowId + i)) { - column.putNull(rowId + i); - } - } - } } /** @@ -767,14 +704,15 @@ private void readBooleanBatch(int rowId, int num, ColumnVector column) throws IO private void readIntBatch(int rowId, int num, ColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions - if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType) { + if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType || + DecimalType.is32BitDecimalType(column.dataType())) { defColumn.readIntegers( - num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn, 0); + num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } else if (column.dataType() == DataTypes.ByteType) { defColumn.readBytes( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); - } else if (DecimalType.is64BitDecimalType(column.dataType())) { - defColumn.readIntsAsLongs( + } else if (column.dataType() == DataTypes.ShortType) { + defColumn.readShorts( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } else { throw new NotImplementedException("Unimplemented type: " + column.dataType()); @@ -830,7 +768,16 @@ private void readFixedLenByteArrayBatch(int rowId, int num, VectorizedValuesReader data = (VectorizedValuesReader) dataColumn; // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions - if (DecimalType.is64BitDecimalType(column.dataType())) { + if (DecimalType.is32BitDecimalType(column.dataType())) { + for (int i = 0; i < num; i++) { + if (defColumn.readInteger() == maxDefLevel) { + column.putInt(rowId + i, + (int) CatalystRowConverter.binaryToUnscaledLong(data.readBinary(arrayLen))); + } else { + column.putNull(rowId + i); + } + } + } else if (DecimalType.is64BitDecimalType(column.dataType())) { for (int i = 0; i < num; i++) { if (defColumn.readInteger() == maxDefLevel) { column.putLong(rowId + i, diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java index bf3283e85329b..ee9a7a221bbde 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java @@ -85,7 +85,7 @@ public final void readBytes(int total, ColumnVector c, int rowId) { for (int i = 0; i < total; i++) { // Bytes are stored as a 4-byte little endian int. Just read the first byte. // TODO: consider pushing this in ColumnVector by adding a readBytes with a stride. - c.putInt(rowId + i, buffer[offset]); + c.putByte(rowId + i, Platform.getByte(buffer, offset)); offset += 4; } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java index 629959a73baf3..62157389013bb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution.datasources.parquet; -import org.apache.commons.lang.NotImplementedException; import org.apache.parquet.Preconditions; import org.apache.parquet.bytes.BytesUtils; import org.apache.parquet.column.values.ValuesReader; @@ -26,7 +25,6 @@ import org.apache.parquet.io.ParquetDecodingException; import org.apache.parquet.io.api.Binary; -import org.apache.spark.sql.Column; import org.apache.spark.sql.execution.vectorized.ColumnVector; /** @@ -176,11 +174,11 @@ public int readInteger() { * if (this.readInt() == level) { * c[rowId] = data.readInteger(); * } else { - * c[rowId] = nullValue; + * c[rowId] = null; * } */ public void readIntegers(int total, ColumnVector c, int rowId, int level, - VectorizedValuesReader data, int nullValue) { + VectorizedValuesReader data) { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -189,7 +187,6 @@ public void readIntegers(int total, ColumnVector c, int rowId, int level, case RLE: if (currentValue == level) { data.readIntegers(n, c, rowId); - c.putNotNulls(rowId, n); } else { c.putNulls(rowId, n); } @@ -198,9 +195,7 @@ public void readIntegers(int total, ColumnVector c, int rowId, int level, for (int i = 0; i < n; ++i) { if (currentBuffer[currentBufferIdx++] == level) { c.putInt(rowId + i, data.readInteger()); - c.putNotNull(rowId + i); } else { - c.putInt(rowId + i, nullValue); c.putNull(rowId + i); } } @@ -223,7 +218,6 @@ public void readBooleans(int total, ColumnVector c, case RLE: if (currentValue == level) { data.readBooleans(n, c, rowId); - c.putNotNulls(rowId, n); } else { c.putNulls(rowId, n); } @@ -232,7 +226,6 @@ public void readBooleans(int total, ColumnVector c, for (int i = 0; i < n; ++i) { if (currentBuffer[currentBufferIdx++] == level) { c.putBoolean(rowId + i, data.readBoolean()); - c.putNotNull(rowId + i); } else { c.putNull(rowId + i); } @@ -245,7 +238,7 @@ public void readBooleans(int total, ColumnVector c, } } - public void readIntsAsLongs(int total, ColumnVector c, + public void readBytes(int total, ColumnVector c, int rowId, int level, VectorizedValuesReader data) { int left = total; while (left > 0) { @@ -254,10 +247,7 @@ public void readIntsAsLongs(int total, ColumnVector c, switch (mode) { case RLE: if (currentValue == level) { - for (int i = 0; i < n; i++) { - c.putLong(rowId + i, data.readInteger()); - } - c.putNotNulls(rowId, n); + data.readBytes(n, c, rowId); } else { c.putNulls(rowId, n); } @@ -265,8 +255,7 @@ public void readIntsAsLongs(int total, ColumnVector c, case PACKED: for (int i = 0; i < n; ++i) { if (currentBuffer[currentBufferIdx++] == level) { - c.putLong(rowId + i, data.readInteger()); - c.putNotNull(rowId + i); + c.putByte(rowId + i, data.readByte()); } else { c.putNull(rowId + i); } @@ -279,7 +268,7 @@ public void readIntsAsLongs(int total, ColumnVector c, } } - public void readBytes(int total, ColumnVector c, + public void readShorts(int total, ColumnVector c, int rowId, int level, VectorizedValuesReader data) { int left = total; while (left > 0) { @@ -288,8 +277,9 @@ public void readBytes(int total, ColumnVector c, switch (mode) { case RLE: if (currentValue == level) { - data.readBytes(n, c, rowId); - c.putNotNulls(rowId, n); + for (int i = 0; i < n; i++) { + c.putShort(rowId + i, (short)data.readInteger()); + } } else { c.putNulls(rowId, n); } @@ -297,8 +287,7 @@ public void readBytes(int total, ColumnVector c, case PACKED: for (int i = 0; i < n; ++i) { if (currentBuffer[currentBufferIdx++] == level) { - c.putByte(rowId + i, data.readByte()); - c.putNotNull(rowId + i); + c.putShort(rowId + i, (short)data.readInteger()); } else { c.putNull(rowId + i); } @@ -321,7 +310,6 @@ public void readLongs(int total, ColumnVector c, int rowId, int level, case RLE: if (currentValue == level) { data.readLongs(n, c, rowId); - c.putNotNulls(rowId, n); } else { c.putNulls(rowId, n); } @@ -330,7 +318,6 @@ public void readLongs(int total, ColumnVector c, int rowId, int level, for (int i = 0; i < n; ++i) { if (currentBuffer[currentBufferIdx++] == level) { c.putLong(rowId + i, data.readLong()); - c.putNotNull(rowId + i); } else { c.putNull(rowId + i); } @@ -353,7 +340,6 @@ public void readFloats(int total, ColumnVector c, int rowId, int level, case RLE: if (currentValue == level) { data.readFloats(n, c, rowId); - c.putNotNulls(rowId, n); } else { c.putNulls(rowId, n); } @@ -362,7 +348,6 @@ public void readFloats(int total, ColumnVector c, int rowId, int level, for (int i = 0; i < n; ++i) { if (currentBuffer[currentBufferIdx++] == level) { c.putFloat(rowId + i, data.readFloat()); - c.putNotNull(rowId + i); } else { c.putNull(rowId + i); } @@ -385,7 +370,6 @@ public void readDoubles(int total, ColumnVector c, int rowId, int level, case RLE: if (currentValue == level) { data.readDoubles(n, c, rowId); - c.putNotNulls(rowId, n); } else { c.putNulls(rowId, n); } @@ -394,7 +378,6 @@ public void readDoubles(int total, ColumnVector c, int rowId, int level, for (int i = 0; i < n; ++i) { if (currentBuffer[currentBufferIdx++] == level) { c.putDouble(rowId + i, data.readDouble()); - c.putNotNull(rowId + i); } else { c.putNull(rowId + i); } @@ -416,7 +399,6 @@ public void readBinarys(int total, ColumnVector c, int rowId, int level, switch (mode) { case RLE: if (currentValue == level) { - c.putNotNulls(rowId, n); data.readBinary(n, c, rowId); } else { c.putNulls(rowId, n); @@ -425,7 +407,6 @@ public void readBinarys(int total, ColumnVector c, int rowId, int level, case PACKED: for (int i = 0; i < n; ++i) { if (currentBuffer[currentBufferIdx++] == level) { - c.putNotNull(rowId + i); data.readBinary(1, c, rowId + i); } else { c.putNull(rowId + i); @@ -439,6 +420,40 @@ public void readBinarys(int total, ColumnVector c, int rowId, int level, } } + /** + * Decoding for dictionary ids. The IDs are populated into `values` and the nullability is + * populated into `nulls`. + */ + public void readIntegers(int total, ColumnVector values, ColumnVector nulls, int rowId, int level, + VectorizedValuesReader data) { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + if (currentValue == level) { + data.readIntegers(n, values, rowId); + } else { + nulls.putNulls(rowId, n); + } + break; + case PACKED: + for (int i = 0; i < n; ++i) { + if (currentBuffer[currentBufferIdx++] == level) { + values.putInt(rowId + i, data.readInteger()); + } else { + nulls.putNull(rowId + i); + } + } + break; + } + rowId += n; + left -= n; + currentCount -= n; + } + } + // The RLE reader implements the vectorized decoding interface when used to decode dictionary // IDs. This is different than the above APIs that decodes definitions levels along with values. @@ -560,12 +575,14 @@ private int readIntLittleEndianPaddedOnBitWidth() { throw new RuntimeException("Unreachable"); } + private int ceil8(int value) { + return (value + 7) / 8; + } + /** * Reads the next group. */ private void readNextGroup() { - Preconditions.checkArgument(this.offset < this.end, - "Reading past RLE/BitPacking stream. offset=" + this.offset + " end=" + this.end); int header = readUnsignedVarInt(); this.mode = (header & 1) == 0 ? MODE.RLE : MODE.PACKED; switch (mode) { @@ -576,14 +593,12 @@ private void readNextGroup() { case PACKED: int numGroups = header >>> 1; this.currentCount = numGroups * 8; + int bytesToRead = ceil8(this.currentCount * this.bitWidth); if (this.currentBuffer.length < this.currentCount) { this.currentBuffer = new int[this.currentCount]; } currentBufferIdx = 0; - int bytesToRead = (int)Math.ceil((double)(this.currentCount * this.bitWidth) / 8.0D); - - bytesToRead = Math.min(bytesToRead, this.end - this.offset); int valueIndex = 0; for (int byteIndex = offset; valueIndex < this.currentCount; byteIndex += this.bitWidth) { this.packer.unpack8Values(in, byteIndex, this.currentBuffer, valueIndex); @@ -595,4 +610,4 @@ private void readNextGroup() { throw new ParquetDecodingException("not a valid mode " + this.mode); } } -} \ No newline at end of file +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index 0514252a8e53d..bb0247c2fbedf 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -19,6 +19,10 @@ import java.math.BigDecimal; import java.math.BigInteger; +import org.apache.commons.lang.NotImplementedException; +import org.apache.parquet.column.Dictionary; +import org.apache.parquet.io.api.Binary; + import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.ArrayData; @@ -27,8 +31,6 @@ import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; -import org.apache.commons.lang.NotImplementedException; - /** * This class represents a column of values and provides the main APIs to access the data * values. It supports all the types and contains get/put APIs as well as their batched versions. @@ -157,7 +159,7 @@ public Object[] array() { } else if (dt instanceof StringType) { for (int i = 0; i < length; i++) { if (!data.getIsNull(offset + i)) { - list[i] = ColumnVectorUtils.toString(data.getByteArray(offset + i)); + list[i] = getUTF8String(i).toString(); } } } else if (dt instanceof CalendarIntervalType) { @@ -204,28 +206,17 @@ public float getFloat(int ordinal) { @Override public Decimal getDecimal(int ordinal, int precision, int scale) { - if (precision <= Decimal.MAX_LONG_DIGITS()) { - return Decimal.apply(getLong(ordinal), precision, scale); - } else { - byte[] bytes = getBinary(ordinal); - BigInteger bigInteger = new BigInteger(bytes); - BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); - return Decimal.apply(javaDecimal, precision, scale); - } + return data.getDecimal(offset + ordinal, precision, scale); } @Override public UTF8String getUTF8String(int ordinal) { - Array child = data.getByteArray(offset + ordinal); - return UTF8String.fromBytes(child.byteArray, child.byteArrayOffset, child.length); + return data.getUTF8String(offset + ordinal); } @Override public byte[] getBinary(int ordinal) { - ColumnVector.Array array = data.getByteArray(offset + ordinal); - byte[] bytes = new byte[array.length]; - System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length); - return bytes; + return data.getBinary(offset + ordinal); } @Override @@ -534,12 +525,57 @@ public final int putByteArray(int rowId, byte[] value) { /** * Returns the value for rowId. */ - public final Array getByteArray(int rowId) { + private Array getByteArray(int rowId) { Array array = getArray(rowId); array.data.loadBytes(array); return array; } + /** + * Returns the decimal for rowId. + */ + public final Decimal getDecimal(int rowId, int precision, int scale) { + if (precision <= Decimal.MAX_INT_DIGITS()) { + return Decimal.apply(getInt(rowId), precision, scale); + } else if (precision <= Decimal.MAX_LONG_DIGITS()) { + return Decimal.apply(getLong(rowId), precision, scale); + } else { + // TODO: best perf? + byte[] bytes = getBinary(rowId); + BigInteger bigInteger = new BigInteger(bytes); + BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); + return Decimal.apply(javaDecimal, precision, scale); + } + } + + /** + * Returns the UTF8String for rowId. + */ + public final UTF8String getUTF8String(int rowId) { + if (dictionary == null) { + ColumnVector.Array a = getByteArray(rowId); + return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length); + } else { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(rowId)); + return UTF8String.fromBytes(v.getBytes()); + } + } + + /** + * Returns the byte array for rowId. + */ + public final byte[] getBinary(int rowId) { + if (dictionary == null) { + ColumnVector.Array array = getByteArray(rowId); + byte[] bytes = new byte[array.length]; + System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length); + return bytes; + } else { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(rowId)); + return v.getBytes(); + } + } + /** * Append APIs. These APIs all behave similarly and will append data to the current vector. It * is not valid to mix the put and append APIs. The append APIs are slower and should only be @@ -816,6 +852,39 @@ public final int appendStruct(boolean isNull) { */ protected final ColumnarBatch.Row resultStruct; + /** + * The Dictionary for this column. + * + * If it's not null, will be used to decode the value in getXXX(). + */ + protected Dictionary dictionary; + + /** + * Reusable column for ids of dictionary. + */ + protected ColumnVector dictionaryIds; + + /** + * Update the dictionary. + */ + public void setDictionary(Dictionary dictionary) { + this.dictionary = dictionary; + } + + /** + * Reserve a integer column for ids of dictionary. + */ + public ColumnVector reserveDictionaryIds(int capacity) { + if (dictionaryIds == null) { + dictionaryIds = allocate(capacity, DataTypes.IntegerType, + this instanceof OnHeapColumnVector ? MemoryMode.ON_HEAP : MemoryMode.OFF_HEAP); + } else { + dictionaryIds.reset(); + dictionaryIds.reserve(capacity); + } + return dictionaryIds; + } + /** * Sets up the common state and also handles creating the child columns if this is a nested * type. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index 2aeef7f2f90fe..681ace3387139 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -22,24 +22,20 @@ import java.util.Iterator; import java.util.List; +import org.apache.commons.lang.NotImplementedException; + import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.util.DateTimeUtils; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; -import org.apache.commons.lang.NotImplementedException; - /** * Utilities to help manipulate data associate with ColumnVectors. These should be used mostly * for debugging or other non-performance critical paths. * These utilities are mostly used to convert ColumnVectors into other formats. */ public class ColumnVectorUtils { - public static String toString(ColumnVector.Array a) { - return new String(a.byteArray, a.byteArrayOffset, a.length); - } - /** * Returns the array data as the java primitive array. * For example, an array of IntegerType will return an int[]. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index 070d897a7158c..8a0d7f8b12379 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -16,11 +16,11 @@ */ package org.apache.spark.sql.execution.vectorized; -import java.math.BigDecimal; -import java.math.BigInteger; import java.util.Arrays; import java.util.Iterator; +import org.apache.commons.lang.NotImplementedException; + import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericMutableRow; @@ -31,8 +31,6 @@ import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; -import org.apache.commons.lang.NotImplementedException; - /** * This class is the in memory representation of rows as they are streamed through operators. It * is designed to maximize CPU efficiency and not storage footprint. Since it is expected that @@ -193,29 +191,17 @@ public final boolean anyNull() { @Override public final Decimal getDecimal(int ordinal, int precision, int scale) { - if (precision <= Decimal.MAX_LONG_DIGITS()) { - return Decimal.apply(getLong(ordinal), precision, scale); - } else { - // TODO: best perf? - byte[] bytes = getBinary(ordinal); - BigInteger bigInteger = new BigInteger(bytes); - BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); - return Decimal.apply(javaDecimal, precision, scale); - } + return columns[ordinal].getDecimal(rowId, precision, scale); } @Override public final UTF8String getUTF8String(int ordinal) { - ColumnVector.Array a = columns[ordinal].getByteArray(rowId); - return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length); + return columns[ordinal].getUTF8String(rowId); } @Override public final byte[] getBinary(int ordinal) { - ColumnVector.Array array = columns[ordinal].getByteArray(rowId); - byte[] bytes = new byte[array.length]; - System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length); - return bytes; + return columns[ordinal].getBinary(rowId); } @Override diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index e38ed051219b7..b06b7f2457b54 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -18,25 +18,11 @@ import java.nio.ByteOrder; -import org.apache.spark.memory.MemoryMode; -import org.apache.spark.sql.execution.vectorized.ColumnVector.Array; -import org.apache.spark.sql.types.BooleanType; -import org.apache.spark.sql.types.ByteType; -import org.apache.spark.sql.types.DataType; -import org.apache.spark.sql.types.DateType; -import org.apache.spark.sql.types.DecimalType; -import org.apache.spark.sql.types.DoubleType; -import org.apache.spark.sql.types.FloatType; -import org.apache.spark.sql.types.IntegerType; -import org.apache.spark.sql.types.LongType; -import org.apache.spark.sql.types.ShortType; -import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.types.UTF8String; - - import org.apache.commons.lang.NotImplementedException; -import org.apache.commons.lang.NotImplementedException; +import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.Platform; /** * Column data backed using offheap memory. @@ -171,7 +157,11 @@ public final void putBytes(int rowId, int count, byte[] src, int srcIndex) { @Override public final byte getByte(int rowId) { - return Platform.getByte(null, data + rowId); + if (dictionary == null) { + return Platform.getByte(null, data + rowId); + } else { + return (byte) dictionary.decodeToInt(dictionaryIds.getInt(rowId)); + } } // @@ -199,7 +189,11 @@ public final void putShorts(int rowId, int count, short[] src, int srcIndex) { @Override public final short getShort(int rowId) { - return Platform.getShort(null, data + 2 * rowId); + if (dictionary == null) { + return Platform.getShort(null, data + 2 * rowId); + } else { + return (short) dictionary.decodeToInt(dictionaryIds.getInt(rowId)); + } } // @@ -233,7 +227,11 @@ public final void putIntsLittleEndian(int rowId, int count, byte[] src, int srcI @Override public final int getInt(int rowId) { - return Platform.getInt(null, data + 4 * rowId); + if (dictionary == null) { + return Platform.getInt(null, data + 4 * rowId); + } else { + return dictionary.decodeToInt(dictionaryIds.getInt(rowId)); + } } // @@ -267,7 +265,11 @@ public final void putLongsLittleEndian(int rowId, int count, byte[] src, int src @Override public final long getLong(int rowId) { - return Platform.getLong(null, data + 8 * rowId); + if (dictionary == null) { + return Platform.getLong(null, data + 8 * rowId); + } else { + return dictionary.decodeToLong(dictionaryIds.getInt(rowId)); + } } // @@ -301,7 +303,11 @@ public final void putFloats(int rowId, int count, byte[] src, int srcIndex) { @Override public final float getFloat(int rowId) { - return Platform.getFloat(null, data + rowId * 4); + if (dictionary == null) { + return Platform.getFloat(null, data + rowId * 4); + } else { + return dictionary.decodeToFloat(dictionaryIds.getInt(rowId)); + } } @@ -336,7 +342,11 @@ public final void putDoubles(int rowId, int count, byte[] src, int srcIndex) { @Override public final double getDouble(int rowId) { - return Platform.getDouble(null, data + rowId * 8); + if (dictionary == null) { + return Platform.getDouble(null, data + rowId * 8); + } else { + return dictionary.decodeToDouble(dictionaryIds.getInt(rowId)); + } } // @@ -394,7 +404,7 @@ private final void reserveInternal(int newCapacity) { } else if (type instanceof ShortType) { this.data = Platform.reallocateMemory(data, elementsAppended * 2, newCapacity * 2); } else if (type instanceof IntegerType || type instanceof FloatType || - type instanceof DateType) { + type instanceof DateType || DecimalType.is32BitDecimalType(type)) { this.data = Platform.reallocateMemory(data, elementsAppended * 4, newCapacity * 4); } else if (type instanceof LongType || type instanceof DoubleType || DecimalType.is64BitDecimalType(type)) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 3502d31bd1dfa..305e84a86bdc7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -16,13 +16,12 @@ */ package org.apache.spark.sql.execution.vectorized; +import java.util.Arrays; + import org.apache.spark.memory.MemoryMode; -import org.apache.spark.sql.execution.vectorized.ColumnVector.Array; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; -import java.util.Arrays; - /** * A column backed by an in memory JVM array. This stores the NULLs as a byte per value * and a java array for the values. @@ -68,7 +67,6 @@ public final void close() { doubleData = null; } - // // APIs dealing with nulls // @@ -154,7 +152,11 @@ public final void putBytes(int rowId, int count, byte[] src, int srcIndex) { @Override public final byte getByte(int rowId) { - return byteData[rowId]; + if (dictionary == null) { + return byteData[rowId]; + } else { + return (byte) dictionary.decodeToInt(dictionaryIds.getInt(rowId)); + } } // @@ -180,7 +182,11 @@ public final void putShorts(int rowId, int count, short[] src, int srcIndex) { @Override public final short getShort(int rowId) { - return shortData[rowId]; + if (dictionary == null) { + return shortData[rowId]; + } else { + return (short) dictionary.decodeToInt(dictionaryIds.getInt(rowId)); + } } @@ -217,7 +223,11 @@ public final void putIntsLittleEndian(int rowId, int count, byte[] src, int srcI @Override public final int getInt(int rowId) { - return intData[rowId]; + if (dictionary == null) { + return intData[rowId]; + } else { + return dictionary.decodeToInt(dictionaryIds.getInt(rowId)); + } } // @@ -253,7 +263,11 @@ public final void putLongsLittleEndian(int rowId, int count, byte[] src, int src @Override public final long getLong(int rowId) { - return longData[rowId]; + if (dictionary == null) { + return longData[rowId]; + } else { + return dictionary.decodeToLong(dictionaryIds.getInt(rowId)); + } } // @@ -280,7 +294,13 @@ public final void putFloats(int rowId, int count, byte[] src, int srcIndex) { } @Override - public final float getFloat(int rowId) { return floatData[rowId]; } + public final float getFloat(int rowId) { + if (dictionary == null) { + return floatData[rowId]; + } else { + return dictionary.decodeToFloat(dictionaryIds.getInt(rowId)); + } + } // // APIs dealing with doubles @@ -309,7 +329,11 @@ public final void putDoubles(int rowId, int count, byte[] src, int srcIndex) { @Override public final double getDouble(int rowId) { - return doubleData[rowId]; + if (dictionary == null) { + return doubleData[rowId]; + } else { + return dictionary.decodeToDouble(dictionaryIds.getInt(rowId)); + } } // @@ -377,7 +401,8 @@ private final void reserveInternal(int newCapacity) { short[] newData = new short[newCapacity]; if (shortData != null) System.arraycopy(shortData, 0, newData, 0, elementsAppended); shortData = newData; - } else if (type instanceof IntegerType || type instanceof DateType) { + } else if (type instanceof IntegerType || type instanceof DateType || + DecimalType.is32BitDecimalType(type)) { int[] newData = new int[newCapacity]; if (intData != null) System.arraycopy(intData, 0, newData, 0, elementsAppended); intData = newData; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index e3412f7a2ea4b..5f5b7f4c19cff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.execution.{ExplainCommand, FileRelation, LogicalRDD, import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.execution.python.EvaluatePython +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -1041,7 +1042,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = withPlan { - Sample(0.0, fraction, withReplacement, seed, logicalPlan) + Sample(0.0, fraction, withReplacement, seed, logicalPlan)() } /** @@ -1073,7 +1074,7 @@ class DataFrame private[sql]( val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) normalizedCumWeights.sliding(2).map { x => - new DataFrame(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, sorted)) + new DataFrame(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, sorted)()) }.toArray } @@ -1425,30 +1426,6 @@ class DataFrame private[sql]( */ def transform[U](t: DataFrame => DataFrame): DataFrame = t(this) - /** - * Returns a new RDD by applying a function to all rows of this DataFrame. - * @group rdd - * @since 1.3.0 - */ - def map[R: ClassTag](f: Row => R): RDD[R] = rdd.map(f) - - /** - * Returns a new RDD by first applying a function to all rows of this [[DataFrame]], - * and then flattening the results. - * @group rdd - * @since 1.3.0 - */ - def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = rdd.flatMap(f) - - /** - * Returns a new RDD by applying a function to each partition of this DataFrame. - * @group rdd - * @since 1.3.0 - */ - def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = { - rdd.mapPartitions(f) - } - /** * Applies a function `f` to all rows. * @group rdd diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index bb3cc02800d51..3eb1f0f0d58ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -36,6 +36,50 @@ import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch} @Experimental final class DataFrameStatFunctions private[sql](df: DataFrame) { + /** + * Calculates the approximate quantiles of a numerical column of a DataFrame. + * + * The result of this algorithm has the following deterministic bound: + * If the DataFrame has N elements and if we request the quantile at probability `p` up to error + * `err`, then the algorithm will return a sample `x` from the DataFrame so that the *exact* rank + * of `x` is close to (p * N). + * More precisely, + * + * floor((p - err) * N) <= rank(x) <= ceil((p + err) * N). + * + * This method implements a variation of the Greenwald-Khanna algorithm (with some speed + * optimizations). + * The algorithm was first present in [[http://dx.doi.org/10.1145/375663.375670 Space-efficient + * Online Computation of Quantile Summaries]] by Greenwald and Khanna. + * + * @param col the name of the numerical column + * @param probabilities a list of quantile probabilities + * Each number must belong to [0, 1]. + * For example 0 is the minimum, 0.5 is the median, 1 is the maximum. + * @param relativeError The relative target precision to achieve (>= 0). + * If set to zero, the exact quantiles are computed, which could be very expensive. + * Note that values greater than 1 are accepted but give the same result as 1. + * @return the approximate quantiles at the given probabilities + * + * @since 2.0.0 + */ + def approxQuantile( + col: String, + probabilities: Array[Double], + relativeError: Double): Array[Double] = { + StatFunctions.multipleApproxQuantiles(df, Seq(col), probabilities, relativeError).head.toArray + } + + /** + * Python-friendly version of [[approxQuantile()]] + */ + private[spark] def approxQuantile( + col: String, + probabilities: List[Double], + relativeError: Double): java.util.List[Double] = { + approxQuantile(col, probabilities.toArray, relativeError).toList.asJava + } + /** * Calculate the sample covariance of two numerical columns of a DataFrame. * @param col1 the name of the first column diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index d6bdd3d825565..093504c765ee8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -453,6 +453,10 @@ final class DataFrameWriter private[sql](df: DataFrame) { * format("json").save(path) * }}} * + * You can set the following JSON-specific option(s) for writing JSON files: + *
  • `compression` (default `null`): compression codec to use when saving to file. This can be + * one of the known case-insensitive shorten names (`bzip2`, `gzip`, `lz4`, and `snappy`).
  • + * * @since 1.4.0 */ def json(path: String): Unit = format("json").save(path) @@ -492,10 +496,29 @@ final class DataFrameWriter private[sql](df: DataFrame) { * df.write().text("/path/to/output") * }}} * + * You can set the following option(s) for writing text files: + *
  • `compression` (default `null`): compression codec to use when saving to file. This can be + * one of the known case-insensitive shorten names (`bzip2`, `gzip`, `lz4`, and `snappy`).
  • + * * @since 1.6.0 */ def text(path: String): Unit = format("text").save(path) + /** + * Saves the content of the [[DataFrame]] in CSV format at the specified path. + * This is equivalent to: + * {{{ + * format("csv").save(path) + * }}} + * + * You can set the following CSV-specific option(s) for writing CSV files: + *
  • `compression` (default `null`): compression codec to use when saving to file. This can be + * one of the known case-insensitive shorten names (`bzip2`, `gzip`, `lz4`, and `snappy`).
  • + * + * @since 2.0.0 + */ + def csv(path: String): Unit = format("csv").save(path) + /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options /////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index ea7e7255abde6..dd1fbcf3c881a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -564,7 +564,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def sample(withReplacement: Boolean, fraction: Double, seed: Long) : Dataset[T] = - withPlan(Sample(0.0, fraction, withReplacement, seed, _)) + withPlan(Sample(0.0, fraction, withReplacement, seed, _)()) /** * Returns a new [[Dataset]] by sampling a fraction of records, using a random seed. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 66ec0e7338fc4..a7258d742aa96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Pivot} import org.apache.spark.sql.catalyst.util.usePrettyExpression +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.NumericType /** @@ -305,6 +306,7 @@ class GroupedData protected[sql]( val values = df.select(pivotColumn) .distinct() .sort(pivotColumn) // ensure that the output columns are in a consistent logical order + .rdd .map(_.get(0)) .take(maxValues + 1) .toSeq diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala new file mode 100644 index 0000000000000..e90a04243164b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +/** + * Runtime configuration interface for Spark. To access this, use `SparkSession.conf`. + * + * @since 2.0.0 + */ +abstract class RuntimeConfig { + + /** + * Sets the given Spark runtime configuration property. + * + * @since 2.0.0 + */ + def set(key: String, value: String): RuntimeConfig + + /** + * Sets the given Spark runtime configuration property. + * + * @since 2.0.0 + */ + def set(key: String, value: Boolean): RuntimeConfig + + /** + * Sets the given Spark runtime configuration property. + * + * @since 2.0.0 + */ + def set(key: String, value: Long): RuntimeConfig + + /** + * Returns the value of Spark runtime configuration property for the given key. + * + * @throws NoSuchElementException if the key is not set and does not have a default value + * @since 2.0.0 + */ + @throws[NoSuchElementException]("if the key is not set") + def get(key: String): String + + /** + * Returns the value of Spark runtime configuration property for the given key. + * + * @since 2.0.0 + */ + def getOption(key: String): Option[String] + + /** + * Resets the configuration property for the given key. + * + * @since 2.0.0 + */ + def unset(key: String): Unit + + /** + * Sets the given Hadoop configuration property. This is passed directly to Hadoop during I/O. + * + * @since 2.0.0 + */ + def setHadoop(key: String, value: String): RuntimeConfig + + /** + * Returns the value of the Hadoop configuration property. + * + * @throws NoSuchElementException if the key is not set + * @since 2.0.0 + */ + @throws[NoSuchElementException]("if the key is not set") + def getHadoop(key: String): String + + /** + * Returns the value of the Hadoop configuration property. + * + * @since 2.0.0 + */ + def getHadoopOption(key: String): Option[String] + + /** + * Resets the Hadoop configuration property for the given key. + * + * @since 2.0.0 + */ + def unsetHadoop(key: String): Unit +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index a2f386850c1b5..cb4a6397b261b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -25,14 +25,12 @@ import scala.collection.JavaConverters._ import scala.collection.immutable import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.{Logging, SparkContext, SparkException} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} -import org.apache.spark.sql.{execution => sparkexecution} -import org.apache.spark.sql.SQLConf.SQLConfEntry -import org.apache.spark.sql.catalyst.{InternalRow, _} +import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions._ @@ -41,8 +39,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} +import org.apache.spark.sql.internal.{SessionState, SQLConf} +import org.apache.spark.sql.internal.SQLConf.SQLConfEntry import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ExecutionListenerManager @@ -67,7 +66,7 @@ class SQLContext private[sql]( @transient protected[sql] val cacheManager: CacheManager, @transient private[sql] val listener: SQLListener, val isRootContext: Boolean) - extends org.apache.spark.Logging with Serializable { + extends Logging with Serializable { self => @@ -114,9 +113,27 @@ class SQLContext private[sql]( } /** - * @return Spark SQL configuration + * Per-session state, e.g. configuration, functions, temporary tables etc. */ - protected[sql] lazy val conf = new SQLConf + @transient + protected[sql] lazy val sessionState: SessionState = new SessionState(self) + protected[sql] def conf: SQLConf = sessionState.conf + protected[sql] def catalog: Catalog = sessionState.catalog + protected[sql] def functionRegistry: FunctionRegistry = sessionState.functionRegistry + protected[sql] def analyzer: Analyzer = sessionState.analyzer + protected[sql] def optimizer: Optimizer = sessionState.optimizer + protected[sql] def sqlParser: ParserInterface = sessionState.sqlParser + protected[sql] def planner: SparkPlanner = sessionState.planner + protected[sql] def continuousQueryManager = sessionState.continuousQueryManager + protected[sql] def prepareForExecution: RuleExecutor[SparkPlan] = + sessionState.prepareForExecution + + /** + * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s + * that listen for execution metrics. + */ + @Experimental + def listenerManager: ExecutionListenerManager = sessionState.listenerManager /** * Set Spark SQL configuration properties. @@ -178,43 +195,11 @@ class SQLContext private[sql]( */ def getAllConfs: immutable.Map[String, String] = conf.getAllConfs - @transient - lazy val listenerManager: ExecutionListenerManager = new ExecutionListenerManager - - protected[sql] lazy val continuousQueryManager = new ContinuousQueryManager(this) - - @transient - protected[sql] lazy val catalog: Catalog = new SimpleCatalog(conf) - - @transient - protected[sql] lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin.copy() - - @transient - protected[sql] lazy val analyzer: Analyzer = - new Analyzer(catalog, functionRegistry, conf) { - override val extendedResolutionRules = - python.ExtractPythonUDFs :: - PreInsertCastAndRename :: - (if (conf.runSQLOnFile) new ResolveDataSource(self) :: Nil else Nil) - - override val extendedCheckRules = Seq( - datasources.PreWriteCheck(catalog) - ) - } - - @transient - protected[sql] lazy val optimizer: Optimizer = new SparkOptimizer(this) - - @transient - protected[sql] val sqlParser: ParserInterface = new SparkQl(conf) - protected[sql] def parseSql(sql: String): LogicalPlan = sqlParser.parsePlan(sql) - protected[sql] def executeSql(sql: String): - org.apache.spark.sql.execution.QueryExecution = executePlan(parseSql(sql)) + protected[sql] def executeSql(sql: String): QueryExecution = executePlan(parseSql(sql)) - protected[sql] def executePlan(plan: LogicalPlan) = - new sparkexecution.QueryExecution(this, plan) + protected[sql] def executePlan(plan: LogicalPlan) = new QueryExecution(this, plan) /** * Add a jar to SQLContext @@ -298,10 +283,8 @@ class SQLContext private[sql]( * * @group basic * @since 1.3.0 - * TODO move to SQLSession? */ - @transient - val udf: UDFRegistration = new UDFRegistration(this) + def udf: UDFRegistration = sessionState.udf /** * Returns true if the table is currently cached in-memory. @@ -871,25 +854,9 @@ class SQLContext private[sql]( }.toArray } - @transient - protected[sql] val planner: sparkexecution.SparkPlanner = new sparkexecution.SparkPlanner(this) - @transient protected[sql] lazy val emptyResult = sparkContext.parallelize(Seq.empty[InternalRow], 1) - /** - * Prepares a planned SparkPlan for execution by inserting shuffle operations and internal - * row format conversions as needed. - */ - @transient - protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] { - val batches = Seq( - Batch("Subquery", Once, PlanSubqueries(self)), - Batch("Add exchange", Once, EnsureRequirements(self)), - Batch("Whole stage codegen", Once, CollapseCodegenStages(self)) - ) - } - /** * Parses the data type in our internal string representation. The data type string should * have the same format as the one generated by `toString` in scala. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index ecfc170bee3ad..d894825632a82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -22,6 +22,7 @@ import scala.util.Try import org.apache.spark.Logging import org.apache.spark.sql.api.java._ +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} import org.apache.spark.sql.execution.aggregate.ScalaUDAF @@ -34,19 +35,17 @@ import org.apache.spark.sql.types.DataType * * @since 1.3.0 */ -class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { - - private val functionRegistry = sqlContext.functionRegistry +class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends Logging { protected[sql] def registerPython(name: String, udf: UserDefinedPythonFunction): Unit = { log.debug( s""" | Registering new PythonUDF: | name: $name - | command: ${udf.command.toSeq} - | envVars: ${udf.envVars} - | pythonIncludes: ${udf.pythonIncludes} - | pythonExec: ${udf.pythonExec} + | command: ${udf.func.command.toSeq} + | envVars: ${udf.func.envVars} + | pythonIncludes: ${udf.func.pythonIncludes} + | pythonExec: ${udf.func.pythonExec} | dataType: ${udf.dataType} """.stripMargin) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index d912aeb70d517..68a251757c596 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -100,7 +100,7 @@ private[r] object SQLUtils { } def dfToRowRDD(df: DataFrame): JavaRDD[Array[Byte]] = { - df.map(r => rowToRBytes(r)) + df.rdd.map(r => rowToRBytes(r)) } private[this] def doConversion(data: Object, dataType: DataType): Object = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java index ea20115770f79..1d1d7edb240dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -29,12 +29,9 @@ /** * An iterator interface used to pull the output from generated function for multiple operators * (whole stage codegen). - * - * TODO: replaced it by batched columnar format. */ -public class BufferedRowIterator { +public abstract class BufferedRowIterator { protected LinkedList currentRows = new LinkedList<>(); - protected Iterator input; // used when there is no column in output protected UnsafeRow unsafeRow = new UnsafeRow(0); @@ -49,8 +46,16 @@ public InternalRow next() { return currentRows.remove(); } - public void setInput(Iterator iter) { - input = iter; + /** + * Initializes from array of iterators of InternalRow. + */ + public abstract void init(Iterator iters[]); + + /** + * Append a row to currentRows. + */ + protected void append(InternalRow row) { + currentRows.add(row); } /** @@ -74,9 +79,5 @@ protected void incPeakExecutionMemory(long size) { * * After it's called, if currentRow is still null, it means no more rows left. */ - protected void processNext() throws IOException { - if (input.hasNext()) { - currentRows.add(input.next()); - } - } + protected abstract void processNext() throws IOException; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index cad7e25a32788..2cbe3f2c94202 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -22,9 +22,12 @@ import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{BaseRelation, HadoopFsRelation} import org.apache.spark.sql.types.DataType @@ -102,7 +105,7 @@ private[sql] case class PhysicalRDD( override val metadata: Map[String, String] = Map.empty, isUnsafeRow: Boolean = false, override val outputPartitioning: Partitioning = UnknownPartitioning(0)) - extends LeafNode { + extends LeafNode with CodegenSupport { private[sql] override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) @@ -128,6 +131,36 @@ private[sql] case class PhysicalRDD( val metadataEntries = for ((key, value) <- metadata.toSeq.sorted) yield s"$key: $value" s"Scan $nodeName${output.mkString("[", ",", "]")}${metadataEntries.mkString(" ", ", ", "")}" } + + override def upstreams(): Seq[RDD[InternalRow]] = { + rdd :: Nil + } + + // Support codegen so that we can avoid the UnsafeRow conversion in all cases. Codegen + // never requires UnsafeRow as input. + override protected def doProduce(ctx: CodegenContext): String = { + val input = ctx.freshName("input") + // PhysicalRDD always just has one input + ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + + val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true)) + val row = ctx.freshName("row") + val numOutputRows = metricTerm(ctx, "numOutputRows") + ctx.INPUT_ROW = row + ctx.currentVars = null + val columns = exprs.map(_.gen(ctx)) + s""" + | while ($input.hasNext()) { + | InternalRow $row = (InternalRow) $input.next(); + | $numOutputRows.add(1); + | ${columns.map(_.code).mkString("\n").trim} + | ${consume(ctx, columns).trim} + | if (shouldStop()) { + | return; + | } + | } + """.stripMargin + } } private[sql] object PhysicalRDD { @@ -140,8 +173,13 @@ private[sql] object PhysicalRDD { rdd: RDD[InternalRow], relation: BaseRelation, metadata: Map[String, String] = Map.empty): PhysicalRDD = { - // All HadoopFsRelations output UnsafeRows - val outputUnsafeRows = relation.isInstanceOf[HadoopFsRelation] + val outputUnsafeRows = if (relation.isInstanceOf[ParquetRelation]) { + // The vectorized parquet reader does not produce unsafe rows. + !SQLContext.getActive().get.conf.getConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED) + } else { + // All HadoopFsRelations output UnsafeRows + relation.isInstanceOf[HadoopFsRelation] + } val bucketSpec = relation match { case r: HadoopFsRelation => r.getBucketSpec diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala index d26a0b74674a6..12998a38f59e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution -import scala.collection.immutable.IndexedSeq - import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ @@ -87,8 +85,8 @@ case class Expand( } } - override def upstream(): RDD[InternalRow] = { - child.asInstanceOf[CodegenSupport].upstream() + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() } protected override def doProduce(ctx: CodegenContext): String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index 4db88a09d8152..6bc4649d432ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.metric.SQLMetrics /** * For lazy computing, be sure the generator.terminate() called in the very last @@ -54,17 +55,19 @@ case class Generate( child: SparkPlan) extends UnaryNode { + private[sql] override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def expressions: Seq[Expression] = generator :: Nil val boundGenerator = BindReferences.bindReference(generator, child.output) protected override def doExecute(): RDD[InternalRow] = { // boundGenerator.terminate() should be triggered after all of the rows in the partition - if (join) { + val rows = if (join) { child.execute().mapPartitionsInternal { iter => - val generatorNullRow = InternalRow.fromSeq(Seq.fill[Any](generator.elementTypes.size)(null)) + val generatorNullRow = new GenericInternalRow(generator.elementTypes.size) val joinedRow = new JoinedRow - val proj = UnsafeProjection.create(output, output) iter.flatMap { row => // we should always set the left (child output) @@ -73,19 +76,26 @@ case class Generate( if (outer && outputRows.isEmpty) { joinedRow.withRight(generatorNullRow) :: Nil } else { - outputRows.map(or => joinedRow.withRight(or)) + outputRows.map(joinedRow.withRight) } - } ++ LazyIterator(() => boundGenerator.terminate()).map { row => + } ++ LazyIterator(boundGenerator.terminate).map { row => // we leave the left side as the last element of its child output // keep it the same as Hive does - proj(joinedRow.withRight(row)) + joinedRow.withRight(row) } } } else { child.execute().mapPartitionsInternal { iter => - val proj = UnsafeProjection.create(output, output) - (iter.flatMap(row => boundGenerator.eval(row)) ++ - LazyIterator(() => boundGenerator.terminate())).map(proj) + iter.flatMap(boundGenerator.eval) ++ LazyIterator(boundGenerator.terminate) + } + } + + val numOutputRows = longMetric("numOutputRows") + rows.mapPartitionsInternal { iter => + val proj = UnsafeProjection.create(output, output) + iter.map { r => + numOutputRows += 1 + proj(r) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala index 75cb6d1137c35..2ea889ea72c75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.execution -import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext} +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.executor.TaskMetrics import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.plans.physical.{Distribution, OrderedDistribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -37,7 +39,7 @@ case class Sort( global: Boolean, child: SparkPlan, testSpillFrequency: Int = 0) - extends UnaryNode { + extends UnaryNode with CodegenSupport { override def output: Seq[Attribute] = child.output @@ -50,34 +52,36 @@ case class Sort( "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")) - protected override def doExecute(): RDD[InternalRow] = { - val schema = child.schema - val childOutput = child.output + def createSorter(): UnsafeExternalRowSorter = { + val ordering = newOrdering(sortOrder, output) + + // The comparator for comparing prefix + val boundSortExpression = BindReferences.bindReference(sortOrder.head, output) + val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression) + + // The generator for prefix + val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression))) + val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { + override def computePrefix(row: InternalRow): Long = { + prefixProjection.apply(row).getLong(0) + } + } + val pageSize = SparkEnv.get.memoryManager.pageSizeBytes + val sorter = new UnsafeExternalRowSorter( + schema, ordering, prefixComparator, prefixComputer, pageSize) + if (testSpillFrequency > 0) { + sorter.setTestSpillFrequency(testSpillFrequency) + } + sorter + } + + protected override def doExecute(): RDD[InternalRow] = { val dataSize = longMetric("dataSize") val spillSize = longMetric("spillSize") child.execute().mapPartitionsInternal { iter => - val ordering = newOrdering(sortOrder, childOutput) - - // The comparator for comparing prefix - val boundSortExpression = BindReferences.bindReference(sortOrder.head, childOutput) - val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression) - - // The generator for prefix - val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression))) - val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { - override def computePrefix(row: InternalRow): Long = { - prefixProjection.apply(row).getLong(0) - } - } - - val pageSize = SparkEnv.get.memoryManager.pageSizeBytes - val sorter = new UnsafeExternalRowSorter( - schema, ordering, prefixComparator, prefixComputer, pageSize) - if (testSpillFrequency > 0) { - sorter.setTestSpillFrequency(testSpillFrequency) - } + val sorter = createSorter() val metrics = TaskContext.get().taskMetrics() // Remember spill data size of this task before execute this operator so that we can @@ -93,4 +97,74 @@ case class Sort( sortedIterator } } + + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() + } + + // Name of sorter variable used in codegen. + private var sorterVariable: String = _ + + override protected def doProduce(ctx: CodegenContext): String = { + val needToSort = ctx.freshName("needToSort") + ctx.addMutableState("boolean", needToSort, s"$needToSort = true;") + + + // Initialize the class member variables. This includes the instance of the Sorter and + // the iterator to return sorted rows. + val thisPlan = ctx.addReferenceObj("plan", this) + sorterVariable = ctx.freshName("sorter") + ctx.addMutableState(classOf[UnsafeExternalRowSorter].getName, sorterVariable, + s"$sorterVariable = $thisPlan.createSorter();") + val metrics = ctx.freshName("metrics") + ctx.addMutableState(classOf[TaskMetrics].getName, metrics, + s"$metrics = org.apache.spark.TaskContext.get().taskMetrics();") + val sortedIterator = ctx.freshName("sortedIter") + ctx.addMutableState("scala.collection.Iterator", sortedIterator, "") + + val addToSorter = ctx.freshName("addToSorter") + ctx.addNewFunction(addToSorter, + s""" + | private void $addToSorter() throws java.io.IOException { + | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} + | } + """.stripMargin.trim) + + val outputRow = ctx.freshName("outputRow") + val dataSize = metricTerm(ctx, "dataSize") + val spillSize = metricTerm(ctx, "spillSize") + val spillSizeBefore = ctx.freshName("spillSizeBefore") + s""" + | if ($needToSort) { + | $addToSorter(); + | Long $spillSizeBefore = $metrics.memoryBytesSpilled(); + | $sortedIterator = $sorterVariable.sort(); + | $dataSize.add($sorterVariable.getPeakMemoryUsage()); + | $spillSize.add($metrics.memoryBytesSpilled() - $spillSizeBefore); + | $metrics.incPeakExecutionMemory($sorterVariable.getPeakMemoryUsage()); + | $needToSort = false; + | } + | + | while ($sortedIterator.hasNext()) { + | UnsafeRow $outputRow = (UnsafeRow)$sortedIterator.next(); + | ${consume(ctx, null, outputRow)} + | if (shouldStop()) return; + | } + """.stripMargin.trim + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val colExprs = child.output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable) + } + + ctx.currentVars = input + val code = GenerateUnsafeProjection.createCode(ctx, colExprs) + + s""" + | // Convert the input attributes to an UnsafeRow and add it to the sorter + | ${code.code} + | $sorterVariable.insertRow(${code.value}); + """.stripMargin.trim + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 7347156398674..0255103b63d81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeComman import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _} import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight} +import org.apache.spark.sql.internal.SQLConf private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SparkPlanner => @@ -70,9 +71,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) => joins.LeftSemiJoinHash( leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil - // no predicate can be evaluated by matching hash keys - case logical.Join(left, right, LeftSemi, condition) => - joins.LeftSemiJoinBNL(planLater(left), planLater(right), condition) :: Nil case _ => Nil } } @@ -81,11 +79,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Matches a plan whose output should be small enough to be used in broadcast join. */ object CanBroadcast { - def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match { - case BroadcastHint(p) => Some(p) - case p if sqlContext.conf.autoBroadcastJoinThreshold > 0 && - p.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => Some(p) - case _ => None + def unapply(plan: LogicalPlan): Option[LogicalPlan] = { + if (sqlContext.conf.autoBroadcastJoinThreshold > 0 && + plan.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold) { + Some(plan) + } else { + None + } } } @@ -96,7 +96,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Join implementations are chosen with the following precedence: * * - Broadcast: if one side of the join has an estimated physical size that is smaller than the - * user-configurable [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold + * user-configurable [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold * or if that side has an explicit broadcast hint (e.g. the user applied the * [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame), then that side * of the join will be broadcasted and the other side will be streamed, with no shuffling @@ -250,22 +250,19 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object BroadcastNestedLoop extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.Join( - CanBroadcast(left), right, joinType, condition) if joinType != LeftSemi => + case j @ logical.Join(CanBroadcast(left), right, Inner | RightOuter, condition) => execution.joins.BroadcastNestedLoopJoin( - planLater(left), planLater(right), joins.BuildLeft, joinType, condition) :: Nil - case logical.Join( - left, CanBroadcast(right), joinType, condition) if joinType != LeftSemi => + planLater(left), planLater(right), joins.BuildLeft, j.joinType, condition) :: Nil + case j @ logical.Join(left, CanBroadcast(right), Inner | LeftOuter | LeftSemi, condition) => execution.joins.BroadcastNestedLoopJoin( - planLater(left), planLater(right), joins.BuildRight, joinType, condition) :: Nil + planLater(left), planLater(right), joins.BuildRight, j.joinType, condition) :: Nil case _ => Nil } } object CartesianProduct extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - // TODO CartesianProduct doesn't support the Left Semi Join - case logical.Join(left, right, joinType, None) if joinType != LeftSemi => + case logical.Join(left, right, Inner, None) => execution.joins.CartesianProduct(planLater(left), planLater(right)) :: Nil case logical.Join(left, right, Inner, Some(condition)) => execution.Filter(condition, @@ -283,6 +280,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } else { joins.BuildLeft } + // This join could be very slow or even hang forever joins.BroadcastNestedLoopJoin( planLater(left), planLater(right), buildSide, joinType, condition) :: Nil case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index d79b547137c0d..cb68ca6ada366 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.toCommentSafeString import org.apache.spark.sql.execution.aggregate.TungstenAggregate -import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight, SortMergeJoin} import org.apache.spark.sql.execution.metric.LongSQLMetricValue /** @@ -40,7 +40,9 @@ trait CodegenSupport extends SparkPlan { /** Prefix used in the current operator's variable names. */ private def variablePrefix: String = this match { case _: TungstenAggregate => "agg" - case _: BroadcastHashJoin => "join" + case _: BroadcastHashJoin => "bhj" + case _: SortMergeJoin => "smj" + case _: PhysicalRDD => "rdd" case _ => nodeName.toLowerCase } @@ -68,9 +70,11 @@ trait CodegenSupport extends SparkPlan { private var parent: CodegenSupport = null /** - * Returns the RDD of InternalRow which generates the input rows. + * Returns all the RDDs of InternalRow which generates the input rows. + * + * Note: right now we support up to two RDDs. */ - def upstream(): RDD[InternalRow] + def upstreams(): Seq[RDD[InternalRow]] /** * Returns Java source code to process the rows from upstream. @@ -179,19 +183,23 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { override def supportCodegen: Boolean = false - override def upstream(): RDD[InternalRow] = { - child.execute() + override def upstreams(): Seq[RDD[InternalRow]] = { + child.execute() :: Nil } override def doProduce(ctx: CodegenContext): String = { + val input = ctx.freshName("input") + // Right now, InputAdapter is only used when there is one upstream. + ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true)) val row = ctx.freshName("row") ctx.INPUT_ROW = row ctx.currentVars = null val columns = exprs.map(_.gen(ctx)) s""" - | while (input.hasNext()) { - | InternalRow $row = (InternalRow) input.next(); + | while ($input.hasNext()) { + | InternalRow $row = (InternalRow) $input.next(); | ${columns.map(_.code).mkString("\n").trim} | ${consume(ctx, columns).trim} | if (shouldStop()) { @@ -215,7 +223,7 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { * * -> execute() * | - * doExecute() ---------> upstream() -------> upstream() ------> execute() + * doExecute() ---------> upstreams() -------> upstreams() ------> execute() * | * -----------------> produce() * | @@ -267,6 +275,9 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) public GeneratedIterator(Object[] references) { this.references = references; + } + + public void init(scala.collection.Iterator inputs[]) { ${ctx.initMutableStates()} } @@ -276,26 +287,40 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) ${code.trim} } } - """ + """.trim // try to compile, helpful for debug val cleanedSource = CodeFormatter.stripExtraNewLines(source) // println(s"${CodeFormatter.format(cleanedSource)}") CodeGenerator.compile(cleanedSource) - plan.upstream().mapPartitions { iter => - - val clazz = CodeGenerator.compile(source) - val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] - buffer.setInput(iter) - new Iterator[InternalRow] { - override def hasNext: Boolean = buffer.hasNext - override def next: InternalRow = buffer.next() + val rdds = plan.upstreams() + assert(rdds.size <= 2, "Up to two upstream RDDs can be supported") + if (rdds.length == 1) { + rdds.head.mapPartitions { iter => + val clazz = CodeGenerator.compile(cleanedSource) + val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] + buffer.init(Array(iter)) + new Iterator[InternalRow] { + override def hasNext: Boolean = buffer.hasNext + override def next: InternalRow = buffer.next() + } + } + } else { + // Right now, we support up to two upstreams. + rdds.head.zipPartitions(rdds(1)) { (leftIter, rightIter) => + val clazz = CodeGenerator.compile(cleanedSource) + val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] + buffer.init(Array(leftIter, rightIter)) + new Iterator[InternalRow] { + override def hasNext: Boolean = buffer.hasNext + override def next: InternalRow = buffer.next() + } } } } - override def upstream(): RDD[InternalRow] = { + override def upstreams(): Seq[RDD[InternalRow]] = { throw new UnsupportedOperationException } @@ -312,8 +337,8 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) if (row != null) { // There is an UnsafeRow already s""" - | currentRows.add($row.copy()); - """.stripMargin + |append($row.copy()); + """.stripMargin.trim } else { assert(input != null) if (input.nonEmpty) { @@ -324,14 +349,14 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) ctx.currentVars = input val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false) s""" - | ${code.code.trim} - | currentRows.add(${code.value}.copy()); - """.stripMargin + |${code.code.trim} + |append(${code.value}.copy()); + """.stripMargin.trim } else { // There is no columns s""" - | currentRows.add(unsafeRow); - """.stripMargin + |append(unsafeRow); + """.stripMargin.trim } } } @@ -402,6 +427,9 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru b.copy(left = apply(left)) case b @ BroadcastHashJoin(_, _, _, BuildRight, _, left, right) => b.copy(right = apply(right)) + case j @ SortMergeJoin(_, _, _, left, right) => + // The children of SortMergeJoin should do codegen separately. + j.copy(left = apply(left), right = apply(right)) case p if !supportCodegen(p) => val input = apply(p) // collapse them recursively inputs += input diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 852203f3743dc..a46722963a6e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -121,8 +121,8 @@ case class TungstenAggregate( !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) } - override def upstream(): RDD[InternalRow] = { - child.asInstanceOf[CodegenSupport].upstream() + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() } protected override def doProduce(ctx: CodegenContext): String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 55bddd196ec46..b2f443c0e9ae6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -31,8 +31,8 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) override def output: Seq[Attribute] = projectList.map(_.toAttribute) - override def upstream(): RDD[InternalRow] = { - child.asInstanceOf[CodegenSupport].upstream() + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() } protected override def doProduce(ctx: CodegenContext): String = { @@ -69,8 +69,8 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit private[sql] override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - override def upstream(): RDD[InternalRow] = { - child.asInstanceOf[CodegenSupport].upstream() + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() } protected override def doProduce(ctx: CodegenContext): String = { @@ -156,8 +156,9 @@ case class Range( private[sql] override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - override def upstream(): RDD[InternalRow] = { - sqlContext.sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i)) + override def upstreams(): Seq[RDD[InternalRow]] = { + sqlContext.sparkContext.parallelize(0 until numSlices, numSlices) + .map(i => InternalRow(i)) :: Nil } protected override def doProduce(ctx: CodegenContext): String = { @@ -213,12 +214,15 @@ case class Range( | } """.stripMargin) + val input = ctx.freshName("input") + // Right now, Range is only used when there is one upstream. + ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") s""" | // initialize Range | if (!$initTerm) { | $initTerm = true; - | if (input.hasNext()) { - | initRange(((InternalRow) input.next()).getInt(0)); + | if ($input.hasNext()) { + | initRange(((InternalRow) $input.next()).getInt(0)); | } else { | return; | } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala index 4858140229d45..22d427808593e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala @@ -26,7 +26,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{LeafNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -63,7 +64,7 @@ private[sql] case class InMemoryRelation( @transient private[sql] var _cachedColumnBuffers: RDD[CachedBatch] = null, @transient private[sql] var _statistics: Statistics = null, private[sql] var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null) - extends LogicalPlan with MultiInstanceRelation { + extends logical.LeafNode with MultiInstanceRelation { override def producedAttributes: AttributeSet = outputSet @@ -184,8 +185,6 @@ private[sql] case class InMemoryRelation( _cachedColumnBuffers, statisticsToBePropagated, batchStats) } - override def children: Seq[LogicalPlan] = Seq.empty - override def newInstance(): this.type = { new InMemoryRelation( output.map(_.newInstance()), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index c6adb583f931b..5574645741823 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -21,12 +21,13 @@ import java.util.NoSuchElementException import org.apache.spark.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row, SQLConf, SQLContext} +import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CompressionCodecs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CompressionCodecs.scala index e683a95ed2aef..bc8ef4ad7e236 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CompressionCodecs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CompressionCodecs.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.{BZip2Codec, GzipCodec, Lz4Codec, SnappyCodec} import org.apache.spark.util.Utils @@ -44,4 +46,16 @@ private[datasources] object CompressionCodecs { s"is not available. Known codecs are ${shortCompressionCodecNames.keys.mkString(", ")}.") } } + + /** + * Set compression configurations to Hadoop `Configuration`. + * `codec` should be a full class path + */ + def setCodecConfiguration(conf: Configuration, codec: String): Unit = { + conf.set("mapreduce.output.fileoutputformat.compress", "true") + conf.set("mapreduce.output.fileoutputformat.compress.type", CompressionType.BLOCK.toString) + conf.set("mapreduce.output.fileoutputformat.compress.codec", codec) + conf.set("mapreduce.map.output.compress", "true") + conf.set("mapreduce.map.output.compress.codec", codec) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala index 2d3e1714d2b7b..c8b020d55a3cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.{RunnableCommand, SQLExecution} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala index 25911334a674f..f4271d165c9bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala @@ -32,7 +32,8 @@ import org.apache.spark.{Partition => SparkPartition, _} import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod -import org.apache.spark.sql.{SQLConf, SQLContext} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.execution.datasources.parquet.UnsafeRowParquetRecordReader import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 7e5c8f2f48d6b..c3db2a0af4bd1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.UnsafeKVExternalSorter +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{HadoopFsRelation, OutputWriter, OutputWriterFactory} import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.util.SerializableConfiguration diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index 292de11cd5c2b..edead9b21b21c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -29,7 +29,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion import org.apache.spark.sql.types._ - private[csv] object CSVInferSchema { /** @@ -48,7 +47,11 @@ private[csv] object CSVInferSchema { tokenRdd.aggregate(startType)(inferRowType(nullValue), mergeRowTypes) val structFields = header.zip(rootTypes).map { case (thisHeader, rootType) => - StructField(thisHeader, rootType, nullable = true) + val dType = rootType match { + case _: NullType => StringType + case other => other + } + StructField(thisHeader, dType, nullable = true) } StructType(structFields) @@ -65,12 +68,8 @@ private[csv] object CSVInferSchema { } def mergeRowTypes(first: Array[DataType], second: Array[DataType]): Array[DataType] = { - first.zipAll(second, NullType, NullType).map { case ((a, b)) => - val tpe = findTightestCommonType(a, b).getOrElse(StringType) - tpe match { - case _: NullType => StringType - case other => other - } + first.zipAll(second, NullType, NullType).map { case (a, b) => + findTightestCommonType(a, b).getOrElse(NullType) } } @@ -149,6 +148,8 @@ private[csv] object CSVInferSchema { case (t1, t2) if t1 == t2 => Some(t1) case (NullType, t1) => Some(t1) case (t1, NullType) => Some(t1) + case (StringType, t2) => Some(StringType) + case (t1, StringType) => Some(StringType) // Promote numeric types to the highest of the two and all numeric types to unlimited decimal case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) => @@ -159,7 +160,6 @@ private[csv] object CSVInferSchema { } } - private[csv] object CSVTypeCast { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index da945c44cde1c..e9afee1cc5142 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -24,7 +24,6 @@ import scala.util.control.NonFatal import com.google.common.base.Objects import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.{LongWritable, NullWritable, Text} -import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.hadoop.mapreduce.RecordWriter @@ -34,6 +33,7 @@ import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.CompressionCodecs import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -50,16 +50,16 @@ private[sql] class CSVRelation( case None => inferSchema(paths) } - private val params = new CSVOptions(parameters) + private val options = new CSVOptions(parameters) @transient private var cachedRDD: Option[RDD[String]] = None private def readText(location: String): RDD[String] = { - if (Charset.forName(params.charset) == Charset.forName("UTF-8")) { + if (Charset.forName(options.charset) == Charset.forName("UTF-8")) { sqlContext.sparkContext.textFile(location) } else { - val charset = params.charset + val charset = options.charset sqlContext.sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](location) .mapPartitions { _.map { pair => new String(pair._2.getBytes, 0, pair._2.getLength, charset) @@ -81,8 +81,8 @@ private[sql] class CSVRelation( private def tokenRdd(header: Array[String], inputPaths: Array[String]): RDD[Array[String]] = { val rdd = baseRdd(inputPaths) // Make sure firstLine is materialized before sending to executors - val firstLine = if (params.headerFlag) findFirstLine(rdd) else null - CSVRelation.univocityTokenizer(rdd, header, firstLine, params) + val firstLine = if (options.headerFlag) findFirstLine(rdd) else null + CSVRelation.univocityTokenizer(rdd, header, firstLine, options) } /** @@ -96,20 +96,16 @@ private[sql] class CSVRelation( val pathsString = inputs.map(_.getPath.toUri.toString) val header = schema.fields.map(_.name) val tokenizedRdd = tokenRdd(header, pathsString) - CSVRelation.parseCsv(tokenizedRdd, schema, requiredColumns, inputs, sqlContext, params) + CSVRelation.parseCsv(tokenizedRdd, schema, requiredColumns, inputs, sqlContext, options) } override def prepareJobForWrite(job: Job): OutputWriterFactory = { val conf = job.getConfiguration - params.compressionCodec.foreach { codec => - conf.set("mapreduce.output.fileoutputformat.compress", "true") - conf.set("mapreduce.output.fileoutputformat.compress.type", CompressionType.BLOCK.toString) - conf.set("mapreduce.output.fileoutputformat.compress.codec", codec) - conf.set("mapreduce.map.output.compress", "true") - conf.set("mapreduce.map.output.compress.codec", codec) + options.compressionCodec.foreach { codec => + CompressionCodecs.setCodecConfiguration(conf, codec) } - new CSVOutputWriterFactory(params) + new CSVOutputWriterFactory(options) } override def hashCode(): Int = Objects.hashCode(paths.toSet, dataSchema, schema, partitionColumns) @@ -129,17 +125,17 @@ private[sql] class CSVRelation( private def inferSchema(paths: Array[String]): StructType = { val rdd = baseRdd(paths) val firstLine = findFirstLine(rdd) - val firstRow = new LineCsvReader(params).parseLine(firstLine) + val firstRow = new LineCsvReader(options).parseLine(firstLine) - val header = if (params.headerFlag) { + val header = if (options.headerFlag) { firstRow } else { firstRow.zipWithIndex.map { case (value, index) => s"C$index" } } val parsedRdd = tokenRdd(header, paths) - if (params.inferSchemaFlag) { - CSVInferSchema.infer(parsedRdd, header, params.nullValue) + if (options.inferSchemaFlag) { + CSVInferSchema.infer(parsedRdd, header, options.nullValue) } else { // By default fields are assumed to be StringType val schemaFields = header.map { fieldName => @@ -153,8 +149,8 @@ private[sql] class CSVRelation( * Returns the first line of the first non-empty file in path */ private def findFirstLine(rdd: RDD[String]): String = { - if (params.isCommentSet) { - val comment = params.comment.toString + if (options.isCommentSet) { + val comment = options.comment.toString rdd.filter { line => line.trim.nonEmpty && !line.startsWith(comment) }.first() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index ee6373d03e1fd..9e336422d1f8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -44,6 +44,12 @@ private[sql] object JDBCRelation { * exactly once. The parameters minValue and maxValue are advisory in that * incorrect values may cause the partitioning to be poor, but no data * will fail to be represented. + * + * Null value predicate is added to the first partition where clause to include + * the rows with null value for the partitions column. + * + * @param partitioning partition information to generate the where clause for each partition + * @return an array of partitions with where clause for each partition */ def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = { if (partitioning == null) return Array[Partition](JDBCPartition(null, 0)) @@ -66,7 +72,7 @@ private[sql] object JDBCRelation { if (upperBound == null) { lowerBound } else if (lowerBound == null) { - upperBound + s"$upperBound or $column is null" } else { s"$lowerBound AND $upperBound" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala index 31a95ed461215..e59dbd6b3d438 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala @@ -48,10 +48,7 @@ private[sql] class JSONOptions( parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true) val allowBackslashEscapingAnyCharacter = parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false) - val compressionCodec = { - val name = parameters.get("compression").orElse(parameters.get("codec")) - name.map(CompressionCodecs.getCodecClassName) - } + val compressionCodec = parameters.get("compression").map(CompressionCodecs.getCodecClassName) /** Sets config options on a Jackson [[JsonFactory]]. */ def setJacksonOptions(factory: JsonFactory): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index c893558136549..28136911fe240 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -165,11 +165,7 @@ private[sql] class JSONRelation( override def prepareJobForWrite(job: Job): BucketedOutputWriterFactory = { val conf = job.getConfiguration options.compressionCodec.foreach { codec => - conf.set("mapreduce.output.fileoutputformat.compress", "true") - conf.set("mapreduce.output.fileoutputformat.compress.type", CompressionType.BLOCK.toString) - conf.set("mapreduce.output.fileoutputformat.compress.codec", codec) - conf.set("mapreduce.map.output.compress", "true") - conf.set("mapreduce.map.output.compress.codec", codec) + CompressionCodecs.setCodecConfiguration(conf, codec) } new BucketedOutputWriterFactory { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index 42d89f4bf81d6..8a128b4b61769 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -368,7 +368,7 @@ private[parquet] class CatalystRowConverter( } protected def decimalFromBinary(value: Binary): Decimal = { - if (precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64) { + if (precision <= Decimal.MAX_LONG_DIGITS) { // Constructs a `Decimal` with an unscaled `Long` value if possible. val unscaled = CatalystRowConverter.binaryToUnscaledLong(value) Decimal(unscaled, precision, scale) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala index 54dda0c3915b4..6f6340f541ada 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala @@ -25,8 +25,9 @@ import org.apache.parquet.schema.OriginalType._ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ import org.apache.parquet.schema.Type.Repetition._ -import org.apache.spark.sql.{AnalysisException, SQLConf} -import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.{maxPrecisionForBytes, MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64} +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.maxPrecisionForBytes +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** @@ -37,7 +38,6 @@ import org.apache.spark.sql.types._ * [[MessageType]] schemas. * * @see https://github.com/apache/parquet-format/blob/master/LogicalTypes.md - * * @constructor * @param assumeBinaryIsString Whether unannotated BINARY fields should be assumed to be Spark SQL * [[StringType]] fields when converting Parquet a [[MessageType]] to Spark SQL @@ -145,7 +145,7 @@ private[parquet] class CatalystSchemaConverter( case INT_16 => ShortType case INT_32 | null => IntegerType case DATE => DateType - case DECIMAL => makeDecimalType(MAX_PRECISION_FOR_INT32) + case DECIMAL => makeDecimalType(Decimal.MAX_INT_DIGITS) case UINT_8 => typeNotSupported() case UINT_16 => typeNotSupported() case UINT_32 => typeNotSupported() @@ -156,7 +156,7 @@ private[parquet] class CatalystSchemaConverter( case INT64 => originalType match { case INT_64 | null => LongType - case DECIMAL => makeDecimalType(MAX_PRECISION_FOR_INT64) + case DECIMAL => makeDecimalType(Decimal.MAX_LONG_DIGITS) case UINT_64 => typeNotSupported() case TIMESTAMP_MILLIS => typeNotImplemented() case _ => illegalType() @@ -403,7 +403,7 @@ private[parquet] class CatalystSchemaConverter( // Uses INT32 for 1 <= precision <= 9 case DecimalType.Fixed(precision, scale) - if precision <= MAX_PRECISION_FOR_INT32 && !writeLegacyParquetFormat => + if precision <= Decimal.MAX_INT_DIGITS && !writeLegacyParquetFormat => Types .primitive(INT32, repetition) .as(DECIMAL) @@ -413,7 +413,7 @@ private[parquet] class CatalystSchemaConverter( // Uses INT64 for 1 <= precision <= 18 case DecimalType.Fixed(precision, scale) - if precision <= MAX_PRECISION_FOR_INT64 && !writeLegacyParquetFormat => + if precision <= Decimal.MAX_LONG_DIGITS && !writeLegacyParquetFormat => Types .primitive(INT64, repetition) .as(DECIMAL) @@ -569,10 +569,6 @@ private[parquet] object CatalystSchemaConverter { // Returns the minimum number of bytes needed to store a decimal with a given `precision`. val minBytesForPrecision = Array.tabulate[Int](39)(computeMinBytesForPrecision) - val MAX_PRECISION_FOR_INT32 = maxPrecisionForBytes(4) /* 9 */ - - val MAX_PRECISION_FOR_INT64 = maxPrecisionForBytes(8) /* 18 */ - // Max precision of a decimal value stored in `numBytes` bytes def maxPrecisionForBytes(numBytes: Int): Int = { Math.round( // convert double to long diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala index e78afa5ae6d0b..0252c79d8e143 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala @@ -30,11 +30,11 @@ import org.apache.parquet.hadoop.api.WriteSupport.WriteContext import org.apache.parquet.io.api.{Binary, RecordConsumer} import org.apache.spark.Logging -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.{minBytesForPrecision, MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64} +import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.minBytesForPrecision +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** @@ -253,13 +253,13 @@ private[parquet] class CatalystWriteSupport extends WriteSupport[InternalRow] wi writeLegacyParquetFormat match { // Standard mode, 1 <= precision <= 9, writes as INT32 - case false if precision <= MAX_PRECISION_FOR_INT32 => int32Writer + case false if precision <= Decimal.MAX_INT_DIGITS => int32Writer // Standard mode, 10 <= precision <= 18, writes as INT64 - case false if precision <= MAX_PRECISION_FOR_INT64 => int64Writer + case false if precision <= Decimal.MAX_LONG_DIGITS => int64Writer // Legacy mode, 1 <= precision <= 18, writes as FIXED_LEN_BYTE_ARRAY - case true if precision <= MAX_PRECISION_FOR_INT64 => binaryWriterUsingUnscaledLong + case true if precision <= Decimal.MAX_LONG_DIGITS => binaryWriterUsingUnscaledLong // Either standard or legacy mode, 19 <= precision <= 38, writes as FIXED_LEN_BYTE_ARRAY case _ => binaryWriterUsingUnscaledBytes diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 1e686d41f41db..184cbb2f296b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -47,6 +47,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.LegacyTypeStringParser import org.apache.spark.sql.execution.datasources.{PartitionSpec, _} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index 430257f60d9fe..8f3f6335e4282 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} -import org.apache.spark.sql.execution.datasources.PartitionSpec +import org.apache.spark.sql.execution.datasources.{CompressionCodecs, PartitionSpec} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.SerializableConfiguration @@ -48,7 +48,7 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { partitionColumns: Option[StructType], parameters: Map[String, String]): HadoopFsRelation = { dataSchema.foreach(verifySchema) - new TextRelation(None, dataSchema, partitionColumns, paths)(sqlContext) + new TextRelation(None, dataSchema, partitionColumns, paths, parameters)(sqlContext) } override def shortName(): String = "text" @@ -114,6 +114,12 @@ private[sql] class TextRelation( /** Write path. */ override def prepareJobForWrite(job: Job): OutputWriterFactory = { + val conf = job.getConfiguration + val compressionCodec = parameters.get("compression").map(CompressionCodecs.getCodecClassName) + compressionCodec.foreach { codec => + CompressionCodecs.setCodecConfiguration(conf, codec) + } + new OutputWriterFactory { override def newInstance( path: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index dbb6b654b1a38..95d033bc57548 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.trees.TreeNodeRef +import org.apache.spark.sql.internal.SQLConf /** * Contains methods for debugging query execution. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index ddc08822f3e17..6699dbafe7e74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -99,8 +99,8 @@ case class BroadcastHashJoin( } } - override def upstream(): RDD[InternalRow] = { - streamedPlan.asInstanceOf[CodegenSupport].upstream() + override def upstreams(): Seq[RDD[InternalRow]] = { + streamedPlan.asInstanceOf[CodegenSupport].upstreams() } override def doProduce(ctx: CodegenContext): String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index e8bd7f69dbab9..d83486df02c87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.joins +import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -26,7 +27,6 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.collection.{BitSet, CompactBuffer} - case class BroadcastNestedLoopJoin( left: SparkPlan, right: SparkPlan, @@ -51,125 +51,266 @@ case class BroadcastNestedLoopJoin( } private[this] def genResultProjection: InternalRow => InternalRow = { - UnsafeProjection.create(schema) + if (joinType == LeftSemi) { + UnsafeProjection.create(output, output) + } else { + // Always put the stream side on left to simplify implementation + UnsafeProjection.create(output, streamed.output ++ broadcast.output) + } } override def outputPartitioning: Partitioning = streamed.outputPartitioning override def output: Seq[Attribute] = { joinType match { + case Inner => + left.output ++ right.output case LeftOuter => left.output ++ right.output.map(_.withNullability(true)) case RightOuter => left.output.map(_.withNullability(true)) ++ right.output case FullOuter => left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case Inner => - // TODO we can avoid breaking the lineage, since we union an empty RDD for Inner Join case - left.output ++ right.output - case x => // TODO support the Left Semi Join + case LeftSemi => + left.output + case x => throw new IllegalArgumentException( s"BroadcastNestedLoopJoin should not take $x as the JoinType") } } - @transient private lazy val boundCondition = - newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) + @transient private lazy val boundCondition = { + if (condition.isDefined) { + newPredicate(condition.get, streamed.output ++ broadcast.output) + } else { + (r: InternalRow) => true + } + } - protected override def doExecute(): RDD[InternalRow] = { - val numOutputRows = longMetric("numOutputRows") + /** + * The implementation for InnerJoin. + */ + private def innerJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = { + streamed.execute().mapPartitionsInternal { streamedIter => + val buildRows = relation.value + val joinedRow = new JoinedRow - val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]() + streamedIter.flatMap { streamedRow => + val joinedRows = buildRows.iterator.map(r => joinedRow(streamedRow, r)) + if (condition.isDefined) { + joinedRows.filter(boundCondition) + } else { + joinedRows + } + } + } + } - /** All rows that either match both-way, or rows from streamed joined with nulls. */ - val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => - val relation = broadcastedRelation.value + /** + * The implementation for these joins: + * + * LeftOuter with BuildRight + * RightOuter with BuildLeft + */ + private def outerJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = { + streamed.execute().mapPartitionsInternal { streamedIter => + val buildRows = relation.value + val joinedRow = new JoinedRow + val nulls = new GenericMutableRow(broadcast.output.size) + + // Returns an iterator to avoid copy the rows. + new Iterator[InternalRow] { + // current row from stream side + private var streamRow: InternalRow = null + // have found a match for current row or not + private var foundMatch: Boolean = false + // the matched result row + private var resultRow: InternalRow = null + // the next index of buildRows to try + private var nextIndex: Int = 0 - val matchedRows = new CompactBuffer[InternalRow] - val includedBroadcastTuples = new BitSet(relation.length) + private def findNextMatch(): Boolean = { + if (streamRow == null) { + if (!streamedIter.hasNext) { + return false + } + streamRow = streamedIter.next() + nextIndex = 0 + foundMatch = false + } + while (nextIndex < buildRows.length) { + resultRow = joinedRow(streamRow, buildRows(nextIndex)) + nextIndex += 1 + if (boundCondition(resultRow)) { + foundMatch = true + return true + } + } + if (!foundMatch) { + resultRow = joinedRow(streamRow, nulls) + streamRow = null + true + } else { + resultRow = null + streamRow = null + findNextMatch() + } + } + + override def hasNext(): Boolean = { + resultRow != null || findNextMatch() + } + override def next(): InternalRow = { + val r = resultRow + resultRow = null + r + } + } + } + } + + /** + * The implementation for these joins: + * + * LeftSemi with BuildRight + */ + private def leftSemiJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = { + assert(buildSide == BuildRight) + streamed.execute().mapPartitionsInternal { streamedIter => + val buildRows = relation.value val joinedRow = new JoinedRow - val leftNulls = new GenericMutableRow(left.output.size) - val rightNulls = new GenericMutableRow(right.output.size) - val resultProj = genResultProjection + if (condition.isDefined) { + streamedIter.filter(l => + buildRows.exists(r => boundCondition(joinedRow(l, r))) + ) + } else { + streamedIter.filter(r => !buildRows.isEmpty) + } + } + } + + /** + * The implementation for these joins: + * + * LeftOuter with BuildLeft + * RightOuter with BuildRight + * FullOuter + * LeftSemi with BuildLeft + */ + private def defaultJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = { + /** All rows that either match both-way, or rows from streamed joined with nulls. */ + val streamRdd = streamed.execute() + + val matchedBuildRows = streamRdd.mapPartitionsInternal { streamedIter => + val buildRows = relation.value + val matched = new BitSet(buildRows.length) + val joinedRow = new JoinedRow streamedIter.foreach { streamedRow => var i = 0 - var streamRowMatched = false - - while (i < relation.length) { - val broadcastedRow = relation(i) - buildSide match { - case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => - matchedRows += resultProj(joinedRow(streamedRow, broadcastedRow)).copy() - streamRowMatched = true - includedBroadcastTuples.set(i) - case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) => - matchedRows += resultProj(joinedRow(broadcastedRow, streamedRow)).copy() - streamRowMatched = true - includedBroadcastTuples.set(i) - case _ => + while (i < buildRows.length) { + if (boundCondition(joinedRow(streamedRow, buildRows(i)))) { + matched.set(i) } i += 1 } + } + Seq(matched).toIterator + } - (streamRowMatched, joinType, buildSide) match { - case (false, LeftOuter | FullOuter, BuildRight) => - matchedRows += resultProj(joinedRow(streamedRow, rightNulls)).copy() - case (false, RightOuter | FullOuter, BuildLeft) => - matchedRows += resultProj(joinedRow(leftNulls, streamedRow)).copy() - case _ => + val matchedBroadcastRows = matchedBuildRows.fold( + new BitSet(relation.value.length) + )(_ | _) + + if (joinType == LeftSemi) { + assert(buildSide == BuildLeft) + val buf: CompactBuffer[InternalRow] = new CompactBuffer() + var i = 0 + val rel = relation.value + while (i < rel.length) { + if (matchedBroadcastRows.get(i)) { + buf += rel(i).copy() } + i += 1 } - Iterator((matchedRows, includedBroadcastTuples)) + return sparkContext.makeRDD(buf.toSeq) } - val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2) - val allIncludedBroadcastTuples = includedBroadcastTuples.fold( - new BitSet(broadcastedRelation.value.size) - )(_ | _) + val matchedStreamRows = streamRdd.mapPartitionsInternal { streamedIter => + val buildRows = relation.value + val joinedRow = new JoinedRow + val nulls = new GenericMutableRow(broadcast.output.size) - val leftNulls = new GenericMutableRow(left.output.size) - val rightNulls = new GenericMutableRow(right.output.size) - val resultProj = genResultProjection + streamedIter.flatMap { streamedRow => + var i = 0 + var foundMatch = false + val matchedRows = new CompactBuffer[InternalRow] + + while (i < buildRows.length) { + if (boundCondition(joinedRow(streamedRow, buildRows(i)))) { + matchedRows += joinedRow.copy() + foundMatch = true + } + i += 1 + } + + if (!foundMatch && joinType == FullOuter) { + matchedRows += joinedRow(streamedRow, nulls).copy() + } + matchedRows.iterator + } + } - /** Rows from broadcasted joined with nulls. */ - val broadcastRowsWithNulls: Seq[InternalRow] = { + val notMatchedBroadcastRows: Seq[InternalRow] = { + val nulls = new GenericMutableRow(streamed.output.size) val buf: CompactBuffer[InternalRow] = new CompactBuffer() var i = 0 - val rel = broadcastedRelation.value - (joinType, buildSide) match { - case (RightOuter | FullOuter, BuildRight) => - val joinedRow = new JoinedRow - joinedRow.withLeft(leftNulls) - while (i < rel.length) { - if (!allIncludedBroadcastTuples.get(i)) { - buf += resultProj(joinedRow.withRight(rel(i))).copy() - } - i += 1 - } - case (LeftOuter | FullOuter, BuildLeft) => - val joinedRow = new JoinedRow - joinedRow.withRight(rightNulls) - while (i < rel.length) { - if (!allIncludedBroadcastTuples.get(i)) { - buf += resultProj(joinedRow.withLeft(rel(i))).copy() - } - i += 1 - } - case _ => + val buildRows = relation.value + val joinedRow = new JoinedRow + joinedRow.withLeft(nulls) + while (i < buildRows.length) { + if (!matchedBroadcastRows.get(i)) { + buf += joinedRow.withRight(buildRows(i)).copy() + } + i += 1 } buf.toSeq } - // TODO: Breaks lineage. sparkContext.union( - matchesOrStreamedRowsWithNulls.flatMap(_._1), - sparkContext.makeRDD(broadcastRowsWithNulls) - ).map { row => - // `broadcastRowsWithNulls` doesn't run in a job so that we have to track numOutputRows here. - numOutputRows += 1 - row + matchedStreamRows, + sparkContext.makeRDD(notMatchedBroadcastRows) + ) + } + + protected override def doExecute(): RDD[InternalRow] = { + val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]() + + val resultRdd = (joinType, buildSide) match { + case (Inner, _) => + innerJoin(broadcastedRelation) + case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) => + outerJoin(broadcastedRelation) + case (LeftSemi, BuildRight) => + leftSemiJoin(broadcastedRelation) + case _ => + /** + * LeftOuter with BuildLeft + * RightOuter with BuildRight + * FullOuter + * LeftSemi with BuildLeft + */ + defaultJoin(broadcastedRelation) + } + + val numOutputRows = longMetric("numOutputRows") + resultRdd.mapPartitionsInternal { iter => + val resultProj = genResultProjection + iter.map { r => + numOutputRows += 1 + resultProj(r) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index e417079b61b4e..fabd2fbe1e0c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter - /** * An optimized CartesianRDD for UnsafeRow, which will cache the rows from second child RDD, * will be much faster than building the right partition for every row in left RDD, it also diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala deleted file mode 100644 index df6dac88187cc..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.metric.SQLMetrics - -/** - * Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys - * for hash join. - */ -case class LeftSemiJoinBNL( - streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) extends BinaryNode { - - override private[sql] lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - - override def outputPartitioning: Partitioning = streamed.outputPartitioning - - override def output: Seq[Attribute] = left.output - - /** The Streamed Relation */ - override def left: SparkPlan = streamed - - /** The Broadcast relation */ - override def right: SparkPlan = broadcast - - override def requiredChildDistribution: Seq[Distribution] = { - UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) :: Nil - } - - @transient private lazy val boundCondition = - newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) - - protected override def doExecute(): RDD[InternalRow] = { - val numOutputRows = longMetric("numOutputRows") - - val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]() - - streamed.execute().mapPartitions { streamedIter => - val joinedRow = new JoinedRow - val relation = broadcastedRelation.value - - streamedIter.filter(streamedRow => { - var i = 0 - var matched = false - - while (i < relation.length && !matched) { - if (boundCondition(joinedRow(streamedRow, relation(i)))) { - matched = true - } - i += 1 - } - if (matched) { - numOutputRows += 1 - } - matched - }) - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index cd8a5670e2301..7ec4027188f14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -22,9 +22,10 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan} -import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, RowIterator, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics /** * Performs an sort merge join of two child relations. @@ -34,7 +35,7 @@ case class SortMergeJoin( rightKeys: Seq[Expression], condition: Option[Expression], left: SparkPlan, - right: SparkPlan) extends BinaryNode { + right: SparkPlan) extends BinaryNode with CodegenSupport { override private[sql] lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) @@ -125,6 +126,246 @@ case class SortMergeJoin( }.toScala } } + + override def upstreams(): Seq[RDD[InternalRow]] = { + left.execute() :: right.execute() :: Nil + } + + private def createJoinKey( + ctx: CodegenContext, + row: String, + keys: Seq[Expression], + input: Seq[Attribute]): Seq[ExprCode] = { + ctx.INPUT_ROW = row + keys.map(BindReferences.bindReference(_, input).gen(ctx)) + } + + private def copyKeys(ctx: CodegenContext, vars: Seq[ExprCode]): Seq[ExprCode] = { + vars.zipWithIndex.map { case (ev, i) => + val value = ctx.freshName("value") + ctx.addMutableState(ctx.javaType(leftKeys(i).dataType), value, "") + val code = + s""" + |$value = ${ev.value}; + """.stripMargin + ExprCode(code, "false", value) + } + } + + private def genComparision(ctx: CodegenContext, a: Seq[ExprCode], b: Seq[ExprCode]): String = { + val comparisons = a.zip(b).zipWithIndex.map { case ((l, r), i) => + s""" + |if (comp == 0) { + | comp = ${ctx.genComp(leftKeys(i).dataType, l.value, r.value)}; + |} + """.stripMargin.trim + } + s""" + |comp = 0; + |${comparisons.mkString("\n")} + """.stripMargin + } + + /** + * Generate a function to scan both left and right to find a match, returns the term for + * matched one row from left side and buffered rows from right side. + */ + private def genScanner(ctx: CodegenContext): (String, String) = { + // Create class member for next row from both sides. + val leftRow = ctx.freshName("leftRow") + ctx.addMutableState("InternalRow", leftRow, "") + val rightRow = ctx.freshName("rightRow") + ctx.addMutableState("InternalRow", rightRow, s"$rightRow = null;") + + // Create variables for join keys from both sides. + val leftKeyVars = createJoinKey(ctx, leftRow, leftKeys, left.output) + val leftAnyNull = leftKeyVars.map(_.isNull).mkString(" || ") + val rightKeyTmpVars = createJoinKey(ctx, rightRow, rightKeys, right.output) + val rightAnyNull = rightKeyTmpVars.map(_.isNull).mkString(" || ") + // Copy the right key as class members so they could be used in next function call. + val rightKeyVars = copyKeys(ctx, rightKeyTmpVars) + + // A list to hold all matched rows from right side. + val matches = ctx.freshName("matches") + val clsName = classOf[java.util.ArrayList[InternalRow]].getName + ctx.addMutableState(clsName, matches, s"$matches = new $clsName();") + // Copy the left keys as class members so they could be used in next function call. + val matchedKeyVars = copyKeys(ctx, leftKeyVars) + + ctx.addNewFunction("findNextInnerJoinRows", + s""" + |private boolean findNextInnerJoinRows( + | scala.collection.Iterator leftIter, + | scala.collection.Iterator rightIter) { + | $leftRow = null; + | int comp = 0; + | while ($leftRow == null) { + | if (!leftIter.hasNext()) return false; + | $leftRow = (InternalRow) leftIter.next(); + | ${leftKeyVars.map(_.code).mkString("\n")} + | if ($leftAnyNull) { + | $leftRow = null; + | continue; + | } + | if (!$matches.isEmpty()) { + | ${genComparision(ctx, leftKeyVars, matchedKeyVars)} + | if (comp == 0) { + | return true; + | } + | $matches.clear(); + | } + | + | do { + | if ($rightRow == null) { + | if (!rightIter.hasNext()) { + | ${matchedKeyVars.map(_.code).mkString("\n")} + | return !$matches.isEmpty(); + | } + | $rightRow = (InternalRow) rightIter.next(); + | ${rightKeyTmpVars.map(_.code).mkString("\n")} + | if ($rightAnyNull) { + | $rightRow = null; + | continue; + | } + | ${rightKeyVars.map(_.code).mkString("\n")} + | } + | ${genComparision(ctx, leftKeyVars, rightKeyVars)} + | if (comp > 0) { + | $rightRow = null; + | } else if (comp < 0) { + | if (!$matches.isEmpty()) { + | ${matchedKeyVars.map(_.code).mkString("\n")} + | return true; + | } + | $leftRow = null; + | } else { + | $matches.add($rightRow.copy()); + | $rightRow = null;; + | } + | } while ($leftRow != null); + | } + | return false; // unreachable + |} + """.stripMargin) + + (leftRow, matches) + } + + /** + * Creates variables for left part of result row. + * + * In order to defer the access after condition and also only access once in the loop, + * the variables should be declared separately from accessing the columns, we can't use the + * codegen of BoundReference here. + */ + private def createLeftVars(ctx: CodegenContext, leftRow: String): Seq[ExprCode] = { + ctx.INPUT_ROW = leftRow + left.output.zipWithIndex.map { case (a, i) => + val value = ctx.freshName("value") + val valueCode = ctx.getValue(leftRow, a.dataType, i.toString) + // declare it as class member, so we can access the column before or in the loop. + ctx.addMutableState(ctx.javaType(a.dataType), value, "") + if (a.nullable) { + val isNull = ctx.freshName("isNull") + ctx.addMutableState("boolean", isNull, "") + val code = + s""" + |$isNull = $leftRow.isNullAt($i); + |$value = $isNull ? ${ctx.defaultValue(a.dataType)} : ($valueCode); + """.stripMargin + ExprCode(code, isNull, value) + } else { + ExprCode(s"$value = $valueCode;", "false", value) + } + } + } + + /** + * Creates the variables for right part of result row, using BoundReference, since the right + * part are accessed inside the loop. + */ + private def createRightVar(ctx: CodegenContext, rightRow: String): Seq[ExprCode] = { + ctx.INPUT_ROW = rightRow + right.output.zipWithIndex.map { case (a, i) => + BoundReference(i, a.dataType, a.nullable).gen(ctx) + } + } + + /** + * Splits variables based on whether it's used by condition or not, returns the code to create + * these variables before the condition and after the condition. + * + * Only a few columns are used by condition, then we can skip the accessing of those columns + * that are not used by condition also filtered out by condition. + */ + private def splitVarsByCondition( + attributes: Seq[Attribute], + variables: Seq[ExprCode]): (String, String) = { + if (condition.isDefined) { + val condRefs = condition.get.references + val (used, notUsed) = attributes.zip(variables).partition{ case (a, ev) => + condRefs.contains(a) + } + val beforeCond = used.map(_._2.code).mkString("\n") + val afterCond = notUsed.map(_._2.code).mkString("\n") + (beforeCond, afterCond) + } else { + (variables.map(_.code).mkString("\n"), "") + } + } + + override def doProduce(ctx: CodegenContext): String = { + val leftInput = ctx.freshName("leftInput") + ctx.addMutableState("scala.collection.Iterator", leftInput, s"$leftInput = inputs[0];") + val rightInput = ctx.freshName("rightInput") + ctx.addMutableState("scala.collection.Iterator", rightInput, s"$rightInput = inputs[1];") + + val (leftRow, matches) = genScanner(ctx) + + // Create variables for row from both sides. + val leftVars = createLeftVars(ctx, leftRow) + val rightRow = ctx.freshName("rightRow") + val rightVars = createRightVar(ctx, rightRow) + val resultVars = leftVars ++ rightVars + + // Check condition + ctx.currentVars = resultVars + val cond = if (condition.isDefined) { + BindReferences.bindReference(condition.get, output).gen(ctx) + } else { + ExprCode("", "false", "true") + } + // Split the code of creating variables based on whether it's used by condition or not. + val loaded = ctx.freshName("loaded") + val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars) + val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars) + + + val size = ctx.freshName("size") + val i = ctx.freshName("i") + val numOutput = metricTerm(ctx, "numOutputRows") + s""" + |while (findNextInnerJoinRows($leftInput, $rightInput)) { + | int $size = $matches.size(); + | boolean $loaded = false; + | $leftBefore + | for (int $i = 0; $i < $size; $i ++) { + | InternalRow $rightRow = (InternalRow) $matches.get($i); + | $rightBefore + | ${cond.code} + | if (${cond.isNull} || !${cond.value}) continue; + | if (!$loaded) { + | $loaded = true; + | $leftAfter + | } + | $rightAfter + | $numOutput.add(1); + | ${consume(ctx, resultVars)} + | } + | if (shouldStop()) return; + |} + """.stripMargin + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index cd543d4195286..45175d36d5c9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -21,9 +21,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LazilyGeneratedOrdering} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.exchange.ShuffleExchange +import org.apache.spark.sql.execution.metric.SQLMetrics /** @@ -48,7 +49,7 @@ case class CollectLimit(limit: Int, child: SparkPlan) extends UnaryNode { /** * Helper trait which defines methods that are shared by both [[LocalLimit]] and [[GlobalLimit]]. */ -trait BaseLimit extends UnaryNode { +trait BaseLimit extends UnaryNode with CodegenSupport { val limit: Int override def output: Seq[Attribute] = child.output override def outputOrdering: Seq[SortOrder] = child.outputOrdering @@ -56,6 +57,36 @@ trait BaseLimit extends UnaryNode { protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => iter.take(limit) } + + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() + } + + protected override def doProduce(ctx: CodegenContext): String = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val stopEarly = ctx.freshName("stopEarly") + ctx.addMutableState("boolean", stopEarly, s"$stopEarly = false;") + + ctx.addNewFunction("shouldStop", s""" + @Override + protected boolean shouldStop() { + return !currentRows.isEmpty() || $stopEarly; + } + """) + val countTerm = ctx.freshName("count") + ctx.addMutableState("int", countTerm, s"$countTerm = 0;") + s""" + | if ($countTerm < $limit) { + | $countTerm += 1; + | ${consume(ctx, input)} + | } else { + | $stopEarly = true; + | } + """.stripMargin + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala index 59345046da495..8f063e24fbf8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide, HashedRelation} +import org.apache.spark.sql.internal.SQLConf /** * A [[HashJoinNode]] that builds the [[HashedRelation]] according to the value of diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BroadcastHashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BroadcastHashJoinNode.scala index cd1c86516ec5f..9ffa272d2189a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BroadcastHashJoinNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BroadcastHashJoinNode.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.execution.local import org.apache.spark.broadcast.Broadcast -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide, HashedRelation} +import org.apache.spark.sql.internal.SQLConf /** * A [[HashJoinNode]] for broadcast join. It takes a streamedNode and a broadcast diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala index b31c5a863832e..f79d795a904d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, FromUnsafeProjection, Projection} +import org.apache.spark.sql.internal.SQLConf case class ConvertToSafeNode(conf: SQLConf, child: LocalNode) extends UnaryLocalNode(conf) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala index de2f4e661ab44..f3fa474b0f7ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Projection, UnsafeProjection} +import org.apache.spark.sql.internal.SQLConf case class ConvertToUnsafeNode(conf: SQLConf, child: LocalNode) extends UnaryLocalNode(conf) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala index 85111bd6d1c98..6ccd6db0e6ca4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.internal.SQLConf case class ExpandNode( conf: SQLConf, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala index dd1113b6726cf..c5eb33cef4420 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate +import org.apache.spark.sql.internal.SQLConf case class FilterNode(conf: SQLConf, condition: Expression, child: LocalNode) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala index 740d485f8d9e6..e594e132dea79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.execution.local import scala.collection.mutable -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.internal.SQLConf case class IntersectNode(conf: SQLConf, left: LocalNode, right: LocalNode) extends BinaryLocalNode(conf) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala index 401b10a5ed307..9af45ac0aac9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.internal.SQLConf case class LimitNode(conf: SQLConf, limit: Int, child: LocalNode) extends UnaryLocalNode(conf) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index 8726e4878106d..a5d09691dc46c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -18,11 +18,12 @@ package org.apache.spark.sql.execution.local import org.apache.spark.Logging -import org.apache.spark.sql.{Row, SQLConf} +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala index b93bde58a55e5..b5ea08325c58e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.collection.{BitSet, CompactBuffer} case class NestedLoopJoinNode( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala index bd73b08263f87..5fe068a13c8a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, UnsafeProjection} +import org.apache.spark.sql.internal.SQLConf case class ProjectNode(conf: SQLConf, projectList: Seq[NamedExpression], child: LocalNode) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala index 793700803f216..078fb50deb16f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala index b8467f6ae58e0..8ebfe3a68b3a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.internal.SQLConf /** * An operator that scans some local data collection in the form of Scala Seq. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala index ca68b7677ce83..f52f5f7bb59b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.BoundedPriorityQueue case class TakeOrderedAndProjectNode( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala index 0f2b8303e7372..e53bc220d8d34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.internal.SQLConf case class UnionNode(conf: SQLConf, children: Seq[LocalNode]) extends LocalNode(conf) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala index 00df0195279c3..c65a7bcff8503 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala @@ -76,13 +76,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: // Output iterator for results from Python. val outputIterator = new PythonRunner( - udf.command, - udf.envVars, - udf.pythonIncludes, - udf.pythonExec, - udf.pythonVer, - udf.broadcastVars, - udf.accumulator, + udf.func, bufferSize, reuseWorker ).compute(inputIterator, context.partitionId(), context) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala index 9aff0be716d4b..0aa2785cb6b9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala @@ -17,9 +17,8 @@ package org.apache.spark.sql.execution.python -import org.apache.spark.{Accumulator, Logging} -import org.apache.spark.api.python.PythonBroadcast -import org.apache.spark.broadcast.Broadcast +import org.apache.spark.Logging +import org.apache.spark.api.python.PythonFunction import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable} import org.apache.spark.sql.types.DataType @@ -28,13 +27,7 @@ import org.apache.spark.sql.types.DataType */ case class PythonUDF( name: String, - command: Array[Byte], - envVars: java.util.Map[String, String], - pythonIncludes: java.util.List[String], - pythonExec: String, - pythonVer: String, - broadcastVars: java.util.List[Broadcast[PythonBroadcast]], - accumulator: Accumulator[java.util.List[Array[Byte]]], + func: PythonFunction, dataType: DataType, children: Seq[Expression]) extends Expression with Unevaluable with NonSQLExpression with Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index 79ac1c85c0be5..d301874c223d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.execution.python -import org.apache.spark.Accumulator -import org.apache.spark.api.python.PythonBroadcast -import org.apache.spark.broadcast.Broadcast +import org.apache.spark.api.python.PythonFunction import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.Column import org.apache.spark.sql.types.DataType @@ -29,18 +27,11 @@ import org.apache.spark.sql.types.DataType */ case class UserDefinedPythonFunction( name: String, - command: Array[Byte], - envVars: java.util.Map[String, String], - pythonIncludes: java.util.List[String], - pythonExec: String, - pythonVer: String, - broadcastVars: java.util.List[Broadcast[PythonBroadcast]], - accumulator: Accumulator[java.util.List[Array[Byte]]], + func: PythonFunction, dataType: DataType) { def builder(e: Seq[Expression]): PythonUDF = { - PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, - accumulator, dataType, e) + PythonUDF(name, func, dataType, e) } /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 7d701949afcf2..26e4eda542d51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.stat +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.Logging import org.apache.spark.sql.{Column, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.{Cast, GenericMutableRow} @@ -27,6 +29,313 @@ import org.apache.spark.unsafe.types.UTF8String private[sql] object StatFunctions extends Logging { + import QuantileSummaries.Stats + + /** + * Calculates the approximate quantiles of multiple numerical columns of a DataFrame in one pass. + * + * The result of this algorithm has the following deterministic bound: + * If the DataFrame has N elements and if we request the quantile at probability `p` up to error + * `err`, then the algorithm will return a sample `x` from the DataFrame so that the *exact* rank + * of `x` is close to (p * N). + * More precisely, + * + * floor((p - err) * N) <= rank(x) <= ceil((p + err) * N). + * + * This method implements a variation of the Greenwald-Khanna algorithm (with some speed + * optimizations). + * The algorithm was first present in [[http://dx.doi.org/10.1145/375663.375670 Space-efficient + * Online Computation of Quantile Summaries]] by Greenwald and Khanna. + * + * @param df the dataframe + * @param cols numerical columns of the dataframe + * @param probabilities a list of quantile probabilities + * Each number must belong to [0, 1]. + * For example 0 is the minimum, 0.5 is the median, 1 is the maximum. + * @param relativeError The relative target precision to achieve (>= 0). + * If set to zero, the exact quantiles are computed, which could be very expensive. + * Note that values greater than 1 are accepted but give the same result as 1. + * + * @return for each column, returns the requested approximations + */ + def multipleApproxQuantiles( + df: DataFrame, + cols: Seq[String], + probabilities: Seq[Double], + relativeError: Double): Seq[Seq[Double]] = { + val columns: Seq[Column] = cols.map { colName => + val field = df.schema(colName) + require(field.dataType.isInstanceOf[NumericType], + s"Quantile calculation for column $colName with data type ${field.dataType}" + + " is not supported.") + Column(Cast(Column(colName).expr, DoubleType)) + } + val emptySummaries = Array.fill(cols.size)( + new QuantileSummaries(QuantileSummaries.defaultCompressThreshold, relativeError)) + + // Note that it works more or less by accident as `rdd.aggregate` is not a pure function: + // this function returns the same array as given in the input (because `aggregate` reuses + // the same argument). + def apply(summaries: Array[QuantileSummaries], row: Row): Array[QuantileSummaries] = { + var i = 0 + while (i < summaries.length) { + summaries(i) = summaries(i).insert(row.getDouble(i)) + i += 1 + } + summaries + } + + def merge( + sum1: Array[QuantileSummaries], + sum2: Array[QuantileSummaries]): Array[QuantileSummaries] = { + sum1.zip(sum2).map { case (s1, s2) => s1.compress().merge(s2.compress()) } + } + val summaries = df.select(columns: _*).rdd.aggregate(emptySummaries)(apply, merge) + + summaries.map { summary => probabilities.map(summary.query) } + } + + /** + * Helper class to compute approximate quantile summary. + * This implementation is based on the algorithm proposed in the paper: + * "Space-efficient Online Computation of Quantile Summaries" by Greenwald, Michael + * and Khanna, Sanjeev. (http://dx.doi.org/10.1145/375663.375670) + * + * In order to optimize for speed, it maintains an internal buffer of the last seen samples, + * and only inserts them after crossing a certain size threshold. This guarantees a near-constant + * runtime complexity compared to the original algorithm. + * + * @param compressThreshold the compression threshold. + * After the internal buffer of statistics crosses this size, it attempts to compress the + * statistics together. + * @param relativeError the target relative error. + * It is uniform across the complete range of values. + * @param sampled a buffer of quantile statistics. + * See the G-K article for more details. + * @param count the count of all the elements *inserted in the sampled buffer* + * (excluding the head buffer) + * @param headSampled a buffer of latest samples seen so far + */ + class QuantileSummaries( + val compressThreshold: Int, + val relativeError: Double, + val sampled: ArrayBuffer[Stats] = ArrayBuffer.empty, + private[stat] var count: Long = 0L, + val headSampled: ArrayBuffer[Double] = ArrayBuffer.empty) extends Serializable { + + import QuantileSummaries._ + + /** + * Returns a summary with the given observation inserted into the summary. + * This method may either modify in place the current summary (and return the same summary, + * modified in place), or it may create a new summary from scratch it necessary. + * @param x the new observation to insert into the summary + */ + def insert(x: Double): QuantileSummaries = { + headSampled.append(x) + if (headSampled.size >= defaultHeadSize) { + this.withHeadBufferInserted + } else { + this + } + } + + /** + * Inserts an array of (unsorted samples) in a batch, sorting the array first to traverse + * the summary statistics in a single batch. + * + * This method does not modify the current object and returns if necessary a new copy. + * + * @return a new quantile summary object. + */ + private def withHeadBufferInserted: QuantileSummaries = { + if (headSampled.isEmpty) { + return this + } + var currentCount = count + val sorted = headSampled.toArray.sorted + val newSamples: ArrayBuffer[Stats] = new ArrayBuffer[Stats]() + // The index of the next element to insert + var sampleIdx = 0 + // The index of the sample currently being inserted. + var opsIdx: Int = 0 + while(opsIdx < sorted.length) { + val currentSample = sorted(opsIdx) + // Add all the samples before the next observation. + while(sampleIdx < sampled.size && sampled(sampleIdx).value <= currentSample) { + newSamples.append(sampled(sampleIdx)) + sampleIdx += 1 + } + + // If it is the first one to insert, of if it is the last one + currentCount += 1 + val delta = + if (newSamples.isEmpty || (sampleIdx == sampled.size && opsIdx == sorted.length - 1)) { + 0 + } else { + math.floor(2 * relativeError * currentCount).toInt + } + + val tuple = Stats(currentSample, 1, delta) + newSamples.append(tuple) + opsIdx += 1 + } + + // Add all the remaining existing samples + while(sampleIdx < sampled.size) { + newSamples.append(sampled(sampleIdx)) + sampleIdx += 1 + } + new QuantileSummaries(compressThreshold, relativeError, newSamples, currentCount) + } + + /** + * Returns a new summary that compresses the summary statistics and the head buffer. + * + * This implements the COMPRESS function of the GK algorithm. It does not modify the object. + * + * @return a new summary object with compressed statistics + */ + def compress(): QuantileSummaries = { + // Inserts all the elements first + val inserted = this.withHeadBufferInserted + assert(inserted.headSampled.isEmpty) + assert(inserted.count == count + headSampled.size) + val compressed = + compressImmut(inserted.sampled, mergeThreshold = 2 * relativeError * inserted.count) + new QuantileSummaries(compressThreshold, relativeError, compressed, inserted.count) + } + + private def shallowCopy: QuantileSummaries = { + new QuantileSummaries(compressThreshold, relativeError, sampled, count, headSampled) + } + + /** + * Merges two (compressed) summaries together. + * + * Returns a new summary. + */ + def merge(other: QuantileSummaries): QuantileSummaries = { + require(headSampled.isEmpty, "Current buffer needs to be compressed before merge") + require(other.headSampled.isEmpty, "Other buffer needs to be compressed before merge") + if (other.count == 0) { + this.shallowCopy + } else if (count == 0) { + other.shallowCopy + } else { + // Merge the two buffers. + // The GK algorithm is a bit unclear about it, but it seems there is no need to adjust the + // statistics during the merging: the invariants are still respected after the merge. + // TODO: could replace full sort by ordered merge, the two lists are known to be sorted + // already. + val res = (sampled ++ other.sampled).sortBy(_.value) + val comp = compressImmut(res, mergeThreshold = 2 * relativeError * count) + new QuantileSummaries( + other.compressThreshold, other.relativeError, comp, other.count + count) + } + } + + /** + * Runs a query for a given quantile. + * The result follows the approximation guarantees detailed above. + * The query can only be run on a compressed summary: you need to call compress() before using + * it. + * + * @param quantile the target quantile + * @return + */ + def query(quantile: Double): Double = { + require(quantile >= 0 && quantile <= 1.0, "quantile should be in the range [0.0, 1.0]") + require(headSampled.isEmpty, + "Cannot operate on an uncompressed summary, call compress() first") + + if (quantile <= relativeError) { + return sampled.head.value + } + + if (quantile >= 1 - relativeError) { + return sampled.last.value + } + + // Target rank + val rank = math.ceil(quantile * count).toInt + val targetError = math.ceil(relativeError * count) + // Minimum rank at current sample + var minRank = 0 + var i = 1 + while (i < sampled.size - 1) { + val curSample = sampled(i) + minRank += curSample.g + val maxRank = minRank + curSample.delta + if (maxRank - targetError <= rank && rank <= minRank + targetError) { + return curSample.value + } + i += 1 + } + sampled.last.value + } + } + + object QuantileSummaries { + // TODO(tjhunter) more tuning could be done one the constants here, but for now + // the main cost of the algorithm is accessing the data in SQL. + /** + * The default value for the compression threshold. + */ + val defaultCompressThreshold: Int = 10000 + + /** + * The size of the head buffer. + */ + val defaultHeadSize: Int = 50000 + + /** + * The default value for the relative error (1%). + * With this value, the best extreme percentiles that can be approximated are 1% and 99%. + */ + val defaultRelativeError: Double = 0.01 + + /** + * Statisttics from the Greenwald-Khanna paper. + * @param value the sampled value + * @param g the minimum rank jump from the previous value's minimum rank + * @param delta the maximum span of the rank. + */ + case class Stats(value: Double, g: Int, delta: Int) + + private def compressImmut( + currentSamples: IndexedSeq[Stats], + mergeThreshold: Double): ArrayBuffer[Stats] = { + val res: ArrayBuffer[Stats] = ArrayBuffer.empty + if (currentSamples.isEmpty) { + return res + } + // Start for the last element, which is always part of the set. + // The head contains the current new head, that may be merged with the current element. + var head = currentSamples.last + var i = currentSamples.size - 2 + // Do not compress the last element + while (i >= 1) { + // The current sample: + val sample1 = currentSamples(i) + // Do we need to compress? + if (sample1.g + head.g + head.delta < mergeThreshold) { + // Do not insert yet, just merge the current element into the head. + head = head.copy(g = head.g + sample1.g) + } else { + // Prepend the current head, and keep the current sample as target for merging. + res.prepend(head) + head = sample1 + } + i -= 1 + } + res.prepend(head) + // If necessary, add the minimum element: + res.prepend(currentSamples.head) + res + } + } + /** Calculate the Pearson Correlation Coefficient for the given columns */ private[sql] def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = { val counts = collectStatisticalData(df, cols, "correlation") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 510894afac7be..b9873d38a664f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1972,7 +1972,7 @@ object functions extends LegacyFunctions { def crc32(e: Column): Column = withExpr { Crc32(e.expr) } /** - * Calculates the hash code of given columns, and returns the result as a int column. + * Calculates the hash code of given columns, and returns the result as an int column. * * @group misc_funcs * @since 2.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala new file mode 100644 index 0000000000000..058df1e3c19a7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import org.apache.spark.sql.RuntimeConfig + +/** + * Implementation for [[RuntimeConfig]]. + */ +class RuntimeConfigImpl extends RuntimeConfig { + + private val conf = new SQLConf + + private val hadoopConf = java.util.Collections.synchronizedMap( + new java.util.HashMap[String, String]()) + + override def set(key: String, value: String): RuntimeConfig = { + conf.setConfString(key, value) + this + } + + override def set(key: String, value: Boolean): RuntimeConfig = set(key, value.toString) + + override def set(key: String, value: Long): RuntimeConfig = set(key, value.toString) + + @throws[NoSuchElementException]("if the key is not set") + override def get(key: String): String = conf.getConfString(key) + + override def getOption(key: String): Option[String] = { + try Option(get(key)) catch { + case _: NoSuchElementException => None + } + } + + override def unset(key: String): Unit = conf.unsetConf(key) + + override def setHadoop(key: String, value: String): RuntimeConfig = { + hadoopConf.put(key, value) + this + } + + @throws[NoSuchElementException]("if the key is not set") + override def getHadoop(key: String): String = hadoopConf.synchronized { + if (hadoopConf.containsKey(key)) { + hadoopConf.get(key) + } else { + throw new NoSuchElementException(key) + } + } + + override def getHadoopOption(key: String): Option[String] = { + try Option(getHadoop(key)) catch { + case _: NoSuchElementException => None + } + } + + override def unsetHadoop(key: String): Unit = hadoopConf.remove(key) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala similarity index 89% rename from sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala rename to sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index a601c87fc930e..1d1e2884414d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -15,12 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.internal -import java.util.Properties +import java.util.{NoSuchElementException, Properties} -import scala.collection.immutable import scala.collection.JavaConverters._ +import scala.collection.immutable import org.apache.parquet.hadoop.ParquetOutputCommitter @@ -34,7 +34,7 @@ import org.apache.spark.util.Utils //////////////////////////////////////////////////////////////////////////////////////////////////// -private[spark] object SQLConf { +object SQLConf { private val sqlConfEntries = java.util.Collections.synchronizedMap( new java.util.HashMap[String, SQLConfEntry[_]]()) @@ -55,7 +55,7 @@ private[spark] object SQLConf { * configuration is only used internally and we should not expose it to the user. * @tparam T the value type */ - private[sql] class SQLConfEntry[T] private( + class SQLConfEntry[T] private( val key: String, val defaultValue: Option[T], val valueConverter: String => T, @@ -70,7 +70,7 @@ private[spark] object SQLConf { } } - private[sql] object SQLConfEntry { + object SQLConfEntry { private def apply[T]( key: String, @@ -345,12 +345,9 @@ private[spark] object SQLConf { defaultValue = Some(true), doc = "Enables using the custom ParquetUnsafeRowRecordReader.") - // Note: this can not be enabled all the time because the reader will not be returning UnsafeRows. - // Doing so is very expensive and we should remove this requirement instead of fixing it here. - // Initial testing seems to indicate only sort requires this. val PARQUET_VECTORIZED_READER_ENABLED = booleanConf( key = "spark.sql.parquet.enableVectorizedReader", - defaultValue = Some(false), + defaultValue = Some(true), doc = "Enables vectorized parquet decoding.") val ORC_FILTER_PUSHDOWN_ENABLED = booleanConf("spark.sql.orc.filterPushdown", @@ -528,7 +525,7 @@ private[spark] object SQLConf { * * SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads). */ -private[sql] class SQLConf extends Serializable with CatalystConf with ParserConf with Logging { +class SQLConf extends Serializable with CatalystConf with ParserConf with Logging { import SQLConf._ /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @@ -537,85 +534,85 @@ private[sql] class SQLConf extends Serializable with CatalystConf with ParserCon /** ************************ Spark SQL Params/Hints ******************* */ - private[spark] def useCompression: Boolean = getConf(COMPRESS_CACHED) + def useCompression: Boolean = getConf(COMPRESS_CACHED) - private[spark] def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION) + def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION) - private[spark] def parquetCacheMetadata: Boolean = getConf(PARQUET_CACHE_METADATA) + def parquetCacheMetadata: Boolean = getConf(PARQUET_CACHE_METADATA) - private[spark] def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE) + def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE) - private[spark] def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS) + def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS) - private[spark] def targetPostShuffleInputSize: Long = + def targetPostShuffleInputSize: Long = getConf(SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) - private[spark] def adaptiveExecutionEnabled: Boolean = getConf(ADAPTIVE_EXECUTION_ENABLED) + def adaptiveExecutionEnabled: Boolean = getConf(ADAPTIVE_EXECUTION_ENABLED) - private[spark] def minNumPostShufflePartitions: Int = + def minNumPostShufflePartitions: Int = getConf(SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS) - private[spark] def parquetFilterPushDown: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED) + def parquetFilterPushDown: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED) - private[spark] def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) + def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) - private[spark] def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITION_PATH) + def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITION_PATH) - private[spark] def metastorePartitionPruning: Boolean = getConf(HIVE_METASTORE_PARTITION_PRUNING) + def metastorePartitionPruning: Boolean = getConf(HIVE_METASTORE_PARTITION_PRUNING) - private[spark] def nativeView: Boolean = getConf(NATIVE_VIEW) + def nativeView: Boolean = getConf(NATIVE_VIEW) - private[spark] def wholeStageEnabled: Boolean = getConf(WHOLESTAGE_CODEGEN_ENABLED) + def wholeStageEnabled: Boolean = getConf(WHOLESTAGE_CODEGEN_ENABLED) - private[spark] def canonicalView: Boolean = getConf(CANONICAL_NATIVE_VIEW) + def canonicalView: Boolean = getConf(CANONICAL_NATIVE_VIEW) def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) - private[spark] def subexpressionEliminationEnabled: Boolean = + def subexpressionEliminationEnabled: Boolean = getConf(SUBEXPRESSION_ELIMINATION_ENABLED) - private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) + def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) - private[spark] def defaultSizeInBytes: Long = + def defaultSizeInBytes: Long = getConf(DEFAULT_SIZE_IN_BYTES, autoBroadcastJoinThreshold + 1L) - private[spark] def isParquetBinaryAsString: Boolean = getConf(PARQUET_BINARY_AS_STRING) + def isParquetBinaryAsString: Boolean = getConf(PARQUET_BINARY_AS_STRING) - private[spark] def isParquetINT96AsTimestamp: Boolean = getConf(PARQUET_INT96_AS_TIMESTAMP) + def isParquetINT96AsTimestamp: Boolean = getConf(PARQUET_INT96_AS_TIMESTAMP) - private[spark] def writeLegacyParquetFormat: Boolean = getConf(PARQUET_WRITE_LEGACY_FORMAT) + def writeLegacyParquetFormat: Boolean = getConf(PARQUET_WRITE_LEGACY_FORMAT) - private[spark] def inMemoryPartitionPruning: Boolean = getConf(IN_MEMORY_PARTITION_PRUNING) + def inMemoryPartitionPruning: Boolean = getConf(IN_MEMORY_PARTITION_PRUNING) - private[spark] def columnNameOfCorruptRecord: String = getConf(COLUMN_NAME_OF_CORRUPT_RECORD) + def columnNameOfCorruptRecord: String = getConf(COLUMN_NAME_OF_CORRUPT_RECORD) - private[spark] def broadcastTimeout: Int = getConf(BROADCAST_TIMEOUT) + def broadcastTimeout: Int = getConf(BROADCAST_TIMEOUT) - private[spark] def defaultDataSourceName: String = getConf(DEFAULT_DATA_SOURCE_NAME) + def defaultDataSourceName: String = getConf(DEFAULT_DATA_SOURCE_NAME) - private[spark] def partitionDiscoveryEnabled(): Boolean = + def partitionDiscoveryEnabled(): Boolean = getConf(SQLConf.PARTITION_DISCOVERY_ENABLED) - private[spark] def partitionColumnTypeInferenceEnabled(): Boolean = + def partitionColumnTypeInferenceEnabled(): Boolean = getConf(SQLConf.PARTITION_COLUMN_TYPE_INFERENCE) - private[spark] def parallelPartitionDiscoveryThreshold: Int = + def parallelPartitionDiscoveryThreshold: Int = getConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD) - private[spark] def bucketingEnabled(): Boolean = getConf(SQLConf.BUCKETING_ENABLED) + def bucketingEnabled(): Boolean = getConf(SQLConf.BUCKETING_ENABLED) // Do not use a value larger than 4000 as the default value of this property. // See the comments of SCHEMA_STRING_LENGTH_THRESHOLD above for more information. - private[spark] def schemaStringLengthThreshold: Int = getConf(SCHEMA_STRING_LENGTH_THRESHOLD) + def schemaStringLengthThreshold: Int = getConf(SCHEMA_STRING_LENGTH_THRESHOLD) - private[spark] def dataFrameEagerAnalysis: Boolean = getConf(DATAFRAME_EAGER_ANALYSIS) + def dataFrameEagerAnalysis: Boolean = getConf(DATAFRAME_EAGER_ANALYSIS) - private[spark] def dataFrameSelfJoinAutoResolveAmbiguity: Boolean = + def dataFrameSelfJoinAutoResolveAmbiguity: Boolean = getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY) - private[spark] def dataFrameRetainGroupColumns: Boolean = getConf(DATAFRAME_RETAIN_GROUP_COLUMNS) + def dataFrameRetainGroupColumns: Boolean = getConf(DATAFRAME_RETAIN_GROUP_COLUMNS) - private[spark] def runSQLOnFile: Boolean = getConf(RUN_SQL_ON_FILES) + def runSQLOnFile: Boolean = getConf(RUN_SQL_ON_FILES) def supportQuotedId: Boolean = getConf(PARSER_SUPPORT_QUOTEDID) @@ -649,6 +646,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf with ParserCon } /** Return the value of Spark SQL configuration property for the given key. */ + @throws[NoSuchElementException]("if key is not set") def getConfString(key: String): String = { Option(settings.get(key)). orElse { @@ -715,15 +713,15 @@ private[sql] class SQLConf extends Serializable with CatalystConf with ParserCon settings.put(key, value) } - private[spark] def unsetConf(key: String): Unit = { + def unsetConf(key: String): Unit = { settings.remove(key) } - private[spark] def unsetConf(entry: SQLConfEntry[_]): Unit = { + def unsetConf(entry: SQLConfEntry[_]): Unit = { settings.remove(entry.key) } - private[spark] def clear(): Unit = { + def clear(): Unit = { settings.clear() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala new file mode 100644 index 0000000000000..f93a405f77fc7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import org.apache.spark.sql.{ContinuousQueryManager, SQLContext, UDFRegistration} +import org.apache.spark.sql.catalyst.ParserInterface +import org.apache.spark.sql.catalyst.analysis.{Analyzer, Catalog, FunctionRegistry, SimpleCatalog} +import org.apache.spark.sql.catalyst.optimizer.Optimizer +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.datasources.{PreInsertCastAndRename, ResolveDataSource} +import org.apache.spark.sql.execution.exchange.EnsureRequirements +import org.apache.spark.sql.util.ExecutionListenerManager + + +/** + * A class that holds all session-specific state in a given [[SQLContext]]. + */ +private[sql] class SessionState(ctx: SQLContext) { + + // Note: These are all lazy vals because they depend on each other (e.g. conf) and we + // want subclasses to override some of the fields. Otherwise, we would get a lot of NPEs. + + /** + * SQL-specific key-value configurations. + */ + lazy val conf = new SQLConf + + /** + * Internal catalog for managing table and database states. + */ + lazy val catalog: Catalog = new SimpleCatalog(conf) + + /** + * Internal catalog for managing functions registered by the user. + */ + lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin.copy() + + /** + * Interface exposed to the user for registering user-defined functions. + */ + lazy val udf: UDFRegistration = new UDFRegistration(functionRegistry) + + /** + * Logical query plan analyzer for resolving unresolved attributes and relations. + */ + lazy val analyzer: Analyzer = { + new Analyzer(catalog, functionRegistry, conf) { + override val extendedResolutionRules = + python.ExtractPythonUDFs :: + PreInsertCastAndRename :: + (if (conf.runSQLOnFile) new ResolveDataSource(ctx) :: Nil else Nil) + + override val extendedCheckRules = Seq(datasources.PreWriteCheck(catalog)) + } + } + + /** + * Logical query plan optimizer. + */ + lazy val optimizer: Optimizer = new SparkOptimizer(ctx) + + /** + * Parser that extracts expressions, plans, table identifiers etc. from SQL texts. + */ + lazy val sqlParser: ParserInterface = new SparkQl(conf) + + /** + * Planner that converts optimized logical plans to physical plans. + */ + lazy val planner: SparkPlanner = new SparkPlanner(ctx) + + /** + * Prepares a planned [[SparkPlan]] for execution by inserting shuffle operations and internal + * row format conversions as needed. + */ + lazy val prepareForExecution = new RuleExecutor[SparkPlan] { + override val batches: Seq[Batch] = Seq( + Batch("Subquery", Once, PlanSubqueries(ctx)), + Batch("Add exchange", Once, EnsureRequirements(ctx)), + Batch("Whole stage codegen", Once, CollapseCodegenStages(ctx)) + ) + } + + /** + * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s + * that listen for execution metrics. + */ + lazy val listenerManager: ExecutionListenerManager = new ExecutionListenerManager + + /** + * Interface to start and stop [[org.apache.spark.sql.ContinuousQuery]]s. + */ + lazy val continuousQueryManager: ContinuousQueryManager = new ContinuousQueryManager(ctx) + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/package-info.java b/sql/core/src/main/scala/org/apache/spark/sql/internal/package-info.java new file mode 100644 index 0000000000000..1e801cb6ee2a4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * All classes in this package are considered an internal API to Spark and + * are subject to change between minor releases. + */ +package org.apache.spark.sql.internal; diff --git a/core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/package.scala similarity index 68% rename from core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala rename to sql/core/src/main/scala/org/apache/spark/sql/internal/package.scala index c1b8bf052c0ca..c2394f42e552d 100644 --- a/core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/package.scala @@ -15,19 +15,10 @@ * limitations under the License. */ -package org.apache.spark.util - -import java.util.EventListener - -import org.apache.spark.TaskContext -import org.apache.spark.annotation.DeveloperApi +package org.apache.spark.sql /** - * :: DeveloperApi :: - * - * Listener providing a callback function to invoke when a task's execution completes. + * All classes in this package are considered an internal API to Spark and + * are subject to change between minor releases. */ -@DeveloperApi -trait TaskCompletionListener extends EventListener { - def onTaskCompletion(context: TaskContext) -} +package object internal diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index 4165c382689f9..8e432e8f3d96b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -42,4 +42,9 @@ private case object OracleDialect extends JdbcDialect { None } } + + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { + case StringType => Some(JdbcType("VARCHAR2(255)", java.sql.Types.VARCHAR)) + case _ => None + } } diff --git a/sql/core/src/test/resources/simple_sparse.csv b/sql/core/src/test/resources/simple_sparse.csv new file mode 100644 index 0000000000000..02d29cabf95f2 --- /dev/null +++ b/sql/core/src/test/resources/simple_sparse.csv @@ -0,0 +1,5 @@ +A,B,C,D +1,,, +,1,, +,,1, +,,,1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 83d7953aaf700..efa2eeaf4d751 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -46,7 +46,9 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext } def isMaterialized(rddId: Int): Boolean = { - sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty + val maybeBlock = sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)) + maybeBlock.foreach(_ => sparkContext.env.blockManager.releaseLock(RDDBlockId(rddId, 0))) + maybeBlock.nonEmpty } test("withColumn doesn't invalidate cached dataframe") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 78bf6c1bcebf2..7d96ef6fe0a10 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.DecimalType @@ -256,7 +257,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } test("count") { - assert(testData2.count() === testData2.map(_ => 1).count()) + assert(testData2.count() === testData2.rdd.map(_ => 1).count()) checkAnswer( testData2.agg(count('a), sumDistinct('a)), // non-partial diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index bc1a336ea4fd0..368aa5cd141f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext class DataFramePivotSuite extends QueryTest with SharedSQLContext{ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index f01f126f7696d..e865dbe6b5063 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -21,6 +21,8 @@ import java.util.Random import org.scalatest.Matchers._ +import org.apache.spark.Logging +import org.apache.spark.sql.execution.stat.StatFunctions import org.apache.spark.sql.functions.col import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.DoubleType @@ -123,6 +125,33 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { assert(math.abs(decimalRes) < 1e-12) } + test("approximate quantile") { + val n = 1000 + val df = Seq.tabulate(n)(i => (i, 2.0 * i)).toDF("singles", "doubles") + + val q1 = 0.5 + val q2 = 0.8 + val epsilons = List(0.1, 0.05, 0.001) + + for (epsilon <- epsilons) { + val Array(single1) = df.stat.approxQuantile("singles", Array(q1), epsilon) + val Array(double2) = df.stat.approxQuantile("doubles", Array(q2), epsilon) + // Also make sure there is no regression by computing multiple quantiles at once. + val Array(d1, d2) = df.stat.approxQuantile("doubles", Array(q1, q2), epsilon) + val Array(s1, s2) = df.stat.approxQuantile("singles", Array(q1, q2), epsilon) + + val error_single = 2 * 1000 * epsilon + val error_double = 2 * 2000 * epsilon + + assert(math.abs(single1 - q1 * n) < error_single) + assert(math.abs(double2 - 2 * q2 * n) < error_double) + assert(math.abs(s1 - q1 * n) < error_single) + assert(math.abs(s2 - q2 * n) < error_single) + assert(math.abs(d1 - 2 * q1 * n) < error_double) + assert(math.abs(d2 - 2 * q2 * n) < error_double) + } + } + test("crosstab") { val rng = new Random() val data = Seq.tabulate(25)(i => (rng.nextInt(5), rng.nextInt(10))) @@ -269,3 +298,40 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { assert(0.until(1000).forall(i => filter4.mightContain(i * 3))) } } + + +class DataFrameStatPerfSuite extends QueryTest with SharedSQLContext with Logging { + + // Turn on this test if you want to test the performance of approximate quantiles. + ignore("computing quantiles should not take much longer than describe()") { + val df = sqlContext.range(5000000L).toDF("col1").cache() + def seconds(f: => Any): Double = { + // Do some warmup + logDebug("warmup...") + for (i <- 1 to 10) { + df.count() + f + } + logDebug("execute...") + // Do it 10 times and report median + val times = (1 to 10).map { i => + val start = System.nanoTime() + f + val end = System.nanoTime() + (end - start) / 1e9 + } + logDebug("execute done") + times.sum / times.length.toDouble + } + + logDebug("*** Normal describe ***") + val t1 = seconds { df.describe() } + logDebug(s"T1 = $t1") + logDebug("*** Just quantiles ***") + val t2 = seconds { + StatFunctions.multipleApproxQuantiles(df, Seq("col1"), Seq(0.1, 0.25, 0.5, 0.75, 0.9), 0.01) + } + logDebug(s"T1 = $t1, T2 = $t2") + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 4930c485da83f..84f30c0aaf862 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union} import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext} import org.apache.spark.sql.test.SQLTestData.TestData2 import org.apache.spark.sql.types._ @@ -194,6 +195,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row("a", Seq("a"), 1) :: Nil) } + test("sort after generate with join=true") { + val df = Seq((Array("a"), 1)).toDF("a", "b") + + checkAnswer( + df.select($"*", explode($"a").as("c")).sortWithinPartitions("b", "c"), + Row(Seq("a"), 1, "a") :: Nil) + } + test("selectExpr") { checkAnswer( testData.selectExpr("abs(key)", "value"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 498f007081918..33df6375e3aad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -613,12 +613,30 @@ class DatasetSuite extends QueryTest with SharedSQLContext { " - Input schema: struct\n" + " - Target schema: struct<_1:string>") } + + test("SPARK-13440: Resolving option fields") { + val df = Seq(1, 2, 3).toDS() + val ds = df.as[Option[Int]] + checkAnswer( + ds.filter(_ => true), + Some(1), Some(2), Some(3)) + } + + test("SPARK-13540 Dataset of nested class defined in Scala object") { + checkAnswer( + Seq(OuterObject.InnerClass("foo")).toDS(), + OuterObject.InnerClass("foo")) + } } class OuterClass extends Serializable { case class InnerClass(a: String) } +object OuterObject { + case class InnerClass(a: String) +} + case class ClassData(a: String, b: Int) case class ClassData2(c: String, d: Int) case class ClassNullableData(a: String, b: Integer) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 8f2a0c0351361..5b98c11ef2a4b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -46,7 +47,6 @@ class JoinSuite extends QueryTest with SharedSQLContext { val operators = physical.collect { case j: LeftSemiJoinHash => j case j: BroadcastHashJoin => j - case j: LeftSemiJoinBNL => j case j: CartesianProduct => j case j: BroadcastNestedLoopJoin => j case j: BroadcastLeftSemiJoinHash => j @@ -63,36 +63,41 @@ class JoinSuite extends QueryTest with SharedSQLContext { test("join operator selection") { sqlContext.cacheManager.clearCache() - Seq( - ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), - ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]), - ("SELECT * FROM testData JOIN testData2", classOf[CartesianProduct]), - ("SELECT * FROM testData JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), - ("SELECT * FROM testData LEFT JOIN testData2", classOf[CartesianProduct]), - ("SELECT * FROM testData RIGHT JOIN testData2", classOf[CartesianProduct]), - ("SELECT * FROM testData FULL OUTER JOIN testData2", classOf[CartesianProduct]), - ("SELECT * FROM testData LEFT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), - ("SELECT * FROM testData RIGHT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), - ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), - ("SELECT * FROM testData JOIN testData2 WHERE key > a", classOf[CartesianProduct]), - ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a", classOf[CartesianProduct]), - ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]), - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeOuterJoin]), - ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[SortMergeJoin]), // converted from Right Outer to Inner - ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[SortMergeOuterJoin]), - ("SELECT * FROM testData full outer join testData2 ON key = a", - classOf[SortMergeOuterJoin]), - ("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)", - classOf[BroadcastNestedLoopJoin]), - ("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)", - classOf[BroadcastNestedLoopJoin]), - ("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)", - classOf[BroadcastNestedLoopJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0") { + Seq( + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), + ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData JOIN testData2", classOf[CartesianProduct]), + ("SELECT * FROM testData JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), + ("SELECT * FROM testData LEFT JOIN testData2", classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2", classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData FULL OUTER JOIN testData2", classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 WHERE key = 2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), + ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData JOIN testData2 WHERE key > a", classOf[CartesianProduct]), + ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a", + classOf[CartesianProduct]), + ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeOuterJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", + classOf[SortMergeJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key = 2", + classOf[SortMergeOuterJoin]), + ("SELECT * FROM testData full outer join testData2 ON key = a", + classOf[SortMergeOuterJoin]), + ("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)", + classOf[BroadcastNestedLoopJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } } // ignore("SortMergeJoin shouldn't work on unsortable columns") { @@ -118,9 +123,10 @@ class JoinSuite extends QueryTest with SharedSQLContext { test("broadcasted hash outer join operator selection") { sqlContext.cacheManager.clearCache() sql("CACHE TABLE testData") + sql("CACHE TABLE testData2") Seq( ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", - classOf[SortMergeOuterJoin]), + classOf[BroadcastHashJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", classOf[BroadcastHashJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", @@ -458,7 +464,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), ("SELECT * FROM testData LEFT SEMI JOIN testData2", - classOf[LeftSemiJoinBNL]), + classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData JOIN testData2", classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData JOIN testData2 WHERE key = 2", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala index 6a375a33bfcf6..0b5a92c256e57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterAll import org.apache.spark._ +import org.apache.spark.sql.internal.SQLConf class MultiSQLContextsSuite extends SparkFunSuite with BeforeAndAfterAll { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 5401212428d6f..c05aa5486ab15 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.{LogicalRDD, Queryable} import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.internal.SQLConf abstract class QueryTest extends PlanTest { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index 14b9448d260f4..ec19d97d8cec2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import org.apache.spark.{SharedSparkContext, SparkFunSuite} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf class SQLContextSuite extends SparkFunSuite with SharedSparkContext{ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index b4d6f4ecddc6b..16e769feca487 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -22,8 +22,9 @@ import java.sql.Timestamp import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.execution.aggregate -import org.apache.spark.sql.execution.joins.{CartesianProduct, SortMergeJoin} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, CartesianProduct, SortMergeJoin} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext} import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types._ @@ -780,8 +781,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { assert(cp.isEmpty, "should not use CartesianProduct for null-safe join") val smj = df.queryExecution.sparkPlan.collect { case smj: SortMergeJoin => smj + case j: BroadcastHashJoin => j } - assert(smj.size > 0, "should use SortMergeJoin") + assert(smj.size > 0, "should use SortMergeJoin or BroadcastHashJoin") checkAnswer(df, Row(100) :: Nil) } @@ -1979,9 +1981,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { verifyCallCount( df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1) - // Would be nice if semantic equals for `+` understood commutative verifyCallCount( - df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2) + df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 1) // Try disabling it via configuration. sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index bcac660a35a65..2d3e34d0e1292 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -38,6 +38,7 @@ import org.apache.spark.util.Benchmark class BenchmarkWholeStageCodegen extends SparkFunSuite { lazy val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark") .set("spark.sql.shuffle.partitions", "1") + .set("spark.sql.autoBroadcastJoinThreshold", "0") lazy val sc = SparkContext.getOrCreate(conf) lazy val sqlContext = SQLContext.getOrCreate(sc) @@ -69,6 +70,20 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { */ } + ignore("range/limit/sum") { + val N = 500 << 20 + runBenchmark("range/limit/sum", N) { + sqlContext.range(N).limit(1000000).groupBy().sum().collect() + } + /* + Westmere E56xx/L56xx/X56xx (Nehalem-C) + range/limit/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + range/limit/sum codegen=false 609 / 672 861.6 1.2 1.0X + range/limit/sum codegen=true 561 / 621 935.3 1.1 1.1X + */ + } + ignore("stat functions") { val N = 100 << 20 @@ -187,6 +202,39 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { */ } + ignore("sort merge join") { + val N = 2 << 20 + runBenchmark("merge join", N) { + val df1 = sqlContext.range(N).selectExpr(s"id * 2 as k1") + val df2 = sqlContext.range(N).selectExpr(s"id * 3 as k2") + df1.join(df2, col("k1") === col("k2")).count() + } + + /** + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + merge join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + merge join codegen=false 1588 / 1880 1.3 757.1 1.0X + merge join codegen=true 1477 / 1531 1.4 704.2 1.1X + */ + + runBenchmark("sort merge join", N) { + val df1 = sqlContext.range(N) + .selectExpr(s"(id * 15485863) % ${N*10} as k1") + val df2 = sqlContext.range(N) + .selectExpr(s"(id * 15485867) % ${N*10} as k2") + df1.join(df2, col("k1") === col("k2")).count() + } + + /** + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + sort merge join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + sort merge join codegen=false 3626 / 3667 0.6 1728.9 1.0X + sort merge join codegen=true 3405 / 3438 0.6 1623.8 1.1X + */ + } + ignore("rube") { val N = 5 << 20 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index b1c588a63d030..4f01e46633c8e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.{MapOutputStatistics, SparkConf, SparkContext, SparkFunS import org.apache.spark.sql._ import org.apache.spark.sql.execution.exchange.{ExchangeCoordinator, ShuffleExchange} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.TestSQLContext class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 4de56783fabce..f66e08e6ca5c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{execution, Row, SQLConf} +import org.apache.spark.sql.{execution, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} @@ -27,6 +27,7 @@ import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMem import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchange} import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 9350205d791d7..de371d85d9fd7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -69,4 +69,13 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[BroadcastHashJoin]).isDefined) assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2"))) } + + test("Sort should be included in WholeStageCodegen") { + val df = sqlContext.range(3, 0, -1).sort(col("id")) + val plan = df.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[Sort]).isDefined) + assert(df.collect() === Array(Row(1), Row(2), Row(3))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala index 86c2c25c2c7e1..d19fec6140acb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala @@ -21,6 +21,7 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala index 95eb5cf912e2a..0000a5d1efd09 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala @@ -17,19 +17,13 @@ package org.apache.spark.sql.execution.columnar.compression -import java.nio.ByteBuffer -import java.nio.ByteOrder +import java.nio.{ByteBuffer, ByteOrder} import org.apache.commons.lang3.RandomStringUtils import org.apache.commons.math3.distribution.LogNormalDistribution import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, GenericMutableRow} -import org.apache.spark.sql.execution.columnar.BOOLEAN -import org.apache.spark.sql.execution.columnar.INT -import org.apache.spark.sql.execution.columnar.LONG -import org.apache.spark.sql.execution.columnar.NativeColumnType -import org.apache.spark.sql.execution.columnar.SHORT -import org.apache.spark.sql.execution.columnar.STRING +import org.apache.spark.sql.execution.columnar.{BOOLEAN, INT, LONG, NativeColumnType, SHORT, STRING} import org.apache.spark.sql.types.AtomicType import org.apache.spark.util.Benchmark import org.apache.spark.util.Utils._ @@ -53,35 +47,70 @@ object CompressionSchemeBenchmark extends AllCompressionSchemes { () => rng.sample } - private[this] def runBenchmark[T <: AtomicType]( + private[this] def prepareEncodeInternal[T <: AtomicType]( + count: Int, + tpe: NativeColumnType[T], + supportedScheme: CompressionScheme, + input: ByteBuffer): ((ByteBuffer, ByteBuffer) => ByteBuffer, Double, ByteBuffer) = { + assert(supportedScheme.supports(tpe)) + + def toRow(d: Any) = new GenericInternalRow(Array[Any](d)) + val encoder = supportedScheme.encoder(tpe) + for (i <- 0 until count) { + encoder.gatherCompressibilityStats(toRow(tpe.extract(input)), 0) + } + input.rewind() + + val compressedSize = if (encoder.compressedSize == 0) { + input.remaining() + } else { + encoder.compressedSize + } + + (encoder.compress, encoder.compressionRatio, allocateLocal(4 + compressedSize)) + } + + private[this] def runEncodeBenchmark[T <: AtomicType]( name: String, iters: Int, count: Int, tpe: NativeColumnType[T], input: ByteBuffer): Unit = { - val benchmark = new Benchmark(name, iters * count) schemes.filter(_.supports(tpe)).map { scheme => - def toRow(d: Any) = new GenericInternalRow(Array[Any](d)) - val encoder = scheme.encoder(tpe) - for (i <- 0 until count) { - encoder.gatherCompressibilityStats(toRow(tpe.extract(input)), 0) - } - input.rewind() + val (compressFunc, compressionRatio, buf) = prepareEncodeInternal(count, tpe, scheme, input) + val label = s"${getFormattedClassName(scheme)}(${compressionRatio.formatted("%.3f")})" - val label = s"${getFormattedClassName(scheme)}(${encoder.compressionRatio.formatted("%.3f")})" benchmark.addCase(label)({ i: Int => - val compressedSize = if (encoder.compressedSize == 0) { - input.remaining() - } else { - encoder.compressedSize + for (n <- 0L until iters) { + compressFunc(input, buf) + input.rewind() + buf.rewind() } + }) + } + + benchmark.run() + } - val buf = allocateLocal(4 + compressedSize) + private[this] def runDecodeBenchmark[T <: AtomicType]( + name: String, + iters: Int, + count: Int, + tpe: NativeColumnType[T], + input: ByteBuffer): Unit = { + val benchmark = new Benchmark(name, iters * count) + + schemes.filter(_.supports(tpe)).map { scheme => + val (compressFunc, _, buf) = prepareEncodeInternal(count, tpe, scheme, input) + val compressedBuf = compressFunc(input, buf) + val label = s"${getFormattedClassName(scheme)}" + + input.rewind() + + benchmark.addCase(label)({ i: Int => val rowBuf = new GenericMutableRow(1) - val compressedBuf = encoder.compress(input, buf) - input.rewind() for (n <- 0L until iters) { compressedBuf.rewind.position(4) @@ -96,16 +125,10 @@ object CompressionSchemeBenchmark extends AllCompressionSchemes { benchmark.run() } - def bitDecode(iters: Int): Unit = { + def bitEncodingBenchmark(iters: Int): Unit = { val count = 65536 val testData = allocateLocal(count * BOOLEAN.defaultSize) - // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz - // BOOLEAN Decode: Avg Time(ms) Avg Rate(M/s) Relative Rate - // ------------------------------------------------------------------------------- - // PassThrough(1.000) 124.98 536.96 1.00 X - // RunLengthEncoding(2.494) 631.37 106.29 0.20 X - // BooleanBitSet(0.125) 1200.36 55.91 0.10 X val g = { val rng = genLowerSkewData() () => (rng().toInt % 2).toByte @@ -113,110 +136,176 @@ object CompressionSchemeBenchmark extends AllCompressionSchemes { for (i <- 0 until count) { testData.put(i * BOOLEAN.defaultSize, g()) } - runBenchmark("BOOLEAN Decode", iters, count, BOOLEAN, testData) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // BOOLEAN Encode: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough(1.000) 3 / 4 19300.2 0.1 1.0X + // RunLengthEncoding(2.491) 923 / 939 72.7 13.8 0.0X + // BooleanBitSet(0.125) 359 / 363 187.1 5.3 0.0X + runEncodeBenchmark("BOOLEAN Encode", iters, count, BOOLEAN, testData) + + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // BOOLEAN Decode: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough 129 / 136 519.8 1.9 1.0X + // RunLengthEncoding 613 / 623 109.4 9.1 0.2X + // BooleanBitSet 1196 / 1222 56.1 17.8 0.1X + runDecodeBenchmark("BOOLEAN Decode", iters, count, BOOLEAN, testData) } - def shortDecode(iters: Int): Unit = { + def shortEncodingBenchmark(iters: Int): Unit = { val count = 65536 val testData = allocateLocal(count * SHORT.defaultSize) - // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz - // SHORT Decode (Lower Skew): Avg Time(ms) Avg Rate(M/s) Relative Rate - // ------------------------------------------------------------------------------- - // PassThrough(1.000) 376.87 178.07 1.00 X - // RunLengthEncoding(1.498) 831.59 80.70 0.45 X val g1 = genLowerSkewData() for (i <- 0 until count) { testData.putShort(i * SHORT.defaultSize, g1().toShort) } - runBenchmark("SHORT Decode (Lower Skew)", iters, count, SHORT, testData) // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz - // SHORT Decode (Higher Skew): Avg Time(ms) Avg Rate(M/s) Relative Rate - // ------------------------------------------------------------------------------- - // PassThrough(1.000) 426.83 157.23 1.00 X - // RunLengthEncoding(1.996) 845.56 79.37 0.50 X + // SHORT Encode (Lower Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough(1.000) 6 / 7 10971.4 0.1 1.0X + // RunLengthEncoding(1.510) 1526 / 1542 44.0 22.7 0.0X + runEncodeBenchmark("SHORT Encode (Lower Skew)", iters, count, SHORT, testData) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // SHORT Decode (Lower Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough 811 / 837 82.8 12.1 1.0X + // RunLengthEncoding 1219 / 1266 55.1 18.2 0.7X + runDecodeBenchmark("SHORT Decode (Lower Skew)", iters, count, SHORT, testData) + val g2 = genHigherSkewData() for (i <- 0 until count) { testData.putShort(i * SHORT.defaultSize, g2().toShort) } - runBenchmark("SHORT Decode (Higher Skew)", iters, count, SHORT, testData) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // SHORT Encode (Higher Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough(1.000) 7 / 7 10112.4 0.1 1.0X + // RunLengthEncoding(2.009) 1623 / 1661 41.4 24.2 0.0X + runEncodeBenchmark("SHORT Encode (Higher Skew)", iters, count, SHORT, testData) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // SHORT Decode (Higher Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough 818 / 827 82.0 12.2 1.0X + // RunLengthEncoding 1202 / 1237 55.8 17.9 0.7X + runDecodeBenchmark("SHORT Decode (Higher Skew)", iters, count, SHORT, testData) } - def intDecode(iters: Int): Unit = { + def intEncodingBenchmark(iters: Int): Unit = { val count = 65536 val testData = allocateLocal(count * INT.defaultSize) - // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz - // INT Decode(Lower Skew): Avg Time(ms) Avg Rate(M/s) Relative Rate - // ------------------------------------------------------------------------------- - // PassThrough(1.000) 325.16 206.39 1.00 X - // RunLengthEncoding(0.997) 1219.44 55.03 0.27 X - // DictionaryEncoding(0.500) 955.51 70.23 0.34 X - // IntDelta(0.250) 1146.02 58.56 0.28 X val g1 = genLowerSkewData() for (i <- 0 until count) { testData.putInt(i * INT.defaultSize, g1().toInt) } - runBenchmark("INT Decode(Lower Skew)", iters, count, INT, testData) // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz - // INT Decode(Higher Skew): Avg Time(ms) Avg Rate(M/s) Relative Rate - // ------------------------------------------------------------------------------- - // PassThrough(1.000) 1133.45 59.21 1.00 X - // RunLengthEncoding(1.334) 1399.00 47.97 0.81 X - // DictionaryEncoding(0.501) 1032.87 64.97 1.10 X - // IntDelta(0.250) 948.02 70.79 1.20 X + // INT Encode (Lower Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough(1.000) 18 / 19 3716.4 0.3 1.0X + // RunLengthEncoding(1.001) 1992 / 2056 33.7 29.7 0.0X + // DictionaryEncoding(0.500) 723 / 739 92.8 10.8 0.0X + // IntDelta(0.250) 368 / 377 182.2 5.5 0.0X + runEncodeBenchmark("INT Encode (Lower Skew)", iters, count, INT, testData) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // INT Decode (Lower Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough 821 / 845 81.8 12.2 1.0X + // RunLengthEncoding 1246 / 1256 53.9 18.6 0.7X + // DictionaryEncoding 757 / 766 88.6 11.3 1.1X + // IntDelta 680 / 689 98.7 10.1 1.2X + runDecodeBenchmark("INT Decode (Lower Skew)", iters, count, INT, testData) + val g2 = genHigherSkewData() for (i <- 0 until count) { testData.putInt(i * INT.defaultSize, g2().toInt) } - runBenchmark("INT Decode(Higher Skew)", iters, count, INT, testData) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // INT Encode (Higher Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough(1.000) 17 / 19 3888.4 0.3 1.0X + // RunLengthEncoding(1.339) 2127 / 2148 31.5 31.7 0.0X + // DictionaryEncoding(0.501) 960 / 972 69.9 14.3 0.0X + // IntDelta(0.250) 362 / 366 185.5 5.4 0.0X + runEncodeBenchmark("INT Encode (Higher Skew)", iters, count, INT, testData) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // INT Decode (Higher Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough 838 / 884 80.1 12.5 1.0X + // RunLengthEncoding 1287 / 1311 52.1 19.2 0.7X + // DictionaryEncoding 844 / 859 79.5 12.6 1.0X + // IntDelta 764 / 784 87.8 11.4 1.1X + runDecodeBenchmark("INT Decode (Higher Skew)", iters, count, INT, testData) } - def longDecode(iters: Int): Unit = { + def longEncodingBenchmark(iters: Int): Unit = { val count = 65536 val testData = allocateLocal(count * LONG.defaultSize) - // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz - // LONG Decode(Lower Skew): Avg Time(ms) Avg Rate(M/s) Relative Rate - // ------------------------------------------------------------------------------- - // PassThrough(1.000) 1101.07 60.95 1.00 X - // RunLengthEncoding(0.756) 1372.57 48.89 0.80 X - // DictionaryEncoding(0.250) 947.80 70.81 1.16 X - // LongDelta(0.125) 721.51 93.01 1.53 X val g1 = genLowerSkewData() for (i <- 0 until count) { testData.putLong(i * LONG.defaultSize, g1().toLong) } - runBenchmark("LONG Decode(Lower Skew)", iters, count, LONG, testData) // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz - // LONG Decode(Higher Skew): Avg Time(ms) Avg Rate(M/s) Relative Rate - // ------------------------------------------------------------------------------- - // PassThrough(1.000) 986.71 68.01 1.00 X - // RunLengthEncoding(1.013) 1348.69 49.76 0.73 X - // DictionaryEncoding(0.251) 865.48 77.54 1.14 X - // LongDelta(0.125) 816.90 82.15 1.21 X + // LONG Encode (Lower Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough(1.000) 37 / 38 1804.8 0.6 1.0X + // RunLengthEncoding(0.748) 2065 / 2094 32.5 30.8 0.0X + // DictionaryEncoding(0.250) 950 / 962 70.6 14.2 0.0X + // LongDelta(0.125) 475 / 482 141.2 7.1 0.1X + runEncodeBenchmark("LONG Encode (Lower Skew)", iters, count, LONG, testData) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // LONG Decode (Lower Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough 888 / 894 75.5 13.2 1.0X + // RunLengthEncoding 1301 / 1311 51.6 19.4 0.7X + // DictionaryEncoding 887 / 904 75.7 13.2 1.0X + // LongDelta 693 / 735 96.8 10.3 1.3X + runDecodeBenchmark("LONG Decode (Lower Skew)", iters, count, LONG, testData) + val g2 = genHigherSkewData() for (i <- 0 until count) { testData.putLong(i * LONG.defaultSize, g2().toLong) } - runBenchmark("LONG Decode(Higher Skew)", iters, count, LONG, testData) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // LONG Encode (Higher Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough(1.000) 34 / 35 1963.9 0.5 1.0X + // RunLengthEncoding(0.999) 2260 / 3021 29.7 33.7 0.0X + // DictionaryEncoding(0.251) 1270 / 1438 52.8 18.9 0.0X + // LongDelta(0.125) 496 / 509 135.3 7.4 0.1X + runEncodeBenchmark("LONG Encode (Higher Skew)", iters, count, LONG, testData) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // LONG Decode (Higher Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough 965 / 1494 69.5 14.4 1.0X + // RunLengthEncoding 1350 / 1378 49.7 20.1 0.7X + // DictionaryEncoding 892 / 924 75.2 13.3 1.1X + // LongDelta 817 / 847 82.2 12.2 1.2X + runDecodeBenchmark("LONG Decode (Higher Skew)", iters, count, LONG, testData) } - def stringDecode(iters: Int): Unit = { + def stringEncodingBenchmark(iters: Int): Unit = { val count = 65536 val strLen = 8 val tableSize = 16 val testData = allocateLocal(count * (4 + strLen)) - // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz - // STRING Decode: Avg Time(ms) Avg Rate(M/s) Relative Rate - // ------------------------------------------------------------------------------- - // PassThrough(1.000) 2277.05 29.47 1.00 X - // RunLengthEncoding(0.893) 2624.35 25.57 0.87 X - // DictionaryEncoding(0.167) 2672.28 25.11 0.85 X val g = { val dataTable = (0 until tableSize).map(_ => RandomStringUtils.randomAlphabetic(strLen)) val rng = genHigherSkewData() @@ -227,14 +316,29 @@ object CompressionSchemeBenchmark extends AllCompressionSchemes { testData.put(g().getBytes) } testData.rewind() - runBenchmark("STRING Decode", iters, count, STRING, testData) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // STRING Encode: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough(1.000) 56 / 57 1197.9 0.8 1.0X + // RunLengthEncoding(0.893) 4892 / 4937 13.7 72.9 0.0X + // DictionaryEncoding(0.167) 2968 / 2992 22.6 44.2 0.0X + runEncodeBenchmark("STRING Encode", iters, count, STRING, testData) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // STRING Decode: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough 2422 / 2449 27.7 36.1 1.0X + // RunLengthEncoding 2885 / 3018 23.3 43.0 0.8X + // DictionaryEncoding 2716 / 2752 24.7 40.5 0.9X + runDecodeBenchmark("STRING Decode", iters, count, STRING, testData) } def main(args: Array[String]): Unit = { - bitDecode(1024) - shortDecode(1024) - intDecode(1024) - longDecode(1024) - stringDecode(1024) + bitEncodingBenchmark(1024) + shortEncodingBenchmark(1024) + intEncodingBenchmark(1024) + longEncodingBenchmark(1024) + stringEncodingBenchmark(1024) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala index f4cb2081b93a4..7af3f94aefea2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala @@ -79,4 +79,9 @@ class InferSchemaSuite extends SparkFunSuite { assert(CSVInferSchema.inferField(TimestampType, "\\N", "\\N") == TimestampType) assert(CSVInferSchema.inferField(BooleanType, "\\N", "\\N") == BooleanType) } + + test("Merging Nulltypes should yeild Nulltype.") { + val mergedNullTypes = CSVInferSchema.mergeRowTypes(Array(NullType), Array(NullType)) + assert(mergedNullTypes.deep == Array(NullType).deep) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 356fe702eac33..ec0bf9b934928 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -38,6 +38,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { private val commentsFile = "comments.csv" private val disableCommentsFile = "disable_comments.csv" private val boolFile = "bool.csv" + private val simpleSparseFile = "simple_sparse.csv" private def testFile(fileName: String): String = { Thread.currentThread().getContextClassLoader.getResource(fileName).toString @@ -246,7 +247,6 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(result.schema.fieldNames.size === 1) } - test("DDL test with empty file") { sqlContext.sql(s""" |CREATE TEMPORARY TABLE carsTable @@ -281,9 +281,8 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .load(testFile(carsFile)) cars.coalesce(1).write - .format("csv") .option("header", "true") - .save(csvDir) + .csv(csvDir) val carsCopy = sqlContext.read .format("csv") @@ -409,4 +408,16 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { verifyCars(carsCopy, withHeader = true) } } + + test("Schema inference correctly identifies the datatype when data is sparse.") { + val df = sqlContext.read + .format("csv") + .option("header", "true") + .option("inferSchema", "true") + .load(testFile(simpleSparseFile)) + + assert( + df.schema.fields.map(field => field.dataType).deep == + Array(IntegerType, IntegerType, IntegerType, IntegerType).deep) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index dd83a0e36f6f7..c7f33e17465b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala index cef6b79a094d1..281a2cffa894a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala @@ -47,7 +47,7 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContex assert(batch.column(0).getByte(i) == 1) assert(batch.column(1).getInt(i) == 2) assert(batch.column(2).getLong(i) == 3) - assert(ColumnVectorUtils.toString(batch.column(3).getByteArray(i)) == "abc") + assert(batch.column(3).getUTF8String(i).toString == "abc") i += 1 } reader.close() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 3ded32c450541..bd51154c58aa6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, LogicalRelation} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -100,7 +101,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex (implicit df: DataFrame): Unit = { def checkBinaryAnswer(df: DataFrame, expected: Seq[Row]) = { assertResult(expected.map(_.getAs[Array[Byte]](0).mkString(",")).sorted) { - df.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted + df.rdd.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 41a9404b00795..c85eeddc2c6d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -598,7 +599,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { test("null and non-null strings") { // Create a dataset where the first values are NULL and then some non-null values. The // number of non-nulls needs to be bigger than the ParquetReader batch size. - val data = sqlContext.range(200).map { i => + val data = sqlContext.range(200).rdd.map { i => if (i.getLong(0) < 150) Row(None) else Row("a") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 3d1677bed4770..8bc5c89959803 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.execution.datasources.{LogicalRelation, Partition, PartitioningUtils, PartitionSpec} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index b123d2b31efcf..acfc1a518a0a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow import org.apache.spark.sql.execution.datasources.parquet.TestingUDT.{NestedStruct, NestedStructUDT} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala index e8893073e305a..660f0f173a9e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala @@ -22,7 +22,8 @@ import scala.collection.JavaConverters._ import scala.util.Try import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.{SQLConf, SQLContext} +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.{Benchmark, Utils} /** @@ -149,21 +150,21 @@ object ParquetReadBenchmark { /* Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - SQL Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate - ------------------------------------------------------------------------------- - SQL Parquet Reader 1350.56 11.65 1.00 X - SQL Parquet MR 1844.09 8.53 0.73 X - SQL Parquet Vectorized 1062.04 14.81 1.27 X + SQL Single Int Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + SQL Parquet Reader 1042 / 1208 15.1 66.2 1.0X + SQL Parquet MR 1544 / 1607 10.2 98.2 0.7X + SQL Parquet Vectorized 674 / 739 23.3 42.9 1.5X */ sqlBenchmark.run() /* Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - Parquet Reader Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate - ------------------------------------------------------------------------------- - ParquetReader 610.40 25.77 1.00 X - ParquetReader(Batched) 172.66 91.10 3.54 X - ParquetReader(Batch -> Row) 192.28 81.80 3.17 X + Parquet Reader Single Int Column Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + ParquetReader 565 / 609 27.8 35.9 1.0X + ParquetReader(Batched) 165 / 174 95.3 10.5 3.4X + ParquetReader(Batch -> Row) 158 / 188 99.3 10.1 3.6X */ parquetReaderBenchmark.run() } @@ -217,12 +218,12 @@ object ParquetReadBenchmark { /* Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - Int and String Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate + Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------- - SQL Parquet Reader 1737.94 6.03 1.00 X - SQL Parquet MR 2393.08 4.38 0.73 X - SQL Parquet Vectorized 1442.99 7.27 1.20 X - ParquetReader 1032.11 10.16 1.68 X + SQL Parquet Reader 1381 / 1679 7.6 131.7 1.0X + SQL Parquet MR 2005 / 2177 5.2 191.2 0.7X + SQL Parquet Vectorized 919 / 1044 11.4 87.6 1.5X + ParquetReader 1035 / 1163 10.1 98.7 1.3X */ benchmark.run() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index 5cbcccbd862dc..e8c524e9e550d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -30,7 +30,8 @@ import org.apache.parquet.hadoop.{Footer, ParquetFileReader, ParquetFileWriter} import org.apache.parquet.hadoop.metadata.{BlockMetaData, FileMetaData, ParquetMetadata} import org.apache.parquet.schema.MessageType -import org.apache.spark.sql.{DataFrame, SaveMode, SQLConf} +import org.apache.spark.sql.{DataFrame, SaveMode} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala index f95272530d585..6ae42a30fb00c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -57,6 +57,21 @@ class TextSuite extends QueryTest with SharedSQLContext { } } + test("SPARK-13503 Support to specify the option for compression codec for TEXT") { + val df = sqlContext.read.text(testFile).withColumnRenamed("value", "adwrasdf") + + val tempFile = Utils.createTempDir() + tempFile.delete() + df.write + .option("compression", "gZiP") + .text(tempFile.getCanonicalPath) + val compressedFiles = tempFile.listFiles() + assert(compressedFiles.exists(_.getName.endsWith(".gz"))) + verifyFrame(sqlContext.read.text(tempFile.getCanonicalPath)) + + Utils.deleteRecursively(tempFile) + } + private def testFile: String = { Thread.currentThread().getContextClassLoader.getResource("text-suite.txt").toString } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 6dfff3770b882..7eb15249ebbd6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.sql.{DataFrame, Row, SQLConf} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.exchange.EnsureRequirements +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructType} @@ -145,6 +146,33 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } } } + + test(s"$testName using CartesianProduct") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + Filter(condition(), CartesianProduct(left, right)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + + test(s"$testName using BroadcastNestedLoopJoin build left") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastNestedLoopJoin(left, right, BuildLeft, Inner, Some(condition())), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + + test(s"$testName using BroadcastNestedLoopJoin build right") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastNestedLoopJoin(left, right, BuildRight, Inner, Some(condition())), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } } testInnerJoin( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index cd6b6fcbb18af..0d1c29fe574a6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.sql.{DataFrame, Row, SQLConf} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} import org.apache.spark.sql.execution.exchange.EnsureRequirements +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} @@ -104,6 +105,24 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { } } } + + test(s"$testName using BroadcastNestedLoopJoin build left") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastNestedLoopJoin(left, right, BuildLeft, joinType, Some(condition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + + test(s"$testName using BroadcastNestedLoopJoin build right") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastNestedLoopJoin(left, right, BuildRight, joinType, Some(condition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } } // --- Basic outer joins ------------------------------------------------------------------------ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala index f3ad8409e5a3d..bc341db5571be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.sql.{DataFrame, Row, SQLConf} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys -import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.{Inner, LeftSemi} import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} import org.apache.spark.sql.execution.exchange.EnsureRequirements +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} @@ -94,10 +95,19 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext { } } - test(s"$testName using LeftSemiJoinBNL") { + test(s"$testName using BroadcastNestedLoopJoin build left") { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - LeftSemiJoinBNL(left, right, Some(condition)), + BroadcastNestedLoopJoin(left, right, BuildLeft, LeftSemi, Some(condition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + + test(s"$testName using BroadcastNestedLoopJoin build right") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastNestedLoopJoin(left, right, BuildRight, LeftSemi, Some(condition)), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala index efc3227dd60d8..cd9277d3bcf1a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.internal.SQLConf /** * A dummy [[LocalNode]] that just returns rows from a [[LocalRelation]]. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala index eb70747926fe5..268f2aac87195 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala @@ -20,10 +20,10 @@ package org.apache.spark.sql.execution.local import org.mockito.Mockito.{mock, when} import org.apache.spark.broadcast.TorrentBroadcast -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{Expression, InterpretedMutableProjection, UnsafeProjection} import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide, HashedRelation} +import org.apache.spark.sql.internal.SQLConf class HashJoinNodeSuite extends LocalNodeTest { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala index 1a485f967dd38..cd67a66ebf576 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.execution.local import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, StringType} class LocalNodeTest extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala index 45df2ea6552d8..bcc87a9175517 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} +import org.apache.spark.sql.internal.SQLConf class NestedLoopJoinNodeSuite extends LocalNodeTest { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 46bb699b780a9..5b4f6f1d2461b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.execution.SparkPlanInfo import org.apache.spark.sql.execution.ui.SparkPlanGraph import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.{JsonProtocol, Utils} @@ -153,6 +154,13 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { ) } + test("Sort metrics") { + // Assume the execution plan is + // WholeStageCodegen(nodeId = 0, Range(nodeId = 2) -> Sort(nodeId = 1)) + val df = sqlContext.range(10).sort('id) + testSparkPlanMetrics(df, 2, Map.empty) + } + test("SortMergeJoin metrics") { // Because SortMergeJoin may skip different rows if the number of partitions is different, this // test should use the deterministic number of partitions. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/stat/ApproxQuantileSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/stat/ApproxQuantileSuite.scala new file mode 100644 index 0000000000000..0a989d026ce1c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/stat/ApproxQuantileSuite.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.stat + +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.execution.stat.StatFunctions.QuantileSummaries + + +class ApproxQuantileSuite extends SparkFunSuite { + + private val r = new Random(1) + private val n = 100 + private val increasing = "increasing" -> (0 until n).map(_.toDouble) + private val decreasing = "decreasing" -> (n until 0 by -1).map(_.toDouble) + private val random = "random" -> Seq.fill(n)(math.ceil(r.nextDouble() * 1000)) + + private def buildSummary( + data: Seq[Double], + epsi: Double, + threshold: Int): QuantileSummaries = { + var summary = new QuantileSummaries(threshold, epsi) + data.foreach { x => + summary = summary.insert(x) + } + summary.compress() + } + + private def checkQuantile(quant: Double, data: Seq[Double], summary: QuantileSummaries): Unit = { + val approx = summary.query(quant) + // The rank of the approximation. + val rank = data.count(_ < approx) // has to be <, not <= to be exact + val lower = math.floor((quant - summary.relativeError) * data.size) + val upper = math.ceil((quant + summary.relativeError) * data.size) + val msg = + s"$rank not in [$lower $upper], requested quantile: $quant, approx returned: $approx" + assert(rank >= lower, msg) + assert(rank <= upper, msg) + } + + for { + (seq_name, data) <- Seq(increasing, decreasing, random) + epsi <- Seq(0.1, 0.0001) + compression <- Seq(1000, 10) + } { + + test(s"Extremas with epsi=$epsi and seq=$seq_name, compression=$compression") { + val s = buildSummary(data, epsi, compression) + val min_approx = s.query(0.0) + assert(min_approx == data.min, s"Did not return the min: min=${data.min}, got $min_approx") + val max_approx = s.query(1.0) + assert(max_approx == data.max, s"Did not return the max: max=${data.max}, got $max_approx") + } + + test(s"Some quantile values with epsi=$epsi and seq=$seq_name, compression=$compression") { + val s = buildSummary(data, epsi, compression) + assert(s.count == data.size, s"Found count=${s.count} but data size=${data.size}") + checkQuantile(0.9999, data, s) + checkQuantile(0.9, data, s) + checkQuantile(0.5, data, s) + checkQuantile(0.1, data, s) + checkQuantile(0.001, data, s) + } + } + + // Tests for merging procedure + for { + (seq_name, data) <- Seq(increasing, decreasing, random) + epsi <- Seq(0.1, 0.0001) + compression <- Seq(1000, 10) + } { + + val (data1, data2) = { + val l = data.size + data.take(l / 2) -> data.drop(l / 2) + } + + test(s"Merging ordered lists with epsi=$epsi and seq=$seq_name, compression=$compression") { + val s1 = buildSummary(data1, epsi, compression) + val s2 = buildSummary(data2, epsi, compression) + val s = s1.merge(s2) + val min_approx = s.query(0.0) + assert(min_approx == data.min, s"Did not return the min: min=${data.min}, got $min_approx") + val max_approx = s.query(1.0) + assert(max_approx == data.max, s"Did not return the max: max=${data.max}, got $max_approx") + checkQuantile(0.9999, data, s) + checkQuantile(0.9, data, s) + checkQuantile(0.5, data, s) + checkQuantile(0.1, data, s) + checkQuantile(0.001, data, s) + } + + val (data11, data12) = { + data.sliding(2).map(_.head).toSeq -> data.sliding(2).map(_.last).toSeq + } + + test(s"Merging interleaved lists with epsi=$epsi and seq=$seq_name, compression=$compression") { + val s1 = buildSummary(data11, epsi, compression) + val s2 = buildSummary(data12, epsi, compression) + val s = s1.merge(s2) + val min_approx = s.query(0.0) + assert(min_approx == data.min, s"Did not return the min: min=${data.min}, got $min_approx") + val max_approx = s.query(1.0) + assert(max_approx == data.max, s"Did not return the max: max=${data.max}, got $max_approx") + checkQuantile(0.9999, data, s) + checkQuantile(0.9, data, s) + checkQuantile(0.5, data, s) + checkQuantile(0.1, data, s) + checkQuantile(0.001, data, s) + } + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala index 8efdf8adb042a..97638a66ab473 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -370,7 +370,7 @@ object ColumnarBatchBenchmark { } i = 0 while (i < count) { - sum += column.getByteArray(i).length + sum += column.getUTF8String(i).numBytes() i += 1 } column.reset() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 445f311107e33..b3c3e66fbcbd5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -360,7 +360,7 @@ class ColumnarBatchSuite extends SparkFunSuite { reference.zipWithIndex.foreach { v => assert(v._1.length == column.getArrayLength(v._2), "MemoryMode=" + memMode) - assert(v._1 == ColumnVectorUtils.toString(column.getByteArray(v._2)), + assert(v._1 == column.getUTF8String(v._2).toString, "MemoryMode" + memMode) } @@ -488,7 +488,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(batch.column(1).getDouble(0) == 1.1) assert(batch.column(1).getIsNull(0) == false) assert(batch.column(2).getIsNull(0) == true) - assert(ColumnVectorUtils.toString(batch.column(3).getByteArray(0)) == "Hello") + assert(batch.column(3).getUTF8String(0).toString == "Hello") // Verify the iterator works correctly. val it = batch.rowIterator() @@ -499,7 +499,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(row.getDouble(1) == 1.1) assert(row.isNullAt(1) == false) assert(row.isNullAt(2) == true) - assert(ColumnVectorUtils.toString(batch.column(3).getByteArray(0)) == "Hello") + assert(batch.column(3).getUTF8String(0).toString == "Hello") assert(it.hasNext == false) assert(it.hasNext == false) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/RuntimeConfigSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/RuntimeConfigSuite.scala new file mode 100644 index 0000000000000..f809e01169355 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/RuntimeConfigSuite.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.RuntimeConfig + +class RuntimeConfigSuite extends SparkFunSuite { + + private def newConf(): RuntimeConfig = new RuntimeConfigImpl + + test("set and get") { + val conf = newConf() + conf + .set("k1", "v1") + .set("k2", 2) + .set("k3", value = false) + + assert(conf.get("k1") == "v1") + assert(conf.get("k2") == "2") + assert(conf.get("k3") == "false") + + intercept[NoSuchElementException] { + conf.get("notset") + } + } + + test("getOption") { + val conf = newConf().set("k1", "v1") + assert(conf.getOption("k1") == Some("v1")) + assert(conf.getOption("notset") == None) + } + + test("unset") { + val conf = newConf().set("k1", "v1") + assert(conf.get("k1") == "v1") + conf.unset("k1") + intercept[NoSuchElementException] { + conf.get("k1") + } + } + + test("set and get hadoop configuration") { + val conf = newConf() + conf + .setHadoop("k1", "v1") + .setHadoop("k2", "v2") + + assert(conf.getHadoop("k1") == "v1") + assert(conf.getHadoop("k2") == "v2") + + intercept[NoSuchElementException] { + conf.get("notset") + } + } + + test("getHadoopOption") { + val conf = newConf().setHadoop("k1", "v1") + assert(conf.getHadoopOption("k1") == Some("v1")) + assert(conf.getHadoopOption("notset") == None) + } + + test("unsetHadoop") { + val conf = newConf().setHadoop("k1", "v1") + assert(conf.getHadoop("k1") == "v1") + conf.unsetHadoop("k1") + intercept[NoSuchElementException] { + conf.getHadoop("k1") + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfEntrySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala similarity index 98% rename from sql/core/src/test/scala/org/apache/spark/sql/SQLConfEntrySuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala index 2e33777f14adc..2b89fa9f23815 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfEntrySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala @@ -15,10 +15,10 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.internal import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.SQLConf._ +import org.apache.spark.sql.internal.SQLConf._ class SQLConfEntrySuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala similarity index 98% rename from sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala index cf0701eca29ea..e944d328a3ab7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -15,8 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.internal +import org.apache.spark.sql.{QueryTest, SQLContext} import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext} class SQLConfSuite extends QueryTest with SharedSQLContext { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 7a0f7abaa1baf..30a5e2ea4acd1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -171,6 +171,27 @@ class JDBCSuite extends SparkFunSuite |OPTIONS (url '$url', dbtable 'TEST.NULLTYPES', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) + conn.prepareStatement( + "create table test.emp(name TEXT(32) NOT NULL," + + " theid INTEGER, \"Dept\" INTEGER)").executeUpdate() + conn.prepareStatement( + "insert into test.emp values ('fred', 1, 10)").executeUpdate() + conn.prepareStatement( + "insert into test.emp values ('mary', 2, null)").executeUpdate() + conn.prepareStatement( + "insert into test.emp values ('joe ''foo'' \"bar\"', 3, 30)").executeUpdate() + conn.prepareStatement( + "insert into test.emp values ('kathy', null, null)").executeUpdate() + conn.commit() + + sql( + s""" + |CREATE TEMPORARY TABLE nullparts + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', dbtable 'TEST.EMP', user 'testUser', password 'testPass', + |partitionColumn '"Dept"', lowerBound '1', upperBound '4', numPartitions '4') + """.stripMargin.replaceAll("\n", " ")) + // Untested: IDENTITY, OTHER, UUID, ARRAY, and GEOMETRY types. } @@ -187,8 +208,10 @@ class JDBCSuite extends SparkFunSuite val parentPlan = df.queryExecution.executedPlan // Check if SparkPlan Filter is removed in a physical plan and // the plan only has PhysicalRDD to scan JDBCRelation. - assert(parentPlan.isInstanceOf[PhysicalRDD]) - assert(parentPlan.asInstanceOf[PhysicalRDD].nodeName.contains("JDBCRelation")) + assert(parentPlan.isInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen]) + val node = parentPlan.asInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen] + assert(node.plan.isInstanceOf[org.apache.spark.sql.execution.PhysicalRDD]) + assert(node.plan.asInstanceOf[PhysicalRDD].nodeName.contains("JDBCRelation")) df } assert(checkPushdown(sql("SELECT * FROM foobar WHERE THEID < 1")).collect().size == 0) @@ -336,6 +359,23 @@ class JDBCSuite extends SparkFunSuite .collect().length === 3) } + test("Partioning on column that might have null values.") { + assert( + sqlContext.read.jdbc(urlWithUserAndPass, "TEST.EMP", "theid", 0, 4, 3, new Properties) + .collect().length === 4) + assert( + sqlContext.read.jdbc(urlWithUserAndPass, "TEST.EMP", "THEID", 0, 4, 3, new Properties) + .collect().length === 4) + // partitioning on a nullable quoted column + assert( + sqlContext.read.jdbc(urlWithUserAndPass, "TEST.EMP", """"Dept"""", 0, 4, 3, new Properties) + .collect().length === 4) + } + + test("SELECT * on partitioned table with a nullable partioncolumn") { + assert(sql("SELECT * FROM nullparts").collect().size == 4) + } + test("H2 integral types") { val rows = sql("SELECT * FROM inttypes WHERE A IS NOT NULL").collect() assert(rows.length === 1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index af04079ec895a..92061133cd49b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources import org.apache.spark.sql._ +import org.apache.spark.sql.internal.SQLConf private[sql] abstract class DataSourceTest extends QueryTest { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 7196b6dc13394..2ff79a2316bdc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.PredicateHelper import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -304,30 +305,38 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext with Predic expectedCount: Int, requiredColumnNames: Set[String], expectedUnhandledFilters: Set[Filter]): Unit = { + test(s"PushDown Returns $expectedCount: $sqlString") { - val queryExecution = sql(sqlString).queryExecution - val rawPlan = queryExecution.executedPlan.collect { - case p: execution.PhysicalRDD => p - } match { - case Seq(p) => p - case _ => fail(s"More than one PhysicalRDD found\n$queryExecution") - } - val rawCount = rawPlan.execute().count() - assert(ColumnsRequired.set === requiredColumnNames) - - val table = caseInsensitiveContext.table("oneToTenFiltered") - val relation = table.queryExecution.logical.collectFirst { - case LogicalRelation(r, _, _) => r - }.get - - assert( - relation.unhandledFilters(FiltersPushed.list.toArray).toSet === expectedUnhandledFilters) - - if (rawCount != expectedCount) { - fail( - s"Wrong # of results for pushed filter. Got $rawCount, Expected $expectedCount\n" + - s"Filters pushed: ${FiltersPushed.list.mkString(",")}\n" + - queryExecution) + // These tests check a particular plan, disable whole stage codegen. + caseInsensitiveContext.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, false) + try { + val queryExecution = sql(sqlString).queryExecution + val rawPlan = queryExecution.executedPlan.collect { + case p: execution.PhysicalRDD => p + } match { + case Seq(p) => p + case _ => fail(s"More than one PhysicalRDD found\n$queryExecution") + } + val rawCount = rawPlan.execute().count() + assert(ColumnsRequired.set === requiredColumnNames) + + val table = caseInsensitiveContext.table("oneToTenFiltered") + val relation = table.queryExecution.logical.collectFirst { + case LogicalRelation(r, _, _) => r + }.get + + assert( + relation.unhandledFilters(FiltersPushed.list.toArray).toSet === expectedUnhandledFilters) + + if (rawCount != expectedCount) { + fail( + s"Wrong # of results for pushed filter. Got $rawCount, Expected $expectedCount\n" + + s"Filters pushed: ${FiltersPushed.list.mkString(",")}\n" + + queryExecution) + } + } finally { + caseInsensitiveContext.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.defaultValue.get) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index a89c5f8007e78..db722975379a2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -21,6 +21,7 @@ import scala.language.existentials import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -117,28 +118,35 @@ class PrunedScanSuite extends DataSourceTest with SharedSQLContext { def testPruning(sqlString: String, expectedColumns: String*): Unit = { test(s"Columns output ${expectedColumns.mkString(",")}: $sqlString") { - val queryExecution = sql(sqlString).queryExecution - val rawPlan = queryExecution.executedPlan.collect { - case p: execution.PhysicalRDD => p - } match { - case Seq(p) => p - case _ => fail(s"More than one PhysicalRDD found\n$queryExecution") - } - val rawColumns = rawPlan.output.map(_.name) - val rawOutput = rawPlan.execute().first() - - if (rawColumns != expectedColumns) { - fail( - s"Wrong column names. Got $rawColumns, Expected $expectedColumns\n" + - s"Filters pushed: ${FiltersPushed.list.mkString(",")}\n" + - queryExecution) - } - if (rawOutput.numFields != expectedColumns.size) { - fail(s"Wrong output row. Got $rawOutput\n$queryExecution") + // These tests check a particular plan, disable whole stage codegen. + caseInsensitiveContext.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, false) + try { + val queryExecution = sql(sqlString).queryExecution + val rawPlan = queryExecution.executedPlan.collect { + case p: execution.PhysicalRDD => p + } match { + case Seq(p) => p + case _ => fail(s"More than one PhysicalRDD found\n$queryExecution") + } + val rawColumns = rawPlan.output.map(_.name) + val rawOutput = rawPlan.execute().first() + + if (rawColumns != expectedColumns) { + fail( + s"Wrong column names. Got $rawColumns, Expected $expectedColumns\n" + + s"Filters pushed: ${FiltersPushed.list.mkString(",")}\n" + + queryExecution) + } + + if (rawOutput.numFields != expectedColumns.size) { + fail(s"Wrong output row. Got $rawOutput\n$queryExecution") + } + } finally { + caseInsensitiveContext.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.defaultValue.get) } } } - } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index e055da9e8a39a..588f6e268f31c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -21,7 +21,8 @@ import java.io.File import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLConf} +import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index c89a1516503e0..b3e146fba80be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.test import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.{SQLConf, SQLContext} - +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.internal.{SessionState, SQLConf} /** * A special [[SQLContext]] prepared for testing. @@ -31,16 +31,16 @@ private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { sel new SparkConf().set("spark.sql.testkey", "true"))) } - protected[sql] override lazy val conf: SQLConf = new SQLConf { - - clear() - - override def clear(): Unit = { - super.clear() - - // Make sure we start with the default test configs even after clear - TestSQLContext.overrideConfs.map { - case (key, value) => setConfString(key, value) + @transient + protected[sql] override lazy val sessionState: SessionState = new SessionState(self) { + override lazy val conf: SQLConf = { + new SQLConf { + clear() + override def clear(): Unit = { + super.clear() + // Make sure we start with the default test configs even after clear + TestSQLContext.overrideConfs.foreach { case (key, value) => setConfString(key, value) } + } } } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 66eaa3ebcd737..f32ba5fe68a63 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -32,10 +32,10 @@ import org.apache.hive.service.server.{HiveServer2, HiveServerServerOptionsProce import org.apache.spark.{Logging, SparkContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd, SparkListenerJobStart} -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.thriftserver.ui.ThriftServerTab +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.{ShutdownHookManager, Utils} /** diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 8fef22cf777f6..458d4f2c3c959 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -33,9 +33,10 @@ import org.apache.hive.service.cli.operation.ExecuteStatementOperation import org.apache.hive.service.cli.session.HiveSession import org.apache.spark.Logging -import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLConf} +import org.apache.spark.sql.{DataFrame, Row => SparkRow} import org.apache.spark.sql.execution.SetCommand import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.{Utils => SparkUtils} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index f279b78f47c7d..fb31119a9e1dd 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -288,8 +288,11 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { val tokens: Array[String] = cmd_trimmed.split("\\s+") val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim() if (cmd_lower.equals("quit") || - cmd_lower.equals("exit") || - tokens(0).toLowerCase(Locale.ENGLISH).equals("source") || + cmd_lower.equals("exit")) { + sessionState.close() + System.exit(0) + } + if (tokens(0).toLowerCase(Locale.ENGLISH).equals("source") || cmd_trimmed.startsWith("!") || tokens(0).toLowerCase.equals("list") || isRemoteMode) { diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 72da266da4d01..81508e134695a 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -234,4 +234,9 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { -> "Error in query: Table not found: nonexistent_table;" ) } + + test("SPARK-11624 Spark SQL CLI should set sessionState only once") { + runCliWithin(2.minute, Seq("-e", "!echo \"This is a test for Spark-11624\";"))( + "" -> "This is a test for Spark-11624") + } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 5f9952a90a57d..c05527b519daa 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -202,7 +202,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { } test("test multiple session") { - import org.apache.spark.sql.SQLConf + import org.apache.spark.sql.internal.SQLConf var defaultV1: String = null var defaultV2: String = null var data: ArrayBuffer[Int] = null diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 9097c1a1d3117..0dc2a95eea70e 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -23,14 +23,12 @@ import java.util.{Locale, TimeZone} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.tags.ExtendedHiveTest +import org.apache.spark.sql.internal.SQLConf /** * Runs the test cases that are included in the hive distribution. */ -@ExtendedHiveTest class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // TODO: bundle in jar files... get from classpath private lazy val hiveQueryDir = TestHive.getHiveFile( @@ -57,7 +55,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Enable in-memory partition pruning for testing purposes TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) // Use Hive hash expression instead of the native one - TestHive.functionRegistry.unregisterFunction("hash") + TestHive.sessionState.functionRegistry.unregisterFunction("hash") RuleExecutor.resetTime() } @@ -67,7 +65,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { Locale.setDefault(originalLocale) TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) - TestHive.functionRegistry.restore() + TestHive.sessionState.functionRegistry.restore() // For debugging dump some statistics about how much time was spent in various optimizer rules. logWarning(RuleExecutor.dumpTimeSpent()) diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 14cf9acf09d5b..22bad93e6dd58 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -72,6 +72,12 @@ protobuf-java ${protobuf.version} +--> + + ${hive.group} + hive-cli + +