Skip to content

Commit

Permalink
small cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
jkbradley committed Feb 4, 2015
1 parent 12d9059 commit 87c4eb8
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.mllib.classification

import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
import org.apache.spark.mllib.classification.impl.GLMClassificationModel

import org.apache.spark.{SparkContext, SparkException, Logging}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ package org.apache.spark.mllib.regression

import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.Loader
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.sql.{DataFrame, Row}

@Experimental
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.regression.impl.GLMRegressionModel
import org.apache.spark.mllib.util.{Saveable, Loader}
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,17 @@ object LogisticRegressionSuite {
/** 3 classes, 2 features */
private val multiclassModel = new LogisticRegressionModel(
weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0, numFeatures = 2, numClasses = 3)

private def checkModelsEqual(a: LogisticRegressionModel, b: LogisticRegressionModel): Unit = {
assert(a.weights == b.weights)
assert(a.intercept == b.intercept)
assert(a.numClasses == b.numClasses)
assert(a.numFeatures == b.numFeatures)
assert(a.getThreshold == b.getThreshold)
}
}


class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers {
def validatePrediction(
predictions: Seq[Double],
Expand Down Expand Up @@ -486,11 +495,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M
try {
model.save(sc, path)
val sameModel = LogisticRegressionModel.load(sc, path)
assert(model.weights == sameModel.weights)
assert(model.intercept == sameModel.intercept)
assert(model.numClasses == sameModel.numClasses)
assert(model.numFeatures == sameModel.numFeatures)
assert(sameModel.getThreshold.isEmpty)
LogisticRegressionSuite.checkModelsEqual(model, sameModel)
} finally {
Utils.deleteRecursively(tempDir)
}
Expand All @@ -499,8 +504,8 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M
try {
model.setThreshold(0.7)
model.save(sc, path)
val sameModel2 = LogisticRegressionModel.load(sc, path)
assert(model.getThreshold.get == sameModel2.getThreshold.get)
val sameModel = LogisticRegressionModel.load(sc, path)
LogisticRegressionSuite.checkModelsEqual(model, sameModel)
} finally {
Utils.deleteRecursively(tempDir)
}
Expand All @@ -517,10 +522,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M
try {
model.save(sc, path)
val sameModel = LogisticRegressionModel.load(sc, path)
assert(model.weights == sameModel.weights)
assert(model.intercept == sameModel.intercept)
assert(model.numClasses == sameModel.numClasses)
assert(model.numFeatures == sameModel.numFeatures)
LogisticRegressionSuite.checkModelsEqual(model, sameModel)
} finally {
Utils.deleteRecursively(tempDir)
}
Expand Down

0 comments on commit 87c4eb8

Please sign in to comment.