From b7f934ad2eb3f39125d9bc29289e8ce3a49f48b7 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Tue, 10 Jan 2017 22:02:44 -0800 Subject: [PATCH] fix Gamma family --- R/pkg/R/mllib.R | 7 ++++++- R/pkg/inst/tests/testthat/test_mllib.R | 8 ++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index b33a16a7cef97..cd07f278ecb1c 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -89,6 +89,8 @@ NULL #' This can be a character string naming a family function, a family function or #' the result of a call to a family function. Refer R family at #' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. +#' Currently these families are supported: \code{binomial}, \code{gaussian}, +#' \code{Gamma}, and \code{poisson}. #' @param tol positive convergence tolerance of iterations. #' @param maxIter integer giving the maximal number of IRLS iterations. #' @param ... additional arguments passed to the method. @@ -134,8 +136,9 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), formula <- paste(deparse(formula), collapse = "") + # For known families, Gamma is upper-cased jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper", - "fit", formula, data@sdf, family$family, family$link, + "fit", formula, data@sdf, tolower(family$family), family$link, tol, as.integer(maxIter)) return(new("GeneralizedLinearRegressionModel", jobj = jobj)) }) @@ -150,6 +153,8 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' This can be a character string naming a family function, a family function or #' the result of a call to a family function. Refer R family at #' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. +#' Currently these families are supported: \code{binomial}, \code{gaussian}, +#' \code{Gamma}, and \code{poisson}. #' @param epsilon positive convergence tolerance of iterations. #' @param maxit integer giving the maximal number of IRLS iterations. #' @return \code{glm} returns a fitted generalized linear model. diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 753da81760971..e0d2e53e5f301 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -69,6 +69,14 @@ test_that("spark.glm and predict", { data = iris, family = poisson(link = identity)), iris)) expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) + # Gamma family + x <- runif(100, -1, 1) + y <- rgamma(100, rate = 10 / exp(0.5 + 1.2 * x), shape = 10) + df <- as.DataFrame(as.data.frame(list(x = x, y = y))) + model <- glm(y ~ x, family = Gamma, df) + out <- capture.output(print(summary(model))) + expect_true(any(grepl("Dispersion parameter for gamma family", out))) + # Test stats::predict is working x <- rnorm(15) y <- x + rnorm(15)