Skip to content

Commit

Permalink
[SPARK-5596] [mllib] ML model import/export for GLMs, NaiveBayes
Browse files Browse the repository at this point in the history
This is a PR for Parquet-based model import/export.  Please see the design doc on [the JIRA](https://issues.apache.org/jira/browse/SPARK-4587).

Note: This includes only a subset of regression and classification models:
* NaiveBayes, SVM, LogisticRegression
* LinearRegression, RidgeRegression, Lasso

Follow-up PRs will cover other models.

Sketch of current contents:
* New traits: Saveable, Loader
* Implementations for some algorithms
* Also: Added LogisticRegressionModel.getThreshold method (so that unit test could check the threshold)

CC: mengxr  selvinsource

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #4233 from jkbradley/ml-import-export and squashes the following commits:

87c4eb8 [Joseph K. Bradley] small cleanups
12d9059 [Joseph K. Bradley] Many cleanups after code review.  Major changes: Storing numFeatures, numClasses in model metadata. Improvements to unit tests
b4ee064 [Joseph K. Bradley] Reorganized save/load for regression and classification.  Renamed concepts to Saveable, Loader
a34aef5 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into ml-import-export
ee99228 [Joseph K. Bradley] scala style fix
79675d5 [Joseph K. Bradley] cleanups in LogisticRegression after rebasing after multinomial PR
d1e5882 [Joseph K. Bradley] organized imports
2935963 [Joseph K. Bradley] Added save/load and tests for most classification and regression models
c495dba [Joseph K. Bradley] made version for model import/export local to each model
1496852 [Joseph K. Bradley] Added save/load for NaiveBayes
8d46386 [Joseph K. Bradley] Added save/load to NaiveBayes
1577d70 [Joseph K. Bradley] fixed issues after rebasing on master (DataFrame patch)
64914a3 [Joseph K. Bradley] added getThreshold to SVMModel
b1fc5ec [Joseph K. Bradley] small cleanups
418ba1b [Joseph K. Bradley] Added save, load to mllib.classification.LogisticRegressionModel, plus test suite

(cherry picked from commit 975bcef)
Signed-off-by: Xiangrui Meng <meng@databricks.com>
  • Loading branch information
jkbradley authored and mengxr committed Feb 5, 2015
1 parent 59fb5c7 commit 885bcbb
Show file tree
Hide file tree
Showing 18 changed files with 863 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ package org.apache.spark.mllib.classification
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.Loader
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}

/**
* :: Experimental ::
Expand Down Expand Up @@ -53,3 +55,21 @@ trait ClassificationModel extends Serializable {
def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] =
predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]]
}

private[mllib] object ClassificationModel {

/**
* Helper method for loading GLM classification model metadata.
*
* @param modelClass String name for model class (used for error messages)
* @return (numFeatures, numClasses)
*/
def getNumFeaturesClasses(metadata: DataFrame, modelClass: String, path: String): (Int, Int) = {
metadata.select("numFeatures", "numClasses").take(1)(0) match {
case Row(nFeatures: Int, nClasses: Int) => (nFeatures, nClasses)
case _ => throw new Exception(s"$modelClass unable to load" +
s" numFeatures, numClasses from metadata: ${Loader.metadataPath(path)}")
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@

package org.apache.spark.mllib.classification

import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.classification.impl.GLMClassificationModel
import org.apache.spark.mllib.linalg.BLAS.dot
import org.apache.spark.mllib.linalg.{DenseVector, Vector}
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{DataValidators, MLUtils}
import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader}
import org.apache.spark.rdd.RDD


/**
* Classification model trained using Multinomial/Binary Logistic Regression.
*
Expand All @@ -42,7 +45,22 @@ class LogisticRegressionModel (
override val intercept: Double,
val numFeatures: Int,
val numClasses: Int)
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable {
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable
with Saveable {

if (numClasses == 2) {
require(weights.size == numFeatures,
s"LogisticRegressionModel with numClasses = 2 was given non-matching values:" +
s" numFeatures = $numFeatures, but weights.size = ${weights.size}")
} else {
val weightsSizeWithoutIntercept = (numClasses - 1) * numFeatures
val weightsSizeWithIntercept = (numClasses - 1) * (numFeatures + 1)
require(weights.size == weightsSizeWithoutIntercept || weights.size == weightsSizeWithIntercept,
s"LogisticRegressionModel.load with numClasses = $numClasses and numFeatures = $numFeatures" +
s" expected weights of length $weightsSizeWithoutIntercept (without intercept)" +
s" or $weightsSizeWithIntercept (with intercept)," +
s" but was given weights of length ${weights.size}")
}

def this(weights: Vector, intercept: Double) = this(weights, intercept, weights.size, 2)

Expand All @@ -60,6 +78,13 @@ class LogisticRegressionModel (
this
}

/**
* :: Experimental ::
* Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions.
*/
@Experimental
def getThreshold: Option[Double] = threshold

/**
* :: Experimental ::
* Clears the threshold so that `predict` will output raw prediction scores.
Expand All @@ -70,7 +95,9 @@ class LogisticRegressionModel (
this
}

override protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector,
override protected def predictPoint(
dataMatrix: Vector,
weightMatrix: Vector,
intercept: Double) = {
require(dataMatrix.size == numFeatures)

Expand Down Expand Up @@ -126,6 +153,40 @@ class LogisticRegressionModel (
bestClass.toDouble
}
}

override def save(sc: SparkContext, path: String): Unit = {
GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName,
numFeatures, numClasses, weights, intercept, threshold)
}

override protected def formatVersion: String = "1.0"
}

object LogisticRegressionModel extends Loader[LogisticRegressionModel] {

override def load(sc: SparkContext, path: String): LogisticRegressionModel = {
val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
// Hard-code class name string in case it changes in the future
val classNameV1_0 = "org.apache.spark.mllib.classification.LogisticRegressionModel"
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
val (numFeatures, numClasses) =
ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path)
val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0)
// numFeatures, numClasses, weights are checked in model initialization
val model =
new LogisticRegressionModel(data.weights, data.intercept, numFeatures, numClasses)
data.threshold match {
case Some(t) => model.setThreshold(t)
case None => model.clearThreshold()
}
model
case _ => throw new Exception(
s"LogisticRegressionModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $version). Supported:\n" +
s" ($classNameV1_0, 1.0)")
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ package org.apache.spark.mllib.classification

import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}

import org.apache.spark.{SparkException, Logging}
import org.apache.spark.SparkContext._
import org.apache.spark.{SparkContext, SparkException, Logging}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SQLContext}


/**
* Model for Naive Bayes Classifiers.
Expand All @@ -36,7 +38,7 @@ import org.apache.spark.rdd.RDD
class NaiveBayesModel private[mllib] (
val labels: Array[Double],
val pi: Array[Double],
val theta: Array[Array[Double]]) extends ClassificationModel with Serializable {
val theta: Array[Array[Double]]) extends ClassificationModel with Serializable with Saveable {

private val brzPi = new BDV[Double](pi)
private val brzTheta = new BDM[Double](theta.length, theta(0).length)
Expand Down Expand Up @@ -65,6 +67,85 @@ class NaiveBayesModel private[mllib] (
override def predict(testData: Vector): Double = {
labels(brzArgmax(brzPi + brzTheta * testData.toBreeze))
}

override def save(sc: SparkContext, path: String): Unit = {
val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta)
NaiveBayesModel.SaveLoadV1_0.save(sc, path, data)
}

override protected def formatVersion: String = "1.0"
}

object NaiveBayesModel extends Loader[NaiveBayesModel] {

import Loader._

private object SaveLoadV1_0 {

def thisFormatVersion = "1.0"

/** Hard-code class name string in case it changes in the future */
def thisClassName = "org.apache.spark.mllib.classification.NaiveBayesModel"

/** Model data for model import/export */
case class Data(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]])

def save(sc: SparkContext, path: String, data: Data): Unit = {
val sqlContext = new SQLContext(sc)
import sqlContext._

// Create JSON metadata.
val metadataRDD =
sc.parallelize(Seq((thisClassName, thisFormatVersion, data.theta(0).size, data.pi.size)), 1)
.toDataFrame("class", "version", "numFeatures", "numClasses")
metadataRDD.toJSON.saveAsTextFile(metadataPath(path))

// Create Parquet data.
val dataRDD: DataFrame = sc.parallelize(Seq(data), 1)
dataRDD.saveAsParquetFile(dataPath(path))
}

def load(sc: SparkContext, path: String): NaiveBayesModel = {
val sqlContext = new SQLContext(sc)
// Load Parquet data.
val dataRDD = sqlContext.parquetFile(dataPath(path))
// Check schema explicitly since erasure makes it hard to use match-case for checking.
checkSchema[Data](dataRDD.schema)
val dataArray = dataRDD.select("labels", "pi", "theta").take(1)
assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}")
val data = dataArray(0)
val labels = data.getAs[Seq[Double]](0).toArray
val pi = data.getAs[Seq[Double]](1).toArray
val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray
new NaiveBayesModel(labels, pi, theta)
}
}

override def load(sc: SparkContext, path: String): NaiveBayesModel = {
val (loadedClassName, version, metadata) = loadMetadata(sc, path)
val classNameV1_0 = SaveLoadV1_0.thisClassName
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
val (numFeatures, numClasses) =
ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path)
val model = SaveLoadV1_0.load(sc, path)
assert(model.pi.size == numClasses,
s"NaiveBayesModel.load expected $numClasses classes," +
s" but class priors vector pi had ${model.pi.size} elements")
assert(model.theta.size == numClasses,
s"NaiveBayesModel.load expected $numClasses classes," +
s" but class conditionals array theta had ${model.theta.size} elements")
assert(model.theta.forall(_.size == numFeatures),
s"NaiveBayesModel.load expected $numFeatures features," +
s" but class conditionals array theta had elements of size:" +
s" ${model.theta.map(_.size).mkString(",")}")
model
case _ => throw new Exception(
s"NaiveBayesModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $version). Supported:\n" +
s" ($classNameV1_0, 1.0)")
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@

package org.apache.spark.mllib.classification

import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.classification.impl.GLMClassificationModel
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.DataValidators
import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader}
import org.apache.spark.rdd.RDD


/**
* Model for Support Vector Machines (SVMs).
*
Expand All @@ -33,7 +36,8 @@ import org.apache.spark.rdd.RDD
class SVMModel (
override val weights: Vector,
override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable {
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable
with Saveable {

private var threshold: Option[Double] = Some(0.0)

Expand All @@ -49,6 +53,13 @@ class SVMModel (
this
}

/**
* :: Experimental ::
* Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions.
*/
@Experimental
def getThreshold: Option[Double] = threshold

/**
* :: Experimental ::
* Clears the threshold so that `predict` will output raw prediction scores.
Expand All @@ -69,6 +80,42 @@ class SVMModel (
case None => margin
}
}

override def save(sc: SparkContext, path: String): Unit = {
GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName,
numFeatures = weights.size, numClasses = 2, weights, intercept, threshold)
}

override protected def formatVersion: String = "1.0"
}

object SVMModel extends Loader[SVMModel] {

override def load(sc: SparkContext, path: String): SVMModel = {
val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
// Hard-code class name string in case it changes in the future
val classNameV1_0 = "org.apache.spark.mllib.classification.SVMModel"
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
val (numFeatures, numClasses) =
ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path)
val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0)
val model = new SVMModel(data.weights, data.intercept)
assert(model.weights.size == numFeatures, s"SVMModel.load with numFeatures=$numFeatures" +
s" was given non-matching weights vector of size ${model.weights.size}")
assert(numClasses == 2,
s"SVMModel.load was given numClasses=$numClasses but only supports 2 classes")
data.threshold match {
case Some(t) => model.setThreshold(t)
case None => model.clearThreshold()
}
model
case _ => throw new Exception(
s"SVMModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $version). Supported:\n" +
s" ($classNameV1_0, 1.0)")
}
}
}

/**
Expand Down
Loading

0 comments on commit 885bcbb

Please sign in to comment.