Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-12811] [ML] Estimator for Generalized Linear Models(GLMs) #11136

Closed
wants to merge 17 commits into from

Conversation

yanboliang
Copy link
Contributor

Estimator for Generalized Linear Models(GLMs) which will be solved by IRLS.

cc @mengxr

Instance(label, weight, features)
}

if ($(family) == "gaussian" && $(link) == "identity") {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For gaussian family with identity link, we only use WeightedLeastSquares to train the model.

override def deriv(mu: Double): Double = 1.0 / (mu * (1.0 - mu))

override def unlink(eta: Double): Double = 1.0 / (1.0 + math.exp(-1.0 * eta))
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The link functions should be further refinement to guarantee that the endogenous variable does not contain invalid values.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the restriction on endogenous variable should go into the Family class since it is truly the distribution on Y that restricts the values. This is how R does it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I have add restriction in Family. We use the clean function to trim invalid data.

@SparkQA
Copy link

SparkQA commented Feb 9, 2016

Test build #50978 has finished for PR 11136 at commit 5af604e.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
    • class GeneralizedLinearRegression @Since(\"2.0.0\") (@Since(\"2.0.0\") override val uid: String)

@SparkQA
Copy link

SparkQA commented Feb 9, 2016

Test build #50980 has finished for PR 11136 at commit 3082686.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

* The default link for the Gamma family is the inverse link.
* @param link a link function instance
*/
private[ml] class Gamma(link: Link = Log) extends Family(link) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in the default link. Log => Inverse

@yanboliang
Copy link
Contributor Author

Jenkins, test this please.

@SparkQA
Copy link

SparkQA commented Feb 11, 2016

Test build #51101 has started for PR 11136 at commit 97c3f6a.

@shaneknapp
Copy link
Contributor

jenkins, test this please

@SparkQA
Copy link

SparkQA commented Feb 11, 2016

Test build #51109 has finished for PR 11136 at commit 97c3f6a.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Feb 14, 2016

Test build #51264 has finished for PR 11136 at commit cc10147.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Feb 15, 2016

Test build #51306 has finished for PR 11136 at commit 4a27970.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.


def variance(mu: Double): Double = 1.0

override def clean(mu: Double): Double = {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we constrict mu in valid range using a different method compared with R. In R, if mu or eta is invalid, it will diminish coefficients until it makes validmu and validaeta passed. I think is will make convergence slowness. I'm looking forward to hear others' thought.

@mengxr
Copy link
Contributor

mengxr commented Feb 22, 2016

I'm making a pass.

*/
@Since("2.0.0")
final val family: Param[String] = new Param(this, "family",
"the name of family which is a description of the error distribution to be used in the model",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Include supported options and the default value in the param doc (and the ScalaDoc).
  • Shall we make "gaussian" the default?

@mengxr
Copy link
Contributor

mengxr commented Feb 25, 2016

Only some minor comments on the implementation. I will make a pass on the tests tomorrow. @dbtsai It would be great if you can make a pass too.

@SparkQA
Copy link

SparkQA commented Feb 25, 2016

Test build #51963 has finished for PR 11136 at commit 2ebcef7.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

* In order to take the normal equation approach efficiently, [[WeightedLeastSquares]]
* only supports the number of features is no more than 4096.
*/
val MaxNumFeatures: Int = 4096
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For constant, do we have naming convention? Like MAX_NUM_FEATURES?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not specified in Spark Code Style guide and Scala code style guide recommends MaxNumFeatures. But I do like MAX_NUM_FEATURES better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I will update it to MAX_NUM_FEATURES after collecting other comments. Thanks!

@dbtsai
Copy link
Member

dbtsai commented Feb 26, 2016

Gonna do another detail pass of the code tomorrow.

@SparkQA
Copy link

SparkQA commented Feb 26, 2016

Test build #52044 has finished for PR 11136 at commit c05a948.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)

val testData =
generateMultinomialLogisticInput(coefficients, xMean, xVariance, true, nPoints, seed)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be good to say addIntercept = true instead of just true.

@mengxr
Copy link
Contributor

mengxr commented Mar 1, 2016

I made one pass on the tests, only some minor comments.

@SparkQA
Copy link

SparkQA commented Mar 1, 2016

Test build #52211 has finished for PR 11136 at commit 314b562.

  • This patch fails to build.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Mar 1, 2016

Test build #52214 has finished for PR 11136 at commit 314b562.

  • This patch fails to build.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Mar 1, 2016

Test build #52215 has finished for PR 11136 at commit 314b562.

  • This patch fails to build.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Mar 1, 2016

Test build #52216 has finished for PR 11136 at commit 31a912c.

  • This patch fails to build.
  • This patch merges cleanly.
  • This patch adds no public classes.


val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol)))
.map { case Row(label: Double, weight: Double, features: Vector) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.rdd.map instead of .map. This is caused by recent DataFrame API changes.

@SparkQA
Copy link

SparkQA commented Mar 1, 2016

Test build #52227 has finished for PR 11136 at commit 007a4ec.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@mengxr
Copy link
Contributor

mengxr commented Mar 1, 2016

LGTM. Merged into master. Thanks! I created SPARK-13597 for the Python API.

@asfgit asfgit closed this in 5ed48dd Mar 1, 2016
@yanboliang yanboliang deleted the spark-12811 branch March 2, 2016 02:40
roygao94 pushed a commit to roygao94/spark that referenced this pull request Mar 22, 2016
Estimator for Generalized Linear Models(GLMs) which will be solved by IRLS.

cc mengxr

Author: Yanbo Liang <ybliang8@gmail.com>

Closes apache#11136 from yanboliang/spark-12811.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
6 participants