Skip to content

Commit

Permalink
Allow an train and eval loop to execute even if input is None
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmed-shariff committed Sep 12, 2019
1 parent 68d2834 commit 02dc7fb
Showing 1 changed file with 47 additions and 44 deletions.
91 changes: 47 additions & 44 deletions mlpipeline/_pipeline_subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,58 +197,61 @@ def _experiment_main_loop(file_path, whitelist_versions=None, blacklist_versions
try:
input_fn = dataloader.get_train_input(mode=ExecutionModeKeys.TRAIN)
except NotImplementedError:
log('`get_train_input` not implemented for training. Ignoring')
log('`get_train_input` not implemented for training. Setting training input to `None`.',
level=logging.WARNING)
input_fn = None

if input_fn is not None:
try:
train_output = current_experiment.train_loop(
input_fn=input_fn)
if isinstance(train_output, MetricContainer):
train_output = train_output.log_metrics(log_to_file=False, complete_epoch=True)
if isinstance(train_output, str):
log("Experiment traning loop output: {0}".format(train_output))
log(log_special_tokens.TRAINING_COMPLETE)
except NotImplementedError:
log("`train_loop` not implemeted.")
except Exception as e:
train_results = "Training loop failed: {0}".format(str(e))
log(train_results, logging.ERROR)
log(traceback.format_exc(), logging.ERROR)
if CONFIG.experiment_mode == ExperimentModeKeys.TEST:
raise
else:
log("`train_loop` not executed as training input is None")
if input_fn is None:
log("input to `train_loop` is `None`",
level=logging.WARNING)
try:
train_output = current_experiment.train_loop(
input_fn=input_fn)
if isinstance(train_output, MetricContainer):
train_output = train_output.log_metrics(log_to_file=False, complete_epoch=True)
if isinstance(train_output, str):
log("Experiment traning loop output: {0}".format(train_output))
log(log_special_tokens.TRAINING_COMPLETE)
except NotImplementedError:
log("`train_loop` not implemeted.")
except Exception as e:
train_results = "Training loop failed: {0}".format(str(e))
log(train_results, logging.ERROR)
log(traceback.format_exc(), logging.ERROR)
if CONFIG.experiment_mode == ExperimentModeKeys.TEST:
raise

try:
input_fn = dataloader.get_train_input(mode=ExecutionModeKeys.TEST)
except NotImplementedError:
log('`get_train_input` not implemented for evaluation. Ignoring.')
log('`get_train_input` not implemented for evaluation. Setting training input to `None`.',
level=logging.WARNING)
input_fn = None

if input_fn is not None:
try:
log("Training evaluation started: {0} steps"
.format(train_eval_steps if train_eval_steps is not None else 'unspecified'))
train_results = current_experiment.evaluate_loop(
input_fn=input_fn)
log("Eval on train set: ")
if isinstance(train_results, MetricContainer):
train_results = train_results.log_metrics(complete_epoch=True, name_prefix="TRAIN_")
elif isinstance(train_results, str):
log("{0}".format(train_results))
else:
raise ValueError("The output of `evaluate_loop` should be"
" a string or a `MetricContainer`")
except NotImplementedError:
log('`evaluate_loop` not implemented. Ignoring')
except Exception as e:
train_results = "Training evaluation failed: {0}".format(str(e))
log(train_results, logging.ERROR)
log(traceback.format_exc(), logging.ERROR)
if CONFIG.experiment_mode == ExperimentModeKeys.TEST:
raise
else:
log('Not executing `evaluate_loop` as training input data is `None`')
log('Input to `evaluate_loop` is `None` for training input data.',
level=logging.WARNING)
try:
log("Training evaluation started: {0} steps"
.format(train_eval_steps if train_eval_steps is not None else 'unspecified'))
train_results = current_experiment.evaluate_loop(
input_fn=input_fn)
log("Eval on train set: ")
if isinstance(train_results, MetricContainer):
train_results = train_results.log_metrics(complete_epoch=True, name_prefix="TRAIN_")
elif isinstance(train_results, str):
log("{0}".format(train_results))
else:
raise ValueError("The output of `evaluate_loop` should be"
" a string or a `MetricContainer`")
except NotImplementedError:
log('`evaluate_loop` not implemented. Ignoring')
except Exception as e:
train_results = "Training evaluation failed: {0}".format(str(e))
log(train_results, logging.ERROR)
log(traceback.format_exc(), logging.ERROR)
if CONFIG.experiment_mode == ExperimentModeKeys.TEST:
raise

try:
input_fn = dataloader.get_test_input()
Expand Down

0 comments on commit 02dc7fb

Please sign in to comment.