Skip to content

Commit

Permalink
Added save/load for NaiveBayes
Browse files Browse the repository at this point in the history
  • Loading branch information
jkbradley committed Feb 3, 2015
1 parent 8d46386 commit 1496852
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ class LogisticRegressionModel (
clazz = this.getClass.getName, version = Exportable.latestVersion)
val metadataRDD: DataFrame = sc.parallelize(Seq(metadata))
metadataRDD.toJSON.saveAsTextFile(path + "/metadata")

// Create Parquet data.
val data = LogisticRegressionModel.Data(weights, intercept, threshold)
val dataRDD: DataFrame = sc.parallelize(Seq(data))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
package org.apache.spark.mllib.classification

import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.types.{ArrayType, DataType, DoubleType, StructField, StructType}

import org.apache.spark.{SparkContext, SparkException, Logging}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
Expand Down Expand Up @@ -79,6 +81,7 @@ class NaiveBayesModel private[mllib] (
clazz = this.getClass.getName, version = Exportable.latestVersion)
val metadataRDD: DataFrame = sc.parallelize(Seq(metadata))
metadataRDD.toJSON.saveAsTextFile(path + "/metadata")

// Create Parquet data.
val data = NaiveBayesModel.Data(labels, pi, theta)
val dataRDD: DataFrame = sc.parallelize(Seq(data))
Expand Down Expand Up @@ -117,11 +120,12 @@ object NaiveBayesModel extends Importable[NaiveBayesModel] {
assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${path + "/data"}")
val data = dataArray(0)
assert(data.size == 3, s"Unable to load NaiveBayesModel data from: ${path + "/data"}")
val nb = data match {
case Row(labels: Seq[Double], pi: Seq[Double], theta: Seq[Seq[Double]]) =>
new NaiveBayesModel(labels.toArray, pi.toArray, theta.map(_.toArray).toArray)
}
nb
// Check schema explicitly since erasure makes it hard to use match-case for checking.
Importable.checkSchema[Data](dataRDD.schema)
val labels = data.getAs[Seq[Double]](0).toArray
val pi = data.getAs[Seq[Double]](1).toArray
val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray
new NaiveBayesModel(labels, pi, theta)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@

package org.apache.spark.mllib.util

import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.SparkContext
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.types.{DataType, StructType, StructField}
import org.apache.spark.sql.{DataFrame, Row, SQLContext}

/**
* :: DeveloperApi ::
Expand Down Expand Up @@ -46,11 +51,7 @@ trait Exportable {

}

/**
* :: DeveloperApi ::
*/
@DeveloperApi
object Exportable {
private[mllib] object Exportable {

/** Current version of model import/export format. */
val latestVersion: String = "1.0"
Expand Down Expand Up @@ -79,34 +80,32 @@ trait Importable[Model <: Exportable] {

}

/*
/**
* :: DeveloperApi ::
*
* Trait for models and transformers which may be saved as files.
* This should be inherited by the class which implements model instances.
*
* This specializes [[Exportable]] for local models which can be stored on a single machine.
* This provides helper functionality, but developers can choose to use [[Exportable]] instead,
* even for local models.
*/
@DeveloperApi
trait LocalExportable {
private[mllib] object Importable {

/**
* Save this model to the given path.
*
* This saves:
* - human-readable (JSON) model metadata to path/metadata/
* - Parquet formatted data to path/data/
* Check the schema of loaded model data.
*
* The model may be loaded using [[Importable.load]].
* This checks every field in the expected schema to make sure that a field with the same
* name and DataType appears in the loaded schema. Note that this does NOT check metadata
* or containsNull.
*
* @param sc Spark context used to save model data.
* @param path Path specifying the directory in which to save this model.
* This directory and any intermediate directory will be created if needed.
* @param loadedSchema Schema for model data loaded from file.
* @tparam Data Expected data type from which an expected schema can be derived.
*/
def save(sc: SparkContext, path: String): Unit
def checkSchema[Data: TypeTag](loadedSchema: StructType): Unit = {
// Check schema explicitly since erasure makes it hard to use match-case for checking.
val expectedFields: Array[StructField] =
ScalaReflection.schemaFor[Data].dataType.asInstanceOf[StructType].fields
val loadedFields: Map[String, DataType] =
loadedSchema.map(field => field.name -> field.dataType).toMap
expectedFields.foreach { field =>
assert(loadedFields.contains(field.name), s"Unable to parse model data." +
s" Expected field with name ${field.name} was missing in loaded schema:" +
s" ${loadedFields.mkString(", ")}")
assert(loadedFields(field.name) == field.dataType,
s"Unable to parse model data. Expected field $field but found field" +
s" with different type: ${loadedFields(field.name)}")
}
}

}
*/

0 comments on commit 1496852

Please sign in to comment.