Skip to content

Commit

Permalink
checkpoint hash idea handle wip
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbrillant committed Sep 5, 2019
1 parent d995377 commit 407d7b3
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 31 deletions.
42 changes: 40 additions & 2 deletions neuraxle/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,38 @@
from neuraxle.hyperparams.space import HyperparameterSpace, HyperparameterSamples


class BaseStep(ABC):
class Hasher(ABC):
@abstractmethod
def hash(self, data_inputs: Any):
return hash(data_inputs)

@abstractmethod
def rehash(self, ids, data_inputs: Any):
return self.hash(data_inputs)


class NullHasher(Hasher):
def hash(self, data_inputs: Any):
pass

def rehash(self, ids, data_inputs: Any):
return ids


class HasherByIndex(Hasher):
def hash(self, data_inputs: Any):
return range(len(data_inputs))


class BaseStep(ABC):
def __init__(
self,
hyperparams: HyperparameterSamples = None,
hyperparams_space: HyperparameterSpace = None,
name: str = None
name: str = None,
hasher: Hasher = NullHasher()
):
self.hasher = hasher
if hyperparams is None:
hyperparams = dict()
if hyperparams_space is None:
Expand Down Expand Up @@ -82,6 +106,20 @@ def set_hyperparams_space(self, hyperparams_space: HyperparameterSpace) -> 'Base
def get_hyperparams_space(self, flat=False) -> HyperparameterSpace:
return self.hyperparams_space

@abstractmethod
def handle_fit_transform(self, ids, data_inputs, expected_outputs) -> ('BaseStep', Any):
return self.fit_transform(data_inputs, expected_outputs)

@abstractmethod
def handle_transform(self, ids, data_inputs) -> Any:
return self.transform(data_inputs)

def hash(self, data_inputs: Any):
return self.hasher.hash(data_inputs)

def rehash(self, ids, data_inputs: Any):
return self.hasher.rehash(ids, data_inputs)

def fit_transform(self, data_inputs, expected_outputs=None) -> ('BaseStep', Any):
new_self = self.fit(data_inputs, expected_outputs)
out = new_self.transform(data_inputs)
Expand Down
25 changes: 16 additions & 9 deletions neuraxle/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ def __init__(self, force_checkpoint_name: str = None):
BaseStep.__init__(self)
self.force_checkpoint_name = force_checkpoint_name

def handle_transform(self, ids, data_inputs) -> Any:
self.save_checkpoint(ids, data_inputs)
return data_inputs

def handle_fit_transform(self, ids, data_inputs, expected_outputs) -> ('BaseStep', Any):
self.save_checkpoint(ids, data_inputs)
return self, data_inputs

def fit(self, data_inputs, expected_outputs=None) -> 'BaseCheckpointStep':
"""
Save checkpoint for data inputs and expected outputs so that it can
Expand All @@ -43,8 +51,6 @@ def fit(self, data_inputs, expected_outputs=None) -> 'BaseCheckpointStep':
:param data_inputs: data inputs to save
:return: self
"""
self.save_checkpoint(data_inputs)

return self

def transform(self, data_inputs):
Expand All @@ -54,8 +60,6 @@ def transform(self, data_inputs):
:param data_inputs: data inputs to save
:return: data_inputs
"""
self.save_checkpoint(data_inputs)

return data_inputs

@abstractmethod
Expand All @@ -67,18 +71,20 @@ def set_checkpoint_path(self, path):
raise NotImplementedError()

@abstractmethod
def read_checkpoint(self, data_inputs) -> Any:
def read_checkpoint(self, data_inputs: Any) -> Any:
"""
Read checkpoint data to get the data inputs and expected output.
:param data_inputs: data inputs to save
:return: data_inputs_checkpoint
"""
raise NotImplementedError()

@abstractmethod
def save_checkpoint(self, data_inputs):
def save_checkpoint(self, ids, data_inputs: Any):
"""
Save checkpoint for data inputs and expected outputs so that it can
be loaded by the checkpoint pipeline runner on the next pipeline run
:param ids: data inputs ids
:param data_inputs: data inputs to save
:return:
"""
Expand All @@ -103,14 +109,15 @@ def read_checkpoint(self, data_inputs):
"""
data_inputs_checkpoint_file_name = self.checkpoint_path
with open(self.get_checkpoint_file_path(data_inputs_checkpoint_file_name), 'rb') as file:
data_inputs = pickle.load(file)
checkpoint = pickle.load(file)

return data_inputs
return checkpoint

def save_checkpoint(self, data_inputs):
def save_checkpoint(self, ids, data_inputs):
"""
Save pickle files for data inputs and expected output
to create a checkpoint
:param ids: data inputs ids
:param data_inputs: data inputs to be saved in a pickle file
:return:
"""
Expand Down
35 changes: 15 additions & 20 deletions neuraxle/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,13 @@
from abc import ABC, abstractmethod
from typing import Any, Tuple

from neuraxle.base import BaseStep, TruncableSteps, NamedTupleList, ResumableStepMixin
from neuraxle.base import BaseStep, TruncableSteps, NamedTupleList, ResumableStepMixin, HasherByIndex, Hasher
from neuraxle.checkpoints import BaseCheckpointStep


class DataObject:
def __init__(self, i, x):
self.i = i
self.x = x

def __hash__(self):
return hash((self.i, self.x))


class BasePipeline(TruncableSteps, ABC):

def __init__(self, steps: NamedTupleList):
BaseStep.__init__(self)
def __init__(self, steps: NamedTupleList, hasher: Hasher = HasherByIndex()):
BaseStep.__init__(self, None, None, None, hasher)
TruncableSteps.__init__(self, steps)

@abstractmethod
Expand Down Expand Up @@ -103,11 +93,13 @@ def fit_transform_steps(self, data_inputs, expected_outputs):
:param expected_outputs: the expected data output to fit on
:return: the pipeline itself
"""
steps_left_to_do, data_inputs = self.read_checkpoint(data_inputs)
steps_left_to_do, ids, data_inputs = self.resume_pipeline(data_inputs)

new_steps_as_tuple: NamedTupleList = []

for step_name, step in steps_left_to_do:
step, data_inputs = step.fit_transform(data_inputs, expected_outputs)
step, data_inputs = step.handle_fit_transform(ids, data_inputs, expected_outputs)
ids = step.rehash(ids, data_inputs)
new_steps_as_tuple.append((step_name, step))

self.steps_as_tuple = self.steps_as_tuple[:len(self.steps_as_tuple) - len(steps_left_to_do)] + \
Expand All @@ -121,9 +113,11 @@ def transform(self, data_inputs):
:param data_inputs: the data input to fit on
:return: transformed data inputs
"""
steps_left_to_do, data_inputs = self.read_checkpoint(data_inputs)
steps_left_to_do, ids, data_inputs = self.resume_pipeline(data_inputs)

for step_name, step in steps_left_to_do:
data_inputs = step.transform(data_inputs)
data_inputs = step.handle_transform(ids, data_inputs)
ids = step.rehash(data_inputs)

return data_inputs

Expand All @@ -137,16 +131,17 @@ def inverse_transform_processed_outputs(self, processed_outputs) -> Any:
processed_outputs = step.transform(processed_outputs)
return processed_outputs

def read_checkpoint(self, data_inputs) -> Tuple[list, Any]:
def resume_pipeline(self, data_inputs) -> Tuple[NamedTupleList, iter, Any]:
new_data_inputs = data_inputs
new_starting_step_index = self.find_starting_step_index(data_inputs)
ids = self.hash(data_inputs)

step = self.steps_as_tuple[new_starting_step_index]
if isinstance(step, BaseCheckpointStep):
checkpoint = step.read_checkpoint(data_inputs)
new_data_inputs = checkpoint
ids, new_data_inputs = checkpoint

return self.steps_as_tuple[new_starting_step_index:], new_data_inputs
return self.steps_as_tuple[new_starting_step_index:], ids, new_data_inputs

def find_starting_step_index(self, data_inputs) -> int:
"""
Expand Down

0 comments on commit 407d7b3

Please sign in to comment.