From 32ce873bf5eb0b87c781b2f242b678e55767a8ac Mon Sep 17 00:00:00 2001 From: "wm624@hotmail.com" Date: Mon, 9 Jan 2017 15:50:36 -0800 Subject: [PATCH 1/6] take additional parameters for spark.kmeans --- R/pkg/R/mllib_clustering.R | 13 +++++++++++-- R/pkg/inst/tests/testthat/test_mllib_clustering.R | 3 ++- .../scala/org/apache/spark/ml/r/KMeansWrapper.scala | 9 ++++++++- 3 files changed, 21 insertions(+), 4 deletions(-) diff --git a/R/pkg/R/mllib_clustering.R b/R/pkg/R/mllib_clustering.R index c44358838703f..ca5182d527cfe 100644 --- a/R/pkg/R/mllib_clustering.R +++ b/R/pkg/R/mllib_clustering.R @@ -175,6 +175,10 @@ setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "charact #' @param k number of centers. #' @param maxIter maximum iteration number. #' @param initMode the initialization algorithm choosen to fit the model. +#' @param seed the random seed for cluster initialization. +#' @param initSteps the number of steps for the k-means|| initialization mode. +#' This is an advanced setting, the default of 2 is almost always enough. Must be > 0. +#' @param tol convergence tolerance of iterations. #' @param ... additional argument(s) passed to the method. #' @return \code{spark.kmeans} returns a fitted k-means model. #' @rdname spark.kmeans @@ -204,11 +208,16 @@ setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "charact #' @note spark.kmeans since 2.0.0 #' @seealso \link{predict}, \link{read.ml}, \link{write.ml} setMethod("spark.kmeans", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula, k = 2, maxIter = 20, initMode = c("k-means||", "random")) { + function(data, formula, k = 2, maxIter = 20, initMode = c("k-means||", "random"), + seed = NULL, initSteps = 2, tol = 1E-4) { formula <- paste(deparse(formula), collapse = "") initMode <- match.arg(initMode) + if (!is.null(seed)) { + seed <- as.character(as.integer(seed)) + } jobj <- callJStatic("org.apache.spark.ml.r.KMeansWrapper", "fit", data@sdf, formula, - as.integer(k), as.integer(maxIter), initMode) + as.integer(k), as.integer(maxIter), initMode, seed, + as.integer(initSteps), as.numeric(tol)) new("KMeansModel", jobj = jobj) }) diff --git a/R/pkg/inst/tests/testthat/test_mllib_clustering.R b/R/pkg/inst/tests/testthat/test_mllib_clustering.R index 1980fffd80cc6..812ca92894d27 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_clustering.R +++ b/R/pkg/inst/tests/testthat/test_mllib_clustering.R @@ -99,7 +99,8 @@ test_that("spark.kmeans", { take(training, 1) - model <- spark.kmeans(data = training, ~ ., k = 2, maxIter = 10, initMode = "random") + model <- spark.kmeans(data = training, ~ ., k = 2, maxIter = 10, initMode = "random", seed = 1, initSteps = 3, + tol = 1E-5) sample <- take(select(predict(model, training), "prediction"), 1) expect_equal(typeof(sample$prediction), "integer") expect_equal(sample$prediction, 1) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala index ea9458525aa31..a1fefd31c0579 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala @@ -68,7 +68,10 @@ private[r] object KMeansWrapper extends MLReadable[KMeansWrapper] { formula: String, k: Int, maxIter: Int, - initMode: String): KMeansWrapper = { + initMode: String, + seed: String, + initSteps: Int, + tol: Double): KMeansWrapper = { val rFormula = new RFormula() .setFormula(formula) @@ -87,6 +90,10 @@ private[r] object KMeansWrapper extends MLReadable[KMeansWrapper] { .setMaxIter(maxIter) .setInitMode(initMode) .setFeaturesCol(rFormula.getFeaturesCol) + .setInitSteps(initSteps) + .setTol(tol) + + if (seed != null && seed.length > 0) kMeans.setSeed(seed.toInt) val pipeline = new Pipeline() .setStages(Array(rFormulaModel, kMeans)) From 961f60147efd8699edb258a62b6f8ab3a4171376 Mon Sep 17 00:00:00 2001 From: "wm624@hotmail.com" Date: Mon, 9 Jan 2017 16:20:12 -0800 Subject: [PATCH 2/6] fix R style --- R/pkg/inst/tests/testthat/test_mllib_clustering.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_mllib_clustering.R b/R/pkg/inst/tests/testthat/test_mllib_clustering.R index 812ca92894d27..e4fce9baf7ad3 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_clustering.R +++ b/R/pkg/inst/tests/testthat/test_mllib_clustering.R @@ -99,8 +99,8 @@ test_that("spark.kmeans", { take(training, 1) - model <- spark.kmeans(data = training, ~ ., k = 2, maxIter = 10, initMode = "random", seed = 1, initSteps = 3, - tol = 1E-5) + model <- spark.kmeans(data = training, ~ ., k = 2, maxIter = 10, initMode = "random", seed = 1, + initSteps = 3, tol = 1E-5) sample <- take(select(predict(model, training), "prediction"), 1) expect_equal(typeof(sample$prediction), "integer") expect_equal(sample$prediction, 1) From 73f8f2e1c6770b0858f090f76e2d28ab701143e6 Mon Sep 17 00:00:00 2001 From: "wm624@hotmail.com" Date: Wed, 11 Jan 2017 12:17:22 -0800 Subject: [PATCH 3/6] add a test that is sensitive to seed value --- .../tests/testthat/test_mllib_clustering.R | 23 +++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_mllib_clustering.R b/R/pkg/inst/tests/testthat/test_mllib_clustering.R index e4fce9baf7ad3..b823eb5f50666 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_clustering.R +++ b/R/pkg/inst/tests/testthat/test_mllib_clustering.R @@ -99,8 +99,7 @@ test_that("spark.kmeans", { take(training, 1) - model <- spark.kmeans(data = training, ~ ., k = 2, maxIter = 10, initMode = "random", seed = 1, - initSteps = 3, tol = 1E-5) + model <- spark.kmeans(data = training, ~ ., k = 2, maxIter = 10, initMode = "random") sample <- take(select(predict(model, training), "prediction"), 1) expect_equal(typeof(sample$prediction), "integer") expect_equal(sample$prediction, 1) @@ -133,6 +132,26 @@ test_that("spark.kmeans", { expect_true(summary2$is.loaded) unlink(modelPath) + + # Test Kmeans on dataset that is sensitive to seed value + col1 <- c(1,2,3,4,0,1,2,3,4,0) + col2 <- c(1,2,3,4,0,1,2,3,4,0) + col3 <- c(1,2,3,4,0,1,2,3,4,0) + cols <- as.data.frame(cbind(col1, col2 , col3)) + df <- createDataFrame(cols) + + model1 <- spark.kmeans(data = df, ~ ., k = 5, maxIter = 10, + initMode = "random", seed = 1, tol = 1E-5) + model2 <- model <- spark.kmeans(data = df, ~ ., k = 5, maxIter = 10, + initMode = "random", seed = 22222, tol = 1E-5) + + fitted.model1 <- fitted(model1) + fitted.model2 <- fitted(model2) + # number of predicted clusters is different + expect_equal(sort(collect(distinct(select(fitted.model1, "prediction")))$prediction), + c(0, 1, 2, 3)) + expect_equal(sort(collect(distinct(select(fitted.model2, "prediction")))$prediction), + c(0, 1, 2)) }) test_that("spark.lda with libsvm", { From 44a0c73cc0b3caf257009cfc43239de29cdc6435 Mon Sep 17 00:00:00 2001 From: "wm624@hotmail.com" Date: Wed, 11 Jan 2017 12:18:48 -0800 Subject: [PATCH 4/6] modify comment --- R/pkg/inst/tests/testthat/test_mllib_clustering.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pkg/inst/tests/testthat/test_mllib_clustering.R b/R/pkg/inst/tests/testthat/test_mllib_clustering.R index b823eb5f50666..593a0ee5ed15b 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_clustering.R +++ b/R/pkg/inst/tests/testthat/test_mllib_clustering.R @@ -147,7 +147,7 @@ test_that("spark.kmeans", { fitted.model1 <- fitted(model1) fitted.model2 <- fitted(model2) - # number of predicted clusters is different + # The predicted clusters are different expect_equal(sort(collect(distinct(select(fitted.model1, "prediction")))$prediction), c(0, 1, 2, 3)) expect_equal(sort(collect(distinct(select(fitted.model2, "prediction")))$prediction), From c840c4d839cccb6097e1eed90314e3d7f02d6084 Mon Sep 17 00:00:00 2001 From: "wm624@hotmail.com" Date: Wed, 11 Jan 2017 12:37:11 -0800 Subject: [PATCH 5/6] fix style --- R/pkg/inst/tests/testthat/test_mllib_clustering.R | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_mllib_clustering.R b/R/pkg/inst/tests/testthat/test_mllib_clustering.R index 593a0ee5ed15b..d5bfe3bf6d641 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_clustering.R +++ b/R/pkg/inst/tests/testthat/test_mllib_clustering.R @@ -134,10 +134,10 @@ test_that("spark.kmeans", { unlink(modelPath) # Test Kmeans on dataset that is sensitive to seed value - col1 <- c(1,2,3,4,0,1,2,3,4,0) - col2 <- c(1,2,3,4,0,1,2,3,4,0) - col3 <- c(1,2,3,4,0,1,2,3,4,0) - cols <- as.data.frame(cbind(col1, col2 , col3)) + col1 <- c(1, 2, 3, 4, 0, 1, 2, 3, 4, 0) + col2 <- c(1, 2, 3, 4, 0, 1, 2, 3, 4, 0) + col3 <- c(1, 2, 3, 4, 0, 1, 2, 3, 4, 0) + cols <- as.data.frame(cbind(col1, col2, col3)) df <- createDataFrame(cols) model1 <- spark.kmeans(data = df, ~ ., k = 5, maxIter = 10, From 1c27df4d8e9ed405ef5a9705cffd67f9d3adc4ff Mon Sep 17 00:00:00 2001 From: "wm624@hotmail.com" Date: Thu, 12 Jan 2017 10:30:04 -0800 Subject: [PATCH 6/6] fix typo --- R/pkg/inst/tests/testthat/test_mllib_clustering.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pkg/inst/tests/testthat/test_mllib_clustering.R b/R/pkg/inst/tests/testthat/test_mllib_clustering.R index d5bfe3bf6d641..f013991002a02 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_clustering.R +++ b/R/pkg/inst/tests/testthat/test_mllib_clustering.R @@ -142,7 +142,7 @@ test_that("spark.kmeans", { model1 <- spark.kmeans(data = df, ~ ., k = 5, maxIter = 10, initMode = "random", seed = 1, tol = 1E-5) - model2 <- model <- spark.kmeans(data = df, ~ ., k = 5, maxIter = 10, + model2 <- spark.kmeans(data = df, ~ ., k = 5, maxIter = 10, initMode = "random", seed = 22222, tol = 1E-5) fitted.model1 <- fitted(model1)