From 44be8c95ef683610a31f9261e484fd82b480a430 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 11 Oct 2016 02:17:42 -0700 Subject: [PATCH] Fix SparkR spark.naiveBayes error when label is numeric type --- R/pkg/inst/tests/testthat/test_mllib.R | 10 ++++++++++ .../org/apache/spark/ml/r/NaiveBayesWrapper.scala | 1 + 2 files changed, 11 insertions(+) diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index a1eaaf20916a2..c99315726a22c 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -481,6 +481,16 @@ test_that("spark.naiveBayes", { expect_error(m <- e1071::naiveBayes(Survived ~ ., data = t1), NA) expect_equal(as.character(predict(m, t1[1, ])), "Yes") } + + # Test numeric response variable + t1$NumericSurvived <- ifelse(t1$Survived == "No", 0, 1) + t2 <- t1[-4] + df <- suppressWarnings(createDataFrame(t2)) + m <- spark.naiveBayes(df, NumericSurvived ~ ., smoothing = 0.0) + s <- summary(m) + expect_equal(as.double(s$apriori[1, 1]), 0.5833333, tolerance = 1e-6) + expect_equal(sum(s$apriori), 1) + expect_equal(as.double(s$tables[1, "Age_Adult"]), 0.5714286, tolerance = 1e-6) }) test_that("spark.survreg", { diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala index d1a39fea76ef8..4fdab2dd94655 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala @@ -59,6 +59,7 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] { def fit(formula: String, data: DataFrame, smoothing: Double): NaiveBayesWrapper = { val rFormula = new RFormula() .setFormula(formula) + .setForceIndexLabel(true) RWrapperUtils.checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) // get labels and feature names from output schema