Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Pass serialization_dir to Model, DatasetReader, and support `include_…
Browse files Browse the repository at this point in the history
…in_archive` (#4713)

* Allow usage of .tar.gz with PretrainedModelInitializer

* changelog

* typing

* simplify code with contextmanager

* make the code cleaner

* fix

* change cleanup

* fix

* fix

* fix typing

* fix according to review

* black and typing

* fix

* load dataset readers in `load_archive`

* black

* fix

* lint

* remove redundant test

* fix

* black

* empty commit

* review changes

* change `include_in_archive` to be a top-level param

* Update CHANGELOG.md

Co-authored-by: Dirk Groeneveld <dirkg@allenai.org>
Co-authored-by: Elad Segal <elשdsegal@users.noreply.github.com>
Co-authored-by: Evan Pete Walsh <epwalsh10@gmail.com>
  • Loading branch information
4 people committed Oct 21, 2020
1 parent 1f29f35 commit 01644ca
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 22 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Expand Up @@ -37,6 +37,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
sampling indices from log probabilities
- Made `BeamSearch` registrable.
- Added `top_k_sampling` and `type_p_sampling` `BeamSearch` implementations.
- Pass `serialization_dir` to `Model` and `DatasetReader`.
- Added an optional `include_in_archive` parameter to the top-level of configuration files. When specified, `include_in_archive` should be a list of paths relative to the serialization directory which will be bundled up with the final archived model from a training run.

### Changed

Expand Down
6 changes: 4 additions & 2 deletions allennlp/commands/find_learning_rate.py
Expand Up @@ -165,7 +165,7 @@ def find_learning_rate_model(
# See https://github.com/allenai/allennlp/issues/3658
assert not distributed_params, "find-lr is not compatible with DistributedDataParallel."

all_datasets = datasets_from_params(params)
all_datasets = datasets_from_params(params, serialization_dir=serialization_dir)
datasets_for_vocab_creation = set(params.pop("datasets_for_vocab_creation", all_datasets))

for dataset in datasets_for_vocab_creation:
Expand All @@ -188,7 +188,9 @@ def find_learning_rate_model(

train_data = all_datasets["train"]
train_data.index_with(vocab)
model = Model.from_params(vocab=vocab, params=params.pop("model"))
model = Model.from_params(
vocab=vocab, params=params.pop("model"), serialization_dir=serialization_dir
)
data_loader = DataLoader.from_params(dataset=train_data, params=params.pop("data_loader"))

trainer_params = params.pop("trainer")
Expand Down
17 changes: 12 additions & 5 deletions allennlp/commands/train.py
Expand Up @@ -24,7 +24,7 @@
from allennlp.common.plugins import import_plugins
from allennlp.data import DatasetReader, Vocabulary
from allennlp.data import DataLoader
from allennlp.models.archival import archive_model, CONFIG_NAME
from allennlp.models.archival import archive_model, CONFIG_NAME, verify_include_in_archive
from allennlp.models.model import _DEFAULT_WEIGHTS, Model
from allennlp.training.trainer import Trainer
from allennlp.training import util as training_util
Expand Down Expand Up @@ -226,6 +226,9 @@ def train_model(
training_util.create_serialization_dir(params, serialization_dir, recover, force)
params.to_file(os.path.join(serialization_dir, CONFIG_NAME))

include_in_archive = params.pop("include_in_archive", None)
verify_include_in_archive(include_in_archive)

distributed_params = params.params.pop("distributed", None)
# If distributed isn't in the config and the config contains strictly
# one cuda device, we just run a single training process.
Expand All @@ -240,7 +243,7 @@ def train_model(
)

if not dry_run:
archive_model(serialization_dir)
archive_model(serialization_dir, include_in_archive=include_in_archive)
return model

# Otherwise, we are running multiple processes for training.
Expand Down Expand Up @@ -315,13 +318,14 @@ def train_model(
world_size,
device_ids,
file_friendly_logging,
include_in_archive,
),
nprocs=num_procs,
)
if dry_run:
return None
else:
archive_model(serialization_dir)
archive_model(serialization_dir, include_in_archive=include_in_archive)
model = Model.load(params, serialization_dir)
return model

Expand All @@ -338,6 +342,7 @@ def _train_worker(
world_size: int = 1,
distributed_device_ids: List[int] = None,
file_friendly_logging: bool = False,
include_in_archive: List[str] = None,
) -> Optional[Model]:
"""
Helper to train the configured model/experiment. In distributed mode, this is spawned as a
Expand Down Expand Up @@ -372,6 +377,8 @@ def _train_worker(
file_friendly_logging : `bool`, optional (default=`False`)
If `True`, we add newlines to tqdm output, even on an interactive terminal, and we slow
down tqdm's output to only once every 10 seconds.
include_in_archive : `List[str]`, optional
Paths relative to `serialization_dir` that should be archived in addition to the default ones.
# Returns
Expand Down Expand Up @@ -462,7 +469,7 @@ def _train_worker(
"Training interrupted by the user. Attempting to create "
"a model archive using the current best epoch weights."
)
archive_model(serialization_dir)
archive_model(serialization_dir, include_in_archive=include_in_archive)
raise

if master:
Expand Down Expand Up @@ -659,7 +666,7 @@ def from_partial_objects(
vocabulary_ = vocabulary.construct(instances=instance_generator)
if not vocabulary_:
vocabulary_ = Vocabulary.from_instances(instance_generator)
model_ = model.construct(vocab=vocabulary_)
model_ = model.construct(vocab=vocabulary_, serialization_dir=serialization_dir)

# Initializing the model can have side effect of expanding the vocabulary.
# Save the vocab only in the master. In the degenerate non-distributed
Expand Down
10 changes: 7 additions & 3 deletions allennlp/common/testing/model_test_case.py
Expand Up @@ -22,12 +22,14 @@ class ModelTestCase(AllenNlpTestCase):
with added methods for testing [`Model`](../../models/model.md) subclasses.
"""

def set_up_model(self, param_file, dataset_file):
def set_up_model(self, param_file, dataset_file, serialization_dir=None):

self.param_file = param_file
params = Params.from_file(self.param_file)

reader = DatasetReader.from_params(params["dataset_reader"])
reader = DatasetReader.from_params(
params["dataset_reader"], serialization_dir=serialization_dir
)
# The dataset reader might be lazy, but a lazy list here breaks some of our tests.
instances = reader.read(str(dataset_file))
# Use parameters for vocabulary if they are present in the config file, so that choices like
Expand All @@ -40,7 +42,9 @@ def set_up_model(self, param_file, dataset_file):
self.vocab = vocab
self.instances = instances
self.instances.index_with(vocab)
self.model = Model.from_params(vocab=self.vocab, params=params["model"])
self.model = Model.from_params(
vocab=self.vocab, params=params["model"], serialization_dir=serialization_dir
)

# TODO(joelgrus) get rid of these
# (a lot of the model tests use them, so they'll have to be changed)
Expand Down
5 changes: 5 additions & 0 deletions allennlp/data/dataset_readers/dataset_reader.py
Expand Up @@ -140,6 +140,9 @@ class DatasetReader(Registrable):
within `_read()`**. In that case you should set `manual_multi_process_sharding`
to `True`.
serialization_dir: `str`, optional (default=`None`)
The directory in which the training output is saved to, or the directory the model is loaded from.
"""

CACHE_FILE_LOCK_TIMEOUT: int = 10
Expand All @@ -154,6 +157,7 @@ def __init__(
max_instances: Optional[int] = None,
manual_distributed_sharding: bool = False,
manual_multi_process_sharding: bool = False,
serialization_dir: Optional[str] = None,
) -> None:
self.lazy = lazy
self.max_instances = max_instances
Expand All @@ -163,6 +167,7 @@ def __init__(
os.makedirs(self._cache_directory, exist_ok=True)
self.manual_distributed_sharding = manual_distributed_sharding
self.manual_multi_process_sharding = manual_multi_process_sharding
self.serialization_dir = serialization_dir

def read(self, file_path: Union[Path, str]) -> Union[AllennlpDataset, AllennlpLazyDataset]:
"""
Expand Down
39 changes: 34 additions & 5 deletions allennlp/models/archival.py
Expand Up @@ -2,14 +2,15 @@
Helper functions for archiving models and restoring archived models.
"""
from os import PathLike
from typing import NamedTuple, Union, Dict, Any
from typing import NamedTuple, Union, Dict, Any, List, Optional
import logging
import os
import tempfile
import tarfile
import shutil
from pathlib import Path
from contextlib import contextmanager
import glob

from torch.nn import Module

Expand Down Expand Up @@ -92,10 +93,22 @@ def extract_module(self, path: str, freeze: bool = True) -> Module:
_WEIGHTS_NAME = "weights.th"


def verify_include_in_archive(include_in_archive: Optional[List[str]] = None):
if include_in_archive is None:
return
saved_names = [CONFIG_NAME, _WEIGHTS_NAME, _DEFAULT_WEIGHTS, "vocabulary"]
for archival_target in include_in_archive:
if archival_target in saved_names:
raise ConfigurationError(
f"{', '.join(saved_names)} are saved names and cannot be used for include_in_archive."
)


def archive_model(
serialization_dir: Union[str, PathLike],
weights: str = _DEFAULT_WEIGHTS,
archive_path: Union[str, PathLike] = None,
include_in_archive: Optional[List[str]] = None,
) -> None:
"""
Archive the model weights, its training configuration, and its vocabulary to `model.tar.gz`.
Expand All @@ -110,6 +123,8 @@ def archive_model(
A full path to serialize the model to. The default is "model.tar.gz" inside the
serialization_dir. If you pass a directory here, we'll serialize the model
to "model.tar.gz" inside the directory.
include_in_archive : `List[str]`, optional, (default = `None`)
Paths relative to `serialization_dir` that should be archived in addition to the default ones.
"""
weights_file = os.path.join(serialization_dir, weights)
if not os.path.exists(weights_file):
Expand All @@ -132,6 +147,14 @@ def archive_model(
archive.add(weights_file, arcname=_WEIGHTS_NAME)
archive.add(os.path.join(serialization_dir, "vocabulary"), arcname="vocabulary")

if include_in_archive is not None:
for archival_target in include_in_archive:
archival_target_path = os.path.join(serialization_dir, archival_target)
for path in glob.glob(archival_target_path):
if os.path.exists(path):
arcname = path[len(os.path.join(serialization_dir, "")) :]
archive.add(path, arcname=arcname)


def load_archive(
archive_file: Union[str, Path],
Expand Down Expand Up @@ -179,7 +202,9 @@ def load_archive(
config = Params.from_file(os.path.join(serialization_dir, CONFIG_NAME), overrides)

# Instantiate model and dataset readers. Use a duplicate of the config, as it will get consumed.
dataset_reader, validation_dataset_reader = _load_dataset_readers(config.duplicate())
dataset_reader, validation_dataset_reader = _load_dataset_readers(
config.duplicate(), serialization_dir
)
model = _load_model(config.duplicate(), weights_path, serialization_dir, cuda_device)
finally:
if tempdir is not None:
Expand All @@ -194,7 +219,7 @@ def load_archive(
)


def _load_dataset_readers(config):
def _load_dataset_readers(config, serialization_dir):
dataset_reader_params = config.get("dataset_reader")

# Try to use the validation dataset reader if there is one - otherwise fall back
Expand All @@ -203,8 +228,12 @@ def _load_dataset_readers(config):
"validation_dataset_reader", dataset_reader_params.duplicate()
)

dataset_reader = DatasetReader.from_params(dataset_reader_params)
validation_dataset_reader = DatasetReader.from_params(validation_dataset_reader_params)
dataset_reader = DatasetReader.from_params(
dataset_reader_params, serialization_dir=serialization_dir
)
validation_dataset_reader = DatasetReader.from_params(
validation_dataset_reader_params, serialization_dir=serialization_dir
)

return dataset_reader, validation_dataset_reader

Expand Down
14 changes: 12 additions & 2 deletions allennlp/models/model.py
Expand Up @@ -66,15 +66,23 @@ class Model(torch.nn.Module, Registrable):
separately.
regularizer: `RegularizerApplicator`, optional
If given, the `Trainer` will use this to regularize model parameters.
serialization_dir: `str`, optional
The directory in which the training output is saved to, or the directory the model is loaded from.
"""

_warn_for_unseparable_batches: Set[str] = set()
default_predictor: Optional[str] = None

def __init__(self, vocab: Vocabulary, regularizer: RegularizerApplicator = None) -> None:
def __init__(
self,
vocab: Vocabulary,
regularizer: RegularizerApplicator = None,
serialization_dir: Optional[str] = None,
) -> None:
super().__init__()
self.vocab = vocab
self._regularizer = regularizer
self.serialization_dir = serialization_dir

def get_regularization_penalty(self) -> Optional[torch.Tensor]:
"""
Expand Down Expand Up @@ -293,7 +301,9 @@ def _load(
# stored in our model. We don't need any pretrained weight file or initializers anymore,
# and we don't want the code to look for it, so we remove it from the parameters here.
remove_keys_from_params(model_params)
model = Model.from_params(vocab=vocab, params=model_params)
model = Model.from_params(
vocab=vocab, params=model_params, serialization_dir=serialization_dir
)

# Force model to cpu or gpu, as appropriate, to make sure that the embeddings are
# in sync with the weights
Expand Down
15 changes: 11 additions & 4 deletions allennlp/training/util.py
Expand Up @@ -124,7 +124,11 @@ def read_all_datasets(


def datasets_from_params(
params: Params, train: bool = True, validation: bool = True, test: bool = True
params: Params,
train: bool = True,
validation: bool = True,
test: bool = True,
serialization_dir: Optional[Union[str, PathLike]] = None,
) -> Dict[str, Union["AllennlpDataset", "AllennlpLazyDataset"]]:
"""
Load datasets specified by the config.
Expand All @@ -139,7 +143,9 @@ def datasets_from_params(
return datasets

dataset_reader_params = params.pop("dataset_reader")
dataset_reader = DatasetReader.from_params(dataset_reader_params)
dataset_reader = DatasetReader.from_params(
dataset_reader_params, serialization_dir=serialization_dir
)

if train:
train_data_path = params.pop("train_data_path")
Expand All @@ -157,7 +163,7 @@ def datasets_from_params(
if validation_dataset_reader_params is not None:
logger.info("Using a separate dataset reader to load validation and test data.")
validation_and_test_dataset_reader = DatasetReader.from_params(
validation_dataset_reader_params
validation_dataset_reader_params, serialization_dir=serialization_dir
)

if validation:
Expand Down Expand Up @@ -464,14 +470,15 @@ def make_vocab_from_params(
if datasets_for_vocab_creation is None:
# If `datasets_for_vocab_creation` was not specified, we'll use all datasets
# from the config.
datasets = datasets_from_params(params)
datasets = datasets_from_params(params, serialization_dir=serialization_dir)
else:
for dataset_name in datasets_for_vocab_creation:
data_path = f"{dataset_name}_data_path"
if data_path not in params:
raise ConfigurationError(f"invalid 'datasets_for_vocab_creation' {dataset_name}")
datasets = datasets_from_params(
params,
serialization_dir=serialization_dir,
train=("train" in datasets_for_vocab_creation),
validation=("validation" in datasets_for_vocab_creation),
test=("test" in datasets_for_vocab_creation),
Expand Down
31 changes: 30 additions & 1 deletion tests/models/archival_test.py
@@ -1,12 +1,17 @@
import copy
import os
import tempfile
import tarfile

import pytest
import torch

from allennlp.commands.train import train_model
from allennlp.common import Params
from allennlp.common.checks import ConfigurationError
from allennlp.common.testing import AllenNlpTestCase
from allennlp.data.dataset_readers import DatasetReader
from allennlp.models.archival import archive_model, load_archive
from allennlp.models.archival import archive_model, load_archive, CONFIG_NAME


def assert_models_equal(model, model2):
Expand Down Expand Up @@ -126,3 +131,27 @@ def test_can_load_from_archive_model(self):
# In this case, the parameters are definitely different, no need for the above
# check.
pass

def test_include_in_archive(self):
self.params["include_in_archive"] = ["metrics_epoch_*.json"]

serialization_dir = self.TEST_DIR / "serialization"
# Train a model
train_model(self.params, serialization_dir=serialization_dir)

# Assert that the additional targets were archived
with tempfile.TemporaryDirectory() as tempdir:
with tarfile.open(serialization_dir / "model.tar.gz", "r:gz") as archive:
archive.extractall(tempdir)
assert os.path.isfile(os.path.join(tempdir, "metrics_epoch_0.json"))
assert os.path.isfile(os.path.join(tempdir, "metrics_epoch_1.json"))
assert not os.path.isfile(os.path.join(tempdir, "metrics.json"))

def test_invalid_include_in_archive(self):
self.params["include_in_archive"] = [CONFIG_NAME]

serialization_dir = self.TEST_DIR / "serialization"

with pytest.raises(ConfigurationError) as exc:
train_model(self.params, serialization_dir=serialization_dir)
assert "are saved names and cannot be used" in str(exc.value)

0 comments on commit 01644ca

Please sign in to comment.