diff --git a/LICENSE b/LICENSE index 3c6117f4aa8f2..d7a790a62894f 100644 --- a/LICENSE +++ b/LICENSE @@ -263,7 +263,7 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf) (The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net) (The BSD License) xmlenc Library (xmlenc:xmlenc:0.52 - http://xmlenc.sourceforge.net) - (The New BSD License) Py4J (net.sf.py4j:py4j:0.9.1 - http://py4j.sourceforge.net/) + (The New BSD License) Py4J (net.sf.py4j:py4j:0.9.2 - http://py4j.sourceforge.net/) (Two-clause BSD-style license) JUnit-Interface (com.novocode:junit-interface:0.10 - http://github.com/szeiger/junit-interface/) (BSD licence) sbt and sbt-launch-lib.bash (BSD 3 Clause) d3.min.js (https://github.com/mbostock/d3/blob/master/LICENSE) diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 465bc37788e5d..0cd0d75df0f70 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -18,10 +18,10 @@ Collate: 'schema.R' 'generics.R' 'jobj.R' - 'RDD.R' - 'pairRDD.R' 'column.R' 'group.R' + 'RDD.R' + 'pairRDD.R' 'DataFrame.R' 'SQLContext.R' 'backend.R' @@ -36,3 +36,4 @@ Collate: 'stats.R' 'types.R' 'utils.R' +RoxygenNote: 5.0.1 diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 50655e9382325..a64a013b654ef 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -321,7 +321,7 @@ setMethod("colnames<-", } # Check if the column names have . in it - if (any(regexec(".", value, fixed=TRUE)[[1]][1] != -1)) { + if (any(regexec(".", value, fixed = TRUE)[[1]][1] != -1)) { stop("Colum names cannot contain the '.' symbol.") } @@ -351,7 +351,7 @@ setMethod("coltypes", types <- sapply(dtypes(x), function(x) {x[[2]]}) # Map Spark data types into R's data types using DATA_TYPES environment - rTypes <- sapply(types, USE.NAMES=F, FUN=function(x) { + rTypes <- sapply(types, USE.NAMES = F, FUN = function(x) { # Check for primitive types type <- PRIMITIVE_TYPES[[x]] @@ -1779,7 +1779,7 @@ setMethod("merge", signature(x = "DataFrame", y = "DataFrame"), function(x, y, by = intersect(names(x), names(y)), by.x = by, by.y = by, all = FALSE, all.x = all, all.y = all, - sort = TRUE, suffixes = c("_x","_y"), ... ) { + sort = TRUE, suffixes = c("_x", "_y"), ... ) { if (length(suffixes) != 2) { stop("suffixes must have length 2") @@ -2299,7 +2299,7 @@ setMethod("as.data.frame", function(x, ...) { # Check if additional parameters have been passed if (length(list(...)) > 0) { - stop(paste("Unused argument(s): ", paste(list(...), collapse=", "))) + stop(paste("Unused argument(s): ", paste(list(...), collapse = ", "))) } collect(x) }) @@ -2395,13 +2395,13 @@ setMethod("str", # Get the first elements for each column firstElements <- if (types[i] == "character") { - paste(paste0("\"", localDF[,i], "\""), collapse = " ") + paste(paste0("\"", localDF[, i], "\""), collapse = " ") } else { - paste(localDF[,i], collapse = " ") + paste(localDF[, i], collapse = " ") } # Add the corresponding number of spaces for alignment - spaces <- paste(rep(" ", max(nchar(names) - nchar(names[i]))), collapse="") + spaces <- paste(rep(" ", max(nchar(names) - nchar(names[i]))), collapse = "") # Get the short type. For 'character', it would be 'chr'; # 'for numeric', it's 'num', etc. @@ -2413,7 +2413,7 @@ setMethod("str", # Concatenate the colnames, coltypes, and first # elements of each column line <- paste0(" $ ", names[i], spaces, ": ", - dataType, " ",firstElements) + dataType, " ", firstElements) # Chop off extra characters if this is too long cat(substr(line, 1, MAX_CHAR_PER_ROW)) diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index a78fbb714f2be..35c4e6f1afaf4 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -67,7 +67,7 @@ setMethod("initialize", "RDD", function(.Object, jrdd, serializedMode, setMethod("show", "RDD", function(object) { - cat(paste(callJMethod(getJRDD(object), "toString"), "\n", sep="")) + cat(paste(callJMethod(getJRDD(object), "toString"), "\n", sep = "")) }) setMethod("initialize", "PipelinedRDD", function(.Object, prev, func, jrdd_val) { diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 471bec1eacf03..b0e67c8ad26ab 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -103,7 +103,10 @@ parallelize <- function(sc, coll, numSlices = 1) { # TODO: bound/safeguard numSlices # TODO: unit tests for if the split works for all primitives # TODO: support matrix, data frame, etc + # nolint start + # suppress lintr warning: Place a space before left parenthesis, except in a function call. if ((!is.list(coll) && !is.vector(coll)) || is.data.frame(coll)) { + # nolint end if (is.data.frame(coll)) { message(paste("context.R: A data frame is parallelized by columns.")) } else { diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index d8a0393275390..eefdf178733fd 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -186,7 +186,7 @@ readMultipleObjects <- function(inputCon) { # of the objects, so the number of objects varies, we try to read # all objects in a loop until the end of the stream. data <- list() - while(TRUE) { + while (TRUE) { # If reaching the end of the stream, type returned should be "". type <- readType(inputCon) if (type == "") { diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index ddfa61717af2e..6ad71fcb46712 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -607,7 +607,7 @@ setGeneric("selectExpr", function(x, expr, ...) { standardGeneric("selectExpr") #' @rdname showDF #' @export -setGeneric("showDF", function(x,...) { standardGeneric("showDF") }) +setGeneric("showDF", function(x, ...) { standardGeneric("showDF") }) # @rdname subset # @export @@ -615,7 +615,7 @@ setGeneric("subset", function(x, ...) { standardGeneric("subset") }) #' @rdname agg #' @export -setGeneric("summarize", function(x,...) { standardGeneric("summarize") }) +setGeneric("summarize", function(x, ...) { standardGeneric("summarize") }) #' @rdname summary #' @export diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 346f33d7dab2c..5c0d3dcf3af90 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -54,7 +54,7 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFram function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0, standardize = TRUE, solver = "auto") { family <- match.arg(family) - formula <- paste(deparse(formula), collapse="") + formula <- paste(deparse(formula), collapse = "") model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "fitRModelFormula", formula, data@sdf, family, lambda, alpha, standardize, solver) diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R index 70e87a93e610f..3bbf60d9b668c 100644 --- a/R/pkg/R/serialize.R +++ b/R/pkg/R/serialize.R @@ -100,7 +100,7 @@ writeJobj <- function(con, value) { writeString <- function(con, value) { utfVal <- enc2utf8(value) writeInt(con, as.integer(nchar(utfVal, type = "bytes") + 1)) - writeBin(utfVal, con, endian = "big", useBytes=TRUE) + writeBin(utfVal, con, endian = "big", useBytes = TRUE) } writeInt <- function(con, value) { diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 3e9eafc7f5b90..c187869fdf121 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -153,7 +153,7 @@ sparkR.init <- function( if (!file.exists(path)) { stop("JVM is not ready after 10 seconds") } - f <- file(path, open="rb") + f <- file(path, open = "rb") backendPort <- readInt(f) monitorPort <- readInt(f) rLibPath <- readString(f) @@ -185,9 +185,9 @@ sparkR.init <- function( } sparkExecutorEnvMap <- convertNamedListToEnv(sparkExecutorEnv) - if(is.null(sparkExecutorEnvMap$LD_LIBRARY_PATH)) { + if (is.null(sparkExecutorEnvMap$LD_LIBRARY_PATH)) { sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <- - paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH")) + paste0("$LD_LIBRARY_PATH:", Sys.getenv("LD_LIBRARY_PATH")) } # Classpath separator is ";" on Windows diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index aa386e5da933b..fb6575cb42907 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -158,7 +158,7 @@ wrapInt <- function(value) { # Multiply `val` by 31 and add `addVal` to the result. Ensures that # integer-overflows are handled at every step. mult31AndAdd <- function(val, addVal) { - vec <- c(bitwShiftL(val, c(4,3,2,1,0)), addVal) + vec <- c(bitwShiftL(val, c(4, 3, 2, 1, 0)), addVal) Reduce(function(a, b) { wrapInt(as.numeric(a) + as.numeric(b)) }, @@ -202,7 +202,7 @@ serializeToString <- function(rdd) { # This function amortizes the allocation cost by doubling # the size of the list every time it fills up. addItemToAccumulator <- function(acc, item) { - if(acc$counter == acc$size) { + if (acc$counter == acc$size) { acc$size <- acc$size * 2 length(acc$data) <- acc$size } diff --git a/R/pkg/inst/profile/general.R b/R/pkg/inst/profile/general.R index c55fe9ba7af7a..8c75c19ca7ac3 100644 --- a/R/pkg/inst/profile/general.R +++ b/R/pkg/inst/profile/general.R @@ -19,5 +19,5 @@ packageDir <- Sys.getenv("SPARKR_PACKAGE_DIR") dirs <- strsplit(packageDir, ",")[[1]] .libPaths(c(dirs, .libPaths())) - Sys.setenv(NOAWT=1) + Sys.setenv(NOAWT = 1) } diff --git a/R/pkg/inst/tests/testthat/packageInAJarTest.R b/R/pkg/inst/tests/testthat/packageInAJarTest.R index 207a37a0cb47f..c26b28b78dee8 100644 --- a/R/pkg/inst/tests/testthat/packageInAJarTest.R +++ b/R/pkg/inst/tests/testthat/packageInAJarTest.R @@ -25,6 +25,6 @@ run2 <- myfunc(-4L) sparkR.stop() -if(run1 != 6) quit(save = "no", status = 1) +if (run1 != 6) quit(save = "no", status = 1) -if(run2 != -3) quit(save = "no", status = 1) +if (run2 != -3) quit(save = "no", status = 1) diff --git a/R/pkg/inst/tests/testthat/test_binaryFile.R b/R/pkg/inst/tests/testthat/test_binaryFile.R index f2452ed97d2ea..976a7558a816d 100644 --- a/R/pkg/inst/tests/testthat/test_binaryFile.R +++ b/R/pkg/inst/tests/testthat/test_binaryFile.R @@ -23,8 +23,8 @@ sc <- sparkR.init() mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("saveAsObjectFile()/objectFile() following textFile() works", { - fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") - fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") + fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) rdd <- textFile(sc, fileName1, 1) @@ -37,7 +37,7 @@ test_that("saveAsObjectFile()/objectFile() following textFile() works", { }) test_that("saveAsObjectFile()/objectFile() works on a parallelized list", { - fileName <- tempfile(pattern="spark-test", fileext=".tmp") + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") l <- list(1, 2, 3) rdd <- parallelize(sc, l, 1) @@ -49,8 +49,8 @@ test_that("saveAsObjectFile()/objectFile() works on a parallelized list", { }) test_that("saveAsObjectFile()/objectFile() following RDD transformations works", { - fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") - fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") + fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) rdd <- textFile(sc, fileName1) @@ -73,8 +73,8 @@ test_that("saveAsObjectFile()/objectFile() following RDD transformations works", }) test_that("saveAsObjectFile()/objectFile() works with multiple paths", { - fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") - fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") + fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") rdd1 <- parallelize(sc, "Spark is pretty.") saveAsObjectFile(rdd1, fileName1) diff --git a/R/pkg/inst/tests/testthat/test_binary_function.R b/R/pkg/inst/tests/testthat/test_binary_function.R index f054ac9a87d61..7bad4d2a7e106 100644 --- a/R/pkg/inst/tests/testthat/test_binary_function.R +++ b/R/pkg/inst/tests/testthat/test_binary_function.R @@ -31,7 +31,7 @@ test_that("union on two RDDs", { actual <- collect(unionRDD(rdd, rdd)) expect_equal(actual, as.list(rep(nums, 2))) - fileName <- tempfile(pattern="spark-test", fileext=".tmp") + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) text.rdd <- textFile(sc, fileName) @@ -74,10 +74,10 @@ test_that("zipPartitions() on RDDs", { actual <- collect(zipPartitions(rdd1, rdd2, rdd3, func = function(x, y, z) { list(list(x, y, z))} )) expect_equal(actual, - list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6)))) + list(list(1, c(1, 2), c(1, 2, 3)), list(2, c(3, 4), c(4, 5, 6)))) mockFile <- c("Spark is pretty.", "Spark is awesome.") - fileName <- tempfile(pattern="spark-test", fileext=".tmp") + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) rdd <- textFile(sc, fileName, 1) diff --git a/R/pkg/inst/tests/testthat/test_broadcast.R b/R/pkg/inst/tests/testthat/test_broadcast.R index bb86a5c922bde..8be6efc3dbed3 100644 --- a/R/pkg/inst/tests/testthat/test_broadcast.R +++ b/R/pkg/inst/tests/testthat/test_broadcast.R @@ -25,7 +25,7 @@ nums <- 1:2 rrdd <- parallelize(sc, nums, 2L) test_that("using broadcast variable", { - randomMat <- matrix(nrow=10, ncol=10, data=rnorm(100)) + randomMat <- matrix(nrow = 10, ncol = 10, data = rnorm(100)) randomMatBr <- broadcast(sc, randomMat) useBroadcast <- function(x) { @@ -37,7 +37,7 @@ test_that("using broadcast variable", { }) test_that("without using broadcast variable", { - randomMat <- matrix(nrow=10, ncol=10, data=rnorm(100)) + randomMat <- matrix(nrow = 10, ncol = 10, data = rnorm(100)) useBroadcast <- function(x) { sum(randomMat * x) diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index af84a0abcf94d..e120462964d1e 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -96,9 +96,9 @@ test_that("summary coefficients match with native glm of family 'binomial'", { training <- filter(df, df$Species != "setosa") stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training, family = "binomial")) - coefs <- as.vector(stats$coefficients[,1]) + coefs <- as.vector(stats$coefficients[, 1]) - rTraining <- iris[iris$Species %in% c("versicolor","virginica"),] + rTraining <- iris[iris$Species %in% c("versicolor", "virginica"), ] rCoefs <- as.vector(coef(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining, family = binomial(link = "logit")))) diff --git a/R/pkg/inst/tests/testthat/test_rdd.R b/R/pkg/inst/tests/testthat/test_rdd.R index 1b3a22486e95f..3b0c16be5a754 100644 --- a/R/pkg/inst/tests/testthat/test_rdd.R +++ b/R/pkg/inst/tests/testthat/test_rdd.R @@ -75,7 +75,7 @@ test_that("mapPartitions on RDD", { test_that("flatMap() on RDDs", { flat <- flatMap(intRdd, function(x) { list(x, x) }) actual <- collect(flat) - expect_equal(actual, rep(intPairs, each=2)) + expect_equal(actual, rep(intPairs, each = 2)) }) test_that("filterRDD on RDD", { @@ -245,9 +245,9 @@ test_that("mapValues() on pairwise RDDs", { }) test_that("flatMapValues() on pairwise RDDs", { - l <- parallelize(sc, list(list(1, c(1,2)), list(2, c(3,4)))) + l <- parallelize(sc, list(list(1, c(1, 2)), list(2, c(3, 4)))) actual <- collect(flatMapValues(l, function(x) { x })) - expect_equal(actual, list(list(1,1), list(1,2), list(2,3), list(2,4))) + expect_equal(actual, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) # Generate x to x+1 for every value actual <- collect(flatMapValues(intRdd, function(x) { x: (x + 1) })) @@ -448,12 +448,12 @@ test_that("zipRDD() on RDDs", { list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004))) mockFile <- c("Spark is pretty.", "Spark is awesome.") - fileName <- tempfile(pattern="spark-test", fileext=".tmp") + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) rdd <- textFile(sc, fileName, 1) actual <- collect(zipRDD(rdd, rdd)) - expected <- lapply(mockFile, function(x) { list(x ,x) }) + expected <- lapply(mockFile, function(x) { list(x, x) }) expect_equal(actual, expected) rdd1 <- parallelize(sc, 0:1, 1) @@ -484,7 +484,7 @@ test_that("cartesian() on RDDs", { expect_equal(actual, list()) mockFile <- c("Spark is pretty.", "Spark is awesome.") - fileName <- tempfile(pattern="spark-test", fileext=".tmp") + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) rdd <- textFile(sc, fileName) @@ -523,19 +523,19 @@ test_that("subtract() on RDDs", { # subtract by an empty RDD rdd2 <- parallelize(sc, list()) actual <- collect(subtract(rdd1, rdd2)) - expect_equal(as.list(sort(as.vector(actual, mode="integer"))), + expect_equal(as.list(sort(as.vector(actual, mode = "integer"))), l) rdd2 <- parallelize(sc, list(2, 4)) actual <- collect(subtract(rdd1, rdd2)) - expect_equal(as.list(sort(as.vector(actual, mode="integer"))), + expect_equal(as.list(sort(as.vector(actual, mode = "integer"))), list(1, 1, 3)) l <- list("a", "a", "b", "b", "c", "d") rdd1 <- parallelize(sc, l) rdd2 <- parallelize(sc, list("b", "d")) actual <- collect(subtract(rdd1, rdd2)) - expect_equal(as.list(sort(as.vector(actual, mode="character"))), + expect_equal(as.list(sort(as.vector(actual, mode = "character"))), list("a", "a", "c")) }) @@ -585,53 +585,53 @@ test_that("intersection() on RDDs", { }) test_that("join() on pairwise RDDs", { - rdd1 <- parallelize(sc, list(list(1,1), list(2,4))) - rdd2 <- parallelize(sc, list(list(1,2), list(1,3))) + rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) + rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) actual <- collect(join(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), sortKeyValueList(list(list(1, list(1, 2)), list(1, list(1, 3))))) - rdd1 <- parallelize(sc, list(list("a",1), list("b",4))) - rdd2 <- parallelize(sc, list(list("a",2), list("a",3))) + rdd1 <- parallelize(sc, list(list("a", 1), list("b", 4))) + rdd2 <- parallelize(sc, list(list("a", 2), list("a", 3))) actual <- collect(join(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), sortKeyValueList(list(list("a", list(1, 2)), list("a", list(1, 3))))) - rdd1 <- parallelize(sc, list(list(1,1), list(2,2))) - rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) + rdd1 <- parallelize(sc, list(list(1, 1), list(2, 2))) + rdd2 <- parallelize(sc, list(list(3, 3), list(4, 4))) actual <- collect(join(rdd1, rdd2, 2L)) expect_equal(actual, list()) - rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) - rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) + rdd1 <- parallelize(sc, list(list("a", 1), list("b", 2))) + rdd2 <- parallelize(sc, list(list("c", 3), list("d", 4))) actual <- collect(join(rdd1, rdd2, 2L)) expect_equal(actual, list()) }) test_that("leftOuterJoin() on pairwise RDDs", { - rdd1 <- parallelize(sc, list(list(1,1), list(2,4))) - rdd2 <- parallelize(sc, list(list(1,2), list(1,3))) + rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) + rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) expected <- list(list(1, list(1, 2)), list(1, list(1, 3)), list(2, list(4, NULL))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) - rdd1 <- parallelize(sc, list(list("a",1), list("b",4))) - rdd2 <- parallelize(sc, list(list("a",2), list("a",3))) + rdd1 <- parallelize(sc, list(list("a", 1), list("b", 4))) + rdd2 <- parallelize(sc, list(list("a", 2), list("a", 3))) actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) expected <- list(list("b", list(4, NULL)), list("a", list(1, 2)), list("a", list(1, 3))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) - rdd1 <- parallelize(sc, list(list(1,1), list(2,2))) - rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) + rdd1 <- parallelize(sc, list(list(1, 1), list(2, 2))) + rdd2 <- parallelize(sc, list(list(3, 3), list(4, 4))) actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) expected <- list(list(1, list(1, NULL)), list(2, list(2, NULL))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) - rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) - rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) + rdd1 <- parallelize(sc, list(list("a", 1), list("b", 2))) + rdd2 <- parallelize(sc, list(list("c", 3), list("d", 4))) actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) expected <- list(list("b", list(2, NULL)), list("a", list(1, NULL))) expect_equal(sortKeyValueList(actual), @@ -639,57 +639,57 @@ test_that("leftOuterJoin() on pairwise RDDs", { }) test_that("rightOuterJoin() on pairwise RDDs", { - rdd1 <- parallelize(sc, list(list(1,2), list(1,3))) - rdd2 <- parallelize(sc, list(list(1,1), list(2,4))) + rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3))) + rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) - rdd1 <- parallelize(sc, list(list("a",2), list("a",3))) - rdd2 <- parallelize(sc, list(list("a",1), list("b",4))) + rdd1 <- parallelize(sc, list(list("a", 2), list("a", 3))) + rdd2 <- parallelize(sc, list(list("a", 1), list("b", 4))) actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), list("a", list(3, 1))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) - rdd1 <- parallelize(sc, list(list(1,1), list(2,2))) - rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) + rdd1 <- parallelize(sc, list(list(1, 1), list(2, 2))) + rdd2 <- parallelize(sc, list(list(3, 3), list(4, 4))) actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), sortKeyValueList(list(list(3, list(NULL, 3)), list(4, list(NULL, 4))))) - rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) - rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) + rdd1 <- parallelize(sc, list(list("a", 1), list("b", 2))) + rdd2 <- parallelize(sc, list(list("c", 3), list("d", 4))) actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), sortKeyValueList(list(list("d", list(NULL, 4)), list("c", list(NULL, 3))))) }) test_that("fullOuterJoin() on pairwise RDDs", { - rdd1 <- parallelize(sc, list(list(1,2), list(1,3), list(3,3))) - rdd2 <- parallelize(sc, list(list(1,1), list(2,4))) + rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3), list(3, 3))) + rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4)), list(3, list(3, NULL))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) - rdd1 <- parallelize(sc, list(list("a",2), list("a",3), list("c", 1))) - rdd2 <- parallelize(sc, list(list("a",1), list("b",4))) + rdd1 <- parallelize(sc, list(list("a", 2), list("a", 3), list("c", 1))) + rdd2 <- parallelize(sc, list(list("a", 1), list("b", 4))) actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), list("a", list(3, 1)), list("c", list(1, NULL))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) - rdd1 <- parallelize(sc, list(list(1,1), list(2,2))) - rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) + rdd1 <- parallelize(sc, list(list(1, 1), list(2, 2))) + rdd2 <- parallelize(sc, list(list(3, 3), list(4, 4))) actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), sortKeyValueList(list(list(1, list(1, NULL)), list(2, list(2, NULL)), list(3, list(NULL, 3)), list(4, list(NULL, 4))))) - rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) - rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) + rdd1 <- parallelize(sc, list(list("a", 1), list("b", 2))) + rdd2 <- parallelize(sc, list(list("c", 3), list("d", 4))) actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), sortKeyValueList(list(list("a", list(1, NULL)), list("b", list(2, NULL)), diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 11a8f12fd5432..63acbadfa6a16 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -41,8 +41,8 @@ sqlContext <- sparkRSQL.init(sc) mockLines <- c("{\"name\":\"Michael\"}", "{\"name\":\"Andy\", \"age\":30}", "{\"name\":\"Justin\", \"age\":19}") -jsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") -parquetPath <- tempfile(pattern="sparkr-test", fileext=".parquet") +jsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") +parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") writeLines(mockLines, jsonPath) # For test nafunctions, like dropna(), fillna(),... @@ -51,7 +51,7 @@ mockLinesNa <- c("{\"name\":\"Bob\",\"age\":16,\"height\":176.5}", "{\"name\":\"David\",\"age\":60,\"height\":null}", "{\"name\":\"Amy\",\"age\":null,\"height\":null}", "{\"name\":null,\"age\":null,\"height\":null}") -jsonPathNa <- tempfile(pattern="sparkr-test", fileext=".tmp") +jsonPathNa <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLinesNa, jsonPathNa) # For test complex types in DataFrame @@ -59,7 +59,7 @@ mockLinesComplexType <- c("{\"c1\":[1, 2, 3], \"c2\":[\"a\", \"b\", \"c\"], \"c3\":[1.0, 2.0, 3.0]}", "{\"c1\":[4, 5, 6], \"c2\":[\"d\", \"e\", \"f\"], \"c3\":[4.0, 5.0, 6.0]}", "{\"c1\":[7, 8, 9], \"c2\":[\"g\", \"h\", \"i\"], \"c3\":[7.0, 8.0, 9.0]}") -complexTypeJsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") +complexTypeJsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLinesComplexType, complexTypeJsonPath) test_that("calling sparkRSQL.init returns existing SQL context", { @@ -151,9 +151,9 @@ test_that("create DataFrame from RDD", { expect_equal(as.list(collect(where(df2AsDF, df2AsDF$name == "Bob"))), list(name = "Bob", age = 16, height = 176.5)) - localDF <- data.frame(name=c("John", "Smith", "Sarah"), - age=c(19L, 23L, 18L), - height=c(176.5, 181.4, 173.7)) + localDF <- data.frame(name = c("John", "Smith", "Sarah"), + age = c(19L, 23L, 18L), + height = c(176.5, 181.4, 173.7)) df <- createDataFrame(sqlContext, localDF, schema) expect_is(df, "DataFrame") expect_equal(count(df), 3) @@ -263,7 +263,7 @@ test_that("create DataFrame from list or data.frame", { irisdf <- suppressWarnings(createDataFrame(sqlContext, iris)) iris_collected <- collect(irisdf) - expect_equivalent(iris_collected[,-5], iris[,-5]) + expect_equivalent(iris_collected[, -5], iris[, -5]) expect_equal(iris_collected$Species, as.character(iris$Species)) mtcarsdf <- createDataFrame(sqlContext, mtcars) @@ -329,7 +329,7 @@ test_that("create DataFrame from a data.frame with complex types", { mockLinesMapType <- c("{\"name\":\"Bob\",\"info\":{\"age\":16,\"height\":176.5}}", "{\"name\":\"Alice\",\"info\":{\"age\":20,\"height\":164.3}}", "{\"name\":\"David\",\"info\":{\"age\":60,\"height\":180}}") -mapTypeJsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") +mapTypeJsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLinesMapType, mapTypeJsonPath) test_that("Collect DataFrame with complex types", { @@ -399,11 +399,11 @@ test_that("read/write json files", { expect_equal(count(df), 3) # Test write.df - jsonPath2 <- tempfile(pattern="jsonPath2", fileext=".json") - write.df(df, jsonPath2, "json", mode="overwrite") + jsonPath2 <- tempfile(pattern = "jsonPath2", fileext = ".json") + write.df(df, jsonPath2, "json", mode = "overwrite") # Test write.json - jsonPath3 <- tempfile(pattern="jsonPath3", fileext=".json") + jsonPath3 <- tempfile(pattern = "jsonPath3", fileext = ".json") write.json(df, jsonPath3) # Test read.json()/jsonFile() works with multiple input paths @@ -466,7 +466,7 @@ test_that("insertInto() on a registered table", { lines <- c("{\"name\":\"Bob\", \"age\":24}", "{\"name\":\"James\", \"age\":35}") - jsonPath2 <- tempfile(pattern="jsonPath2", fileext=".tmp") + jsonPath2 <- tempfile(pattern = "jsonPath2", fileext = ".tmp") parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") writeLines(lines, jsonPath2) df2 <- read.df(sqlContext, jsonPath2, "json") @@ -526,7 +526,7 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", { textLines <- c("Michael", "Andy, 30", "Justin, 19") - textPath <- tempfile(pattern="sparkr-textLines", fileext=".tmp") + textPath <- tempfile(pattern = "sparkr-textLines", fileext = ".tmp") writeLines(textLines, textPath) textRDD <- textFile(sc, textPath) @@ -547,7 +547,7 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", { }) test_that("objectFile() works with row serialization", { - objectPath <- tempfile(pattern="spark-test", fileext=".tmp") + objectPath <- tempfile(pattern = "spark-test", fileext = ".tmp") df <- read.json(sqlContext, jsonPath) dfRDD <- toRDD(df) saveAsObjectFile(coalesce(dfRDD, 1L), objectPath) @@ -611,7 +611,7 @@ test_that("collect() support Unicode characters", { "{\"name\":\"こんにちは\", \"age\":19}", "{\"name\":\"Xin chào\"}") - jsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") + jsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(lines, jsonPath) df <- read.df(sqlContext, jsonPath, "json") @@ -705,7 +705,7 @@ test_that("names() colnames() set the column names", { # Test base::colnames base::names m2 <- cbind(1, 1:4) expect_equal(colnames(m2, do.NULL = FALSE), c("col1", "col2")) - colnames(m2) <- c("x","Y") + colnames(m2) <- c("x", "Y") expect_equal(colnames(m2), c("x", "Y")) z <- list(a = 1, b = "c", c = 1:3) @@ -745,7 +745,7 @@ test_that("distinct(), unique() and dropDuplicates() on DataFrames", { "{\"name\":\"Andy\", \"age\":30}", "{\"name\":\"Justin\", \"age\":19}", "{\"name\":\"Justin\", \"age\":19}") - jsonPathWithDup <- tempfile(pattern="sparkr-test", fileext=".tmp") + jsonPathWithDup <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(lines, jsonPathWithDup) df <- read.json(sqlContext, jsonPathWithDup) @@ -774,7 +774,7 @@ test_that("distinct(), unique() and dropDuplicates() on DataFrames", { c(2, 2, 1), c(2, 2, 2)) names(expected) <- c("key", "value1", "value2") expect_equivalent( - result[order(result$key, result$value1, result$value2),], + result[order(result$key, result$value1, result$value2), ], expected) result <- collect(dropDuplicates(df, c("key", "value1"))) @@ -782,7 +782,7 @@ test_that("distinct(), unique() and dropDuplicates() on DataFrames", { c(1, 1, 1), c(1, 2, 1), c(2, 1, 2), c(2, 2, 2)) names(expected) <- c("key", "value1", "value2") expect_equivalent( - result[order(result$key, result$value1, result$value2),], + result[order(result$key, result$value1, result$value2), ], expected) result <- collect(dropDuplicates(df, "key")) @@ -790,7 +790,7 @@ test_that("distinct(), unique() and dropDuplicates() on DataFrames", { c(1, 1, 1), c(2, 1, 2)) names(expected) <- c("key", "value1", "value2") expect_equivalent( - result[order(result$key, result$value1, result$value2),], + result[order(result$key, result$value1, result$value2), ], expected) }) @@ -822,10 +822,10 @@ test_that("select operators", { expect_is(df[[2]], "Column") expect_is(df[["age"]], "Column") - expect_is(df[,1], "DataFrame") - expect_equal(columns(df[,1]), c("name")) - expect_equal(columns(df[,"age"]), c("age")) - df2 <- df[,c("age", "name")] + expect_is(df[, 1], "DataFrame") + expect_equal(columns(df[, 1]), c("name")) + expect_equal(columns(df[, "age"]), c("age")) + df2 <- df[, c("age", "name")] expect_is(df2, "DataFrame") expect_equal(columns(df2), c("age", "name")) @@ -884,7 +884,7 @@ test_that("drop column", { test_that("subsetting", { # read.json returns columns in random order df <- select(read.json(sqlContext, jsonPath), "name", "age") - filtered <- df[df$age > 20,] + filtered <- df[df$age > 20, ] expect_equal(count(filtered), 1) expect_equal(columns(filtered), c("name", "age")) expect_equal(collect(filtered)$name, "Andy") @@ -903,11 +903,11 @@ test_that("subsetting", { expect_equal(count(df4), 2) expect_equal(columns(df4), c("name", "age")) - df5 <- df[df$age %in% c(19), c(1,2)] + df5 <- df[df$age %in% c(19), c(1, 2)] expect_equal(count(df5), 1) expect_equal(columns(df5), c("name", "age")) - df6 <- subset(df, df$age %in% c(30), c(1,2)) + df6 <- subset(df, df$age %in% c(30), c(1, 2)) expect_equal(count(df6), 1) expect_equal(columns(df6), c("name", "age")) @@ -959,22 +959,22 @@ test_that("test HiveContext", { expect_is(df2, "DataFrame") expect_equal(count(df2), 3) - jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") + jsonPath2 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") invisible(saveAsTable(df, "json2", "json", "append", path = jsonPath2)) df3 <- sql(hiveCtx, "select * from json2") expect_is(df3, "DataFrame") expect_equal(count(df3), 3) unlink(jsonPath2) - hivetestDataPath <- tempfile(pattern="sparkr-test", fileext=".tmp") + hivetestDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") invisible(saveAsTable(df, "hivetestbl", path = hivetestDataPath)) df4 <- sql(hiveCtx, "select * from hivetestbl") expect_is(df4, "DataFrame") expect_equal(count(df4), 3) unlink(hivetestDataPath) - parquetDataPath <- tempfile(pattern="sparkr-test", fileext=".tmp") - invisible(saveAsTable(df, "parquetest", "parquet", mode="overwrite", path=parquetDataPath)) + parquetDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") + invisible(saveAsTable(df, "parquetest", "parquet", mode = "overwrite", path = parquetDataPath)) df5 <- sql(hiveCtx, "select * from parquetest") expect_is(df5, "DataFrame") expect_equal(count(df5), 3) @@ -1094,7 +1094,7 @@ test_that("column binary mathfunctions", { "{\"a\":2, \"b\":6}", "{\"a\":3, \"b\":7}", "{\"a\":4, \"b\":8}") - jsonPathWithDup <- tempfile(pattern="sparkr-test", fileext=".tmp") + jsonPathWithDup <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(lines, jsonPathWithDup) df <- read.json(sqlContext, jsonPathWithDup) expect_equal(collect(select(df, atan2(df$a, df$b)))[1, "ATAN2(a, b)"], atan2(1, 5)) @@ -1244,7 +1244,7 @@ test_that("group by, agg functions", { df3 <- agg(gd, age = "stddev") expect_is(df3, "DataFrame") df3_local <- collect(df3) - expect_true(is.nan(df3_local[df3_local$name == "Andy",][1, 2])) + expect_true(is.nan(df3_local[df3_local$name == "Andy", ][1, 2])) df4 <- agg(gd, sumAge = sum(df$age)) expect_is(df4, "DataFrame") @@ -1264,34 +1264,34 @@ test_that("group by, agg functions", { "{\"name\":\"ID1\", \"value\": \"10\"}", "{\"name\":\"ID1\", \"value\": \"22\"}", "{\"name\":\"ID2\", \"value\": \"-3\"}") - jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") + jsonPath2 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLines2, jsonPath2) gd2 <- groupBy(read.json(sqlContext, jsonPath2), "name") df6 <- agg(gd2, value = "sum") df6_local <- collect(df6) - expect_equal(42, df6_local[df6_local$name == "ID1",][1, 2]) - expect_equal(-3, df6_local[df6_local$name == "ID2",][1, 2]) + expect_equal(42, df6_local[df6_local$name == "ID1", ][1, 2]) + expect_equal(-3, df6_local[df6_local$name == "ID2", ][1, 2]) df7 <- agg(gd2, value = "stddev") df7_local <- collect(df7) - expect_true(abs(df7_local[df7_local$name == "ID1",][1, 2] - 6.928203) < 1e-6) - expect_true(is.nan(df7_local[df7_local$name == "ID2",][1, 2])) + expect_true(abs(df7_local[df7_local$name == "ID1", ][1, 2] - 6.928203) < 1e-6) + expect_true(is.nan(df7_local[df7_local$name == "ID2", ][1, 2])) mockLines3 <- c("{\"name\":\"Andy\", \"age\":30}", "{\"name\":\"Andy\", \"age\":30}", "{\"name\":\"Justin\", \"age\":19}", "{\"name\":\"Justin\", \"age\":1}") - jsonPath3 <- tempfile(pattern="sparkr-test", fileext=".tmp") + jsonPath3 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLines3, jsonPath3) df8 <- read.json(sqlContext, jsonPath3) gd3 <- groupBy(df8, "name") gd3_local <- collect(sum(gd3)) - expect_equal(60, gd3_local[gd3_local$name == "Andy",][1, 2]) - expect_equal(20, gd3_local[gd3_local$name == "Justin",][1, 2]) + expect_equal(60, gd3_local[gd3_local$name == "Andy", ][1, 2]) + expect_equal(20, gd3_local[gd3_local$name == "Justin", ][1, 2]) expect_true(abs(collect(agg(df, sd(df$age)))[1, 1] - 7.778175) < 1e-6) gd3_local <- collect(agg(gd3, var(df8$age))) - expect_equal(162, gd3_local[gd3_local$name == "Justin",][1, 2]) + expect_equal(162, gd3_local[gd3_local$name == "Justin", ][1, 2]) # Test stats::sd, stats::var are working expect_true(abs(sd(1:2) - 0.7071068) < 1e-6) @@ -1304,10 +1304,10 @@ test_that("group by, agg functions", { test_that("arrange() and orderBy() on a DataFrame", { df <- read.json(sqlContext, jsonPath) sorted <- arrange(df, df$age) - expect_equal(collect(sorted)[1,2], "Michael") + expect_equal(collect(sorted)[1, 2], "Michael") sorted2 <- arrange(df, "name", decreasing = FALSE) - expect_equal(collect(sorted2)[2,"age"], 19) + expect_equal(collect(sorted2)[2, "age"], 19) sorted3 <- orderBy(df, asc(df$age)) expect_true(is.na(first(sorted3)$age)) @@ -1315,16 +1315,16 @@ test_that("arrange() and orderBy() on a DataFrame", { sorted4 <- orderBy(df, desc(df$name)) expect_equal(first(sorted4)$name, "Michael") - expect_equal(collect(sorted4)[3,"name"], "Andy") + expect_equal(collect(sorted4)[3, "name"], "Andy") sorted5 <- arrange(df, "age", "name", decreasing = TRUE) - expect_equal(collect(sorted5)[1,2], "Andy") + expect_equal(collect(sorted5)[1, 2], "Andy") - sorted6 <- arrange(df, "age","name", decreasing = c(T, F)) - expect_equal(collect(sorted6)[1,2], "Andy") + sorted6 <- arrange(df, "age", "name", decreasing = c(T, F)) + expect_equal(collect(sorted6)[1, 2], "Andy") sorted7 <- arrange(df, "name", decreasing = FALSE) - expect_equal(collect(sorted7)[2,"age"], 19) + expect_equal(collect(sorted7)[2, "age"], 19) }) test_that("filter() on a DataFrame", { @@ -1357,7 +1357,7 @@ test_that("join() and merge() on a DataFrame", { "{\"name\":\"Andy\", \"test\": \"no\"}", "{\"name\":\"Justin\", \"test\": \"yes\"}", "{\"name\":\"Bob\", \"test\": \"yes\"}") - jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") + jsonPath2 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLines2, jsonPath2) df2 <- read.json(sqlContext, jsonPath2) @@ -1409,12 +1409,12 @@ test_that("join() and merge() on a DataFrame", { expect_equal(names(merged), c("age", "name_x", "name_y", "test")) expect_equal(collect(orderBy(merged, merged$name_x))$age[3], 19) - merged <- merge(df, df2, suffixes = c("-X","-Y")) + merged <- merge(df, df2, suffixes = c("-X", "-Y")) expect_equal(count(merged), 3) expect_equal(names(merged), c("age", "name-X", "name-Y", "test")) expect_equal(collect(orderBy(merged, merged$"name-X"))$age[1], 30) - merged <- merge(df, df2, by = "name", suffixes = c("-X","-Y"), sort = FALSE) + merged <- merge(df, df2, by = "name", suffixes = c("-X", "-Y"), sort = FALSE) expect_equal(count(merged), 3) expect_equal(names(merged), c("age", "name-X", "name-Y", "test")) expect_equal(collect(orderBy(merged, merged$"name-Y"))$"name-X"[3], "Michael") @@ -1432,7 +1432,7 @@ test_that("join() and merge() on a DataFrame", { "{\"name\":\"Andy\", \"name_y\":\"Andy\", \"test\": \"no\"}", "{\"name\":\"Justin\", \"name_y\":\"Justin\", \"test\": \"yes\"}", "{\"name\":\"Bob\", \"name_y\":\"Bob\", \"test\": \"yes\"}") - jsonPath3 <- tempfile(pattern="sparkr-test", fileext=".tmp") + jsonPath3 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLines3, jsonPath3) df3 <- read.json(sqlContext, jsonPath3) expect_error(merge(df, df3), @@ -1460,8 +1460,8 @@ test_that("showDF()", { "|null|Michael|\n", "| 30| Andy|\n", "| 19| Justin|\n", - "+----+-------+\n", sep="") - expect_output(s , expected) + "+----+-------+\n", sep = "") + expect_output(s, expected) }) test_that("isLocal()", { @@ -1475,7 +1475,7 @@ test_that("unionAll(), rbind(), except(), and intersect() on a DataFrame", { lines <- c("{\"name\":\"Bob\", \"age\":24}", "{\"name\":\"Andy\", \"age\":30}", "{\"name\":\"James\", \"age\":35}") - jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") + jsonPath2 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(lines, jsonPath2) df2 <- read.df(sqlContext, jsonPath2, "json") @@ -1558,7 +1558,7 @@ test_that("mutate(), transform(), rename() and names()", { test_that("read/write Parquet files", { df <- read.df(sqlContext, jsonPath, "json") # Test write.df and read.df - write.df(df, parquetPath, "parquet", mode="overwrite") + write.df(df, parquetPath, "parquet", mode = "overwrite") df2 <- read.df(sqlContext, parquetPath, "parquet") expect_is(df2, "DataFrame") expect_equal(count(df2), 3) @@ -1593,7 +1593,7 @@ test_that("read/write text files", { expect_equal(colnames(df), c("value")) expect_equal(count(df), 3) textPath <- tempfile(pattern = "textPath", fileext = ".txt") - write.df(df, textPath, "text", mode="overwrite") + write.df(df, textPath, "text", mode = "overwrite") # Test write.text and read.text textPath2 <- tempfile(pattern = "textPath2", fileext = ".txt") @@ -1631,13 +1631,13 @@ test_that("dropna() and na.omit() on a DataFrame", { # drop with columns - expected <- rows[!is.na(rows$name),] + expected <- rows[!is.na(rows$name), ] actual <- collect(dropna(df, cols = "name")) expect_identical(expected, actual) actual <- collect(na.omit(df, cols = "name")) expect_identical(expected, actual) - expected <- rows[!is.na(rows$age),] + expected <- rows[!is.na(rows$age), ] actual <- collect(dropna(df, cols = "age")) row.names(expected) <- row.names(actual) # identical on two dataframes does not work here. Don't know why. @@ -1647,13 +1647,13 @@ test_that("dropna() and na.omit() on a DataFrame", { expect_identical(expected$name, actual$name) actual <- collect(na.omit(df, cols = "age")) - expected <- rows[!is.na(rows$age) & !is.na(rows$height),] + expected <- rows[!is.na(rows$age) & !is.na(rows$height), ] actual <- collect(dropna(df, cols = c("age", "height"))) expect_identical(expected, actual) actual <- collect(na.omit(df, cols = c("age", "height"))) expect_identical(expected, actual) - expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] + expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name), ] actual <- collect(dropna(df)) expect_identical(expected, actual) actual <- collect(na.omit(df)) @@ -1661,31 +1661,31 @@ test_that("dropna() and na.omit() on a DataFrame", { # drop with how - expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] + expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name), ] actual <- collect(dropna(df)) expect_identical(expected, actual) actual <- collect(na.omit(df)) expect_identical(expected, actual) - expected <- rows[!is.na(rows$age) | !is.na(rows$height) | !is.na(rows$name),] + expected <- rows[!is.na(rows$age) | !is.na(rows$height) | !is.na(rows$name), ] actual <- collect(dropna(df, "all")) expect_identical(expected, actual) actual <- collect(na.omit(df, "all")) expect_identical(expected, actual) - expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] + expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name), ] actual <- collect(dropna(df, "any")) expect_identical(expected, actual) actual <- collect(na.omit(df, "any")) expect_identical(expected, actual) - expected <- rows[!is.na(rows$age) & !is.na(rows$height),] + expected <- rows[!is.na(rows$age) & !is.na(rows$height), ] actual <- collect(dropna(df, "any", cols = c("age", "height"))) expect_identical(expected, actual) actual <- collect(na.omit(df, "any", cols = c("age", "height"))) expect_identical(expected, actual) - expected <- rows[!is.na(rows$age) | !is.na(rows$height),] + expected <- rows[!is.na(rows$age) | !is.na(rows$height), ] actual <- collect(dropna(df, "all", cols = c("age", "height"))) expect_identical(expected, actual) actual <- collect(na.omit(df, "all", cols = c("age", "height"))) @@ -1693,7 +1693,7 @@ test_that("dropna() and na.omit() on a DataFrame", { # drop with threshold - expected <- rows[as.integer(!is.na(rows$age)) + as.integer(!is.na(rows$height)) >= 2,] + expected <- rows[as.integer(!is.na(rows$age)) + as.integer(!is.na(rows$height)) >= 2, ] actual <- collect(dropna(df, minNonNulls = 2, cols = c("age", "height"))) expect_identical(expected, actual) actual <- collect(na.omit(df, minNonNulls = 2, cols = c("age", "height"))) @@ -1701,7 +1701,7 @@ test_that("dropna() and na.omit() on a DataFrame", { expected <- rows[as.integer(!is.na(rows$age)) + as.integer(!is.na(rows$height)) + - as.integer(!is.na(rows$name)) >= 3,] + as.integer(!is.na(rows$name)) >= 3, ] actual <- collect(dropna(df, minNonNulls = 3, cols = c("name", "age", "height"))) expect_identical(expected, actual) actual <- collect(na.omit(df, minNonNulls = 3, cols = c("name", "age", "height"))) @@ -1754,7 +1754,7 @@ test_that("crosstab() on a DataFrame", { }) df <- toDF(rdd, list("a", "b")) ct <- crosstab(df, "a", "b") - ordered <- ct[order(ct$a_b),] + ordered <- ct[order(ct$a_b), ] row.names(ordered) <- NULL expected <- data.frame("a_b" = c("a0", "a1", "a2"), "b0" = c(1, 0, 1), "b1" = c(1, 1, 0), stringsAsFactors = FALSE, row.names = NULL) @@ -1782,10 +1782,10 @@ test_that("freqItems() on a DataFrame", { negDoubles = input * -1.0, stringsAsFactors = F) rdf[ input %% 3 == 0, ] <- c(1, "1", -1) df <- createDataFrame(sqlContext, rdf) - multiColResults <- freqItems(df, c("numbers", "letters"), support=0.1) + multiColResults <- freqItems(df, c("numbers", "letters"), support = 0.1) expect_true(1 %in% multiColResults$numbers[[1]]) expect_true("1" %in% multiColResults$letters[[1]]) - singleColResult <- freqItems(df, "negDoubles", support=0.1) + singleColResult <- freqItems(df, "negDoubles", support = 0.1) expect_true(-1 %in% head(singleColResult$negDoubles)[[1]]) l <- lapply(c(0:99), function(i) { @@ -1860,9 +1860,9 @@ test_that("with() on a DataFrame", { test_that("Method coltypes() to get and set R's data types of a DataFrame", { expect_equal(coltypes(irisDF), c(rep("numeric", 4), "character")) - data <- data.frame(c1=c(1,2,3), - c2=c(T,F,T), - c3=c("2015/01/01 10:00:00", "2015/01/02 10:00:00", "2015/01/03 10:00:00")) + data <- data.frame(c1 = c(1, 2, 3), + c2 = c(T, F, T), + c3 = c("2015/01/01 10:00:00", "2015/01/02 10:00:00", "2015/01/03 10:00:00")) schema <- structType(structField("c1", "byte"), structField("c3", "boolean"), @@ -1874,7 +1874,7 @@ test_that("Method coltypes() to get and set R's data types of a DataFrame", { # Test complex types x <- createDataFrame(sqlContext, list(list(as.environment( - list("a"="b", "c"="d", "e"="f"))))) + list("a" = "b", "c" = "d", "e" = "f"))))) expect_equal(coltypes(x), "map") df <- selectExpr(read.json(sqlContext, jsonPath), "name", "(age * 1.21) as age") @@ -1918,7 +1918,7 @@ test_that("Method str()", { # the number of columns. Therefore, it will suffice to check for the # number of returned rows x <- runif(200, 1, 10) - df <- data.frame(t(as.matrix(data.frame(x,x,x,x,x,x,x,x,x)))) + df <- data.frame(t(as.matrix(data.frame(x, x, x, x, x, x, x, x, x)))) DF <- createDataFrame(sqlContext, df) out <- capture.output(str(DF)) expect_equal(length(out), 103) diff --git a/R/pkg/inst/tests/testthat/test_textFile.R b/R/pkg/inst/tests/testthat/test_textFile.R index a9cf83dbdbdb1..e64ef1bb31a3a 100644 --- a/R/pkg/inst/tests/testthat/test_textFile.R +++ b/R/pkg/inst/tests/testthat/test_textFile.R @@ -23,7 +23,7 @@ sc <- sparkR.init() mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("textFile() on a local file returns an RDD", { - fileName <- tempfile(pattern="spark-test", fileext=".tmp") + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) rdd <- textFile(sc, fileName) @@ -35,7 +35,7 @@ test_that("textFile() on a local file returns an RDD", { }) test_that("textFile() followed by a collect() returns the same content", { - fileName <- tempfile(pattern="spark-test", fileext=".tmp") + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) rdd <- textFile(sc, fileName) @@ -45,7 +45,7 @@ test_that("textFile() followed by a collect() returns the same content", { }) test_that("textFile() word count works as expected", { - fileName <- tempfile(pattern="spark-test", fileext=".tmp") + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) rdd <- textFile(sc, fileName) @@ -63,7 +63,7 @@ test_that("textFile() word count works as expected", { }) test_that("several transformations on RDD created by textFile()", { - fileName <- tempfile(pattern="spark-test", fileext=".tmp") + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) rdd <- textFile(sc, fileName) # RDD @@ -77,8 +77,8 @@ test_that("several transformations on RDD created by textFile()", { }) test_that("textFile() followed by a saveAsTextFile() returns the same content", { - fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") - fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") + fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) rdd <- textFile(sc, fileName1, 1L) @@ -91,7 +91,7 @@ test_that("textFile() followed by a saveAsTextFile() returns the same content", }) test_that("saveAsTextFile() on a parallelized list works as expected", { - fileName <- tempfile(pattern="spark-test", fileext=".tmp") + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") l <- list(1, 2, 3) rdd <- parallelize(sc, l, 1L) saveAsTextFile(rdd, fileName) @@ -102,8 +102,8 @@ test_that("saveAsTextFile() on a parallelized list works as expected", { }) test_that("textFile() and saveAsTextFile() word count works as expected", { - fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") - fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") + fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) rdd <- textFile(sc, fileName1) @@ -127,8 +127,8 @@ test_that("textFile() and saveAsTextFile() word count works as expected", { }) test_that("textFile() on multiple paths", { - fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") - fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") + fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines("Spark is pretty.", fileName1) writeLines("Spark is awesome.", fileName2) @@ -140,7 +140,7 @@ test_that("textFile() on multiple paths", { }) test_that("Pipelined operations on RDDs created using textFile", { - fileName <- tempfile(pattern="spark-test", fileext=".tmp") + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) rdd <- textFile(sc, fileName) diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index 56f14a3bce61e..4218138f641d1 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -41,7 +41,7 @@ test_that("convertJListToRList() gives back (deserializes) the original JLists test_that("serializeToBytes on RDD", { # File content mockFile <- c("Spark is pretty.", "Spark is awesome.") - fileName <- tempfile(pattern="spark-test", fileext=".tmp") + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) text.rdd <- textFile(sc, fileName) @@ -86,8 +86,8 @@ test_that("cleanClosure on R functions", { f <- function(x) { defUse <- base::as.integer(x) + 1 # Test for access operators `::`. lapply(x, g) + 1 # Test for capturing function call "g"'s closure as a argument of lapply. - l$field[1,1] <- 3 # Test for access operators `$`. - res <- defUse + l$field[1,] # Test for def-use chain of "defUse", and "" symbol. + l$field[1, 1] <- 3 # Test for access operators `$`. + res <- defUse + l$field[1, ] # Test for def-use chain of "defUse", and "" symbol. f(res) # Test for recursive calls. } newF <- cleanClosure(f) @@ -132,7 +132,7 @@ test_that("cleanClosure on R functions", { expect_equal(actual, expected) # Test for broadcast variables. - a <- matrix(nrow=10, ncol=10, data=rnorm(100)) + a <- matrix(nrow = 10, ncol = 10, data = rnorm(100)) aBroadcast <- broadcast(sc, a) normMultiply <- function(x) { norm(aBroadcast$value) * x } newnormMultiply <- SparkR:::cleanClosure(normMultiply) diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R index 3ae072beca11b..b6784dbae3203 100644 --- a/R/pkg/inst/worker/worker.R +++ b/R/pkg/inst/worker/worker.R @@ -55,7 +55,7 @@ serializer <- SparkR:::readString(inputCon) # Include packages as required packageNames <- unserialize(SparkR:::readRaw(inputCon)) for (pkg in packageNames) { - suppressPackageStartupMessages(library(as.character(pkg), character.only=TRUE)) + suppressPackageStartupMessages(library(as.character(pkg), character.only = TRUE)) } # read function dependencies diff --git a/bin/pyspark b/bin/pyspark index 2ac4a8be250d6..6962f4577d5b0 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -67,7 +67,7 @@ export PYSPARK_PYTHON # Add the PySpark classes to the Python path: export PYTHONPATH="${SPARK_HOME}/python/:$PYTHONPATH" -export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.9.1-src.zip:$PYTHONPATH" +export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.9.2-src.zip:$PYTHONPATH" # Load the PySpark shell.py script when ./pyspark is used interactively: export OLD_PYTHONSTARTUP="$PYTHONSTARTUP" diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index 21fe28155a596..cb788497ffc79 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -30,7 +30,7 @@ if "x%PYSPARK_DRIVER_PYTHON%"=="x" ( ) set PYTHONPATH=%SPARK_HOME%\python;%PYTHONPATH% -set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.9.1-src.zip;%PYTHONPATH% +set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.9.2-src.zip;%PYTHONPATH% set OLD_PYTHONSTARTUP=%PYTHONSTARTUP% set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py diff --git a/bin/spark-class b/bin/spark-class index 5d964ba96abd8..e710e388be1bc 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -35,42 +35,27 @@ else fi fi -# Find assembly jar -SPARK_ASSEMBLY_JAR= +# Find Spark jars. +# TODO: change the directory name when Spark jars move from "lib". if [ -f "${SPARK_HOME}/RELEASE" ]; then - ASSEMBLY_DIR="${SPARK_HOME}/lib" + SPARK_JARS_DIR="${SPARK_HOME}/lib" else - ASSEMBLY_DIR="${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION" + SPARK_JARS_DIR="${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION" fi -GREP_OPTIONS= -num_jars="$(ls -1 "$ASSEMBLY_DIR" | grep "^spark-assembly.*hadoop.*\.jar$" | wc -l)" -if [ "$num_jars" -eq "0" -a -z "$SPARK_ASSEMBLY_JAR" -a "$SPARK_PREPEND_CLASSES" != "1" ]; then - echo "Failed to find Spark assembly in $ASSEMBLY_DIR." 1>&2 +if [ ! -d "$SPARK_JARS_DIR" ]; then + echo "Failed to find Spark jars directory ($SPARK_JARS_DIR)." 1>&2 echo "You need to build Spark before running this program." 1>&2 exit 1 fi -if [ -d "$ASSEMBLY_DIR" ]; then - ASSEMBLY_JARS="$(ls -1 "$ASSEMBLY_DIR" | grep "^spark-assembly.*hadoop.*\.jar$" || true)" - if [ "$num_jars" -gt "1" ]; then - echo "Found multiple Spark assembly jars in $ASSEMBLY_DIR:" 1>&2 - echo "$ASSEMBLY_JARS" 1>&2 - echo "Please remove all but one jar." 1>&2 - exit 1 - fi -fi -SPARK_ASSEMBLY_JAR="${ASSEMBLY_DIR}/${ASSEMBLY_JARS}" - -LAUNCH_CLASSPATH="$SPARK_ASSEMBLY_JAR" +LAUNCH_CLASSPATH="$SPARK_JARS_DIR/*" # Add the launcher build dir to the classpath if requested. if [ -n "$SPARK_PREPEND_CLASSES" ]; then LAUNCH_CLASSPATH="${SPARK_HOME}/launcher/target/scala-$SPARK_SCALA_VERSION/classes:$LAUNCH_CLASSPATH" fi -export _SPARK_ASSEMBLY="$SPARK_ASSEMBLY_JAR" - # For tests if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index c4fadb822323d..565b87c102b19 100644 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -28,33 +28,27 @@ if "x%1"=="x" ( exit /b 1 ) -rem Find assembly jar -set SPARK_ASSEMBLY_JAR=0 - +rem Find Spark jars. +rem TODO: change the directory name when Spark jars move from "lib". if exist "%SPARK_HOME%\RELEASE" ( - set ASSEMBLY_DIR="%SPARK_HOME%\lib" + set SPARK_JARS_DIR="%SPARK_HOME%\lib" ) else ( - set ASSEMBLY_DIR="%SPARK_HOME%\assembly\target\scala-%SPARK_SCALA_VERSION%" + set SPARK_JARS_DIR="%SPARK_HOME%\assembly\target\scala-%SPARK_SCALA_VERSION%" ) -for %%d in (%ASSEMBLY_DIR%\spark-assembly*hadoop*.jar) do ( - set SPARK_ASSEMBLY_JAR=%%d -) -if "%SPARK_ASSEMBLY_JAR%"=="0" ( +if not exist "%SPARK_JARS_DIR%"\ ( echo Failed to find Spark assembly JAR. echo You need to build Spark before running this program. exit /b 1 ) -set LAUNCH_CLASSPATH=%SPARK_ASSEMBLY_JAR% +set LAUNCH_CLASSPATH=%SPARK_JARS_DIR%\* rem Add the launcher build dir to the classpath if requested. if not "x%SPARK_PREPEND_CLASSES%"=="x" ( set LAUNCH_CLASSPATH="%SPARK_HOME%\launcher\target\scala-%SPARK_SCALA_VERSION%\classes;%LAUNCH_CLASSPATH%" ) -set _SPARK_ASSEMBLY=%SPARK_ASSEMBLY_JAR% - rem Figure out where java is. set RUNNER=java if not "x%JAVA_HOME%"=="x" set RUNNER=%JAVA_HOME%\bin\java diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java index 88ba3ccebdf20..b0e85bae7c309 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java @@ -34,8 +34,7 @@ class StreamInterceptor implements TransportFrameDecoder.Interceptor { private final String streamId; private final long byteCount; private final StreamCallback callback; - - private volatile long bytesRead; + private long bytesRead; StreamInterceptor( TransportResponseHandler handler, diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java index 9162d0b977f83..be217522367c5 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java @@ -17,8 +17,8 @@ package org.apache.spark.network.protocol; +import java.nio.charset.StandardCharsets; -import com.google.common.base.Charsets; import io.netty.buffer.ByteBuf; /** Provides a canonical set of Encoders for simple types. */ @@ -27,11 +27,11 @@ public class Encoders { /** Strings are encoded with their length followed by UTF-8 bytes. */ public static class Strings { public static int encodedLength(String s) { - return 4 + s.getBytes(Charsets.UTF_8).length; + return 4 + s.getBytes(StandardCharsets.UTF_8).length; } public static void encode(ByteBuf buf, String s) { - byte[] bytes = s.getBytes(Charsets.UTF_8); + byte[] bytes = s.getBytes(StandardCharsets.UTF_8); buf.writeInt(bytes.length); buf.writeBytes(bytes); } @@ -40,7 +40,7 @@ public static String decode(ByteBuf buf) { int length = buf.readInt(); byte[] bytes = new byte[length]; buf.readBytes(bytes); - return new String(bytes, Charsets.UTF_8); + return new String(bytes, StandardCharsets.UTF_8); } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java index 431cb67a2ae0b..b802a5af63c94 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java @@ -28,9 +28,9 @@ import javax.security.sasl.SaslException; import javax.security.sasl.SaslServer; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.util.Map; -import com.google.common.base.Charsets; import com.google.common.base.Preconditions; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableMap; @@ -187,14 +187,14 @@ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallback /* Encode a byte[] identifier as a Base64-encoded string. */ public static String encodeIdentifier(String identifier) { Preconditions.checkNotNull(identifier, "User cannot be null if SASL is enabled"); - return Base64.encode(Unpooled.wrappedBuffer(identifier.getBytes(Charsets.UTF_8))) - .toString(Charsets.UTF_8); + return Base64.encode(Unpooled.wrappedBuffer(identifier.getBytes(StandardCharsets.UTF_8))) + .toString(StandardCharsets.UTF_8); } /** Encode a password as a base64-encoded char[] array. */ public static char[] encodePassword(String password) { Preconditions.checkNotNull(password, "Password cannot be null if SASL is enabled"); - return Base64.encode(Unpooled.wrappedBuffer(password.getBytes(Charsets.UTF_8))) - .toString(Charsets.UTF_8).toCharArray(); + return Base64.encode(Unpooled.wrappedBuffer(password.getBytes(StandardCharsets.UTF_8))) + .toString(StandardCharsets.UTF_8).toCharArray(); } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java index ccc527306d920..fbed2f053dc6c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -21,11 +21,11 @@ import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.util.concurrent.TimeUnit; import java.util.regex.Matcher; import java.util.regex.Pattern; -import com.google.common.base.Charsets; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableMap; import io.netty.buffer.Unpooled; @@ -68,7 +68,7 @@ public static int nonNegativeHash(Object obj) { * converted back to the same string through {@link #bytesToString(ByteBuffer)}. */ public static ByteBuffer stringToBytes(String s) { - return Unpooled.wrappedBuffer(s.getBytes(Charsets.UTF_8)).nioBuffer(); + return Unpooled.wrappedBuffer(s.getBytes(StandardCharsets.UTF_8)).nioBuffer(); } /** @@ -76,7 +76,7 @@ public static ByteBuffer stringToBytes(String s) { * converted back to the same byte buffer through {@link #stringToBytes(String)}. */ public static String bytesToString(ByteBuffer b) { - return Unpooled.wrappedBuffer(b).toString(Charsets.UTF_8); + return Unpooled.wrappedBuffer(b).toString(StandardCharsets.UTF_8); } /* @@ -236,11 +236,11 @@ public static long byteStringAs(String str, ByteUnit unit) { } } catch (NumberFormatException e) { - String timeError = "Size must be specified as bytes (b), " + + String byteError = "Size must be specified as bytes (b), " + "kibibytes (k), mebibytes (m), gibibytes (g), tebibytes (t), or pebibytes(p). " + "E.g. 50b, 100k, or 250m."; - throw new NumberFormatException(timeError + "\n" + e.getMessage()); + throw new NumberFormatException(byteError + "\n" + e.getMessage()); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index fe933ed650caf..460110d78f15b 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -18,6 +18,7 @@ package org.apache.spark.network.shuffle; import java.io.*; +import java.nio.charset.StandardCharsets; import java.util.*; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.Executor; @@ -27,7 +28,6 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Charsets; import com.google.common.base.Objects; import com.google.common.collect.Maps; import org.fusesource.leveldbjni.JniDBFactory; @@ -152,7 +152,7 @@ public void registerExecutor( try { if (db != null) { byte[] key = dbAppExecKey(fullId); - byte[] value = mapper.writeValueAsString(executorInfo).getBytes(Charsets.UTF_8); + byte[] value = mapper.writeValueAsString(executorInfo).getBytes(StandardCharsets.UTF_8); db.put(key, value); } } catch (Exception e) { @@ -350,7 +350,7 @@ private static byte[] dbAppExecKey(AppExecId appExecId) throws IOException { // we stick a common prefix on all the keys so we can find them in the DB String appExecJson = mapper.writeValueAsString(appExecId); String key = (APP_KEY_PREFIX + ";" + appExecJson); - return key.getBytes(Charsets.UTF_8); + return key.getBytes(StandardCharsets.UTF_8); } private static AppExecId parseDbAppExecKey(String s) throws IOException { @@ -368,10 +368,10 @@ static ConcurrentMap reloadRegisteredExecutors(D ConcurrentMap registeredExecutors = Maps.newConcurrentMap(); if (db != null) { DBIterator itr = db.iterator(); - itr.seek(APP_KEY_PREFIX.getBytes(Charsets.UTF_8)); + itr.seek(APP_KEY_PREFIX.getBytes(StandardCharsets.UTF_8)); while (itr.hasNext()) { Map.Entry e = itr.next(); - String key = new String(e.getKey(), Charsets.UTF_8); + String key = new String(e.getKey(), StandardCharsets.UTF_8); if (!key.startsWith(APP_KEY_PREFIX)) { break; } @@ -418,7 +418,7 @@ private static void storeVersion(DB db) throws IOException { public static class StoreVersion { - static final byte[] KEY = "StoreVersion".getBytes(Charsets.UTF_8); + static final byte[] KEY = "StoreVersion".getBytes(StandardCharsets.UTF_8); public final int major; public final int minor; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java index 675820308bd4c..2add9c83a73d2 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java @@ -19,7 +19,12 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.apache.spark.network.shuffle.protocol.mesos.ShuffleServiceHeartbeat; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -41,6 +46,13 @@ public class MesosExternalShuffleClient extends ExternalShuffleClient { private final Logger logger = LoggerFactory.getLogger(MesosExternalShuffleClient.class); + private final ScheduledExecutorService heartbeaterThread = + Executors.newSingleThreadScheduledExecutor( + new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat("mesos-external-shuffle-client-heartbeater") + .build()); + /** * Creates an Mesos external shuffle client that wraps the {@link ExternalShuffleClient}. * Please refer to docs on {@link ExternalShuffleClient} for more information. @@ -53,21 +65,59 @@ public MesosExternalShuffleClient( super(conf, secretKeyHolder, saslEnabled, saslEncryptionEnabled); } - public void registerDriverWithShuffleService(String host, int port) throws IOException { + public void registerDriverWithShuffleService( + String host, + int port, + long heartbeatTimeoutMs, + long heartbeatIntervalMs) throws IOException { + checkInit(); - ByteBuffer registerDriver = new RegisterDriver(appId).toByteBuffer(); + ByteBuffer registerDriver = new RegisterDriver(appId, heartbeatTimeoutMs).toByteBuffer(); TransportClient client = clientFactory.createClient(host, port); - client.sendRpc(registerDriver, new RpcResponseCallback() { - @Override - public void onSuccess(ByteBuffer response) { - logger.info("Successfully registered app " + appId + " with external shuffle service."); - } - - @Override - public void onFailure(Throwable e) { - logger.warn("Unable to register app " + appId + " with external shuffle service. " + + client.sendRpc(registerDriver, new RegisterDriverCallback(client, heartbeatIntervalMs)); + } + + private class RegisterDriverCallback implements RpcResponseCallback { + private final TransportClient client; + private final long heartbeatIntervalMs; + + private RegisterDriverCallback(TransportClient client, long heartbeatIntervalMs) { + this.client = client; + this.heartbeatIntervalMs = heartbeatIntervalMs; + } + + @Override + public void onSuccess(ByteBuffer response) { + heartbeaterThread.scheduleAtFixedRate( + new Heartbeater(client), 0, heartbeatIntervalMs, TimeUnit.MILLISECONDS); + logger.info("Successfully registered app " + appId + " with external shuffle service."); + } + + @Override + public void onFailure(Throwable e) { + logger.warn("Unable to register app " + appId + " with external shuffle service. " + "Please manually remove shuffle data after driver exit. Error: " + e); - } - }); + } + } + + @Override + public void close() { + heartbeaterThread.shutdownNow(); + super.close(); + } + + private class Heartbeater implements Runnable { + + private final TransportClient client; + + private Heartbeater(TransportClient client) { + this.client = client; + } + + @Override + public void run() { + // TODO: Stop sending heartbeats if the shuffle service has lost the app due to timeout + client.send(new ShuffleServiceHeartbeat(appId).toByteBuffer()); + } } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java index 7fbe3384b4d4f..21c0ff4136aa8 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -24,6 +24,7 @@ import org.apache.spark.network.protocol.Encodable; import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver; +import org.apache.spark.network.shuffle.protocol.mesos.ShuffleServiceHeartbeat; /** * Messages handled by the {@link org.apache.spark.network.shuffle.ExternalShuffleBlockHandler}, or @@ -40,7 +41,8 @@ public abstract class BlockTransferMessage implements Encodable { /** Preceding every serialized message is its type, which allows us to deserialize it. */ public static enum Type { - OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4); + OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4), + HEARTBEAT(5); private final byte id; @@ -64,6 +66,7 @@ public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) { case 2: return RegisterExecutor.decode(buf); case 3: return StreamHandle.decode(buf); case 4: return RegisterDriver.decode(buf); + case 5: return ShuffleServiceHeartbeat.decode(buf); default: throw new IllegalArgumentException("Unknown message type: " + type); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java index eeb0019411628..d5f53ccb7f741 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java @@ -31,29 +31,34 @@ */ public class RegisterDriver extends BlockTransferMessage { private final String appId; + private final long heartbeatTimeoutMs; - public RegisterDriver(String appId) { + public RegisterDriver(String appId, long heartbeatTimeoutMs) { this.appId = appId; + this.heartbeatTimeoutMs = heartbeatTimeoutMs; } public String getAppId() { return appId; } + public long getHeartbeatTimeoutMs() { return heartbeatTimeoutMs; } + @Override protected Type type() { return Type.REGISTER_DRIVER; } @Override public int encodedLength() { - return Encoders.Strings.encodedLength(appId); + return Encoders.Strings.encodedLength(appId) + Long.SIZE / Byte.SIZE; } @Override public void encode(ByteBuf buf) { Encoders.Strings.encode(buf, appId); + buf.writeLong(heartbeatTimeoutMs); } @Override public int hashCode() { - return Objects.hashCode(appId); + return Objects.hashCode(appId, heartbeatTimeoutMs); } @Override @@ -66,6 +71,7 @@ public boolean equals(Object o) { public static RegisterDriver decode(ByteBuf buf) { String appId = Encoders.Strings.decode(buf); - return new RegisterDriver(appId); + long heartbeatTimeout = buf.readLong(); + return new RegisterDriver(appId, heartbeatTimeout); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/ShuffleServiceHeartbeat.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/ShuffleServiceHeartbeat.java new file mode 100644 index 0000000000000..b30bb9aed55b6 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/ShuffleServiceHeartbeat.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol.mesos; + +import io.netty.buffer.ByteBuf; +import org.apache.spark.network.protocol.Encoders; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; + +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + +/** + * A heartbeat sent from the driver to the MesosExternalShuffleService. + */ +public class ShuffleServiceHeartbeat extends BlockTransferMessage { + private final String appId; + + public ShuffleServiceHeartbeat(String appId) { + this.appId = appId; + } + + public String getAppId() { return appId; } + + @Override + protected Type type() { return Type.HEARTBEAT; } + + @Override + public int encodedLength() { return Encoders.Strings.encodedLength(appId); } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + } + + public static ShuffleServiceHeartbeat decode(ByteBuf buf) { + return new ShuffleServiceHeartbeat(Encoders.Strings.decode(buf)); + } +} diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index 60a1b8b0451fe..d9b5f0261aaba 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.io.CharStreams; @@ -34,15 +35,16 @@ import static org.junit.Assert.*; public class ExternalShuffleBlockResolverSuite { - static String sortBlock0 = "Hello!"; - static String sortBlock1 = "World!"; + private static final String sortBlock0 = "Hello!"; + private static final String sortBlock1 = "World!"; - static String hashBlock0 = "Elementary"; - static String hashBlock1 = "Tabular"; + private static final String hashBlock0 = "Elementary"; + private static final String hashBlock1 = "Tabular"; - static TestShuffleDataContext dataContext; + private static TestShuffleDataContext dataContext; - static TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + private static final TransportConf conf = + new TransportConf("shuffle", new SystemPropertyConfigProvider()); @BeforeClass public static void beforeAll() throws IOException { @@ -50,10 +52,12 @@ public static void beforeAll() throws IOException { dataContext.create(); // Write some sort and hash data. - dataContext.insertSortShuffleData(0, 0, - new byte[][] { sortBlock0.getBytes(), sortBlock1.getBytes() } ); - dataContext.insertHashShuffleData(1, 0, - new byte[][] { hashBlock0.getBytes(), hashBlock1.getBytes() } ); + dataContext.insertSortShuffleData(0, 0, new byte[][] { + sortBlock0.getBytes(StandardCharsets.UTF_8), + sortBlock1.getBytes(StandardCharsets.UTF_8)}); + dataContext.insertHashShuffleData(1, 0, new byte[][] { + hashBlock0.getBytes(StandardCharsets.UTF_8), + hashBlock1.getBytes(StandardCharsets.UTF_8)}); } @AfterClass @@ -100,13 +104,15 @@ public void testSortShuffleBlocks() throws IOException { InputStream block0Stream = resolver.getBlockData("app0", "exec0", "shuffle_0_0_0").createInputStream(); - String block0 = CharStreams.toString(new InputStreamReader(block0Stream)); + String block0 = CharStreams.toString( + new InputStreamReader(block0Stream, StandardCharsets.UTF_8)); block0Stream.close(); assertEquals(sortBlock0, block0); InputStream block1Stream = resolver.getBlockData("app0", "exec0", "shuffle_0_0_1").createInputStream(); - String block1 = CharStreams.toString(new InputStreamReader(block1Stream)); + String block1 = CharStreams.toString( + new InputStreamReader(block1Stream, StandardCharsets.UTF_8)); block1Stream.close(); assertEquals(sortBlock1, block1); } @@ -119,13 +125,15 @@ public void testHashShuffleBlocks() throws IOException { InputStream block0Stream = resolver.getBlockData("app0", "exec0", "shuffle_1_0_0").createInputStream(); - String block0 = CharStreams.toString(new InputStreamReader(block0Stream)); + String block0 = CharStreams.toString( + new InputStreamReader(block0Stream, StandardCharsets.UTF_8)); block0Stream.close(); assertEquals(hashBlock0, block0); InputStream block1Stream = resolver.getBlockData("app0", "exec0", "shuffle_1_0_1").createInputStream(); - String block1 = CharStreams.toString(new InputStreamReader(block1Stream)); + String block1 = CharStreams.toString( + new InputStreamReader(block1Stream, StandardCharsets.UTF_8)); block1Stream.close(); assertEquals(hashBlock1, block1); } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java index 532d7ab8d01bd..43d0201405872 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java @@ -19,6 +19,7 @@ import java.io.File; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.util.Random; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicBoolean; @@ -34,8 +35,8 @@ public class ExternalShuffleCleanupSuite { // Same-thread Executor used to ensure cleanup happens synchronously in test thread. - Executor sameThreadExecutor = MoreExecutors.sameThreadExecutor(); - TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + private Executor sameThreadExecutor = MoreExecutors.sameThreadExecutor(); + private TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); @Test public void noCleanupAndCleanup() throws IOException { @@ -123,27 +124,29 @@ public void cleanupOnlyRemovedApp() throws IOException { assertCleanedUp(dataContext1); } - private void assertStillThere(TestShuffleDataContext dataContext) { + private static void assertStillThere(TestShuffleDataContext dataContext) { for (String localDir : dataContext.localDirs) { assertTrue(localDir + " was cleaned up prematurely", new File(localDir).exists()); } } - private void assertCleanedUp(TestShuffleDataContext dataContext) { + private static void assertCleanedUp(TestShuffleDataContext dataContext) { for (String localDir : dataContext.localDirs) { assertFalse(localDir + " wasn't cleaned up", new File(localDir).exists()); } } - private TestShuffleDataContext createSomeData() throws IOException { + private static TestShuffleDataContext createSomeData() throws IOException { Random rand = new Random(123); TestShuffleDataContext dataContext = new TestShuffleDataContext(10, 5); dataContext.create(); - dataContext.insertSortShuffleData(rand.nextInt(1000), rand.nextInt(1000), - new byte[][] { "ABC".getBytes(), "DEF".getBytes() } ); - dataContext.insertHashShuffleData(rand.nextInt(1000), rand.nextInt(1000) + 1000, - new byte[][] { "GHI".getBytes(), "JKLMNOPQRSTUVWXYZ".getBytes() } ); + dataContext.insertSortShuffleData(rand.nextInt(1000), rand.nextInt(1000), new byte[][] { + "ABC".getBytes(StandardCharsets.UTF_8), + "DEF".getBytes(StandardCharsets.UTF_8)}); + dataContext.insertHashShuffleData(rand.nextInt(1000), rand.nextInt(1000) + 1000, new byte[][] { + "GHI".getBytes(StandardCharsets.UTF_8), + "JKLMNOPQRSTUVWXYZ".getBytes(StandardCharsets.UTF_8)}); return dataContext; } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 427a8315e02b7..e16166ade4e5d 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -21,6 +21,7 @@ import java.io.*; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Map; @@ -825,14 +826,7 @@ public UTF8String translate(Map dict) { @Override public String toString() { - try { - return new String(getBytes(), "utf-8"); - } catch (UnsupportedEncodingException e) { - // Turn the exception into unchecked so we can find out about it at runtime, but - // don't need to add lots of boilerplate code everywhere. - throwException(e); - return "unknown"; // we will never reach here. - } + return new String(getBytes(), StandardCharsets.UTF_8); } @Override diff --git a/core/pom.xml b/core/pom.xml index be40d9936afd7..4c7e3a36620a9 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -314,7 +314,7 @@ net.sf.py4j py4j - 0.9.1 + 0.9.2 org.apache.spark diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index 9e0a840b72e27..efab61e132a20 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -43,7 +43,7 @@ private[spark] trait Logging { // Method to get or create the logger for this object protected def log: Logger = { if (log_ == null) { - initializeIfNecessary() + initializeLogIfNecessary(false) log_ = LoggerFactory.getLogger(logName) } log_ @@ -95,17 +95,17 @@ private[spark] trait Logging { log.isTraceEnabled } - private def initializeIfNecessary() { + protected def initializeLogIfNecessary(isInterpreter: Boolean): Unit = { if (!Logging.initialized) { Logging.initLock.synchronized { if (!Logging.initialized) { - initializeLogging() + initializeLogging(isInterpreter) } } } } - private def initializeLogging() { + private def initializeLogging(isInterpreter: Boolean): Unit = { // Don't use a logger in here, as this is itself occurring during initialization of a logger // If Log4j 1.2 is being used, but is not initialized, load a default properties file val binderClass = StaticLoggerBinder.getSingleton.getLoggerFactoryClassStr @@ -127,11 +127,11 @@ private[spark] trait Logging { } } - if (Utils.isInInterpreter) { + if (isInterpreter) { // Use the repl's main class to define the default log level when running the shell, // overriding the root logger's config if they're different. val rootLogger = LogManager.getRootLogger() - val replLogger = LogManager.getLogger("org.apache.spark.repl.Main") + val replLogger = LogManager.getLogger(logName) val replLevel = Option(replLogger.getLevel()).getOrElse(Level.WARN) if (replLevel != rootLogger.getEffectiveLevel()) { System.err.printf("Setting default log level to \"%s\".\n", replLevel) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index ff8c63158532e..0e2d51f9e78dd 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -47,7 +47,7 @@ import org.apache.spark.util.Utils * * @param loadDefaults whether to also load values from Java system properties */ -class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { +class SparkConf private[spark] (loadDefaults: Boolean) extends Cloneable with Logging { import SparkConf._ @@ -57,21 +57,32 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { private val settings = new ConcurrentHashMap[String, String]() if (loadDefaults) { + loadFromSystemProperties(false) + } + + private[spark] def loadFromSystemProperties(silent: Boolean): SparkConf = { // Load any spark.* system properties for ((key, value) <- Utils.getSystemProperties if key.startsWith("spark.")) { - set(key, value) + set(key, value, silent) } + this } /** Set a configuration variable. */ def set(key: String, value: String): SparkConf = { + set(key, value, false) + } + + private[spark] def set(key: String, value: String, silent: Boolean): SparkConf = { if (key == null) { throw new NullPointerException("null key") } if (value == null) { throw new NullPointerException("null value for " + key) } - logDeprecationWarning(key) + if (!silent) { + logDeprecationWarning(key) + } settings.put(key, value) this } @@ -395,7 +406,11 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { /** Copy this object */ override def clone: SparkConf = { - new SparkConf(false).setAll(getAll) + val cloned = new SparkConf(false) + settings.entrySet().asScala.foreach { e => + cloned.set(e.getKey(), e.getValue(), true) + } + cloned } /** diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index 13e18a56c8fd8..0d3a5237d9906 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -66,7 +66,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) */ def unpersist(blocking: Boolean): JavaDoubleRDD = fromRDD(srdd.unpersist(blocking)) - // first() has to be overriden here in order for its return type to be Double instead of Object. + // first() has to be overridden here in order for its return type to be Double instead of Object. override def first(): JDouble = srdd.first() // Transformations (return a new RDD) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index f1aebbcd39638..d362c40b7af4b 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -351,7 +351,7 @@ class JavaSparkContext(val sc: SparkContext) } /** - * Get an RDD for a Hadoop-readable dataset from a Hadooop JobConf giving its InputFormat and any + * Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf giving its InputFormat and any * other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable, * etc). * @@ -383,7 +383,7 @@ class JavaSparkContext(val sc: SparkContext) } /** - * Get an RDD for a Hadoop-readable dataset from a Hadooop JobConf giving its InputFormat and any + * Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf giving its InputFormat and any * other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable, * * @param conf JobConf for setting up the dataset. Note: This will be put into a Broadcast. diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 05d1c31a08f22..8f306770a184f 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -19,6 +19,7 @@ package org.apache.spark.api.python import java.io._ import java.net._ +import java.nio.charset.StandardCharsets import java.util.{ArrayList => JArrayList, Collections, List => JList, Map => JMap} import scala.collection.JavaConverters._ @@ -26,7 +27,6 @@ import scala.collection.mutable import scala.language.existentials import scala.util.control.NonFatal -import com.google.common.base.Charsets.UTF_8 import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.{InputFormat, JobConf, OutputFormat} @@ -165,7 +165,7 @@ private[spark] class PythonRunner( val exLength = stream.readInt() val obj = new Array[Byte](exLength) stream.readFully(obj) - throw new PythonException(new String(obj, UTF_8), + throw new PythonException(new String(obj, StandardCharsets.UTF_8), writerThread.exception.getOrElse(null)) case SpecialLengths.END_OF_DATA_SECTION => // We've finished the data section of the output, but we can still @@ -624,7 +624,7 @@ private[spark] object PythonRDD extends Logging { } def writeUTF(str: String, dataOut: DataOutputStream) { - val bytes = str.getBytes(UTF_8) + val bytes = str.getBytes(StandardCharsets.UTF_8) dataOut.writeInt(bytes.length) dataOut.write(bytes) } @@ -817,7 +817,7 @@ private[spark] object PythonRDD extends Logging { private class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] { - override def call(arr: Array[Byte]) : String = new String(arr, UTF_8) + override def call(arr: Array[Byte]) : String = new String(arr, StandardCharsets.UTF_8) } /** diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index bda872746c8b8..8bcd2903fe768 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -32,7 +32,7 @@ private[spark] object PythonUtils { val pythonPath = new ArrayBuffer[String] for (sparkHome <- sys.env.get("SPARK_HOME")) { pythonPath += Seq(sparkHome, "python", "lib", "pyspark.zip").mkString(File.separator) - pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.9.1-src.zip").mkString(File.separator) + pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.9.2-src.zip").mkString(File.separator) } pythonPath ++= SparkContext.jarOfObject(this) pythonPath.mkString(File.pathSeparator) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index a2a2f89f1e875..433764be89fb7 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -19,6 +19,7 @@ package org.apache.spark.api.python import java.io.{DataInputStream, DataOutputStream, InputStream, OutputStreamWriter} import java.net.{InetAddress, ServerSocket, Socket, SocketException} +import java.nio.charset.StandardCharsets import java.util.Arrays import scala.collection.mutable @@ -121,7 +122,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String redirectStreamsToStderr(worker.getInputStream, worker.getErrorStream) // Tell the worker our port - val out = new OutputStreamWriter(worker.getOutputStream) + val out = new OutputStreamWriter(worker.getOutputStream, StandardCharsets.UTF_8) out.write(serverSocket.getLocalPort + "\n") out.flush() diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala index 9549784aeabf5..34cb7c61d7034 100644 --- a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala +++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala @@ -19,10 +19,10 @@ package org.apache.spark.api.python import java.{util => ju} import java.io.{DataInput, DataOutput} +import java.nio.charset.StandardCharsets import scala.collection.JavaConverters._ -import com.google.common.base.Charsets.UTF_8 import org.apache.hadoop.io._ import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat @@ -134,7 +134,7 @@ object WriteInputFormatTestDataGenerator { sc.parallelize(intKeys).saveAsSequenceFile(intPath) sc.parallelize(intKeys.map{ case (k, v) => (k.toDouble, v) }).saveAsSequenceFile(doublePath) sc.parallelize(intKeys.map{ case (k, v) => (k.toString, v) }).saveAsSequenceFile(textPath) - sc.parallelize(intKeys.map{ case (k, v) => (k, v.getBytes(UTF_8)) } + sc.parallelize(intKeys.map{ case (k, v) => (k, v.getBytes(StandardCharsets.UTF_8)) } ).saveAsSequenceFile(bytesPath) val bools = Seq((1, true), (2, true), (2, false), (3, true), (2, false), (1, false)) sc.parallelize(bools).saveAsSequenceFile(boolPath) diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index af815f885e8ae..c7fb192f26bd0 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -18,6 +18,7 @@ package org.apache.spark.api.r import java.io.{DataInputStream, DataOutputStream} +import java.nio.charset.StandardCharsets import java.sql.{Date, Time, Timestamp} import scala.collection.JavaConverters._ @@ -109,7 +110,7 @@ private[spark] object SerDe { val bytes = new Array[Byte](len) in.readFully(bytes) assert(bytes(len - 1) == 0) - val str = new String(bytes.dropRight(1), "UTF-8") + val str = new String(bytes.dropRight(1), StandardCharsets.UTF_8) str } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 550e1ba6d3de0..8091aa8062a21 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -74,7 +74,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) } else { None } - // Note: use getSizeAsKb (not bytes) to maintain compatiblity if no units are provided + // Note: use getSizeAsKb (not bytes) to maintain compatibility if no units are provided blockSize = conf.getSizeAsKb("spark.broadcast.blockSize", "4m").toInt * 1024 } setConf(SparkEnv.get.conf) diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index 434aadd2c61ac..305994a3f3543 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -19,6 +19,7 @@ package org.apache.spark.deploy import java.io._ import java.net.URL +import java.nio.charset.StandardCharsets import java.util.concurrent.TimeoutException import scala.collection.mutable.ListBuffer @@ -348,7 +349,8 @@ private class TestMasterInfo(val ip: String, val dockerId: DockerId, val logFile def readState() { try { - val masterStream = new InputStreamReader(new URL("http://%s:8080/json".format(ip)).openStream) + val masterStream = new InputStreamReader( + new URL("http://%s:8080/json".format(ip)).openStream, StandardCharsets.UTF_8) val json = JsonMethods.parse(masterStream) val workers = json \ "workers" diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 06b5101b1f566..270ca84e24ae4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -45,7 +45,7 @@ import org.apache.spark.util.Utils */ @DeveloperApi class SparkHadoopUtil extends Logging { - private val sparkConf = new SparkConf() + private val sparkConf = new SparkConf(false).loadFromSystemProperties(true) val conf: Configuration = newConfiguration(sparkConf) UserGroupInformation.setConfiguration(conf) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 175756b80b6bb..a62096d771724 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -20,6 +20,7 @@ package org.apache.spark.deploy import java.io.{ByteArrayOutputStream, PrintStream} import java.lang.reflect.InvocationTargetException import java.net.URI +import java.nio.charset.StandardCharsets import java.util.{List => JList} import java.util.jar.JarFile @@ -608,7 +609,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S stream.flush() // Get the output and discard any unnecessary lines from it. - Source.fromString(new String(out.toByteArray())).getLines + Source.fromString(new String(out.toByteArray(), StandardCharsets.UTF_8)).getLines .filter { line => !line.startsWith("log4j") && !line.startsWith("usage") } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationCache.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationCache.scala index e2fda29044385..000f7e8e1e6e8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationCache.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationCache.scala @@ -87,7 +87,7 @@ private[history] class ApplicationCache( /** * The cache of applications. * - * Tagged as `protected` so as to allow subclasses in tests to accesss it directly + * Tagged as `protected` so as to allow subclasses in tests to access it directly */ protected val appCache: LoadingCache[CacheKey, CacheEntry] = { CacheBuilder.newBuilder() @@ -447,7 +447,7 @@ private[history] class CacheMetrics(prefix: String) extends Source { private[history] trait ApplicationCacheOperations { /** - * Get the application UI and the probe neededed to see if it has been updated. + * Get the application UI and the probe needed to see if it has been updated. * @param appId application ID * @param attemptId attempt ID * @return If found, the Spark UI and any history information to be used in the cache @@ -590,7 +590,7 @@ private[history] object ApplicationCacheCheckFilterRelay extends Logging { // name of the attempt ID entry in the filter configuration. Optional. val ATTEMPT_ID = "attemptId" - // namer of the filter to register + // name of the filter to register val FILTER_NAME = "org.apache.spark.deploy.history.ApplicationCacheCheckFilter" /** the application cache to relay requests to */ diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 6b9b1408ee44e..c97ad4d72350a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -727,15 +727,28 @@ private[deploy] class Master( * every time a new app joins or resource availability changes. */ private def schedule(): Unit = { - if (state != RecoveryState.ALIVE) { return } + if (state != RecoveryState.ALIVE) { + return + } // Drivers take strict precedence over executors - val shuffledWorkers = Random.shuffle(workers) // Randomization helps balance drivers - for (worker <- shuffledWorkers if worker.state == WorkerState.ALIVE) { - for (driver <- waitingDrivers) { + val shuffledAliveWorkers = Random.shuffle(workers.toSeq.filter(_.state == WorkerState.ALIVE)) + val numWorkersAlive = shuffledAliveWorkers.size + var curPos = 0 + for (driver <- waitingDrivers.toList) { // iterate over a copy of waitingDrivers + // We assign workers to each waiting driver in a round-robin fashion. For each driver, we + // start from the last worker that was assigned a driver, and continue onwards until we have + // explored all alive workers. + var launched = false + var numWorkersVisited = 0 + while (numWorkersVisited < numWorkersAlive && !launched) { + val worker = shuffledAliveWorkers(curPos) + numWorkersVisited += 1 if (worker.memoryFree >= driver.desc.mem && worker.coresFree >= driver.desc.cores) { launchDriver(worker, driver) waitingDrivers -= driver + launched = true } + curPos = (curPos + 1) % numWorkersAlive } } startExecutorsOnWorkers() diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala index 4172d924c802d..c0f9129a423f9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala @@ -17,69 +17,89 @@ package org.apache.spark.deploy.mesos -import java.net.SocketAddress import java.nio.ByteBuffer +import java.util.concurrent.{ConcurrentHashMap, TimeUnit} -import scala.collection.mutable +import scala.collection.JavaConverters._ import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.ExternalShuffleService import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler import org.apache.spark.network.shuffle.protocol.BlockTransferMessage -import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver +import org.apache.spark.network.shuffle.protocol.mesos.{RegisterDriver, ShuffleServiceHeartbeat} import org.apache.spark.network.util.TransportConf +import org.apache.spark.util.ThreadUtils /** * An RPC endpoint that receives registration requests from Spark drivers running on Mesos. * It detects driver termination and calls the cleanup callback to [[ExternalShuffleService]]. */ -private[mesos] class MesosExternalShuffleBlockHandler(transportConf: TransportConf) +private[mesos] class MesosExternalShuffleBlockHandler( + transportConf: TransportConf, + cleanerIntervalS: Long) extends ExternalShuffleBlockHandler(transportConf, null) with Logging { - // Stores a map of driver socket addresses to app ids - private val connectedApps = new mutable.HashMap[SocketAddress, String] + ThreadUtils.newDaemonSingleThreadScheduledExecutor("shuffle-cleaner-watcher") + .scheduleAtFixedRate(new CleanerThread(), 0, cleanerIntervalS, TimeUnit.SECONDS) + + // Stores a map of app id to app state (timeout value and last heartbeat) + private val connectedApps = new ConcurrentHashMap[String, AppState]() protected override def handleMessage( message: BlockTransferMessage, client: TransportClient, callback: RpcResponseCallback): Unit = { message match { - case RegisterDriverParam(appId) => + case RegisterDriverParam(appId, appState) => val address = client.getSocketAddress - logDebug(s"Received registration request from app $appId (remote address $address).") - if (connectedApps.contains(address)) { - val existingAppId = connectedApps(address) - if (!existingAppId.equals(appId)) { - logError(s"A new app '$appId' has connected to existing address $address, " + - s"removing previously registered app '$existingAppId'.") - applicationRemoved(existingAppId, true) - } + val timeout = appState.heartbeatTimeout + logInfo(s"Received registration request from app $appId (remote address $address, " + + s"heartbeat timeout $timeout ms).") + if (connectedApps.containsKey(appId)) { + logWarning(s"Received a registration request from app $appId, but it was already " + + s"registered") } - connectedApps(address) = appId + connectedApps.put(appId, appState) callback.onSuccess(ByteBuffer.allocate(0)) + case Heartbeat(appId) => + val address = client.getSocketAddress + Option(connectedApps.get(appId)) match { + case Some(existingAppState) => + logTrace(s"Received ShuffleServiceHeartbeat from app '$appId' (remote " + + s"address $address).") + existingAppState.lastHeartbeat = System.nanoTime() + case None => + logWarning(s"Received ShuffleServiceHeartbeat from an unknown app (remote " + + s"address $address, appId '$appId').") + } case _ => super.handleMessage(message, client, callback) } } - /** - * On connection termination, clean up shuffle files written by the associated application. - */ - override def channelInactive(client: TransportClient): Unit = { - val address = client.getSocketAddress - if (connectedApps.contains(address)) { - val appId = connectedApps(address) - logInfo(s"Application $appId disconnected (address was $address).") - applicationRemoved(appId, true /* cleanupLocalDirs */) - connectedApps.remove(address) - } else { - logWarning(s"Unknown $address disconnected.") - } - } - /** An extractor object for matching [[RegisterDriver]] message. */ private object RegisterDriverParam { - def unapply(r: RegisterDriver): Option[String] = Some(r.getAppId) + def unapply(r: RegisterDriver): Option[(String, AppState)] = + Some((r.getAppId, new AppState(r.getHeartbeatTimeoutMs, System.nanoTime()))) + } + + private object Heartbeat { + def unapply(h: ShuffleServiceHeartbeat): Option[String] = Some(h.getAppId) + } + + private class AppState(val heartbeatTimeout: Long, @volatile var lastHeartbeat: Long) + + private class CleanerThread extends Runnable { + override def run(): Unit = { + val now = System.nanoTime() + connectedApps.asScala.foreach { case (appId, appState) => + if (now - appState.lastHeartbeat > appState.heartbeatTimeout * 1000 * 1000) { + logInfo(s"Application $appId timed out. Removing shuffle files.") + connectedApps.remove(appId) + applicationRemoved(appId, true) + } + } + } } } @@ -93,7 +113,8 @@ private[mesos] class MesosExternalShuffleService(conf: SparkConf, securityManage protected override def newShuffleBlockHandler( conf: TransportConf): ExternalShuffleBlockHandler = { - new MesosExternalShuffleBlockHandler(conf) + val cleanerIntervalS = this.conf.getTimeAsSeconds("spark.shuffle.cleaner.interval", "30s") + new MesosExternalShuffleBlockHandler(conf, cleanerIntervalS) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index 006e2e1472580..d3e092a34c172 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -19,6 +19,7 @@ package org.apache.spark.deploy.rest import java.io.{DataOutputStream, FileNotFoundException} import java.net.{ConnectException, HttpURLConnection, SocketException, URL} +import java.nio.charset.StandardCharsets import java.util.concurrent.TimeoutException import javax.servlet.http.HttpServletResponse @@ -28,7 +29,6 @@ import scala.concurrent.duration._ import scala.io.Source import com.fasterxml.jackson.core.JsonProcessingException -import com.google.common.base.Charsets import org.apache.spark.{Logging, SPARK_VERSION => sparkVersion, SparkConf} import org.apache.spark.util.Utils @@ -211,7 +211,7 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { try { val out = new DataOutputStream(conn.getOutputStream) Utils.tryWithSafeFinally { - out.write(json.getBytes(Charsets.UTF_8)) + out.write(json.getBytes(StandardCharsets.UTF_8)) } { out.close() } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 6049db6d989ae..7f4fe26c0d15e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -18,10 +18,10 @@ package org.apache.spark.deploy.worker import java.io._ +import java.nio.charset.StandardCharsets import scala.collection.JavaConverters._ -import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.apache.hadoop.fs.Path @@ -174,7 +174,7 @@ private[deploy] class DriverRunner( val stderr = new File(baseDir, "stderr") val formattedCommand = builder.command.asScala.mkString("\"", "\" \"", "\"") val header = "Launch Command: %s\n%s\n\n".format(formattedCommand, "=" * 40) - Files.append(header, stderr, UTF_8) + Files.append(header, stderr, StandardCharsets.UTF_8) CommandUtils.redirectStream(process.getErrorStream, stderr) } runCommandWithRetry(ProcessBuilderLike(builder), initialize, supervise) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index c6687a4c63a6a..208a1bb68edb9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -18,10 +18,10 @@ package org.apache.spark.deploy.worker import java.io._ +import java.nio.charset.StandardCharsets import scala.collection.JavaConverters._ -import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.apache.spark.{Logging, SecurityManager, SparkConf} @@ -168,7 +168,7 @@ private[deploy] class ExecutorRunner( stdoutAppender = FileAppender(process.getInputStream, stdout, conf) val stderr = new File(executorDir, "stderr") - Files.write(header, stderr, UTF_8) + Files.write(header, stderr, StandardCharsets.UTF_8) stderrAppender = FileAppender(process.getErrorStream, stderr, conf) // Wait for it to exit; executor may exit with code 0 (when driver instructs it to shutdown) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 07e3c12bc9bc3..48372d70d52a9 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -124,7 +124,7 @@ private[spark] class Executor( private val HEARTBEAT_MAX_FAILURES = conf.getInt("spark.executor.heartbeat.maxFailures", 60) /** - * Count the failure times of heartbeat. It should only be acessed in the heartbeat thread. Each + * Count the failure times of heartbeat. It should only be accessed in the heartbeat thread. Each * successful heartbeat will reset it to 0. */ private var heartbeatFailures = 0 diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index c9606600ed068..0f579cfe420c5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -141,7 +141,7 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { * And on the input of 1 and 50 we would have a histogram of 1, 0, 1 * * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched - * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets + * from an O(log n) insertion to O(1) per element. (where n = # buckets) if you set evenBuckets * to true. * buckets must be sorted and not contain any duplicates. * buckets array must be at least two elements diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index b576d4c5f3c45..8a36af27bdd27 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1148,13 +1148,13 @@ class DAGScheduler( null } - // The success case is dealt with separately below. - // TODO: Why post it only for failed tasks in cancelled stages? Clarify semantics here. - if (event.reason != Success) { - val attemptId = task.stageAttemptId - listenerBus.post(SparkListenerTaskEnd( - stageId, attemptId, taskType, event.reason, event.taskInfo, taskMetrics)) - } + // The stage may have already finished when we get this event -- eg. maybe it was a + // speculative task. It is important that we send the TaskEnd event in any case, so listeners + // are properly notified and can chose to handle it. For instance, some listeners are + // doing their own accounting and if they don't get the task end event they think + // tasks are still running when they really aren't. + listenerBus.post(SparkListenerTaskEnd( + stageId, task.stageAttemptId, taskType, event.reason, event.taskInfo, taskMetrics)) if (!stageIdToStage.contains(task.stageId)) { // Skip all the actions if the stage has been cancelled. @@ -1164,8 +1164,6 @@ class DAGScheduler( val stage = stageIdToStage(task.stageId) event.reason match { case Success => - listenerBus.post(SparkListenerTaskEnd(stageId, stage.latestInfo.attemptId, taskType, - event.reason, event.taskInfo, taskMetrics)) stage.pendingPartitions -= task.partitionId task match { case rt: ResultTask[_, _] => diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 8354e2a6112a2..2d76d08af6cdd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -19,11 +19,11 @@ package org.apache.spark.scheduler import java.io._ import java.net.URI +import java.nio.charset.StandardCharsets import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import com.google.common.base.Charsets import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, FSDataOutputStream, Path} import org.apache.hadoop.fs.permission.FsPermission @@ -254,7 +254,7 @@ private[spark] object EventLoggingListener extends Logging { def initEventLog(logStream: OutputStream): Unit = { val metadata = SparkListenerLogStart(SPARK_VERSION) val metadataJson = compact(JsonProtocol.logStartToJson(metadata)) + "\n" - logStream.write(metadataJson.getBytes(Charsets.UTF_8)) + logStream.write(metadataJson.getBytes(StandardCharsets.UTF_8)) } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala index def0aac720b64..dfcdd113dfb98 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala @@ -157,7 +157,7 @@ object InputFormatInfo { b) Decrement the currently allocated containers on that host. c) Compute rack info for each host and update rack -> count map based on (b). d) Allocate nodes based on (c) - e) On the allocation result, ensure that we dont allocate "too many" jobs on a single node + e) On the allocation result, ensure that we don't allocate "too many" jobs on a single node (even if data locality on that is very high) : this is to prevent fragility of job if a single (or small set of) hosts go down. diff --git a/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala index 1ce83485f024b..6e9337bb90635 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala @@ -45,7 +45,7 @@ class SplitInfo( hashCode } - // This is practically useless since most of the Split impl's dont seem to implement equals :-( + // This is practically useless since most of the Split impl's don't seem to implement equals :-( // So unless there is identity equality between underlyingSplits, it will always fail even if it // is pointing to same block. override def equals(other: Any): Boolean = other match { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index e1180980eed68..90b1813750be7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -448,7 +448,12 @@ private[spark] class CoarseMesosSchedulerBackend( s"host ${slave.hostname}, port $externalShufflePort for app ${conf.getAppId}") mesosExternalShuffleClient.get - .registerDriverWithShuffleService(slave.hostname, externalShufflePort) + .registerDriverWithShuffleService( + slave.hostname, + externalShufflePort, + sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", + s"${sc.conf.getTimeAsMs("spark.network.timeout", "120s")}ms"), + sc.conf.getTimeAsMs("spark.executor.heartbeatInterval", "10s")) slave.shuffleRegistered = true } @@ -506,6 +511,9 @@ private[spark] class CoarseMesosSchedulerBackend( + "on the mesos nodes.") } + // Close the mesos external shuffle client if used + mesosExternalShuffleClient.foreach(_.close()) + if (mesosDriver != null) { mesosDriver.stop() } diff --git a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala index 3d5b7105f0ca8..1a8e545b4f59e 100644 --- a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala @@ -19,6 +19,7 @@ package org.apache.spark.serializer import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets import scala.collection.mutable @@ -86,7 +87,7 @@ private[serializer] class GenericAvroSerializer(schemas: Map[Long, String]) schemaBytes.arrayOffset() + schemaBytes.position(), schemaBytes.remaining()) val bytes = IOUtils.toByteArray(codec.compressedInputStream(bis)) - new Schema.Parser().parse(new String(bytes, "UTF-8")) + new Schema.Parser().parse(new String(bytes, StandardCharsets.UTF_8)) }) /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index bcf65e9d7e25c..d21df4b95b3cd 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -425,39 +425,10 @@ private[spark] class BlockManager( val iterToReturn: Iterator[Any] = { val diskBytes = diskStore.getBytes(blockId) if (level.deserialized) { - val diskIterator = dataDeserialize(blockId, diskBytes) - if (level.useMemory) { - // Cache the values before returning them - memoryStore.putIterator(blockId, diskIterator, level) match { - case Left(iter) => - // The memory store put() failed, so it returned the iterator back to us: - iter - case Right(_) => - // The put() succeeded, so we can read the values back: - memoryStore.getValues(blockId).get - } - } else { - diskIterator - } - } else { // storage level is serialized - if (level.useMemory) { - // Cache the bytes back into memory to speed up subsequent reads. - val putSucceeded = memoryStore.putBytes(blockId, diskBytes.limit(), () => { - // https://issues.apache.org/jira/browse/SPARK-6076 - // If the file size is bigger than the free memory, OOM will happen. So if we - // cannot put it into MemoryStore, copyForMemory should not be created. That's why - // this action is put into a `() => ByteBuffer` and created lazily. - val copyForMemory = ByteBuffer.allocate(diskBytes.limit) - copyForMemory.put(diskBytes) - }) - if (putSucceeded) { - dataDeserialize(blockId, memoryStore.getBytes(blockId).get) - } else { - dataDeserialize(blockId, diskBytes) - } - } else { - dataDeserialize(blockId, diskBytes) - } + val diskValues = dataDeserialize(blockId, diskBytes) + maybeCacheDiskValuesInMemory(info, blockId, level, diskValues) + } else { + dataDeserialize(blockId, maybeCacheDiskBytesInMemory(info, blockId, level, diskBytes)) } } val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, releaseLock(blockId)) @@ -517,26 +488,7 @@ private[spark] class BlockManager( if (level.useMemory && memoryStore.contains(blockId)) { memoryStore.getBytes(blockId).get } else if (level.useDisk && diskStore.contains(blockId)) { - val bytes = diskStore.getBytes(blockId) - if (level.useMemory) { - // Cache the bytes back into memory to speed up subsequent reads. - val memoryStorePutSucceeded = memoryStore.putBytes(blockId, bytes.limit(), () => { - // https://issues.apache.org/jira/browse/SPARK-6076 - // If the file size is bigger than the free memory, OOM will happen. So if we cannot - // put it into MemoryStore, copyForMemory should not be created. That's why this - // action is put into a `() => ByteBuffer` and created lazily. - val copyForMemory = ByteBuffer.allocate(bytes.limit) - copyForMemory.put(bytes) - }) - if (memoryStorePutSucceeded) { - memoryStore.getBytes(blockId).get - } else { - bytes.rewind() - bytes - } - } else { - bytes - } + maybeCacheDiskBytesInMemory(info, blockId, level, diskStore.getBytes(blockId)) } else { releaseLock(blockId) throw new SparkException(s"Block $blockId was not found even though it's read-locked") @@ -717,8 +669,15 @@ private[spark] class BlockManager( level: StorageLevel, tellMaster: Boolean = true): Boolean = { require(values != null, "Values is null") - // If doPut() didn't hand work back to us, then block already existed or was successfully stored - doPutIterator(blockId, () => values, level, tellMaster).isEmpty + doPutIterator(blockId, () => values, level, tellMaster) match { + case None => + true + case Some(iter) => + // Caller doesn't care about the iterator values, so we can close the iterator here + // to free resources earlier + iter.close() + false + } } /** @@ -793,7 +752,14 @@ private[spark] class BlockManager( // We will drop it to disk later if the memory store can't hold it. val putSucceeded = if (level.deserialized) { val values = dataDeserialize(blockId, bytes.duplicate()) - memoryStore.putIterator(blockId, values, level).isRight + memoryStore.putIterator(blockId, values, level) match { + case Right(_) => true + case Left(iter) => + // If putting deserialized values in memory failed, we will put the bytes directly to + // disk, so we don't need this iterator and can close it to free resources earlier. + iter.close() + false + } } else { memoryStore.putBytes(blockId, size, () => bytes) } @@ -905,10 +871,10 @@ private[spark] class BlockManager( iterator: () => Iterator[Any], level: StorageLevel, tellMaster: Boolean = true, - keepReadLock: Boolean = false): Option[Iterator[Any]] = { + keepReadLock: Boolean = false): Option[PartiallyUnrolledIterator] = { doPut(blockId, level, tellMaster = tellMaster, keepReadLock = keepReadLock) { putBlockInfo => val startTimeMs = System.currentTimeMillis - var iteratorFromFailedMemoryStorePut: Option[Iterator[Any]] = None + var iteratorFromFailedMemoryStorePut: Option[PartiallyUnrolledIterator] = None // Size of the block in bytes var size = 0L if (level.useMemory) { @@ -966,6 +932,85 @@ private[spark] class BlockManager( } } + /** + * Attempts to cache spilled bytes read from disk into the MemoryStore in order to speed up + * subsequent reads. This method requires the caller to hold a read lock on the block. + * + * @return a copy of the bytes. The original bytes passed this method should no longer + * be used after this method returns. + */ + private def maybeCacheDiskBytesInMemory( + blockInfo: BlockInfo, + blockId: BlockId, + level: StorageLevel, + diskBytes: ByteBuffer): ByteBuffer = { + require(!level.deserialized) + if (level.useMemory) { + // Synchronize on blockInfo to guard against a race condition where two readers both try to + // put values read from disk into the MemoryStore. + blockInfo.synchronized { + if (memoryStore.contains(blockId)) { + BlockManager.dispose(diskBytes) + memoryStore.getBytes(blockId).get + } else { + val putSucceeded = memoryStore.putBytes(blockId, diskBytes.limit(), () => { + // https://issues.apache.org/jira/browse/SPARK-6076 + // If the file size is bigger than the free memory, OOM will happen. So if we + // cannot put it into MemoryStore, copyForMemory should not be created. That's why + // this action is put into a `() => ByteBuffer` and created lazily. + val copyForMemory = ByteBuffer.allocate(diskBytes.limit) + copyForMemory.put(diskBytes) + }) + if (putSucceeded) { + BlockManager.dispose(diskBytes) + memoryStore.getBytes(blockId).get + } else { + diskBytes.rewind() + diskBytes + } + } + } + } else { + diskBytes + } + } + + /** + * Attempts to cache spilled values read from disk into the MemoryStore in order to speed up + * subsequent reads. This method requires the caller to hold a read lock on the block. + * + * @return a copy of the iterator. The original iterator passed this method should no longer + * be used after this method returns. + */ + private def maybeCacheDiskValuesInMemory( + blockInfo: BlockInfo, + blockId: BlockId, + level: StorageLevel, + diskIterator: Iterator[Any]): Iterator[Any] = { + require(level.deserialized) + if (level.useMemory) { + // Synchronize on blockInfo to guard against a race condition where two readers both try to + // put values read from disk into the MemoryStore. + blockInfo.synchronized { + if (memoryStore.contains(blockId)) { + // Note: if we had a means to discard the disk iterator, we would do that here. + memoryStore.getValues(blockId).get + } else { + memoryStore.putIterator(blockId, diskIterator, level) match { + case Left(iter) => + // The memory store put() failed, so it returned the iterator back to us: + iter + case Right(_) => + // The put() succeeded, so we can read the values back: + memoryStore.getValues(blockId).get + } + } + } + } else { + diskIterator + } + } + /** * Get peer block managers in the system. */ @@ -1057,7 +1102,7 @@ private[spark] class BlockManager( failures += 1 replicationFailed = true peersFailedToReplicateTo += peer - if (failures > maxReplicationFailures) { // too many failures in replcating to peers + if (failures > maxReplicationFailures) { // too many failures in replicating to peers done = true } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 478a928acd03c..b19c30e2ff779 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -70,7 +70,7 @@ final class ShuffleBlockFetcherIterator( private[this] var numBlocksToFetch = 0 /** - * The number of blocks proccessed by the caller. The iterator is exhausted when + * The number of blocks processed by the caller. The iterator is exhausted when * [[numBlocksProcessed]] == [[numBlocksToFetch]]. */ private[this] var numBlocksProcessed = 0 diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index a80b2357ff911..02d44dc732951 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -26,7 +26,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.{Logging, SparkConf, TaskContext} import org.apache.spark.memory.MemoryManager import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel} -import org.apache.spark.util.{SizeEstimator, Utils} +import org.apache.spark.util.{CompletionIterator, SizeEstimator, Utils} import org.apache.spark.util.collection.SizeTrackingVector private case class MemoryEntry(value: Any, size: Long, deserialized: Boolean) @@ -49,14 +49,6 @@ private[spark] class MemoryStore( // A mapping from taskAttemptId to amount of memory used for unrolling a block (in bytes) // All accesses of this map are assumed to have manually synchronized on `memoryManager` private val unrollMemoryMap = mutable.HashMap[Long, Long]() - // Same as `unrollMemoryMap`, but for pending unroll memory as defined below. - // Pending unroll memory refers to the intermediate memory occupied by a task - // after the unroll but before the actual putting of the block in the cache. - // This chunk of memory is expected to be released *as soon as* we finish - // caching the corresponding block as opposed to until after the task finishes. - // This is only used if a block is successfully unrolled in its entirety in - // memory (SPARK-4777). - private val pendingUnrollMemoryMap = mutable.HashMap[Long, Long]() // Initial memory to request before unrolling any block private val unrollMemoryThreshold: Long = @@ -100,48 +92,151 @@ private[spark] class MemoryStore( */ def putBytes(blockId: BlockId, size: Long, _bytes: () => ByteBuffer): Boolean = { require(!contains(blockId), s"Block $blockId is already present in the MemoryStore") - // Work on a duplicate - since the original input might be used elsewhere. - lazy val bytes = _bytes().duplicate().rewind().asInstanceOf[ByteBuffer] - val putSuccess = tryToPut(blockId, () => bytes, size, deserialized = false) - if (putSuccess) { + if (memoryManager.acquireStorageMemory(blockId, size)) { + // We acquired enough memory for the block, so go ahead and put it + // Work on a duplicate - since the original input might be used elsewhere. + val bytes = _bytes().duplicate().rewind().asInstanceOf[ByteBuffer] assert(bytes.limit == size) + val entry = new MemoryEntry(bytes, size, deserialized = false) + entries.synchronized { + entries.put(blockId, entry) + } + logInfo("Block %s stored as bytes in memory (estimated size %s, free %s)".format( + blockId, Utils.bytesToString(size), Utils.bytesToString(blocksMemoryUsed))) + true + } else { + false } - putSuccess } /** * Attempt to put the given block in memory store. * - * @return the estimated size of the stored data if the put() succeeded, or an iterator - * in case the put() failed (the returned iterator lets callers fall back to the disk - * store if desired). + * It's possible that the iterator is too large to materialize and store in memory. To avoid + * OOM exceptions, this method will gradually unroll the iterator while periodically checking + * whether there is enough free memory. If the block is successfully materialized, then the + * temporary unroll memory used during the materialization is "transferred" to storage memory, + * so we won't acquire more memory than is actually needed to store the block. + * + * @return in case of success, the estimated the estimated size of the stored data. In case of + * failure, return an iterator containing the values of the block. The returned iterator + * will be backed by the combination of the partially-unrolled block and the remaining + * elements of the original input iterator. The caller must either fully consume this + * iterator or call `close()` on it in order to free the storage memory consumed by the + * partially-unrolled block. */ private[storage] def putIterator( blockId: BlockId, values: Iterator[Any], - level: StorageLevel): Either[Iterator[Any], Long] = { + level: StorageLevel): Either[PartiallyUnrolledIterator, Long] = { + require(!contains(blockId), s"Block $blockId is already present in the MemoryStore") - val unrolledValues = unrollSafely(blockId, values) - unrolledValues match { - case Left(arrayValues) => - // Values are fully unrolled in memory, so store them as an array - if (level.deserialized) { - val sizeEstimate = SizeEstimator.estimate(arrayValues.asInstanceOf[AnyRef]) - if (tryToPut(blockId, () => arrayValues, sizeEstimate, deserialized = true)) { - Right(sizeEstimate) - } else { - Left(arrayValues.toIterator) + + // Number of elements unrolled so far + var elementsUnrolled = 0 + // Whether there is still enough memory for us to continue unrolling this block + var keepUnrolling = true + // Initial per-task memory to request for unrolling blocks (bytes). + val initialMemoryThreshold = unrollMemoryThreshold + // How often to check whether we need to request more memory + val memoryCheckPeriod = 16 + // Memory currently reserved by this task for this particular unrolling operation + var memoryThreshold = initialMemoryThreshold + // Memory to request as a multiple of current vector size + val memoryGrowthFactor = 1.5 + // Keep track of unroll memory used by this particular block / putIterator() operation + var unrollMemoryUsedByThisBlock = 0L + // Underlying vector for unrolling the block + var vector = new SizeTrackingVector[Any] + + // Request enough memory to begin unrolling + keepUnrolling = reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold) + + if (!keepUnrolling) { + logWarning(s"Failed to reserve initial memory threshold of " + + s"${Utils.bytesToString(initialMemoryThreshold)} for computing block $blockId in memory.") + } else { + unrollMemoryUsedByThisBlock += initialMemoryThreshold + } + + // Unroll this block safely, checking whether we have exceeded our threshold periodically + while (values.hasNext && keepUnrolling) { + vector += values.next() + if (elementsUnrolled % memoryCheckPeriod == 0) { + // If our vector's size has exceeded the threshold, request more memory + val currentSize = vector.estimateSize() + if (currentSize >= memoryThreshold) { + val amountToRequest = (currentSize * memoryGrowthFactor - memoryThreshold).toLong + keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest) + if (keepUnrolling) { + unrollMemoryUsedByThisBlock += amountToRequest } - } else { - val bytes = blockManager.dataSerialize(blockId, arrayValues.iterator) - if (tryToPut(blockId, () => bytes, bytes.limit, deserialized = false)) { - Right(bytes.limit()) - } else { - Left(arrayValues.toIterator) + // New threshold is currentSize * memoryGrowthFactor + memoryThreshold += amountToRequest + } + } + elementsUnrolled += 1 + } + + if (keepUnrolling) { + // We successfully unrolled the entirety of this block + val arrayValues = vector.toArray + vector = null + val entry = if (level.deserialized) { + new MemoryEntry(arrayValues, SizeEstimator.estimate(arrayValues), deserialized = true) + } else { + val bytes = blockManager.dataSerialize(blockId, arrayValues.iterator) + new MemoryEntry(bytes, bytes.limit, deserialized = false) + } + val size = entry.size + def transferUnrollToStorage(amount: Long): Unit = { + // Synchronize so that transfer is atomic + memoryManager.synchronized { + releaseUnrollMemoryForThisTask(amount) + val success = memoryManager.acquireStorageMemory(blockId, amount) + assert(success, "transferring unroll memory to storage memory failed") + } + } + // Acquire storage memory if necessary to store this block in memory. + val enoughStorageMemory = { + if (unrollMemoryUsedByThisBlock <= size) { + val acquiredExtra = + memoryManager.acquireStorageMemory(blockId, size - unrollMemoryUsedByThisBlock) + if (acquiredExtra) { + transferUnrollToStorage(unrollMemoryUsedByThisBlock) } + acquiredExtra + } else { // unrollMemoryUsedByThisBlock > size + // If this task attempt already owns more unroll memory than is necessary to store the + // block, then release the extra memory that will not be used. + val excessUnrollMemory = unrollMemoryUsedByThisBlock - size + releaseUnrollMemoryForThisTask(excessUnrollMemory) + transferUnrollToStorage(size) + true } - case Right(iteratorValues) => - Left(iteratorValues) + } + if (enoughStorageMemory) { + entries.synchronized { + entries.put(blockId, entry) + } + val bytesOrValues = if (level.deserialized) "values" else "bytes" + logInfo("Block %s stored as %s in memory (estimated size %s, free %s)".format( + blockId, bytesOrValues, Utils.bytesToString(size), Utils.bytesToString(blocksMemoryUsed))) + Right(size) + } else { + assert(currentUnrollMemoryForThisTask >= currentUnrollMemoryForThisTask, + "released too much unroll memory") + Left(new PartiallyUnrolledIterator( + this, + unrollMemoryUsedByThisBlock, + unrolled = arrayValues.toIterator, + rest = Iterator.empty)) + } + } else { + // We ran out of space while unrolling the values for this block + logUnrollFailureMessage(blockId, vector.estimateSize()) + Left(new PartiallyUnrolledIterator( + this, unrollMemoryUsedByThisBlock, unrolled = vector.iterator, rest = values)) } } @@ -188,102 +283,10 @@ private[spark] class MemoryStore( entries.clear() } unrollMemoryMap.clear() - pendingUnrollMemoryMap.clear() memoryManager.releaseAllStorageMemory() logInfo("MemoryStore cleared") } - /** - * Unroll the given block in memory safely. - * - * The safety of this operation refers to avoiding potential OOM exceptions caused by - * unrolling the entirety of the block in memory at once. This is achieved by periodically - * checking whether the memory restrictions for unrolling blocks are still satisfied, - * stopping immediately if not. This check is a safeguard against the scenario in which - * there is not enough free memory to accommodate the entirety of a single block. - * - * This method returns either an array with the contents of the entire block or an iterator - * containing the values of the block (if the array would have exceeded available memory). - */ - def unrollSafely(blockId: BlockId, values: Iterator[Any]): Either[Array[Any], Iterator[Any]] = { - - // Number of elements unrolled so far - var elementsUnrolled = 0 - // Whether there is still enough memory for us to continue unrolling this block - var keepUnrolling = true - // Initial per-task memory to request for unrolling blocks (bytes). Exposed for testing. - val initialMemoryThreshold = unrollMemoryThreshold - // How often to check whether we need to request more memory - val memoryCheckPeriod = 16 - // Memory currently reserved by this task for this particular unrolling operation - var memoryThreshold = initialMemoryThreshold - // Memory to request as a multiple of current vector size - val memoryGrowthFactor = 1.5 - // Keep track of pending unroll memory reserved by this method. - var pendingMemoryReserved = 0L - // Underlying vector for unrolling the block - var vector = new SizeTrackingVector[Any] - - // Request enough memory to begin unrolling - keepUnrolling = reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold) - - if (!keepUnrolling) { - logWarning(s"Failed to reserve initial memory threshold of " + - s"${Utils.bytesToString(initialMemoryThreshold)} for computing block $blockId in memory.") - } else { - pendingMemoryReserved += initialMemoryThreshold - } - - // Unroll this block safely, checking whether we have exceeded our threshold periodically - try { - while (values.hasNext && keepUnrolling) { - vector += values.next() - if (elementsUnrolled % memoryCheckPeriod == 0) { - // If our vector's size has exceeded the threshold, request more memory - val currentSize = vector.estimateSize() - if (currentSize >= memoryThreshold) { - val amountToRequest = (currentSize * memoryGrowthFactor - memoryThreshold).toLong - keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest) - if (keepUnrolling) { - pendingMemoryReserved += amountToRequest - } - // New threshold is currentSize * memoryGrowthFactor - memoryThreshold += amountToRequest - } - } - elementsUnrolled += 1 - } - - if (keepUnrolling) { - // We successfully unrolled the entirety of this block - Left(vector.toArray) - } else { - // We ran out of space while unrolling the values for this block - logUnrollFailureMessage(blockId, vector.estimateSize()) - Right(vector.iterator ++ values) - } - - } finally { - // If we return an array, the values returned here will be cached in `tryToPut` later. - // In this case, we should release the memory only after we cache the block there. - if (keepUnrolling) { - val taskAttemptId = currentTaskAttemptId() - memoryManager.synchronized { - // Since we continue to hold onto the array until we actually cache it, we cannot - // release the unroll memory yet. Instead, we transfer it to pending unroll memory - // so `tryToPut` can further transfer it to normal storage memory later. - // TODO: we can probably express this without pending unroll memory (SPARK-10907) - unrollMemoryMap(taskAttemptId) -= pendingMemoryReserved - pendingUnrollMemoryMap(taskAttemptId) = - pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) + pendingMemoryReserved - } - } else { - // Otherwise, if we return an iterator, we can only release the unroll memory when - // the task finishes since we don't know when the iterator will be consumed. - } - } - } - /** * Return the RDD ID that a given block ID is from, or None if it is not an RDD block. */ @@ -291,48 +294,6 @@ private[spark] class MemoryStore( blockId.asRDDId.map(_.rddId) } - /** - * Try to put in a set of values, if we can free up enough space. The value should either be - * an Array if deserialized is true or a ByteBuffer otherwise. Its (possibly estimated) size - * must also be passed by the caller. - * - * @return whether put was successful. - */ - private def tryToPut( - blockId: BlockId, - value: () => Any, - size: Long, - deserialized: Boolean): Boolean = { - val acquiredEnoughStorageMemory = { - // Synchronize on memoryManager so that the pending unroll memory isn't stolen by another - // task. - memoryManager.synchronized { - // Note: if we have previously unrolled this block successfully, then pending unroll - // memory should be non-zero. This is the amount that we already reserved during the - // unrolling process. In this case, we can just reuse this space to cache our block. - // The synchronization on `memoryManager` here guarantees that the release and acquire - // happen atomically. This relies on the assumption that all memory acquisitions are - // synchronized on the same lock. - releasePendingUnrollMemoryForThisTask() - memoryManager.acquireStorageMemory(blockId, size) - } - } - - if (acquiredEnoughStorageMemory) { - // We acquired enough memory for the block, so go ahead and put it - val entry = new MemoryEntry(value(), size, deserialized) - entries.synchronized { - entries.put(blockId, entry) - } - val valuesOrBytes = if (deserialized) "values" else "bytes" - logInfo("Block %s stored as %s in memory (estimated size %s, free %s)".format( - blockId, valuesOrBytes, Utils.bytesToString(size), Utils.bytesToString(blocksMemoryUsed))) - true - } else { - false - } - } - /** * Try to evict blocks to free up a given amount of space to store a particular block. * Can fail if either the block is bigger than our memory or it would require replacing @@ -456,30 +417,11 @@ private[spark] class MemoryStore( } } - /** - * Release pending unroll memory of current unroll successful block used by this task - */ - def releasePendingUnrollMemoryForThisTask(memory: Long = Long.MaxValue): Unit = { - val taskAttemptId = currentTaskAttemptId() - memoryManager.synchronized { - if (pendingUnrollMemoryMap.contains(taskAttemptId)) { - val memoryToRelease = math.min(memory, pendingUnrollMemoryMap(taskAttemptId)) - if (memoryToRelease > 0) { - pendingUnrollMemoryMap(taskAttemptId) -= memoryToRelease - if (pendingUnrollMemoryMap(taskAttemptId) == 0) { - pendingUnrollMemoryMap.remove(taskAttemptId) - } - memoryManager.releaseUnrollMemory(memoryToRelease) - } - } - } - } - /** * Return the amount of memory currently occupied for unrolling blocks across all tasks. */ def currentUnrollMemory: Long = memoryManager.synchronized { - unrollMemoryMap.values.sum + pendingUnrollMemoryMap.values.sum + unrollMemoryMap.values.sum } /** @@ -520,3 +462,42 @@ private[spark] class MemoryStore( logMemoryUsage() } } + +/** + * The result of a failed [[MemoryStore.putIterator()]] call. + * + * @param memoryStore the memoryStore, used for freeing memory. + * @param unrollMemory the amount of unroll memory used by the values in `unrolled`. + * @param unrolled an iterator for the partially-unrolled values. + * @param rest the rest of the original iterator passed to [[MemoryStore.putIterator()]]. + */ +private[storage] class PartiallyUnrolledIterator( + memoryStore: MemoryStore, + unrollMemory: Long, + unrolled: Iterator[Any], + rest: Iterator[Any]) + extends Iterator[Any] { + + private[this] var unrolledIteratorIsConsumed: Boolean = false + private[this] var iter: Iterator[Any] = { + val completionIterator = CompletionIterator[Any, Iterator[Any]](unrolled, { + unrolledIteratorIsConsumed = true + memoryStore.releaseUnrollMemoryForThisTask(unrollMemory) + }) + completionIterator ++ rest + } + + override def hasNext: Boolean = iter.hasNext + override def next(): Any = iter.next() + + /** + * Called to dispose of this iterator and free its memory. + */ + def close(): Unit = { + if (!unrolledIteratorIsConsumed) { + memoryStore.releaseUnrollMemoryForThisTask(unrollMemory) + unrolledIteratorIsConsumed = true + } + iter = null + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala index 5a8c2914314c2..094953f2f5b5e 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala @@ -102,7 +102,7 @@ private[spark] object UIWorkloadGenerator { try { setProperties(desc) job() - println("Job funished: " + desc) + println("Job finished: " + desc) } catch { case e: Exception => println("Job Failed: " + desc) diff --git a/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala b/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala index 73d126ff6254e..c9b7493fcdc1b 100644 --- a/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala +++ b/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala @@ -18,7 +18,7 @@ package org.apache.spark.util /** - * A class loader which makes some protected methods in ClassLoader accesible. + * A class loader which makes some protected methods in ClassLoader accessible. */ private[spark] class ParentClassLoader(parent: ClassLoader) extends ClassLoader(parent) { diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index 83ded92609d66..a06db9a4fcfa5 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -253,7 +253,7 @@ object SizeEstimator extends Logging { } else { // Estimate the size of a large array by sampling elements without replacement. // To exclude the shared objects that the array elements may link, sample twice - // and use the min one to caculate array size. + // and use the min one to calculate array size. val rand = new Random(42) val drawn = new OpenHashSet[Int](2 * ARRAY_SAMPLE_SIZE) val s1 = sampleArray(array, state, rand, drawn, length) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index b4c49513711c2..63b9d34b79fe7 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -22,6 +22,7 @@ import java.lang.management.ManagementFactory import java.net._ import java.nio.ByteBuffer import java.nio.channels.Channels +import java.nio.charset.StandardCharsets import java.nio.file.Files import java.util.{Locale, Properties, Random, UUID} import java.util.concurrent._ @@ -1529,7 +1530,7 @@ private[spark] object Utils extends Logging { rawMod + (if (rawMod < 0) mod else 0) } - // Handles idiosyncracies with hash (add more as required) + // Handles idiosyncrasies with hash (add more as required) // This method should be kept in sync with // org.apache.spark.network.util.JavaUtils#nonNegativeHash(). def nonNegativeHash(obj: AnyRef): Int = { @@ -1599,7 +1600,7 @@ private[spark] object Utils extends Logging { * @param f function to be executed. If prepare is not None, the running time of each call to f * must be an order of magnitude longer than one millisecond for accurate timing. * @param prepare function to be executed before each call to f. Its running time doesn't count. - * @return the total time across all iterations (not couting preparation time) + * @return the total time across all iterations (not counting preparation time) */ def timeIt(numIters: Int)(f: => Unit, prepare: Option[() => Unit] = None): Long = { if (prepare.isEmpty) { @@ -1819,15 +1820,6 @@ private[spark] object Utils extends Logging { } } - lazy val isInInterpreter: Boolean = { - try { - val interpClass = classForName("org.apache.spark.repl.Main") - interpClass.getMethod("interp").invoke(null) != null - } catch { - case _: ClassNotFoundException => false - } - } - /** * Return a well-formed URI for the file described by a user input string. * @@ -1904,7 +1896,7 @@ private[spark] object Utils extends Logging { require(file.exists(), s"Properties file $file does not exist") require(file.isFile(), s"Properties file $file is not a normal file") - val inReader = new InputStreamReader(new FileInputStream(file), "UTF-8") + val inReader = new InputStreamReader(new FileInputStream(file), StandardCharsets.UTF_8) try { val properties = new Properties() properties.load(inReader) @@ -2344,7 +2336,7 @@ private[spark] class CircularBuffer(sizeInBytes: Int = 10240) extends java.io.Ou def read(): Int = if (iterator.hasNext) iterator.next() else -1 } - val reader = new BufferedReader(new InputStreamReader(input)) + val reader = new BufferedReader(new InputStreamReader(input, StandardCharsets.UTF_8)) val stringBuilder = new StringBuilder var line = reader.readLine() while (line != null) { diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala index 1314217023d15..3c61528ab5287 100644 --- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala +++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala @@ -54,7 +54,7 @@ object RandomSampler { /** * Default maximum gap-sampling fraction. * For sampling fractions <= this value, the gap sampling optimization will be applied. - * Above this value, it is assumed that "tradtional" Bernoulli sampling is faster. The + * Above this value, it is assumed that "traditional" Bernoulli sampling is faster. The * optimal value for this will depend on the RNG. More expensive RNGs will tend to make * the optimal value higher. The most reliable way to determine this value for a new RNG * is to experiment. When tuning for a new RNG, I would expect a value of 0.5 to be close @@ -319,7 +319,7 @@ class GapSamplingReplacementIterator[T: ClassTag]( /** * Skip elements with replication factor zero (i.e. elements that won't be sampled). * Samples 'k' from geometric distribution P(k) = (1-q)(q)^k, where q = e^(-f), that is - * q is the probabililty of Poisson(0; f) + * q is the probability of Poisson(0; f) */ private def advance(): Unit = { val u = math.max(rng.nextDouble(), epsilon) diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index e6a4ab7550c2a..a7e74c00793c3 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -21,6 +21,7 @@ import java.nio.channels.FileChannel; import java.nio.ByteBuffer; import java.net.URI; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -45,7 +46,6 @@ import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.base.Throwables; -import com.google.common.base.Charsets; import com.google.common.io.Files; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.Text; @@ -1058,7 +1058,7 @@ public void textFiles() throws IOException { rdd.saveAsTextFile(outputDir); // Read the plain text file and check it's OK File outputFile = new File(outputDir, "part-00000"); - String content = Files.toString(outputFile, Charsets.UTF_8); + String content = Files.toString(outputFile, StandardCharsets.UTF_8); Assert.assertEquals("1\n2\n3\n4\n", content); // Also try reading it in as a text file RDD List expected = Arrays.asList("1", "2", "3", "4"); diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java index b4fa33f32a3fd..a3502708aadec 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.shuffle.sort; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Random; @@ -41,7 +42,7 @@ public class ShuffleInMemorySorterSuite { private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) { final byte[] strBytes = new byte[strLength]; Platform.copyMemory(baseObject, baseOffset, strBytes, Platform.BYTE_ARRAY_OFFSET, strLength); - return new String(strBytes); + return new String(strBytes, StandardCharsets.UTF_8); } @Test diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index b757ddc3b37f9..a79ed58133f1b 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -80,7 +80,6 @@ public int compare( } }; - SparkConf sparkConf; File tempDir; @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager; @@ -99,7 +98,6 @@ public OutputStream apply(OutputStream stream) { @Before public void setUp() { MockitoAnnotations.initMocks(this); - sparkConf = new SparkConf(); tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "unsafe-test"); spillFilesCreated.clear(); taskContext = mock(TaskContext.class); diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index ff41768df1d8f..90849ab0bd8f3 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.util.collection.unsafe.sort; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import org.junit.Assert; @@ -41,7 +42,7 @@ public class UnsafeInMemorySorterSuite { private static String getStringFromDataPage(Object baseObject, long baseOffset, int length) { final byte[] strBytes = new byte[length]; Platform.copyMemory(baseObject, baseOffset, strBytes, Platform.BYTE_ARRAY_OFFSET, length); - return new String(strBytes); + return new String(strBytes, StandardCharsets.UTF_8); } @Test diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index d1e806b2eb80a..e60678b300093 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark import java.lang.ref.WeakReference -import scala.collection.mutable.{HashSet, SynchronizedSet} +import scala.collection.mutable.HashSet import scala.language.existentials import scala.util.Random @@ -442,25 +442,25 @@ class CleanerTester( checkpointIds: Seq[Long] = Seq.empty) extends Logging { - val toBeCleanedRDDIds = new HashSet[Int] with SynchronizedSet[Int] ++= rddIds - val toBeCleanedShuffleIds = new HashSet[Int] with SynchronizedSet[Int] ++= shuffleIds - val toBeCleanedBroadcstIds = new HashSet[Long] with SynchronizedSet[Long] ++= broadcastIds - val toBeCheckpointIds = new HashSet[Long] with SynchronizedSet[Long] ++= checkpointIds + val toBeCleanedRDDIds = new HashSet[Int] ++= rddIds + val toBeCleanedShuffleIds = new HashSet[Int] ++= shuffleIds + val toBeCleanedBroadcstIds = new HashSet[Long] ++= broadcastIds + val toBeCheckpointIds = new HashSet[Long] ++= checkpointIds val isDistributed = !sc.isLocal val cleanerListener = new CleanerListener { def rddCleaned(rddId: Int): Unit = { - toBeCleanedRDDIds -= rddId + toBeCleanedRDDIds.synchronized { toBeCleanedRDDIds -= rddId } logInfo("RDD " + rddId + " cleaned") } def shuffleCleaned(shuffleId: Int): Unit = { - toBeCleanedShuffleIds -= shuffleId + toBeCleanedShuffleIds.synchronized { toBeCleanedShuffleIds -= shuffleId } logInfo("Shuffle " + shuffleId + " cleaned") } def broadcastCleaned(broadcastId: Long): Unit = { - toBeCleanedBroadcstIds -= broadcastId + toBeCleanedBroadcstIds.synchronized { toBeCleanedBroadcstIds -= broadcastId } logInfo("Broadcast " + broadcastId + " cleaned") } @@ -469,7 +469,7 @@ class CleanerTester( } def checkpointCleaned(rddId: Long): Unit = { - toBeCheckpointIds -= rddId + toBeCheckpointIds.synchronized { toBeCheckpointIds -= rddId } logInfo("checkpoint " + rddId + " cleaned") } } @@ -578,18 +578,27 @@ class CleanerTester( } private def uncleanedResourcesToString = { + val s1 = toBeCleanedRDDIds.synchronized { + toBeCleanedRDDIds.toSeq.sorted.mkString("[", ", ", "]") + } + val s2 = toBeCleanedShuffleIds.synchronized { + toBeCleanedShuffleIds.toSeq.sorted.mkString("[", ", ", "]") + } + val s3 = toBeCleanedBroadcstIds.synchronized { + toBeCleanedBroadcstIds.toSeq.sorted.mkString("[", ", ", "]") + } s""" - |\tRDDs = ${toBeCleanedRDDIds.toSeq.sorted.mkString("[", ", ", "]")} - |\tShuffles = ${toBeCleanedShuffleIds.toSeq.sorted.mkString("[", ", ", "]")} - |\tBroadcasts = ${toBeCleanedBroadcstIds.toSeq.sorted.mkString("[", ", ", "]")} + |\tRDDs = $s1 + |\tShuffles = $s2 + |\tBroadcasts = $s3 """.stripMargin } private def isAllCleanedUp = - toBeCleanedRDDIds.isEmpty && - toBeCleanedShuffleIds.isEmpty && - toBeCleanedBroadcstIds.isEmpty && - toBeCheckpointIds.isEmpty + toBeCleanedRDDIds.synchronized { toBeCleanedRDDIds.isEmpty } && + toBeCleanedShuffleIds.synchronized { toBeCleanedShuffleIds.isEmpty } && + toBeCleanedBroadcstIds.synchronized { toBeCleanedBroadcstIds.isEmpty } && + toBeCheckpointIds.synchronized { toBeCheckpointIds.isEmpty } private def getRDDBlocks(rddId: Int): Seq[BlockId] = { blockManager.master.getMatchingBlockIds( _ match { diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 556afd08bbfe5..841fd02ae8bb6 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark import java.io.File +import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit import scala.concurrent.Await import scala.concurrent.duration.Duration -import com.google.common.base.Charsets._ import com.google.common.io.Files import org.apache.hadoop.io.{BytesWritable, LongWritable, Text} import org.apache.hadoop.mapred.TextInputFormat @@ -115,8 +115,8 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext { val absolutePath2 = file2.getAbsolutePath try { - Files.write("somewords1", file1, UTF_8) - Files.write("somewords2", file2, UTF_8) + Files.write("somewords1", file1, StandardCharsets.UTF_8) + Files.write("somewords2", file2, StandardCharsets.UTF_8) val length1 = file1.length() val length2 = file2.length() @@ -243,11 +243,12 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext { try { // Create 5 text files. - Files.write("someline1 in file1\nsomeline2 in file1\nsomeline3 in file1", file1, UTF_8) - Files.write("someline1 in file2\nsomeline2 in file2", file2, UTF_8) - Files.write("someline1 in file3", file3, UTF_8) - Files.write("someline1 in file4\nsomeline2 in file4", file4, UTF_8) - Files.write("someline1 in file2\nsomeline2 in file5", file5, UTF_8) + Files.write("someline1 in file1\nsomeline2 in file1\nsomeline3 in file1", file1, + StandardCharsets.UTF_8) + Files.write("someline1 in file2\nsomeline2 in file2", file2, StandardCharsets.UTF_8) + Files.write("someline1 in file3", file3, StandardCharsets.UTF_8) + Files.write("someline1 in file4\nsomeline2 in file4", file4, StandardCharsets.UTF_8) + Files.write("someline1 in file2\nsomeline2 in file5", file5, StandardCharsets.UTF_8) sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala index 41f2a5c972b6b..05b4e67412f2e 100644 --- a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.api.python import java.io.{ByteArrayOutputStream, DataOutputStream} +import java.nio.charset.StandardCharsets import org.apache.spark.SparkFunSuite @@ -35,10 +36,12 @@ class PythonRDDSuite extends SparkFunSuite { // The correctness will be tested in Python PythonRDD.writeIteratorToStream(Iterator("a", null), buffer) PythonRDD.writeIteratorToStream(Iterator(null, "a"), buffer) - PythonRDD.writeIteratorToStream(Iterator("a".getBytes, null), buffer) - PythonRDD.writeIteratorToStream(Iterator(null, "a".getBytes), buffer) + PythonRDD.writeIteratorToStream(Iterator("a".getBytes(StandardCharsets.UTF_8), null), buffer) + PythonRDD.writeIteratorToStream(Iterator(null, "a".getBytes(StandardCharsets.UTF_8)), buffer) PythonRDD.writeIteratorToStream(Iterator((null, null), ("a", null), (null, "b")), buffer) - PythonRDD.writeIteratorToStream( - Iterator((null, null), ("a".getBytes, null), (null, "b".getBytes)), buffer) + PythonRDD.writeIteratorToStream(Iterator( + (null, null), + ("a".getBytes(StandardCharsets.UTF_8), null), + (null, "b".getBytes(StandardCharsets.UTF_8))), buffer) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 41ac60ece0eda..bb2adff57e944 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -18,10 +18,10 @@ package org.apache.spark.deploy import java.io._ +import java.nio.charset.StandardCharsets import scala.collection.mutable.ArrayBuffer -import com.google.common.base.Charsets.UTF_8 import com.google.common.io.ByteStreams import org.scalatest.{BeforeAndAfterEach, Matchers} import org.scalatest.concurrent.Timeouts @@ -34,7 +34,7 @@ import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate import org.apache.spark.util.{ResetSystemProperties, Utils} // Note: this suite mixes in ResetSystemProperties because SparkSubmit.main() sets a bunch -// of properties that neeed to be cleared after tests. +// of properties that needed to be cleared after tests. class SparkSubmitSuite extends SparkFunSuite with Matchers @@ -593,7 +593,7 @@ class SparkSubmitSuite val tmpDir = Utils.createTempDir() val defaultsConf = new File(tmpDir.getAbsolutePath, "spark-defaults.conf") - val writer = new OutputStreamWriter(new FileOutputStream(defaultsConf)) + val writer = new OutputStreamWriter(new FileOutputStream(defaultsConf), StandardCharsets.UTF_8) for ((key, value) <- defaults) writer.write(s"$key $value\n") writer.close() @@ -661,7 +661,7 @@ object UserClasspathFirstTest { val ccl = Thread.currentThread().getContextClassLoader() val resource = ccl.getResourceAsStream("test.resource") val bytes = ByteStreams.toByteArray(resource) - val contents = new String(bytes, 0, bytes.length, UTF_8) + val contents = new String(bytes, 0, bytes.length, StandardCharsets.UTF_8) if (contents != "USER") { throw new SparkException("Should have read user resource, but instead read: " + contents) } diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index b7ff5c9e8c0d3..d2e24912b570f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -501,7 +501,7 @@ class StandaloneDynamicAllocationSuite master.self.askWithRetry[MasterStateResponse](RequestMasterState) } - /** Get the applictions that are active from Master */ + /** Get the applications that are active from Master */ private def getApplications(): Seq[ApplicationInfo] = { getMasterState.activeApps } diff --git a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala index 379c038c5503d..7017296bd1361 100644 --- a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala @@ -159,7 +159,7 @@ class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAnd master.self.askWithRetry[MasterStateResponse](RequestMasterState) } - /** Get the applictions that are active from Master */ + /** Get the applications that are active from Master */ private def getApplications(): Seq[ApplicationInfo] = { getMasterState.activeApps } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala index e24188781f7cd..c874b95b0960a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala @@ -219,7 +219,7 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar val cacheEntry = cache.lookupCacheEntry(app1, None) assert(1 === cacheEntry.probeTime) assert(cacheEntry.completed) - // assert about queries made of the opereations + // assert about queries made of the operations assert(1 === operations.getAppUICount, "getAppUICount") assert(1 === operations.attachCount, "attachCount") @@ -338,7 +338,7 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar } /** - * Look up the cache entry and assert that it maches in the expected value. + * Look up the cache entry and assert that it matches in the expected value. * This assertion works if the two CacheEntries are different -it looks at the fields. * UI are compared on object equality; the timestamp and completed flags directly. * @param appId application ID @@ -384,7 +384,7 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar val operations = new StubCacheOperations() val clock = new ManualClock(0) val size = 5 - // only two entries are retained, so we expect evictions to occurr on lookups + // only two entries are retained, so we expect evictions to occur on lookups implicit val cache: ApplicationCache = new TestApplicationCache(operations, retainedApplications = size, clock = clock) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 8e8007f4ebf4b..5fd599e190c7c 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -20,13 +20,13 @@ package org.apache.spark.deploy.history import java.io.{BufferedOutputStream, ByteArrayInputStream, ByteArrayOutputStream, File, FileOutputStream, OutputStreamWriter} import java.net.URI +import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit import java.util.zip.{ZipInputStream, ZipOutputStream} import scala.concurrent.duration._ import scala.language.postfixOps -import com.google.common.base.Charsets import com.google.common.io.{ByteStreams, Files} import org.apache.hadoop.hdfs.DistributedFileSystem import org.json4s.jackson.JsonMethods._ @@ -320,8 +320,9 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc var entry = inputStream.getNextEntry entry should not be null while (entry != null) { - val actual = new String(ByteStreams.toByteArray(inputStream), Charsets.UTF_8) - val expected = Files.toString(logs.find(_.getName == entry.getName).get, Charsets.UTF_8) + val actual = new String(ByteStreams.toByteArray(inputStream), StandardCharsets.UTF_8) + val expected = + Files.toString(logs.find(_.getName == entry.getName).get, StandardCharsets.UTF_8) actual should be (expected) totalEntries += 1 entry = inputStream.getNextEntry @@ -415,7 +416,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc if (isNewFormat) { EventLoggingListener.initEventLog(new FileOutputStream(file)) } - val writer = new OutputStreamWriter(bstream, "UTF-8") + val writer = new OutputStreamWriter(bstream, StandardCharsets.UTF_8) Utils.tryWithSafeFinally { events.foreach(e => writer.write(compact(render(JsonProtocol.sparkEventToJson(e))) + "\n")) } { diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index e5cd2eddba1e8..5822261d8da75 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.history import java.io.{File, FileInputStream, FileWriter, InputStream, IOException} import java.net.{HttpURLConnection, URL} +import java.nio.charset.StandardCharsets import java.util.zip.ZipInputStream import javax.servlet.http.{HttpServletRequest, HttpServletResponse} @@ -25,7 +26,6 @@ import scala.concurrent.duration._ import scala.language.postfixOps import com.codahale.metrics.Counter -import com.google.common.base.Charsets import com.google.common.io.{ByteStreams, Files} import org.apache.commons.io.{FileUtils, IOUtils} import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} @@ -216,8 +216,8 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers val expectedFile = { new File(logDir, entry.getName) } - val expected = Files.toString(expectedFile, Charsets.UTF_8) - val actual = new String(ByteStreams.toByteArray(zipStream), Charsets.UTF_8) + val expected = Files.toString(expectedFile, StandardCharsets.UTF_8) + val actual = new String(ByteStreams.toByteArray(zipStream), StandardCharsets.UTF_8) actual should be (expected) filesCompared += 1 } diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index ee889bf144546..a7bb9aa4686eb 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -19,11 +19,11 @@ package org.apache.spark.deploy.rest import java.io.DataOutputStream import java.net.{HttpURLConnection, URL} +import java.nio.charset.StandardCharsets import javax.servlet.http.HttpServletResponse import scala.collection.mutable -import com.google.common.base.Charsets import org.json4s.JsonAST._ import org.json4s.jackson.JsonMethods._ import org.scalatest.BeforeAndAfterEach @@ -498,7 +498,7 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { if (body.nonEmpty) { conn.setDoOutput(true) val out = new DataOutputStream(conn.getOutputStream) - out.write(body.getBytes(Charsets.UTF_8)) + out.write(body.getBytes(StandardCharsets.UTF_8)) out.close() } conn diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index e5a448298a624..056e5463a0abf 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -98,14 +98,14 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext rdd.coalesce(4).count() } - // for count and coelesce, the same bytes should be read. + // for count and coalesce, the same bytes should be read. assert(bytesRead != 0) assert(bytesRead2 == bytesRead) } /** * This checks the situation where we have interleaved reads from - * different sources. Currently, we only accumulate fron the first + * different sources. Currently, we only accumulate from the first * read method we find in the task. This test uses cartesian to create * the interleaved reads. * @@ -183,7 +183,7 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext assert(records == numRecords) } - test("input metrics on recordsd read with cache") { + test("input metrics on records read with cache") { // prime the cache manager val rdd = sc.textFile(tmpFilePath, 4).cache() rdd.collect() diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index 47dbcb8fc0eaa..02806a16b9467 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.network.netty import java.io.InputStreamReader import java.nio._ -import java.nio.charset.Charset +import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit import scala.concurrent.{Await, Promise} @@ -103,7 +103,8 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi val blockManager = mock[BlockDataManager] val blockId = ShuffleBlockId(0, 1, 2) val blockString = "Hello, world!" - val blockBuffer = new NioManagedBuffer(ByteBuffer.wrap(blockString.getBytes)) + val blockBuffer = new NioManagedBuffer(ByteBuffer.wrap( + blockString.getBytes(StandardCharsets.UTF_8))) when(blockManager.getBlockData(blockId)).thenReturn(blockBuffer) val securityManager0 = new SecurityManager(conf0) @@ -117,7 +118,7 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi val result = fetchBlock(exec0, exec1, "1", blockId) match { case Success(buf) => val actualString = CharStreams.toString( - new InputStreamReader(buf.createInputStream(), Charset.forName("UTF-8"))) + new InputStreamReader(buf.createInputStream(), StandardCharsets.UTF_8)) actualString should equal(blockString) buf.release() Success() diff --git a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala index 76451788d2406..864adddad3426 100644 --- a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala @@ -164,8 +164,8 @@ class DoubleRDDSuite extends SparkFunSuite with SharedSparkContext { val expectedHistogramResults = Array(4, 2, 1, 2, 3) assert(histogramResults === expectedHistogramResults) } - // Make sure this works with a NaN end bucket and an inifity - test("WorksMixedRangeWithUnevenBucketsAndNaNAndNaNRangeAndInfity") { + // Make sure this works with a NaN end bucket and an infinity + test("WorksMixedRangeWithUnevenBucketsAndNaNAndNaNRangeAndInfinity") { // Make sure that it works with two unequally spaced buckets and elements in each val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, 200.0, 200.1, 1.0/0.0, -1.0/0.0, Double.NaN)) diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 7d51538d92597..b0d69de6e2ef4 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -182,7 +182,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { assert(sums(2) === 1) } - test("reduceByKey with many output partitons") { + test("reduceByKey with many output partitions") { val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) val sums = pairs.reduceByKey(_ + _, 10).collect() assert(sums.toSet === Set((1, 7), (2, 1))) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index d8849d59482e6..55f4190680dd5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -134,6 +134,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou val successfulStages = new HashSet[Int] val failedStages = new ArrayBuffer[Int] val stageByOrderOfExecution = new ArrayBuffer[Int] + val endedTasks = new HashSet[Long] override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { submittedStageInfos += stageSubmitted.stageInfo @@ -148,6 +149,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou failedStages += stageInfo.stageId } } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + endedTasks += taskEnd.taskInfo.taskId + } } var mapOutputTracker: MapOutputTrackerMaster = null @@ -195,6 +200,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou sparkListener.submittedStageInfos.clear() sparkListener.successfulStages.clear() sparkListener.failedStages.clear() + sparkListener.endedTasks.clear() failure = null sc.addSparkListener(sparkListener) taskSets.clear() @@ -663,7 +669,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou completeShuffleMapStageSuccessfully(0, 1, numShufflePartitions = parts) completeNextResultStageWithSuccess(1, 1) - // Confirm job finished succesfully + // Confirm job finished successfully sc.listenerBus.waitUntilEmpty(1000) assert(ended === true) assert(results === (0 until parts).map { idx => idx -> 42 }.toMap) @@ -982,6 +988,52 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou assert(countSubmittedMapStageAttempts() === 2) } + test("task events always posted in speculation / when stage is killed") { + val baseRdd = new MyRDD(sc, 4, Nil) + val finalRdd = new MyRDD(sc, 4, List(new OneToOneDependency(baseRdd))) + submit(finalRdd, Array(0, 1, 2, 3)) + + // complete two tasks + runEvent(makeCompletionEvent( + taskSets(0).tasks(0), Success, 42, + Seq.empty[AccumulableInfo], createFakeTaskInfoWithId(0))) + runEvent(makeCompletionEvent( + taskSets(0).tasks(1), Success, 42, + Seq.empty[AccumulableInfo], createFakeTaskInfoWithId(1))) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + // verify stage exists + assert(scheduler.stageIdToStage.contains(0)) + assert(sparkListener.endedTasks.size == 2) + + // finish other 2 tasks + runEvent(makeCompletionEvent( + taskSets(0).tasks(2), Success, 42, + Seq.empty[AccumulableInfo], createFakeTaskInfoWithId(2))) + runEvent(makeCompletionEvent( + taskSets(0).tasks(3), Success, 42, + Seq.empty[AccumulableInfo], createFakeTaskInfoWithId(3))) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(sparkListener.endedTasks.size == 4) + + // verify the stage is done + assert(!scheduler.stageIdToStage.contains(0)) + + // Stage should be complete. Finish one other Successful task to simulate what can happen + // with a speculative task and make sure the event is sent out + runEvent(makeCompletionEvent( + taskSets(0).tasks(3), Success, 42, + Seq.empty[AccumulableInfo], createFakeTaskInfoWithId(5))) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(sparkListener.endedTasks.size == 5) + + // make sure non successful tasks also send out event + runEvent(makeCompletionEvent( + taskSets(0).tasks(3), UnknownReason, 42, + Seq.empty[AccumulableInfo], createFakeTaskInfoWithId(6))) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(sparkListener.endedTasks.size == 6) + } + test("ignore late map task completions") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) @@ -1944,6 +1996,12 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou info } + private def createFakeTaskInfoWithId(taskId: Long): TaskInfo = { + val info = new TaskInfo(taskId, 0, 0, 0L, "", "", TaskLocality.ANY, false) + info.finishTime = 1 // to prevent spurious errors in JobProgressListener + info + } + private def makeCompletionEvent( task: Task[_], reason: TaskEndReason, diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 2c99dd5afb32e..d35ca411f4080 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -396,7 +396,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val rescheduleDelay = 300L val conf = new SparkConf(). set("spark.scheduler.executorTaskBlacklistTime", rescheduleDelay.toString). - // dont wait to jump locality levels in this test + // don't wait to jump locality levels in this test set("spark.locality.wait", "0") sc = new SparkContext("local", "test", conf) diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala index dd76644288b4c..b18f0eb162b1d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala @@ -192,7 +192,8 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite val status2 = createTaskStatus("1", "s1", TaskState.TASK_RUNNING) backend.statusUpdate(driver, status2) - verify(externalShuffleClient, times(1)).registerDriverWithShuffleService(anyString, anyInt) + verify(externalShuffleClient, times(1)) + .registerDriverWithShuffleService(anyString, anyInt, anyLong, anyLong) } test("mesos kills an executor when told") { diff --git a/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala index 683aaa3aab1ba..bdee889cdc409 100644 --- a/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala @@ -191,7 +191,7 @@ class SerializationDebuggerSuite extends SparkFunSuite with BeforeAndAfterEach { } val originalException = new NotSerializableException("someClass") - // verify thaht original exception is returned on failure + // verify that original exception is returned on failure assert(SerializationDebugger.improveException(o, originalException).eq(originalException)) } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index dc4be1467794c..2e0c0596a75bb 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -1065,28 +1065,6 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(memoryStore.currentUnrollMemoryForThisTask === 0) } - /** - * Verify the result of MemoryStore#unrollSafely is as expected. - */ - private def verifyUnroll( - expected: Iterator[Any], - result: Either[Array[Any], Iterator[Any]], - shouldBeArray: Boolean): Unit = { - val actual: Iterator[Any] = result match { - case Left(arr: Array[Any]) => - assert(shouldBeArray, "expected iterator from unroll!") - arr.iterator - case Right(it: Iterator[Any]) => - assert(!shouldBeArray, "expected array from unroll!") - it - case _ => - fail("unroll returned neither an iterator nor an array...") - } - expected.zip(actual).foreach { case (e, a) => - assert(e === a, "unroll did not return original values!") - } - } - test("safely unroll blocks") { store = makeBlockManager(12000) val smallList = List.fill(40)(new Array[Byte](100)) @@ -1094,30 +1072,41 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val memoryStore = store.memoryStore assert(memoryStore.currentUnrollMemoryForThisTask === 0) - // Unroll with all the space in the world. This should succeed and return an array. - var unrollResult = memoryStore.unrollSafely("unroll", smallList.iterator) - verifyUnroll(smallList.iterator, unrollResult, shouldBeArray = true) + // Unroll with all the space in the world. This should succeed. + var putResult = memoryStore.putIterator("unroll", smallList.iterator, StorageLevel.MEMORY_ONLY) + assert(putResult.isRight) assert(memoryStore.currentUnrollMemoryForThisTask === 0) - memoryStore.releasePendingUnrollMemoryForThisTask() + smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case (e, a) => + assert(e === a, "getValues() did not return original values!") + } + assert(memoryStore.remove("unroll")) // Unroll with not enough space. This should succeed after kicking out someBlock1. - store.putIterator("someBlock1", smallList.iterator, StorageLevel.MEMORY_ONLY) - store.putIterator("someBlock2", smallList.iterator, StorageLevel.MEMORY_ONLY) - unrollResult = memoryStore.unrollSafely("unroll", smallList.iterator) - verifyUnroll(smallList.iterator, unrollResult, shouldBeArray = true) + assert(store.putIterator("someBlock1", smallList.iterator, StorageLevel.MEMORY_ONLY)) + assert(store.putIterator("someBlock2", smallList.iterator, StorageLevel.MEMORY_ONLY)) + putResult = memoryStore.putIterator("unroll", smallList.iterator, StorageLevel.MEMORY_ONLY) + assert(putResult.isRight) assert(memoryStore.currentUnrollMemoryForThisTask === 0) assert(memoryStore.contains("someBlock2")) assert(!memoryStore.contains("someBlock1")) - memoryStore.releasePendingUnrollMemoryForThisTask() + smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case (e, a) => + assert(e === a, "getValues() did not return original values!") + } + assert(memoryStore.remove("unroll")) // Unroll huge block with not enough space. Even after ensuring free space of 12000 * 0.4 = // 4800 bytes, there is still not enough room to unroll this block. This returns an iterator. // In the mean time, however, we kicked out someBlock2 before giving up. - store.putIterator("someBlock3", smallList.iterator, StorageLevel.MEMORY_ONLY) - unrollResult = memoryStore.unrollSafely("unroll", bigList.iterator) - verifyUnroll(bigList.iterator, unrollResult, shouldBeArray = false) + assert(store.putIterator("someBlock3", smallList.iterator, StorageLevel.MEMORY_ONLY)) + putResult = memoryStore.putIterator("unroll", bigList.iterator, StorageLevel.MEMORY_ONLY) assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator assert(!memoryStore.contains("someBlock2")) + assert(putResult.isLeft) + bigList.iterator.zip(putResult.left.get).foreach { case (e, a) => + assert(e === a, "putIterator() did not return original values!") + } + // The unroll memory was freed once the iterator returned by putIterator() was fully traversed. + assert(memoryStore.currentUnrollMemoryForThisTask === 0) } test("safely unroll blocks through putIterator") { @@ -1208,6 +1197,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(memoryStore.contains("b3")) assert(!memoryStore.contains("b4")) assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator + result4.left.get.close() + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // close released the unroll memory } test("multiple unrolls by the same thread") { @@ -1218,29 +1209,29 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] assert(memoryStore.currentUnrollMemoryForThisTask === 0) - // All unroll memory used is released because unrollSafely returned an array - memoryStore.putIterator("b1", smallIterator, memOnly) + // All unroll memory used is released because putIterator did not return an iterator + assert(memoryStore.putIterator("b1", smallIterator, memOnly).isRight) assert(memoryStore.currentUnrollMemoryForThisTask === 0) - memoryStore.putIterator("b2", smallIterator, memOnly) + assert(memoryStore.putIterator("b2", smallIterator, memOnly).isRight) assert(memoryStore.currentUnrollMemoryForThisTask === 0) - // Unroll memory is not released because unrollSafely returned an iterator + // Unroll memory is not released because putIterator returned an iterator // that still depends on the underlying vector used in the process - memoryStore.putIterator("b3", smallIterator, memOnly) + assert(memoryStore.putIterator("b3", smallIterator, memOnly).isLeft) val unrollMemoryAfterB3 = memoryStore.currentUnrollMemoryForThisTask assert(unrollMemoryAfterB3 > 0) // The unroll memory owned by this thread builds on top of its value after the previous unrolls - memoryStore.putIterator("b4", smallIterator, memOnly) + assert(memoryStore.putIterator("b4", smallIterator, memOnly).isLeft) val unrollMemoryAfterB4 = memoryStore.currentUnrollMemoryForThisTask assert(unrollMemoryAfterB4 > unrollMemoryAfterB3) // ... but only to a certain extent (until we run out of free space to grant new unroll memory) - memoryStore.putIterator("b5", smallIterator, memOnly) + assert(memoryStore.putIterator("b5", smallIterator, memOnly).isLeft) val unrollMemoryAfterB5 = memoryStore.currentUnrollMemoryForThisTask - memoryStore.putIterator("b6", smallIterator, memOnly) + assert(memoryStore.putIterator("b6", smallIterator, memOnly).isLeft) val unrollMemoryAfterB6 = memoryStore.currentUnrollMemoryForThisTask - memoryStore.putIterator("b7", smallIterator, memOnly) + assert(memoryStore.putIterator("b7", smallIterator, memOnly).isLeft) val unrollMemoryAfterB7 = memoryStore.currentUnrollMemoryForThisTask assert(unrollMemoryAfterB5 === unrollMemoryAfterB4) assert(unrollMemoryAfterB6 === unrollMemoryAfterB4) diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala index b367cc8358342..d30eafd2d4218 100644 --- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.util import java.io._ +import java.nio.charset.StandardCharsets import java.util.concurrent.CountDownLatch import scala.collection.mutable.HashSet import scala.reflect._ -import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.apache.log4j.{Appender, Level, Logger} import org.apache.log4j.spi.LoggingEvent @@ -48,11 +48,11 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { test("basic file appender") { val testString = (1 to 1000).mkString(", ") - val inputStream = new ByteArrayInputStream(testString.getBytes(UTF_8)) + val inputStream = new ByteArrayInputStream(testString.getBytes(StandardCharsets.UTF_8)) val appender = new FileAppender(inputStream, testFile) inputStream.close() appender.awaitTermination() - assert(Files.toString(testFile, UTF_8) === testString) + assert(Files.toString(testFile, StandardCharsets.UTF_8) === testString) } test("rolling file appender - time-based rolling") { @@ -100,7 +100,7 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { val allGeneratedFiles = new HashSet[String]() val items = (1 to 10).map { _.toString * 10000 } for (i <- 0 until items.size) { - testOutputStream.write(items(i).getBytes(UTF_8)) + testOutputStream.write(items(i).getBytes(StandardCharsets.UTF_8)) testOutputStream.flush() allGeneratedFiles ++= RollingFileAppender.getSortedRolledOverFiles( testFile.getParentFile.toString, testFile.getName).map(_.toString) @@ -267,7 +267,7 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { // send data to appender through the input stream, and wait for the data to be written val expectedText = textToAppend.mkString("") for (i <- 0 until textToAppend.size) { - outputStream.write(textToAppend(i).getBytes(UTF_8)) + outputStream.write(textToAppend(i).getBytes(StandardCharsets.UTF_8)) outputStream.flush() Thread.sleep(sleepTimeBetweenTexts) } @@ -282,7 +282,7 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { logInfo("Filtered files: \n" + generatedFiles.mkString("\n")) assert(generatedFiles.size > 1) val allText = generatedFiles.map { file => - Files.toString(file, UTF_8) + Files.toString(file, StandardCharsets.UTF_8) }.mkString("") assert(allText === expectedText) generatedFiles diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 412c0ac9d9be3..093d1bd6e5948 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -21,6 +21,7 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, FileOutputStr import java.lang.{Double => JDouble, Float => JFloat} import java.net.{BindException, ServerSocket, URI} import java.nio.{ByteBuffer, ByteOrder} +import java.nio.charset.StandardCharsets import java.text.DecimalFormatSymbols import java.util.Locale import java.util.concurrent.TimeUnit @@ -28,7 +29,6 @@ import java.util.concurrent.TimeUnit import scala.collection.mutable.ListBuffer import scala.util.Random -import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.apache.commons.lang3.SystemUtils import org.apache.hadoop.conf.Configuration @@ -268,7 +268,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { val tmpDir2 = Utils.createTempDir() val f1Path = tmpDir2 + "/f1" val f1 = new FileOutputStream(f1Path) - f1.write("1\n2\n3\n4\n5\n6\n7\n8\n9\n".getBytes(UTF_8)) + f1.write("1\n2\n3\n4\n5\n6\n7\n8\n9\n".getBytes(StandardCharsets.UTF_8)) f1.close() // Read first few bytes @@ -295,9 +295,9 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { test("reading offset bytes across multiple files") { val tmpDir = Utils.createTempDir() val files = (1 to 3).map(i => new File(tmpDir, i.toString)) - Files.write("0123456789", files(0), UTF_8) - Files.write("abcdefghij", files(1), UTF_8) - Files.write("ABCDEFGHIJ", files(2), UTF_8) + Files.write("0123456789", files(0), StandardCharsets.UTF_8) + Files.write("abcdefghij", files(1), StandardCharsets.UTF_8) + Files.write("ABCDEFGHIJ", files(2), StandardCharsets.UTF_8) // Read first few bytes in the 1st file assert(Utils.offsetBytes(files, 0, 5) === "01234") @@ -529,7 +529,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { try { System.setProperty("spark.test.fileNameLoadB", "2") Files.write("spark.test.fileNameLoadA true\n" + - "spark.test.fileNameLoadB 1\n", outFile, UTF_8) + "spark.test.fileNameLoadB 1\n", outFile, StandardCharsets.UTF_8) val properties = Utils.getPropertiesFromFile(outFile.getAbsolutePath) properties .filter { case (k, v) => k.startsWith("spark.")} @@ -559,7 +559,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { val innerSourceDir = Utils.createTempDir(root = sourceDir.getPath) val sourceFile = File.createTempFile("someprefix", "somesuffix", innerSourceDir) val targetDir = new File(tempDir, "target-dir") - Files.write("some text", sourceFile, UTF_8) + Files.write("some text", sourceFile, StandardCharsets.UTF_8) val path = if (Utils.isWindows) { @@ -801,7 +801,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { |trap "" SIGTERM |sleep 10 """.stripMargin - Files.write(cmd.getBytes(), file) + Files.write(cmd.getBytes(StandardCharsets.UTF_8), file) file.getAbsoluteFile.setExecutable(true) val process = new ProcessBuilder(file.getAbsolutePath).start() diff --git a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala index 853503bbc2bba..83eba3690e289 100644 --- a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala @@ -53,7 +53,7 @@ class XORShiftRandomSuite extends SparkFunSuite with Matchers { * Perform the chi square test on the 5 rows of randomly generated numbers evenly divided into * 10 bins. chiSquareTest returns true iff the null hypothesis (that the classifications * represented by the counts in the columns of the input 2-way table are independent of the - * rows) can be rejected with 100 * (1 - alpha) percent confidence, where alpha is prespeficied + * rows) can be rejected with 100 * (1 - alpha) percent confidence, where alpha is prespecified * as 0.05 */ val chiTest = new ChiSquareTest diff --git a/dev/audit-release/audit_release.py b/dev/audit-release/audit_release.py index 4dabb51254af7..426b3117f14d4 100755 --- a/dev/audit-release/audit_release.py +++ b/dev/audit-release/audit_release.py @@ -116,8 +116,7 @@ def ensure_path_not_present(path): # dependencies within those projects. modules = [ "spark-core", "spark-mllib", "spark-streaming", "spark-repl", - "spark-graphx", "spark-streaming-flume", "spark-streaming-kafka", - "spark-streaming-mqtt", "spark-streaming-twitter", "spark-streaming-zeromq", + "spark-graphx", "spark-streaming-kafka", "spark-catalyst", "spark-sql", "spark-hive", "spark-streaming-kinesis-asl" ] modules = map(lambda m: "%s_%s" % (m, SCALA_BINARY_VERSION), modules) diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index 9991a3be36f8b..512675a599b3c 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -153,7 +153,7 @@ pmml-agent-1.2.7.jar pmml-model-1.2.7.jar pmml-schema-1.2.7.jar protobuf-java-2.5.0.jar -py4j-0.9.1.jar +py4j-0.9.2.jar pyrolite-4.9.jar reflectasm-1.07-shaded.jar scala-compiler-2.11.7.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index c52af73b4eb2b..31f8694fedfcd 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -144,7 +144,7 @@ pmml-agent-1.2.7.jar pmml-model-1.2.7.jar pmml-schema-1.2.7.jar protobuf-java-2.5.0.jar -py4j-0.9.1.jar +py4j-0.9.2.jar pyrolite-4.9.jar reflectasm-1.07-shaded.jar scala-compiler-2.11.7.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index 092a0900268f4..0fa8bccab0317 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -145,7 +145,7 @@ pmml-agent-1.2.7.jar pmml-model-1.2.7.jar pmml-schema-1.2.7.jar protobuf-java-2.5.0.jar -py4j-0.9.1.jar +py4j-0.9.2.jar pyrolite-4.9.jar reflectasm-1.07-shaded.jar scala-compiler-2.11.7.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 81d86ee3d4d1b..8d2f6e6e32ab9 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -151,7 +151,7 @@ pmml-agent-1.2.7.jar pmml-model-1.2.7.jar pmml-schema-1.2.7.jar protobuf-java-2.5.0.jar -py4j-0.9.1.jar +py4j-0.9.2.jar pyrolite-4.9.jar reflectasm-1.07-shaded.jar scala-compiler-2.11.7.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index e38ad1e240fde..a114c4ae8dba0 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -152,7 +152,7 @@ pmml-agent-1.2.7.jar pmml-model-1.2.7.jar pmml-schema-1.2.7.jar protobuf-java-2.5.0.jar -py4j-0.9.1.jar +py4j-0.9.2.jar pyrolite-4.9.jar reflectasm-1.07-shaded.jar scala-compiler-2.11.7.jar diff --git a/dev/mima b/dev/mima index b7f8d62b7d26f..c8e2df6cfcd4f 100755 --- a/dev/mima +++ b/dev/mima @@ -40,7 +40,7 @@ SPARK_PROFILES="-Pyarn -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -P generate_mima_ignore "$(build/sbt $SPARK_PROFILES "export assembly/fullClasspath" | tail -n1)" generate_mima_ignore "$(build/sbt $SPARK_PROFILES "export oldDeps/fullClasspath" | tail -n1)" -echo -e "q\n" | build/sbt mima-report-binary-issues | grep -v -e "info.*Resolving" +echo -e "q\n" | build/sbt mimaReportBinaryIssues | grep -v -e "info.*Resolving" ret_val=$? if [ $ret_val != 0 ]; then diff --git a/dev/run-tests.py b/dev/run-tests.py index a1e6f1bdb560e..d940cdad3e278 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -337,9 +337,6 @@ def build_spark_sbt(hadoop_version): build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags sbt_goals = ["package", "streaming-kafka-assembly/assembly", - "streaming-flume-assembly/assembly", - "streaming-mqtt-assembly/assembly", - "streaming-mqtt/test:assembly", "streaming-kinesis-asl-assembly/assembly"] profiles_and_goals = build_profiles + sbt_goals diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 1781de4c657ce..d1184886e2c19 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -210,43 +210,6 @@ def __hash__(self): ) -streaming_zeromq = Module( - name="streaming-zeromq", - dependencies=[streaming], - source_file_regexes=[ - "external/zeromq", - ], - sbt_test_goals=[ - "streaming-zeromq/test", - ] -) - - -streaming_twitter = Module( - name="streaming-twitter", - dependencies=[streaming], - source_file_regexes=[ - "external/twitter", - ], - sbt_test_goals=[ - "streaming-twitter/test", - ] -) - - -streaming_mqtt = Module( - name="streaming-mqtt", - dependencies=[streaming], - source_file_regexes=[ - "external/mqtt", - "external/mqtt-assembly", - ], - sbt_test_goals=[ - "streaming-mqtt/test", - ] -) - - streaming_kafka = Module( name="streaming-kafka", dependencies=[streaming], @@ -260,51 +223,6 @@ def __hash__(self): ) -streaming_flume_sink = Module( - name="streaming-flume-sink", - dependencies=[streaming], - source_file_regexes=[ - "external/flume-sink", - ], - sbt_test_goals=[ - "streaming-flume-sink/test", - ] -) - - -streaming_akka = Module( - name="streaming-akka", - dependencies=[streaming], - source_file_regexes=[ - "external/akka", - ], - sbt_test_goals=[ - "streaming-akka/test", - ] -) - - -streaming_flume = Module( - name="streaming-flume", - dependencies=[streaming], - source_file_regexes=[ - "external/flume", - ], - sbt_test_goals=[ - "streaming-flume/test", - ] -) - - -streaming_flume_assembly = Module( - name="streaming-flume-assembly", - dependencies=[streaming_flume, streaming_flume_sink], - source_file_regexes=[ - "external/flume-assembly", - ] -) - - mllib = Module( name="mllib", dependencies=[streaming, sql], @@ -376,8 +294,6 @@ def __hash__(self): pyspark_core, streaming, streaming_kafka, - streaming_flume_assembly, - streaming_mqtt, streaming_kinesis_asl ], source_file_regexes=[ diff --git a/docs/streaming-custom-receivers.md b/docs/streaming-custom-receivers.md index 84547748618d1..a4e17fd24eac2 100644 --- a/docs/streaming-custom-receivers.md +++ b/docs/streaming-custom-receivers.md @@ -72,7 +72,8 @@ class CustomReceiver(host: String, port: Int) socket = new Socket(host, port) // Until stopped or connection broken continue reading - val reader = new BufferedReader(new InputStreamReader(socket.getInputStream(), "UTF-8")) + val reader = new BufferedReader( + new InputStreamReader(socket.getInputStream(), StandardCharsets.UTF_8)) userInput = reader.readLine() while(!isStopped && userInput != null) { store(userInput) @@ -135,7 +136,8 @@ public class JavaCustomReceiver extends Receiver { // connect to the server socket = new Socket(host, port); - BufferedReader reader = new BufferedReader(new InputStreamReader(socket.getInputStream())); + BufferedReader reader = new BufferedReader( + new InputStreamReader(socket.getInputStream(), StandardCharsets.UTF_8)); // Until stopped or connection broken continue reading while (!isStopped() && (userInput = reader.readLine()) != null) { @@ -254,64 +256,3 @@ The following table summarizes the characteristics of both types of receivers - -## Implementing and Using a Custom Actor-based Receiver - -Custom [Akka Actors](http://doc.akka.io/docs/akka/2.3.11/scala/actors.html) can also be used to -receive data. Here are the instructions. - -1. **Linking:** You need to add the following dependency to your SBT or Maven project (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). - - groupId = org.apache.spark - artifactId = spark-streaming-akka_{{site.SCALA_BINARY_VERSION}} - version = {{site.SPARK_VERSION_SHORT}} - -2. **Programming:** - -
-
- - You need to extend [`ActorReceiver`](api/scala/index.html#org.apache.spark.streaming.akka.ActorReceiver) - so as to store received data into Spark using `store(...)` methods. The supervisor strategy of - this actor can be configured to handle failures, etc. - - class CustomActor extends ActorReceiver { - def receive = { - case data: String => store(data) - } - } - - // A new input stream can be created with this custom actor as - val ssc: StreamingContext = ... - val lines = AkkaUtils.createStream[String](ssc, Props[CustomActor](), "CustomReceiver") - - See [ActorWordCount.scala](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala) for an end-to-end example. -
-
- - You need to extend [`JavaActorReceiver`](api/scala/index.html#org.apache.spark.streaming.akka.JavaActorReceiver) - so as to store received data into Spark using `store(...)` methods. The supervisor strategy of - this actor can be configured to handle failures, etc. - - class CustomActor extends JavaActorReceiver { - @Override - public void onReceive(Object msg) throws Exception { - store((String) msg); - } - } - - // A new input stream can be created with this custom actor as - JavaStreamingContext jssc = ...; - JavaDStream lines = AkkaUtils.createStream(jssc, Props.create(CustomActor.class), "CustomReceiver"); - - See [JavaActorWordCount.scala](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/JavaActorWordCount.scala) for an end-to-end example. -
-
- -3. **Deploying:** As with any Spark applications, `spark-submit` is used to launch your application. -You need to package `spark-streaming-akka_{{site.SCALA_BINARY_VERSION}}` and its dependencies into -the application JAR. Make sure `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` -are marked as `provided` dependencies as those are already present in a Spark installation. Then -use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). - -Python API Since actors are available only in the Java and Scala libraries, AkkaUtils is not available in the Python API. diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 998644f2e23db..6c36b41e78d52 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -594,7 +594,7 @@ data from a source and stores it in Spark's memory for processing. Spark Streaming provides two categories of built-in streaming sources. - *Basic sources*: Sources directly available in the StreamingContext API. - Examples: file systems, socket connections, and Akka actors. + Examples: file systems, and socket connections. - *Advanced sources*: Sources like Kafka, Flume, Kinesis, Twitter, etc. are available through extra utility classes. These require linking against extra dependencies as discussed in the [linking](#linking) section. @@ -631,7 +631,7 @@ as well as to run the receiver(s). We have already taken a look at the `ssc.socketTextStream(...)` in the [quick example](#a-quick-example) which creates a DStream from text data received over a TCP socket connection. Besides sockets, the StreamingContext API provides -methods for creating DStreams from files and Akka actors as input sources. +methods for creating DStreams from files as input sources. - **File Streams:** For reading data from files on any file system compatible with the HDFS API (that is, HDFS, S3, NFS, etc.), a DStream can be created as: @@ -658,17 +658,12 @@ methods for creating DStreams from files and Akka actors as input sources. Python API `fileStream` is not available in the Python API, only `textFileStream` is available. -- **Streams based on Custom Actors:** DStreams can be created with data streams received through Akka - actors by using `AkkaUtils.createStream(ssc, actorProps, actor-name)`. See the [Custom Receiver - Guide](streaming-custom-receivers.html) for more details. - - Python API Since actors are available only in the Java and Scala - libraries, `AkkaUtils.createStream` is not available in the Python API. +- **Streams based on Custom Receivers:** DStreams can be created with data streams received through custom receivers. See the [Custom Receiver + Guide](streaming-custom-receivers.html) and [DStream Akka](https://github.com/spark-packages/dstream-akka) for more details. - **Queue of RDDs as a Stream:** For testing a Spark Streaming application with test data, one can also create a DStream based on a queue of RDDs, using `streamingContext.queueStream(queueOfRDDs)`. Each RDD pushed into the queue will be treated as a batch of data in the DStream, and processed like a stream. -For more details on streams from sockets, files, and actors, -see the API documentations of the relevant functions in +For more details on streams from sockets and files, see the API documentations of the relevant functions in [StreamingContext](api/scala/index.html#org.apache.spark.streaming.StreamingContext) for Scala, [JavaStreamingContext](api/java/index.html?org/apache/spark/streaming/api/java/JavaStreamingContext.html) for Java, and [StreamingContext](api/python/pyspark.streaming.html#pyspark.streaming.StreamingContext) for Python. @@ -2439,13 +2434,8 @@ that can be called to store the data in Spark. So, to migrate your custom networ BlockGenerator object (does not exist any more in Spark 1.0 anyway), and use `store(...)` methods on received data. -**Actor-based Receivers**: Data could have been received using any Akka Actors by extending the actor class with -`org.apache.spark.streaming.receivers.Receiver` trait. This has been renamed to -[`org.apache.spark.streaming.receiver.ActorHelper`](api/scala/index.html#org.apache.spark.streaming.receiver.ActorHelper) -and the `pushBlock(...)` methods to store received data has been renamed to `store(...)`. Other helper classes in -the `org.apache.spark.streaming.receivers` package were also moved -to [`org.apache.spark.streaming.receiver`](api/scala/index.html#org.apache.spark.streaming.receiver.package) -package and renamed for better clarity. +**Actor-based Receivers**: The Actor-based Receiver APIs have been moved to [DStream Akka](https://github.com/spark-packages/dstream-akka). +Please refer to the project for more details. *************************************************************************************************** *************************************************************************************************** diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index cebdb6d910228..66025ed6baab8 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -30,7 +30,7 @@ dependencies, and can support different cluster managers and deploy modes that S {% highlight bash %} ./bin/spark-submit \ - --class + --class \ --master \ --deploy-mode \ --conf = \ @@ -92,8 +92,8 @@ run it with `--help`. Here are a few examples of common options: ./bin/spark-submit \ --class org.apache.spark.examples.SparkPi \ --master spark://207.184.161.138:7077 \ - --deploy-mode cluster - --supervise + --deploy-mode cluster \ + --supervise \ --executor-memory 20G \ --total-executor-cores 100 \ /path/to/examples.jar \ @@ -120,8 +120,8 @@ export HADOOP_CONF_DIR=XXX ./bin/spark-submit \ --class org.apache.spark.examples.SparkPi \ --master mesos://207.184.161.138:7077 \ - --deploy-mode cluster - --supervise + --deploy-mode cluster \ + --supervise \ --executor-memory 20G \ --total-executor-cores 100 \ http://path/to/examples.jar \ diff --git a/examples/pom.xml b/examples/pom.xml index 3a3f547915015..92bb373c7382d 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -65,37 +65,6 @@ ${project.version} provided
- - org.apache.spark - spark-streaming-twitter_${scala.binary.version} - ${project.version} - - - org.apache.spark - spark-streaming-flume_${scala.binary.version} - ${project.version} - - - org.apache.spark - spark-streaming-akka_${scala.binary.version} - ${project.version} - - - org.apache.spark - spark-streaming-mqtt_${scala.binary.version} - ${project.version} - - - org.apache.spark - spark-streaming-zeromq_${scala.binary.version} - ${project.version} - - - org.spark-project.protobuf - protobuf-java - - - org.apache.spark spark-streaming-kafka_${scala.binary.version} diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaActorWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaActorWordCount.java deleted file mode 100644 index 7884b8cdfff84..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaActorWordCount.java +++ /dev/null @@ -1,144 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.streaming; - -import java.util.Arrays; -import java.util.Iterator; - -import scala.Tuple2; - -import akka.actor.ActorSelection; -import akka.actor.Props; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.function.FlatMapFunction; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.streaming.Duration; -import org.apache.spark.streaming.api.java.JavaDStream; -import org.apache.spark.streaming.api.java.JavaStreamingContext; -import org.apache.spark.streaming.akka.AkkaUtils; -import org.apache.spark.streaming.akka.JavaActorReceiver; - -/** - * A sample actor as receiver, is also simplest. This receiver actor - * goes and subscribe to a typical publisher/feeder actor and receives - * data. - * - * @see [[org.apache.spark.examples.streaming.FeederActor]] - */ -class JavaSampleActorReceiver extends JavaActorReceiver { - - private final String urlOfPublisher; - - public JavaSampleActorReceiver(String urlOfPublisher) { - this.urlOfPublisher = urlOfPublisher; - } - - private ActorSelection remotePublisher; - - @Override - public void preStart() { - remotePublisher = getContext().actorSelection(urlOfPublisher); - remotePublisher.tell(new SubscribeReceiver(getSelf()), getSelf()); - } - - @Override - public void onReceive(Object msg) throws Exception { - @SuppressWarnings("unchecked") - T msgT = (T) msg; - store(msgT); - } - - @Override - public void postStop() { - remotePublisher.tell(new UnsubscribeReceiver(getSelf()), getSelf()); - } -} - -/** - * A sample word count program demonstrating the use of plugging in - * Actor as Receiver - * Usage: JavaActorWordCount - * and describe the AkkaSystem that Spark Sample feeder is running on. - * - * To run this example locally, you may run Feeder Actor as - *
- *     $ bin/run-example org.apache.spark.examples.streaming.FeederActor localhost 9999
- * 
- * and then run the example - *
- *     $ bin/run-example org.apache.spark.examples.streaming.JavaActorWordCount localhost 9999
- * 
- */ -public class JavaActorWordCount { - - public static void main(String[] args) { - if (args.length < 2) { - System.err.println("Usage: JavaActorWordCount "); - System.exit(1); - } - - StreamingExamples.setStreamingLogLevels(); - - final String host = args[0]; - final String port = args[1]; - SparkConf sparkConf = new SparkConf().setAppName("JavaActorWordCount"); - // Create the context and set the batch size - JavaStreamingContext jssc = new JavaStreamingContext(sparkConf, new Duration(2000)); - - String feederActorURI = "akka.tcp://test@" + host + ":" + port + "/user/FeederActor"; - - /* - * Following is the use of AkkaUtils.createStream to plug in custom actor as receiver - * - * An important point to note: - * Since Actor may exist outside the spark framework, It is thus user's responsibility - * to ensure the type safety, i.e type of data received and InputDstream - * should be same. - * - * For example: Both AkkaUtils.createStream and JavaSampleActorReceiver are parameterized - * to same type to ensure type safety. - */ - JavaDStream lines = AkkaUtils.createStream( - jssc, - Props.create(JavaSampleActorReceiver.class, feederActorURI), - "SampleReceiver"); - - // compute wordcount - lines.flatMap(new FlatMapFunction() { - @Override - public Iterator call(String s) { - return Arrays.asList(s.split("\\s+")).iterator(); - } - }).mapToPair(new PairFunction() { - @Override - public Tuple2 call(String s) { - return new Tuple2<>(s, 1); - } - }).reduceByKey(new Function2() { - @Override - public Integer call(Integer i1, Integer i2) { - return i1 + i2; - } - }).print(); - - jssc.start(); - jssc.awaitTermination(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java index 5de56340c6d22..4544ad2b42ca7 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java @@ -36,6 +36,7 @@ import java.io.InputStreamReader; import java.net.ConnectException; import java.net.Socket; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Iterator; import java.util.regex.Pattern; @@ -130,7 +131,8 @@ private void receive() { try { // connect to the server socket = new Socket(host, port); - reader = new BufferedReader(new InputStreamReader(socket.getInputStream())); + reader = new BufferedReader( + new InputStreamReader(socket.getInputStream(), StandardCharsets.UTF_8)); // Until stopped or connection broken continue reading while (!isStopped() && (userInput = reader.readLine()) != null) { System.out.println("Received data '" + userInput + "'"); diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java deleted file mode 100644 index da56637fe891a..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.streaming; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.examples.streaming.StreamingExamples; -import org.apache.spark.streaming.*; -import org.apache.spark.streaming.api.java.*; -import org.apache.spark.streaming.flume.FlumeUtils; -import org.apache.spark.streaming.flume.SparkFlumeEvent; - -/** - * Produces a count of events received from Flume. - * - * This should be used in conjunction with an AvroSink in Flume. It will start - * an Avro server on at the request host:port address and listen for requests. - * Your Flume AvroSink should be pointed to this address. - * - * Usage: JavaFlumeEventCount - * is the host the Flume receiver will be started on - a receiver - * creates a server and listens for flume events. - * is the port the Flume receiver will listen on. - * - * To run this example: - * `$ bin/run-example org.apache.spark.examples.streaming.JavaFlumeEventCount ` - */ -public final class JavaFlumeEventCount { - private JavaFlumeEventCount() { - } - - public static void main(String[] args) { - if (args.length != 2) { - System.err.println("Usage: JavaFlumeEventCount "); - System.exit(1); - } - - StreamingExamples.setStreamingLogLevels(); - - String host = args[0]; - int port = Integer.parseInt(args[1]); - - Duration batchInterval = new Duration(2000); - SparkConf sparkConf = new SparkConf().setAppName("JavaFlumeEventCount"); - JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, batchInterval); - JavaReceiverInputDStream flumeStream = FlumeUtils.createStream(ssc, host, port); - - flumeStream.count(); - - flumeStream.count().map(new Function() { - @Override - public String call(Long in) { - return "Received " + in + " flume events."; - } - }).print(); - - ssc.start(); - ssc.awaitTermination(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaTwitterHashTagJoinSentiments.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaTwitterHashTagJoinSentiments.java deleted file mode 100644 index f0ae9a99bae47..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaTwitterHashTagJoinSentiments.java +++ /dev/null @@ -1,175 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.streaming; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.function.FlatMapFunction; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.api.java.function.VoidFunction; -import org.apache.spark.streaming.Duration; -import org.apache.spark.streaming.api.java.JavaDStream; -import org.apache.spark.streaming.api.java.JavaPairDStream; -import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; -import org.apache.spark.streaming.api.java.JavaStreamingContext; -import org.apache.spark.streaming.twitter.TwitterUtils; -import scala.Tuple2; -import twitter4j.Status; - -import java.util.Arrays; -import java.util.Iterator; -import java.util.List; - -/** - * Displays the most positive hash tags by joining the streaming Twitter data with a static RDD of - * the AFINN word list (http://neuro.imm.dtu.dk/wiki/AFINN) - */ -public class JavaTwitterHashTagJoinSentiments { - - public static void main(String[] args) { - if (args.length < 4) { - System.err.println("Usage: JavaTwitterHashTagJoinSentiments " + - " []"); - System.exit(1); - } - - StreamingExamples.setStreamingLogLevels(); - - String consumerKey = args[0]; - String consumerSecret = args[1]; - String accessToken = args[2]; - String accessTokenSecret = args[3]; - String[] filters = Arrays.copyOfRange(args, 4, args.length); - - // Set the system properties so that Twitter4j library used by Twitter stream - // can use them to generate OAuth credentials - System.setProperty("twitter4j.oauth.consumerKey", consumerKey); - System.setProperty("twitter4j.oauth.consumerSecret", consumerSecret); - System.setProperty("twitter4j.oauth.accessToken", accessToken); - System.setProperty("twitter4j.oauth.accessTokenSecret", accessTokenSecret); - - SparkConf sparkConf = new SparkConf().setAppName("JavaTwitterHashTagJoinSentiments"); - JavaStreamingContext jssc = new JavaStreamingContext(sparkConf, new Duration(2000)); - JavaReceiverInputDStream stream = TwitterUtils.createStream(jssc, filters); - - JavaDStream words = stream.flatMap(new FlatMapFunction() { - @Override - public Iterator call(Status s) { - return Arrays.asList(s.getText().split(" ")).iterator(); - } - }); - - JavaDStream hashTags = words.filter(new Function() { - @Override - public Boolean call(String word) { - return word.startsWith("#"); - } - }); - - // Read in the word-sentiment list and create a static RDD from it - String wordSentimentFilePath = "data/streaming/AFINN-111.txt"; - final JavaPairRDD wordSentiments = jssc.sparkContext().textFile(wordSentimentFilePath) - .mapToPair(new PairFunction(){ - @Override - public Tuple2 call(String line) { - String[] columns = line.split("\t"); - return new Tuple2<>(columns[0], Double.parseDouble(columns[1])); - } - }); - - JavaPairDStream hashTagCount = hashTags.mapToPair( - new PairFunction() { - @Override - public Tuple2 call(String s) { - // leave out the # character - return new Tuple2<>(s.substring(1), 1); - } - }); - - JavaPairDStream hashTagTotals = hashTagCount.reduceByKeyAndWindow( - new Function2() { - @Override - public Integer call(Integer a, Integer b) { - return a + b; - } - }, new Duration(10000)); - - // Determine the hash tags with the highest sentiment values by joining the streaming RDD - // with the static RDD inside the transform() method and then multiplying - // the frequency of the hash tag by its sentiment value - JavaPairDStream> joinedTuples = - hashTagTotals.transformToPair(new Function, - JavaPairRDD>>() { - @Override - public JavaPairRDD> call( - JavaPairRDD topicCount) { - return wordSentiments.join(topicCount); - } - }); - - JavaPairDStream topicHappiness = joinedTuples.mapToPair( - new PairFunction>, String, Double>() { - @Override - public Tuple2 call(Tuple2> topicAndTuplePair) { - Tuple2 happinessAndCount = topicAndTuplePair._2(); - return new Tuple2<>(topicAndTuplePair._1(), - happinessAndCount._1() * happinessAndCount._2()); - } - }); - - JavaPairDStream happinessTopicPairs = topicHappiness.mapToPair( - new PairFunction, Double, String>() { - @Override - public Tuple2 call(Tuple2 topicHappiness) { - return new Tuple2<>(topicHappiness._2(), - topicHappiness._1()); - } - }); - - JavaPairDStream happiest10 = happinessTopicPairs.transformToPair( - new Function, JavaPairRDD>() { - @Override - public JavaPairRDD call( - JavaPairRDD happinessAndTopics) { - return happinessAndTopics.sortByKey(false); - } - } - ); - - // Print hash tags with the most positive sentiment values - happiest10.foreachRDD(new VoidFunction>() { - @Override - public void call(JavaPairRDD happinessTopicPairs) { - List> topList = happinessTopicPairs.take(10); - System.out.println( - String.format("\nHappiest topics in last 10 seconds (%s total):", - happinessTopicPairs.count())); - for (Tuple2 pair : topList) { - System.out.println( - String.format("%s (%s happiness)", pair._2(), pair._1())); - } - } - }); - - jssc.start(); - jssc.awaitTermination(); - } -} diff --git a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala index 7796f362bb56c..d498af9c390a7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala @@ -38,7 +38,7 @@ object SkewedGroupByTest { val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p => val ranGen = new Random - // map output sizes lineraly increase from the 1st to the last + // map output sizes linearly increase from the 1st to the last numKVPairs = (1.0 * (p + 1) / numMappers * numKVPairs).toInt var arr1 = new Array[(Int, Array[Byte])](numKVPairs) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala deleted file mode 100644 index 844772a289284..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala +++ /dev/null @@ -1,175 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.streaming - -import scala.collection.mutable.LinkedHashSet -import scala.util.Random - -import akka.actor._ -import com.typesafe.config.ConfigFactory - -import org.apache.spark.SparkConf -import org.apache.spark.streaming.{Seconds, StreamingContext} -import org.apache.spark.streaming.akka.{ActorReceiver, AkkaUtils} - -case class SubscribeReceiver(receiverActor: ActorRef) -case class UnsubscribeReceiver(receiverActor: ActorRef) - -/** - * Sends the random content to every receiver subscribed with 1/2 - * second delay. - */ -class FeederActor extends Actor { - - val rand = new Random() - val receivers = new LinkedHashSet[ActorRef]() - - val strings: Array[String] = Array("words ", "may ", "count ") - - def makeMessage(): String = { - val x = rand.nextInt(3) - strings(x) + strings(2 - x) - } - - /* - * A thread to generate random messages - */ - new Thread() { - override def run() { - while (true) { - Thread.sleep(500) - receivers.foreach(_ ! makeMessage) - } - } - }.start() - - def receive: Receive = { - case SubscribeReceiver(receiverActor: ActorRef) => - println("received subscribe from %s".format(receiverActor.toString)) - receivers += receiverActor - - case UnsubscribeReceiver(receiverActor: ActorRef) => - println("received unsubscribe from %s".format(receiverActor.toString)) - receivers -= receiverActor - } -} - -/** - * A sample actor as receiver, is also simplest. This receiver actor - * goes and subscribe to a typical publisher/feeder actor and receives - * data. - * - * @see [[org.apache.spark.examples.streaming.FeederActor]] - */ -class SampleActorReceiver[T](urlOfPublisher: String) extends ActorReceiver { - - lazy private val remotePublisher = context.actorSelection(urlOfPublisher) - - override def preStart(): Unit = remotePublisher ! SubscribeReceiver(context.self) - - def receive: PartialFunction[Any, Unit] = { - case msg => store(msg.asInstanceOf[T]) - } - - override def postStop(): Unit = remotePublisher ! UnsubscribeReceiver(context.self) - -} - -/** - * A sample feeder actor - * - * Usage: FeederActor - * and describe the AkkaSystem that Spark Sample feeder would start on. - */ -object FeederActor { - - def main(args: Array[String]) { - if (args.length < 2) { - System.err.println("Usage: FeederActor \n") - System.exit(1) - } - val Seq(host, port) = args.toSeq - - val akkaConf = ConfigFactory.parseString( - s"""akka.actor.provider = "akka.remote.RemoteActorRefProvider" - |akka.remote.enabled-transports = ["akka.remote.netty.tcp"] - |akka.remote.netty.tcp.hostname = "$host" - |akka.remote.netty.tcp.port = $port - |""".stripMargin) - val actorSystem = ActorSystem("test", akkaConf) - val feeder = actorSystem.actorOf(Props[FeederActor], "FeederActor") - - println("Feeder started as:" + feeder) - - actorSystem.awaitTermination() - } -} - -/** - * A sample word count program demonstrating the use of plugging in - * - * Actor as Receiver - * Usage: ActorWordCount - * and describe the AkkaSystem that Spark Sample feeder is running on. - * - * To run this example locally, you may run Feeder Actor as - * `$ bin/run-example org.apache.spark.examples.streaming.FeederActor localhost 9999` - * and then run the example - * `$ bin/run-example org.apache.spark.examples.streaming.ActorWordCount localhost 9999` - */ -object ActorWordCount { - def main(args: Array[String]) { - if (args.length < 2) { - System.err.println( - "Usage: ActorWordCount ") - System.exit(1) - } - - StreamingExamples.setStreamingLogLevels() - - val Seq(host, port) = args.toSeq - val sparkConf = new SparkConf().setAppName("ActorWordCount") - // Create the context and set the batch size - val ssc = new StreamingContext(sparkConf, Seconds(2)) - - /* - * Following is the use of AkkaUtils.createStream to plug in custom actor as receiver - * - * An important point to note: - * Since Actor may exist outside the spark framework, It is thus user's responsibility - * to ensure the type safety, i.e type of data received and InputDStream - * should be same. - * - * For example: Both AkkaUtils.createStream and SampleActorReceiver are parameterized - * to same type to ensure type safety. - */ - val lines = AkkaUtils.createStream[String]( - ssc, - Props(classOf[SampleActorReceiver[String]], - "akka.tcp://test@%s:%s/user/FeederActor".format(host, port.toInt)), - "SampleReceiver") - - // compute wordcount - lines.flatMap(_.split("\\s+")).map(x => (x, 1)).reduceByKey(_ + _).print() - - ssc.start() - ssc.awaitTermination() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala index 5ce5778e42a3a..d67da270a8178 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala @@ -20,6 +20,7 @@ package org.apache.spark.examples.streaming import java.io.{BufferedReader, InputStreamReader} import java.net.Socket +import java.nio.charset.StandardCharsets import org.apache.spark.{Logging, SparkConf} import org.apache.spark.storage.StorageLevel @@ -83,7 +84,8 @@ class CustomReceiver(host: String, port: Int) logInfo("Connecting to " + host + ":" + port) socket = new Socket(host, port) logInfo("Connected to " + host + ":" + port) - val reader = new BufferedReader(new InputStreamReader(socket.getInputStream(), "UTF-8")) + val reader = new BufferedReader( + new InputStreamReader(socket.getInputStream(), StandardCharsets.UTF_8)) userInput = reader.readLine() while(!isStopped && userInput != null) { store(userInput) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala deleted file mode 100644 index 91e52e4eff5a7..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.streaming - -import org.apache.spark.SparkConf -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming._ -import org.apache.spark.streaming.flume._ -import org.apache.spark.util.IntParam - -/** - * Produces a count of events received from Flume. - * - * This should be used in conjunction with an AvroSink in Flume. It will start - * an Avro server on at the request host:port address and listen for requests. - * Your Flume AvroSink should be pointed to this address. - * - * Usage: FlumeEventCount - * is the host the Flume receiver will be started on - a receiver - * creates a server and listens for flume events. - * is the port the Flume receiver will listen on. - * - * To run this example: - * `$ bin/run-example org.apache.spark.examples.streaming.FlumeEventCount ` - */ -object FlumeEventCount { - def main(args: Array[String]) { - if (args.length < 2) { - System.err.println( - "Usage: FlumeEventCount ") - System.exit(1) - } - - StreamingExamples.setStreamingLogLevels() - - val Array(host, IntParam(port)) = args - - val batchInterval = Milliseconds(2000) - - // Create the context and set the batch size - val sparkConf = new SparkConf().setAppName("FlumeEventCount") - val ssc = new StreamingContext(sparkConf, batchInterval) - - // Create a flume stream - val stream = FlumeUtils.createStream(ssc, host, port, StorageLevel.MEMORY_ONLY_SER_2) - - // Print out the count of events received from this server in each batch - stream.count().map(cnt => "Received " + cnt + " flume events." ).print() - - ssc.start() - ssc.awaitTermination() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala deleted file mode 100644 index dd725d72c23ef..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.streaming - -import org.apache.spark.SparkConf -import org.apache.spark.streaming._ -import org.apache.spark.streaming.flume._ -import org.apache.spark.util.IntParam - -/** - * Produces a count of events received from Flume. - * - * This should be used in conjunction with the Spark Sink running in a Flume agent. See - * the Spark Streaming programming guide for more details. - * - * Usage: FlumePollingEventCount - * `host` is the host on which the Spark Sink is running. - * `port` is the port at which the Spark Sink is listening. - * - * To run this example: - * `$ bin/run-example org.apache.spark.examples.streaming.FlumePollingEventCount [host] [port] ` - */ -object FlumePollingEventCount { - def main(args: Array[String]) { - if (args.length < 2) { - System.err.println( - "Usage: FlumePollingEventCount ") - System.exit(1) - } - - StreamingExamples.setStreamingLogLevels() - - val Array(host, IntParam(port)) = args - - val batchInterval = Milliseconds(2000) - - // Create the context and set the batch size - val sparkConf = new SparkConf().setAppName("FlumePollingEventCount") - val ssc = new StreamingContext(sparkConf, batchInterval) - - // Create a flume stream that polls the Spark Sink running in a Flume agent - val stream = FlumeUtils.createPollingStream(ssc, host, port) - - // Print out the count of events received from this server in each batch - stream.count().map(cnt => "Received " + cnt + " flume events." ).print() - - ssc.start() - ssc.awaitTermination() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala deleted file mode 100644 index d772ae309f40d..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala +++ /dev/null @@ -1,119 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.streaming - -import org.eclipse.paho.client.mqttv3._ -import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence - -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{Seconds, StreamingContext} -import org.apache.spark.streaming.mqtt._ -import org.apache.spark.SparkConf - -/** - * A simple Mqtt publisher for demonstration purposes, repeatedly publishes - * Space separated String Message "hello mqtt demo for spark streaming" - */ -object MQTTPublisher { - - def main(args: Array[String]) { - if (args.length < 2) { - System.err.println("Usage: MQTTPublisher ") - System.exit(1) - } - - StreamingExamples.setStreamingLogLevels() - - val Seq(brokerUrl, topic) = args.toSeq - - var client: MqttClient = null - - try { - val persistence = new MemoryPersistence() - client = new MqttClient(brokerUrl, MqttClient.generateClientId(), persistence) - - client.connect() - - val msgtopic = client.getTopic(topic) - val msgContent = "hello mqtt demo for spark streaming" - val message = new MqttMessage(msgContent.getBytes("utf-8")) - - while (true) { - try { - msgtopic.publish(message) - println(s"Published data. topic: ${msgtopic.getName()}; Message: $message") - } catch { - case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => - Thread.sleep(10) - println("Queue is full, wait for to consume data from the message queue") - } - } - } catch { - case e: MqttException => println("Exception Caught: " + e) - } finally { - if (client != null) { - client.disconnect() - } - } - } -} - -/** - * A sample wordcount with MqttStream stream - * - * To work with Mqtt, Mqtt Message broker/server required. - * Mosquitto (http://mosquitto.org/) is an open source Mqtt Broker - * In ubuntu mosquitto can be installed using the command `$ sudo apt-get install mosquitto` - * Eclipse paho project provides Java library for Mqtt Client http://www.eclipse.org/paho/ - * Example Java code for Mqtt Publisher and Subscriber can be found here - * https://bitbucket.org/mkjinesh/mqttclient - * Usage: MQTTWordCount - * and describe where Mqtt publisher is running. - * - * To run this example locally, you may run publisher as - * `$ bin/run-example \ - * org.apache.spark.examples.streaming.MQTTPublisher tcp://localhost:1883 foo` - * and run the example as - * `$ bin/run-example \ - * org.apache.spark.examples.streaming.MQTTWordCount tcp://localhost:1883 foo` - */ -object MQTTWordCount { - - def main(args: Array[String]) { - if (args.length < 2) { - // scalastyle:off println - System.err.println( - "Usage: MQTTWordCount ") - // scalastyle:on println - System.exit(1) - } - - val Seq(brokerUrl, topic) = args.toSeq - val sparkConf = new SparkConf().setAppName("MQTTWordCount") - val ssc = new StreamingContext(sparkConf, Seconds(2)) - val lines = MQTTUtils.createStream(ssc, brokerUrl, topic, StorageLevel.MEMORY_ONLY_SER_2) - val words = lines.flatMap(x => x.split(" ")) - val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - - wordCounts.print() - ssc.start() - ssc.awaitTermination() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala deleted file mode 100644 index 5af82e161a2f7..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala +++ /dev/null @@ -1,116 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.streaming - -import com.twitter.algebird._ -import com.twitter.algebird.CMSHasherImplicits._ - -import org.apache.spark.SparkConf -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{Seconds, StreamingContext} -import org.apache.spark.streaming.twitter._ - -// scalastyle:off -/** - * Illustrates the use of the Count-Min Sketch, from Twitter's Algebird library, to compute - * windowed and global Top-K estimates of user IDs occurring in a Twitter stream. - *
- * Note that since Algebird's implementation currently only supports Long inputs, - * the example operates on Long IDs. Once the implementation supports other inputs (such as String), - * the same approach could be used for computing popular topics for example. - *

- *

- * - * This blog post has a good overview of the Count-Min Sketch (CMS). The CMS is a data - * structure for approximate frequency estimation in data streams (e.g. Top-K elements, frequency - * of any given element, etc), that uses space sub-linear in the number of elements in the - * stream. Once elements are added to the CMS, the estimated count of an element can be computed, - * as well as "heavy-hitters" that occur more than a threshold percentage of the overall total - * count. - *

- * Algebird's implementation is a monoid, so we can succinctly merge two CMS instances in the - * reduce operation. - */ -// scalastyle:on -object TwitterAlgebirdCMS { - def main(args: Array[String]) { - StreamingExamples.setStreamingLogLevels() - - // CMS parameters - val DELTA = 1E-3 - val EPS = 0.01 - val SEED = 1 - val PERC = 0.001 - // K highest frequency elements to take - val TOPK = 10 - - val filters = args - val sparkConf = new SparkConf().setAppName("TwitterAlgebirdCMS") - val ssc = new StreamingContext(sparkConf, Seconds(10)) - val stream = TwitterUtils.createStream(ssc, None, filters, StorageLevel.MEMORY_ONLY_SER_2) - - val users = stream.map(status => status.getUser.getId) - - // val cms = new CountMinSketchMonoid(EPS, DELTA, SEED, PERC) - val cms = TopPctCMS.monoid[Long](EPS, DELTA, SEED, PERC) - var globalCMS = cms.zero - val mm = new MapMonoid[Long, Int]() - var globalExact = Map[Long, Int]() - - val approxTopUsers = users.mapPartitions(ids => { - ids.map(id => cms.create(id)) - }).reduce(_ ++ _) - - val exactTopUsers = users.map(id => (id, 1)) - .reduceByKey((a, b) => a + b) - - approxTopUsers.foreachRDD(rdd => { - if (rdd.count() != 0) { - val partial = rdd.first() - val partialTopK = partial.heavyHitters.map(id => - (id, partial.frequency(id).estimate)).toSeq.sortBy(_._2).reverse.slice(0, TOPK) - globalCMS ++= partial - val globalTopK = globalCMS.heavyHitters.map(id => - (id, globalCMS.frequency(id).estimate)).toSeq.sortBy(_._2).reverse.slice(0, TOPK) - println("Approx heavy hitters at %2.2f%% threshold this batch: %s".format(PERC, - partialTopK.mkString("[", ",", "]"))) - println("Approx heavy hitters at %2.2f%% threshold overall: %s".format(PERC, - globalTopK.mkString("[", ",", "]"))) - } - }) - - exactTopUsers.foreachRDD(rdd => { - if (rdd.count() != 0) { - val partialMap = rdd.collect().toMap - val partialTopK = rdd.map( - {case (id, count) => (count, id)}) - .sortByKey(ascending = false).take(TOPK) - globalExact = mm.plus(globalExact.toMap, partialMap) - val globalTopK = globalExact.toSeq.sortBy(_._2).reverse.slice(0, TOPK) - println("Exact heavy hitters this batch: %s".format(partialTopK.mkString("[", ",", "]"))) - println("Exact heavy hitters overall: %s".format(globalTopK.mkString("[", ",", "]"))) - } - }) - - ssc.start() - ssc.awaitTermination() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala deleted file mode 100644 index 6442b2a4e294b..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.streaming - -import com.twitter.algebird.HyperLogLog._ -import com.twitter.algebird.HyperLogLogMonoid - -import org.apache.spark.SparkConf -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{Seconds, StreamingContext} -import org.apache.spark.streaming.twitter._ - -// scalastyle:off -/** - * Illustrates the use of the HyperLogLog algorithm, from Twitter's Algebird library, to compute - * a windowed and global estimate of the unique user IDs occurring in a Twitter stream. - *

- *

- * This - * blog post and this - * - * blog post - * have good overviews of HyperLogLog (HLL). HLL is a memory-efficient datastructure for - * estimating the cardinality of a data stream, i.e. the number of unique elements. - *

- * Algebird's implementation is a monoid, so we can succinctly merge two HLL instances in the - * reduce operation. - */ -// scalastyle:on -object TwitterAlgebirdHLL { - def main(args: Array[String]) { - - StreamingExamples.setStreamingLogLevels() - - /** Bit size parameter for HyperLogLog, trades off accuracy vs size */ - val BIT_SIZE = 12 - val filters = args - val sparkConf = new SparkConf().setAppName("TwitterAlgebirdHLL") - val ssc = new StreamingContext(sparkConf, Seconds(5)) - val stream = TwitterUtils.createStream(ssc, None, filters, StorageLevel.MEMORY_ONLY_SER) - - val users = stream.map(status => status.getUser.getId) - - val hll = new HyperLogLogMonoid(BIT_SIZE) - var globalHll = hll.zero - var userSet: Set[Long] = Set() - - val approxUsers = users.mapPartitions(ids => { - ids.map(id => hll.create(id)) - }).reduce(_ + _) - - val exactUsers = users.map(id => Set(id)).reduce(_ ++ _) - - approxUsers.foreachRDD(rdd => { - if (rdd.count() != 0) { - val partial = rdd.first() - globalHll += partial - println("Approx distinct users this batch: %d".format(partial.estimatedSize.toInt)) - println("Approx distinct users overall: %d".format(globalHll.estimatedSize.toInt)) - } - }) - - exactUsers.foreachRDD(rdd => { - if (rdd.count() != 0) { - val partial = rdd.first() - userSet ++= partial - println("Exact distinct users this batch: %d".format(partial.size)) - println("Exact distinct users overall: %d".format(userSet.size)) - println("Error rate: %2.5f%%".format(((globalHll.estimatedSize / userSet.size.toDouble) - 1 - ) * 100)) - } - }) - - ssc.start() - ssc.awaitTermination() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterHashTagJoinSentiments.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterHashTagJoinSentiments.scala deleted file mode 100644 index a8d392ca35b40..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterHashTagJoinSentiments.scala +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.streaming - -import org.apache.spark.SparkConf -import org.apache.spark.streaming.{Seconds, StreamingContext} -import org.apache.spark.streaming.twitter.TwitterUtils - -/** - * Displays the most positive hash tags by joining the streaming Twitter data with a static RDD of - * the AFINN word list (http://neuro.imm.dtu.dk/wiki/AFINN) - */ -object TwitterHashTagJoinSentiments { - def main(args: Array[String]) { - if (args.length < 4) { - System.err.println("Usage: TwitterHashTagJoinSentiments " + - " []") - System.exit(1) - } - - StreamingExamples.setStreamingLogLevels() - - val Array(consumerKey, consumerSecret, accessToken, accessTokenSecret) = args.take(4) - val filters = args.takeRight(args.length - 4) - - // Set the system properties so that Twitter4j library used by Twitter stream - // can use them to generate OAuth credentials - System.setProperty("twitter4j.oauth.consumerKey", consumerKey) - System.setProperty("twitter4j.oauth.consumerSecret", consumerSecret) - System.setProperty("twitter4j.oauth.accessToken", accessToken) - System.setProperty("twitter4j.oauth.accessTokenSecret", accessTokenSecret) - - val sparkConf = new SparkConf().setAppName("TwitterHashTagJoinSentiments") - val ssc = new StreamingContext(sparkConf, Seconds(2)) - val stream = TwitterUtils.createStream(ssc, None, filters) - - val hashTags = stream.flatMap(status => status.getText.split(" ").filter(_.startsWith("#"))) - - // Read in the word-sentiment list and create a static RDD from it - val wordSentimentFilePath = "data/streaming/AFINN-111.txt" - val wordSentiments = ssc.sparkContext.textFile(wordSentimentFilePath).map { line => - val Array(word, happinessValue) = line.split("\t") - (word, happinessValue.toInt) - }.cache() - - // Determine the hash tags with the highest sentiment values by joining the streaming RDD - // with the static RDD inside the transform() method and then multiplying - // the frequency of the hash tag by its sentiment value - val happiest60 = hashTags.map(hashTag => (hashTag.tail, 1)) - .reduceByKeyAndWindow(_ + _, Seconds(60)) - .transform{topicCount => wordSentiments.join(topicCount)} - .map{case (topic, tuple) => (topic, tuple._1 * tuple._2)} - .map{case (topic, happinessValue) => (happinessValue, topic)} - .transform(_.sortByKey(false)) - - val happiest10 = hashTags.map(hashTag => (hashTag.tail, 1)) - .reduceByKeyAndWindow(_ + _, Seconds(10)) - .transform{topicCount => wordSentiments.join(topicCount)} - .map{case (topic, tuple) => (topic, tuple._1 * tuple._2)} - .map{case (topic, happinessValue) => (happinessValue, topic)} - .transform(_.sortByKey(false)) - - // Print hash tags with the most positive sentiment values - happiest60.foreachRDD(rdd => { - val topList = rdd.take(10) - println("\nHappiest topics in last 60 seconds (%s total):".format(rdd.count())) - topList.foreach{case (happiness, tag) => println("%s (%s happiness)".format(tag, happiness))} - }) - - happiest10.foreachRDD(rdd => { - val topList = rdd.take(10) - println("\nHappiest topics in last 10 seconds (%s total):".format(rdd.count())) - topList.foreach{case (happiness, tag) => println("%s (%s happiness)".format(tag, happiness))} - }) - - ssc.start() - ssc.awaitTermination() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala deleted file mode 100644 index 5b69963cc8880..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.streaming - -import org.apache.spark.streaming.{Seconds, StreamingContext} -import org.apache.spark.streaming.twitter._ -import org.apache.spark.SparkConf - -/** - * Calculates popular hashtags (topics) over sliding 10 and 60 second windows from a Twitter - * stream. The stream is instantiated with credentials and optionally filters supplied by the - * command line arguments. - * - * Run this on your local machine as - * - */ -object TwitterPopularTags { - def main(args: Array[String]) { - if (args.length < 4) { - System.err.println("Usage: TwitterPopularTags " + - " []") - System.exit(1) - } - - StreamingExamples.setStreamingLogLevels() - - val Array(consumerKey, consumerSecret, accessToken, accessTokenSecret) = args.take(4) - val filters = args.takeRight(args.length - 4) - - // Set the system properties so that Twitter4j library used by twitter stream - // can use them to generate OAuth credentials - System.setProperty("twitter4j.oauth.consumerKey", consumerKey) - System.setProperty("twitter4j.oauth.consumerSecret", consumerSecret) - System.setProperty("twitter4j.oauth.accessToken", accessToken) - System.setProperty("twitter4j.oauth.accessTokenSecret", accessTokenSecret) - - val sparkConf = new SparkConf().setAppName("TwitterPopularTags") - val ssc = new StreamingContext(sparkConf, Seconds(2)) - val stream = TwitterUtils.createStream(ssc, None, filters) - - val hashTags = stream.flatMap(status => status.getText.split(" ").filter(_.startsWith("#"))) - - val topCounts60 = hashTags.map((_, 1)).reduceByKeyAndWindow(_ + _, Seconds(60)) - .map{case (topic, count) => (count, topic)} - .transform(_.sortByKey(false)) - - val topCounts10 = hashTags.map((_, 1)).reduceByKeyAndWindow(_ + _, Seconds(10)) - .map{case (topic, count) => (count, topic)} - .transform(_.sortByKey(false)) - - - // Print popular hashtags - topCounts60.foreachRDD(rdd => { - val topList = rdd.take(10) - println("\nPopular topics in last 60 seconds (%s total):".format(rdd.count())) - topList.foreach{case (count, tag) => println("%s (%s tweets)".format(tag, count))} - }) - - topCounts10.foreachRDD(rdd => { - val topList = rdd.take(10) - println("\nPopular topics in last 10 seconds (%s total):".format(rdd.count())) - topList.foreach{case (count, tag) => println("%s (%s tweets)".format(tag, count))} - }) - - ssc.start() - ssc.awaitTermination() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala deleted file mode 100644 index 99b561750bf9f..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.streaming - -import scala.language.implicitConversions - -import akka.actor.ActorSystem -import akka.actor.actorRef2Scala -import akka.util.ByteString -import akka.zeromq._ -import akka.zeromq.Subscribe - -import org.apache.spark.SparkConf -import org.apache.spark.streaming.{Seconds, StreamingContext} -import org.apache.spark.streaming.zeromq._ - -/** - * A simple publisher for demonstration purposes, repeatedly publishes random Messages - * every one second. - */ -object SimpleZeroMQPublisher { - - def main(args: Array[String]): Unit = { - if (args.length < 2) { - System.err.println("Usage: SimpleZeroMQPublisher ") - System.exit(1) - } - - val Seq(url, topic) = args.toSeq - val acs: ActorSystem = ActorSystem() - - val pubSocket = ZeroMQExtension(acs).newSocket(SocketType.Pub, Bind(url)) - implicit def stringToByteString(x: String): ByteString = ByteString(x) - val messages: List[ByteString] = List("words ", "may ", "count ") - while (true) { - Thread.sleep(1000) - pubSocket ! ZMQMessage(ByteString(topic) :: messages) - } - acs.awaitTermination() - } -} - -// scalastyle:off -/** - * A sample wordcount with ZeroMQStream stream - * - * To work with zeroMQ, some native libraries have to be installed. - * Install zeroMQ (release 2.1) core libraries. [ZeroMQ Install guide] - * (http://www.zeromq.org/intro:get-the-software) - * - * Usage: ZeroMQWordCount - * and describe where zeroMq publisher is running. - * - * To run this example locally, you may run publisher as - * `$ bin/run-example \ - * org.apache.spark.examples.streaming.SimpleZeroMQPublisher tcp://127.0.0.1:1234 foo` - * and run the example as - * `$ bin/run-example \ - * org.apache.spark.examples.streaming.ZeroMQWordCount tcp://127.0.0.1:1234 foo` - */ -// scalastyle:on -object ZeroMQWordCount { - def main(args: Array[String]) { - if (args.length < 2) { - System.err.println("Usage: ZeroMQWordCount ") - System.exit(1) - } - StreamingExamples.setStreamingLogLevels() - val Seq(url, topic) = args.toSeq - val sparkConf = new SparkConf().setAppName("ZeroMQWordCount") - // Create the context and set the batch size - val ssc = new StreamingContext(sparkConf, Seconds(2)) - - def bytesToStringIterator(x: Seq[ByteString]): Iterator[String] = x.map(_.utf8String).iterator - - // For this stream, a zeroMQ publisher should be running. - val lines = ZeroMQUtils.createStream( - ssc, - url, - Subscribe(topic), - bytesToStringIterator _) - val words = lines.flatMap(_.split(" ")) - val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - wordCounts.print() - ssc.start() - ssc.awaitTermination() - } -} -// scalastyle:on println diff --git a/external/akka/pom.xml b/external/akka/pom.xml deleted file mode 100644 index bbe644e3b32b3..0000000000000 --- a/external/akka/pom.xml +++ /dev/null @@ -1,70 +0,0 @@ - - - - - 4.0.0 - - org.apache.spark - spark-parent_2.11 - 2.0.0-SNAPSHOT - ../../pom.xml - - - org.apache.spark - spark-streaming-akka_2.11 - - streaming-akka - - jar - Spark Project External Akka - http://spark.apache.org/ - - - - org.apache.spark - spark-streaming_${scala.binary.version} - ${project.version} - provided - - - org.apache.spark - spark-core_${scala.binary.version} - ${project.version} - test-jar - test - - - org.apache.spark - spark-test-tags_${scala.binary.version} - - - ${akka.group} - akka-actor_${scala.binary.version} - ${akka.version} - - - ${akka.group} - akka-remote_${scala.binary.version} - ${akka.version} - - - - target/scala-${scala.binary.version}/classes - target/scala-${scala.binary.version}/test-classes - - diff --git a/external/akka/src/main/scala/org/apache/spark/streaming/akka/ActorReceiver.scala b/external/akka/src/main/scala/org/apache/spark/streaming/akka/ActorReceiver.scala deleted file mode 100644 index 33415c15be2ef..0000000000000 --- a/external/akka/src/main/scala/org/apache/spark/streaming/akka/ActorReceiver.scala +++ /dev/null @@ -1,306 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.akka - -import java.nio.ByteBuffer -import java.util.concurrent.atomic.AtomicInteger - -import scala.concurrent.Future -import scala.concurrent.duration._ -import scala.language.postfixOps -import scala.reflect.ClassTag - -import akka.actor._ -import akka.actor.SupervisorStrategy.{Escalate, Restart} -import akka.pattern.ask -import akka.util.Timeout -import com.typesafe.config.ConfigFactory - -import org.apache.spark.{Logging, TaskContext} -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.receiver.Receiver - -/** - * :: DeveloperApi :: - * A helper with set of defaults for supervisor strategy - */ -@DeveloperApi -object ActorReceiver { - - /** - * A OneForOneStrategy supervisor strategy with `maxNrOfRetries = 10` and - * `withinTimeRange = 15 millis`. For RuntimeException, it will restart the ActorReceiver; for - * others, it just escalates the failure to the supervisor of the supervisor. - */ - val defaultSupervisorStrategy = OneForOneStrategy(maxNrOfRetries = 10, withinTimeRange = - 15 millis) { - case _: RuntimeException => Restart - case _: Exception => Escalate - } - - /** - * A default ActorSystem creator. It will use a unique system name - * (streaming-actor-system-) to start an ActorSystem that supports remote - * communication. - */ - val defaultActorSystemCreator: () => ActorSystem = () => { - val uniqueSystemName = s"streaming-actor-system-${TaskContext.get().taskAttemptId()}" - val akkaConf = ConfigFactory.parseString( - s"""akka.actor.provider = "akka.remote.RemoteActorRefProvider" - |akka.remote.enabled-transports = ["akka.remote.netty.tcp"] - |""".stripMargin) - ActorSystem(uniqueSystemName, akkaConf) - } -} - -/** - * :: DeveloperApi :: - * A base Actor that provides APIs for pushing received data into Spark Streaming for processing. - * - * Find more details at: http://spark.apache.org/docs/latest/streaming-custom-receivers.html - * - * @example {{{ - * class MyActor extends ActorReceiver { - * def receive { - * case anything: String => store(anything) - * } - * } - * - * AkkaUtils.createStream[String](ssc, Props[MyActor](),"MyActorReceiver") - * - * }}} - * - * @note Since Actor may exist outside the spark framework, It is thus user's responsibility - * to ensure the type safety, i.e. parametrized type of push block and InputDStream - * should be same. - */ -@DeveloperApi -abstract class ActorReceiver extends Actor { - - /** Store an iterator of received data as a data block into Spark's memory. */ - def store[T](iter: Iterator[T]) { - context.parent ! IteratorData(iter) - } - - /** - * Store the bytes of received data as a data block into Spark's memory. Note - * that the data in the ByteBuffer must be serialized using the same serializer - * that Spark is configured to use. - */ - def store(bytes: ByteBuffer) { - context.parent ! ByteBufferData(bytes) - } - - /** - * Store a single item of received data to Spark's memory asynchronously. - * These single items will be aggregated together into data blocks before - * being pushed into Spark's memory. - */ - def store[T](item: T) { - context.parent ! SingleItemData(item) - } - - /** - * Store a single item of received data to Spark's memory and returns a `Future`. - * The `Future` will be completed when the operator finishes, or with an - * `akka.pattern.AskTimeoutException` after the given timeout has expired. - * These single items will be aggregated together into data blocks before - * being pushed into Spark's memory. - * - * This method allows the user to control the flow speed using `Future` - */ - def store[T](item: T, timeout: Timeout): Future[Unit] = { - context.parent.ask(AskStoreSingleItemData(item))(timeout).map(_ => ())(context.dispatcher) - } -} - -/** - * :: DeveloperApi :: - * A Java UntypedActor that provides APIs for pushing received data into Spark Streaming for - * processing. - * - * Find more details at: http://spark.apache.org/docs/latest/streaming-custom-receivers.html - * - * @example {{{ - * class MyActor extends JavaActorReceiver { - * @Override - * public void onReceive(Object msg) throws Exception { - * store((String) msg); - * } - * } - * - * AkkaUtils.createStream(jssc, Props.create(MyActor.class), "MyActorReceiver"); - * - * }}} - * - * @note Since Actor may exist outside the spark framework, It is thus user's responsibility - * to ensure the type safety, i.e. parametrized type of push block and InputDStream - * should be same. - */ -@DeveloperApi -abstract class JavaActorReceiver extends UntypedActor { - - /** Store an iterator of received data as a data block into Spark's memory. */ - def store[T](iter: Iterator[T]) { - context.parent ! IteratorData(iter) - } - - /** - * Store the bytes of received data as a data block into Spark's memory. Note - * that the data in the ByteBuffer must be serialized using the same serializer - * that Spark is configured to use. - */ - def store(bytes: ByteBuffer) { - context.parent ! ByteBufferData(bytes) - } - - /** - * Store a single item of received data to Spark's memory. - * These single items will be aggregated together into data blocks before - * being pushed into Spark's memory. - */ - def store[T](item: T) { - context.parent ! SingleItemData(item) - } - - /** - * Store a single item of received data to Spark's memory and returns a `Future`. - * The `Future` will be completed when the operator finishes, or with an - * `akka.pattern.AskTimeoutException` after the given timeout has expired. - * These single items will be aggregated together into data blocks before - * being pushed into Spark's memory. - * - * This method allows the user to control the flow speed using `Future` - */ - def store[T](item: T, timeout: Timeout): Future[Unit] = { - context.parent.ask(AskStoreSingleItemData(item))(timeout).map(_ => ())(context.dispatcher) - } -} - -/** - * :: DeveloperApi :: - * Statistics for querying the supervisor about state of workers. Used in - * conjunction with `AkkaUtils.createStream` and - * [[org.apache.spark.streaming.akka.ActorReceiverSupervisor]]. - */ -@DeveloperApi -case class Statistics(numberOfMsgs: Int, - numberOfWorkers: Int, - numberOfHiccups: Int, - otherInfo: String) - -/** Case class to receive data sent by child actors */ -private[akka] sealed trait ActorReceiverData -private[akka] case class SingleItemData[T](item: T) extends ActorReceiverData -private[akka] case class AskStoreSingleItemData[T](item: T) extends ActorReceiverData -private[akka] case class IteratorData[T](iterator: Iterator[T]) extends ActorReceiverData -private[akka] case class ByteBufferData(bytes: ByteBuffer) extends ActorReceiverData -private[akka] object Ack extends ActorReceiverData - -/** - * Provides Actors as receivers for receiving stream. - * - * As Actors can also be used to receive data from almost any stream source. - * A nice set of abstraction(s) for actors as receivers is already provided for - * a few general cases. It is thus exposed as an API where user may come with - * their own Actor to run as receiver for Spark Streaming input source. - * - * This starts a supervisor actor which starts workers and also provides - * [http://doc.akka.io/docs/akka/snapshot/scala/fault-tolerance.html fault-tolerance]. - * - * Here's a way to start more supervisor/workers as its children. - * - * @example {{{ - * context.parent ! Props(new Supervisor) - * }}} OR {{{ - * context.parent ! Props(new Worker, "Worker") - * }}} - */ -private[akka] class ActorReceiverSupervisor[T: ClassTag]( - actorSystemCreator: () => ActorSystem, - props: Props, - name: String, - storageLevel: StorageLevel, - receiverSupervisorStrategy: SupervisorStrategy - ) extends Receiver[T](storageLevel) with Logging { - - private lazy val actorSystem = actorSystemCreator() - protected lazy val actorSupervisor = actorSystem.actorOf(Props(new Supervisor), - "Supervisor" + streamId) - - class Supervisor extends Actor { - - override val supervisorStrategy = receiverSupervisorStrategy - private val worker = context.actorOf(props, name) - logInfo("Started receiver worker at:" + worker.path) - - private val n: AtomicInteger = new AtomicInteger(0) - private val hiccups: AtomicInteger = new AtomicInteger(0) - - override def receive: PartialFunction[Any, Unit] = { - - case IteratorData(iterator) => - logDebug("received iterator") - store(iterator.asInstanceOf[Iterator[T]]) - - case SingleItemData(msg) => - logDebug("received single") - store(msg.asInstanceOf[T]) - n.incrementAndGet - - case AskStoreSingleItemData(msg) => - logDebug("received single sync") - store(msg.asInstanceOf[T]) - n.incrementAndGet - sender() ! Ack - - case ByteBufferData(bytes) => - logDebug("received bytes") - store(bytes) - - case props: Props => - val worker = context.actorOf(props) - logInfo("Started receiver worker at:" + worker.path) - sender ! worker - - case (props: Props, name: String) => - val worker = context.actorOf(props, name) - logInfo("Started receiver worker at:" + worker.path) - sender ! worker - - case _: PossiblyHarmful => hiccups.incrementAndGet() - - case _: Statistics => - val workers = context.children - sender ! Statistics(n.get, workers.size, hiccups.get, workers.mkString("\n")) - - } - } - - def onStart(): Unit = { - actorSupervisor - logInfo("Supervision tree for receivers initialized at:" + actorSupervisor.path) - } - - def onStop(): Unit = { - actorSupervisor ! PoisonPill - actorSystem.shutdown() - actorSystem.awaitTermination() - } -} diff --git a/external/akka/src/main/scala/org/apache/spark/streaming/akka/AkkaUtils.scala b/external/akka/src/main/scala/org/apache/spark/streaming/akka/AkkaUtils.scala deleted file mode 100644 index 38c35c5ae7a18..0000000000000 --- a/external/akka/src/main/scala/org/apache/spark/streaming/akka/AkkaUtils.scala +++ /dev/null @@ -1,147 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.akka - -import scala.reflect.ClassTag - -import akka.actor.{ActorSystem, Props, SupervisorStrategy} - -import org.apache.spark.api.java.function.{Function0 => JFunction0} -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} -import org.apache.spark.streaming.dstream.ReceiverInputDStream - -object AkkaUtils { - - /** - * Create an input stream with a user-defined actor. See [[ActorReceiver]] for more details. - * - * @param ssc The StreamingContext instance - * @param propsForActor Props object defining creation of the actor - * @param actorName Name of the actor - * @param storageLevel RDD storage level (default: StorageLevel.MEMORY_AND_DISK_SER_2) - * @param actorSystemCreator A function to create ActorSystem in executors. `ActorSystem` will - * be shut down when the receiver is stopping (default: - * ActorReceiver.defaultActorSystemCreator) - * @param supervisorStrategy the supervisor strategy (default: ActorReceiver.defaultStrategy) - * - * @note An important point to note: - * Since Actor may exist outside the spark framework, It is thus user's responsibility - * to ensure the type safety, i.e. parametrized type of data received and createStream - * should be same. - */ - def createStream[T: ClassTag]( - ssc: StreamingContext, - propsForActor: Props, - actorName: String, - storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2, - actorSystemCreator: () => ActorSystem = ActorReceiver.defaultActorSystemCreator, - supervisorStrategy: SupervisorStrategy = ActorReceiver.defaultSupervisorStrategy - ): ReceiverInputDStream[T] = ssc.withNamedScope("actor stream") { - val cleanF = ssc.sc.clean(actorSystemCreator) - ssc.receiverStream(new ActorReceiverSupervisor[T]( - cleanF, - propsForActor, - actorName, - storageLevel, - supervisorStrategy)) - } - - /** - * Create an input stream with a user-defined actor. See [[JavaActorReceiver]] for more details. - * - * @param jssc The StreamingContext instance - * @param propsForActor Props object defining creation of the actor - * @param actorName Name of the actor - * @param storageLevel Storage level to use for storing the received objects - * @param actorSystemCreator A function to create ActorSystem in executors. `ActorSystem` will - * be shut down when the receiver is stopping. - * @param supervisorStrategy the supervisor strategy - * - * @note An important point to note: - * Since Actor may exist outside the spark framework, It is thus user's responsibility - * to ensure the type safety, i.e. parametrized type of data received and createStream - * should be same. - */ - def createStream[T]( - jssc: JavaStreamingContext, - propsForActor: Props, - actorName: String, - storageLevel: StorageLevel, - actorSystemCreator: JFunction0[ActorSystem], - supervisorStrategy: SupervisorStrategy - ): JavaReceiverInputDStream[T] = { - implicit val cm: ClassTag[T] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - createStream[T]( - jssc.ssc, - propsForActor, - actorName, - storageLevel, - () => actorSystemCreator.call(), - supervisorStrategy) - } - - /** - * Create an input stream with a user-defined actor. See [[JavaActorReceiver]] for more details. - * - * @param jssc The StreamingContext instance - * @param propsForActor Props object defining creation of the actor - * @param actorName Name of the actor - * @param storageLevel Storage level to use for storing the received objects - * - * @note An important point to note: - * Since Actor may exist outside the spark framework, It is thus user's responsibility - * to ensure the type safety, i.e. parametrized type of data received and createStream - * should be same. - */ - def createStream[T]( - jssc: JavaStreamingContext, - propsForActor: Props, - actorName: String, - storageLevel: StorageLevel - ): JavaReceiverInputDStream[T] = { - implicit val cm: ClassTag[T] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - createStream[T](jssc.ssc, propsForActor, actorName, storageLevel) - } - - /** - * Create an input stream with a user-defined actor. Storage level of the data will be the default - * StorageLevel.MEMORY_AND_DISK_SER_2. See [[JavaActorReceiver]] for more details. - * - * @param jssc The StreamingContext instance - * @param propsForActor Props object defining creation of the actor - * @param actorName Name of the actor - * - * @note An important point to note: - * Since Actor may exist outside the spark framework, It is thus user's responsibility - * to ensure the type safety, i.e. parametrized type of data received and createStream - * should be same. - */ - def createStream[T]( - jssc: JavaStreamingContext, - propsForActor: Props, - actorName: String - ): JavaReceiverInputDStream[T] = { - implicit val cm: ClassTag[T] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - createStream[T](jssc.ssc, propsForActor, actorName) - } -} diff --git a/external/akka/src/test/java/org/apache/spark/streaming/akka/JavaAkkaUtilsSuite.java b/external/akka/src/test/java/org/apache/spark/streaming/akka/JavaAkkaUtilsSuite.java deleted file mode 100644 index ac5ef31c8b355..0000000000000 --- a/external/akka/src/test/java/org/apache/spark/streaming/akka/JavaAkkaUtilsSuite.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.akka; - -import akka.actor.ActorSystem; -import akka.actor.Props; -import akka.actor.SupervisorStrategy; -import akka.util.Timeout; -import org.apache.spark.streaming.Duration; -import org.apache.spark.streaming.api.java.JavaStreamingContext; -import org.junit.Test; - -import org.apache.spark.api.java.function.Function0; -import org.apache.spark.storage.StorageLevel; -import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; - -public class JavaAkkaUtilsSuite { - - @Test // tests the API, does not actually test data receiving - public void testAkkaUtils() { - JavaStreamingContext jsc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); - try { - JavaReceiverInputDStream test1 = AkkaUtils.createStream( - jsc, Props.create(JavaTestActor.class), "test"); - JavaReceiverInputDStream test2 = AkkaUtils.createStream( - jsc, Props.create(JavaTestActor.class), "test", StorageLevel.MEMORY_AND_DISK_SER_2()); - JavaReceiverInputDStream test3 = AkkaUtils.createStream( - jsc, - Props.create(JavaTestActor.class), - "test", StorageLevel.MEMORY_AND_DISK_SER_2(), - new ActorSystemCreatorForTest(), - SupervisorStrategy.defaultStrategy()); - } finally { - jsc.stop(); - } - } -} - -class ActorSystemCreatorForTest implements Function0 { - @Override - public ActorSystem call() { - return null; - } -} - - -class JavaTestActor extends JavaActorReceiver { - @Override - public void onReceive(Object message) throws Exception { - store((String) message); - store((String) message, new Timeout(1000)); - } -} diff --git a/external/akka/src/test/scala/org/apache/spark/streaming/akka/AkkaUtilsSuite.scala b/external/akka/src/test/scala/org/apache/spark/streaming/akka/AkkaUtilsSuite.scala deleted file mode 100644 index ce95d9dd72f90..0000000000000 --- a/external/akka/src/test/scala/org/apache/spark/streaming/akka/AkkaUtilsSuite.scala +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.akka - -import scala.concurrent.duration._ - -import akka.actor.{Props, SupervisorStrategy} - -import org.apache.spark.SparkFunSuite -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{Seconds, StreamingContext} -import org.apache.spark.streaming.dstream.ReceiverInputDStream - -class AkkaUtilsSuite extends SparkFunSuite { - - test("createStream") { - val ssc: StreamingContext = new StreamingContext("local[2]", "test", Seconds(1000)) - try { - // tests the API, does not actually test data receiving - val test1: ReceiverInputDStream[String] = AkkaUtils.createStream( - ssc, Props[TestActor](), "test") - val test2: ReceiverInputDStream[String] = AkkaUtils.createStream( - ssc, Props[TestActor](), "test", StorageLevel.MEMORY_AND_DISK_SER_2) - val test3: ReceiverInputDStream[String] = AkkaUtils.createStream( - ssc, - Props[TestActor](), - "test", - StorageLevel.MEMORY_AND_DISK_SER_2, - supervisorStrategy = SupervisorStrategy.defaultStrategy) - val test4: ReceiverInputDStream[String] = AkkaUtils.createStream( - ssc, Props[TestActor](), "test", StorageLevel.MEMORY_AND_DISK_SER_2, () => null) - val test5: ReceiverInputDStream[String] = AkkaUtils.createStream( - ssc, Props[TestActor](), "test", StorageLevel.MEMORY_AND_DISK_SER_2, () => null) - val test6: ReceiverInputDStream[String] = AkkaUtils.createStream( - ssc, - Props[TestActor](), - "test", - StorageLevel.MEMORY_AND_DISK_SER_2, - () => null, - SupervisorStrategy.defaultStrategy) - } finally { - ssc.stop() - } - } -} - -class TestActor extends ActorReceiver { - override def receive: Receive = { - case m: String => store(m) - case m => store(m, 10.seconds) - } -} diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml deleted file mode 100644 index ac15b93c048da..0000000000000 --- a/external/flume-assembly/pom.xml +++ /dev/null @@ -1,168 +0,0 @@ - - - - - 4.0.0 - - org.apache.spark - spark-parent_2.11 - 2.0.0-SNAPSHOT - ../../pom.xml - - - org.apache.spark - spark-streaming-flume-assembly_2.11 - jar - Spark Project External Flume Assembly - http://spark.apache.org/ - - - provided - streaming-flume-assembly - - - - - org.apache.spark - spark-streaming-flume_${scala.binary.version} - ${project.version} - - - org.mortbay.jetty - jetty - - - org.mortbay.jetty - jetty-util - - - - - org.apache.spark - spark-streaming_${scala.binary.version} - ${project.version} - provided - - - - commons-codec - commons-codec - provided - - - commons-lang - commons-lang - provided - - - commons-net - commons-net - provided - - - com.google.protobuf - protobuf-java - provided - - - org.apache.avro - avro - provided - - - org.apache.avro - avro-ipc - provided - - - org.apache.avro - avro-mapred - ${avro.mapred.classifier} - provided - - - org.scala-lang - scala-library - provided - - - - - target/scala-${scala.binary.version}/classes - target/scala-${scala.binary.version}/test-classes - - - org.apache.maven.plugins - maven-shade-plugin - - false - - - *:* - - - - - *:* - - META-INF/*.SF - META-INF/*.DSA - META-INF/*.RSA - - - - - - - package - - shade - - - - - - reference.conf - - - log4j.properties - - - - - - - - - - - - - - flume-provided - - provided - - - - - diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml deleted file mode 100644 index e4effe158c826..0000000000000 --- a/external/flume-sink/pom.xml +++ /dev/null @@ -1,129 +0,0 @@ - - - - - 4.0.0 - - org.apache.spark - spark-parent_2.11 - 2.0.0-SNAPSHOT - ../../pom.xml - - - org.apache.spark - spark-streaming-flume-sink_2.11 - - streaming-flume-sink - - jar - Spark Project External Flume Sink - http://spark.apache.org/ - - - - org.apache.flume - flume-ng-sdk - - - - com.google.guava - guava - - - - org.apache.thrift - libthrift - - - - - org.apache.flume - flume-ng-core - - - com.google.guava - guava - - - org.apache.thrift - libthrift - - - - - org.scala-lang - scala-library - - - - com.google.guava - guava - test - - - - io.netty - netty - 3.4.0.Final - test - - - org.apache.spark - spark-test-tags_${scala.binary.version} - - - - target/scala-${scala.binary.version}/classes - target/scala-${scala.binary.version}/test-classes - - - org.apache.avro - avro-maven-plugin - ${avro.version} - - - ${project.basedir}/target/scala-${scala.binary.version}/src_managed/main/compiled_avro - - - - generate-sources - - idl-protocol - - - - - - org.apache.maven.plugins - maven-shade-plugin - - - - - - - - diff --git a/external/flume-sink/src/main/avro/sparkflume.avdl b/external/flume-sink/src/main/avro/sparkflume.avdl deleted file mode 100644 index 8806e863ac7c6..0000000000000 --- a/external/flume-sink/src/main/avro/sparkflume.avdl +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -@namespace("org.apache.spark.streaming.flume.sink") - -protocol SparkFlumeProtocol { - - record SparkSinkEvent { - map headers; - bytes body; - } - - record EventBatch { - string errorMsg = ""; // If this is empty it is a valid message, else it represents an error - string sequenceNumber; - array events; - } - - EventBatch getEventBatch (int n); - - void ack (string sequenceNumber); - - void nack (string sequenceNumber); -} diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala deleted file mode 100644 index aa530a7121bd0..0000000000000 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala +++ /dev/null @@ -1,127 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.streaming.flume.sink - -import org.slf4j.{Logger, LoggerFactory} - -/** - * Copy of the org.apache.spark.Logging for being used in the Spark Sink. - * The org.apache.spark.Logging is not used so that all of Spark is not brought - * in as a dependency. - */ -private[sink] trait Logging { - // Make the log field transient so that objects with Logging can - // be serialized and used on another machine - @transient private var _log: Logger = null - - // Method to get or create the logger for this object - protected def log: Logger = { - if (_log == null) { - initializeIfNecessary() - var className = this.getClass.getName - // Ignore trailing $'s in the class names for Scala objects - if (className.endsWith("$")) { - className = className.substring(0, className.length - 1) - } - _log = LoggerFactory.getLogger(className) - } - _log - } - - // Log methods that take only a String - protected def logInfo(msg: => String) { - if (log.isInfoEnabled) log.info(msg) - } - - protected def logDebug(msg: => String) { - if (log.isDebugEnabled) log.debug(msg) - } - - protected def logTrace(msg: => String) { - if (log.isTraceEnabled) log.trace(msg) - } - - protected def logWarning(msg: => String) { - if (log.isWarnEnabled) log.warn(msg) - } - - protected def logError(msg: => String) { - if (log.isErrorEnabled) log.error(msg) - } - - // Log methods that take Throwables (Exceptions/Errors) too - protected def logInfo(msg: => String, throwable: Throwable) { - if (log.isInfoEnabled) log.info(msg, throwable) - } - - protected def logDebug(msg: => String, throwable: Throwable) { - if (log.isDebugEnabled) log.debug(msg, throwable) - } - - protected def logTrace(msg: => String, throwable: Throwable) { - if (log.isTraceEnabled) log.trace(msg, throwable) - } - - protected def logWarning(msg: => String, throwable: Throwable) { - if (log.isWarnEnabled) log.warn(msg, throwable) - } - - protected def logError(msg: => String, throwable: Throwable) { - if (log.isErrorEnabled) log.error(msg, throwable) - } - - protected def isTraceEnabled(): Boolean = { - log.isTraceEnabled - } - - private def initializeIfNecessary() { - if (!Logging.initialized) { - Logging.initLock.synchronized { - if (!Logging.initialized) { - initializeLogging() - } - } - } - } - - private def initializeLogging() { - Logging.initialized = true - - // Force a call into slf4j to initialize it. Avoids this happening from mutliple threads - // and triggering this: http://mailman.qos.ch/pipermail/slf4j-dev/2010-April/002956.html - log - } -} - -private[sink] object Logging { - @volatile private var initialized = false - val initLock = new Object() - try { - // We use reflection here to handle the case where users remove the - // slf4j-to-jul bridge order to route their logs to JUL. - // scalastyle:off classforname - val bridgeClass = Class.forName("org.slf4j.bridge.SLF4JBridgeHandler") - // scalastyle:on classforname - bridgeClass.getMethod("removeHandlersForRootLogger").invoke(null) - val installed = bridgeClass.getMethod("isInstalled").invoke(null).asInstanceOf[Boolean] - if (!installed) { - bridgeClass.getMethod("install").invoke(null) - } - } catch { - case e: ClassNotFoundException => // can't log anything yet so just fail silently - } -} diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala deleted file mode 100644 index 719fca0938b3a..0000000000000 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala +++ /dev/null @@ -1,166 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.streaming.flume.sink - -import java.util.UUID -import java.util.concurrent.{CountDownLatch, Executors} -import java.util.concurrent.atomic.AtomicLong - -import scala.collection.mutable - -import org.apache.flume.Channel - -/** - * Class that implements the SparkFlumeProtocol, that is used by the Avro Netty Server to process - * requests. Each getEvents, ack and nack call is forwarded to an instance of this class. - * @param threads Number of threads to use to process requests. - * @param channel The channel that the sink pulls events from - * @param transactionTimeout Timeout in millis after which the transaction if not acked by Spark - * is rolled back. - */ -// Flume forces transactions to be thread-local. So each transaction *must* be committed, or -// rolled back from the thread it was originally created in. So each getEvents call from Spark -// creates a TransactionProcessor which runs in a new thread, in which the transaction is created -// and events are pulled off the channel. Once the events are sent to spark, -// that thread is blocked and the TransactionProcessor is saved in a map, -// until an ACK or NACK comes back or the transaction times out (after the specified timeout). -// When the response comes or a timeout is hit, the TransactionProcessor is retrieved and then -// unblocked, at which point the transaction is committed or rolled back. - -private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Channel, - val transactionTimeout: Int, val backOffInterval: Int) extends SparkFlumeProtocol with Logging { - val transactionExecutorOpt = Option(Executors.newFixedThreadPool(threads, - new SparkSinkThreadFactory("Spark Sink Processor Thread - %d"))) - // Protected by `sequenceNumberToProcessor` - private val sequenceNumberToProcessor = mutable.HashMap[CharSequence, TransactionProcessor]() - // This sink will not persist sequence numbers and reuses them if it gets restarted. - // So it is possible to commit a transaction which may have been meant for the sink before the - // restart. - // Since the new txn may not have the same sequence number we must guard against accidentally - // committing a new transaction. To reduce the probability of that happening a random string is - // prepended to the sequence number. Does not change for life of sink - private val seqBase = UUID.randomUUID().toString.substring(0, 8) - private val seqCounter = new AtomicLong(0) - - // Protected by `sequenceNumberToProcessor` - private var stopped = false - - @volatile private var isTest = false - private var testLatch: CountDownLatch = null - - /** - * Returns a bunch of events to Spark over Avro RPC. - * @param n Maximum number of events to return in a batch - * @return [[EventBatch]] instance that has a sequence number and an array of at most n events - */ - override def getEventBatch(n: Int): EventBatch = { - logDebug("Got getEventBatch call from Spark.") - val sequenceNumber = seqBase + seqCounter.incrementAndGet() - createProcessor(sequenceNumber, n) match { - case Some(processor) => - transactionExecutorOpt.foreach(_.submit(processor)) - // Wait until a batch is available - will be an error if error message is non-empty - val batch = processor.getEventBatch - if (SparkSinkUtils.isErrorBatch(batch)) { - // Remove the processor if it is an error batch since no ACK is sent. - removeAndGetProcessor(sequenceNumber) - logWarning("Received an error batch - no events were received from channel! ") - } - batch - case None => - new EventBatch("Spark sink has been stopped!", "", java.util.Collections.emptyList()) - } - } - - private def createProcessor(seq: String, n: Int): Option[TransactionProcessor] = { - sequenceNumberToProcessor.synchronized { - if (!stopped) { - val processor = new TransactionProcessor( - channel, seq, n, transactionTimeout, backOffInterval, this) - sequenceNumberToProcessor.put(seq, processor) - if (isTest) { - processor.countDownWhenBatchAcked(testLatch) - } - Some(processor) - } else { - None - } - } - } - - /** - * Called by Spark to indicate successful commit of a batch - * @param sequenceNumber The sequence number of the event batch that was successful - */ - override def ack(sequenceNumber: CharSequence): Void = { - logDebug("Received Ack for batch with sequence number: " + sequenceNumber) - completeTransaction(sequenceNumber, success = true) - null - } - - /** - * Called by Spark to indicate failed commit of a batch - * @param sequenceNumber The sequence number of the event batch that failed - * @return - */ - override def nack(sequenceNumber: CharSequence): Void = { - completeTransaction(sequenceNumber, success = false) - logInfo("Spark failed to commit transaction. Will reattempt events.") - null - } - - /** - * Helper method to commit or rollback a transaction. - * @param sequenceNumber The sequence number of the batch that was completed - * @param success Whether the batch was successful or not. - */ - private def completeTransaction(sequenceNumber: CharSequence, success: Boolean) { - removeAndGetProcessor(sequenceNumber).foreach(processor => { - processor.batchProcessed(success) - }) - } - - /** - * Helper method to remove the TxnProcessor for a Sequence Number. Can be used to avoid a leak. - * @param sequenceNumber - * @return An `Option` of the transaction processor for the corresponding batch. Note that this - * instance is no longer tracked and the caller is responsible for that txn processor. - */ - private[sink] def removeAndGetProcessor(sequenceNumber: CharSequence): - Option[TransactionProcessor] = { - sequenceNumberToProcessor.synchronized { - sequenceNumberToProcessor.remove(sequenceNumber.toString) - } - } - - private[sink] def countDownWhenBatchAcked(latch: CountDownLatch) { - testLatch = latch - isTest = true - } - - /** - * Shuts down the executor used to process transactions. - */ - def shutdown() { - logInfo("Shutting down Spark Avro Callback Handler") - sequenceNumberToProcessor.synchronized { - stopped = true - sequenceNumberToProcessor.values.foreach(_.shutdown()) - } - transactionExecutorOpt.foreach(_.shutdownNow()) - } -} diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala deleted file mode 100644 index 14dffb15fef98..0000000000000 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala +++ /dev/null @@ -1,171 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.streaming.flume.sink - -import java.net.InetSocketAddress -import java.util.concurrent._ - -import org.apache.avro.ipc.NettyServer -import org.apache.avro.ipc.specific.SpecificResponder -import org.apache.flume.Context -import org.apache.flume.Sink.Status -import org.apache.flume.conf.{Configurable, ConfigurationException} -import org.apache.flume.sink.AbstractSink - -/** - * A sink that uses Avro RPC to run a server that can be polled by Spark's - * FlumePollingInputDStream. This sink has the following configuration parameters: - * - * hostname - The hostname to bind to. Default: 0.0.0.0 - * port - The port to bind to. (No default - mandatory) - * timeout - Time in seconds after which a transaction is rolled back, - * if an ACK is not received from Spark within that time - * threads - Number of threads to use to receive requests from Spark (Default: 10) - * - * This sink is unlike other Flume sinks in the sense that it does not push data, - * instead the process method in this sink simply blocks the SinkRunner the first time it is - * called. This sink starts up an Avro IPC server that uses the SparkFlumeProtocol. - * - * Each time a getEventBatch call comes, creates a transaction and reads events - * from the channel. When enough events are read, the events are sent to the Spark receiver and - * the thread itself is blocked and a reference to it saved off. - * - * When the ack for that batch is received, - * the thread which created the transaction is is retrieved and it commits the transaction with the - * channel from the same thread it was originally created in (since Flume transactions are - * thread local). If a nack is received instead, the sink rolls back the transaction. If no ack - * is received within the specified timeout, the transaction is rolled back too. If an ack comes - * after that, it is simply ignored and the events get re-sent. - * - */ - -class SparkSink extends AbstractSink with Logging with Configurable { - - // Size of the pool to use for holding transaction processors. - private var poolSize: Integer = SparkSinkConfig.DEFAULT_THREADS - - // Timeout for each transaction. If spark does not respond in this much time, - // rollback the transaction - private var transactionTimeout = SparkSinkConfig.DEFAULT_TRANSACTION_TIMEOUT - - // Address info to bind on - private var hostname: String = SparkSinkConfig.DEFAULT_HOSTNAME - private var port: Int = 0 - - private var backOffInterval: Int = 200 - - // Handle to the server - private var serverOpt: Option[NettyServer] = None - - // The handler that handles the callback from Avro - private var handler: Option[SparkAvroCallbackHandler] = None - - // Latch that blocks off the Flume framework from wasting 1 thread. - private val blockingLatch = new CountDownLatch(1) - - override def start() { - logInfo("Starting Spark Sink: " + getName + " on port: " + port + " and interface: " + - hostname + " with " + "pool size: " + poolSize + " and transaction timeout: " + - transactionTimeout + ".") - handler = Option(new SparkAvroCallbackHandler(poolSize, getChannel, transactionTimeout, - backOffInterval)) - val responder = new SpecificResponder(classOf[SparkFlumeProtocol], handler.get) - // Using the constructor that takes specific thread-pools requires bringing in netty - // dependencies which are being excluded in the build. In practice, - // Netty dependencies are already available on the JVM as Flume would have pulled them in. - serverOpt = Option(new NettyServer(responder, new InetSocketAddress(hostname, port))) - serverOpt.foreach(server => { - logInfo("Starting Avro server for sink: " + getName) - server.start() - }) - super.start() - } - - override def stop() { - logInfo("Stopping Spark Sink: " + getName) - handler.foreach(callbackHandler => { - callbackHandler.shutdown() - }) - serverOpt.foreach(server => { - logInfo("Stopping Avro Server for sink: " + getName) - server.close() - server.join() - }) - blockingLatch.countDown() - super.stop() - } - - override def configure(ctx: Context) { - import SparkSinkConfig._ - hostname = ctx.getString(CONF_HOSTNAME, DEFAULT_HOSTNAME) - port = Option(ctx.getInteger(CONF_PORT)). - getOrElse(throw new ConfigurationException("The port to bind to must be specified")) - poolSize = ctx.getInteger(THREADS, DEFAULT_THREADS) - transactionTimeout = ctx.getInteger(CONF_TRANSACTION_TIMEOUT, DEFAULT_TRANSACTION_TIMEOUT) - backOffInterval = ctx.getInteger(CONF_BACKOFF_INTERVAL, DEFAULT_BACKOFF_INTERVAL) - logInfo("Configured Spark Sink with hostname: " + hostname + ", port: " + port + ", " + - "poolSize: " + poolSize + ", transactionTimeout: " + transactionTimeout + ", " + - "backoffInterval: " + backOffInterval) - } - - override def process(): Status = { - // This method is called in a loop by the Flume framework - block it until the sink is - // stopped to save CPU resources. The sink runner will interrupt this thread when the sink is - // being shut down. - logInfo("Blocking Sink Runner, sink will continue to run..") - blockingLatch.await() - Status.BACKOFF - } - - private[flume] def getPort(): Int = { - serverOpt - .map(_.getPort) - .getOrElse( - throw new RuntimeException("Server was not started!") - ) - } - - /** - * Pass in a [[CountDownLatch]] for testing purposes. This batch is counted down when each - * batch is received. The test can simply call await on this latch till the expected number of - * batches are received. - * @param latch - */ - private[flume] def countdownWhenBatchReceived(latch: CountDownLatch) { - handler.foreach(_.countDownWhenBatchAcked(latch)) - } -} - -/** - * Configuration parameters and their defaults. - */ -private[flume] -object SparkSinkConfig { - val THREADS = "threads" - val DEFAULT_THREADS = 10 - - val CONF_TRANSACTION_TIMEOUT = "timeout" - val DEFAULT_TRANSACTION_TIMEOUT = 60 - - val CONF_HOSTNAME = "hostname" - val DEFAULT_HOSTNAME = "0.0.0.0" - - val CONF_PORT = "port" - - val CONF_BACKOFF_INTERVAL = "backoffInterval" - val DEFAULT_BACKOFF_INTERVAL = 200 -} diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala deleted file mode 100644 index 845fc8debda75..0000000000000 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.streaming.flume.sink - -import java.util.concurrent.ThreadFactory -import java.util.concurrent.atomic.AtomicLong - -/** - * Thread factory that generates daemon threads with a specified name format. - */ -private[sink] class SparkSinkThreadFactory(nameFormat: String) extends ThreadFactory { - - private val threadId = new AtomicLong() - - override def newThread(r: Runnable): Thread = { - val t = new Thread(r, nameFormat.format(threadId.incrementAndGet())) - t.setDaemon(true) - t - } - -} diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkUtils.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkUtils.scala deleted file mode 100644 index 47c0e294d6b52..0000000000000 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkUtils.scala +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.streaming.flume.sink - -private[flume] object SparkSinkUtils { - /** - * This method determines if this batch represents an error or not. - * @param batch - The batch to check - * @return - true if the batch represents an error - */ - def isErrorBatch(batch: EventBatch): Boolean = { - !batch.getErrorMsg.toString.equals("") // If there is an error message, it is an error batch. - } -} diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala deleted file mode 100644 index b15c2097e550c..0000000000000 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala +++ /dev/null @@ -1,252 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.streaming.flume.sink - -import java.nio.ByteBuffer -import java.util -import java.util.concurrent.{Callable, CountDownLatch, TimeUnit} - -import scala.util.control.Breaks - -import org.apache.flume.{Channel, Transaction} - -// Flume forces transactions to be thread-local (horrible, I know!) -// So the sink basically spawns a new thread to pull the events out within a transaction. -// The thread fills in the event batch object that is set before the thread is scheduled. -// After filling it in, the thread waits on a condition - which is released only -// when the success message comes back for the specific sequence number for that event batch. -/** - * This class represents a transaction on the Flume channel. This class runs a separate thread - * which owns the transaction. The thread is blocked until the success call for that transaction - * comes back with an ACK or NACK. - * @param channel The channel from which to pull events - * @param seqNum The sequence number to use for the transaction. Must be unique - * @param maxBatchSize The maximum number of events to process per batch - * @param transactionTimeout Time in seconds after which a transaction must be rolled back - * without waiting for an ACK from Spark - * @param parent The parent [[SparkAvroCallbackHandler]] instance, for reporting timeouts - */ -private class TransactionProcessor(val channel: Channel, val seqNum: String, - var maxBatchSize: Int, val transactionTimeout: Int, val backOffInterval: Int, - val parent: SparkAvroCallbackHandler) extends Callable[Void] with Logging { - - // If a real batch is not returned, we always have to return an error batch. - @volatile private var eventBatch: EventBatch = new EventBatch("Unknown Error", "", - util.Collections.emptyList()) - - // Synchronization primitives - val batchGeneratedLatch = new CountDownLatch(1) - val batchAckLatch = new CountDownLatch(1) - - // Sanity check to ensure we don't loop like crazy - val totalAttemptsToRemoveFromChannel = Int.MaxValue / 2 - - // OK to use volatile, since the change would only make this true (otherwise it will be - // changed to false - we never apply a negation operation to this) - which means the transaction - // succeeded. - @volatile private var batchSuccess = false - - @volatile private var stopped = false - - @volatile private var isTest = false - - private var testLatch: CountDownLatch = null - - // The transaction that this processor would handle - var txOpt: Option[Transaction] = None - - /** - * Get an event batch from the channel. This method will block until a batch of events is - * available from the channel. If no events are available after a large number of attempts of - * polling the channel, this method will return an [[EventBatch]] with a non-empty error message - * - * @return An [[EventBatch]] instance with sequence number set to seqNum, filled with a - * maximum of maxBatchSize events - */ - def getEventBatch: EventBatch = { - batchGeneratedLatch.await() - eventBatch - } - - /** - * This method is to be called by the sink when it receives an ACK or NACK from Spark. This - * method is a no-op if it is called after transactionTimeout has expired since - * getEventBatch returned a batch of events. - * @param success True if an ACK was received and the transaction should be committed, else false. - */ - def batchProcessed(success: Boolean) { - logDebug("Batch processed for sequence number: " + seqNum) - batchSuccess = success - batchAckLatch.countDown() - } - - private[flume] def shutdown(): Unit = { - logDebug("Shutting down transaction processor") - stopped = true - } - - /** - * Populates events into the event batch. If the batch cannot be populated, - * this method will not set the events into the event batch, but it sets an error message. - */ - private def populateEvents() { - try { - txOpt = Option(channel.getTransaction) - if(txOpt.isEmpty) { - eventBatch.setErrorMsg("Something went wrong. Channel was " + - "unable to create a transaction!") - } - txOpt.foreach(tx => { - tx.begin() - val events = new util.ArrayList[SparkSinkEvent](maxBatchSize) - val loop = new Breaks - var gotEventsInThisTxn = false - var loopCounter: Int = 0 - loop.breakable { - while (!stopped && events.size() < maxBatchSize - && loopCounter < totalAttemptsToRemoveFromChannel) { - loopCounter += 1 - Option(channel.take()) match { - case Some(event) => - events.add(new SparkSinkEvent(toCharSequenceMap(event.getHeaders), - ByteBuffer.wrap(event.getBody))) - gotEventsInThisTxn = true - case None => - if (!gotEventsInThisTxn && !stopped) { - logDebug("Sleeping for " + backOffInterval + " millis as no events were read in" + - " the current transaction") - TimeUnit.MILLISECONDS.sleep(backOffInterval) - } else { - loop.break() - } - } - } - } - if (!gotEventsInThisTxn && !stopped) { - val msg = "Tried several times, " + - "but did not get any events from the channel!" - logWarning(msg) - eventBatch.setErrorMsg(msg) - } else { - // At this point, the events are available, so fill them into the event batch - eventBatch = new EventBatch("", seqNum, events) - } - }) - } catch { - case interrupted: InterruptedException => - // Don't pollute logs if the InterruptedException came from this being stopped - if (!stopped) { - logWarning("Error while processing transaction.", interrupted) - } - case e: Exception => - logWarning("Error while processing transaction.", e) - eventBatch.setErrorMsg(e.getMessage) - try { - txOpt.foreach(tx => { - rollbackAndClose(tx, close = true) - }) - } finally { - txOpt = None - } - } finally { - batchGeneratedLatch.countDown() - } - } - - /** - * Waits for upto transactionTimeout seconds for an ACK. If an ACK comes in - * this method commits the transaction with the channel. If the ACK does not come in within - * that time or a NACK comes in, this method rolls back the transaction. - */ - private def processAckOrNack() { - batchAckLatch.await(transactionTimeout, TimeUnit.SECONDS) - txOpt.foreach(tx => { - if (batchSuccess) { - try { - logDebug("Committing transaction") - tx.commit() - } catch { - case e: Exception => - logWarning("Error while attempting to commit transaction. Transaction will be rolled " + - "back", e) - rollbackAndClose(tx, close = false) // tx will be closed later anyway - } finally { - tx.close() - if (isTest) { - testLatch.countDown() - } - } - } else { - logWarning("Spark could not commit transaction, NACK received. Rolling back transaction.") - rollbackAndClose(tx, close = true) - // This might have been due to timeout or a NACK. Either way the following call does not - // cause issues. This is required to ensure the TransactionProcessor instance is not leaked - parent.removeAndGetProcessor(seqNum) - } - }) - } - - /** - * Helper method to rollback and optionally close a transaction - * @param tx The transaction to rollback - * @param close Whether the transaction should be closed or not after rolling back - */ - private def rollbackAndClose(tx: Transaction, close: Boolean) { - try { - logWarning("Spark was unable to successfully process the events. Transaction is being " + - "rolled back.") - tx.rollback() - } catch { - case e: Exception => - logError("Error rolling back transaction. Rollback may have failed!", e) - } finally { - if (close) { - tx.close() - } - } - } - - /** - * Helper method to convert a Map[String, String] to Map[CharSequence, CharSequence] - * @param inMap The map to be converted - * @return The converted map - */ - private def toCharSequenceMap(inMap: java.util.Map[String, String]): java.util.Map[CharSequence, - CharSequence] = { - val charSeqMap = new util.HashMap[CharSequence, CharSequence](inMap.size()) - charSeqMap.putAll(inMap) - charSeqMap - } - - /** - * When the thread is started it sets as many events as the batch size or less (if enough - * events aren't available) into the eventBatch and object and lets any threads waiting on the - * [[getEventBatch]] method to proceed. Then this thread waits for acks or nacks to come in, - * or for a specified timeout and commits or rolls back the transaction. - * @return - */ - override def call(): Void = { - populateEvents() - processAckOrNack() - null - } - - private[sink] def countDownWhenBatchAcked(latch: CountDownLatch) { - testLatch = latch - isTest = true - } -} diff --git a/external/flume-sink/src/test/resources/log4j.properties b/external/flume-sink/src/test/resources/log4j.properties deleted file mode 100644 index 42df8792f147f..0000000000000 --- a/external/flume-sink/src/test/resources/log4j.properties +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Set everything to be logged to the file streaming/target/unit-tests.log -log4j.rootCategory=INFO, file -log4j.appender.file=org.apache.log4j.FileAppender -log4j.appender.file.append=true -log4j.appender.file.file=target/unit-tests.log -log4j.appender.file.layout=org.apache.log4j.PatternLayout -log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n - -# Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN - diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala deleted file mode 100644 index 7f6cecf9cd18d..0000000000000 --- a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala +++ /dev/null @@ -1,216 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.streaming.flume.sink - -import java.net.InetSocketAddress -import java.util.concurrent.{CountDownLatch, Executors, TimeUnit} -import java.util.concurrent.atomic.AtomicInteger - -import scala.collection.JavaConverters._ -import scala.concurrent.{ExecutionContext, Future} -import scala.util.{Failure, Success} - -import org.apache.avro.ipc.NettyTransceiver -import org.apache.avro.ipc.specific.SpecificRequestor -import org.apache.flume.Context -import org.apache.flume.channel.MemoryChannel -import org.apache.flume.event.EventBuilder -import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory - -// Due to MNG-1378, there is not a way to include test dependencies transitively. -// We cannot include Spark core tests as a dependency here because it depends on -// Spark core main, which has too many dependencies to require here manually. -// For this reason, we continue to use FunSuite and ignore the scalastyle checks -// that fail if this is detected. -// scalastyle:off -import org.scalatest.FunSuite - -class SparkSinkSuite extends FunSuite { -// scalastyle:on - - val eventsPerBatch = 1000 - val channelCapacity = 5000 - - test("Success with ack") { - val (channel, sink, latch) = initializeChannelAndSink() - channel.start() - sink.start() - - putEvents(channel, eventsPerBatch) - - val port = sink.getPort - val address = new InetSocketAddress("0.0.0.0", port) - - val (transceiver, client) = getTransceiverAndClient(address, 1)(0) - val events = client.getEventBatch(1000) - client.ack(events.getSequenceNumber) - assert(events.getEvents.size() === 1000) - latch.await(1, TimeUnit.SECONDS) - assertChannelIsEmpty(channel) - sink.stop() - channel.stop() - transceiver.close() - } - - test("Failure with nack") { - val (channel, sink, latch) = initializeChannelAndSink() - channel.start() - sink.start() - putEvents(channel, eventsPerBatch) - - val port = sink.getPort - val address = new InetSocketAddress("0.0.0.0", port) - - val (transceiver, client) = getTransceiverAndClient(address, 1)(0) - val events = client.getEventBatch(1000) - assert(events.getEvents.size() === 1000) - client.nack(events.getSequenceNumber) - latch.await(1, TimeUnit.SECONDS) - assert(availableChannelSlots(channel) === 4000) - sink.stop() - channel.stop() - transceiver.close() - } - - test("Failure with timeout") { - val (channel, sink, latch) = initializeChannelAndSink(Map(SparkSinkConfig - .CONF_TRANSACTION_TIMEOUT -> 1.toString)) - channel.start() - sink.start() - putEvents(channel, eventsPerBatch) - val port = sink.getPort - val address = new InetSocketAddress("0.0.0.0", port) - - val (transceiver, client) = getTransceiverAndClient(address, 1)(0) - val events = client.getEventBatch(1000) - assert(events.getEvents.size() === 1000) - latch.await(1, TimeUnit.SECONDS) - assert(availableChannelSlots(channel) === 4000) - sink.stop() - channel.stop() - transceiver.close() - } - - test("Multiple consumers") { - testMultipleConsumers(failSome = false) - } - - test("Multiple consumers with some failures") { - testMultipleConsumers(failSome = true) - } - - def testMultipleConsumers(failSome: Boolean): Unit = { - implicit val executorContext = ExecutionContext - .fromExecutorService(Executors.newFixedThreadPool(5)) - val (channel, sink, latch) = initializeChannelAndSink(Map.empty, 5) - channel.start() - sink.start() - (1 to 5).foreach(_ => putEvents(channel, eventsPerBatch)) - val port = sink.getPort - val address = new InetSocketAddress("0.0.0.0", port) - val transceiversAndClients = getTransceiverAndClient(address, 5) - val batchCounter = new CountDownLatch(5) - val counter = new AtomicInteger(0) - transceiversAndClients.foreach(x => { - Future { - val client = x._2 - val events = client.getEventBatch(1000) - if (!failSome || counter.getAndIncrement() % 2 == 0) { - client.ack(events.getSequenceNumber) - } else { - client.nack(events.getSequenceNumber) - throw new RuntimeException("Sending NACK for failure!") - } - events - }.onComplete { - case Success(events) => - assert(events.getEvents.size() === 1000) - batchCounter.countDown() - case Failure(t) => - // Don't re-throw the exception, causes a nasty unnecessary stack trace on stdout - batchCounter.countDown() - } - }) - batchCounter.await() - latch.await(1, TimeUnit.SECONDS) - executorContext.shutdown() - if(failSome) { - assert(availableChannelSlots(channel) === 3000) - } else { - assertChannelIsEmpty(channel) - } - sink.stop() - channel.stop() - transceiversAndClients.foreach(x => x._1.close()) - } - - private def initializeChannelAndSink(overrides: Map[String, String] = Map.empty, - batchCounter: Int = 1): (MemoryChannel, SparkSink, CountDownLatch) = { - val channel = new MemoryChannel() - val channelContext = new Context() - - channelContext.put("capacity", channelCapacity.toString) - channelContext.put("transactionCapacity", 1000.toString) - channelContext.put("keep-alive", 0.toString) - channelContext.putAll(overrides.asJava) - channel.setName(scala.util.Random.nextString(10)) - channel.configure(channelContext) - - val sink = new SparkSink() - val sinkContext = new Context() - sinkContext.put(SparkSinkConfig.CONF_HOSTNAME, "0.0.0.0") - sinkContext.put(SparkSinkConfig.CONF_PORT, 0.toString) - sink.configure(sinkContext) - sink.setChannel(channel) - val latch = new CountDownLatch(batchCounter) - sink.countdownWhenBatchReceived(latch) - (channel, sink, latch) - } - - private def putEvents(ch: MemoryChannel, count: Int): Unit = { - val tx = ch.getTransaction - tx.begin() - (1 to count).foreach(x => ch.put(EventBuilder.withBody(x.toString.getBytes))) - tx.commit() - tx.close() - } - - private def getTransceiverAndClient(address: InetSocketAddress, - count: Int): Seq[(NettyTransceiver, SparkFlumeProtocol.Callback)] = { - - (1 to count).map(_ => { - lazy val channelFactoryExecutor = Executors.newCachedThreadPool( - new SparkSinkThreadFactory("Flume Receiver Channel Thread - %d")) - lazy val channelFactory = - new NioClientSocketChannelFactory(channelFactoryExecutor, channelFactoryExecutor) - val transceiver = new NettyTransceiver(address, channelFactory) - val client = SpecificRequestor.getClient(classOf[SparkFlumeProtocol.Callback], transceiver) - (transceiver, client) - }) - } - - private def assertChannelIsEmpty(channel: MemoryChannel): Unit = { - assert(availableChannelSlots(channel) === channelCapacity) - } - - private def availableChannelSlots(channel: MemoryChannel): Int = { - val queueRemaining = channel.getClass.getDeclaredField("queueRemaining") - queueRemaining.setAccessible(true) - val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits") - m.invoke(queueRemaining.get(channel)).asInstanceOf[Int] - } -} diff --git a/external/flume/pom.xml b/external/flume/pom.xml deleted file mode 100644 index d650dd034d636..0000000000000 --- a/external/flume/pom.xml +++ /dev/null @@ -1,78 +0,0 @@ - - - - - 4.0.0 - - org.apache.spark - spark-parent_2.11 - 2.0.0-SNAPSHOT - ../../pom.xml - - - org.apache.spark - spark-streaming-flume_2.11 - - streaming-flume - - jar - Spark Project External Flume - http://spark.apache.org/ - - - - org.apache.spark - spark-streaming_${scala.binary.version} - ${project.version} - provided - - - org.apache.spark - spark-core_${scala.binary.version} - ${project.version} - test-jar - test - - - org.apache.spark - spark-streaming-flume-sink_${scala.binary.version} - ${project.version} - - - org.apache.flume - flume-ng-core - - - org.apache.flume - flume-ng-sdk - - - org.scalacheck - scalacheck_${scala.binary.version} - test - - - org.apache.spark - spark-test-tags_${scala.binary.version} - - - - target/scala-${scala.binary.version}/classes - target/scala-${scala.binary.version}/test-classes - - diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala deleted file mode 100644 index 5c773d4b07cf6..0000000000000 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.flume - -import java.io.{ObjectInput, ObjectOutput} - -import scala.collection.JavaConverters._ - -import org.apache.spark.Logging -import org.apache.spark.util.Utils - -/** - * A simple object that provides the implementation of readExternal and writeExternal for both - * the wrapper classes for Flume-style Events. - */ -private[streaming] object EventTransformer extends Logging { - def readExternal(in: ObjectInput): (java.util.HashMap[CharSequence, CharSequence], - Array[Byte]) = { - val bodyLength = in.readInt() - val bodyBuff = new Array[Byte](bodyLength) - in.readFully(bodyBuff) - - val numHeaders = in.readInt() - val headers = new java.util.HashMap[CharSequence, CharSequence] - - for (i <- 0 until numHeaders) { - val keyLength = in.readInt() - val keyBuff = new Array[Byte](keyLength) - in.readFully(keyBuff) - val key: String = Utils.deserialize(keyBuff) - - val valLength = in.readInt() - val valBuff = new Array[Byte](valLength) - in.readFully(valBuff) - val value: String = Utils.deserialize(valBuff) - - headers.put(key, value) - } - (headers, bodyBuff) - } - - def writeExternal(out: ObjectOutput, headers: java.util.Map[CharSequence, CharSequence], - body: Array[Byte]) { - out.writeInt(body.length) - out.write(body) - val numHeaders = headers.size() - out.writeInt(numHeaders) - for ((k, v) <- headers.asScala) { - val keyBuff = Utils.serialize(k.toString) - out.writeInt(keyBuff.length) - out.write(keyBuff) - val valBuff = Utils.serialize(v.toString) - out.writeInt(valBuff.length) - out.write(valBuff) - } - } -} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeBatchFetcher.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeBatchFetcher.scala deleted file mode 100644 index b9d4e762ca05d..0000000000000 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeBatchFetcher.scala +++ /dev/null @@ -1,166 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.streaming.flume - -import scala.collection.mutable.ArrayBuffer - -import com.google.common.base.Throwables - -import org.apache.spark.Logging -import org.apache.spark.streaming.flume.sink._ - -/** - * This class implements the core functionality of [[FlumePollingReceiver]]. When started it - * pulls data from Flume, stores it to Spark and then sends an Ack or Nack. This class should be - * run via an [[java.util.concurrent.Executor]] as this implements [[Runnable]] - * - * @param receiver The receiver that owns this instance. - */ - -private[flume] class FlumeBatchFetcher(receiver: FlumePollingReceiver) extends Runnable with - Logging { - - def run(): Unit = { - while (!receiver.isStopped()) { - val connection = receiver.getConnections.poll() - val client = connection.client - var batchReceived = false - var seq: CharSequence = null - try { - getBatch(client) match { - case Some(eventBatch) => - batchReceived = true - seq = eventBatch.getSequenceNumber - val events = toSparkFlumeEvents(eventBatch.getEvents) - if (store(events)) { - sendAck(client, seq) - } else { - sendNack(batchReceived, client, seq) - } - case None => - } - } catch { - case e: Exception => - Throwables.getRootCause(e) match { - // If the cause was an InterruptedException, then check if the receiver is stopped - - // if yes, just break out of the loop. Else send a Nack and log a warning. - // In the unlikely case, the cause was not an Exception, - // then just throw it out and exit. - case interrupted: InterruptedException => - if (!receiver.isStopped()) { - logWarning("Interrupted while receiving data from Flume", interrupted) - sendNack(batchReceived, client, seq) - } - case exception: Exception => - logWarning("Error while receiving data from Flume", exception) - sendNack(batchReceived, client, seq) - } - } finally { - receiver.getConnections.add(connection) - } - } - } - - /** - * Gets a batch of events from the specified client. This method does not handle any exceptions - * which will be propogated to the caller. - * @param client Client to get events from - * @return [[Some]] which contains the event batch if Flume sent any events back, else [[None]] - */ - private def getBatch(client: SparkFlumeProtocol.Callback): Option[EventBatch] = { - val eventBatch = client.getEventBatch(receiver.getMaxBatchSize) - if (!SparkSinkUtils.isErrorBatch(eventBatch)) { - // No error, proceed with processing data - logDebug(s"Received batch of ${eventBatch.getEvents.size} events with sequence " + - s"number: ${eventBatch.getSequenceNumber}") - Some(eventBatch) - } else { - logWarning("Did not receive events from Flume agent due to error on the Flume agent: " + - eventBatch.getErrorMsg) - None - } - } - - /** - * Store the events in the buffer to Spark. This method will not propogate any exceptions, - * but will propogate any other errors. - * @param buffer The buffer to store - * @return true if the data was stored without any exception being thrown, else false - */ - private def store(buffer: ArrayBuffer[SparkFlumeEvent]): Boolean = { - try { - receiver.store(buffer) - true - } catch { - case e: Exception => - logWarning("Error while attempting to store data received from Flume", e) - false - } - } - - /** - * Send an ack to the client for the sequence number. This method does not handle any exceptions - * which will be propagated to the caller. - * @param client client to send the ack to - * @param seq sequence number of the batch to be ack-ed. - * @return - */ - private def sendAck(client: SparkFlumeProtocol.Callback, seq: CharSequence): Unit = { - logDebug("Sending ack for sequence number: " + seq) - client.ack(seq) - logDebug("Ack sent for sequence number: " + seq) - } - - /** - * This method sends a Nack if a batch was received to the client with the given sequence - * number. Any exceptions thrown by the RPC call is simply thrown out as is - no effort is made - * to handle it. - * @param batchReceived true if a batch was received. If this is false, no nack is sent - * @param client The client to which the nack should be sent - * @param seq The sequence number of the batch that is being nack-ed. - */ - private def sendNack(batchReceived: Boolean, client: SparkFlumeProtocol.Callback, - seq: CharSequence): Unit = { - if (batchReceived) { - // Let Flume know that the events need to be pushed back into the channel. - logDebug("Sending nack for sequence number: " + seq) - client.nack(seq) // If the agent is down, even this could fail and throw - logDebug("Nack sent for sequence number: " + seq) - } - } - - /** - * Utility method to convert [[SparkSinkEvent]]s to [[SparkFlumeEvent]]s - * @param events - Events to convert to SparkFlumeEvents - * @return - The SparkFlumeEvent generated from SparkSinkEvent - */ - private def toSparkFlumeEvents(events: java.util.List[SparkSinkEvent]): - ArrayBuffer[SparkFlumeEvent] = { - // Convert each Flume event to a serializable SparkFlumeEvent - val buffer = new ArrayBuffer[SparkFlumeEvent](events.size()) - var j = 0 - while (j < events.size()) { - val event = events.get(j) - val sparkFlumeEvent = new SparkFlumeEvent() - sparkFlumeEvent.event.setBody(event.getBody) - sparkFlumeEvent.event.setHeaders(event.getHeaders) - buffer += sparkFlumeEvent - j += 1 - } - buffer - } -} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala deleted file mode 100644 index 74bd0165c6209..0000000000000 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ /dev/null @@ -1,205 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.flume - -import java.io.{Externalizable, ObjectInput, ObjectOutput} -import java.net.InetSocketAddress -import java.nio.ByteBuffer -import java.util.concurrent.Executors - -import scala.collection.JavaConverters._ -import scala.reflect.ClassTag - -import org.apache.avro.ipc.NettyServer -import org.apache.avro.ipc.specific.SpecificResponder -import org.apache.flume.source.avro.{AvroFlumeEvent, AvroSourceProtocol, Status} -import org.jboss.netty.channel.{ChannelPipeline, ChannelPipelineFactory, Channels} -import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory -import org.jboss.netty.handler.codec.compression._ - -import org.apache.spark.Logging -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.dstream._ -import org.apache.spark.streaming.receiver.Receiver -import org.apache.spark.util.Utils - -private[streaming] -class FlumeInputDStream[T: ClassTag]( - _ssc: StreamingContext, - host: String, - port: Int, - storageLevel: StorageLevel, - enableDecompression: Boolean -) extends ReceiverInputDStream[SparkFlumeEvent](_ssc) { - - override def getReceiver(): Receiver[SparkFlumeEvent] = { - new FlumeReceiver(host, port, storageLevel, enableDecompression) - } -} - -/** - * A wrapper class for AvroFlumeEvent's with a custom serialization format. - * - * This is necessary because AvroFlumeEvent uses inner data structures - * which are not serializable. - */ -class SparkFlumeEvent() extends Externalizable { - var event: AvroFlumeEvent = new AvroFlumeEvent() - - /* De-serialize from bytes. */ - def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { - val bodyLength = in.readInt() - val bodyBuff = new Array[Byte](bodyLength) - in.readFully(bodyBuff) - - val numHeaders = in.readInt() - val headers = new java.util.HashMap[CharSequence, CharSequence] - - for (i <- 0 until numHeaders) { - val keyLength = in.readInt() - val keyBuff = new Array[Byte](keyLength) - in.readFully(keyBuff) - val key: String = Utils.deserialize(keyBuff) - - val valLength = in.readInt() - val valBuff = new Array[Byte](valLength) - in.readFully(valBuff) - val value: String = Utils.deserialize(valBuff) - - headers.put(key, value) - } - - event.setBody(ByteBuffer.wrap(bodyBuff)) - event.setHeaders(headers) - } - - /* Serialize to bytes. */ - def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { - val body = event.getBody - out.writeInt(body.remaining()) - Utils.writeByteBuffer(body, out) - - val numHeaders = event.getHeaders.size() - out.writeInt(numHeaders) - for ((k, v) <- event.getHeaders.asScala) { - val keyBuff = Utils.serialize(k.toString) - out.writeInt(keyBuff.length) - out.write(keyBuff) - val valBuff = Utils.serialize(v.toString) - out.writeInt(valBuff.length) - out.write(valBuff) - } - } -} - -private[streaming] object SparkFlumeEvent { - def fromAvroFlumeEvent(in: AvroFlumeEvent): SparkFlumeEvent = { - val event = new SparkFlumeEvent - event.event = in - event - } -} - -/** A simple server that implements Flume's Avro protocol. */ -private[streaming] -class FlumeEventServer(receiver: FlumeReceiver) extends AvroSourceProtocol { - override def append(event: AvroFlumeEvent): Status = { - receiver.store(SparkFlumeEvent.fromAvroFlumeEvent(event)) - Status.OK - } - - override def appendBatch(events: java.util.List[AvroFlumeEvent]): Status = { - events.asScala.foreach(event => receiver.store(SparkFlumeEvent.fromAvroFlumeEvent(event))) - Status.OK - } -} - -/** A NetworkReceiver which listens for events using the - * Flume Avro interface. */ -private[streaming] -class FlumeReceiver( - host: String, - port: Int, - storageLevel: StorageLevel, - enableDecompression: Boolean - ) extends Receiver[SparkFlumeEvent](storageLevel) with Logging { - - lazy val responder = new SpecificResponder( - classOf[AvroSourceProtocol], new FlumeEventServer(this)) - var server: NettyServer = null - - private def initServer() = { - if (enableDecompression) { - val channelFactory = new NioServerSocketChannelFactory(Executors.newCachedThreadPool(), - Executors.newCachedThreadPool()) - val channelPipelineFactory = new CompressionChannelPipelineFactory() - - new NettyServer( - responder, - new InetSocketAddress(host, port), - channelFactory, - channelPipelineFactory, - null) - } else { - new NettyServer(responder, new InetSocketAddress(host, port)) - } - } - - def onStart() { - synchronized { - if (server == null) { - server = initServer() - server.start() - } else { - logWarning("Flume receiver being asked to start more then once with out close") - } - } - logInfo("Flume receiver started") - } - - def onStop() { - synchronized { - if (server != null) { - server.close() - server = null - } - } - logInfo("Flume receiver stopped") - } - - override def preferredLocation: Option[String] = Option(host) - - /** A Netty Pipeline factory that will decompress incoming data from - * and the Netty client and compress data going back to the client. - * - * The compression on the return is required because Flume requires - * a successful response to indicate it can remove the event/batch - * from the configured channel - */ - private[streaming] - class CompressionChannelPipelineFactory extends ChannelPipelineFactory { - def getPipeline(): ChannelPipeline = { - val pipeline = Channels.pipeline() - val encoder = new ZlibEncoder(6) - pipeline.addFirst("deflater", encoder) - pipeline.addFirst("inflater", new ZlibDecoder()) - pipeline - } - } -} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala deleted file mode 100644 index d9c25e86540db..0000000000000 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.streaming.flume - - -import java.net.InetSocketAddress -import java.util.concurrent.{Executors, LinkedBlockingQueue, TimeUnit} - -import scala.collection.JavaConverters._ -import scala.reflect.ClassTag - -import com.google.common.util.concurrent.ThreadFactoryBuilder -import org.apache.avro.ipc.NettyTransceiver -import org.apache.avro.ipc.specific.SpecificRequestor -import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory - -import org.apache.spark.Logging -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.dstream.ReceiverInputDStream -import org.apache.spark.streaming.flume.sink._ -import org.apache.spark.streaming.receiver.Receiver - -/** - * A [[ReceiverInputDStream]] that can be used to read data from several Flume agents running - * [[org.apache.spark.streaming.flume.sink.SparkSink]]s. - * @param _ssc Streaming context that will execute this input stream - * @param addresses List of addresses at which SparkSinks are listening - * @param maxBatchSize Maximum size of a batch - * @param parallelism Number of parallel connections to open - * @param storageLevel The storage level to use. - * @tparam T Class type of the object of this stream - */ -private[streaming] class FlumePollingInputDStream[T: ClassTag]( - _ssc: StreamingContext, - val addresses: Seq[InetSocketAddress], - val maxBatchSize: Int, - val parallelism: Int, - storageLevel: StorageLevel - ) extends ReceiverInputDStream[SparkFlumeEvent](_ssc) { - - override def getReceiver(): Receiver[SparkFlumeEvent] = { - new FlumePollingReceiver(addresses, maxBatchSize, parallelism, storageLevel) - } -} - -private[streaming] class FlumePollingReceiver( - addresses: Seq[InetSocketAddress], - maxBatchSize: Int, - parallelism: Int, - storageLevel: StorageLevel - ) extends Receiver[SparkFlumeEvent](storageLevel) with Logging { - - lazy val channelFactoryExecutor = - Executors.newCachedThreadPool(new ThreadFactoryBuilder().setDaemon(true). - setNameFormat("Flume Receiver Channel Thread - %d").build()) - - lazy val channelFactory = - new NioClientSocketChannelFactory(channelFactoryExecutor, channelFactoryExecutor) - - lazy val receiverExecutor = Executors.newFixedThreadPool(parallelism, - new ThreadFactoryBuilder().setDaemon(true).setNameFormat("Flume Receiver Thread - %d").build()) - - private lazy val connections = new LinkedBlockingQueue[FlumeConnection]() - - override def onStart(): Unit = { - // Create the connections to each Flume agent. - addresses.foreach(host => { - val transceiver = new NettyTransceiver(host, channelFactory) - val client = SpecificRequestor.getClient(classOf[SparkFlumeProtocol.Callback], transceiver) - connections.add(new FlumeConnection(transceiver, client)) - }) - for (i <- 0 until parallelism) { - logInfo("Starting Flume Polling Receiver worker threads..") - // Threads that pull data from Flume. - receiverExecutor.submit(new FlumeBatchFetcher(this)) - } - } - - override def onStop(): Unit = { - logInfo("Shutting down Flume Polling Receiver") - receiverExecutor.shutdown() - // Wait upto a minute for the threads to die - if (!receiverExecutor.awaitTermination(60, TimeUnit.SECONDS)) { - receiverExecutor.shutdownNow() - } - connections.asScala.foreach(_.transceiver.close()) - channelFactory.releaseExternalResources() - } - - private[flume] def getConnections: LinkedBlockingQueue[FlumeConnection] = { - this.connections - } - - private[flume] def getMaxBatchSize: Int = { - this.maxBatchSize - } -} - -/** - * A wrapper around the transceiver and the Avro IPC API. - * @param transceiver The transceiver to use for communication with Flume - * @param client The client that the callbacks are received on. - */ -private[flume] class FlumeConnection(val transceiver: NettyTransceiver, - val client: SparkFlumeProtocol.Callback) - - - diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala deleted file mode 100644 index 3f87ce46e5952..0000000000000 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala +++ /dev/null @@ -1,117 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.flume - -import java.net.{InetSocketAddress, ServerSocket} -import java.nio.ByteBuffer -import java.util.{List => JList} -import java.util.Collections - -import scala.collection.JavaConverters._ - -import com.google.common.base.Charsets.UTF_8 -import org.apache.avro.ipc.NettyTransceiver -import org.apache.avro.ipc.specific.SpecificRequestor -import org.apache.commons.lang3.RandomUtils -import org.apache.flume.source.avro -import org.apache.flume.source.avro.{AvroFlumeEvent, AvroSourceProtocol} -import org.jboss.netty.channel.ChannelPipeline -import org.jboss.netty.channel.socket.SocketChannel -import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory -import org.jboss.netty.handler.codec.compression.{ZlibDecoder, ZlibEncoder} - -import org.apache.spark.util.Utils -import org.apache.spark.SparkConf - -/** - * Share codes for Scala and Python unit tests - */ -private[flume] class FlumeTestUtils { - - private var transceiver: NettyTransceiver = null - - private val testPort: Int = findFreePort() - - def getTestPort(): Int = testPort - - /** Find a free port */ - private def findFreePort(): Int = { - val candidatePort = RandomUtils.nextInt(1024, 65536) - Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { - val socket = new ServerSocket(trialPort) - socket.close() - (null, trialPort) - }, new SparkConf())._2 - } - - /** Send data to the flume receiver */ - def writeInput(input: JList[String], enableCompression: Boolean): Unit = { - val testAddress = new InetSocketAddress("localhost", testPort) - - val inputEvents = input.asScala.map { item => - val event = new AvroFlumeEvent - event.setBody(ByteBuffer.wrap(item.getBytes(UTF_8))) - event.setHeaders(Collections.singletonMap("test", "header")) - event - } - - // if last attempted transceiver had succeeded, close it - close() - - // Create transceiver - transceiver = { - if (enableCompression) { - new NettyTransceiver(testAddress, new CompressionChannelFactory(6)) - } else { - new NettyTransceiver(testAddress) - } - } - - // Create Avro client with the transceiver - val client = SpecificRequestor.getClient(classOf[AvroSourceProtocol], transceiver) - if (client == null) { - throw new AssertionError("Cannot create client") - } - - // Send data - val status = client.appendBatch(inputEvents.asJava) - if (status != avro.Status.OK) { - throw new AssertionError("Sent events unsuccessfully") - } - } - - def close(): Unit = { - if (transceiver != null) { - transceiver.close() - transceiver = null - } - } - - /** Class to create socket channel with compression */ - private class CompressionChannelFactory(compressionLevel: Int) - extends NioClientSocketChannelFactory { - - override def newChannel(pipeline: ChannelPipeline): SocketChannel = { - val encoder = new ZlibEncoder(compressionLevel) - pipeline.addFirst("deflater", encoder) - pipeline.addFirst("inflater", new ZlibDecoder()) - super.newChannel(pipeline) - } - } - -} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala deleted file mode 100644 index 3e3ed712f0dbf..0000000000000 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala +++ /dev/null @@ -1,311 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.flume - -import java.io.{ByteArrayOutputStream, DataOutputStream} -import java.net.InetSocketAddress -import java.util.{List => JList, Map => JMap} - -import scala.collection.JavaConverters._ - -import org.apache.spark.api.java.function.PairFunction -import org.apache.spark.api.python.PythonRDD -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaReceiverInputDStream, JavaStreamingContext} -import org.apache.spark.streaming.dstream.ReceiverInputDStream - -object FlumeUtils { - private val DEFAULT_POLLING_PARALLELISM = 5 - private val DEFAULT_POLLING_BATCH_SIZE = 1000 - - /** - * Create a input stream from a Flume source. - * @param ssc StreamingContext object - * @param hostname Hostname of the slave machine to which the flume data will be sent - * @param port Port of the slave machine to which the flume data will be sent - * @param storageLevel Storage level to use for storing the received objects - */ - def createStream ( - ssc: StreamingContext, - hostname: String, - port: Int, - storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 - ): ReceiverInputDStream[SparkFlumeEvent] = { - createStream(ssc, hostname, port, storageLevel, false) - } - - /** - * Create a input stream from a Flume source. - * @param ssc StreamingContext object - * @param hostname Hostname of the slave machine to which the flume data will be sent - * @param port Port of the slave machine to which the flume data will be sent - * @param storageLevel Storage level to use for storing the received objects - * @param enableDecompression should netty server decompress input stream - */ - def createStream ( - ssc: StreamingContext, - hostname: String, - port: Int, - storageLevel: StorageLevel, - enableDecompression: Boolean - ): ReceiverInputDStream[SparkFlumeEvent] = { - val inputStream = new FlumeInputDStream[SparkFlumeEvent]( - ssc, hostname, port, storageLevel, enableDecompression) - - inputStream - } - - /** - * Creates a input stream from a Flume source. - * Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2. - * @param hostname Hostname of the slave machine to which the flume data will be sent - * @param port Port of the slave machine to which the flume data will be sent - */ - def createStream( - jssc: JavaStreamingContext, - hostname: String, - port: Int - ): JavaReceiverInputDStream[SparkFlumeEvent] = { - createStream(jssc.ssc, hostname, port) - } - - /** - * Creates a input stream from a Flume source. - * @param hostname Hostname of the slave machine to which the flume data will be sent - * @param port Port of the slave machine to which the flume data will be sent - * @param storageLevel Storage level to use for storing the received objects - */ - def createStream( - jssc: JavaStreamingContext, - hostname: String, - port: Int, - storageLevel: StorageLevel - ): JavaReceiverInputDStream[SparkFlumeEvent] = { - createStream(jssc.ssc, hostname, port, storageLevel, false) - } - - /** - * Creates a input stream from a Flume source. - * @param hostname Hostname of the slave machine to which the flume data will be sent - * @param port Port of the slave machine to which the flume data will be sent - * @param storageLevel Storage level to use for storing the received objects - * @param enableDecompression should netty server decompress input stream - */ - def createStream( - jssc: JavaStreamingContext, - hostname: String, - port: Int, - storageLevel: StorageLevel, - enableDecompression: Boolean - ): JavaReceiverInputDStream[SparkFlumeEvent] = { - createStream(jssc.ssc, hostname, port, storageLevel, enableDecompression) - } - - /** - * Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. - * This stream will poll the sink for data and will pull events as they are available. - * This stream will use a batch size of 1000 events and run 5 threads to pull data. - * @param hostname Address of the host on which the Spark Sink is running - * @param port Port of the host at which the Spark Sink is listening - * @param storageLevel Storage level to use for storing the received objects - */ - def createPollingStream( - ssc: StreamingContext, - hostname: String, - port: Int, - storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 - ): ReceiverInputDStream[SparkFlumeEvent] = { - createPollingStream(ssc, Seq(new InetSocketAddress(hostname, port)), storageLevel) - } - - /** - * Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. - * This stream will poll the sink for data and will pull events as they are available. - * This stream will use a batch size of 1000 events and run 5 threads to pull data. - * @param addresses List of InetSocketAddresses representing the hosts to connect to. - * @param storageLevel Storage level to use for storing the received objects - */ - def createPollingStream( - ssc: StreamingContext, - addresses: Seq[InetSocketAddress], - storageLevel: StorageLevel - ): ReceiverInputDStream[SparkFlumeEvent] = { - createPollingStream(ssc, addresses, storageLevel, - DEFAULT_POLLING_BATCH_SIZE, DEFAULT_POLLING_PARALLELISM) - } - - /** - * Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. - * This stream will poll the sink for data and will pull events as they are available. - * @param addresses List of InetSocketAddresses representing the hosts to connect to. - * @param maxBatchSize Maximum number of events to be pulled from the Spark sink in a - * single RPC call - * @param parallelism Number of concurrent requests this stream should send to the sink. Note - * that having a higher number of requests concurrently being pulled will - * result in this stream using more threads - * @param storageLevel Storage level to use for storing the received objects - */ - def createPollingStream( - ssc: StreamingContext, - addresses: Seq[InetSocketAddress], - storageLevel: StorageLevel, - maxBatchSize: Int, - parallelism: Int - ): ReceiverInputDStream[SparkFlumeEvent] = { - new FlumePollingInputDStream[SparkFlumeEvent](ssc, addresses, maxBatchSize, - parallelism, storageLevel) - } - - /** - * Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. - * This stream will poll the sink for data and will pull events as they are available. - * This stream will use a batch size of 1000 events and run 5 threads to pull data. - * @param hostname Hostname of the host on which the Spark Sink is running - * @param port Port of the host at which the Spark Sink is listening - */ - def createPollingStream( - jssc: JavaStreamingContext, - hostname: String, - port: Int - ): JavaReceiverInputDStream[SparkFlumeEvent] = { - createPollingStream(jssc, hostname, port, StorageLevel.MEMORY_AND_DISK_SER_2) - } - - /** - * Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. - * This stream will poll the sink for data and will pull events as they are available. - * This stream will use a batch size of 1000 events and run 5 threads to pull data. - * @param hostname Hostname of the host on which the Spark Sink is running - * @param port Port of the host at which the Spark Sink is listening - * @param storageLevel Storage level to use for storing the received objects - */ - def createPollingStream( - jssc: JavaStreamingContext, - hostname: String, - port: Int, - storageLevel: StorageLevel - ): JavaReceiverInputDStream[SparkFlumeEvent] = { - createPollingStream(jssc, Array(new InetSocketAddress(hostname, port)), storageLevel) - } - - /** - * Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. - * This stream will poll the sink for data and will pull events as they are available. - * This stream will use a batch size of 1000 events and run 5 threads to pull data. - * @param addresses List of InetSocketAddresses on which the Spark Sink is running. - * @param storageLevel Storage level to use for storing the received objects - */ - def createPollingStream( - jssc: JavaStreamingContext, - addresses: Array[InetSocketAddress], - storageLevel: StorageLevel - ): JavaReceiverInputDStream[SparkFlumeEvent] = { - createPollingStream(jssc, addresses, storageLevel, - DEFAULT_POLLING_BATCH_SIZE, DEFAULT_POLLING_PARALLELISM) - } - - /** - * Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. - * This stream will poll the sink for data and will pull events as they are available. - * @param addresses List of InetSocketAddresses on which the Spark Sink is running - * @param maxBatchSize The maximum number of events to be pulled from the Spark sink in a - * single RPC call - * @param parallelism Number of concurrent requests this stream should send to the sink. Note - * that having a higher number of requests concurrently being pulled will - * result in this stream using more threads - * @param storageLevel Storage level to use for storing the received objects - */ - def createPollingStream( - jssc: JavaStreamingContext, - addresses: Array[InetSocketAddress], - storageLevel: StorageLevel, - maxBatchSize: Int, - parallelism: Int - ): JavaReceiverInputDStream[SparkFlumeEvent] = { - createPollingStream(jssc.ssc, addresses, storageLevel, maxBatchSize, parallelism) - } -} - -/** - * This is a helper class that wraps the methods in FlumeUtils into more Python-friendly class and - * function so that it can be easily instantiated and called from Python's FlumeUtils. - */ -private[flume] class FlumeUtilsPythonHelper { - - def createStream( - jssc: JavaStreamingContext, - hostname: String, - port: Int, - storageLevel: StorageLevel, - enableDecompression: Boolean - ): JavaPairDStream[Array[Byte], Array[Byte]] = { - val dstream = FlumeUtils.createStream(jssc, hostname, port, storageLevel, enableDecompression) - FlumeUtilsPythonHelper.toByteArrayPairDStream(dstream) - } - - def createPollingStream( - jssc: JavaStreamingContext, - hosts: JList[String], - ports: JList[Int], - storageLevel: StorageLevel, - maxBatchSize: Int, - parallelism: Int - ): JavaPairDStream[Array[Byte], Array[Byte]] = { - assert(hosts.size() == ports.size()) - val addresses = hosts.asScala.zip(ports.asScala).map { - case (host, port) => new InetSocketAddress(host, port) - } - val dstream = FlumeUtils.createPollingStream( - jssc.ssc, addresses, storageLevel, maxBatchSize, parallelism) - FlumeUtilsPythonHelper.toByteArrayPairDStream(dstream) - } - -} - -private object FlumeUtilsPythonHelper { - - private def stringMapToByteArray(map: JMap[CharSequence, CharSequence]): Array[Byte] = { - val byteStream = new ByteArrayOutputStream() - val output = new DataOutputStream(byteStream) - try { - output.writeInt(map.size) - map.asScala.foreach { kv => - PythonRDD.writeUTF(kv._1.toString, output) - PythonRDD.writeUTF(kv._2.toString, output) - } - byteStream.toByteArray - } - finally { - output.close() - } - } - - private def toByteArrayPairDStream(dstream: JavaReceiverInputDStream[SparkFlumeEvent]): - JavaPairDStream[Array[Byte], Array[Byte]] = { - dstream.mapToPair(new PairFunction[SparkFlumeEvent, Array[Byte], Array[Byte]] { - override def call(sparkEvent: SparkFlumeEvent): (Array[Byte], Array[Byte]) = { - val event = sparkEvent.event - val byteBuffer = event.getBody - val body = new Array[Byte](byteBuffer.remaining()) - byteBuffer.get(body) - (stringMapToByteArray(event.getHeaders), body) - } - }) - } -} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala deleted file mode 100644 index 9515d07c5ee5b..0000000000000 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala +++ /dev/null @@ -1,208 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.flume - -import java.util.{Collections, List => JList, Map => JMap} -import java.util.concurrent._ - -import scala.collection.mutable.ArrayBuffer - -import com.google.common.base.Charsets.UTF_8 -import org.apache.flume.event.EventBuilder -import org.apache.flume.Context -import org.apache.flume.channel.MemoryChannel -import org.apache.flume.conf.Configurables - -import org.apache.spark.streaming.flume.sink.{SparkSink, SparkSinkConfig} - -/** - * Share codes for Scala and Python unit tests - */ -private[flume] class PollingFlumeTestUtils { - - private val batchCount = 5 - val eventsPerBatch = 100 - private val totalEventsPerChannel = batchCount * eventsPerBatch - private val channelCapacity = 5000 - - def getTotalEvents: Int = totalEventsPerChannel * channels.size - - private val channels = new ArrayBuffer[MemoryChannel] - private val sinks = new ArrayBuffer[SparkSink] - - /** - * Start a sink and return the port of this sink - */ - def startSingleSink(): Int = { - channels.clear() - sinks.clear() - - // Start the channel and sink. - val context = new Context() - context.put("capacity", channelCapacity.toString) - context.put("transactionCapacity", "1000") - context.put("keep-alive", "0") - val channel = new MemoryChannel() - Configurables.configure(channel, context) - - val sink = new SparkSink() - context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) - Configurables.configure(sink, context) - sink.setChannel(channel) - sink.start() - - channels += (channel) - sinks += sink - - sink.getPort() - } - - /** - * Start 2 sinks and return the ports - */ - def startMultipleSinks(): Seq[Int] = { - channels.clear() - sinks.clear() - - // Start the channel and sink. - val context = new Context() - context.put("capacity", channelCapacity.toString) - context.put("transactionCapacity", "1000") - context.put("keep-alive", "0") - val channel = new MemoryChannel() - Configurables.configure(channel, context) - - val channel2 = new MemoryChannel() - Configurables.configure(channel2, context) - - val sink = new SparkSink() - context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) - Configurables.configure(sink, context) - sink.setChannel(channel) - sink.start() - - val sink2 = new SparkSink() - context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) - Configurables.configure(sink2, context) - sink2.setChannel(channel2) - sink2.start() - - sinks += sink - sinks += sink2 - channels += channel - channels += channel2 - - sinks.map(_.getPort()) - } - - /** - * Send data and wait until all data has been received - */ - def sendDatAndEnsureAllDataHasBeenReceived(): Unit = { - val executor = Executors.newCachedThreadPool() - val executorCompletion = new ExecutorCompletionService[Void](executor) - - val latch = new CountDownLatch(batchCount * channels.size) - sinks.foreach(_.countdownWhenBatchReceived(latch)) - - channels.foreach(channel => { - executorCompletion.submit(new TxnSubmitter(channel)) - }) - - for (i <- 0 until channels.size) { - executorCompletion.take() - } - - latch.await(15, TimeUnit.SECONDS) // Ensure all data has been received. - } - - /** - * A Python-friendly method to assert the output - */ - def assertOutput( - outputHeaders: JList[JMap[String, String]], outputBodies: JList[String]): Unit = { - require(outputHeaders.size == outputBodies.size) - val eventSize = outputHeaders.size - if (eventSize != totalEventsPerChannel * channels.size) { - throw new AssertionError( - s"Expected ${totalEventsPerChannel * channels.size} events, but was $eventSize") - } - var counter = 0 - for (k <- 0 until channels.size; i <- 0 until totalEventsPerChannel) { - val eventBodyToVerify = s"${channels(k).getName}-$i" - val eventHeaderToVerify: JMap[String, String] = Collections.singletonMap(s"test-$i", "header") - var found = false - var j = 0 - while (j < eventSize && !found) { - if (eventBodyToVerify == outputBodies.get(j) && - eventHeaderToVerify == outputHeaders.get(j)) { - found = true - counter += 1 - } - j += 1 - } - } - if (counter != totalEventsPerChannel * channels.size) { - throw new AssertionError( - s"111 Expected ${totalEventsPerChannel * channels.size} events, but was $counter") - } - } - - def assertChannelsAreEmpty(): Unit = { - channels.foreach(assertChannelIsEmpty) - } - - private def assertChannelIsEmpty(channel: MemoryChannel): Unit = { - val queueRemaining = channel.getClass.getDeclaredField("queueRemaining") - queueRemaining.setAccessible(true) - val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits") - if (m.invoke(queueRemaining.get(channel)).asInstanceOf[Int] != 5000) { - throw new AssertionError(s"Channel ${channel.getName} is not empty") - } - } - - def close(): Unit = { - sinks.foreach(_.stop()) - sinks.clear() - channels.foreach(_.stop()) - channels.clear() - } - - private class TxnSubmitter(channel: MemoryChannel) extends Callable[Void] { - override def call(): Void = { - var t = 0 - for (i <- 0 until batchCount) { - val tx = channel.getTransaction - tx.begin() - for (j <- 0 until eventsPerBatch) { - channel.put(EventBuilder.withBody(s"${channel.getName}-$t".getBytes(UTF_8), - Collections.singletonMap(s"test-$t", "header"))) - t += 1 - } - tx.commit() - tx.close() - Thread.sleep(500) // Allow some time for the events to reach - } - null - } - } - -} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/package-info.java b/external/flume/src/main/scala/org/apache/spark/streaming/flume/package-info.java deleted file mode 100644 index d31aa5f5c096c..0000000000000 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/package-info.java +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * Spark streaming receiver for Flume. - */ -package org.apache.spark.streaming.flume; \ No newline at end of file diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/package.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/package.scala deleted file mode 100644 index 9bfab68c4b8b7..0000000000000 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/package.scala +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming - -/** - * Spark streaming receiver for Flume. - */ -package object flume diff --git a/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java deleted file mode 100644 index cfedb5a042a35..0000000000000 --- a/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming; - -import org.apache.spark.SparkConf; -import org.apache.spark.streaming.api.java.JavaStreamingContext; -import org.junit.After; -import org.junit.Before; - -public abstract class LocalJavaStreamingContext { - - protected transient JavaStreamingContext ssc; - - @Before - public void setUp() { - SparkConf conf = new SparkConf() - .setMaster("local[2]") - .setAppName("test") - .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); - ssc = new JavaStreamingContext(conf, new Duration(1000)); - ssc.checkpoint("checkpoint"); - } - - @After - public void tearDown() { - ssc.stop(); - ssc = null; - } -} diff --git a/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumePollingStreamSuite.java b/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumePollingStreamSuite.java deleted file mode 100644 index 79c5b91654b42..0000000000000 --- a/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumePollingStreamSuite.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.flume; - -import java.net.InetSocketAddress; - -import org.apache.spark.storage.StorageLevel; -import org.apache.spark.streaming.LocalJavaStreamingContext; - -import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; -import org.junit.Test; - -public class JavaFlumePollingStreamSuite extends LocalJavaStreamingContext { - @Test - public void testFlumeStream() { - // tests the API, does not actually test data receiving - InetSocketAddress[] addresses = new InetSocketAddress[] { - new InetSocketAddress("localhost", 12345) - }; - JavaReceiverInputDStream test1 = - FlumeUtils.createPollingStream(ssc, "localhost", 12345); - JavaReceiverInputDStream test2 = FlumeUtils.createPollingStream( - ssc, "localhost", 12345, StorageLevel.MEMORY_AND_DISK_SER_2()); - JavaReceiverInputDStream test3 = FlumeUtils.createPollingStream( - ssc, addresses, StorageLevel.MEMORY_AND_DISK_SER_2()); - JavaReceiverInputDStream test4 = FlumeUtils.createPollingStream( - ssc, addresses, StorageLevel.MEMORY_AND_DISK_SER_2(), 100, 5); - } -} diff --git a/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumeStreamSuite.java b/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumeStreamSuite.java deleted file mode 100644 index 3b5e0c7746b2c..0000000000000 --- a/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumeStreamSuite.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.flume; - -import org.apache.spark.storage.StorageLevel; -import org.apache.spark.streaming.LocalJavaStreamingContext; - -import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; -import org.junit.Test; - -public class JavaFlumeStreamSuite extends LocalJavaStreamingContext { - @Test - public void testFlumeStream() { - // tests the API, does not actually test data receiving - JavaReceiverInputDStream test1 = FlumeUtils.createStream(ssc, "localhost", 12345); - JavaReceiverInputDStream test2 = FlumeUtils.createStream(ssc, "localhost", 12345, - StorageLevel.MEMORY_AND_DISK_SER_2()); - JavaReceiverInputDStream test3 = FlumeUtils.createStream(ssc, "localhost", 12345, - StorageLevel.MEMORY_AND_DISK_SER_2(), false); - } -} diff --git a/external/flume/src/test/resources/log4j.properties b/external/flume/src/test/resources/log4j.properties deleted file mode 100644 index 75e3b53a093f6..0000000000000 --- a/external/flume/src/test/resources/log4j.properties +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Set everything to be logged to the file target/unit-tests.log -log4j.rootCategory=INFO, file -log4j.appender.file=org.apache.log4j.FileAppender -log4j.appender.file.append=true -log4j.appender.file.file=target/unit-tests.log -log4j.appender.file.layout=org.apache.log4j.PatternLayout -log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n - -# Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN - diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala b/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala deleted file mode 100644 index c97a27ca7c7aa..0000000000000 --- a/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming - -import java.io.{IOException, ObjectInputStream} -import java.util.concurrent.ConcurrentLinkedQueue - -import scala.reflect.ClassTag - -import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.dstream.{DStream, ForEachDStream} -import org.apache.spark.util.Utils - -/** - * This is a output stream just for the testsuites. All the output is collected into a - * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint. - * - * The buffer contains a sequence of RDD's, each containing a sequence of items - */ -class TestOutputStream[T: ClassTag](parent: DStream[T], - val output: ConcurrentLinkedQueue[Seq[T]] = new ConcurrentLinkedQueue[Seq[T]]()) - extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { - val collected = rdd.collect() - output.add(collected) - }, false) { - - // This is to clear the output buffer every it is read from a checkpoint - @throws(classOf[IOException]) - private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException { - ois.defaultReadObject() - output.clear() - } -} diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala deleted file mode 100644 index 10dcbf98bc3b6..0000000000000 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.flume - -import java.net.InetSocketAddress -import java.util.concurrent.ConcurrentLinkedQueue - -import scala.collection.JavaConverters._ -import scala.concurrent.duration._ -import scala.language.postfixOps - -import org.scalatest.BeforeAndAfter -import org.scalatest.concurrent.Eventually._ - -import org.apache.spark.{Logging, SparkConf, SparkFunSuite} -import org.apache.spark.network.util.JavaUtils -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{Seconds, StreamingContext, TestOutputStream} -import org.apache.spark.streaming.dstream.ReceiverInputDStream -import org.apache.spark.util.{ManualClock, Utils} - -class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging { - - val maxAttempts = 5 - val batchDuration = Seconds(1) - - val conf = new SparkConf() - .setMaster("local[2]") - .setAppName(this.getClass.getSimpleName) - .set("spark.streaming.clock", "org.apache.spark.util.ManualClock") - - val utils = new PollingFlumeTestUtils - - test("flume polling test") { - testMultipleTimes(testFlumePolling) - } - - test("flume polling test multiple hosts") { - testMultipleTimes(testFlumePollingMultipleHost) - } - - /** - * Run the given test until no more java.net.BindException's are thrown. - * Do this only up to a certain attempt limit. - */ - private def testMultipleTimes(test: () => Unit): Unit = { - var testPassed = false - var attempt = 0 - while (!testPassed && attempt < maxAttempts) { - try { - test() - testPassed = true - } catch { - case e: Exception if Utils.isBindCollision(e) => - logWarning("Exception when running flume polling test: " + e) - attempt += 1 - } - } - assert(testPassed, s"Test failed after $attempt attempts!") - } - - private def testFlumePolling(): Unit = { - try { - val port = utils.startSingleSink() - - writeAndVerify(Seq(port)) - utils.assertChannelsAreEmpty() - } finally { - utils.close() - } - } - - private def testFlumePollingMultipleHost(): Unit = { - try { - val ports = utils.startMultipleSinks() - writeAndVerify(ports) - utils.assertChannelsAreEmpty() - } finally { - utils.close() - } - } - - def writeAndVerify(sinkPorts: Seq[Int]): Unit = { - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val addresses = sinkPorts.map(port => new InetSocketAddress("localhost", port)) - val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = - FlumeUtils.createPollingStream(ssc, addresses, StorageLevel.MEMORY_AND_DISK, - utils.eventsPerBatch, 5) - val outputQueue = new ConcurrentLinkedQueue[Seq[SparkFlumeEvent]] - val outputStream = new TestOutputStream(flumeStream, outputQueue) - outputStream.register() - - ssc.start() - try { - utils.sendDatAndEnsureAllDataHasBeenReceived() - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - clock.advance(batchDuration.milliseconds) - - // The eventually is required to ensure that all data in the batch has been processed. - eventually(timeout(10 seconds), interval(100 milliseconds)) { - val flattenOutput = outputQueue.asScala.toSeq.flatten - val headers = flattenOutput.map(_.event.getHeaders.asScala.map { - case (key, value) => (key.toString, value.toString) - }).map(_.asJava) - val bodies = flattenOutput.map(e => JavaUtils.bytesToString(e.event.getBody)) - utils.assertOutput(headers.asJava, bodies.asJava) - } - } finally { - ssc.stop() - } - } - -} diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala deleted file mode 100644 index 38208c651805f..0000000000000 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.flume - -import java.util.concurrent.ConcurrentLinkedQueue - -import scala.collection.JavaConverters._ -import scala.concurrent.duration._ -import scala.language.postfixOps - -import org.jboss.netty.channel.ChannelPipeline -import org.jboss.netty.channel.socket.SocketChannel -import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory -import org.jboss.netty.handler.codec.compression._ -import org.scalatest.{BeforeAndAfter, Matchers} -import org.scalatest.concurrent.Eventually._ - -import org.apache.spark.{Logging, SparkConf, SparkFunSuite} -import org.apache.spark.network.util.JavaUtils -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream} - -class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { - val conf = new SparkConf().setMaster("local[4]").setAppName("FlumeStreamSuite") - var ssc: StreamingContext = null - - test("flume input stream") { - testFlumeStream(testCompression = false) - } - - test("flume input compressed stream") { - testFlumeStream(testCompression = true) - } - - /** Run test on flume stream */ - private def testFlumeStream(testCompression: Boolean): Unit = { - val input = (1 to 100).map { _.toString } - val utils = new FlumeTestUtils - try { - val outputQueue = startContext(utils.getTestPort(), testCompression) - - eventually(timeout(10 seconds), interval(100 milliseconds)) { - utils.writeInput(input.asJava, testCompression) - } - - eventually(timeout(10 seconds), interval(100 milliseconds)) { - val outputEvents = outputQueue.asScala.toSeq.flatten.map { _.event } - outputEvents.foreach { - event => - event.getHeaders.get("test") should be("header") - } - val output = outputEvents.map(event => JavaUtils.bytesToString(event.getBody)) - output should be (input) - } - } finally { - if (ssc != null) { - ssc.stop() - } - utils.close() - } - } - - /** Setup and start the streaming context */ - private def startContext( - testPort: Int, testCompression: Boolean): (ConcurrentLinkedQueue[Seq[SparkFlumeEvent]]) = { - ssc = new StreamingContext(conf, Milliseconds(200)) - val flumeStream = FlumeUtils.createStream( - ssc, "localhost", testPort, StorageLevel.MEMORY_AND_DISK, testCompression) - val outputQueue = new ConcurrentLinkedQueue[Seq[SparkFlumeEvent]] - val outputStream = new TestOutputStream(flumeStream, outputQueue) - outputStream.register() - ssc.start() - outputQueue - } - - /** Class to create socket channel with compression */ - private class CompressionChannelFactory(compressionLevel: Int) - extends NioClientSocketChannelFactory { - - override def newChannel(pipeline: ChannelPipeline): SocketChannel = { - val encoder = new ZlibEncoder(compressionLevel) - pipeline.addFirst("deflater", encoder) - pipeline.addFirst("inflater", new ZlibDecoder()) - super.newChannel(pipeline) - } - } -} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala index 8a66621a3125c..726b5d8ec3d3b 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala @@ -167,7 +167,7 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { ): Either[Err, Map[TopicAndPartition, LeaderOffset]] = { getLeaderOffsets(topicAndPartitions, before, 1).right.map { r => r.map { kv => - // mapValues isnt serializable, see SI-7005 + // mapValues isn't serializable, see SI-7005 kv._1 -> kv._2.head } } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index 0cb875c9758f9..edaafb912c5c5 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -19,12 +19,12 @@ package org.apache.spark.streaming.kafka import java.io.OutputStream import java.lang.{Integer => JInt, Long => JLong} +import java.nio.charset.StandardCharsets import java.util.{List => JList, Map => JMap, Set => JSet} import scala.collection.JavaConverters._ import scala.reflect.ClassTag -import com.google.common.base.Charsets.UTF_8 import kafka.common.TopicAndPartition import kafka.message.MessageAndMetadata import kafka.serializer.{Decoder, DefaultDecoder, StringDecoder} @@ -615,7 +615,7 @@ object KafkaUtils { /** * This is a helper class that wraps the KafkaUtils.createStream() into more * Python-friendly class and function so that it can be easily - * instantiated and called from Python's KafkaUtils (see SPARK-6027). + * instantiated and called from Python's KafkaUtils. * * The zero-arg constructor helps instantiate this class from the Class object * classOf[KafkaUtilsPythonHelper].newInstance(), and the createStream() @@ -787,7 +787,7 @@ private object KafkaUtilsPythonHelper { def pickle(obj: Object, out: OutputStream, pickler: Pickler) { if (obj == this) { out.write(Opcodes.GLOBAL) - out.write(s"$module\nKafkaMessageAndMetadata\n".getBytes(UTF_8)) + out.write(s"$module\nKafkaMessageAndMetadata\n".getBytes(StandardCharsets.UTF_8)) } else { pickler.save(this) val msgAndMetaData = obj.asInstanceOf[PythonMessageAndMetadata] diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala index b5b76cb92d866..23b74da64237a 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala @@ -132,7 +132,7 @@ private[kinesis] object KinesisRecordProcessor extends Logging { * Retry the given amount of times with a random backoff time (millis) less than the * given maxBackOffMillis * - * @param expression expression to evalute + * @param expression expression to evaluate * @param numRetriesLeft number of retries left * @param maxBackOffMillis: max millis between retries * diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala index 0ace453ee9280..026387ed65d50 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -18,6 +18,7 @@ package org.apache.spark.streaming.kinesis import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ @@ -242,7 +243,7 @@ private[kinesis] class SimpleDataGenerator( val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]() data.foreach { num => val str = num.toString - val data = ByteBuffer.wrap(str.getBytes()) + val data = ByteBuffer.wrap(str.getBytes(StandardCharsets.UTF_8)) val putRecordRequest = new PutRecordRequest().withStreamName(streamName) .withData(data) .withPartitionKey(str) diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala index fdb270eaad8c9..0b455e574e6fa 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala @@ -17,6 +17,7 @@ package org.apache.spark.streaming.kinesis import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -51,7 +52,7 @@ private[kinesis] class KPLDataGenerator(regionName: String) extends KinesisDataG val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]() data.foreach { num => val str = num.toString - val data = ByteBuffer.wrap(str.getBytes()) + val data = ByteBuffer.wrap(str.getBytes(StandardCharsets.UTF_8)) val future = producer.addUserRecord(streamName, str, data) val kinesisCallBack = new FutureCallback[UserRecordResult]() { override def onFailure(t: Throwable): Unit = {} // do nothing diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala index 2555332d222da..905c33834df16 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala @@ -122,7 +122,7 @@ abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean) testIsBlockValid = true) } - testIfEnabled("Test whether RDD is valid after removing blocks from block anager") { + testIfEnabled("Test whether RDD is valid after removing blocks from block manager") { testRDD(numPartitions = 2, numPartitionsInBM = 2, numPartitionsInKinesis = 2, testBlockRemove = true) } diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala index fd15b6ccdc889..deac9090e2f48 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -194,7 +194,7 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft verify(checkpointerMock, times(1)).checkpoint() } - test("retry failed after exhausing all retries") { + test("retry failed after exhausting all retries") { val expectedErrorMessage = "final try error message" when(checkpointerMock.checkpoint()) .thenThrow(new ThrottlingException("error message")) diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index ca5d13da46e99..4460b6bccaa81 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -180,17 +180,20 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun Seconds(10), StorageLevel.MEMORY_ONLY, awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) - val collected = new mutable.HashSet[Int] with mutable.SynchronizedSet[Int] + val collected = new mutable.HashSet[Int] stream.map { bytes => new String(bytes).toInt }.foreachRDD { rdd => - collected ++= rdd.collect() - logInfo("Collected = " + collected.mkString(", ")) + collected.synchronized { + collected ++= rdd.collect() + logInfo("Collected = " + collected.mkString(", ")) + } } ssc.start() val testData = 1 to 10 eventually(timeout(120 seconds), interval(10 second)) { testUtils.pushData(testData, aggregateTestData) - assert(collected === testData.toSet, "\nData received does not match data sent") + assert(collected.synchronized { collected === testData.toSet }, + "\nData received does not match data sent") } ssc.stop(stopSparkContext = false) } @@ -205,10 +208,12 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun stream shouldBe a [ReceiverInputDStream[_]] - val collected = new mutable.HashSet[Int] with mutable.SynchronizedSet[Int] + val collected = new mutable.HashSet[Int] stream.foreachRDD { rdd => - collected ++= rdd.collect() - logInfo("Collected = " + collected.mkString(", ")) + collected.synchronized { + collected ++= rdd.collect() + logInfo("Collected = " + collected.mkString(", ")) + } } ssc.start() @@ -216,7 +221,8 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun eventually(timeout(120 seconds), interval(10 second)) { testUtils.pushData(testData, aggregateTestData) val modData = testData.map(_ + 5) - assert(collected === modData.toSet, "\nData received does not match data sent") + assert(collected.synchronized { collected === modData.toSet }, + "\nData received does not match data sent") } ssc.stop(stopSparkContext = false) } diff --git a/external/mqtt-assembly/pom.xml b/external/mqtt-assembly/pom.xml deleted file mode 100644 index ac2a3f65ed2f5..0000000000000 --- a/external/mqtt-assembly/pom.xml +++ /dev/null @@ -1,175 +0,0 @@ - - - - - 4.0.0 - - org.apache.spark - spark-parent_2.11 - 2.0.0-SNAPSHOT - ../../pom.xml - - - org.apache.spark - spark-streaming-mqtt-assembly_2.11 - jar - Spark Project External MQTT Assembly - http://spark.apache.org/ - - - streaming-mqtt-assembly - - - - - org.apache.spark - spark-streaming-mqtt_${scala.binary.version} - ${project.version} - - - org.apache.spark - spark-streaming_${scala.binary.version} - ${project.version} - provided - - - - commons-lang - commons-lang - provided - - - com.google.protobuf - protobuf-java - provided - - - com.sun.jersey - jersey-server - provided - - - com.sun.jersey - jersey-core - provided - - - org.apache.hadoop - hadoop-client - provided - - - org.apache.avro - avro-mapred - ${avro.mapred.classifier} - provided - - - org.apache.curator - curator-recipes - provided - - - org.apache.zookeeper - zookeeper - provided - - - log4j - log4j - provided - - - net.java.dev.jets3t - jets3t - provided - - - org.scala-lang - scala-library - provided - - - org.slf4j - slf4j-api - provided - - - org.slf4j - slf4j-log4j12 - provided - - - org.xerial.snappy - snappy-java - provided - - - - - target/scala-${scala.binary.version}/classes - target/scala-${scala.binary.version}/test-classes - - - org.apache.maven.plugins - maven-shade-plugin - - false - - - *:* - - - - - *:* - - META-INF/*.SF - META-INF/*.DSA - META-INF/*.RSA - - - - - - - package - - shade - - - - - - reference.conf - - - log4j.properties - - - - - - - - - - - diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml deleted file mode 100644 index d0d968782c7f1..0000000000000 --- a/external/mqtt/pom.xml +++ /dev/null @@ -1,104 +0,0 @@ - - - - - 4.0.0 - - org.apache.spark - spark-parent_2.11 - 2.0.0-SNAPSHOT - ../../pom.xml - - - org.apache.spark - spark-streaming-mqtt_2.11 - - streaming-mqtt - - jar - Spark Project External MQTT - http://spark.apache.org/ - - - - org.apache.spark - spark-streaming_${scala.binary.version} - ${project.version} - provided - - - org.apache.spark - spark-core_${scala.binary.version} - ${project.version} - test-jar - test - - - org.eclipse.paho - org.eclipse.paho.client.mqttv3 - 1.0.2 - - - org.scalacheck - scalacheck_${scala.binary.version} - test - - - org.apache.activemq - activemq-core - 5.7.0 - test - - - org.apache.spark - spark-test-tags_${scala.binary.version} - - - - target/scala-${scala.binary.version}/classes - target/scala-${scala.binary.version}/test-classes - - - - - org.apache.maven.plugins - maven-assembly-plugin - - - test-jar-with-dependencies - package - - single - - - - spark-streaming-mqtt-test-${project.version} - ${project.build.directory}/scala-${scala.binary.version}/ - false - - false - - src/main/assembly/assembly.xml - - - - - - - - diff --git a/external/mqtt/src/main/assembly/assembly.xml b/external/mqtt/src/main/assembly/assembly.xml deleted file mode 100644 index c110b01b34e10..0000000000000 --- a/external/mqtt/src/main/assembly/assembly.xml +++ /dev/null @@ -1,44 +0,0 @@ - - - test-jar-with-dependencies - - jar - - false - - - - ${project.build.directory}/scala-${scala.binary.version}/test-classes - - - - - - - true - test - true - - org.apache.hadoop:*:jar - org.apache.zookeeper:*:jar - org.apache.avro:*:jar - - - - - diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala deleted file mode 100644 index 079bd8a9a87ea..0000000000000 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.mqtt - -import org.eclipse.paho.client.mqttv3.IMqttDeliveryToken -import org.eclipse.paho.client.mqttv3.MqttCallback -import org.eclipse.paho.client.mqttv3.MqttClient -import org.eclipse.paho.client.mqttv3.MqttMessage -import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence - -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.dstream._ -import org.apache.spark.streaming.receiver.Receiver - -/** - * Input stream that subscribe messages from a Mqtt Broker. - * Uses eclipse paho as MqttClient http://www.eclipse.org/paho/ - * @param brokerUrl Url of remote mqtt publisher - * @param topic topic name to subscribe to - * @param storageLevel RDD storage level. - */ - -private[streaming] -class MQTTInputDStream( - _ssc: StreamingContext, - brokerUrl: String, - topic: String, - storageLevel: StorageLevel - ) extends ReceiverInputDStream[String](_ssc) { - - private[streaming] override def name: String = s"MQTT stream [$id]" - - def getReceiver(): Receiver[String] = { - new MQTTReceiver(brokerUrl, topic, storageLevel) - } -} - -private[streaming] -class MQTTReceiver( - brokerUrl: String, - topic: String, - storageLevel: StorageLevel - ) extends Receiver[String](storageLevel) { - - def onStop() { - - } - - def onStart() { - - // Set up persistence for messages - val persistence = new MemoryPersistence() - - // Initializing Mqtt Client specifying brokerUrl, clientID and MqttClientPersistance - val client = new MqttClient(brokerUrl, MqttClient.generateClientId(), persistence) - - // Callback automatically triggers as and when new message arrives on specified topic - val callback = new MqttCallback() { - - // Handles Mqtt message - override def messageArrived(topic: String, message: MqttMessage) { - store(new String(message.getPayload(), "utf-8")) - } - - override def deliveryComplete(token: IMqttDeliveryToken) { - } - - override def connectionLost(cause: Throwable) { - restart("Connection lost ", cause) - } - } - - // Set up callback for MqttClient. This needs to happen before - // connecting or subscribing, otherwise messages may be lost - client.setCallback(callback) - - // Connect to MqttBroker - client.connect() - - // Subscribe to Mqtt topic - client.subscribe(topic) - - } -} diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala deleted file mode 100644 index 7b8d56d6faf2d..0000000000000 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.mqtt - -import scala.reflect.ClassTag - -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaDStream, JavaReceiverInputDStream, JavaStreamingContext} -import org.apache.spark.streaming.dstream.ReceiverInputDStream - -object MQTTUtils { - /** - * Create an input stream that receives messages pushed by a MQTT publisher. - * @param ssc StreamingContext object - * @param brokerUrl Url of remote MQTT publisher - * @param topic Topic name to subscribe to - * @param storageLevel RDD storage level. Defaults to StorageLevel.MEMORY_AND_DISK_SER_2. - */ - def createStream( - ssc: StreamingContext, - brokerUrl: String, - topic: String, - storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 - ): ReceiverInputDStream[String] = { - new MQTTInputDStream(ssc, brokerUrl, topic, storageLevel) - } - - /** - * Create an input stream that receives messages pushed by a MQTT publisher. - * Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2. - * @param jssc JavaStreamingContext object - * @param brokerUrl Url of remote MQTT publisher - * @param topic Topic name to subscribe to - */ - def createStream( - jssc: JavaStreamingContext, - brokerUrl: String, - topic: String - ): JavaReceiverInputDStream[String] = { - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[String]] - createStream(jssc.ssc, brokerUrl, topic) - } - - /** - * Create an input stream that receives messages pushed by a MQTT publisher. - * @param jssc JavaStreamingContext object - * @param brokerUrl Url of remote MQTT publisher - * @param topic Topic name to subscribe to - * @param storageLevel RDD storage level. - */ - def createStream( - jssc: JavaStreamingContext, - brokerUrl: String, - topic: String, - storageLevel: StorageLevel - ): JavaReceiverInputDStream[String] = { - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[String]] - createStream(jssc.ssc, brokerUrl, topic, storageLevel) - } -} - -/** - * This is a helper class that wraps the methods in MQTTUtils into more Python-friendly class and - * function so that it can be easily instantiated and called from Python's MQTTUtils. - */ -private[mqtt] class MQTTUtilsPythonHelper { - - def createStream( - jssc: JavaStreamingContext, - brokerUrl: String, - topic: String, - storageLevel: StorageLevel - ): JavaDStream[String] = { - MQTTUtils.createStream(jssc, brokerUrl, topic, storageLevel) - } -} diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/package-info.java b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/package-info.java deleted file mode 100644 index 728e0d8663d01..0000000000000 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/package-info.java +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * MQTT receiver for Spark Streaming. - */ -package org.apache.spark.streaming.mqtt; \ No newline at end of file diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/package.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/package.scala deleted file mode 100644 index 63d0d138183a9..0000000000000 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/package.scala +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming - -/** - * MQTT receiver for Spark Streaming. - */ -package object mqtt diff --git a/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java deleted file mode 100644 index cfedb5a042a35..0000000000000 --- a/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming; - -import org.apache.spark.SparkConf; -import org.apache.spark.streaming.api.java.JavaStreamingContext; -import org.junit.After; -import org.junit.Before; - -public abstract class LocalJavaStreamingContext { - - protected transient JavaStreamingContext ssc; - - @Before - public void setUp() { - SparkConf conf = new SparkConf() - .setMaster("local[2]") - .setAppName("test") - .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); - ssc = new JavaStreamingContext(conf, new Duration(1000)); - ssc.checkpoint("checkpoint"); - } - - @After - public void tearDown() { - ssc.stop(); - ssc = null; - } -} diff --git a/external/mqtt/src/test/java/org/apache/spark/streaming/mqtt/JavaMQTTStreamSuite.java b/external/mqtt/src/test/java/org/apache/spark/streaming/mqtt/JavaMQTTStreamSuite.java deleted file mode 100644 index ce5aa1e0cdda4..0000000000000 --- a/external/mqtt/src/test/java/org/apache/spark/streaming/mqtt/JavaMQTTStreamSuite.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.mqtt; - -import org.apache.spark.storage.StorageLevel; -import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; -import org.junit.Test; - -import org.apache.spark.streaming.LocalJavaStreamingContext; - -public class JavaMQTTStreamSuite extends LocalJavaStreamingContext { - @Test - public void testMQTTStream() { - String brokerUrl = "abc"; - String topic = "def"; - - // tests the API, does not actually test data receiving - JavaReceiverInputDStream test1 = MQTTUtils.createStream(ssc, brokerUrl, topic); - JavaReceiverInputDStream test2 = MQTTUtils.createStream(ssc, brokerUrl, topic, - StorageLevel.MEMORY_AND_DISK_SER_2()); - } -} diff --git a/external/mqtt/src/test/resources/log4j.properties b/external/mqtt/src/test/resources/log4j.properties deleted file mode 100644 index 75e3b53a093f6..0000000000000 --- a/external/mqtt/src/test/resources/log4j.properties +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Set everything to be logged to the file target/unit-tests.log -log4j.rootCategory=INFO, file -log4j.appender.file=org.apache.log4j.FileAppender -log4j.appender.file.append=true -log4j.appender.file.file=target/unit-tests.log -log4j.appender.file.layout=org.apache.log4j.PatternLayout -log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n - -# Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN - diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala deleted file mode 100644 index fdcd18c6fb048..0000000000000 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.mqtt - -import scala.concurrent.duration._ -import scala.language.postfixOps - -import org.scalatest.BeforeAndAfter -import org.scalatest.concurrent.Eventually - -import org.apache.spark.SparkFunSuite -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{Milliseconds, StreamingContext} - -class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter { - - private val batchDuration = Milliseconds(500) - private val master = "local[2]" - private val framework = this.getClass.getSimpleName - private val topic = "def" - - private var ssc: StreamingContext = _ - private var mqttTestUtils: MQTTTestUtils = _ - - before { - ssc = new StreamingContext(master, framework, batchDuration) - mqttTestUtils = new MQTTTestUtils - mqttTestUtils.setup() - } - - after { - if (ssc != null) { - ssc.stop() - ssc = null - } - if (mqttTestUtils != null) { - mqttTestUtils.teardown() - mqttTestUtils = null - } - } - - test("mqtt input stream") { - val sendMessage = "MQTT demo for spark streaming" - val receiveStream = MQTTUtils.createStream(ssc, "tcp://" + mqttTestUtils.brokerUri, topic, - StorageLevel.MEMORY_ONLY) - - @volatile var receiveMessage: List[String] = List() - receiveStream.foreachRDD { rdd => - if (rdd.collect.length > 0) { - receiveMessage = receiveMessage ::: List(rdd.first) - receiveMessage - } - } - - ssc.start() - - // Retry it because we don't know when the receiver will start. - eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { - mqttTestUtils.publishData(topic, sendMessage) - assert(sendMessage.equals(receiveMessage(0))) - } - ssc.stop() - } -} diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala deleted file mode 100644 index 26c6dc45d5115..0000000000000 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.mqtt - -import java.net.{ServerSocket, URI} - -import scala.language.postfixOps - -import com.google.common.base.Charsets.UTF_8 -import org.apache.activemq.broker.{BrokerService, TransportConnector} -import org.apache.commons.lang3.RandomUtils -import org.eclipse.paho.client.mqttv3._ -import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence - -import org.apache.spark.{Logging, SparkConf} -import org.apache.spark.util.Utils - -/** - * Share codes for Scala and Python unit tests - */ -private[mqtt] class MQTTTestUtils extends Logging { - - private val persistenceDir = Utils.createTempDir() - private val brokerHost = "localhost" - private val brokerPort = findFreePort() - - private var broker: BrokerService = _ - private var connector: TransportConnector = _ - - def brokerUri: String = { - s"$brokerHost:$brokerPort" - } - - def setup(): Unit = { - broker = new BrokerService() - broker.setDataDirectoryFile(Utils.createTempDir()) - connector = new TransportConnector() - connector.setName("mqtt") - connector.setUri(new URI("mqtt://" + brokerUri)) - broker.addConnector(connector) - broker.start() - } - - def teardown(): Unit = { - if (broker != null) { - broker.stop() - broker = null - } - if (connector != null) { - connector.stop() - connector = null - } - Utils.deleteRecursively(persistenceDir) - } - - private def findFreePort(): Int = { - val candidatePort = RandomUtils.nextInt(1024, 65536) - Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { - val socket = new ServerSocket(trialPort) - socket.close() - (null, trialPort) - }, new SparkConf())._2 - } - - def publishData(topic: String, data: String): Unit = { - var client: MqttClient = null - try { - val persistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath) - client = new MqttClient("tcp://" + brokerUri, MqttClient.generateClientId(), persistence) - client.connect() - if (client.isConnected) { - val msgTopic = client.getTopic(topic) - val message = new MqttMessage(data.getBytes(UTF_8)) - message.setQos(1) - message.setRetained(true) - - for (i <- 0 to 10) { - try { - msgTopic.publish(message) - } catch { - case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => - // wait for Spark streaming to consume something from the message queue - Thread.sleep(50) - } - } - } - } finally { - if (client != null) { - client.disconnect() - client.close() - client = null - } - } - } - -} diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml deleted file mode 100644 index 5d4053afcbba7..0000000000000 --- a/external/twitter/pom.xml +++ /dev/null @@ -1,70 +0,0 @@ - - - - - 4.0.0 - - org.apache.spark - spark-parent_2.11 - 2.0.0-SNAPSHOT - ../../pom.xml - - - org.apache.spark - spark-streaming-twitter_2.11 - - streaming-twitter - - jar - Spark Project External Twitter - http://spark.apache.org/ - - - - org.apache.spark - spark-streaming_${scala.binary.version} - ${project.version} - provided - - - org.apache.spark - spark-core_${scala.binary.version} - ${project.version} - test-jar - test - - - org.twitter4j - twitter4j-stream - 4.0.4 - - - org.scalacheck - scalacheck_${scala.binary.version} - test - - - org.apache.spark - spark-test-tags_${scala.binary.version} - - - - target/scala-${scala.binary.version}/classes - target/scala-${scala.binary.version}/test-classes - - diff --git a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala deleted file mode 100644 index bdd57fdde3b89..0000000000000 --- a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.twitter - -import twitter4j._ -import twitter4j.auth.Authorization -import twitter4j.auth.OAuthAuthorization -import twitter4j.conf.ConfigurationBuilder - -import org.apache.spark.Logging -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming._ -import org.apache.spark.streaming.dstream._ -import org.apache.spark.streaming.receiver.Receiver - -/* A stream of Twitter statuses, potentially filtered by one or more keywords. -* -* @constructor create a new Twitter stream using the supplied Twitter4J authentication credentials. -* An optional set of string filters can be used to restrict the set of tweets. The Twitter API is -* such that this may return a sampled subset of all tweets during each interval. -* -* If no Authorization object is provided, initializes OAuth authorization using the system -* properties twitter4j.oauth.consumerKey, .consumerSecret, .accessToken and .accessTokenSecret. -*/ -private[streaming] -class TwitterInputDStream( - _ssc: StreamingContext, - twitterAuth: Option[Authorization], - filters: Seq[String], - storageLevel: StorageLevel - ) extends ReceiverInputDStream[Status](_ssc) { - - private def createOAuthAuthorization(): Authorization = { - new OAuthAuthorization(new ConfigurationBuilder().build()) - } - - private val authorization = twitterAuth.getOrElse(createOAuthAuthorization()) - - override def getReceiver(): Receiver[Status] = { - new TwitterReceiver(authorization, filters, storageLevel) - } -} - -private[streaming] -class TwitterReceiver( - twitterAuth: Authorization, - filters: Seq[String], - storageLevel: StorageLevel - ) extends Receiver[Status](storageLevel) with Logging { - - @volatile private var twitterStream: TwitterStream = _ - @volatile private var stopped = false - - def onStart() { - try { - val newTwitterStream = new TwitterStreamFactory().getInstance(twitterAuth) - newTwitterStream.addListener(new StatusListener { - def onStatus(status: Status): Unit = { - store(status) - } - // Unimplemented - def onDeletionNotice(statusDeletionNotice: StatusDeletionNotice) {} - def onTrackLimitationNotice(i: Int) {} - def onScrubGeo(l: Long, l1: Long) {} - def onStallWarning(stallWarning: StallWarning) {} - def onException(e: Exception) { - if (!stopped) { - restart("Error receiving tweets", e) - } - } - }) - - val query = new FilterQuery - if (filters.size > 0) { - query.track(filters.mkString(",")) - newTwitterStream.filter(query) - } else { - newTwitterStream.sample() - } - setTwitterStream(newTwitterStream) - logInfo("Twitter receiver started") - stopped = false - } catch { - case e: Exception => restart("Error starting Twitter stream", e) - } - } - - def onStop() { - stopped = true - setTwitterStream(null) - logInfo("Twitter receiver stopped") - } - - private def setTwitterStream(newTwitterStream: TwitterStream) = synchronized { - if (twitterStream != null) { - twitterStream.shutdown() - } - twitterStream = newTwitterStream - } -} diff --git a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterUtils.scala b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterUtils.scala deleted file mode 100644 index 9cb0106ab1e7b..0000000000000 --- a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterUtils.scala +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.twitter - -import twitter4j.Status -import twitter4j.auth.Authorization - -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} -import org.apache.spark.streaming.dstream.ReceiverInputDStream - -object TwitterUtils { - /** - * Create a input stream that returns tweets received from Twitter. - * @param ssc StreamingContext object - * @param twitterAuth Twitter4J authentication, or None to use Twitter4J's default OAuth - * authorization; this uses the system properties twitter4j.oauth.consumerKey, - * twitter4j.oauth.consumerSecret, twitter4j.oauth.accessToken and - * twitter4j.oauth.accessTokenSecret - * @param filters Set of filter strings to get only those tweets that match them - * @param storageLevel Storage level to use for storing the received objects - */ - def createStream( - ssc: StreamingContext, - twitterAuth: Option[Authorization], - filters: Seq[String] = Nil, - storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 - ): ReceiverInputDStream[Status] = { - new TwitterInputDStream(ssc, twitterAuth, filters, storageLevel) - } - - /** - * Create a input stream that returns tweets received from Twitter using Twitter4J's default - * OAuth authentication; this requires the system properties twitter4j.oauth.consumerKey, - * twitter4j.oauth.consumerSecret, twitter4j.oauth.accessToken and - * twitter4j.oauth.accessTokenSecret. - * Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2. - * @param jssc JavaStreamingContext object - */ - def createStream(jssc: JavaStreamingContext): JavaReceiverInputDStream[Status] = { - createStream(jssc.ssc, None) - } - - /** - * Create a input stream that returns tweets received from Twitter using Twitter4J's default - * OAuth authentication; this requires the system properties twitter4j.oauth.consumerKey, - * twitter4j.oauth.consumerSecret, twitter4j.oauth.accessToken and - * twitter4j.oauth.accessTokenSecret. - * Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2. - * @param jssc JavaStreamingContext object - * @param filters Set of filter strings to get only those tweets that match them - */ - def createStream(jssc: JavaStreamingContext, filters: Array[String] - ): JavaReceiverInputDStream[Status] = { - createStream(jssc.ssc, None, filters) - } - - /** - * Create a input stream that returns tweets received from Twitter using Twitter4J's default - * OAuth authentication; this requires the system properties twitter4j.oauth.consumerKey, - * twitter4j.oauth.consumerSecret, twitter4j.oauth.accessToken and - * twitter4j.oauth.accessTokenSecret. - * @param jssc JavaStreamingContext object - * @param filters Set of filter strings to get only those tweets that match them - * @param storageLevel Storage level to use for storing the received objects - */ - def createStream( - jssc: JavaStreamingContext, - filters: Array[String], - storageLevel: StorageLevel - ): JavaReceiverInputDStream[Status] = { - createStream(jssc.ssc, None, filters, storageLevel) - } - - /** - * Create a input stream that returns tweets received from Twitter. - * Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2. - * @param jssc JavaStreamingContext object - * @param twitterAuth Twitter4J Authorization - */ - def createStream(jssc: JavaStreamingContext, twitterAuth: Authorization - ): JavaReceiverInputDStream[Status] = { - createStream(jssc.ssc, Some(twitterAuth)) - } - - /** - * Create a input stream that returns tweets received from Twitter. - * Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2. - * @param jssc JavaStreamingContext object - * @param twitterAuth Twitter4J Authorization - * @param filters Set of filter strings to get only those tweets that match them - */ - def createStream( - jssc: JavaStreamingContext, - twitterAuth: Authorization, - filters: Array[String] - ): JavaReceiverInputDStream[Status] = { - createStream(jssc.ssc, Some(twitterAuth), filters) - } - - /** - * Create a input stream that returns tweets received from Twitter. - * @param jssc JavaStreamingContext object - * @param twitterAuth Twitter4J Authorization object - * @param filters Set of filter strings to get only those tweets that match them - * @param storageLevel Storage level to use for storing the received objects - */ - def createStream( - jssc: JavaStreamingContext, - twitterAuth: Authorization, - filters: Array[String], - storageLevel: StorageLevel - ): JavaReceiverInputDStream[Status] = { - createStream(jssc.ssc, Some(twitterAuth), filters, storageLevel) - } -} diff --git a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/package-info.java b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/package-info.java deleted file mode 100644 index 258c0950a0aa7..0000000000000 --- a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/package-info.java +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * Twitter feed receiver for spark streaming. - */ -package org.apache.spark.streaming.twitter; \ No newline at end of file diff --git a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/package.scala b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/package.scala deleted file mode 100644 index 580e37fa8f814..0000000000000 --- a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/package.scala +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming - -/** - * Twitter feed receiver for spark streaming. - */ -package object twitter diff --git a/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java deleted file mode 100644 index cfedb5a042a35..0000000000000 --- a/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming; - -import org.apache.spark.SparkConf; -import org.apache.spark.streaming.api.java.JavaStreamingContext; -import org.junit.After; -import org.junit.Before; - -public abstract class LocalJavaStreamingContext { - - protected transient JavaStreamingContext ssc; - - @Before - public void setUp() { - SparkConf conf = new SparkConf() - .setMaster("local[2]") - .setAppName("test") - .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); - ssc = new JavaStreamingContext(conf, new Duration(1000)); - ssc.checkpoint("checkpoint"); - } - - @After - public void tearDown() { - ssc.stop(); - ssc = null; - } -} diff --git a/external/twitter/src/test/java/org/apache/spark/streaming/twitter/JavaTwitterStreamSuite.java b/external/twitter/src/test/java/org/apache/spark/streaming/twitter/JavaTwitterStreamSuite.java deleted file mode 100644 index 26ec8af455bcf..0000000000000 --- a/external/twitter/src/test/java/org/apache/spark/streaming/twitter/JavaTwitterStreamSuite.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.twitter; - -import org.junit.Test; -import twitter4j.Status; -import twitter4j.auth.Authorization; -import twitter4j.auth.NullAuthorization; -import org.apache.spark.storage.StorageLevel; -import org.apache.spark.streaming.LocalJavaStreamingContext; -import org.apache.spark.streaming.api.java.JavaDStream; - -public class JavaTwitterStreamSuite extends LocalJavaStreamingContext { - @Test - public void testTwitterStream() { - String[] filters = { "filter1", "filter2" }; - Authorization auth = NullAuthorization.getInstance(); - - // tests the API, does not actually test data receiving - JavaDStream test1 = TwitterUtils.createStream(ssc); - JavaDStream test2 = TwitterUtils.createStream(ssc, filters); - JavaDStream test3 = TwitterUtils.createStream( - ssc, filters, StorageLevel.MEMORY_AND_DISK_SER_2()); - JavaDStream test4 = TwitterUtils.createStream(ssc, auth); - JavaDStream test5 = TwitterUtils.createStream(ssc, auth, filters); - JavaDStream test6 = TwitterUtils.createStream(ssc, - auth, filters, StorageLevel.MEMORY_AND_DISK_SER_2()); - } -} diff --git a/external/twitter/src/test/resources/log4j.properties b/external/twitter/src/test/resources/log4j.properties deleted file mode 100644 index 9a3569789d2e0..0000000000000 --- a/external/twitter/src/test/resources/log4j.properties +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Set everything to be logged to the filetarget/unit-tests.log -log4j.rootCategory=INFO, file -log4j.appender.file=org.apache.log4j.FileAppender -log4j.appender.file.append=true -log4j.appender.file.file=target/unit-tests.log -log4j.appender.file.layout=org.apache.log4j.PatternLayout -log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n - -# Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN - diff --git a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala deleted file mode 100644 index 7e5fc0cbb9b30..0000000000000 --- a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.twitter - -import org.scalatest.BeforeAndAfter -import twitter4j.Status -import twitter4j.auth.{Authorization, NullAuthorization} - -import org.apache.spark.{Logging, SparkFunSuite} -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{Seconds, StreamingContext} -import org.apache.spark.streaming.dstream.ReceiverInputDStream - -class TwitterStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging { - - val batchDuration = Seconds(1) - - private val master: String = "local[2]" - - private val framework: String = this.getClass.getSimpleName - - test("twitter input stream") { - val ssc = new StreamingContext(master, framework, batchDuration) - val filters = Seq("filter1", "filter2") - val authorization: Authorization = NullAuthorization.getInstance() - - // tests the API, does not actually test data receiving - val test1: ReceiverInputDStream[Status] = TwitterUtils.createStream(ssc, None) - val test2: ReceiverInputDStream[Status] = - TwitterUtils.createStream(ssc, None, filters) - val test3: ReceiverInputDStream[Status] = - TwitterUtils.createStream(ssc, None, filters, StorageLevel.MEMORY_AND_DISK_SER_2) - val test4: ReceiverInputDStream[Status] = - TwitterUtils.createStream(ssc, Some(authorization)) - val test5: ReceiverInputDStream[Status] = - TwitterUtils.createStream(ssc, Some(authorization), filters) - val test6: ReceiverInputDStream[Status] = TwitterUtils.createStream( - ssc, Some(authorization), filters, StorageLevel.MEMORY_AND_DISK_SER_2) - - // Note that actually testing the data receiving is hard as authentication keys are - // necessary for accessing Twitter live stream - ssc.stop() - } -} diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml deleted file mode 100644 index f16bc0f319744..0000000000000 --- a/external/zeromq/pom.xml +++ /dev/null @@ -1,74 +0,0 @@ - - - - - 4.0.0 - - org.apache.spark - spark-parent_2.11 - 2.0.0-SNAPSHOT - ../../pom.xml - - - org.apache.spark - spark-streaming-zeromq_2.11 - - streaming-zeromq - - jar - Spark Project External ZeroMQ - http://spark.apache.org/ - - - - org.apache.spark - spark-streaming_${scala.binary.version} - ${project.version} - provided - - - org.apache.spark - spark-streaming-akka_${scala.binary.version} - ${project.version} - - - org.apache.spark - spark-core_${scala.binary.version} - ${project.version} - test-jar - test - - - ${akka.group} - akka-zeromq_${scala.binary.version} - - - org.scalacheck - scalacheck_${scala.binary.version} - test - - - org.apache.spark - spark-test-tags_${scala.binary.version} - - - - target/scala-${scala.binary.version}/classes - target/scala-${scala.binary.version}/test-classes - - diff --git a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala deleted file mode 100644 index dd367cd43b807..0000000000000 --- a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.zeromq - -import scala.reflect.ClassTag - -import akka.util.ByteString -import akka.zeromq._ - -import org.apache.spark.Logging -import org.apache.spark.streaming.akka.ActorReceiver - -/** - * A receiver to subscribe to ZeroMQ stream. - */ -private[streaming] class ZeroMQReceiver[T: ClassTag]( - publisherUrl: String, - subscribe: Subscribe, - bytesToObjects: Seq[ByteString] => Iterator[T]) - extends ActorReceiver with Logging { - - override def preStart(): Unit = { - ZeroMQExtension(context.system) - .newSocket(SocketType.Sub, Listener(self), Connect(publisherUrl), subscribe) - } - - def receive: Receive = { - - case Connecting => logInfo("connecting ...") - - case m: ZMQMessage => - logDebug("Received message for:" + m.frame(0)) - - // We ignore first frame for processing as it is the topic - val bytes = m.frames.tail - store(bytesToObjects(bytes)) - - case Closed => logInfo("received closed ") - } -} diff --git a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala deleted file mode 100644 index 1784d6e8623ad..0000000000000 --- a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala +++ /dev/null @@ -1,163 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.zeromq - -import scala.collection.JavaConverters._ -import scala.reflect.ClassTag - -import akka.actor.{ActorSystem, Props, SupervisorStrategy} -import akka.util.ByteString -import akka.zeromq.Subscribe - -import org.apache.spark.api.java.function.{Function => JFunction, Function0 => JFunction0} -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.akka.{ActorReceiver, AkkaUtils} -import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} -import org.apache.spark.streaming.dstream.ReceiverInputDStream - -object ZeroMQUtils { - /** - * Create an input stream that receives messages pushed by a zeromq publisher. - * @param ssc StreamingContext object - * @param publisherUrl Url of remote zeromq publisher - * @param subscribe Topic to subscribe to - * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic - * and each frame has sequence of byte thus it needs the converter - * (which might be deserializer of bytes) to translate from sequence - * of sequence of bytes, where sequence refer to a frame - * and sub sequence refer to its payload. - * @param storageLevel RDD storage level. Defaults to StorageLevel.MEMORY_AND_DISK_SER_2. - * @param actorSystemCreator A function to create ActorSystem in executors. `ActorSystem` will - * be shut down when the receiver is stopping (default: - * ActorReceiver.defaultActorSystemCreator) - * @param supervisorStrategy the supervisor strategy (default: ActorReceiver.defaultStrategy) - */ - def createStream[T: ClassTag]( - ssc: StreamingContext, - publisherUrl: String, - subscribe: Subscribe, - bytesToObjects: Seq[ByteString] => Iterator[T], - storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2, - actorSystemCreator: () => ActorSystem = ActorReceiver.defaultActorSystemCreator, - supervisorStrategy: SupervisorStrategy = ActorReceiver.defaultSupervisorStrategy - ): ReceiverInputDStream[T] = { - AkkaUtils.createStream( - ssc, - Props(new ZeroMQReceiver(publisherUrl, subscribe, bytesToObjects)), - "ZeroMQReceiver", - storageLevel, - actorSystemCreator, - supervisorStrategy) - } - - /** - * Create an input stream that receives messages pushed by a zeromq publisher. - * @param jssc JavaStreamingContext object - * @param publisherUrl Url of remote ZeroMQ publisher - * @param subscribe Topic to subscribe to - * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic and each - * frame has sequence of byte thus it needs the converter(which might be - * deserializer of bytes) to translate from sequence of sequence of bytes, - * where sequence refer to a frame and sub sequence refer to its payload. - * @param storageLevel Storage level to use for storing the received objects - * @param actorSystemCreator A function to create ActorSystem in executors. `ActorSystem` will - * be shut down when the receiver is stopping. - * @param supervisorStrategy the supervisor strategy (default: ActorReceiver.defaultStrategy) - */ - def createStream[T]( - jssc: JavaStreamingContext, - publisherUrl: String, - subscribe: Subscribe, - bytesToObjects: JFunction[Array[Array[Byte]], java.lang.Iterable[T]], - storageLevel: StorageLevel, - actorSystemCreator: JFunction0[ActorSystem], - supervisorStrategy: SupervisorStrategy - ): JavaReceiverInputDStream[T] = { - implicit val cm: ClassTag[T] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - val fn = - (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).iterator().asScala - createStream[T]( - jssc.ssc, - publisherUrl, - subscribe, - fn, - storageLevel, - () => actorSystemCreator.call(), - supervisorStrategy) - } - - /** - * Create an input stream that receives messages pushed by a zeromq publisher. - * @param jssc JavaStreamingContext object - * @param publisherUrl Url of remote zeromq publisher - * @param subscribe Topic to subscribe to - * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic and each - * frame has sequence of byte thus it needs the converter(which might be - * deserializer of bytes) to translate from sequence of sequence of bytes, - * where sequence refer to a frame and sub sequence refer to its payload. - * @param storageLevel RDD storage level. - */ - def createStream[T]( - jssc: JavaStreamingContext, - publisherUrl: String, - subscribe: Subscribe, - bytesToObjects: JFunction[Array[Array[Byte]], java.lang.Iterable[T]], - storageLevel: StorageLevel - ): JavaReceiverInputDStream[T] = { - implicit val cm: ClassTag[T] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - val fn = - (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).iterator().asScala - createStream[T]( - jssc.ssc, - publisherUrl, - subscribe, - fn, - storageLevel) - } - - /** - * Create an input stream that receives messages pushed by a zeromq publisher. - * @param jssc JavaStreamingContext object - * @param publisherUrl Url of remote zeromq publisher - * @param subscribe Topic to subscribe to - * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic and each - * frame has sequence of byte thus it needs the converter(which might - * be deserializer of bytes) to translate from sequence of sequence of - * bytes, where sequence refer to a frame and sub sequence refer to its - * payload. - */ - def createStream[T]( - jssc: JavaStreamingContext, - publisherUrl: String, - subscribe: Subscribe, - bytesToObjects: JFunction[Array[Array[Byte]], java.lang.Iterable[T]] - ): JavaReceiverInputDStream[T] = { - implicit val cm: ClassTag[T] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - val fn = - (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).iterator().asScala - createStream[T]( - jssc.ssc, - publisherUrl, - subscribe, - fn) - } -} diff --git a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/package-info.java b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/package-info.java deleted file mode 100644 index 587c524e2120f..0000000000000 --- a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/package-info.java +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * Zeromq receiver for spark streaming. - */ -package org.apache.spark.streaming.zeromq; \ No newline at end of file diff --git a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/package.scala b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/package.scala deleted file mode 100644 index 65e6e57f2c05d..0000000000000 --- a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/package.scala +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming - -/** - * Zeromq receiver for spark streaming. - */ -package object zeromq diff --git a/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java deleted file mode 100644 index cfedb5a042a35..0000000000000 --- a/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming; - -import org.apache.spark.SparkConf; -import org.apache.spark.streaming.api.java.JavaStreamingContext; -import org.junit.After; -import org.junit.Before; - -public abstract class LocalJavaStreamingContext { - - protected transient JavaStreamingContext ssc; - - @Before - public void setUp() { - SparkConf conf = new SparkConf() - .setMaster("local[2]") - .setAppName("test") - .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); - ssc = new JavaStreamingContext(conf, new Duration(1000)); - ssc.checkpoint("checkpoint"); - } - - @After - public void tearDown() { - ssc.stop(); - ssc = null; - } -} diff --git a/external/zeromq/src/test/java/org/apache/spark/streaming/zeromq/JavaZeroMQStreamSuite.java b/external/zeromq/src/test/java/org/apache/spark/streaming/zeromq/JavaZeroMQStreamSuite.java deleted file mode 100644 index 9ff4b41f97d50..0000000000000 --- a/external/zeromq/src/test/java/org/apache/spark/streaming/zeromq/JavaZeroMQStreamSuite.java +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.zeromq; - -import akka.actor.ActorSystem; -import akka.actor.SupervisorStrategy; -import akka.util.ByteString; -import akka.zeromq.Subscribe; -import org.junit.Test; - -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.Function0; -import org.apache.spark.storage.StorageLevel; -import org.apache.spark.streaming.LocalJavaStreamingContext; -import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; - -public class JavaZeroMQStreamSuite extends LocalJavaStreamingContext { - - @Test // tests the API, does not actually test data receiving - public void testZeroMQStream() { - String publishUrl = "abc"; - Subscribe subscribe = new Subscribe((ByteString)null); - Function> bytesToObjects = new BytesToObjects(); - Function0 actorSystemCreator = new ActorSystemCreatorForTest(); - - JavaReceiverInputDStream test1 = ZeroMQUtils.createStream( - ssc, publishUrl, subscribe, bytesToObjects); - JavaReceiverInputDStream test2 = ZeroMQUtils.createStream( - ssc, publishUrl, subscribe, bytesToObjects, StorageLevel.MEMORY_AND_DISK_SER_2()); - JavaReceiverInputDStream test3 = ZeroMQUtils.createStream( - ssc, publishUrl, subscribe, bytesToObjects, StorageLevel.MEMORY_AND_DISK_SER_2(), actorSystemCreator, - SupervisorStrategy.defaultStrategy()); - } -} - -class BytesToObjects implements Function> { - @Override - public Iterable call(byte[][] bytes) throws Exception { - return null; - } -} - -class ActorSystemCreatorForTest implements Function0 { - @Override - public ActorSystem call() { - return null; - } -} diff --git a/external/zeromq/src/test/resources/log4j.properties b/external/zeromq/src/test/resources/log4j.properties deleted file mode 100644 index 75e3b53a093f6..0000000000000 --- a/external/zeromq/src/test/resources/log4j.properties +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Set everything to be logged to the file target/unit-tests.log -log4j.rootCategory=INFO, file -log4j.appender.file=org.apache.log4j.FileAppender -log4j.appender.file.append=true -log4j.appender.file.file=target/unit-tests.log -log4j.appender.file.layout=org.apache.log4j.PatternLayout -log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n - -# Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN - diff --git a/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala b/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala deleted file mode 100644 index bac2679cabae5..0000000000000 --- a/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.zeromq - -import akka.actor.SupervisorStrategy -import akka.util.ByteString -import akka.zeromq.Subscribe - -import org.apache.spark.SparkFunSuite -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{Seconds, StreamingContext} -import org.apache.spark.streaming.dstream.ReceiverInputDStream - -class ZeroMQStreamSuite extends SparkFunSuite { - - val batchDuration = Seconds(1) - - private val master: String = "local[2]" - - private val framework: String = this.getClass.getSimpleName - - test("zeromq input stream") { - val ssc = new StreamingContext(master, framework, batchDuration) - val publishUrl = "abc" - val subscribe = new Subscribe(null.asInstanceOf[ByteString]) - val bytesToObjects = (bytes: Seq[ByteString]) => null.asInstanceOf[Iterator[String]] - - // tests the API, does not actually test data receiving - val test1: ReceiverInputDStream[String] = - ZeroMQUtils.createStream( - ssc, publishUrl, subscribe, bytesToObjects, actorSystemCreator = () => null) - val test2: ReceiverInputDStream[String] = ZeroMQUtils.createStream( - ssc, publishUrl, subscribe, bytesToObjects, StorageLevel.MEMORY_AND_DISK_SER_2, () => null) - val test3: ReceiverInputDStream[String] = ZeroMQUtils.createStream( - ssc, publishUrl, subscribe, bytesToObjects, - StorageLevel.MEMORY_AND_DISK_SER_2, () => null, SupervisorStrategy.defaultStrategy) - val test4: ReceiverInputDStream[String] = - ZeroMQUtils.createStream(ssc, publishUrl, subscribe, bytesToObjects) - val test5: ReceiverInputDStream[String] = ZeroMQUtils.createStream( - ssc, publishUrl, subscribe, bytesToObjects, StorageLevel.MEMORY_AND_DISK_SER_2) - val test6: ReceiverInputDStream[String] = ZeroMQUtils.createStream( - ssc, publishUrl, subscribe, bytesToObjects, - StorageLevel.MEMORY_AND_DISK_SER_2, supervisorStrategy = SupervisorStrategy.defaultStrategy) - - // TODO: Actually test data receiving. A real test needs the native ZeroMQ library - ssc.stop() - } -} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index fe884d0022500..5485e30f5a2c9 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -297,7 +297,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab /** * Restricts the graph to only the vertices and edges satisfying the predicates. The resulting - * subgraph satisifies + * subgraph satisfies * * {{{ * V' = {v : for all v in V where vpred(v)} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index d537b6141cc90..fcb1b5999fae7 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -236,11 +236,11 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali * @param preprocess a function to compute new vertex and edge data before filtering * @param epred edge pred to filter on after preprocess, see more details under * [[org.apache.spark.graphx.Graph#subgraph]] - * @param vpred vertex pred to filter on after prerocess, see more details under + * @param vpred vertex pred to filter on after preprocess, see more details under * [[org.apache.spark.graphx.Graph#subgraph]] * @tparam VD2 vertex type the vpred operates on * @tparam ED2 edge type the epred operates on - * @return a subgraph of the orginal graph, with its data unchanged + * @return a subgraph of the original graph, with its data unchanged * * @example This function can be used to filter the graph based on some property, without * changing the vertex and edge values in your program. For example, we could remove the vertices diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala index 6dab465fb9012..a4e293d74a012 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala @@ -49,7 +49,7 @@ object ShippableVertexPartition { /** * Construct a `ShippableVertexPartition` from the given vertices with the specified routing * table, filling in missing vertices mentioned in the routing table using `defaultVal`, - * and merging duplicate vertex atrribute with mergeFunc. + * and merging duplicate vertex attribute with mergeFunc. */ def apply[VD: ClassTag]( iter: Iterator[(VertexId, VD)], routingTable: RoutingTablePartition, defaultVal: VD, diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index 46faad2e68c50..00ba358a9b4a6 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -54,7 +54,7 @@ import org.apache.spark.graphx._ * }}} * * `alpha` is the random reset probability (typically 0.15), `inNbrs[i]` is the set of - * neighbors whick link to `i` and `outDeg[j]` is the out degree of vertex `j`. + * neighbors which link to `i` and `outDeg[j]` is the out degree of vertex `j`. * * Note that this is not the "normalized" PageRank and as a consequence pages that have no * inlinks will have a PageRank of alpha. @@ -209,7 +209,7 @@ object PageRank extends Logging { } // Set the weight on the edges based on the degree .mapTriplets( e => 1.0 / e.srcAttr ) - // Set the vertex attributes to (initalPR, delta = 0) + // Set the vertex attributes to (initialPR, delta = 0) .mapVertices { (id, attr) => if (id == src) (resetProb, Double.NegativeInfinity) else (0.0, 0.0) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/package.scala b/graphx/src/main/scala/org/apache/spark/graphx/package.scala index 6aab28ff05355..dde25b96594be 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/package.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/package.scala @@ -30,7 +30,7 @@ package object graphx { */ type VertexId = Long - /** Integer identifer of a graph partition. Must be less than 2^30. */ + /** Integer identifier of a graph partition. Must be less than 2^30. */ // TODO: Consider using Char. type PartitionID = Int diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala index a6d0cb6409664..d76e84ed8c9ed 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala @@ -92,7 +92,7 @@ private[graphx] object BytecodeUtils { /** * Given the class name, return whether we should look into the class or not. This is used to - * skip examing a large quantity of Java or Scala classes that we know for sure wouldn't access + * skip examining a large quantity of Java or Scala classes that we know for sure wouldn't access * the closures. Note that the class name is expected in ASM style (i.e. use "/" instead of "."). */ private def skipClass(className: String): Boolean = { diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphLoaderSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphLoaderSuite.scala index bff9f328d4907..e55b05fa996ad 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphLoaderSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphLoaderSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.graphx import java.io.File import java.io.FileOutputStream import java.io.OutputStreamWriter +import java.nio.charset.StandardCharsets import org.apache.spark.SparkFunSuite import org.apache.spark.util.Utils @@ -30,7 +31,7 @@ class GraphLoaderSuite extends SparkFunSuite with LocalSparkContext { withSpark { sc => val tmpDir = Utils.createTempDir() val graphFile = new File(tmpDir.getAbsolutePath, "graph.txt") - val writer = new OutputStreamWriter(new FileOutputStream(graphFile)) + val writer = new OutputStreamWriter(new FileOutputStream(graphFile), StandardCharsets.UTF_8) for (i <- (1 until 101)) writer.write(s"$i 0\n") writer.close() try { diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index f497e001dfa4f..cb981797d3239 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -229,7 +229,7 @@ class GraphSuite extends SparkFunSuite with LocalSparkContext { test("subgraph") { withSpark { sc => - // Create a star graph of 10 veritces. + // Create a star graph of 10 vertices. val n = 10 val star = starGraph(sc, n) // Take only vertices whose vids are even diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index 46410327a5d72..f6c7e07654ee9 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -19,10 +19,10 @@ import java.io.BufferedReader; import java.io.File; -import java.io.FileFilter; import java.io.FileInputStream; import java.io.InputStreamReader; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -102,7 +102,7 @@ List buildJavaCommand(String extraClassPath) throws IOException { File javaOpts = new File(join(File.separator, getConfDir(), "java-opts")); if (javaOpts.isFile()) { BufferedReader br = new BufferedReader(new InputStreamReader( - new FileInputStream(javaOpts), "UTF-8")); + new FileInputStream(javaOpts), StandardCharsets.UTF_8)); try { String line; while ((line = br.readLine()) != null) { @@ -171,21 +171,13 @@ List buildClassPath(String appClassPath) throws IOException { addToClassPath(cp, String.format("%s/core/target/jars/*", sparkHome)); } - // We can't rely on the ENV_SPARK_ASSEMBLY variable to be set. Certain situations, such as - // when running unit tests, or user code that embeds Spark and creates a SparkContext - // with a local or local-cluster master, will cause this code to be called from an - // environment where that env variable is not guaranteed to exist. - // - // For the testing case, we rely on the test code to set and propagate the test classpath - // appropriately. - // - // For the user code case, we fall back to looking for the Spark assembly under SPARK_HOME. - // That duplicates some of the code in the shell scripts that look for the assembly, though. - String assembly = getenv(ENV_SPARK_ASSEMBLY); - if (assembly == null && !isTesting) { - assembly = findAssembly(); + // Add Spark jars to the classpath. For the testing case, we rely on the test code to set and + // propagate the test classpath appropriately. For normal invocation, look for the jars + // directory under SPARK_HOME. + String jarsDir = findJarsDir(!isTesting); + if (jarsDir != null) { + addToClassPath(cp, join(File.separator, jarsDir, "*")); } - addToClassPath(cp, assembly); // Datanucleus jars must be included on the classpath. Datanucleus jars do not work if only // included in the uber jar as plugin.xml metadata is lost. Both sbt and maven will populate @@ -301,7 +293,7 @@ private Properties loadPropertiesFile() throws IOException { FileInputStream fd = null; try { fd = new FileInputStream(propsFile); - props.load(new InputStreamReader(fd, "UTF-8")); + props.load(new InputStreamReader(fd, StandardCharsets.UTF_8)); for (Map.Entry e : props.entrySet()) { e.setValue(e.getValue().toString().trim()); } @@ -319,28 +311,25 @@ private Properties loadPropertiesFile() throws IOException { return props; } - private String findAssembly() { + private String findJarsDir(boolean failIfNotFound) { + // TODO: change to the correct directory once the assembly build is changed. String sparkHome = getSparkHome(); File libdir; if (new File(sparkHome, "RELEASE").isFile()) { libdir = new File(sparkHome, "lib"); - checkState(libdir.isDirectory(), "Library directory '%s' does not exist.", - libdir.getAbsolutePath()); + checkState(!failIfNotFound || libdir.isDirectory(), + "Library directory '%s' does not exist.", + libdir.getAbsolutePath()); } else { libdir = new File(sparkHome, String.format("assembly/target/scala-%s", getScalaVersion())); - } - - final Pattern re = Pattern.compile("spark-assembly.*hadoop.*\\.jar"); - FileFilter filter = new FileFilter() { - @Override - public boolean accept(File file) { - return file.isFile() && re.matcher(file.getName()).matches(); + if (!libdir.isDirectory()) { + checkState(!failIfNotFound, + "Library directory '%s' does not exist; make sure Spark is built.", + libdir.getAbsolutePath()); + libdir = null; } - }; - File[] assemblies = libdir.listFiles(filter); - checkState(assemblies != null && assemblies.length > 0, "No assemblies found in '%s'.", libdir); - checkState(assemblies.length == 1, "Multiple assemblies found in '%s'.", libdir); - return assemblies[0].getAbsolutePath(); + } + return libdir != null ? libdir.getAbsolutePath() : null; } private String getConfDir() { diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java index 7942d7372faff..37afafea28fdc 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java +++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java @@ -30,7 +30,6 @@ class CommandBuilderUtils { static final String DEFAULT_MEM = "1g"; static final String DEFAULT_PROPERTIES_FILE = "spark-defaults.conf"; static final String ENV_SPARK_HOME = "SPARK_HOME"; - static final String ENV_SPARK_ASSEMBLY = "_SPARK_ASSEMBLY"; /** The set of known JVM vendors. */ static enum JavaVendor { diff --git a/launcher/src/main/java/org/apache/spark/launcher/OutputRedirector.java b/launcher/src/main/java/org/apache/spark/launcher/OutputRedirector.java index 6e7120167d605..c7959aee9f888 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/OutputRedirector.java +++ b/launcher/src/main/java/org/apache/spark/launcher/OutputRedirector.java @@ -21,6 +21,7 @@ import java.io.InputStream; import java.io.InputStreamReader; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.util.concurrent.ThreadFactory; import java.util.logging.Level; import java.util.logging.Logger; @@ -42,7 +43,7 @@ class OutputRedirector { OutputRedirector(InputStream in, String loggerName, ThreadFactory tf) { this.active = true; - this.reader = new BufferedReader(new InputStreamReader(in)); + this.reader = new BufferedReader(new InputStreamReader(in, StandardCharsets.UTF_8)); this.thread = tf.newThread(new Runnable() { @Override public void run() { diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java index d36731840b1a1..00f967122bd70 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java @@ -199,11 +199,7 @@ private void testCmdBuilder(boolean isDriver, boolean useDefaultPropertyFile) th for (String arg : cmd) { if (arg.startsWith("-XX:MaxPermSize=")) { - if (isDriver) { - assertEquals("-XX:MaxPermSize=256m", arg); - } else { - assertEquals("-XX:MaxPermSize=256m", arg); - } + assertEquals("-XX:MaxPermSize=256m", arg); } } @@ -286,7 +282,6 @@ private boolean findInStringList(String list, String sep, String needle) { private SparkSubmitCommandBuilder newCommandBuilder(List args) { SparkSubmitCommandBuilder builder = new SparkSubmitCommandBuilder(args); builder.childEnv.put(CommandBuilderUtils.ENV_SPARK_HOME, System.getProperty("spark.test.home")); - builder.childEnv.put(CommandBuilderUtils.ENV_SPARK_ASSEMBLY, "dummy"); return builder; } diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala index f21b623e93253..2cd94fa8f5856 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala @@ -523,7 +523,7 @@ private[ml] object FeedForwardTopology { /** * Creates a multi-layer perceptron * @param layerSizes sizes of layers including input and output size - * @param softmax wether to use SoftMax or Sigmoid function for an output layer. + * @param softmax whether to use SoftMax or Sigmoid function for an output layer. * Softmax is default * @return multilayer perceptron topology */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala index 521d209a8f0ed..27554acdf3c26 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala @@ -481,7 +481,7 @@ object NominalAttribute extends AttributeFactory { * A binary attribute. * @param name optional name * @param index optional index - * @param values optionla values. If set, its size must be 2. + * @param values optional values. If set, its size must be 2. */ @DeveloperApi class BinaryAttribute private[ml] ( diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 7f0397f6bd65a..bcbedc8bc108b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -93,6 +93,14 @@ final class DecisionTreeClassifier @Since("1.4.0") ( trees.head.asInstanceOf[DecisionTreeClassificationModel] } + /** (private[ml]) Train a decision tree on an RDD */ + private[ml] def train(data: RDD[LabeledPoint], + oldStrategy: OldStrategy): DecisionTreeClassificationModel = { + val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all", + seed = 0L, parentUID = Some(uid)) + trees.head.asInstanceOf[DecisionTreeClassificationModel] + } + /** (private[ml]) Create a Strategy instance to use with the old API. */ private[ml] def getOldStrategy( categoricalFeatures: Map[Int, Int], diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index e0ffbedf6cb03..82059b1d0ecbb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -26,10 +26,10 @@ import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeClassifierParams, TreeEnsembleModel} +import org.apache.spark.ml.tree.impl.GradientBoostedTrees import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} @@ -158,9 +158,8 @@ final class GBTClassifier @Since("1.4.0") ( val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val numFeatures = oldDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) - val oldGBT = new OldGBT(boostingStrategy) - val oldModel = oldGBT.run(oldDataset) - GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures, numFeatures) + val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy) + new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures) } @Since("1.4.1") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 897b23383c0cb..6e462924511e0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -87,6 +87,14 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val trees.head.asInstanceOf[DecisionTreeRegressionModel] } + /** (private[ml]) Train a decision tree on an RDD */ + private[ml] def train(data: RDD[LabeledPoint], + oldStrategy: OldStrategy): DecisionTreeRegressionModel = { + val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all", + seed = 0L, parentUID = Some(uid)) + trees.head.asInstanceOf[DecisionTreeRegressionModel] + } + /** (private[ml]) Create a Strategy instance to use with the old API. */ private[ml] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = { super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity, diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 9c842a6c88202..4cc2721aefb22 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -25,10 +25,10 @@ import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeEnsembleModel, TreeRegressorParams} +import org.apache.spark.ml.tree.impl.GradientBoostedTrees import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss => OldLoss, SquaredError => OldSquaredError} @@ -145,9 +145,8 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val numFeatures = oldDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) - val oldGBT = new OldGBT(boostingStrategy) - val oldModel = oldGBT.run(oldDataset) - GBTRegressionModel.fromOld(oldModel, this, categoricalFeatures, numFeatures) + val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy) + new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures) } @Since("1.4.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 976343ed961c5..13a13f0a7e402 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -150,7 +150,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { requiredColumns: Array[String], filters: Array[Filter], bucketSet: Option[BitSet], - inputFiles: Array[FileStatus], + inputFiles: Seq[FileStatus], broadcastedConf: Broadcast[SerializableConfiguration], options: Map[String, String]): RDD[InternalRow] = { // TODO: This does not handle cases where column pruning has been performed. diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala new file mode 100644 index 0000000000000..44ab5b723bd7a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.Logging +import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor} +import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.tree.configuration.{BoostingStrategy => OldBoostingStrategy} +import org.apache.spark.mllib.tree.impl.TimeTracker +import org.apache.spark.mllib.tree.impurity.{Variance => OldVariance} +import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + +private[ml] object GradientBoostedTrees extends Logging { + + /** + * Method to train a gradient boosting model + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @return tuple of ensemble models and weights: + * (array of decision tree models, array of model weights) + */ + def run(input: RDD[LabeledPoint], + boostingStrategy: OldBoostingStrategy + ): (Array[DecisionTreeRegressionModel], Array[Double]) = { + val algo = boostingStrategy.treeStrategy.algo + algo match { + case OldAlgo.Regression => + GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false) + case OldAlgo.Classification => + // Map labels to -1, +1 so binary classification can be treated as regression. + val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) + GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false) + case _ => + throw new IllegalArgumentException(s"$algo is not supported by gradient boosting.") + } + } + + /** + * Method to validate a gradient boosting model + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @param validationInput Validation dataset. + * This dataset should be different from the training dataset, + * but it should follow the same distribution. + * E.g., these two datasets could be created from an original dataset + * by using [[org.apache.spark.rdd.RDD.randomSplit()]] + * @return tuple of ensemble models and weights: + * (array of decision tree models, array of model weights) + */ + def runWithValidation( + input: RDD[LabeledPoint], + validationInput: RDD[LabeledPoint], + boostingStrategy: OldBoostingStrategy + ): (Array[DecisionTreeRegressionModel], Array[Double]) = { + val algo = boostingStrategy.treeStrategy.algo + algo match { + case OldAlgo.Regression => + GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true) + case OldAlgo.Classification => + // Map labels to -1, +1 so binary classification can be treated as regression. + val remappedInput = input.map( + x => new LabeledPoint((x.label * 2) - 1, x.features)) + val remappedValidationInput = validationInput.map( + x => new LabeledPoint((x.label * 2) - 1, x.features)) + GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy, + validate = true) + case _ => + throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") + } + } + + /** + * Compute the initial predictions and errors for a dataset for the first + * iteration of gradient boosting. + * @param data: training data. + * @param initTreeWeight: learning rate assigned to the first tree. + * @param initTree: first DecisionTreeModel. + * @param loss: evaluation metric. + * @return a RDD with each element being a zip of the prediction and error + * corresponding to every sample. + */ + def computeInitialPredictionAndError( + data: RDD[LabeledPoint], + initTreeWeight: Double, + initTree: DecisionTreeRegressionModel, + loss: OldLoss): RDD[(Double, Double)] = { + data.map { lp => + val pred = initTreeWeight * initTree.rootNode.predictImpl(lp.features).prediction + val error = loss.computeError(pred, lp.label) + (pred, error) + } + } + + /** + * Update a zipped predictionError RDD + * (as obtained with computeInitialPredictionAndError) + * @param data: training data. + * @param predictionAndError: predictionError RDD + * @param treeWeight: Learning rate. + * @param tree: Tree using which the prediction and error should be updated. + * @param loss: evaluation metric. + * @return a RDD with each element being a zip of the prediction and error + * corresponding to each sample. + */ + def updatePredictionError( + data: RDD[LabeledPoint], + predictionAndError: RDD[(Double, Double)], + treeWeight: Double, + tree: DecisionTreeRegressionModel, + loss: OldLoss): RDD[(Double, Double)] = { + + val newPredError = data.zip(predictionAndError).mapPartitions { iter => + iter.map { case (lp, (pred, error)) => + val newPred = pred + tree.rootNode.predictImpl(lp.features).prediction * treeWeight + val newError = loss.computeError(newPred, lp.label) + (newPred, newError) + } + } + newPredError + } + + /** + * Internal method for performing regression using trees as base learners. + * @param input training dataset + * @param validationInput validation dataset, ignored if validate is set to false. + * @param boostingStrategy boosting parameters + * @param validate whether or not to use the validation dataset. + * @return tuple of ensemble models and weights: + * (array of decision tree models, array of model weights) + */ + def boost( + input: RDD[LabeledPoint], + validationInput: RDD[LabeledPoint], + boostingStrategy: OldBoostingStrategy, + validate: Boolean): (Array[DecisionTreeRegressionModel], Array[Double]) = { + val timer = new TimeTracker() + timer.start("total") + timer.start("init") + + boostingStrategy.assertValid() + + // Initialize gradient boosting parameters + val numIterations = boostingStrategy.numIterations + val baseLearners = new Array[DecisionTreeRegressionModel](numIterations) + val baseLearnerWeights = new Array[Double](numIterations) + val loss = boostingStrategy.loss + val learningRate = boostingStrategy.learningRate + // Prepare strategy for individual trees, which use regression with variance impurity. + val treeStrategy = boostingStrategy.treeStrategy.copy + val validationTol = boostingStrategy.validationTol + treeStrategy.algo = OldAlgo.Regression + treeStrategy.impurity = OldVariance + treeStrategy.assertValid() + + // Cache input + val persistedInput = if (input.getStorageLevel == StorageLevel.NONE) { + input.persist(StorageLevel.MEMORY_AND_DISK) + true + } else { + false + } + + // Prepare periodic checkpointers + val predErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)]( + treeStrategy.getCheckpointInterval, input.sparkContext) + val validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)]( + treeStrategy.getCheckpointInterval, input.sparkContext) + + timer.stop("init") + + logDebug("##########") + logDebug("Building tree 0") + logDebug("##########") + + // Initialize tree + timer.start("building tree 0") + val firstTree = new DecisionTreeRegressor() + val firstTreeModel = firstTree.train(input, treeStrategy) + val firstTreeWeight = 1.0 + baseLearners(0) = firstTreeModel + baseLearnerWeights(0) = firstTreeWeight + + var predError: RDD[(Double, Double)] = + computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss) + predErrorCheckpointer.update(predError) + logDebug("error of gbt = " + predError.values.mean()) + + // Note: A model of type regression is used since we require raw prediction + timer.stop("building tree 0") + + var validatePredError: RDD[(Double, Double)] = + computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss) + if (validate) validatePredErrorCheckpointer.update(validatePredError) + var bestValidateError = if (validate) validatePredError.values.mean() else 0.0 + var bestM = 1 + + var m = 1 + var doneLearning = false + while (m < numIterations && !doneLearning) { + // Update data with pseudo-residuals + val data = predError.zip(input).map { case ((pred, _), point) => + LabeledPoint(-loss.gradient(pred, point.label), point.features) + } + + timer.start(s"building tree $m") + logDebug("###################################################") + logDebug("Gradient boosting tree iteration " + m) + logDebug("###################################################") + val dt = new DecisionTreeRegressor() + val model = dt.train(data, treeStrategy) + timer.stop(s"building tree $m") + // Update partial model + baseLearners(m) = model + // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError. + // Technically, the weight should be optimized for the particular loss. + // However, the behavior should be reasonable, though not optimal. + baseLearnerWeights(m) = learningRate + + predError = updatePredictionError( + input, predError, baseLearnerWeights(m), baseLearners(m), loss) + predErrorCheckpointer.update(predError) + logDebug("error of gbt = " + predError.values.mean()) + + if (validate) { + // Stop training early if + // 1. Reduction in error is less than the validationTol or + // 2. If the error increases, that is if the model is overfit. + // We want the model returned corresponding to the best validation error. + + validatePredError = updatePredictionError( + validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss) + validatePredErrorCheckpointer.update(validatePredError) + val currentValidateError = validatePredError.values.mean() + if (bestValidateError - currentValidateError < validationTol * Math.max( + currentValidateError, 0.01)) { + doneLearning = true + } else if (currentValidateError < bestValidateError) { + bestValidateError = currentValidateError + bestM = m + 1 + } + } + m += 1 + } + + timer.stop("total") + + logInfo("Internal timing for DecisionTree:") + logInfo(s"$timer") + + predErrorCheckpointer.deleteAllCheckpoints() + validatePredErrorCheckpointer.deleteAllCheckpoints() + if (persistedInput) input.unpersist() + + if (validate) { + (baseLearners.slice(0, bestM), baseLearnerWeights.slice(0, bestM)) + } else { + (baseLearners, baseLearnerWeights) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 132dc174a894e..53935f328ab8a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -19,6 +19,7 @@ package org.apache.spark.mllib.api.python import java.io.OutputStream import java.nio.{ByteBuffer, ByteOrder} +import java.nio.charset.StandardCharsets import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} import scala.collection.JavaConverters._ @@ -1226,7 +1227,7 @@ private[spark] object SerDe extends Serializable { def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = { if (obj == this) { out.write(Opcodes.GLOBAL) - out.write((module + "\n" + name + "\n").getBytes) + out.write((module + "\n" + name + "\n").getBytes(StandardCharsets.UTF_8)) } else { pickler.save(this) // it will be memorized by Pickler saveState(obj, out, pickler) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala index f8de4e2220c4d..c8ec0c16851f1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala @@ -83,7 +83,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])] /** * Returns the mean average precision (MAP) of all the queries. * If a query has an empty ground truth set, the average precision will be zero and a log - * warining is generated. + * warning is generated. */ lazy val meanAveragePrecision: Double = { predictionAndLabels.map { case (pred, lab) => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala index 07eb750b06a3b..790d6b101ee5f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala @@ -58,7 +58,7 @@ class AssociationRules private[fpm] ( /** * Computes the association rules with confidence above [[minConfidence]]. * @param freqItemsets frequent itemset model obtained from [[FPGrowth]] - * @return a [[Set[Rule[Item]]] containing the assocation rules. + * @return a [[Set[Rule[Item]]] containing the association rules. * */ @Since("1.5.0") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala index f31ed2aa90a64..145dc22b7428e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala @@ -74,7 +74,7 @@ import org.apache.spark.storage.StorageLevel * * TODO: Move this out of MLlib? */ -private[mllib] class PeriodicRDDCheckpointer[T]( +private[spark] class PeriodicRDDCheckpointer[T]( checkpointInterval: Int, sc: SparkContext) extends PeriodicCheckpointer[RDD[T]](checkpointInterval, sc) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala index ae1faf6a2d841..6f0b0a9bc6004 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.linalg.distributed import scala.collection.mutable.ArrayBuffer -import breeze.linalg.{DenseMatrix => BDM} +import breeze.linalg.{DenseMatrix => BDM, Matrix => BM} import org.apache.spark.{Logging, Partitioner, SparkException} import org.apache.spark.annotation.Since @@ -317,40 +317,72 @@ class BlockMatrix @Since("1.3.0") ( } /** - * Adds two block matrices together. The matrices must have the same size and matching - * `rowsPerBlock` and `colsPerBlock` values. If one of the blocks that are being added are - * instances of [[SparseMatrix]], the resulting sub matrix will also be a [[SparseMatrix]], even - * if it is being added to a [[DenseMatrix]]. If two dense matrices are added, the output will - * also be a [[DenseMatrix]]. + * For given matrices `this` and `other` of compatible dimensions and compatible block dimensions, + * it applies a binary function on their corresponding blocks. + * + * @param other The second BlockMatrix argument for the operator specified by `binMap` + * @param binMap A function taking two breeze matrices and returning a breeze matrix + * @return A [[BlockMatrix]] whose blocks are the results of a specified binary map on blocks + * of `this` and `other`. + * Note: `blockMap` ONLY works for `add` and `subtract` methods and it does not support + * operators such as (a, b) => -a + b + * TODO: Make the use of zero matrices more storage efficient. */ - @Since("1.3.0") - def add(other: BlockMatrix): BlockMatrix = { + private[mllib] def blockMap( + other: BlockMatrix, + binMap: (BM[Double], BM[Double]) => BM[Double]): BlockMatrix = { require(numRows() == other.numRows(), "Both matrices must have the same number of rows. " + s"A.numRows: ${numRows()}, B.numRows: ${other.numRows()}") require(numCols() == other.numCols(), "Both matrices must have the same number of columns. " + s"A.numCols: ${numCols()}, B.numCols: ${other.numCols()}") if (rowsPerBlock == other.rowsPerBlock && colsPerBlock == other.colsPerBlock) { - val addedBlocks = blocks.cogroup(other.blocks, createPartitioner()) + val newBlocks = blocks.cogroup(other.blocks, createPartitioner()) .map { case ((blockRowIndex, blockColIndex), (a, b)) => if (a.size > 1 || b.size > 1) { throw new SparkException("There are multiple MatrixBlocks with indices: " + s"($blockRowIndex, $blockColIndex). Please remove them.") } if (a.isEmpty) { - new MatrixBlock((blockRowIndex, blockColIndex), b.head) + val zeroBlock = BM.zeros[Double](b.head.numRows, b.head.numCols) + val result = binMap(zeroBlock, b.head.toBreeze) + new MatrixBlock((blockRowIndex, blockColIndex), Matrices.fromBreeze(result)) } else if (b.isEmpty) { new MatrixBlock((blockRowIndex, blockColIndex), a.head) } else { - val result = a.head.toBreeze + b.head.toBreeze + val result = binMap(a.head.toBreeze, b.head.toBreeze) new MatrixBlock((blockRowIndex, blockColIndex), Matrices.fromBreeze(result)) } } - new BlockMatrix(addedBlocks, rowsPerBlock, colsPerBlock, numRows(), numCols()) + new BlockMatrix(newBlocks, rowsPerBlock, colsPerBlock, numRows(), numCols()) } else { - throw new SparkException("Cannot add matrices with different block dimensions") + throw new SparkException("Cannot perform on matrices with different block dimensions") } } + /** + * Adds the given block matrix `other` to `this` block matrix: `this + other`. + * The matrices must have the same size and matching `rowsPerBlock` and `colsPerBlock` + * values. If one of the blocks that are being added are instances of [[SparseMatrix]], + * the resulting sub matrix will also be a [[SparseMatrix]], even if it is being added + * to a [[DenseMatrix]]. If two dense matrices are added, the output will also be a + * [[DenseMatrix]]. + */ + @Since("1.3.0") + def add(other: BlockMatrix): BlockMatrix = + blockMap(other, (x: BM[Double], y: BM[Double]) => x + y) + + /** + * Subtracts the given block matrix `other` from `this` block matrix: `this - other`. + * The matrices must have the same size and matching `rowsPerBlock` and `colsPerBlock` + * values. If one of the blocks that are being subtracted are instances of [[SparseMatrix]], + * the resulting sub matrix will also be a [[SparseMatrix]], even if it is being subtracted + * from a [[DenseMatrix]]. If two dense matrices are subtracted, the output will also be a + * [[DenseMatrix]]. + */ + @Since("2.0.0") + def subtract(other: BlockMatrix): BlockMatrix = + blockMap(other, (x: BM[Double], y: BM[Double]) => x - y) + /** Block (i,j) --> Set of destination partitions */ private type BlockDestinations = Map[(Int, Int), Set[Int]] diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 7da82c862a2b1..e754e74492755 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -89,6 +89,7 @@ object LinearRegressionModel extends Loader[LinearRegressionModel] { class LinearRegressionWithSGD private[mllib] ( private var stepSize: Double, private var numIterations: Int, + private var regParam: Double, private var miniBatchFraction: Double) extends GeneralizedLinearAlgorithm[LinearRegressionModel] with Serializable { @@ -98,6 +99,7 @@ class LinearRegressionWithSGD private[mllib] ( override val optimizer = new GradientDescent(gradient, updater) .setStepSize(stepSize) .setNumIterations(numIterations) + .setRegParam(regParam) .setMiniBatchFraction(miniBatchFraction) /** @@ -105,7 +107,7 @@ class LinearRegressionWithSGD private[mllib] ( * numIterations: 100, miniBatchFraction: 1.0}. */ @Since("0.8.0") - def this() = this(1.0, 100, 1.0) + def this() = this(1.0, 100, 0.0, 1.0) override protected[mllib] def createModel(weights: Vector, intercept: Double) = { new LinearRegressionModel(weights, intercept) @@ -141,7 +143,7 @@ object LinearRegressionWithSGD { stepSize: Double, miniBatchFraction: Double, initialWeights: Vector): LinearRegressionModel = { - new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction) + new LinearRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction) .run(input, initialWeights) } @@ -163,7 +165,7 @@ object LinearRegressionWithSGD { numIterations: Int, stepSize: Double, miniBatchFraction: Double): LinearRegressionModel = { - new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction).run(input) + new LinearRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction).run(input) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala index fe2a46b9eecc7..e8f4422fd4b8f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala @@ -43,6 +43,7 @@ import org.apache.spark.mllib.linalg.Vector class StreamingLinearRegressionWithSGD private[mllib] ( private var stepSize: Double, private var numIterations: Int, + private var regParam: Double, private var miniBatchFraction: Double) extends StreamingLinearAlgorithm[LinearRegressionModel, LinearRegressionWithSGD] with Serializable { @@ -54,10 +55,10 @@ class StreamingLinearRegressionWithSGD private[mllib] ( * (see `StreamingLinearAlgorithm`) */ @Since("1.1.0") - def this() = this(0.1, 50, 1.0) + def this() = this(0.1, 50, 0.0, 1.0) @Since("1.1.0") - val algorithm = new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction) + val algorithm = new LinearRegressionWithSGD(stepSize, numIterations, regParam, miniBatchFraction) protected var model: Option[LinearRegressionModel] = None @@ -71,8 +72,17 @@ class StreamingLinearRegressionWithSGD private[mllib] ( } /** - * Set the number of iterations of gradient descent to run per update. Default: 50. + * Set the regularization parameter. Default: 0.0. */ + @Since("2.0.0") + def setRegParam(regParam: Double): this.type = { + this.algorithm.optimizer.setRegParam(regParam) + this + } + + /** + * Set the number of iterations of gradient descent to run per update. Default: 50. + */ @Since("1.1.0") def setNumIterations(numIterations: Int): this.type = { this.algorithm.optimizer.setNumIterations(numIterations) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index 0b118a76733fd..d8405d13ce904 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -59,7 +59,7 @@ case class BoostingStrategy @Since("1.4.0") ( * Check validity of parameters. * Throws exception if invalid. */ - private[tree] def assertValid(): Unit = { + private[spark] def assertValid(): Unit = { treeStrategy.algo match { case Classification => require(treeStrategy.numClasses == 2, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 9e3e50192d507..8a0907564e728 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -133,7 +133,7 @@ class Strategy @Since("1.3.0") ( * Check validity of parameters. * Throws exception if invalid. */ - private[tree] def assertValid(): Unit = { + private[spark] def assertValid(): Unit = { algo match { case Classification => require(numClasses >= 2, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala index 48a4e38a346d6..9b60d018d0eda 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala @@ -45,7 +45,7 @@ object AbsoluteError extends Loss { if (label - prediction < 0) 1.0 else -1.0 } - override private[mllib] def computeError(prediction: Double, label: Double): Double = { + override private[spark] def computeError(prediction: Double, label: Double): Double = { val err = label - prediction math.abs(err) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala index b88743c0dbab6..5d92ce495b04d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala @@ -47,7 +47,7 @@ object LogLoss extends Loss { - 4.0 * label / (1.0 + math.exp(2.0 * label * prediction)) } - override private[mllib] def computeError(prediction: Double, label: Double): Double = { + override private[spark] def computeError(prediction: Double, label: Double): Double = { val margin = 2.0 * label * prediction // The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable. 2.0 * MLUtils.log1pExp(-margin) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala index 687cde325ffed..de14ddf024d75 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala @@ -61,5 +61,5 @@ trait Loss extends Serializable { * @param label True label. * @return Measure of model error on datapoint. */ - private[mllib] def computeError(prediction: Double, label: Double): Double + private[spark] def computeError(prediction: Double, label: Double): Double } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala index cb97f6fd29d95..4eb6810c46b20 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala @@ -45,7 +45,7 @@ object SquaredError extends Loss { - 2.0 * (label - prediction) } - override private[mllib] def computeError(prediction: Double, label: Double): Double = { + override private[spark] def computeError(prediction: Double, label: Double): Double = { val err = label - prediction err * err } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala index 240781bcd335b..58fd010e4905f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala @@ -158,7 +158,7 @@ object LinearDataGenerator { /** * Generate an RDD containing sample data for Linear Regression models - including Ridge, Lasso, - * and uregularized variants. + * and unregularized variants. * * @param sc SparkContext to be used for generating the RDD. * @param nexamples Number of examples that will be contained in the RDD. diff --git a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java index b8ddf907d05ad..1c18b2b266fef 100644 --- a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java @@ -19,8 +19,8 @@ import java.io.File; import java.io.IOException; +import java.nio.charset.StandardCharsets; -import com.google.common.base.Charsets; import com.google.common.io.Files; import org.junit.After; @@ -55,7 +55,7 @@ public void setUp() throws IOException { tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource"); File file = new File(tempDir, "part-00000"); String s = "1 1:1.0 3:2.0 5:3.0\n0\n0 2:4.0 4:5.0 6:6.0"; - Files.write(s, file, Charsets.US_ASCII); + Files.write(s, file, StandardCharsets.UTF_8); path = tempDir.toURI().toString(); } diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala index 84fc08be09ee7..114a238462a3d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.ml.source.libsvm -import java.io.{File, IOException} +import java.io.File +import java.nio.charset.StandardCharsets -import com.google.common.base.Charsets import com.google.common.io.Files import org.apache.spark.{SparkException, SparkFunSuite} @@ -42,7 +42,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { """.stripMargin tempDir = Utils.createTempDir() val file = new File(tempDir, "part-00000") - Files.write(lines, file, Charsets.US_ASCII) + Files.write(lines, file, StandardCharsets.UTF_8) path = tempDir.toURI.toString } @@ -88,7 +88,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { val df = sqlContext.read.format("libsvm").load(path) val tempDir2 = Utils.createTempDir() val writepath = tempDir2.toURI.toString - // TODO: Remove requirement to coalesce by supporting mutiple reads. + // TODO: Remove requirement to coalesce by supporting multiple reads. df.coalesce(1).write.format("libsvm").mode(SaveMode.Overwrite).save(writepath) val df2 = sqlContext.read.format("libsvm").load(writepath) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index cea0adc55c076..28fada7053d65 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -496,7 +496,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w * features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) * weights = coef(glmnet(features,label, family="multinomial", alpha = 0, lambda = 0)) * - * The model weights of mutinomial logstic regression in R have `K` set of linear predictors + * The model weights of multinomial logistic regression in R have `K` set of linear predictors * for `K` classes classification problem; however, only `K-1` set is required if the first * outcome is chosen as a "pivot", and the other `K-1` outcomes are separately regressed against * the pivot outcome. This can be done by subtracting the first weights from those `K-1` set diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala index d91ba8a6fdb72..f737d2c51a262 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala @@ -192,6 +192,49 @@ class BlockMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { assert(sparseBM.add(sparseBM).toBreeze() === sparseBM.add(denseBM).toBreeze()) } + test("subtract") { + val blocks: Seq[((Int, Int), Matrix)] = Seq( + ((0, 0), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 2.0))), + ((0, 1), new DenseMatrix(2, 2, Array(0.0, 1.0, 0.0, 0.0))), + ((1, 0), new DenseMatrix(2, 2, Array(3.0, 0.0, 1.0, 1.0))), + ((1, 1), new DenseMatrix(2, 2, Array(1.0, 2.0, 0.0, 1.0))), + ((2, 0), new DenseMatrix(1, 2, Array(1.0, 0.0))), // Added block that doesn't exist in A + ((2, 1), new DenseMatrix(1, 2, Array(1.0, 5.0)))) + val rdd = sc.parallelize(blocks, numPartitions) + val B = new BlockMatrix(rdd, rowPerPart, colPerPart) + + val expected = BDM( + (0.0, 0.0, 0.0, 0.0), + (0.0, 0.0, 0.0, 0.0), + (0.0, 0.0, 0.0, 0.0), + (0.0, 0.0, 0.0, 0.0), + (-1.0, 0.0, 0.0, 0.0)) + + val AsubtractB = gridBasedMat.subtract(B) + assert(AsubtractB.numRows() === m) + assert(AsubtractB.numCols() === B.numCols()) + assert(AsubtractB.toBreeze() === expected) + + val C = new BlockMatrix(rdd, rowPerPart, colPerPart, m, n + 1) // columns don't match + intercept[IllegalArgumentException] { + gridBasedMat.subtract(C) + } + val largerBlocks: Seq[((Int, Int), Matrix)] = Seq( + ((0, 0), new DenseMatrix(4, 4, new Array[Double](16))), + ((1, 0), new DenseMatrix(1, 4, Array(1.0, 0.0, 1.0, 5.0)))) + val C2 = new BlockMatrix(sc.parallelize(largerBlocks, numPartitions), 4, 4, m, n) + intercept[SparkException] { // partitioning doesn't match + gridBasedMat.subtract(C2) + } + // subtracting BlockMatrices composed of SparseMatrices + val sparseBlocks = for (i <- 0 until 4) yield ((i / 2, i % 2), SparseMatrix.speye(4)) + val denseBlocks = for (i <- 0 until 4) yield ((i / 2, i % 2), DenseMatrix.eye(4)) + val sparseBM = new BlockMatrix(sc.makeRDD(sparseBlocks, 4), 4, 4, 8, 8) + val denseBM = new BlockMatrix(sc.makeRDD(denseBlocks, 4), 4, 4, 8, 8) + + assert(sparseBM.subtract(sparseBM).toBreeze() === sparseBM.subtract(denseBM).toBreeze()) + } + test("multiply") { // identity matrix val blocks: Seq[((Int, Int), Matrix)] = Seq( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala index 142b90e764a7c..46fcebe132749 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala @@ -144,7 +144,7 @@ class HypothesisTestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(chi.size === numCols) assert(chi(1000) != null) // SPARK-3087 - // Detect continous features or labels + // Detect continuous features or labels val random = new Random(11L) val continuousLabel = Seq.fill(100000)(LabeledPoint(random.nextDouble(), Vectors.dense(random.nextInt(2)))) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index dca8ea815aa6a..5518bdf527c8a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -1075,7 +1075,7 @@ object DecisionTreeSuite extends SparkFunSuite { assert(a.isLeaf === b.isLeaf) assert(a.split === b.split) (a.stats, b.stats) match { - // TODO: Check other fields besides the infomation gain. + // TODO: Check other fields besides the information gain. case (Some(aStats), Some(bStats)) => assert(aStats.gain === bStats.gain) case (None, None) => case _ => throw new AssertionError( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index 70219e9ad9d3e..e542f21a1802c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -18,11 +18,11 @@ package org.apache.spark.mllib.util import java.io.File +import java.nio.charset.StandardCharsets import scala.io.Source import breeze.linalg.{squaredDistance => breezeSquaredDistance} -import com.google.common.base.Charsets import com.google.common.io.Files import org.apache.spark.SparkException @@ -84,7 +84,7 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { """.stripMargin val tempDir = Utils.createTempDir() val file = new File(tempDir.getPath, "part-00000") - Files.write(lines, file, Charsets.US_ASCII) + Files.write(lines, file, StandardCharsets.UTF_8) val path = tempDir.toURI.toString val pointsWithNumFeatures = loadLibSVMFile(sc, path, 6).collect() @@ -117,7 +117,7 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { """.stripMargin val tempDir = Utils.createTempDir() val file = new File(tempDir.getPath, "part-00000") - Files.write(lines, file, Charsets.US_ASCII) + Files.write(lines, file, StandardCharsets.UTF_8) val path = tempDir.toURI.toString intercept[SparkException] { @@ -134,7 +134,7 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { """.stripMargin val tempDir = Utils.createTempDir() val file = new File(tempDir.getPath, "part-00000") - Files.write(lines, file, Charsets.US_ASCII) + Files.write(lines, file, StandardCharsets.UTF_8) val path = tempDir.toURI.toString intercept[SparkException] { diff --git a/pom.xml b/pom.xml index ea5da3dc4c6a2..0faa691c5e78b 100644 --- a/pom.xml +++ b/pom.xml @@ -101,14 +101,6 @@ sql/hive external/docker-integration-tests assembly - external/twitter - external/flume - external/flume-sink - external/flume-assembly - external/akka - external/mqtt - external/mqtt-assembly - external/zeromq examples repl launcher @@ -119,8 +111,6 @@ UTF-8 UTF-8 - com.typesafe.akka - 2.3.11 1.7 3.3.9 spark @@ -133,7 +123,6 @@ ${hadoop.version} 0.98.17-hadoop2 hbase - 1.6.0 3.4.5 2.4.0 org.spark-project.hive @@ -197,7 +186,6 @@ during compilation if the dependency is transivite (e.g. "graphx/" depending on "core/" and needing Hadoop classes in the classpath to compile). --> - compile compile compile compile @@ -511,37 +499,6 @@ ${protobuf.version} ${hadoop.deps.scope} - - ${akka.group} - akka-actor_${scala.binary.version} - ${akka.version} - - - ${akka.group} - akka-remote_${scala.binary.version} - ${akka.version} - - - ${akka.group} - akka-slf4j_${scala.binary.version} - ${akka.version} - - - ${akka.group} - akka-testkit_${scala.binary.version} - ${akka.version} - - - ${akka.group} - akka-zeromq_${scala.binary.version} - ${akka.version} - - - ${akka.group} - akka-actor_${scala.binary.version} - - - org.apache.mesos mesos @@ -1630,46 +1587,6 @@ ${hive.parquet.version} compile - - org.apache.flume - flume-ng-core - ${flume.version} - ${flume.deps.scope} - - - io.netty - netty - - - org.apache.flume - flume-ng-auth - - - org.apache.thrift - libthrift - - - org.mortbay.jetty - servlet-api - - - - - org.apache.flume - flume-ng-sdk - ${flume.version} - ${flume.deps.scope} - - - io.netty - netty - - - org.apache.thrift - libthrift - - - org.apache.calcite calcite-core @@ -2521,9 +2438,6 @@ maven does not complain when they're provided on the command line for a sub-module that does not have them. --> - - flume-provided - hadoop-provided diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 567a717b9d24b..2a4a874fef8bb 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -192,7 +192,7 @@ object MimaExcludes { ) ++ Seq( // SPARK-12896 Send only accumulator updates to driver, not TaskMetrics ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.Accumulable.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.Accumulator.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.Accumulator.this"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.Accumulator.initialValue") ) ++ Seq( // SPARK-12692 Scala style: Fix the style violation (Space before "," or ":") @@ -314,10 +314,14 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrame"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrame$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.LegacyFunctions"), ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.MultilabelMetrics.this"), ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictions"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictions") + ) ++ Seq( + // [SPARK-13686][MLLIB][STREAMING] Add a constructor parameter `reqParam` to (Streaming)LinearRegressionWithSGD + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.regression.LinearRegressionWithSGD.this") ) case v if v.startsWith("1.6") => Seq( @@ -334,7 +338,7 @@ object MimaExcludes { excludePackage("org.apache.spark.sql.columnar"), // The shuffle package is considered private. excludePackage("org.apache.spark.shuffle"), - // The collections utlities are considered pricate. + // The collections utilities are considered private. excludePackage("org.apache.spark.util.collection") ) ++ MimaBuild.excludeSparkClass("streaming.flume.FlumeTestUtils") ++ @@ -639,7 +643,7 @@ object MimaExcludes { Seq( MimaBuild.excludeSparkPackage("deploy"), MimaBuild.excludeSparkPackage("ml"), - // SPARK-7910 Adding a method to get the partioner to JavaRDD, + // SPARK-7910 Adding a method to get the partitioner to JavaRDD, ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitioner"), // SPARK-5922 Adding a generalized diff(other: RDD[(VertexId, VD)]) to VertexRDD ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.diff"), @@ -657,7 +661,7 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.scheduler.OutputCommitCoordinator$OutputCommitCoordinatorEndpoint") ) ++ Seq( - // SPARK-4655 - Making Stage an Abstract class broke binary compatility even though + // SPARK-4655 - Making Stage an Abstract class broke binary compatibility even though // the stage class is defined as private[spark] ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.scheduler.Stage") ) ++ Seq( diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index e74fb174725d3..d7519e82b8706 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -26,7 +26,6 @@ import sbt.Classpaths.publishTask import sbt.Keys._ import sbtunidoc.Plugin.UnidocKeys.unidocGenjavadocVersion import com.typesafe.sbt.pom.{PomBuild, SbtPomKeys} -import net.virtualvoid.sbt.graph.Plugin.graphSettings import spray.revolver.RevolverPlugin._ @@ -39,11 +38,9 @@ object BuildCommons { ).map(ProjectRef(buildLocation, _)) val streamingProjects@Seq( - streaming, streamingFlumeSink, streamingFlume, streamingAkka, streamingKafka, streamingMqtt, - streamingTwitter, streamingZeromq + streaming, streamingKafka ) = Seq( - "streaming", "streaming-flume-sink", "streaming-flume", "streaming-akka", "streaming-kafka", - "streaming-mqtt", "streaming-twitter", "streaming-zeromq" + "streaming", "streaming-kafka" ).map(ProjectRef(buildLocation, _)) val allProjects@Seq( @@ -58,8 +55,8 @@ object BuildCommons { Seq("yarn", "java8-tests", "ganglia-lgpl", "streaming-kinesis-asl", "docker-integration-tests").map(ProjectRef(buildLocation, _)) - val assemblyProjects@Seq(assembly, examples, networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingMqttAssembly, streamingKinesisAslAssembly) = - Seq("assembly", "examples", "network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly", "streaming-mqtt-assembly", "streaming-kinesis-asl-assembly") + val assemblyProjects@Seq(assembly, examples, networkYarn, streamingKafkaAssembly, streamingKinesisAslAssembly) = + Seq("assembly", "examples", "network-yarn", "streaming-kafka-assembly", "streaming-kinesis-asl-assembly") .map(ProjectRef(buildLocation, _)) val tools = ProjectRef(buildLocation, "tools") @@ -144,7 +141,7 @@ object SparkBuild extends PomBuild { "org.spark-project" %% "genjavadoc-plugin" % unidocGenjavadocVersion.value cross CrossVersion.full), scalacOptions <+= target.map(t => "-P:genjavadoc:out=" + (t / "java"))) - lazy val sharedSettings = graphSettings ++ sparkGenjavadocSettings ++ Seq ( + lazy val sharedSettings = sparkGenjavadocSettings ++ Seq ( javaHome := sys.env.get("JAVA_HOME") .orElse(sys.props.get("java.home").map { p => new File(p).getParentFile().getAbsolutePath() }) .map(file), @@ -241,16 +238,15 @@ object SparkBuild extends PomBuild { /* Enable shared settings on all projects */ (allProjects ++ optionallyEnabledProjects ++ assemblyProjects ++ Seq(spark, tools)) .foreach(enable(sharedSettings ++ DependencyOverrides.settings ++ - ExcludedDependencies.settings ++ Revolver.settings)) + ExcludedDependencies.settings)) /* Enable tests settings for all projects except examples, assembly and tools */ (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings)) - // TODO: remove streamingAkka and sketch from this list after 2.0.0 allProjects.filterNot { x => Seq( spark, hive, hiveThriftServer, catalyst, repl, networkCommon, networkShuffle, networkYarn, - unsafe, streamingAkka, testTags, sketch + unsafe, testTags, sketch ).contains(x) }.foreach { x => enable(MimaBuild.mimaSettings(sparkHome, x))(x) @@ -262,9 +258,6 @@ object SparkBuild extends PomBuild { /* Enable Assembly for all assembly projects */ assemblyProjects.foreach(enable(Assembly.settings)) - /* Enable Assembly for streamingMqtt test */ - enable(inConfig(Test)(Assembly.settings))(streamingMqtt) - /* Package pyspark artifacts in a separate zip file for YARN. */ enable(PySparkAssembly.settings)(assembly) @@ -280,8 +273,6 @@ object SparkBuild extends PomBuild { /* Hive console settings */ enable(Hive.settings)(hive) - enable(Flume.settings)(streamingFlumeSink) - enable(Java8TestSettings.settings)(java8Tests) enable(DockerIntegrationTests.settings)(dockerIntegrationTests) @@ -347,10 +338,6 @@ object Unsafe { ) } -object Flume { - lazy val settings = sbtavro.SbtAvro.avroSettings -} - object DockerIntegrationTests { // This serves to override the override specified in DependencyOverrides: lazy val settings = Seq( @@ -388,10 +375,6 @@ object OldDeps { name := "old-deps", scalaVersion := "2.10.5", libraryDependencies := Seq( - "spark-streaming-mqtt", - "spark-streaming-zeromq", - "spark-streaming-flume", - "spark-streaming-twitter", "spark-streaming", "spark-mllib", "spark-graphx", @@ -532,7 +515,7 @@ object Assembly { .getOrElse(SbtPomKeys.effectivePom.value.getProperties.get("hadoop.version").asInstanceOf[String]) }, jarName in assembly <<= (version, moduleName, hadoopVersion) map { (v, mName, hv) => - if (mName.contains("streaming-flume-assembly") || mName.contains("streaming-kafka-assembly") || mName.contains("streaming-mqtt-assembly") || mName.contains("streaming-kinesis-asl-assembly")) { + if (mName.contains("streaming-kafka-assembly") || mName.contains("streaming-kinesis-asl-assembly")) { // This must match the same name used in maven (see external/kafka-assembly/pom.xml) s"${mName}-${v}.jar" } else { @@ -630,7 +613,6 @@ object Unidoc { private def ignoreUndocumentedPackages(packages: Seq[Seq[File]]): Seq[Seq[File]] = { packages .map(_.filterNot(_.getName.contains("$"))) - .map(_.filterNot(_.getCanonicalPath.contains("akka"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/deploy"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/examples"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/memory"))) @@ -651,9 +633,9 @@ object Unidoc { publish := {}, unidocProjectFilter in(ScalaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, testTags), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, yarn, testTags), unidocProjectFilter in(JavaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, testTags), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, yarn, testTags), // Skip actual catalyst, but include the subproject. // Catalyst is not public API and contains quasiquotes which break scaladoc. @@ -672,8 +654,7 @@ object Unidoc { "-public", "-group", "Core Java API", packageList("api.java", "api.java.function"), "-group", "Spark Streaming", packageList( - "streaming.api.java", "streaming.flume", "streaming.akka", "streaming.kafka", - "streaming.mqtt", "streaming.twitter", "streaming.zeromq", "streaming.kinesis" + "streaming.api.java", "streaming.kafka", "streaming.kinesis" ), "-group", "MLlib", packageList( "mllib.classification", "mllib.clustering", "mllib.evaluation.binary", "mllib.linalg", @@ -773,7 +754,6 @@ object TestSettings { scalacOptions in (Compile, doc) := Seq( "-groups", "-skip-packages", Seq( - "akka", "org.apache.spark.api.python", "org.apache.spark.network", "org.apache.spark.deploy", diff --git a/project/build.properties b/project/build.properties index 86ca8755820a4..1e38156e0b577 100644 --- a/project/build.properties +++ b/project/build.properties @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -sbt.version=0.13.9 +sbt.version=0.13.11 diff --git a/project/plugins.sbt b/project/plugins.sbt index 822a7c4a82d5e..eeca94a47ce79 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,14 +1,14 @@ addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.11.2") -addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "2.2.0") +addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "4.0.0") addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.6.0") -addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.7.4") +addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.8.2") addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.8.0") -addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6") +addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.9") addSbtPlugin("com.alpinenow" % "junit_xml_listener" % "0.5.1") @@ -16,7 +16,7 @@ addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.3") addSbtPlugin("com.cavorite" % "sbt-avro" % "0.3.2") -addSbtPlugin("io.spray" % "sbt-revolver" % "0.7.2") +addSbtPlugin("io.spray" % "sbt-revolver" % "0.8.0") libraryDependencies += "org.ow2.asm" % "asm" % "5.0.3" diff --git a/python/docs/Makefile b/python/docs/Makefile index b6d24d8599cf7..903009790ba3b 100644 --- a/python/docs/Makefile +++ b/python/docs/Makefile @@ -7,7 +7,7 @@ SPHINXBUILD = sphinx-build PAPER = BUILDDIR = _build -export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.9.1-src.zip) +export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.9.2-src.zip) # User-friendly check for sphinx-build ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) diff --git a/python/lib/py4j-0.9.1-src.zip b/python/lib/py4j-0.9.2-src.zip similarity index 51% rename from python/lib/py4j-0.9.1-src.zip rename to python/lib/py4j-0.9.2-src.zip index fedde845fda19..881bb759d7823 100644 Binary files a/python/lib/py4j-0.9.1-src.zip and b/python/lib/py4j-0.9.2-src.zip differ diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index d530723ca9803..111ebaafee3e1 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -37,6 +37,8 @@ """ +import types + from pyspark.conf import SparkConf from pyspark.context import SparkContext from pyspark.rdd import RDD @@ -64,6 +66,24 @@ def deco(f): return deco +def copy_func(f, name=None, sinceversion=None, doc=None): + """ + Returns a function with same code, globals, defaults, closure, and + name (or provide a new name). + """ + # See + # http://stackoverflow.com/questions/6527633/how-can-i-make-a-deepcopy-of-a-function-in-python + fn = types.FunctionType(f.__code__, f.__globals__, name or f.__name__, f.__defaults__, + f.__closure__) + # in case f was given attrs (note this dict is a shallow copy): + fn.__dict__.update(f.__dict__) + if doc is not None: + fn.__doc__ = doc + if sinceversion is not None: + fn = since(sinceversion)(fn) + return fn + + # for back compatibility from pyspark.sql import SQLContext, HiveContext, Row diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 3866a49c0b5f6..19ec6fcc5d6dc 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -22,7 +22,7 @@ basestring = str long = int -from pyspark import since +from pyspark import copy_func, since from pyspark.context import SparkContext from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.types import * @@ -337,7 +337,7 @@ def cast(self, dataType): raise TypeError("unexpected type: %s" % type(dataType)) return Column(jc) - astype = cast + astype = copy_func(cast, sinceversion=1.4, doc=":func:`astype` is an alias for :func:`cast`.") @since(1.3) def between(self, lowerBound, upperBound): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 7008e8fadffc3..7e1854c43be3b 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -26,7 +26,7 @@ else: from itertools import imap as map -from pyspark import since +from pyspark import copy_func, since from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer from pyspark.storagelevel import StorageLevel @@ -829,8 +829,6 @@ def filter(self, condition): raise TypeError("condition should be string or Column") return DataFrame(jdf, self.sql_ctx) - where = filter - @ignore_unicode_prefix @since(1.3) def groupBy(self, *cols): @@ -1361,8 +1359,20 @@ def toPandas(self): # Pandas compatibility ########################################################################################## - groupby = groupBy - drop_duplicates = dropDuplicates + groupby = copy_func( + groupBy, + sinceversion=1.4, + doc=":func:`groupby` is an alias for :func:`groupBy`.") + + drop_duplicates = copy_func( + dropDuplicates, + sinceversion=1.4, + doc=":func:`drop_duplicates` is an alias for :func:`dropDuplicates`.") + + where = copy_func( + filter, + sinceversion=1.3, + doc=":func:`where` is an alias for :func:`filter`.") def _to_scala_map(sc, jm): diff --git a/python/pyspark/streaming/flume.py b/python/pyspark/streaming/flume.py deleted file mode 100644 index edd5886a85079..0000000000000 --- a/python/pyspark/streaming/flume.py +++ /dev/null @@ -1,143 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import sys -if sys.version >= "3": - from io import BytesIO -else: - from StringIO import StringIO -from py4j.protocol import Py4JJavaError - -from pyspark.storagelevel import StorageLevel -from pyspark.serializers import PairDeserializer, NoOpSerializer, UTF8Deserializer, read_int -from pyspark.streaming import DStream - -__all__ = ['FlumeUtils', 'utf8_decoder'] - - -def utf8_decoder(s): - """ Decode the unicode as UTF-8 """ - if s is None: - return None - return s.decode('utf-8') - - -class FlumeUtils(object): - - @staticmethod - def createStream(ssc, hostname, port, - storageLevel=StorageLevel.MEMORY_AND_DISK_2, - enableDecompression=False, - bodyDecoder=utf8_decoder): - """ - Create an input stream that pulls events from Flume. - - :param ssc: StreamingContext object - :param hostname: Hostname of the slave machine to which the flume data will be sent - :param port: Port of the slave machine to which the flume data will be sent - :param storageLevel: Storage level to use for storing the received objects - :param enableDecompression: Should netty server decompress input stream - :param bodyDecoder: A function used to decode body (default is utf8_decoder) - :return: A DStream object - """ - jlevel = ssc._sc._getJavaStorageLevel(storageLevel) - helper = FlumeUtils._get_helper(ssc._sc) - jstream = helper.createStream(ssc._jssc, hostname, port, jlevel, enableDecompression) - return FlumeUtils._toPythonDStream(ssc, jstream, bodyDecoder) - - @staticmethod - def createPollingStream(ssc, addresses, - storageLevel=StorageLevel.MEMORY_AND_DISK_2, - maxBatchSize=1000, - parallelism=5, - bodyDecoder=utf8_decoder): - """ - Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. - This stream will poll the sink for data and will pull events as they are available. - - :param ssc: StreamingContext object - :param addresses: List of (host, port)s on which the Spark Sink is running. - :param storageLevel: Storage level to use for storing the received objects - :param maxBatchSize: The maximum number of events to be pulled from the Spark sink - in a single RPC call - :param parallelism: Number of concurrent requests this stream should send to the sink. - Note that having a higher number of requests concurrently being pulled - will result in this stream using more threads - :param bodyDecoder: A function used to decode body (default is utf8_decoder) - :return: A DStream object - """ - jlevel = ssc._sc._getJavaStorageLevel(storageLevel) - hosts = [] - ports = [] - for (host, port) in addresses: - hosts.append(host) - ports.append(port) - helper = FlumeUtils._get_helper(ssc._sc) - jstream = helper.createPollingStream( - ssc._jssc, hosts, ports, jlevel, maxBatchSize, parallelism) - return FlumeUtils._toPythonDStream(ssc, jstream, bodyDecoder) - - @staticmethod - def _toPythonDStream(ssc, jstream, bodyDecoder): - ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) - stream = DStream(jstream, ssc, ser) - - def func(event): - headersBytes = BytesIO(event[0]) if sys.version >= "3" else StringIO(event[0]) - headers = {} - strSer = UTF8Deserializer() - for i in range(0, read_int(headersBytes)): - key = strSer.loads(headersBytes) - value = strSer.loads(headersBytes) - headers[key] = value - body = bodyDecoder(event[1]) - return (headers, body) - return stream.map(func) - - @staticmethod - def _get_helper(sc): - try: - helperClass = sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ - .loadClass("org.apache.spark.streaming.flume.FlumeUtilsPythonHelper") - return helperClass.newInstance() - except Py4JJavaError as e: - # TODO: use --jar once it also work on driver - if 'ClassNotFoundException' in str(e.java_exception): - FlumeUtils._printErrorMsg(sc) - raise - - @staticmethod - def _printErrorMsg(sc): - print(""" -________________________________________________________________________________________________ - - Spark Streaming's Flume libraries not found in class path. Try one of the following. - - 1. Include the Flume library and its dependencies with in the - spark-submit command as - - $ bin/spark-submit --packages org.apache.spark:spark-streaming-flume:%s ... - - 2. Download the JAR of the artifact from Maven Central http://search.maven.org/, - Group Id = org.apache.spark, Artifact Id = spark-streaming-flume-assembly, Version = %s. - Then, include the jar in the spark-submit command as - - $ bin/spark-submit --jars ... - -________________________________________________________________________________________________ - -""" % (sc.version, sc.version)) diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index a70b99249d3a2..02a88699a2886 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -192,13 +192,9 @@ def funcWithMessageHandler(m): @staticmethod def _get_helper(sc): try: - # Use KafkaUtilsPythonHelper to access Scala's KafkaUtils (see SPARK-6027) - helperClass = sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ - .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") - return helperClass.newInstance() - except Py4JJavaError as e: - # TODO: use --jar once it also work on driver - if 'ClassNotFoundException' in str(e.java_exception): + return sc._jvm.org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper() + except TypeError as e: + if str(e) == "'JavaPackage' object is not callable": KafkaUtils._printErrorMsg(sc) raise diff --git a/python/pyspark/streaming/kinesis.py b/python/pyspark/streaming/kinesis.py index e681301681a81..434ce83e1e6f9 100644 --- a/python/pyspark/streaming/kinesis.py +++ b/python/pyspark/streaming/kinesis.py @@ -74,16 +74,14 @@ def createStream(ssc, kinesisAppName, streamName, endpointUrl, regionName, try: # Use KinesisUtilsPythonHelper to access Scala's KinesisUtils - helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\ - .loadClass("org.apache.spark.streaming.kinesis.KinesisUtilsPythonHelper") - helper = helperClass.newInstance() - jstream = helper.createStream(ssc._jssc, kinesisAppName, streamName, endpointUrl, - regionName, initialPositionInStream, jduration, jlevel, - awsAccessKeyId, awsSecretKey) - except Py4JJavaError as e: - if 'ClassNotFoundException' in str(e.java_exception): + helper = ssc._jvm.org.apache.spark.streaming.kinesis.KinesisUtilsPythonHelper() + except TypeError as e: + if str(e) == "'JavaPackage' object is not callable": KinesisUtils._printErrorMsg(ssc.sparkContext) raise + jstream = helper.createStream(ssc._jssc, kinesisAppName, streamName, endpointUrl, + regionName, initialPositionInStream, jduration, jlevel, + awsAccessKeyId, awsSecretKey) stream = DStream(jstream, ssc, NoOpSerializer()) return stream.map(lambda v: decoder(v)) diff --git a/python/pyspark/streaming/mqtt.py b/python/pyspark/streaming/mqtt.py deleted file mode 100644 index 388e9526ba73a..0000000000000 --- a/python/pyspark/streaming/mqtt.py +++ /dev/null @@ -1,73 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from py4j.protocol import Py4JJavaError - -from pyspark.storagelevel import StorageLevel -from pyspark.serializers import UTF8Deserializer -from pyspark.streaming import DStream - -__all__ = ['MQTTUtils'] - - -class MQTTUtils(object): - - @staticmethod - def createStream(ssc, brokerUrl, topic, - storageLevel=StorageLevel.MEMORY_AND_DISK_2): - """ - Create an input stream that pulls messages from a Mqtt Broker. - - :param ssc: StreamingContext object - :param brokerUrl: Url of remote mqtt publisher - :param topic: topic name to subscribe to - :param storageLevel: RDD storage level. - :return: A DStream object - """ - jlevel = ssc._sc._getJavaStorageLevel(storageLevel) - - try: - helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ - .loadClass("org.apache.spark.streaming.mqtt.MQTTUtilsPythonHelper") - helper = helperClass.newInstance() - jstream = helper.createStream(ssc._jssc, brokerUrl, topic, jlevel) - except Py4JJavaError as e: - if 'ClassNotFoundException' in str(e.java_exception): - MQTTUtils._printErrorMsg(ssc.sparkContext) - raise - - return DStream(jstream, ssc, UTF8Deserializer()) - - @staticmethod - def _printErrorMsg(sc): - print(""" -________________________________________________________________________________________________ - - Spark Streaming's MQTT libraries not found in class path. Try one of the following. - - 1. Include the MQTT library and its dependencies with in the - spark-submit command as - - $ bin/spark-submit --packages org.apache.spark:spark-streaming-mqtt:%s ... - - 2. Download the JAR of the artifact from Maven Central http://search.maven.org/, - Group Id = org.apache.spark, Artifact Id = spark-streaming-mqtt-assembly, Version = %s. - Then, include the jar in the spark-submit command as - - $ bin/spark-submit --jars ... -________________________________________________________________________________________________ -""" % (sc.version, sc.version)) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 469c068134a46..eb4696c55d4f5 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -45,8 +45,6 @@ from pyspark.storagelevel import StorageLevel from pyspark.streaming.context import StreamingContext from pyspark.streaming.kafka import Broker, KafkaUtils, OffsetRange, TopicAndPartition -from pyspark.streaming.flume import FlumeUtils -from pyspark.streaming.mqtt import MQTTUtils from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream from pyspark.streaming.listener import StreamingListener @@ -1006,10 +1004,7 @@ class KafkaStreamTests(PySparkStreamingTestCase): def setUp(self): super(KafkaStreamTests, self).setUp() - - kafkaTestUtilsClz = self.ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\ - .loadClass("org.apache.spark.streaming.kafka.KafkaTestUtils") - self._kafkaTestUtils = kafkaTestUtilsClz.newInstance() + self._kafkaTestUtils = self.ssc._jvm.org.apache.spark.streaming.kafka.KafkaTestUtils() self._kafkaTestUtils.setup() def tearDown(self): @@ -1265,216 +1260,6 @@ def getKeyAndDoubleMessage(m): self._validateStreamResult({"aa": 1, "bb": 2, "cc": 3}, stream) -class FlumeStreamTests(PySparkStreamingTestCase): - timeout = 20 # seconds - duration = 1 - - def setUp(self): - super(FlumeStreamTests, self).setUp() - - utilsClz = self.ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ - .loadClass("org.apache.spark.streaming.flume.FlumeTestUtils") - self._utils = utilsClz.newInstance() - - def tearDown(self): - if self._utils is not None: - self._utils.close() - self._utils = None - - super(FlumeStreamTests, self).tearDown() - - def _startContext(self, n, compressed): - # Start the StreamingContext and also collect the result - dstream = FlumeUtils.createStream(self.ssc, "localhost", self._utils.getTestPort(), - enableDecompression=compressed) - result = [] - - def get_output(_, rdd): - for event in rdd.collect(): - if len(result) < n: - result.append(event) - dstream.foreachRDD(get_output) - self.ssc.start() - return result - - def _validateResult(self, input, result): - # Validate both the header and the body - header = {"test": "header"} - self.assertEqual(len(input), len(result)) - for i in range(0, len(input)): - self.assertEqual(header, result[i][0]) - self.assertEqual(input[i], result[i][1]) - - def _writeInput(self, input, compressed): - # Try to write input to the receiver until success or timeout - start_time = time.time() - while True: - try: - self._utils.writeInput(input, compressed) - break - except: - if time.time() - start_time < self.timeout: - time.sleep(0.01) - else: - raise - - def test_flume_stream(self): - input = [str(i) for i in range(1, 101)] - result = self._startContext(len(input), False) - self._writeInput(input, False) - self.wait_for(result, len(input)) - self._validateResult(input, result) - - def test_compressed_flume_stream(self): - input = [str(i) for i in range(1, 101)] - result = self._startContext(len(input), True) - self._writeInput(input, True) - self.wait_for(result, len(input)) - self._validateResult(input, result) - - -class FlumePollingStreamTests(PySparkStreamingTestCase): - timeout = 20 # seconds - duration = 1 - maxAttempts = 5 - - def setUp(self): - utilsClz = \ - self.sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ - .loadClass("org.apache.spark.streaming.flume.PollingFlumeTestUtils") - self._utils = utilsClz.newInstance() - - def tearDown(self): - if self._utils is not None: - self._utils.close() - self._utils = None - - def _writeAndVerify(self, ports): - # Set up the streaming context and input streams - ssc = StreamingContext(self.sc, self.duration) - try: - addresses = [("localhost", port) for port in ports] - dstream = FlumeUtils.createPollingStream( - ssc, - addresses, - maxBatchSize=self._utils.eventsPerBatch(), - parallelism=5) - outputBuffer = [] - - def get_output(_, rdd): - for e in rdd.collect(): - outputBuffer.append(e) - - dstream.foreachRDD(get_output) - ssc.start() - self._utils.sendDatAndEnsureAllDataHasBeenReceived() - - self.wait_for(outputBuffer, self._utils.getTotalEvents()) - outputHeaders = [event[0] for event in outputBuffer] - outputBodies = [event[1] for event in outputBuffer] - self._utils.assertOutput(outputHeaders, outputBodies) - finally: - ssc.stop(False) - - def _testMultipleTimes(self, f): - attempt = 0 - while True: - try: - f() - break - except: - attempt += 1 - if attempt >= self.maxAttempts: - raise - else: - import traceback - traceback.print_exc() - - def _testFlumePolling(self): - try: - port = self._utils.startSingleSink() - self._writeAndVerify([port]) - self._utils.assertChannelsAreEmpty() - finally: - self._utils.close() - - def _testFlumePollingMultipleHosts(self): - try: - port = self._utils.startSingleSink() - self._writeAndVerify([port]) - self._utils.assertChannelsAreEmpty() - finally: - self._utils.close() - - def test_flume_polling(self): - self._testMultipleTimes(self._testFlumePolling) - - def test_flume_polling_multiple_hosts(self): - self._testMultipleTimes(self._testFlumePollingMultipleHosts) - - -class MQTTStreamTests(PySparkStreamingTestCase): - timeout = 20 # seconds - duration = 1 - - def setUp(self): - super(MQTTStreamTests, self).setUp() - - MQTTTestUtilsClz = self.ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ - .loadClass("org.apache.spark.streaming.mqtt.MQTTTestUtils") - self._MQTTTestUtils = MQTTTestUtilsClz.newInstance() - self._MQTTTestUtils.setup() - - def tearDown(self): - if self._MQTTTestUtils is not None: - self._MQTTTestUtils.teardown() - self._MQTTTestUtils = None - - super(MQTTStreamTests, self).tearDown() - - def _randomTopic(self): - return "topic-%d" % random.randint(0, 10000) - - def _startContext(self, topic): - # Start the StreamingContext and also collect the result - stream = MQTTUtils.createStream(self.ssc, "tcp://" + self._MQTTTestUtils.brokerUri(), topic) - result = [] - - def getOutput(_, rdd): - for data in rdd.collect(): - result.append(data) - - stream.foreachRDD(getOutput) - self.ssc.start() - return result - - def test_mqtt_stream(self): - """Test the Python MQTT stream API.""" - sendData = "MQTT demo for spark streaming" - topic = self._randomTopic() - result = self._startContext(topic) - - def retry(): - self._MQTTTestUtils.publishData(topic, sendData) - # Because "publishData" sends duplicate messages, here we should use > 0 - self.assertTrue(len(result) > 0) - self.assertEqual(sendData, result[0]) - - # Retry it because we don't know when the receiver will start. - self._retry_or_timeout(retry) - - def _retry_or_timeout(self, test_func): - start_time = time.time() - while True: - try: - test_func() - break - except: - if time.time() - start_time > self.timeout: - raise - time.sleep(0.01) - - class KinesisStreamTests(PySparkStreamingTestCase): def test_kinesis_stream_api(self): @@ -1498,10 +1283,7 @@ def test_kinesis_stream(self): import random kinesisAppName = ("KinesisStreamTests-%d" % abs(random.randint(0, 10000000))) - kinesisTestUtilsClz = \ - self.sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ - .loadClass("org.apache.spark.streaming.kinesis.KinesisTestUtils") - kinesisTestUtils = kinesisTestUtilsClz.newInstance() + kinesisTestUtils = self.ssc._jvm.org.apache.spark.streaming.kinesis.KinesisTestUtils() try: kinesisTestUtils.createStream() aWSCredentials = kinesisTestUtils.getAWSCredentials() @@ -1566,57 +1348,6 @@ def search_kafka_assembly_jar(): return jars[0] -def search_flume_assembly_jar(): - SPARK_HOME = os.environ["SPARK_HOME"] - flume_assembly_dir = os.path.join(SPARK_HOME, "external/flume-assembly") - jars = search_jar(flume_assembly_dir, "spark-streaming-flume-assembly") - if not jars: - raise Exception( - ("Failed to find Spark Streaming Flume assembly jar in %s. " % flume_assembly_dir) + - "You need to build Spark with " - "'build/sbt assembly/assembly streaming-flume-assembly/assembly' or " - "'build/mvn package' before running this test.") - elif len(jars) > 1: - raise Exception(("Found multiple Spark Streaming Flume assembly JARs: %s; please " - "remove all but one") % (", ".join(jars))) - else: - return jars[0] - - -def search_mqtt_assembly_jar(): - SPARK_HOME = os.environ["SPARK_HOME"] - mqtt_assembly_dir = os.path.join(SPARK_HOME, "external/mqtt-assembly") - jars = search_jar(mqtt_assembly_dir, "spark-streaming-mqtt-assembly") - if not jars: - raise Exception( - ("Failed to find Spark Streaming MQTT assembly jar in %s. " % mqtt_assembly_dir) + - "You need to build Spark with " - "'build/sbt assembly/assembly streaming-mqtt-assembly/assembly' or " - "'build/mvn package' before running this test") - elif len(jars) > 1: - raise Exception(("Found multiple Spark Streaming MQTT assembly JARs: %s; please " - "remove all but one") % (", ".join(jars))) - else: - return jars[0] - - -def search_mqtt_test_jar(): - SPARK_HOME = os.environ["SPARK_HOME"] - mqtt_test_dir = os.path.join(SPARK_HOME, "external/mqtt") - jars = glob.glob( - os.path.join(mqtt_test_dir, "target/scala-*/spark-streaming-mqtt-test-*.jar")) - if not jars: - raise Exception( - ("Failed to find Spark Streaming MQTT test jar in %s. " % mqtt_test_dir) + - "You need to build Spark with " - "'build/sbt assembly/assembly streaming-mqtt/test:assembly'") - elif len(jars) > 1: - raise Exception(("Found multiple Spark Streaming MQTT test JARs: %s; please " - "remove all but one") % (", ".join(jars))) - else: - return jars[0] - - def search_kinesis_asl_assembly_jar(): SPARK_HOME = os.environ["SPARK_HOME"] kinesis_asl_assembly_dir = os.path.join(SPARK_HOME, "external/kinesis-asl-assembly") @@ -1637,24 +1368,18 @@ def search_kinesis_asl_assembly_jar(): if __name__ == "__main__": from pyspark.streaming.tests import * kafka_assembly_jar = search_kafka_assembly_jar() - flume_assembly_jar = search_flume_assembly_jar() - mqtt_assembly_jar = search_mqtt_assembly_jar() - mqtt_test_jar = search_mqtt_test_jar() kinesis_asl_assembly_jar = search_kinesis_asl_assembly_jar() if kinesis_asl_assembly_jar is None: kinesis_jar_present = False - jars = "%s,%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, mqtt_assembly_jar, - mqtt_test_jar) + jars = kafka_assembly_jar else: kinesis_jar_present = True - jars = "%s,%s,%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, mqtt_assembly_jar, - mqtt_test_jar, kinesis_asl_assembly_jar) + jars = "%s,%s" % (kafka_assembly_jar, kinesis_asl_assembly_jar) os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars testcases = [BasicOperationTests, WindowFunctionTests, StreamingContextTests, CheckpointTests, - KafkaStreamTests, FlumeStreamTests, FlumePollingStreamTests, MQTTStreamTests, - StreamingListenerTests] + KafkaStreamTests, StreamingListenerTests] if kinesis_jar_present is True: testcases.append(KinesisStreamTests) diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/Main.scala index 14b448d076d84..5fe5c86289738 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/Main.scala @@ -19,7 +19,12 @@ package org.apache.spark.repl import scala.collection.mutable.Set -object Main { +import org.apache.spark.Logging + +object Main extends Logging { + + initializeLogIfNecessary(true) + private var _interp: SparkILoop = _ def interp = _interp diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala index 999e7ad3ccbaa..a58f4234da14c 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala @@ -27,6 +27,8 @@ import org.apache.spark.sql.SQLContext object Main extends Logging { + initializeLogIfNecessary(true) + val conf = new SparkConf() val rootDir = conf.getOption("spark.repl.classdir").getOrElse(Utils.getLocalDir(conf)) val outputDir = Utils.createTempDir(root = rootDir, namePrefix = "repl") @@ -50,39 +52,27 @@ object Main extends Logging { // Visible for testing private[repl] def doMain(args: Array[String], _interp: SparkILoop): Unit = { interp = _interp + val jars = conf.getOption("spark.jars") + .map(_.replace(",", File.pathSeparator)) + .getOrElse("") val interpArguments = List( "-Yrepl-class-based", "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", - "-classpath", getAddedJars.mkString(File.pathSeparator) + "-classpath", jars ) ++ args.toList val settings = new GenericRunnerSettings(scalaOptionError) settings.processArguments(interpArguments, true) if (!hasErrors) { - if (getMaster == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") interp.process(settings) // Repl starts and goes in loop of R.E.P.L Option(sparkContext).map(_.stop) } } - def getAddedJars: Array[String] = { - val envJars = sys.env.get("ADD_JARS") - if (envJars.isDefined) { - logWarning("ADD_JARS environment variable is deprecated, use --jar spark submit argument instead") - } - val propJars = sys.props.get("spark.jars").flatMap { p => if (p == "") None else Some(p) } - val jars = propJars.orElse(envJars).getOrElse("") - Utils.resolveURIs(jars).split(",").filter(_.nonEmpty) - } - def createSparkContext(): SparkContext = { val execUri = System.getenv("SPARK_EXECUTOR_URI") - val jars = getAddedJars - val conf = new SparkConf() - .setMaster(getMaster) - .setJars(jars) - .setIfMissing("spark.app.name", "Spark shell") + conf.setIfMissing("spark.app.name", "Spark shell") // SparkContext will detect this configuration and register it with the RpcEnv's // file server, setting spark.repl.class.uri to the actual URI for executors to // use. This is sort of ugly but since executors are started as part of SparkContext @@ -115,12 +105,4 @@ object Main extends Logging { sqlContext } - private def getMaster: String = { - val master = { - val envMaster = sys.env.get("MASTER") - val propMaster = sys.props.get("spark.master") - propMaster.orElse(envMaster).getOrElse("local[*]") - } - master - } } diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 239096be79e77..6bee880640ced 100644 --- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -48,8 +48,8 @@ class ReplSuite extends SparkFunSuite { val oldExecutorClasspath = System.getProperty(CONF_EXECUTOR_CLASSPATH) System.setProperty(CONF_EXECUTOR_CLASSPATH, classpath) - System.setProperty("spark.master", master) - Main.doMain(Array("-classpath", classpath), new SparkILoop(in, new PrintWriter(out))) + Main.conf.set("spark.master", master) + Main.doMain(Array("-classpath", classpath), new SparkILoop(in, new PrintWriter(out))) if (oldExecutorClasspath != null) { System.setProperty(CONF_EXECUTOR_CLASSPATH, oldExecutorClasspath) diff --git a/sbin/spark-config.sh b/sbin/spark-config.sh index 0c37985a670b2..97df433a0b675 100755 --- a/sbin/spark-config.sh +++ b/sbin/spark-config.sh @@ -27,4 +27,4 @@ fi export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"${SPARK_HOME}/conf"}" # Add the PySpark classes to the PYTHONPATH: export PYTHONPATH="${SPARK_HOME}/python:${PYTHONPATH}" -export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.9.1-src.zip:${PYTHONPATH}" +export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.9.2-src.zip:${PYTHONPATH}" diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java index 2520c7bb8dae4..01f89112a759b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java @@ -18,6 +18,8 @@ package org.apache.spark.sql.catalyst.parser; +import java.nio.charset.StandardCharsets; + /** * A couple of utility methods that help with parsing ASTs. * @@ -76,7 +78,7 @@ public static String unescapeSQLString(String b) { byte bVal = (byte) ((i3 - '0') + ((i2 - '0') * 8) + ((i1 - '0') * 8 * 8)); byte[] bValArr = new byte[1]; bValArr[0] = bVal; - String tmp = new String(bValArr); + String tmp = new String(bValArr, StandardCharsets.UTF_8); sb.append(tmp); i += 3; continue; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index f108264861eee..1219d4d453e13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -305,7 +305,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. */ def getStruct(i: Int): Row = { - // Product and Row both are recoginized as StructType in a Row + // Product and Row both are recognized as StructType in a Row val t = get(i) if (t.isInstanceOf[Product]) { Row.fromTuple(t.asInstanceOf[Product]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala index 2c7c58e66b855..35884139b6be8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala @@ -17,8 +17,22 @@ package org.apache.spark.sql.catalyst +import org.apache.spark.sql.catalyst.analysis._ + private[spark] trait CatalystConf { def caseSensitiveAnalysis: Boolean + + /** + * Returns the [[Resolver]] for the current configuration, which can be used to determin if two + * identifiers are equal. + */ + def resolver: Resolver = { + if (caseSensitiveAnalysis) { + caseSensitiveResolution + } else { + caseInsensitiveResolution + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index c12b5c20ea7bf..bf07f4557a5b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -29,7 +29,7 @@ import org.apache.spark.util.Utils */ object ScalaReflection extends ScalaReflection { val universe: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe - // Since we are creating a runtime mirror usign the class loader of current thread, + // Since we are creating a runtime mirror using the class loader of current thread, // we need to use def at here. So, every time we call mirror, it is using the // class loader of the current thread. // SPARK-13640: Synchronize this because universe.runtimeMirror is not thread-safe in Scala 2.10. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index ad56c9864979b..9c38dd2ee4e53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -155,7 +155,7 @@ object DecimalPrecision extends Rule[LogicalPlan] { * * Note that technically this is an "optimization" and should go into the optimizer. However, * by the time the optimizer runs, these comparison expressions would be pretty hard to pattern - * match because there are multuple (at least 2) levels of casts involved. + * match because there are multiple (at least 2) levels of casts involved. * * There are a lot more possible rules we can implement, but we don't do them * because we are not sure how common they are. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala index 38c1641f73d9f..2e30d83a60970 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala @@ -96,7 +96,7 @@ import org.apache.spark.sql.types.IntegerType * This rule duplicates the input data by two or more times (# distinct groups + an optional * non-distinct group). This will put quite a bit of memory pressure of the used aggregate and * exchange operators. Keeping the number of distinct groups as low a possible should be priority, - * we could improve this in the current rule by applying more advanced expression cannocalization + * we could improve this in the current rule by applying more advanced expression canonicalization * techniques. */ object DistinctAggregationRewriter extends Rule[LogicalPlan] { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 57bdb164e1a0d..0f85f44ffa768 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -307,7 +307,7 @@ object HiveTypeCoercion { case p @ Equality(left @ TimestampType(), right @ StringType()) => p.makeCopy(Array(left, Cast(right, TimestampType))) - // We should cast all relative timestamp/date/string comparison into string comparisions + // We should cast all relative timestamp/date/string comparison into string comparisons // This behaves as a user would expect because timestamp strings sort lexicographically. // i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true case p @ BinaryComparison(left @ StringType(), right @ DateType()) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala index 3831535574205..8bdf9b29c9641 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -53,7 +53,7 @@ object AttributeSet { * cosmetically (e.g., the names have different capitalizations). * * Note that we do not override equality for Attribute references as it is really weird when - * `AttributeReference("a"...) == AttrributeReference("b", ...)`. This tactic leads to broken tests, + * `AttributeReference("a"...) == AttributeReference("b", ...)`. This tactic leads to broken tests, * and also makes doing transformations hard (we always try keep older trees instead of new ones * when the transformation was a no-op). */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala index b58a5273041e4..ae1f6006135bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.rules._ - /** * Rewrites an expression using rules that are guaranteed preserve the result while attempting * to remove cosmetic variations. Deterministic expressions that are `equal` after canonicalization @@ -30,26 +28,23 @@ import org.apache.spark.sql.catalyst.rules._ * - Names and nullability hints for [[org.apache.spark.sql.types.DataType]]s are stripped. * - Commutative and associative operations ([[Add]] and [[Multiply]]) have their children ordered * by `hashCode`. -* - [[EqualTo]] and [[EqualNullSafe]] are reordered by `hashCode`. + * - [[EqualTo]] and [[EqualNullSafe]] are reordered by `hashCode`. * - Other comparisons ([[GreaterThan]], [[LessThan]]) are reversed by `hashCode`. */ -object Canonicalize extends RuleExecutor[Expression] { - override protected def batches: Seq[Batch] = - Batch( - "Expression Canonicalization", FixedPoint(100), - IgnoreNamesTypes, - Reorder) :: Nil +object Canonicalize extends { + def execute(e: Expression): Expression = { + expressionReorder(ignoreNamesTypes(e)) + } /** Remove names and nullability from types. */ - protected object IgnoreNamesTypes extends Rule[Expression] { - override def apply(e: Expression): Expression = e transformUp { - case a: AttributeReference => - AttributeReference("none", a.dataType.asNullable)(exprId = a.exprId) - } + private def ignoreNamesTypes(e: Expression): Expression = e match { + case a: AttributeReference => + AttributeReference("none", a.dataType.asNullable)(exprId = a.exprId) + case _ => e } /** Collects adjacent commutative operations. */ - protected def gatherCommutative( + private def gatherCommutative( e: Expression, f: PartialFunction[Expression, Seq[Expression]]): Seq[Expression] = e match { case c if f.isDefinedAt(c) => f(c).flatMap(gatherCommutative(_, f)) @@ -57,25 +52,25 @@ object Canonicalize extends RuleExecutor[Expression] { } /** Orders a set of commutative operations by their hash code. */ - protected def orderCommutative( + private def orderCommutative( e: Expression, f: PartialFunction[Expression, Seq[Expression]]): Seq[Expression] = gatherCommutative(e, f).sortBy(_.hashCode()) /** Rearrange expressions that are commutative or associative. */ - protected object Reorder extends Rule[Expression] { - override def apply(e: Expression): Expression = e transformUp { - case a: Add => orderCommutative(a, { case Add(l, r) => Seq(l, r) }).reduce(Add) - case m: Multiply => orderCommutative(m, { case Multiply(l, r) => Seq(l, r) }).reduce(Multiply) + private def expressionReorder(e: Expression): Expression = e match { + case a: Add => orderCommutative(a, { case Add(l, r) => Seq(l, r) }).reduce(Add) + case m: Multiply => orderCommutative(m, { case Multiply(l, r) => Seq(l, r) }).reduce(Multiply) + + case EqualTo(l, r) if l.hashCode() > r.hashCode() => EqualTo(r, l) + case EqualNullSafe(l, r) if l.hashCode() > r.hashCode() => EqualNullSafe(r, l) - case EqualTo(l, r) if l.hashCode() > r.hashCode() => EqualTo(r, l) - case EqualNullSafe(l, r) if l.hashCode() > r.hashCode() => EqualNullSafe(r, l) + case GreaterThan(l, r) if l.hashCode() > r.hashCode() => LessThan(r, l) + case LessThan(l, r) if l.hashCode() > r.hashCode() => GreaterThan(r, l) - case GreaterThan(l, r) if l.hashCode() > r.hashCode() => LessThan(r, l) - case LessThan(l, r) if l.hashCode() > r.hashCode() => GreaterThan(r, l) + case GreaterThanOrEqual(l, r) if l.hashCode() > r.hashCode() => LessThanOrEqual(r, l) + case LessThanOrEqual(l, r) if l.hashCode() > r.hashCode() => GreaterThanOrEqual(r, l) - case GreaterThanOrEqual(l, r) if l.hashCode() > r.hashCode() => LessThanOrEqual(r, l) - case LessThanOrEqual(l, r) if l.hashCode() > r.hashCode() => GreaterThanOrEqual(r, l) - } + case _ => e } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 692c16092fe3f..16a1b2aee2730 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -152,7 +152,10 @@ abstract class Expression extends TreeNode[Expression] { * `deterministic` expressions where `this.canonicalized == other.canonicalized` will always * evaluate to the same result. */ - lazy val canonicalized: Expression = Canonicalize.execute(this) + lazy val canonicalized: Expression = { + val canonicalizedChildred = children.map(_.canonicalized) + Canonicalize.execute(withNewChildren(canonicalizedChildred)) + } /** * Returns true when two expressions will always compute the same result, even if they differ @@ -161,7 +164,7 @@ abstract class Expression extends TreeNode[Expression] { * See [[Canonicalize]] for more details. */ def semanticEquals(other: Expression): Boolean = - deterministic && other.deterministic && canonicalized == other.canonicalized + deterministic && other.deterministic && canonicalized == other.canonicalized /** * Returns a `hashCode` for the calculation performed by this expression. Unlike the standard diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala index acea049adca3d..644a5b28a2151 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala @@ -36,7 +36,7 @@ object ExpressionSet { * Internally this set uses the canonical representation, but keeps also track of the original * expressions to ease debugging. Since different expressions can share the same canonical * representation, this means that operations that extract expressions from this set are only - * guranteed to see at least one such expression. For example: + * guaranteed to see at least one such expression. For example: * * {{{ * val set = AttributeSet(a + 1, 1 + a) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 22184f1ddfbb5..500ff447a9754 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -973,7 +973,7 @@ case class ScalaUDF( // scalastyle:on line.size.limit - // Generate codes used to convert the arguments to Scala type for user-defined funtions + // Generate codes used to convert the arguments to Scala type for user-defined functions private[this] def genCodeForConverter(ctx: CodegenContext, index: Int): String = { val converterClassName = classOf[Any => Any].getName val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index c4265a753933f..3dbe6348986b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -126,7 +126,7 @@ class CodegenContext { * For expressions that appear more than once, generate additional code to prevent * recomputing the value. * - * For example, consider two exprsesion generated from this SQL statement: + * For example, consider two expression generated from this SQL statement: * SELECT (col1 + col2), (col1 + col2) / col3. * * equivalentExpressions will match the tree containing `col1 + col2` and it will only @@ -140,7 +140,7 @@ class CodegenContext { // Foreach expression that is participating in subexpression elimination, the state to use. val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] - // The collection of sub-exression result resetting methods that need to be called on each row. + // The collection of sub-expression result resetting methods that need to be called on each row. val subexprFunctions = mutable.ArrayBuffer.empty[String] def declareAddedFunctions(): String = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 5ceb36513f840..103ab365e3190 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -214,7 +214,7 @@ case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[E /** Factory methods for CaseWhen. */ object CaseWhen { - // The maxium number of switches supported with codegen. + // The maximum number of switches supported with codegen. val MAX_NUM_CASES_FOR_CODEGEN = 20 def apply(branches: Seq[(Expression, Expression)], elseValue: Expression): CaseWhen = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index a76517a89cc4a..e6804d096cd96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import org.json4s.JsonAST._ @@ -109,7 +110,7 @@ object Literal { case DateType => create(0, DateType) case TimestampType => create(0L, TimestampType) case StringType => Literal("") - case BinaryType => Literal("".getBytes) + case BinaryType => Literal("".getBytes(StandardCharsets.UTF_8)) case CalendarIntervalType => Literal(new CalendarInterval(0, 0)) case arr: ArrayType => create(Array(), arr) case map: MapType => create(Map(), map) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index b95c5dd892d06..7eba617fcde59 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -364,7 +364,7 @@ object MapObjects { * used as input for the `lambdaFunction`. It also carries the element type info. * @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function * to handle collection elements. - * @param inputData An expression that when evaluted returns a collection object. + * @param inputData An expression that when evaluated returns a collection object. */ case class MapObjects private( loopVar: LambdaVariable, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 85776670e5c4e..2de92d06ec836 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -71,6 +71,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { PushPredicateThroughAggregate, LimitPushDown, ColumnPruning, + EliminateOperators, // Operator combine CollapseRepartition, CollapseProject, @@ -315,11 +316,7 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { * - LeftSemiJoin */ object ColumnPruning extends Rule[LogicalPlan] { - private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean = - output1.size == output2.size && - output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2)) - - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { // Prunes the unused columns from project list of Project/Aggregate/Expand case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty => p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains))) @@ -380,12 +377,6 @@ object ColumnPruning extends Rule[LogicalPlan] { p.copy(child = w.copy( windowExpressions = w.windowExpressions.filter(p.references.contains))) - // Eliminate no-op Window - case w: Window if w.windowExpressions.isEmpty => w.child - - // Eliminate no-op Projects - case p @ Project(projectList, child) if sameOutput(child.output, p.output) => child - // Can't prune the columns on LeafNode case p @ Project(_, l: LeafNode) => p @@ -409,6 +400,24 @@ object ColumnPruning extends Rule[LogicalPlan] { } } +/** + * Eliminate no-op Project and Window. + * + * Note: this rule should be executed just after ColumnPruning. + */ +object EliminateOperators extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + // Eliminate no-op Projects + case p @ Project(projectList, child) if sameOutput(child.output, p.output) => child + // Eliminate no-op Window + case w: Window if w.windowExpressions.isEmpty => w.child + } + + private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean = + output1.size == output2.size && + output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2)) +} + /** * Combines two adjacent [[Project]] operators into one and perform alias substitution, * merging the expressions into one single expression. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index b32c7d0fcbaa4..c8aadb2ed5340 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode} +import org.apache.spark.sql.types.StructType abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { @@ -116,6 +117,23 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { override lazy val canonicalized: LogicalPlan = EliminateSubqueryAliases(this) + /** + * Resolves a given schema to concrete [[Attribute]] references in this query plan. This function + * should only be called on analyzed plans since it will throw [[AnalysisException]] for + * unresolved [[Attribute]]s. + */ + def resolve(schema: StructType, resolver: Resolver): Seq[Attribute] = { + schema.map { field => + resolveQuoted(field.name, resolver).map { + case a: AttributeReference => a + case other => sys.error(s"can not handle nested schema yet... plan $this") + }.getOrElse { + throw new AnalysisException( + s"Unable to resolve ${field.name} given [${output.map(_.name).mkString(", ")}]") + } + } + } + /** * Optionally resolves the given strings to a [[NamedExpression]] using the input from all child * nodes of this LogicalPlan. The attribute is expressed as diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala index e4417e0955143..da90ddbd63afb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala @@ -66,7 +66,7 @@ object NumberConverter { * negative digit is found, ignore the suffix starting there. * * @param radix must be between MIN_RADIX and MAX_RADIX - * @param fromPos is the first element that should be conisdered + * @param fromPos is the first element that should be considered * @return the result should be treated as an unsigned 64-bit integer. */ private def encode(radix: Int, fromPos: Int): Long = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index d9a9b6151a0be..b11365b297184 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst import java.io._ +import java.nio.charset.StandardCharsets import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{NumericType, StringType} @@ -118,7 +119,7 @@ package object util { val writer = new PrintWriter(out) t.printStackTrace(writer) writer.flush() - new String(out.toByteArray) + new String(out.toByteArray, StandardCharsets.UTF_8) } def stringOrNull(a: AnyRef): String = if (a == null) null else a.toString diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index de9a56dc9c064..4e7bbc38d60ea 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -276,7 +276,7 @@ class AnalysisErrorSuite extends AnalysisTest { test("SPARK-6452 regression test") { // CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s) - // Since we manually construct the logical plan at here and Sum only accetp + // Since we manually construct the logical plan at here and Sum only accept // LongType, DoubleType, and DecimalType. We use LongType as the type of a. val plan = Aggregate( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index aa1d2b08613dd..8b568b6dd6acd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -250,7 +250,7 @@ class AnalysisSuite extends AnalysisTest { assertAnalysisSuccess(plan) } - test("SPARK-8654: different types in inlist but can be converted to a commmon type") { + test("SPARK-8654: different types in inlist but can be converted to a common type") { val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(1.2345))), "a")() :: Nil, LocalRelation() ) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index c30434a0063b0..6f289dcc475cd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -205,7 +205,7 @@ class HiveTypeCoercionSuite extends PlanTest { Project(Seq(Alias(transformed, "a")()), testRelation)) } - test("cast NullType for expresions that implement ExpectsInputTypes") { + test("cast NullType for expressions that implement ExpectsInputTypes") { import HiveTypeCoercionSuite._ ruleTest(HiveTypeCoercion.ImplicitTypeCasts, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala index ce42e5784ccd2..0b350c6a98255 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala @@ -70,7 +70,7 @@ class ExpressionSetSuite extends SparkFunSuite { // Not commutative setTest(2, aUpper - bUpper, bUpper - aUpper) - // Reversable + // Reversible setTest(1, aUpper > bUpper, bUpper < aUpper) setTest(1, aUpper >= bUpper, bUpper <= aUpper) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index 124172bd66f19..450222d8cbba3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.nio.charset.StandardCharsets + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -54,7 +56,7 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.default(FloatType), 0.0f) checkEvaluation(Literal.default(DoubleType), 0.0) checkEvaluation(Literal.default(StringType), "") - checkEvaluation(Literal.default(BinaryType), "".getBytes) + checkEvaluation(Literal.default(BinaryType), "".getBytes(StandardCharsets.UTF_8)) checkEvaluation(Literal.default(DecimalType.USER_DEFAULT), Decimal(0)) checkEvaluation(Literal.default(DecimalType.SYSTEM_DEFAULT), Decimal(0)) checkEvaluation(Literal.default(DateType), DateTimeUtils.toJavaDate(0)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 4ad65db0977c7..fba5f53715039 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.nio.charset.StandardCharsets + import com.google.common.math.LongMath import org.apache.spark.SparkFunSuite @@ -440,7 +442,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Hex(Literal(100800200404L)), "177828FED4") checkEvaluation(Hex(Literal(-100800200404L)), "FFFFFFE887D7012C") checkEvaluation(Hex(Literal.create(null, BinaryType)), null) - checkEvaluation(Hex(Literal("helloHex".getBytes())), "68656C6C6F486578") + checkEvaluation(Hex(Literal("helloHex".getBytes(StandardCharsets.UTF_8))), "68656C6C6F486578") // scalastyle:off // Turn off scala style for non-ascii chars checkEvaluation(Hex(Literal("三重的".getBytes("UTF8"))), "E4B889E9878DE79A84") @@ -452,7 +454,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("unhex") { checkEvaluation(Unhex(Literal.create(null, StringType)), null) - checkEvaluation(Unhex(Literal("737472696E67")), "string".getBytes) + checkEvaluation(Unhex(Literal("737472696E67")), "string".getBytes(StandardCharsets.UTF_8)) checkEvaluation(Unhex(Literal("")), new Array[Byte](0)) checkEvaluation(Unhex(Literal("F")), Array[Byte](15)) checkEvaluation(Unhex(Literal("ff")), Array[Byte](-1)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala index 75131a6170222..60d50baf511d9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.nio.charset.StandardCharsets + import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.SparkFunSuite @@ -27,7 +29,8 @@ import org.apache.spark.sql.types._ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("md5") { - checkEvaluation(Md5(Literal("ABC".getBytes)), "902fbdd2b1df0c4f70b4a5d23525e932") + checkEvaluation(Md5(Literal("ABC".getBytes(StandardCharsets.UTF_8))), + "902fbdd2b1df0c4f70b4a5d23525e932") checkEvaluation(Md5(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), "6ac1e56bc78f031059be7be854522c4c") checkEvaluation(Md5(Literal.create(null, BinaryType)), null) @@ -35,27 +38,31 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("sha1") { - checkEvaluation(Sha1(Literal("ABC".getBytes)), "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8") + checkEvaluation(Sha1(Literal("ABC".getBytes(StandardCharsets.UTF_8))), + "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8") checkEvaluation(Sha1(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), "5d211bad8f4ee70e16c7d343a838fc344a1ed961") checkEvaluation(Sha1(Literal.create(null, BinaryType)), null) - checkEvaluation(Sha1(Literal("".getBytes)), "da39a3ee5e6b4b0d3255bfef95601890afd80709") + checkEvaluation(Sha1(Literal("".getBytes(StandardCharsets.UTF_8))), + "da39a3ee5e6b4b0d3255bfef95601890afd80709") checkConsistencyBetweenInterpretedAndCodegen(Sha1, BinaryType) } test("sha2") { - checkEvaluation(Sha2(Literal("ABC".getBytes), Literal(256)), DigestUtils.sha256Hex("ABC")) + checkEvaluation(Sha2(Literal("ABC".getBytes(StandardCharsets.UTF_8)), Literal(256)), + DigestUtils.sha256Hex("ABC")) checkEvaluation(Sha2(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType), Literal(384)), DigestUtils.sha384Hex(Array[Byte](1, 2, 3, 4, 5, 6))) // unsupported bit length checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(1024)), null) checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(512)), null) - checkEvaluation(Sha2(Literal("ABC".getBytes), Literal.create(null, IntegerType)), null) + checkEvaluation(Sha2(Literal("ABC".getBytes(StandardCharsets.UTF_8)), + Literal.create(null, IntegerType)), null) checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal.create(null, IntegerType)), null) } test("crc32") { - checkEvaluation(Crc32(Literal("ABC".getBytes)), 2743272264L) + checkEvaluation(Crc32(Literal("ABC".getBytes(StandardCharsets.UTF_8))), 2743272264L) checkEvaluation(Crc32(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), 2180413220L) checkEvaluation(Crc32(Literal.create(null, BinaryType)), null) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 68545f33e5465..1265908182b3a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import org.scalatest.Matchers @@ -77,16 +78,16 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val row = new SpecificMutableRow(fieldTypes) row.setLong(0, 0) row.update(1, UTF8String.fromString("Hello")) - row.update(2, "World".getBytes) + row.update(2, "World".getBytes(StandardCharsets.UTF_8)) val unsafeRow: UnsafeRow = converter.apply(row) assert(unsafeRow.getSizeInBytes === 8 + (8 * 3) + - roundedSize("Hello".getBytes.length) + - roundedSize("World".getBytes.length)) + roundedSize("Hello".getBytes(StandardCharsets.UTF_8).length) + + roundedSize("World".getBytes(StandardCharsets.UTF_8).length)) assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.getString(1) === "Hello") - assert(unsafeRow.getBinary(2) === "World".getBytes) + assert(unsafeRow.getBinary(2) === "World".getBytes(StandardCharsets.UTF_8)) } test("basic conversion with primitive, string, date and timestamp types") { @@ -100,7 +101,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { row.update(3, DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-05-08 08:10:25"))) val unsafeRow: UnsafeRow = converter.apply(row) - assert(unsafeRow.getSizeInBytes === 8 + (8 * 4) + roundedSize("Hello".getBytes.length)) + assert(unsafeRow.getSizeInBytes === + 8 + (8 * 4) + roundedSize("Hello".getBytes(StandardCharsets.UTF_8).length)) assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.getString(1) === "Hello") @@ -175,7 +177,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { r.setFloat(6, 600) r.setDouble(7, 700) r.update(8, UTF8String.fromString("hello")) - r.update(9, "world".getBytes) + r.update(9, "world".getBytes(StandardCharsets.UTF_8)) r.setDecimal(10, Decimal(10), 10) r.setDecimal(11, Decimal(10.00, 38, 18), 38) // r.update(11, Array(11)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala index 0dbfb01e881f5..f5374229ca5cd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala @@ -131,7 +131,7 @@ class HyperLogLogPlusPlusSuite extends SparkFunSuite { i += 1 } - // Merge the lower and upper halfs. + // Merge the lower and upper halves. hll.merge(buffer1a, buffer1b) // Create the other buffer in reverse diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala index 1522ee34e43a5..e2a8eb8ee1d34 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions.codegen +import java.nio.charset.StandardCharsets + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -107,7 +109,8 @@ class GeneratedProjectionSuite extends SparkFunSuite { val fields = Array[DataType](StringType, struct) val unsafeProj = UnsafeProjection.create(fields) - val innerRow = InternalRow(false, 1.toByte, 2.toShort, 3, 4.0f, "".getBytes, + val innerRow = InternalRow(false, 1.toByte, 2.toShort, 3, 4.0f, + "".getBytes(StandardCharsets.UTF_8), UTF8String.fromString("")) val row1 = InternalRow(UTF8String.fromString(""), innerRow) val unsafe1 = unsafeProj(row1).copy() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index dd7d65ddc9e96..6187fb9e2fb87 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -35,6 +35,7 @@ class ColumnPruningSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Column pruning", FixedPoint(100), ColumnPruning, + EliminateOperators, CollapseProject) :: Nil } @@ -327,8 +328,8 @@ class ColumnPruningSuite extends PlanTest { val input2 = LocalRelation('c.int, 'd.string, 'e.double) val query = Project('b :: Nil, Union(input1 :: input2 :: Nil)).analyze - val expected = Project('b :: Nil, - Union(Project('b :: Nil, input1) :: Project('d :: Nil, input2) :: Nil)).analyze + val expected = + Union(Project('b :: Nil, input1) :: Project('d :: Nil, input2) :: Nil).analyze comparePlans(Optimize.execute(query), expected) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index 87ad81db11b64..e0e9b6d93ec96 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -28,7 +28,8 @@ class CombiningLimitsSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Filter Pushdown", FixedPoint(100), - ColumnPruning) :: + ColumnPruning, + EliminateOperators) :: Batch("Combine Limit", FixedPoint(10), CombineLimits) :: Batch("Constant Folding", FixedPoint(10), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index e2f8146beee7b..51468fa5ced31 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -43,6 +43,7 @@ class JoinOptimizationSuite extends PlanTest { PushPredicateThroughGenerate, PushPredicateThroughAggregate, ColumnPruning, + EliminateOperators, CollapseProject) :: Nil } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index 68f146f7a2622..b084eda6f84c1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -18,6 +18,7 @@ import java.math.BigDecimal; import java.math.BigInteger; +import java.nio.charset.StandardCharsets; import java.sql.Date; import java.util.Iterator; import java.util.List; @@ -138,7 +139,7 @@ private static void appendValue(ColumnVector dst, DataType t, Object o) { } else if (t == DataTypes.DoubleType) { dst.appendDouble(((Double)o).doubleValue()); } else if (t == DataTypes.StringType) { - byte[] b =((String)o).getBytes(); + byte[] b =((String)o).getBytes(StandardCharsets.UTF_8); dst.appendByteArray(b, 0, b.length); } else if (t instanceof DecimalType) { DecimalType dt = (DecimalType)t; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index 9a8aedfa56b8c..09c001baaeafd 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -18,6 +18,7 @@ import java.util.Arrays; import java.util.Iterator; +import java.util.NoSuchElementException; import org.apache.commons.lang.NotImplementedException; @@ -254,6 +255,9 @@ public Row next() { while (rowId < maxRows && ColumnarBatch.this.filteredRows[rowId]) { ++rowId; } + if (rowId >= maxRows) { + throw new NoSuchElementException(); + } row.rowId = rowId++; return row; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index f7ba61d2b804f..1751720a7db88 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -78,7 +78,7 @@ class TypedColumn[-T, U]( * * {{{ * df("columnName") // On a specific DataFrame. - * col("columnName") // A generic column no yet associcated with a DataFrame. + * col("columnName") // A generic column no yet associated with a DataFrame. * col("columnName.field") // Extracting a struct field * col("`a.column.with.dots`") // Escape `.` in column names. * $"columnName" // Scala short hand for a named column. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 52b567ea250b1..76b8d71ac9359 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -394,7 +394,8 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { */ def table(tableName: String): DataFrame = { Dataset.newDataFrame(sqlContext, - sqlContext.catalog.lookupRelation(sqlContext.sqlParser.parseTableIdentifier(tableName))) + sqlContext.catalog.lookupRelation( + sqlContext.sessionState.sqlParser.parseTableIdentifier(tableName))) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 3349b8421b3e8..de87f4d7c24ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -242,7 +242,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { options = extraOptions.toMap, partitionColumns = normalizedParCols.getOrElse(Nil)) - df.sqlContext.continuousQueryManager.startQuery( + df.sqlContext.sessionState.continuousQueryManager.startQuery( extraOptions.getOrElse("queryName", StreamExecution.nextName), df, dataSource.createSink()) } @@ -255,7 +255,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def insertInto(tableName: String): Unit = { - insertInto(df.sqlContext.sqlParser.parseTableIdentifier(tableName)) + insertInto(df.sqlContext.sessionState.sqlParser.parseTableIdentifier(tableName)) } private def insertInto(tableIdent: TableIdentifier): Unit = { @@ -354,7 +354,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def saveAsTable(tableName: String): Unit = { - saveAsTable(df.sqlContext.sqlParser.parseTableIdentifier(tableName)) + saveAsTable(df.sqlContext.sessionState.sqlParser.parseTableIdentifier(tableName)) } private def saveAsTable(tableIdent: TableIdentifier): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala rename to sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index b40089acc429c..df5839dd5c3d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -818,7 +818,7 @@ class Dataset[T] private[sql]( @scala.annotation.varargs def selectExpr(exprs: String*): DataFrame = { select(exprs.map { expr => - Column(sqlContext.sqlParser.parseExpression(expr)) + Column(sqlContext.sessionState.sqlParser.parseExpression(expr)) }: _*) } @@ -919,7 +919,7 @@ class Dataset[T] private[sql]( * @since 1.3.0 */ def filter(conditionExpr: String): Dataset[T] = { - filter(Column(sqlContext.sqlParser.parseExpression(conditionExpr))) + filter(Column(sqlContext.sessionState.sqlParser.parseExpression(conditionExpr))) } /** @@ -943,7 +943,7 @@ class Dataset[T] private[sql]( * @since 1.5.0 */ def where(conditionExpr: String): Dataset[T] = { - filter(Column(sqlContext.sqlParser.parseExpression(conditionExpr))) + filter(Column(sqlContext.sessionState.sqlParser.parseExpression(conditionExpr))) } /** @@ -1797,14 +1797,14 @@ class Dataset[T] private[sql]( */ def collectAsList(): java.util.List[T] = withCallback("collectAsList", toDF()) { _ => withNewExecutionId { - val values = queryExecution.toRdd.map(_.copy()).collect().map(boundTEncoder.fromRow) + val values = queryExecution.executedPlan.executeCollect().map(boundTEncoder.fromRow) java.util.Arrays.asList(values : _*) } } private def collect(needCallback: Boolean): Array[T] = { def execute(): Array[T] = withNewExecutionId { - queryExecution.toRdd.map(_.copy()).collect().map(boundTEncoder.fromRow) + queryExecution.executedPlan.executeCollect().map(boundTEncoder.fromRow) } if (needCallback) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala index deed45d273c33..d7cd84fd246c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.rules.Rule * @since 1.3.0 */ @Experimental -class ExperimentalMethods protected[sql](sqlContext: SQLContext) { +class ExperimentalMethods private[sql]() { /** * Allows extra strategies to be injected into the query planner at runtime. Note this API diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 472ae716f1530..a8700de135ce4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -223,7 +223,7 @@ class GroupedDataset[K, V] private[sql]( * Internal helper function for building typed aggregations that return tuples. For simplicity * and code reuse, we do this without the help of the type system and then use helper functions * that cast appropriately for the user facing interface. - * TODO: does not handle aggrecations that return nonflat results, + * TODO: does not handle aggregations that return nonflat results, */ protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(_.encoder) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 36fe57f78be1d..0f5d1c8cab519 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -121,14 +121,7 @@ class SQLContext private[sql]( protected[sql] lazy val sessionState: SessionState = new SessionState(self) protected[sql] def conf: SQLConf = sessionState.conf protected[sql] def catalog: Catalog = sessionState.catalog - protected[sql] def functionRegistry: FunctionRegistry = sessionState.functionRegistry protected[sql] def analyzer: Analyzer = sessionState.analyzer - protected[sql] def optimizer: Optimizer = sessionState.optimizer - protected[sql] def sqlParser: ParserInterface = sessionState.sqlParser - protected[sql] def planner: SparkPlanner = sessionState.planner - protected[sql] def continuousQueryManager = sessionState.continuousQueryManager - protected[sql] def prepareForExecution: RuleExecutor[SparkPlan] = - sessionState.prepareForExecution /** * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s @@ -197,7 +190,7 @@ class SQLContext private[sql]( */ def getAllConfs: immutable.Map[String, String] = conf.getAllConfs - protected[sql] def parseSql(sql: String): LogicalPlan = sqlParser.parsePlan(sql) + protected[sql] def parseSql(sql: String): LogicalPlan = sessionState.sqlParser.parsePlan(sql) protected[sql] def executeSql(sql: String): QueryExecution = executePlan(parseSql(sql)) @@ -244,7 +237,7 @@ class SQLContext private[sql]( */ @Experimental @transient - val experimental: ExperimentalMethods = new ExperimentalMethods(this) + def experimental: ExperimentalMethods = sessionState.experimentalMethods /** * :: Experimental :: @@ -641,7 +634,7 @@ class SQLContext private[sql]( tableName: String, source: String, options: Map[String, String]): DataFrame = { - val tableIdent = sqlParser.parseTableIdentifier(tableName) + val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) val cmd = CreateTableUsing( tableIdent, @@ -687,7 +680,7 @@ class SQLContext private[sql]( source: String, schema: StructType, options: Map[String, String]): DataFrame = { - val tableIdent = sqlParser.parseTableIdentifier(tableName) + val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) val cmd = CreateTableUsing( tableIdent, @@ -706,7 +699,7 @@ class SQLContext private[sql]( * only during the lifetime of this instance of SQLContext. */ private[sql] def registerDataFrameAsTable(df: DataFrame, tableName: String): Unit = { - catalog.registerTable(sqlParser.parseTableIdentifier(tableName), df.logicalPlan) + catalog.registerTable(sessionState.sqlParser.parseTableIdentifier(tableName), df.logicalPlan) } /** @@ -800,7 +793,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ def table(tableName: String): DataFrame = { - table(sqlParser.parseTableIdentifier(tableName)) + table(sessionState.sqlParser.parseTableIdentifier(tableName)) } private def table(tableIdent: TableIdentifier): DataFrame = { @@ -837,9 +830,7 @@ class SQLContext private[sql]( * * @since 2.0.0 */ - def streams: ContinuousQueryManager = { - continuousQueryManager - } + def streams: ContinuousQueryManager = sessionState.continuousQueryManager /** * Returns the names of tables in the current database as an array. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index d363cb000d39a..e97c6be7f177a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -151,7 +151,7 @@ private[sql] case class DataSourceScan( override val outputPartitioning = { val bucketSpec = relation match { // TODO: this should be closer to bucket planning. - case r: HadoopFsRelation if r.sqlContext.conf.bucketingEnabled() => r.bucketSpec + case r: HadoopFsRelation if r.sqlContext.conf.bucketingEnabled => r.bucketSpec case _ => None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 9e60c1cd6141c..5b4254f741ab1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -45,16 +45,16 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { sqlContext.cacheManager.useCachedData(analyzed) } - lazy val optimizedPlan: LogicalPlan = sqlContext.optimizer.execute(withCachedData) + lazy val optimizedPlan: LogicalPlan = sqlContext.sessionState.optimizer.execute(withCachedData) lazy val sparkPlan: SparkPlan = { SQLContext.setActive(sqlContext) - sqlContext.planner.plan(ReturnAnswer(optimizedPlan)).next() + sqlContext.sessionState.planner.plan(ReturnAnswer(optimizedPlan)).next() } // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. - lazy val executedPlan: SparkPlan = sqlContext.prepareForExecution.execute(sparkPlan) + lazy val executedPlan: SparkPlan = sqlContext.sessionState.prepareForExecution.execute(sparkPlan) /** Internal version of the RDD. Avoids copies and has no schema */ lazy val toRdd: RDD[InternalRow] = executedPlan.execute() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index edaf3b36aa52e..cbde777d98415 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.optimizer._ +import org.apache.spark.sql.ExperimentalMethods +import org.apache.spark.sql.catalyst.optimizer.Optimizer -class SparkOptimizer(val sqlContext: SQLContext) - extends Optimizer { - override def batches: Seq[Batch] = super.batches :+ Batch( - "User Provided Optimizers", FixedPoint(100), sqlContext.experimental.extraOptimizations: _*) +class SparkOptimizer(experimentalMethods: ExperimentalMethods) extends Optimizer { + override def batches: Seq[Batch] = super.batches :+ Batch( + "User Provided Optimizers", FixedPoint(100), experimentalMethods.extraOptimizations: _*) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 3be4cce045fea..e04683c499a32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -17,14 +17,15 @@ package org.apache.spark.sql.execution +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable.ArrayBuffer import scala.concurrent.{Await, ExecutionContext, Future} import scala.concurrent.duration._ -import org.apache.spark.Logging -import org.apache.spark.broadcast +import org.apache.spark.{broadcast, Logging, SparkEnv} +import org.apache.spark.io.CompressionCodec import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} @@ -52,7 +53,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ protected def sparkContext = sqlContext.sparkContext // sqlContext will be null when we are being deserialized on the slaves. In this instance - // the value of subexpressionEliminationEnabled will be set by the desserializer after the + // the value of subexpressionEliminationEnabled will be set by the deserializer after the // constructor has run. val subexpressionEliminationEnabled: Boolean = if (sqlContext != null) { sqlContext.conf.subexpressionEliminationEnabled @@ -65,7 +66,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ */ private val prepareCalled = new AtomicBoolean(false) - /** Overridden make copy also propogates sqlContext to copied plan. */ + /** Overridden make copy also propagates sqlContext to copied plan. */ override def makeCopy(newArgs: Array[AnyRef]): SparkPlan = { SQLContext.setActive(sqlContext) super.makeCopy(newArgs) @@ -220,7 +221,47 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * Runs this query returning the result as an array. */ def executeCollect(): Array[InternalRow] = { - execute().map(_.copy()).collect() + // Packing the UnsafeRows into byte array for faster serialization. + // The byte arrays are in the following format: + // [size] [bytes of UnsafeRow] [size] [bytes of UnsafeRow] ... [-1] + // + // UnsafeRow is highly compressible (at least 8 bytes for any column), the byte array is also + // compressed. + val byteArrayRdd = execute().mapPartitionsInternal { iter => + val buffer = new Array[Byte](4 << 10) // 4K + val codec = CompressionCodec.createCodec(SparkEnv.get.conf) + val bos = new ByteArrayOutputStream() + val out = new DataOutputStream(codec.compressedOutputStream(bos)) + while (iter.hasNext) { + val row = iter.next().asInstanceOf[UnsafeRow] + out.writeInt(row.getSizeInBytes) + row.writeToStream(out, buffer) + } + out.writeInt(-1) + out.flush() + out.close() + Iterator(bos.toByteArray) + } + + // Collect the byte arrays back to driver, then decode them as UnsafeRows. + val nFields = schema.length + val results = ArrayBuffer[InternalRow]() + + byteArrayRdd.collect().foreach { bytes => + val codec = CompressionCodec.createCodec(SparkEnv.get.conf) + val bis = new ByteArrayInputStream(bytes) + val ins = new DataInputStream(codec.compressedInputStream(bis)) + var sizeOfNextRow = ins.readInt() + while (sizeOfNextRow >= 0) { + val bs = new Array[Byte](sizeOfNextRow) + ins.readFully(bs) + val row = new UnsafeRow(nFields) + row.pointTo(bs, sizeOfNextRow) + results += row + sizeOfNextRow = ins.readInt() + } + } + results.toArray } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index d1569a4ec2b40..9da2c74c62fc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -20,15 +20,20 @@ package org.apache.spark.sql.execution import org.apache.spark.SparkContext import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, FileSourceStrategy} +import org.apache.spark.sql.internal.SQLConf -class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies { - val sparkContext: SparkContext = sqlContext.sparkContext +class SparkPlanner( + val sparkContext: SparkContext, + val conf: SQLConf, + val experimentalMethods: ExperimentalMethods) + extends SparkStrategies { - def numPartitions: Int = sqlContext.conf.numShufflePartitions + def numPartitions: Int = conf.numShufflePartitions def strategies: Seq[Strategy] = - sqlContext.experimental.extraStrategies ++ ( + experimentalMethods.extraStrategies ++ ( + FileSourceStrategy :: DataSourceStrategy :: DDLStrategy :: SpecialLimits :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala index d12dab567b00a..11391bd12acae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala @@ -34,8 +34,21 @@ private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends Cataly } /** - * For each node, extract properties in the form of a list ['key1', 'key2', 'key3', 'value'] - * into a pair (key1.key2.key3, value). + * For each node, extract properties in the form of a list + * ['key_part1', 'key_part2', 'key_part3', 'value'] + * into a pair (key_part1.key_part2.key_part3, value). + * + * Example format: + * + * TOK_TABLEPROPERTY + * :- 'k1' + * +- 'v1' + * TOK_TABLEPROPERTY + * :- 'k2' + * +- 'v2' + * TOK_TABLEPROPERTY + * :- 'k3' + * +- 'v3' */ private def extractProps( props: Seq[ASTNode], @@ -101,6 +114,16 @@ private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends Cataly } val props = dbprops.toSeq.flatMap { case Token("TOK_DATABASEPROPERTIES", Token("TOK_DBPROPLIST", propList) :: Nil) => + // Example format: + // + // TOK_DATABASEPROPERTIES + // +- TOK_DBPROPLIST + // :- TOK_TABLEPROPERTY + // : :- 'k1' + // : +- 'v1' + // :- TOK_TABLEPROPERTY + // :- 'k2' + // +- 'v2' extractProps(propList, "TOK_TABLEPROPERTY") case _ => parseFailed("Invalid CREATE DATABASE command", node) }.toMap @@ -112,16 +135,16 @@ private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends Cataly // Example format: // // TOK_CREATEFUNCTION - // :- db_name - // :- func_name - // :- alias - // +- TOK_RESOURCE_LIST - // :- TOK_RESOURCE_URI - // : :- TOK_JAR - // : +- '/path/to/jar' - // +- TOK_RESOURCE_URI - // :- TOK_FILE - // +- 'path/to/file' + // :- db_name + // :- func_name + // :- alias + // +- TOK_RESOURCE_LIST + // :- TOK_RESOURCE_URI + // : :- TOK_JAR + // : +- '/path/to/jar' + // +- TOK_RESOURCE_URI + // :- TOK_FILE + // +- 'path/to/file' val (funcNameArgs, otherArgs) = args.partition { case Token("TOK_RESOURCE_LIST", _) => false case Token("TOK_TEMPORARY", _) => false @@ -139,9 +162,9 @@ private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends Cataly } // Extract other keywords, if they exist val Seq(rList, temp) = getClauses(Seq("TOK_RESOURCE_LIST", "TOK_TEMPORARY"), otherArgs) - val resourcesMap = rList.toSeq.flatMap { - case Token("TOK_RESOURCE_LIST", resources) => - resources.map { + val resources: Seq[(String, String)] = rList.toSeq.flatMap { + case Token("TOK_RESOURCE_LIST", resList) => + resList.map { case Token("TOK_RESOURCE_URI", rType :: Token(rPath, Nil) :: Nil) => val resourceType = rType match { case Token("TOK_JAR", Nil) => "jar" @@ -153,8 +176,8 @@ private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends Cataly case _ => parseFailed("Invalid CREATE FUNCTION command", node) } case _ => parseFailed("Invalid CREATE FUNCTION command", node) - }.toMap - CreateFunction(funcName, alias, resourcesMap, temp.isDefined)(node.source) + } + CreateFunction(funcName, alias, resources, temp.isDefined)(node.source) case Token("TOK_ALTERTABLE", alterTableArgs) => AlterTableCommandParser.parse(node) @@ -248,15 +271,14 @@ private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends Cataly // issue. val tableIdent = TableIdentifier( cleanIdentifier(tableName), Some(cleanIdentifier(dbName))) - datasources.DescribeCommand( - UnresolvedRelation(tableIdent, None), isExtended = extended.isDefined) + datasources.DescribeCommand(tableIdent, isExtended = extended.isDefined) case Token(dbName, Nil) :: Token(tableName, Nil) :: Token(colName, Nil) :: Nil => // It is describing a column with the format like "describe db.table column". nodeToDescribeFallback(node) case tableName :: Nil => // It is describing a table with the format like "describe table". datasources.DescribeCommand( - UnresolvedRelation(TableIdentifier(cleanIdentifier(tableName.text)), None), + TableIdentifier(cleanIdentifier(tableName.text)), isExtended = extended.isDefined) case _ => nodeToDescribeFallback(node) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index bae0750788088..113cf9ae2f222 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -80,8 +80,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { */ object CanBroadcast { def unapply(plan: LogicalPlan): Option[LogicalPlan] = { - if (sqlContext.conf.autoBroadcastJoinThreshold > 0 && - plan.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold) { + if (conf.autoBroadcastJoinThreshold > 0 && + plan.statistics.sizeInBytes <= conf.autoBroadcastJoinThreshold) { Some(plan) } else { None @@ -398,11 +398,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.") case describe @ LogicalDescribeCommand(table, isExtended) => - val resultPlan = self.sqlContext.executePlan(table).executedPlan - ExecutedCommand( - RunnableDescribeCommand(resultPlan, describe.output, isExtended)) :: Nil + ExecutedCommand(RunnableDescribeCommand(table, describe.output, isExtended)) :: Nil - case logical.ShowFunctions(db, pattern) => ExecutedCommand(ShowFunctions(db, pattern)) :: Nil + case logical.ShowFunctions(db, pattern) => + ExecutedCommand(ShowFunctions(db, pattern)) :: Nil case logical.DescribeFunction(function, extended) => ExecutedCommand(DescribeFunction(function, extended)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 8fb4705581a38..81676d3ebb346 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution import org.apache.spark.broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -29,6 +28,7 @@ import org.apache.spark.sql.catalyst.util.toCommentSafeString import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin} import org.apache.spark.sql.execution.metric.LongSQLMetricValue +import org.apache.spark.sql.internal.SQLConf /** * An interface for those physical operators that support codegen. @@ -427,7 +427,7 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup /** * Find the chained plans that support codegen, collapse them together as WholeStageCodegen. */ -private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Rule[SparkPlan] { +case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { private def supportCodegen(e: Expression): Boolean = e match { case e: LeafExpression => true @@ -472,7 +472,7 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru } def apply(plan: SparkPlan): SparkPlan = { - if (sqlContext.conf.wholeStageEnabled) { + if (conf.wholeStageEnabled) { insertWholeStageCodegen(plan) } else { plan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala index 3ec01185c4328..f9d606e37ea89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala @@ -40,7 +40,7 @@ import org.apache.spark.unsafe.types.UTF8String * so we do not have helper methods for them. * * - * WARNNING: This only works with HeapByteBuffer + * WARNING: This only works with HeapByteBuffer */ private[columnar] object ByteBufferHelper { def getInt(buffer: ByteBuffer): Int = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AlterTableCommandParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AlterTableCommandParser.scala index 58639275c111b..9fbe6db467ffa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AlterTableCommandParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AlterTableCommandParser.scala @@ -55,20 +55,22 @@ object AlterTableCommandParser { /** * Extract partition spec from the given [[ASTNode]] as a map, assuming it exists. * - * Expected format: - * +- TOK_PARTSPEC - * :- TOK_PARTVAL - * : :- dt - * : +- '2008-08-08' - * +- TOK_PARTVAL - * :- country - * +- 'us' + * Example format: + * + * TOK_PARTSPEC + * :- TOK_PARTVAL + * : :- dt + * : +- '2008-08-08' + * +- TOK_PARTVAL + * :- country + * +- 'us' */ private def parsePartitionSpec(node: ASTNode): Map[String, String] = { node match { case Token("TOK_PARTSPEC", partitions) => partitions.map { // Note: sometimes there's a "=", "<" or ">" between the key and the value + // (e.g. when dropping all partitions with value > than a certain constant) case Token("TOK_PARTVAL", ident :: conj :: constant :: Nil) => (cleanAndUnquoteString(ident.text), cleanAndUnquoteString(constant.text)) case Token("TOK_PARTVAL", ident :: constant :: Nil) => @@ -86,15 +88,16 @@ object AlterTableCommandParser { /** * Extract table properties from the given [[ASTNode]] as a map, assuming it exists. * - * Expected format: - * +- TOK_TABLEPROPERTIES - * +- TOK_TABLEPROPLIST - * :- TOK_TABLEPROPERTY - * : :- 'test' - * : +- 'value' - * +- TOK_TABLEPROPERTY - * :- 'comment' - * +- 'new_comment' + * Example format: + * + * TOK_TABLEPROPERTIES + * +- TOK_TABLEPROPLIST + * :- TOK_TABLEPROPERTY + * : :- 'test' + * : +- 'value' + * +- TOK_TABLEPROPERTY + * :- 'comment' + * +- 'new_comment' */ private def extractTableProps(node: ASTNode): Map[String, String] = { node match { @@ -209,21 +212,21 @@ object AlterTableCommandParser { Token("TOK_TABCOLNAME", colNames) :: colValues :: rest) :: Nil) :: _ => // Example format: // - // +- TOK_ALTERTABLE_SKEWED - // :- TOK_TABLESKEWED - // : :- TOK_TABCOLNAME - // : : :- dt - // : : +- country - // :- TOK_TABCOLVALUE_PAIR - // : :- TOK_TABCOLVALUES - // : : :- TOK_TABCOLVALUE - // : : : :- '2008-08-08' - // : : : +- 'us' - // : :- TOK_TABCOLVALUES - // : : :- TOK_TABCOLVALUE - // : : : :- '2009-09-09' - // : : : +- 'uk' - // +- TOK_STOREASDIR + // TOK_ALTERTABLE_SKEWED + // :- TOK_TABLESKEWED + // : :- TOK_TABCOLNAME + // : : :- dt + // : : +- country + // :- TOK_TABCOLVALUE_PAIR + // : :- TOK_TABCOLVALUES + // : : :- TOK_TABCOLVALUE + // : : : :- '2008-08-08' + // : : : +- 'us' + // : :- TOK_TABCOLVALUES + // : : :- TOK_TABCOLVALUE + // : : : :- '2009-09-09' + // : : : +- 'uk' + // +- TOK_STOREASDIR val names = colNames.map { n => cleanAndUnquoteString(n.text) } val values = colValues match { case Token("TOK_TABCOLVALUE", vals) => @@ -260,20 +263,20 @@ object AlterTableCommandParser { case Token("TOK_ALTERTABLE_SKEWED_LOCATION", Token("TOK_SKEWED_LOCATIONS", Token("TOK_SKEWED_LOCATION_LIST", locationMaps) :: Nil) :: Nil) :: _ => - // Expected format: + // Example format: // - // +- TOK_ALTERTABLE_SKEWED_LOCATION - // +- TOK_SKEWED_LOCATIONS - // +- TOK_SKEWED_LOCATION_LIST - // :- TOK_SKEWED_LOCATION_MAP - // : :- 'col1' - // : +- 'loc1' - // +- TOK_SKEWED_LOCATION_MAP - // :- TOK_TABCOLVALUES - // : +- TOK_TABCOLVALUE - // : :- 'col2' - // : +- 'col3' - // +- 'loc2' + // TOK_ALTERTABLE_SKEWED_LOCATION + // +- TOK_SKEWED_LOCATIONS + // +- TOK_SKEWED_LOCATION_LIST + // :- TOK_SKEWED_LOCATION_MAP + // : :- 'col1' + // : +- 'loc1' + // +- TOK_SKEWED_LOCATION_MAP + // :- TOK_TABCOLVALUES + // : +- TOK_TABCOLVALUE + // : :- 'col2' + // : +- 'col3' + // +- 'loc2' val skewedMaps = locationMaps.flatMap { case Token("TOK_SKEWED_LOCATION_MAP", col :: loc :: Nil) => col match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 54cdcb10ac213..e711797c1b51a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -22,7 +22,7 @@ import java.util.NoSuchElementException import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Dataset, Row, SQLContext} -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical @@ -293,13 +293,14 @@ case object ClearCacheCommand extends RunnableCommand { case class DescribeCommand( - child: SparkPlan, + table: TableIdentifier, override val output: Seq[Attribute], isExtended: Boolean) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { - child.schema.fields.map { field => + val relation = sqlContext.sessionState.catalog.lookupRelation(table) + relation.schema.fields.map { field => val cmtKey = "comment" val comment = if (field.metadata.contains(cmtKey)) field.metadata.getString(cmtKey) else "" Row(field.name, field.dataType.simpleString, comment) @@ -357,13 +358,14 @@ case class ShowFunctions(db: Option[String], pattern: Option[String]) extends Ru case Some(p) => try { val regex = java.util.regex.Pattern.compile(p) - sqlContext.functionRegistry.listFunction().filter(regex.matcher(_).matches()).map(Row(_)) + sqlContext.sessionState.functionRegistry.listFunction() + .filter(regex.matcher(_).matches()).map(Row(_)) } catch { // probably will failed in the regex that user provided, then returns empty row. case _: Throwable => Seq.empty[Row] } case None => - sqlContext.functionRegistry.listFunction().map(Row(_)) + sqlContext.sessionState.functionRegistry.listFunction().map(Row(_)) } } @@ -394,7 +396,7 @@ case class DescribeFunction( } override def run(sqlContext: SQLContext): Seq[Row] = { - sqlContext.functionRegistry.lookupFunction(functionName) match { + sqlContext.sessionState.functionRegistry.lookupFunction(functionName) match { case Some(info) => val result = Row(s"Function: ${info.getName}") :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 9df58d214a504..3fb2e34101a68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -55,7 +55,7 @@ case class CreateDatabase( case class CreateFunction( functionName: String, alias: String, - resourcesMap: Map[String, String], + resources: Seq[(String, String)], isTemp: Boolean)(sql: String) extends NativeDDLCommand(sql) with Logging diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 887f5469b5f8f..e65a771202bce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -143,7 +143,7 @@ case class DataSource( SparkHadoopUtil.get.globPathIfNecessary(qualified) }.toArray - val fileCatalog: FileCatalog = new HDFSFileCatalog(sqlContext, options, globbedPaths) + val fileCatalog: FileCatalog = new HDFSFileCatalog(sqlContext, options, globbedPaths, None) val dataSchema = userSpecifiedSchema.orElse { format.inferSchema( sqlContext, @@ -208,7 +208,20 @@ case class DataSource( SparkHadoopUtil.get.globPathIfNecessary(qualified) }.toArray - val fileCatalog: FileCatalog = new HDFSFileCatalog(sqlContext, options, globbedPaths) + // If they gave a schema, then we try and figure out the types of the partition columns + // from that schema. + val partitionSchema = userSpecifiedSchema.map { schema => + StructType( + partitionColumns.map { c => + // TODO: Case sensitivity. + schema + .find(_.name.toLowerCase() == c.toLowerCase()) + .getOrElse(throw new AnalysisException(s"Invalid partition column '$c'")) + }) + } + + val fileCatalog: FileCatalog = + new HDFSFileCatalog(sqlContext, options, globbedPaths, partitionSchema) val dataSchema = userSpecifiedSchema.orElse { format.inferSchema( sqlContext, @@ -220,22 +233,11 @@ case class DataSource( "It must be specified manually") } - // If they gave a schema, then we try and figure out the types of the partition columns - // from that schema. - val partitionSchema = userSpecifiedSchema.map { schema => - StructType( - partitionColumns.map { c => - // TODO: Case sensitivity. - schema - .find(_.name.toLowerCase() == c.toLowerCase()) - .getOrElse(throw new AnalysisException(s"Invalid partition column '$c'")) - }) - }.getOrElse(fileCatalog.partitionSpec(None).partitionColumns) HadoopFsRelation( sqlContext, fileCatalog, - partitionSchema = partitionSchema, + partitionSchema = fileCatalog.partitionSpec().partitionColumns, dataSchema = dataSchema.asNullable, bucketSpec = bucketSpec, format, @@ -296,7 +298,7 @@ case class DataSource( resolveRelation() .asInstanceOf[HadoopFsRelation] .location - .partitionSpec(None) + .partitionSpec() .partitionColumns .fieldNames .toSet) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 1adf3b6676555..7f6671552ebde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -126,7 +126,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { val partitionAndNormalColumnFilters = filters.toSet -- partitionFilters.toSet -- pushedFilters.toSet - val selectedPartitions = prunePartitions(partitionFilters, t.partitionSpec).toArray + val selectedPartitions = t.location.listFiles(partitionFilters) logInfo { val total = t.partitionSpec.partitions.length @@ -180,7 +180,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { t.sqlContext.sparkContext.broadcast(new SerializableConfiguration(sharedHadoopConf)) t.bucketSpec match { - case Some(spec) if t.sqlContext.conf.bucketingEnabled() => + case Some(spec) if t.sqlContext.conf.bucketingEnabled => val scanBuilder: (Seq[Attribute], Array[Filter]) => RDD[InternalRow] = { (requiredColumns: Seq[Attribute], filters: Array[Filter]) => { val bucketed = @@ -200,7 +200,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { requiredColumns.map(_.name).toArray, filters, None, - bucketFiles.toArray, + bucketFiles, confBroadcast, t.options).coalesce(1) } @@ -233,7 +233,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { a.map(_.name).toArray, f, None, - t.location.allFiles().toArray, + t.location.allFiles(), confBroadcast, t.options)) :: Nil } @@ -255,7 +255,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { filters: Seq[Expression], buckets: Option[BitSet], partitionColumns: StructType, - partitions: Array[Partition], + partitions: Seq[Partition], options: Map[String, String]): SparkPlan = { val relation = logicalRelation.relation.asInstanceOf[HadoopFsRelation] @@ -272,14 +272,13 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { (requiredColumns: Seq[Attribute], filters: Array[Filter]) => { relation.bucketSpec match { - case Some(spec) if relation.sqlContext.conf.bucketingEnabled() => + case Some(spec) if relation.sqlContext.conf.bucketingEnabled => val requiredDataColumns = requiredColumns.filterNot(c => partitionColumnNames.contains(c.name)) // Builds RDD[Row]s for each selected partition. val perPartitionRows: Seq[(Int, RDD[InternalRow])] = partitions.flatMap { - case Partition(partitionValues, dir) => - val files = relation.location.getStatus(dir) + case Partition(partitionValues, files) => val bucketed = files.groupBy { f => BucketingUtils .getBucketId(f.getPath.getName) @@ -327,14 +326,14 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // Builds RDD[Row]s for each selected partition. val perPartitionRows = partitions.map { - case Partition(partitionValues, dir) => + case Partition(partitionValues, files) => val dataRows = relation.fileFormat.buildInternalScan( relation.sqlContext, relation.dataSchema, requiredDataColumns.map(_.name).toArray, filters, buckets, - relation.location.getStatus(dir), + files, confBroadcast, options) @@ -525,33 +524,6 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { if (matchedBuckets.cardinality() == 0) None else Some(matchedBuckets) } - protected def prunePartitions( - predicates: Seq[Expression], - partitionSpec: PartitionSpec): Seq[Partition] = { - val PartitionSpec(partitionColumns, partitions) = partitionSpec - val partitionColumnNames = partitionColumns.map(_.name).toSet - val partitionPruningPredicates = predicates.filter { - _.references.map(_.name).toSet.subsetOf(partitionColumnNames) - } - - if (partitionPruningPredicates.nonEmpty) { - val predicate = - partitionPruningPredicates - .reduceOption(expressions.And) - .getOrElse(Literal(true)) - - val boundPredicate = InterpretedPredicate.create(predicate.transform { - case a: AttributeReference => - val index = partitionColumns.indexWhere(a.name == _.name) - BoundReference(index, partitionColumns(index).dataType, nullable = true) - }) - - partitions.filter { case Partition(values, _) => boundPredicate(values) } - } else { - partitions - } - } - // Based on Public API. protected def pruneFilterProject( relation: LogicalRelation, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala new file mode 100644 index 0000000000000..e2cbbc34d91a4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.{Partition, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.InternalRow + +/** + * A single file that should be read, along with partition column values that + * need to be prepended to each row. The reading should start at the first + * valid record found after `offset`. + */ +case class PartitionedFile( + partitionValues: InternalRow, + filePath: String, + start: Long, + length: Long) + +/** + * A collection of files that should be read as a single task possibly from multiple partitioned + * directories. + * + * IMPLEMENT ME: This is just a placeholder for a future implementation. + * TODO: This currently does not take locality information about the files into account. + */ +case class FilePartition(val index: Int, files: Seq[PartitionedFile]) extends Partition + +class FileScanRDD( + @transient val sqlContext: SQLContext, + readFunction: (PartitionedFile) => Iterator[InternalRow], + @transient val filePartitions: Seq[FilePartition]) + extends RDD[InternalRow](sqlContext.sparkContext, Nil) { + + + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { + throw new NotImplementedError("Not Implemented Yet") + } + + override protected def getPartitions: Array[Partition] = Array.empty +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala new file mode 100644 index 0000000000000..ef95d5d28961f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import scala.collection.mutable.ArrayBuffer + +import org.apache.hadoop.fs.Path + +import org.apache.spark.Logging +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.{DataSourceScan, SparkPlan} +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ + +/** + * A strategy for planning scans over collections of files that might be partitioned or bucketed + * by user specified columns. + * + * At a high level planning occurs in several phases: + * - Split filters by when they need to be evaluated. + * - Prune the schema of the data requested based on any projections present. Today this pruning + * is only done on top level columns, but formats should support pruning of nested columns as + * well. + * - Construct a reader function by passing filters and the schema into the FileFormat. + * - Using an partition pruning predicates, enumerate the list of files that should be read. + * - Split the files into tasks and construct a FileScanRDD. + * - Add any projection or filters that must be evaluated after the scan. + * + * Files are assigned into tasks using the following algorithm: + * - If the table is bucketed, group files by bucket id into the correct number of partitions. + * - If the table is not bucketed or bucketing is turned off: + * - If any file is larger than the threshold, split it into pieces based on that threshold + * - Sort the files by decreasing file size. + * - Assign the ordered files to buckets using the following algorithm. If the current partition + * is under the threshold with the addition of the next file, add it. If not, open a new bucket + * and add it. Proceed to the next file. + */ +private[sql] object FileSourceStrategy extends Strategy with Logging { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case PhysicalOperation(projects, filters, l@LogicalRelation(files: HadoopFsRelation, _, _)) + if files.fileFormat.toString == "TestFileFormat" => + // Filters on this relation fall into four categories based on where we can use them to avoid + // reading unneeded data: + // - partition keys only - used to prune directories to read + // - bucket keys only - optionally used to prune files to read + // - keys stored in the data only - optionally used to skip groups of data in files + // - filters that need to be evaluated again after the scan + val filterSet = ExpressionSet(filters) + + val partitionColumns = + AttributeSet(l.resolve(files.partitionSchema, files.sqlContext.analyzer.resolver)) + val partitionKeyFilters = + ExpressionSet(filters.filter(_.references.subsetOf(partitionColumns))) + logInfo(s"Pruning directories with: ${partitionKeyFilters.mkString(",")}") + + val bucketColumns = + AttributeSet( + files.bucketSpec + .map(_.bucketColumnNames) + .getOrElse(Nil) + .map(l.resolveQuoted(_, files.sqlContext.conf.resolver) + .getOrElse(sys.error("")))) + + // Partition keys are not available in the statistics of the files. + val dataFilters = filters.filter(_.references.intersect(partitionColumns).isEmpty) + + // Predicates with both partition keys and attributes need to be evaluated after the scan. + val afterScanFilters = filterSet -- partitionKeyFilters + logInfo(s"Post-Scan Filters: ${afterScanFilters.mkString(",")}") + + val selectedPartitions = files.location.listFiles(partitionKeyFilters.toSeq) + + val filterAttributes = AttributeSet(afterScanFilters) + val requiredExpressions: Seq[NamedExpression] = filterAttributes.toSeq ++ projects + val requiredAttributes = AttributeSet(requiredExpressions).map(_.name).toSet + + val prunedDataSchema = + StructType( + files.dataSchema.filter(f => requiredAttributes.contains(f.name))) + logInfo(s"Pruned Data Schema: ${prunedDataSchema.simpleString(5)}") + + val pushedDownFilters = dataFilters.flatMap(DataSourceStrategy.translateFilter) + logInfo(s"Pushed Filters: ${pushedDownFilters.mkString(",")}") + + val readFile = files.fileFormat.buildReader( + sqlContext = files.sqlContext, + partitionSchema = files.partitionSchema, + dataSchema = prunedDataSchema, + filters = pushedDownFilters, + options = files.options) + + val plannedPartitions = files.bucketSpec match { + case Some(bucketing) if files.sqlContext.conf.bucketingEnabled => + logInfo(s"Planning with ${bucketing.numBuckets} buckets") + val bucketed = + selectedPartitions + .flatMap { p => + p.files.map(f => PartitionedFile(p.values, f.getPath.toUri.toString, 0, f.getLen)) + }.groupBy { f => + BucketingUtils + .getBucketId(new Path(f.filePath).getName) + .getOrElse(sys.error(s"Invalid bucket file ${f.filePath}")) + } + + (0 until bucketing.numBuckets).map { bucketId => + FilePartition(bucketId, bucketed.getOrElse(bucketId, Nil)) + } + + case _ => + val maxSplitBytes = files.sqlContext.conf.filesMaxPartitionBytes + logInfo(s"Planning scan with bin packing, max size: $maxSplitBytes bytes") + + val splitFiles = selectedPartitions.flatMap { partition => + partition.files.flatMap { file => + assert(file.getLen != 0) + (0L to file.getLen by maxSplitBytes).map { offset => + val remaining = file.getLen - offset + val size = if (remaining > maxSplitBytes) maxSplitBytes else remaining + PartitionedFile(partition.values, file.getPath.toUri.toString, offset, size) + } + } + }.toArray.sortBy(_.length)(implicitly[Ordering[Long]].reverse) + + val partitions = new ArrayBuffer[FilePartition] + val currentFiles = new ArrayBuffer[PartitionedFile] + var currentSize = 0L + + /** Add the given file to the current partition. */ + def addFile(file: PartitionedFile): Unit = { + currentSize += file.length + currentFiles.append(file) + } + + /** Close the current partition and move to the next. */ + def closePartition(): Unit = { + if (currentFiles.nonEmpty) { + val newPartition = + FilePartition( + partitions.size, + currentFiles.toArray.toSeq) // Copy to a new Array. + partitions.append(newPartition) + } + currentFiles.clear() + currentSize = 0 + } + + // Assign files to partitions using "First Fit Decreasing" (FFD) + // TODO: consider adding a slop factor here? + splitFiles.foreach { file => + if (currentSize + file.length > maxSplitBytes) { + closePartition() + addFile(file) + } else { + addFile(file) + } + } + closePartition() + partitions + } + + val scan = + DataSourceScan( + l.output, + new FileScanRDD( + files.sqlContext, + readFile, + plannedPartitions), + files, + Map("format" -> files.fileFormat.toString)) + + val afterScanFilter = afterScanFilters.toSeq.reduceOption(expressions.And) + val withFilter = afterScanFilter.map(execution.Filter(_, scan)).getOrElse(scan) + val withProjections = if (projects.forall(_.isInstanceOf[AttributeReference])) { + withFilter + } else { + execution.Project(projects, withFilter) + } + + withProjections :: Nil + + case _ => Nil + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index c3f8d7f75a23a..3ac2ff494fa81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -32,17 +32,23 @@ import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.types._ -object Partition { - def apply(values: InternalRow, path: String): Partition = +object PartitionDirectory { + def apply(values: InternalRow, path: String): PartitionDirectory = apply(values, new Path(path)) } -private[sql] case class Partition(values: InternalRow, path: Path) +/** + * Holds a directory in a partitioned collection of files as well as as the partition values + * in the form of a Row. Before scanning, the files at `path` need to be enumerated. + */ +private[sql] case class PartitionDirectory(values: InternalRow, path: Path) -private[sql] case class PartitionSpec(partitionColumns: StructType, partitions: Seq[Partition]) +private[sql] case class PartitionSpec( + partitionColumns: StructType, + partitions: Seq[PartitionDirectory]) private[sql] object PartitionSpec { - val emptySpec = PartitionSpec(StructType(Seq.empty[StructField]), Seq.empty[Partition]) + val emptySpec = PartitionSpec(StructType(Seq.empty[StructField]), Seq.empty[PartitionDirectory]) } private[sql] object PartitioningUtils { @@ -88,7 +94,7 @@ private[sql] object PartitioningUtils { }.unzip // We create pairs of (path -> path's partition value) here - // If the corresponding partition value is None, the pair will be skiped + // If the corresponding partition value is None, the pair will be skipped val pathsWithPartitionValues = paths.zip(partitionValues).flatMap(x => x._2.map(x._1 -> _)) if (pathsWithPartitionValues.isEmpty) { @@ -133,7 +139,7 @@ private[sql] object PartitioningUtils { // Finally, we create `Partition`s based on paths and resolved partition values. val partitions = resolvedPartitionValues.zip(pathsWithPartitionValues).map { case (PartitionValues(_, literals), (path, _)) => - Partition(InternalRow.fromSeq(literals.map(_.value)), path) + PartitionDirectory(InternalRow.fromSeq(literals.map(_.value)), path) } PartitionSpec(StructType(fields), partitions) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 38aa2dd80a4f1..6a0290c11228f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.csv -import java.nio.charset.Charset +import java.nio.charset.StandardCharsets import org.apache.spark.Logging import org.apache.spark.sql.execution.datasources.CompressionCodecs @@ -64,7 +64,7 @@ private[sql] class CSVOptions( parameters.getOrElse("sep", parameters.getOrElse("delimiter", ","))) val parseMode = parameters.getOrElse("mode", "PERMISSIVE") val charset = parameters.getOrElse("encoding", - parameters.getOrElse("charset", Charset.forName("UTF-8").name())) + parameters.getOrElse("charset", StandardCharsets.UTF_8.name())) val quote = getChar("quote", '\"') val escape = getChar("escape", '\\') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala index 8f1421844c648..8c3f63d307321 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.csv import java.io.{ByteArrayOutputStream, OutputStreamWriter, StringReader} +import java.nio.charset.StandardCharsets import com.univocity.parsers.csv.{CsvParser, CsvParserSettings, CsvWriter, CsvWriterSettings} @@ -76,7 +77,7 @@ private[sql] class LineCsvWriter(params: CSVOptions, headers: Seq[String]) exten def writeRow(row: Seq[String], includeHeader: Boolean): String = { val buffer = new ByteArrayOutputStream() - val outputWriter = new OutputStreamWriter(buffer) + val outputWriter = new OutputStreamWriter(buffer, StandardCharsets.UTF_8) val writer = new CsvWriter(outputWriter, writerSettings) if (includeHeader) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index 0e6b9855c70de..c96a508cf1baa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -52,7 +52,7 @@ object CSVRelation extends Logging { tokenizedRDD: RDD[Array[String]], schema: StructType, requiredColumns: Array[String], - inputs: Array[FileStatus], + inputs: Seq[FileStatus], sqlContext: SQLContext, params: CSVOptions): RDD[Row] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala index aff672281d640..a5f94262ff402 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.csv -import java.nio.charset.Charset +import java.nio.charset.{Charset, StandardCharsets} import org.apache.hadoop.fs.FileStatus import org.apache.hadoop.io.{LongWritable, Text} @@ -103,7 +103,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { requiredColumns: Array[String], filters: Array[Filter], bucketSet: Option[BitSet], - inputFiles: Array[FileStatus], + inputFiles: Seq[FileStatus], broadcastedConf: Broadcast[SerializableConfiguration], options: Map[String, String]): RDD[InternalRow] = { // TODO: Filter before calling buildInternalScan. @@ -161,7 +161,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { sqlContext: SQLContext, options: CSVOptions, location: String): RDD[String] = { - if (Charset.forName(options.charset) == Charset.forName("UTF-8")) { + if (Charset.forName(options.charset) == StandardCharsets.UTF_8) { sqlContext.sparkContext.textFile(location) } else { val charset = options.charset diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 903c9913ac0bd..04e51735c46f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -33,8 +33,9 @@ import org.apache.spark.sql.types._ * It is effective only when the table is a Hive table. */ case class DescribeCommand( - table: LogicalPlan, - isExtended: Boolean) extends LogicalPlan with logical.Command { + table: TableIdentifier, + isExtended: Boolean) + extends LogicalPlan with logical.Command { override def children: Seq[LogicalPlan] = Seq.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index e295722cacf15..64a820c6d741f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -70,7 +70,7 @@ object JdbcUtils extends Logging { // Somewhat hacky, but there isn't a good way to identify whether a table exists for all // SQL database systems using JDBC meta data calls, considering "table" could also include - // the database name. Query used to find table exists can be overriden by the dialects. + // the database name. Query used to find table exists can be overridden by the dialects. Try { val statement = conn.prepareStatement(dialect.getTableExistsQuery(table)) try { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 05b44d1a2a04f..3fa5ebf1bb81e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -95,7 +95,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { requiredColumns: Array[String], filters: Array[Filter], bucketSet: Option[BitSet], - inputFiles: Array[FileStatus], + inputFiles: Seq[FileStatus], broadcastedConf: Broadcast[SerializableConfiguration], options: Map[String, String]): RDD[InternalRow] = { // TODO: Filter files for all formats before calling buildInternalScan. @@ -115,7 +115,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { } } - private def createBaseRdd(sqlContext: SQLContext, inputPaths: Array[FileStatus]): RDD[String] = { + private def createBaseRdd(sqlContext: SQLContext, inputPaths: Seq[FileStatus]): RDD[String] = { val job = Job.getInstance(sqlContext.sparkContext.hadoopConfiguration) val conf = job.getConfiguration diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index f1060074d6acb..342034ca0ff92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -274,7 +274,7 @@ private[sql] class DefaultSource extends FileFormat with DataSourceRegister with requiredColumns: Array[String], filters: Array[Filter], bucketSet: Option[BitSet], - allFiles: Array[FileStatus], + allFiles: Seq[FileStatus], broadcastedConf: Broadcast[SerializableConfiguration], options: Map[String, String]): RDD[InternalRow] = { val useMetadataCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index 2869a6a1ac074..6af403dec5fba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -94,7 +94,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { requiredColumns: Array[String], filters: Array[Filter], bucketSet: Option[BitSet], - inputFiles: Array[FileStatus], + inputFiles: Seq[FileStatus], broadcastedConf: Broadcast[SerializableConfiguration], options: Map[String, String]): RDD[InternalRow] = { verifySchema(dataSchema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 709a4246365dd..4864db7f2ac9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.execution.exchange -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ +import org.apache.spark.sql.internal.SQLConf /** * Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]] @@ -30,15 +30,15 @@ import org.apache.spark.sql.execution._ * each operator by inserting [[ShuffleExchange]] Operators where required. Also ensure that the * input partition ordering requirements are met. */ -private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] { - private def defaultNumPreShufflePartitions: Int = sqlContext.conf.numShufflePartitions +case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { + private def defaultNumPreShufflePartitions: Int = conf.numShufflePartitions - private def targetPostShuffleInputSize: Long = sqlContext.conf.targetPostShuffleInputSize + private def targetPostShuffleInputSize: Long = conf.targetPostShuffleInputSize - private def adaptiveExecutionEnabled: Boolean = sqlContext.conf.adaptiveExecutionEnabled + private def adaptiveExecutionEnabled: Boolean = conf.adaptiveExecutionEnabled private def minNumPostShufflePartitions: Option[Int] = { - val minNumPostShufflePartitions = sqlContext.conf.minNumPostShufflePartitions + val minNumPostShufflePartitions = conf.minNumPostShufflePartitions if (minNumPostShufflePartitions > 0) Some(minNumPostShufflePartitions) else None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala index 12513e9106707..9eaadea1b11ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala @@ -22,11 +22,11 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType /** @@ -64,10 +64,10 @@ case class ReusedExchange(override val output: Seq[Attribute], child: Exchange) * Find out duplicated exchanges in the spark plan, then use the same exchange for all the * references. */ -private[sql] case class ReuseExchange(sqlContext: SQLContext) extends Rule[SparkPlan] { +case class ReuseExchange(conf: SQLConf) extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = { - if (!sqlContext.conf.exchangeReuseEnabled) { + if (!conf.exchangeReuseEnabled) { return plan } // Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 99f8841c8737b..0b0f59c3e4634 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode import org.apache.spark.sql.execution.SparkSqlSerializer -import org.apache.spark.sql.execution.local.LocalNode import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.util.{KnownSizeEstimation, SizeEstimator, Utils} @@ -157,10 +156,11 @@ private[joins] class UniqueKeyHashedRelation( private[execution] object HashedRelation { - def apply(localNode: LocalNode, keyGenerator: Projection): HashedRelation = { - apply(localNode.asIterator, keyGenerator) - } - + /** + * Create a HashedRelation from an Iterator of InternalRow. + * + * Note: The caller should make sure that these InternalRow are different objects. + */ def apply( input: Iterator[InternalRow], keyGenerator: Projection, @@ -193,7 +193,7 @@ private[execution] object HashedRelation { keyIsUnique = false existingMatchList } - matchList += currentRow.copy() + matchList += currentRow } } @@ -443,7 +443,7 @@ private[joins] object UnsafeHashedRelation { } else { existingMatchList } - matchList += unsafeRow.copy() + matchList += unsafeRow } } @@ -627,7 +627,7 @@ private[joins] object LongHashedRelation { keyIsUnique = false existingMatchList } - matchList += unsafeRow.copy() + matchList += unsafeRow } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index 242ed612327cc..14389e45babed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -47,7 +47,7 @@ case class LeftSemiJoinHash( val numOutputRows = longMetric("numOutputRows") right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) => - val hashRelation = HashedRelation(buildIter, rightKeyGenerator) + val hashRelation = HashedRelation(buildIter.map(_.copy()), rightKeyGenerator) hashSemiJoin(streamIter, hashRelation, numOutputRows) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala deleted file mode 100644 index 97f9358016940..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala +++ /dev/null @@ -1,71 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide, HashedRelation} -import org.apache.spark.sql.internal.SQLConf - -/** - * A [[HashJoinNode]] that builds the [[HashedRelation]] according to the value of - * `buildSide`. The actual work of this node is defined in [[HashJoinNode]]. - */ -case class BinaryHashJoinNode( - conf: SQLConf, - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - buildSide: BuildSide, - left: LocalNode, - right: LocalNode) - extends BinaryLocalNode(conf) with HashJoinNode { - - protected override val (streamedNode, streamedKeys) = buildSide match { - case BuildLeft => (right, rightKeys) - case BuildRight => (left, leftKeys) - } - - private val (buildNode, buildKeys) = buildSide match { - case BuildLeft => (left, leftKeys) - case BuildRight => (right, rightKeys) - } - - override def output: Seq[Attribute] = left.output ++ right.output - - private def buildSideKeyGenerator: Projection = { - // We are expecting the data types of buildKeys and streamedKeys are the same. - assert(buildKeys.map(_.dataType) == streamedKeys.map(_.dataType)) - UnsafeProjection.create(buildKeys, buildNode.output) - } - - protected override def doOpen(): Unit = { - buildNode.open() - val hashedRelation = HashedRelation(buildNode, buildSideKeyGenerator) - // We have built the HashedRelation. So, close buildNode. - buildNode.close() - - streamedNode.open() - // Set the HashedRelation used by the HashJoinNode. - withHashedRelation(hashedRelation) - } - - override def close(): Unit = { - // Please note that we do not need to call the close method of our buildNode because - // it has been called in this.open. - streamedNode.close() - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BroadcastHashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BroadcastHashJoinNode.scala deleted file mode 100644 index 779f4833fa417..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BroadcastHashJoinNode.scala +++ /dev/null @@ -1,58 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide, HashedRelation} -import org.apache.spark.sql.internal.SQLConf - -/** - * A [[HashJoinNode]] for broadcast join. It takes a streamedNode and a broadcast - * [[HashedRelation]]. The actual work of this node is defined in [[HashJoinNode]]. - */ -case class BroadcastHashJoinNode( - conf: SQLConf, - streamedKeys: Seq[Expression], - streamedNode: LocalNode, - buildSide: BuildSide, - buildOutput: Seq[Attribute], - hashedRelation: Broadcast[HashedRelation]) - extends UnaryLocalNode(conf) with HashJoinNode { - - override val child = streamedNode - - // Because we do not pass in the buildNode, we take the output of buildNode to - // create the inputSet properly. - override def inputSet: AttributeSet = AttributeSet(child.output ++ buildOutput) - - override def output: Seq[Attribute] = buildSide match { - case BuildRight => streamedNode.output ++ buildOutput - case BuildLeft => buildOutput ++ streamedNode.output - } - - protected override def doOpen(): Unit = { - streamedNode.open() - // Set the HashedRelation used by the HashJoinNode. - withHashedRelation(hashedRelation.value) - } - - override def close(): Unit = { - streamedNode.close() - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala deleted file mode 100644 index f79d795a904d1..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala +++ /dev/null @@ -1,40 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, FromUnsafeProjection, Projection} -import org.apache.spark.sql.internal.SQLConf - -case class ConvertToSafeNode(conf: SQLConf, child: LocalNode) extends UnaryLocalNode(conf) { - - override def output: Seq[Attribute] = child.output - - private[this] var convertToSafe: Projection = _ - - override def open(): Unit = { - child.open() - convertToSafe = FromUnsafeProjection(child.schema) - } - - override def next(): Boolean = child.next() - - override def fetch(): InternalRow = convertToSafe(child.fetch()) - - override def close(): Unit = child.close() -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala deleted file mode 100644 index f3fa474b0f7ff..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala +++ /dev/null @@ -1,40 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Projection, UnsafeProjection} -import org.apache.spark.sql.internal.SQLConf - -case class ConvertToUnsafeNode(conf: SQLConf, child: LocalNode) extends UnaryLocalNode(conf) { - - override def output: Seq[Attribute] = child.output - - private[this] var convertToUnsafe: Projection = _ - - override def open(): Unit = { - child.open() - convertToUnsafe = UnsafeProjection.create(child.schema) - } - - override def next(): Boolean = child.next() - - override def fetch(): InternalRow = convertToUnsafe(child.fetch()) - - override def close(): Unit = child.close() -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala deleted file mode 100644 index 6ccd6db0e6ca4..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala +++ /dev/null @@ -1,60 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.internal.SQLConf - -case class ExpandNode( - conf: SQLConf, - projections: Seq[Seq[Expression]], - output: Seq[Attribute], - child: LocalNode) extends UnaryLocalNode(conf) { - - assert(projections.size > 0) - - private[this] var result: InternalRow = _ - private[this] var idx: Int = _ - private[this] var input: InternalRow = _ - private[this] var groups: Array[Projection] = _ - - override def open(): Unit = { - child.open() - groups = projections.map(ee => newMutableProjection(ee, child.output)()).toArray - idx = groups.length - } - - override def next(): Boolean = { - if (idx >= groups.length) { - if (child.next()) { - input = child.fetch() - idx = 0 - } else { - return false - } - } - result = groups(idx)(input) - idx += 1 - true - } - - override def fetch(): InternalRow = result - - override def close(): Unit = child.close() -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala deleted file mode 100644 index c5eb33cef4420..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate -import org.apache.spark.sql.internal.SQLConf - - -case class FilterNode(conf: SQLConf, condition: Expression, child: LocalNode) - extends UnaryLocalNode(conf) { - - private[this] var predicate: (InternalRow) => Boolean = _ - - override def output: Seq[Attribute] = child.output - - override def open(): Unit = { - child.open() - predicate = GeneratePredicate.generate(condition, child.output) - } - - override def next(): Boolean = { - var found = false - while (!found && child.next()) { - found = predicate.apply(child.fetch()) - } - found - } - - override def fetch(): InternalRow = child.fetch() - - override def close(): Unit = child.close() -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala deleted file mode 100644 index fd7948ffa9a9b..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala +++ /dev/null @@ -1,111 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.joins._ - -/** - * An abstract node for sharing common functionality among different implementations of - * inner hash equi-join, notably [[BinaryHashJoinNode]] and [[BroadcastHashJoinNode]]. - * - * Much of this code is similar to [[org.apache.spark.sql.execution.joins.HashJoin]]. - */ -trait HashJoinNode { - - self: LocalNode => - - protected def streamedKeys: Seq[Expression] - protected def streamedNode: LocalNode - protected def buildSide: BuildSide - - private[this] var currentStreamedRow: InternalRow = _ - private[this] var currentHashMatches: Seq[InternalRow] = _ - private[this] var currentMatchPosition: Int = -1 - - private[this] var joinRow: JoinedRow = _ - private[this] var resultProjection: (InternalRow) => InternalRow = _ - - private[this] var hashed: HashedRelation = _ - private[this] var joinKeys: Projection = _ - - private def streamSideKeyGenerator: Projection = - UnsafeProjection.create(streamedKeys, streamedNode.output) - - /** - * Sets the HashedRelation used by this node. This method needs to be called after - * before the first `next` gets called. - */ - protected def withHashedRelation(hashedRelation: HashedRelation): Unit = { - hashed = hashedRelation - } - - /** - * Custom open implementation to be overridden by subclasses. - */ - protected def doOpen(): Unit - - override def open(): Unit = { - doOpen() - joinRow = new JoinedRow - resultProjection = UnsafeProjection.create(schema) - joinKeys = streamSideKeyGenerator - } - - override def next(): Boolean = { - currentMatchPosition += 1 - if (currentHashMatches == null || currentMatchPosition >= currentHashMatches.size) { - fetchNextMatch() - } else { - true - } - } - - /** - * Populate `currentHashMatches` with build-side rows matching the next streamed row. - * @return whether matches are found such that subsequent calls to `fetch` are valid. - */ - private def fetchNextMatch(): Boolean = { - currentHashMatches = null - currentMatchPosition = -1 - - while (currentHashMatches == null && streamedNode.next()) { - currentStreamedRow = streamedNode.fetch() - val key = joinKeys(currentStreamedRow) - if (!key.anyNull) { - currentHashMatches = hashed.get(key) - } - } - - if (currentHashMatches == null) { - false - } else { - currentMatchPosition = 0 - true - } - } - - override def fetch(): InternalRow = { - val ret = buildSide match { - case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) - case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) - } - resultProjection(ret) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala deleted file mode 100644 index e594e132dea79..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala +++ /dev/null @@ -1,63 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import scala.collection.mutable - -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.internal.SQLConf - -case class IntersectNode(conf: SQLConf, left: LocalNode, right: LocalNode) - extends BinaryLocalNode(conf) { - - override def output: Seq[Attribute] = left.output - - private[this] var leftRows: mutable.HashSet[InternalRow] = _ - - private[this] var currentRow: InternalRow = _ - - override def open(): Unit = { - left.open() - leftRows = mutable.HashSet[InternalRow]() - while (left.next()) { - leftRows += left.fetch().copy() - } - left.close() - right.open() - } - - override def next(): Boolean = { - currentRow = null - while (currentRow == null && right.next()) { - currentRow = right.fetch() - if (!leftRows.contains(currentRow)) { - currentRow = null - } - } - currentRow != null - } - - override def fetch(): InternalRow = currentRow - - override def close(): Unit = { - left.close() - right.close() - } - -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala deleted file mode 100644 index 9af45ac0aac9a..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala +++ /dev/null @@ -1,46 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.internal.SQLConf - - -case class LimitNode(conf: SQLConf, limit: Int, child: LocalNode) extends UnaryLocalNode(conf) { - - private[this] var count = 0 - - override def output: Seq[Attribute] = child.output - - override def open(): Unit = child.open() - - override def close(): Unit = child.close() - - override def fetch(): InternalRow = child.fetch() - - override def next(): Boolean = { - if (count < limit) { - count += 1 - child.next() - } else { - false - } - } - -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala deleted file mode 100644 index a5d09691dc46c..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ /dev/null @@ -1,157 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.Logging -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.StructType - -/** - * A local physical operator, in the form of an iterator. - * - * Before consuming the iterator, open function must be called. - * After consuming the iterator, close function must be called. - */ -abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Logging { - - private[this] lazy val isTesting: Boolean = sys.props.contains("spark.testing") - - /** - * Called before open(). Prepare can be used to reserve memory needed. It must NOT consume - * any input data. - * - * Implementations of this must also call the `prepare()` function of its children. - */ - def prepare(): Unit = children.foreach(_.prepare()) - - /** - * Initializes the iterator state. Must be called before calling `next()`. - * - * Implementations of this must also call the `open()` function of its children. - */ - def open(): Unit - - /** - * Advances the iterator to the next tuple. Returns true if there is at least one more tuple. - */ - def next(): Boolean - - /** - * Returns the current tuple. - */ - def fetch(): InternalRow - - /** - * Closes the iterator and releases all resources. It should be idempotent. - * - * Implementations of this must also call the `close()` function of its children. - */ - def close(): Unit - - /** - * Returns the content through the [[Iterator]] interface. - */ - final def asIterator: Iterator[InternalRow] = new LocalNodeIterator(this) - - /** - * Returns the content of the iterator from the beginning to the end in the form of a Scala Seq. - */ - final def collect(): Seq[Row] = { - val converter = CatalystTypeConverters.createToScalaConverter(StructType.fromAttributes(output)) - val result = new scala.collection.mutable.ArrayBuffer[Row] - open() - try { - while (next()) { - result += converter.apply(fetch()).asInstanceOf[Row] - } - } finally { - close() - } - result - } - - protected def newMutableProjection( - expressions: Seq[Expression], - inputSchema: Seq[Attribute]): () => MutableProjection = { - log.debug( - s"Creating MutableProj: $expressions, inputSchema: $inputSchema") - GenerateMutableProjection.generate(expressions, inputSchema) - } - - protected def newPredicate( - expression: Expression, - inputSchema: Seq[Attribute]): (InternalRow) => Boolean = { - GeneratePredicate.generate(expression, inputSchema) - } -} - - -abstract class LeafLocalNode(conf: SQLConf) extends LocalNode(conf) { - override def children: Seq[LocalNode] = Seq.empty -} - - -abstract class UnaryLocalNode(conf: SQLConf) extends LocalNode(conf) { - - def child: LocalNode - - override def children: Seq[LocalNode] = Seq(child) -} - -abstract class BinaryLocalNode(conf: SQLConf) extends LocalNode(conf) { - - def left: LocalNode - - def right: LocalNode - - override def children: Seq[LocalNode] = Seq(left, right) -} - -/** - * An thin wrapper around a [[LocalNode]] that provides an `Iterator` interface. - */ -private class LocalNodeIterator(localNode: LocalNode) extends Iterator[InternalRow] { - private var nextRow: InternalRow = _ - - override def hasNext: Boolean = { - if (nextRow == null) { - val res = localNode.next() - if (res) { - nextRow = localNode.fetch() - } - res - } else { - true - } - } - - override def next(): InternalRow = { - if (hasNext) { - val res = nextRow - nextRow = null - res - } else { - throw new NoSuchElementException - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala deleted file mode 100644 index b5ea08325c58e..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala +++ /dev/null @@ -1,152 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} -import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.util.collection.{BitSet, CompactBuffer} - -case class NestedLoopJoinNode( - conf: SQLConf, - left: LocalNode, - right: LocalNode, - buildSide: BuildSide, - joinType: JoinType, - condition: Option[Expression]) extends BinaryLocalNode(conf) { - - override def output: Seq[Attribute] = { - joinType match { - case LeftOuter => - left.output ++ right.output.map(_.withNullability(true)) - case RightOuter => - left.output.map(_.withNullability(true)) ++ right.output - case FullOuter => - left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case x => - throw new IllegalArgumentException( - s"NestedLoopJoin should not take $x as the JoinType") - } - } - - private[this] def genResultProjection: InternalRow => InternalRow = { - UnsafeProjection.create(schema) - } - - private[this] var currentRow: InternalRow = _ - - private[this] var iterator: Iterator[InternalRow] = _ - - override def open(): Unit = { - val (streamed, build) = buildSide match { - case BuildRight => (left, right) - case BuildLeft => (right, left) - } - build.open() - val buildRelation = new CompactBuffer[InternalRow] - while (build.next()) { - buildRelation += build.fetch().copy() - } - build.close() - - val boundCondition = - newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) - - val leftNulls = new GenericMutableRow(left.output.size) - val rightNulls = new GenericMutableRow(right.output.size) - val joinedRow = new JoinedRow - val matchedBuildTuples = new BitSet(buildRelation.size) - val resultProj = genResultProjection - streamed.open() - - // streamedRowMatches also contains null rows if using outer join - val streamedRowMatches: Iterator[InternalRow] = streamed.asIterator.flatMap { streamedRow => - val matchedRows = new CompactBuffer[InternalRow] - - var i = 0 - var streamRowMatched = false - - // Scan the build relation to look for matches for each streamed row - while (i < buildRelation.size) { - val buildRow = buildRelation(i) - buildSide match { - case BuildRight => joinedRow(streamedRow, buildRow) - case BuildLeft => joinedRow(buildRow, streamedRow) - } - if (boundCondition(joinedRow)) { - matchedRows += resultProj(joinedRow).copy() - streamRowMatched = true - matchedBuildTuples.set(i) - } - i += 1 - } - - // If this row had no matches and we're using outer join, join it with the null rows - if (!streamRowMatched) { - (joinType, buildSide) match { - case (LeftOuter | FullOuter, BuildRight) => - matchedRows += resultProj(joinedRow(streamedRow, rightNulls)).copy() - case (RightOuter | FullOuter, BuildLeft) => - matchedRows += resultProj(joinedRow(leftNulls, streamedRow)).copy() - case _ => - } - } - - matchedRows.iterator - } - - // If we're using outer join, find rows on the build side that didn't match anything - // and join them with the null row - lazy val unmatchedBuildRows: Iterator[InternalRow] = { - var i = 0 - buildRelation.filter { row => - val r = !matchedBuildTuples.get(i) - i += 1 - r - }.iterator - } - iterator = (joinType, buildSide) match { - case (RightOuter | FullOuter, BuildRight) => - streamedRowMatches ++ - unmatchedBuildRows.map { buildRow => resultProj(joinedRow(leftNulls, buildRow)) } - case (LeftOuter | FullOuter, BuildLeft) => - streamedRowMatches ++ - unmatchedBuildRows.map { buildRow => resultProj(joinedRow(buildRow, rightNulls)) } - case _ => streamedRowMatches - } - } - - override def next(): Boolean = { - if (iterator.hasNext) { - currentRow = iterator.next() - true - } else { - false - } - } - - override def fetch(): InternalRow = currentRow - - override def close(): Unit = { - left.close() - right.close() - } - -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala deleted file mode 100644 index 5fe068a13c8a4..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala +++ /dev/null @@ -1,44 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, UnsafeProjection} -import org.apache.spark.sql.internal.SQLConf - - -case class ProjectNode(conf: SQLConf, projectList: Seq[NamedExpression], child: LocalNode) - extends UnaryLocalNode(conf) { - - private[this] var project: UnsafeProjection = _ - - override def output: Seq[Attribute] = projectList.map(_.toAttribute) - - override def open(): Unit = { - project = UnsafeProjection.create(projectList, child.output) - child.open() - } - - override def next(): Boolean = child.next() - - override def fetch(): InternalRow = { - project.apply(child.fetch()) - } - - override def close(): Unit = child.close() -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala deleted file mode 100644 index 078fb50deb16f..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} - - -/** - * Sample the dataset. - * - * @param conf the SQLConf - * @param lowerBound Lower-bound of the sampling probability (usually 0.0) - * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled - * will be ub - lb. - * @param withReplacement Whether to sample with replacement. - * @param seed the random seed - * @param child the LocalNode - */ -case class SampleNode( - conf: SQLConf, - lowerBound: Double, - upperBound: Double, - withReplacement: Boolean, - seed: Long, - child: LocalNode) extends UnaryLocalNode(conf) { - - override def output: Seq[Attribute] = child.output - - private[this] var iterator: Iterator[InternalRow] = _ - - private[this] var currentRow: InternalRow = _ - - override def open(): Unit = { - child.open() - val sampler = - if (withReplacement) { - // Disable gap sampling since the gap sampling method buffers two rows internally, - // requiring us to copy the row, which is more expensive than the random number generator. - new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false) - } else { - new BernoulliCellSampler[InternalRow](lowerBound, upperBound) - } - sampler.setSeed(seed) - iterator = sampler.sample(child.asIterator) - } - - override def next(): Boolean = { - if (iterator.hasNext) { - currentRow = iterator.next() - true - } else { - false - } - } - - override def fetch(): InternalRow = currentRow - - override def close(): Unit = child.close() - -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala deleted file mode 100644 index 8ebfe3a68b3a3..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.internal.SQLConf - -/** - * An operator that scans some local data collection in the form of Scala Seq. - */ -case class SeqScanNode(conf: SQLConf, output: Seq[Attribute], data: Seq[InternalRow]) - extends LeafLocalNode(conf) { - - private[this] var iterator: Iterator[InternalRow] = _ - private[this] var currentRow: InternalRow = _ - - override def open(): Unit = { - iterator = data.iterator - } - - override def next(): Boolean = { - if (iterator.hasNext) { - currentRow = iterator.next() - true - } else { - false - } - } - - override def fetch(): InternalRow = currentRow - - override def close(): Unit = { - // Do nothing - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala deleted file mode 100644 index f52f5f7bb59b7..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.util.BoundedPriorityQueue - -case class TakeOrderedAndProjectNode( - conf: SQLConf, - limit: Int, - sortOrder: Seq[SortOrder], - projectList: Option[Seq[NamedExpression]], - child: LocalNode) extends UnaryLocalNode(conf) { - - private[this] var projection: Option[Projection] = _ - private[this] var ord: Ordering[InternalRow] = _ - private[this] var iterator: Iterator[InternalRow] = _ - private[this] var currentRow: InternalRow = _ - - override def output: Seq[Attribute] = { - val projectOutput = projectList.map(_.map(_.toAttribute)) - projectOutput.getOrElse(child.output) - } - - override def open(): Unit = { - child.open() - projection = projectList.map(UnsafeProjection.create(_, child.output)) - ord = GenerateOrdering.generate(sortOrder, child.output) - // Priority keeps the largest elements, so let's reverse the ordering. - val queue = new BoundedPriorityQueue[InternalRow](limit)(ord.reverse) - while (child.next()) { - queue += child.fetch() - } - // Close it eagerly since we don't need it. - child.close() - iterator = queue.toArray.sorted(ord).iterator - } - - override def next(): Boolean = { - if (iterator.hasNext) { - val _currentRow = iterator.next() - currentRow = projection match { - case Some(p) => p(_currentRow) - case None => _currentRow - } - true - } else { - false - } - } - - override def fetch(): InternalRow = currentRow - - override def close(): Unit = child.close() - -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala deleted file mode 100644 index e53bc220d8d34..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala +++ /dev/null @@ -1,73 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.internal.SQLConf - -case class UnionNode(conf: SQLConf, children: Seq[LocalNode]) extends LocalNode(conf) { - - override def output: Seq[Attribute] = children.head.output - - private[this] var currentChild: LocalNode = _ - - private[this] var nextChildIndex: Int = _ - - override def open(): Unit = { - currentChild = children.head - currentChild.open() - nextChildIndex = 1 - } - - private def advanceToNextChild(): Boolean = { - var found = false - var exit = false - while (!exit && !found) { - if (currentChild != null) { - currentChild.close() - } - if (nextChildIndex >= children.size) { - found = false - exit = true - } else { - currentChild = children(nextChildIndex) - nextChildIndex += 1 - currentChild.open() - found = currentChild.next() - } - } - found - } - - override def close(): Unit = { - if (currentChild != null) { - currentChild.close() - } - } - - override def fetch(): InternalRow = currentChild.fetch() - - override def next(): Boolean = { - if (currentChild.next()) { - true - } else { - advanceToNextChild() - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala index c65a7bcff8503..79e4491026b65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types.{StructField, StructType} /** - * A physical plan that evalutes a [[PythonUDF]], one partition of tuples at a time. + * A physical plan that evaluates a [[PythonUDF]], one partition of tuples at a time. * * Python evaluation works by sending the necessary (projected) input data via a socket to an * external Python process, and combine the result from the Python process with the original row. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 14ba9f69bb1d7..25c8a69b1f1ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -17,13 +17,9 @@ package org.apache.spark.sql.execution.streaming -import java.io._ +import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.{ArrayBuffer, HashMap} -import scala.io.Codec - -import com.google.common.base.Charsets.UTF_8 -import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.Logging import org.apache.spark.sql.{DataFrame, SQLContext} @@ -44,33 +40,12 @@ class FileStreamSource( dataFrameBuilder: Array[String] => DataFrame) extends Source with Logging { private val fs = FileSystem.get(sqlContext.sparkContext.hadoopConfiguration) - private var maxBatchId = -1 - private val seenFiles = new OpenHashSet[String] + private val metadataLog = new HDFSMetadataLog[Seq[String]](sqlContext, metadataPath) + private var maxBatchId = metadataLog.getLatest().map(_._1).getOrElse(-1L) - /** Map of batch id to files. This map is also stored in `metadataPath`. */ - private val batchToMetadata = new HashMap[Long, Seq[String]] - - { - // Restore file paths from the metadata files - val existingBatchFiles = fetchAllBatchFiles() - if (existingBatchFiles.nonEmpty) { - val existingBatchIds = existingBatchFiles.map(_.getPath.getName.toInt) - maxBatchId = existingBatchIds.max - // Recover "batchToMetadata" and "seenFiles" from existing metadata files. - existingBatchIds.sorted.foreach { batchId => - val files = readBatch(batchId) - if (files.isEmpty) { - // Assert that the corrupted file must be the latest metadata file. - if (batchId != maxBatchId) { - throw new IllegalStateException("Invalid metadata files") - } - maxBatchId = maxBatchId - 1 - } else { - batchToMetadata(batchId) = files - files.foreach(seenFiles.add) - } - } - } + private val seenFiles = new OpenHashSet[String] + metadataLog.get(None, maxBatchId).foreach { case (batchId, files) => + files.foreach(seenFiles.add) } /** Returns the schema of the data from this source */ @@ -112,7 +87,7 @@ class FileStreamSource( if (newFiles.nonEmpty) { maxBatchId += 1 - writeBatch(maxBatchId, newFiles) + metadataLog.add(maxBatchId, newFiles) } new LongOffset(maxBatchId) @@ -140,9 +115,7 @@ class FileStreamSource( val endId = end.offset if (startId + 1 <= endId) { - val files = (startId + 1 to endId).filter(_ >= 0).flatMap { batchId => - batchToMetadata.getOrElse(batchId, Nil) - }.toArray + val files = metadataLog.get(Some(startId + 1), endId).map(_._2).flatten logDebug(s"Return files from batches ${startId + 1}:$endId") logDebug(s"Streaming ${files.mkString(", ")}") Some(new Batch(end, dataFrameBuilder(files))) @@ -152,89 +125,9 @@ class FileStreamSource( } } - private def fetchAllBatchFiles(): Seq[FileStatus] = { - try fs.listStatus(new Path(metadataPath)) catch { - case _: java.io.FileNotFoundException => - fs.mkdirs(new Path(metadataPath)) - Seq.empty - } - } - private def fetchAllFiles(): Seq[String] = { fs.listStatus(new Path(path)) .filterNot(_.getPath.getName.startsWith("_")) .map(_.getPath.toUri.toString) } - - /** - * Write the metadata of a batch to disk. The file format is as follows: - * - * {{{ - * - * START - * -/a/b/c - * -/d/e/f - * ... - * END - * }}} - * - * Note: means the value of `FileStreamSource.VERSION`. Every file - * path starts with "-" so that we can know if a line is a file path easily. - */ - private def writeBatch(id: Int, files: Seq[String]): Unit = { - assert(files.nonEmpty, "create a new batch without any file") - val output = fs.create(new Path(metadataPath + "/" + id), true) - val writer = new PrintWriter(new OutputStreamWriter(output, UTF_8)) - try { - // scalastyle:off println - writer.println(FileStreamSource.VERSION) - writer.println(FileStreamSource.START_TAG) - files.foreach(file => writer.println(FileStreamSource.PATH_PREFIX + file)) - writer.println(FileStreamSource.END_TAG) - // scalastyle:on println - } finally { - writer.close() - } - batchToMetadata(id) = files - } - - /** Read the file names of the specified batch id from the metadata file */ - private def readBatch(id: Int): Seq[String] = { - val input = fs.open(new Path(metadataPath + "/" + id)) - try { - FileStreamSource.readBatch(input) - } finally { - input.close() - } - } -} - -object FileStreamSource { - - private val START_TAG = "START" - private val END_TAG = "END" - private val PATH_PREFIX = "-" - val VERSION = "FILESTREAM_V1" - - /** - * Parse a metadata file and return the content. If the metadata file is corrupted, it will return - * an empty `Seq`. - */ - def readBatch(input: InputStream): Seq[String] = { - val lines = scala.io.Source.fromInputStream(input)(Codec.UTF8).getLines().toArray - if (lines.length < 4) { - // version + start tag + end tag + at least one file path - return Nil - } - if (lines.head != VERSION) { - return Nil - } - if (lines(1) != START_TAG) { - return Nil - } - if (lines.last != END_TAG) { - return Nil - } - lines.slice(2, lines.length - 1).map(_.stripPrefix(PATH_PREFIX)) // Drop character "-" - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala new file mode 100644 index 0000000000000..ac2842b6d5df9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -0,0 +1,193 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.streaming + +import java.io.{FileNotFoundException, IOException} +import java.nio.ByteBuffer +import java.util.{ConcurrentModificationException, EnumSet} + +import scala.reflect.ClassTag + +import org.apache.commons.io.IOUtils +import org.apache.hadoop.fs._ +import org.apache.hadoop.fs.permission.FsPermission + +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.sql.SQLContext + +/** + * A [[MetadataLog]] implementation based on HDFS. [[HDFSMetadataLog]] uses the specified `path` + * as the metadata storage. + * + * When writing a new batch, [[HDFSMetadataLog]] will firstly write to a temp file and then rename + * it to the final batch file. If the rename step fails, there must be multiple writers and only + * one of them will succeed and the others will fail. + * + * Note: [[HDFSMetadataLog]] doesn't support S3-like file systems as they don't guarantee listing + * files in a directory always shows the latest files. + */ +class HDFSMetadataLog[T: ClassTag](sqlContext: SQLContext, path: String) extends MetadataLog[T] { + + private val metadataPath = new Path(path) + + private val fc = + if (metadataPath.toUri.getScheme == null) { + FileContext.getFileContext(sqlContext.sparkContext.hadoopConfiguration) + } else { + FileContext.getFileContext(metadataPath.toUri, sqlContext.sparkContext.hadoopConfiguration) + } + + if (!fc.util().exists(metadataPath)) { + fc.mkdir(metadataPath, FsPermission.getDirDefault, true) + } + + /** + * A `PathFilter` to filter only batch files + */ + private val batchFilesFilter = new PathFilter { + override def accept(path: Path): Boolean = try { + path.getName.toLong + true + } catch { + case _: NumberFormatException => false + } + } + + private val serializer = new JavaSerializer(sqlContext.sparkContext.conf).newInstance() + + private def batchFile(batchId: Long): Path = { + new Path(metadataPath, batchId.toString) + } + + override def add(batchId: Long, metadata: T): Boolean = { + get(batchId).map(_ => false).getOrElse { + // Only write metadata when the batch has not yet been written. + val buffer = serializer.serialize(metadata) + try { + writeBatch(batchId, JavaUtils.bufferToArray(buffer)) + true + } catch { + case e: IOException if "java.lang.InterruptedException" == e.getMessage => + // create may convert InterruptedException to IOException. Let's convert it back to + // InterruptedException so that this failure won't crash StreamExecution + throw new InterruptedException("Creating file is interrupted") + } + } + } + + /** + * Write a batch to a temp file then rename it to the batch file. + * + * There may be multiple [[HDFSMetadataLog]] using the same metadata path. Although it is not a + * valid behavior, we still need to prevent it from destroying the files. + */ + private def writeBatch(batchId: Long, bytes: Array[Byte]): Unit = { + // Use nextId to create a temp file + var nextId = 0 + while (true) { + val tempPath = new Path(metadataPath, s".${batchId}_$nextId.tmp") + fc.deleteOnExit(tempPath) + try { + val output = fc.create(tempPath, EnumSet.of(CreateFlag.CREATE)) + try { + output.write(bytes) + } finally { + output.close() + } + try { + // Try to commit the batch + // It will fail if there is an existing file (someone has committed the batch) + fc.rename(tempPath, batchFile(batchId), Options.Rename.NONE) + return + } catch { + case e: IOException if isFileAlreadyExistsException(e) => + // If "rename" fails, it means some other "HDFSMetadataLog" has committed the batch. + // So throw an exception to tell the user this is not a valid behavior. + throw new ConcurrentModificationException( + s"Multiple HDFSMetadataLog are using $path", e) + case e: FileNotFoundException => + // Sometimes, "create" will succeed when multiple writers are calling it at the same + // time. However, only one writer can call "rename" successfully, others will get + // FileNotFoundException because the first writer has removed it. + throw new ConcurrentModificationException( + s"Multiple HDFSMetadataLog are using $path", e) + } + } catch { + case e: IOException if isFileAlreadyExistsException(e) => + // Failed to create "tempPath". There are two cases: + // 1. Someone is creating "tempPath" too. + // 2. This is a restart. "tempPath" has already been created but not moved to the final + // batch file (not committed). + // + // For both cases, the batch has not yet been committed. So we can retry it. + // + // Note: there is a potential risk here: if HDFSMetadataLog A is running, people can use + // the same metadata path to create "HDFSMetadataLog" and fail A. However, this is not a + // big problem because it requires the attacker must have the permission to write the + // metadata path. In addition, the old Streaming also have this issue, people can create + // malicious checkpoint files to crash a Streaming application too. + nextId += 1 + } + } + } + + private def isFileAlreadyExistsException(e: IOException): Boolean = { + e.isInstanceOf[FileAlreadyExistsException] || + // Old Hadoop versions don't throw FileAlreadyExistsException. Although it's fixed in + // HADOOP-9361, we still need to support old Hadoop versions. + (e.getMessage != null && e.getMessage.startsWith("File already exists: ")) + } + + override def get(batchId: Long): Option[T] = { + val batchMetadataFile = batchFile(batchId) + if (fc.util().exists(batchMetadataFile)) { + val input = fc.open(batchMetadataFile) + val bytes = IOUtils.toByteArray(input) + Some(serializer.deserialize[T](ByteBuffer.wrap(bytes))) + } else { + None + } + } + + override def get(startId: Option[Long], endId: Long): Array[(Long, T)] = { + val batchIds = fc.util().listStatus(metadataPath, batchFilesFilter) + .map(_.getPath.getName.toLong) + .filter { batchId => + batchId <= endId && (startId.isEmpty || batchId >= startId.get) + } + batchIds.sorted.map(batchId => (batchId, get(batchId))).filter(_._2.isDefined).map { + case (batchId, metadataOption) => + (batchId, metadataOption.get) + } + } + + override def getLatest(): Option[(Long, T)] = { + val batchIds = fc.util().listStatus(metadataPath, batchFilesFilter) + .map(_.getPath.getName.toLong) + .sorted + .reverse + for (batchId <- batchIds) { + val batch = get(batchId) + if (batch.isDefined) { + return Some((batchId, batch.get)) + } + } + None + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLog.scala new file mode 100644 index 0000000000000..3f9896d23ce36 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLog.scala @@ -0,0 +1,51 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.streaming + +/** + * A general MetadataLog that supports the following features: + * + * - Allow the user to store a metadata object for each batch. + * - Allow the user to query the latest batch id. + * - Allow the user to query the metadata object of a specified batch id. + * - Allow the user to query metadata objects in a range of batch ids. + */ +trait MetadataLog[T] { + + /** + * Store the metadata for the specified batchId and return `true` if successful. If the batchId's + * metadata has already been stored, this method will return `false`. + */ + def add(batchId: Long, metadata: T): Boolean + + /** + * Return the metadata for the specified batchId if it's stored. Otherwise, return None. + */ + def get(batchId: Long): Option[T] + + /** + * Return metadata for batches between startId (inclusive) and endId (inclusive). If `startId` is + * `None`, just return all batches before endId (inclusive). + */ + def get(startId: Option[Long], endId: Long): Array[(Long, T)] + + /** + * Return the latest batch Id and its metadata if exist. + */ + def getLatest(): Option[(Long, T)] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala index 1bd71b6b02ea9..e3b2d2f67ee0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala @@ -38,7 +38,7 @@ trait Sink { * Accepts a new batch of data as well as a [[Offset]] that denotes how far in the input * data computation has progressed to. When computation restarts after a failure, it is important * that a [[Sink]] returns the same [[Offset]] as the most recent batch of data that - * has been persisted durrably. Note that this does not necessarily have to be the + * has been persisted durably. Note that this does not necessarily have to be the * [[Offset]] for the most recent batch of data that was given to the sink. For example, * it is valid to buffer data before persisting, as long as the [[Offset]] is stored * transactionally as data is eventually persisted. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 096477ce0e511..d7ff44afadf22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -100,7 +100,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) /** * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit - * tests and does not provide durablility. + * tests and does not provide durability. */ class MemorySink(schema: StructType) extends Sink with Logging { /** An order list of batches that have been written to this [[Sink]]. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index e6d7480b0422c..0d580703f5547 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -17,12 +17,12 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions.{ExprId, Literal, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SessionState import org.apache.spark.sql.types.DataType /** @@ -62,12 +62,12 @@ case class ScalarSubquery( /** * Convert the subquery from logical plan into executed plan. */ -case class PlanSubqueries(sqlContext: SQLContext) extends Rule[SparkPlan] { +case class PlanSubqueries(sessionState: SessionState) extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = { plan.transformAllExpressions { case subquery: expressions.ScalarSubquery => - val sparkPlan = sqlContext.planner.plan(ReturnAnswer(subquery.query)).next() - val executedPlan = sqlContext.prepareForExecution.execute(sparkPlan) + val sparkPlan = sessionState.planner.plan(ReturnAnswer(subquery.query)).next() + val executedPlan = sessionState.prepareForExecution.execute(sparkPlan) ScalarSubquery(executedPlan, subquery.exprId) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 737e125f6cf03..dd4aa9e93ae4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -33,25 +33,6 @@ import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -/** - * Ensures that java functions signatures for methods that now return a [[TypedColumn]] still have - * legacy equivalents in bytecode. This compatibility is done by forcing the compiler to generate - * "bridge" methods due to the use of covariant return types. - * - * {{{ - * // In LegacyFunctions: - * public abstract org.apache.spark.sql.Column avg(java.lang.String); - * - * // In functions: - * public static org.apache.spark.sql.TypedColumn avg(...); - * }}} - * - * This allows us to use the same functions both in typed [[Dataset]] operations and untyped - * [[DataFrame]] operations when the return type for a given function is statically known. - */ -private[sql] abstract class LegacyFunctions { - def count(columnName: String): Column -} /** * :: Experimental :: @@ -72,7 +53,7 @@ private[sql] abstract class LegacyFunctions { */ @Experimental // scalastyle:off -object functions extends LegacyFunctions { +object functions { // scalastyle:on private def withExpr(expr: Expression): Column = Column(expr) @@ -287,7 +268,7 @@ object functions extends LegacyFunctions { * @since 1.3.0 */ def count(columnName: String): TypedColumn[Any, Long] = - count(Column(columnName)).as(ExpressionEncoder[Long]) + count(Column(columnName)).as(ExpressionEncoder[Long]()) /** * Aggregate function: returns the number of distinct items in a group. @@ -1180,7 +1161,7 @@ object functions extends LegacyFunctions { * @group normal_funcs */ def expr(expr: String): Column = { - val parser = SQLContext.getActive().map(_.sqlParser).getOrElse(new CatalystQl()) + val parser = SQLContext.getActive().map(_.sessionState.sqlParser).getOrElse(new CatalystQl()) Column(parser.parseExpression(expr)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 384102e5eaa5b..cbdc37a2a1622 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -441,7 +441,7 @@ object SQLConf { // NOTE: // // 1. Instead of SQLConf, this option *must be set in Hadoop Configuration*. - // 2. This option can be overriden by "spark.sql.parquet.output.committer.class". + // 2. This option can be overridden by "spark.sql.parquet.output.committer.class". val OUTPUT_COMMITTER_CLASS = stringConf("spark.sql.sources.outputCommitterClass", isPublic = false) @@ -504,6 +504,11 @@ object SQLConf { " method", isPublic = false) + val FILES_MAX_PARTITION_BYTES = longConf("spark.sql.files.maxPartitionBytes", + defaultValue = Some(128 * 1024 * 1024), // parquet.block.size + doc = "The maximum number of bytes to pack into a single partition when reading files.", + isPublic = true) + val EXCHANGE_REUSE_ENABLED = booleanConf("spark.sql.exchange.reuse", defaultValue = Some(true), doc = "When true, the planner will try to find out duplicated exchanges and re-use them", @@ -538,6 +543,8 @@ class SQLConf extends Serializable with CatalystConf with ParserConf with Loggin /** ************************ Spark SQL Params/Hints ******************* */ + def filesMaxPartitionBytes: Long = getConf(FILES_MAX_PARTITION_BYTES) + def useCompression: Boolean = getConf(COMPRESS_CACHED) def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION) @@ -605,7 +612,7 @@ class SQLConf extends Serializable with CatalystConf with ParserConf with Loggin def parallelPartitionDiscoveryThreshold: Int = getConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD) - def bucketingEnabled(): Boolean = getConf(SQLConf.BUCKETING_ENABLED) + def bucketingEnabled: Boolean = getConf(SQLConf.BUCKETING_ENABLED) // Do not use a value larger than 4000 as the default value of this property. // See the comments of SCHEMA_STRING_LENGTH_THRESHOLD above for more information. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index 98ada4d58af7e..e6be0ab3bc420 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.internal -import org.apache.spark.sql.{ContinuousQueryManager, SQLContext, UDFRegistration} +import org.apache.spark.sql.{ContinuousQueryManager, ExperimentalMethods, SQLContext, UDFRegistration} import org.apache.spark.sql.catalyst.analysis.{Analyzer, Catalog, FunctionRegistry, SimpleCatalog} import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface @@ -40,6 +40,8 @@ private[sql] class SessionState(ctx: SQLContext) { */ lazy val conf = new SQLConf + lazy val experimentalMethods = new ExperimentalMethods + /** * Internal catalog for managing table and database states. */ @@ -73,7 +75,7 @@ private[sql] class SessionState(ctx: SQLContext) { /** * Logical query plan optimizer. */ - lazy val optimizer: Optimizer = new SparkOptimizer(ctx) + lazy val optimizer: Optimizer = new SparkOptimizer(experimentalMethods) /** * Parser that extracts expressions, plans, table identifiers etc. from SQL texts. @@ -83,7 +85,7 @@ private[sql] class SessionState(ctx: SQLContext) { /** * Planner that converts optimized logical plans to physical plans. */ - lazy val planner: SparkPlanner = new SparkPlanner(ctx) + lazy val planner: SparkPlanner = new SparkPlanner(ctx.sparkContext, conf, experimentalMethods) /** * Prepares a planned [[SparkPlan]] for execution by inserting shuffle operations and internal @@ -91,10 +93,10 @@ private[sql] class SessionState(ctx: SQLContext) { */ lazy val prepareForExecution = new RuleExecutor[SparkPlan] { override val batches: Seq[Batch] = Seq( - Batch("Subquery", Once, PlanSubqueries(ctx)), - Batch("Add exchange", Once, EnsureRequirements(ctx)), - Batch("Whole stage codegen", Once, CollapseCodegenStages(ctx)), - Batch("Reuse duplicated exchanges", Once, ReuseExchange(ctx)) + Batch("Subquery", Once, PlanSubqueries(SessionState.this)), + Batch("Add exchange", Once, EnsureRequirements(conf)), + Batch("Whole stage codegen", Once, CollapseCodegenStages(conf)), + Batch("Reuse duplicated exchanges", Once, ReuseExchange(conf)) ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index e251b52f6c0f9..95ffc33011e8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -30,11 +30,11 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.FileRelation import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.streaming.{FileStreamSource, Sink, Source} +import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.SerializableConfiguration import org.apache.spark.util.collection.BitSet @@ -57,7 +57,7 @@ trait DataSourceRegister { * overridden by children to provide a nice alias for the data source. For example: * * {{{ - * override def format(): String = "parquet" + * override def shortName(): String = "parquet" * }}} * * @since 1.5.0 @@ -409,7 +409,7 @@ case class HadoopFsRelation( def partitionSchemaOption: Option[StructType] = if (partitionSchema.isEmpty) None else Some(partitionSchema) - def partitionSpec: PartitionSpec = location.partitionSpec(partitionSchemaOption) + def partitionSpec: PartitionSpec = location.partitionSpec() def refresh(): Unit = location.refresh() @@ -454,11 +454,41 @@ trait FileFormat { requiredColumns: Array[String], filters: Array[Filter], bucketSet: Option[BitSet], - inputFiles: Array[FileStatus], + inputFiles: Seq[FileStatus], broadcastedConf: Broadcast[SerializableConfiguration], options: Map[String, String]): RDD[InternalRow] + + /** + * Returns a function that can be used to read a single file in as an Iterator of InternalRow. + * + * @param partitionSchema The schema of the partition column row that will be present in each + * PartitionedFile. These columns should be prepended to the rows that + * are produced by the iterator. + * @param dataSchema The schema of the data that should be output for each row. This may be a + * subset of the columns that are present in the file if column pruning has + * occurred. + * @param filters A set of filters than can optionally be used to reduce the number of rows output + * @param options A set of string -> string configuration options. + * @return + */ + def buildReader( + sqlContext: SQLContext, + partitionSchema: StructType, + dataSchema: StructType, + filters: Seq[Filter], + options: Map[String, String]): PartitionedFile => Iterator[InternalRow] = { + // TODO: Remove this default implementation when the other formats have been ported + // Until then we guard in [[FileSourceStrategy]] to only call this method on supported formats. + throw new UnsupportedOperationException(s"buildReader is not supported for $this") + } } +/** + * A collection of data files from a partitioned relation, along with the partition values in the + * form of an [[InternalRow]]. + */ +case class Partition(values: InternalRow, files: Seq[FileStatus]) + /** * An interface for objects capable of enumerating the files that comprise a relation as well * as the partitioning characteristics of those files. @@ -466,7 +496,18 @@ trait FileFormat { trait FileCatalog { def paths: Seq[Path] - def partitionSpec(schema: Option[StructType]): PartitionSpec + def partitionSpec(): PartitionSpec + + /** + * Returns all valid files grouped into partitions when the data is partitioned. If the data is + * unpartitioned, this will return a single partition with not partition values. + * + * @param filters the filters used to prune which partitions are returned. These filters must + * only refer to partition columns and this method will only return files + * where these predicates are guaranteed to evaluate to `true`. Thus, these + * filters will not need to be evaluated again on the returned data. + */ + def listFiles(filters: Seq[Expression]): Seq[Partition] def allFiles(): Seq[FileStatus] @@ -478,11 +519,17 @@ trait FileCatalog { /** * A file catalog that caches metadata gathered by scanning all the files present in `paths` * recursively. + * + * @param parameters as set of options to control discovery + * @param paths a list of paths to scan + * @param partitionSchema an optional partition schema that will be use to provide types for the + * discovered partitions */ class HDFSFileCatalog( val sqlContext: SQLContext, val parameters: Map[String, String], - val paths: Seq[Path]) + val paths: Seq[Path], + val partitionSchema: Option[StructType]) extends FileCatalog with Logging { private val hadoopConf = new Configuration(sqlContext.sparkContext.hadoopConfiguration) @@ -491,9 +538,9 @@ class HDFSFileCatalog( var leafDirToChildrenFiles = mutable.Map.empty[Path, Array[FileStatus]] var cachedPartitionSpec: PartitionSpec = _ - def partitionSpec(schema: Option[StructType]): PartitionSpec = { + def partitionSpec(): PartitionSpec = { if (cachedPartitionSpec == null) { - cachedPartitionSpec = inferPartitioning(schema) + cachedPartitionSpec = inferPartitioning(partitionSchema) } cachedPartitionSpec @@ -501,6 +548,53 @@ class HDFSFileCatalog( refresh() + override def listFiles(filters: Seq[Expression]): Seq[Partition] = { + if (partitionSpec().partitionColumns.isEmpty) { + Partition(InternalRow.empty, allFiles()) :: Nil + } else { + prunePartitions(filters, partitionSpec()).map { + case PartitionDirectory(values, path) => Partition(values, getStatus(path)) + } + } + } + + protected def prunePartitions( + predicates: Seq[Expression], + partitionSpec: PartitionSpec): Seq[PartitionDirectory] = { + val PartitionSpec(partitionColumns, partitions) = partitionSpec + val partitionColumnNames = partitionColumns.map(_.name).toSet + val partitionPruningPredicates = predicates.filter { + _.references.map(_.name).toSet.subsetOf(partitionColumnNames) + } + + if (partitionPruningPredicates.nonEmpty) { + val predicate = + partitionPruningPredicates + .reduceOption(expressions.And) + .getOrElse(Literal(true)) + + val boundPredicate = InterpretedPredicate.create(predicate.transform { + case a: AttributeReference => + val index = partitionColumns.indexWhere(a.name == _.name) + BoundReference(index, partitionColumns(index).dataType, nullable = true) + }) + + val selected = partitions.filter { + case PartitionDirectory(values, _) => boundPredicate(values) + } + logInfo { + val total = partitions.length + val selectedSize = selected.length + val percentPruned = (1 - selectedSize.toDouble / total.toDouble) * 100 + s"Selected $selectedSize partitions out of $total, pruned $percentPruned% partitions." + } + + selected + } else { + partitions + } + } + def allFiles(): Seq[FileStatus] = leafFiles.values.toSeq def getStatus(path: Path): Array[FileStatus] = leafDirToChildrenFiles(path) @@ -522,7 +616,7 @@ class HDFSFileCatalog( } }.filterNot { status => val name = status.getPath.getName - name.toLowerCase == "_temporary" || name.startsWith(".") + HadoopFsRelation.shouldFilterOut(name) } val (dirs, files) = statuses.partition(_.isDirectory) @@ -560,7 +654,7 @@ class HDFSFileCatalog( PartitionSpec(userProvidedSchema, spec.partitions.map { part => part.copy(values = castPartitionValuesToUserSchema(part.values)) }) - case None => + case _ => PartitioningUtils.parsePartitions( leafDirs, PartitioningUtils.DEFAULT_PARTITION_NAME, @@ -616,6 +710,16 @@ class HDFSFileCatalog( * Helper methods for gathering metadata from HDFS. */ private[sql] object HadoopFsRelation extends Logging { + + /** Checks if we should filter out this path name. */ + def shouldFilterOut(pathName: String): Boolean = { + // TODO: We should try to filter out all files/dirs starting with "." or "_". + // The only reason that we are not doing it now is that Parquet needs to find those + // metadata files from leaf files returned by this methods. We should refactor + // this logic to not mix metadata files with data files. + pathName == "_SUCCESS" || pathName == "_temporary" || pathName.startsWith(".") + } + // We don't filter files/directories whose name start with "_" except "_temporary" here, as // specific data sources may take advantages over them (e.g. Parquet _metadata and // _common_metadata files). "_temporary" directories are explicitly ignored since failed @@ -624,19 +728,21 @@ private[sql] object HadoopFsRelation extends Logging { def listLeafFiles(fs: FileSystem, status: FileStatus): Array[FileStatus] = { logInfo(s"Listing ${status.getPath}") val name = status.getPath.getName.toLowerCase - if (name == "_temporary" || name.startsWith(".")) { + if (shouldFilterOut(name)) { Array.empty } else { // Dummy jobconf to get to the pathFilter defined in configuration val jobConf = new JobConf(fs.getConf, this.getClass()) val pathFilter = FileInputFormat.getInputPathFilter(jobConf) - if (pathFilter != null) { - val (dirs, files) = fs.listStatus(status.getPath, pathFilter).partition(_.isDirectory) - files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) - } else { - val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDirectory) - files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) - } + val statuses = + if (pathFilter != null) { + val (dirs, files) = fs.listStatus(status.getPath, pathFilter).partition(_.isDirectory) + files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + } else { + val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDirectory) + files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + } + statuses.filterNot(status => shouldFilterOut(status.getPath.getName)) } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java index 0f9e453d26db8..9e65158eb0a33 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java @@ -40,7 +40,6 @@ public class JavaSaveLoadSuite { private transient JavaSparkContext sc; private transient SQLContext sqlContext; - String originalDefaultSource; File path; Dataset df; @@ -57,7 +56,6 @@ public void setUp() throws IOException { sqlContext = new SQLContext(_sc); sc = new JavaSparkContext(_sc); - originalDefaultSource = sqlContext.conf().defaultDataSourceName(); path = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource").getCanonicalFile(); if (path.exists()) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index aff9efe4b2b16..2aa6f8d4acf7d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.nio.charset.StandardCharsets + import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -167,12 +169,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } test("misc sha1 function") { - val df = Seq(("ABC", "ABC".getBytes)).toDF("a", "b") + val df = Seq(("ABC", "ABC".getBytes(StandardCharsets.UTF_8))).toDF("a", "b") checkAnswer( df.select(sha1($"a"), sha1($"b")), Row("3c01bdbb26f358bab27f267924aa2c9a03fcfdb8", "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8")) - val dfEmpty = Seq(("", "".getBytes)).toDF("a", "b") + val dfEmpty = Seq(("", "".getBytes(StandardCharsets.UTF_8))).toDF("a", "b") checkAnswer( dfEmpty.selectExpr("sha1(a)", "sha1(b)"), Row("da39a3ee5e6b4b0d3255bfef95601890afd80709", "da39a3ee5e6b4b0d3255bfef95601890afd80709")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index e865dbe6b5063..a7a826bc7a8d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -80,7 +80,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { // Verify that the splits span the entire dataset assert(splits.flatMap(_.collect()).toSet == data.collect().toSet) - // Verify that the splits don't overalap + // Verify that the splits don't overlap assert(splits(0).intersect(splits(1)).collect().isEmpty) // Verify that the results are deterministic across multiple runs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index e6e27ec413bb2..2333fa27ca623 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import java.io.File +import java.nio.charset.StandardCharsets import scala.language.postfixOps import scala.util.Random @@ -665,8 +666,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("showString: binary") { val df = Seq( - ("12".getBytes, "ABC.".getBytes), - ("34".getBytes, "12346".getBytes) + ("12".getBytes(StandardCharsets.UTF_8), "ABC.".getBytes(StandardCharsets.UTF_8)), + ("34".getBytes(StandardCharsets.UTF_8), "12346".getBytes(StandardCharsets.UTF_8)) ).toDF() val expectedAnswer = """+-------+----------------+ || _1| _2| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 9f32c8bf95ad6..d7fa23651bcee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -46,7 +46,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { 1, 1, 1) } - test("SPARK-12404: Datatype Helper Serializablity") { + test("SPARK-12404: Datatype Helper Serializability") { val ds = sparkContext.parallelize(( new Timestamp(0), new Date(0), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala index 2c4b4f80ff9ed..b1987c690811d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala @@ -29,7 +29,9 @@ case class FastOperator(output: Seq[Attribute]) extends SparkPlan { override protected def doExecute(): RDD[InternalRow] = { val str = Literal("so fast").value val row = new GenericInternalRow(Array[Any](str)) - sparkContext.parallelize(Seq(row)) + val unsafeProj = UnsafeProjection.create(schema) + val unsafeRow = unsafeProj(row).copy() + sparkContext.parallelize(Seq(unsafeRow)) } override def producedAttributes: AttributeSet = outputSet diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 2bd29ef19b649..50647c28402eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -37,7 +37,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan - val planned = sqlContext.planner.EquiJoinSelection(join) + val planned = sqlContext.sessionState.planner.EquiJoinSelection(join) assert(planned.size === 1) } @@ -139,7 +139,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan - val planned = sqlContext.planner.EquiJoinSelection(join) + val planned = sqlContext.sessionState.planner.EquiJoinSelection(join) assert(planned.size === 1) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 013a90875e2b2..f5a67fd782d63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.nio.charset.StandardCharsets + import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions.{log => logarithm} import org.apache.spark.sql.test.SharedSQLContext @@ -262,9 +264,9 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext { test("unhex") { val data = Seq(("1C", "737472696E67")).toDF("a", "b") checkAnswer(data.select(unhex('a)), Row(Array[Byte](28.toByte))) - checkAnswer(data.select(unhex('b)), Row("string".getBytes)) + checkAnswer(data.select(unhex('b)), Row("string".getBytes(StandardCharsets.UTF_8))) checkAnswer(data.selectExpr("unhex(a)"), Row(Array[Byte](28.toByte))) - checkAnswer(data.selectExpr("unhex(b)"), Row("string".getBytes)) + checkAnswer(data.selectExpr("unhex(b)"), Row("string".getBytes(StandardCharsets.UTF_8))) checkAnswer(data.selectExpr("""unhex("##")"""), Row(null)) checkAnswer(data.selectExpr("""unhex("G123")"""), Row(null)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index ec19d97d8cec2..2ad92b52c4ff0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -76,6 +76,6 @@ class SQLContextSuite extends SparkFunSuite with SharedSparkContext{ test("Catalyst optimization passes are modifiable at runtime") { val sqlContext = SQLContext.getOrCreate(sc) sqlContext.experimental.extraOptimizations = Seq(DummyRule) - assert(sqlContext.optimizer.batches.flatMap(_.rules).contains(DummyRule)) + assert(sqlContext.sessionState.optimizer.batches.flatMap(_.rules).contains(DummyRule)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 182f287dd001c..836fb1ce853c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -54,7 +54,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("show functions") { def getFunctions(pattern: String): Seq[Row] = { val regex = java.util.regex.Pattern.compile(pattern) - sqlContext.functionRegistry.listFunction().filter(regex.matcher(_).matches()).map(Row(_)) + sqlContext.sessionState.functionRegistry.listFunction() + .filter(regex.matcher(_).matches()).map(Row(_)) } checkAnswer(sql("SHOW functions"), getFunctions(".*")) Seq("^c.*", ".*e$", "log.*", ".*date.*").foreach { pattern => @@ -986,7 +987,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SET commands with illegal or inappropriate argument") { sqlContext.conf.clear() - // Set negative mapred.reduce.tasks for automatically determing + // Set negative mapred.reduce.tasks for automatically determining // the number of reducers is not supported intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-1")) intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-01")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala index 7a5b63911546f..81078dc6a0450 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala @@ -80,7 +80,7 @@ trait StreamTest extends QueryTest with Timeouts { trait StreamMustBeRunning /** - * Adds the given data to the stream. Subsuquent check answers will block until this data has + * Adds the given data to the stream. Subsequent check answers will block until this data has * been processed. */ object AddData { @@ -109,7 +109,7 @@ trait StreamTest extends QueryTest with Timeouts { /** * Checks to make sure that the current data stored in the sink matches the `expectedAnswer`. - * This operation automatically blocks untill all added data has been processed. + * This operation automatically blocks until all added data has been processed. */ object CheckAnswer { def apply[A : Encoder](data: A*): CheckAnswerRows = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 2d3e34d0e1292..9f33e4ab62298 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -428,4 +428,29 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { */ benchmark.run() } + + ignore("collect") { + val N = 1 << 20 + + val benchmark = new Benchmark("collect", N) + benchmark.addCase("collect 1 million") { iter => + sqlContext.range(N).collect() + } + benchmark.addCase("collect 2 millions") { iter => + sqlContext.range(N * 2).collect() + } + benchmark.addCase("collect 4 millions") { iter => + sqlContext.range(N * 4).collect() + } + benchmark.run() + + /** + * Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + collect: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + collect 1 million 775 / 1170 1.4 738.9 1.0X + collect 2 millions 1153 / 1758 0.9 1099.3 0.7X + collect 4 millions 4451 / 5124 0.2 4244.9 0.2X + */ + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index ab0a7ff628962..88fbcda296cac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -37,7 +37,7 @@ class PlannerSuite extends SharedSQLContext { setupTestData() private def testPartialAggregationPlan(query: LogicalPlan): Unit = { - val planner = sqlContext.planner + val planner = sqlContext.sessionState.planner import planner._ val plannedOption = Aggregation(query).headOption val planned = @@ -294,7 +294,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(Seq.empty, Seq.empty) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) { fail(s"Exchange should have been added:\n$outputPlan") @@ -314,7 +314,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(Seq.empty, Seq.empty) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) } @@ -332,7 +332,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(Seq.empty, Seq.empty) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) { fail(s"Exchange should have been added:\n$outputPlan") @@ -352,7 +352,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(Seq.empty, Seq.empty) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) { fail(s"Exchange should not have been added:\n$outputPlan") @@ -375,7 +375,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(outputOrdering, outputOrdering) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) { fail(s"No Exchanges should have been added:\n$outputPlan") @@ -391,7 +391,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq(orderingB)), requiredChildDistribution = Seq(UnspecifiedDistribution) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case s: Sort => true }.isEmpty) { fail(s"Sort should have been added:\n$outputPlan") @@ -407,7 +407,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq(orderingA)), requiredChildDistribution = Seq(UnspecifiedDistribution) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case s: Sort => true }.nonEmpty) { fail(s"No sorts should have been added:\n$outputPlan") @@ -424,7 +424,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq(orderingA, orderingB)), requiredChildDistribution = Seq(UnspecifiedDistribution) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case s: Sort => true }.isEmpty) { fail(s"Sort should have been added:\n$outputPlan") @@ -443,7 +443,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq.empty)), None) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.size == 2) { fail(s"Topmost Exchange should have been eliminated:\n$outputPlan") @@ -463,7 +463,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq.empty)), None) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.size == 1) { fail(s"Topmost Exchange should not have been eliminated:\n$outputPlan") @@ -491,7 +491,7 @@ class PlannerSuite extends SharedSQLContext { shuffle, shuffle) - val outputPlan = ReuseExchange(sqlContext).apply(inputPlan) + val outputPlan = ReuseExchange(sqlContext.sessionState.conf).apply(inputPlan) if (outputPlan.collect { case e: ReusedExchange => true }.size != 1) { fail(s"Should re-use the shuffle:\n$outputPlan") } @@ -507,7 +507,7 @@ class PlannerSuite extends SharedSQLContext { ShuffleExchange(finalPartitioning, inputPlan), ShuffleExchange(finalPartitioning, inputPlan)) - val outputPlan2 = ReuseExchange(sqlContext).apply(inputPlan2) + val outputPlan2 = ReuseExchange(sqlContext.sessionState.conf).apply(inputPlan2) if (outputPlan2.collect { case e: ReusedExchange => true }.size != 2) { fail(s"Should re-use the two shuffles:\n$outputPlan2") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index aa928cfc8096f..ed0d3f56e5ca9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -233,7 +233,7 @@ object SparkPlanTest { private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = { // A very simple resolver to make writing tests easier. In contrast to the real resolver // this is always case sensitive and does not try to handle scoping or complex type resolution. - val resolvedPlan = sqlContext.prepareForExecution.execute( + val resolvedPlan = sqlContext.sessionState.prepareForExecution.execute( outputPlan transform { case plan: SparkPlan => val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala index 0000a5d1efd09..1aadd700d7443 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.columnar.compression import java.nio.{ByteBuffer, ByteOrder} +import java.nio.charset.StandardCharsets import org.apache.commons.lang3.RandomStringUtils import org.apache.commons.math3.distribution.LogNormalDistribution @@ -313,7 +314,7 @@ object CompressionSchemeBenchmark extends AllCompressionSchemes { } for (i <- 0 until count) { testData.putInt(strLen) - testData.put(g().getBytes) + testData.put(g().getBytes(StandardCharsets.UTF_8)) } testData.rewind() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 0d632a8a130ed..6f1eea273fafa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -48,26 +48,26 @@ class DDLCommandSuite extends PlanTest { val sql1 = """ |CREATE TEMPORARY FUNCTION helloworld as - |'com.matthewrathbone.example.SimpleUDFExample' USING JAR '/path/to/jar', - |FILE 'path/to/file' + |'com.matthewrathbone.example.SimpleUDFExample' USING JAR '/path/to/jar1', + |JAR '/path/to/jar2' """.stripMargin val sql2 = """ |CREATE FUNCTION hello.world as |'com.matthewrathbone.example.SimpleUDFExample' USING ARCHIVE '/path/to/archive', - |FILE 'path/to/file' + |FILE '/path/to/file' """.stripMargin val parsed1 = parser.parsePlan(sql1) val parsed2 = parser.parsePlan(sql2) val expected1 = CreateFunction( "helloworld", "com.matthewrathbone.example.SimpleUDFExample", - Map("jar" -> "/path/to/jar", "file" -> "path/to/file"), + Seq(("jar", "/path/to/jar1"), ("jar", "/path/to/jar2")), isTemp = true)(sql1) val expected2 = CreateFunction( "hello.world", "com.matthewrathbone.example.SimpleUDFExample", - Map("archive" -> "/path/to/archive", "file" -> "path/to/file"), + Seq(("archive", "/path/to/archive"), ("file", "/path/to/file")), isTemp = false)(sql2) comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala new file mode 100644 index 0000000000000..2f8129c5da40d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -0,0 +1,345 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.io.{File, FilenameFilter} + +import org.apache.hadoop.fs.FileStatus +import org.apache.hadoop.mapreduce.Job + +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionSet, PredicateHelper} +import org.apache.spark.sql.catalyst.util +import org.apache.spark.sql.execution.{DataSourceScan, PhysicalRDD} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.util.collection.BitSet + +class FileSourceStrategySuite extends QueryTest with SharedSQLContext with PredicateHelper { + import testImplicits._ + + test("unpartitioned table, single partition") { + val table = + createTable( + files = Seq( + "file1" -> 1, + "file2" -> 1, + "file3" -> 1, + "file4" -> 1, + "file5" -> 1, + "file6" -> 1, + "file7" -> 1, + "file8" -> 1, + "file9" -> 1, + "file10" -> 1)) + + checkScan(table.select('c1)) { partitions => + // 10 one byte files should fit in a single partition with 10 files. + assert(partitions.size == 1, "when checking partitions") + assert(partitions.head.files.size == 10, "when checking partition 1") + // 1 byte files are too small to split so we should read the whole thing. + assert(partitions.head.files.head.start == 0) + assert(partitions.head.files.head.length == 1) + } + + checkPartitionSchema(StructType(Nil)) + checkDataSchema(StructType(Nil).add("c1", IntegerType)) + } + + test("unpartitioned table, multiple partitions") { + val table = + createTable( + files = Seq( + "file1" -> 5, + "file2" -> 5, + "file3" -> 5)) + + withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "10") { + checkScan(table.select('c1)) { partitions => + // 5 byte files should be laid out [(5, 5), (5)] + assert(partitions.size == 2, "when checking partitions") + assert(partitions(0).files.size == 2, "when checking partition 1") + assert(partitions(1).files.size == 1, "when checking partition 2") + + // 5 byte files are too small to split so we should read the whole thing. + assert(partitions.head.files.head.start == 0) + assert(partitions.head.files.head.length == 5) + } + + checkPartitionSchema(StructType(Nil)) + checkDataSchema(StructType(Nil).add("c1", IntegerType)) + } + } + + test("Unpartitioned table, large file that gets split") { + val table = + createTable( + files = Seq( + "file1" -> 15, + "file2" -> 4)) + + withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "10") { + checkScan(table.select('c1)) { partitions => + // Files should be laid out [(0-5), (5-10, 4)] + assert(partitions.size == 2, "when checking partitions") + assert(partitions(0).files.size == 1, "when checking partition 1") + assert(partitions(1).files.size == 2, "when checking partition 2") + + // Start by reading 10 bytes of the first file + assert(partitions.head.files.head.start == 0) + assert(partitions.head.files.head.length == 10) + + // Second partition reads the remaining 5 + assert(partitions(1).files.head.start == 10) + assert(partitions(1).files.head.length == 5) + } + + checkPartitionSchema(StructType(Nil)) + checkDataSchema(StructType(Nil).add("c1", IntegerType)) + } + } + + test("partitioned table") { + val table = + createTable( + files = Seq( + "p1=1/file1" -> 10, + "p1=2/file2" -> 10)) + + // Only one file should be read. + checkScan(table.where("p1 = 1")) { partitions => + assert(partitions.size == 1, "when checking partitions") + assert(partitions.head.files.size == 1, "when files in partition 1") + } + // We don't need to reevaluate filters that are only on partitions. + checkDataFilters(Set.empty) + + // Only one file should be read. + checkScan(table.where("p1 = 1 AND c1 = 1 AND (p1 + c1) = 1")) { partitions => + assert(partitions.size == 1, "when checking partitions") + assert(partitions.head.files.size == 1, "when checking files in partition 1") + assert(partitions.head.files.head.partitionValues.getInt(0) == 1, + "when checking partition values") + } + // Only the filters that do not contain the partition column should be pushed down + checkDataFilters(Set(IsNotNull("c1"), EqualTo("c1", 1))) + } + + test("partitioned table - after scan filters") { + val table = + createTable( + files = Seq( + "p1=1/file1" -> 10, + "p1=2/file2" -> 10)) + + val df = table.where("p1 = 1 AND (p1 + c1) = 2 AND c1 = 1") + // Filter on data only are advisory so we have to reevaluate. + assert(getPhysicalFilters(df) contains resolve(df, "c1 = 1")) + // Need to evalaute filters that are not pushed down. + assert(getPhysicalFilters(df) contains resolve(df, "(p1 + c1) = 2")) + // Don't reevaluate partition only filters. + assert(!(getPhysicalFilters(df) contains resolve(df, "p1 = 1"))) + } + + test("bucketed table") { + val table = + createTable( + files = Seq( + "p1=1/file1_0000" -> 1, + "p1=1/file2_0000" -> 1, + "p1=1/file3_0002" -> 1, + "p1=2/file4_0002" -> 1, + "p1=2/file5_0000" -> 1, + "p1=2/file6_0000" -> 1, + "p1=2/file7_0000" -> 1), + buckets = 3) + + // No partition pruning + checkScan(table) { partitions => + assert(partitions.size == 3) + assert(partitions(0).files.size == 5) + assert(partitions(1).files.size == 0) + assert(partitions(2).files.size == 2) + } + + // With partition pruning + checkScan(table.where("p1=2")) { partitions => + assert(partitions.size == 3) + assert(partitions(0).files.size == 3) + assert(partitions(1).files.size == 0) + assert(partitions(2).files.size == 1) + } + } + + // Helpers for checking the arguments passed to the FileFormat. + + protected val checkPartitionSchema = + checkArgument("partition schema", _.partitionSchema, _: StructType) + protected val checkDataSchema = + checkArgument("data schema", _.dataSchema, _: StructType) + protected val checkDataFilters = + checkArgument("data filters", _.filters.toSet, _: Set[Filter]) + + /** Helper for building checks on the arguments passed to the reader. */ + protected def checkArgument[T](name: String, arg: LastArguments.type => T, expected: T): Unit = { + if (arg(LastArguments) != expected) { + fail( + s""" + |Wrong $name + |expected: $expected + |actual: ${arg(LastArguments)} + """.stripMargin) + } + } + + /** Returns a resolved expression for `str` in the context of `df`. */ + def resolve(df: DataFrame, str: String): Expression = { + df.select(expr(str)).queryExecution.analyzed.expressions.head.children.head + } + + /** Returns a set with all the filters present in the physical plan. */ + def getPhysicalFilters(df: DataFrame): ExpressionSet = { + ExpressionSet( + df.queryExecution.executedPlan.collect { + case execution.Filter(f, _) => splitConjunctivePredicates(f) + }.flatten) + } + + /** Plans the query and calls the provided validation function with the planned partitioning. */ + def checkScan(df: DataFrame)(func: Seq[FilePartition] => Unit): Unit = { + val fileScan = df.queryExecution.executedPlan.collect { + case DataSourceScan(_, scan: FileScanRDD, _, _) => scan + }.headOption.getOrElse { + fail(s"No FileScan in query\n${df.queryExecution}") + } + + func(fileScan.filePartitions) + } + + /** + * Constructs a new table given a list of file names and sizes expressed in bytes. The table + * is written out in a temporary directory and any nested directories in the files names + * are automatically created. + * + * When `buckets` is > 0 the returned [[DataFrame]] will have metadata specifying that number of + * buckets. However, it is the responsibility of the caller to assign files to each bucket + * by appending the bucket id to the file names. + */ + def createTable( + files: Seq[(String, Int)], + buckets: Int = 0): DataFrame = { + val tempDir = Utils.createTempDir() + files.foreach { + case (name, size) => + val file = new File(tempDir, name) + assert(file.getParentFile.exists() || file.getParentFile.mkdirs()) + util.stringToFile(file, "*" * size) + } + + val df = sqlContext.read + .format(classOf[TestFileFormat].getName) + .load(tempDir.getCanonicalPath) + + if (buckets > 0) { + val bucketed = df.queryExecution.analyzed transform { + case l @ LogicalRelation(r: HadoopFsRelation, _, _) => + l.copy(relation = + r.copy(bucketSpec = Some(BucketSpec(numBuckets = buckets, "c1" :: Nil, Nil)))) + } + Dataset.newDataFrame(sqlContext, bucketed) + } else { + df + } + } +} + +/** Holds the last arguments passed to [[TestFileFormat]]. */ +object LastArguments { + var partitionSchema: StructType = _ + var dataSchema: StructType = _ + var filters: Seq[Filter] = _ + var options: Map[String, String] = _ +} + +/** A test [[FileFormat]] that records the arguments passed to buildReader, and returns nothing. */ +class TestFileFormat extends FileFormat { + + override def toString: String = "TestFileFormat" + + /** + * When possible, this method should return the schema of the given `files`. When the format + * does not support inference, or no valid files are given should return None. In these cases + * Spark will require that user specify the schema manually. + */ + override def inferSchema( + sqlContext: SQLContext, + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = + Some( + StructType(Nil) + .add("c1", IntegerType) + .add("c2", IntegerType)) + + /** + * Prepares a write job and returns an [[OutputWriterFactory]]. Client side job preparation can + * be put here. For example, user defined output committer can be configured here + * by setting the output committer class in the conf of spark.sql.sources.outputCommitterClass. + */ + override def prepareWrite( + sqlContext: SQLContext, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + throw new NotImplementedError("JUST FOR TESTING") + } + + override def buildInternalScan( + sqlContext: SQLContext, + dataSchema: StructType, + requiredColumns: Array[String], + filters: Array[Filter], + bucketSet: Option[BitSet], + inputFiles: Seq[FileStatus], + broadcastedConf: Broadcast[SerializableConfiguration], + options: Map[String, String]): RDD[InternalRow] = { + throw new NotImplementedError("JUST FOR TESTING") + } + + override def buildReader( + sqlContext: SQLContext, + partitionSchema: StructType, + dataSchema: StructType, + filters: Seq[Filter], + options: Map[String, String]): PartitionedFile => Iterator[InternalRow] = { + + // Record the arguments so they can be checked in the test case. + LastArguments.partitionSchema = partitionSchema + LastArguments.dataSchema = dataSchema + LastArguments.filters = filters + LastArguments.options = options + + (file: PartitionedFile) => { Iterator.empty } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala index 7af3f94aefea2..3a7cb25b4fa9b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala @@ -80,7 +80,7 @@ class InferSchemaSuite extends SparkFunSuite { assert(CSVInferSchema.inferField(BooleanType, "\\N", "\\N") == BooleanType) } - test("Merging Nulltypes should yeild Nulltype.") { + test("Merging Nulltypes should yield Nulltype.") { val mergedNullTypes = CSVInferSchema.mergeRowTypes(Array(NullType), Array(NullType)) assert(mergedNullTypes.deep == Array(NullType).deep) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 680759d4572a3..58d9d69d9a8a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -160,12 +160,13 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { |OPTIONS (path "${testFile(carsFile8859)}", header "true", |charset "iso-8859-1", delimiter "þ") """.stripMargin.replaceAll("\n", " ")) - //scalstyle:on + // scalastyle:on verifyCars(sqlContext.table("carsTable"), withHeader = true) } test("test aliases sep and encoding for delimiter and charset") { + // scalastyle:off val cars = sqlContext .read .format("csv") @@ -173,6 +174,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .option("encoding", "iso-8859-1") .option("sep", "þ") .load(testFile(carsFile8859)) + // scalastyle:on verifyCars(cars, withHeader = true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index b74b9d3f3bbca..f875b54cd6649 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -29,7 +29,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.execution.datasources.{LogicalRelation, Partition, PartitioningUtils, PartitionSpec} +import org.apache.spark.sql.execution.datasources.{LogicalRelation, PartitionDirectory => Partition, PartitioningUtils, PartitionSpec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.HadoopFsRelation import org.apache.spark.sql.test.SharedSQLContext @@ -706,6 +706,29 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha } } + test("_SUCCESS should not break partitioning discovery") { + Seq(1, 32).foreach { threshold => + // We have two paths to list files, one at driver side, another one that we use + // a Spark job. We need to test both ways. + withSQLConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> threshold.toString) { + withTempPath { dir => + val tablePath = new File(dir, "table") + val df = (1 to 3).map(i => (i, i, i, i)).toDF("a", "b", "c", "d") + + df.write + .format("parquet") + .partitionBy("b", "c", "d") + .save(tablePath.getCanonicalPath) + + Files.touch(new File(s"${tablePath.getCanonicalPath}/b=1", "_SUCCESS")) + Files.touch(new File(s"${tablePath.getCanonicalPath}/b=1/c=1", "_SUCCESS")) + Files.touch(new File(s"${tablePath.getCanonicalPath}/b=1/c=1/d=1", "_SUCCESS")) + checkAnswer(sqlContext.read.format("parquet").load(tablePath.getCanonicalPath), df) + } + } + } + } + test("listConflictingPartitionColumns") { def makeExpectedMessage(colNameLists: Seq[String], paths: Seq[String]): String = { val conflictingColNameLists = colNameLists.zipWithIndex.map { case (list, index) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index a256ee95a153c..6d5b777733f41 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -63,7 +63,8 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { // Comparison at the end is for broadcast left semi join val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") val df3 = df1.join(broadcast(df2), joinExpression, joinType) - val plan = EnsureRequirements(sqlContext).apply(df3.queryExecution.sparkPlan) + val plan = + EnsureRequirements(sqlContext.sessionState.conf).apply(df3.queryExecution.sparkPlan) assert(plan.collect { case p: T => p }.size === 1) plan.executeCollect() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 7eb15249ebbd6..eeb44404e9e47 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -98,7 +98,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { boundCondition, leftPlan, rightPlan) - EnsureRequirements(sqlContext).apply(broadcastJoin) + EnsureRequirements(sqlContext.sessionState.conf).apply(broadcastJoin) } def makeSortMergeJoin( @@ -109,7 +109,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { rightPlan: SparkPlan) = { val sortMergeJoin = joins.SortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, rightPlan) - EnsureRequirements(sqlContext).apply(sortMergeJoin) + EnsureRequirements(sqlContext.sessionState.conf).apply(sortMergeJoin) } test(s"$testName using BroadcastHashJoin (build=left)") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 0d1c29fe574a6..45254864309eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -98,7 +98,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - EnsureRequirements(sqlContext).apply( + EnsureRequirements(sqlContext.sessionState.conf).apply( SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), expectedAnswer.map(Row.fromTuple), sortAnswers = true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala index bc341db5571be..d8c9564f1e4fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -76,7 +76,7 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext { extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - EnsureRequirements(left.sqlContext).apply( + EnsureRequirements(left.sqlContext.sessionState.conf).apply( LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)), expectedAnswer.map(Row.fromTuple), sortAnswers = true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala deleted file mode 100644 index cd9277d3bcf1a..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala +++ /dev/null @@ -1,68 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.internal.SQLConf - -/** - * A dummy [[LocalNode]] that just returns rows from a [[LocalRelation]]. - */ -private[local] case class DummyNode( - output: Seq[Attribute], - relation: LocalRelation, - conf: SQLConf) - extends LocalNode(conf) { - - import DummyNode._ - - private var index: Int = CLOSED - private val input: Seq[InternalRow] = relation.data - - def this(output: Seq[Attribute], data: Seq[Product], conf: SQLConf = new SQLConf) { - this(output, LocalRelation.fromProduct(output, data), conf) - } - - def isOpen: Boolean = index != CLOSED - - override def children: Seq[LocalNode] = Seq.empty - - override def open(): Unit = { - index = -1 - } - - override def next(): Boolean = { - index += 1 - index < input.size - } - - override def fetch(): InternalRow = { - assert(index >= 0 && index < input.size) - input(index) - } - - override def close(): Unit = { - index = CLOSED - } -} - -private object DummyNode { - val CLOSED: Int = Int.MinValue -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala deleted file mode 100644 index bbd94d8da2d11..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.sql.catalyst.dsl.expressions._ - - -class ExpandNodeSuite extends LocalNodeTest { - - private def testExpand(inputData: Array[(Int, Int)] = Array.empty): Unit = { - val inputNode = new DummyNode(kvIntAttributes, inputData) - val projections = Seq(Seq('k + 'v, 'k - 'v), Seq('k * 'v, 'k / 'v)) - val expandNode = new ExpandNode(conf, projections, inputNode.output, inputNode) - val resolvedNode = resolveExpressions(expandNode) - val expectedOutput = { - val firstHalf = inputData.map { case (k, v) => (k + v, k - v) } - val secondHalf = inputData.map { case (k, v) => (k * v, k / v) } - firstHalf ++ secondHalf - } - val actualOutput = resolvedNode.collect().map { case row => - (row.getInt(0), row.getInt(1)) - } - assert(actualOutput.toSet === expectedOutput.toSet) - } - - test("empty") { - testExpand() - } - - test("basic") { - testExpand((1 to 100).map { i => (i, i * 1000) }.toArray) - } - -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala deleted file mode 100644 index 4eadce646d379..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala +++ /dev/null @@ -1,45 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.sql.catalyst.dsl.expressions._ - - -class FilterNodeSuite extends LocalNodeTest { - - private def testFilter(inputData: Array[(Int, Int)] = Array.empty): Unit = { - val cond = 'k % 2 === 0 - val inputNode = new DummyNode(kvIntAttributes, inputData) - val filterNode = new FilterNode(conf, cond, inputNode) - val resolvedNode = resolveExpressions(filterNode) - val expectedOutput = inputData.filter { case (k, _) => k % 2 == 0 } - val actualOutput = resolvedNode.collect().map { case row => - (row.getInt(0), row.getInt(1)) - } - assert(actualOutput === expectedOutput) - } - - test("empty") { - testFilter() - } - - test("basic") { - testFilter((1 to 100).map { i => (i, i) }.toArray) - } - -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala deleted file mode 100644 index 74142ea598d9d..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala +++ /dev/null @@ -1,141 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import org.mockito.Mockito.{mock, when} - -import org.apache.spark.broadcast.TorrentBroadcast -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeProjection} -import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide, HashedRelation} -import org.apache.spark.sql.internal.SQLConf - -class HashJoinNodeSuite extends LocalNodeTest { - - // Test all combinations of the two dimensions: with/out unsafe and build sides - private val buildSides = Seq(BuildLeft, BuildRight) - buildSides.foreach { buildSide => - testJoin(buildSide) - } - - /** - * Builds a [[HashedRelation]] based on a resolved `buildKeys` - * and a resolved `buildNode`. - */ - private def buildHashedRelation( - conf: SQLConf, - buildKeys: Seq[Expression], - buildNode: LocalNode): HashedRelation = { - - val buildSideKeyGenerator = UnsafeProjection.create(buildKeys, buildNode.output) - buildNode.prepare() - buildNode.open() - val hashedRelation = HashedRelation(buildNode, buildSideKeyGenerator) - buildNode.close() - - hashedRelation - } - - /** - * Test inner hash join with varying degrees of matches. - */ - private def testJoin(buildSide: BuildSide): Unit = { - val testNamePrefix = buildSide - val someData = (1 to 100).map { i => (i, "burger" + i) }.toArray - val conf = new SQLConf - - // Actual test body - def runTest(leftInput: Array[(Int, String)], rightInput: Array[(Int, String)]): Unit = { - val rightInputMap = rightInput.toMap - val leftNode = new DummyNode(joinNameAttributes, leftInput) - val rightNode = new DummyNode(joinNicknameAttributes, rightInput) - val makeBinaryHashJoinNode = (node1: LocalNode, node2: LocalNode) => { - val binaryHashJoinNode = - BinaryHashJoinNode(conf, Seq('id1), Seq('id2), buildSide, node1, node2) - resolveExpressions(binaryHashJoinNode) - } - val makeBroadcastJoinNode = (node1: LocalNode, node2: LocalNode) => { - val leftKeys = Seq('id1.attr) - val rightKeys = Seq('id2.attr) - // Figure out the build side and stream side. - val (buildNode, buildKeys, streamedNode, streamedKeys) = buildSide match { - case BuildLeft => (node1, leftKeys, node2, rightKeys) - case BuildRight => (node2, rightKeys, node1, leftKeys) - } - // Resolve the expressions of the build side and then create a HashedRelation. - val resolvedBuildNode = resolveExpressions(buildNode) - val resolvedBuildKeys = resolveExpressions(buildKeys, resolvedBuildNode) - val hashedRelation = buildHashedRelation(conf, resolvedBuildKeys, resolvedBuildNode) - val broadcastHashedRelation = mock(classOf[TorrentBroadcast[HashedRelation]]) - when(broadcastHashedRelation.value).thenReturn(hashedRelation) - - val hashJoinNode = - BroadcastHashJoinNode( - conf, - streamedKeys, - streamedNode, - buildSide, - resolvedBuildNode.output, - broadcastHashedRelation) - resolveExpressions(hashJoinNode) - } - - val expectedOutput = leftInput - .filter { case (k, _) => rightInputMap.contains(k) } - .map { case (k, v) => (k, v, k, rightInputMap(k)) } - - Seq(makeBinaryHashJoinNode, makeBroadcastJoinNode).foreach { makeNode => - val makeUnsafeNode = wrapForUnsafe(makeNode) - val hashJoinNode = makeUnsafeNode(leftNode, rightNode) - - val actualOutput = hashJoinNode.collect().map { row => - // (id, name, id, nickname) - (row.getInt(0), row.getString(1), row.getInt(2), row.getString(3)) - } - assert(actualOutput === expectedOutput) - } - } - - test(s"$testNamePrefix: empty") { - runTest(Array.empty, Array.empty) - runTest(someData, Array.empty) - runTest(Array.empty, someData) - } - - test(s"$testNamePrefix: no matches") { - val someIrrelevantData = (10000 to 100100).map { i => (i, "piper" + i) }.toArray - runTest(someData, Array.empty) - runTest(Array.empty, someData) - runTest(someData, someIrrelevantData) - runTest(someIrrelevantData, someData) - } - - test(s"$testNamePrefix: partial matches") { - val someOtherData = (50 to 150).map { i => (i, "finnegan" + i) }.toArray - runTest(someData, someOtherData) - runTest(someOtherData, someData) - } - - test(s"$testNamePrefix: full matches") { - val someSuperRelevantData = someData.map { case (k, v) => (k, "cooper" + v) }.toArray - runTest(someData, someSuperRelevantData) - runTest(someSuperRelevantData, someData) - } - } - -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala deleted file mode 100644 index c0ad2021b204a..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - - -class IntersectNodeSuite extends LocalNodeTest { - - test("basic") { - val n = 100 - val leftData = (1 to n).filter { i => i % 2 == 0 }.map { i => (i, i) }.toArray - val rightData = (1 to n).filter { i => i % 3 == 0 }.map { i => (i, i) }.toArray - val leftNode = new DummyNode(kvIntAttributes, leftData) - val rightNode = new DummyNode(kvIntAttributes, rightData) - val intersectNode = new IntersectNode(conf, leftNode, rightNode) - val expectedOutput = leftData.intersect(rightData) - val actualOutput = intersectNode.collect().map { case row => - (row.getInt(0), row.getInt(1)) - } - assert(actualOutput === expectedOutput) - } - -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala deleted file mode 100644 index fb790636a3689..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala +++ /dev/null @@ -1,41 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - - -class LimitNodeSuite extends LocalNodeTest { - - private def testLimit(inputData: Array[(Int, Int)] = Array.empty, limit: Int = 10): Unit = { - val inputNode = new DummyNode(kvIntAttributes, inputData) - val limitNode = new LimitNode(conf, limit, inputNode) - val expectedOutput = inputData.take(limit) - val actualOutput = limitNode.collect().map { case row => - (row.getInt(0), row.getInt(1)) - } - assert(actualOutput === expectedOutput) - } - - test("empty") { - testLimit() - } - - test("basic") { - testLimit((1 to 100).map { i => (i, i) }.toArray, 20) - } - -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala deleted file mode 100644 index 0d1ed99eec6cd..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala +++ /dev/null @@ -1,73 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - - -class LocalNodeSuite extends LocalNodeTest { - private val data = (1 to 100).map { i => (i, i) }.toArray - - test("basic open, next, fetch, close") { - val node = new DummyNode(kvIntAttributes, data) - assert(!node.isOpen) - node.open() - assert(node.isOpen) - data.foreach { case (k, v) => - assert(node.next()) - // fetch should be idempotent - val fetched = node.fetch() - assert(node.fetch() === fetched) - assert(node.fetch() === fetched) - assert(node.fetch().numFields === 2) - assert(node.fetch().getInt(0) === k) - assert(node.fetch().getInt(1) === v) - } - assert(!node.next()) - node.close() - assert(!node.isOpen) - } - - test("asIterator") { - val node = new DummyNode(kvIntAttributes, data) - val iter = node.asIterator - node.open() - data.foreach { case (k, v) => - // hasNext should be idempotent - assert(iter.hasNext) - assert(iter.hasNext) - val item = iter.next() - assert(item.numFields === 2) - assert(item.getInt(0) === k) - assert(item.getInt(1) === v) - } - intercept[NoSuchElementException] { - iter.next() - } - node.close() - } - - test("collect") { - val node = new DummyNode(kvIntAttributes, data) - node.open() - val collected = node.collect() - assert(collected.size === data.size) - assert(collected.forall(_.size === 2)) - assert(collected.map { case row => (row.getInt(0), row.getInt(0)) } === data) - node.close() - } - -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala deleted file mode 100644 index cd67a66ebf576..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala +++ /dev/null @@ -1,87 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{IntegerType, StringType} - -class LocalNodeTest extends SparkFunSuite { - - protected val conf: SQLConf = new SQLConf - protected val kvIntAttributes = Seq( - AttributeReference("k", IntegerType)(), - AttributeReference("v", IntegerType)()) - protected val joinNameAttributes = Seq( - AttributeReference("id1", IntegerType)(), - AttributeReference("name", StringType)()) - protected val joinNicknameAttributes = Seq( - AttributeReference("id2", IntegerType)(), - AttributeReference("nickname", StringType)()) - - /** - * Wrap a function processing two [[LocalNode]]s such that: - * (1) all input rows are automatically converted to unsafe rows - * (2) all output rows are automatically converted back to safe rows - */ - protected def wrapForUnsafe( - f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => LocalNode = { - (left: LocalNode, right: LocalNode) => { - val _left = ConvertToUnsafeNode(conf, left) - val _right = ConvertToUnsafeNode(conf, right) - val r = f(_left, _right) - ConvertToSafeNode(conf, r) - } - } - - /** - * Recursively resolve all expressions in a [[LocalNode]] using the node's attributes. - */ - protected def resolveExpressions(outputNode: LocalNode): LocalNode = { - outputNode transform { - case node: LocalNode => - val inputMap = node.output.map { a => (a.name, a) }.toMap - node transformExpressions { - case UnresolvedAttribute(Seq(u)) => - inputMap.getOrElse(u, - sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) - } - } - } - - /** - * Resolve all expressions in `expressions` based on the `output` of `localNode`. - * It assumes that all expressions in the `localNode` are resolved. - */ - protected def resolveExpressions( - expressions: Seq[Expression], - localNode: LocalNode): Seq[Expression] = { - require(localNode.expressions.forall(_.resolved)) - val inputMap = localNode.output.map { a => (a.name, a) }.toMap - expressions.map { expression => - expression.transformUp { - case UnresolvedAttribute(Seq(u)) => - inputMap.getOrElse(u, - sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) - } - } - } - -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala deleted file mode 100644 index bcc87a9175517..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala +++ /dev/null @@ -1,142 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} -import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} -import org.apache.spark.sql.internal.SQLConf - - -class NestedLoopJoinNodeSuite extends LocalNodeTest { - - // Test all combinations of the three dimensions: with/out unsafe, build sides, and join types - private val buildSides = Seq(BuildLeft, BuildRight) - private val joinTypes = Seq(LeftOuter, RightOuter, FullOuter) - buildSides.foreach { buildSide => - joinTypes.foreach { joinType => - testJoin(buildSide, joinType) - } - } - - /** - * Test outer nested loop joins with varying degrees of matches. - */ - private def testJoin(buildSide: BuildSide, joinType: JoinType): Unit = { - val testNamePrefix = s"$buildSide / $joinType" - val someData = (1 to 100).map { i => (i, "burger" + i) }.toArray - val conf = new SQLConf - - // Actual test body - def runTest( - joinType: JoinType, - leftInput: Array[(Int, String)], - rightInput: Array[(Int, String)]): Unit = { - val leftNode = new DummyNode(joinNameAttributes, leftInput) - val rightNode = new DummyNode(joinNicknameAttributes, rightInput) - val cond = 'id1 === 'id2 - val makeNode = (node1: LocalNode, node2: LocalNode) => { - resolveExpressions( - new NestedLoopJoinNode(conf, node1, node2, buildSide, joinType, Some(cond))) - } - val makeUnsafeNode = wrapForUnsafe(makeNode) - val hashJoinNode = makeUnsafeNode(leftNode, rightNode) - val expectedOutput = generateExpectedOutput(leftInput, rightInput, joinType) - val actualOutput = hashJoinNode.collect().map { row => - // ( - // id, name, - // id, nickname - // ) - ( - Option(row.get(0)).map(_.asInstanceOf[Int]), Option(row.getString(1)), - Option(row.get(2)).map(_.asInstanceOf[Int]), Option(row.getString(3)) - ) - } - assert(actualOutput.toSet === expectedOutput.toSet) - } - - test(s"$testNamePrefix: empty") { - runTest(joinType, Array.empty, Array.empty) - } - - test(s"$testNamePrefix: no matches") { - val someIrrelevantData = (10000 to 10100).map { i => (i, "piper" + i) }.toArray - runTest(joinType, someData, Array.empty) - runTest(joinType, Array.empty, someData) - runTest(joinType, someData, someIrrelevantData) - runTest(joinType, someIrrelevantData, someData) - } - - test(s"$testNamePrefix: partial matches") { - val someOtherData = (50 to 150).map { i => (i, "finnegan" + i) }.toArray - runTest(joinType, someData, someOtherData) - runTest(joinType, someOtherData, someData) - } - - test(s"$testNamePrefix: full matches") { - val someSuperRelevantData = someData.map { case (k, v) => (k, "cooper" + v) } - runTest(joinType, someData, someSuperRelevantData) - runTest(joinType, someSuperRelevantData, someData) - } - } - - /** - * Helper method to generate the expected output of a test based on the join type. - */ - private def generateExpectedOutput( - leftInput: Array[(Int, String)], - rightInput: Array[(Int, String)], - joinType: JoinType): Array[(Option[Int], Option[String], Option[Int], Option[String])] = { - joinType match { - case LeftOuter => - val rightInputMap = rightInput.toMap - leftInput.map { case (k, v) => - val rightKey = rightInputMap.get(k).map { _ => k } - val rightValue = rightInputMap.get(k) - (Some(k), Some(v), rightKey, rightValue) - } - - case RightOuter => - val leftInputMap = leftInput.toMap - rightInput.map { case (k, v) => - val leftKey = leftInputMap.get(k).map { _ => k } - val leftValue = leftInputMap.get(k) - (leftKey, leftValue, Some(k), Some(v)) - } - - case FullOuter => - val leftInputMap = leftInput.toMap - val rightInputMap = rightInput.toMap - val leftOutput = leftInput.map { case (k, v) => - val rightKey = rightInputMap.get(k).map { _ => k } - val rightValue = rightInputMap.get(k) - (Some(k), Some(v), rightKey, rightValue) - } - val rightOutput = rightInput.map { case (k, v) => - val leftKey = leftInputMap.get(k).map { _ => k } - val leftValue = leftInputMap.get(k) - (leftKey, leftValue, Some(k), Some(v)) - } - (leftOutput ++ rightOutput).distinct - - case other => - throw new IllegalArgumentException(s"Join type $other is not applicable") - } - } - -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala deleted file mode 100644 index 02ecb23d34b2f..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, NamedExpression} -import org.apache.spark.sql.types.{IntegerType, StringType} - - -class ProjectNodeSuite extends LocalNodeTest { - private val pieAttributes = Seq( - AttributeReference("id", IntegerType)(), - AttributeReference("age", IntegerType)(), - AttributeReference("name", StringType)()) - - private def testProject(inputData: Array[(Int, Int, String)] = Array.empty): Unit = { - val inputNode = new DummyNode(pieAttributes, inputData) - val columns = Seq[NamedExpression](inputNode.output(0), inputNode.output(2)) - val projectNode = new ProjectNode(conf, columns, inputNode) - val expectedOutput = inputData.map { case (id, age, name) => (id, name) } - val actualOutput = projectNode.collect().map { case row => - (row.getInt(0), row.getString(1)) - } - assert(actualOutput === expectedOutput) - } - - test("empty") { - testProject() - } - - test("basic") { - testProject((1 to 100).map { i => (i, i + 1, "pie" + i) }.toArray) - } - -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala deleted file mode 100644 index a3e83bbd51457..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} - - -class SampleNodeSuite extends LocalNodeTest { - - private def testSample(withReplacement: Boolean): Unit = { - val seed = 0L - val lowerb = 0.0 - val upperb = 0.3 - val maybeOut = if (withReplacement) "" else "out" - test(s"with$maybeOut replacement") { - val inputData = (1 to 1000).map { i => (i, i) }.toArray - val inputNode = new DummyNode(kvIntAttributes, inputData) - val sampleNode = new SampleNode(conf, lowerb, upperb, withReplacement, seed, inputNode) - val sampler = - if (withReplacement) { - new PoissonSampler[(Int, Int)](upperb - lowerb, useGapSamplingIfPossible = false) - } else { - new BernoulliCellSampler[(Int, Int)](lowerb, upperb) - } - sampler.setSeed(seed) - val expectedOutput = sampler.sample(inputData.iterator).toArray - val actualOutput = sampleNode.collect().map { case row => - (row.getInt(0), row.getInt(1)) - } - assert(actualOutput === expectedOutput) - } - } - - testSample(withReplacement = true) - testSample(withReplacement = false) -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala deleted file mode 100644 index 42ebc7bfcaadc..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.local - -import scala.util.Random - -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.SortOrder - - -class TakeOrderedAndProjectNodeSuite extends LocalNodeTest { - - private def testTakeOrderedAndProject(desc: Boolean): Unit = { - val limit = 10 - val ascOrDesc = if (desc) "desc" else "asc" - test(ascOrDesc) { - val inputData = Random.shuffle((1 to 100).toList).map { i => (i, i) }.toArray - val inputNode = new DummyNode(kvIntAttributes, inputData) - val firstColumn = inputNode.output(0) - val sortDirection = if (desc) Descending else Ascending - val sortOrder = SortOrder(firstColumn, sortDirection) - val takeOrderAndProjectNode = new TakeOrderedAndProjectNode( - conf, limit, Seq(sortOrder), Some(Seq(firstColumn)), inputNode) - val expectedOutput = inputData - .map { case (k, _) => k } - .sortBy { k => k * (if (desc) -1 else 1) } - .take(limit) - val actualOutput = takeOrderAndProjectNode.collect().map { row => row.getInt(0) } - assert(actualOutput === expectedOutput) - } - } - - testTakeOrderedAndProject(desc = false) - testTakeOrderedAndProject(desc = true) -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala deleted file mode 100644 index 666b0235c061d..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala +++ /dev/null @@ -1,55 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - - -class UnionNodeSuite extends LocalNodeTest { - - private def testUnion(inputData: Seq[Array[(Int, Int)]]): Unit = { - val inputNodes = inputData.map { data => - new DummyNode(kvIntAttributes, data) - } - val unionNode = new UnionNode(conf, inputNodes) - val expectedOutput = inputData.flatten - val actualOutput = unionNode.collect().map { case row => - (row.getInt(0), row.getInt(1)) - } - assert(actualOutput === expectedOutput) - } - - test("empty") { - testUnion(Seq(Array.empty)) - testUnion(Seq(Array.empty, Array.empty)) - } - - test("self") { - val data = (1 to 100).map { i => (i, i) }.toArray - testUnion(Seq(data)) - testUnion(Seq(data, data)) - testUnion(Seq(data, data, data)) - } - - test("basic") { - val zero = Array.empty[(Int, Int)] - val one = (1 to 100).map { i => (i, i) }.toArray - val two = (50 to 150).map { i => (i, i) }.toArray - val three = (800 to 900).map { i => (i, i) }.toArray - testUnion(Seq(zero, one, two, three)) - } - -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala new file mode 100644 index 0000000000000..4ddc218455eb2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.ConcurrentModificationException + +import org.scalatest.concurrent.AsyncAssertions._ +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.SharedSQLContext + +class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { + + test("basic") { + withTempDir { temp => + val metadataLog = new HDFSMetadataLog[String](sqlContext, temp.getAbsolutePath) + assert(metadataLog.add(0, "batch0")) + assert(metadataLog.getLatest() === Some(0 -> "batch0")) + assert(metadataLog.get(0) === Some("batch0")) + assert(metadataLog.getLatest() === Some(0 -> "batch0")) + assert(metadataLog.get(None, 0) === Array(0 -> "batch0")) + + assert(metadataLog.add(1, "batch1")) + assert(metadataLog.get(0) === Some("batch0")) + assert(metadataLog.get(1) === Some("batch1")) + assert(metadataLog.getLatest() === Some(1 -> "batch1")) + assert(metadataLog.get(None, 1) === Array(0 -> "batch0", 1 -> "batch1")) + + // Adding the same batch does nothing + metadataLog.add(1, "batch1-duplicated") + assert(metadataLog.get(0) === Some("batch0")) + assert(metadataLog.get(1) === Some("batch1")) + assert(metadataLog.getLatest() === Some(1 -> "batch1")) + assert(metadataLog.get(None, 1) === Array(0 -> "batch0", 1 -> "batch1")) + } + } + + test("restart") { + withTempDir { temp => + val metadataLog = new HDFSMetadataLog[String](sqlContext, temp.getAbsolutePath) + assert(metadataLog.add(0, "batch0")) + assert(metadataLog.add(1, "batch1")) + assert(metadataLog.get(0) === Some("batch0")) + assert(metadataLog.get(1) === Some("batch1")) + assert(metadataLog.getLatest() === Some(1 -> "batch1")) + assert(metadataLog.get(None, 1) === Array(0 -> "batch0", 1 -> "batch1")) + + val metadataLog2 = new HDFSMetadataLog[String](sqlContext, temp.getAbsolutePath) + assert(metadataLog2.get(0) === Some("batch0")) + assert(metadataLog2.get(1) === Some("batch1")) + assert(metadataLog2.getLatest() === Some(1 -> "batch1")) + assert(metadataLog2.get(None, 1) === Array(0 -> "batch0", 1 -> "batch1")) + } + } + + test("metadata directory collision") { + withTempDir { temp => + val waiter = new Waiter + val maxBatchId = 100 + for (id <- 0 until 10) { + new Thread() { + override def run(): Unit = waiter { + val metadataLog = new HDFSMetadataLog[String](sqlContext, temp.getAbsolutePath) + try { + var nextBatchId = metadataLog.getLatest().map(_._1).getOrElse(-1L) + nextBatchId += 1 + while (nextBatchId <= maxBatchId) { + metadataLog.add(nextBatchId, nextBatchId.toString) + nextBatchId += 1 + } + } catch { + case e: ConcurrentModificationException => + // This is expected since there are multiple writers + } finally { + waiter.dismiss() + } + } + }.start() + } + + waiter.await(timeout(10.seconds), dismissals(10)) + val metadataLog = new HDFSMetadataLog[String](sqlContext, temp.getAbsolutePath) + assert(metadataLog.getLatest() === Some(maxBatchId -> maxBatchId.toString)) + assert(metadataLog.get(None, maxBatchId) === (0 to maxBatchId).map(i => (i, i.toString))) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala index 97638a66ab473..67b3d98c1daed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets import scala.util.Random @@ -357,7 +358,8 @@ object ColumnarBatchBenchmark { val maxString = 32 val count = 4 * 1000 - val data = Seq.fill(count)(randomString(minString, maxString)).map(_.getBytes).toArray + val data = Seq.fill(count)(randomString(minString, maxString)) + .map(_.getBytes(StandardCharsets.UTF_8)).toArray def column(memoryMode: MemoryMode) = { i: Int => val column = ColumnVector.allocate(count, BinaryType, memoryMode) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index b3c3e66fbcbd5..ed97f59ea1690 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.vectorized +import java.nio.charset.StandardCharsets + import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Random @@ -329,18 +331,21 @@ class ColumnarBatchSuite extends SparkFunSuite { var idx = 0 val values = ("Hello" :: "abc" :: Nil).toArray - column.putByteArray(idx, values(0).getBytes, 0, values(0).getBytes().length) + column.putByteArray(idx, values(0).getBytes(StandardCharsets.UTF_8), + 0, values(0).getBytes(StandardCharsets.UTF_8).length) reference += values(0) idx += 1 assert(column.arrayData().elementsAppended == 5) - column.putByteArray(idx, values(1).getBytes, 0, values(1).getBytes().length) + column.putByteArray(idx, values(1).getBytes(StandardCharsets.UTF_8), + 0, values(1).getBytes(StandardCharsets.UTF_8).length) reference += values(1) idx += 1 assert(column.arrayData().elementsAppended == 8) // Just put llo - val offset = column.putByteArray(idx, values(0).getBytes, 2, values(0).getBytes().length - 2) + val offset = column.putByteArray(idx, values(0).getBytes(StandardCharsets.UTF_8), + 2, values(0).getBytes(StandardCharsets.UTF_8).length - 2) reference += "llo" idx += 1 assert(column.arrayData().elementsAppended == 11) @@ -353,7 +358,7 @@ class ColumnarBatchSuite extends SparkFunSuite { // Put a long string val s = "abcdefghijklmnopqrstuvwxyz" - column.putByteArray(idx, (s + s).getBytes) + column.putByteArray(idx, (s + s).getBytes(StandardCharsets.UTF_8)) reference += (s + s) idx += 1 assert(column.arrayData().elementsAppended == 11 + (s + s).length) @@ -473,7 +478,7 @@ class ColumnarBatchSuite extends SparkFunSuite { batch.column(0).putInt(0, 1) batch.column(1).putDouble(0, 1.1) batch.column(2).putNull(0) - batch.column(3).putByteArray(0, "Hello".getBytes) + batch.column(3).putByteArray(0, "Hello".getBytes(StandardCharsets.UTF_8)) batch.setNumRows(1) // Verify the results of the row. @@ -519,17 +524,17 @@ class ColumnarBatchSuite extends SparkFunSuite { batch.column(0).putNull(0) batch.column(1).putDouble(0, 2.2) batch.column(2).putInt(0, 2) - batch.column(3).putByteArray(0, "abc".getBytes) + batch.column(3).putByteArray(0, "abc".getBytes(StandardCharsets.UTF_8)) batch.column(0).putInt(1, 3) batch.column(1).putNull(1) batch.column(2).putInt(1, 3) - batch.column(3).putByteArray(1, "".getBytes) + batch.column(3).putByteArray(1, "".getBytes(StandardCharsets.UTF_8)) batch.column(0).putInt(2, 4) batch.column(1).putDouble(2, 4.4) batch.column(2).putInt(2, 4) - batch.column(3).putByteArray(2, "world".getBytes) + batch.column(3).putByteArray(2, "world".getBytes(StandardCharsets.UTF_8)) batch.setNumRows(3) def rowEquals(x: InternalRow, y: Row): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 1ef517324d7cb..f66deea06589c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -359,7 +359,7 @@ class JDBCSuite extends SparkFunSuite .collect().length === 3) } - test("Partioning on column that might have null values.") { + test("Partitioning on column that might have null values.") { assert( sqlContext.read.jdbc(urlWithUserAndPass, "TEST.EMP", "theid", 0, 4, 3, new Properties) .collect().length === 4) @@ -372,7 +372,7 @@ class JDBCSuite extends SparkFunSuite .collect().length === 4) } - test("SELECT * on partitioned table with a nullable partioncolumn") { + test("SELECT * on partitioned table with a nullable partition column") { assert(sql("SELECT * FROM nullparts").collect().size == 4) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 26c1ff520406c..99f1661ad0d15 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -339,7 +339,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { test("exceptions") { // Make sure we do throw correct exception when users use a relation provider that - // only implements the RelationProvier or the SchemaRelationProvider. + // only implements the RelationProvider or the SchemaRelationProvider. val schemaNotAllowed = intercept[Exception] { sql( """ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQuerySuite.scala index dac1a398ff5e6..84ed017a9d0d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQuerySuite.scala @@ -27,7 +27,7 @@ class ContinuousQuerySuite extends StreamTest with SharedSQLContext { import AwaitTerminationTester._ import testImplicits._ - test("lifecycle states and awaitTermination") { + testQuietly("lifecycle states and awaitTermination") { val inputData = MemoryStream[Int] val mapped = inputData.toDS().map { 6 / _} @@ -59,7 +59,7 @@ class ContinuousQuerySuite extends StreamTest with SharedSQLContext { ) } - test("source and sink statuses") { + testQuietly("source and sink statuses") { val inputData = MemoryStream[Int] val mapped = inputData.toDS().map(6 / _) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index e9d77abb8c23c..4c18e38db8280 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -17,14 +17,11 @@ package org.apache.spark.sql.streaming -import java.io.{ByteArrayInputStream, File, FileNotFoundException, InputStream} - -import com.google.common.base.Charsets.UTF_8 +import java.io.File import org.apache.spark.sql.{AnalysisException, StreamTest} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.FileStreamSource._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.Utils @@ -360,59 +357,6 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { Utils.deleteRecursively(tmp) } - test("fault tolerance with corrupted metadata file") { - val src = Utils.createTempDir("streaming.src") - assert(new File(src, "_metadata").mkdirs()) - stringToFile( - new File(src, "_metadata/0"), - s"${FileStreamSource.VERSION}\nSTART\n-/a/b/c\n-/e/f/g\nEND\n") - stringToFile(new File(src, "_metadata/1"), s"${FileStreamSource.VERSION}\nSTART\n-") - - val textSource = createFileStreamSource("text", src.getCanonicalPath) - // the metadata file of batch is corrupted, so currentOffset should be 0 - assert(textSource.currentOffset === LongOffset(0)) - - Utils.deleteRecursively(src) - } - - test("fault tolerance with normal metadata file") { - val src = Utils.createTempDir("streaming.src") - assert(new File(src, "_metadata").mkdirs()) - stringToFile( - new File(src, "_metadata/0"), - s"${FileStreamSource.VERSION}\nSTART\n-/a/b/c\n-/e/f/g\nEND\n") - stringToFile( - new File(src, "_metadata/1"), - s"${FileStreamSource.VERSION}\nSTART\n-/x/y/z\nEND\n") - - val textSource = createFileStreamSource("text", src.getCanonicalPath) - assert(textSource.currentOffset === LongOffset(1)) - - Utils.deleteRecursively(src) - } - - test("readBatch") { - def stringToStream(str: String): InputStream = new ByteArrayInputStream(str.getBytes(UTF_8)) - - // Invalid metadata - assert(readBatch(stringToStream("")) === Nil) - assert(readBatch(stringToStream(FileStreamSource.VERSION)) === Nil) - assert(readBatch(stringToStream(s"${FileStreamSource.VERSION}\n")) === Nil) - assert(readBatch(stringToStream(s"${FileStreamSource.VERSION}\nSTART")) === Nil) - assert(readBatch(stringToStream(s"${FileStreamSource.VERSION}\nSTART\n-")) === Nil) - assert(readBatch(stringToStream(s"${FileStreamSource.VERSION}\nSTART\n-/a/b/c")) === Nil) - assert(readBatch(stringToStream(s"${FileStreamSource.VERSION}\nSTART\n-/a/b/c\n")) === Nil) - assert(readBatch(stringToStream(s"${FileStreamSource.VERSION}\nSTART\n-/a/b/c\nEN")) === Nil) - - // Valid metadata - assert(readBatch(stringToStream( - s"${FileStreamSource.VERSION}\nSTART\n-/a/b/c\nEND")) === Seq("/a/b/c")) - assert(readBatch(stringToStream( - s"${FileStreamSource.VERSION}\nSTART\n-/a/b/c\nEND\n")) === Seq("/a/b/c")) - assert(readBatch(stringToStream( - s"${FileStreamSource.VERSION}\nSTART\n-/a/b/c\n-/e/f/g\nEND\n")) - === Seq("/a/b/c", "/e/f/g")) - } } class FileStreamSourceStressTestSuite extends FileStreamSourceTest with SharedSQLContext { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 83c63e04f344a..7fa6760b71c8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.test +import java.nio.charset.StandardCharsets + import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits} @@ -103,11 +105,11 @@ private[sql] trait SQLTestData { self => protected lazy val binaryData: DataFrame = { val df = sqlContext.sparkContext.parallelize( - BinaryData("12".getBytes, 1) :: - BinaryData("22".getBytes, 5) :: - BinaryData("122".getBytes, 3) :: - BinaryData("121".getBytes, 2) :: - BinaryData("123".getBytes, 4) :: Nil).toDF() + BinaryData("12".getBytes(StandardCharsets.UTF_8), 1) :: + BinaryData("22".getBytes(StandardCharsets.UTF_8), 5) :: + BinaryData("122".getBytes(StandardCharsets.UTF_8), 3) :: + BinaryData("121".getBytes(StandardCharsets.UTF_8), 2) :: + BinaryData("123".getBytes(StandardCharsets.UTF_8), 4) :: Nil).toDF() df.registerTempTable("binaryData") df } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 81508e134695a..694bd97515b86 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.thriftserver import java.io._ +import java.nio.charset.StandardCharsets import java.sql.Timestamp import java.util.Date @@ -67,7 +68,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { * with one of these strings is found, fail the test immediately. * The default value is `Seq("Error:")` * - * @param queriesAndExpectedAnswers one or more tupes of query + answer + * @param queriesAndExpectedAnswers one or more tuples of query + answer */ def runCliWithin( timeout: FiniteDuration, @@ -121,7 +122,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { val process = new ProcessBuilder(command: _*).start() - val stdinWriter = new OutputStreamWriter(process.getOutputStream) + val stdinWriter = new OutputStreamWriter(process.getOutputStream, StandardCharsets.UTF_8) stdinWriter.write(queriesString) stdinWriter.flush() stdinWriter.close() diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index c05527b519daa..e89bb1c470d5a 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive.thriftserver import java.io.File import java.net.URL +import java.nio.charset.StandardCharsets import java.sql.{Date, DriverManager, SQLException, Statement} import scala.collection.mutable @@ -28,7 +29,6 @@ import scala.concurrent.duration._ import scala.io.Source import scala.util.{Random, Try} -import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.jdbc.HiveDriver @@ -700,7 +700,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n """.stripMargin, new File(s"$tempLog4jConf/log4j.properties"), - UTF_8) + StandardCharsets.UTF_8) tempLog4jConf } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 081d849a88886..a78b7b0cc4961 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive import java.io.File import java.net.{URL, URLClassLoader} +import java.nio.charset.StandardCharsets import java.sql.Timestamp import java.util.concurrent.TimeUnit import java.util.regex.Pattern @@ -209,6 +210,7 @@ class HiveContext private[hive]( logInfo(s"Initializing execution hive, version $hiveExecutionVersion") val loader = new IsolatedClientLoader( version = IsolatedClientLoader.hiveVersion(hiveExecutionVersion), + sparkConf = sc.conf, execJars = Seq(), config = newTemporaryConfiguration(useInMemoryDerby = true), isolationOn = false, @@ -277,6 +279,7 @@ class HiveContext private[hive]( s"Initializing HiveMetastoreConnection version $hiveMetastoreVersion using Spark classes.") new IsolatedClientLoader( version = metaVersion, + sparkConf = sc.conf, execJars = jars.toSeq, config = allConfig, isolationOn = true, @@ -289,6 +292,7 @@ class HiveContext private[hive]( IsolatedClientLoader.forVersion( hiveMetastoreVersion = hiveMetastoreVersion, hadoopVersion = VersionInfo.getVersion, + sparkConf = sc.conf, config = allConfig, barrierPrefixes = hiveMetastoreBarrierPrefixes, sharedPrefixes = hiveMetastoreSharedPrefixes) @@ -316,6 +320,7 @@ class HiveContext private[hive]( s"using ${jars.mkString(":")}") new IsolatedClientLoader( version = metaVersion, + sparkConf = sc.conf, execJars = jars.toSeq, config = allConfig, isolationOn = true, @@ -343,12 +348,12 @@ class HiveContext private[hive]( * @since 1.3.0 */ def refreshTable(tableName: String): Unit = { - val tableIdent = sqlParser.parseTableIdentifier(tableName) + val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) catalog.refreshTable(tableIdent) } protected[hive] def invalidateTable(tableName: String): Unit = { - val tableIdent = sqlParser.parseTableIdentifier(tableName) + val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) catalog.invalidateTable(tableIdent) } @@ -362,7 +367,7 @@ class HiveContext private[hive]( * @since 1.2.0 */ def analyze(tableName: String) { - val tableIdent = sqlParser.parseTableIdentifier(tableName) + val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) val relation = EliminateSubqueryAliases(catalog.lookupRelation(tableIdent)) relation match { @@ -715,7 +720,7 @@ private[hive] object HiveContext { case (null, _) => "NULL" case (d: Int, DateType) => new DateWritable(d).toString case (t: Timestamp, TimestampType) => new TimestampWritable(t).toString - case (bin: Array[Byte], BinaryType) => new String(bin, "UTF-8") + case (bin: Array[Byte], BinaryType) => new String(bin, StandardCharsets.UTF_8) case (decimal: java.math.BigDecimal, DecimalType()) => // Hive strips trailing zeros so use its toString HiveDecimal.create(decimal).toString diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 8f6cd66f1f681..c70510b4834d6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -41,11 +41,11 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.execution.{datasources, FileRelation} -import org.apache.spark.sql.execution.datasources.{Partition => ParquetPartition, _} +import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{DefaultSource, ParquetRelation} import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.HiveNativeCommand -import org.apache.spark.sql.sources._ +import org.apache.spark.sql.sources.{HadoopFsRelation, HDFSFileCatalog} import org.apache.spark.sql.types._ private[hive] case class HiveSerDe( @@ -469,7 +469,7 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte parquetRelation.location.paths.map(_.toString).toSet == pathsInMetastore.toSet && logical.schema.sameType(metastoreSchema) && parquetRelation.partitionSpec == partitionSpecInMetastore.getOrElse { - PartitionSpec(StructType(Nil), Array.empty[datasources.Partition]) + PartitionSpec(StructType(Nil), Array.empty[datasources.PartitionDirectory]) } if (useCached) { @@ -499,7 +499,7 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte val values = InternalRow.fromSeq(p.getValues.asScala.zip(partitionColumnDataTypes).map { case (rawValue, dataType) => Cast(Literal(rawValue), dataType).eval(null) }) - ParquetPartition(values, location) + PartitionDirectory(values, location) } val partitionSpec = PartitionSpec(partitionSchema, partitions) @@ -753,7 +753,7 @@ class MetaStoreFileCatalog( hive: HiveContext, paths: Seq[Path], partitionSpecFromHive: PartitionSpec) - extends HDFSFileCatalog(hive, Map.empty, paths) { + extends HDFSFileCatalog(hive, Map.empty, paths, Some(partitionSpecFromHive.partitionColumns)) { override def getStatus(path: Path): Array[FileStatus] = { @@ -761,7 +761,7 @@ class MetaStoreFileCatalog( fs.listStatus(path) } - override def partitionSpec(schema: Option[StructType]): PartitionSpec = partitionSpecFromHive + override def partitionSpec(): PartitionSpec = partitionSpecFromHive } /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 56acb87c800d3..739fbaf4446ff 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -245,7 +245,7 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging val tableName = tableNameParts.map { case Token(p, Nil) => p }.mkString(".") DropTable(tableName, ifExists.nonEmpty) - // Support "ANALYZE TABLE tableNmae COMPUTE STATISTICS noscan" + // Support "ANALYZE TABLE tableName COMPUTE STATISTICS noscan" case Token("TOK_ANALYZE", Token("TOK_TAB", Token("TOK_TABNAME", tableNameParts) :: partitionSpec) :: isNoscan) => // Reference: @@ -535,7 +535,7 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging case Token("TOK_STORAGEHANDLER", _) => throw new AnalysisException( "CREATE TABLE AS SELECT cannot be used for a non-native table") - case _ => // Unsupport features + case _ => // Unsupported features } CreateTableAsSelect(tableDesc, nodeToPlan(query), allowExisting.isDefined) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index 614f9e05d76f9..d9cd96d66f493 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -74,11 +74,12 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx) * Planner that takes into account Hive-specific strategies. */ override lazy val planner: SparkPlanner = { - new SparkPlanner(ctx) with HiveStrategies { + new SparkPlanner(ctx.sparkContext, conf, experimentalMethods) with HiveStrategies { override val hiveContext = ctx override def strategies: Seq[Strategy] = { - ctx.experimental.extraStrategies ++ Seq( + experimentalMethods.extraStrategies ++ Seq( + FileSourceStrategy, DataSourceStrategy, HiveCommandStrategy(ctx), HiveDDLStrategy, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index a19dc216382ef..f44937ec6f980 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -102,18 +102,8 @@ private[hive] trait HiveStrategies { case class HiveCommandStrategy(context: HiveContext) extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case describe: DescribeCommand => - val resolvedTable = context.executePlan(describe.table).analyzed - resolvedTable match { - case t: MetastoreRelation => - ExecutedCommand( - DescribeHiveTableCommand(t, describe.output, describe.isExtended)) :: Nil - - case o: LogicalPlan => - val resultPlan = context.executePlan(o).executedPlan - ExecutedCommand(RunnableDescribeCommand( - resultPlan, describe.output, describe.isExtended)) :: Nil - } - + ExecutedCommand( + DescribeHiveTableCommand(describe.table, describe.output, describe.isExtended)) :: Nil case _ => Nil } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index c1c8e631ee740..c108750c383cc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -60,6 +60,7 @@ import org.apache.spark.util.{CircularBuffer, Utils} */ private[hive] class HiveClientImpl( override val version: HiveVersion, + sparkConf: SparkConf, config: Map[String, String], initClassLoader: ClassLoader, val clientLoader: IsolatedClientLoader) @@ -90,7 +91,6 @@ private[hive] class HiveClientImpl( // instance of SparkConf is needed for the original value of spark.yarn.keytab // and spark.yarn.principal set in SparkSubmit, as yarn.Client resets the // keytab configuration for the link name in distributed cache - val sparkConf = new SparkConf if (sparkConf.contains("spark.yarn.principal") && sparkConf.contains("spark.yarn.keytab")) { val principalName = sparkConf.get("spark.yarn.principal") val keytabFileName = sparkConf.get("spark.yarn.keytab") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 1653371d89f07..024f4dfeba9d8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -27,7 +27,7 @@ import scala.util.Try import org.apache.commons.io.{FileUtils, IOUtils} -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.SparkSubmitUtils import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.hive.HiveContext @@ -41,6 +41,7 @@ private[hive] object IsolatedClientLoader extends Logging { def forVersion( hiveMetastoreVersion: String, hadoopVersion: String, + sparkConf: SparkConf, config: Map[String, String] = Map.empty, ivyPath: Option[String] = None, sharedPrefixes: Seq[String] = Seq.empty, @@ -75,7 +76,8 @@ private[hive] object IsolatedClientLoader extends Logging { } new IsolatedClientLoader( - version = hiveVersion(hiveMetastoreVersion), + hiveVersion(hiveMetastoreVersion), + sparkConf, execJars = files, config = config, sharesHadoopClasses = sharesHadoopClasses, @@ -146,6 +148,7 @@ private[hive] object IsolatedClientLoader extends Logging { */ private[hive] class IsolatedClientLoader( val version: HiveVersion, + val sparkConf: SparkConf, val execJars: Seq[URL] = Seq.empty, val config: Map[String, String] = Map.empty, val isolationOn: Boolean = true, @@ -235,7 +238,7 @@ private[hive] class IsolatedClientLoader( /** The isolated client interface to Hive. */ private[hive] def createClient(): HiveClient = { if (!isolationOn) { - return new HiveClientImpl(version, config, baseClassLoader, this) + return new HiveClientImpl(version, sparkConf, config, baseClassLoader, this) } // Pre-reflective instantiation setup. logDebug("Initializing the logger to avoid disaster...") @@ -246,7 +249,7 @@ private[hive] class IsolatedClientLoader( classLoader .loadClass(classOf[HiveClientImpl].getName) .getConstructors.head - .newInstance(version, config, classLoader, this) + .newInstance(version, sparkConf, config, classLoader, this) .asInstanceOf[HiveClient] } catch { case e: InvocationTargetException => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala index 57293fce97295..8481324086c34 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala @@ -22,8 +22,10 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.execution.command.RunnableCommand +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.command.{DescribeCommand, RunnableCommand} import org.apache.spark.sql.hive.MetastoreRelation /** @@ -31,33 +33,44 @@ import org.apache.spark.sql.hive.MetastoreRelation */ private[hive] case class DescribeHiveTableCommand( - table: MetastoreRelation, + tableId: TableIdentifier, override val output: Seq[Attribute], isExtended: Boolean) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { - // Trying to mimic the format of Hive's output. But not exactly the same. - var results: Seq[(String, String, String)] = Nil - - val columns: Seq[FieldSchema] = table.hiveQlTable.getCols.asScala - val partitionColumns: Seq[FieldSchema] = table.hiveQlTable.getPartCols.asScala - results ++= columns.map(field => (field.getName, field.getType, field.getComment)) - if (partitionColumns.nonEmpty) { - val partColumnInfo = - partitionColumns.map(field => (field.getName, field.getType, field.getComment)) - results ++= - partColumnInfo ++ - Seq(("# Partition Information", "", "")) ++ - Seq((s"# ${output(0).name}", output(1).name, output(2).name)) ++ - partColumnInfo - } + // There are two modes here: + // For metastore tables, create an output similar to Hive's. + // For other tables, delegate to DescribeCommand. - if (isExtended) { - results ++= Seq(("Detailed Table Information", table.hiveQlTable.getTTable.toString, "")) - } + // In the future, we will consolidate the two and simply report what the catalog reports. + sqlContext.sessionState.catalog.lookupRelation(tableId) match { + case table: MetastoreRelation => + // Trying to mimic the format of Hive's output. But not exactly the same. + var results: Seq[(String, String, String)] = Nil + + val columns: Seq[FieldSchema] = table.hiveQlTable.getCols.asScala + val partitionColumns: Seq[FieldSchema] = table.hiveQlTable.getPartCols.asScala + results ++= columns.map(field => (field.getName, field.getType, field.getComment)) + if (partitionColumns.nonEmpty) { + val partColumnInfo = + partitionColumns.map(field => (field.getName, field.getType, field.getComment)) + results ++= + partColumnInfo ++ + Seq(("# Partition Information", "", "")) ++ + Seq((s"# ${output(0).name}", output(1).name, output(2).name)) ++ + partColumnInfo + } + + if (isExtended) { + results ++= Seq(("Detailed Table Information", table.hiveQlTable.getTTable.toString, "")) + } + + results.map { case (name, dataType, comment) => + Row(name, dataType, comment) + } - results.map { case (name, dataType, comment) => - Row(name, dataType, comment) + case o: LogicalPlan => + DescribeCommand(tableId, output, isExtended).run(sqlContext) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 5e6641693798f..b6e2f1f6b3ab7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.execution import java.io._ +import java.nio.charset.StandardCharsets import java.util.Properties import javax.annotation.Nullable @@ -113,7 +114,7 @@ case class ScriptTransformation( ioschema.initOutputSerDe(output).getOrElse((null, null)) } - val reader = new BufferedReader(new InputStreamReader(inputStream)) + val reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)) val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors { var curLine: String = null val scriptOutputStream = new DataInputStream(inputStream) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala index 059ad8b1f7274..8240f2f2220cd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala @@ -89,7 +89,7 @@ private[orc] object OrcFileOperator extends Logging { } def listOrcFiles(pathStr: String, conf: Configuration): Seq[Path] = { - // TODO: Check if the paths comming in are already qualified and simplify. + // TODO: Check if the paths coming in are already qualified and simplify. val origPath = new Path(pathStr) val fs = origPath.getFileSystem(conf) val path = origPath.makeQualified(fs.getUri, fs.getWorkingDirectory) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 8a39d95fc5677..ae041c5137f0f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -111,7 +111,7 @@ private[sql] class DefaultSource extends FileFormat with DataSourceRegister { requiredColumns: Array[String], filters: Array[Filter], bucketSet: Option[BitSet], - inputFiles: Array[FileStatus], + inputFiles: Seq[FileStatus], broadcastedConf: Broadcast[SerializableConfiguration], options: Map[String, String]): RDD[InternalRow] = { val output = StructType(requiredColumns.map(dataSchema(_))).toAttributes @@ -221,7 +221,7 @@ private[orc] case class OrcTableScan( @transient sqlContext: SQLContext, attributes: Seq[Attribute], filters: Array[Filter], - @transient inputPaths: Array[FileStatus]) + @transient inputPaths: Seq[FileStatus]) extends Logging with HiveInspectors { diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java index 5a539eaec7507..e9356541c22df 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java @@ -48,7 +48,6 @@ public class JavaMetastoreDataSourcesSuite { private transient JavaSparkContext sc; private transient HiveContext sqlContext; - String originalDefaultSource; File path; Path hiveManagedPath; FileSystem fs; @@ -66,7 +65,6 @@ public void setUp() throws IOException { sqlContext = TestHive$.MODULE$; sc = new JavaSparkContext(sqlContext.sparkContext()); - originalDefaultSource = sqlContext.conf().defaultDataSourceName(); path = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource").getCanonicalFile(); if (path.exists()) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveCatalogSuite.scala index f557abcd522e6..2809f9439b823 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveCatalogSuite.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.hive import org.apache.hadoop.util.VersionInfo +import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.hive.client.{HiveClient, IsolatedClientLoader} import org.apache.spark.util.Utils - /** * Test suite for the [[HiveCatalog]]. */ @@ -32,7 +32,8 @@ class HiveCatalogSuite extends CatalogTestCases { private val client: HiveClient = { IsolatedClientLoader.forVersion( hiveMetastoreVersion = HiveContext.hiveExecutionVersion, - hadoopVersion = VersionInfo.getVersion).createClient() + hadoopVersion = VersionInfo.getVersion, + sparkConf = new SparkConf()).createClient() } protected override val tableInputFormat: String = diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index d850d522be297..6292f6c3af02b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -21,7 +21,7 @@ import java.io.File import org.apache.hadoop.util.VersionInfo -import org.apache.spark.{Logging, SparkFunSuite} +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal, NamedExpression} import org.apache.spark.sql.catalyst.util.quietly @@ -39,6 +39,8 @@ import org.apache.spark.util.Utils @ExtendedHiveTest class VersionsSuite extends SparkFunSuite with Logging { + private val sparkConf = new SparkConf() + // In order to speed up test execution during development or in Jenkins, you can specify the path // of an existing Ivy cache: private val ivyPath: Option[String] = { @@ -59,6 +61,7 @@ class VersionsSuite extends SparkFunSuite with Logging { val badClient = IsolatedClientLoader.forVersion( hiveMetastoreVersion = HiveContext.hiveExecutionVersion, hadoopVersion = VersionInfo.getVersion, + sparkConf = sparkConf, config = buildConf(), ivyPath = ivyPath).createClient() val db = new CatalogDatabase("default", "desc", "loc", Map()) @@ -93,6 +96,7 @@ class VersionsSuite extends SparkFunSuite with Logging { IsolatedClientLoader.forVersion( hiveMetastoreVersion = "13", hadoopVersion = VersionInfo.getVersion, + sparkConf = sparkConf, config = buildConf(), ivyPath = ivyPath).createClient() } @@ -112,6 +116,7 @@ class VersionsSuite extends SparkFunSuite with Logging { IsolatedClientLoader.forVersion( hiveMetastoreVersion = version, hadoopVersion = VersionInfo.getVersion, + sparkConf = sparkConf, config = buildConf(), ivyPath = ivyPath).createClient() } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 1053246fc2958..5e452d107dc75 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -143,7 +143,7 @@ abstract class HiveComparisonTest 0D } - s"""SQLBuiler statistics: + s"""SQLBuilder statistics: |- Total query number: $numTotalQueries |- Number of convertible queries: $numConvertibleQueries |- Percentage of convertible queries: $percentage% diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 10024874472f2..d905f0cd68a4a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -602,7 +602,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |select * where key = 4 """.stripMargin) - // test get_json_object again Hive, because the HiveCompatabilitySuite cannot handle result + // test get_json_object again Hive, because the HiveCompatibilitySuite cannot handle result // with newline in it. createQueryTest("get_json_object #1", "SELECT get_json_object(src_json.json, '$') FROM src_json") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 9ca07e96eb088..8cfb32f00a884 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.orc import java.io.File +import java.nio.charset.StandardCharsets import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.io.orc.CompressionKind @@ -73,7 +74,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { test("Read/write binary data") { withOrcFile(BinaryData("test".getBytes("utf8")) :: Nil) { file => val bytes = read.orc(file).head().getAs[Array[Byte]](0) - assert(new String(bytes, "utf8") === "test") + assert(new String(bytes, StandardCharsets.UTF_8) === "test") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index d77c88fa4b384..33c1bb059e2fe 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -69,7 +69,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") def tableDir: File = { - val identifier = hiveContext.sqlParser.parseTableIdentifier("bucketed_table") + val identifier = hiveContext.sessionState.sqlParser.parseTableIdentifier("bucketed_table") new File(URI.create(hiveContext.catalog.hiveDefaultTableFilePath(identifier))) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 11a4c7dfd011f..16c575bcc13ab 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -205,7 +205,7 @@ class CheckpointWriter( // also use the latest checkpoint time as the file name, so that we can recovery from the // latest checkpoint file. // - // Note: there is only one thread writting the checkpoint files, so we don't need to worry + // Note: there is only one thread writing the checkpoint files, so we don't need to worry // about thread-safety. val checkpointFile = Checkpoint.checkpointFile(checkpointDir, latestCheckpointTime) val backupFile = Checkpoint.checkpointBackupFile(checkpointDir, latestCheckpointTime) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 25e61578a1860..e7f3a213d468e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -299,7 +299,7 @@ class StreamingContext private[streaming] ( /** * Create a input stream from TCP source hostname:port. Data is received using - * a TCP socket and the receive bytes it interepreted as object using the given + * a TCP socket and the receive bytes it interpreted as object using the given * converter. * @param hostname Hostname to connect to for receiving data * @param port Port to connect to for receiving data diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index aad9a12c15246..2a80cf4466588 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -155,7 +155,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( /** * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are * merged using the supplied reduce function. org.apache.spark.Partitioner is used to control - * thepartitioning of each RDD. + * the partitioning of each RDD. */ def reduceByKey(func: JFunction2[V, V, V], partitioner: Partitioner): JavaPairDStream[K, V] = { dstream.reduceByKey(func, partitioner) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 860b8027253fd..05f4da6face4d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -530,7 +530,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Return the current state of the context. The context can be in three possible states - *

    *
  • - * StreamingContextState.INTIALIZED - The context has been created, but not been started yet. + * StreamingContextState.INITIALIZED - The context has been created, but not been started yet. * Input DStreams, transformations and output operations can be created on the context. *
  • *
  • diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index 1dcdb64e289bd..d6ff96e1fc696 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -446,7 +446,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) * remember the partitioner despite the key being changed. * @param partitioner Partitioner for controlling the partitioning of each RDD in the new * DStream - * @param rememberPartitioner Whether to remember the paritioner object in the generated RDDs. + * @param rememberPartitioner Whether to remember the partitioner object in the generated RDDs. * @tparam S State type */ def updateStateByKey[S: ClassTag]( @@ -490,7 +490,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) * remember the partitioner despite the key being changed. * @param partitioner Partitioner for controlling the partitioning of each RDD in the new * DStream - * @param rememberPartitioner Whether to remember the paritioner object in the generated RDDs. + * @param rememberPartitioner Whether to remember the partitioner object in the generated RDDs. * @param initialRDD initial state value of each key. * @tparam S State type */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala index 441477479167a..f7519c10c8eb1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala @@ -19,6 +19,7 @@ package org.apache.spark.streaming.dstream import java.io._ import java.net.{ConnectException, Socket} +import java.nio.charset.StandardCharsets import scala.reflect.ClassTag import scala.util.control.NonFatal @@ -113,7 +114,8 @@ object SocketReceiver { * to '\n' delimited strings and returns an iterator to access the strings. */ def bytesToLines(inputStream: InputStream): Iterator[String] = { - val dataInputStream = new BufferedReader(new InputStreamReader(inputStream, "UTF-8")) + val dataInputStream = new BufferedReader( + new InputStreamReader(inputStream, StandardCharsets.UTF_8)) new NextIterator[String] { protected override def getNext() = { val nextValue = dataInputStream.readLine() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala index 080bc873fa0a8..47eb9b806fa7d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala @@ -55,7 +55,7 @@ class TransformedDStream[U: ClassTag] ( /** * Wrap a body of code such that the call site and operation scope * information are passed to the RDDs created in this body properly. - * This has been overriden to make sure that `displayInnerRDDOps` is always `true`, that is, + * This has been overridden to make sure that `displayInnerRDDOps` is always `true`, that is, * the inner scopes and callsites of RDDs generated in `DStream.transform` are always * displayed in the UI. */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala index 430f35a400dbe..d6fcc582b9c4d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala @@ -257,7 +257,7 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) } batchUIData.foreach { _batchUIData => // We use an Iterable rather than explicitly converting to a seq so that updates - // will propegate + // will propagate val outputOpIdToSparkJobIds: Iterable[OutputOpIdAndSparkJobId] = Option(batchTimeToOutputOpIdSparkJobIdPair.get(batchTime).asScala) .getOrElse(Seq.empty) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala index 2be1d6df86f89..3a21cfae5ac2f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -177,7 +177,7 @@ private[streaming] class OpenHashMapBasedStateMap[K, S]( new OpenHashMapBasedStateMap[K, S](this, deltaChainThreshold = deltaChainThreshold) } - /** Whether the delta chain lenght is long enough that it should be compacted */ + /** Whether the delta chain length is long enough that it should be compacted */ def shouldCompact: Boolean = { deltaChainLength >= deltaChainThreshold } diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index 806cea24caddb..66448fd40057d 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -18,7 +18,7 @@ package org.apache.spark.streaming; import java.io.*; -import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; import java.util.*; import java.util.concurrent.atomic.AtomicBoolean; @@ -1866,7 +1866,8 @@ public void testSocketString() { @Override public Iterable call(InputStream in) throws IOException { List out = new ArrayList<>(); - try (BufferedReader reader = new BufferedReader(new InputStreamReader(in))) { + try (BufferedReader reader = new BufferedReader( + new InputStreamReader(in, StandardCharsets.UTF_8))) { for (String line; (line = reader.readLine()) != null;) { out.add(line); } @@ -1930,7 +1931,7 @@ public void testRawSocketStream() { private static List> fileTestPrepare(File testDir) throws IOException { File existingFile = new File(testDir, "0"); - Files.write("0\n", existingFile, Charset.forName("UTF-8")); + Files.write("0\n", existingFile, StandardCharsets.UTF_8); Assert.assertTrue(existingFile.setLastModified(1000)); Assert.assertEquals(1000, existingFile.lastModified()); return Arrays.asList(Arrays.asList("0")); diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java index d09258e0e4a85..091ccbfd85cad 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java @@ -38,6 +38,7 @@ import java.io.Serializable; import java.net.ConnectException; import java.net.Socket; +import java.nio.charset.StandardCharsets; import java.util.concurrent.atomic.AtomicLong; public class JavaReceiverAPISuite implements Serializable { @@ -126,7 +127,8 @@ private void receive() { BufferedReader in = null; try { socket = new Socket(host, port); - in = new BufferedReader(new InputStreamReader(socket.getInputStream())); + in = new BufferedReader( + new InputStreamReader(socket.getInputStream(), StandardCharsets.UTF_8)); String userInput; while ((userInput = in.readLine()) != null) { store(userInput); diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index ca716cf4e6f11..9a3248b3e8175 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.streaming import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, ObjectOutputStream} +import java.nio.charset.StandardCharsets import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.JavaConverters._ import scala.reflect.ClassTag -import com.google.common.base.Charsets import com.google.common.io.Files import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} @@ -609,7 +609,7 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester */ def writeFile(i: Int, clock: Clock): Unit = { val file = new File(testDir, i.toString) - Files.write(i + "\n", file, Charsets.UTF_8) + Files.write(i + "\n", file, StandardCharsets.UTF_8) assert(file.setLastModified(clock.getTimeMillis())) // Check that the file's modification date is actually the value we wrote, since rounding or // truncation will break the test: diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index fa17b3a15c4b6..cc2a67187e710 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.streaming import java.io.{BufferedWriter, File, OutputStreamWriter} import java.net.{ServerSocket, Socket, SocketException} -import java.nio.charset.Charset +import java.nio.charset.StandardCharsets import java.util.concurrent._ import java.util.concurrent.atomic.AtomicInteger @@ -146,7 +146,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val testDir = Utils.createTempDir() // Create a file that exists before the StreamingContext is created: val existingFile = new File(testDir, "0") - Files.write("0\n", existingFile, Charset.forName("UTF-8")) + Files.write("0\n", existingFile, StandardCharsets.UTF_8) assert(existingFile.setLastModified(10000) && existingFile.lastModified === 10000) // Set up the streaming context and input streams @@ -369,7 +369,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val testDir = Utils.createTempDir() // Create a file that exists before the StreamingContext is created: val existingFile = new File(testDir, "0") - Files.write("0\n", existingFile, Charset.forName("UTF-8")) + Files.write("0\n", existingFile, StandardCharsets.UTF_8) assert(existingFile.setLastModified(10000) && existingFile.lastModified === 10000) // Set up the streaming context and input streams @@ -393,7 +393,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val input = Seq(1, 2, 3, 4, 5) input.foreach { i => val file = new File(testDir, i.toString) - Files.write(i + "\n", file, Charset.forName("UTF-8")) + Files.write(i + "\n", file, StandardCharsets.UTF_8) assert(file.setLastModified(clock.getTimeMillis())) assert(file.lastModified === clock.getTimeMillis()) logInfo("Created file " + file) @@ -448,7 +448,7 @@ class TestServer(portToBind: Int = 0) extends Logging { try { clientSocket.setTcpNoDelay(true) val outputStream = new BufferedWriter( - new OutputStreamWriter(clientSocket.getOutputStream)) + new OutputStreamWriter(clientSocket.getOutputStream, StandardCharsets.UTF_8)) while (clientSocket.isConnected) { val msg = queue.poll(100, TimeUnit.MILLISECONDS) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala index 403400904bac2..3b662ec1833aa 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala @@ -518,7 +518,7 @@ class MapWithStateSuite extends SparkFunSuite val mapWithStateStream = dstream.map { _ -> 1 }.mapWithState( StateSpec.function(runningCount)) - // Set internval make sure there is one RDD checkpointing + // Set interval make sure there is one RDD checkpointing mapWithStateStream.checkpoint(checkpointDuration) mapWithStateStream.stateSnapshots() } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala index faa9c4f0cbd6a..6406d53f8941a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala @@ -18,7 +18,7 @@ package org.apache.spark.streaming import java.io.{File, IOException} -import java.nio.charset.Charset +import java.nio.charset.StandardCharsets import java.util.UUID import scala.collection.JavaConverters._ @@ -371,7 +371,7 @@ class FileGeneratingThread(input: Seq[String], testDir: Path, interval: Long) val localFile = new File(localTestDir, (i + 1).toString) val hadoopFile = new Path(testDir, (i + 1).toString) val tempHadoopFile = new Path(testDir, ".tmp_" + (i + 1).toString) - Files.write(input(i) + "\n", localFile, Charset.forName("UTF-8")) + Files.write(input(i) + "\n", localFile, StandardCharsets.UTF_8) var tries = 0 var done = false while (!done && tries < maxTries) { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 45424f9bac05a..95c1609d8e9a0 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -202,13 +202,13 @@ class ReceivedBlockHandlerSuite blockManager = createBlockManager(12000, sparkConf) // there is not enough space to store this block in MEMORY, - // But BlockManager will be able to sereliaze this block to WAL + // But BlockManager will be able to serialize this block to WAL // and hence count returns correct value. testRecordcount(false, StorageLevel.MEMORY_ONLY, IteratorBlock((List.fill(70)(new Array[Byte](100))).iterator), blockManager, Some(70)) // there is not enough space to store this block in MEMORY, - // But BlockManager will be able to sereliaze this block to DISK + // But BlockManager will be able to serialize this block to DISK // and hence count returns correct value. testRecordcount(true, StorageLevel.MEMORY_AND_DISK, IteratorBlock((List.fill(70)(new Array[Byte](100))).iterator), blockManager, Some(70)) @@ -272,7 +272,7 @@ class ReceivedBlockHandlerSuite } /** - * Test storing of data using different types of Handler, StorageLevle and ReceivedBlocks + * Test storing of data using different types of Handler, StorageLevel and ReceivedBlocks * and verify the correct record count */ private def testRecordcount(isBlockManagedBasedBlockHandler: Boolean, diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala index a4871b460eb4d..6763ac64da287 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala @@ -97,7 +97,7 @@ class ReceiverInputDStreamSuite extends TestSuiteBase with BeforeAndAfterAll { assert(blockRDD.walRecordHandles.toSeq === blockInfos.map { _.walRecordHandleOption.get }) } - testWithWAL("createBlockRDD creates BlockRDD when some block info dont have WAL info") { + testWithWAL("createBlockRDD creates BlockRDD when some block info don't have WAL info") { receiverStream => val blockInfos1 = Seq.fill(2) { createBlockInfo(withWALInfo = true) } val blockInfos2 = Seq.fill(3) { createBlockInfo(withWALInfo = false) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala index 7a76cafc9a11c..484f3733e8423 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala @@ -182,7 +182,7 @@ class StateMapSuite extends SparkFunSuite { * * - These operations are done on a test map in "sets". After each set, the map is "copied" * to create a new map, and the next set of operations are done on the new one. This tests - * whether the map data persistes correctly across copies. + * whether the map data persist correctly across copies. * * - Within each set, there are a number of operations to test whether the map correctly * updates and removes data without affecting the parent state map. diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 197b3d143995a..2159edce2bf52 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -147,7 +147,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo } } - test("start with non-seriazable DStream checkpoints") { + test("start with non-serializable DStream checkpoints") { val checkpointDir = Utils.createTempDir() ssc = new StreamingContext(conf, batchDuration) ssc.checkpoint(checkpointDir.getAbsolutePath) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 82cd63bcafc97..8269963edffa8 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -56,7 +56,7 @@ private[streaming] class DummyInputDStream(ssc: StreamingContext) extends InputD /** * This is a input stream just for the testsuites. This is equivalent to a checkpointable, * replayable, reliable message queue like Kafka. It requires a sequence as input, and - * returns the i_th element at the i_th batch unde manual clock. + * returns the i_th element at the i_th batch under manual clock. */ class TestInputStream[T: ClassTag](_ssc: StreamingContext, input: Seq[Seq[T]], numPartitions: Int) extends InputDStream[T](_ssc) { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala index 6e95bb97105fd..498471b23b51e 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala @@ -115,7 +115,7 @@ private[yarn] class AMDelegationTokenRenewer( } } // Schedule update of credentials. This handles the case of updating the tokens right now - // as well, since the renenwal interval will be 0, and the thread will get scheduled + // as well, since the renewal interval will be 0, and the thread will get scheduled // immediately. scheduleRenewal(driverTokenRenewerRunnable) } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 0b5ceb768cc86..40cf9b68edaac 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -21,6 +21,7 @@ import java.io.{ByteArrayInputStream, DataInputStream, File, FileOutputStream, I OutputStreamWriter} import java.net.{InetAddress, UnknownHostException, URI} import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets import java.util.{Properties, UUID} import java.util.zip.{ZipEntry, ZipOutputStream} @@ -29,7 +30,6 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, ListBuffer, Map} import scala.util.{Failure, Success, Try} import scala.util.control.NonFatal -import com.google.common.base.Charsets.UTF_8 import com.google.common.base.Objects import com.google.common.io.Files import org.apache.hadoop.conf.Configuration @@ -619,7 +619,7 @@ private[spark] class Client( val props = new Properties() sparkConf.getAll.foreach { case (k, v) => props.setProperty(k, v) } confStream.putNextEntry(new ZipEntry(SPARK_CONF_FILE)) - val writer = new OutputStreamWriter(confStream, UTF_8) + val writer = new OutputStreamWriter(confStream, StandardCharsets.UTF_8) props.store(writer, "Spark configuration.") writer.flush() confStream.closeEntry() @@ -1087,9 +1087,9 @@ private[spark] class Client( val pyArchivesFile = new File(pyLibPath, "pyspark.zip") require(pyArchivesFile.exists(), "pyspark.zip not found; cannot run pyspark application in YARN mode.") - val py4jFile = new File(pyLibPath, "py4j-0.9.1-src.zip") + val py4jFile = new File(pyLibPath, "py4j-0.9.2-src.zip") require(py4jFile.exists(), - "py4j-0.9.1-src.zip not found; cannot run pyspark application in YARN mode.") + "py4j-0.9.2-src.zip not found; cannot run pyspark application in YARN mode.") Seq(pyArchivesFile.getAbsolutePath(), py4jFile.getAbsolutePath()) } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 9f91d182ebc32..9cdbd6da62185 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -186,9 +186,9 @@ private[yarn] class ExecutorRunnable( else { // If no java_opts specified, default to using -XX:+CMSIncrementalMode // It might be possible that other modes/config is being done in - // spark.executor.extraJavaOptions, so we dont want to mess with it. - // In our expts, using (default) throughput collector has severe perf ramnifications in - // multi-tennent machines + // spark.executor.extraJavaOptions, so we don't want to mess with it. + // In our expts, using (default) throughput collector has severe perf ramifications in + // multi-tenant machines // The options are based on // http://www.oracle.com/technetwork/java/gc-tuning-5-138395.html#0.0.0.%20When%20to%20Use // %20the%20Concurrent%20Low%20Pause%20Collector|outline diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index a96cb4957be88..e34cd8d1b710b 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -265,25 +265,52 @@ private[yarn] class YarnAllocator( // For locality unmatched and locality free container requests, cancel these container // requests, since required locality preference has been changed, recalculating using // container placement strategy. - val (localityMatched, localityUnMatched, localityFree) = splitPendingAllocationsByLocality( + val (localRequests, staleRequests, anyHostRequests) = splitPendingAllocationsByLocality( hostToLocalTaskCounts, pendingAllocate) - // Remove the outdated container request and recalculate the requested container number - localityUnMatched.foreach(amClient.removeContainerRequest) - localityFree.foreach(amClient.removeContainerRequest) - val updatedNumContainer = missing + localityUnMatched.size + localityFree.size + // cancel "stale" requests for locations that are no longer needed + staleRequests.foreach { stale => + amClient.removeContainerRequest(stale) + } + val cancelledContainers = staleRequests.size + logInfo(s"Canceled $cancelledContainers container requests (locality no longer needed)") + + // consider the number of new containers and cancelled stale containers available + val availableContainers = missing + cancelledContainers + + // to maximize locality, include requests with no locality preference that can be cancelled + val potentialContainers = availableContainers + anyHostRequests.size val containerLocalityPreferences = containerPlacementStrategy.localityOfRequestedContainers( - updatedNumContainer, numLocalityAwareTasks, hostToLocalTaskCounts, - allocatedHostToContainersMap, localityMatched) + potentialContainers, numLocalityAwareTasks, hostToLocalTaskCounts, + allocatedHostToContainersMap, localRequests) + + val newLocalityRequests = new mutable.ArrayBuffer[ContainerRequest] + containerLocalityPreferences.foreach { + case ContainerLocalityPreferences(nodes, racks) if nodes != null => + newLocalityRequests.append(createContainerRequest(resource, nodes, racks)) + case _ => + } - for (locality <- containerLocalityPreferences) { - val request = createContainerRequest(resource, locality.nodes, locality.racks) + if (availableContainers >= newLocalityRequests.size) { + // more containers are available than needed for locality, fill in requests for any host + for (i <- 0 until (availableContainers - newLocalityRequests.size)) { + newLocalityRequests.append(createContainerRequest(resource, null, null)) + } + } else { + val numToCancel = newLocalityRequests.size - availableContainers + // cancel some requests without locality preferences to schedule more local containers + anyHostRequests.slice(0, numToCancel).foreach { nonLocal => + amClient.removeContainerRequest(nonLocal) + } + logInfo(s"Canceled $numToCancel container requests for any host to resubmit with locality") + } + + newLocalityRequests.foreach { request => amClient.addContainerRequest(request) - val nodes = request.getNodes - val hostStr = if (nodes == null || nodes.isEmpty) "Any" else nodes.asScala.last - logInfo(s"Container request (host: $hostStr, capability: $resource)") + logInfo(s"Submitted container request (host: ${hostStr(request)}, capability: $resource)") } + } else if (missing < 0) { val numToCancel = math.min(numPendingAllocate, -missing) logInfo(s"Canceling requests for $numToCancel executor containers") @@ -298,6 +325,13 @@ private[yarn] class YarnAllocator( } } + private def hostStr(request: ContainerRequest): String = { + Option(request.getNodes) match { + case Some(nodes) => nodes.asScala.mkString(",") + case None => "Any" + } + } + /** * Creates a container request, handling the reflection required to use YARN features that were * added in recent versions. diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index ed56d4bd44fe8..2915e664beffe 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -65,7 +65,7 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { override def isYarnMode(): Boolean = { true } // Return an appropriate (subclass) of Configuration. Creating a config initializes some Hadoop - // subsystems. Always create a new config, dont reuse yarnConf. + // subsystems. Always create a new config, don't reuse yarnConf. override def newConfiguration(conf: SparkConf): Configuration = new YarnConfiguration(super.newConfiguration(conf)) @@ -217,7 +217,7 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { // the hive configuration class is a subclass of Hadoop Configuration, so can be cast down // to a Configuration and used without reflection val hiveConfClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.conf.HiveConf") - // using the (Configuration, Class) constructor allows the current configuratin to be included + // using the (Configuration, Class) constructor allows the current configuration to be included // in the hive config. val ctor = hiveConfClass.getDeclaredConstructor(classOf[Configuration], classOf[Object].getClass) @@ -502,7 +502,7 @@ object YarnSparkHadoopUtil { /** * Getting the initial target number of executors depends on whether dynamic allocation is * enabled. - * If not using dynamic allocation it gets the number of executors reqeusted by the user. + * If not using dynamic allocation it gets the number of executors requested by the user. */ def getInitialTargetExecutorNumber( conf: SparkConf, diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala index b12e506033e39..78b57da482f70 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.yarn import java.io.{File, FileOutputStream, OutputStreamWriter} +import java.nio.charset.StandardCharsets import java.util.Properties import java.util.concurrent.TimeUnit @@ -25,7 +26,6 @@ import scala.collection.JavaConverters._ import scala.concurrent.duration._ import scala.language.postfixOps -import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.apache.commons.lang3.SerializationUtils import org.apache.hadoop.yarn.conf.YarnConfiguration @@ -75,7 +75,7 @@ abstract class BaseYarnClusterSuite System.setProperty("SPARK_YARN_MODE", "true") val logConfFile = new File(logConfDir, "log4j.properties") - Files.write(LOG4J_CONF, logConfFile, UTF_8) + Files.write(LOG4J_CONF, logConfFile, StandardCharsets.UTF_8) // Disable the disk utilization check to avoid the test hanging when people's disks are // getting full. @@ -191,7 +191,7 @@ abstract class BaseYarnClusterSuite result: File, expected: String): Unit = { finalState should be (SparkAppHandle.State.FINISHED) - val resultString = Files.toString(result, UTF_8) + val resultString = Files.toString(result, StandardCharsets.UTF_8) resultString should be (expected) } @@ -231,7 +231,7 @@ abstract class BaseYarnClusterSuite extraConf.foreach { case (k, v) => props.setProperty(k, v) } val propsFile = File.createTempFile("spark", ".properties", tempDir) - val writer = new OutputStreamWriter(new FileOutputStream(propsFile), UTF_8) + val writer = new OutputStreamWriter(new FileOutputStream(propsFile), StandardCharsets.UTF_8) props.store(writer, "Spark properties.") writer.close() propsFile.getAbsolutePath() diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index e935163c3487f..8a92a7ecda544 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -19,13 +19,13 @@ package org.apache.spark.deploy.yarn import java.io.File import java.net.URL +import java.nio.charset.StandardCharsets import java.util.{HashMap => JHashMap} import scala.collection.mutable import scala.concurrent.duration._ import scala.language.postfixOps -import com.google.common.base.Charsets.UTF_8 import com.google.common.io.{ByteStreams, Files} import org.apache.hadoop.yarn.conf.YarnConfiguration import org.scalatest.Matchers @@ -147,14 +147,14 @@ class YarnClusterSuite extends BaseYarnClusterSuite { private def testPySpark(clientMode: Boolean): Unit = { val primaryPyFile = new File(tempDir, "test.py") - Files.write(TEST_PYFILE, primaryPyFile, UTF_8) + Files.write(TEST_PYFILE, primaryPyFile, StandardCharsets.UTF_8) // When running tests, let's not assume the user has built the assembly module, which also // creates the pyspark archive. Instead, let's use PYSPARK_ARCHIVES_PATH to point at the // needed locations. val sparkHome = sys.props("spark.test.home") val pythonPath = Seq( - s"$sparkHome/python/lib/py4j-0.9.1-src.zip", + s"$sparkHome/python/lib/py4j-0.9.2-src.zip", s"$sparkHome/python") val extraEnv = Map( "PYSPARK_ARCHIVES_PATH" -> pythonPath.map("local:" + _).mkString(File.pathSeparator), @@ -171,7 +171,7 @@ class YarnClusterSuite extends BaseYarnClusterSuite { subdir } val pyModule = new File(moduleDir, "mod1.py") - Files.write(TEST_PYMODULE, pyModule, UTF_8) + Files.write(TEST_PYMODULE, pyModule, StandardCharsets.UTF_8) val mod2Archive = TestUtils.createJarWithFiles(Map("mod2.py" -> TEST_PYMODULE), moduleDir) val pyFiles = Seq(pyModule.getAbsolutePath(), mod2Archive.getPath()).mkString(",") @@ -245,7 +245,7 @@ private object YarnClusterDriver extends Logging with Matchers { data should be (Set(1, 2, 3, 4)) result = "success" } finally { - Files.write(result, status, UTF_8) + Files.write(result, status, StandardCharsets.UTF_8) sc.stop() } @@ -319,14 +319,14 @@ private object YarnClasspathTest extends Logging { val ccl = Thread.currentThread().getContextClassLoader() val resource = ccl.getResourceAsStream("test.resource") val bytes = ByteStreams.toByteArray(resource) - result = new String(bytes, 0, bytes.length, UTF_8) + result = new String(bytes, 0, bytes.length, StandardCharsets.UTF_8) } catch { case t: Throwable => error(s"loading test.resource to $resultPath", t) // set the exit code if not yet set exitCode = 2 } finally { - Files.write(result, new File(resultPath), UTF_8) + Files.write(result, new File(resultPath), StandardCharsets.UTF_8) } } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala index c17e8695c24fb..05c1e1613dd35 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.deploy.yarn import java.io.File +import java.nio.charset.StandardCharsets -import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.apache.commons.io.FileUtils import org.apache.hadoop.yarn.conf.YarnConfiguration @@ -78,7 +78,7 @@ private object YarnExternalShuffleDriver extends Logging with Matchers { s""" |Invalid command line: ${args.mkString(" ")} | - |Usage: ExternalShuffleDriver [result file] [registed exec file] + |Usage: ExternalShuffleDriver [result file] [registered exec file] """.stripMargin) // scalastyle:on println System.exit(1) @@ -104,7 +104,7 @@ private object YarnExternalShuffleDriver extends Logging with Matchers { } finally { sc.stop() FileUtils.deleteDirectory(execStateCopy) - Files.write(result, status, UTF_8) + Files.write(result, status, StandardCharsets.UTF_8) } } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index 9202bd892f01b..70b8732946a2b 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.deploy.yarn import java.io.{File, IOException} import java.lang.reflect.InvocationTargetException +import java.nio.charset.StandardCharsets import com.google.common.io.{ByteStreams, Files} import org.apache.hadoop.conf.Configuration @@ -59,7 +60,7 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging val args = Array("arg1", "${arg.2}", "\"arg3\"", "'arg4'", "$arg5", "\\arg6") try { val argLine = args.map(a => YarnSparkHadoopUtil.escapeForShell(a)).mkString(" ") - Files.write(("bash -c \"echo " + argLine + "\"").getBytes(), scriptFile) + Files.write(("bash -c \"echo " + argLine + "\"").getBytes(StandardCharsets.UTF_8), scriptFile) scriptFile.setExecutable(true) val proc = Runtime.getRuntime().exec(Array(scriptFile.getAbsolutePath()))