From cd1c7eae3246f93b6ee4e044443adfe57fdf1386 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 3 Nov 2015 10:56:22 -0800 Subject: [PATCH 1/9] initial implementation --- .../apache/spark/ml/feature/Binarizer.scala | 10 +- .../org/apache/spark/ml/param/params.scala | 2 +- .../org/apache/spark/ml/util/saveload.scala | 225 ++++++++++++++++++ .../spark/ml/feature/BinarizerSuite.scala | 25 ++ .../spark/ml/util/DefaultSaveLoadTest.scala | 103 ++++++++ .../apache/spark/ml/util/TempDirectory.scala | 44 ++++ 6 files changed, 406 insertions(+), 3 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/util/saveload.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/util/DefaultSaveLoadTest.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index edad75443645..c99ee3e01e9b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.BinaryAttribute import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructType} @@ -33,7 +33,7 @@ import org.apache.spark.sql.types.{DoubleType, StructType} */ @Experimental final class Binarizer(override val uid: String) - extends Transformer with HasInputCol with HasOutputCol { + extends Transformer with Saveable with HasInputCol with HasOutputCol { def this() = this(Identifiable.randomUID("binarizer")) @@ -86,4 +86,10 @@ final class Binarizer(override val uid: String) } override def copy(extra: ParamMap): Binarizer = defaultCopy(extra) + + override def save: Saver = new DefaultParamsSaver(this) +} + +object Binarizer { + def load: Loader[Binarizer] = new DefaultParamsLoader[Binarizer] } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 8361406f8729..c9325709187c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -592,7 +592,7 @@ trait Params extends Identifiable with Serializable { /** * Sets a parameter in the embedded param map. */ - protected final def set[T](param: Param[T], value: T): this.type = { + final def set[T](param: Param[T], value: T): this.type = { set(param -> value) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/saveload.scala b/mllib/src/main/scala/org/apache/spark/ml/util/saveload.scala new file mode 100644 index 000000000000..af5c2dfdf205 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/saveload.scala @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import java.io.IOException +import java.{util => ju} + +import scala.annotation.varargs +import scala.collection.mutable +import scala.collection.JavaConverters._ + +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.spark.annotation.{Since, Experimental} +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.ml.param.{Param, ParamPair, Params} +import org.apache.spark.sql.SQLContext + +/** + * Trait for [[Saver]] and [[Loader]]. + */ +private[util] sealed trait BaseSaveLoad { + private var optionSQLContext: Option[SQLContext] = None + + /** + * User-specified options. + */ + protected final var options: mutable.Map[String, String] = mutable.Map.empty + + /** + * Java-friendly version of [[options]]. + */ + protected final def javaOptions: ju.Map[String, String] = options.asJava + + /** + * Sets the SQL context to use for saving/loading. + */ + def context(sqlContext: SQLContext): this.type = { + optionSQLContext = Option(sqlContext) + this + } + + /** + * Returns the user-specified SQL context or the default. + */ + protected final def sqlContext: SQLContext = optionSQLContext.getOrElse { + SQLContext.getOrCreate(SparkContext.getOrCreate()) + } + + /** + * Adds one or more options as (key, value) pairs. + */ + def options(first: (String, String), others: (String, String)*): this.type = { + options += first + options ++= others + this + } + + /** + * Adds one or more options with alternating key and value strings. + * @param k1 first key + * @param v1 first value + * @param others other options, must be paired + */ + @varargs + def options(k1: String, v1: String, others: String*): this.type = { + options += k1 -> v1 + require(others.length % 2 == 0, + s"Options must be specified in pairs but got: ${others.mkString(",")}.") + others.grouped(2).foreach { case Seq(k, v) => + options += k -> v + } + this + } + + /** + * Adds options as a Scala map. + * @return + */ + def options(options: Map[String, String]): this.type = { + this.options ++= options + this + } + + /** + * Adds options as a Java map. + */ + def options(options: ju.Map[String, String]): this.type = { + this.options ++= options.asScala + this + } +} + +/** + * Abstract class for utility classes that can save ML instances. + */ +@Experimental +@Since("1.6.0") +abstract class Saver extends BaseSaveLoad { + + /** + * Saves the ML instance to the input path. + */ + def to(path: String): Unit +} + +/** + * Trait for classes that provide [[Saver]]. + */ +@Since("1.6.0") +trait Saveable { + + /** + * Returns a [[Saver]] instance for this class. + */ + def save: Saver +} + +/** + * Abstract class for utility classes that can load ML instances. + * @tparam T ML instance type + */ +@Experimental +@Since("1.6.0") +abstract class Loader[T] extends BaseSaveLoad { + + /** + * Loads the ML component from the input path. + */ + def from(path: String): T +} + +/** + * Default [[Saver]] implementation for non-meta transformers and estimators. + * @param instance object to save + */ +private[ml] class DefaultParamsSaver(instance: Params) extends Saver with Logging { + + options("overwrite" -> "false") + + /** + * Saves the ML component to the input path. + */ + override def to(path: String): Unit = { + val sc = sqlContext.sparkContext + + val hadoopConf = sc.hadoopConfiguration + val fs = FileSystem.get(hadoopConf) + val p = new Path(path) + if (fs.exists(p)) { + if (options("overwrite").toBoolean) { + logInfo(s"Path $path already exists. It will be overwritten.") + fs.delete(p, true) + } else { + throw new IOException( + s"Path $path already exists. Please set overwrite=true to overwrite it.") + } + } + + val uid = instance.uid + val cls = instance.getClass.getName + val params = instance.params.asInstanceOf[Array[Param[Any]]] + .flatMap(p => instance.get(p).map(v => p -> v)) + val jsonParams = params.map { case ParamPair(p, v) => + p.name -> parse(p.jsonEncode(v)) + }.toList + val metadata = ("class" -> cls) ~ + ("timestamp" -> System.currentTimeMillis()) ~ + ("uid" -> uid) ~ + ("params" -> jsonParams) + val metadataPath = new Path(path, "metadata").toString + val metadataJson = compact(render(metadata)) + sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) + } +} + +/** + * Default [[Loader]] implementation for non-meta transformers and estimators. + * @tparam T ML instance type + */ +private[ml] class DefaultParamsLoader[T] extends Loader[T] { + + /** + * Loads the ML component from the input path. + */ + override def from(path: String): T = { + implicit val format = DefaultFormats + val sc = sqlContext.sparkContext + val metadataPath = new Path(path, "metadata").toString + val metadataStr = sc.textFile(metadataPath, 1).first() + val metadata = parse(metadataStr) + val cls = Class.forName((metadata \ "class").extract[String]) + val uid = (metadata \ "uid").extract[String] + val instance = cls.getConstructor(classOf[String]).newInstance(uid).asInstanceOf[Params] + (metadata \ "params") match { + case JObject(pairs) => + pairs.foreach { case (paramName, jsonValue) => + val param = instance.getParam(paramName) + val value = param.jsonDecode(compact(render(jsonValue))) + instance.set(param, value) + } + case _ => + throw new IllegalArgumentException(s"Cannot recognize JSON metadata: $metadataStr.") + } + instance.asInstanceOf[T] + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index 208604398366..065be99faf43 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -17,18 +17,28 @@ package org.apache.spark.ml.feature +import java.io.File + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.util.Utils class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext { + var tmpDir: File = _ @transient var data: Array[Double] = _ override def beforeAll(): Unit = { super.beforeAll() data = Array(0.1, -0.5, 0.2, -0.3, 0.8, 0.7, -0.1, -0.4) + tmpDir = Utils.createTempDir(namePrefix = "BinarizerSuite") + } + + override def afterAll(): Unit = { + Utils.deleteRecursively(tmpDir) + super.afterAll() } test("params") { @@ -66,4 +76,19 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext { assert(x === y, "The feature value is not correct after binarization.") } } + + test("save/load") { + val outputPath = new File(tmpDir, "saveload").getPath + println(outputPath) + val binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + .setThreshold(0.1) + binarizer.save.to(outputPath) + val newBinarizer = Binarizer.load.from(outputPath) + assert(newBinarizer.uid === binarizer.uid) + for (param <- binarizer.params) { + assert(binarizer.get(param) === newBinarizer.get(param)) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultSaveLoadTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultSaveLoadTest.scala new file mode 100644 index 000000000000..502ebbdd95d5 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultSaveLoadTest.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import java.io.{File, IOException} + +import org.scalatest.Suite + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param._ +import org.apache.spark.mllib.util.MLlibTestSparkContext + +trait DefaultSaveLoadTest extends TempDirectory { self: Suite => + + /** + * Checks "overwrite" option and params. + * @param instance ML instance to test saving/loading + * @tparam T ML instance type + */ + def testDefaultSaveLoad[T <: Params with Saveable](instance: T): Unit = { + val uid = instance.uid + val path = new File(tempDir, uid).getPath + + instance.save.to(path) + intercept[IOException] { + instance.save.to(path) + } + instance.save.options("overwrite" -> "true").to(path) + + val loader = instance.getClass.getMethod("load").invoke(null).asInstanceOf[Loader[T]] + val newInstance = loader.from(path) + + assert(newInstance.uid === instance.uid) + instance.params.foreach { p => + (instance.get(p), newInstance.get(p)) match { + case (None, None) => + case (Some(Array(values)), Some(Array(newValues))) => + assert(values === newValues) + case (Some(value), Some(newValue)) => + assert(value === newValue) + case (value, newValue) => + fail(s"Param values do not match, expecting $value but got $newValue.") + } + } + } +} + +class MyParams(override val uid: String) extends Params with Saveable { + + final val intParamWithDefault: IntParam = new IntParam(this, "intParamWithDefault", "doc") + final val intParam: IntParam = new IntParam(this, "intParam", "doc") + final val floatParam: FloatParam = new FloatParam(this, "floatParam", "doc") + final val doubleParam: DoubleParam = new DoubleParam(this, "doubleParam", "doc") + final val longParam: LongParam = new LongParam(this, "longParam", "doc") + final val stringParam: Param[String] = new Param[String](this, "stringParam", "doc") + final val intArrayParam: IntArrayParam = new IntArrayParam(this, "intArrayParam", "doc") + final val doubleArrayParam: DoubleArrayParam = + new DoubleArrayParam(this, "doubleArrayParam", "doc") + final val stringArrayParam: StringArrayParam = + new StringArrayParam(this, "stringArrayParam", "doc") + + setDefault(intParamWithDefault -> 0) + set(intParam -> 1) + set(floatParam -> 2.0f) + set(doubleParam -> 3.0) + set(longParam -> 4L) + set(stringParam -> "5") + set(intArrayParam -> Array(6, 7)) + set(doubleArrayParam -> Array(8.0, 9.0)) + set(stringArrayParam -> Array("10", "11")) + + override def copy(extra: ParamMap): Params = defaultCopy(extra) + + override def save: Saver = new DefaultParamsSaver(this) +} + +object MyParams { + def load: Loader[MyParams] = new DefaultParamsLoader[MyParams] +} + +class DefaultSaveLoadSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultSaveLoadTest { + + test("default save/load") { + val uid = "my_params" + val myParams = new MyParams(uid) + testDefaultSaveLoad(myParams) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala new file mode 100644 index 000000000000..9b2de148d2b7 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import java.io.File + +import org.apache.spark.util.Utils +import org.scalatest.{BeforeAndAfterAll, Suite} + +/** + * Trait that creates a temporary directory before all tests and deletes it after all. + */ +trait TempDirectory extends BeforeAndAfterAll { self: Suite => + + private var _tempDir: File = _ + + /** Returns the temporary directory as a [[File]] instance. */ + protected def tempDir: File = _tempDir + + override def beforeAll(): Unit = { + super.beforeAll() + _tempDir = Utils.createTempDir(this.getClass.getName) + } + + override def afterAll(): Unit = { + Utils.deleteRecursively(_tempDir) + super.afterAll() + } +} From df81d61f73c6a854913df638770f0b0409f046a3 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 3 Nov 2015 15:41:58 -0800 Subject: [PATCH 2/9] update doc and test --- .../apache/spark/ml/feature/Binarizer.scala | 5 +- .../org/apache/spark/ml/util/saveload.scala | 64 ++++++++++++++----- .../spark/ml/feature/BinarizerSuite.scala | 22 +------ .../spark/ml/util/DefaultSaveLoadTest.scala | 23 +++---- 4 files changed, 67 insertions(+), 47 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index c99ee3e01e9b..dfeeede01ef4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -90,6 +90,7 @@ final class Binarizer(override val uid: String) override def save: Saver = new DefaultParamsSaver(this) } -object Binarizer { - def load: Loader[Binarizer] = new DefaultParamsLoader[Binarizer] +object Binarizer extends Loadable[Binarizer] { + + override def load: Loader[Binarizer] = new DefaultParamsLoader[Binarizer] } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/saveload.scala b/mllib/src/main/scala/org/apache/spark/ml/util/saveload.scala index af5c2dfdf205..8aab6b3a1d5b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/saveload.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/saveload.scala @@ -17,21 +17,21 @@ package org.apache.spark.ml.util -import java.io.IOException import java.{util => ju} +import java.io.IOException import scala.annotation.varargs import scala.collection.mutable import scala.collection.JavaConverters._ import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.annotation.{Since, Experimental} import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, SparkContext} -import org.apache.spark.ml.param.{Param, ParamPair, Params} +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.param.{ParamPair, Params} import org.apache.spark.sql.SQLContext /** @@ -43,7 +43,7 @@ private[util] sealed trait BaseSaveLoad { /** * User-specified options. */ - protected final var options: mutable.Map[String, String] = mutable.Map.empty + protected final val options: mutable.Map[String, String] = mutable.Map.empty /** * Java-friendly version of [[options]]. @@ -53,6 +53,7 @@ private[util] sealed trait BaseSaveLoad { /** * Sets the SQL context to use for saving/loading. */ + @Since("1.6.0") def context(sqlContext: SQLContext): this.type = { optionSQLContext = Option(sqlContext) this @@ -68,6 +69,7 @@ private[util] sealed trait BaseSaveLoad { /** * Adds one or more options as (key, value) pairs. */ + @Since("1.6.0") def options(first: (String, String), others: (String, String)*): this.type = { options += first options ++= others @@ -81,6 +83,7 @@ private[util] sealed trait BaseSaveLoad { * @param others other options, must be paired */ @varargs + @Since("1.6.0") def options(k1: String, v1: String, others: String*): this.type = { options += k1 -> v1 require(others.length % 2 == 0, @@ -92,17 +95,18 @@ private[util] sealed trait BaseSaveLoad { } /** - * Adds options as a Scala map. - * @return + * Adds options as a Scala map (overwrites if an option already exists). */ + @Since("1.6.0") def options(options: Map[String, String]): this.type = { this.options ++= options this } /** - * Adds options as a Java map. + * Adds options as a Java map (overwrites if an option already exists). */ + @Since("1.6.0") def options(options: ju.Map[String, String]): this.type = { this.options ++= options.asScala this @@ -115,11 +119,27 @@ private[util] sealed trait BaseSaveLoad { @Experimental @Since("1.6.0") abstract class Saver extends BaseSaveLoad { + import Saver._ /** * Saves the ML instance to the input path. */ + @Since("1.6.0") def to(path: String): Unit + + /** + * Tells whether we should overwrite if the output directory exists (default: false). + */ + protected final def shouldOverwrite: Boolean = { + options.get(Overwrite).map(_.toBoolean).getOrElse(false) + } +} + +@Experimental +@Since("1.6.0") +object Saver { + /** Option key to control overwrite. */ + val Overwrite: String = "overwrite" } /** @@ -129,8 +149,9 @@ abstract class Saver extends BaseSaveLoad { trait Saveable { /** - * Returns a [[Saver]] instance for this class. + * Returns a [[Saver]] instance for this ML instance. */ + @Since("1.6.0") def save: Saver } @@ -145,17 +166,31 @@ abstract class Loader[T] extends BaseSaveLoad { /** * Loads the ML component from the input path. */ + @Since("1.6.0") def from(path: String): T } +/** + * Trait for objects that provide [[Loader]]. + * @tparam T ML instance type + */ +@Experimental +@Since("1.6.0") +trait Loadable[T] { + + /** + * Returns a [[Loader]] instance for this class. + */ + @Since("1.6.0") + def load: Loader[T] +} + /** * Default [[Saver]] implementation for non-meta transformers and estimators. * @param instance object to save */ private[ml] class DefaultParamsSaver(instance: Params) extends Saver with Logging { - options("overwrite" -> "false") - /** * Saves the ML component to the input path. */ @@ -166,7 +201,7 @@ private[ml] class DefaultParamsSaver(instance: Params) extends Saver with Loggin val fs = FileSystem.get(hadoopConf) val p = new Path(path) if (fs.exists(p)) { - if (options("overwrite").toBoolean) { + if (shouldOverwrite) { logInfo(s"Path $path already exists. It will be overwritten.") fs.delete(p, true) } else { @@ -177,15 +212,14 @@ private[ml] class DefaultParamsSaver(instance: Params) extends Saver with Loggin val uid = instance.uid val cls = instance.getClass.getName - val params = instance.params.asInstanceOf[Array[Param[Any]]] - .flatMap(p => instance.get(p).map(v => p -> v)) + val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] val jsonParams = params.map { case ParamPair(p, v) => p.name -> parse(p.jsonEncode(v)) }.toList val metadata = ("class" -> cls) ~ ("timestamp" -> System.currentTimeMillis()) ~ ("uid" -> uid) ~ - ("params" -> jsonParams) + ("paramMap" -> jsonParams) val metadataPath = new Path(path, "metadata").toString val metadataJson = compact(render(metadata)) sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) @@ -210,7 +244,7 @@ private[ml] class DefaultParamsLoader[T] extends Loader[T] { val cls = Class.forName((metadata \ "class").extract[String]) val uid = (metadata \ "uid").extract[String] val instance = cls.getConstructor(classOf[String]).newInstance(uid).asInstanceOf[Params] - (metadata \ "params") match { + (metadata \ "paramMap") match { case JObject(pairs) => pairs.foreach { case (paramName, jsonValue) => val param = instance.getParam(paramName) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index 065be99faf43..8a43653dd6c3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -17,28 +17,19 @@ package org.apache.spark.ml.feature -import java.io.File - import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultSaveLoadTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.util.Utils -class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext { +class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultSaveLoadTest { - var tmpDir: File = _ @transient var data: Array[Double] = _ override def beforeAll(): Unit = { super.beforeAll() data = Array(0.1, -0.5, 0.2, -0.3, 0.8, 0.7, -0.1, -0.4) - tmpDir = Utils.createTempDir(namePrefix = "BinarizerSuite") - } - - override def afterAll(): Unit = { - Utils.deleteRecursively(tmpDir) - super.afterAll() } test("params") { @@ -78,17 +69,10 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext { } test("save/load") { - val outputPath = new File(tmpDir, "saveload").getPath - println(outputPath) val binarizer = new Binarizer() .setInputCol("feature") .setOutputCol("binarized_feature") .setThreshold(0.1) - binarizer.save.to(outputPath) - val newBinarizer = Binarizer.load.from(outputPath) - assert(newBinarizer.uid === binarizer.uid) - for (param <- binarizer.params) { - assert(binarizer.get(param) === newBinarizer.get(param)) - } + testDefaultSaveLoad(binarizer) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultSaveLoadTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultSaveLoadTest.scala index 502ebbdd95d5..c3b117f69891 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultSaveLoadTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultSaveLoadTest.scala @@ -47,14 +47,15 @@ trait DefaultSaveLoadTest extends TempDirectory { self: Suite => assert(newInstance.uid === instance.uid) instance.params.foreach { p => - (instance.get(p), newInstance.get(p)) match { - case (None, None) => - case (Some(Array(values)), Some(Array(newValues))) => - assert(values === newValues) - case (Some(value), Some(newValue)) => - assert(value === newValue) - case (value, newValue) => - fail(s"Param values do not match, expecting $value but got $newValue.") + if (instance.isDefined(p)) { + (instance.getOrDefault(p), newInstance.getOrDefault(p)) match { + case (Array(values), Array(newValues)) => + assert(values !== newValues, s"Values do not match on param ${p.name}.") + case (value, newValue) => + assert(value === newValue, s"Values do not match on param ${p.name}.") + } + } else { + assert(!newInstance.isDefined(p), s"Param ${p.name} shouldn't be defined.") } } } @@ -93,11 +94,11 @@ object MyParams { def load: Loader[MyParams] = new DefaultParamsLoader[MyParams] } -class DefaultSaveLoadSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultSaveLoadTest { +class DefaultSaveLoadSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultSaveLoadTest { test("default save/load") { - val uid = "my_params" - val myParams = new MyParams(uid) + val myParams = new MyParams("my_params") testDefaultSaveLoad(myParams) } } From e01e92d92f3f799356dd6a8cebc60002899090e9 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 3 Nov 2015 16:49:39 -0800 Subject: [PATCH 3/9] fix Scala style --- mllib/src/main/scala/org/apache/spark/ml/util/saveload.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/saveload.scala b/mllib/src/main/scala/org/apache/spark/ml/util/saveload.scala index 8aab6b3a1d5b..5ac401cf344a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/saveload.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/saveload.scala @@ -33,6 +33,7 @@ import org.apache.spark.{Logging, SparkContext} import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.{ParamPair, Params} import org.apache.spark.sql.SQLContext +import org.apache.spark.util.Utils /** * Trait for [[Saver]] and [[Loader]]. @@ -241,7 +242,7 @@ private[ml] class DefaultParamsLoader[T] extends Loader[T] { val metadataPath = new Path(path, "metadata").toString val metadataStr = sc.textFile(metadataPath, 1).first() val metadata = parse(metadataStr) - val cls = Class.forName((metadata \ "class").extract[String]) + val cls = Utils.classForName((metadata \ "class").extract[String]) val uid = (metadata \ "uid").extract[String] val instance = cls.getConstructor(classOf[String]).newInstance(uid).asInstanceOf[Params] (metadata \ "paramMap") match { From bc8611d070e58326a31ef6c1d7f95d043839b42e Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 5 Nov 2015 08:45:38 -0800 Subject: [PATCH 4/9] rename save/load to write/read to be compatible with DataFrames API --- .../apache/spark/ml/feature/Binarizer.scala | 8 +-- .../util/{saveload.scala => ReadWrite.scala} | 49 ++++++++++++------- .../spark/ml/feature/BinarizerSuite.scala | 8 +-- ...dTest.scala => DefaultReadWriteTest.scala} | 26 +++++----- 4 files changed, 51 insertions(+), 40 deletions(-) rename mllib/src/main/scala/org/apache/spark/ml/util/{saveload.scala => ReadWrite.scala} (85%) rename mllib/src/test/scala/org/apache/spark/ml/util/{DefaultSaveLoadTest.scala => DefaultReadWriteTest.scala} (81%) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index dfeeede01ef4..e5c25574d4b1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.types.{DoubleType, StructType} */ @Experimental final class Binarizer(override val uid: String) - extends Transformer with Saveable with HasInputCol with HasOutputCol { + extends Transformer with Writable with HasInputCol with HasOutputCol { def this() = this(Identifiable.randomUID("binarizer")) @@ -87,10 +87,10 @@ final class Binarizer(override val uid: String) override def copy(extra: ParamMap): Binarizer = defaultCopy(extra) - override def save: Saver = new DefaultParamsSaver(this) + override def write: Writer = new DefaultParamsWriter(this) } -object Binarizer extends Loadable[Binarizer] { +object Binarizer extends Readable[Binarizer] { - override def load: Loader[Binarizer] = new DefaultParamsLoader[Binarizer] + override def read: Reader[Binarizer] = new DefaultParamsReader[Binarizer] } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/saveload.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala similarity index 85% rename from mllib/src/main/scala/org/apache/spark/ml/util/saveload.scala rename to mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 5ac401cf344a..6a9dd710146b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/saveload.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -36,9 +36,9 @@ import org.apache.spark.sql.SQLContext import org.apache.spark.util.Utils /** - * Trait for [[Saver]] and [[Loader]]. + * Trait for [[Writer]] and [[Reader]]. */ -private[util] sealed trait BaseSaveLoad { +private[util] sealed trait BaseReadWrite { private var optionSQLContext: Option[SQLContext] = None /** @@ -119,8 +119,8 @@ private[util] sealed trait BaseSaveLoad { */ @Experimental @Since("1.6.0") -abstract class Saver extends BaseSaveLoad { - import Saver._ +abstract class Writer extends BaseReadWrite { + import Writer._ /** * Saves the ML instance to the input path. @@ -128,6 +128,12 @@ abstract class Saver extends BaseSaveLoad { @Since("1.6.0") def to(path: String): Unit + /** + * Saves the ML instances to the input path, the same as [[to()]]. + */ + @Since("1.6.0") + def save(path: String): Unit = to(path) + /** * Tells whether we should overwrite if the output directory exists (default: false). */ @@ -138,22 +144,22 @@ abstract class Saver extends BaseSaveLoad { @Experimental @Since("1.6.0") -object Saver { +object Writer { /** Option key to control overwrite. */ val Overwrite: String = "overwrite" } /** - * Trait for classes that provide [[Saver]]. + * Trait for classes that provide [[Writer]]. */ @Since("1.6.0") -trait Saveable { +trait Writable { /** - * Returns a [[Saver]] instance for this ML instance. + * Returns a [[Writer]] instance for this ML instance. */ @Since("1.6.0") - def save: Saver + def write: Writer } /** @@ -162,35 +168,40 @@ trait Saveable { */ @Experimental @Since("1.6.0") -abstract class Loader[T] extends BaseSaveLoad { +abstract class Reader[T] extends BaseReadWrite { /** * Loads the ML component from the input path. */ @Since("1.6.0") def from(path: String): T + + /** + * Loads the ML component from the input path, the same as [[from()]]. + */ + def load(path: String): T = from(path) } /** - * Trait for objects that provide [[Loader]]. + * Trait for objects that provide [[Reader]]. * @tparam T ML instance type */ @Experimental @Since("1.6.0") -trait Loadable[T] { - +trait Readable[T] { + /** - * Returns a [[Loader]] instance for this class. + * Returns a [[Reader]] instance for this class. */ @Since("1.6.0") - def load: Loader[T] + def read: Reader[T] } /** - * Default [[Saver]] implementation for non-meta transformers and estimators. + * Default [[Writer]] implementation for non-meta transformers and estimators. * @param instance object to save */ -private[ml] class DefaultParamsSaver(instance: Params) extends Saver with Logging { +private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logging { /** * Saves the ML component to the input path. @@ -228,10 +239,10 @@ private[ml] class DefaultParamsSaver(instance: Params) extends Saver with Loggin } /** - * Default [[Loader]] implementation for non-meta transformers and estimators. + * Default [[Reader]] implementation for non-meta transformers and estimators. * @tparam T ML instance type */ -private[ml] class DefaultParamsLoader[T] extends Loader[T] { +private[ml] class DefaultParamsReader[T] extends Reader[T] { /** * Loads the ML component from the input path. diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index 8a43653dd6c3..9dfa1439cc30 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -19,11 +19,11 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultSaveLoadTest +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} -class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultSaveLoadTest { +class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @transient var data: Array[Double] = _ @@ -68,11 +68,11 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau } } - test("save/load") { + test("read/write") { val binarizer = new Binarizer() .setInputCol("feature") .setOutputCol("binarized_feature") .setThreshold(0.1) - testDefaultSaveLoad(binarizer) + testDefaultReadWrite(binarizer) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultSaveLoadTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala similarity index 81% rename from mllib/src/test/scala/org/apache/spark/ml/util/DefaultSaveLoadTest.scala rename to mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index c3b117f69891..3dba486a1981 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultSaveLoadTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -25,24 +25,24 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param._ import org.apache.spark.mllib.util.MLlibTestSparkContext -trait DefaultSaveLoadTest extends TempDirectory { self: Suite => +trait DefaultReadWriteTest extends TempDirectory { self: Suite => /** * Checks "overwrite" option and params. * @param instance ML instance to test saving/loading * @tparam T ML instance type */ - def testDefaultSaveLoad[T <: Params with Saveable](instance: T): Unit = { + def testDefaultReadWrite[T <: Params with Writable](instance: T): Unit = { val uid = instance.uid val path = new File(tempDir, uid).getPath - instance.save.to(path) + instance.write.to(path) intercept[IOException] { - instance.save.to(path) + instance.write.to(path) } - instance.save.options("overwrite" -> "true").to(path) + instance.write.options("overwrite" -> "true").to(path) - val loader = instance.getClass.getMethod("load").invoke(null).asInstanceOf[Loader[T]] + val loader = instance.getClass.getMethod("read").invoke(null).asInstanceOf[Reader[T]] val newInstance = loader.from(path) assert(newInstance.uid === instance.uid) @@ -61,7 +61,7 @@ trait DefaultSaveLoadTest extends TempDirectory { self: Suite => } } -class MyParams(override val uid: String) extends Params with Saveable { +class MyParams(override val uid: String) extends Params with Writable { final val intParamWithDefault: IntParam = new IntParam(this, "intParamWithDefault", "doc") final val intParam: IntParam = new IntParam(this, "intParam", "doc") @@ -87,18 +87,18 @@ class MyParams(override val uid: String) extends Params with Saveable { override def copy(extra: ParamMap): Params = defaultCopy(extra) - override def save: Saver = new DefaultParamsSaver(this) + override def write: Writer = new DefaultParamsWriter(this) } object MyParams { - def load: Loader[MyParams] = new DefaultParamsLoader[MyParams] + def load: Reader[MyParams] = new DefaultParamsReader[MyParams] } -class DefaultSaveLoadSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultSaveLoadTest { +class DefaultReadWriteSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { - test("default save/load") { + test("default read/write") { val myParams = new MyParams("my_params") - testDefaultSaveLoad(myParams) + testDefaultReadWrite(myParams) } } From dd57812b2e2b102ff04c9d94a198fa5ca49b28c3 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 5 Nov 2015 10:34:11 -0800 Subject: [PATCH 5/9] add a test in Java --- .../org/apache/spark/ml/util/ReadWrite.scala | 2 + .../ml/util/JavaDefaultReadWriteSuite.java | 72 +++++++++++++++++++ 2 files changed, 74 insertions(+) create mode 100644 mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 6a9dd710146b..ebf6426995bd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -126,12 +126,14 @@ abstract class Writer extends BaseReadWrite { * Saves the ML instance to the input path. */ @Since("1.6.0") + @throws[IOException]("If the input path already exists but overwrite is not enabled.") def to(path: String): Unit /** * Saves the ML instances to the input path, the same as [[to()]]. */ @Since("1.6.0") + @throws[IOException]("If the input path already exists but overwrite is not enabled.") def save(path: String): Unit = to(path) /** diff --git a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java new file mode 100644 index 000000000000..e1e67e02a3c6 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util; + +import java.io.File; +import java.io.IOException; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.util.Utils; + +public class JavaDefaultReadWriteSuite { + + JavaSparkContext jsc = null; + File tempDir = null; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local[2]", "JavaDefaultReadWriteSuite"); + tempDir = Utils.createTempDir( + System.getProperty("java.io.tmpdir"), "JavaDefaultReadWriteSuite"); + } + + @After + public void tearDown() { + if (jsc != null) { + jsc.stop(); + jsc = null; + } + Utils.deleteRecursively(tempDir); + } + + @Test + public void testDefaultReadWrite() throws IOException { + String uid = "my_params"; + MyParams instance = new MyParams(uid); + instance.set(instance.intParam(), 2); + String outputPath = new File(tempDir, uid).getPath(); + instance.write().to(outputPath); + try { + instance.write().to(outputPath); + Assert.fail( + "Write without overwrite enabled should fail if the output directory already exists."); + } catch (IOException e) { + // expected + } + instance.write().options("overwrite", "true").to(outputPath); + MyParams newInstance = MyParams.load().from(outputPath); + Assert.assertEquals("UID should match.", instance.uid(), newInstance.uid()); + Assert.assertEquals("Params should be preserved.", + 2, newInstance.getOrDefault(newInstance.intParam())); + } +} From 59d1c5eac4294928961cec8b3caeb3ac4cfb15d2 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 5 Nov 2015 12:08:20 -0800 Subject: [PATCH 6/9] remove options --- .../org/apache/spark/ml/util/ReadWrite.scala | 75 ++----------------- .../ml/util/JavaDefaultReadWriteSuite.java | 2 +- .../spark/ml/util/DefaultReadWriteTest.scala | 3 +- .../apache/spark/ml/util/TempDirectory.scala | 3 +- 4 files changed, 11 insertions(+), 72 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index ebf6426995bd..6f4db5a4e777 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -41,16 +41,6 @@ import org.apache.spark.util.Utils private[util] sealed trait BaseReadWrite { private var optionSQLContext: Option[SQLContext] = None - /** - * User-specified options. - */ - protected final val options: mutable.Map[String, String] = mutable.Map.empty - - /** - * Java-friendly version of [[options]]. - */ - protected final def javaOptions: ju.Map[String, String] = options.asJava - /** * Sets the SQL context to use for saving/loading. */ @@ -66,52 +56,6 @@ private[util] sealed trait BaseReadWrite { protected final def sqlContext: SQLContext = optionSQLContext.getOrElse { SQLContext.getOrCreate(SparkContext.getOrCreate()) } - - /** - * Adds one or more options as (key, value) pairs. - */ - @Since("1.6.0") - def options(first: (String, String), others: (String, String)*): this.type = { - options += first - options ++= others - this - } - - /** - * Adds one or more options with alternating key and value strings. - * @param k1 first key - * @param v1 first value - * @param others other options, must be paired - */ - @varargs - @Since("1.6.0") - def options(k1: String, v1: String, others: String*): this.type = { - options += k1 -> v1 - require(others.length % 2 == 0, - s"Options must be specified in pairs but got: ${others.mkString(",")}.") - others.grouped(2).foreach { case Seq(k, v) => - options += k -> v - } - this - } - - /** - * Adds options as a Scala map (overwrites if an option already exists). - */ - @Since("1.6.0") - def options(options: Map[String, String]): this.type = { - this.options ++= options - this - } - - /** - * Adds options as a Java map (overwrites if an option already exists). - */ - @Since("1.6.0") - def options(options: ju.Map[String, String]): this.type = { - this.options ++= options.asScala - this - } } /** @@ -120,7 +64,8 @@ private[util] sealed trait BaseReadWrite { @Experimental @Since("1.6.0") abstract class Writer extends BaseReadWrite { - import Writer._ + + protected var shouldOverwrite: Boolean = false /** * Saves the ML instance to the input path. @@ -135,22 +80,16 @@ abstract class Writer extends BaseReadWrite { @Since("1.6.0") @throws[IOException]("If the input path already exists but overwrite is not enabled.") def save(path: String): Unit = to(path) - + /** - * Tells whether we should overwrite if the output directory exists (default: false). + * Overwrites if the output path already exists. */ - protected final def shouldOverwrite: Boolean = { - options.get(Overwrite).map(_.toBoolean).getOrElse(false) + def overwrite(): this.type = { + shouldOverwrite = true + this } } -@Experimental -@Since("1.6.0") -object Writer { - /** Option key to control overwrite. */ - val Overwrite: String = "overwrite" -} - /** * Trait for classes that provide [[Writer]]. */ diff --git a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java index e1e67e02a3c6..1fe717b50239 100644 --- a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java @@ -63,7 +63,7 @@ public void testDefaultReadWrite() throws IOException { } catch (IOException e) { // expected } - instance.write().options("overwrite", "true").to(outputPath); + instance.write().overwrite().to(outputPath); MyParams newInstance = MyParams.load().from(outputPath); Assert.assertEquals("UID should match.", instance.uid(), newInstance.uid()); Assert.assertEquals("Params should be preserved.", diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index 3dba486a1981..18335d486551 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -40,8 +40,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => intercept[IOException] { instance.write.to(path) } - instance.write.options("overwrite" -> "true").to(path) - + instance.write.overwrite().to(path) val loader = instance.getClass.getMethod("read").invoke(null).asInstanceOf[Reader[T]] val newInstance = loader.from(path) diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala index 9b2de148d2b7..2742026a69c2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala @@ -19,9 +19,10 @@ package org.apache.spark.ml.util import java.io.File -import org.apache.spark.util.Utils import org.scalatest.{BeforeAndAfterAll, Suite} +import org.apache.spark.util.Utils + /** * Trait that creates a temporary directory before all tests and deletes it after all. */ From a41053860ee0540bfff97d8e7e13100c2110c8df Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 6 Nov 2015 08:13:20 -0800 Subject: [PATCH 7/9] address comments --- .../org/apache/spark/ml/util/ReadWrite.scala | 21 ++++++++++++------- .../ml/util/JavaDefaultReadWriteSuite.java | 6 ++++-- .../spark/ml/util/DefaultReadWriteTest.scala | 6 +++--- 3 files changed, 20 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 6f4db5a4e777..3a590ef544b3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -17,13 +17,8 @@ package org.apache.spark.ml.util -import java.{util => ju} import java.io.IOException -import scala.annotation.varargs -import scala.collection.mutable -import scala.collection.JavaConverters._ - import org.apache.hadoop.fs.{FileSystem, Path} import org.json4s._ import org.json4s.JsonDSL._ @@ -88,6 +83,9 @@ abstract class Writer extends BaseReadWrite { shouldOverwrite = true this } + + // override for Java compatibility + override def context(sqlContext: SQLContext): this.type = super.context(sqlContext) } /** @@ -121,6 +119,9 @@ abstract class Reader[T] extends BaseReadWrite { * Loads the ML component from the input path, the same as [[from()]]. */ def load(path: String): T = from(path) + + // override for Java compatibility + override def context(sqlContext: SQLContext): this.type = super.context(sqlContext) } /** @@ -130,7 +131,7 @@ abstract class Reader[T] extends BaseReadWrite { @Experimental @Since("1.6.0") trait Readable[T] { - + /** * Returns a [[Reader]] instance for this class. */ @@ -139,7 +140,9 @@ trait Readable[T] { } /** - * Default [[Writer]] implementation for non-meta transformers and estimators. + * Default [[Writer]] implementation for transformers and estimators that contain basic + * (json4s-serializable) params and no data. This will not handle more complex params or types with + * data (e.g., models with coefficients). * @param instance object to save */ private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logging { @@ -180,7 +183,9 @@ private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logg } /** - * Default [[Reader]] implementation for non-meta transformers and estimators. + * Default [[Reader]] implementation for transformers and estimators that contain basic + * (json4s-serializable) params and no data. This will not handle more complex params or types with + * data (e.g., models with coefficients). * @tparam T ML instance type */ private[ml] class DefaultParamsReader[T] extends Reader[T] { diff --git a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java index 1fe717b50239..4ea954fe059a 100644 --- a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java @@ -26,6 +26,7 @@ import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; import org.apache.spark.util.Utils; public class JavaDefaultReadWriteSuite { @@ -63,8 +64,9 @@ public void testDefaultReadWrite() throws IOException { } catch (IOException e) { // expected } - instance.write().overwrite().to(outputPath); - MyParams newInstance = MyParams.load().from(outputPath); + SQLContext sqlContext = new SQLContext(jsc); + instance.write().context(sqlContext).overwrite().to(outputPath); + MyParams newInstance = MyParams.read().from(outputPath); Assert.assertEquals("UID should match.", instance.uid(), newInstance.uid()); Assert.assertEquals("Params should be preserved.", 2, newInstance.getOrDefault(newInstance.intParam())); diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index 18335d486551..fb2f6ed9e6ae 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -49,7 +49,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => if (instance.isDefined(p)) { (instance.getOrDefault(p), newInstance.getOrDefault(p)) match { case (Array(values), Array(newValues)) => - assert(values !== newValues, s"Values do not match on param ${p.name}.") + assert(values === newValues, s"Values do not match on param ${p.name}.") case (value, newValue) => assert(value === newValue, s"Values do not match on param ${p.name}.") } @@ -89,8 +89,8 @@ class MyParams(override val uid: String) extends Params with Writable { override def write: Writer = new DefaultParamsWriter(this) } -object MyParams { - def load: Reader[MyParams] = new DefaultParamsReader[MyParams] +object MyParams extends Readable[MyParams] { + override def read: Reader[MyParams] = new DefaultParamsReader[MyParams] } class DefaultReadWriteSuite extends SparkFunSuite with MLlibTestSparkContext From f862b6a997faeaa3df770779549193eae20f38d7 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 6 Nov 2015 08:48:11 -0800 Subject: [PATCH 8/9] remove from/to --- .../org/apache/spark/ml/util/ReadWrite.scala | 39 ++++++++++--------- .../ml/util/JavaDefaultReadWriteSuite.java | 8 ++-- .../spark/ml/util/DefaultReadWriteTest.scala | 14 +++++-- 3 files changed, 35 insertions(+), 26 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 3a590ef544b3..ea790e0dddc7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -63,22 +63,16 @@ abstract class Writer extends BaseReadWrite { protected var shouldOverwrite: Boolean = false /** - * Saves the ML instance to the input path. + * Saves the ML instances to the input path. */ @Since("1.6.0") @throws[IOException]("If the input path already exists but overwrite is not enabled.") - def to(path: String): Unit - - /** - * Saves the ML instances to the input path, the same as [[to()]]. - */ - @Since("1.6.0") - @throws[IOException]("If the input path already exists but overwrite is not enabled.") - def save(path: String): Unit = to(path) + def save(path: String): Unit /** * Overwrites if the output path already exists. */ + @Since("1.6.0") def overwrite(): this.type = { shouldOverwrite = true this @@ -99,6 +93,13 @@ trait Writable { */ @Since("1.6.0") def write: Writer + + /** + * Saves this ML instance to the input path, a shortcut of `write.save(path)`. + */ + @Since("1.6.0") + @throws[IOException]("If the input path already exists but overwrite is not enabled.") + def save(path: String): Unit = write.save(path) } /** @@ -113,12 +114,7 @@ abstract class Reader[T] extends BaseReadWrite { * Loads the ML component from the input path. */ @Since("1.6.0") - def from(path: String): T - - /** - * Loads the ML component from the input path, the same as [[from()]]. - */ - def load(path: String): T = from(path) + def load(path: String): T // override for Java compatibility override def context(sqlContext: SQLContext): this.type = super.context(sqlContext) @@ -137,6 +133,12 @@ trait Readable[T] { */ @Since("1.6.0") def read: Reader[T] + + /** + * Reads an ML instance from the input path, a shortcut of `read.load(path)`. + */ + @Since("1.6.0") + def load(path: String): T = read.load(path) } /** @@ -150,7 +152,7 @@ private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logg /** * Saves the ML component to the input path. */ - override def to(path: String): Unit = { + override def save(path: String): Unit = { val sc = sqlContext.sparkContext val hadoopConf = sc.hadoopConfiguration @@ -159,10 +161,11 @@ private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logg if (fs.exists(p)) { if (shouldOverwrite) { logInfo(s"Path $path already exists. It will be overwritten.") + // TODO: Revert back to the original content if save is not successful. fs.delete(p, true) } else { throw new IOException( - s"Path $path already exists. Please set overwrite=true to overwrite it.") + s"Path $path already exists. Please use write.overwrite().save(path) to overwrite it.") } } @@ -193,7 +196,7 @@ private[ml] class DefaultParamsReader[T] extends Reader[T] { /** * Loads the ML component from the input path. */ - override def from(path: String): T = { + override def load(path: String): T = { implicit val format = DefaultFormats val sc = sqlContext.sparkContext val metadataPath = new Path(path, "metadata").toString diff --git a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java index 4ea954fe059a..c39538014be8 100644 --- a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java @@ -56,17 +56,17 @@ public void testDefaultReadWrite() throws IOException { MyParams instance = new MyParams(uid); instance.set(instance.intParam(), 2); String outputPath = new File(tempDir, uid).getPath(); - instance.write().to(outputPath); + instance.save(outputPath); try { - instance.write().to(outputPath); + instance.save(outputPath); Assert.fail( "Write without overwrite enabled should fail if the output directory already exists."); } catch (IOException e) { // expected } SQLContext sqlContext = new SQLContext(jsc); - instance.write().context(sqlContext).overwrite().to(outputPath); - MyParams newInstance = MyParams.read().from(outputPath); + instance.write().context(sqlContext).overwrite().save(outputPath); + MyParams newInstance = MyParams.load(outputPath); Assert.assertEquals("UID should match.", instance.uid(), newInstance.uid()); Assert.assertEquals("Params should be preserved.", 2, newInstance.getOrDefault(newInstance.intParam())); diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index fb2f6ed9e6ae..e542d89c7605 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -36,13 +36,13 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => val uid = instance.uid val path = new File(tempDir, uid).getPath - instance.write.to(path) + instance.save(path) intercept[IOException] { - instance.write.to(path) + instance.save(path) } - instance.write.overwrite().to(path) + instance.write.overwrite().save(path) val loader = instance.getClass.getMethod("read").invoke(null).asInstanceOf[Reader[T]] - val newInstance = loader.from(path) + val newInstance = loader.load(path) assert(newInstance.uid === instance.uid) instance.params.foreach { p => @@ -57,6 +57,9 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => assert(!newInstance.isDefined(p), s"Param ${p.name} shouldn't be defined.") } } + + val another = instance.getClass.getMethod("load", classOf[String]).invoke(path).asInstanceOf[T] + assert(another.uid === instance.uid) } } @@ -90,7 +93,10 @@ class MyParams(override val uid: String) extends Params with Writable { } object MyParams extends Readable[MyParams] { + override def read: Reader[MyParams] = new DefaultParamsReader[MyParams] + + override def load(path: String): MyParams = read.load(path) } class DefaultReadWriteSuite extends SparkFunSuite with MLlibTestSparkContext From 7952bd40dd98304aadda25df349627ad938e0cb5 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 6 Nov 2015 11:18:45 -0800 Subject: [PATCH 9/9] fix test --- .../scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index e542d89c7605..4545b0f281f5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -58,7 +58,8 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => } } - val another = instance.getClass.getMethod("load", classOf[String]).invoke(path).asInstanceOf[T] + val load = instance.getClass.getMethod("load", classOf[String]) + val another = load.invoke(instance, path).asInstanceOf[T] assert(another.uid === instance.uid) } }