From 149685247cb53f213f505fdbea527c42a085ba5d Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sat, 31 Jan 2015 02:00:00 -0800 Subject: [PATCH] Added save/load for NaiveBayes --- .../classification/LogisticRegression.scala | 1 + .../mllib/classification/NaiveBayes.scala | 14 +++-- .../spark/mllib/util/modelImportExport.scala | 57 +++++++++---------- 3 files changed, 38 insertions(+), 34 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 4497a81e2d90d..093aa391dfaa8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -147,6 +147,7 @@ class LogisticRegressionModel ( clazz = this.getClass.getName, version = Exportable.latestVersion) val metadataRDD: DataFrame = sc.parallelize(Seq(metadata)) metadataRDD.toJSON.saveAsTextFile(path + "/metadata") + // Create Parquet data. val data = LogisticRegressionModel.Data(weights, intercept, threshold) val dataRDD: DataFrame = sc.parallelize(Seq(data)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 7616d5067e649..fa46b64e80fbf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -18,6 +18,8 @@ package org.apache.spark.mllib.classification import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.types.{ArrayType, DataType, DoubleType, StructField, StructType} import org.apache.spark.{SparkContext, SparkException, Logging} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector} @@ -79,6 +81,7 @@ class NaiveBayesModel private[mllib] ( clazz = this.getClass.getName, version = Exportable.latestVersion) val metadataRDD: DataFrame = sc.parallelize(Seq(metadata)) metadataRDD.toJSON.saveAsTextFile(path + "/metadata") + // Create Parquet data. val data = NaiveBayesModel.Data(labels, pi, theta) val dataRDD: DataFrame = sc.parallelize(Seq(data)) @@ -117,11 +120,12 @@ object NaiveBayesModel extends Importable[NaiveBayesModel] { assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${path + "/data"}") val data = dataArray(0) assert(data.size == 3, s"Unable to load NaiveBayesModel data from: ${path + "/data"}") - val nb = data match { - case Row(labels: Seq[Double], pi: Seq[Double], theta: Seq[Seq[Double]]) => - new NaiveBayesModel(labels.toArray, pi.toArray, theta.map(_.toArray).toArray) - } - nb + // Check schema explicitly since erasure makes it hard to use match-case for checking. + Importable.checkSchema[Data](dataRDD.schema) + val labels = data.getAs[Seq[Double]](0).toArray + val pi = data.getAs[Seq[Double]](1).toArray + val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray + new NaiveBayesModel(labels, pi, theta) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala index 729ba860728d9..06cd822afff5f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala @@ -17,8 +17,13 @@ package org.apache.spark.mllib.util +import scala.reflect.runtime.universe.TypeTag + import org.apache.spark.SparkContext import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.types.{DataType, StructType, StructField} +import org.apache.spark.sql.{DataFrame, Row, SQLContext} /** * :: DeveloperApi :: @@ -46,11 +51,7 @@ trait Exportable { } -/** - * :: DeveloperApi :: - */ -@DeveloperApi -object Exportable { +private[mllib] object Exportable { /** Current version of model import/export format. */ val latestVersion: String = "1.0" @@ -79,34 +80,32 @@ trait Importable[Model <: Exportable] { } -/* -/** - * :: DeveloperApi :: - * - * Trait for models and transformers which may be saved as files. - * This should be inherited by the class which implements model instances. - * - * This specializes [[Exportable]] for local models which can be stored on a single machine. - * This provides helper functionality, but developers can choose to use [[Exportable]] instead, - * even for local models. - */ -@DeveloperApi -trait LocalExportable { +private[mllib] object Importable { /** - * Save this model to the given path. - * - * This saves: - * - human-readable (JSON) model metadata to path/metadata/ - * - Parquet formatted data to path/data/ + * Check the schema of loaded model data. * - * The model may be loaded using [[Importable.load]]. + * This checks every field in the expected schema to make sure that a field with the same + * name and DataType appears in the loaded schema. Note that this does NOT check metadata + * or containsNull. * - * @param sc Spark context used to save model data. - * @param path Path specifying the directory in which to save this model. - * This directory and any intermediate directory will be created if needed. + * @param loadedSchema Schema for model data loaded from file. + * @tparam Data Expected data type from which an expected schema can be derived. */ - def save(sc: SparkContext, path: String): Unit + def checkSchema[Data: TypeTag](loadedSchema: StructType): Unit = { + // Check schema explicitly since erasure makes it hard to use match-case for checking. + val expectedFields: Array[StructField] = + ScalaReflection.schemaFor[Data].dataType.asInstanceOf[StructType].fields + val loadedFields: Map[String, DataType] = + loadedSchema.map(field => field.name -> field.dataType).toMap + expectedFields.foreach { field => + assert(loadedFields.contains(field.name), s"Unable to parse model data." + + s" Expected field with name ${field.name} was missing in loaded schema:" + + s" ${loadedFields.mkString(", ")}") + assert(loadedFields(field.name) == field.dataType, + s"Unable to parse model data. Expected field $field but found field" + + s" with different type: ${loadedFields(field.name)}") + } + } } -*/