Skip to content

Commit

Permalink
GeneralizedLinearRegression support save/load
Browse files Browse the repository at this point in the history
  • Loading branch information
yanboliang committed Mar 2, 2016
1 parent e42724b commit 3448bdf
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.ml.regression

import breeze.stats.distributions.{Gaussian => GD}
import org.apache.hadoop.fs.Path

import org.apache.spark.{Logging, SparkException}
import org.apache.spark.annotation.{Experimental, Since}
Expand All @@ -26,7 +27,7 @@ import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.optim._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{BLAS, Vector}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
Expand Down Expand Up @@ -106,7 +107,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
@Since("2.0.0")
class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val uid: String)
extends Regressor[Vector, GeneralizedLinearRegression, GeneralizedLinearRegressionModel]
with GeneralizedLinearRegressionBase with Logging {
with GeneralizedLinearRegressionBase with DefaultParamsWritable with Logging {

import GeneralizedLinearRegression._

Expand Down Expand Up @@ -236,23 +237,26 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
}

@Since("2.0.0")
private[ml] object GeneralizedLinearRegression {
object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLinearRegression] {

@Since("2.0.0")
override def load(path: String): GeneralizedLinearRegression = super.load(path)

/** Set of family and link pairs that GeneralizedLinearRegression supports. */
lazy val supportedFamilyAndLinkPairs = Set(
private[ml] lazy val supportedFamilyAndLinkPairs = Set(
Gaussian -> Identity, Gaussian -> Log, Gaussian -> Inverse,
Binomial -> Logit, Binomial -> Probit, Binomial -> CLogLog,
Poisson -> Log, Poisson -> Identity, Poisson -> Sqrt,
Gamma -> Inverse, Gamma -> Identity, Gamma -> Log
)

/** Set of family names that GeneralizedLinearRegression supports. */
lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name)
private[ml] lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name)

/** Set of link names that GeneralizedLinearRegression supports. */
lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name)
private[ml] lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name)

val epsilon: Double = 1E-16
private[ml] val epsilon: Double = 1E-16

/**
* Wrapper of family and link combination used in the model.
Expand Down Expand Up @@ -552,7 +556,7 @@ class GeneralizedLinearRegressionModel private[ml] (
@Since("2.0.0") val coefficients: Vector,
@Since("2.0.0") val intercept: Double)
extends RegressionModel[Vector, GeneralizedLinearRegressionModel]
with GeneralizedLinearRegressionBase {
with GeneralizedLinearRegressionBase with MLWritable {

import GeneralizedLinearRegression._

Expand All @@ -574,4 +578,58 @@ class GeneralizedLinearRegressionModel private[ml] (
copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept), extra)
.setParent(parent)
}

@Since("2.0.0")
override def write: MLWriter =
new GeneralizedLinearRegressionModel.GeneralizedLinearRegressionModelWriter(this)
}

@Since("2.0.0")
object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegressionModel] {

@Since("2.0.0")
override def read: MLReader[GeneralizedLinearRegressionModel] =
new GeneralizedLinearRegressionModelReader

@Since("2.0.0")
override def load(path: String): GeneralizedLinearRegressionModel = super.load(path)

/** [[MLWriter]] instance for [[GeneralizedLinearRegressionModel]] */
private[GeneralizedLinearRegressionModel]
class GeneralizedLinearRegressionModelWriter(instance: GeneralizedLinearRegressionModel)
extends MLWriter with Logging {

private case class Data(intercept: Double, coefficients: Vector)

override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Save model data: intercept, coefficients
val data = Data(instance.intercept, instance.coefficients)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

private class GeneralizedLinearRegressionModelReader
extends MLReader[GeneralizedLinearRegressionModel] {

/** Checked against metadata when loading model */
private val className = classOf[GeneralizedLinearRegressionModel].getName

override def load(path: String): GeneralizedLinearRegressionModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)

val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath)
.select("intercept", "coefficients").head()
val intercept = data.getDouble(0)
val coefficients = data.getAs[Vector](1)

val model = new GeneralizedLinearRegressionModel(metadata.uid, coefficients, intercept)

DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.util.Random

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.linalg.{BLAS, DenseVector, Vectors}
import org.apache.spark.mllib.random._
Expand All @@ -30,7 +30,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row}

class GeneralizedLinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
class GeneralizedLinearRegressionSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

private val seed: Int = 42
@transient var datasetGaussianIdentity: DataFrame = _
Expand Down Expand Up @@ -464,10 +465,37 @@ class GeneralizedLinearRegressionSuite extends SparkFunSuite with MLlibTestSpark
}
}
}

test("read/write") {
def checkModelData(
model: GeneralizedLinearRegressionModel,
model2: GeneralizedLinearRegressionModel): Unit = {
assert(model.intercept === model2.intercept)
assert(model.coefficients.toArray === model2.coefficients.toArray)
}

val glr = new GeneralizedLinearRegression()
testEstimatorAndModelReadWrite(glr, datasetPoissonLog,
GeneralizedLinearRegressionSuite.allParamSettings, checkModelData)
}
}

object GeneralizedLinearRegressionSuite {

/**
* Mapping from all Params to valid settings which differ from the defaults.
* This is useful for tests which need to exercise all Params, such as save/load.
* This excludes input columns to simplify some tests.
*/
val allParamSettings: Map[String, Any] = Map(
"family" -> "poisson",
"link" -> "log",
"fitIntercept" -> true,
"maxIter" -> 2, // intentionally small
"tol" -> 0.8,
"regParam" -> 0.01,
"predictionCol" -> "myPrediction")

def generateGeneralizedLinearRegressionInput(
intercept: Double,
coefficients: Array[Double],
Expand Down

0 comments on commit 3448bdf

Please sign in to comment.