-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-5596] [mllib] ML model import/export for GLMs, NaiveBayes #4233
Changes from 13 commits
418ba1b
b1fc5ec
64914a3
1577d70
8d46386
1496852
c495dba
2935963
d1e5882
79675d5
ee99228
a34aef5
b4ee064
12d9059
87c4eb8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,14 +17,17 @@ | |
|
||
package org.apache.spark.mllib.classification | ||
|
||
import org.apache.spark.SparkContext | ||
import org.apache.spark.annotation.Experimental | ||
import org.apache.spark.mllib.classification.impl.GLMClassificationModel | ||
import org.apache.spark.mllib.linalg.BLAS.dot | ||
import org.apache.spark.mllib.linalg.{DenseVector, Vector} | ||
import org.apache.spark.mllib.optimization._ | ||
import org.apache.spark.mllib.regression._ | ||
import org.apache.spark.mllib.util.{DataValidators, MLUtils} | ||
import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader} | ||
import org.apache.spark.rdd.RDD | ||
|
||
|
||
/** | ||
* Classification model trained using Multinomial/Binary Logistic Regression. | ||
* | ||
|
@@ -42,7 +45,8 @@ class LogisticRegressionModel ( | |
override val intercept: Double, | ||
val numFeatures: Int, | ||
val numClasses: Int) | ||
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable { | ||
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable | ||
with Saveable { | ||
|
||
def this(weights: Vector, intercept: Double) = this(weights, intercept, weights.size, 2) | ||
|
||
|
@@ -60,6 +64,13 @@ class LogisticRegressionModel ( | |
this | ||
} | ||
|
||
/** | ||
* :: Experimental :: | ||
* Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions. | ||
*/ | ||
@Experimental | ||
def getThreshold: Option[Double] = threshold | ||
|
||
/** | ||
* :: Experimental :: | ||
* Clears the threshold so that `predict` will output raw prediction scores. | ||
|
@@ -126,6 +137,35 @@ class LogisticRegressionModel ( | |
bestClass.toDouble | ||
} | ||
} | ||
|
||
override def save(sc: SparkContext, path: String): Unit = { | ||
GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any proposed guidelines about when to change the minor version and when the major version? I'm not expecting many versions, so I'm not sure whether it is necessary to have minor versions. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was thinking minor versions could be used for format changes and major ones for model changes. But I'm OK with a single version number too. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have no strong preference here. It is okay with the current versioning. |
||
weights, intercept, threshold) | ||
} | ||
|
||
override protected def formatVersion: String = "1.0" | ||
} | ||
|
||
object LogisticRegressionModel extends Loader[LogisticRegressionModel] { | ||
|
||
override def load(sc: SparkContext, path: String): LogisticRegressionModel = { | ||
val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) | ||
val classNameV1_0 = "org.apache.spark.mllib.classification.LogisticRegressionModel" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we should put a comment here about why using literal string name. |
||
(loadedClassName, version) match { | ||
case (className, "1.0") if className == classNameV1_0 => | ||
val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0) | ||
val model = new LogisticRegressionModel(data.weights, data.intercept) | ||
data.threshold match { | ||
case Some(t) => model.setThreshold(t) | ||
case None => model.clearThreshold() | ||
} | ||
model | ||
case _ => throw new Exception( | ||
s"LogisticRegressionModel.load did not recognize model with (className, format version):" + | ||
s"($loadedClassName, $version). Supported:\n" + | ||
s" ($classNameV1_0, 1.0)") | ||
} | ||
} | ||
} | ||
|
||
/** | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,11 +19,13 @@ package org.apache.spark.mllib.classification | |
|
||
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} | ||
|
||
import org.apache.spark.{SparkException, Logging} | ||
import org.apache.spark.SparkContext._ | ||
import org.apache.spark.{SparkContext, SparkException, Logging} | ||
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector} | ||
import org.apache.spark.mllib.regression.LabeledPoint | ||
import org.apache.spark.mllib.util.{Loader, Saveable} | ||
import org.apache.spark.rdd.RDD | ||
import org.apache.spark.sql.{DataFrame, SQLContext} | ||
|
||
|
||
/** | ||
* Model for Naive Bayes Classifiers. | ||
|
@@ -36,7 +38,7 @@ import org.apache.spark.rdd.RDD | |
class NaiveBayesModel private[mllib] ( | ||
val labels: Array[Double], | ||
val pi: Array[Double], | ||
val theta: Array[Array[Double]]) extends ClassificationModel with Serializable { | ||
val theta: Array[Array[Double]]) extends ClassificationModel with Serializable with Saveable { | ||
|
||
private val brzPi = new BDV[Double](pi) | ||
private val brzTheta = new BDM[Double](theta.length, theta(0).length) | ||
|
@@ -65,6 +67,68 @@ class NaiveBayesModel private[mllib] ( | |
override def predict(testData: Vector): Double = { | ||
labels(brzArgmax(brzPi + brzTheta * testData.toBreeze)) | ||
} | ||
|
||
override def save(sc: SparkContext, path: String): Unit = { | ||
val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta) | ||
NaiveBayesModel.SaveLoadV1_0.save(sc, path, data) | ||
} | ||
|
||
override protected def formatVersion: String = "1.0" | ||
} | ||
|
||
object NaiveBayesModel extends Loader[NaiveBayesModel] { | ||
|
||
private object SaveLoadV1_0 { | ||
|
||
def thisFormatVersion = "1.0" | ||
|
||
def thisClassName = "org.apache.spark.mllib.classification.NaiveBayesModel" | ||
|
||
/** Model data for model import/export */ | ||
case class Data(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) | ||
|
||
def save(sc: SparkContext, path: String, data: Data): Unit = { | ||
val sqlContext = new SQLContext(sc) | ||
import sqlContext._ | ||
|
||
// Create JSON metadata. | ||
val metadataRDD = | ||
sc.parallelize(Seq((thisClassName, thisFormatVersion))).toDataFrame("class", "version") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use |
||
metadataRDD.toJSON.repartition(1).saveAsTextFile(path + "/metadata") | ||
|
||
// Create Parquet data. | ||
val dataRDD: DataFrame = sc.parallelize(Seq(data)) | ||
dataRDD.repartition(1).saveAsParquetFile(path + "/data") | ||
} | ||
|
||
def load(sc: SparkContext, path: String): NaiveBayesModel = { | ||
val sqlContext = new SQLContext(sc) | ||
// Load Parquet data. | ||
val dataRDD = sqlContext.parquetFile(path + "/data") | ||
// Check schema explicitly since erasure makes it hard to use match-case for checking. | ||
Loader.checkSchema[Data](dataRDD.schema) | ||
val dataArray = dataRDD.select("labels", "pi", "theta").take(1) | ||
assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${path + "/data"}") | ||
val data = dataArray(0) | ||
val labels = data.getAs[Seq[Double]](0).toArray | ||
val pi = data.getAs[Seq[Double]](1).toArray | ||
val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray | ||
new NaiveBayesModel(labels, pi, theta) | ||
} | ||
} | ||
|
||
override def load(sc: SparkContext, path: String): NaiveBayesModel = { | ||
val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) | ||
val classNameV1_0 = SaveLoadV1_0.thisClassName | ||
(loadedClassName, version) match { | ||
case (className, "1.0") if className == classNameV1_0 => | ||
SaveLoadV1_0.load(sc, path) | ||
case _ => throw new Exception( | ||
s"NaiveBayesModel.load did not recognize model with (className, format version):" + | ||
s"($loadedClassName, $version). Supported:\n" + | ||
s" ($classNameV1_0, 1.0)") | ||
} | ||
} | ||
} | ||
|
||
/** | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.classification.impl | ||
|
||
import org.apache.spark.SparkContext | ||
import org.apache.spark.mllib.linalg.Vector | ||
import org.apache.spark.sql.{DataFrame, Row, SQLContext} | ||
|
||
/** | ||
* Helper class for import/export of GLM classification models. | ||
*/ | ||
private[classification] object GLMClassificationModel { | ||
|
||
object SaveLoadV1_0 { | ||
|
||
def thisFormatVersion = "1.0" | ||
|
||
/** Model data for model import/export */ | ||
case class Data(weights: Vector, intercept: Double, threshold: Option[Double]) | ||
|
||
def save( | ||
sc: SparkContext, | ||
path: String, | ||
modelClass: String, | ||
weights: Vector, | ||
intercept: Double, | ||
threshold: Option[Double]): Unit = { | ||
val sqlContext = new SQLContext(sc) | ||
import sqlContext._ | ||
|
||
// Create JSON metadata. | ||
val metadataRDD = | ||
sc.parallelize(Seq((modelClass, thisFormatVersion))).toDataFrame("class", "version") | ||
metadataRDD.toJSON.repartition(1).saveAsTextFile(path + "/metadata") | ||
|
||
// Create Parquet data. | ||
val data = Data(weights, intercept, threshold) | ||
val dataRDD: DataFrame = sc.parallelize(Seq(data)) | ||
// TODO: repartition with 1 partition after SPARK-5532 gets fixed | ||
dataRDD.saveAsParquetFile(path + "/data") | ||
} | ||
|
||
def loadData(sc: SparkContext, path: String, modelClass: String): Data = { | ||
val sqlContext = new SQLContext(sc) | ||
val dataRDD = sqlContext.parquetFile(path + "/data") | ||
val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1) | ||
assert(dataArray.size == 1, s"Unable to load $modelClass data from: ${path + "/data"}") | ||
val data = dataArray(0) | ||
assert(data.size == 3, s"Unable to load $modelClass data from: ${path + "/data"}") | ||
val (weights, intercept) = data match { | ||
case Row(weights: Vector, intercept: Double, _) => | ||
(weights, intercept) | ||
} | ||
val threshold = if (data.isNullAt(2)) { | ||
None | ||
} else { | ||
Some(data.getDouble(2)) | ||
} | ||
Data(weights, intercept, threshold) | ||
} | ||
} | ||
|
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We merged multinomial logistic regression. LRModel holds
numFeatures
andnumClasses
now. We need a specialized implementation and a test for it. Or for all classification models, we savenumFeatures
andnumClasses
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll save numFeatures and numClasses in all classification models' metadata. I'm going for metadata instead of data in case the model data requires multiple RDD rows (e.g., for decision tree).