Skip to content

Commit

Permalink
small cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
jkbradley committed Feb 3, 2015
1 parent 418ba1b commit b1fc5ec
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ import org.apache.spark.mllib.linalg.BLAS.dot
import org.apache.spark.mllib.linalg.{DenseVector, Vector}
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{DataValidators, MLUtils}
import org.apache.spark.mllib.util.{Importable, DataValidators, Exportable}
import org.apache.spark.mllib.util.{DataValidators, Exportable, Importable, MLUtils}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext, SchemaRDD}


/**
* Classification model trained using Multinomial/Binary Logistic Regression.
*
Expand Down Expand Up @@ -143,8 +143,8 @@ class LogisticRegressionModel (
import sqlContext._
// TODO: Do we need to use a SELECT statement to make the column ordering deterministic?
// Create JSON metadata.
val metadata =
LogisticRegressionModel.Metadata(clazz = this.getClass.getName, version = Exportable.version)
val metadata = LogisticRegressionModel.Metadata(
clazz = this.getClass.getName, version = Exportable.latestVersion)
val metadataRDD: SchemaRDD = sc.parallelize(Seq(metadata))
metadataRDD.toJSON.saveAsTextFile(path + "/metadata")
// Create Parquet data.
Expand All @@ -156,6 +156,10 @@ class LogisticRegressionModel (

object LogisticRegressionModel extends Importable[LogisticRegressionModel] {

private case class Metadata(clazz: String, version: String)

private case class Data(weights: Vector, intercept: Double, threshold: Option[Double])

override def load(sc: SparkContext, path: String): LogisticRegressionModel = {
val sqlContext = new SQLContext(sc)
import sqlContext._
Expand All @@ -169,7 +173,7 @@ object LogisticRegressionModel extends Importable[LogisticRegressionModel] {
case Row(clazz: String, version: String) =>
assert(clazz == classOf[LogisticRegressionModel].getName, s"LogisticRegressionModel.load" +
s" was given model file with metadata specifying a different model class: $clazz")
assert(version == Importable.version, // only 1 version exists currently
assert(version == Exportable.latestVersion, // only 1 version exists currently
s"LogisticRegressionModel.load did not recognize model format version: $version")
}

Expand All @@ -192,10 +196,6 @@ object LogisticRegressionModel extends Importable[LogisticRegressionModel] {
lr
}

private case class Metadata(clazz: String, version: String)

private case class Data(weights: Vector, intercept: Double, threshold: Option[Double])

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.mllib.util
import org.apache.spark.SparkContext
import org.apache.spark.annotation.DeveloperApi


/**
* :: DeveloperApi ::
*
Expand Down Expand Up @@ -50,7 +49,7 @@ trait Exportable {
object Exportable {

/** Current version of model import/export format. */
val version: String = "1.0"
val latestVersion: String = "1.0"

}

Expand All @@ -75,10 +74,3 @@ trait Importable[Model <: Exportable] {
def load(sc: SparkContext, path: String): Model

}

object Importable {

/** Current version of model import/export format. */
val version: String = Exportable.version

}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils

object LogisticRegressionSuite {

Expand Down Expand Up @@ -481,20 +482,19 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M
val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString

// Save model
// Save model, load it back, and compare.
model.save(sc, path)
val sameModel = LogisticRegressionModel.load(sc, path)
assert(model.weights == sameModel.weights)
assert(model.intercept == sameModel.intercept)
assert(sameModel.getThreshold.isEmpty)
Utils.deleteRecursively(tempDir)

// Save model with threshold
// Save model with threshold.
model.setThreshold(0.7)
model.save(sc, path)
val sameModel2 = LogisticRegressionModel.load(sc, path)
assert(model.getThreshold.get == sameModel2.getThreshold.get)

Utils.deleteRecursively(tempDir)
}

Expand Down

0 comments on commit b1fc5ec

Please sign in to comment.