From 380250e319203df9c5aed857ae8b1f865b86d70b Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 5 Nov 2015 15:12:51 -0800 Subject: [PATCH 1/3] fix glm with long fomular --- R/pkg/R/mllib.R | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 60bfadb8e7503..b0d73dd93a79d 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -48,8 +48,9 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFram function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0, standardize = TRUE, solver = "auto") { family <- match.arg(family) + formula <- paste(deparse(formula), collapse="") model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "fitRModelFormula", deparse(formula), data@sdf, family, lambda, + "fitRModelFormula", formula, data@sdf, family, lambda, alpha, standardize, solver) return(new("PipelineModel", model = model)) }) From 9b0879bb694c81ac09cdd9a24483814ae791042d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 5 Nov 2015 15:36:36 -0800 Subject: [PATCH 2/3] add regression test --- R/pkg/inst/tests/test_mllib.R | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index 032cfef061fd3..520dac9f24983 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -33,6 +33,20 @@ test_that("glm and predict", { expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") }) +test_that("glm should work with long formula", { + training <- createDataFrame(sqlContext, iris) + training$LongLongLongLongLongName <- training$Sepal_Width + training$VeryLongLongLongLongLongLongLongLongLongLongLongLongName <- training$Sepal_Length + training$AnotherVeryLongLongLongLongLongLongLongLongLongLongLongName <- training$Species + model <- glm(LongLongLongLongLongName ~ + VeryLongLongLongLongLongLongLongLongLongLongLongLongName + + AnotherVeryLongLongLongLongLongLongLongLongLongLongLongName, + data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) +}) + test_that("predictions match with native glm", { training <- createDataFrame(sqlContext, iris) model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) From 23f95c6c465d288afac0e649b6e2bfe7b99e8a25 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 5 Nov 2015 15:54:59 -0800 Subject: [PATCH 3/3] fix style --- R/pkg/inst/tests/test_mllib.R | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index 520dac9f24983..4761e285a2479 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -36,11 +36,9 @@ test_that("glm and predict", { test_that("glm should work with long formula", { training <- createDataFrame(sqlContext, iris) training$LongLongLongLongLongName <- training$Sepal_Width - training$VeryLongLongLongLongLongLongLongLongLongLongLongLongName <- training$Sepal_Length - training$AnotherVeryLongLongLongLongLongLongLongLongLongLongLongName <- training$Species - model <- glm(LongLongLongLongLongName ~ - VeryLongLongLongLongLongLongLongLongLongLongLongLongName + - AnotherVeryLongLongLongLongLongLongLongLongLongLongLongName, + training$VeryLongLongLongLonLongName <- training$Sepal_Length + training$AnotherLongLongLongLongName <- training$Species + model <- glm(LongLongLongLongLongName ~ VeryLongLongLongLonLongName + AnotherLongLongLongLongName, data = training) vals <- collect(select(predict(model, training), "prediction")) rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris)