Skip to content

Commit

Permalink
Merge pull request #366 from alexbrillant/setup-with-context
Browse files Browse the repository at this point in the history
Feature: Pass context in setup #365
  • Loading branch information
alexbrillant committed Jul 20, 2020
2 parents 0fd53cb + 62eb169 commit 21215e2
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 16 deletions.
15 changes: 9 additions & 6 deletions neuraxle/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1651,7 +1651,7 @@ def save(self, context: ExecutionContext, full_dump=False) -> 'BaseTransformer':

def _initialize_if_needed(step):
if not step.is_initialized:
step.setup()
step.setup(context=context)
return RecursiveDict()

def _invalidate(step):
Expand Down Expand Up @@ -2078,13 +2078,14 @@ def __init__(
self.is_initialized = False
self.is_train: bool = True

def setup(self) -> 'BaseTransformer':
def setup(self, context: ExecutionContext = None) -> 'BaseTransformer':
"""
Initialize the step before it runs. Only from here and not before that heavy things should be created
(e.g.: things inside GPU), and NOT in the constructor.
The setup method is called for each step before any fit, or fit_transform.
:param context: execution context
:return: self
"""
self.is_initialized = True
Expand Down Expand Up @@ -2426,14 +2427,15 @@ def set_step(self, step: BaseTransformer) -> BaseStep:
self.wrapped: BaseTransformer = _sklearn_to_neuraxle_step(step)
return self

def setup(self) -> BaseStep:
def setup(self, context: ExecutionContext = None) -> BaseStep:
"""
Initialize step before it runs. Also initialize the wrapped step.
:param context: execution context
:return: self
"""
super().setup()
self.wrapped.setup()
super().setup(context=context)
self.wrapped.setup(context=context)
self.is_initialized = True
return self

Expand Down Expand Up @@ -2873,10 +2875,11 @@ def set_steps(self, steps_as_tuple: NamedTupleList):
self.steps_as_tuple: NamedTupleList = self._patch_missing_names(steps_as_tuple)
self._refresh_steps()

def setup(self) -> 'BaseTransformer':
def setup(self, context: ExecutionContext = None) -> 'BaseTransformer':
"""
Initialize step before it runs.
:param context: execution context
:return: self
"""
if self.is_initialized:
Expand Down
7 changes: 4 additions & 3 deletions neuraxle/distributed/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,19 +414,20 @@ def _will_process(self, data_container: DataContainer, context: ExecutionContext
:param context: execution context
:return:
"""
self.setup()
self.setup(context=context)
return data_container, context

def setup(self) -> 'BaseTransformer':
def setup(self, context: ExecutionContext = None) -> 'BaseTransformer':
"""
Connect the queued workers together so that the data can correctly flow through the pipeline.
:param context: execution context
:return: step
:rtype: BaseStep
"""
if not self.is_initialized:
self.connect_queued_pipeline()
super().setup()
super().setup(context=context)
return self

def fit_transform_data_container(self, data_container: DataContainer, context: ExecutionContext) -> ('Pipeline', DataContainer):
Expand Down
12 changes: 6 additions & 6 deletions neuraxle/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,14 @@ def _fit_data_container(self, data_container: DataContainer, context: ExecutionC
:return: tuple(pipeline, data_container)
"""
steps_left_to_do, data_container = self._load_checkpoint(data_container, context)
self.setup()
self.setup(context=context)

index_last_step = len(steps_left_to_do) - 1

new_steps_as_tuple: NamedTupleList = []

for index, (step_name, step) in enumerate(steps_left_to_do):
step.setup()
step.setup(context=context)

if index != index_last_step:
step, data_container = step.handle_fit_transform(data_container, context)
Expand All @@ -174,12 +174,12 @@ def _fit_transform_data_container(self, data_container: DataContainer, context:
:return: tuple(pipeline, data_container)
"""
steps_left_to_do, data_container = self._load_checkpoint(data_container, context)
self.setup()
self.setup(context=context)

new_steps_as_tuple: NamedTupleList = []

for step_name, step in steps_left_to_do:
step.setup()
step.setup(context=context)
step, data_container = step.handle_fit_transform(data_container, context)
new_steps_as_tuple.append((step_name, step))

Expand Down Expand Up @@ -399,7 +399,7 @@ def fit_data_container(self, data_container: DataContainer, context: ExecutionCo
index_start = 0

for sub_pipeline in sub_pipelines:
sub_pipeline.setup()
sub_pipeline.setup(context=context)

barrier = sub_pipeline[-1]
sub_pipeline, data_container = barrier.join_fit_transform(
Expand Down Expand Up @@ -430,7 +430,7 @@ def fit_transform_data_container(self, data_container: DataContainer, context: E
index_start = 0

for sub_pipeline in sub_pipelines:
sub_pipeline.setup()
sub_pipeline.setup(context=context)

barrier = sub_pipeline[-1]
sub_pipeline, data_container = barrier.join_fit_transform(
Expand Down
2 changes: 1 addition & 1 deletion testing/test_pipeline_setup_teardown.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self):
SomeStep.__init__(self)
self.called_with = None

def setup(self) -> 'BaseStep':
def setup(self, context: ExecutionContext = None) -> 'BaseStep':
self.is_initialized = True
return self

Expand Down

0 comments on commit 21215e2

Please sign in to comment.