From 56ad42dc6638ccde953b848139a6b1fe5d3d8176 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sun, 10 Apr 2016 21:00:25 +0800 Subject: [PATCH 1/4] GLM supoort link prediction --- .../GeneralizedLinearRegression.scala | 58 +++++++++++++- .../GeneralizedLinearRegressionSuite.scala | 80 +++++++++++-------- 2 files changed, 104 insertions(+), 34 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 00cf25dc54d11..24b21884b9072 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -78,6 +78,19 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam @Since("2.0.0") def getLink: String = $(link) + /** + * Param for link prediction (linear predictor) column name. + * @group param + */ + @Since("2.0.0") + final val linkPredictionCol: Param[String] = new Param[String](this, "linkPredictionCol", + "link prediction (linear predictor) column name") + setDefault(linkPredictionCol, "") + + /** @group getParam */ + @Since("2.0.0") + def getLinkPredictionCol: String = $(linkPredictionCol) + import GeneralizedLinearRegression._ @Since("2.0.0") @@ -93,7 +106,12 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam Family.fromName($(family)) -> Link.fromName($(link))), "Generalized Linear Regression " + s"with ${$(family)} family does not support ${$(link)} link function.") } - super.validateAndTransformSchema(schema, fitting, featuresDataType) + val newSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType) + if ($(linkPredictionCol).nonEmpty) { + SchemaUtils.appendColumn(newSchema, $(linkPredictionCol), DoubleType) + } else { + newSchema + } } } @@ -196,6 +214,13 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val def setSolver(value: String): this.type = set(solver, value) setDefault(solver -> "irls") + /** + * Sets the link prediction (linear predictor) column name. + * @group setParam + */ + @Since("2.0.0") + def setLinkPredictionCol(value: String): this.type = set(linkPredictionCol, value) + override protected def train(dataset: Dataset[_]): GeneralizedLinearRegressionModel = { val familyObj = Family.fromName($(family)) val linkObj = if (isDefined(link)) { @@ -664,6 +689,13 @@ class GeneralizedLinearRegressionModel private[ml] ( extends RegressionModel[Vector, GeneralizedLinearRegressionModel] with GeneralizedLinearRegressionBase with MLWritable { + /** + * Sets the link prediction (linear predictor) column name. + * @group setParam + */ + @Since("2.0.0") + def setLinkPredictionCol(value: String): this.type = set(linkPredictionCol, value) + import GeneralizedLinearRegression._ lazy val familyObj = Family.fromName($(family)) @@ -675,10 +707,32 @@ class GeneralizedLinearRegressionModel private[ml] ( lazy val familyAndLink = new FamilyAndLink(familyObj, linkObj) override protected def predict(features: Vector): Double = { - val eta = BLAS.dot(features, coefficients) + intercept + val eta = predictLink(features) familyAndLink.fitted(eta) } + protected def predictLink(features: Vector): Double = { + BLAS.dot(features, coefficients) + intercept + } + + override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema) + transformImpl(dataset) + } + + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { + val predictUDF = udf { (features: Vector) => predict(features) } + val predictLinkUDF = udf { (features: Vector) => predictLink(features) } + var output = dataset + if ($(predictionCol).nonEmpty) { + output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + if ($(linkPredictionCol).nonEmpty) { + output = output.withColumn($(linkPredictionCol), predictLinkUDF(col($(featuresCol)))) + } + output + } + private var trainingSummary: Option[GeneralizedLinearRegressionSummary] = None /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 4905f3e0687f2..7ab27f6c77acf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -247,20 +247,24 @@ class GeneralizedLinearRegressionSuite ("inverse", datasetGaussianInverse))) { for (fitIntercept <- Seq(false, true)) { val trainer = new GeneralizedLinearRegression().setFamily("gaussian").setLink(link) - .setFitIntercept(fitIntercept) + .setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction") val model = trainer.fit(dataset) val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gaussian family, " + s"$link link and fitIntercept = $fitIntercept.") val familyLink = new FamilyAndLink(Gaussian, Link.fromName(link)) - model.transform(dataset).select("features", "prediction").collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val eta = BLAS.dot(features, model.coefficients) + model.intercept - val prediction2 = familyLink.fitted(eta) - assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + - s"gaussian family, $link link and fitIntercept = $fitIntercept.") - } + model.transform(dataset).select("features", "prediction", "linkPrediction").collect() + .foreach { + case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) => + val eta = BLAS.dot(features, model.coefficients) + model.intercept + val prediction2 = familyLink.fitted(eta) + val linkPrediction2 = eta + assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + + s"gaussian family, $link link and fitIntercept = $fitIntercept.") + assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " + + s"GLM with gaussian family, $link link and fitIntercept = $fitIntercept.") + } idx += 1 } @@ -358,7 +362,7 @@ class GeneralizedLinearRegressionSuite ("cloglog", datasetBinomial))) { for (fitIntercept <- Seq(false, true)) { val trainer = new GeneralizedLinearRegression().setFamily("binomial").setLink(link) - .setFitIntercept(fitIntercept) + .setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction") val model = trainer.fit(dataset) val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1), model.coefficients(2), model.coefficients(3)) @@ -366,13 +370,17 @@ class GeneralizedLinearRegressionSuite s"$link link and fitIntercept = $fitIntercept.") val familyLink = new FamilyAndLink(Binomial, Link.fromName(link)) - model.transform(dataset).select("features", "prediction").collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val eta = BLAS.dot(features, model.coefficients) + model.intercept - val prediction2 = familyLink.fitted(eta) - assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + - s"binomial family, $link link and fitIntercept = $fitIntercept.") - } + model.transform(dataset).select("features", "prediction", "linkPrediction").collect() + .foreach { + case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) => + val eta = BLAS.dot(features, model.coefficients) + model.intercept + val prediction2 = familyLink.fitted(eta) + val linkPrediction2 = eta + assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + + s"binomial family, $link link and fitIntercept = $fitIntercept.") + assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " + + s"GLM with binomial family, $link link and fitIntercept = $fitIntercept.") + } idx += 1 } @@ -427,20 +435,24 @@ class GeneralizedLinearRegressionSuite ("sqrt", datasetPoissonSqrt))) { for (fitIntercept <- Seq(false, true)) { val trainer = new GeneralizedLinearRegression().setFamily("poisson").setLink(link) - .setFitIntercept(fitIntercept) + .setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction") val model = trainer.fit(dataset) val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with poisson family, " + s"$link link and fitIntercept = $fitIntercept.") val familyLink = new FamilyAndLink(Poisson, Link.fromName(link)) - model.transform(dataset).select("features", "prediction").collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val eta = BLAS.dot(features, model.coefficients) + model.intercept - val prediction2 = familyLink.fitted(eta) - assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + - s"poisson family, $link link and fitIntercept = $fitIntercept.") - } + model.transform(dataset).select("features", "prediction", "linkPrediction").collect() + .foreach { + case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) => + val eta = BLAS.dot(features, model.coefficients) + model.intercept + val prediction2 = familyLink.fitted(eta) + val linkPrediction2 = eta + assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + + s"poisson family, $link link and fitIntercept = $fitIntercept.") + assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Prediction mismatch: " + + s"GLM with poisson family, $link link and fitIntercept = $fitIntercept.") + } idx += 1 } @@ -495,20 +507,24 @@ class GeneralizedLinearRegressionSuite ("identity", datasetGammaIdentity), ("log", datasetGammaLog))) { for (fitIntercept <- Seq(false, true)) { val trainer = new GeneralizedLinearRegression().setFamily("gamma").setLink(link) - .setFitIntercept(fitIntercept) + .setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction") val model = trainer.fit(dataset) val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gamma family, " + s"$link link and fitIntercept = $fitIntercept.") val familyLink = new FamilyAndLink(Gamma, Link.fromName(link)) - model.transform(dataset).select("features", "prediction").collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val eta = BLAS.dot(features, model.coefficients) + model.intercept - val prediction2 = familyLink.fitted(eta) - assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + - s"gamma family, $link link and fitIntercept = $fitIntercept.") - } + model.transform(dataset).select("features", "prediction", "linkPrediction").collect() + .foreach { + case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) => + val eta = BLAS.dot(features, model.coefficients) + model.intercept + val prediction2 = familyLink.fitted(eta) + val linkPrediction2 = eta + assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + + s"gamma family, $link link and fitIntercept = $fitIntercept.") + assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Prediction mismatch: " + + s"GLM with gamma family, $link link and fitIntercept = $fitIntercept.") + } idx += 1 } From e2020394750e324a44064fb263c8db26bd604e91 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sun, 10 Apr 2016 21:28:36 +0800 Subject: [PATCH 2/4] fix doc of test --- .../ml/regression/GeneralizedLinearRegressionSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 7ab27f6c77acf..65de170eefb5e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -450,7 +450,7 @@ class GeneralizedLinearRegressionSuite val linkPrediction2 = eta assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + s"poisson family, $link link and fitIntercept = $fitIntercept.") - assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Prediction mismatch: " + + assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " + s"GLM with poisson family, $link link and fitIntercept = $fitIntercept.") } @@ -522,7 +522,7 @@ class GeneralizedLinearRegressionSuite val linkPrediction2 = eta assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + s"gamma family, $link link and fitIntercept = $fitIntercept.") - assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Prediction mismatch: " + + assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " + s"GLM with gamma family, $link link and fitIntercept = $fitIntercept.") } From e5aea09d3eedaa37fb3d3b873be602bcf5308c0f Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 12 Apr 2016 07:02:22 -0700 Subject: [PATCH 3/4] fix transform return DF --- .../spark/ml/regression/GeneralizedLinearRegression.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 24b21884b9072..26114708c7f46 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -730,7 +730,7 @@ class GeneralizedLinearRegressionModel private[ml] ( if ($(linkPredictionCol).nonEmpty) { output = output.withColumn($(linkPredictionCol), predictLinkUDF(col($(featuresCol)))) } - output + output.toDF } private var trainingSummary: Option[GeneralizedLinearRegressionSummary] = None From cb1b5e6a5fbea84d8493d649aa8b931c1bb6841f Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 20 Apr 2016 01:58:43 -0700 Subject: [PATCH 4/4] add documents --- .../spark/ml/regression/GeneralizedLinearRegression.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 26114708c7f46..fe375ed695a43 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -80,6 +80,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam /** * Param for link prediction (linear predictor) column name. + * Default is empty, which means we do not output link prediction. * @group param */ @Since("2.0.0") @@ -711,7 +712,10 @@ class GeneralizedLinearRegressionModel private[ml] ( familyAndLink.fitted(eta) } - protected def predictLink(features: Vector): Double = { + /** + * Calculate the link prediction (linear predictor) of the given instance. + */ + private def predictLink(features: Vector): Double = { BLAS.dot(features, coefficients) + intercept }