-
Notifications
You must be signed in to change notification settings - Fork 28k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-10931][PYSPARK][ML] PySpark ML Models should contain Param values #14653
Conversation
CC @MLnick |
@@ -59,6 +59,16 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti | |||
... Row(label=0.0, weight=2.0, features=Vectors.sparse(1, [], []))]).toDF() | |||
>>> lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight") | |||
>>> model = lr.fit(df) | |||
>>> emap = lr.extractParamMap() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style:
emap
-> estimator_paramMap
mmap
-> model_paramMap
?
Should we start having
as done in the Scala side? |
@@ -243,7 +240,7 @@ def __init__(self, java_model=None): | |||
""" | |||
Initialize this instance with a Java model object. | |||
Subclasses should call this constructor, initialize params, | |||
and then call _transfer_params_from_java. | |||
and then call _transformer_params. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure you intended this change.
2f9417c
to
8d092e8
Compare
@MechCoder I made the changes for emap -> estimator_paramMap, mmap -> model_paramMap, and (param, value) -> param, value. |
@@ -336,6 +336,11 @@ def hasParam(self, paramName): | |||
return isinstance(p, Param) | |||
else: | |||
raise TypeError("hasParam(): paramName must be a string") | |||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this code is reachable, is this necessary?
Hey @evanyc15 , this is looking pretty good. I had a couple initial comments, but I'll have to look at it more in depth since there's a lot of changes. Mind resolving the conflicts first? |
8d7aedb
to
e12cbd7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@evanyc15 it looks good in general.
I think the doctests you have here would be better suited as unit tests because it's an area normal users would touch, and you could reduce some code duplication.
This PR might be easier for others to review if you pick a single estimator/model like LogisticRegression and demonstrate this change once before changing it everywhere.
>>> estimator_paramMap = lr.extractParamMap() | ||
>>> model_paramMap = model.extractParamMap() | ||
>>> all([estimator_paramMap[getattr(lr, param.name)] == value | ||
... for param, value in model_paramMap.items()]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this comparison should be the other way around. Here you check that each param in the model is in the estimator, but it should be checking that each param in estimator made it to the model.
>>> all([param.parent == model.uid for param in model_paramMap]) | ||
True | ||
>>> [param.name for param in model.params] # doctest: +NORMALIZE_WHITESPACE | ||
['elasticNetParam', 'featuresCol', 'fitIntercept', 'labelCol', 'maxIter', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need this test, it's too brittle and doesn't really add much from the tests above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 for moving these kinds of tests to unit tests. Here, they make the documentation example confusing.
class LogisticRegressionModel(JavaModel, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, | ||
HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol, | ||
HasElasticNetParam, HasFitIntercept, HasStandardization, | ||
HasThresholds, JavaMLWritable, JavaMLReadable): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed this is missing a couple shared params like HasWeightCol
@@ -748,8 +785,9 @@ def _create_model(self, java_model): | |||
return RandomForestClassificationModel(java_model) | |||
|
|||
|
|||
class RandomForestClassificationModel(TreeEnsembleModel, JavaClassificationModel, JavaMLWritable, | |||
JavaMLReadable): | |||
class RandomForestClassificationModel(TreeEnsembleModel, HasFeaturesCol, HasLabelCol, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this needs JavaClassificationModel
@@ -900,8 +947,8 @@ def getLossType(self): | |||
return self.getOrDefault(self.lossType) | |||
|
|||
|
|||
class GBTClassificationModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, | |||
JavaMLReadable): | |||
class GBTClassificationModel(TreeEnsembleModel, HasFeaturesCol, HasLabelCol, HasPredictionCol, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this needs JavaPredictionModel
@@ -560,6 +586,7 @@ class TreeRegressorParams(Params): | |||
""" | |||
|
|||
supportedImpurities = ["variance"] | |||
# a placeholder to make it appear in the generated doc |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this necessary? maybe just a empty line will do
Abstraction for Decision Tree models. | ||
class DecisionTreeModel(JavaModel, JavaPredictionModel, | ||
HasFeaturesCol, HasLabelCol, HasPredictionCol): | ||
"""Abstraction for Decision Tree models. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
newline after quotes
JavaMLReadable): | ||
class RandomForestRegressionModel(TreeEnsembleModel, JavaPredictionModel, HasFeaturesCol, | ||
HasLabelCol, HasPredictionCol, | ||
JavaMLWritable, JavaMLReadable): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: you could probably merge these 2 lines
@@ -1116,7 +1183,7 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi | |||
| 1.0| [1.0]| 1.0| 1.0| | |||
| 0.0|(1,[],[])| 0.0| 1.0| | |||
+-----+---------+------+----------+ | |||
... | |||
<BLANKLINE> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure why this is needed
@@ -200,7 +197,6 @@ def _create_model(self, java_model): | |||
def _fit_java(self, dataset): | |||
""" | |||
Fits a Java model to the input dataset. | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these blank lines should stay. Even though they are private functions, that's how it's usually done for Sphinx I believe
@evanyc15 it looks good in general. I think the doctests you have here would be better suited as unit tests because it's not an area normal users would care about and it could reduce code duplication It might also be easier for others to review this PR if you picked a single estimator/model and demonstrated this change once to get feedback before applying it everywhere. |
It might also be good to discuss grouping the param mixins, similar to how it's done in Scala, so that both the estimator and model can inherit from a single common trait. This way you could be sure they will contain the same shared params. |
ok to test Sorry for the delay on this, but it'd be great to fix now! |
Sounds good Joseph. I'll resolve the conflicts. |
fc11247
to
e706c7e
Compare
@jkbradley Hey Joseph, I've resolved the merge conflicts. Can you please test? |
Huh I'm not sure why jenkins isn't picking this up - @jkbradley or @davidnavas can you tell jenkins this is ok to test again? |
@holdenk happy to help if I can, is there something a mere mortal like myself can accomplish? [dunno how to poke jenkins hereabouts] |
Test build #3328 has finished for PR 14653 at commit
|
Looks like jenkins has picked it up. Maybe @evanyc15 can merge in master (or rebase on master) so jenkins re-runs and verify the tests? |
e706c7e
to
eb7ca31
Compare
@holdenk I just rebased and pushed again. Hopefully, Jenkins passes this time |
@davidnavas |
Sadly, I don't have that superpower :( Leastwise not that I know. Perhaps Holden's appeal to @jkbradley was what worked last time? |
jenkins test this please |
Test build #66953 has finished for PR 14653 at commit
|
@MLnick @jkbradley Do you mind merging the PR? Thank you |
@evanyc15 OK back for real now...sorry for the delay. @BryanCutler has a lot of good comments. Could you please address them? Regarding splitting this up into multiple PRs, I strongly +1 that in general, though I'm OK if you want to do this as a batch. I'll test this PR out now... |
Having some trouble b/c the doc build was apparently broken 19 days ago. Looking into a fix now. ~ Doc build is being fixed... |
eb7ca31
to
80303bf
Compare
One good set of unit tests might emulate |
Hey @jkbradley the checkParams method already exists in the Python side. It's defined in the tests.py DefaultValuesTests class and is being called by test_java_params. I'm removing the param testing from the Python Doctests now and will be implementing the Unit test in one of the classes for now. Once approved, I will then implement the Unit test in the remaining classes. |
Copied parameters over from Estimator to Transformer Estimator UID is being copied correctly to the Transformer model objects and params now, working on Doctests Changed the way parameters are copied from the Estimator to Transformer Checkpoint, switching back to inheritance method Working on DocTests Implemented Doctests for Recommendation, Clustering, Classification (except RandomForestClassifier), Evaluation, Tuning, Regression (except RandomRegression) Ready for Code Review Code Review changeset apache#1
80303bf
to
05c11f4
Compare
Can one of the admins verify this patch? |
@evanyc15 would you mind closing this PR? Thanks! |
What changes were proposed in this pull request?
Changed PySpark models to include the Param values.
Refer to the closed PR #10270 for additional information.
How was this patch tested?
Tested using Python doctests
Changesets:
Estimator UID is being copied correctly to the Transformer model objects and params now, working on Doctests
Changed the way parameters are copied from the Estimator to Transformer
Checkpoint, switching back to inheritance method
Working on DocTests
Implemented Doctests for Recommendation, Clustering, Classification (except RandomForestClassifier), Evaluation, Tuning, Regression (except RandomRegression)
Ready for Code Review
Code Review changeset #1