Skip to content

Commit

Permalink
Added save/load to NaiveBayes
Browse files Browse the repository at this point in the history
  • Loading branch information
jkbradley committed Feb 3, 2015
1 parent 1577d70 commit 8d46386
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,10 @@ class LogisticRegressionModel (

object LogisticRegressionModel extends Importable[LogisticRegressionModel] {

/** Metadata for model import/export */
private case class Metadata(clazz: String, version: String)

/** Model data for model import/export */
private case class Data(weights: Vector, intercept: Double, threshold: Option[Double])

override def load(sc: SparkContext, path: String): LogisticRegressionModel = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@ package org.apache.spark.mllib.classification

import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}

import org.apache.spark.{SparkException, Logging}
import org.apache.spark.SparkContext._
import org.apache.spark.{SparkContext, SparkException, Logging}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{Importable, Exportable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, DataFrame, SQLContext}

import scala.collection.mutable.ArrayBuffer


/**
* Model for Naive Bayes Classifiers.
Expand All @@ -36,7 +40,7 @@ import org.apache.spark.rdd.RDD
class NaiveBayesModel private[mllib] (
val labels: Array[Double],
val pi: Array[Double],
val theta: Array[Array[Double]]) extends ClassificationModel with Serializable {
val theta: Array[Array[Double]]) extends ClassificationModel with Serializable with Exportable {

private val brzPi = new BDV[Double](pi)
private val brzTheta = new BDM[Double](theta.length, theta(0).length)
Expand Down Expand Up @@ -65,6 +69,60 @@ class NaiveBayesModel private[mllib] (
override def predict(testData: Vector): Double = {
labels(brzArgmax(brzPi + brzTheta * testData.toBreeze))
}

override def save(sc: SparkContext, path: String): Unit = {
val sqlContext = new SQLContext(sc)
import sqlContext._

// Create JSON metadata.
val metadata = NaiveBayesModel.Metadata(
clazz = this.getClass.getName, version = Exportable.latestVersion)
val metadataRDD: DataFrame = sc.parallelize(Seq(metadata))
metadataRDD.toJSON.saveAsTextFile(path + "/metadata")
// Create Parquet data.
val data = NaiveBayesModel.Data(labels, pi, theta)
val dataRDD: DataFrame = sc.parallelize(Seq(data))
dataRDD.saveAsParquetFile(path + "/data")
}
}

object NaiveBayesModel extends Importable[NaiveBayesModel] {

/** Metadata for model import/export */
private case class Metadata(clazz: String, version: String)

/** Model data for model import/export */
private case class Data(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]])

override def load(sc: SparkContext, path: String): NaiveBayesModel = {
val sqlContext = new SQLContext(sc)
import sqlContext._

// Load JSON metadata.
val metadataRDD = sqlContext.jsonFile(path + "/metadata")
val metadataArray = metadataRDD.select("clazz", "version").take(1)
assert(metadataArray.size == 1,
s"Unable to load NaiveBayesModel metadata from: ${path + "/metadata"}")
metadataArray(0) match {
case Row(clazz: String, version: String) =>
assert(clazz == classOf[NaiveBayesModel].getName, s"NaiveBayesModel.load" +
s" was given model file with metadata specifying a different model class: $clazz")
assert(version == Exportable.latestVersion, // only 1 version exists currently
s"NaiveBayesModel.load did not recognize model format version: $version")
}

// Load Parquet data.
val dataRDD = sqlContext.parquetFile(path + "/data")
val dataArray = dataRDD.select("labels", "pi", "theta").take(1)
assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${path + "/data"}")
val data = dataArray(0)
assert(data.size == 3, s"Unable to load NaiveBayesModel data from: ${path + "/data"}")
val nb = data match {
case Row(labels: Seq[Double], pi: Seq[Double], theta: Seq[Seq[Double]]) =>
new NaiveBayesModel(labels.toArray, pi.toArray, theta.map(_.toArray).toArray)
}
nb
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ trait Exportable {

}

/**
* :: DeveloperApi ::
*/
@DeveloperApi
object Exportable {

/** Current version of model import/export format. */
Expand Down Expand Up @@ -74,3 +78,35 @@ trait Importable[Model <: Exportable] {
def load(sc: SparkContext, path: String): Model

}

/*
/**
* :: DeveloperApi ::
*
* Trait for models and transformers which may be saved as files.
* This should be inherited by the class which implements model instances.
*
* This specializes [[Exportable]] for local models which can be stored on a single machine.
* This provides helper functionality, but developers can choose to use [[Exportable]] instead,
* even for local models.
*/
@DeveloperApi
trait LocalExportable {
/**
* Save this model to the given path.
*
* This saves:
* - human-readable (JSON) model metadata to path/metadata/
* - Parquet formatted data to path/data/
*
* The model may be loaded using [[Importable.load]].
*
* @param sc Spark context used to save model data.
* @param path Path specifying the directory in which to save this model.
* This directory and any intermediate directory will be created if needed.
*/
def save(sc: SparkContext, path: String): Unit
}
*/
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.mllib.classification

import org.apache.spark.util.Utils

import scala.util.Random

import org.scalatest.FunSuite
Expand Down Expand Up @@ -58,6 +60,14 @@ object NaiveBayesSuite {
LabeledPoint(y, Vectors.dense(xi))
}
}

private val smallPi = Array(0.5, 0.3, 0.2).map(math.log)

private val smallTheta = Array(
Array(0.91, 0.03, 0.03, 0.03), // label 0
Array(0.03, 0.91, 0.03, 0.03), // label 1
Array(0.03, 0.03, 0.91, 0.03) // label 2
).map(_.map(math.log))
}

class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
Expand All @@ -74,12 +84,8 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
test("Naive Bayes") {
val nPoints = 10000

val pi = Array(0.5, 0.3, 0.2).map(math.log)
val theta = Array(
Array(0.91, 0.03, 0.03, 0.03), // label 0
Array(0.03, 0.91, 0.03, 0.03), // label 1
Array(0.03, 0.03, 0.91, 0.03) // label 2
).map(_.map(math.log))
val pi = NaiveBayesSuite.smallPi
val theta = NaiveBayesSuite.smallTheta

val testData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 42)
val testRDD = sc.parallelize(testData, 2)
Expand Down Expand Up @@ -123,6 +129,30 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
NaiveBayes.train(sc.makeRDD(nan, 2))
}
}

test("model export/import") {
val nPoints = 10

val pi = NaiveBayesSuite.smallPi
val theta = NaiveBayesSuite.smallTheta

val data = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 42)
val rdd = sc.parallelize(data, 2)
rdd.cache()

val model = NaiveBayes.train(rdd)

val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString

// Save model, load it back, and compare.
model.save(sc, path)
val sameModel = NaiveBayesModel.load(sc, path)
assert(model.labels === sameModel.labels)
assert(model.pi === sameModel.pi)
assert(model.theta === sameModel.theta)
Utils.deleteRecursively(tempDir)
}
}

class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext {
Expand Down

0 comments on commit 8d46386

Please sign in to comment.