Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds needs_fitting property to ComponentBase #1044

Merged
merged 10 commits into from Aug 13, 2020
1 change: 1 addition & 0 deletions docs/source/release_notes.rst
Expand Up @@ -11,6 +11,7 @@ Release Notes
* Updated TextFeaturizer component to no longer require an internet connection to run :pr:`1022`
* Fixed non-deterministic element of TextFeaturizer transformations :pr:`1022`
* Changes
* Added `needs_fitting` property to ComponentBase :pr:`1044`
* Updated references to data types to use datatype lists defined in `evalml.utils.gen_utils` :pr:`1039`
* Documentation Changes
* Update setup.py URL to point to the github repo :pr:`1037`
Expand Down
9 changes: 6 additions & 3 deletions evalml/pipelines/components/component_base.py
Expand Up @@ -22,8 +22,6 @@ class ComponentBaseMeta(ABCMeta):
"""Metaclass that overrides creating a new component by wrapping method with validators and setters"""
from evalml.exceptions import ComponentNotYetFittedError

NO_FITTING_REQUIRED = ['DropColumns', 'SelectColumns']

@classmethod
def set_fit(cls, method):
@wraps(method)
Expand All @@ -41,7 +39,7 @@ def check_for_fit(cls, method):
@wraps(method)
def _check_for_fit(self, X=None, y=None):
klass = type(self).__name__
if not self._is_fitted and klass not in cls.NO_FITTING_REQUIRED:
if not self._is_fitted and self.needs_fitting:
raise ComponentNotYetFittedError(f'This {klass} is not fitted yet. You must fit {klass} before calling {method.__name__}.')
elif X is None and y is None:
return method(self)
Expand Down Expand Up @@ -91,6 +89,11 @@ def name(cls):
def model_family(cls):
"""Returns ModelFamily of this component"""

@classproperty
def needs_fitting(self):
"""Returns boolean determining if component needs fitting before calling predict, predict_proba, transform, or feature_importances."""
angela97lin marked this conversation as resolved.
Show resolved Hide resolved
return True

@property
def parameters(self):
"""Returns the parameters which were used to initialize the component"""
Expand Down
2 changes: 2 additions & 0 deletions evalml/pipelines/components/transformers/column_selectors.py
Expand Up @@ -85,6 +85,7 @@ class DropColumns(ColumnSelector):
"""Drops specified columns in input data."""
name = "Drop Columns Transformer"
hyperparameter_ranges = {}
needs_fitting = False

def _modify_columns(self, cols, X, y=None):
return X.drop(columns=cols, axis=1)
Expand All @@ -106,6 +107,7 @@ class SelectColumns(ColumnSelector):
"""Selects specified columns in input data."""
name = "Select Columns Transformer"
hyperparameter_ranges = {}
needs_fitting = False

def _modify_columns(self, cols, X, y=None):
return X[cols]
Expand Down
15 changes: 11 additions & 4 deletions evalml/tests/component_tests/test_components.py
Expand Up @@ -15,7 +15,6 @@
from evalml.model_family import ModelFamily
from evalml.pipelines.components import (
ComponentBase,
ComponentBaseMeta,
DropColumns,
ElasticNetClassifier,
ElasticNetRegressor,
Expand Down Expand Up @@ -645,10 +644,18 @@ def transform(self, X):
transformer_subclass.transform(X)


def test_all_transformers_needs_fitting():
for component_class in _all_transformers + _all_estimators:
if component_class.__name__ in ['DropColumns', 'SelectColumns']:
assert not component_class.needs_fitting
else:
assert component_class.needs_fitting
angela97lin marked this conversation as resolved.
Show resolved Hide resolved


def test_all_transformers_check_fit(X_y_binary):
X, y = X_y_binary
for component_class in _all_transformers:
if component_class.__name__ in ComponentBaseMeta.NO_FITTING_REQUIRED:
if not component_class.needs_fitting:
continue

component = component_class()
Expand All @@ -666,7 +673,7 @@ def test_all_transformers_check_fit(X_y_binary):
def test_all_estimators_check_fit(X_y_binary):
X, y = X_y_binary
for component_class in _all_estimators:
if component_class.__name__ in ComponentBaseMeta.NO_FITTING_REQUIRED:
if not component_class.needs_fitting:
continue

component = component_class()
Expand All @@ -692,7 +699,7 @@ def test_all_estimators_check_fit(X_y_binary):
def test_no_fitting_required_components(X_y_binary):
X, y = X_y_binary
for component_class in all_components:
if component_class.__name__ in ComponentBaseMeta.NO_FITTING_REQUIRED:
if not component_class.needs_fitting:
component = component_class()
if issubclass(component_class, Estimator):
component.predict(X)
Expand Down