From bbead70f6e7d3ed2f05423c077db9f09acff6869 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Mon, 21 Nov 2016 13:14:50 +0800 Subject: [PATCH 1/2] create pr --- .../spark/ml/regression/GeneralizedLinearRegression.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 3f9de1fe74c9..f6cbde359707 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -226,7 +226,11 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val * @group setParam */ @Since("2.0.0") - def setSolver(value: String): this.type = set(solver, value) + def setSolver(value: String): this.type = { + require("irls" == value, + s"Solver $value was not supported. Supported options: irls") + set(solver, value) + } setDefault(solver -> "irls") /** From 1fe2f925aa7dbe614f65a76eb61ebe13fe67dd6a Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Tue, 22 Nov 2016 10:48:56 +0800 Subject: [PATCH 2/2] update --- .../GeneralizedLinearRegression.scala | 13 +++++++--- .../ml/regression/LinearRegression.scala | 24 +++++++++++++++---- 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index f6cbde359707..e9305331883b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -227,11 +227,11 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val */ @Since("2.0.0") def setSolver(value: String): this.type = { - require("irls" == value, - s"Solver $value was not supported. Supported options: irls") + require(supportedSolvers.contains(value), + s"Solver $value was not supported. Supported options: ${supportedSolvers.mkString(", ")}") set(solver, value) } - setDefault(solver -> "irls") + setDefault(solver -> IRLS) /** * Sets the link prediction (linear predictor) column name. @@ -301,6 +301,13 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine @Since("2.0.0") override def load(path: String): GeneralizedLinearRegression = super.load(path) + /** String name for "irls" solver. */ + private[regression] val IRLS = "irls" + + /** Set of solvers that GeneralizedLinearRegression supports. */ + private[regression] val supportedSolvers = Array(IRLS) + + /** Set of family and link pairs that GeneralizedLinearRegression supports. */ private[regression] lazy val supportedFamilyAndLinkPairs = Set( Gaussian -> Identity, Gaussian -> Log, Gaussian -> Inverse, diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 8ea5e1e6c453..c8bf7de1b0bc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -77,6 +77,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String extends Regressor[Vector, LinearRegression, LinearRegressionModel] with LinearRegressionParams with DefaultParamsWritable with Logging { + import LinearRegression._ + @Since("1.4.0") def this() = this(Identifiable.randomUID("linReg")) @@ -174,11 +176,11 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String */ @Since("1.6.0") def setSolver(value: String): this.type = { - require(Set("auto", "l-bfgs", "normal").contains(value), - s"Solver $value was not supported. Supported options: auto, l-bfgs, normal") + require(supportedSolvers.contains(value), + s"Solver $value was not supported. Supported options: ${supportedSolvers.mkString(", ")}") set(solver, value) } - setDefault(solver -> "auto") + setDefault(solver -> AUTO) /** * Suggested depth for treeAggregate (>= 2). @@ -203,8 +205,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String Instance(label, weight, features) } - if (($(solver) == "auto" && - numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == "normal") { + if (($(solver) == AUTO && + numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == NORMAL) { // For low dimensional data, WeightedLeastSquares is more efficient since the // training algorithm only requires one pass through the data. (SPARK-10668) @@ -410,6 +412,18 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String @Since("1.6.0") object LinearRegression extends DefaultParamsReadable[LinearRegression] { + /** String name for "auto" solver. */ + private[regression] val AUTO = "auto" + + /** String name for "l-bfgs" solver. */ + private[regression] val LBFGS = "l-bfgs" + + /** String name for "normal" solver. */ + private[regression] val NORMAL = "normal" + + /** Set of solvers that LinearRegression supports. */ + private[regression] val supportedSolvers = Array(AUTO, LBFGS, NORMAL) + @Since("1.6.0") override def load(path: String): LinearRegression = super.load(path)