From 4e92737959182724a5b37a6afc3d641d13a8586d Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 25 Jul 2016 06:33:49 -0700 Subject: [PATCH 1/2] spark.glm should support weightCol --- R/pkg/R/mllib.R | 15 +++++++++---- R/pkg/inst/tests/testthat/test_mllib.R | 22 +++++++++++++++++++ .../GeneralizedLinearRegressionWrapper.scala | 2 ++ 3 files changed, 35 insertions(+), 4 deletions(-) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 50c601fcd9e1b..b376399a8145b 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -89,6 +89,8 @@ NULL #' This can be a character string naming a family function, a family function or #' the result of a call to a family function. Refer R family at #' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. +#' @param weightCol The weight column name. If this is not set or NULL, we treat all instance +#' weights as 1.0. #' @param tol Positive convergence tolerance of iterations. #' @param maxIter Integer giving the maximal number of IRLS iterations. #' @aliases spark.glm,SparkDataFrame,formula-method @@ -119,7 +121,7 @@ NULL #' @note spark.glm since 2.0.0 #' @seealso \link{glm}, \link{read.ml} setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25) { + function(data, formula, family = gaussian, weightCol = NULL, tol = 1e-6, maxIter = 25) { if (is.character(family)) { family <- get(family, mode = "function", envir = parent.frame()) } @@ -132,9 +134,12 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), } formula <- paste(deparse(formula), collapse = "") + if (is.null(weightCol)) { + weightCol <- "" + } jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper", - "fit", formula, data@sdf, family$family, family$link, + "fit", formula, data@sdf, family$family, family$link, weightCol, tol, as.integer(maxIter)) return(new("GeneralizedLinearRegressionModel", jobj = jobj)) }) @@ -149,6 +154,8 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' This can be a character string naming a family function, a family function or #' the result of a call to a family function. Refer R family at #' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. +#' @param weightCol The weight column name. If this is not set or NULL, we treat all instance +#' weights as 1.0. #' @param epsilon Positive convergence tolerance of iterations. #' @param maxit Integer giving the maximal number of IRLS iterations. #' @return \code{glm} returns a fitted generalized linear model. @@ -165,8 +172,8 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' @note glm since 1.5.0 #' @seealso \link{spark.glm} setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDataFrame"), - function(formula, family = gaussian, data, epsilon = 1e-6, maxit = 25) { - spark.glm(data, formula, family, tol = epsilon, maxIter = maxit) + function(formula, family = gaussian, data, weightCol = NULL, epsilon = 1e-6, maxit = 25) { + spark.glm(data, formula, family, weightCol, tol = epsilon, maxIter = maxit) }) # Returns the summary of a model produced by glm() or spark.glm(), similarly to R's summary(). diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index ab390a86d1ccd..bc18224680586 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -118,6 +118,28 @@ test_that("spark.glm summary", { expect_equal(stats$df.residual, rStats$df.residual) expect_equal(stats$aic, rStats$aic) + # Test spark.glm works with weighted dataset + a1 <- c(0, 1, 2, 3) + a2 <- c(5, 2, 1, 3) + w <- c(1, 2, 3, 4) + b <- c(1, 0, 1, 0) + data <- as.data.frame(cbind(a1, a2, w, b)) + df <- suppressWarnings(createDataFrame(data)) + + stats <- summary(spark.glm(df, b ~ a1 + a2, family = "binomial", weightCol = "w")) + rStats <- summary(glm(b ~ a1 + a2, family = "binomial", data = data, weights = w)) + + coefs <- unlist(stats$coefficients) + rCoefs <- unlist(rStats$coefficients) + expect_true(all(abs(rCoefs - coefs) < 1e-3)) + expect_true(all(rownames(stats$coefficients) == c("(Intercept)", "a1", "a2"))) + expect_equal(stats$dispersion, rStats$dispersion) + expect_equal(stats$null.deviance, rStats$null.deviance) + expect_equal(stats$deviance, rStats$deviance) + expect_equal(stats$df.null, rStats$df.null) + expect_equal(stats$df.residual, rStats$df.residual) + expect_equal(stats$aic, rStats$aic) + # Test summary works on base GLM models baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris) baseSummary <- summary(baseModel) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala index 5642abc6450f1..6b35b2ec71493 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -67,6 +67,7 @@ private[r] object GeneralizedLinearRegressionWrapper data: DataFrame, family: String, link: String, + weightCol: String, tol: Double, maxIter: Int): GeneralizedLinearRegressionWrapper = { val rFormula = new RFormula() @@ -82,6 +83,7 @@ private[r] object GeneralizedLinearRegressionWrapper .setFamily(family) .setLink(link) .setFitIntercept(rFormula.hasIntercept) + .setWeightCol(weightCol) .setTol(tol) .setMaxIter(maxIter) val pipeline = new Pipeline() From 5f96b6ec3179cac7c5076a2fddf29ff8f5a7566b Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 1 Aug 2016 07:37:37 -0700 Subject: [PATCH 2/2] Move weightCol to the end of arguments. --- R/pkg/R/mllib.R | 18 +++++++++--------- .../r/GeneralizedLinearRegressionWrapper.scala | 6 +++--- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index b376399a8145b..25d9f077b487c 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -89,10 +89,10 @@ NULL #' This can be a character string naming a family function, a family function or #' the result of a call to a family function. Refer R family at #' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. -#' @param weightCol The weight column name. If this is not set or NULL, we treat all instance -#' weights as 1.0. #' @param tol Positive convergence tolerance of iterations. #' @param maxIter Integer giving the maximal number of IRLS iterations. +#' @param weightCol The weight column name. If this is not set or NULL, we treat all instance +#' weights as 1.0. #' @aliases spark.glm,SparkDataFrame,formula-method #' @return \code{spark.glm} returns a fitted generalized linear model #' @rdname spark.glm @@ -121,7 +121,7 @@ NULL #' @note spark.glm since 2.0.0 #' @seealso \link{glm}, \link{read.ml} setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula, family = gaussian, weightCol = NULL, tol = 1e-6, maxIter = 25) { + function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25, weightCol = NULL) { if (is.character(family)) { family <- get(family, mode = "function", envir = parent.frame()) } @@ -139,8 +139,8 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), } jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper", - "fit", formula, data@sdf, family$family, family$link, weightCol, - tol, as.integer(maxIter)) + "fit", formula, data@sdf, family$family, family$link, + tol, as.integer(maxIter), weightCol) return(new("GeneralizedLinearRegressionModel", jobj = jobj)) }) @@ -154,10 +154,10 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' This can be a character string naming a family function, a family function or #' the result of a call to a family function. Refer R family at #' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. -#' @param weightCol The weight column name. If this is not set or NULL, we treat all instance -#' weights as 1.0. #' @param epsilon Positive convergence tolerance of iterations. #' @param maxit Integer giving the maximal number of IRLS iterations. +#' @param weightCol The weight column name. If this is not set or NULL, we treat all instance +#' weights as 1.0. #' @return \code{glm} returns a fitted generalized linear model. #' @rdname glm #' @export @@ -172,8 +172,8 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' @note glm since 1.5.0 #' @seealso \link{spark.glm} setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDataFrame"), - function(formula, family = gaussian, data, weightCol = NULL, epsilon = 1e-6, maxit = 25) { - spark.glm(data, formula, family, weightCol, tol = epsilon, maxIter = maxit) + function(formula, family = gaussian, data, epsilon = 1e-6, maxit = 25, weightCol = NULL) { + spark.glm(data, formula, family, tol = epsilon, maxIter = maxit, weightCol = weightCol) }) # Returns the summary of a model produced by glm() or spark.glm(), similarly to R's summary(). diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala index 6b35b2ec71493..0d3181d0acb48 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -67,9 +67,9 @@ private[r] object GeneralizedLinearRegressionWrapper data: DataFrame, family: String, link: String, - weightCol: String, tol: Double, - maxIter: Int): GeneralizedLinearRegressionWrapper = { + maxIter: Int, + weightCol: String): GeneralizedLinearRegressionWrapper = { val rFormula = new RFormula() .setFormula(formula) val rFormulaModel = rFormula.fit(data) @@ -83,9 +83,9 @@ private[r] object GeneralizedLinearRegressionWrapper .setFamily(family) .setLink(link) .setFitIntercept(rFormula.hasIntercept) - .setWeightCol(weightCol) .setTol(tol) .setMaxIter(maxIter) + .setWeightCol(weightCol) val pipeline = new Pipeline() .setStages(Array(rFormulaModel, glr)) .fit(data)