Skip to content

Commit

Permalink
add save/load for kmeans and naive bayes
Browse files Browse the repository at this point in the history
  • Loading branch information
yinxusen committed Nov 19, 2015
1 parent 599a8c6 commit 4cfa86e
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 28 deletions.
Expand Up @@ -17,12 +17,15 @@

package org.apache.spark.ml.classification

import org.apache.hadoop.fs.Path

import org.apache.spark.SparkException
import org.apache.spark.annotation.Experimental
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes, NaiveBayesModel => OldNaiveBayesModel}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes}
import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesModel}
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -72,7 +75,7 @@ private[ml] trait NaiveBayesParams extends PredictorParams {
@Experimental
class NaiveBayes(override val uid: String)
extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel]
with NaiveBayesParams {
with NaiveBayesParams with DefaultParamsWritable {

def this() = this(Identifiable.randomUID("nb"))

Expand Down Expand Up @@ -102,6 +105,13 @@ class NaiveBayes(override val uid: String)
override def copy(extra: ParamMap): NaiveBayes = defaultCopy(extra)
}

@Since("1.6.0")
object NaiveBayes extends DefaultParamsReadable[NaiveBayes] {

@Since("1.6.0")
override def load(path: String): NaiveBayes = super.load(path)
}

/**
* :: Experimental ::
* Model produced by [[NaiveBayes]]
Expand All @@ -114,7 +124,8 @@ class NaiveBayesModel private[ml] (
override val uid: String,
val pi: Vector,
val theta: Matrix)
extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] with NaiveBayesParams {
extends ProbabilisticClassificationModel[Vector, NaiveBayesModel]
with NaiveBayesParams with MLWritable {

import OldNaiveBayes.{Bernoulli, Multinomial}

Expand Down Expand Up @@ -203,12 +214,15 @@ class NaiveBayesModel private[ml] (
s"NaiveBayesModel (uid=$uid) with ${pi.size} classes"
}

@Since("1.6.0")
override def write: MLWriter = new NaiveBayesModel.NaiveBayesModelWriter(this)
}

private[ml] object NaiveBayesModel {
@Since("1.6.0")
object NaiveBayesModel extends MLReadable[NaiveBayesModel] {

/** Convert a model from the old API */
def fromOld(
private[ml] def fromOld(
oldModel: OldNaiveBayesModel,
parent: NaiveBayes): NaiveBayesModel = {
val uid = if (parent != null) parent.uid else Identifiable.randomUID("nb")
Expand All @@ -218,4 +232,44 @@ private[ml] object NaiveBayesModel {
oldModel.theta.flatten, true)
new NaiveBayesModel(uid, pi, theta)
}

@Since("1.6.0")
override def read: MLReader[NaiveBayesModel] = new NaiveBayesModelReader

@Since("1.6.0")
override def load(path: String): NaiveBayesModel = super.load(path)

/** [[MLWriter]] instance for [[NaiveBayesModel]] */
private[NaiveBayesModel] class NaiveBayesModelWriter(instance: NaiveBayesModel) extends MLWriter {

private case class Data(pi: Vector, theta: Matrix)

override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Save model data: pi, theta
val data = Data(instance.pi, instance.theta)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

private class NaiveBayesModelReader extends MLReader[NaiveBayesModel] {

/** Checked against metadata when loading model */
private val className = classOf[NaiveBayesModel].getName

override def load(path: String): NaiveBayesModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)

val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath).select("pi", "theta").head()
val pi = data.getAs[Vector](0)
val theta = data.getAs[Matrix](1)
val model = new NaiveBayesModel(metadata.uid, pi, theta)

DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}
}
67 changes: 61 additions & 6 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
Expand Up @@ -17,18 +17,19 @@

package org.apache.spark.ml.clustering

import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml.param.{Param, Params, IntParam, ParamMap}
import org.apache.hadoop.fs.Path

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params}
import org.apache.spark.ml.util._
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel}
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.sql.{DataFrame, Row}


/**
* Common params for KMeans and KMeansModel
*/
Expand Down Expand Up @@ -94,7 +95,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
@Experimental
class KMeansModel private[ml] (
@Since("1.5.0") override val uid: String,
private val parentModel: MLlibKMeansModel) extends Model[KMeansModel] with KMeansParams {
private val parentModel: MLlibKMeansModel)
extends Model[KMeansModel] with KMeansParams with MLWritable {

@Since("1.5.0")
override def copy(extra: ParamMap): KMeansModel = {
Expand Down Expand Up @@ -129,6 +131,52 @@ class KMeansModel private[ml] (
val data = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point }
parentModel.computeCost(data)
}

@Since("1.6.0")
override def write: MLWriter = new KMeansModel.KMeansModelWriter(this)
}

@Since("1.6.0")
object KMeansModel extends MLReadable[KMeansModel] {

@Since("1.6.0")
override def read: MLReader[KMeansModel] = new AFTSurvivalRegressionModelReader

@Since("1.6.0")
override def load(path: String): KMeansModel = super.load(path)

/** [[MLWriter]] instance for [[KMeansModel]] */
private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter {

private case class Data(clusterCenters: Array[Vector])

override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Save model data: cluster centers
val data = Data(instance.clusterCenters)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

private class AFTSurvivalRegressionModelReader extends MLReader[KMeansModel] {

/** Checked against metadata when loading model */
private val className = classOf[KMeansModel].getName

override def load(path: String): KMeansModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)

val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath).select("clusterCenters").head()
val clusterCenters = data.getAs[Seq[Vector]](0).toArray
val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters))

DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}
}

/**
Expand All @@ -141,7 +189,7 @@ class KMeansModel private[ml] (
@Experimental
class KMeans @Since("1.5.0") (
@Since("1.5.0") override val uid: String)
extends Estimator[KMeansModel] with KMeansParams {
extends Estimator[KMeansModel] with KMeansParams with DefaultParamsWritable {

setDefault(
k -> 2,
Expand Down Expand Up @@ -210,3 +258,10 @@ class KMeans @Since("1.5.0") (
}
}

@Since("1.6.0")
object KMeans extends DefaultParamsReadable[KMeans] {

@Since("1.6.0")
override def load(path: String): KMeans = super.load(path)
}

Expand Up @@ -21,15 +21,30 @@ import breeze.linalg.{Vector => BV}

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.classification.NaiveBayes.{Multinomial, Bernoulli}
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.classification.NaiveBayes.{Bernoulli, Multinomial}
import org.apache.spark.mllib.classification.NaiveBayesSuite._
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.mllib.classification.NaiveBayesSuite._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Row
import org.apache.spark.sql.{DataFrame, Row}

class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

@transient var dataset: DataFrame = _

override def beforeAll(): Unit = {
super.beforeAll()

val pi = Array(0.5, 0.1, 0.4).map(math.log)
val theta = Array(
Array(0.70, 0.10, 0.10, 0.10), // label 0
Array(0.10, 0.70, 0.10, 0.10), // label 1
Array(0.10, 0.10, 0.70, 0.10) // label 2
).map(_.map(math.log))

class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
dataset = sqlContext.createDataFrame(generateNaiveBayesInput(pi, theta, 100, 42))
}

def validatePrediction(predictionAndLabels: DataFrame): Unit = {
val numOfErrorPredictions = predictionAndLabels.collect().count {
Expand Down Expand Up @@ -161,4 +176,26 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
.select("features", "probability")
validateProbabilities(featureAndProbabilities, model, "bernoulli")
}

test("read/write") {
def checkModelData(model: NaiveBayesModel, model2: NaiveBayesModel): Unit = {
assert(model.pi === model2.pi)
assert(model.theta === model2.theta)
}
val nb = new NaiveBayes()
testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, checkModelData)
}
}

object NaiveBayesSuite {

/**
* Mapping from all Params to valid settings which differ from the defaults.
* This is useful for tests which need to exercise all Params, such as save/load.
* This excludes input columns to simplify some tests.
*/
val allParamSettings: Map[String, Any] = Map(
"predictionCol" -> "myPrediction",
"smoothing" -> 0.1
)
}
Expand Up @@ -18,23 +18,15 @@
package org.apache.spark.ml.clustering

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, SQLContext}

private[clustering] case class TestRow(features: Vector)

object KMeansSuite {
def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = {
val sc = sql.sparkContext
val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble)))
.map(v => new TestRow(v))
sql.createDataFrame(rdd)
}
}

class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

final val k = 5
@transient var dataset: DataFrame = _
Expand Down Expand Up @@ -106,4 +98,33 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(clusters === Set(0, 1, 2, 3, 4))
assert(model.computeCost(dataset) < 0.1)
}

test("read/write") {
def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = {
assert(model.clusterCenters === model2.clusterCenters)
}
val kmeans = new KMeans()
testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, checkModelData)
}
}

object KMeansSuite {
def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = {
val sc = sql.sparkContext
val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble)))
.map(v => new TestRow(v))
sql.createDataFrame(rdd)
}

/**
* Mapping from all Params to valid settings which differ from the defaults.
* This is useful for tests which need to exercise all Params, such as save/load.
* This excludes input columns to simplify some tests.
*/
val allParamSettings: Map[String, Any] = Map(
"predictionCol" -> "myPrediction",
"k" -> 3,
"maxIter" -> 2,
"tol" -> 0.01
)
}

0 comments on commit 4cfa86e

Please sign in to comment.