Skip to content

Commit

Permalink
Merge branch 'master' of github.com:Neuraxio/Neuraxle into setup-with…
Browse files Browse the repository at this point in the history
…-context
  • Loading branch information
alexbrillant committed Jul 20, 2020
2 parents f379109 + 0fd53cb commit 62eb169
Showing 1 changed file with 35 additions and 5 deletions.
40 changes: 35 additions & 5 deletions neuraxle/steps/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def __init__(
self,
wrapped_sklearn_predictor,
hyperparams_space: HyperparameterSpace = None,
return_all_sklearn_default_params_on_get=False
return_all_sklearn_default_params_on_get: bool = False,
use_partial_fit: bool = False
):
if not isinstance(wrapped_sklearn_predictor, BaseEstimator):
raise ValueError("The wrapped_sklearn_predictor must be an instance of scikit-learn's BaseEstimator.")
Expand All @@ -50,14 +51,15 @@ def __init__(
BaseStep.__init__(self, hyperparams=params, hyperparams_space=hyperparams_space)
self.return_all_sklearn_default_params_on_get = return_all_sklearn_default_params_on_get
self.name += "_" + wrapped_sklearn_predictor.__class__.__name__
self.partial_fit: bool = use_partial_fit

def fit_transform(self, data_inputs, expected_outputs=None) -> ('BaseStep', Any):

if hasattr(self.wrapped_sklearn_predictor, 'fit_transform'):
if expected_outputs is None or len(inspect.getfullargspec(self.wrapped_sklearn_predictor.fit).args) < 3:
out = self.wrapped_sklearn_predictor.fit_transform(data_inputs)
out = self._sklearn_fit_transform_without_expected_outputs(data_inputs)
else:
out = self.wrapped_sklearn_predictor.fit_transform(data_inputs, expected_outputs)
out = self._sklearn_fit_transform_with_expected_outputs(data_inputs, expected_outputs)
return self, out

self.fit(data_inputs, expected_outputs)
Expand All @@ -66,13 +68,41 @@ def fit_transform(self, data_inputs, expected_outputs=None) -> ('BaseStep', Any)
return self, self.wrapped_sklearn_predictor.predict(data_inputs)
return self, self.wrapped_sklearn_predictor.transform(data_inputs)

def _sklearn_fit_transform_with_expected_outputs(self, data_inputs, expected_outputs):
if self.partial_fit:
self.wrapped_sklearn_predictor = self.wrapped_sklearn_predictor.partial_fit(data_inputs, expected_outputs)
out = self.wrapped_sklearn_predictor.transform(data_inputs)
else:
out = self.wrapped_sklearn_predictor.fit_transform(data_inputs, expected_outputs)
return out

def _sklearn_fit_transform_without_expected_outputs(self, data_inputs):
if self.partial_fit:
self.wrapped_sklearn_predictor = self.wrapped_sklearn_predictor.partial_fit(data_inputs)
out = self.wrapped_sklearn_predictor.transform(data_inputs)
else:
out = self.wrapped_sklearn_predictor.fit_transform(data_inputs)
return out

def fit(self, data_inputs, expected_outputs=None) -> 'SKLearnWrapper':
if expected_outputs is None or len(inspect.getfullargspec(self.wrapped_sklearn_predictor.fit).args) < 3:
self.wrapped_sklearn_predictor = self.wrapped_sklearn_predictor.fit(data_inputs)
self._sklearn_fit_without_expected_outputs(data_inputs)
else:
self.wrapped_sklearn_predictor = self.wrapped_sklearn_predictor.fit(data_inputs, expected_outputs)
self._sklearn_fit_with_expected_outputs(data_inputs, expected_outputs)
return self

def _sklearn_fit_with_expected_outputs(self, data_inputs, expected_outputs):
if self.partial_fit:
self.wrapped_sklearn_predictor = self.wrapped_sklearn_predictor.partial_fit(data_inputs, expected_outputs)
else:
self.wrapped_sklearn_predictor = self.wrapped_sklearn_predictor.fit(data_inputs, expected_outputs)

def _sklearn_fit_without_expected_outputs(self, data_inputs):
if self.partial_fit:
self.wrapped_sklearn_predictor = self.wrapped_sklearn_predictor.partial_fit(data_inputs)
else:
self.wrapped_sklearn_predictor = self.wrapped_sklearn_predictor.fit(data_inputs)

def transform(self, data_inputs):
if hasattr(self.wrapped_sklearn_predictor, 'predict'):
return self.wrapped_sklearn_predictor.predict(data_inputs)
Expand Down

0 comments on commit 62eb169

Please sign in to comment.