-
Notifications
You must be signed in to change notification settings - Fork 90
Component parameters: save all fields for builtin components #847
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -19,7 +19,7 @@ def __init__(self, columns=None, random_state=0): | |||
random_state=random_state) | |||
|
|||
def _check_input_for_columns(self, X): | |||
cols = self.parameters["columns"] or [] | |||
cols = self.parameters.get("columns", []) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a style tweak, not important
OneHotEncoder, | ||
PerColumnImputer, | ||
RFClassifierSelectFromModel, | ||
RFRegressorSelectFromModel, | ||
SimpleImputer, | ||
StandardScaler, | ||
Transformer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its good to remove these -- they shouldn't show up in all_components()
because they're base classes and would never be used in a pipeline directly. Right now, any subclass of ComponentBase
which is imported in this file will be included in all_components()
.
@@ -47,7 +43,7 @@ def _components_dict(): | |||
components = dict() | |||
for _, obj in inspect.getmembers(sys.modules[__name__], inspect.isclass): | |||
params = inspect.getargspec(obj.__init__) | |||
if issubclass(obj, ComponentBase): | |||
if issubclass(obj, ComponentBase) and obj is not ComponentBase: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This excludes ComponentBase
from all_components()
assert defaults.pop(-1) == 0 | ||
|
||
expected_parameters = {arg: default for (arg, default) in zip(args, defaults)} | ||
assert parameters == expected_parameters |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the main test. It introspects on the __init__
of all components listed in all_components()
, and expects they've saved all args in parameters
except for random_state
, which should be the last __init__
arg.
I'd love to find a way to get this code into a metaclass so it can be used to raise exceptions at class-definition-time when a user defines a custom component. I was working on that a couple months ago, but... it's hard!! 😂
component2 = component_class(**parameters) | ||
parameters2 = component2.parameters | ||
|
||
assert parameters == parameters2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test ensures that instantiating a component with no parameters and instantiating a component with default parameters produces the same parameters attached to the instance.
We need to add default_parameters
, will file that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"""A component that fits and predicts given data""" | ||
"""A component that fits and predicts given data | ||
|
||
To implement a new Transformer, define your own class which is a subclass of Transformer, including |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean Estimator here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, thanks, good eye, will fix!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cb_classifier = catboost.CatBoostClassifier(**parameters, | ||
# catboost will choose an intelligent default for bootstrap_type, so only set if provided | ||
cb_parameters = copy.copy(parameters) | ||
if bootstrap_type is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this inversion!
Fix #522 fix #355
Changes
ComponentBase.parameters
getter, to keep the saved parameters immutable after component initComponentBase.__init__
optional with defaults:parameters
(empty dict),component_obj
(None) andrandom_state
(0).parameters
(except forrandom_state
)all_components
ComponentBase
,Transformer
,Estimator
and other transformer base classes fromall_components
list