-
Notifications
You must be signed in to change notification settings - Fork 83
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Standardize error when calling transform/predict before fit for pipel…
…ines (#1048) * init * add metaclass subclasses * remove stored err * add test * update file hierarchy
- Loading branch information
1 parent
2ecf3bd
commit c0ad9f8
Showing
10 changed files
with
162 additions
and
65 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
|
||
|
||
from functools import wraps | ||
|
||
from evalml.exceptions import ComponentNotYetFittedError | ||
from evalml.utils.base_meta import BaseMeta | ||
|
||
|
||
class ComponentBaseMeta(BaseMeta): | ||
"""Metaclass that overrides creating a new component by wrapping methods with validators and setters""" | ||
|
||
@classmethod | ||
def check_for_fit(cls, method): | ||
"""`check_for_fit` wraps a method that validates if `self._is_fitted` is `True`. | ||
It raises an exception if `False` and calls and returns the wrapped method if `True`. | ||
""" | ||
@wraps(method) | ||
def _check_for_fit(self, X=None, y=None): | ||
klass = type(self).__name__ | ||
if not self._is_fitted and self.needs_fitting: | ||
raise ComponentNotYetFittedError(f'This {klass} is not fitted yet. You must fit {klass} before calling {method.__name__}.') | ||
elif X is None and y is None: | ||
return method(self) | ||
elif y is None: | ||
return method(self, X) | ||
else: | ||
return method(self, X, y) | ||
return _check_for_fit |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
|
||
|
||
from functools import wraps | ||
|
||
from evalml.exceptions import PipelineNotYetFittedError | ||
from evalml.utils.base_meta import BaseMeta | ||
|
||
|
||
class PipelineBaseMeta(BaseMeta): | ||
"""Metaclass that overrides creating a new pipeline by wrapping methods with validators and setters""" | ||
|
||
@classmethod | ||
def check_for_fit(cls, method): | ||
"""`check_for_fit` wraps a method that validates if `self._is_fitted` is `True`. | ||
It raises an exception if `False` and calls and returns the wrapped method if `True`. | ||
""" | ||
@wraps(method) | ||
def _check_for_fit(self, X=None, y=None): | ||
klass = type(self).__name__ | ||
if not self._is_fitted: | ||
raise PipelineNotYetFittedError(f'This {klass} is not fitted yet. You must fit {klass} before calling {method.__name__}.') | ||
elif X is None and y is None: | ||
return method(self) | ||
elif y is None: | ||
return method(self, X) | ||
else: | ||
return method(self, X, y) | ||
return _check_for_fit |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
|
||
|
||
from abc import ABCMeta | ||
from functools import wraps | ||
|
||
|
||
class BaseMeta(ABCMeta): | ||
"""Metaclass that overrides creating a new component or pipeline by wrapping methods with validators and setters""" | ||
|
||
@classmethod | ||
def set_fit(cls, method): | ||
@wraps(method) | ||
def _set_fit(self, X, y=None): | ||
return_value = method(self, X, y) | ||
self._is_fitted = True | ||
return return_value | ||
return _set_fit | ||
|
||
def __new__(cls, name, bases, dct): | ||
if 'predict' in dct: | ||
dct['predict'] = cls.check_for_fit(dct['predict']) | ||
if 'predict_proba' in dct: | ||
dct['predict_proba'] = cls.check_for_fit(dct['predict_proba']) | ||
if 'transform' in dct: | ||
dct['transform'] = cls.check_for_fit(dct['transform']) | ||
if 'feature_importance' in dct: | ||
fi = dct['feature_importance'] | ||
new_fi = property(cls.check_for_fit(fi.__get__), fi.__set__, fi.__delattr__) | ||
dct['feature_importance'] = new_fi | ||
if 'fit' in dct: | ||
dct['fit'] = cls.set_fit(dct['fit']) | ||
if 'fit_transform' in dct: | ||
dct['fit_transform'] = cls.set_fit(dct['fit_transform']) | ||
return super().__new__(cls, name, bases, dct) |