-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
OneHotEncoder: expose categories detected for each feature during fit #1182
Conversation
Codecov Report
@@ Coverage Diff @@
## main #1182 +/- ##
=======================================
Coverage 99.92% 99.92%
=======================================
Files 196 196
Lines 11710 11729 +19
=======================================
+ Hits 11701 11720 +19
Misses 9 9
Continue to review full report at Codecov.
|
METHODS_TO_CHECK = ComponentBaseMeta.METHODS_TO_CHECK + ['categories'] | ||
|
||
|
||
class OneHotEncoder(Transformer, metaclass=OneHotEncoderMeta): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I generalized BaseMeta
so that we can add to the list of methods to check, for specific subclasses. Because in this case for one-hot encoder, if a user calls categories
before fit
, we want an error to get thrown.
It would be so cool if we found a way to roll this all up in a decorator... but I'm not sure how at the moment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice change!
self._encoder = None | ||
super().__init__(parameters=parameters, | ||
component_obj=None, | ||
random_state=random_state) | ||
|
||
def _get_cat_cols(self, X): | ||
@staticmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unrelated change, why not :)
@@ -72,24 +80,24 @@ def fit(self, X, y=None): | |||
if not isinstance(X, pd.DataFrame): | |||
X = pd.DataFrame(X) | |||
X_t = X | |||
cols_to_encode = self._get_cat_cols(X_t) | |||
self._cols_to_encode = self._get_cat_cols(X_t) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We now save _cols_to_encode
, so that in categories
we can index into the sklearn encoder's category list correctly.
property_orig = dct[attribute] | ||
dct[attribute] = property(cls.check_for_fit(property_orig.__get__), | ||
property_orig.__set__, | ||
property_orig.__delattr__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Building off @jeremyliweishih 's great code
81c3332
to
2704eaa
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dsherry Looks good!
METHODS_TO_CHECK = ComponentBaseMeta.METHODS_TO_CHECK + ['categories'] | ||
|
||
|
||
class OneHotEncoder(Transformer, metaclass=OneHotEncoderMeta): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice change!
2704eaa
to
ab3b754
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Fix #1180
Added an API to the one-hot-encoder for accessing the list of categories associated with a given feature.
Also, generalized the
BaseMeta
abstraction to support subclasses overriding the list of methods to be validated.Will retarget against main once #1179 is merged