New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-12811] [ML] Estimator for Generalized Linear Models(GLMs) #11136
Conversation
Instance(label, weight, features) | ||
} | ||
|
||
if ($(family) == "gaussian" && $(link) == "identity") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For gaussian
family with identity
link, we only use WeightedLeastSquares
to train the model.
5af604e
to
3082686
Compare
override def deriv(mu: Double): Double = 1.0 / (mu * (1.0 - mu)) | ||
|
||
override def unlink(eta: Double): Double = 1.0 / (1.0 + math.exp(-1.0 * eta)) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The link functions should be further refinement to guarantee that the endogenous variable does not contain invalid values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the restriction on endogenous variable should go into the Family
class since it is truly the distribution on Y that restricts the values. This is how R does it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I have add restriction in Family
. We use the clean
function to trim invalid data.
Test build #50978 has finished for PR 11136 at commit
|
Test build #50980 has finished for PR 11136 at commit
|
* The default link for the Gamma family is the inverse link. | ||
* @param link a link function instance | ||
*/ | ||
private[ml] class Gamma(link: Link = Log) extends Family(link) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo in the default link. Log
=> Inverse
Jenkins, test this please. |
Test build #51101 has started for PR 11136 at commit |
jenkins, test this please |
Test build #51109 has finished for PR 11136 at commit
|
Test build #51264 has finished for PR 11136 at commit
|
Test build #51306 has finished for PR 11136 at commit
|
|
||
def variance(mu: Double): Double = 1.0 | ||
|
||
override def clean(mu: Double): Double = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we constrict mu
in valid range using a different method compared with R. In R, if mu
or eta
is invalid, it will diminish coefficients
until it makes validmu
and validaeta
passed. I think is will make convergence slowness. I'm looking forward to hear others' thought.
I'm making a pass. |
*/ | ||
@Since("2.0.0") | ||
final val family: Param[String] = new Param(this, "family", | ||
"the name of family which is a description of the error distribution to be used in the model", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Include supported options and the default value in the param doc (and the ScalaDoc).
- Shall we make "gaussian" the default?
Only some minor comments on the implementation. I will make a pass on the tests tomorrow. @dbtsai It would be great if you can make a pass too. |
Test build #51963 has finished for PR 11136 at commit
|
* In order to take the normal equation approach efficiently, [[WeightedLeastSquares]] | ||
* only supports the number of features is no more than 4096. | ||
*/ | ||
val MaxNumFeatures: Int = 4096 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For constant, do we have naming convention? Like MAX_NUM_FEATURES
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not specified in Spark Code Style guide and Scala code style guide recommends MaxNumFeatures
. But I do like MAX_NUM_FEATURES
better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I will update it to MAX_NUM_FEATURES
after collecting other comments. Thanks!
Gonna do another detail pass of the code tomorrow. |
Test build #52044 has finished for PR 11136 at commit
|
val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) | ||
|
||
val testData = | ||
generateMultinomialLogisticInput(coefficients, xMean, xVariance, true, nPoints, seed) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would be good to say addIntercept = true
instead of just true
.
I made one pass on the tests, only some minor comments. |
Test build #52211 has finished for PR 11136 at commit
|
Test build #52214 has finished for PR 11136 at commit
|
Test build #52215 has finished for PR 11136 at commit
|
Test build #52216 has finished for PR 11136 at commit
|
|
||
val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) | ||
val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))) | ||
.map { case Row(label: Double, weight: Double, features: Vector) => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.rdd.map
instead of .map
. This is caused by recent DataFrame API changes.
Test build #52227 has finished for PR 11136 at commit
|
LGTM. Merged into master. Thanks! I created SPARK-13597 for the Python API. |
Estimator for Generalized Linear Models(GLMs) which will be solved by IRLS. cc mengxr Author: Yanbo Liang <ybliang8@gmail.com> Closes apache#11136 from yanboliang/spark-12811.
Estimator for Generalized Linear Models(GLMs) which will be solved by IRLS.
cc @mengxr