From 87c4eb8e2e030cf033418901fd7c0533efb090f3 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 4 Feb 2015 15:32:29 -0800 Subject: [PATCH] small cleanups --- .../mllib/classification/NaiveBayes.scala | 1 - .../mllib/regression/RegressionModel.scala | 2 +- .../mllib/regression/RidgeRegression.scala | 2 +- .../LogisticRegressionSuite.scala | 24 ++++++++++--------- 4 files changed, 15 insertions(+), 14 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index c8fe19855ddab..4bafd495f90b1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -18,7 +18,6 @@ package org.apache.spark.mllib.classification import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} -import org.apache.spark.mllib.classification.impl.GLMClassificationModel import org.apache.spark.{SparkContext, SparkException, Logging} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala index d6bbe7bbf4409..843e59bdfbdd2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala @@ -19,9 +19,9 @@ package org.apache.spark.mllib.regression import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.Loader import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.linalg.Vector import org.apache.spark.sql.{DataFrame, Row} @Experimental diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index 32a40b9a51d83..f2a5f1db1ece6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression.impl.GLMRegressionModel -import org.apache.spark.mllib.util.{Saveable, Loader} +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 6be1b290a9b60..d2b40f2cae020 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -157,8 +157,17 @@ object LogisticRegressionSuite { /** 3 classes, 2 features */ private val multiclassModel = new LogisticRegressionModel( weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0, numFeatures = 2, numClasses = 3) + + private def checkModelsEqual(a: LogisticRegressionModel, b: LogisticRegressionModel): Unit = { + assert(a.weights == b.weights) + assert(a.intercept == b.intercept) + assert(a.numClasses == b.numClasses) + assert(a.numFeatures == b.numFeatures) + assert(a.getThreshold == b.getThreshold) + } } + class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers { def validatePrediction( predictions: Seq[Double], @@ -486,11 +495,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M try { model.save(sc, path) val sameModel = LogisticRegressionModel.load(sc, path) - assert(model.weights == sameModel.weights) - assert(model.intercept == sameModel.intercept) - assert(model.numClasses == sameModel.numClasses) - assert(model.numFeatures == sameModel.numFeatures) - assert(sameModel.getThreshold.isEmpty) + LogisticRegressionSuite.checkModelsEqual(model, sameModel) } finally { Utils.deleteRecursively(tempDir) } @@ -499,8 +504,8 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M try { model.setThreshold(0.7) model.save(sc, path) - val sameModel2 = LogisticRegressionModel.load(sc, path) - assert(model.getThreshold.get == sameModel2.getThreshold.get) + val sameModel = LogisticRegressionModel.load(sc, path) + LogisticRegressionSuite.checkModelsEqual(model, sameModel) } finally { Utils.deleteRecursively(tempDir) } @@ -517,10 +522,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M try { model.save(sc, path) val sameModel = LogisticRegressionModel.load(sc, path) - assert(model.weights == sameModel.weights) - assert(model.intercept == sameModel.intercept) - assert(model.numClasses == sameModel.numClasses) - assert(model.numFeatures == sameModel.numFeatures) + LogisticRegressionSuite.checkModelsEqual(model, sameModel) } finally { Utils.deleteRecursively(tempDir) }