Skip to content

Commit

Permalink
evaluate_loop and train_loop executed only when respective input …
Browse files Browse the repository at this point in the history
…data is implemeted and not None
  • Loading branch information
ahmed-shariff committed Jul 29, 2019
1 parent 7b91be2 commit bd3b646
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 59 deletions.
126 changes: 70 additions & 56 deletions mlpipeline/_pipeline_subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,73 +196,87 @@ def _experiment_main_loop(file_path, whitelist_versions=None, blacklist_versions
input_fn = dataloader.get_train_input(mode=ExecutionModeKeys.TRAIN)
except NotImplementedError:
log('`get_train_input` not implemented for training. Ignoring')
input_fn = None

try:
train_output = current_experiment.train_loop(
input_fn=input_fn)
except Exception as e:
train_results = "Training loop failed: {0}".format(str(e))
log(train_results, logging.ERROR)
log(traceback.format_exc(), logging.ERROR)
if CONFIG.experiment_mode == ExperimentModeKeys.TEST:
raise
if isinstance(train_output, MetricContainer):
train_output = train_output.log_metrics(log_to_file=False, complete_epoch=True)
if isinstance(train_output, str):
log("Experiment traning loop output: {0}".format(train_output))
log(log_special_tokens.TRAINING_COMPLETE)
if input_fn is not None:
try:
train_output = current_experiment.train_loop(
input_fn=input_fn)
if isinstance(train_output, MetricContainer):
train_output = train_output.log_metrics(log_to_file=False, complete_epoch=True)
if isinstance(train_output, str):
log("Experiment traning loop output: {0}".format(train_output))
log(log_special_tokens.TRAINING_COMPLETE)
except NotImplementedError:
log("`train_loop` not implemeted.")
except Exception as e:
train_results = "Training loop failed: {0}".format(str(e))
log(train_results, logging.ERROR)
log(traceback.format_exc(), logging.ERROR)
if CONFIG.experiment_mode == ExperimentModeKeys.TEST:
raise
else:
log("`train_loop` not executed as training input is None")

try:
input_fn = dataloader.get_train_input(mode=ExecutionModeKeys.TEST)
except NotImplementedError:
log('`get_train_input` not implemented for evaluation. Ignoring')
try:
log("Training evaluation started: {0} steps"
.format(train_eval_steps if train_eval_steps is not None else 'unspecified'))
train_results = current_experiment.evaluate_loop(
input_fn=input_fn)
log("Eval on train set: ")
if isinstance(train_results, MetricContainer):
train_results = train_results.log_metrics(complete_epoch=True, name_prefix="TRAIN_")
elif isinstance(train_results, str):
log("{0}".format(train_results))
else:
raise ValueError("The output of `evaluate_loop` should be"
" a string or a `MetricContainer`")
except NotImplementedError:
log('`evaluate_loop` not implemented. Ignoring')
except Exception as e:
train_results = "Training evaluation failed: {0}".format(str(e))
log(train_results, logging.ERROR)
log(traceback.format_exc(), logging.ERROR)
if CONFIG.experiment_mode == ExperimentModeKeys.TEST:
raise
log('`get_train_input` not implemented for evaluation. Ignoring.')
input_fn = None
if input_fn is not None:
try:
log("Training evaluation started: {0} steps"
.format(train_eval_steps if train_eval_steps is not None else 'unspecified'))
train_results = current_experiment.evaluate_loop(
input_fn=input_fn)
log("Eval on train set: ")
if isinstance(train_results, MetricContainer):
train_results = train_results.log_metrics(complete_epoch=True, name_prefix="TRAIN_")
elif isinstance(train_results, str):
log("{0}".format(train_results))
else:
raise ValueError("The output of `evaluate_loop` should be"
" a string or a `MetricContainer`")
except NotImplementedError:
log('`evaluate_loop` not implemented. Ignoring')
except Exception as e:
train_results = "Training evaluation failed: {0}".format(str(e))
log(train_results, logging.ERROR)
log(traceback.format_exc(), logging.ERROR)
if CONFIG.experiment_mode == ExperimentModeKeys.TEST:
raise
else:
log('Not executing `evaluate_loop` as training input data is `None`')

try:
input_fn = dataloader.get_test_input()
except NotImplementedError:
log('`get_test_input` not implemented.')
input_fn = None
try:
log("Testing evaluation started: {0} steps".
format(test__eval_steps if test__eval_steps is not None else 'unspecified'))
eval_results = current_experiment.evaluate_loop(input_fn=input_fn)
log("Eval on train set:")
if isinstance(eval_results, MetricContainer):
eval_results = eval_results.log_metrics(complete_epoch=True, name_prefix="TEST_")
elif isinstance(eval_results, str):
log("{0}".format(eval_results))
else:
raise ValueError("The output of `evaluate_loop` should"
" be a string or a `MetricContainer`")
except NotImplementedError:
log('`evaluate_loop` not implemented. Ignoring')
except Exception as e:
eval_results = "Test evaluation failed: {0}".format(str(e))
log(eval_results, logging.ERROR)
log(traceback.format_exc(), logging.ERROR)
if CONFIG.experiment_mode == ExperimentModeKeys.TEST:
raise
if input_fn is not None:
try:
log("Testing evaluation started: {0} steps".
format(test__eval_steps if test__eval_steps is not None else 'unspecified'))
eval_results = current_experiment.evaluate_loop(input_fn=input_fn)
log("Eval on train set:")
if isinstance(eval_results, MetricContainer):
eval_results = eval_results.log_metrics(complete_epoch=True, name_prefix="TEST_")
elif isinstance(eval_results, str):
log("{0}".format(eval_results))
else:
raise ValueError("The output of `evaluate_loop` should"
" be a string or a `MetricContainer`")
except NotImplementedError:
log('`evaluate_loop` not implemented. Ignoring')
except Exception as e:
eval_results = "Test evaluation failed: {0}".format(str(e))
log(eval_results, logging.ERROR)
log(traceback.format_exc(), logging.ERROR)
if CONFIG.experiment_mode == ExperimentModeKeys.TEST:
raise
else:
log('Not executing `evaluate_loop` as testing input data is `None`')

try:
current_experiment.post_execution_hook(mode=CONFIG.experiment_mode)
except NotImplementedError:
Expand Down
6 changes: 3 additions & 3 deletions mlpipeline/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ class Versions():
Also prvodes helper functions to define and add new parameter versions.
'''
def __init__(self,
dataloader,
batch_size,
epoch_count,
dataloader=None,
batch_size=None,
epoch_count=None,
**kwargs):
self._default_values = EasyDict(dataloader=dataloader,
batch_size=batch_size,
Expand Down

0 comments on commit bd3b646

Please sign in to comment.