Skip to content
Permalink
Browse files

checkpoint hash idea handle wip

  • Loading branch information...
alexbrillant committed Sep 5, 2019
1 parent d995377 commit 407d7b3b6d114e668469d0c51dfcabe28a83e152
Showing with 71 additions and 31 deletions.
  1. +40 −2 neuraxle/base.py
  2. +16 −9 neuraxle/checkpoints.py
  3. +15 −20 neuraxle/pipeline.py
@@ -29,14 +29,38 @@
from neuraxle.hyperparams.space import HyperparameterSpace, HyperparameterSamples


class BaseStep(ABC):
class Hasher(ABC):
@abstractmethod
def hash(self, data_inputs: Any):
return hash(data_inputs)

@abstractmethod
def rehash(self, ids, data_inputs: Any):
return self.hash(data_inputs)


class NullHasher(Hasher):
def hash(self, data_inputs: Any):
pass

def rehash(self, ids, data_inputs: Any):
return ids


class HasherByIndex(Hasher):
def hash(self, data_inputs: Any):
return range(len(data_inputs))


class BaseStep(ABC):
def __init__(
self,
hyperparams: HyperparameterSamples = None,
hyperparams_space: HyperparameterSpace = None,
name: str = None
name: str = None,
hasher: Hasher = NullHasher()
):
self.hasher = hasher
if hyperparams is None:
hyperparams = dict()
if hyperparams_space is None:
@@ -82,6 +106,20 @@ def set_hyperparams_space(self, hyperparams_space: HyperparameterSpace) -> 'Base
def get_hyperparams_space(self, flat=False) -> HyperparameterSpace:
return self.hyperparams_space

@abstractmethod
def handle_fit_transform(self, ids, data_inputs, expected_outputs) -> ('BaseStep', Any):
return self.fit_transform(data_inputs, expected_outputs)

@abstractmethod
def handle_transform(self, ids, data_inputs) -> Any:
return self.transform(data_inputs)

def hash(self, data_inputs: Any):
return self.hasher.hash(data_inputs)

def rehash(self, ids, data_inputs: Any):
return self.hasher.rehash(ids, data_inputs)

def fit_transform(self, data_inputs, expected_outputs=None) -> ('BaseStep', Any):
new_self = self.fit(data_inputs, expected_outputs)
out = new_self.transform(data_inputs)
@@ -35,6 +35,14 @@ def __init__(self, force_checkpoint_name: str = None):
BaseStep.__init__(self)
self.force_checkpoint_name = force_checkpoint_name

def handle_transform(self, ids, data_inputs) -> Any:
self.save_checkpoint(ids, data_inputs)
return data_inputs

def handle_fit_transform(self, ids, data_inputs, expected_outputs) -> ('BaseStep', Any):
self.save_checkpoint(ids, data_inputs)
return self, data_inputs

def fit(self, data_inputs, expected_outputs=None) -> 'BaseCheckpointStep':
"""
Save checkpoint for data inputs and expected outputs so that it can
@@ -43,8 +51,6 @@ def fit(self, data_inputs, expected_outputs=None) -> 'BaseCheckpointStep':
:param data_inputs: data inputs to save
:return: self
"""
self.save_checkpoint(data_inputs)

return self

def transform(self, data_inputs):
@@ -54,8 +60,6 @@ def transform(self, data_inputs):
:param data_inputs: data inputs to save
:return: data_inputs
"""
self.save_checkpoint(data_inputs)

return data_inputs

@abstractmethod
@@ -67,18 +71,20 @@ def set_checkpoint_path(self, path):
raise NotImplementedError()

@abstractmethod
def read_checkpoint(self, data_inputs) -> Any:
def read_checkpoint(self, data_inputs: Any) -> Any:
"""
Read checkpoint data to get the data inputs and expected output.
:param data_inputs: data inputs to save
:return: data_inputs_checkpoint
"""
raise NotImplementedError()

@abstractmethod
def save_checkpoint(self, data_inputs):
def save_checkpoint(self, ids, data_inputs: Any):
"""
Save checkpoint for data inputs and expected outputs so that it can
be loaded by the checkpoint pipeline runner on the next pipeline run
:param ids: data inputs ids
:param data_inputs: data inputs to save
:return:
"""
@@ -103,14 +109,15 @@ def read_checkpoint(self, data_inputs):
"""
data_inputs_checkpoint_file_name = self.checkpoint_path
with open(self.get_checkpoint_file_path(data_inputs_checkpoint_file_name), 'rb') as file:
data_inputs = pickle.load(file)
checkpoint = pickle.load(file)

return data_inputs
return checkpoint

def save_checkpoint(self, data_inputs):
def save_checkpoint(self, ids, data_inputs):
"""
Save pickle files for data inputs and expected output
to create a checkpoint
:param ids: data inputs ids
:param data_inputs: data inputs to be saved in a pickle file
:return:
"""
@@ -22,23 +22,13 @@
from abc import ABC, abstractmethod
from typing import Any, Tuple

from neuraxle.base import BaseStep, TruncableSteps, NamedTupleList, ResumableStepMixin
from neuraxle.base import BaseStep, TruncableSteps, NamedTupleList, ResumableStepMixin, HasherByIndex, Hasher
from neuraxle.checkpoints import BaseCheckpointStep


class DataObject:
def __init__(self, i, x):
self.i = i
self.x = x

def __hash__(self):
return hash((self.i, self.x))


class BasePipeline(TruncableSteps, ABC):

def __init__(self, steps: NamedTupleList):
BaseStep.__init__(self)
def __init__(self, steps: NamedTupleList, hasher: Hasher = HasherByIndex()):
BaseStep.__init__(self, None, None, None, hasher)
TruncableSteps.__init__(self, steps)

@abstractmethod
@@ -103,11 +93,13 @@ def fit_transform_steps(self, data_inputs, expected_outputs):
:param expected_outputs: the expected data output to fit on
:return: the pipeline itself
"""
steps_left_to_do, data_inputs = self.read_checkpoint(data_inputs)
steps_left_to_do, ids, data_inputs = self.resume_pipeline(data_inputs)

new_steps_as_tuple: NamedTupleList = []

for step_name, step in steps_left_to_do:
step, data_inputs = step.fit_transform(data_inputs, expected_outputs)
step, data_inputs = step.handle_fit_transform(ids, data_inputs, expected_outputs)
ids = step.rehash(ids, data_inputs)
new_steps_as_tuple.append((step_name, step))

self.steps_as_tuple = self.steps_as_tuple[:len(self.steps_as_tuple) - len(steps_left_to_do)] + \
@@ -121,9 +113,11 @@ def transform(self, data_inputs):
:param data_inputs: the data input to fit on
:return: transformed data inputs
"""
steps_left_to_do, data_inputs = self.read_checkpoint(data_inputs)
steps_left_to_do, ids, data_inputs = self.resume_pipeline(data_inputs)

for step_name, step in steps_left_to_do:
data_inputs = step.transform(data_inputs)
data_inputs = step.handle_transform(ids, data_inputs)
ids = step.rehash(data_inputs)

return data_inputs

@@ -137,16 +131,17 @@ def inverse_transform_processed_outputs(self, processed_outputs) -> Any:
processed_outputs = step.transform(processed_outputs)
return processed_outputs

def read_checkpoint(self, data_inputs) -> Tuple[list, Any]:
def resume_pipeline(self, data_inputs) -> Tuple[NamedTupleList, iter, Any]:
new_data_inputs = data_inputs
new_starting_step_index = self.find_starting_step_index(data_inputs)
ids = self.hash(data_inputs)

step = self.steps_as_tuple[new_starting_step_index]
if isinstance(step, BaseCheckpointStep):
checkpoint = step.read_checkpoint(data_inputs)
new_data_inputs = checkpoint
ids, new_data_inputs = checkpoint

return self.steps_as_tuple[new_starting_step_index:], new_data_inputs
return self.steps_as_tuple[new_starting_step_index:], ids, new_data_inputs

def find_starting_step_index(self, data_inputs) -> int:
"""

0 comments on commit 407d7b3

Please sign in to comment.
You can’t perform that action at this time.