diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index a329e14f25aeb..b2d92bdf4840e 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -29,6 +29,7 @@ exportMethods("arrange", "count", "crosstab", "describe", + "dim", "distinct", "dropna", "dtypes", @@ -45,11 +46,16 @@ exportMethods("arrange", "isLocal", "join", "limit", + "merge", + "names", + "ncol", + "nrow", "orderBy", "mutate", "names", "persist", "printSchema", + "rbind", "registerTempTable", "rename", "repartition", @@ -64,8 +70,10 @@ exportMethods("arrange", "show", "showDF", "summarize", + "summary", "take", "unionAll", + "unique", "unpersist", "where", "withColumn", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index f4c93d3c7dd67..895603235011e 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -255,6 +255,16 @@ setMethod("names", columns(x) }) +#' @rdname columns +setMethod("names<-", + signature(x = "DataFrame"), + function(x, value) { + if (!is.null(value)) { + sdf <- callJMethod(x@sdf, "toDF", listToSeq(as.list(value))) + dataFrame(sdf) + } + }) + #' Register Temporary Table #' #' Registers a DataFrame as a Temporary Table in the SQLContext @@ -473,6 +483,18 @@ setMethod("distinct", dataFrame(sdf) }) +#' @title Distinct rows in a DataFrame +# +#' @description Returns a new DataFrame containing distinct rows in this DataFrame +#' +#' @rdname unique +#' @aliases unique +setMethod("unique", + signature(x = "DataFrame"), + function(x) { + distinct(x) + }) + #' Sample #' #' Return a sampled subset of this DataFrame using a random seed. @@ -534,6 +556,58 @@ setMethod("count", callJMethod(x@sdf, "count") }) +#' @title Number of rows for a DataFrame +#' @description Returns number of rows in a DataFrames +#' +#' @name nrow +#' +#' @rdname nrow +#' @aliases count +setMethod("nrow", + signature(x = "DataFrame"), + function(x) { + count(x) + }) + +#' Returns the number of columns in a DataFrame +#' +#' @param x a SparkSQL DataFrame +#' +#' @rdname ncol +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' ncol(df) +#' } +setMethod("ncol", + signature(x = "DataFrame"), + function(x) { + length(columns(x)) + }) + +#' Returns the dimentions (number of rows and columns) of a DataFrame +#' @param x a SparkSQL DataFrame +#' +#' @rdname dim +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' dim(df) +#' } +setMethod("dim", + signature(x = "DataFrame"), + function(x) { + c(count(x), ncol(x)) + }) + #' Collects all the elements of a Spark DataFrame and coerces them into an R data.frame. #' #' @param x A SparkSQL DataFrame @@ -1205,6 +1279,15 @@ setMethod("join", dataFrame(sdf) }) +#' rdname merge +#' aliases join +setMethod("merge", + signature(x = "DataFrame", y = "DataFrame"), + function(x, y, joinExpr = NULL, joinType = NULL, ...) { + join(x, y, joinExpr, joinType) + }) + + #' UnionAll #' #' Return a new DataFrame containing the union of rows in this DataFrame @@ -1231,6 +1314,22 @@ setMethod("unionAll", dataFrame(unioned) }) +#' @title Union two or more DataFrames +# +#' @description Returns a new DataFrame containing rows of all parameters. +# +#' @rdname rbind +#' @aliases unionAll +setMethod("rbind", + signature(... = "DataFrame"), + function(x, ..., deparse.level = 1) { + if (nargs() == 3) { + unionAll(x, ...) + } else { + unionAll(x, Recall(..., deparse.level = 1)) + } + }) + #' Intersect #' #' Return a new DataFrame containing rows only in both this DataFrame @@ -1322,9 +1421,11 @@ setMethod("write.df", "org.apache.spark.sql.parquet") } allModes <- c("append", "overwrite", "error", "ignore") + # nolint start if (!(mode %in% allModes)) { stop('mode should be one of "append", "overwrite", "error", "ignore"') } + # nolint end jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) options <- varargsToEnv(...) if (!is.null(path)) { @@ -1384,9 +1485,11 @@ setMethod("saveAsTable", "org.apache.spark.sql.parquet") } allModes <- c("append", "overwrite", "error", "ignore") + # nolint start if (!(mode %in% allModes)) { stop('mode should be one of "append", "overwrite", "error", "ignore"') } + # nolint end jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) options <- varargsToEnv(...) callJMethod(df@sdf, "saveAsTable", tableName, source, jmode, options) @@ -1430,6 +1533,19 @@ setMethod("describe", dataFrame(sdf) }) +#' @title Summary +#' +#' @description Computes statistics for numeric columns of the DataFrame +#' +#' @rdname summary +#' @aliases describe +setMethod("summary", + signature(x = "DataFrame"), + function(x) { + describe(x) + }) + + #' dropna #' #' Returns a new DataFrame omitting rows with null values. diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index d2d096709245d..051e441d4e063 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -85,7 +85,9 @@ setMethod("initialize", "PipelinedRDD", function(.Object, prev, func, jrdd_val) isPipelinable <- function(rdd) { e <- rdd@env + # nolint start !(e$isCached || e$isCheckpointed) + # nolint end } if (!inherits(prev, "PipelinedRDD") || !isPipelinable(prev)) { @@ -97,7 +99,8 @@ setMethod("initialize", "PipelinedRDD", function(.Object, prev, func, jrdd_val) # prev_serializedMode is used during the delayed computation of JRDD in getJRDD } else { pipelinedFunc <- function(partIndex, part) { - func(partIndex, prev@func(partIndex, part)) + f <- prev@func + func(partIndex, f(partIndex, part)) } .Object@func <- cleanClosure(pipelinedFunc) .Object@prev_jrdd <- prev@prev_jrdd # maintain the pipeline @@ -841,7 +844,7 @@ setMethod("sampleRDD", if (withReplacement) { count <- rpois(1, fraction) if (count > 0) { - res[(len + 1):(len + count)] <- rep(list(elem), count) + res[ (len + 1) : (len + count) ] <- rep(list(elem), count) len <- len + count } } else { @@ -1261,12 +1264,12 @@ setMethod("pipeRDD", signature(x = "RDD", command = "character"), function(x, command, env = list()) { func <- function(part) { - trim.trailing.func <- function(x) { + trim_trailing_func <- function(x) { sub("[\r\n]*$", "", toString(x)) } - input <- unlist(lapply(part, trim.trailing.func)) + input <- unlist(lapply(part, trim_trailing_func)) res <- system2(command, stdout = TRUE, input = input, env = env) - lapply(res, trim.trailing.func) + lapply(res, trim_trailing_func) } lapplyPartition(x, func) }) diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 2892e1416cc65..eeaf9f193b728 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -65,7 +65,7 @@ functions <- c("min", "max", "sum", "avg", "mean", "count", "abs", "sqrt", "acos", "asin", "atan", "cbrt", "ceiling", "cos", "cosh", "exp", "expm1", "floor", "log", "log10", "log1p", "rint", "sign", "sin", "sinh", "tan", "tanh", "toDegrees", "toRadians") -binary_mathfunctions<- c("atan2", "hypot") +binary_mathfunctions <- c("atan2", "hypot") createOperator <- function(op) { setMethod(op, diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 43be9c904fdf6..720990e1c6087 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -121,7 +121,7 @@ parallelize <- function(sc, coll, numSlices = 1) { numSlices <- length(coll) sliceLen <- ceiling(length(coll) / numSlices) - slices <- split(coll, rep(1:(numSlices + 1), each = sliceLen)[1:length(coll)]) + slices <- split(coll, rep(1: (numSlices + 1), each = sliceLen)[1:length(coll)]) # Serialize each slice: obtain a list of raws, or a list of lists (slices) of # 2-tuples of raws diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index a3a121058e165..c43b947129e87 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -461,6 +461,10 @@ setGeneric("isLocal", function(x) { standardGeneric("isLocal") }) #' @export setGeneric("limit", function(x, num) {standardGeneric("limit") }) +#' rdname merge +#' @export +setGeneric("merge") + #' @rdname withColumn #' @export setGeneric("mutate", function(x, ...) {standardGeneric("mutate") }) @@ -531,6 +535,10 @@ setGeneric("showDF", function(x,...) { standardGeneric("showDF") }) #' @export setGeneric("summarize", function(x,...) { standardGeneric("summarize") }) +##' rdname summary +##' @export +setGeneric("summary", function(x, ...) { standardGeneric("summary") }) + # @rdname tojson # @export setGeneric("toJSON", function(x) { standardGeneric("toJSON") }) @@ -669,3 +677,7 @@ setGeneric("upper", function(x) { standardGeneric("upper") }) #' @rdname glm #' @export setGeneric("glm") + +#' @rdname rbind +#' @export +setGeneric("rbind", signature = "...") diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index efddcc1d8d71c..b524d1fd87496 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -86,12 +86,12 @@ setMethod("predict", signature(object = "PipelineModel"), #' model <- glm(y ~ x, trainingData) #' summary(model) #'} -setMethod("summary", signature(object = "PipelineModel"), - function(object) { +setMethod("summary", signature(x = "PipelineModel"), + function(x, ...) { features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelFeatures", object@model) + "getModelFeatures", x@model) weights <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelWeights", object@model) + "getModelWeights", x@model) coefficients <- as.matrix(unlist(weights)) colnames(coefficients) <- c("Estimate") rownames(coefficients) <- unlist(features) diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 83801d3209700..199c3fd6ab1b2 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -879,7 +879,7 @@ setMethod("sampleByKey", if (withReplacement) { count <- rpois(1, frac) if (count > 0) { - res[(len + 1):(len + count)] <- rep(list(elem), count) + res[ (len + 1) : (len + count) ] <- rep(list(elem), count) len <- len + count } } else { diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 3f45589a50443..4f9f4d9cad2a8 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -32,7 +32,7 @@ convertJListToRList <- function(jList, flatten, logicalUpperBound = NULL, } results <- if (arrSize > 0) { - lapply(0:(arrSize - 1), + lapply(0 : (arrSize - 1), function(index) { obj <- callJMethod(jList, "get", as.integer(index)) @@ -572,7 +572,7 @@ mergePartitions <- function(rdd, zip) { keys <- list() } if (lengthOfValues > 1) { - values <- part[(lengthOfKeys + 1) : (len - 1)] + values <- part[ (lengthOfKeys + 1) : (len - 1) ] } else { values <- list() } diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/test_binary_function.R index dca0657c57e0d..f054ac9a87d61 100644 --- a/R/pkg/inst/tests/test_binary_function.R +++ b/R/pkg/inst/tests/test_binary_function.R @@ -40,7 +40,7 @@ test_that("union on two RDDs", { expect_equal(actual, c(as.list(nums), mockFile)) expect_equal(getSerializedMode(union.rdd), "byte") - rdd<- map(text.rdd, function(x) {x}) + rdd <- map(text.rdd, function(x) {x}) union.rdd <- unionRDD(rdd, text.rdd) actual <- collect(union.rdd) expect_equal(actual, as.list(c(mockFile, mockFile))) diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R index 6c3aaab8c711e..71aed2bb9d6a8 100644 --- a/R/pkg/inst/tests/test_rdd.R +++ b/R/pkg/inst/tests/test_rdd.R @@ -250,7 +250,7 @@ test_that("flatMapValues() on pairwise RDDs", { expect_equal(actual, list(list(1,1), list(1,2), list(2,3), list(2,4))) # Generate x to x+1 for every value - actual <- collect(flatMapValues(intRdd, function(x) { x:(x + 1) })) + actual <- collect(flatMapValues(intRdd, function(x) { x: (x + 1) })) expect_equal(actual, list(list(1L, -1), list(1L, 0), list(2L, 100), list(2L, 101), list(2L, 1), list(2L, 2), list(1L, 200), list(1L, 201))) @@ -293,7 +293,7 @@ test_that("sumRDD() on RDDs", { }) test_that("keyBy on RDDs", { - func <- function(x) { x*x } + func <- function(x) { x * x } keys <- keyBy(rdd, func) actual <- collect(keys) expect_equal(actual, lapply(nums, function(x) { list(func(x), x) })) @@ -311,7 +311,7 @@ test_that("repartition/coalesce on RDDs", { r2 <- repartition(rdd, 6) expect_equal(numPartitions(r2), 6L) count <- length(collectPartition(r2, 0L)) - expect_true(count >=0 && count <= 4) + expect_true(count >= 0 && count <= 4) # coalesce r3 <- coalesce(rdd, 1) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 61c8a7ec7d837..7377fc8f1ca9c 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -88,6 +88,9 @@ test_that("create DataFrame from RDD", { df <- createDataFrame(sqlContext, rdd, list("a", "b")) expect_is(df, "DataFrame") expect_equal(count(df), 10) + expect_equal(nrow(df), 10) + expect_equal(ncol(df), 2) + expect_equal(dim(df), c(10, 2)) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) @@ -128,7 +131,9 @@ test_that("create DataFrame from RDD", { expect_equal(dtypes(df2), list(c("name", "string"), c("age", "int"), c("height", "float"))) expect_equal(collect(where(df2, df2$name == "Bob")), c("Bob", 16, 176.5)) - localDF <- data.frame(name=c("John", "Smith", "Sarah"), age=c(19, 23, 18), height=c(164.10, 181.4, 173.7)) + localDF <- data.frame(name=c("John", "Smith", "Sarah"), + age=c(19, 23, 18), + height=c(164.10, 181.4, 173.7)) df <- createDataFrame(sqlContext, localDF, schema) expect_is(df, "DataFrame") expect_equal(count(df), 3) @@ -489,7 +494,7 @@ test_that("head() and first() return the correct data", { expect_equal(nrow(testFirst), 1) }) -test_that("distinct() on DataFrames", { +test_that("distinct() and unique on DataFrames", { lines <- c("{\"name\":\"Michael\"}", "{\"name\":\"Andy\", \"age\":30}", "{\"name\":\"Justin\", \"age\":19}", @@ -501,6 +506,10 @@ test_that("distinct() on DataFrames", { uniques <- distinct(df) expect_is(uniques, "DataFrame") expect_equal(count(uniques), 3) + + uniques2 <- unique(df) + expect_is(uniques2, "DataFrame") + expect_equal(count(uniques2), 3) }) test_that("sample on a DataFrame", { @@ -666,10 +675,12 @@ test_that("column binary mathfunctions", { expect_equal(collect(select(df, atan2(df$a, df$b)))[2, "ATAN2(a, b)"], atan2(2, 6)) expect_equal(collect(select(df, atan2(df$a, df$b)))[3, "ATAN2(a, b)"], atan2(3, 7)) expect_equal(collect(select(df, atan2(df$a, df$b)))[4, "ATAN2(a, b)"], atan2(4, 8)) + ## nolint start expect_equal(collect(select(df, hypot(df$a, df$b)))[1, "HYPOT(a, b)"], sqrt(1^2 + 5^2)) expect_equal(collect(select(df, hypot(df$a, df$b)))[2, "HYPOT(a, b)"], sqrt(2^2 + 6^2)) expect_equal(collect(select(df, hypot(df$a, df$b)))[3, "HYPOT(a, b)"], sqrt(3^2 + 7^2)) expect_equal(collect(select(df, hypot(df$a, df$b)))[4, "HYPOT(a, b)"], sqrt(4^2 + 8^2)) + ## nolint end }) test_that("string operators", { @@ -754,7 +765,7 @@ test_that("filter() on a DataFrame", { expect_equal(count(filtered6), 2) }) -test_that("join() on a DataFrame", { +test_that("join() and merge() on a DataFrame", { df <- jsonFile(sqlContext, jsonPath) mockLines2 <- c("{\"name\":\"Michael\", \"test\": \"yes\"}", @@ -783,6 +794,12 @@ test_that("join() on a DataFrame", { expect_equal(names(joined4), c("newAge", "name", "test")) expect_equal(count(joined4), 4) expect_equal(collect(orderBy(joined4, joined4$name))$newAge[3], 24) + + merged <- select(merge(df, df2, df$name == df2$name, "outer"), + alias(df$age + 5, "newAge"), df$name, df2$test) + expect_equal(names(merged), c("newAge", "name", "test")) + expect_equal(count(merged), 4) + expect_equal(collect(orderBy(merged, joined4$name))$newAge[3], 24) }) test_that("toJSON() returns an RDD of the correct values", { @@ -811,7 +828,7 @@ test_that("isLocal()", { expect_false(isLocal(df)) }) -test_that("unionAll(), except(), and intersect() on a DataFrame", { +test_that("unionAll(), rbind(), except(), and intersect() on a DataFrame", { df <- jsonFile(sqlContext, jsonPath) lines <- c("{\"name\":\"Bob\", \"age\":24}", @@ -826,6 +843,11 @@ test_that("unionAll(), except(), and intersect() on a DataFrame", { expect_equal(count(unioned), 6) expect_equal(first(unioned)$name, "Michael") + unioned2 <- arrange(rbind(unioned, df, df2), df$age) + expect_is(unioned2, "DataFrame") + expect_equal(count(unioned2), 12) + expect_equal(first(unioned2)$name, "Michael") + excepted <- arrange(except(df, df2), desc(df$age)) expect_is(unioned, "DataFrame") expect_equal(count(excepted), 2) @@ -849,7 +871,7 @@ test_that("withColumn() and withColumnRenamed()", { expect_equal(columns(newDF2)[1], "newerAge") }) -test_that("mutate() and rename()", { +test_that("mutate(), rename() and names()", { df <- jsonFile(sqlContext, jsonPath) newDF <- mutate(df, newAge = df$age + 2) expect_equal(length(columns(newDF)), 3) @@ -859,6 +881,10 @@ test_that("mutate() and rename()", { newDF2 <- rename(df, newerAge = df$age) expect_equal(length(columns(newDF2)), 2) expect_equal(columns(newDF2)[1], "newerAge") + + names(newDF2) <- c("newerName", "evenNewerAge") + expect_equal(length(names(newDF2)), 2) + expect_equal(names(newDF2)[1], "newerName") }) test_that("write.df() on DataFrame and works with parquetFile", { @@ -876,10 +902,10 @@ test_that("parquetFile works with multiple input paths", { write.df(df, parquetPath2, "parquet", mode="overwrite") parquetDF <- parquetFile(sqlContext, parquetPath, parquetPath2) expect_is(parquetDF, "DataFrame") - expect_equal(count(parquetDF), count(df)*2) + expect_equal(count(parquetDF), count(df) * 2) }) -test_that("describe() on a DataFrame", { +test_that("describe() and summarize() on a DataFrame", { df <- jsonFile(sqlContext, jsonPath) stats <- describe(df, "age") expect_equal(collect(stats)[1, "summary"], "count") @@ -888,6 +914,10 @@ test_that("describe() on a DataFrame", { stats <- describe(df) expect_equal(collect(stats)[4, "name"], "Andy") expect_equal(collect(stats)[5, "age"], "30") + + stats2 <- summary(df) + expect_equal(collect(stats2)[4, "name"], "Andy") + expect_equal(collect(stats2)[5, "age"], "30") }) test_that("dropna() on a DataFrame", { diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java similarity index 73% rename from unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java rename to core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 198e0684f32f8..481375f493a50 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -24,7 +24,10 @@ import java.util.List; import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.unsafe.*; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.array.LongArray; @@ -45,6 +48,8 @@ */ public final class BytesToBytesMap { + private final Logger logger = LoggerFactory.getLogger(BytesToBytesMap.class); + private static final Murmur3_x86_32 HASHER = new Murmur3_x86_32(0); private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING; @@ -54,7 +59,9 @@ public final class BytesToBytesMap { */ private static final int END_OF_PAGE_MARKER = -1; - private final TaskMemoryManager memoryManager; + private final TaskMemoryManager taskMemoryManager; + + private final ShuffleMemoryManager shuffleMemoryManager; /** * A linked list for tracking all allocated data pages so that we can free all of our memory. @@ -120,7 +127,7 @@ public final class BytesToBytesMap { /** * Number of keys defined in the map. */ - private int size; + private int numElements; /** * The map will be expanded once the number of keys exceeds this threshold. @@ -150,12 +157,14 @@ public final class BytesToBytesMap { private long numHashCollisions = 0; public BytesToBytesMap( - TaskMemoryManager memoryManager, + TaskMemoryManager taskMemoryManager, + ShuffleMemoryManager shuffleMemoryManager, int initialCapacity, double loadFactor, long pageSizeBytes, boolean enablePerfMetrics) { - this.memoryManager = memoryManager; + this.taskMemoryManager = taskMemoryManager; + this.shuffleMemoryManager = shuffleMemoryManager; this.loadFactor = loadFactor; this.loc = new Location(); this.pageSizeBytes = pageSizeBytes; @@ -175,26 +184,34 @@ public BytesToBytesMap( } public BytesToBytesMap( - TaskMemoryManager memoryManager, + TaskMemoryManager taskMemoryManager, + ShuffleMemoryManager shuffleMemoryManager, int initialCapacity, long pageSizeBytes) { - this(memoryManager, initialCapacity, 0.70, pageSizeBytes, false); + this(taskMemoryManager, shuffleMemoryManager, initialCapacity, 0.70, pageSizeBytes, false); } public BytesToBytesMap( - TaskMemoryManager memoryManager, + TaskMemoryManager taskMemoryManager, + ShuffleMemoryManager shuffleMemoryManager, int initialCapacity, long pageSizeBytes, boolean enablePerfMetrics) { - this(memoryManager, initialCapacity, 0.70, pageSizeBytes, enablePerfMetrics); + this( + taskMemoryManager, + shuffleMemoryManager, + initialCapacity, + 0.70, + pageSizeBytes, + enablePerfMetrics); } /** * Returns the number of keys defined in the map. */ - public int size() { return size; } + public int numElements() { return numElements; } - private static final class BytesToBytesMapIterator implements Iterator { + public static final class BytesToBytesMapIterator implements Iterator { private final int numRecords; private final Iterator dataPagesIterator; @@ -204,7 +221,8 @@ private static final class BytesToBytesMapIterator implements Iterator private Object pageBaseObject; private long offsetInPage; - BytesToBytesMapIterator(int numRecords, Iterator dataPagesIterator, Location loc) { + private BytesToBytesMapIterator( + int numRecords, Iterator dataPagesIterator, Location loc) { this.numRecords = numRecords; this.dataPagesIterator = dataPagesIterator; this.loc = loc; @@ -226,13 +244,13 @@ public boolean hasNext() { @Override public Location next() { - int keyLength = (int) PlatformDependent.UNSAFE.getLong(pageBaseObject, offsetInPage); - if (keyLength == END_OF_PAGE_MARKER) { + int totalLength = PlatformDependent.UNSAFE.getInt(pageBaseObject, offsetInPage); + if (totalLength == END_OF_PAGE_MARKER) { advanceToNextPage(); - keyLength = (int) PlatformDependent.UNSAFE.getLong(pageBaseObject, offsetInPage); + totalLength = PlatformDependent.UNSAFE.getInt(pageBaseObject, offsetInPage); } loc.with(pageBaseObject, offsetInPage); - offsetInPage += 8 + 8 + keyLength + loc.getValueLength(); + offsetInPage += 8 + totalLength; currentRecordNumber++; return loc; } @@ -251,8 +269,8 @@ public void remove() { * If any other lookups or operations are performed on this map while iterating over it, including * `lookup()`, the behavior of the returned iterator is undefined. */ - public Iterator iterator() { - return new BytesToBytesMapIterator(size, dataPages.iterator(), loc); + public BytesToBytesMapIterator iterator() { + return new BytesToBytesMapIterator(numElements, dataPages.iterator(), loc); } /** @@ -330,18 +348,22 @@ public final class Location { private void updateAddressesAndSizes(long fullKeyAddress) { updateAddressesAndSizes( - memoryManager.getPage(fullKeyAddress), memoryManager.getOffsetInPage(fullKeyAddress)); + taskMemoryManager.getPage(fullKeyAddress), + taskMemoryManager.getOffsetInPage(fullKeyAddress)); } - private void updateAddressesAndSizes(Object page, long keyOffsetInPage) { - long position = keyOffsetInPage; - keyLength = (int) PlatformDependent.UNSAFE.getLong(page, position); - position += 8; // word used to store the key size - keyMemoryLocation.setObjAndOffset(page, position); - position += keyLength; - valueLength = (int) PlatformDependent.UNSAFE.getLong(page, position); - position += 8; // word used to store the key size - valueMemoryLocation.setObjAndOffset(page, position); + private void updateAddressesAndSizes(final Object page, final long keyOffsetInPage) { + long position = keyOffsetInPage; + final int totalLength = PlatformDependent.UNSAFE.getInt(page, position); + position += 4; + keyLength = PlatformDependent.UNSAFE.getInt(page, position); + position += 4; + valueLength = totalLength - keyLength; + + keyMemoryLocation.setObjAndOffset(page, position); + + position += keyLength; + valueMemoryLocation.setObjAndOffset(page, position); } Location with(int pos, int keyHashcode, boolean isDefined) { @@ -411,7 +433,8 @@ public int getValueLength() { /** * Store a new key and value. This method may only be called once for a given key; if you want * to update the value associated with a key, then you can directly manipulate the bytes stored - * at the value address. + * at the value address. The return value indicates whether the put succeeded or whether it + * failed because additional memory could not be acquired. *

* It is only valid to call this method immediately after calling `lookup()` using the same key. *

@@ -428,14 +451,19 @@ public int getValueLength() { *
      *   Location loc = map.lookup(keyBaseObject, keyBaseOffset, keyLengthInBytes);
      *   if (!loc.isDefined()) {
-     *     loc.putNewKey(keyBaseObject, keyBaseOffset, keyLengthInBytes, ...)
+     *     if (!loc.putNewKey(keyBaseObject, keyBaseOffset, keyLengthInBytes, ...)) {
+     *       // handle failure to grow map (by spilling, for example)
+     *     }
      *   }
      * 
*

* Unspecified behavior if the key is not defined. *

+ * + * @return true if the put() was successful and false if the put() failed because memory could + * not be acquired. */ - public void putNewKey( + public boolean putNewKey( Object keyBaseObject, long keyBaseOffset, int keyLengthBytes, @@ -445,63 +473,111 @@ public void putNewKey( assert (!isDefined) : "Can only set value once for a key"; assert (keyLengthBytes % 8 == 0); assert (valueLengthBytes % 8 == 0); - if (size == MAX_CAPACITY) { + if (numElements == MAX_CAPACITY) { throw new IllegalStateException("BytesToBytesMap has reached maximum capacity"); } + // Here, we'll copy the data into our data pages. Because we only store a relative offset from // the key address instead of storing the absolute address of the value, the key and value // must be stored in the same memory page. // (8 byte key length) (key) (8 byte value length) (value) - final long requiredSize = 8 + keyLengthBytes + 8 + valueLengthBytes; - assert (requiredSize <= pageSizeBytes - 8); // Reserve 8 bytes for the end-of-page marker. - size++; - bitset.set(pos); - - // If there's not enough space in the current page, allocate a new page (8 bytes are reserved - // for the end-of-page marker). - if (currentDataPage == null || pageSizeBytes - 8 - pageCursor < requiredSize) { + final long requiredSize = 8 + keyLengthBytes + valueLengthBytes; + + // --- Figure out where to insert the new record --------------------------------------------- + + final MemoryBlock dataPage; + final Object dataPageBaseObject; + final long dataPageInsertOffset; + boolean useOverflowPage = requiredSize > pageSizeBytes - 8; + if (useOverflowPage) { + // The record is larger than the page size, so allocate a special overflow page just to hold + // that record. + final long memoryRequested = requiredSize + 8; + final long memoryGranted = shuffleMemoryManager.tryToAcquire(memoryRequested); + if (memoryGranted != memoryRequested) { + shuffleMemoryManager.release(memoryGranted); + logger.debug("Failed to acquire {} bytes of memory", memoryRequested); + return false; + } + MemoryBlock overflowPage = taskMemoryManager.allocatePage(memoryRequested); + dataPages.add(overflowPage); + dataPage = overflowPage; + dataPageBaseObject = overflowPage.getBaseObject(); + dataPageInsertOffset = overflowPage.getBaseOffset(); + } else if (currentDataPage == null || pageSizeBytes - 8 - pageCursor < requiredSize) { + // The record can fit in a data page, but either we have not allocated any pages yet or + // the current page does not have enough space. if (currentDataPage != null) { // There wasn't enough space in the current page, so write an end-of-page marker: final Object pageBaseObject = currentDataPage.getBaseObject(); final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor; - PlatformDependent.UNSAFE.putLong(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER); + PlatformDependent.UNSAFE.putInt(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER); + } + final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes); + if (memoryGranted != pageSizeBytes) { + shuffleMemoryManager.release(memoryGranted); + logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes); + return false; } - MemoryBlock newPage = memoryManager.allocatePage(pageSizeBytes); + MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes); dataPages.add(newPage); pageCursor = 0; currentDataPage = newPage; + dataPage = currentDataPage; + dataPageBaseObject = currentDataPage.getBaseObject(); + dataPageInsertOffset = currentDataPage.getBaseOffset(); + } else { + // There is enough space in the current data page. + dataPage = currentDataPage; + dataPageBaseObject = currentDataPage.getBaseObject(); + dataPageInsertOffset = currentDataPage.getBaseOffset() + pageCursor; } - // Compute all of our offsets up-front: - final Object pageBaseObject = currentDataPage.getBaseObject(); - final long pageBaseOffset = currentDataPage.getBaseOffset(); - final long keySizeOffsetInPage = pageBaseOffset + pageCursor; - pageCursor += 8; // word used to store the key size - final long keyDataOffsetInPage = pageBaseOffset + pageCursor; - pageCursor += keyLengthBytes; - final long valueSizeOffsetInPage = pageBaseOffset + pageCursor; - pageCursor += 8; // word used to store the value size - final long valueDataOffsetInPage = pageBaseOffset + pageCursor; - pageCursor += valueLengthBytes; + // --- Append the key and value data to the current data page -------------------------------- + long insertCursor = dataPageInsertOffset; + + // Compute all of our offsets up-front: + final long totalLengthOffset = insertCursor; + insertCursor += 4; + final long keyLengthOffset = insertCursor; + insertCursor += 4; + final long keyDataOffsetInPage = insertCursor; + insertCursor += keyLengthBytes; + final long valueDataOffsetInPage = insertCursor; + insertCursor += valueLengthBytes; // word used to store the value size + + PlatformDependent.UNSAFE.putInt(dataPageBaseObject, totalLengthOffset, + keyLengthBytes + valueLengthBytes); + PlatformDependent.UNSAFE.putInt(dataPageBaseObject, keyLengthOffset, keyLengthBytes); // Copy the key - PlatformDependent.UNSAFE.putLong(pageBaseObject, keySizeOffsetInPage, keyLengthBytes); PlatformDependent.copyMemory( - keyBaseObject, keyBaseOffset, pageBaseObject, keyDataOffsetInPage, keyLengthBytes); + keyBaseObject, keyBaseOffset, dataPageBaseObject, keyDataOffsetInPage, keyLengthBytes); // Copy the value - PlatformDependent.UNSAFE.putLong(pageBaseObject, valueSizeOffsetInPage, valueLengthBytes); - PlatformDependent.copyMemory( - valueBaseObject, valueBaseOffset, pageBaseObject, valueDataOffsetInPage, valueLengthBytes); + PlatformDependent.copyMemory(valueBaseObject, valueBaseOffset, dataPageBaseObject, + valueDataOffsetInPage, valueLengthBytes); + + // --- Update bookeeping data structures ----------------------------------------------------- + + if (useOverflowPage) { + // Store the end-of-page marker at the end of the data page + PlatformDependent.UNSAFE.putInt(dataPageBaseObject, insertCursor, END_OF_PAGE_MARKER); + } else { + pageCursor += requiredSize; + } - final long storedKeyAddress = memoryManager.encodePageNumberAndOffset( - currentDataPage, keySizeOffsetInPage); + numElements++; + bitset.set(pos); + final long storedKeyAddress = taskMemoryManager.encodePageNumberAndOffset( + dataPage, totalLengthOffset); longArray.set(pos * 2, storedKeyAddress); longArray.set(pos * 2 + 1, keyHashcode); updateAddressesAndSizes(storedKeyAddress); isDefined = true; - if (size > growthThreshold && longArray.size() < MAX_CAPACITY) { + if (numElements > growthThreshold && longArray.size() < MAX_CAPACITY) { growAndRehash(); } + return true; } } @@ -516,7 +592,7 @@ private void allocate(int capacity) { // The capacity needs to be divisible by 64 so that our bit set can be sized properly capacity = Math.max((int) Math.min(MAX_CAPACITY, nextPowerOf2(capacity)), 64); assert (capacity <= MAX_CAPACITY); - longArray = new LongArray(memoryManager.allocate(capacity * 8L * 2)); + longArray = new LongArray(MemoryBlock.fromLongArray(new long[capacity * 2])); bitset = new BitSet(MemoryBlock.fromLongArray(new long[capacity / 64])); this.growthThreshold = (int) (capacity * loadFactor); @@ -530,18 +606,14 @@ private void allocate(int capacity) { * This method is idempotent. */ public void free() { - if (longArray != null) { - memoryManager.free(longArray.memoryBlock()); - longArray = null; - } - if (bitset != null) { - // The bitset's heap memory isn't managed by a memory manager, so no need to free it here. - bitset = null; - } + longArray = null; + bitset = null; Iterator dataPagesIterator = dataPages.iterator(); while (dataPagesIterator.hasNext()) { - memoryManager.freePage(dataPagesIterator.next()); + MemoryBlock dataPage = dataPagesIterator.next(); dataPagesIterator.remove(); + taskMemoryManager.freePage(dataPage); + shuffleMemoryManager.release(dataPage.size()); } assert(dataPages.isEmpty()); } @@ -628,8 +700,6 @@ void growAndRehash() { } } - // Deallocate the old data structures. - memoryManager.free(oldLongArray.memoryBlock()); if (enablePerfMetrics) { timeSpentResizingNs += System.nanoTime() - resizeStartTime; } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java b/core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java rename to core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 866e0b4151577..c05f2c332eee3 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -282,6 +282,21 @@ public void insertRecord( sorter.insertRecord(recordAddress, prefix); } + /** + * Write a record to the sorter. The record is broken down into two different parts, and + * + */ + public void insertRecord( + Object recordBaseObject1, + long recordBaseOffset1, + int lengthInBytes1, + Object recordBaseObject2, + long recordBaseOffset2, + int lengthInBytes2, + long prefix) throws IOException { + + } + public UnsafeSorterIterator getSortedIterator() throws IOException { final UnsafeSorterIterator inMemoryIterator = sorter.getSortedIterator(); int numIteratorsToMerge = spillWriters.size() + (inMemoryIterator.hasNext() ? 1 : 0); diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 4161792976c7b..08bab4bf2739f 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -548,7 +548,9 @@ private[spark] object SparkConf extends Logging { "spark.rpc.askTimeout" -> Seq( AlternateConfig("spark.akka.askTimeout", "1.4")), "spark.rpc.lookupTimeout" -> Seq( - AlternateConfig("spark.akka.lookupTimeout", "1.4")) + AlternateConfig("spark.akka.lookupTimeout", "1.4")), + "spark.streaming.fileStream.minRememberDuration" -> Seq( + AlternateConfig("spark.streaming.minRememberDuration", "1.5")) ) /** diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index ac6ac6c216767..2d8aa25d81daa 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1689,33 +1689,57 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli Utils.removeShutdownHook(_shutdownHookRef) } - postApplicationEnd() - _ui.foreach(_.stop()) + Utils.tryLogNonFatalError { + postApplicationEnd() + } + Utils.tryLogNonFatalError { + _ui.foreach(_.stop()) + } if (env != null) { - env.metricsSystem.report() + Utils.tryLogNonFatalError { + env.metricsSystem.report() + } } if (metadataCleaner != null) { - metadataCleaner.cancel() + Utils.tryLogNonFatalError { + metadataCleaner.cancel() + } + } + Utils.tryLogNonFatalError { + _cleaner.foreach(_.stop()) + } + Utils.tryLogNonFatalError { + _executorAllocationManager.foreach(_.stop()) } - _cleaner.foreach(_.stop()) - _executorAllocationManager.foreach(_.stop()) if (_dagScheduler != null) { - _dagScheduler.stop() + Utils.tryLogNonFatalError { + _dagScheduler.stop() + } _dagScheduler = null } if (_listenerBusStarted) { - listenerBus.stop() - _listenerBusStarted = false + Utils.tryLogNonFatalError { + listenerBus.stop() + _listenerBusStarted = false + } + } + Utils.tryLogNonFatalError { + _eventLogger.foreach(_.stop()) } - _eventLogger.foreach(_.stop()) if (env != null && _heartbeatReceiver != null) { - env.rpcEnv.stop(_heartbeatReceiver) + Utils.tryLogNonFatalError { + env.rpcEnv.stop(_heartbeatReceiver) + } + } + Utils.tryLogNonFatalError { + _progressBar.foreach(_.stop()) } - _progressBar.foreach(_.stop()) _taskScheduler = null // TODO: Cache.stop()? if (_env != null) { - _env.stop() + Utils.tryLogNonFatalError { + _env.stop() + } SparkEnv.set(null) } SparkContext.clearActiveContext() diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 79b251e7e62fe..a659abf70395d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -27,7 +27,7 @@ import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.Master import org.apache.spark.rpc._ -import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.util.{RpcUtils, ThreadUtils, Utils} /** * Interface allowing applications to speak with a Spark deploy cluster. Takes a master URL, @@ -248,7 +248,8 @@ private[spark] class AppClient( def stop() { if (endpoint != null) { try { - endpoint.askWithRetry[Boolean](StopAppClient) + val timeout = RpcUtils.askRpcTimeout(conf) + timeout.awaitResult(endpoint.ask[Boolean](StopAppClient)) } catch { case e: TimeoutException => logInfo("Stop request to Master timed out; it may already be shut down.") diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 82e9578bbcba5..0276c24f85368 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -25,7 +25,7 @@ import java.util.concurrent._ import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} import scala.collection.JavaConversions._ -import scala.collection.mutable.{HashMap, HashSet} +import scala.collection.mutable.{HashMap, HashSet, LinkedHashMap} import scala.concurrent.ExecutionContext import scala.util.Random import scala.util.control.NonFatal @@ -115,13 +115,18 @@ private[worker] class Worker( } var workDir: File = null - val finishedExecutors = new HashMap[String, ExecutorRunner] + val finishedExecutors = new LinkedHashMap[String, ExecutorRunner] val drivers = new HashMap[String, DriverRunner] val executors = new HashMap[String, ExecutorRunner] - val finishedDrivers = new HashMap[String, DriverRunner] + val finishedDrivers = new LinkedHashMap[String, DriverRunner] val appDirectories = new HashMap[String, Seq[String]] val finishedApps = new HashSet[String] + val retainedExecutors = conf.getInt("spark.worker.ui.retainedExecutors", + WorkerWebUI.DEFAULT_RETAINED_EXECUTORS) + val retainedDrivers = conf.getInt("spark.worker.ui.retainedDrivers", + WorkerWebUI.DEFAULT_RETAINED_DRIVERS) + // The shuffle service is not actually started unless configured. private val shuffleService = new ExternalShuffleService(conf, securityMgr) @@ -461,25 +466,7 @@ private[worker] class Worker( } case executorStateChanged @ ExecutorStateChanged(appId, execId, state, message, exitStatus) => - sendToMaster(executorStateChanged) - val fullId = appId + "/" + execId - if (ExecutorState.isFinished(state)) { - executors.get(fullId) match { - case Some(executor) => - logInfo("Executor " + fullId + " finished with state " + state + - message.map(" message " + _).getOrElse("") + - exitStatus.map(" exitStatus " + _).getOrElse("")) - executors -= fullId - finishedExecutors(fullId) = executor - coresUsed -= executor.cores - memoryUsed -= executor.memory - case None => - logInfo("Unknown Executor " + fullId + " finished with state " + state + - message.map(" message " + _).getOrElse("") + - exitStatus.map(" exitStatus " + _).getOrElse("")) - } - maybeCleanupApplication(appId) - } + handleExecutorStateChanged(executorStateChanged) case KillExecutor(masterUrl, appId, execId) => if (masterUrl != activeMasterUrl) { @@ -523,24 +510,8 @@ private[worker] class Worker( } } - case driverStageChanged @ DriverStateChanged(driverId, state, exception) => { - state match { - case DriverState.ERROR => - logWarning(s"Driver $driverId failed with unrecoverable exception: ${exception.get}") - case DriverState.FAILED => - logWarning(s"Driver $driverId exited with failure") - case DriverState.FINISHED => - logInfo(s"Driver $driverId exited successfully") - case DriverState.KILLED => - logInfo(s"Driver $driverId was killed by user") - case _ => - logDebug(s"Driver $driverId changed state to $state") - } - sendToMaster(driverStageChanged) - val driver = drivers.remove(driverId).get - finishedDrivers(driverId) = driver - memoryUsed -= driver.driverDesc.mem - coresUsed -= driver.driverDesc.cores + case driverStateChanged @ DriverStateChanged(driverId, state, exception) => { + handleDriverStateChanged(driverStateChanged) } case ReregisterWithMaster => @@ -614,6 +585,78 @@ private[worker] class Worker( webUi.stop() metricsSystem.stop() } + + private def trimFinishedExecutorsIfNecessary(): Unit = { + // do not need to protect with locks since both WorkerPage and Restful server get data through + // thread-safe RpcEndPoint + if (finishedExecutors.size > retainedExecutors) { + finishedExecutors.take(math.max(finishedExecutors.size / 10, 1)).foreach { + case (executorId, _) => finishedExecutors.remove(executorId) + } + } + } + + private def trimFinishedDriversIfNecessary(): Unit = { + // do not need to protect with locks since both WorkerPage and Restful server get data through + // thread-safe RpcEndPoint + if (finishedDrivers.size > retainedDrivers) { + finishedDrivers.take(math.max(finishedDrivers.size / 10, 1)).foreach { + case (driverId, _) => finishedDrivers.remove(driverId) + } + } + } + + private[worker] def handleDriverStateChanged(driverStateChanged: DriverStateChanged): Unit = { + val driverId = driverStateChanged.driverId + val exception = driverStateChanged.exception + val state = driverStateChanged.state + state match { + case DriverState.ERROR => + logWarning(s"Driver $driverId failed with unrecoverable exception: ${exception.get}") + case DriverState.FAILED => + logWarning(s"Driver $driverId exited with failure") + case DriverState.FINISHED => + logInfo(s"Driver $driverId exited successfully") + case DriverState.KILLED => + logInfo(s"Driver $driverId was killed by user") + case _ => + logDebug(s"Driver $driverId changed state to $state") + } + sendToMaster(driverStateChanged) + val driver = drivers.remove(driverId).get + finishedDrivers(driverId) = driver + trimFinishedDriversIfNecessary() + memoryUsed -= driver.driverDesc.mem + coresUsed -= driver.driverDesc.cores + } + + private[worker] def handleExecutorStateChanged(executorStateChanged: ExecutorStateChanged): + Unit = { + sendToMaster(executorStateChanged) + val state = executorStateChanged.state + if (ExecutorState.isFinished(state)) { + val appId = executorStateChanged.appId + val fullId = appId + "/" + executorStateChanged.execId + val message = executorStateChanged.message + val exitStatus = executorStateChanged.exitStatus + executors.get(fullId) match { + case Some(executor) => + logInfo("Executor " + fullId + " finished with state " + state + + message.map(" message " + _).getOrElse("") + + exitStatus.map(" exitStatus " + _).getOrElse("")) + executors -= fullId + finishedExecutors(fullId) = executor + trimFinishedExecutorsIfNecessary() + coresUsed -= executor.cores + memoryUsed -= executor.memory + case None => + logInfo("Unknown Executor " + fullId + " finished with state " + state + + message.map(" message " + _).getOrElse("") + + exitStatus.map(" exitStatus " + _).getOrElse("")) + } + maybeCleanupApplication(appId) + } + } } private[deploy] object Worker extends Logging { @@ -669,5 +712,4 @@ private[deploy] object Worker extends Logging { cmd } } - } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index 334a5b10142aa..709a27233598c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -53,6 +53,8 @@ class WorkerWebUI( } } -private[ui] object WorkerWebUI { +private[worker] object WorkerWebUI { val STATIC_RESOURCE_BASE = SparkUI.STATIC_RESOURCE_DIR + val DEFAULT_RETAINED_DRIVERS = 1000 + val DEFAULT_RETAINED_EXECUTORS = 1000 } diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala index f038b722957b8..00c1e078a441c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala @@ -85,7 +85,7 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging { return toGrant } else { logInfo( - s"Thread $taskAttemptId waiting for at least 1/2N of shuffle memory pool to be free") + s"TID $taskAttemptId waiting for at least 1/2N of shuffle memory pool to be free") wait() } } else { @@ -116,6 +116,12 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging { taskMemory.remove(taskAttemptId) notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed } + + /** Returns the memory consumption, in bytes, for the current task */ + def getMemoryConsumptionForThisTask(): Long = synchronized { + val taskAttemptId = currentTaskAttemptId() + taskMemory.getOrElse(taskAttemptId, 0L) + } } private object ShuffleMemoryManager { diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java similarity index 63% rename from unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java rename to core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 0be94ad371255..70f8ca4d21345 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -21,15 +21,14 @@ import java.nio.ByteBuffer; import java.util.*; -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; +import org.junit.*; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; +import static org.hamcrest.Matchers.greaterThan; import static org.mockito.AdditionalMatchers.geq; import static org.mockito.Mockito.*; +import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.memory.*; import org.apache.spark.unsafe.PlatformDependent; @@ -41,32 +40,39 @@ public abstract class AbstractBytesToBytesMapSuite { private final Random rand = new Random(42); - private TaskMemoryManager memoryManager; - private TaskMemoryManager sizeLimitedMemoryManager; + private ShuffleMemoryManager shuffleMemoryManager; + private TaskMemoryManager taskMemoryManager; + private TaskMemoryManager sizeLimitedTaskMemoryManager; private final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes @Before public void setup() { - memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(getMemoryAllocator())); + shuffleMemoryManager = new ShuffleMemoryManager(Long.MAX_VALUE); + taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(getMemoryAllocator())); // Mocked memory manager for tests that check the maximum array size, since actually allocating // such large arrays will cause us to run out of memory in our tests. - sizeLimitedMemoryManager = spy(memoryManager); - when(sizeLimitedMemoryManager.allocate(geq(1L << 20))).thenAnswer(new Answer() { - @Override - public MemoryBlock answer(InvocationOnMock invocation) throws Throwable { - if (((Long) invocation.getArguments()[0] / 8) > Integer.MAX_VALUE) { - throw new OutOfMemoryError("Requested array size exceeds VM limit"); + sizeLimitedTaskMemoryManager = mock(TaskMemoryManager.class); + when(sizeLimitedTaskMemoryManager.allocate(geq(1L << 20))).thenAnswer( + new Answer() { + @Override + public MemoryBlock answer(InvocationOnMock invocation) throws Throwable { + if (((Long) invocation.getArguments()[0] / 8) > Integer.MAX_VALUE) { + throw new OutOfMemoryError("Requested array size exceeds VM limit"); + } + return new MemoryBlock(null, 0, (Long) invocation.getArguments()[0]); } - return memoryManager.allocate(1L << 20); } - }); + ); } @After public void tearDown() { - if (memoryManager != null) { - memoryManager.cleanUpAllAllocatedMemory(); - memoryManager = null; + if (taskMemoryManager != null) { + long leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask(); + Assert.assertEquals(0, taskMemoryManager.cleanUpAllAllocatedMemory()); + Assert.assertEquals(0, leakedShuffleMemory); + shuffleMemoryManager = null; + taskMemoryManager = null; } } @@ -85,7 +91,7 @@ private static byte[] getByteArray(MemoryLocation loc, int size) { } private byte[] getRandomByteArray(int numWords) { - Assert.assertTrue(numWords > 0); + Assert.assertTrue(numWords >= 0); final int lengthInBytes = numWords * 8; final byte[] bytes = new byte[lengthInBytes]; rand.nextBytes(bytes); @@ -111,9 +117,10 @@ private static boolean arrayEquals( @Test public void emptyMap() { - BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64, PAGE_SIZE_BYTES); + BytesToBytesMap map = new BytesToBytesMap( + taskMemoryManager, shuffleMemoryManager, 64, PAGE_SIZE_BYTES); try { - Assert.assertEquals(0, map.size()); + Assert.assertEquals(0, map.numElements()); final int keyLengthInWords = 10; final int keyLengthInBytes = keyLengthInWords * 8; final byte[] key = getRandomByteArray(keyLengthInWords); @@ -126,7 +133,8 @@ public void emptyMap() { @Test public void setAndRetrieveAKey() { - BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64, PAGE_SIZE_BYTES); + BytesToBytesMap map = new BytesToBytesMap( + taskMemoryManager, shuffleMemoryManager, 64, PAGE_SIZE_BYTES); final int recordLengthWords = 10; final int recordLengthBytes = recordLengthWords * 8; final byte[] keyData = getRandomByteArray(recordLengthWords); @@ -135,14 +143,14 @@ public void setAndRetrieveAKey() { final BytesToBytesMap.Location loc = map.lookup(keyData, BYTE_ARRAY_OFFSET, recordLengthBytes); Assert.assertFalse(loc.isDefined()); - loc.putNewKey( + Assert.assertTrue(loc.putNewKey( keyData, BYTE_ARRAY_OFFSET, recordLengthBytes, valueData, BYTE_ARRAY_OFFSET, recordLengthBytes - ); + )); // After storing the key and value, the other location methods should return results that // reflect the result of this store without us having to call lookup() again on the same key. Assert.assertEquals(recordLengthBytes, loc.getKeyLength()); @@ -158,14 +166,14 @@ public void setAndRetrieveAKey() { Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes)); try { - loc.putNewKey( + Assert.assertTrue(loc.putNewKey( keyData, BYTE_ARRAY_OFFSET, recordLengthBytes, valueData, BYTE_ARRAY_OFFSET, recordLengthBytes - ); + )); Assert.fail("Should not be able to set a new value for a key"); } catch (AssertionError e) { // Expected exception; do nothing. @@ -178,7 +186,8 @@ public void setAndRetrieveAKey() { @Test public void iteratorTest() throws Exception { final int size = 4096; - BytesToBytesMap map = new BytesToBytesMap(memoryManager, size / 2, PAGE_SIZE_BYTES); + BytesToBytesMap map = new BytesToBytesMap( + taskMemoryManager, shuffleMemoryManager, size / 2, PAGE_SIZE_BYTES); try { for (long i = 0; i < size; i++) { final long[] value = new long[] { i }; @@ -187,23 +196,23 @@ public void iteratorTest() throws Exception { Assert.assertFalse(loc.isDefined()); // Ensure that we store some zero-length keys if (i % 5 == 0) { - loc.putNewKey( + Assert.assertTrue(loc.putNewKey( null, PlatformDependent.LONG_ARRAY_OFFSET, 0, value, PlatformDependent.LONG_ARRAY_OFFSET, 8 - ); + )); } else { - loc.putNewKey( + Assert.assertTrue(loc.putNewKey( value, PlatformDependent.LONG_ARRAY_OFFSET, 8, value, PlatformDependent.LONG_ARRAY_OFFSET, 8 - ); + )); } } final java.util.BitSet valuesSeen = new java.util.BitSet(size); @@ -234,16 +243,17 @@ public void iteratorTest() throws Exception { @Test public void iteratingOverDataPagesWithWastedSpace() throws Exception { final int NUM_ENTRIES = 1000 * 1000; - final int KEY_LENGTH = 16; + final int KEY_LENGTH = 24; final int VALUE_LENGTH = 40; - final BytesToBytesMap map = new BytesToBytesMap(memoryManager, NUM_ENTRIES, PAGE_SIZE_BYTES); - // Each record will take 8 + 8 + 16 + 40 = 72 bytes of space in the data page. Our 64-megabyte + final BytesToBytesMap map = new BytesToBytesMap( + taskMemoryManager, shuffleMemoryManager, NUM_ENTRIES, PAGE_SIZE_BYTES); + // Each record will take 8 + 24 + 40 = 72 bytes of space in the data page. Our 64-megabyte // pages won't be evenly-divisible by records of this size, which will cause us to waste some // space at the end of the page. This is necessary in order for us to take the end-of-record // handling branch in iterator(). try { for (int i = 0; i < NUM_ENTRIES; i++) { - final long[] key = new long[] { i, i }; // 2 * 8 = 16 bytes + final long[] key = new long[] { i, i, i }; // 3 * 8 = 24 bytes final long[] value = new long[] { i, i, i, i, i }; // 5 * 8 = 40 bytes final BytesToBytesMap.Location loc = map.lookup( key, @@ -251,14 +261,14 @@ public void iteratingOverDataPagesWithWastedSpace() throws Exception { KEY_LENGTH ); Assert.assertFalse(loc.isDefined()); - loc.putNewKey( + Assert.assertTrue(loc.putNewKey( key, LONG_ARRAY_OFFSET, KEY_LENGTH, value, LONG_ARRAY_OFFSET, VALUE_LENGTH - ); + )); } Assert.assertEquals(2, map.getNumDataPages()); @@ -305,7 +315,8 @@ public void randomizedStressTest() { // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays // into ByteBuffers in order to use them as keys here. final Map expected = new HashMap(); - final BytesToBytesMap map = new BytesToBytesMap(memoryManager, size, PAGE_SIZE_BYTES); + final BytesToBytesMap map = new BytesToBytesMap( + taskMemoryManager, shuffleMemoryManager, size, PAGE_SIZE_BYTES); try { // Fill the map to 90% full so that we can trigger probing @@ -320,14 +331,14 @@ public void randomizedStressTest() { key.length ); Assert.assertFalse(loc.isDefined()); - loc.putNewKey( + Assert.assertTrue(loc.putNewKey( key, BYTE_ARRAY_OFFSET, key.length, value, BYTE_ARRAY_OFFSET, value.length - ); + )); // After calling putNewKey, the following should be true, even before calling // lookup(): Assert.assertTrue(loc.isDefined()); @@ -351,10 +362,102 @@ public void randomizedStressTest() { } } + @Test + public void randomizedTestWithRecordsLargerThanPageSize() { + final long pageSizeBytes = 128; + final BytesToBytesMap map = new BytesToBytesMap( + taskMemoryManager, shuffleMemoryManager, 64, pageSizeBytes); + // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays + // into ByteBuffers in order to use them as keys here. + final Map expected = new HashMap(); + try { + for (int i = 0; i < 1000; i++) { + final byte[] key = getRandomByteArray(rand.nextInt(128)); + final byte[] value = getRandomByteArray(rand.nextInt(128)); + if (!expected.containsKey(ByteBuffer.wrap(key))) { + expected.put(ByteBuffer.wrap(key), value); + final BytesToBytesMap.Location loc = map.lookup( + key, + BYTE_ARRAY_OFFSET, + key.length + ); + Assert.assertFalse(loc.isDefined()); + Assert.assertTrue(loc.putNewKey( + key, + BYTE_ARRAY_OFFSET, + key.length, + value, + BYTE_ARRAY_OFFSET, + value.length + )); + // After calling putNewKey, the following should be true, even before calling + // lookup(): + Assert.assertTrue(loc.isDefined()); + Assert.assertEquals(key.length, loc.getKeyLength()); + Assert.assertEquals(value.length, loc.getValueLength()); + Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), key.length)); + Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), value.length)); + } + } + for (Map.Entry entry : expected.entrySet()) { + final byte[] key = entry.getKey().array(); + final byte[] value = entry.getValue(); + final BytesToBytesMap.Location loc = map.lookup(key, BYTE_ARRAY_OFFSET, key.length); + Assert.assertTrue(loc.isDefined()); + Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength())); + Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength())); + } + } finally { + map.free(); + } + } + + @Test + public void failureToAllocateFirstPage() { + shuffleMemoryManager = new ShuffleMemoryManager(1024); + BytesToBytesMap map = + new BytesToBytesMap(taskMemoryManager, shuffleMemoryManager, 1, PAGE_SIZE_BYTES); + try { + final long[] emptyArray = new long[0]; + final BytesToBytesMap.Location loc = + map.lookup(emptyArray, PlatformDependent.LONG_ARRAY_OFFSET, 0); + Assert.assertFalse(loc.isDefined()); + Assert.assertFalse(loc.putNewKey( + emptyArray, LONG_ARRAY_OFFSET, 0, + emptyArray, LONG_ARRAY_OFFSET, 0 + )); + } finally { + map.free(); + } + } + + + @Test + public void failureToGrow() { + shuffleMemoryManager = new ShuffleMemoryManager(1024 * 10); + BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, shuffleMemoryManager, 1, 1024); + try { + boolean success = true; + int i; + for (i = 0; i < 1024; i++) { + final long[] arr = new long[]{i}; + final BytesToBytesMap.Location loc = map.lookup(arr, PlatformDependent.LONG_ARRAY_OFFSET, 8); + success = loc.putNewKey(arr, LONG_ARRAY_OFFSET, 8, arr, LONG_ARRAY_OFFSET, 8); + if (!success) { + break; + } + } + Assert.assertThat(i, greaterThan(0)); + Assert.assertFalse(success); + } finally { + map.free(); + } + } + @Test public void initialCapacityBoundsChecking() { try { - new BytesToBytesMap(sizeLimitedMemoryManager, 0, PAGE_SIZE_BYTES); + new BytesToBytesMap(sizeLimitedTaskMemoryManager, shuffleMemoryManager, 0, PAGE_SIZE_BYTES); Assert.fail("Expected IllegalArgumentException to be thrown"); } catch (IllegalArgumentException e) { // expected exception @@ -362,23 +465,34 @@ public void initialCapacityBoundsChecking() { try { new BytesToBytesMap( - sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY + 1, PAGE_SIZE_BYTES); + sizeLimitedTaskMemoryManager, + shuffleMemoryManager, + BytesToBytesMap.MAX_CAPACITY + 1, + PAGE_SIZE_BYTES); Assert.fail("Expected IllegalArgumentException to be thrown"); } catch (IllegalArgumentException e) { // expected exception } - // Can allocate _at_ the max capacity - BytesToBytesMap map = - new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY, PAGE_SIZE_BYTES); - map.free(); + // Ignored because this can OOM now that we allocate the long array w/o a TaskMemoryManager + // Can allocate _at_ the max capacity + // BytesToBytesMap map = new BytesToBytesMap( + // sizeLimitedTaskMemoryManager, + // shuffleMemoryManager, + // BytesToBytesMap.MAX_CAPACITY, + // PAGE_SIZE_BYTES); + // map.free(); } - @Test + // Ignored because this can OOM now that we allocate the long array w/o a TaskMemoryManager + @Ignore public void resizingLargeMap() { // As long as a map's capacity is below the max, we should be able to resize up to the max BytesToBytesMap map = new BytesToBytesMap( - sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY - 64, PAGE_SIZE_BYTES); + sizeLimitedTaskMemoryManager, + shuffleMemoryManager, + BytesToBytesMap.MAX_CAPACITY - 64, + PAGE_SIZE_BYTES); map.growAndRehash(); map.free(); } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java similarity index 100% rename from unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java rename to core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java similarity index 100% rename from unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java rename to core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java diff --git a/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala b/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala new file mode 100644 index 0000000000000..967aa0976f0ce --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.io.File +import java.util.Date + +import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo} +import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} +import org.apache.spark.{SecurityManager, SparkConf} + +private[deploy] object DeployTestUtils { + def createAppDesc(): ApplicationDescription = { + val cmd = new Command("mainClass", List("arg1", "arg2"), Map(), Seq(), Seq(), Seq()) + new ApplicationDescription("name", Some(4), 1234, cmd, "appUiUrl") + } + + def createAppInfo() : ApplicationInfo = { + val appInfo = new ApplicationInfo(JsonConstants.appInfoStartTime, + "id", createAppDesc(), JsonConstants.submitDate, null, Int.MaxValue) + appInfo.endTime = JsonConstants.currTimeInMillis + appInfo + } + + def createDriverCommand(): Command = new Command( + "org.apache.spark.FakeClass", Seq("some arg --and-some options -g foo"), + Map(("K1", "V1"), ("K2", "V2")), Seq("cp1", "cp2"), Seq("lp1", "lp2"), Seq("-Dfoo") + ) + + def createDriverDesc(): DriverDescription = + new DriverDescription("hdfs://some-dir/some.jar", 100, 3, false, createDriverCommand()) + + def createDriverInfo(): DriverInfo = new DriverInfo(3, "driver-3", + createDriverDesc(), new Date()) + + def createWorkerInfo(): WorkerInfo = { + val workerInfo = new WorkerInfo("id", "host", 8080, 4, 1234, null, 80, "publicAddress") + workerInfo.lastHeartbeat = JsonConstants.currTimeInMillis + workerInfo + } + + def createExecutorRunner(execId: Int): ExecutorRunner = { + new ExecutorRunner( + "appId", + execId, + createAppDesc(), + 4, + 1234, + null, + "workerId", + "host", + 123, + "publicAddress", + new File("sparkHome"), + new File("workDir"), + "akka://worker", + new SparkConf, + Seq("localDir"), + ExecutorState.RUNNING) + } + + def createDriverRunner(driverId: String): DriverRunner = { + val conf = new SparkConf() + new DriverRunner( + conf, + driverId, + new File("workDir"), + new File("sparkHome"), + createDriverDesc(), + null, + "akka://worker", + new SecurityManager(conf)) + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala index 08529e0ef2806..0a9f128a3a6b6 100644 --- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.deploy -import java.io.File import java.util.Date import com.fasterxml.jackson.core.JsonParseException @@ -25,12 +24,14 @@ import org.json4s._ import org.json4s.jackson.JsonMethods import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse} -import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, RecoveryState, WorkerInfo} -import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} -import org.apache.spark.{JsonTestUtils, SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.deploy.master.{ApplicationInfo, RecoveryState} +import org.apache.spark.deploy.worker.ExecutorRunner +import org.apache.spark.{JsonTestUtils, SparkFunSuite} class JsonProtocolSuite extends SparkFunSuite with JsonTestUtils { + import org.apache.spark.deploy.DeployTestUtils._ + test("writeApplicationInfo") { val output = JsonProtocol.writeApplicationInfo(createAppInfo()) assertValidJson(output) @@ -50,7 +51,7 @@ class JsonProtocolSuite extends SparkFunSuite with JsonTestUtils { } test("writeExecutorRunner") { - val output = JsonProtocol.writeExecutorRunner(createExecutorRunner()) + val output = JsonProtocol.writeExecutorRunner(createExecutorRunner(123)) assertValidJson(output) assertValidDataInJson(output, JsonMethods.parse(JsonConstants.executorRunnerJsonStr)) } @@ -77,9 +78,10 @@ class JsonProtocolSuite extends SparkFunSuite with JsonTestUtils { test("writeWorkerState") { val executors = List[ExecutorRunner]() - val finishedExecutors = List[ExecutorRunner](createExecutorRunner(), createExecutorRunner()) - val drivers = List(createDriverRunner()) - val finishedDrivers = List(createDriverRunner(), createDriverRunner()) + val finishedExecutors = List[ExecutorRunner](createExecutorRunner(123), + createExecutorRunner(123)) + val drivers = List(createDriverRunner("driverId")) + val finishedDrivers = List(createDriverRunner("driverId"), createDriverRunner("driverId")) val stateResponse = new WorkerStateResponse("host", 8080, "workerId", executors, finishedExecutors, drivers, finishedDrivers, "masterUrl", 4, 1234, 4, 1234, "masterWebUiUrl") val output = JsonProtocol.writeWorkerState(stateResponse) @@ -87,47 +89,6 @@ class JsonProtocolSuite extends SparkFunSuite with JsonTestUtils { assertValidDataInJson(output, JsonMethods.parse(JsonConstants.workerStateJsonStr)) } - def createAppDesc(): ApplicationDescription = { - val cmd = new Command("mainClass", List("arg1", "arg2"), Map(), Seq(), Seq(), Seq()) - new ApplicationDescription("name", Some(4), 1234, cmd, "appUiUrl") - } - - def createAppInfo() : ApplicationInfo = { - val appInfo = new ApplicationInfo(JsonConstants.appInfoStartTime, - "id", createAppDesc(), JsonConstants.submitDate, null, Int.MaxValue) - appInfo.endTime = JsonConstants.currTimeInMillis - appInfo - } - - def createDriverCommand(): Command = new Command( - "org.apache.spark.FakeClass", Seq("some arg --and-some options -g foo"), - Map(("K1", "V1"), ("K2", "V2")), Seq("cp1", "cp2"), Seq("lp1", "lp2"), Seq("-Dfoo") - ) - - def createDriverDesc(): DriverDescription = - new DriverDescription("hdfs://some-dir/some.jar", 100, 3, false, createDriverCommand()) - - def createDriverInfo(): DriverInfo = new DriverInfo(3, "driver-3", - createDriverDesc(), new Date()) - - def createWorkerInfo(): WorkerInfo = { - val workerInfo = new WorkerInfo("id", "host", 8080, 4, 1234, null, 80, "publicAddress") - workerInfo.lastHeartbeat = JsonConstants.currTimeInMillis - workerInfo - } - - def createExecutorRunner(): ExecutorRunner = { - new ExecutorRunner("appId", 123, createAppDesc(), 4, 1234, null, "workerId", "host", 123, - "publicAddress", new File("sparkHome"), new File("workDir"), "akka://worker", - new SparkConf, Seq("localDir"), ExecutorState.RUNNING) - } - - def createDriverRunner(): DriverRunner = { - val conf = new SparkConf() - new DriverRunner(conf, "driverId", new File("workDir"), new File("sparkHome"), - createDriverDesc(), null, "akka://worker", new SecurityManager(conf)) - } - def assertValidJson(json: JValue) { try { JsonMethods.parse(JsonMethods.compact(json)) diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala index 0f4d3b28d09df..faed4bdc68447 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala @@ -17,13 +17,18 @@ package org.apache.spark.deploy.worker -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.Command - import org.scalatest.Matchers +import org.apache.spark.deploy.DeployMessages.{DriverStateChanged, ExecutorStateChanged} +import org.apache.spark.deploy.master.DriverState +import org.apache.spark.deploy.{Command, ExecutorState} +import org.apache.spark.rpc.{RpcAddress, RpcEnv} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} + class WorkerSuite extends SparkFunSuite with Matchers { + import org.apache.spark.deploy.DeployTestUtils._ + def cmd(javaOpts: String*): Command = { Command("", Seq.empty, Map.empty, Seq.empty, Seq.empty, Seq(javaOpts : _*)) } @@ -56,4 +61,126 @@ class WorkerSuite extends SparkFunSuite with Matchers { "-Dspark.ssl.useNodeLocalConf=true", "-Dspark.ssl.opt1=y", "-Dspark.ssl.opt2=z") } + + test("test clearing of finishedExecutors (small number of executors)") { + val conf = new SparkConf() + conf.set("spark.worker.ui.retainedExecutors", 2.toString) + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), + "sparkWorker1", "Worker", "/tmp", conf, new SecurityManager(conf)) + // initialize workers + for (i <- 0 until 5) { + worker.executors += s"app1/$i" -> createExecutorRunner(i) + } + // initialize ExecutorStateChanged Message + worker.handleExecutorStateChanged( + ExecutorStateChanged("app1", 0, ExecutorState.EXITED, None, None)) + assert(worker.finishedExecutors.size === 1) + assert(worker.executors.size === 4) + for (i <- 1 until 5) { + worker.handleExecutorStateChanged( + ExecutorStateChanged("app1", i, ExecutorState.EXITED, None, None)) + assert(worker.finishedExecutors.size === 2) + if (i > 1) { + assert(!worker.finishedExecutors.contains(s"app1/${i - 2}")) + } + assert(worker.executors.size === 4 - i) + } + } + + test("test clearing of finishedExecutors (more executors)") { + val conf = new SparkConf() + conf.set("spark.worker.ui.retainedExecutors", 30.toString) + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), + "sparkWorker1", "Worker", "/tmp", conf, new SecurityManager(conf)) + // initialize workers + for (i <- 0 until 50) { + worker.executors += s"app1/$i" -> createExecutorRunner(i) + } + // initialize ExecutorStateChanged Message + worker.handleExecutorStateChanged( + ExecutorStateChanged("app1", 0, ExecutorState.EXITED, None, None)) + assert(worker.finishedExecutors.size === 1) + assert(worker.executors.size === 49) + for (i <- 1 until 50) { + val expectedValue = { + if (worker.finishedExecutors.size < 30) { + worker.finishedExecutors.size + 1 + } else { + 28 + } + } + worker.handleExecutorStateChanged( + ExecutorStateChanged("app1", i, ExecutorState.EXITED, None, None)) + if (expectedValue == 28) { + for (j <- i - 30 until i - 27) { + assert(!worker.finishedExecutors.contains(s"app1/$j")) + } + } + assert(worker.executors.size === 49 - i) + assert(worker.finishedExecutors.size === expectedValue) + } + } + + test("test clearing of finishedDrivers (small number of drivers)") { + val conf = new SparkConf() + conf.set("spark.worker.ui.retainedDrivers", 2.toString) + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), + "sparkWorker1", "Worker", "/tmp", conf, new SecurityManager(conf)) + // initialize workers + for (i <- 0 until 5) { + val driverId = s"driverId-$i" + worker.drivers += driverId -> createDriverRunner(driverId) + } + // initialize DriverStateChanged Message + worker.handleDriverStateChanged(DriverStateChanged("driverId-0", DriverState.FINISHED, None)) + assert(worker.drivers.size === 4) + assert(worker.finishedDrivers.size === 1) + for (i <- 1 until 5) { + val driverId = s"driverId-$i" + worker.handleDriverStateChanged(DriverStateChanged(driverId, DriverState.FINISHED, None)) + if (i > 1) { + assert(!worker.finishedDrivers.contains(s"driverId-${i - 2}")) + } + assert(worker.drivers.size === 4 - i) + assert(worker.finishedDrivers.size === 2) + } + } + + test("test clearing of finishedDrivers (more drivers)") { + val conf = new SparkConf() + conf.set("spark.worker.ui.retainedDrivers", 30.toString) + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), + "sparkWorker1", "Worker", "/tmp", conf, new SecurityManager(conf)) + // initialize workers + for (i <- 0 until 50) { + val driverId = s"driverId-$i" + worker.drivers += driverId -> createDriverRunner(driverId) + } + // initialize DriverStateChanged Message + worker.handleDriverStateChanged(DriverStateChanged("driverId-0", DriverState.FINISHED, None)) + assert(worker.finishedDrivers.size === 1) + assert(worker.drivers.size === 49) + for (i <- 1 until 50) { + val expectedValue = { + if (worker.finishedDrivers.size < 30) { + worker.finishedDrivers.size + 1 + } else { + 28 + } + } + val driverId = s"driverId-$i" + worker.handleDriverStateChanged(DriverStateChanged(driverId, DriverState.FINISHED, None)) + if (expectedValue == 28) { + for (j <- i - 30 until i - 27) { + assert(!worker.finishedDrivers.contains(s"driverId-$j")) + } + } + assert(worker.drivers.size === 49 - i) + assert(worker.finishedDrivers.size === expectedValue) + } + } } diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 86a7a4068c40e..4311c8c9e4ca6 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -118,13 +118,13 @@ if [[ ! "$@" =~ --skip-publish ]]; then rm -rf $SPARK_REPO - build/mvn -DskipTests -Pyarn -Phive -Prelease\ + build/mvn -DskipTests -Pyarn -Phive \ -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ clean install ./dev/change-scala-version.sh 2.11 - build/mvn -DskipTests -Pyarn -Phive -Prelease\ + build/mvn -DskipTests -Pyarn -Phive \ -Dscala-2.11 -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ clean install diff --git a/dev/run-tests.py b/dev/run-tests.py index 1eff2b4d5c071..69e11fd9db861 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -303,7 +303,8 @@ def build_spark_sbt(hadoop_version): "streaming-kafka-assembly/assembly", "streaming-flume-assembly/assembly", "streaming-mqtt-assembly/assembly", - "streaming-mqtt/test:assembly"] + "streaming-mqtt/test:assembly", + "streaming-kinesis-asl-assembly/assembly"] profiles_and_goals = build_profiles + sbt_goals print("[info] Building Spark (w/Hive 0.13.1) using SBT with these arguments: ", diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 3e273af10ffae..9b588fbc1164e 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -138,6 +138,7 @@ def contains_file(self, filename): dependencies=[], source_file_regexes=[ "extras/kinesis-asl/", + "extras/kinesis-asl-assembly/", ], build_profile_flags=[ "-Pkinesis-asl", @@ -306,7 +307,8 @@ def contains_file(self, filename): streaming, streaming_kafka, streaming_flume_assembly, - streaming_mqtt + streaming_mqtt, + streaming_kinesis_asl ], source_file_regexes=[ "python/pyspark/streaming" diff --git a/docs/configuration.md b/docs/configuration.md index fd236137cb96e..24b606356a149 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -557,6 +557,20 @@ Apart from these, the following properties are also available, and may be useful collecting. + + spark.worker.ui.retainedExecutors + 1000 + + How many finished executors the Spark UI and status APIs remember before garbage collecting. + + + + spark.worker.ui.retainedDrivers + 1000 + + How many finished drivers the Spark UI and status APIs remember before garbage collecting. + + #### Compression and Serialization diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 8c46adf256a9a..b6ca50e98db02 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -561,7 +561,7 @@ test = sc.parallelize([(4L, "spark i j k"), prediction = model.transform(test) selected = prediction.select("id", "text", "prediction") for row in selected.collect(): - print row + print(row) sc.stop() {% endhighlight %} diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md index 4ca0bb06b26a6..7066d5c97418c 100644 --- a/docs/mllib-evaluation-metrics.md +++ b/docs/mllib-evaluation-metrics.md @@ -302,10 +302,10 @@ predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp metrics = BinaryClassificationMetrics(predictionAndLabels) # Area under precision-recall curve -print "Area under PR = %s" % metrics.areaUnderPR +print("Area under PR = %s" % metrics.areaUnderPR) # Area under ROC curve -print "Area under ROC = %s" % metrics.areaUnderROC +print("Area under ROC = %s" % metrics.areaUnderROC) {% endhighlight %} @@ -606,24 +606,24 @@ metrics = MulticlassMetrics(predictionAndLabels) precision = metrics.precision() recall = metrics.recall() f1Score = metrics.fMeasure() -print "Summary Stats" -print "Precision = %s" % precision -print "Recall = %s" % recall -print "F1 Score = %s" % f1Score +print("Summary Stats") +print("Precision = %s" % precision) +print("Recall = %s" % recall) +print("F1 Score = %s" % f1Score) # Statistics by class labels = data.map(lambda lp: lp.label).distinct().collect() for label in sorted(labels): - print "Class %s precision = %s" % (label, metrics.precision(label)) - print "Class %s recall = %s" % (label, metrics.recall(label)) - print "Class %s F1 Measure = %s" % (label, metrics.fMeasure(label, beta=1.0)) + print("Class %s precision = %s" % (label, metrics.precision(label))) + print("Class %s recall = %s" % (label, metrics.recall(label))) + print("Class %s F1 Measure = %s" % (label, metrics.fMeasure(label, beta=1.0))) # Weighted stats -print "Weighted recall = %s" % metrics.weightedRecall -print "Weighted precision = %s" % metrics.weightedPrecision -print "Weighted F(1) Score = %s" % metrics.weightedFMeasure() -print "Weighted F(0.5) Score = %s" % metrics.weightedFMeasure(beta=0.5) -print "Weighted false positive rate = %s" % metrics.weightedFalsePositiveRate +print("Weighted recall = %s" % metrics.weightedRecall) +print("Weighted precision = %s" % metrics.weightedPrecision) +print("Weighted F(1) Score = %s" % metrics.weightedFMeasure()) +print("Weighted F(0.5) Score = %s" % metrics.weightedFMeasure(beta=0.5)) +print("Weighted false positive rate = %s" % metrics.weightedFalsePositiveRate) {% endhighlight %} @@ -881,28 +881,28 @@ scoreAndLabels = sc.parallelize([ metrics = MultilabelMetrics(scoreAndLabels) # Summary stats -print "Recall = %s" % metrics.recall() -print "Precision = %s" % metrics.precision() -print "F1 measure = %s" % metrics.f1Measure() -print "Accuracy = %s" % metrics.accuracy +print("Recall = %s" % metrics.recall()) +print("Precision = %s" % metrics.precision()) +print("F1 measure = %s" % metrics.f1Measure()) +print("Accuracy = %s" % metrics.accuracy) # Individual label stats labels = scoreAndLabels.flatMap(lambda x: x[1]).distinct().collect() for label in labels: - print "Class %s precision = %s" % (label, metrics.precision(label)) - print "Class %s recall = %s" % (label, metrics.recall(label)) - print "Class %s F1 Measure = %s" % (label, metrics.f1Measure(label)) + print("Class %s precision = %s" % (label, metrics.precision(label))) + print("Class %s recall = %s" % (label, metrics.recall(label))) + print("Class %s F1 Measure = %s" % (label, metrics.f1Measure(label))) # Micro stats -print "Micro precision = %s" % metrics.microPrecision -print "Micro recall = %s" % metrics.microRecall -print "Micro F1 measure = %s" % metrics.microF1Measure +print("Micro precision = %s" % metrics.microPrecision) +print("Micro recall = %s" % metrics.microRecall) +print("Micro F1 measure = %s" % metrics.microF1Measure) # Hamming loss -print "Hamming loss = %s" % metrics.hammingLoss +print("Hamming loss = %s" % metrics.hammingLoss) # Subset accuracy -print "Subset accuracy = %s" % metrics.subsetAccuracy +print("Subset accuracy = %s" % metrics.subsetAccuracy) {% endhighlight %} @@ -1283,10 +1283,10 @@ scoreAndLabels = predictions.join(ratingsTuple).map(lambda tup: tup[1]) metrics = RegressionMetrics(scoreAndLabels) # Root mean sqaured error -print "RMSE = %s" % metrics.rootMeanSquaredError +print("RMSE = %s" % metrics.rootMeanSquaredError) # R-squared -print "R-squared = %s" % metrics.r2 +print("R-squared = %s" % metrics.r2) {% endhighlight %} @@ -1479,17 +1479,17 @@ valuesAndPreds = parsedData.map(lambda p: (float(model.predict(p.features)), p.l metrics = RegressionMetrics(valuesAndPreds) # Squared Error -print "MSE = %s" % metrics.meanSquaredError -print "RMSE = %s" % metrics.rootMeanSquaredError +print("MSE = %s" % metrics.meanSquaredError) +print("RMSE = %s" % metrics.rootMeanSquaredError) # R-squared -print "R-squared = %s" % metrics.r2 +print("R-squared = %s" % metrics.r2) # Mean absolute error -print "MAE = %s" % metrics.meanAbsoluteError +print("MAE = %s" % metrics.meanAbsoluteError) # Explained variance -print "Explained variance = %s" % metrics.explainedVariance +print("Explained variance = %s" % metrics.explainedVariance) {% endhighlight %} diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index a69e41e2a1936..de86aba2ae627 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -221,7 +221,7 @@ model = word2vec.fit(inp) synonyms = model.findSynonyms('china', 40) for word, cosine_distance in synonyms: - print "{}: {}".format(word, cosine_distance) + print("{}: {}".format(word, cosine_distance)) {% endhighlight %} diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md index de5d6485f9b5f..be04d0b4b53a8 100644 --- a/docs/mllib-statistics.md +++ b/docs/mllib-statistics.md @@ -95,9 +95,9 @@ mat = ... # an RDD of Vectors # Compute column summary statistics. summary = Statistics.colStats(mat) -print summary.mean() -print summary.variance() -print summary.numNonzeros() +print(summary.mean()) +print(summary.variance()) +print(summary.numNonzeros()) {% endhighlight %} @@ -183,12 +183,12 @@ seriesY = ... # must have the same number of partitions and cardinality as serie # Compute the correlation using Pearson's method. Enter "spearman" for Spearman's method. If a # method is not specified, Pearson's method will be used by default. -print Statistics.corr(seriesX, seriesY, method="pearson") +print(Statistics.corr(seriesX, seriesY, method="pearson")) data = ... # an RDD of Vectors # calculate the correlation matrix using Pearson's method. Use "spearman" for Spearman's method. # If a method is not specified, Pearson's method will be used by default. -print Statistics.corr(data, method="pearson") +print(Statistics.corr(data, method="pearson")) {% endhighlight %} @@ -398,14 +398,14 @@ vec = Vectors.dense(...) # a vector composed of the frequencies of events # compute the goodness of fit. If a second vector to test against is not supplied as a parameter, # the test runs against a uniform distribution. goodnessOfFitTestResult = Statistics.chiSqTest(vec) -print goodnessOfFitTestResult # summary of the test including the p-value, degrees of freedom, - # test statistic, the method used, and the null hypothesis. +print(goodnessOfFitTestResult) # summary of the test including the p-value, degrees of freedom, + # test statistic, the method used, and the null hypothesis. mat = Matrices.dense(...) # a contingency matrix # conduct Pearson's independence test on the input contingency matrix independenceTestResult = Statistics.chiSqTest(mat) -print independenceTestResult # summary of the test including the p-value, degrees of freedom... +print(independenceTestResult) # summary of the test including the p-value, degrees of freedom... obs = sc.parallelize(...) # LabeledPoint(feature, label) . @@ -415,8 +415,8 @@ obs = sc.parallelize(...) # LabeledPoint(feature, label) . featureTestResults = Statistics.chiSqTest(obs) for i, result in enumerate(featureTestResults): - print "Column $d:" % (i + 1) - print result + print("Column $d:" % (i + 1)) + print(result) {% endhighlight %} diff --git a/docs/quick-start.md b/docs/quick-start.md index bb39e4111f244..ce2cc9d2169cd 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -406,7 +406,7 @@ logData = sc.textFile(logFile).cache() numAs = logData.filter(lambda s: 'a' in s).count() numBs = logData.filter(lambda s: 'b' in s).count() -print "Lines with a: %i, lines with b: %i" % (numAs, numBs) +print("Lines with a: %i, lines with b: %i" % (numAs, numBs)) {% endhighlight %} diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 95945eb7fc8a0..d31baa080cbce 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -570,7 +570,7 @@ teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 1 # The results of SQL queries are RDDs and support all the normal RDD operations. teenNames = teenagers.map(lambda p: "Name: " + p.name) for teenName in teenNames.collect(): - print teenName + print(teenName) {% endhighlight %} @@ -752,7 +752,7 @@ results = sqlContext.sql("SELECT name FROM people") # The results of SQL queries are RDDs and support all the normal RDD operations. names = results.map(lambda p: "Name: " + p.name) for name in names.collect(): - print name + print(name) {% endhighlight %} @@ -1006,7 +1006,7 @@ parquetFile.registerTempTable("parquetFile"); teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19") teenNames = teenagers.map(lambda p: "Name: " + p.name) for teenName in teenNames.collect(): - print teenName + print(teenName) {% endhighlight %} diff --git a/docs/streaming-kinesis-integration.md b/docs/streaming-kinesis-integration.md index aa9749afbc867..a7bcaec6fcd84 100644 --- a/docs/streaming-kinesis-integration.md +++ b/docs/streaming-kinesis-integration.md @@ -51,6 +51,17 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m See the [API docs](api/java/index.html?org/apache/spark/streaming/kinesis/KinesisUtils.html) and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java). Refer to the next subsection for instructions to run the example. + +
+ from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream + + kinesisStream = KinesisUtils.createStream( + streamingContext, [Kinesis app name], [Kinesis stream name], [endpoint URL], + [region name], [initial position], [checkpoint interval], StorageLevel.MEMORY_AND_DISK_2) + + See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.kinesis.KinesisUtils) + and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py). Refer to the next subsection for instructions to run the example. +
@@ -135,6 +146,14 @@ To run the example, bin/run-example streaming.JavaKinesisWordCountASL [Kinesis app name] [Kinesis stream name] [endpoint URL] + +
+ + bin/spark-submit --jars extras/kinesis-asl/target/scala-*/\ + spark-streaming-kinesis-asl-assembly_*.jar \ + extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py \ + [Kinesis app name] [Kinesis stream name] [endpoint URL] [region name] +
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 77d62047c3525..ddf0b39292b3f 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -1525,7 +1525,7 @@ def getSqlContextInstance(sparkContext): words = ... # DStream of strings def process(time, rdd): - print "========= %s =========" % str(time) + print("========= %s =========" % str(time)) try: # Get the singleton instance of SQLContext sqlContext = getSqlContextInstance(rdd.context) diff --git a/extras/kinesis-asl-assembly/pom.xml b/extras/kinesis-asl-assembly/pom.xml new file mode 100644 index 0000000000000..70d2c9c58f54e --- /dev/null +++ b/extras/kinesis-asl-assembly/pom.xml @@ -0,0 +1,103 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.10 + 1.5.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-streaming-kinesis-asl-assembly_2.10 + jar + Spark Project Kinesis Assembly + http://spark.apache.org/ + + + streaming-kinesis-asl-assembly + + + + + org.apache.spark + spark-streaming-kinesis-asl_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + provided + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-shade-plugin + + false + ${project.build.directory}/scala-${scala.binary.version}/spark-streaming-kinesis-asl-assembly-${project.version}.jar + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + package + + shade + + + + + + reference.conf + + + log4j.properties + + + + + + + + + + + + diff --git a/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py b/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py new file mode 100644 index 0000000000000..f428f64da3c42 --- /dev/null +++ b/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py @@ -0,0 +1,81 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Consumes messages from a Amazon Kinesis streams and does wordcount. + + This example spins up 1 Kinesis Receiver per shard for the given stream. + It then starts pulling from the last checkpointed sequence number of the given stream. + + Usage: kinesis_wordcount_asl.py + is the name of the consumer app, used to track the read data in DynamoDB + name of the Kinesis stream (ie. mySparkStream) + endpoint of the Kinesis service + (e.g. https://kinesis.us-east-1.amazonaws.com) + + + Example: + # export AWS keys if necessary + $ export AWS_ACCESS_KEY_ID= + $ export AWS_SECRET_KEY= + + # run the example + $ bin/spark-submit -jar extras/kinesis-asl/target/scala-*/\ + spark-streaming-kinesis-asl-assembly_*.jar \ + extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py \ + myAppName mySparkStream https://kinesis.us-east-1.amazonaws.com + + There is a companion helper class called KinesisWordProducerASL which puts dummy data + onto the Kinesis stream. + + This code uses the DefaultAWSCredentialsProviderChain to find credentials + in the following order: + Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY + Java System Properties - aws.accessKeyId and aws.secretKey + Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs + Instance profile credentials - delivered through the Amazon EC2 metadata service + For more information, see + http://docs.aws.amazon.com/AWSSdkDocsJava/latest/DeveloperGuide/credentials.html + + See http://spark.apache.org/docs/latest/streaming-kinesis-integration.html for more details on + the Kinesis Spark Streaming integration. +""" +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream + +if __name__ == "__main__": + if len(sys.argv) != 5: + print( + "Usage: kinesis_wordcount_asl.py ", + file=sys.stderr) + sys.exit(-1) + + sc = SparkContext(appName="PythonStreamingKinesisWordCountAsl") + ssc = StreamingContext(sc, 1) + appName, streamName, endpointUrl, regionName = sys.argv[1:] + lines = KinesisUtils.createStream( + ssc, appName, streamName, endpointUrl, regionName, InitialPositionInStream.LATEST, 2) + counts = lines.flatMap(lambda line: line.split(" ")) \ + .map(lambda word: (word, 1)) \ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala index ca39358b75cb6..255ac27f793ba 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -36,9 +36,15 @@ import org.apache.spark.Logging /** * Shared utility methods for performing Kinesis tests that actually transfer data */ -private class KinesisTestUtils( - val endpointUrl: String = "https://kinesis.us-west-2.amazonaws.com", - _regionName: String = "") extends Logging { +private class KinesisTestUtils(val endpointUrl: String, _regionName: String) extends Logging { + + def this() { + this("https://kinesis.us-west-2.amazonaws.com", "") + } + + def this(endpointUrl: String) { + this(endpointUrl, "") + } val regionName = if (_regionName.length == 0) { RegionUtils.getRegionByEndpoint(endpointUrl).getName() @@ -117,6 +123,13 @@ private class KinesisTestUtils( shardIdToSeqNumbers.toMap } + /** + * Expose a Python friendly API. + */ + def pushData(testData: java.util.List[Int]): Unit = { + pushData(scala.collection.JavaConversions.asScalaBuffer(testData)) + } + def deleteStream(): Unit = { try { if (streamCreated) { diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index e5acab50181e1..7dab17eba8483 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -86,19 +86,19 @@ object KinesisUtils { * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) * @param regionName Name of region used by the Kinesis Client Library (KCL) to update * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) - * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) - * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) - * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. - * See the Kinesis Spark Streaming documentation for more - * details on the different types of checkpoints. * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the * worker's initial starting position in the stream. * The values are either the beginning of the stream * per Kinesis' limit of 24 hours * (InitialPositionInStream.TRIM_HORIZON) or * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. * @param storageLevel Storage level to use for storing the received objects. * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) + * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) */ def createStream( ssc: StreamingContext, @@ -130,7 +130,7 @@ object KinesisUtils { * - The Kinesis application name used by the Kinesis Client Library (KCL) will be the app name in * [[org.apache.spark.SparkConf]]. * - * @param ssc Java StreamingContext object + * @param ssc StreamingContext object * @param streamName Kinesis stream name * @param endpointUrl Endpoint url of Kinesis service * (e.g., https://kinesis.us-east-1.amazonaws.com) @@ -175,15 +175,15 @@ object KinesisUtils { * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) * @param regionName Name of region used by the Kinesis Client Library (KCL) to update * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) - * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. - * See the Kinesis Spark Streaming documentation for more - * details on the different types of checkpoints. * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the * worker's initial starting position in the stream. * The values are either the beginning of the stream * per Kinesis' limit of 24 hours * (InitialPositionInStream.TRIM_HORIZON) or * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. * @param storageLevel Storage level to use for storing the received objects. * StorageLevel.MEMORY_AND_DISK_2 is recommended. */ @@ -206,8 +206,8 @@ object KinesisUtils { * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * * Note: - * The given AWS credentials will get saved in DStream checkpoints if checkpointing - * is enabled. Make sure that your checkpoint directory is secure. + * The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. * * @param jssc Java StreamingContext object * @param kinesisAppName Kinesis application name used by the Kinesis Client Library @@ -216,19 +216,19 @@ object KinesisUtils { * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) * @param regionName Name of region used by the Kinesis Client Library (KCL) to update * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) - * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) - * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) - * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. - * See the Kinesis Spark Streaming documentation for more - * details on the different types of checkpoints. * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the * worker's initial starting position in the stream. * The values are either the beginning of the stream * per Kinesis' limit of 24 hours * (InitialPositionInStream.TRIM_HORIZON) or * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. * @param storageLevel Storage level to use for storing the received objects. * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) + * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) */ def createStream( jssc: JavaStreamingContext, @@ -297,3 +297,49 @@ object KinesisUtils { } } } + +/** + * This is a helper class that wraps the methods in KinesisUtils into more Python-friendly class and + * function so that it can be easily instantiated and called from Python's KinesisUtils. + */ +private class KinesisUtilsPythonHelper { + + def getInitialPositionInStream(initialPositionInStream: Int): InitialPositionInStream = { + initialPositionInStream match { + case 0 => InitialPositionInStream.LATEST + case 1 => InitialPositionInStream.TRIM_HORIZON + case _ => throw new IllegalArgumentException( + "Illegal InitialPositionInStream. Please use " + + "InitialPositionInStream.LATEST or InitialPositionInStream.TRIM_HORIZON") + } + } + + def createStream( + jssc: JavaStreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: Int, + checkpointInterval: Duration, + storageLevel: StorageLevel, + awsAccessKeyId: String, + awsSecretKey: String + ): JavaReceiverInputDStream[Array[Byte]] = { + if (awsAccessKeyId == null && awsSecretKey != null) { + throw new IllegalArgumentException("awsSecretKey is set but awsAccessKeyId is null") + } + if (awsAccessKeyId != null && awsSecretKey == null) { + throw new IllegalArgumentException("awsAccessKeyId is set but awsSecretKey is null") + } + if (awsAccessKeyId == null && awsSecretKey == null) { + KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName, + getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel) + } else { + KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName, + getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel, + awsAccessKeyId, awsSecretKey) + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/BreezeUtil.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/BreezeUtil.scala new file mode 100644 index 0000000000000..7429f9d652ac5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/BreezeUtil.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.ann + +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV} +import com.github.fommil.netlib.BLAS.{getInstance => NativeBLAS} + +/** + * In-place DGEMM and DGEMV for Breeze + */ +private[ann] object BreezeUtil { + + // TODO: switch to MLlib BLAS interface + private def transposeString(a: BDM[Double]): String = if (a.isTranspose) "T" else "N" + + /** + * DGEMM: C := alpha * A * B + beta * C + * @param alpha alpha + * @param a A + * @param b B + * @param beta beta + * @param c C + */ + def dgemm(alpha: Double, a: BDM[Double], b: BDM[Double], beta: Double, c: BDM[Double]): Unit = { + // TODO: add code if matrices isTranspose!!! + require(a.cols == b.rows, "A & B Dimension mismatch!") + require(a.rows == c.rows, "A & C Dimension mismatch!") + require(b.cols == c.cols, "A & C Dimension mismatch!") + NativeBLAS.dgemm(transposeString(a), transposeString(b), c.rows, c.cols, a.cols, + alpha, a.data, a.offset, a.majorStride, b.data, b.offset, b.majorStride, + beta, c.data, c.offset, c.rows) + } + + /** + * DGEMV: y := alpha * A * x + beta * y + * @param alpha alpha + * @param a A + * @param x x + * @param beta beta + * @param y y + */ + def dgemv(alpha: Double, a: BDM[Double], x: BDV[Double], beta: Double, y: BDV[Double]): Unit = { + require(a.cols == x.length, "A & b Dimension mismatch!") + NativeBLAS.dgemv(transposeString(a), a.rows, a.cols, + alpha, a.data, a.offset, a.majorStride, x.data, x.offset, x.stride, + beta, y.data, y.offset, y.stride) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala new file mode 100644 index 0000000000000..b5258ff348477 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala @@ -0,0 +1,882 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.ann + +import breeze.linalg.{*, DenseMatrix => BDM, DenseVector => BDV, Vector => BV, axpy => Baxpy, + sum => Bsum} +import breeze.numerics.{log => Blog, sigmoid => Bsigmoid} + +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.optimization._ +import org.apache.spark.rdd.RDD +import org.apache.spark.util.random.XORShiftRandom + +/** + * Trait that holds Layer properties, that are needed to instantiate it. + * Implements Layer instantiation. + * + */ +private[ann] trait Layer extends Serializable { + /** + * Returns the instance of the layer based on weights provided + * @param weights vector with layer weights + * @param position position of weights in the vector + * @return the layer model + */ + def getInstance(weights: Vector, position: Int): LayerModel + + /** + * Returns the instance of the layer with random generated weights + * @param seed seed + * @return the layer model + */ + def getInstance(seed: Long): LayerModel +} + +/** + * Trait that holds Layer weights (or parameters). + * Implements functions needed for forward propagation, computing delta and gradient. + * Can return weights in Vector format. + */ +private[ann] trait LayerModel extends Serializable { + /** + * number of weights + */ + val size: Int + + /** + * Evaluates the data (process the data through the layer) + * @param data data + * @return processed data + */ + def eval(data: BDM[Double]): BDM[Double] + + /** + * Computes the delta for back propagation + * @param nextDelta delta of the next layer + * @param input input data + * @return delta + */ + def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double] + + /** + * Computes the gradient + * @param delta delta for this layer + * @param input input data + * @return gradient + */ + def grad(delta: BDM[Double], input: BDM[Double]): Array[Double] + + /** + * Returns weights for the layer in a single vector + * @return layer weights + */ + def weights(): Vector +} + +/** + * Layer properties of affine transformations, that is y=A*x+b + * @param numIn number of inputs + * @param numOut number of outputs + */ +private[ann] class AffineLayer(val numIn: Int, val numOut: Int) extends Layer { + + override def getInstance(weights: Vector, position: Int): LayerModel = { + AffineLayerModel(this, weights, position) + } + + override def getInstance(seed: Long = 11L): LayerModel = { + AffineLayerModel(this, seed) + } +} + +/** + * Model of Affine layer y=A*x+b + * @param w weights (matrix A) + * @param b bias (vector b) + */ +private[ann] class AffineLayerModel private(w: BDM[Double], b: BDV[Double]) extends LayerModel { + val size = w.size + b.length + val gwb = new Array[Double](size) + private lazy val gw: BDM[Double] = new BDM[Double](w.rows, w.cols, gwb) + private lazy val gb: BDV[Double] = new BDV[Double](gwb, w.size) + private var z: BDM[Double] = null + private var d: BDM[Double] = null + private var ones: BDV[Double] = null + + override def eval(data: BDM[Double]): BDM[Double] = { + if (z == null || z.cols != data.cols) z = new BDM[Double](w.rows, data.cols) + z(::, *) := b + BreezeUtil.dgemm(1.0, w, data, 1.0, z) + z + } + + override def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double] = { + if (d == null || d.cols != nextDelta.cols) d = new BDM[Double](w.cols, nextDelta.cols) + BreezeUtil.dgemm(1.0, w.t, nextDelta, 0.0, d) + d + } + + override def grad(delta: BDM[Double], input: BDM[Double]): Array[Double] = { + BreezeUtil.dgemm(1.0 / input.cols, delta, input.t, 0.0, gw) + if (ones == null || ones.length != delta.cols) ones = BDV.ones[Double](delta.cols) + BreezeUtil.dgemv(1.0 / input.cols, delta, ones, 0.0, gb) + gwb + } + + override def weights(): Vector = AffineLayerModel.roll(w, b) +} + +/** + * Fabric for Affine layer models + */ +private[ann] object AffineLayerModel { + + /** + * Creates a model of Affine layer + * @param layer layer properties + * @param weights vector with weights + * @param position position of weights in the vector + * @return model of Affine layer + */ + def apply(layer: AffineLayer, weights: Vector, position: Int): AffineLayerModel = { + val (w, b) = unroll(weights, position, layer.numIn, layer.numOut) + new AffineLayerModel(w, b) + } + + /** + * Creates a model of Affine layer + * @param layer layer properties + * @param seed seed + * @return model of Affine layer + */ + def apply(layer: AffineLayer, seed: Long): AffineLayerModel = { + val (w, b) = randomWeights(layer.numIn, layer.numOut, seed) + new AffineLayerModel(w, b) + } + + /** + * Unrolls the weights from the vector + * @param weights vector with weights + * @param position position of weights for this layer + * @param numIn number of layer inputs + * @param numOut number of layer outputs + * @return matrix A and vector b + */ + def unroll( + weights: Vector, + position: Int, + numIn: Int, + numOut: Int): (BDM[Double], BDV[Double]) = { + val weightsCopy = weights.toArray + // TODO: the array is not copied to BDMs, make sure this is OK! + val a = new BDM[Double](numOut, numIn, weightsCopy, position) + val b = new BDV[Double](weightsCopy, position + (numOut * numIn), 1, numOut) + (a, b) + } + + /** + * Roll the layer weights into a vector + * @param a matrix A + * @param b vector b + * @return vector of weights + */ + def roll(a: BDM[Double], b: BDV[Double]): Vector = { + val result = new Array[Double](a.size + b.length) + // TODO: make sure that we need to copy! + System.arraycopy(a.toArray, 0, result, 0, a.size) + System.arraycopy(b.toArray, 0, result, a.size, b.length) + Vectors.dense(result) + } + + /** + * Generate random weights for the layer + * @param numIn number of inputs + * @param numOut number of outputs + * @param seed seed + * @return (matrix A, vector b) + */ + def randomWeights(numIn: Int, numOut: Int, seed: Long = 11L): (BDM[Double], BDV[Double]) = { + val rand: XORShiftRandom = new XORShiftRandom(seed) + val weights = BDM.fill[Double](numOut, numIn){ (rand.nextDouble * 4.8 - 2.4) / numIn } + val bias = BDV.fill[Double](numOut){ (rand.nextDouble * 4.8 - 2.4) / numIn } + (weights, bias) + } +} + +/** + * Trait for functions and their derivatives for functional layers + */ +private[ann] trait ActivationFunction extends Serializable { + + /** + * Implements a function + * @param x input data + * @param y output data + */ + def eval(x: BDM[Double], y: BDM[Double]): Unit + + /** + * Implements a derivative of a function (needed for the back propagation) + * @param x input data + * @param y output data + */ + def derivative(x: BDM[Double], y: BDM[Double]): Unit + + /** + * Implements a cross entropy error of a function. + * Needed if the functional layer that contains this function is the output layer + * of the network. + * @param target target output + * @param output computed output + * @param result intermediate result + * @return cross-entropy + */ + def crossEntropy(target: BDM[Double], output: BDM[Double], result: BDM[Double]): Double + + /** + * Implements a mean squared error of a function + * @param target target output + * @param output computed output + * @param result intermediate result + * @return mean squared error + */ + def squared(target: BDM[Double], output: BDM[Double], result: BDM[Double]): Double +} + +/** + * Implements in-place application of functions + */ +private[ann] object ActivationFunction { + + def apply(x: BDM[Double], y: BDM[Double], func: Double => Double): Unit = { + var i = 0 + while (i < x.rows) { + var j = 0 + while (j < x.cols) { + y(i, j) = func(x(i, j)) + j += 1 + } + i += 1 + } + } + + def apply( + x1: BDM[Double], + x2: BDM[Double], + y: BDM[Double], + func: (Double, Double) => Double): Unit = { + var i = 0 + while (i < x1.rows) { + var j = 0 + while (j < x1.cols) { + y(i, j) = func(x1(i, j), x2(i, j)) + j += 1 + } + i += 1 + } + } +} + +/** + * Implements SoftMax activation function + */ +private[ann] class SoftmaxFunction extends ActivationFunction { + override def eval(x: BDM[Double], y: BDM[Double]): Unit = { + var j = 0 + // find max value to make sure later that exponent is computable + while (j < x.cols) { + var i = 0 + var max = Double.MinValue + while (i < x.rows) { + if (x(i, j) > max) { + max = x(i, j) + } + i += 1 + } + var sum = 0.0 + i = 0 + while (i < x.rows) { + val res = Math.exp(x(i, j) - max) + y(i, j) = res + sum += res + i += 1 + } + i = 0 + while (i < x.rows) { + y(i, j) /= sum + i += 1 + } + j += 1 + } + } + + override def crossEntropy( + output: BDM[Double], + target: BDM[Double], + result: BDM[Double]): Double = { + def m(o: Double, t: Double): Double = o - t + ActivationFunction(output, target, result, m) + -Bsum( target :* Blog(output)) / output.cols + } + + override def derivative(x: BDM[Double], y: BDM[Double]): Unit = { + def sd(z: Double): Double = (1 - z) * z + ActivationFunction(x, y, sd) + } + + override def squared(output: BDM[Double], target: BDM[Double], result: BDM[Double]): Double = { + throw new UnsupportedOperationException("Sorry, squared error is not defined for SoftMax.") + } +} + +/** + * Implements Sigmoid activation function + */ +private[ann] class SigmoidFunction extends ActivationFunction { + override def eval(x: BDM[Double], y: BDM[Double]): Unit = { + def s(z: Double): Double = Bsigmoid(z) + ActivationFunction(x, y, s) + } + + override def crossEntropy( + output: BDM[Double], + target: BDM[Double], + result: BDM[Double]): Double = { + def m(o: Double, t: Double): Double = o - t + ActivationFunction(output, target, result, m) + -Bsum(target :* Blog(output)) / output.cols + } + + override def derivative(x: BDM[Double], y: BDM[Double]): Unit = { + def sd(z: Double): Double = (1 - z) * z + ActivationFunction(x, y, sd) + } + + override def squared(output: BDM[Double], target: BDM[Double], result: BDM[Double]): Double = { + // TODO: make it readable + def m(o: Double, t: Double): Double = (o - t) + ActivationFunction(output, target, result, m) + val e = Bsum(result :* result) / 2 / output.cols + def m2(x: Double, o: Double) = x * (o - o * o) + ActivationFunction(result, output, result, m2) + e + } +} + +/** + * Functional layer properties, y = f(x) + * @param activationFunction activation function + */ +private[ann] class FunctionalLayer (val activationFunction: ActivationFunction) extends Layer { + override def getInstance(weights: Vector, position: Int): LayerModel = getInstance(0L) + + override def getInstance(seed: Long): LayerModel = + FunctionalLayerModel(this) +} + +/** + * Functional layer model. Holds no weights. + * @param activationFunction activation function + */ +private[ann] class FunctionalLayerModel private (val activationFunction: ActivationFunction) + extends LayerModel { + val size = 0 + // matrices for in-place computations + // outputs + private var f: BDM[Double] = null + // delta + private var d: BDM[Double] = null + // matrix for error computation + private var e: BDM[Double] = null + // delta gradient + private lazy val dg = new Array[Double](0) + + override def eval(data: BDM[Double]): BDM[Double] = { + if (f == null || f.cols != data.cols) f = new BDM[Double](data.rows, data.cols) + activationFunction.eval(data, f) + f + } + + override def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double] = { + if (d == null || d.cols != nextDelta.cols) d = new BDM[Double](nextDelta.rows, nextDelta.cols) + activationFunction.derivative(input, d) + d :*= nextDelta + d + } + + override def grad(delta: BDM[Double], input: BDM[Double]): Array[Double] = dg + + override def weights(): Vector = Vectors.dense(new Array[Double](0)) + + def crossEntropy(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = { + if (e == null || e.cols != output.cols) e = new BDM[Double](output.rows, output.cols) + val error = activationFunction.crossEntropy(output, target, e) + (e, error) + } + + def squared(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = { + if (e == null || e.cols != output.cols) e = new BDM[Double](output.rows, output.cols) + val error = activationFunction.squared(output, target, e) + (e, error) + } + + def error(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = { + // TODO: allow user pick error + activationFunction match { + case sigmoid: SigmoidFunction => squared(output, target) + case softmax: SoftmaxFunction => crossEntropy(output, target) + } + } +} + +/** + * Fabric of functional layer models + */ +private[ann] object FunctionalLayerModel { + def apply(layer: FunctionalLayer): FunctionalLayerModel = + new FunctionalLayerModel(layer.activationFunction) +} + +/** + * Trait for the artificial neural network (ANN) topology properties + */ +private[ann] trait Topology extends Serializable{ + def getInstance(weights: Vector): TopologyModel + def getInstance(seed: Long): TopologyModel +} + +/** + * Trait for ANN topology model + */ +private[ann] trait TopologyModel extends Serializable{ + /** + * Forward propagation + * @param data input data + * @return array of outputs for each of the layers + */ + def forward(data: BDM[Double]): Array[BDM[Double]] + + /** + * Prediction of the model + * @param data input data + * @return prediction + */ + def predict(data: Vector): Vector + + /** + * Computes gradient for the network + * @param data input data + * @param target target output + * @param cumGradient cumulative gradient + * @param blockSize block size + * @return error + */ + def computeGradient(data: BDM[Double], target: BDM[Double], cumGradient: Vector, + blockSize: Int): Double + + /** + * Returns the weights of the ANN + * @return weights + */ + def weights(): Vector +} + +/** + * Feed forward ANN + * @param layers + */ +private[ann] class FeedForwardTopology private(val layers: Array[Layer]) extends Topology { + override def getInstance(weights: Vector): TopologyModel = FeedForwardModel(this, weights) + + override def getInstance(seed: Long): TopologyModel = FeedForwardModel(this, seed) +} + +/** + * Factory for some of the frequently-used topologies + */ +private[ml] object FeedForwardTopology { + /** + * Creates a feed forward topology from the array of layers + * @param layers array of layers + * @return feed forward topology + */ + def apply(layers: Array[Layer]): FeedForwardTopology = { + new FeedForwardTopology(layers) + } + + /** + * Creates a multi-layer perceptron + * @param layerSizes sizes of layers including input and output size + * @param softmax wether to use SoftMax or Sigmoid function for an output layer. + * Softmax is default + * @return multilayer perceptron topology + */ + def multiLayerPerceptron(layerSizes: Array[Int], softmax: Boolean = true): FeedForwardTopology = { + val layers = new Array[Layer]((layerSizes.length - 1) * 2) + for(i <- 0 until layerSizes.length - 1){ + layers(i * 2) = new AffineLayer(layerSizes(i), layerSizes(i + 1)) + layers(i * 2 + 1) = + if (softmax && i == layerSizes.length - 2) { + new FunctionalLayer(new SoftmaxFunction()) + } else { + new FunctionalLayer(new SigmoidFunction()) + } + } + FeedForwardTopology(layers) + } +} + +/** + * Model of Feed Forward Neural Network. + * Implements forward, gradient computation and can return weights in vector format. + * @param layerModels models of layers + * @param topology topology of the network + */ +private[ml] class FeedForwardModel private( + val layerModels: Array[LayerModel], + val topology: FeedForwardTopology) extends TopologyModel { + override def forward(data: BDM[Double]): Array[BDM[Double]] = { + val outputs = new Array[BDM[Double]](layerModels.length) + outputs(0) = layerModels(0).eval(data) + for (i <- 1 until layerModels.length) { + outputs(i) = layerModels(i).eval(outputs(i-1)) + } + outputs + } + + override def computeGradient( + data: BDM[Double], + target: BDM[Double], + cumGradient: Vector, + realBatchSize: Int): Double = { + val outputs = forward(data) + val deltas = new Array[BDM[Double]](layerModels.length) + val L = layerModels.length - 1 + val (newE, newError) = layerModels.last match { + case flm: FunctionalLayerModel => flm.error(outputs.last, target) + case _ => + throw new UnsupportedOperationException("Non-functional layer not supported at the top") + } + deltas(L) = new BDM[Double](0, 0) + deltas(L - 1) = newE + for (i <- (L - 2) to (0, -1)) { + deltas(i) = layerModels(i + 1).prevDelta(deltas(i + 1), outputs(i + 1)) + } + val grads = new Array[Array[Double]](layerModels.length) + for (i <- 0 until layerModels.length) { + val input = if (i==0) data else outputs(i - 1) + grads(i) = layerModels(i).grad(deltas(i), input) + } + // update cumGradient + val cumGradientArray = cumGradient.toArray + var offset = 0 + // TODO: extract roll + for (i <- 0 until grads.length) { + val gradArray = grads(i) + var k = 0 + while (k < gradArray.length) { + cumGradientArray(offset + k) += gradArray(k) + k += 1 + } + offset += gradArray.length + } + newError + } + + // TODO: do we really need to copy the weights? they should be read-only + override def weights(): Vector = { + // TODO: extract roll + var size = 0 + for (i <- 0 until layerModels.length) { + size += layerModels(i).size + } + val array = new Array[Double](size) + var offset = 0 + for (i <- 0 until layerModels.length) { + val layerWeights = layerModels(i).weights().toArray + System.arraycopy(layerWeights, 0, array, offset, layerWeights.length) + offset += layerWeights.length + } + Vectors.dense(array) + } + + override def predict(data: Vector): Vector = { + val size = data.size + val result = forward(new BDM[Double](size, 1, data.toArray)) + Vectors.dense(result.last.toArray) + } +} + +/** + * Fabric for feed forward ANN models + */ +private[ann] object FeedForwardModel { + + /** + * Creates a model from a topology and weights + * @param topology topology + * @param weights weights + * @return model + */ + def apply(topology: FeedForwardTopology, weights: Vector): FeedForwardModel = { + val layers = topology.layers + val layerModels = new Array[LayerModel](layers.length) + var offset = 0 + for (i <- 0 until layers.length) { + layerModels(i) = layers(i).getInstance(weights, offset) + offset += layerModels(i).size + } + new FeedForwardModel(layerModels, topology) + } + + /** + * Creates a model given a topology and seed + * @param topology topology + * @param seed seed for generating the weights + * @return model + */ + def apply(topology: FeedForwardTopology, seed: Long = 11L): FeedForwardModel = { + val layers = topology.layers + val layerModels = new Array[LayerModel](layers.length) + var offset = 0 + for(i <- 0 until layers.length){ + layerModels(i) = layers(i).getInstance(seed) + offset += layerModels(i).size + } + new FeedForwardModel(layerModels, topology) + } +} + +/** + * Neural network gradient. Does nothing but calling Model's gradient + * @param topology topology + * @param dataStacker data stacker + */ +private[ann] class ANNGradient(topology: Topology, dataStacker: DataStacker) extends Gradient { + + override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { + val gradient = Vectors.zeros(weights.size) + val loss = compute(data, label, weights, gradient) + (gradient, loss) + } + + override def compute( + data: Vector, + label: Double, + weights: Vector, + cumGradient: Vector): Double = { + val (input, target, realBatchSize) = dataStacker.unstack(data) + val model = topology.getInstance(weights) + model.computeGradient(input, target, cumGradient, realBatchSize) + } +} + +/** + * Stacks pairs of training samples (input, output) in one vector allowing them to pass + * through Optimizer/Gradient interfaces. If stackSize is more than one, makes blocks + * or matrices of inputs and outputs and then stack them in one vector. + * This can be used for further batch computations after unstacking. + * @param stackSize stack size + * @param inputSize size of the input vectors + * @param outputSize size of the output vectors + */ +private[ann] class DataStacker(stackSize: Int, inputSize: Int, outputSize: Int) + extends Serializable { + + /** + * Stacks the data + * @param data RDD of vector pairs + * @return RDD of double (always zero) and vector that contains the stacked vectors + */ + def stack(data: RDD[(Vector, Vector)]): RDD[(Double, Vector)] = { + val stackedData = if (stackSize == 1) { + data.map { v => + (0.0, + Vectors.fromBreeze(BDV.vertcat( + v._1.toBreeze.toDenseVector, + v._2.toBreeze.toDenseVector)) + ) } + } else { + data.mapPartitions { it => + it.grouped(stackSize).map { seq => + val size = seq.size + val bigVector = new Array[Double](inputSize * size + outputSize * size) + var i = 0 + seq.foreach { case (in, out) => + System.arraycopy(in.toArray, 0, bigVector, i * inputSize, inputSize) + System.arraycopy(out.toArray, 0, bigVector, + inputSize * size + i * outputSize, outputSize) + i += 1 + } + (0.0, Vectors.dense(bigVector)) + } + } + } + stackedData + } + + /** + * Unstack the stacked vectors into matrices for batch operations + * @param data stacked vector + * @return pair of matrices holding input and output data and the real stack size + */ + def unstack(data: Vector): (BDM[Double], BDM[Double], Int) = { + val arrData = data.toArray + val realStackSize = arrData.length / (inputSize + outputSize) + val input = new BDM(inputSize, realStackSize, arrData) + val target = new BDM(outputSize, realStackSize, arrData, inputSize * realStackSize) + (input, target, realStackSize) + } +} + +/** + * Simple updater + */ +private[ann] class ANNUpdater extends Updater { + + override def compute( + weightsOld: Vector, + gradient: Vector, + stepSize: Double, + iter: Int, + regParam: Double): (Vector, Double) = { + val thisIterStepSize = stepSize + val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector + Baxpy(-thisIterStepSize, gradient.toBreeze, brzWeights) + (Vectors.fromBreeze(brzWeights), 0) + } +} + +/** + * MLlib-style trainer class that trains a network given the data and topology + * @param topology topology of ANN + * @param inputSize input size + * @param outputSize output size + */ +private[ml] class FeedForwardTrainer( + topology: Topology, + val inputSize: Int, + val outputSize: Int) extends Serializable { + + // TODO: what if we need to pass random seed? + private var _weights = topology.getInstance(11L).weights() + private var _stackSize = 128 + private var dataStacker = new DataStacker(_stackSize, inputSize, outputSize) + private var _gradient: Gradient = new ANNGradient(topology, dataStacker) + private var _updater: Updater = new ANNUpdater() + private var optimizer: Optimizer = LBFGSOptimizer.setConvergenceTol(1e-4).setNumIterations(100) + + /** + * Returns weights + * @return weights + */ + def getWeights: Vector = _weights + + /** + * Sets weights + * @param value weights + * @return trainer + */ + def setWeights(value: Vector): FeedForwardTrainer = { + _weights = value + this + } + + /** + * Sets the stack size + * @param value stack size + * @return trainer + */ + def setStackSize(value: Int): FeedForwardTrainer = { + _stackSize = value + dataStacker = new DataStacker(value, inputSize, outputSize) + this + } + + /** + * Sets the SGD optimizer + * @return SGD optimizer + */ + def SGDOptimizer: GradientDescent = { + val sgd = new GradientDescent(_gradient, _updater) + optimizer = sgd + sgd + } + + /** + * Sets the LBFGS optimizer + * @return LBGS optimizer + */ + def LBFGSOptimizer: LBFGS = { + val lbfgs = new LBFGS(_gradient, _updater) + optimizer = lbfgs + lbfgs + } + + /** + * Sets the updater + * @param value updater + * @return trainer + */ + def setUpdater(value: Updater): FeedForwardTrainer = { + _updater = value + updateUpdater(value) + this + } + + /** + * Sets the gradient + * @param value gradient + * @return trainer + */ + def setGradient(value: Gradient): FeedForwardTrainer = { + _gradient = value + updateGradient(value) + this + } + + private[this] def updateGradient(gradient: Gradient): Unit = { + optimizer match { + case lbfgs: LBFGS => lbfgs.setGradient(gradient) + case sgd: GradientDescent => sgd.setGradient(gradient) + case other => throw new UnsupportedOperationException( + s"Only LBFGS and GradientDescent are supported but got ${other.getClass}.") + } + } + + private[this] def updateUpdater(updater: Updater): Unit = { + optimizer match { + case lbfgs: LBFGS => lbfgs.setUpdater(updater) + case sgd: GradientDescent => sgd.setUpdater(updater) + case other => throw new UnsupportedOperationException( + s"Only LBFGS and GradientDescent are supported but got ${other.getClass}.") + } + } + + /** + * Trains the ANN + * @param data RDD of input and output vector pairs + * @return model + */ + def train(data: RDD[(Vector, Vector)]): TopologyModel = { + val newWeights = optimizer.optimize(dataStacker.stack(data), getWeights) + topology.getInstance(newWeights) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 36fe1bd40469c..f27cfd0331419 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -18,12 +18,11 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams} import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util.{Identifiable, MetadataUtils} -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} @@ -39,7 +38,7 @@ import org.apache.spark.sql.DataFrame */ @Experimental final class DecisionTreeClassifier(override val uid: String) - extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] + extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] with DecisionTreeParams with TreeClassifierParams { def this() = this(Identifiable.randomUID("dtc")) @@ -106,8 +105,9 @@ object DecisionTreeClassifier { @Experimental final class DecisionTreeClassificationModel private[ml] ( override val uid: String, - override val rootNode: Node) - extends PredictionModel[Vector, DecisionTreeClassificationModel] + override val rootNode: Node, + override val numClasses: Int) + extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel] with DecisionTreeModel with Serializable { require(rootNode != null, @@ -117,14 +117,36 @@ final class DecisionTreeClassificationModel private[ml] ( * Construct a decision tree classification model. * @param rootNode Root node of tree, with other nodes attached. */ - def this(rootNode: Node) = this(Identifiable.randomUID("dtc"), rootNode) + def this(rootNode: Node, numClasses: Int) = + this(Identifiable.randomUID("dtc"), rootNode, numClasses) override protected def predict(features: Vector): Double = { - rootNode.predict(features) + rootNode.predictImpl(features).prediction + } + + override protected def predictRaw(features: Vector): Vector = { + Vectors.dense(rootNode.predictImpl(features).impurityStats.stats.clone()) + } + + override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { + rawPrediction match { + case dv: DenseVector => + var i = 0 + val size = dv.size + val sum = dv.values.sum + while (i < size) { + dv.values(i) = if (sum != 0) dv.values(i) / sum else 0.0 + i += 1 + } + dv + case sv: SparseVector => + throw new RuntimeException("Unexpected error in DecisionTreeClassificationModel:" + + " raw2probabilityInPlace encountered SparseVector") + } } override def copy(extra: ParamMap): DecisionTreeClassificationModel = { - copyValues(new DecisionTreeClassificationModel(uid, rootNode), extra) + copyValues(new DecisionTreeClassificationModel(uid, rootNode, numClasses), extra) } override def toString: String = { @@ -149,6 +171,6 @@ private[ml] object DecisionTreeClassificationModel { s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}") val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc") - new DecisionTreeClassificationModel(uid, rootNode) + new DecisionTreeClassificationModel(uid, rootNode, -1) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index eb0b1a0a405fc..c3891a9599262 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -190,7 +190,7 @@ final class GBTClassificationModel( override protected def predict(features: Vector): Double = { // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 // Classifies by thresholding sum of weighted tree predictions - val treePredictions = _trees.map(_.rootNode.predict(features)) + val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction) val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) if (prediction > 0.0) 1.0 else 0.0 } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala new file mode 100644 index 0000000000000..8cd2103d7d5e6 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.param.shared.{HasTol, HasMaxIter, HasSeed} +import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor} +import org.apache.spark.ml.param.{IntParam, ParamValidators, IntArrayParam, ParamMap} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.ann.{FeedForwardTrainer, FeedForwardTopology} +import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.sql.DataFrame + +/** Params for Multilayer Perceptron. */ +private[ml] trait MultilayerPerceptronParams extends PredictorParams + with HasSeed with HasMaxIter with HasTol { + /** + * Layer sizes including input size and output size. + * @group param + */ + final val layers: IntArrayParam = new IntArrayParam(this, "layers", + "Sizes of layers from input layer to output layer" + + " E.g., Array(780, 100, 10) means 780 inputs, " + + "one hidden layer with 100 neurons and output layer of 10 neurons.", + // TODO: how to check ALSO that all elements are greater than 0? + ParamValidators.arrayLengthGt(1) + ) + + /** @group setParam */ + def setLayers(value: Array[Int]): this.type = set(layers, value) + + /** @group getParam */ + final def getLayers: Array[Int] = $(layers) + + /** + * Block size for stacking input data in matrices to speed up the computation. + * Data is stacked within partitions. If block size is more than remaining data in + * a partition then it is adjusted to the size of this data. + * Recommended size is between 10 and 1000. + * @group expertParam + */ + final val blockSize: IntParam = new IntParam(this, "blockSize", + "Block size for stacking input data in matrices. Data is stacked within partitions." + + " If block size is more than remaining data in a partition then " + + "it is adjusted to the size of this data. Recommended size is between 10 and 1000", + ParamValidators.gt(0)) + + /** @group setParam */ + def setBlockSize(value: Int): this.type = set(blockSize, value) + + /** @group getParam */ + final def getBlockSize: Int = $(blockSize) + + /** + * Set the maximum number of iterations. + * Default is 100. + * @group setParam + */ + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** + * Set the convergence tolerance of iterations. + * Smaller value will lead to higher accuracy with the cost of more iterations. + * Default is 1E-4. + * @group setParam + */ + def setTol(value: Double): this.type = set(tol, value) + + /** + * Set the seed for weights initialization. + * @group setParam + */ + def setSeed(value: Long): this.type = set(seed, value) + + setDefault(maxIter -> 100, tol -> 1e-4, layers -> Array(1, 1), blockSize -> 128) +} + +/** Label to vector converter. */ +private object LabelConverter { + // TODO: Use OneHotEncoder instead + /** + * Encodes a label as a vector. + * Returns a vector of given length with zeroes at all positions + * and value 1.0 at the position that corresponds to the label. + * + * @param labeledPoint labeled point + * @param labelCount total number of labels + * @return pair of features and vector encoding of a label + */ + def encodeLabeledPoint(labeledPoint: LabeledPoint, labelCount: Int): (Vector, Vector) = { + val output = Array.fill(labelCount)(0.0) + output(labeledPoint.label.toInt) = 1.0 + (labeledPoint.features, Vectors.dense(output)) + } + + /** + * Converts a vector to a label. + * Returns the position of the maximal element of a vector. + * + * @param output label encoded with a vector + * @return label + */ + def decodeLabel(output: Vector): Double = { + output.argmax.toDouble + } +} + +/** + * :: Experimental :: + * Classifier trainer based on the Multilayer Perceptron. + * Each layer has sigmoid activation function, output layer has softmax. + * Number of inputs has to be equal to the size of feature vectors. + * Number of outputs has to be equal to the total number of labels. + * + */ +@Experimental +class MultilayerPerceptronClassifier(override val uid: String) + extends Predictor[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassifierModel] + with MultilayerPerceptronParams { + + def this() = this(Identifiable.randomUID("mlpc")) + + override def copy(extra: ParamMap): MultilayerPerceptronClassifier = defaultCopy(extra) + + /** + * Train a model using the given dataset and parameters. + * Developers can implement this instead of [[fit()]] to avoid dealing with schema validation + * and copying parameters into the model. + * + * @param dataset Training dataset + * @return Fitted model + */ + override protected def train(dataset: DataFrame): MultilayerPerceptronClassifierModel = { + val myLayers = $(layers) + val labels = myLayers.last + val lpData = extractLabeledPoints(dataset) + val data = lpData.map(lp => LabelConverter.encodeLabeledPoint(lp, labels)) + val topology = FeedForwardTopology.multiLayerPerceptron(myLayers, true) + val FeedForwardTrainer = new FeedForwardTrainer(topology, myLayers(0), myLayers.last) + FeedForwardTrainer.LBFGSOptimizer.setConvergenceTol($(tol)).setNumIterations($(maxIter)) + FeedForwardTrainer.setStackSize($(blockSize)) + val mlpModel = FeedForwardTrainer.train(data) + new MultilayerPerceptronClassifierModel(uid, myLayers, mlpModel.weights()) + } +} + +/** + * :: Experimental :: + * Classifier model based on the Multilayer Perceptron. + * Each layer has sigmoid activation function, output layer has softmax. + * @param uid uid + * @param layers array of layer sizes including input and output layers + * @param weights vector of initial weights for the model that consists of the weights of layers + * @return prediction model + */ +@Experimental +class MultilayerPerceptronClassifierModel private[ml] ( + override val uid: String, + layers: Array[Int], + weights: Vector) + extends PredictionModel[Vector, MultilayerPerceptronClassifierModel] + with Serializable { + + private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights) + + /** + * Predict label for the given features. + * This internal method is used to implement [[transform()]] and output [[predictionCol]]. + */ + override protected def predict(features: Vector): Double = { + LabelConverter.decodeLabel(mlpModel.predict(features)) + } + + override def copy(extra: ParamMap): MultilayerPerceptronClassifierModel = { + copyValues(new MultilayerPerceptronClassifierModel(uid, layers, weights), extra) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 1f547e4a98af7..b46b676204e0e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -38,11 +38,11 @@ private[ml] trait NaiveBayesParams extends PredictorParams { * (default = 1.0). * @group param */ - final val lambda: DoubleParam = new DoubleParam(this, "lambda", "The smoothing parameter.", + final val smoothing: DoubleParam = new DoubleParam(this, "smoothing", "The smoothing parameter.", ParamValidators.gtEq(0)) /** @group getParam */ - final def getLambda: Double = $(lambda) + final def getSmoothing: Double = $(smoothing) /** * The model type which is a string (case-sensitive). @@ -69,7 +69,7 @@ private[ml] trait NaiveBayesParams extends PredictorParams { * The input feature values must be nonnegative. */ class NaiveBayes(override val uid: String) - extends Predictor[Vector, NaiveBayes, NaiveBayesModel] + extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel] with NaiveBayesParams { def this() = this(Identifiable.randomUID("nb")) @@ -79,8 +79,8 @@ class NaiveBayes(override val uid: String) * Default is 1.0. * @group setParam */ - def setLambda(value: Double): this.type = set(lambda, value) - setDefault(lambda -> 1.0) + def setSmoothing(value: Double): this.type = set(smoothing, value) + setDefault(smoothing -> 1.0) /** * Set the model type using a string (case-sensitive). @@ -92,7 +92,7 @@ class NaiveBayes(override val uid: String) override protected def train(dataset: DataFrame): NaiveBayesModel = { val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) - val oldModel = OldNaiveBayes.train(oldDataset, $(lambda), $(modelType)) + val oldModel = OldNaiveBayes.train(oldDataset, $(smoothing), $(modelType)) NaiveBayesModel.fromOld(oldModel, this) } @@ -106,7 +106,7 @@ class NaiveBayesModel private[ml] ( override val uid: String, val pi: Vector, val theta: Matrix) - extends PredictionModel[Vector, NaiveBayesModel] with NaiveBayesParams { + extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] with NaiveBayesParams { import OldNaiveBayes.{Bernoulli, Multinomial} @@ -129,29 +129,62 @@ class NaiveBayesModel private[ml] ( throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") } - override protected def predict(features: Vector): Double = { + override val numClasses: Int = pi.size + + private def multinomialCalculation(features: Vector) = { + val prob = theta.multiply(features) + BLAS.axpy(1.0, pi, prob) + prob + } + + private def bernoulliCalculation(features: Vector) = { + features.foreachActive((_, value) => + if (value != 0.0 && value != 1.0) { + throw new SparkException( + s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features.") + } + ) + val prob = thetaMinusNegTheta.get.multiply(features) + BLAS.axpy(1.0, pi, prob) + BLAS.axpy(1.0, negThetaSum.get, prob) + prob + } + + override protected def predictRaw(features: Vector): Vector = { $(modelType) match { case Multinomial => - val prob = theta.multiply(features) - BLAS.axpy(1.0, pi, prob) - prob.argmax + multinomialCalculation(features) case Bernoulli => - features.foreachActive{ (index, value) => - if (value != 0.0 && value != 1.0) { - throw new SparkException( - s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features") - } - } - val prob = thetaMinusNegTheta.get.multiply(features) - BLAS.axpy(1.0, pi, prob) - BLAS.axpy(1.0, negThetaSum.get, prob) - prob.argmax + bernoulliCalculation(features) case _ => // This should never happen. throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") } } + override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { + rawPrediction match { + case dv: DenseVector => + var i = 0 + val size = dv.size + val maxLog = dv.values.max + while (i < size) { + dv.values(i) = math.exp(dv.values(i) - maxLog) + i += 1 + } + val probSum = dv.values.sum + i = 0 + while (i < size) { + dv.values(i) = dv.values(i) / probSum + i += 1 + } + dv + case sv: SparseVector => + throw new RuntimeException("Unexpected error in NaiveBayesModel:" + + " raw2probabilityInPlace encountered SparseVector") + } + } + override def copy(extra: ParamMap): NaiveBayesModel = { copyValues(new NaiveBayesModel(uid, pi, theta).setParent(this.parent), extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index bc19bd6df894f..0c7eb4a662fdb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -160,7 +160,7 @@ final class RandomForestClassificationModel private[ml] ( // Ignore the weights since all are 1.0 for now. val votes = new Array[Double](numClasses) _trees.view.foreach { tree => - val prediction = tree.rootNode.predict(features).toInt + val prediction = tree.rootNode.predictImpl(features).prediction.toInt votes(prediction) = votes(prediction) + 1.0 // 1.0 = weight } Vectors.dense(votes) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala new file mode 100644 index 0000000000000..44f779c1908d7 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param} +import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} +import org.apache.spark.ml.util.{SchemaUtils, Identifiable} +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.sql.types.DoubleType + +/** + * :: Experimental :: + * Evaluator for multiclass classification, which expects two input columns: score and label. + */ +@Experimental +class MulticlassClassificationEvaluator (override val uid: String) + extends Evaluator with HasPredictionCol with HasLabelCol { + + def this() = this(Identifiable.randomUID("mcEval")) + + /** + * param for metric name in evaluation (supports `"f1"` (default), `"precision"`, `"recall"`, + * `"weightedPrecision"`, `"weightedRecall"`) + * @group param + */ + val metricName: Param[String] = { + val allowedParams = ParamValidators.inArray(Array("f1", "precision", + "recall", "weightedPrecision", "weightedRecall")) + new Param(this, "metricName", "metric name in evaluation " + + "(f1|precision|recall|weightedPrecision|weightedRecall)", allowedParams) + } + + /** @group getParam */ + def getMetricName: String = $(metricName) + + /** @group setParam */ + def setMetricName(value: String): this.type = set(metricName, value) + + /** @group setParam */ + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + def setLabelCol(value: String): this.type = set(labelCol, value) + + setDefault(metricName -> "f1") + + override def evaluate(dataset: DataFrame): Double = { + val schema = dataset.schema + SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType) + SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + + val predictionAndLabels = dataset.select($(predictionCol), $(labelCol)) + .map { case Row(prediction: Double, label: Double) => + (prediction, label) + } + val metrics = new MulticlassMetrics(predictionAndLabels) + val metric = $(metricName) match { + case "f1" => metrics.weightedFMeasure + case "precision" => metrics.precision + case "recall" => metrics.recall + case "weightedPrecision" => metrics.weightedPrecision + case "weightedRecall" => metrics.weightedRecall + } + metric + } + + override def copy(extra: ParamMap): MulticlassClassificationEvaluator = defaultCopy(extra) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala new file mode 100644 index 0000000000000..3cc41424460f2 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.param.{ParamMap, BooleanParam, Param} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.types.{StringType, StructField, ArrayType, StructType} +import org.apache.spark.sql.functions.{col, udf} + +/** + * stop words list + */ +private object StopWords { + + /** + * Use the same default stopwords list as scikit-learn. + * The original list can be found from "Glasgow Information Retrieval Group" + * [[http://ir.dcs.gla.ac.uk/resources/linguistic_utils/stop_words]] + */ + val EnglishStopWords = Array( "a", "about", "above", "across", "after", "afterwards", "again", + "against", "all", "almost", "alone", "along", "already", "also", "although", "always", + "am", "among", "amongst", "amoungst", "amount", "an", "and", "another", + "any", "anyhow", "anyone", "anything", "anyway", "anywhere", "are", + "around", "as", "at", "back", "be", "became", "because", "become", + "becomes", "becoming", "been", "before", "beforehand", "behind", "being", + "below", "beside", "besides", "between", "beyond", "bill", "both", + "bottom", "but", "by", "call", "can", "cannot", "cant", "co", "con", + "could", "couldnt", "cry", "de", "describe", "detail", "do", "done", + "down", "due", "during", "each", "eg", "eight", "either", "eleven", "else", + "elsewhere", "empty", "enough", "etc", "even", "ever", "every", "everyone", + "everything", "everywhere", "except", "few", "fifteen", "fify", "fill", + "find", "fire", "first", "five", "for", "former", "formerly", "forty", + "found", "four", "from", "front", "full", "further", "get", "give", "go", + "had", "has", "hasnt", "have", "he", "hence", "her", "here", "hereafter", + "hereby", "herein", "hereupon", "hers", "herself", "him", "himself", "his", + "how", "however", "hundred", "i", "ie", "if", "in", "inc", "indeed", + "interest", "into", "is", "it", "its", "itself", "keep", "last", "latter", + "latterly", "least", "less", "ltd", "made", "many", "may", "me", + "meanwhile", "might", "mill", "mine", "more", "moreover", "most", "mostly", + "move", "much", "must", "my", "myself", "name", "namely", "neither", + "never", "nevertheless", "next", "nine", "no", "nobody", "none", "noone", + "nor", "not", "nothing", "now", "nowhere", "of", "off", "often", "on", + "once", "one", "only", "onto", "or", "other", "others", "otherwise", "our", + "ours", "ourselves", "out", "over", "own", "part", "per", "perhaps", + "please", "put", "rather", "re", "same", "see", "seem", "seemed", + "seeming", "seems", "serious", "several", "she", "should", "show", "side", + "since", "sincere", "six", "sixty", "so", "some", "somehow", "someone", + "something", "sometime", "sometimes", "somewhere", "still", "such", + "system", "take", "ten", "than", "that", "the", "their", "them", + "themselves", "then", "thence", "there", "thereafter", "thereby", + "therefore", "therein", "thereupon", "these", "they", "thick", "thin", + "third", "this", "those", "though", "three", "through", "throughout", + "thru", "thus", "to", "together", "too", "top", "toward", "towards", + "twelve", "twenty", "two", "un", "under", "until", "up", "upon", "us", + "very", "via", "was", "we", "well", "were", "what", "whatever", "when", + "whence", "whenever", "where", "whereafter", "whereas", "whereby", + "wherein", "whereupon", "wherever", "whether", "which", "while", "whither", + "who", "whoever", "whole", "whom", "whose", "why", "will", "with", + "within", "without", "would", "yet", "you", "your", "yours", "yourself", "yourselves") +} + +/** + * :: Experimental :: + * A feature transformer that filters out stop words from input. + * Note: null values from input array are preserved unless adding null to stopWords explicitly. + * @see [[http://en.wikipedia.org/wiki/Stop_words]] + */ +@Experimental +class StopWordsRemover(override val uid: String) + extends Transformer with HasInputCol with HasOutputCol { + + def this() = this(Identifiable.randomUID("stopWords")) + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** + * the stop words set to be filtered out + * @group param + */ + val stopWords: Param[Array[String]] = new Param(this, "stopWords", "stop words") + + /** @group setParam */ + def setStopWords(value: Array[String]): this.type = set(stopWords, value) + + /** @group getParam */ + def getStopWords: Array[String] = $(stopWords) + + /** + * whether to do a case sensitive comparison over the stop words + * @group param + */ + val caseSensitive: BooleanParam = new BooleanParam(this, "caseSensitive", + "whether to do case-sensitive comparison during filtering") + + /** @group setParam */ + def setCaseSensitive(value: Boolean): this.type = set(caseSensitive, value) + + /** @group getParam */ + def getCaseSensitive: Boolean = $(caseSensitive) + + setDefault(stopWords -> StopWords.EnglishStopWords, caseSensitive -> false) + + override def transform(dataset: DataFrame): DataFrame = { + val outputSchema = transformSchema(dataset.schema) + val t = if ($(caseSensitive)) { + val stopWordsSet = $(stopWords).toSet + udf { terms: Seq[String] => + terms.filter(s => !stopWordsSet.contains(s)) + } + } else { + val toLower = (s: String) => if (s != null) s.toLowerCase else s + val lowerStopWords = $(stopWords).map(toLower(_)).toSet + udf { terms: Seq[String] => + terms.filter(s => !lowerStopWords.contains(toLower(s))) + } + } + + val metadata = outputSchema($(outputCol)).metadata + dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata)) + } + + override def transformSchema(schema: StructType): StructType = { + val inputType = schema($(inputCol)).dataType + require(inputType.sameType(ArrayType(StringType)), + s"Input type must be ArrayType(StringType) but got $inputType.") + val outputFields = schema.fields :+ + StructField($(outputCol), inputType, schema($(inputCol)).nullable) + StructType(outputFields) + } + + override def copy(extra: ParamMap): StopWordsRemover = defaultCopy(extra) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index bf7be363b8224..ebfa972532358 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -20,13 +20,14 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkException import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.attribute.NominalAttribute +import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{NumericType, StringType, StructType} +import org.apache.spark.sql.types.{DoubleType, NumericType, StringType, StructType} import org.apache.spark.util.collection.OpenHashMap /** @@ -151,4 +152,105 @@ class StringIndexerModel private[ml] ( val copied = new StringIndexerModel(uid, labels) copyValues(copied, extra) } + + /** + * Return a model to perform the inverse transformation. + * Note: By default we keep the original columns during this transformation, so the inverse + * should only be used on new columns such as predicted labels. + */ + def invert(inputCol: String, outputCol: String): StringIndexerInverse = { + new StringIndexerInverse() + .setInputCol(inputCol) + .setOutputCol(outputCol) + .setLabels(labels) + } +} + +/** + * :: Experimental :: + * Transform a provided column back to the original input types using either the metadata + * on the input column, or if provided using the labels supplied by the user. + * Note: By default we keep the original columns during this transformation, + * so the inverse should only be used on new columns such as predicted labels. + */ +@Experimental +class StringIndexerInverse private[ml] ( + override val uid: String) extends Transformer + with HasInputCol with HasOutputCol { + + def this() = + this(Identifiable.randomUID("strIdxInv")) + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** + * Optional labels to be provided by the user, if not supplied column + * metadata is read for labels. The default value is an empty array, + * but the empty array is ignored and column metadata used instead. + * @group setParam + */ + def setLabels(value: Array[String]): this.type = set(labels, value) + + /** + * Param for array of labels. + * Optional labels to be provided by the user, if not supplied column + * metadata is read for labels. + * @group param + */ + final val labels: StringArrayParam = new StringArrayParam(this, "labels", + "array of labels, if not provided metadata from inputCol is used instead.") + setDefault(labels, Array.empty[String]) + + /** + * Optional labels to be provided by the user, if not supplied column + * metadata is read for labels. + * @group getParam + */ + final def getLabels: Array[String] = $(labels) + + /** Transform the schema for the inverse transformation */ + override def transformSchema(schema: StructType): StructType = { + val inputColName = $(inputCol) + val inputDataType = schema(inputColName).dataType + require(inputDataType.isInstanceOf[NumericType], + s"The input column $inputColName must be a numeric type, " + + s"but got $inputDataType.") + val inputFields = schema.fields + val outputColName = $(outputCol) + require(inputFields.forall(_.name != outputColName), + s"Output column $outputColName already exists.") + val attr = NominalAttribute.defaultAttr.withName($(outputCol)) + val outputFields = inputFields :+ attr.toStructField() + StructType(outputFields) + } + + override def transform(dataset: DataFrame): DataFrame = { + val inputColSchema = dataset.schema($(inputCol)) + // If the labels array is empty use column metadata + val values = if ($(labels).isEmpty) { + Attribute.fromStructField(inputColSchema) + .asInstanceOf[NominalAttribute].values.get + } else { + $(labels) + } + val indexer = udf { index: Double => + val idx = index.toInt + if (0 <= idx && idx < values.size) { + values(idx) + } else { + throw new SparkException(s"Unseen index: $index ??") + } + } + val outputColName = $(outputCol) + dataset.select(col("*"), + indexer(dataset($(inputCol)).cast(DoubleType)).as(outputColName)) + } + + override def copy(extra: ParamMap): StringIndexerInverse = { + defaultCopy(extra) + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 954aa17e26a02..d68f5ff0053c9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -166,6 +166,11 @@ object ParamValidators { def inArray[T](allowed: java.util.List[T]): T => Boolean = { (value: T) => allowed.contains(value) } + + /** Check that the array length is greater than lowerBound. */ + def arrayLengthGt[T](lowerBound: Double): Array[T] => Boolean = { (value: Array[T]) => + value.length > lowerBound + } } // specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 6f3340c2f02be..4d30e4b5548aa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -110,7 +110,7 @@ final class DecisionTreeRegressionModel private[ml] ( def this(rootNode: Node) = this(Identifiable.randomUID("dtr"), rootNode) override protected def predict(features: Vector): Double = { - rootNode.predict(features) + rootNode.predictImpl(features).prediction } override def copy(extra: ParamMap): DecisionTreeRegressionModel = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index e38dc73ee0ba7..5633bc320273a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -180,7 +180,7 @@ final class GBTRegressionModel( override protected def predict(features: Vector): Double = { // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 // Classifies by thresholding sum of weighted tree predictions - val treePredictions = _trees.map(_.rootNode.predict(features)) + val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction) blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 506a878c2553b..17fb1ad5e15d4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -143,7 +143,7 @@ final class RandomForestRegressionModel private[ml] ( // TODO: When we add a generic Bagging class, handle transform there. SPARK-7128 // Predict average of tree predictions. // Ignore the weights since all are 1.0 for now. - _trees.map(_.rootNode.predict(features)).sum / numTrees + _trees.map(_.rootNode.predictImpl(features).prediction).sum / numTrees } override def copy(extra: ParamMap): RandomForestRegressionModel = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index bbc2427ca7d3d..8879352a600a9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -19,8 +19,9 @@ package org.apache.spark.ml.tree import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.tree.impurity.ImpurityCalculator import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats, - Node => OldNode, Predict => OldPredict} + Node => OldNode, Predict => OldPredict, ImpurityStats} /** * :: DeveloperApi :: @@ -38,8 +39,15 @@ sealed abstract class Node extends Serializable { /** Impurity measure at this node (for training data) */ def impurity: Double + /** + * Statistics aggregated from training data at this node, used to compute prediction, impurity, + * and probabilities. + * For classification, the array of class counts must be normalized to a probability distribution. + */ + private[tree] def impurityStats: ImpurityCalculator + /** Recursive prediction helper method */ - private[ml] def predict(features: Vector): Double = prediction + private[ml] def predictImpl(features: Vector): LeafNode /** * Get the number of nodes in tree below this node, including leaf nodes. @@ -75,7 +83,8 @@ private[ml] object Node { if (oldNode.isLeaf) { // TODO: Once the implementation has been moved to this API, then include sufficient // statistics here. - new LeafNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity) + new LeafNode(prediction = oldNode.predict.predict, + impurity = oldNode.impurity, impurityStats = null) } else { val gain = if (oldNode.stats.nonEmpty) { oldNode.stats.get.gain @@ -85,7 +94,7 @@ private[ml] object Node { new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity, gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures), rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures), - split = Split.fromOld(oldNode.split.get, categoricalFeatures)) + split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null) } } } @@ -99,11 +108,13 @@ private[ml] object Node { @DeveloperApi final class LeafNode private[ml] ( override val prediction: Double, - override val impurity: Double) extends Node { + override val impurity: Double, + override val impurityStats: ImpurityCalculator) extends Node { - override def toString: String = s"LeafNode(prediction = $prediction, impurity = $impurity)" + override def toString: String = + s"LeafNode(prediction = $prediction, impurity = $impurity)" - override private[ml] def predict(features: Vector): Double = prediction + override private[ml] def predictImpl(features: Vector): LeafNode = this override private[tree] def numDescendants: Int = 0 @@ -115,9 +126,8 @@ final class LeafNode private[ml] ( override private[tree] def subtreeDepth: Int = 0 override private[ml] def toOld(id: Int): OldNode = { - // NOTE: We do NOT store 'prob' in the new API currently. - new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = true, - None, None, None, None) + new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)), + impurity, isLeaf = true, None, None, None, None) } } @@ -139,17 +149,18 @@ final class InternalNode private[ml] ( val gain: Double, val leftChild: Node, val rightChild: Node, - val split: Split) extends Node { + val split: Split, + override val impurityStats: ImpurityCalculator) extends Node { override def toString: String = { s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)" } - override private[ml] def predict(features: Vector): Double = { + override private[ml] def predictImpl(features: Vector): LeafNode = { if (split.shouldGoLeft(features)) { - leftChild.predict(features) + leftChild.predictImpl(features) } else { - rightChild.predict(features) + rightChild.predictImpl(features) } } @@ -172,9 +183,8 @@ final class InternalNode private[ml] ( override private[ml] def toOld(id: Int): OldNode = { assert(id.toLong * 2 < Int.MaxValue, "Decision Tree could not be converted from new to old API" + " since the old API does not support deep trees.") - // NOTE: We do NOT store 'prob' in the new API currently. - new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = false, - Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))), + new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)), impurity, + isLeaf = false, Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))), Some(rightChild.toOld(OldNode.rightChildIndex(id))), Some(new OldInformationGainStats(gain, impurity, leftChild.impurity, rightChild.impurity, new OldPredict(leftChild.prediction, prob = 0.0), @@ -223,36 +233,36 @@ private object InternalNode { * * @param id We currently use the same indexing as the old implementation in * [[org.apache.spark.mllib.tree.model.Node]], but this will change later. - * @param predictionStats Predicted label + class probability (for classification). - * We will later modify this to store aggregate statistics for labels - * to provide all class probabilities (for classification) and maybe a - * distribution (for regression). * @param isLeaf Indicates whether this node will definitely be a leaf in the learned tree, * so that we do not need to consider splitting it further. - * @param stats Old structure for storing stats about information gain, prediction, etc. - * This is legacy and will be modified in the future. + * @param stats Impurity statistics for this node. */ private[tree] class LearningNode( var id: Int, - var predictionStats: OldPredict, - var impurity: Double, var leftChild: Option[LearningNode], var rightChild: Option[LearningNode], var split: Option[Split], var isLeaf: Boolean, - var stats: Option[OldInformationGainStats]) extends Serializable { + var stats: ImpurityStats) extends Serializable { /** * Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children. */ def toNode: Node = { if (leftChild.nonEmpty) { - assert(rightChild.nonEmpty && split.nonEmpty && stats.nonEmpty, + assert(rightChild.nonEmpty && split.nonEmpty && stats != null, "Unknown error during Decision Tree learning. Could not convert LearningNode to Node.") - new InternalNode(predictionStats.predict, impurity, stats.get.gain, - leftChild.get.toNode, rightChild.get.toNode, split.get) + new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, + leftChild.get.toNode, rightChild.get.toNode, split.get, stats.impurityCalculator) } else { - new LeafNode(predictionStats.predict, impurity) + if (stats.valid) { + new LeafNode(stats.impurityCalculator.predict, stats.impurity, + stats.impurityCalculator) + } else { + // Here we want to keep same behavior with the old mllib.DecisionTreeModel + new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator) + } + } } @@ -263,16 +273,14 @@ private[tree] object LearningNode { /** Create a node with some of its fields set. */ def apply( id: Int, - predictionStats: OldPredict, - impurity: Double, - isLeaf: Boolean): LearningNode = { - new LearningNode(id, predictionStats, impurity, None, None, None, false, None) + isLeaf: Boolean, + stats: ImpurityStats): LearningNode = { + new LearningNode(id, None, None, None, false, stats) } /** Create an empty node with the given node index. Values must be set later on. */ def emptyNode(nodeIndex: Int): LearningNode = { - new LearningNode(nodeIndex, new OldPredict(Double.NaN, Double.NaN), Double.NaN, - None, None, None, false, None) + new LearningNode(nodeIndex, None, None, None, false, null) } // The below indexing methods were copied from spark.mllib.tree.model.Node diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 15b56bd844bad..a8b90d9d266a1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -31,7 +31,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => O import org.apache.spark.mllib.tree.impl.{BaggedPoint, DTStatsAggregator, DecisionTreeMetadata, TimeTracker} import org.apache.spark.mllib.tree.impurity.ImpurityCalculator -import org.apache.spark.mllib.tree.model.{InformationGainStats, Predict} +import org.apache.spark.mllib.tree.model.ImpurityStats import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom} @@ -180,13 +180,17 @@ private[ml] object RandomForest extends Logging { parentUID match { case Some(uid) => if (strategy.algo == OldAlgo.Classification) { - topNodes.map(rootNode => new DecisionTreeClassificationModel(uid, rootNode.toNode)) + topNodes.map { rootNode => + new DecisionTreeClassificationModel(uid, rootNode.toNode, strategy.getNumClasses) + } } else { topNodes.map(rootNode => new DecisionTreeRegressionModel(uid, rootNode.toNode)) } case None => if (strategy.algo == OldAlgo.Classification) { - topNodes.map(rootNode => new DecisionTreeClassificationModel(rootNode.toNode)) + topNodes.map { rootNode => + new DecisionTreeClassificationModel(rootNode.toNode, strategy.getNumClasses) + } } else { topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode)) } @@ -549,9 +553,9 @@ private[ml] object RandomForest extends Logging { } // find best split for each node - val (split: Split, stats: InformationGainStats, predict: Predict) = + val (split: Split, stats: ImpurityStats) = binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex)) - (nodeIndex, (split, stats, predict)) + (nodeIndex, (split, stats)) }.collectAsMap() timer.stop("chooseSplits") @@ -568,17 +572,15 @@ private[ml] object RandomForest extends Logging { val nodeIndex = node.id val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex) val aggNodeIndex = nodeInfo.nodeIndexInGroup - val (split: Split, stats: InformationGainStats, predict: Predict) = + val (split: Split, stats: ImpurityStats) = nodeToBestSplits(aggNodeIndex) logDebug("best split = " + split) // Extract info for this node. Create children if not leaf. val isLeaf = (stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) == metadata.maxDepth) - node.predictionStats = predict node.isLeaf = isLeaf - node.stats = Some(stats) - node.impurity = stats.impurity + node.stats = stats logDebug("Node = " + node) if (!isLeaf) { @@ -587,9 +589,9 @@ private[ml] object RandomForest extends Logging { val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0) val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0) node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex), - stats.leftPredict, stats.leftImpurity, leftChildIsLeaf)) + leftChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator))) node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex), - stats.rightPredict, stats.rightImpurity, rightChildIsLeaf)) + rightChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.rightImpurityCalculator))) if (nodeIdCache.nonEmpty) { val nodeIndexUpdater = NodeIndexUpdater( @@ -621,28 +623,44 @@ private[ml] object RandomForest extends Logging { } /** - * Calculate the information gain for a given (feature, split) based upon left/right aggregates. + * Calculate the impurity statistics for a give (feature, split) based upon left/right aggregates. + * @param stats the recycle impurity statistics for this feature's all splits, + * only 'impurity' and 'impurityCalculator' are valid between each iteration * @param leftImpurityCalculator left node aggregates for this (feature, split) * @param rightImpurityCalculator right node aggregate for this (feature, split) - * @return information gain and statistics for split + * @param metadata learning and dataset metadata for DecisionTree + * @return Impurity statistics for this (feature, split) */ - private def calculateGainForSplit( + private def calculateImpurityStats( + stats: ImpurityStats, leftImpurityCalculator: ImpurityCalculator, rightImpurityCalculator: ImpurityCalculator, - metadata: DecisionTreeMetadata, - impurity: Double): InformationGainStats = { + metadata: DecisionTreeMetadata): ImpurityStats = { + + val parentImpurityCalculator: ImpurityCalculator = if (stats == null) { + leftImpurityCalculator.copy.add(rightImpurityCalculator) + } else { + stats.impurityCalculator + } + + val impurity: Double = if (stats == null) { + parentImpurityCalculator.calculate() + } else { + stats.impurity + } + val leftCount = leftImpurityCalculator.count val rightCount = rightImpurityCalculator.count + val totalCount = leftCount + rightCount + // If left child or right child doesn't satisfy minimum instances per node, // then this split is invalid, return invalid information gain stats. if ((leftCount < metadata.minInstancesPerNode) || (rightCount < metadata.minInstancesPerNode)) { - return InformationGainStats.invalidInformationGainStats + return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) } - val totalCount = leftCount + rightCount - val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 val rightImpurity = rightImpurityCalculator.calculate() @@ -654,39 +672,11 @@ private[ml] object RandomForest extends Logging { // if information gain doesn't satisfy minimum information gain, // then this split is invalid, return invalid information gain stats. if (gain < metadata.minInfoGain) { - return InformationGainStats.invalidInformationGainStats + return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) } - // calculate left and right predict - val leftPredict = calculatePredict(leftImpurityCalculator) - val rightPredict = calculatePredict(rightImpurityCalculator) - - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, - leftPredict, rightPredict) - } - - private def calculatePredict(impurityCalculator: ImpurityCalculator): Predict = { - val predict = impurityCalculator.predict - val prob = impurityCalculator.prob(predict) - new Predict(predict, prob) - } - - /** - * Calculate predict value for current node, given stats of any split. - * Note that this function is called only once for each node. - * @param leftImpurityCalculator left node aggregates for a split - * @param rightImpurityCalculator right node aggregates for a split - * @return predict value and impurity for current node - */ - private def calculatePredictImpurity( - leftImpurityCalculator: ImpurityCalculator, - rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = { - val parentNodeAgg = leftImpurityCalculator.copy - parentNodeAgg.add(rightImpurityCalculator) - val predict = calculatePredict(parentNodeAgg) - val impurity = parentNodeAgg.calculate() - - (predict, impurity) + new ImpurityStats(gain, impurity, parentImpurityCalculator, + leftImpurityCalculator, rightImpurityCalculator) } /** @@ -698,14 +688,14 @@ private[ml] object RandomForest extends Logging { binAggregates: DTStatsAggregator, splits: Array[Array[Split]], featuresForNode: Option[Array[Int]], - node: LearningNode): (Split, InformationGainStats, Predict) = { + node: LearningNode): (Split, ImpurityStats) = { - // Calculate prediction and impurity if current node is top node + // Calculate InformationGain and ImpurityStats if current node is top node val level = LearningNode.indexToLevel(node.id) - var predictionAndImpurity: Option[(Predict, Double)] = if (level == 0) { - None + var gainAndImpurityStats: ImpurityStats = if (level ==0) { + null } else { - Some((node.predictionStats, node.impurity)) + node.stats } // For each (feature, split), calculate the gain, and select the best (feature, split). @@ -734,11 +724,9 @@ private[ml] object RandomForest extends Logging { val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) rightChildStats.subtract(leftChildStats) - predictionAndImpurity = Some(predictionAndImpurity.getOrElse( - calculatePredictImpurity(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2) - (splitIdx, gainStats) + gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, + leftChildStats, rightChildStats, binAggregates.metadata) + (splitIdx, gainAndImpurityStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else if (binAggregates.metadata.isUnordered(featureIndex)) { @@ -750,11 +738,9 @@ private[ml] object RandomForest extends Logging { val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) - predictionAndImpurity = Some(predictionAndImpurity.getOrElse( - calculatePredictImpurity(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2) - (splitIndex, gainStats) + gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, + leftChildStats, rightChildStats, binAggregates.metadata) + (splitIndex, gainAndImpurityStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else { @@ -825,11 +811,9 @@ private[ml] object RandomForest extends Logging { val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) rightChildStats.subtract(leftChildStats) - predictionAndImpurity = Some(predictionAndImpurity.getOrElse( - calculatePredictImpurity(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2) - (splitIndex, gainStats) + gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, + leftChildStats, rightChildStats, binAggregates.metadata) + (splitIndex, gainAndImpurityStats) }.maxBy(_._2.gain) val categoriesForSplit = categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) @@ -839,7 +823,7 @@ private[ml] object RandomForest extends Logging { } }.maxBy(_._2.gain) - (bestSplit, bestSplitStats, predictionAndImpurity.get._1) + (bestSplit, bestSplitStats) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 6cfad3fbbdb87..0cdac84eeb591 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.clustering -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, normalize, sum} +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argtopk, normalize, sum} import breeze.numerics.{exp, lgamma} import org.apache.hadoop.fs.Path import org.json4s.DefaultFormats @@ -217,22 +217,28 @@ class LocalLDAModel private[clustering] ( LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix, docConcentration, topicConcentration, gammaShape) } - // TODO - // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ??? + + // TODO: declare in LDAModel and override once implemented in DistributedLDAModel + /** + * Calculates a lower bound on the log likelihood of the entire corpus. + * @param documents test corpus to use for calculating log likelihood + * @return variational lower bound on the log likelihood of the entire corpus + */ + def logLikelihood(documents: RDD[(Long, Vector)]): Double = bound(documents, + docConcentration, topicConcentration, topicsMatrix.toBreeze.toDenseMatrix, gammaShape, k, + vocabSize) /** - * Calculate the log variational bound on perplexity. See Equation (16) in original Online + * Calculate an upper bound bound on perplexity. See Equation (16) in original Online * LDA paper. * @param documents test corpus to use for calculating perplexity - * @return the log perplexity per word + * @return variational upper bound on log perplexity per word */ def logPerplexity(documents: RDD[(Long, Vector)]): Double = { val corpusWords = documents .map { case (_, termCounts) => termCounts.toArray.sum } .sum() - val batchVariationalBound = bound(documents, docConcentration, - topicConcentration, topicsMatrix.toBreeze.toDenseMatrix, gammaShape, k, vocabSize) - val perWordBound = batchVariationalBound / corpusWords + val perWordBound = -logLikelihood(documents) / corpusWords perWordBound } @@ -510,6 +516,43 @@ class DistributedLDAModel private[clustering] ( } } + /** + * Return the top documents for each topic + * + * This is approximate; it may not return exactly the top-weighted documents for each topic. + * To get a more precise set of top documents, increase maxDocumentsPerTopic. + * + * @param maxDocumentsPerTopic Maximum number of documents to collect for each topic. + * @return Array over topics. Each element represent as a pair of matching arrays: + * (IDs for the documents, weights of the topic in these documents). + * For each topic, documents are sorted in order of decreasing topic weights. + */ + def topDocumentsPerTopic(maxDocumentsPerTopic: Int): Array[(Array[Long], Array[Double])] = { + val numTopics = k + val topicsInQueues: Array[BoundedPriorityQueue[(Double, Long)]] = + topicDistributions.mapPartitions { docVertices => + // For this partition, collect the most common docs for each topic in queues: + // queues(topic) = queue of (doc topic, doc ID). + val queues = + Array.fill(numTopics)(new BoundedPriorityQueue[(Double, Long)](maxDocumentsPerTopic)) + for ((docId, docTopics) <- docVertices) { + var topic = 0 + while (topic < numTopics) { + queues(topic) += (docTopics(topic) -> docId) + topic += 1 + } + } + Iterator(queues) + }.treeReduce { (q1, q2) => + q1.zip(q2).foreach { case (a, b) => a ++= b } + q1 + } + topicsInQueues.map { q => + val (docTopics, docs) = q.toArray.sortBy(-_._1).unzip + (docs.toArray, docTopics.toArray) + } + } + // TODO // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ??? @@ -591,6 +634,23 @@ class DistributedLDAModel private[clustering] ( JavaPairRDD.fromRDD(topicDistributions.asInstanceOf[RDD[(java.lang.Long, Vector)]]) } + /** + * For each document, return the top k weighted topics for that document and their weights. + * @return RDD of (doc ID, topic indices, topic weights) + */ + def topTopicsPerDocument(k: Int): RDD[(Long, Array[Int], Array[Double])] = { + graph.vertices.filter(LDA.isDocumentVertex).map { case (docID, topicCounts) => + val topIndices = argtopk(topicCounts, k) + val sumCounts = sum(topicCounts) + val weights = if (sumCounts != 0) { + topicCounts(topIndices) / sumCounts + } else { + topicCounts(topIndices) + } + (docID.toLong, topIndices.toArray, weights.toArray) + } + } + // TODO: // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ??? diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index d6f8b29a43dfd..b0e14cb8296a6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -19,8 +19,8 @@ package org.apache.spark.mllib.clustering import java.util.Random -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, normalize, sum} -import breeze.numerics.{abs, exp} +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, all, normalize, sum} +import breeze.numerics.{trigamma, abs, exp} import breeze.stats.distributions.{Gamma, RandBasis} import org.apache.spark.annotation.DeveloperApi @@ -239,22 +239,26 @@ final class OnlineLDAOptimizer extends LDAOptimizer { /** alias for docConcentration */ private var alpha: Vector = Vectors.dense(0) - /** (private[clustering] for debugging) Get docConcentration */ + /** (for debugging) Get docConcentration */ private[clustering] def getAlpha: Vector = alpha /** alias for topicConcentration */ private var eta: Double = 0 - /** (private[clustering] for debugging) Get topicConcentration */ + /** (for debugging) Get topicConcentration */ private[clustering] def getEta: Double = eta private var randomGenerator: java.util.Random = null + /** (for debugging) Whether to sample mini-batches with replacement. (default = true) */ + private var sampleWithReplacement: Boolean = true + // Online LDA specific parameters // Learning rate is: (tau0 + t)^{-kappa} private var tau0: Double = 1024 private var kappa: Double = 0.51 private var miniBatchFraction: Double = 0.05 + private var optimizeAlpha: Boolean = false // internal data structure private var docs: RDD[(Long, Vector)] = null @@ -262,7 +266,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { /** Dirichlet parameter for the posterior over topics */ private var lambda: BDM[Double] = null - /** (private[clustering] for debugging) Get parameter for topics */ + /** (for debugging) Get parameter for topics */ private[clustering] def getLambda: BDM[Double] = lambda /** Current iteration (count of invocations of [[next()]]) */ @@ -325,7 +329,22 @@ final class OnlineLDAOptimizer extends LDAOptimizer { } /** - * (private[clustering]) + * Optimize alpha, indicates whether alpha (Dirichlet parameter for document-topic distribution) + * will be optimized during training. + */ + def getOptimzeAlpha: Boolean = this.optimizeAlpha + + /** + * Sets whether to optimize alpha parameter during training. + * + * Default: false + */ + def setOptimzeAlpha(optimizeAlpha: Boolean): this.type = { + this.optimizeAlpha = optimizeAlpha + this + } + + /** * Set the Dirichlet parameter for the posterior over topics. * This is only used for testing now. In the future, it can help support training stop/resume. */ @@ -335,7 +354,6 @@ final class OnlineLDAOptimizer extends LDAOptimizer { } /** - * (private[clustering]) * Used for random initialization of the variational parameters. * Larger value produces values closer to 1.0. * This is only used for testing currently. @@ -345,6 +363,15 @@ final class OnlineLDAOptimizer extends LDAOptimizer { this } + /** + * Sets whether to sample mini-batches with or without replacement. (default = true) + * This is only used for testing currently. + */ + private[clustering] def setSampleWithReplacement(replace: Boolean): this.type = { + this.sampleWithReplacement = replace + this + } + override private[clustering] def initialize( docs: RDD[(Long, Vector)], lda: LDA): OnlineLDAOptimizer = { @@ -376,7 +403,8 @@ final class OnlineLDAOptimizer extends LDAOptimizer { } override private[clustering] def next(): OnlineLDAOptimizer = { - val batch = docs.sample(withReplacement = true, miniBatchFraction, randomGenerator.nextLong()) + val batch = docs.sample(withReplacement = sampleWithReplacement, miniBatchFraction, + randomGenerator.nextLong()) if (batch.isEmpty()) return this submitMiniBatch(batch) } @@ -418,6 +446,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { // Note that this is an optimization to avoid batch.count updateLambda(batchResult, (miniBatchFraction * corpusSize).ceil.toInt) + if (optimizeAlpha) updateAlpha(gammat) this } @@ -433,13 +462,39 @@ final class OnlineLDAOptimizer extends LDAOptimizer { weight * (stat * (corpusSize.toDouble / batchSize.toDouble) + eta) } - /** Calculates learning rate rho, which decays as a function of [[iteration]] */ + /** + * Update alpha based on `gammat`, the inferred topic distributions for documents in the + * current mini-batch. Uses Newton-Rhapson method. + * @see Section 3.3, Huang: Maximum Likelihood Estimation of Dirichlet Distribution Parameters + * (http://jonathan-huang.org/research/dirichlet/dirichlet.pdf) + */ + private def updateAlpha(gammat: BDM[Double]): Unit = { + val weight = rho() + val N = gammat.rows.toDouble + val alpha = this.alpha.toBreeze.toDenseVector + val logphat: BDM[Double] = sum(LDAUtils.dirichletExpectation(gammat)(::, breeze.linalg.*)) / N + val gradf = N * (-LDAUtils.dirichletExpectation(alpha) + logphat.toDenseVector) + + val c = N * trigamma(sum(alpha)) + val q = -N * trigamma(alpha) + val b = sum(gradf / q) / (1D / c + sum(1D / q)) + + val dalpha = -(gradf - b) / q + + if (all((weight * dalpha + alpha) :> 0D)) { + alpha :+= weight * dalpha + this.alpha = Vectors.dense(alpha.toArray) + } + } + + + /** Calculate learning rate rho for the current [[iteration]]. */ private def rho(): Double = { math.pow(getTau0 + this.iteration, -getKappa) } /** - * Get a random matrix to initialize lambda + * Get a random matrix to initialize lambda. */ private def getGammaMatrix(row: Int, col: Int): BDM[Double] = { val randBasis = new RandBasis(new org.apache.commons.math3.random.MersenneTwister( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala index 0ea792081086d..ccebf951c850d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala @@ -25,7 +25,7 @@ import org.apache.spark.Logging * Calculate all patterns of a projected database in local. */ private[fpm] object LocalPrefixSpan extends Logging with Serializable { - + import PrefixSpan._ /** * Calculate all patterns of a projected database. * @param minCount minimum count @@ -39,12 +39,19 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { def run( minCount: Long, maxPatternLength: Int, - prefixes: List[Int], - database: Iterable[Array[Int]]): Iterator[(List[Int], Long)] = { - if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty - val frequentItemAndCounts = getFreqItemAndCounts(minCount, database) - val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains)) - frequentItemAndCounts.iterator.flatMap { case (item, count) => + prefixes: List[Set[Int]], + database: Iterable[List[Set[Int]]]): Iterator[(List[Set[Int]], Long)] = { + if (prefixes.length == maxPatternLength || database.isEmpty) { + return Iterator.empty + } + val freqItemSetsAndCounts = getFreqItemAndCounts(minCount, database) + val freqItems = freqItemSetsAndCounts.keys.flatten.toSet + val filteredDatabase = database.map { suffix => + suffix + .map(item => freqItems.intersect(item)) + .filter(_.nonEmpty) + } + freqItemSetsAndCounts.iterator.flatMap { case (item, count) => val newPrefixes = item :: prefixes val newProjected = project(filteredDatabase, item) Iterator.single((newPrefixes, count)) ++ @@ -54,20 +61,23 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { /** * Calculate suffix sequence immediately after the first occurrence of an item. - * @param item item to get suffix after + * @param item itemset to get suffix after * @param sequence sequence to extract suffix from * @return suffix sequence */ - def getSuffix(item: Int, sequence: Array[Int]): Array[Int] = { - val index = sequence.indexOf(item) + def getSuffix(item: Set[Int], sequence: List[Set[Int]]): List[Set[Int]] = { + val itemsetSeq = sequence + val index = itemsetSeq.indexWhere(item.subsetOf(_)) if (index == -1) { - Array() + List() } else { - sequence.drop(index + 1) + itemsetSeq.drop(index + 1) } } - def project(database: Iterable[Array[Int]], prefix: Int): Iterable[Array[Int]] = { + def project( + database: Iterable[List[Set[Int]]], + prefix: Set[Int]): Iterable[List[Set[Int]]] = { database .map(getSuffix(prefix, _)) .filter(_.nonEmpty) @@ -81,14 +91,16 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { */ private def getFreqItemAndCounts( minCount: Long, - database: Iterable[Array[Int]]): mutable.Map[Int, Long] = { + database: Iterable[List[Set[Int]]]): Map[Set[Int], Long] = { // TODO: use PrimitiveKeyOpenHashMap - val counts = mutable.Map[Int, Long]().withDefaultValue(0L) + val counts = mutable.Map[Set[Int], Long]().withDefaultValue(0L) database.foreach { sequence => - sequence.distinct.foreach { item => + sequence.flatMap(nonemptySubsets(_)).distinct.foreach { item => counts(item) += 1L } } - counts.filter(_._2 >= minCount) + counts + .filter { case (_, count) => count >= minCount } + .toMap } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index e6752332cdeeb..22b4ddb8b3495 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.fpm -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.ArrayBuilder import org.apache.spark.Logging import org.apache.spark.annotation.Experimental @@ -44,13 +44,14 @@ import org.apache.spark.storage.StorageLevel class PrefixSpan private ( private var minSupport: Double, private var maxPatternLength: Int) extends Logging with Serializable { + import PrefixSpan._ /** * The maximum number of items allowed in a projected database before local processing. If a * projected database exceeds this size, another iteration of distributed PrefixSpan is run. */ - // TODO: make configurable with a better default value, 10000 may be too small - private val maxLocalProjDBSize: Long = 10000 + // TODO: make configurable with a better default value + private val maxLocalProjDBSize: Long = 32000000L /** * Constructs a default instance with default parameters @@ -90,35 +91,41 @@ class PrefixSpan private ( /** * Find the complete set of sequential patterns in the input sequences. - * @param sequences input data set, contains a set of sequences, - * a sequence is an ordered list of elements. + * @param data ordered sequences of itemsets. Items are represented by non-negative integers. + * Each itemset has one or more items and is delimited by [[DELIMITER]]. * @return a set of sequential pattern pairs, * the key of pair is pattern (a list of elements), * the value of pair is the pattern's count. */ - def run(sequences: RDD[Array[Int]]): RDD[(Array[Int], Long)] = { - val sc = sequences.sparkContext + // TODO: generalize to arbitrary item-types and use mapping to Ints for internal algorithm + def run(data: RDD[Array[Int]]): RDD[(Array[Int], Long)] = { + val sc = data.sparkContext - if (sequences.getStorageLevel == StorageLevel.NONE) { + if (data.getStorageLevel == StorageLevel.NONE) { logWarning("Input data is not cached.") } + // Use List[Set[Item]] for internal computation + val sequences = data.map { seq => splitSequence(seq.toList) } + // Convert min support to a min number of transactions for this dataset val minCount = if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong // (Frequent items -> number of occurrences, all items here satisfy the `minSupport` threshold val freqItemCounts = sequences - .flatMap(seq => seq.distinct.map(item => (item, 1L))) + .flatMap(seq => seq.flatMap(nonemptySubsets(_)).distinct.map(item => (item, 1L))) .reduceByKey(_ + _) - .filter(_._2 >= minCount) + .filter { case (item, count) => (count >= minCount) } .collect() + .toMap // Pairs of (length 1 prefix, suffix consisting of frequent items) val itemSuffixPairs = { - val freqItems = freqItemCounts.map(_._1).toSet + val freqItemSets = freqItemCounts.keys.toSet + val freqItems = freqItemSets.flatten sequences.flatMap { seq => - val filteredSeq = seq.filter(freqItems.contains(_)) - freqItems.flatMap { item => + val filteredSeq = seq.map(item => freqItems.intersect(item)).filter(_.nonEmpty) + freqItemSets.flatMap { item => val candidateSuffix = LocalPrefixSpan.getSuffix(item, filteredSeq) candidateSuffix match { case suffix if !suffix.isEmpty => Some((List(item), suffix)) @@ -130,14 +137,15 @@ class PrefixSpan private ( // Accumulator for the computed results to be returned, initialized to the frequent items (i.e. // frequent length-one prefixes) - var resultsAccumulator = freqItemCounts.map(x => (List(x._1), x._2)) + var resultsAccumulator = freqItemCounts.map { case (item, count) => (List(item), count) }.toList // Remaining work to be locally and distributively processed respectfully var (pairsForLocal, pairsForDistributed) = partitionByProjDBSize(itemSuffixPairs) // Continue processing until no pairs for distributed processing remain (i.e. all prefixes have - // projected database sizes <= `maxLocalProjDBSize`) - while (pairsForDistributed.count() != 0) { + // projected database sizes <= `maxLocalProjDBSize`) or `maxPatternLength` is reached + var patternLength = 1 + while (pairsForDistributed.count() != 0 && patternLength < maxPatternLength) { val (nextPatternAndCounts, nextPrefixSuffixPairs) = extendPrefixes(minCount, pairsForDistributed) pairsForDistributed.unpersist() @@ -146,6 +154,7 @@ class PrefixSpan private ( pairsForDistributed.persist(StorageLevel.MEMORY_AND_DISK) pairsForLocal ++= smallerPairsPart resultsAccumulator ++= nextPatternAndCounts.collect() + patternLength += 1 // pattern length grows one per iteration } // Process the small projected databases locally @@ -153,7 +162,7 @@ class PrefixSpan private ( minCount, sc.parallelize(pairsForLocal, 1).groupByKey()) (sc.parallelize(resultsAccumulator, 1) ++ remainingResults) - .map { case (pattern, count) => (pattern.toArray, count) } + .map { case (pattern, count) => (flattenSequence(pattern.reverse).toArray, count) } } @@ -163,8 +172,8 @@ class PrefixSpan private ( * @return prefix-suffix pairs partitioned by whether their projected database size is <= or * greater than [[maxLocalProjDBSize]] */ - private def partitionByProjDBSize(prefixSuffixPairs: RDD[(List[Int], Array[Int])]) - : (Array[(List[Int], Array[Int])], RDD[(List[Int], Array[Int])]) = { + private def partitionByProjDBSize(prefixSuffixPairs: RDD[(List[Set[Int]], List[Set[Int]])]) + : (List[(List[Set[Int]], List[Set[Int]])], RDD[(List[Set[Int]], List[Set[Int]])]) = { val prefixToSuffixSize = prefixSuffixPairs .aggregateByKey(0)( seqOp = { case (count, suffix) => count + suffix.length }, @@ -176,12 +185,12 @@ class PrefixSpan private ( .toSet val small = prefixSuffixPairs.filter { case (prefix, _) => smallPrefixes.contains(prefix) } val large = prefixSuffixPairs.filter { case (prefix, _) => !smallPrefixes.contains(prefix) } - (small.collect(), large) + (small.collect().toList, large) } /** - * Extends all prefixes by one item from their suffix and computes the resulting frequent prefixes - * and remaining work. + * Extends all prefixes by one itemset from their suffix and computes the resulting frequent + * prefixes and remaining work. * @param minCount minimum count * @param prefixSuffixPairs prefix (length N) and suffix pairs, * @return (frequent length N+1 extended prefix, count) pairs and (frequent length N+1 extended @@ -189,15 +198,16 @@ class PrefixSpan private ( */ private def extendPrefixes( minCount: Long, - prefixSuffixPairs: RDD[(List[Int], Array[Int])]) - : (RDD[(List[Int], Long)], RDD[(List[Int], Array[Int])]) = { + prefixSuffixPairs: RDD[(List[Set[Int]], List[Set[Int]])]) + : (RDD[(List[Set[Int]], Long)], RDD[(List[Set[Int]], List[Set[Int]])]) = { - // (length N prefix, item from suffix) pairs and their corresponding number of occurrences + // (length N prefix, itemset from suffix) pairs and their corresponding number of occurrences // Every (prefix :+ suffix) is guaranteed to have support exceeding `minSupport` val prefixItemPairAndCounts = prefixSuffixPairs - .flatMap { case (prefix, suffix) => suffix.distinct.map(y => ((prefix, y), 1L)) } + .flatMap { case (prefix, suffix) => + suffix.flatMap(nonemptySubsets(_)).distinct.map(y => ((prefix, y), 1L)) } .reduceByKey(_ + _) - .filter(_._2 >= minCount) + .filter { case (item, count) => (count >= minCount) } // Map from prefix to set of possible next items from suffix val prefixToNextItems = prefixItemPairAndCounts @@ -207,7 +217,6 @@ class PrefixSpan private ( .collect() .toMap - // Frequent patterns with length N+1 and their corresponding counts val extendedPrefixAndCounts = prefixItemPairAndCounts .map { case ((prefix, item), count) => (item :: prefix, count) } @@ -216,9 +225,12 @@ class PrefixSpan private ( val extendedPrefixAndSuffix = prefixSuffixPairs .filter(x => prefixToNextItems.contains(x._1)) .flatMap { case (prefix, suffix) => - val frequentNextItems = prefixToNextItems(prefix) - val filteredSuffix = suffix.filter(frequentNextItems.contains(_)) - frequentNextItems.flatMap { item => + val frequentNextItemSets = prefixToNextItems(prefix) + val frequentNextItems = frequentNextItemSets.flatten + val filteredSuffix = suffix + .map(item => frequentNextItems.intersect(item)) + .filter(_.nonEmpty) + frequentNextItemSets.flatMap { item => LocalPrefixSpan.getSuffix(item, filteredSuffix) match { case suffix if !suffix.isEmpty => Some(item :: prefix, suffix) case _ => None @@ -237,13 +249,38 @@ class PrefixSpan private ( */ private def getPatternsInLocal( minCount: Long, - data: RDD[(List[Int], Iterable[Array[Int]])]): RDD[(List[Int], Long)] = { + data: RDD[(List[Set[Int]], Iterable[List[Set[Int]]])]): RDD[(List[Set[Int]], Long)] = { data.flatMap { - case (prefix, projDB) => - LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList.reverse, projDB) - .map { case (pattern: List[Int], count: Long) => - (pattern.reverse, count) - } + case (prefix, projDB) => LocalPrefixSpan.run(minCount, maxPatternLength, prefix, projDB) + } + } + +} + +private[fpm] object PrefixSpan { + private[fpm] val DELIMITER = -1 + + /** Splits a sequence of itemsets delimited by [[DELIMITER]]. */ + private[fpm] def splitSequence(sequence: List[Int]): List[Set[Int]] = { + sequence.span(_ != DELIMITER) match { + case (x, xs) if xs.length > 1 => x.toSet :: splitSequence(xs.tail) + case (x, xs) => List(x.toSet) + } + } + + /** Flattens a sequence of itemsets into an Array, inserting[[DELIMITER]] between itemsets. */ + private[fpm] def flattenSequence(sequence: List[Set[Int]]): List[Int] = { + val builder = ArrayBuilder.make[Int]() + for (itemSet <- sequence) { + builder += DELIMITER + builder ++= itemSet.toSeq.sorted } + builder.result().toList.drop(1) // drop trailing delimiter + } + + /** Returns an iterator over all non-empty subsets of `itemSet` */ + private[fpm] def nonemptySubsets(itemSet: Set[Int]): Iterator[Set[Int]] = { + // TODO: improve complexity by using partial prefixes, considering one item at a time + itemSet.subsets.filter(_ != Set.empty[Int]) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 88914fa875990..1c858348bf20e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -179,12 +179,12 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { val tpe = row.getByte(0) val numRows = row.getInt(1) val numCols = row.getInt(2) - val values = row.getArray(5).toArray.map(_.asInstanceOf[Double]) + val values = row.getArray(5).toDoubleArray() val isTransposed = row.getBoolean(6) tpe match { case 0 => - val colPtrs = row.getArray(3).toArray.map(_.asInstanceOf[Int]) - val rowIndices = row.getArray(4).toArray.map(_.asInstanceOf[Int]) + val colPtrs = row.getArray(3).toIntArray() + val rowIndices = row.getArray(4).toIntArray() new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed) case 1 => new DenseMatrix(numRows, numCols, values, isTransposed) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 89a1818db0d1d..96d1f48ba2ba3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -209,11 +209,11 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { tpe match { case 0 => val size = row.getInt(1) - val indices = row.getArray(2).toArray().map(_.asInstanceOf[Int]) - val values = row.getArray(3).toArray().map(_.asInstanceOf[Double]) + val indices = row.getArray(2).toIntArray() + val values = row.getArray(3).toDoubleArray() new SparseVector(size, indices, values) case 1 => - val values = row.getArray(3).toArray().map(_.asInstanceOf[Double]) + val values = row.getArray(3).toDoubleArray() new DenseVector(values) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index ab7611fd077ef..8f0d1e4aa010a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -32,7 +32,7 @@ import org.apache.spark.mllib.linalg.{Vectors, Vector} * @param gradient Gradient function to be used. * @param updater Updater to be used to update weights after every iteration. */ -class GradientDescent private[mllib] (private var gradient: Gradient, private var updater: Updater) +class GradientDescent private[spark] (private var gradient: Gradient, private var updater: Updater) extends Optimizer with Logging { private var stepSize: Double = 1.0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 5ac10f3fd32dd..0768204c33914 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -118,7 +118,7 @@ private[tree] class EntropyAggregator(numClasses: Int) * (node, feature, bin). * @param stats Array of sufficient statistics for a (node, feature, bin). */ -private[tree] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { +private[spark] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { /** * Make a deep copy of this [[ImpurityCalculator]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 19d318203c344..d0077db6832e3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -114,7 +114,7 @@ private[tree] class GiniAggregator(numClasses: Int) * (node, feature, bin). * @param stats Array of sufficient statistics for a (node, feature, bin). */ -private[tree] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { +private[spark] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { /** * Make a deep copy of this [[ImpurityCalculator]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 578749d85a4e6..86cee7e430b0a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -95,7 +95,7 @@ private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Ser * (node, feature, bin). * @param stats Array of sufficient statistics for a (node, feature, bin). */ -private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) { +private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) extends Serializable { /** * Make a deep copy of this [[ImpurityCalculator]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 7104a7fa4dd4c..04d0cd24e6632 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -98,7 +98,7 @@ private[tree] class VarianceAggregator() * (node, feature, bin). * @param stats Array of sufficient statistics for a (node, feature, bin). */ -private[tree] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { +private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { require(stats.size == 3, s"VarianceCalculator requires sufficient statistics array stats to be of length 3," + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index dc9e0f9f51ffb..508bf9c1bdb47 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.tree.impurity.ImpurityCalculator /** * :: DeveloperApi :: @@ -66,7 +67,6 @@ class InformationGainStats( } } - private[spark] object InformationGainStats { /** * An [[org.apache.spark.mllib.tree.model.InformationGainStats]] object to @@ -76,3 +76,62 @@ private[spark] object InformationGainStats { val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, new Predict(0.0, 0.0), new Predict(0.0, 0.0)) } + +/** + * :: DeveloperApi :: + * Impurity statistics for each split + * @param gain information gain value + * @param impurity current node impurity + * @param impurityCalculator impurity statistics for current node + * @param leftImpurityCalculator impurity statistics for left child node + * @param rightImpurityCalculator impurity statistics for right child node + * @param valid whether the current split satisfies minimum info gain or + * minimum number of instances per node + */ +@DeveloperApi +private[spark] class ImpurityStats( + val gain: Double, + val impurity: Double, + val impurityCalculator: ImpurityCalculator, + val leftImpurityCalculator: ImpurityCalculator, + val rightImpurityCalculator: ImpurityCalculator, + val valid: Boolean = true) extends Serializable { + + override def toString: String = { + s"gain = $gain, impurity = $impurity, left impurity = $leftImpurity, " + + s"right impurity = $rightImpurity" + } + + def leftImpurity: Double = if (leftImpurityCalculator != null) { + leftImpurityCalculator.calculate() + } else { + -1.0 + } + + def rightImpurity: Double = if (rightImpurityCalculator != null) { + rightImpurityCalculator.calculate() + } else { + -1.0 + } +} + +private[spark] object ImpurityStats { + + /** + * Return an [[org.apache.spark.mllib.tree.model.ImpurityStats]] object to + * denote that current split doesn't satisfies minimum info gain or + * minimum number of instances per node. + */ + def getInvalidImpurityStats(impurityCalculator: ImpurityCalculator): ImpurityStats = { + new ImpurityStats(Double.MinValue, impurityCalculator.calculate(), + impurityCalculator, null, null, false) + } + + /** + * Return an [[org.apache.spark.mllib.tree.model.ImpurityStats]] object + * that only 'impurity' and 'impurityCalculator' are defined. + */ + def getEmptyImpurityStats(impurityCalculator: ImpurityCalculator): ImpurityStats = { + new ImpurityStats(Double.NaN, impurityCalculator.calculate(), impurityCalculator, null, null) + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java index 09a9fba0c19cf..a700c9cddb206 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java @@ -68,7 +68,7 @@ public void naiveBayesDefaultParams() { assert(nb.getLabelCol() == "label"); assert(nb.getFeaturesCol() == "features"); assert(nb.getPredictionCol() == "prediction"); - assert(nb.getLambda() == 1.0); + assert(nb.getSmoothing() == 1.0); assert(nb.getModelType() == "multinomial"); } @@ -89,7 +89,7 @@ public void testNaiveBayes() { }); DataFrame dataset = jsql.createDataFrame(jrdd, schema); - NaiveBayes nb = new NaiveBayes().setLambda(0.5).setModelType("multinomial"); + NaiveBayes nb = new NaiveBayes().setSmoothing(0.5).setModelType("multinomial"); NaiveBayesModel model = nb.fit(dataset); DataFrame predictionAndLabels = model.transform(dataset).select("prediction", "label"); diff --git a/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala new file mode 100644 index 0000000000000..1292e57d7c01a --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.ann + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + + +class ANNSuite extends SparkFunSuite with MLlibTestSparkContext { + + // TODO: test for weights comparison with Weka MLP + test("ANN with Sigmoid learns XOR function with LBFGS optimizer") { + val inputs = Array( + Array(0.0, 0.0), + Array(0.0, 1.0), + Array(1.0, 0.0), + Array(1.0, 1.0) + ) + val outputs = Array(0.0, 1.0, 1.0, 0.0) + val data = inputs.zip(outputs).map { case (features, label) => + (Vectors.dense(features), Vectors.dense(label)) + } + val rddData = sc.parallelize(data, 1) + val hiddenLayersTopology = Array(5) + val dataSample = rddData.first() + val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size + val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false) + val initialWeights = FeedForwardModel(topology, 23124).weights() + val trainer = new FeedForwardTrainer(topology, 2, 1) + trainer.setWeights(initialWeights) + trainer.LBFGSOptimizer.setNumIterations(20) + val model = trainer.train(rddData) + val predictionAndLabels = rddData.map { case (input, label) => + (model.predict(input)(0), label(0)) + }.collect() + predictionAndLabels.foreach { case (p, l) => + assert(math.round(p) === l) + } + } + + test("ANN with SoftMax learns XOR function with 2-bit output and batch GD optimizer") { + val inputs = Array( + Array(0.0, 0.0), + Array(0.0, 1.0), + Array(1.0, 0.0), + Array(1.0, 1.0) + ) + val outputs = Array( + Array(1.0, 0.0), + Array(0.0, 1.0), + Array(0.0, 1.0), + Array(1.0, 0.0) + ) + val data = inputs.zip(outputs).map { case (features, label) => + (Vectors.dense(features), Vectors.dense(label)) + } + val rddData = sc.parallelize(data, 1) + val hiddenLayersTopology = Array(5) + val dataSample = rddData.first() + val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size + val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false) + val initialWeights = FeedForwardModel(topology, 23124).weights() + val trainer = new FeedForwardTrainer(topology, 2, 2) + trainer.SGDOptimizer.setNumIterations(2000) + trainer.setWeights(initialWeights) + val model = trainer.train(rddData) + val predictionAndLabels = rddData.map { case (input, label) => + (model.predict(input), label) + }.collect() + predictionAndLabels.foreach { case (p, l) => + assert(p ~== l absTol 0.5) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 73b4805c4c597..c7bbf1ce07a23 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -21,12 +21,13 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.tree.LeafNode -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Row class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -57,7 +58,7 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte test("params") { ParamsSuite.checkParams(new DecisionTreeClassifier) - val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0)) + val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2) ParamsSuite.checkParams(model) } @@ -231,6 +232,31 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) } + test("predictRaw and predictProbability") { + val rdd = continuousDataPointsForMulticlassRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 3) + val numClasses = 3 + + val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) + val newTree = dt.fit(newData) + + val predictions = newTree.transform(newData) + .select(newTree.getPredictionCol, newTree.getRawPredictionCol, newTree.getProbabilityCol) + .collect() + + predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) => + assert(pred === rawPred.argmax, + s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.") + val sum = rawPred.toArray.sum + assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred, + "probability prediction mismatch") + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index a7bc77965fefd..d4b5896c12c06 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -58,7 +58,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { test("params") { ParamsSuite.checkParams(new GBTClassifier) val model = new GBTClassificationModel("gbtc", - Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0))), + Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null))), Array(1.0)) ParamsSuite.checkParams(model) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala new file mode 100644 index 0000000000000..ddc948f65df45 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.classification.LogisticRegressionSuite._ +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.Row + +class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("XOR function learning as binary classification problem with two outputs.") { + val dataFrame = sqlContext.createDataFrame(Seq( + (Vectors.dense(0.0, 0.0), 0.0), + (Vectors.dense(0.0, 1.0), 1.0), + (Vectors.dense(1.0, 0.0), 1.0), + (Vectors.dense(1.0, 1.0), 0.0)) + ).toDF("features", "label") + val layers = Array[Int](2, 5, 2) + val trainer = new MultilayerPerceptronClassifier() + .setLayers(layers) + .setBlockSize(1) + .setSeed(11L) + .setMaxIter(100) + val model = trainer.fit(dataFrame) + val result = model.transform(dataFrame) + val predictionAndLabels = result.select("prediction", "label").collect() + predictionAndLabels.foreach { case Row(p: Double, l: Double) => + assert(p == l) + } + } + + // TODO: implement a more rigorous test + test("3 class classification with 2 hidden layers") { + val nPoints = 1000 + + // The following weights are taken from OneVsRestSuite.scala + // they represent 3-class iris dataset + val weights = Array( + -0.57997, 0.912083, -0.371077, -0.819866, 2.688191, + -0.16624, -0.84355, -0.048509, -0.301789, 4.170682) + + val xMean = Array(5.843, 3.057, 3.758, 1.199) + val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) + val rdd = sc.parallelize(generateMultinomialLogisticInput( + weights, xMean, xVariance, true, nPoints, 42), 2) + val dataFrame = sqlContext.createDataFrame(rdd).toDF("label", "features") + val numClasses = 3 + val numIterations = 100 + val layers = Array[Int](4, 5, 4, numClasses) + val trainer = new MultilayerPerceptronClassifier() + .setLayers(layers) + .setBlockSize(1) + .setSeed(11L) + .setMaxIter(numIterations) + val model = trainer.fit(dataFrame) + val mlpPredictionAndLabels = model.transform(dataFrame).select("prediction", "label") + .map { case Row(p: Double, l: Double) => (p, l) } + // train multinomial logistic regression + val lr = new LogisticRegressionWithLBFGS() + .setIntercept(true) + .setNumClasses(numClasses) + lr.optimizer.setRegParam(0.0) + .setNumIterations(numIterations) + val lrModel = lr.run(rdd) + val lrPredictionAndLabels = lrModel.predict(rdd.map(_.features)).zip(rdd.map(_.label)) + // MLP's predictions should not differ a lot from LR's. + val lrMetrics = new MulticlassMetrics(lrPredictionAndLabels) + val mlpMetrics = new MulticlassMetrics(mlpPredictionAndLabels) + assert(mlpMetrics.confusionMatrix ~== lrMetrics.confusionMatrix absTol 100) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 76381a2741296..aea3d9b694490 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -17,8 +17,11 @@ package org.apache.spark.ml.classification +import breeze.linalg.{Vector => BV} + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.classification.NaiveBayes import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -28,6 +31,8 @@ import org.apache.spark.sql.Row class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { + import NaiveBayes.{Multinomial, Bernoulli} + def validatePrediction(predictionAndLabels: DataFrame): Unit = { val numOfErrorPredictions = predictionAndLabels.collect().count { case Row(prediction: Double, label: Double) => @@ -46,6 +51,43 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { assert(model.theta.map(math.exp) ~== thetaData.map(math.exp) absTol 0.05, "theta mismatch") } + def expectedMultinomialProbabilities(model: NaiveBayesModel, feature: Vector): Vector = { + val logClassProbs: BV[Double] = model.pi.toBreeze + model.theta.multiply(feature).toBreeze + val classProbs = logClassProbs.toArray.map(math.exp) + val classProbsSum = classProbs.sum + Vectors.dense(classProbs.map(_ / classProbsSum)) + } + + def expectedBernoulliProbabilities(model: NaiveBayesModel, feature: Vector): Vector = { + val negThetaMatrix = model.theta.map(v => math.log(1.0 - math.exp(v))) + val negFeature = Vectors.dense(feature.toArray.map(v => 1.0 - v)) + val piTheta: BV[Double] = model.pi.toBreeze + model.theta.multiply(feature).toBreeze + val logClassProbs: BV[Double] = piTheta + negThetaMatrix.multiply(negFeature).toBreeze + val classProbs = logClassProbs.toArray.map(math.exp) + val classProbsSum = classProbs.sum + Vectors.dense(classProbs.map(_ / classProbsSum)) + } + + def validateProbabilities( + featureAndProbabilities: DataFrame, + model: NaiveBayesModel, + modelType: String): Unit = { + featureAndProbabilities.collect().foreach { + case Row(features: Vector, probability: Vector) => { + assert(probability.toArray.sum ~== 1.0 relTol 1.0e-10) + val expected = modelType match { + case Multinomial => + expectedMultinomialProbabilities(model, features) + case Bernoulli => + expectedBernoulliProbabilities(model, features) + case _ => + throw new UnknownError(s"Invalid modelType: $modelType.") + } + assert(probability ~== expected relTol 1.0e-10) + } + } + } + test("params") { ParamsSuite.checkParams(new NaiveBayes) val model = new NaiveBayesModel("nb", pi = Vectors.dense(Array(0.2, 0.8)), @@ -58,7 +100,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { assert(nb.getLabelCol === "label") assert(nb.getFeaturesCol === "features") assert(nb.getPredictionCol === "prediction") - assert(nb.getLambda === 1.0) + assert(nb.getSmoothing === 1.0) assert(nb.getModelType === "multinomial") } @@ -75,7 +117,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput( piArray, thetaArray, nPoints, 42, "multinomial")) - val nb = new NaiveBayes().setLambda(1.0).setModelType("multinomial") + val nb = new NaiveBayes().setSmoothing(1.0).setModelType("multinomial") val model = nb.fit(testDataset) validateModelFit(pi, theta, model) @@ -83,9 +125,13 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput( piArray, thetaArray, nPoints, 17, "multinomial")) - val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") + val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") validatePrediction(predictionAndLabels) + + val featureAndProbabilities = model.transform(validationDataset) + .select("features", "probability") + validateProbabilities(featureAndProbabilities, model, "multinomial") } test("Naive Bayes Bernoulli") { @@ -101,7 +147,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput( piArray, thetaArray, nPoints, 45, "bernoulli")) - val nb = new NaiveBayes().setLambda(1.0).setModelType("bernoulli") + val nb = new NaiveBayes().setSmoothing(1.0).setModelType("bernoulli") val model = nb.fit(testDataset) validateModelFit(pi, theta, model) @@ -109,8 +155,12 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput( piArray, thetaArray, nPoints, 20, "bernoulli")) - val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") + val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") validatePrediction(predictionAndLabels) + + val featureAndProbabilities = model.transform(validationDataset) + .select("features", "probability") + validateProbabilities(featureAndProbabilities, model, "bernoulli") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index ab711c8e4b215..dbb2577c6204d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -66,7 +66,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte test("params") { ParamsSuite.checkParams(new RandomForestClassifier) val model = new RandomForestClassificationModel("rfc", - Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))), 2) + Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)), 2) ParamsSuite.checkParams(model) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala new file mode 100644 index 0000000000000..6d8412b0b3701 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite + +class MulticlassClassificationEvaluatorSuite extends SparkFunSuite { + + test("params") { + ParamsSuite.checkParams(new MulticlassClassificationEvaluator) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala new file mode 100644 index 0000000000000..f01306f89cb5f --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Row} + +object StopWordsRemoverSuite extends SparkFunSuite { + def testStopWordsRemover(t: StopWordsRemover, dataset: DataFrame): Unit = { + t.transform(dataset) + .select("filtered", "expected") + .collect() + .foreach { case Row(tokens, wantedTokens) => + assert(tokens === wantedTokens) + } + } +} + +class StopWordsRemoverSuite extends SparkFunSuite with MLlibTestSparkContext { + import StopWordsRemoverSuite._ + + test("StopWordsRemover default") { + val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") + val dataSet = sqlContext.createDataFrame(Seq( + (Seq("test", "test"), Seq("test", "test")), + (Seq("a", "b", "c", "d"), Seq("b", "c", "d")), + (Seq("a", "the", "an"), Seq()), + (Seq("A", "The", "AN"), Seq()), + (Seq(null), Seq(null)), + (Seq(), Seq()) + )).toDF("raw", "expected") + + testStopWordsRemover(remover, dataSet) + } + + test("StopWordsRemover case sensitive") { + val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") + .setCaseSensitive(true) + val dataSet = sqlContext.createDataFrame(Seq( + (Seq("A"), Seq("A")), + (Seq("The", "the"), Seq("The")) + )).toDF("raw", "expected") + + testStopWordsRemover(remover, dataSet) + } + + test("StopWordsRemover with additional words") { + val stopWords = StopWords.EnglishStopWords ++ Array("python", "scala") + val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") + .setStopWords(stopWords) + val dataSet = sqlContext.createDataFrame(Seq( + (Seq("python", "scala", "a"), Seq()), + (Seq("Python", "Scala", "swift"), Seq("swift")) + )).toDF("raw", "expected") + + testStopWordsRemover(remover, dataSet) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 99f82bea42688..d0295a0fe2fc1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -47,6 +47,19 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { // a -> 0, b -> 2, c -> 1 val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) assert(output === expected) + // convert reverse our transform + val reversed = indexer.invert("labelIndex", "label2") + .transform(transformed) + .select("id", "label2") + assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet === + reversed.collect().map(r => (r.getInt(0), r.getString(1))).toSet) + // Check invert using only metadata + val inverse2 = new StringIndexerInverse() + .setInputCol("labelIndex") + .setOutputCol("label2") + val reversed2 = inverse2.transform(transformed).select("id", "label2") + assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet === + reversed2.collect().map(r => (r.getInt(0), r.getString(1))).toSet) } test("StringIndexer with a numeric input column") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index c43e1e575c09c..fdc2554ab853e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.clustering -import breeze.linalg.{DenseMatrix => BDM, max, argmax} +import breeze.linalg.{DenseMatrix => BDM, argtopk, max, argmax} import org.apache.spark.SparkFunSuite import org.apache.spark.graphx.Edge @@ -108,9 +108,42 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { assert(topicDistribution.toArray.sum ~== 1.0 absTol 1e-5) } + val top2TopicsPerDoc = model.topTopicsPerDocument(2).map(t => (t._1, (t._2, t._3))) + model.topicDistributions.join(top2TopicsPerDoc).collect().foreach { + case (docId, (topicDistribution, (indices, weights))) => + assert(indices.length == 2) + assert(weights.length == 2) + val bdvTopicDist = topicDistribution.toBreeze + val top2Indices = argtopk(bdvTopicDist, 2) + assert(top2Indices.toArray === indices) + assert(bdvTopicDist(top2Indices).toArray === weights) + } + // Check: log probabilities assert(model.logLikelihood < 0.0) assert(model.logPrior < 0.0) + + // Check: topDocumentsPerTopic + // Compare it with top documents per topic derived from topicDistributions + val topDocsByTopicDistributions = { n: Int => + Range(0, k).map { topic => + val (doc, docWeights) = topicDistributions.sortBy(-_._2(topic)).take(n).unzip + (doc.toArray, docWeights.map(_(topic)).toArray) + }.toArray + } + + // Top 3 documents per topic + model.topDocumentsPerTopic(3).zip(topDocsByTopicDistributions(3)).foreach {case (t1, t2) => + assert(t1._1 === t2._1) + assert(t1._2 === t2._2) + } + + // All documents per topic + val q = tinyCorpus.length + model.topDocumentsPerTopic(q).zip(topDocsByTopicDistributions(q)).foreach {case (t1, t2) => + assert(t1._1 === t2._1) + assert(t1._2 === t2._2) + } } test("vertex indexing") { @@ -199,16 +232,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { } test("OnlineLDAOptimizer with toy data") { - def toydata: Array[(Long, Vector)] = Array( - Vectors.sparse(6, Array(0, 1), Array(1, 1)), - Vectors.sparse(6, Array(1, 2), Array(1, 1)), - Vectors.sparse(6, Array(0, 2), Array(1, 1)), - Vectors.sparse(6, Array(3, 4), Array(1, 1)), - Vectors.sparse(6, Array(3, 5), Array(1, 1)), - Vectors.sparse(6, Array(4, 5), Array(1, 1)) - ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } - - val docs = sc.parallelize(toydata) + val docs = sc.parallelize(toyData) val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau0(1024).setKappa(0.51) .setGammaShape(1e10) val lda = new LDA().setK(2) @@ -231,30 +255,45 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { } } - test("LocalLDAModel logPerplexity") { - val k = 2 - val vocabSize = 6 - val alpha = 0.01 - val eta = 0.01 - val gammaShape = 100 - // obtained from LDA model trained in gensim, see below - val topics = new DenseMatrix(numRows = vocabSize, numCols = k, values = Array( - 1.86738052, 1.94056535, 1.89981687, 0.0833265, 0.07405918, 0.07940597, - 0.15081551, 0.08637973, 0.12428538, 1.9474897, 1.94615165, 1.95204124)) + test("LocalLDAModel logLikelihood") { + val ldaModel: LocalLDAModel = toyModel - def toydata: Array[(Long, Vector)] = Array( - Vectors.sparse(6, Array(0, 1), Array(1, 1)), - Vectors.sparse(6, Array(1, 2), Array(1, 1)), - Vectors.sparse(6, Array(0, 2), Array(1, 1)), - Vectors.sparse(6, Array(3, 4), Array(1, 1)), - Vectors.sparse(6, Array(3, 5), Array(1, 1)), - Vectors.sparse(6, Array(4, 5), Array(1, 1)) - ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } - val docs = sc.parallelize(toydata) + val docsSingleWord = sc.parallelize(Array(Vectors.sparse(6, Array(0), Array(1))) + .zipWithIndex + .map { case (wordCounts, docId) => (docId.toLong, wordCounts) }) + val docsRepeatedWord = sc.parallelize(Array(Vectors.sparse(6, Array(0), Array(5))) + .zipWithIndex + .map { case (wordCounts, docId) => (docId.toLong, wordCounts) }) + /* Verify results using gensim: + import numpy as np + from gensim import models + corpus = [ + [(0, 1.0), (1, 1.0)], + [(1, 1.0), (2, 1.0)], + [(0, 1.0), (2, 1.0)], + [(3, 1.0), (4, 1.0)], + [(3, 1.0), (5, 1.0)], + [(4, 1.0), (5, 1.0)]] + np.random.seed(2345) + lda = models.ldamodel.LdaModel( + corpus=corpus, alpha=0.01, eta=0.01, num_topics=2, update_every=0, passes=100, + decay=0.51, offset=1024) + docsSingleWord = [[(0, 1.0)]] + docsRepeatedWord = [[(0, 5.0)]] + print(lda.bound(docsSingleWord)) + > -25.9706969833 + print(lda.bound(docsRepeatedWord)) + > -31.4413908227 + */ - val ldaModel: LocalLDAModel = new LocalLDAModel( - topics, Vectors.dense(Array.fill(k)(alpha)), eta, gammaShape) + assert(ldaModel.logLikelihood(docsSingleWord) ~== -25.971 relTol 1E-3D) + assert(ldaModel.logLikelihood(docsRepeatedWord) ~== -31.441 relTol 1E-3D) + } + + test("LocalLDAModel logPerplexity") { + val docs = sc.parallelize(toyData) + val ldaModel: LocalLDAModel = toyModel /* Verify results using gensim: import numpy as np @@ -274,32 +313,13 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { > -3.69051285096 */ - assert(ldaModel.logPerplexity(docs) ~== -3.690D relTol 1E-3D) + // Gensim's definition of perplexity is negative our (and Stanford NLP's) definition + assert(ldaModel.logPerplexity(docs) ~== 3.690D relTol 1E-3D) } test("LocalLDAModel predict") { - val k = 2 - val vocabSize = 6 - val alpha = 0.01 - val eta = 0.01 - val gammaShape = 100 - // obtained from LDA model trained in gensim, see below - val topics = new DenseMatrix(numRows = vocabSize, numCols = k, values = Array( - 1.86738052, 1.94056535, 1.89981687, 0.0833265, 0.07405918, 0.07940597, - 0.15081551, 0.08637973, 0.12428538, 1.9474897, 1.94615165, 1.95204124)) - - def toydata: Array[(Long, Vector)] = Array( - Vectors.sparse(6, Array(0, 1), Array(1, 1)), - Vectors.sparse(6, Array(1, 2), Array(1, 1)), - Vectors.sparse(6, Array(0, 2), Array(1, 1)), - Vectors.sparse(6, Array(3, 4), Array(1, 1)), - Vectors.sparse(6, Array(3, 5), Array(1, 1)), - Vectors.sparse(6, Array(4, 5), Array(1, 1)) - ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } - val docs = sc.parallelize(toydata) - - val ldaModel: LocalLDAModel = new LocalLDAModel( - topics, Vectors.dense(Array.fill(k)(alpha)), eta, gammaShape) + val docs = sc.parallelize(toyData) + val ldaModel: LocalLDAModel = toyModel /* Verify results using gensim: import numpy as np @@ -340,16 +360,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { } test("OnlineLDAOptimizer with asymmetric prior") { - def toydata: Array[(Long, Vector)] = Array( - Vectors.sparse(6, Array(0, 1), Array(1, 1)), - Vectors.sparse(6, Array(1, 2), Array(1, 1)), - Vectors.sparse(6, Array(0, 2), Array(1, 1)), - Vectors.sparse(6, Array(3, 4), Array(1, 1)), - Vectors.sparse(6, Array(3, 5), Array(1, 1)), - Vectors.sparse(6, Array(4, 5), Array(1, 1)) - ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } - - val docs = sc.parallelize(toydata) + val docs = sc.parallelize(toyData) val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau0(1024).setKappa(0.51) .setGammaShape(1e10) val lda = new LDA().setK(2) @@ -389,6 +400,40 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("OnlineLDAOptimizer alpha hyperparameter optimization") { + val k = 2 + val docs = sc.parallelize(toyData) + val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau0(1024).setKappa(0.51) + .setGammaShape(100).setOptimzeAlpha(true).setSampleWithReplacement(false) + val lda = new LDA().setK(k) + .setDocConcentration(1D / k) + .setTopicConcentration(0.01) + .setMaxIterations(100) + .setOptimizer(op) + .setSeed(12345) + val ldaModel: LocalLDAModel = lda.run(docs).asInstanceOf[LocalLDAModel] + + /* Verify the results with gensim: + import numpy as np + from gensim import models + corpus = [ + [(0, 1.0), (1, 1.0)], + [(1, 1.0), (2, 1.0)], + [(0, 1.0), (2, 1.0)], + [(3, 1.0), (4, 1.0)], + [(3, 1.0), (5, 1.0)], + [(4, 1.0), (5, 1.0)]] + np.random.seed(2345) + lda = models.ldamodel.LdaModel( + corpus=corpus, alpha='auto', eta=0.01, num_topics=2, update_every=0, passes=100, + decay=0.51, offset=1024) + print(lda.alpha) + > [ 0.42582646 0.43511073] + */ + + assert(ldaModel.docConcentration ~== Vectors.dense(0.42582646, 0.43511073) absTol 0.05) + } + test("model save/load") { // Test for LocalLDAModel. val localModel = new LocalLDAModel(tinyTopics, @@ -520,4 +565,27 @@ private[clustering] object LDASuite { def getNonEmptyDoc(corpus: Array[(Long, Vector)]): Array[(Long, Vector)] = corpus.filter { case (_, wc: Vector) => Vectors.norm(wc, p = 1.0) != 0.0 } + + def toyData: Array[(Long, Vector)] = Array( + Vectors.sparse(6, Array(0, 1), Array(1, 1)), + Vectors.sparse(6, Array(1, 2), Array(1, 1)), + Vectors.sparse(6, Array(0, 2), Array(1, 1)), + Vectors.sparse(6, Array(3, 4), Array(1, 1)), + Vectors.sparse(6, Array(3, 5), Array(1, 1)), + Vectors.sparse(6, Array(4, 5), Array(1, 1)) + ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } + + def toyModel: LocalLDAModel = { + val k = 2 + val vocabSize = 6 + val alpha = 0.01 + val eta = 0.01 + val gammaShape = 100 + val topics = new DenseMatrix(numRows = vocabSize, numCols = k, values = Array( + 1.86738052, 1.94056535, 1.89981687, 0.0833265, 0.07405918, 0.07940597, + 0.15081551, 0.08637973, 0.12428538, 1.9474897, 1.94615165, 1.95204124)) + val ldaModel: LocalLDAModel = new LocalLDAModel( + topics, Vectors.dense(Array.fill(k)(alpha)), eta, gammaShape) + ldaModel + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala index 6dd2dc926acc5..457f32670fd4e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { - test("PrefixSpan using Integer type") { + test("PrefixSpan using Integer type, singleton itemsets") { /* library("arulesSequences") @@ -35,12 +35,12 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { */ val sequences = Array( - Array(1, 3, 4, 5), - Array(2, 3, 1), - Array(2, 4, 1), - Array(3, 1, 3, 4, 5), - Array(3, 4, 4, 3), - Array(6, 5, 3)) + Array(1, -1, 3, -1, 4, -1, 5), + Array(2, -1, 3, -1, 1), + Array(2, -1, 4, -1, 1), + Array(3, -1, 1, -1, 3, -1, 4, -1, 5), + Array(3, -1, 4, -1, 4, -1, 3), + Array(6, -1, 5, -1, 3)) val rdd = sc.parallelize(sequences, 2).cache() @@ -50,64 +50,225 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { val result1 = prefixspan.run(rdd) val expectedValue1 = Array( (Array(1), 4L), - (Array(1, 3), 2L), - (Array(1, 3, 4), 2L), - (Array(1, 3, 4, 5), 2L), - (Array(1, 3, 5), 2L), - (Array(1, 4), 2L), - (Array(1, 4, 5), 2L), - (Array(1, 5), 2L), + (Array(1, -1, 3), 2L), + (Array(1, -1, 3, -1, 4), 2L), + (Array(1, -1, 3, -1, 4, -1, 5), 2L), + (Array(1, -1, 3, -1, 5), 2L), + (Array(1, -1, 4), 2L), + (Array(1, -1, 4, -1, 5), 2L), + (Array(1, -1, 5), 2L), (Array(2), 2L), - (Array(2, 1), 2L), + (Array(2, -1, 1), 2L), (Array(3), 5L), - (Array(3, 1), 2L), - (Array(3, 3), 2L), - (Array(3, 4), 3L), - (Array(3, 4, 5), 2L), - (Array(3, 5), 2L), + (Array(3, -1, 1), 2L), + (Array(3, -1, 3), 2L), + (Array(3, -1, 4), 3L), + (Array(3, -1, 4, -1, 5), 2L), + (Array(3, -1, 5), 2L), (Array(4), 4L), - (Array(4, 5), 2L), + (Array(4, -1, 5), 2L), (Array(5), 3L) ) - assert(compareResults(expectedValue1, result1.collect())) + compareResults(expectedValue1, result1.collect()) prefixspan.setMinSupport(0.5).setMaxPatternLength(50) val result2 = prefixspan.run(rdd) val expectedValue2 = Array( (Array(1), 4L), (Array(3), 5L), - (Array(3, 4), 3L), + (Array(3, -1, 4), 3L), (Array(4), 4L), (Array(5), 3L) ) - assert(compareResults(expectedValue2, result2.collect())) + compareResults(expectedValue2, result2.collect()) prefixspan.setMinSupport(0.33).setMaxPatternLength(2) val result3 = prefixspan.run(rdd) val expectedValue3 = Array( (Array(1), 4L), - (Array(1, 3), 2L), - (Array(1, 4), 2L), - (Array(1, 5), 2L), - (Array(2, 1), 2L), + (Array(1, -1, 3), 2L), + (Array(1, -1, 4), 2L), + (Array(1, -1, 5), 2L), + (Array(2, -1, 1), 2L), (Array(2), 2L), (Array(3), 5L), - (Array(3, 1), 2L), - (Array(3, 3), 2L), - (Array(3, 4), 3L), - (Array(3, 5), 2L), + (Array(3, -1, 1), 2L), + (Array(3, -1, 3), 2L), + (Array(3, -1, 4), 3L), + (Array(3, -1, 5), 2L), (Array(4), 4L), - (Array(4, 5), 2L), + (Array(4, -1, 5), 2L), (Array(5), 3L) ) - assert(compareResults(expectedValue3, result3.collect())) + compareResults(expectedValue3, result3.collect()) + } + + test("PrefixSpan using Integer type, variable-size itemsets") { + val sequences = Array( + Array(1, -1, 1, 2, 3, -1, 1, 3, -1, 4, -1, 3, 6), + Array(1, 4, -1, 3, -1, 2, 3, -1, 1, 5), + Array(5, 6, -1, 1, 2, -1, 4, 6, -1, 3, -1, 2), + Array(5, -1, 7, -1, 1, 6, -1, 3, -1, 2, -1, 3)) + val rdd = sc.parallelize(sequences, 2).cache() + val prefixspan = new PrefixSpan().setMinSupport(0.5).setMaxPatternLength(5) + val result = prefixspan.run(rdd) + + /* + To verify results, create file "prefixSpanSeqs" with content + (format = (transactionID, idxInTransaction, numItemsinItemset, itemset)): + 1 1 1 1 + 1 2 3 1 2 3 + 1 3 2 1 3 + 1 4 1 4 + 1 5 2 3 6 + 2 1 2 1 4 + 2 2 1 3 + 2 3 2 2 3 + 2 4 2 1 5 + 3 1 2 5 6 + 3 2 2 1 2 + 3 3 2 4 6 + 3 4 1 3 + 3 5 1 2 + 4 1 1 5 + 4 2 1 7 + 4 3 2 1 6 + 4 4 1 3 + 4 5 1 2 + 4 6 1 3 + In R, run: + library("arulesSequences") + prefixSpanSeqs = read_baskets("prefixSpanSeqs", info = c("sequenceID","eventID","SIZE")) + freqItemSeq = cspade(prefixSpanSeqs, + parameter = list(support = 0.5, maxlen = 5 )) + resSeq = as(freqItemSeq, "data.frame") + resSeq + + sequence support + 1 <{1}> 1.00 + 2 <{2}> 1.00 + 3 <{3}> 1.00 + 4 <{4}> 0.75 + 5 <{5}> 0.75 + 6 <{6}> 0.75 + 7 <{1},{6}> 0.50 + 8 <{2},{6}> 0.50 + 9 <{5},{6}> 0.50 + 10 <{1,2},{6}> 0.50 + 11 <{1},{4}> 0.50 + 12 <{2},{4}> 0.50 + 13 <{1,2},{4}> 0.50 + 14 <{1},{3}> 1.00 + 15 <{2},{3}> 0.75 + 16 <{2,3}> 0.50 + 17 <{3},{3}> 0.75 + 18 <{4},{3}> 0.75 + 19 <{5},{3}> 0.50 + 20 <{6},{3}> 0.50 + 21 <{5},{6},{3}> 0.50 + 22 <{6},{2},{3}> 0.50 + 23 <{5},{2},{3}> 0.50 + 24 <{5},{1},{3}> 0.50 + 25 <{2},{4},{3}> 0.50 + 26 <{1},{4},{3}> 0.50 + 27 <{1,2},{4},{3}> 0.50 + 28 <{1},{3},{3}> 0.75 + 29 <{1,2},{3}> 0.50 + 30 <{1},{2},{3}> 0.50 + 31 <{1},{2,3}> 0.50 + 32 <{1},{2}> 1.00 + 33 <{1,2}> 0.50 + 34 <{3},{2}> 0.75 + 35 <{4},{2}> 0.50 + 36 <{5},{2}> 0.50 + 37 <{6},{2}> 0.50 + 38 <{5},{6},{2}> 0.50 + 39 <{6},{3},{2}> 0.50 + 40 <{5},{3},{2}> 0.50 + 41 <{5},{1},{2}> 0.50 + 42 <{4},{3},{2}> 0.50 + 43 <{1},{3},{2}> 0.75 + 44 <{5},{6},{3},{2}> 0.50 + 45 <{5},{1},{3},{2}> 0.50 + 46 <{1},{1}> 0.50 + 47 <{2},{1}> 0.50 + 48 <{3},{1}> 0.50 + 49 <{5},{1}> 0.50 + 50 <{2,3},{1}> 0.50 + 51 <{1},{3},{1}> 0.50 + 52 <{1},{2,3},{1}> 0.50 + 53 <{1},{2},{1}> 0.50 + */ + val expectedValue = Array( + (Array(1), 4L), + (Array(2), 4L), + (Array(3), 4L), + (Array(4), 3L), + (Array(5), 3L), + (Array(6), 3L), + (Array(1, -1, 6), 2L), + (Array(2, -1, 6), 2L), + (Array(5, -1, 6), 2L), + (Array(1, 2, -1, 6), 2L), + (Array(1, -1, 4), 2L), + (Array(2, -1, 4), 2L), + (Array(1, 2, -1, 4), 2L), + (Array(1, -1, 3), 4L), + (Array(2, -1, 3), 3L), + (Array(2, 3), 2L), + (Array(3, -1, 3), 3L), + (Array(4, -1, 3), 3L), + (Array(5, -1, 3), 2L), + (Array(6, -1, 3), 2L), + (Array(5, -1, 6, -1, 3), 2L), + (Array(6, -1, 2, -1, 3), 2L), + (Array(5, -1, 2, -1, 3), 2L), + (Array(5, -1, 1, -1, 3), 2L), + (Array(2, -1, 4, -1, 3), 2L), + (Array(1, -1, 4, -1, 3), 2L), + (Array(1, 2, -1, 4, -1, 3), 2L), + (Array(1, -1, 3, -1, 3), 3L), + (Array(1, 2, -1, 3), 2L), + (Array(1, -1, 2, -1, 3), 2L), + (Array(1, -1, 2, 3), 2L), + (Array(1, -1, 2), 4L), + (Array(1, 2), 2L), + (Array(3, -1, 2), 3L), + (Array(4, -1, 2), 2L), + (Array(5, -1, 2), 2L), + (Array(6, -1, 2), 2L), + (Array(5, -1, 6, -1, 2), 2L), + (Array(6, -1, 3, -1, 2), 2L), + (Array(5, -1, 3, -1, 2), 2L), + (Array(5, -1, 1, -1, 2), 2L), + (Array(4, -1, 3, -1, 2), 2L), + (Array(1, -1, 3, -1, 2), 3L), + (Array(5, -1, 6, -1, 3, -1, 2), 2L), + (Array(5, -1, 1, -1, 3, -1, 2), 2L), + (Array(1, -1, 1), 2L), + (Array(2, -1, 1), 2L), + (Array(3, -1, 1), 2L), + (Array(5, -1, 1), 2L), + (Array(2, 3, -1, 1), 2L), + (Array(1, -1, 3, -1, 1), 2L), + (Array(1, -1, 2, 3, -1, 1), 2L), + (Array(1, -1, 2, -1, 1), 2L)) + + compareResults(expectedValue, result.collect()) } private def compareResults( - expectedValue: Array[(Array[Int], Long)], - actualValue: Array[(Array[Int], Long)]): Boolean = { - expectedValue.map(x => (x._1.toSeq, x._2)).toSet == - actualValue.map(x => (x._1.toSeq, x._2)).toSet + expectedValue: Array[(Array[Int], Long)], + actualValue: Array[(Array[Int], Long)]): Unit = { + val expectedSet = expectedValue.map(x => (x._1.toSeq, x._2)).toSet + val actualSet = actualValue.map(x => (x._1.toSeq, x._2)).toSet + assert(expectedSet === actualSet) + } + + private def insertDelimiter(sequence: Array[Int]): Array[Int] = { + sequence.zip(Seq.fill(sequence.length)(PrefixSpan.DELIMITER)).map { case (a, b) => + List(a, b) + }.flatten } } diff --git a/pom.xml b/pom.xml index b4410c6c56de8..eaf44597d2fb1 100644 --- a/pom.xml +++ b/pom.xml @@ -161,9 +161,6 @@ 2.4.4 1.1.1.7 1.1.2 - - false - ${java.home} - ${create.dependency.reduced.pom} @@ -1643,6 +1642,7 @@ kinesis-asl extras/kinesis-asl + extras/kinesis-asl-assembly @@ -1836,26 +1836,6 @@ - - - release - - - true - - -