Skip to content

Commit

Permalink
[SPARK-11892][ML] Model export/import for spark.ml: OneVsRest
Browse files Browse the repository at this point in the history
# What changes were proposed in this pull request?

https://issues.apache.org/jira/browse/SPARK-11892

Add save/load for spark ml.OneVsRest and its model. Also add OneVsRest and OneVsRestModel in MetaAlgorithmReadWrite.

# How was this patch tested?

Test with Scala unit test.

Author: Xusen Yin <yinxusen@gmail.com>

Closes #9934 from yinxusen/SPARK-11892.
  • Loading branch information
yinxusen authored and jkbradley committed Mar 31, 2016
1 parent a0a1991 commit 8b207f3
Show file tree
Hide file tree
Showing 3 changed files with 223 additions and 18 deletions.
165 changes: 154 additions & 11 deletions mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,37 @@ import java.util.UUID

import scala.language.existentials

import org.apache.hadoop.fs.Path
import org.json4s.{DefaultFormats, JObject, _}
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.SparkContext
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml._
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel

/**
* Params for [[OneVsRest]].
*/
private[ml] trait OneVsRestParams extends PredictorParams {

private[ml] trait ClassifierTypeTrait {
// scalastyle:off structural.type
type ClassifierType = Classifier[F, E, M] forSome {
type F
type M <: ClassificationModel[F, M]
type E <: Classifier[F, E, M]
}
// scalastyle:on structural.type
}

/**
* Params for [[OneVsRest]].
*/
private[ml] trait OneVsRestParams extends PredictorParams with ClassifierTypeTrait {

/**
* param for the base binary classifier that we reduce multiclass classification into.
Expand All @@ -57,6 +65,55 @@ private[ml] trait OneVsRestParams extends PredictorParams {
def getClassifier: ClassifierType = $(classifier)
}

private[ml] object OneVsRestParams extends ClassifierTypeTrait {

def validateParams(instance: OneVsRestParams): Unit = {
def checkElement(elem: Params, name: String): Unit = elem match {
case stage: MLWritable => // good
case other =>
throw new UnsupportedOperationException("OneVsRest write will fail " +
s" because it contains $name which does not implement MLWritable." +
s" Non-Writable $name: ${other.uid} of type ${other.getClass}")
}

instance match {
case ovrModel: OneVsRestModel => ovrModel.models.foreach(checkElement(_, "model"))
case _ => // no need to check OneVsRest here
}

checkElement(instance.getClassifier, "classifier")
}

def saveImpl(
path: String,
instance: OneVsRestParams,
sc: SparkContext,
extraMetadata: Option[JObject] = None): Unit = {

val params = instance.extractParamMap().toSeq
val jsonParams = render(params
.filter { case ParamPair(p, v) => p.name != "classifier" }
.map { case ParamPair(p, v) => p.name -> parse(p.jsonEncode(v)) }
.toList)

DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams))

val classifierPath = new Path(path, "classifier").toString
instance.getClassifier.asInstanceOf[MLWritable].save(classifierPath)
}

def loadImpl(
path: String,
sc: SparkContext,
expectedClassName: String): (DefaultParamsReader.Metadata, ClassifierType) = {

val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
val classifierPath = new Path(path, "classifier").toString
val estimator = DefaultParamsReader.loadParamsInstance[ClassifierType](classifierPath, sc)
(metadata, estimator)
}
}

/**
* :: Experimental ::
* Model produced by [[OneVsRest]].
Expand All @@ -73,10 +130,10 @@ private[ml] trait OneVsRestParams extends PredictorParams {
@Since("1.4.0")
@Experimental
final class OneVsRestModel private[ml] (
@Since("1.4.0") override val uid: String,
@Since("1.4.0") labelMetadata: Metadata,
@Since("1.4.0") override val uid: String,
private[ml] val labelMetadata: Metadata,
@Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]])
extends Model[OneVsRestModel] with OneVsRestParams {
extends Model[OneVsRestModel] with OneVsRestParams with MLWritable {

@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
Expand Down Expand Up @@ -143,6 +200,56 @@ final class OneVsRestModel private[ml] (
uid, labelMetadata, models.map(_.copy(extra).asInstanceOf[ClassificationModel[_, _]]))
copyValues(copied, extra).setParent(parent)
}

@Since("2.0.0")
override def write: MLWriter = new OneVsRestModel.OneVsRestModelWriter(this)
}

@Since("2.0.0")
object OneVsRestModel extends MLReadable[OneVsRestModel] {

@Since("2.0.0")
override def read: MLReader[OneVsRestModel] = new OneVsRestModelReader

@Since("2.0.0")
override def load(path: String): OneVsRestModel = super.load(path)

/** [[MLWriter]] instance for [[OneVsRestModel]] */
private[OneVsRestModel] class OneVsRestModelWriter(instance: OneVsRestModel) extends MLWriter {

OneVsRestParams.validateParams(instance)

override protected def saveImpl(path: String): Unit = {
val extraJson = ("labelMetadata" -> instance.labelMetadata.json) ~
("numClasses" -> instance.models.length)
OneVsRestParams.saveImpl(path, instance, sc, Some(extraJson))
instance.models.zipWithIndex.foreach { case (model: MLWritable, idx) =>
val modelPath = new Path(path, s"model_$idx").toString
model.save(modelPath)
}
}
}

private class OneVsRestModelReader extends MLReader[OneVsRestModel] {

/** Checked against metadata when loading model */
private val className = classOf[OneVsRestModel].getName

override def load(path: String): OneVsRestModel = {
implicit val format = DefaultFormats
val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className)
val labelMetadata = Metadata.fromJson((metadata.metadata \ "labelMetadata").extract[String])
val numClasses = (metadata.metadata \ "numClasses").extract[Int]
val models = Range(0, numClasses).toArray.map { idx =>
val modelPath = new Path(path, s"model_$idx").toString
DefaultParamsReader.loadParamsInstance[ClassificationModel[_, _]](modelPath, sc)
}
val ovrModel = new OneVsRestModel(metadata.uid, labelMetadata, models)
DefaultParamsReader.getAndSetParams(ovrModel, metadata)
ovrModel.set("classifier", classifier)
ovrModel
}
}
}

/**
Expand All @@ -158,7 +265,7 @@ final class OneVsRestModel private[ml] (
@Experimental
final class OneVsRest @Since("1.4.0") (
@Since("1.4.0") override val uid: String)
extends Estimator[OneVsRestModel] with OneVsRestParams {
extends Estimator[OneVsRestModel] with OneVsRestParams with MLWritable {

@Since("1.4.0")
def this() = this(Identifiable.randomUID("oneVsRest"))
Expand Down Expand Up @@ -243,4 +350,40 @@ final class OneVsRest @Since("1.4.0") (
}
copied
}

@Since("2.0.0")
override def write: MLWriter = new OneVsRest.OneVsRestWriter(this)
}

@Since("2.0.0")
object OneVsRest extends MLReadable[OneVsRest] {

@Since("2.0.0")
override def read: MLReader[OneVsRest] = new OneVsRestReader

@Since("2.0.0")
override def load(path: String): OneVsRest = super.load(path)

/** [[MLWriter]] instance for [[OneVsRest]] */
private[OneVsRest] class OneVsRestWriter(instance: OneVsRest) extends MLWriter {

OneVsRestParams.validateParams(instance)

override protected def saveImpl(path: String): Unit = {
OneVsRestParams.saveImpl(path, instance, sc)
}
}

private class OneVsRestReader extends MLReader[OneVsRest] {

/** Checked against metadata when loading model */
private val className = classOf[OneVsRest].getName

override def load(path: String): OneVsRest = {
val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className)
val ovr = new OneVsRest(metadata.uid)
DefaultParamsReader.getAndSetParams(ovr, metadata)
ovr.setClassifier(classifier)
}
}
}
8 changes: 3 additions & 5 deletions mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.SparkContext
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml._
import org.apache.spark.ml.classification.OneVsRestParams
import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel}
import org.apache.spark.ml.feature.RFormulaModel
import org.apache.spark.ml.param.{ParamPair, Params}
import org.apache.spark.ml.tuning.ValidatorParams
Expand Down Expand Up @@ -381,10 +381,8 @@ private[ml] object MetaAlgorithmReadWrite {
case p: Pipeline => p.getStages.asInstanceOf[Array[Params]]
case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]]
case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator)
case ovr: OneVsRestParams =>
// TODO: SPARK-11892: This case may require special handling.
throw new UnsupportedOperationException(s"${instance.getClass.getName} write will fail" +
s" because it cannot yet handle an estimator containing type: ${ovr.getClass.getName}.")
case ovr: OneVsRest => Array(ovr.getClassifier)
case ovrModel: OneVsRestModel => Array(ovrModel.getClassifier) ++ ovrModel.models
case rformModel: RFormulaModel => Array(rformModel.pipelineModel)
case _: Params => Array()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
import org.apache.spark.ml.util.{MetadataUtils, MLTestingUtils}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MetadataUtils, MLTestingUtils}
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.evaluation.MulticlassMetrics
Expand All @@ -33,7 +33,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.Metadata

class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

@transient var dataset: DataFrame = _
@transient var rdd: RDD[LabeledPoint] = _
Expand Down Expand Up @@ -160,6 +160,70 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
require(m.getThreshold === 0.1, "copy should handle extra model params")
}
}

test("read/write: OneVsRest") {
val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.01)

val ova = new OneVsRest()
.setClassifier(lr)
.setLabelCol("myLabel")
.setFeaturesCol("myFeature")
.setPredictionCol("myPrediction")

val ova2 = testDefaultReadWrite(ova, testParams = false)
assert(ova.uid === ova2.uid)
assert(ova.getFeaturesCol === ova2.getFeaturesCol)
assert(ova.getLabelCol === ova2.getLabelCol)
assert(ova.getPredictionCol === ova2.getPredictionCol)

ova2.getClassifier match {
case lr2: LogisticRegression =>
assert(lr.uid === lr2.uid)
assert(lr.getMaxIter === lr2.getMaxIter)
assert(lr.getRegParam === lr2.getRegParam)
case other =>
throw new AssertionError(s"Loaded OneVsRest expected classifier of type" +
s" LogisticRegression but found ${other.getClass.getName}")
}
}

test("read/write: OneVsRestModel") {
def checkModelData(model: OneVsRestModel, model2: OneVsRestModel): Unit = {
assert(model.uid === model2.uid)
assert(model.getFeaturesCol === model2.getFeaturesCol)
assert(model.getLabelCol === model2.getLabelCol)
assert(model.getPredictionCol === model2.getPredictionCol)

val classifier = model.getClassifier.asInstanceOf[LogisticRegression]

model2.getClassifier match {
case lr2: LogisticRegression =>
assert(classifier.uid === lr2.uid)
assert(classifier.getMaxIter === lr2.getMaxIter)
assert(classifier.getRegParam === lr2.getRegParam)
case other =>
throw new AssertionError(s"Loaded OneVsRestModel expected classifier of type" +
s" LogisticRegression but found ${other.getClass.getName}")
}

assert(model.labelMetadata === model2.labelMetadata)
model.models.zip(model2.models).foreach {
case (lrModel1: LogisticRegressionModel, lrModel2: LogisticRegressionModel) =>
assert(lrModel1.uid === lrModel2.uid)
assert(lrModel1.coefficients === lrModel2.coefficients)
assert(lrModel1.intercept === lrModel2.intercept)
case other =>
throw new AssertionError(s"Loaded OneVsRestModel expected model of type" +
s" LogisticRegressionModel but found ${other.getClass.getName}")
}
}

val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.01)
val ova = new OneVsRest().setClassifier(lr)
val ovaModel = ova.fit(dataset)
val newOvaModel = testDefaultReadWrite(ovaModel, testParams = false)
checkModelData(ovaModel, newOvaModel)
}
}

private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) {
Expand Down

0 comments on commit 8b207f3

Please sign in to comment.