Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Combined recovery logic, made it not crash on beaker (#925)
Browse files Browse the repository at this point in the history
  • Loading branch information
matt-gardner committed Feb 26, 2018
1 parent de5df62 commit 97bc7a2
Showing 1 changed file with 40 additions and 17 deletions.
57 changes: 40 additions & 17 deletions allennlp/commands/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,20 +103,18 @@ def train_model_from_args(args: argparse.Namespace):
for package_name in args.include_package:
import_submodules(package_name)

if not args.recover and os.path.exists(args.serialization_dir):
raise ConfigurationError(f"Serialization directory ({args.serialization_dir}) already exists. "
f"Specify --recover to recover training from existing output.")
elif args.recover and not os.path.exists(args.serialization_dir):
raise ConfigurationError(f"--recover specified but serialization_dir ({args.serialization_dir}) does not "
f"exist. There is nothing to recover from.")

train_model_from_file(args.param_path, args.serialization_dir, args.overrides, args.file_friendly_logging)
train_model_from_file(args.param_path,
args.serialization_dir,
args.overrides,
args.file_friendly_logging,
args.recover)


def train_model_from_file(parameter_filename: str,
serialization_dir: str,
overrides: str = "",
file_friendly_logging: bool = False) -> Model:
file_friendly_logging: bool = False,
recover: bool = False) -> Model:
"""
A wrapper around :func:`train_model` which loads the params from a file.
Expand All @@ -132,10 +130,14 @@ def train_model_from_file(parameter_filename: str,
file_friendly_logging : ``bool``, optional (default=False)
If ``True``, we make our output more friendly to saved model files. We just pass this
along to :func:`train_model`.
recover : ``bool`, optional (default=False)
If ``True``, we will try to recover a training run from an existing serialization
directory. This is only intended for use when something actually crashed during the middle
of a run. For continuing training a model on new data, see the ``fine-tune`` command.
"""
# Load the experiment config from a file and pass it to ``train_model``.
params = Params.from_file(parameter_filename, overrides)
return train_model(params, serialization_dir, file_friendly_logging)
return train_model(params, serialization_dir, file_friendly_logging, recover)


def datasets_from_params(params: Params) -> Dict[str, Iterable[Instance]]:
Expand Down Expand Up @@ -170,19 +172,30 @@ def datasets_from_params(params: Params) -> Dict[str, Iterable[Instance]]:

return datasets

def create_serialization_dir(params: Params, serialization_dir: str) -> None:
def create_serialization_dir(params: Params, serialization_dir: str, recover: bool) -> None:
"""
This function creates the serialization directory if it doesn't exist. If it already exists,
then it verifies that we're recovering from a training with an identical configuration.
Parameters
----------
params: Params, required.
params: ``Params``
A parameter object specifying an AllenNLP Experiment.
serialization_dir: str, required
serialization_dir: ``str``
The directory in which to save results and logs.
recover: ``bool``
If ``True``, we will try to recover from an existing serialization directory, and crash if
the directory doesn't exist, or doesn't match the configuration we're given.
"""
if os.path.exists(serialization_dir):
if serialization_dir == '/output':
# Special-casing the beaker output directory, which will already exist when training
# starts.
return
if not recover:
raise ConfigurationError(f"Serialization directory ({serialization_dir}) already exists. "
f"Specify --recover to recover training from existing output.")

logger.info(f"Recovering from prior training at {serialization_dir}.")

recovered_config_file = os.path.join(serialization_dir, CONFIG_NAME)
Expand All @@ -192,8 +205,8 @@ def create_serialization_dir(params: Params, serialization_dir: str) -> None:
else:
loaded_params = Params.from_file(recovered_config_file)

# Check whether any of the training configuration differs from the configuration we are resuming.
# If so, warn the user that training may fail.
# Check whether any of the training configuration differs from the configuration we are
# resuming. If so, warn the user that training may fail.
fail = False
flat_params = params.as_flat_dict()
flat_loaded = loaded_params.as_flat_dict()
Expand All @@ -215,10 +228,16 @@ def create_serialization_dir(params: Params, serialization_dir: str) -> None:
raise ConfigurationError("Training configuration does not match the configuration we're "
"recovering from.")
else:
if recover:
raise ConfigurationError(f"--recover specified but serialization_dir ({serialization_dir}) "
"does not exist. There is nothing to recover from.")
os.makedirs(serialization_dir)


def train_model(params: Params, serialization_dir: str, file_friendly_logging: bool = False) -> Model:
def train_model(params: Params,
serialization_dir: str,
file_friendly_logging: bool = False,
recover: bool = False) -> Model:
"""
Trains the model specified in the given :class:`Params` object, using the data and training
parameters also specified in that object, and saves the results in ``serialization_dir``.
Expand All @@ -232,10 +251,14 @@ def train_model(params: Params, serialization_dir: str, file_friendly_logging: b
file_friendly_logging : ``bool``, optional (default=False)
If ``True``, we add newlines to tqdm output, even on an interactive terminal, and we slow
down tqdm's output to only once every 10 seconds.
recover : ``bool`, optional (default=False)
If ``True``, we will try to recover a training run from an existing serialization
directory. This is only intended for use when something actually crashed during the middle
of a run. For continuing training a model on new data, see the ``fine-tune`` command.
"""
prepare_environment(params)

create_serialization_dir(params, serialization_dir)
create_serialization_dir(params, serialization_dir, recover)

# TODO(mattg): pull this block out into a separate function (maybe just add this to
# `prepare_environment`?)
Expand Down

0 comments on commit 97bc7a2

Please sign in to comment.