From b1fc5eca06808f2250fef75aa10c816c889dd5f1 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 27 Jan 2015 14:08:10 -0800 Subject: [PATCH] small cleanups --- .../classification/LogisticRegression.scala | 18 +++++++++--------- .../spark/mllib/util/modelImportExport.scala | 10 +--------- .../LogisticRegressionSuite.scala | 6 +++--- 3 files changed, 13 insertions(+), 21 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index dde34195163df..22e4d2ef3af3d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -23,11 +23,11 @@ import org.apache.spark.mllib.linalg.BLAS.dot import org.apache.spark.mllib.linalg.{DenseVector, Vector} import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.{DataValidators, MLUtils} -import org.apache.spark.mllib.util.{Importable, DataValidators, Exportable} +import org.apache.spark.mllib.util.{DataValidators, Exportable, Importable, MLUtils} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext, SchemaRDD} + /** * Classification model trained using Multinomial/Binary Logistic Regression. * @@ -143,8 +143,8 @@ class LogisticRegressionModel ( import sqlContext._ // TODO: Do we need to use a SELECT statement to make the column ordering deterministic? // Create JSON metadata. - val metadata = - LogisticRegressionModel.Metadata(clazz = this.getClass.getName, version = Exportable.version) + val metadata = LogisticRegressionModel.Metadata( + clazz = this.getClass.getName, version = Exportable.latestVersion) val metadataRDD: SchemaRDD = sc.parallelize(Seq(metadata)) metadataRDD.toJSON.saveAsTextFile(path + "/metadata") // Create Parquet data. @@ -156,6 +156,10 @@ class LogisticRegressionModel ( object LogisticRegressionModel extends Importable[LogisticRegressionModel] { + private case class Metadata(clazz: String, version: String) + + private case class Data(weights: Vector, intercept: Double, threshold: Option[Double]) + override def load(sc: SparkContext, path: String): LogisticRegressionModel = { val sqlContext = new SQLContext(sc) import sqlContext._ @@ -169,7 +173,7 @@ object LogisticRegressionModel extends Importable[LogisticRegressionModel] { case Row(clazz: String, version: String) => assert(clazz == classOf[LogisticRegressionModel].getName, s"LogisticRegressionModel.load" + s" was given model file with metadata specifying a different model class: $clazz") - assert(version == Importable.version, // only 1 version exists currently + assert(version == Exportable.latestVersion, // only 1 version exists currently s"LogisticRegressionModel.load did not recognize model format version: $version") } @@ -192,10 +196,6 @@ object LogisticRegressionModel extends Importable[LogisticRegressionModel] { lr } - private case class Metadata(clazz: String, version: String) - - private case class Data(weights: Vector, intercept: Double, threshold: Option[Double]) - } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala index f6f312a18cc5f..c6ac6a45edaea 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala @@ -20,7 +20,6 @@ package org.apache.spark.mllib.util import org.apache.spark.SparkContext import org.apache.spark.annotation.DeveloperApi - /** * :: DeveloperApi :: * @@ -50,7 +49,7 @@ trait Exportable { object Exportable { /** Current version of model import/export format. */ - val version: String = "1.0" + val latestVersion: String = "1.0" } @@ -75,10 +74,3 @@ trait Importable[Model <: Exportable] { def load(sc: SparkContext, path: String): Model } - -object Importable { - - /** Current version of model import/export format. */ - val version: String = Exportable.version - -} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 82b2fdb4608eb..18deca92e06a0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils object LogisticRegressionSuite { @@ -481,7 +482,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M val tempDir = Utils.createTempDir() val path = tempDir.toURI.toString - // Save model + // Save model, load it back, and compare. model.save(sc, path) val sameModel = LogisticRegressionModel.load(sc, path) assert(model.weights == sameModel.weights) @@ -489,12 +490,11 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M assert(sameModel.getThreshold.isEmpty) Utils.deleteRecursively(tempDir) - // Save model with threshold + // Save model with threshold. model.setThreshold(0.7) model.save(sc, path) val sameModel2 = LogisticRegressionModel.load(sc, path) assert(model.getThreshold.get == sameModel2.getThreshold.get) - Utils.deleteRecursively(tempDir) }