From 6afde57ebf3d86ae3a0bb07cdb42875a7bbbd0ff Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 1 Sep 2015 23:09:26 -0700 Subject: [PATCH 01/13] switch to withColumn --- .../scala/org/apache/spark/ml/regression/LinearRegression.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 884003eb38524..f1d78516e58c1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -300,7 +300,7 @@ class LinearRegressionModel private[ml] ( private[regression] def evaluate(dataset: DataFrame): LinearRegressionSummary = { val t = udf { features: Vector => predict(features) } val predictionAndObservations = dataset - .select(col($(labelCol)), t(col($(featuresCol))).as($(predictionCol))) + .withColumn(col($(labelCol)), t(col($(featuresCol))).as($(predictionCol))) new LinearRegressionSummary(predictionAndObservations, $(predictionCol), $(labelCol)) } From 43420b7a14ff88f7d6aa3f288093537a8e52502f Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 1 Sep 2015 23:35:43 -0700 Subject: [PATCH 02/13] Fix it to work and add a test --- .../spark/ml/regression/LinearRegression.scala | 2 +- .../spark/ml/regression/LinearRegressionSuite.scala | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index f1d78516e58c1..f06972bbb0f51 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -300,7 +300,7 @@ class LinearRegressionModel private[ml] ( private[regression] def evaluate(dataset: DataFrame): LinearRegressionSummary = { val t = udf { features: Vector => predict(features) } val predictionAndObservations = dataset - .withColumn(col($(labelCol)), t(col($(featuresCol))).as($(predictionCol))) + .withColumn($(predictionCol), t(col($(featuresCol)))) new LinearRegressionSummary(predictionAndObservations, $(predictionCol), $(labelCol)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 2aaee71ecc734..7eb1b58aebd06 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -460,6 +460,18 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { // Training results for the model should be available assert(model.hasSummary) + // Feature column should be equal + val expectedFeature = dataset.select("features").collect() + val resultFeature = model.summary.predictions.select("features").collect() + resultFeature.zip(expectedFeature).foreach{ case (r1, r2) => + val result1 = r1.getAs[DenseVector](0)(0) + val result2 = r1.getAs[DenseVector](0)(1) + val expected1 = r2.getAs[DenseVector](0)(0) + val expected2 = r2.getAs[DenseVector](0)(1) + assert(result1 ~== expected1 relTol 1E-5) + assert(result2 ~== expected2 relTol 1E-5) + } + // Residuals in [[LinearRegressionResults]] should equal those manually computed val expectedResiduals = dataset.select("features", "label") .map { case Row(features: DenseVector, label: Double) => From e633ac562792a2a302fb39af98da75b897bb7e28 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sun, 13 Sep 2015 19:39:47 -0700 Subject: [PATCH 03/13] CR feedback, use transform on the dataset when calling evaluate and als change our test to just check the new expected schema --- .../spark/ml/regression/LinearRegression.scala | 6 +----- .../ml/regression/LinearRegressionSuite.scala | 14 +++----------- 2 files changed, 4 insertions(+), 16 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index f06972bbb0f51..96a738f4f7779 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -298,11 +298,7 @@ class LinearRegressionModel private[ml] ( */ // TODO: decide on a good name before exposing to public API private[regression] def evaluate(dataset: DataFrame): LinearRegressionSummary = { - val t = udf { features: Vector => predict(features) } - val predictionAndObservations = dataset - .withColumn($(predictionCol), t(col($(featuresCol)))) - - new LinearRegressionSummary(predictionAndObservations, $(predictionCol), $(labelCol)) + new LinearRegressionSummary(transform(dataset), $(predictionCol), $(labelCol)) } override protected def predict(features: Vector): Double = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 7eb1b58aebd06..60960dc2738bc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -460,17 +460,9 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { // Training results for the model should be available assert(model.hasSummary) - // Feature column should be equal - val expectedFeature = dataset.select("features").collect() - val resultFeature = model.summary.predictions.select("features").collect() - resultFeature.zip(expectedFeature).foreach{ case (r1, r2) => - val result1 = r1.getAs[DenseVector](0)(0) - val result2 = r1.getAs[DenseVector](0)(1) - val expected1 = r2.getAs[DenseVector](0)(0) - val expected2 = r2.getAs[DenseVector](0)(1) - assert(result1 ~== expected1 relTol 1E-5) - assert(result2 ~== expected2 relTol 1E-5) - } + // Schema should be a superset of the input dataset + assert(model.summary.predictions.schema.fieldNames.toSet === + dataset.schema.fieldNames.toSet ++ Set("prediction")) // Residuals in [[LinearRegressionResults]] should equal those manually computed val expectedResiduals = dataset.select("features", "label") From 64e9f3f7a9404639b7be6f10f839f9737561221d Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 18 Sep 2015 14:01:02 -0700 Subject: [PATCH 04/13] murh --- .../org/apache/spark/ml/regression/LinearRegressionSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 60960dc2738bc..bbfb08942060e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -461,7 +461,7 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(model.hasSummary) // Schema should be a superset of the input dataset - assert(model.summary.predictions.schema.fieldNames.toSet === + assert(model.summary.predictions.schema.fieldNames.toSet contains dataset.schema.fieldNames.toSet ++ Set("prediction")) // Residuals in [[LinearRegressionResults]] should equal those manually computed From 9e96c133b6490eaabc0a099d5acfa82cb437930f Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 18 Sep 2015 14:02:03 -0700 Subject: [PATCH 05/13] use subsetOf for test --- .../apache/spark/ml/regression/LinearRegressionSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index bbfb08942060e..47a030221b6e1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -461,8 +461,8 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(model.hasSummary) // Schema should be a superset of the input dataset - assert(model.summary.predictions.schema.fieldNames.toSet contains - dataset.schema.fieldNames.toSet ++ Set("prediction")) + assert((dataset.schema.fieldNames.toSet ++ Set("prediction")).subsetOf( + model.summary.predictions.schema.fieldNames.toSet)) // Residuals in [[LinearRegressionResults]] should equal those manually computed val expectedResiduals = dataset.select("features", "label") From d8baa4784e3c512211ff63e2378f75ce64a1e419 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 18 Sep 2015 14:02:58 -0700 Subject: [PATCH 06/13] Use single element add operation to set --- .../org/apache/spark/ml/regression/LinearRegressionSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 47a030221b6e1..8d64d1e1a8e26 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -461,7 +461,7 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(model.hasSummary) // Schema should be a superset of the input dataset - assert((dataset.schema.fieldNames.toSet ++ Set("prediction")).subsetOf( + assert((dataset.schema.fieldNames.toSet + "prediction").subsetOf( model.summary.predictions.schema.fieldNames.toSet)) // Residuals in [[LinearRegressionResults]] should equal those manually computed From f6448971f7a3c82da8d763078ec0606784599361 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 24 Sep 2015 18:26:16 -0700 Subject: [PATCH 07/13] Partial progress, handle missing or invalid prediction columns in one pass (todo handle in evaluate call as well) --- .../spark/ml/regression/LinearRegression.scala | 14 ++++++++++++-- .../ml/regression/LinearRegressionSuite.scala | 9 +++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 6c7c3174dac26..b10fa0001f045 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -267,9 +267,19 @@ class LinearRegression(override val uid: String) if (handlePersistence) instances.unpersist() val model = copyValues(new LinearRegressionModel(uid, weights, intercept)) + // Handle possible missing or invalid prediction columns + val predictionColOpt = get(predictionCol).orElse(getDefault(predictionCol)).filter(_ != "") + val (summaryModel, predictionColName) = predictionColOpt match { + case Some(p) => (model, p) + case None => { + val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString() + (copyValues(model).setPredictionCol(predictionColName), predictionColName) + } + } + val trainingSummary = new LinearRegressionTrainingSummary( - model.transform(dataset), - $(predictionCol), + summaryModel.transform(dataset), + predictionColName, $(labelCol), $(featuresCol), objectiveHistory) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index ce380442edb4b..48bf03372803f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -461,13 +461,22 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { test("linear regression model training summary") { val trainer = new LinearRegression val model = trainer.fit(dataset) + val trainerNoPredictionCol = trainer.setPredictionCol("") + val modelNoPredictionCol = trainerNoPredictionCol.fit(dataset) + // Training results for the model should be available assert(model.hasSummary) + assert(modelNoPredictionCol.hasSummary) // Schema should be a superset of the input dataset assert((dataset.schema.fieldNames.toSet + "prediction").subsetOf( model.summary.predictions.schema.fieldNames.toSet)) + // Validate that we re-insert a prediction column for evaluation + val modelNoPredictionColFieldNames = modelNoPredictionCol.summary.predictions.schema.fieldNames + assert((dataset.schema.fieldNames.toSet).subsetOf( + modelNoPredictionColFieldNames.toSet)) + assert(!modelNoPredictionColFieldNames.find(s => s.startsWith("prediction_")).isEmpty) // Residuals in [[LinearRegressionResults]] should equal those manually computed val expectedResiduals = dataset.select("features", "label") From f8758a04054109136400fde56ab3de6ea165d3b1 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 25 Sep 2015 16:35:37 -0700 Subject: [PATCH 08/13] Do similar logic in the different places where we may be missing a prediciton column and want to generate one --- .../ml/regression/LinearRegression.scala | 25 ++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index b10fa0001f045..fa0a7a2603f62 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -179,9 +179,19 @@ class LinearRegression(override val uid: String) val intercept = yMean val model = new LinearRegressionModel(uid, weights, intercept) + // Handle possible missing or invalid prediction columns + val predictionColOpt = get(predictionCol).orElse(getDefault(predictionCol)).filter(_ != "") + val (summaryModel, predictionColName) = predictionColOpt match { + case Some(p) => (model, p) + case None => { + val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString() + (copyValues(model).setPredictionCol(predictionColName), predictionColName) + } + } + val trainingSummary = new LinearRegressionTrainingSummary( - model.transform(dataset), - $(predictionCol), + summaryModel.transform(dataset), + predictionColName, $(labelCol), $(featuresCol), Array(0D)) @@ -331,7 +341,16 @@ class LinearRegressionModel private[ml] ( */ // TODO: decide on a good name before exposing to public API private[regression] def evaluate(dataset: DataFrame): LinearRegressionSummary = { - new LinearRegressionSummary(transform(dataset), $(predictionCol), $(labelCol)) + // Handle possible missing or invalid prediction columns + val predictionColOpt = get(predictionCol).orElse(getDefault(predictionCol)).filter(_ != "") + val (summaryModel, predictionColName) = predictionColOpt match { + case Some(p) => (this, p) + case None => { + val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString() + (copyValues(this).setPredictionCol(predictionColName), predictionColName) + } + } + new LinearRegressionSummary(summaryModel.transform(dataset), predictionColName, $(labelCol)) } override protected def predict(features: Vector): Double = { From b8208552df93f5dc513fd39f06fdf9907e72431c Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 7 Oct 2015 11:59:16 -0700 Subject: [PATCH 09/13] Simplify checking for column with exists --- .../org/apache/spark/ml/regression/LinearRegressionSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 48bf03372803f..8f6ee557e3d67 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -476,7 +476,7 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val modelNoPredictionColFieldNames = modelNoPredictionCol.summary.predictions.schema.fieldNames assert((dataset.schema.fieldNames.toSet).subsetOf( modelNoPredictionColFieldNames.toSet)) - assert(!modelNoPredictionColFieldNames.find(s => s.startsWith("prediction_")).isEmpty) + assert(modelNoPredictionColFieldNames.exists(s => s.startsWith("prediction_"))) // Residuals in [[LinearRegressionResults]] should equal those manually computed val expectedResiduals = dataset.select("features", "label") From 033f6bae777ba26f5ba4f05df42c9e387ad07ba1 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 7 Oct 2015 12:04:57 -0700 Subject: [PATCH 10/13] Start making the linear regression panda --- .../spark/ml/regression/LinearRegression.scala | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index fa0a7a2603f62..22e909e797a4d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -299,6 +299,19 @@ class LinearRegression(override val uid: String) override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra) } +object LinearRegression { + /** + * Takes the current linear regression model and an option representing the + * prediction column. If the prediction column is set returns the current + * model and prediction column, otherwise generates a new column and sets + * it as the prediction column on a new copy of the input model. + */ + protected def findSummaryModelAndPredictionCol(model: LinearRegressionmodel, + predictionColOpt: Option[String]): (LinearRegressionModel, String) = { + + } +} + /** * :: Experimental :: * Model produced by [[LinearRegression]]. From e91cdc49ca53305031088f9bcca2b8a3857db441 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 7 Oct 2015 12:37:49 -0700 Subject: [PATCH 11/13] Simplify the dynamic prediciton column creation code to be unified in a signle place --- .../ml/regression/LinearRegression.scala | 49 +++++++------------ 1 file changed, 17 insertions(+), 32 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 22e909e797a4d..08de9d5a2419a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -180,14 +180,8 @@ class LinearRegression(override val uid: String) val model = new LinearRegressionModel(uid, weights, intercept) // Handle possible missing or invalid prediction columns - val predictionColOpt = get(predictionCol).orElse(getDefault(predictionCol)).filter(_ != "") - val (summaryModel, predictionColName) = predictionColOpt match { - case Some(p) => (model, p) - case None => { - val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString() - (copyValues(model).setPredictionCol(predictionColName), predictionColName) - } - } + val predictionColOpt = get(predictionCol).orElse(getDefault(predictionCol)) + val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() val trainingSummary = new LinearRegressionTrainingSummary( summaryModel.transform(dataset), @@ -278,14 +272,8 @@ class LinearRegression(override val uid: String) val model = copyValues(new LinearRegressionModel(uid, weights, intercept)) // Handle possible missing or invalid prediction columns - val predictionColOpt = get(predictionCol).orElse(getDefault(predictionCol)).filter(_ != "") - val (summaryModel, predictionColName) = predictionColOpt match { - case Some(p) => (model, p) - case None => { - val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString() - (copyValues(model).setPredictionCol(predictionColName), predictionColName) - } - } + val predictionColOpt = get(predictionCol).orElse(getDefault(predictionCol)) + val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() val trainingSummary = new LinearRegressionTrainingSummary( summaryModel.transform(dataset), @@ -299,19 +287,6 @@ class LinearRegression(override val uid: String) override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra) } -object LinearRegression { - /** - * Takes the current linear regression model and an option representing the - * prediction column. If the prediction column is set returns the current - * model and prediction column, otherwise generates a new column and sets - * it as the prediction column on a new copy of the input model. - */ - protected def findSummaryModelAndPredictionCol(model: LinearRegressionmodel, - predictionColOpt: Option[String]): (LinearRegressionModel, String) = { - - } -} - /** * :: Experimental :: * Model produced by [[LinearRegression]]. @@ -355,17 +330,27 @@ class LinearRegressionModel private[ml] ( // TODO: decide on a good name before exposing to public API private[regression] def evaluate(dataset: DataFrame): LinearRegressionSummary = { // Handle possible missing or invalid prediction columns + val (summaryModel, predictionColName) = findSummaryModelAndPredictionCol() + new LinearRegressionSummary(summaryModel.transform(dataset), predictionColName, $(labelCol)) + } + + /** + * If the prediction column is set returns the current model and prediction column, + * otherwise generates a new column and sets it as the prediction column on a new copy + * of the current model. + */ + private[regression] def findSummaryModelAndPredictionCol(): (LinearRegressionModel, String) = { val predictionColOpt = get(predictionCol).orElse(getDefault(predictionCol)).filter(_ != "") - val (summaryModel, predictionColName) = predictionColOpt match { + predictionColOpt match { case Some(p) => (this, p) case None => { val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString() - (copyValues(this).setPredictionCol(predictionColName), predictionColName) + (copy(ParamMap.empty).setPredictionCol(predictionColName), predictionColName) } } - new LinearRegressionSummary(summaryModel.transform(dataset), predictionColName, $(labelCol)) } + override protected def predict(features: Vector): Double = { dot(features, weights) + intercept } From a247aa6e8a703dac9eb890d1f05661cf9243c44c Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 7 Oct 2015 13:35:40 -0700 Subject: [PATCH 12/13] Remove some more dead code and simplify match --- .../apache/spark/ml/regression/LinearRegression.scala | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 08de9d5a2419a..889589cde3e7e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -180,7 +180,6 @@ class LinearRegression(override val uid: String) val model = new LinearRegressionModel(uid, weights, intercept) // Handle possible missing or invalid prediction columns - val predictionColOpt = get(predictionCol).orElse(getDefault(predictionCol)) val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() val trainingSummary = new LinearRegressionTrainingSummary( @@ -272,7 +271,6 @@ class LinearRegression(override val uid: String) val model = copyValues(new LinearRegressionModel(uid, weights, intercept)) // Handle possible missing or invalid prediction columns - val predictionColOpt = get(predictionCol).orElse(getDefault(predictionCol)) val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() val trainingSummary = new LinearRegressionTrainingSummary( @@ -340,13 +338,12 @@ class LinearRegressionModel private[ml] ( * of the current model. */ private[regression] def findSummaryModelAndPredictionCol(): (LinearRegressionModel, String) = { - val predictionColOpt = get(predictionCol).orElse(getDefault(predictionCol)).filter(_ != "") - predictionColOpt match { - case Some(p) => (this, p) - case None => { + $(predictionCol) match { + case "" => { val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString() (copy(ParamMap.empty).setPredictionCol(predictionColName), predictionColName) } + case p => (this, p) } } From 7fb3b1c1bfeba2b8543a5db5fe51a37e3196a254 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 7 Oct 2015 17:49:53 -0700 Subject: [PATCH 13/13] Remove extra braces (CR feedback) --- .../org/apache/spark/ml/regression/LinearRegression.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index b1a2496006791..dd09667ef5a0f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -334,10 +334,9 @@ class LinearRegressionModel private[ml] ( */ private[regression] def findSummaryModelAndPredictionCol(): (LinearRegressionModel, String) = { $(predictionCol) match { - case "" => { + case "" => val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString() (copy(ParamMap.empty).setPredictionCol(predictionColName), predictionColName) - } case p => (this, p) } }