Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-5596] [mllib] ML model import/export for GLMs, NaiveBayes #4233

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@

package org.apache.spark.mllib.classification

import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.classification.impl.GLMClassificationModel
import org.apache.spark.mllib.linalg.BLAS.dot
import org.apache.spark.mllib.linalg.{DenseVector, Vector}
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{DataValidators, MLUtils}
import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader}
import org.apache.spark.rdd.RDD


/**
* Classification model trained using Multinomial/Binary Logistic Regression.
*
Expand All @@ -42,7 +45,8 @@ class LogisticRegressionModel (
override val intercept: Double,
val numFeatures: Int,
val numClasses: Int)
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable {
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable
with Saveable {

def this(weights: Vector, intercept: Double) = this(weights, intercept, weights.size, 2)

Expand All @@ -60,6 +64,13 @@ class LogisticRegressionModel (
this
}

/**
* :: Experimental ::
* Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions.
*/
@Experimental
def getThreshold: Option[Double] = threshold

/**
* :: Experimental ::
* Clears the threshold so that `predict` will output raw prediction scores.
Expand Down Expand Up @@ -126,6 +137,35 @@ class LogisticRegressionModel (
bestClass.toDouble
}
}

override def save(sc: SparkContext, path: String): Unit = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We merged multinomial logistic regression. LRModel holds numFeatures and numClasses now. We need a specialized implementation and a test for it. Or for all classification models, we save numFeatures and numClasses.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll save numFeatures and numClasses in all classification models' metadata. I'm going for metadata instead of data in case the model data requires multiple RDD rows (e.g., for decision tree).

GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any proposed guidelines about when to change the minor version and when the major version? I'm not expecting many versions, so I'm not sure whether it is necessary to have minor versions.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking minor versions could be used for format changes and major ones for model changes. But I'm OK with a single version number too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no strong preference here. It is okay with the current versioning.

weights, intercept, threshold)
}

override protected def formatVersion: String = "1.0"
}

object LogisticRegressionModel extends Loader[LogisticRegressionModel] {

override def load(sc: SparkContext, path: String): LogisticRegressionModel = {
val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
val classNameV1_0 = "org.apache.spark.mllib.classification.LogisticRegressionModel"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should put a comment here about why using literal string name.

(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0)
val model = new LogisticRegressionModel(data.weights, data.intercept)
data.threshold match {
case Some(t) => model.setThreshold(t)
case None => model.clearThreshold()
}
model
case _ => throw new Exception(
s"LogisticRegressionModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $version). Supported:\n" +
s" ($classNameV1_0, 1.0)")
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ package org.apache.spark.mllib.classification

import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}

import org.apache.spark.{SparkException, Logging}
import org.apache.spark.SparkContext._
import org.apache.spark.{SparkContext, SparkException, Logging}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SQLContext}


/**
* Model for Naive Bayes Classifiers.
Expand All @@ -36,7 +38,7 @@ import org.apache.spark.rdd.RDD
class NaiveBayesModel private[mllib] (
val labels: Array[Double],
val pi: Array[Double],
val theta: Array[Array[Double]]) extends ClassificationModel with Serializable {
val theta: Array[Array[Double]]) extends ClassificationModel with Serializable with Saveable {

private val brzPi = new BDV[Double](pi)
private val brzTheta = new BDM[Double](theta.length, theta(0).length)
Expand Down Expand Up @@ -65,6 +67,68 @@ class NaiveBayesModel private[mllib] (
override def predict(testData: Vector): Double = {
labels(brzArgmax(brzPi + brzTheta * testData.toBreeze))
}

override def save(sc: SparkContext, path: String): Unit = {
val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta)
NaiveBayesModel.SaveLoadV1_0.save(sc, path, data)
}

override protected def formatVersion: String = "1.0"
}

object NaiveBayesModel extends Loader[NaiveBayesModel] {

private object SaveLoadV1_0 {

def thisFormatVersion = "1.0"

def thisClassName = "org.apache.spark.mllib.classification.NaiveBayesModel"

/** Model data for model import/export */
case class Data(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]])

def save(sc: SparkContext, path: String, data: Data): Unit = {
val sqlContext = new SQLContext(sc)
import sqlContext._

// Create JSON metadata.
val metadataRDD =
sc.parallelize(Seq((thisClassName, thisFormatVersion))).toDataFrame("class", "version")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use sc.parallelize(.., 1) and remove repartition(1).

metadataRDD.toJSON.repartition(1).saveAsTextFile(path + "/metadata")

// Create Parquet data.
val dataRDD: DataFrame = sc.parallelize(Seq(data))
dataRDD.repartition(1).saveAsParquetFile(path + "/data")
}

def load(sc: SparkContext, path: String): NaiveBayesModel = {
val sqlContext = new SQLContext(sc)
// Load Parquet data.
val dataRDD = sqlContext.parquetFile(path + "/data")
// Check schema explicitly since erasure makes it hard to use match-case for checking.
Loader.checkSchema[Data](dataRDD.schema)
val dataArray = dataRDD.select("labels", "pi", "theta").take(1)
assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${path + "/data"}")
val data = dataArray(0)
val labels = data.getAs[Seq[Double]](0).toArray
val pi = data.getAs[Seq[Double]](1).toArray
val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray
new NaiveBayesModel(labels, pi, theta)
}
}

override def load(sc: SparkContext, path: String): NaiveBayesModel = {
val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
val classNameV1_0 = SaveLoadV1_0.thisClassName
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
SaveLoadV1_0.load(sc, path)
case _ => throw new Exception(
s"NaiveBayesModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $version). Supported:\n" +
s" ($classNameV1_0, 1.0)")
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@

package org.apache.spark.mllib.classification

import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.classification.impl.GLMClassificationModel
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.DataValidators
import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader}
import org.apache.spark.rdd.RDD


/**
* Model for Support Vector Machines (SVMs).
*
Expand All @@ -33,7 +36,8 @@ import org.apache.spark.rdd.RDD
class SVMModel (
override val weights: Vector,
override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable {
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable
with Saveable {

private var threshold: Option[Double] = Some(0.0)

Expand All @@ -49,6 +53,13 @@ class SVMModel (
this
}

/**
* :: Experimental ::
* Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions.
*/
@Experimental
def getThreshold: Option[Double] = threshold

/**
* :: Experimental ::
* Clears the threshold so that `predict` will output raw prediction scores.
Expand All @@ -69,6 +80,35 @@ class SVMModel (
case None => margin
}
}

override def save(sc: SparkContext, path: String): Unit = {
GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName,
weights, intercept, threshold)
}

override protected def formatVersion: String = "1.0"
}

object SVMModel extends Loader[SVMModel] {

override def load(sc: SparkContext, path: String): SVMModel = {
val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
val classNameV1_0 = "org.apache.spark.mllib.classification.SVMModel"
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0)
val model = new SVMModel(data.weights, data.intercept)
data.threshold match {
case Some(t) => model.setThreshold(t)
case None => model.clearThreshold()
}
model
case _ => throw new Exception(
s"SVMModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $version). Supported:\n" +
s" ($classNameV1_0, 1.0)")
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.classification.impl

import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.sql.{DataFrame, Row, SQLContext}

/**
* Helper class for import/export of GLM classification models.
*/
private[classification] object GLMClassificationModel {

object SaveLoadV1_0 {

def thisFormatVersion = "1.0"

/** Model data for model import/export */
case class Data(weights: Vector, intercept: Double, threshold: Option[Double])

def save(
sc: SparkContext,
path: String,
modelClass: String,
weights: Vector,
intercept: Double,
threshold: Option[Double]): Unit = {
val sqlContext = new SQLContext(sc)
import sqlContext._

// Create JSON metadata.
val metadataRDD =
sc.parallelize(Seq((modelClass, thisFormatVersion))).toDataFrame("class", "version")
metadataRDD.toJSON.repartition(1).saveAsTextFile(path + "/metadata")

// Create Parquet data.
val data = Data(weights, intercept, threshold)
val dataRDD: DataFrame = sc.parallelize(Seq(data))
// TODO: repartition with 1 partition after SPARK-5532 gets fixed
dataRDD.saveAsParquetFile(path + "/data")
}

def loadData(sc: SparkContext, path: String, modelClass: String): Data = {
val sqlContext = new SQLContext(sc)
val dataRDD = sqlContext.parquetFile(path + "/data")
val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1)
assert(dataArray.size == 1, s"Unable to load $modelClass data from: ${path + "/data"}")
val data = dataArray(0)
assert(data.size == 3, s"Unable to load $modelClass data from: ${path + "/data"}")
val (weights, intercept) = data match {
case Row(weights: Vector, intercept: Double, _) =>
(weights, intercept)
}
val threshold = if (data.isNullAt(2)) {
None
} else {
Some(data.getDouble(2))
}
Data(weights, intercept, threshold)
}
}

}
31 changes: 28 additions & 3 deletions mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

package org.apache.spark.mllib.regression

import org.apache.spark.annotation.Experimental
import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.regression.impl.GLMRegressionModel
import org.apache.spark.mllib.util.{Saveable, Loader}
import org.apache.spark.rdd.RDD

/**
Expand All @@ -32,20 +34,43 @@ class LassoModel (
override val weights: Vector,
override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept)
with RegressionModel with Serializable {
with RegressionModel with Serializable with Saveable {

override protected def predictPoint(
dataMatrix: Vector,
weightMatrix: Vector,
intercept: Double): Double = {
weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
}

override def save(sc: SparkContext, path: String): Unit = {
GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept)
}

override protected def formatVersion: String = "1.0"
}

object LassoModel extends Loader[LassoModel] {

override def load(sc: SparkContext, path: String): LassoModel = {
val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
val classNameV1_0 = "org.apache.spark.mllib.regression.LassoModel"
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0)
new LassoModel(data.weights, data.intercept)
case _ => throw new Exception(
s"LassoModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $version). Supported:\n" +
s" ($classNameV1_0, 1.0)")
}
}
}

/**
* Train a regression model with L1-regularization using Stochastic Gradient Descent.
* This solves the l1-regularized least squares regression formulation
* f(weights) = 1/2n ||A weights-y||^2 + regParam ||weights||_1
* f(weights) = 1/2n ||A weights-y||^2^ + regParam ||weights||_1
* Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with
* its corresponding right hand side label y.
* See also the documentation for the precise formulation.
Expand Down
Loading