Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions autosklearn/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def build_automl(self):
def fit(self, *args, **kwargs):
self._automl = self.build_automl()
self._automl.fit(*args, **kwargs)
return self

def fit_ensemble(self, y, task=None, metric=None, precision='32',
dataset_name=None, ensemble_nbest=None,
Expand Down Expand Up @@ -294,9 +295,10 @@ def fit_ensemble(self, y, task=None, metric=None, precision='32',
"""
if self._automl is None:
self._automl = self.build_automl()
return self._automl.fit_ensemble(y, task, metric, precision,
dataset_name, ensemble_nbest,
ensemble_size)
self._automl.fit_ensemble(y, task, metric, precision,
dataset_name, ensemble_nbest,
ensemble_size)
return self

def refit(self, X, y):
"""Refit all models found with fit to new data.
Expand All @@ -323,7 +325,9 @@ def refit(self, X, y):
self

"""
return self._automl.refit(X, y)
self._automl.refit(X, y)
return self


def predict(self, X, batch_size=None, n_jobs=1):
return self._automl.predict(X, batch_size=batch_size, n_jobs=n_jobs)
Expand Down Expand Up @@ -451,7 +455,7 @@ def fit(self, X, y,
self

"""
return super().fit(
super().fit(
X=X,
y=y,
X_test=X_test,
Expand All @@ -461,6 +465,8 @@ def fit(self, X, y,
dataset_name=dataset_name,
)

return self

def predict(self, X, batch_size=None, n_jobs=1):
"""Predict classes for X.

Expand Down Expand Up @@ -554,7 +560,7 @@ def fit(self, X, y,
"""
# Fit is supposed to be idempotent!
# But not if we use share_mode.
return super().fit(
super().fit(
X=X,
y=y,
X_test=X_test,
Expand All @@ -564,6 +570,10 @@ def fit(self, X, y,
dataset_name=dataset_name,
)

return self



def predict(self, X, batch_size=None, n_jobs=1):
"""Predict regression target for X.

Expand Down
2 changes: 1 addition & 1 deletion doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
#'reference_url': {
# 'autosklearn': None
#},
#'backreferences_dir': 'gen_modules/backreferences'
'backreferences_dir': False
}

# Add any paths that contain templates here, relative to this directory.
Expand Down
42 changes: 42 additions & 0 deletions test/test_automl/test_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,3 +390,45 @@ def test_conversion_of_list_to_np(self, fit_ensemble, refit, fit):
automl.fit_ensemble(y)
self.assertEqual(fit_ensemble.call_count, 1)
self.assertIsInstance(fit_ensemble.call_args[0][0], np.ndarray)


class AutoSklearnClassifierTest(unittest.TestCase):
# Currently this class only tests that the methods of AutoSklearnClassifier
# which should return self actually return self.
def test_classification_methods_returns_self(self):
X_train, y_train, X_test, y_test = putil.get_dataset('iris')
automl = AutoSklearnClassifier(time_left_for_this_task=20,
per_run_time_limit=5,
ensemble_size=0)

automl_fitted = automl.fit(X_train, y_train)
self.assertIs(automl, automl_fitted)

automl_ensemble_fitted = automl.fit_ensemble(y_train, ensemble_size=5)
self.assertIs(automl, automl_ensemble_fitted)

automl_refitted = automl.refit(X_train.copy(), y_train.copy())
self.assertIs(automl, automl_refitted)


class AutoSklearnRegressorTest(unittest.TestCase):
# Currently this class only tests that the methods of AutoSklearnRegressor
# that should return self actually return self.
def test_regression_methods_returns_self(self):
X_train, y_train, X_test, y_test = putil.get_dataset('boston')
automl = AutoSklearnRegressor(time_left_for_this_task=20,
per_run_time_limit=5,
ensemble_size=0)

automl_fitted = automl.fit(X_train, y_train)
self.assertIs(automl, automl_fitted)

automl_ensemble_fitted = automl.fit_ensemble(y_train, ensemble_size=5)
self.assertIs(automl, automl_ensemble_fitted)

automl_refitted = automl.refit(X_train.copy(), y_train.copy())
self.assertIs(automl, automl_refitted)


if __name__=="__main__":
unittest.main()