From 0fd854df04a031ee5e259fb0ad8351ea296001e7 Mon Sep 17 00:00:00 2001 From: Wenjian Huang Date: Thu, 12 Nov 2015 15:25:14 +0800 Subject: [PATCH 1/5] Update LinearRegression.scala --- .../ml/regression/LinearRegression.scala | 57 ++++++++++++++++++- 1 file changed, 55 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 913140e581983..d4bac0ffab23d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -22,6 +22,7 @@ import scala.collection.mutable import breeze.linalg.{DenseVector => BDV} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} import breeze.stats.distributions.StudentsT +import org.apache.hadoop.fs.Path import org.apache.spark.{Logging, SparkException} import org.apache.spark.ml.feature.Instance @@ -30,7 +31,7 @@ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS._ @@ -341,6 +342,58 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String @Since("1.4.0") override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra) + + /** + * Returns a [[Writer]] instance for this ML instance. + * + * For [[LinearRegressionModel]], this does NOT currently save the training [[summary]]. + * An option to save [[summary]] may be added in the future. + */ + override def write: Writer = new LinearRegressionWriter(this) +} + +/** [[Writer]] instance for [[LinearRegressionModel]] */ +private[regression] class LinearRegressionWriter(instance: LinearRegressionModel) + extends Writer with Logging { + + private case class Data(intercept: Double, coefficients: Vector) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: intercept, coefficients + val data = Data(instance.intercept, instance.coefficients) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath) + } +} + +object LinearRegressionModel extends Readable[LinearRegressionModel] { + + override def read: Reader[LinearRegressionModel] = new LinearRegressionReader + + override def load(path: String): LinearRegressionModel = read.load(path) +} + + +private[regression] class LinearRegressionReader extends Reader[LinearRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = "org.apache.spark.ml.regression.LinearRegressionModel" + + override def load(path: String): LinearRegressionModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.format("parquet").load(dataPath) + .select("intercept", "coefficients").head() + val intercept = data.getDouble(0) + val coefficients = data.getAs[Vector](1) + val model = new LinearRegressionModel(metadata.uid, coefficients, intercept) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } } /** @@ -354,7 +407,7 @@ class LinearRegressionModel private[ml] ( val coefficients: Vector, val intercept: Double) extends RegressionModel[Vector, LinearRegressionModel] - with LinearRegressionParams { + with LinearRegressionParams with Writable { private var trainingSummary: Option[LinearRegressionTrainingSummary] = None From 4cc7472a1bd165221dde38f223f4fa1785fd1b40 Mon Sep 17 00:00:00 2001 From: Wenjian Huang Date: Thu, 12 Nov 2015 15:29:39 +0800 Subject: [PATCH 2/5] add linear regression model import/export feature --- .../ml/regression/LinearRegressionSuite.scala | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index a1d86fe8fedad..2d3e616a2d03c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -22,14 +22,15 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{Identifiable, DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.linalg.{Vector, DenseVector, Vectors} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} -class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { +class LinearRegressionSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { private val seed: Int = 42 @transient var datasetWithDenseFeature: DataFrame = _ @@ -854,4 +855,15 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { model.summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } model.summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } } + + ignore("read/write") { // SPARK-11672 + // Set some Params to make sure set Params are serialized. + val linearRegression = new LinearRegression() + .setElasticNetParam(0.1) + .setMaxIter(2) + .fit(datasetWithDenseFeature) + val linearRegression2 = testDefaultReadWrite(linearRegression) + assert(linearRegression.intercept === linearRegression2.intercept) + assert(linearRegression.coefficients.toArray === linearRegression2.coefficients.toArray) + } } From 46ec4a143a0e01bb39a9efe33c7d802d845bd8f5 Mon Sep 17 00:00:00 2001 From: Wenjian Huang Date: Wed, 18 Nov 2015 08:18:17 +0800 Subject: [PATCH 3/5] Update LinearRegressionSuite.scala --- .../apache/spark/ml/regression/LinearRegressionSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 2d3e616a2d03c..1ce1d30c475e4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -855,7 +855,7 @@ class LinearRegressionSuite model.summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } model.summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } } - + ignore("read/write") { // SPARK-11672 // Set some Params to make sure set Params are serialized. val linearRegression = new LinearRegression() @@ -865,5 +865,5 @@ class LinearRegressionSuite val linearRegression2 = testDefaultReadWrite(linearRegression) assert(linearRegression.intercept === linearRegression2.intercept) assert(linearRegression.coefficients.toArray === linearRegression2.coefficients.toArray) - } + } } From f1bebc3a2378da04cf5e9f544f5ae70708e9e37c Mon Sep 17 00:00:00 2001 From: Wenjian Huang Date: Wed, 18 Nov 2015 09:34:20 +0800 Subject: [PATCH 4/5] Update LinearRegression.scala --- .../ml/regression/LinearRegression.scala | 105 +++++++++--------- 1 file changed, 55 insertions(+), 50 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index d4bac0ffab23d..89c8e8012f19c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -66,7 +66,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams @Experimental class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String) extends Regressor[Vector, LinearRegression, LinearRegressionModel] - with LinearRegressionParams with Logging { + with LinearRegressionParams with Writable with Logging { @Since("1.4.0") def this() = this(Identifiable.randomUID("linReg")) @@ -343,57 +343,11 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String @Since("1.4.0") override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra) - /** - * Returns a [[Writer]] instance for this ML instance. - * - * For [[LinearRegressionModel]], this does NOT currently save the training [[summary]]. - * An option to save [[summary]] may be added in the future. - */ - override def write: Writer = new LinearRegressionWriter(this) + override def write: Writer = new DefaultParamsWriter(this) } -/** [[Writer]] instance for [[LinearRegressionModel]] */ -private[regression] class LinearRegressionWriter(instance: LinearRegressionModel) - extends Writer with Logging { - - private case class Data(intercept: Double, coefficients: Vector) - - override protected def saveImpl(path: String): Unit = { - // Save metadata and Params - DefaultParamsWriter.saveMetadata(instance, path, sc) - // Save model data: intercept, coefficients - val data = Data(instance.intercept, instance.coefficients) - val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath) - } -} - -object LinearRegressionModel extends Readable[LinearRegressionModel] { - - override def read: Reader[LinearRegressionModel] = new LinearRegressionReader - - override def load(path: String): LinearRegressionModel = read.load(path) -} - - -private[regression] class LinearRegressionReader extends Reader[LinearRegressionModel] { - - /** Checked against metadata when loading model */ - private val className = "org.apache.spark.ml.regression.LinearRegressionModel" - - override def load(path: String): LinearRegressionModel = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) - - val dataPath = new Path(path, "data").toString - val data = sqlContext.read.format("parquet").load(dataPath) - .select("intercept", "coefficients").head() - val intercept = data.getDouble(0) - val coefficients = data.getAs[Vector](1) - val model = new LinearRegressionModel(metadata.uid, coefficients, intercept) - - DefaultParamsReader.getAndSetParams(model, metadata) - model - } +object LinearRegression extends Readable[LinearRegression] { + override def read: Reader[LinearRegression] = new DefaultParamsReader[LinearRegression] } /** @@ -475,8 +429,59 @@ class LinearRegressionModel private[ml] ( if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) newModel.setParent(parent) } + + /** + * Returns a [[Writer]] instance for this ML instance. + * + * For [[LinearRegressionModel]], this does NOT currently save the training [[summary]]. + * An option to save [[summary]] may be added in the future. + */ + override def write: Writer = new LinearRegressionModel.LinearRegressionModelWriter(this) } +object LinearRegressionModel extends Readable[LinearRegressionModel] { + + override def read: Reader[LinearRegressionModel] = new LinearRegressionModelReader + + override def load(path: String): LinearRegressionModel = read.load(path) + + /** [[Writer]] instance for [[LinearRegressionModel]] */ + private[regression] class LinearRegressionModelWriter(instance: LinearRegressionModel) + extends Writer with Logging { + + private case class Data(intercept: Double, coefficients: Vector) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: intercept, coefficients + val data = Data(instance.intercept, instance.coefficients) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath) + } + } + +private[regression] class LinearRegressionReader extends Reader[LinearRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = "org.apache.spark.ml.regression.LinearRegressionModel" + + override def load(path: String): LinearRegressionModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.format("parquet").load(dataPath) + .select("intercept", "coefficients").head() + val intercept = data.getDouble(0) + val coefficients = data.getAs[Vector](1) + val model = new LinearRegressionModel(metadata.uid, coefficients, intercept) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } +} + + /** * :: Experimental :: * Linear regression training results. Currently, the training summary ignores the From 6e6de5b2e05388800978c9ffd340d09397fe777f Mon Sep 17 00:00:00 2001 From: Wenjian Huang Date: Wed, 18 Nov 2015 10:02:07 +0800 Subject: [PATCH 5/5] Update LinearRegression.scala