From 3b92783e58a66cc9d5ca200f47719d7552cb5c88 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 10 Nov 2015 13:16:40 -0800 Subject: [PATCH 1/7] added save/load to logreg in spark.ml --- .../org/apache/spark/ml/util/ReadWrite.scala | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 85f888c9f2f67..2c0c230c4add0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -283,3 +283,46 @@ private[ml] object DefaultParamsReader { } } } + +private[ml] object DefaultParamsReader { + + case class Metadata(className: String, uid: String, timestamp: Long, sparkVersion: String) + + /** + * Load metadata from file. + * @param expectedClassName If non empty, this is checked against the loaded metadata. + * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata + */ + def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = { + implicit val format = DefaultFormats + val metadataPath = new Path(path, "metadata").toString + val metadataStr = sc.textFile(metadataPath, 1).first() + val metadata = parse(metadataStr) + val className = (metadata \ "class").extract[String] + val uid = (metadata \ "uid").extract[String] + val timestamp = (metadata \ "timestamp").extract[Long] + val sparkVersion = (metadata \ "sparkVersion").extract[String] + if (expectedClassName.nonEmpty) { + require(className == expectedClassName, s"Error loading metadata: Expected class name" + + s" $expectedClassName but found class name $className") + } + Metadata(className, uid, timestamp, sparkVersion) + } + + def loadParams(instance: Params, path: String, sc: SparkContext): Unit = { + implicit val format = DefaultFormats + val metadataPath = new Path(path, "metadata").toString + val metadataStr = sc.textFile(metadataPath, 1).first() + val metadata = parse(metadataStr) + (metadata \ "paramMap") match { + case JObject(pairs) => + pairs.foreach { case (paramName, jsonValue) => + val param = instance.getParam(paramName) + val value = param.jsonDecode(compact(render(jsonValue))) + instance.set(param, value) + } + case _ => + throw new IllegalArgumentException(s"Cannot recognize JSON metadata: $metadataStr.") + } + } +} From a367e7485ed92eaffb99fb48518c332a2c51f7ea Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 10 Nov 2015 13:53:42 -0800 Subject: [PATCH 2/7] fixed read, write for logreg --- .../org/apache/spark/ml/util/ReadWrite.scala | 43 ------------------- 1 file changed, 43 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 2c0c230c4add0..85f888c9f2f67 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -283,46 +283,3 @@ private[ml] object DefaultParamsReader { } } } - -private[ml] object DefaultParamsReader { - - case class Metadata(className: String, uid: String, timestamp: Long, sparkVersion: String) - - /** - * Load metadata from file. - * @param expectedClassName If non empty, this is checked against the loaded metadata. - * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata - */ - def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = { - implicit val format = DefaultFormats - val metadataPath = new Path(path, "metadata").toString - val metadataStr = sc.textFile(metadataPath, 1).first() - val metadata = parse(metadataStr) - val className = (metadata \ "class").extract[String] - val uid = (metadata \ "uid").extract[String] - val timestamp = (metadata \ "timestamp").extract[Long] - val sparkVersion = (metadata \ "sparkVersion").extract[String] - if (expectedClassName.nonEmpty) { - require(className == expectedClassName, s"Error loading metadata: Expected class name" + - s" $expectedClassName but found class name $className") - } - Metadata(className, uid, timestamp, sparkVersion) - } - - def loadParams(instance: Params, path: String, sc: SparkContext): Unit = { - implicit val format = DefaultFormats - val metadataPath = new Path(path, "metadata").toString - val metadataStr = sc.textFile(metadataPath, 1).first() - val metadata = parse(metadataStr) - (metadata \ "paramMap") match { - case JObject(pairs) => - pairs.foreach { case (paramName, jsonValue) => - val param = instance.getParam(paramName) - val value = param.jsonDecode(compact(render(jsonValue))) - instance.set(param, value) - } - case _ => - throw new IllegalArgumentException(s"Cannot recognize JSON metadata: $metadataStr.") - } - } -} From 38d262c82aedfcda3c9d7ee9ea5b64c78832f9f7 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 11 Nov 2015 10:56:44 -0800 Subject: [PATCH 3/7] added Pipeline save, load but not PipelineModel --- .../scala/org/apache/spark/ml/Pipeline.scala | 95 ++++++++++++++++++- .../org/apache/spark/ml/util/ReadWrite.scala | 2 +- .../org/apache/spark/ml/PipelineSuite.scala | 67 ++++++++++++- .../spark/ml/util/DefaultReadWriteTest.scala | 25 +++-- 4 files changed, 173 insertions(+), 16 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index a3e59401c5cfb..854d50fb967f8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -22,12 +22,19 @@ import java.{util => ju} import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.Logging import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.ml.param.{Param, ParamMap, Params} -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util.Reader +import org.apache.spark.ml.util.Writer +import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -82,7 +89,7 @@ abstract class PipelineStage extends Params with Logging { * an identity transformer. */ @Experimental -class Pipeline(override val uid: String) extends Estimator[PipelineModel] { +class Pipeline(override val uid: String) extends Estimator[PipelineModel] with Writable { def this() = this(Identifiable.randomUID("pipeline")) @@ -166,6 +173,90 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] { "Cannot have duplicate components in a pipeline.") theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur)) } + + override def write: Writer = new PipelineWriter(this) +} + +private[ml] class PipelineWriter(instance: Pipeline) extends Writer { + + import org.json4s.JsonDSL._ + + // Check that all stages are Writable + instance.getStages.foreach { + case stage: Writable => // good + case stage => + throw new UnsupportedOperationException("Pipeline.write will fail on this Pipeline because" + + s" it contains a stage which does not implement Writable. Non-Writable stage: ${stage.uid}") + } + + override protected def saveImpl(path: String): Unit = { + // Copied and edited from DefaultParamsWriter.saveMetadata + // TODO: modify DefaultParamsWriter.saveMetadata to avoid duplication + val uid = instance.uid + val cls = instance.getClass.getName + val stages = instance.getStages.map(_.uid) + val jsonParams = List("stageUids" -> parse(compact(render(stages.toSeq)))) + val metadata = ("class" -> cls) ~ + ("timestamp" -> System.currentTimeMillis()) ~ + ("sparkVersion" -> sc.version) ~ + ("uid" -> uid) ~ + ("paramMap" -> jsonParams) + val metadataPath = new Path(path, "metadata").toString + val metadataJson = compact(render(metadata)) + sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) + + // Save stages + val stagesDir = new Path(path, "stages").toString + instance.getStages.foreach { + case stage: Writable => + val stagePath = new Path(stagesDir, stage.uid).toString + stage.write.save(stagePath) + } + } +} + +object Pipeline extends Readable[Pipeline] { + + override def read: Reader[Pipeline] = new PipelineReader +} + +private[ml] class PipelineReader extends Reader[Pipeline] { + + /** Checked against metadata when loading model */ + private val className = "org.apache.spark.ml.Pipeline" + + override def load(path: String): Pipeline = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + implicit val format = DefaultFormats + val stagesDir = new Path(path, "stages").toString + val stageUids: Array[String] = metadata.params match { + case JObject(pairs) => + if (pairs.length != 1) { + // Should not happen unless file is corrupted or we have a bug. + throw new RuntimeException( + s"Pipeline.read expected 1 Param (stageUids), but found ${pairs.length}.") + } + pairs.head match { + case ("stageUids", jsonValue) => + parse(compact(render(jsonValue))).extract[Seq[String]].toArray + case (paramName, jsonValue) => + // Should not happen unless file is corrupted or we have a bug. + throw new RuntimeException(s"Pipeline.read encountered unexpected Param $paramName" + + s" in metadata: ${metadata.metadataStr}") + } + case _ => + throw new IllegalArgumentException( + s"Cannot recognize JSON metadata: ${metadata.metadataStr}.") + } + val stages: Array[PipelineStage] = stageUids.map { stageUid => + val stagePath = new Path(stagesDir, stageUid).toString + val stageMetadata = DefaultParamsReader.loadMetadata(stagePath, sc) + val cls = Utils.classForName(stageMetadata.className) + cls.getMethod("read").invoke(cls).asInstanceOf[Reader[PipelineStage]].load(stagePath) + } + new Pipeline(metadata.uid).setStages(stages) + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 85f888c9f2f67..03b524a3adb4e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -187,7 +187,7 @@ private[ml] object DefaultParamsWriter { * - timestamp * - sparkVersion * - uid - * - paramMap + * - paramMap: These must be encodable using [[org.apache.spark.ml.param.Param.jsonEncode()]]. */ def saveMetadata(instance: Params, path: String, sc: SparkContext): Unit = { val uid = instance.uid diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 1f2c9b75b617b..9eade52279c43 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -25,11 +25,13 @@ import org.scalatest.mock.MockitoSugar.mock import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.HashingTF -import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.param.{IntParam, ParamMap} +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.types.StructType -class PipelineSuite extends SparkFunSuite { +class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { abstract class MyModel extends Model[MyModel] @@ -111,4 +113,63 @@ class PipelineSuite extends SparkFunSuite { assert(pipelineModel1.uid === "pipeline1") assert(pipelineModel1.stages === stages) } + + test("read/write") { + val writableStage = new WritableStage("writableStage").setIntParam(56) + val pipeline = new Pipeline().setStages(Array(writableStage)) + + val pipeline2 = testDefaultReadWrite(pipeline, testParams = false) + assert(pipeline2.getStages.length === 1) + assert(pipeline2.getStages(0).isInstanceOf[WritableStage]) + val writableStage2 = pipeline2.getStages(0).asInstanceOf[WritableStage] + assert(writableStage.getIntParam === writableStage2.getIntParam) + } + + test("read/write with non-Writable stage") { + val unWritableStage = new UnWritableStage("unwritableStage") + val unWritablePipeline = new Pipeline().setStages(Array(unWritableStage)) + withClue("Pipeline.write should fail when Pipeline contains non-Writable stage") { + intercept[UnsupportedOperationException] { + unWritablePipeline.write + } + } + } +} + + +/** Used to test [[Pipeline]] with [[Writable]] stages */ +class WritableStage(override val uid: String) extends PipelineStage with Writable { + + final val intParam: IntParam = new IntParam(this, "intParam", "doc") + + def getIntParam: Int = $(intParam) + + def setIntParam(value: Int): this.type = set(intParam, value) + + setDefault(intParam -> 0) + + override def copy(extra: ParamMap): WritableStage = defaultCopy(extra) + + override def write: Writer = new DefaultParamsWriter(this) + + def transformSchema(schema: StructType): StructType = schema +} + +object WritableStage extends Readable[WritableStage] { + + override def read: Reader[WritableStage] = new DefaultParamsReader[WritableStage] + + override def load(path: String): WritableStage = read.load(path) +} + +/** Used to test [[Pipeline]] with non-[[Writable]] stages */ +class UnWritableStage(override val uid: String) extends PipelineStage { + + final val intParam: IntParam = new IntParam(this, "intParam", "doc") + + setDefault(intParam -> 0) + + override def copy(extra: ParamMap): UnWritableStage = defaultCopy(extra) + + def transformSchema(schema: StructType): StructType = schema } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index cac4bd9aa3ab8..c37f0503f1332 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -30,10 +30,13 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => /** * Checks "overwrite" option and params. * @param instance ML instance to test saving/loading + * @param testParams If true, then test values of Params. Otherwise, just test overwrite option. * @tparam T ML instance type * @return Instance loaded from file */ - def testDefaultReadWrite[T <: Params with Writable](instance: T): T = { + def testDefaultReadWrite[T <: Params with Writable]( + instance: T, + testParams: Boolean = true): T = { val uid = instance.uid val path = new File(tempDir, uid).getPath @@ -46,16 +49,18 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => val newInstance = loader.load(path) assert(newInstance.uid === instance.uid) - instance.params.foreach { p => - if (instance.isDefined(p)) { - (instance.getOrDefault(p), newInstance.getOrDefault(p)) match { - case (Array(values), Array(newValues)) => - assert(values === newValues, s"Values do not match on param ${p.name}.") - case (value, newValue) => - assert(value === newValue, s"Values do not match on param ${p.name}.") + if (testParams) { + instance.params.foreach { p => + if (instance.isDefined(p)) { + (instance.getOrDefault(p), newInstance.getOrDefault(p)) match { + case (Array(values), Array(newValues)) => + assert(values === newValues, s"Values do not match on param ${p.name}.") + case (value, newValue) => + assert(value === newValue, s"Values do not match on param ${p.name}.") + } + } else { + assert(!newInstance.isDefined(p), s"Param ${p.name} shouldn't be defined.") } - } else { - assert(!newInstance.isDefined(p), s"Param ${p.name} shouldn't be defined.") } } From 5d1339347a1379deac59bfbc18943ec3c0432847 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 12 Nov 2015 12:03:22 -0800 Subject: [PATCH 4/7] added PipelineModel save/load --- .../scala/org/apache/spark/ml/Pipeline.scala | 91 +++++++++++++++---- .../org/apache/spark/ml/util/ReadWrite.scala | 2 + .../org/apache/spark/ml/PipelineSuite.scala | 39 ++++++-- 3 files changed, 110 insertions(+), 22 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 854d50fb967f8..2599d9713f007 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.fs.Path import org.json4s._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.Logging +import org.apache.spark.{SparkContext, Logging} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.ml.param.{Param, ParamMap, Params} import org.apache.spark.ml.util.Reader @@ -177,25 +177,32 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] with W override def write: Writer = new PipelineWriter(this) } -private[ml] class PipelineWriter(instance: Pipeline) extends Writer { +private[ml] object PipelineSharedWriter { import org.json4s.JsonDSL._ - // Check that all stages are Writable - instance.getStages.foreach { - case stage: Writable => // good - case stage => - throw new UnsupportedOperationException("Pipeline.write will fail on this Pipeline because" + - s" it contains a stage which does not implement Writable. Non-Writable stage: ${stage.uid}") + /** Check that all stages are Writable */ + def validateStages(stages: Array[PipelineStage]): Unit = { + stages.foreach { + case stage: Writable => // good + case stage => + throw new UnsupportedOperationException("Pipeline write will fail on this Pipeline" + + s" because it contains a stage which does not implement Writable. Non-Writable stage:" + + s" ${stage.uid}") + } } - override protected def saveImpl(path: String): Unit = { + def saveImpl( + instance: Params, + stages: Array[PipelineStage], + sc: SparkContext, + path: String): Unit = { // Copied and edited from DefaultParamsWriter.saveMetadata // TODO: modify DefaultParamsWriter.saveMetadata to avoid duplication val uid = instance.uid val cls = instance.getClass.getName - val stages = instance.getStages.map(_.uid) - val jsonParams = List("stageUids" -> parse(compact(render(stages.toSeq)))) + val stageUids = stages.map(_.uid) + val jsonParams = List("stageUids" -> parse(compact(render(stageUids.toSeq)))) val metadata = ("class" -> cls) ~ ("timestamp" -> System.currentTimeMillis()) ~ ("sparkVersion" -> sc.version) ~ @@ -207,7 +214,7 @@ private[ml] class PipelineWriter(instance: Pipeline) extends Writer { // Save stages val stagesDir = new Path(path, "stages").toString - instance.getStages.foreach { + stages.foreach { case stage: Writable => val stagePath = new Path(stagesDir, stage.uid).toString stage.write.save(stagePath) @@ -215,9 +222,19 @@ private[ml] class PipelineWriter(instance: Pipeline) extends Writer { } } +private[ml] class PipelineWriter(instance: Pipeline) extends Writer { + + PipelineSharedWriter.validateStages(instance.getStages) + + override protected def saveImpl(path: String): Unit = + PipelineSharedWriter.saveImpl(instance, instance.getStages, sc, path) +} + object Pipeline extends Readable[Pipeline] { override def read: Reader[Pipeline] = new PipelineReader + + override def load(path: String): Pipeline = read.load(path) } private[ml] class PipelineReader extends Reader[Pipeline] { @@ -226,6 +243,14 @@ private[ml] class PipelineReader extends Reader[Pipeline] { private val className = "org.apache.spark.ml.Pipeline" override def load(path: String): Pipeline = { + val (uid: String, stages: Array[PipelineStage]) = PipelineSharedReader.load(className, sc, path) + new Pipeline(uid).setStages(stages) + } +} + +private[ml] object PipelineSharedReader { + + def load(className: String, sc: SparkContext, path: String): (String, Array[PipelineStage]) = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) implicit val format = DefaultFormats @@ -235,14 +260,14 @@ private[ml] class PipelineReader extends Reader[Pipeline] { if (pairs.length != 1) { // Should not happen unless file is corrupted or we have a bug. throw new RuntimeException( - s"Pipeline.read expected 1 Param (stageUids), but found ${pairs.length}.") + s"Pipeline read expected 1 Param (stageUids), but found ${pairs.length}.") } pairs.head match { case ("stageUids", jsonValue) => parse(compact(render(jsonValue))).extract[Seq[String]].toArray case (paramName, jsonValue) => // Should not happen unless file is corrupted or we have a bug. - throw new RuntimeException(s"Pipeline.read encountered unexpected Param $paramName" + + throw new RuntimeException(s"Pipeline read encountered unexpected Param $paramName" + s" in metadata: ${metadata.metadataStr}") } case _ => @@ -255,7 +280,7 @@ private[ml] class PipelineReader extends Reader[Pipeline] { val cls = Utils.classForName(stageMetadata.className) cls.getMethod("read").invoke(cls).asInstanceOf[Reader[PipelineStage]].load(stagePath) } - new Pipeline(metadata.uid).setStages(stages) + (metadata.uid, stages) } } @@ -267,7 +292,7 @@ private[ml] class PipelineReader extends Reader[Pipeline] { class PipelineModel private[ml] ( override val uid: String, val stages: Array[Transformer]) - extends Model[PipelineModel] with Logging { + extends Model[PipelineModel] with Writable with Logging { /** A Java/Python-friendly auxiliary constructor. */ private[ml] def this(uid: String, stages: ju.List[Transformer]) = { @@ -291,4 +316,38 @@ class PipelineModel private[ml] ( override def copy(extra: ParamMap): PipelineModel = { new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent) } + + override def write: Writer = new PipelineModelWriter(this) +} + +object PipelineModel extends Readable[PipelineModel] { + + override def read: Reader[PipelineModel] = new PipelineModelReader + + override def load(path: String): PipelineModel = read.load(path) +} + +private[ml] class PipelineModelWriter(instance: PipelineModel) extends Writer { + + PipelineSharedWriter.validateStages(instance.stages.asInstanceOf[Array[PipelineStage]]) + + override protected def saveImpl(path: String): Unit = PipelineSharedWriter.saveImpl(instance, + instance.stages.asInstanceOf[Array[PipelineStage]], sc, path) +} + +private[ml] class PipelineModelReader extends Reader[PipelineModel] { + + /** Checked against metadata when loading model */ + private val className = "org.apache.spark.ml.PipelineModel" + + override def load(path: String): PipelineModel = { + val (uid: String, stages: Array[PipelineStage]) = + PipelineSharedReader.load(className, sc, path) + val transformers = stages map { + case stage: Transformer => stage + case stage => throw new RuntimeException(s"PipelineModel.read loaded a stage but found it" + + s" was not a Transformer. Bad stage: ${stage.uid}") + } + new PipelineModel(uid, transformers) + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 03b524a3adb4e..836e04a0db3e2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -161,6 +161,8 @@ trait Readable[T] { /** * Reads an ML instance from the input path, a shortcut of `read.load(path)`. + * + * Note: Implementing classes should override this to be Java-friendly. */ @Since("1.6.0") def load(path: String): T = read.load(path) diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 9eade52279c43..9e7e3575ab803 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -114,7 +114,7 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(pipelineModel1.stages === stages) } - test("read/write") { + test("Pipeline read/write") { val writableStage = new WritableStage("writableStage").setIntParam(56) val pipeline = new Pipeline().setStages(Array(writableStage)) @@ -125,7 +125,7 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(writableStage.getIntParam === writableStage2.getIntParam) } - test("read/write with non-Writable stage") { + test("Pipeline read/write with non-Writable stage") { val unWritableStage = new UnWritableStage("unwritableStage") val unWritablePipeline = new Pipeline().setStages(Array(unWritableStage)) withClue("Pipeline.write should fail when Pipeline contains non-Writable stage") { @@ -134,11 +134,34 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } } } + + test("PipelineModel read/write") { + val writableStage = new WritableStage("writableStage").setIntParam(56) + val pipeline = + new PipelineModel("pipeline_89329327", Array(writableStage.asInstanceOf[Transformer])) + + val pipeline2 = testDefaultReadWrite(pipeline, testParams = false) + assert(pipeline2.stages.length === 1) + assert(pipeline2.stages(0).isInstanceOf[WritableStage]) + val writableStage2 = pipeline2.stages(0).asInstanceOf[WritableStage] + assert(writableStage.getIntParam === writableStage2.getIntParam) + } + + test("PipelineModel read/write with non-Writable stage") { + val unWritableStage = new UnWritableStage("unwritableStage") + val unWritablePipeline = + new PipelineModel("pipeline_328957", Array(unWritableStage.asInstanceOf[Transformer])) + withClue("PipelineModel.write should fail when PipelineModel contains non-Writable stage") { + intercept[UnsupportedOperationException] { + unWritablePipeline.write + } + } + } } /** Used to test [[Pipeline]] with [[Writable]] stages */ -class WritableStage(override val uid: String) extends PipelineStage with Writable { +class WritableStage(override val uid: String) extends Transformer with Writable { final val intParam: IntParam = new IntParam(this, "intParam", "doc") @@ -152,7 +175,9 @@ class WritableStage(override val uid: String) extends PipelineStage with Writabl override def write: Writer = new DefaultParamsWriter(this) - def transformSchema(schema: StructType): StructType = schema + override def transform(dataset: DataFrame): DataFrame = dataset + + override def transformSchema(schema: StructType): StructType = schema } object WritableStage extends Readable[WritableStage] { @@ -163,7 +188,7 @@ object WritableStage extends Readable[WritableStage] { } /** Used to test [[Pipeline]] with non-[[Writable]] stages */ -class UnWritableStage(override val uid: String) extends PipelineStage { +class UnWritableStage(override val uid: String) extends Transformer { final val intParam: IntParam = new IntParam(this, "intParam", "doc") @@ -171,5 +196,7 @@ class UnWritableStage(override val uid: String) extends PipelineStage { override def copy(extra: ParamMap): UnWritableStage = defaultCopy(extra) - def transformSchema(schema: StructType): StructType = schema + override def transform(dataset: DataFrame): DataFrame = dataset + + override def transformSchema(schema: StructType): StructType = schema } From caf57c2c48d90e3f4160626ca77a173260756cfe Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 12 Nov 2015 12:08:16 -0800 Subject: [PATCH 5/7] reorder for Pipeline.scala classes --- .../scala/org/apache/spark/ml/Pipeline.scala | 170 +++++++++--------- 1 file changed, 86 insertions(+), 84 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 2599d9713f007..eb18ba4d61180 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -177,49 +177,11 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] with W override def write: Writer = new PipelineWriter(this) } -private[ml] object PipelineSharedWriter { - - import org.json4s.JsonDSL._ - - /** Check that all stages are Writable */ - def validateStages(stages: Array[PipelineStage]): Unit = { - stages.foreach { - case stage: Writable => // good - case stage => - throw new UnsupportedOperationException("Pipeline write will fail on this Pipeline" + - s" because it contains a stage which does not implement Writable. Non-Writable stage:" + - s" ${stage.uid}") - } - } +object Pipeline extends Readable[Pipeline] { - def saveImpl( - instance: Params, - stages: Array[PipelineStage], - sc: SparkContext, - path: String): Unit = { - // Copied and edited from DefaultParamsWriter.saveMetadata - // TODO: modify DefaultParamsWriter.saveMetadata to avoid duplication - val uid = instance.uid - val cls = instance.getClass.getName - val stageUids = stages.map(_.uid) - val jsonParams = List("stageUids" -> parse(compact(render(stageUids.toSeq)))) - val metadata = ("class" -> cls) ~ - ("timestamp" -> System.currentTimeMillis()) ~ - ("sparkVersion" -> sc.version) ~ - ("uid" -> uid) ~ - ("paramMap" -> jsonParams) - val metadataPath = new Path(path, "metadata").toString - val metadataJson = compact(render(metadata)) - sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) + override def read: Reader[Pipeline] = new PipelineReader - // Save stages - val stagesDir = new Path(path, "stages").toString - stages.foreach { - case stage: Writable => - val stagePath = new Path(stagesDir, stage.uid).toString - stage.write.save(stagePath) - } - } + override def load(path: String): Pipeline = read.load(path) } private[ml] class PipelineWriter(instance: Pipeline) extends Writer { @@ -230,13 +192,6 @@ private[ml] class PipelineWriter(instance: Pipeline) extends Writer { PipelineSharedWriter.saveImpl(instance, instance.getStages, sc, path) } -object Pipeline extends Readable[Pipeline] { - - override def read: Reader[Pipeline] = new PipelineReader - - override def load(path: String): Pipeline = read.load(path) -} - private[ml] class PipelineReader extends Reader[Pipeline] { /** Checked against metadata when loading model */ @@ -248,42 +203,6 @@ private[ml] class PipelineReader extends Reader[Pipeline] { } } -private[ml] object PipelineSharedReader { - - def load(className: String, sc: SparkContext, path: String): (String, Array[PipelineStage]) = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) - - implicit val format = DefaultFormats - val stagesDir = new Path(path, "stages").toString - val stageUids: Array[String] = metadata.params match { - case JObject(pairs) => - if (pairs.length != 1) { - // Should not happen unless file is corrupted or we have a bug. - throw new RuntimeException( - s"Pipeline read expected 1 Param (stageUids), but found ${pairs.length}.") - } - pairs.head match { - case ("stageUids", jsonValue) => - parse(compact(render(jsonValue))).extract[Seq[String]].toArray - case (paramName, jsonValue) => - // Should not happen unless file is corrupted or we have a bug. - throw new RuntimeException(s"Pipeline read encountered unexpected Param $paramName" + - s" in metadata: ${metadata.metadataStr}") - } - case _ => - throw new IllegalArgumentException( - s"Cannot recognize JSON metadata: ${metadata.metadataStr}.") - } - val stages: Array[PipelineStage] = stageUids.map { stageUid => - val stagePath = new Path(stagesDir, stageUid).toString - val stageMetadata = DefaultParamsReader.loadMetadata(stagePath, sc) - val cls = Utils.classForName(stageMetadata.className) - cls.getMethod("read").invoke(cls).asInstanceOf[Reader[PipelineStage]].load(stagePath) - } - (metadata.uid, stages) - } -} - /** * :: Experimental :: * Represents a fitted pipeline. @@ -351,3 +270,86 @@ private[ml] class PipelineModelReader extends Reader[PipelineModel] { new PipelineModel(uid, transformers) } } + +/** Methods for [[Writer]] shared between [[Pipeline]] and [[PipelineModel]] */ +private[ml] object PipelineSharedWriter { + + import org.json4s.JsonDSL._ + + /** Check that all stages are Writable */ + def validateStages(stages: Array[PipelineStage]): Unit = { + stages.foreach { + case stage: Writable => // good + case stage => + throw new UnsupportedOperationException("Pipeline write will fail on this Pipeline" + + s" because it contains a stage which does not implement Writable. Non-Writable stage:" + + s" ${stage.uid}") + } + } + + def saveImpl( + instance: Params, + stages: Array[PipelineStage], + sc: SparkContext, + path: String): Unit = { + // Copied and edited from DefaultParamsWriter.saveMetadata + // TODO: modify DefaultParamsWriter.saveMetadata to avoid duplication + val uid = instance.uid + val cls = instance.getClass.getName + val stageUids = stages.map(_.uid) + val jsonParams = List("stageUids" -> parse(compact(render(stageUids.toSeq)))) + val metadata = ("class" -> cls) ~ + ("timestamp" -> System.currentTimeMillis()) ~ + ("sparkVersion" -> sc.version) ~ + ("uid" -> uid) ~ + ("paramMap" -> jsonParams) + val metadataPath = new Path(path, "metadata").toString + val metadataJson = compact(render(metadata)) + sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) + + // Save stages + val stagesDir = new Path(path, "stages").toString + stages.foreach { + case stage: Writable => + val stagePath = new Path(stagesDir, stage.uid).toString + stage.write.save(stagePath) + } + } +} + +/** Methods for [[Reader]] shared between [[Pipeline]] and [[PipelineModel]] */ +private[ml] object PipelineSharedReader { + + def load(className: String, sc: SparkContext, path: String): (String, Array[PipelineStage]) = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + implicit val format = DefaultFormats + val stagesDir = new Path(path, "stages").toString + val stageUids: Array[String] = metadata.params match { + case JObject(pairs) => + if (pairs.length != 1) { + // Should not happen unless file is corrupted or we have a bug. + throw new RuntimeException( + s"Pipeline read expected 1 Param (stageUids), but found ${pairs.length}.") + } + pairs.head match { + case ("stageUids", jsonValue) => + parse(compact(render(jsonValue))).extract[Seq[String]].toArray + case (paramName, jsonValue) => + // Should not happen unless file is corrupted or we have a bug. + throw new RuntimeException(s"Pipeline read encountered unexpected Param $paramName" + + s" in metadata: ${metadata.metadataStr}") + } + case _ => + throw new IllegalArgumentException( + s"Cannot recognize JSON metadata: ${metadata.metadataStr}.") + } + val stages: Array[PipelineStage] = stageUids.map { stageUid => + val stagePath = new Path(stagesDir, stageUid).toString + val stageMetadata = DefaultParamsReader.loadMetadata(stagePath, sc) + val cls = Utils.classForName(stageMetadata.className) + cls.getMethod("read").invoke(cls).asInstanceOf[Reader[PipelineStage]].load(stagePath) + } + (metadata.uid, stages) + } +} From 1d1d31c9af8dc4a2429b0f627fcff76907df06a5 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 16 Nov 2015 14:59:21 -0800 Subject: [PATCH 6/7] Cleanups per code review, including adding stage index to stage paths --- .../scala/org/apache/spark/ml/Pipeline.scala | 34 ++++++++++++------- .../org/apache/spark/ml/PipelineSuite.scala | 26 ++++++++++++++ 2 files changed, 48 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index eb18ba4d61180..aba7d511e03c9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -264,8 +264,8 @@ private[ml] class PipelineModelReader extends Reader[PipelineModel] { PipelineSharedReader.load(className, sc, path) val transformers = stages map { case stage: Transformer => stage - case stage => throw new RuntimeException(s"PipelineModel.read loaded a stage but found it" + - s" was not a Transformer. Bad stage: ${stage.uid}") + case other => throw new RuntimeException(s"PipelineModel.read loaded a stage but found it" + + s" was not a Transformer. Bad stage ${other.uid} of type ${other.getClass}") } new PipelineModel(uid, transformers) } @@ -280,13 +280,17 @@ private[ml] object PipelineSharedWriter { def validateStages(stages: Array[PipelineStage]): Unit = { stages.foreach { case stage: Writable => // good - case stage => + case other => throw new UnsupportedOperationException("Pipeline write will fail on this Pipeline" + s" because it contains a stage which does not implement Writable. Non-Writable stage:" + - s" ${stage.uid}") + s" ${other.uid} of type ${other.getClass}") } } + /** + * Save metadata to path/metadata + * Save stages to stages/IDX_UID + */ def saveImpl( instance: Params, stages: Array[PipelineStage], @@ -309,12 +313,18 @@ private[ml] object PipelineSharedWriter { // Save stages val stagesDir = new Path(path, "stages").toString - stages.foreach { - case stage: Writable => - val stagePath = new Path(stagesDir, stage.uid).toString - stage.write.save(stagePath) + stages.zipWithIndex.foreach { case (stage: Writable, idx: Int) => + stage.write.save(getStagePath(stage.uid, idx, stages.length, stagesDir)) } } + + /** Get path for saving the given stage. Used by [[PipelineSharedReader]] as well */ + def getStagePath(stageUid: String, stageIdx: Int, numStages: Int, stagesDir: String): String = { + val stageIdxDigits = numStages.toString.length + val idxFormat = s"%0${stageIdxDigits}d" + val stageDir = idxFormat.format(stageIdx) + "_" + stageUid + new Path(stagesDir, stageDir).toString + } } /** Methods for [[Reader]] shared between [[Pipeline]] and [[PipelineModel]] */ @@ -334,7 +344,7 @@ private[ml] object PipelineSharedReader { } pairs.head match { case ("stageUids", jsonValue) => - parse(compact(render(jsonValue))).extract[Seq[String]].toArray + jsonValue.extract[Seq[String]].toArray case (paramName, jsonValue) => // Should not happen unless file is corrupted or we have a bug. throw new RuntimeException(s"Pipeline read encountered unexpected Param $paramName" + @@ -344,11 +354,11 @@ private[ml] object PipelineSharedReader { throw new IllegalArgumentException( s"Cannot recognize JSON metadata: ${metadata.metadataStr}.") } - val stages: Array[PipelineStage] = stageUids.map { stageUid => - val stagePath = new Path(stagesDir, stageUid).toString + val stages: Array[PipelineStage] = stageUids.zipWithIndex.map { case (stageUid, idx) => + val stagePath = PipelineSharedWriter.getStagePath(stageUid, idx, stageUids.length, stagesDir) val stageMetadata = DefaultParamsReader.loadMetadata(stagePath, sc) val cls = Utils.classForName(stageMetadata.className) - cls.getMethod("read").invoke(cls).asInstanceOf[Reader[PipelineStage]].load(stagePath) + cls.getMethod("read").invoke(null).asInstanceOf[Reader[PipelineStage]].load(stagePath) } (metadata.uid, stages) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 9e7e3575ab803..a668d7a27f809 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -17,8 +17,11 @@ package org.apache.spark.ml +import java.io.File + import scala.collection.JavaConverters._ +import org.apache.hadoop.fs.{FileSystem, Path} import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito.when import org.scalatest.mock.MockitoSugar.mock @@ -145,6 +148,29 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(pipeline2.stages(0).isInstanceOf[WritableStage]) val writableStage2 = pipeline2.stages(0).asInstanceOf[WritableStage] assert(writableStage.getIntParam === writableStage2.getIntParam) + + val path = new File(tempDir, pipeline.uid).getPath + val stagesDir = new Path(path, "stages").toString + val expectedStagePath = PipelineSharedWriter.getStagePath(writableStage.uid, 0, 1, stagesDir) + assert(FileSystem.get(sc.hadoopConfiguration).exists(new Path(expectedStagePath)), + s"Expected stage 0 of 1 with uid ${writableStage.uid} in Pipeline with uid ${pipeline.uid}" + + s" to be saved to path: $expectedStagePath") + } + + test("PipelineModel read/write: getStagePath") { + val stageUid = "myStage" + val stagesDir = new Path("pipeline", "stages").toString + import PipelineSharedWriter.getStagePath + def testStage(stageIdx: Int, numStages: Int, expectedPrefix: String): Unit = { + val path = getStagePath(stageUid, stageIdx, numStages, stagesDir) + val expected = new Path(stagesDir, expectedPrefix + "_" + stageUid).toString + assert(path === expected) + } + testStage(0, 1, "0") + testStage(0, 9, "0") + testStage(0, 10, "00") + testStage(1, 10, "01") + testStage(12, 999, "012") } test("PipelineModel read/write with non-Writable stage") { From f791010ec056ab9684329083e9bb63f142b3150c Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 16 Nov 2015 15:07:41 -0800 Subject: [PATCH 7/7] refactored Pipeline reader and writer classes to be under Pipeline, PipelineModel, to clean up namespace --- .../scala/org/apache/spark/ml/Pipeline.scala | 251 +++++++++--------- .../org/apache/spark/ml/PipelineSuite.scala | 6 +- 2 files changed, 131 insertions(+), 126 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index aba7d511e03c9..25f0c696f42be 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -174,7 +174,7 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] with W theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur)) } - override def write: Writer = new PipelineWriter(this) + override def write: Writer = new Pipeline.PipelineWriter(this) } object Pipeline extends Readable[Pipeline] { @@ -182,24 +182,121 @@ object Pipeline extends Readable[Pipeline] { override def read: Reader[Pipeline] = new PipelineReader override def load(path: String): Pipeline = read.load(path) -} -private[ml] class PipelineWriter(instance: Pipeline) extends Writer { + private[ml] class PipelineWriter(instance: Pipeline) extends Writer { - PipelineSharedWriter.validateStages(instance.getStages) + SharedReadWrite.validateStages(instance.getStages) - override protected def saveImpl(path: String): Unit = - PipelineSharedWriter.saveImpl(instance, instance.getStages, sc, path) -} + override protected def saveImpl(path: String): Unit = + SharedReadWrite.saveImpl(instance, instance.getStages, sc, path) + } + + private[ml] class PipelineReader extends Reader[Pipeline] { + + /** Checked against metadata when loading model */ + private val className = "org.apache.spark.ml.Pipeline" + + override def load(path: String): Pipeline = { + val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path) + new Pipeline(uid).setStages(stages) + } + } + + /** Methods for [[Reader]] and [[Writer]] shared between [[Pipeline]] and [[PipelineModel]] */ + private[ml] object SharedReadWrite { + + import org.json4s.JsonDSL._ + + /** Check that all stages are Writable */ + def validateStages(stages: Array[PipelineStage]): Unit = { + stages.foreach { + case stage: Writable => // good + case other => + throw new UnsupportedOperationException("Pipeline write will fail on this Pipeline" + + s" because it contains a stage which does not implement Writable. Non-Writable stage:" + + s" ${other.uid} of type ${other.getClass}") + } + } -private[ml] class PipelineReader extends Reader[Pipeline] { + /** + * Save metadata and stages for a [[Pipeline]] or [[PipelineModel]] + * - save metadata to path/metadata + * - save stages to stages/IDX_UID + */ + def saveImpl( + instance: Params, + stages: Array[PipelineStage], + sc: SparkContext, + path: String): Unit = { + // Copied and edited from DefaultParamsWriter.saveMetadata + // TODO: modify DefaultParamsWriter.saveMetadata to avoid duplication + val uid = instance.uid + val cls = instance.getClass.getName + val stageUids = stages.map(_.uid) + val jsonParams = List("stageUids" -> parse(compact(render(stageUids.toSeq)))) + val metadata = ("class" -> cls) ~ + ("timestamp" -> System.currentTimeMillis()) ~ + ("sparkVersion" -> sc.version) ~ + ("uid" -> uid) ~ + ("paramMap" -> jsonParams) + val metadataPath = new Path(path, "metadata").toString + val metadataJson = compact(render(metadata)) + sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) + + // Save stages + val stagesDir = new Path(path, "stages").toString + stages.zipWithIndex.foreach { case (stage: Writable, idx: Int) => + stage.write.save(getStagePath(stage.uid, idx, stages.length, stagesDir)) + } + } - /** Checked against metadata when loading model */ - private val className = "org.apache.spark.ml.Pipeline" + /** + * Load metadata and stages for a [[Pipeline]] or [[PipelineModel]] + * @return (UID, list of stages) + */ + def load( + expectedClassName: String, + sc: SparkContext, + path: String): (String, Array[PipelineStage]) = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName) + + implicit val format = DefaultFormats + val stagesDir = new Path(path, "stages").toString + val stageUids: Array[String] = metadata.params match { + case JObject(pairs) => + if (pairs.length != 1) { + // Should not happen unless file is corrupted or we have a bug. + throw new RuntimeException( + s"Pipeline read expected 1 Param (stageUids), but found ${pairs.length}.") + } + pairs.head match { + case ("stageUids", jsonValue) => + jsonValue.extract[Seq[String]].toArray + case (paramName, jsonValue) => + // Should not happen unless file is corrupted or we have a bug. + throw new RuntimeException(s"Pipeline read encountered unexpected Param $paramName" + + s" in metadata: ${metadata.metadataStr}") + } + case _ => + throw new IllegalArgumentException( + s"Cannot recognize JSON metadata: ${metadata.metadataStr}.") + } + val stages: Array[PipelineStage] = stageUids.zipWithIndex.map { case (stageUid, idx) => + val stagePath = SharedReadWrite.getStagePath(stageUid, idx, stageUids.length, stagesDir) + val stageMetadata = DefaultParamsReader.loadMetadata(stagePath, sc) + val cls = Utils.classForName(stageMetadata.className) + cls.getMethod("read").invoke(null).asInstanceOf[Reader[PipelineStage]].load(stagePath) + } + (metadata.uid, stages) + } - override def load(path: String): Pipeline = { - val (uid: String, stages: Array[PipelineStage]) = PipelineSharedReader.load(className, sc, path) - new Pipeline(uid).setStages(stages) + /** Get path for saving the given stage. */ + def getStagePath(stageUid: String, stageIdx: Int, numStages: Int, stagesDir: String): String = { + val stageIdxDigits = numStages.toString.length + val idxFormat = s"%0${stageIdxDigits}d" + val stageDir = idxFormat.format(stageIdx) + "_" + stageUid + new Path(stagesDir, stageDir).toString + } } } @@ -236,130 +333,38 @@ class PipelineModel private[ml] ( new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent) } - override def write: Writer = new PipelineModelWriter(this) + override def write: Writer = new PipelineModel.PipelineModelWriter(this) } object PipelineModel extends Readable[PipelineModel] { + import Pipeline.SharedReadWrite + override def read: Reader[PipelineModel] = new PipelineModelReader override def load(path: String): PipelineModel = read.load(path) -} -private[ml] class PipelineModelWriter(instance: PipelineModel) extends Writer { + private[ml] class PipelineModelWriter(instance: PipelineModel) extends Writer { - PipelineSharedWriter.validateStages(instance.stages.asInstanceOf[Array[PipelineStage]]) + SharedReadWrite.validateStages(instance.stages.asInstanceOf[Array[PipelineStage]]) - override protected def saveImpl(path: String): Unit = PipelineSharedWriter.saveImpl(instance, - instance.stages.asInstanceOf[Array[PipelineStage]], sc, path) -} - -private[ml] class PipelineModelReader extends Reader[PipelineModel] { - - /** Checked against metadata when loading model */ - private val className = "org.apache.spark.ml.PipelineModel" - - override def load(path: String): PipelineModel = { - val (uid: String, stages: Array[PipelineStage]) = - PipelineSharedReader.load(className, sc, path) - val transformers = stages map { - case stage: Transformer => stage - case other => throw new RuntimeException(s"PipelineModel.read loaded a stage but found it" + - s" was not a Transformer. Bad stage ${other.uid} of type ${other.getClass}") - } - new PipelineModel(uid, transformers) + override protected def saveImpl(path: String): Unit = SharedReadWrite.saveImpl(instance, + instance.stages.asInstanceOf[Array[PipelineStage]], sc, path) } -} - -/** Methods for [[Writer]] shared between [[Pipeline]] and [[PipelineModel]] */ -private[ml] object PipelineSharedWriter { - import org.json4s.JsonDSL._ + private[ml] class PipelineModelReader extends Reader[PipelineModel] { - /** Check that all stages are Writable */ - def validateStages(stages: Array[PipelineStage]): Unit = { - stages.foreach { - case stage: Writable => // good - case other => - throw new UnsupportedOperationException("Pipeline write will fail on this Pipeline" + - s" because it contains a stage which does not implement Writable. Non-Writable stage:" + - s" ${other.uid} of type ${other.getClass}") - } - } + /** Checked against metadata when loading model */ + private val className = "org.apache.spark.ml.PipelineModel" - /** - * Save metadata to path/metadata - * Save stages to stages/IDX_UID - */ - def saveImpl( - instance: Params, - stages: Array[PipelineStage], - sc: SparkContext, - path: String): Unit = { - // Copied and edited from DefaultParamsWriter.saveMetadata - // TODO: modify DefaultParamsWriter.saveMetadata to avoid duplication - val uid = instance.uid - val cls = instance.getClass.getName - val stageUids = stages.map(_.uid) - val jsonParams = List("stageUids" -> parse(compact(render(stageUids.toSeq)))) - val metadata = ("class" -> cls) ~ - ("timestamp" -> System.currentTimeMillis()) ~ - ("sparkVersion" -> sc.version) ~ - ("uid" -> uid) ~ - ("paramMap" -> jsonParams) - val metadataPath = new Path(path, "metadata").toString - val metadataJson = compact(render(metadata)) - sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) - - // Save stages - val stagesDir = new Path(path, "stages").toString - stages.zipWithIndex.foreach { case (stage: Writable, idx: Int) => - stage.write.save(getStagePath(stage.uid, idx, stages.length, stagesDir)) - } - } - - /** Get path for saving the given stage. Used by [[PipelineSharedReader]] as well */ - def getStagePath(stageUid: String, stageIdx: Int, numStages: Int, stagesDir: String): String = { - val stageIdxDigits = numStages.toString.length - val idxFormat = s"%0${stageIdxDigits}d" - val stageDir = idxFormat.format(stageIdx) + "_" + stageUid - new Path(stagesDir, stageDir).toString - } -} - -/** Methods for [[Reader]] shared between [[Pipeline]] and [[PipelineModel]] */ -private[ml] object PipelineSharedReader { - - def load(className: String, sc: SparkContext, path: String): (String, Array[PipelineStage]) = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) - - implicit val format = DefaultFormats - val stagesDir = new Path(path, "stages").toString - val stageUids: Array[String] = metadata.params match { - case JObject(pairs) => - if (pairs.length != 1) { - // Should not happen unless file is corrupted or we have a bug. - throw new RuntimeException( - s"Pipeline read expected 1 Param (stageUids), but found ${pairs.length}.") - } - pairs.head match { - case ("stageUids", jsonValue) => - jsonValue.extract[Seq[String]].toArray - case (paramName, jsonValue) => - // Should not happen unless file is corrupted or we have a bug. - throw new RuntimeException(s"Pipeline read encountered unexpected Param $paramName" + - s" in metadata: ${metadata.metadataStr}") - } - case _ => - throw new IllegalArgumentException( - s"Cannot recognize JSON metadata: ${metadata.metadataStr}.") - } - val stages: Array[PipelineStage] = stageUids.zipWithIndex.map { case (stageUid, idx) => - val stagePath = PipelineSharedWriter.getStagePath(stageUid, idx, stageUids.length, stagesDir) - val stageMetadata = DefaultParamsReader.loadMetadata(stagePath, sc) - val cls = Utils.classForName(stageMetadata.className) - cls.getMethod("read").invoke(null).asInstanceOf[Reader[PipelineStage]].load(stagePath) + override def load(path: String): PipelineModel = { + val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path) + val transformers = stages map { + case stage: Transformer => stage + case other => throw new RuntimeException(s"PipelineModel.read loaded a stage but found it" + + s" was not a Transformer. Bad stage ${other.uid} of type ${other.getClass}") + } + new PipelineModel(uid, transformers) } - (metadata.uid, stages) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index a668d7a27f809..484026b1ba9ad 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -27,6 +27,7 @@ import org.mockito.Mockito.when import org.scalatest.mock.MockitoSugar.mock import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.Pipeline.SharedReadWrite import org.apache.spark.ml.feature.HashingTF import org.apache.spark.ml.param.{IntParam, ParamMap} import org.apache.spark.ml.util._ @@ -151,7 +152,7 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val path = new File(tempDir, pipeline.uid).getPath val stagesDir = new Path(path, "stages").toString - val expectedStagePath = PipelineSharedWriter.getStagePath(writableStage.uid, 0, 1, stagesDir) + val expectedStagePath = SharedReadWrite.getStagePath(writableStage.uid, 0, 1, stagesDir) assert(FileSystem.get(sc.hadoopConfiguration).exists(new Path(expectedStagePath)), s"Expected stage 0 of 1 with uid ${writableStage.uid} in Pipeline with uid ${pipeline.uid}" + s" to be saved to path: $expectedStagePath") @@ -160,9 +161,8 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("PipelineModel read/write: getStagePath") { val stageUid = "myStage" val stagesDir = new Path("pipeline", "stages").toString - import PipelineSharedWriter.getStagePath def testStage(stageIdx: Int, numStages: Int, expectedPrefix: String): Unit = { - val path = getStagePath(stageUid, stageIdx, numStages, stagesDir) + val path = SharedReadWrite.getStagePath(stageUid, stageIdx, numStages, stagesDir) val expected = new Path(stagesDir, expectedPrefix + "_" + stageUid).toString assert(path === expected) }