-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
REF: Hardcode Keras params #47
Conversation
Codecov Report
@@ Coverage Diff @@
## master #47 +/- ##
==========================================
- Coverage 99.53% 99.51% -0.02%
==========================================
Files 3 3
Lines 426 414 -12
==========================================
- Hits 424 412 -12
Misses 2 2
Continue to review full report at Codecov.
|
5ce3332
to
a9b8715
Compare
class TestUnsetParameter: | ||
"""Tests for appropriate error on unfitted models. | ||
""" | ||
|
||
@pytest.mark.filterwarnings("ignore::FutureWarning") | ||
def test_unset_input_parameter(self): | ||
class ClassBuildFnClf(wrappers.KerasClassifier): | ||
def __init__(self, input_param): | ||
# does not set input_param | ||
super().__init__() | ||
|
||
def _keras_build_fn(self, hidden_dim): | ||
return build_fn_clf(hidden_dim) | ||
|
||
with pytest.raises(RuntimeError): | ||
ClassBuildFnClf(input_param=10) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing this check and test. It is really a redundant feature that no other estimators have. It was much more helpful previously when the interface was less defined.
return ( | ||
k | ||
for k in self.__dict__ | ||
if not k.endswith("_") and not k.startswith("_") | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is borrowed from Skorch. For where we are headed, I think this makes more sense than storing _sk_params
. It'll be up to users to make sure they don't set any params without _
that they don't want to show up in get_params
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added k.startswith("_")
to accomodate parameters like self._random_state
. So the rule will be that all model paramers (i.e. the __init__
ones) cannot start/end with _
, at least for now. We'll have to see how that plays in with my overwriting proposal in #47 (comment)
self.initial_epoch = initial_epoch | ||
|
||
# Unpack kwargs | ||
vars(self).update(**kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also borrowed from Skorch. Much cleaner with the new _get_param_names
implementation.
tagging @stsievert in case you want to take a look |
@stsievert do you think we are good to go here or is there anything else you see that should be changed? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good. Will returning un-compiled models from build_fn
be supported in the next release as per #50 (comment)? Otherwise I'm not sure some parameters here make any sense (e.g, loss
, optimizer
).
I've glanced at the Keras.fit API. There might be some missing parameters:
use_multiprocessing
,workers
. This might be important for performance, though I suspect threading will have the same or higher performance because this loading data is IO bound. It might be good to leave it to the user though.- I think Dask will require
use_multiprocessing=False
, but haven't verified.
- I think Dask will require
steps_per_epoch
,sample_weight
,class_weight
.
scikeras/wrappers.py
Outdated
verbose=1, | ||
steps=None, | ||
callbacks=None, | ||
epochs=1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe rename to max_epochs
for the fit
implementation. I think partial_fit
should always have epochs=1
.
scikeras/wrappers.py
Outdated
epochs=1, | ||
validation_split=0.0, | ||
shuffle=True, | ||
initial_epoch=0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This only makes sense if a pretrained model is passed to build_fn
. I'd be okay removing support for this, especially if max_epochs
is supported.
scikeras/wrappers.py
Outdated
optimizer="rmsprop", | ||
loss=None, | ||
metrics=None, | ||
run_eagerly=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does SciKeras have tests for run_eagerly in [True, False]
? It seems run_eagerly
will affect the serialization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, we do not. Maybe I'll just remove that parameter for now, I don't want to go down that rabbit hole in this PR...
*, | ||
random_state=None, | ||
optimizer="rmsprop", | ||
loss=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where are optimizer
and loss
used? Are they passed to the compile
method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As it stands right now, it would be up to the user to declare their build_fn
to use them. I.e.:
def build_fn(optimizer=..., loss=...):
model = Model()
...
model.compile(loss=loss, optimizer=optimizer)
The reasons they are not being included are:
Regarding |
Yes. Scikit-Learn's SGDClassifier/SGDRegressor/MLPClassifier do the same thing. Here's the docstring for MLPClassifier.partial_fit: "Update the model with a single iteration over the given data."
I think the difference is semantic. I don't think it's super important. I mention it because Scikit-Learn calls this parameter |
Hmm okay. I think what I'd like to do is keep this PR smallish and leave the following parameters for a future PR (that can also contain any logic related to their handling, like what you mention about
|
This implements hardcoding of common Keras parameters in
BaseWrapper.__init__
as discussed in #37.The parameters chosen come from
Model.compile
,Model.predict
andModel.fit
. Parameters were not hardcoded if:steps_per_epoch
).fit
directly because they depend on the data (ex:sample_weights
,class_weights
). I thinkvalidation_data
also falls into this category, but I am not sure if this makes much sense to implement since AFAIK this is not used anywhere else in Scikit-Learn. I am going to leave this for another PR where I'll also deprectatefit
's**kwargs
.