diff --git a/neuraxle/base.py b/neuraxle/base.py index 986e7379..d0c9e89e 100644 --- a/neuraxle/base.py +++ b/neuraxle/base.py @@ -1651,7 +1651,7 @@ def save(self, context: ExecutionContext, full_dump=False) -> 'BaseTransformer': def _initialize_if_needed(step): if not step.is_initialized: - step.setup() + step.setup(context=context) return RecursiveDict() def _invalidate(step): @@ -2078,13 +2078,14 @@ def __init__( self.is_initialized = False self.is_train: bool = True - def setup(self) -> 'BaseTransformer': + def setup(self, context: ExecutionContext = None) -> 'BaseTransformer': """ Initialize the step before it runs. Only from here and not before that heavy things should be created (e.g.: things inside GPU), and NOT in the constructor. The setup method is called for each step before any fit, or fit_transform. + :param context: execution context :return: self """ self.is_initialized = True @@ -2426,14 +2427,15 @@ def set_step(self, step: BaseTransformer) -> BaseStep: self.wrapped: BaseTransformer = _sklearn_to_neuraxle_step(step) return self - def setup(self) -> BaseStep: + def setup(self, context: ExecutionContext = None) -> BaseStep: """ Initialize step before it runs. Also initialize the wrapped step. + :param context: execution context :return: self """ - super().setup() - self.wrapped.setup() + super().setup(context=context) + self.wrapped.setup(context=context) self.is_initialized = True return self @@ -2873,10 +2875,11 @@ def set_steps(self, steps_as_tuple: NamedTupleList): self.steps_as_tuple: NamedTupleList = self._patch_missing_names(steps_as_tuple) self._refresh_steps() - def setup(self) -> 'BaseTransformer': + def setup(self, context: ExecutionContext = None) -> 'BaseTransformer': """ Initialize step before it runs. + :param context: execution context :return: self """ if self.is_initialized: diff --git a/neuraxle/distributed/streaming.py b/neuraxle/distributed/streaming.py index 6ad9ca46..e9120560 100644 --- a/neuraxle/distributed/streaming.py +++ b/neuraxle/distributed/streaming.py @@ -414,19 +414,20 @@ def _will_process(self, data_container: DataContainer, context: ExecutionContext :param context: execution context :return: """ - self.setup() + self.setup(context=context) return data_container, context - def setup(self) -> 'BaseTransformer': + def setup(self, context: ExecutionContext = None) -> 'BaseTransformer': """ Connect the queued workers together so that the data can correctly flow through the pipeline. + :param context: execution context :return: step :rtype: BaseStep """ if not self.is_initialized: self.connect_queued_pipeline() - super().setup() + super().setup(context=context) return self def fit_transform_data_container(self, data_container: DataContainer, context: ExecutionContext) -> ('Pipeline', DataContainer): diff --git a/neuraxle/pipeline.py b/neuraxle/pipeline.py index a1784f23..60220599 100644 --- a/neuraxle/pipeline.py +++ b/neuraxle/pipeline.py @@ -144,14 +144,14 @@ def _fit_data_container(self, data_container: DataContainer, context: ExecutionC :return: tuple(pipeline, data_container) """ steps_left_to_do, data_container = self._load_checkpoint(data_container, context) - self.setup() + self.setup(context=context) index_last_step = len(steps_left_to_do) - 1 new_steps_as_tuple: NamedTupleList = [] for index, (step_name, step) in enumerate(steps_left_to_do): - step.setup() + step.setup(context=context) if index != index_last_step: step, data_container = step.handle_fit_transform(data_container, context) @@ -174,12 +174,12 @@ def _fit_transform_data_container(self, data_container: DataContainer, context: :return: tuple(pipeline, data_container) """ steps_left_to_do, data_container = self._load_checkpoint(data_container, context) - self.setup() + self.setup(context=context) new_steps_as_tuple: NamedTupleList = [] for step_name, step in steps_left_to_do: - step.setup() + step.setup(context=context) step, data_container = step.handle_fit_transform(data_container, context) new_steps_as_tuple.append((step_name, step)) @@ -399,7 +399,7 @@ def fit_data_container(self, data_container: DataContainer, context: ExecutionCo index_start = 0 for sub_pipeline in sub_pipelines: - sub_pipeline.setup() + sub_pipeline.setup(context=context) barrier = sub_pipeline[-1] sub_pipeline, data_container = barrier.join_fit_transform( @@ -430,7 +430,7 @@ def fit_transform_data_container(self, data_container: DataContainer, context: E index_start = 0 for sub_pipeline in sub_pipelines: - sub_pipeline.setup() + sub_pipeline.setup(context=context) barrier = sub_pipeline[-1] sub_pipeline, data_container = barrier.join_fit_transform( diff --git a/testing/test_pipeline_setup_teardown.py b/testing/test_pipeline_setup_teardown.py index 57e781d6..08d3e219 100644 --- a/testing/test_pipeline_setup_teardown.py +++ b/testing/test_pipeline_setup_teardown.py @@ -33,7 +33,7 @@ def __init__(self): SomeStep.__init__(self) self.called_with = None - def setup(self) -> 'BaseStep': + def setup(self, context: ExecutionContext = None) -> 'BaseStep': self.is_initialized = True return self