Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Merge remote-tracking branch 'origin/master' into vision
Browse files Browse the repository at this point in the history
  • Loading branch information
dirkgr committed Sep 11, 2020
2 parents 191b641 + 2df364f commit f886fd0
Show file tree
Hide file tree
Showing 15 changed files with 157 additions and 30 deletions.
19 changes: 19 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,31 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added

- Added the ability to ignore certain missing keys when loading a model from an archive. This is done
by adding a class-level variable called `authorized_missing_keys` to any PyTorch module that a `Model` uses.
If defined, `authorized_missing_keys` should be a list of regex string patterns.

### Changed

- `transformers` dependency updated to version 3.1.0.

### Fixed

- Ignore *args when constructing classes with `FromParams`.

## [v1.1.0](https://github.com/allenai/allennlp/releases/tag/v1.1.0) - 2020-09-08

### Fixed

- Fixed handling of some edge cases when constructing classes with `FromParams` where the class
accepts `**kwargs`.
- Fixed division by zero error when there are zero-length spans in the input to a
`PretrainedTransformerMismatchedIndexer`.
- Improved robustness of `cached_path` when extracting archives so that the cache won't be corrupted
if a failure occurs during extraction.
- Fixed a bug with the `average` and `evalb_bracketing_score` metrics in distributed training.

### Added

Expand Down
53 changes: 37 additions & 16 deletions allennlp/common/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,21 +122,32 @@ def cached_path(
if isinstance(url_or_filename, PathLike):
url_or_filename = str(url_or_filename)

file_path: str

# If we're using the /a/b/foo.zip!c/d/file.txt syntax, handle it here.
exclamation_index = url_or_filename.find("!")
if extract_archive and exclamation_index >= 0:
archive_path = url_or_filename[:exclamation_index]
archive_path = cached_path(archive_path, cache_dir, True, force_extract)
if not os.path.isdir(archive_path):
file_name = url_or_filename[exclamation_index + 1 :]

# Call 'cached_path' recursively now to get the local path to the archive itself.
cached_archive_path = cached_path(archive_path, cache_dir, True, force_extract)
if not os.path.isdir(cached_archive_path):
raise ValueError(
f"{url_or_filename} uses the ! syntax, but does not specify an archive file."
)
return os.path.join(archive_path, url_or_filename[exclamation_index + 1 :])

# Now return the full path to the desired file within the extracted archive,
# provided it exists.
file_path = os.path.join(cached_archive_path, file_name)
if not os.path.exists(file_path):
raise FileNotFoundError(f"file {file_name} not found within {archive_path}")

return file_path

url_or_filename = os.path.expanduser(url_or_filename)
parsed = urlparse(url_or_filename)

file_path: str
extraction_path: Optional[str] = None

if parsed.scheme in ("http", "https", "s3"):
Expand All @@ -161,29 +172,39 @@ def cached_path(

elif parsed.scheme == "":
# File, but it doesn't exist.
raise FileNotFoundError("file {} not found".format(url_or_filename))
raise FileNotFoundError(f"file {url_or_filename} not found")

else:
# Something unknown
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
raise ValueError(f"unable to parse {url_or_filename} as a URL or as a local path")

if extraction_path is not None:
# No need to extract again.
# If the extracted directory already exists (and is non-empty), then no
# need to extract again unless `force_extract=True`.
if os.path.isdir(extraction_path) and os.listdir(extraction_path) and not force_extract:
return extraction_path

# Extract it.
with FileLock(file_path + ".lock"):
shutil.rmtree(extraction_path, ignore_errors=True)
os.makedirs(extraction_path)
if is_zipfile(file_path):
with ZipFile(file_path, "r") as zip_file:
zip_file.extractall(extraction_path)
zip_file.close()
else:
tar_file = tarfile.open(file_path)
tar_file.extractall(extraction_path)
tar_file.close()

# We extract first to a temporary directory in case something goes wrong
# during the extraction process so we don't end up with a corrupted cache.
tmp_extraction_dir = tempfile.mkdtemp(dir=os.path.split(extraction_path)[0])
try:
if is_zipfile(file_path):
with ZipFile(file_path, "r") as zip_file:
zip_file.extractall(tmp_extraction_dir)
zip_file.close()
else:
tar_file = tarfile.open(file_path)
tar_file.extractall(tmp_extraction_dir)
tar_file.close()
# Extraction was successful, rename temp directory to final
# cache directory.
os.replace(tmp_extraction_dir, extraction_path)
finally:
shutil.rmtree(tmp_extraction_dir, ignore_errors=True)

return extraction_path

Expand Down
8 changes: 6 additions & 2 deletions allennlp/common/from_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,18 +113,22 @@ def remove_optional(annotation: type):


def infer_params(cls: Type[T], constructor: Callable[..., T] = None) -> Dict[str, Any]:
if cls == FromParams:
return {}
if constructor is None:
constructor = cls.__init__

signature = inspect.signature(constructor)
parameters = dict(signature.parameters)

has_kwargs = False
var_positional_key = None
for param in parameters.values():
if param.kind == param.VAR_KEYWORD:
has_kwargs = True
elif param.kind == param.VAR_POSITIONAL:
var_positional_key = param.name

if var_positional_key:
del parameters[var_positional_key]

if not has_kwargs:
return parameters
Expand Down
7 changes: 7 additions & 0 deletions allennlp/common/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import importlib
import logging
import os
import sys
from typing import Iterable

from allennlp.common.util import push_python_path, import_module_and_submodules
Expand Down Expand Up @@ -46,6 +47,12 @@ def import_plugins() -> None:
"""
Imports the plugins found with `discover_plugins()`.
"""

# Workaround for a presumed Python issue where spawned processes can't find modules in the current directory.
cwd = os.getcwd()
if cwd not in sys.path:
sys.path.append(cwd)

for module_name in DEFAULT_PLUGINS:
try:
# For default plugins we recursively import everything.
Expand Down
6 changes: 4 additions & 2 deletions allennlp/common/testing/distributed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,14 @@ def run_distributed_test(
func: `Callable`
`func` needs to be global for spawning the processes, so that it can be pickled.
"""

check_for_gpu(device_ids)
# "fork" start method is the default and should be preferred, except when we're
# running the tests on GPU, in which case we need to use "spawn".
start_method = "spawn" if any(x >= 0 for x in device_ids) else "fork"
nprocs = world_size = len(device_ids)
mp.start_processes(
init_process,
args=(device_ids, world_size, func, args, kwargs),
nprocs=nprocs,
start_method="fork",
start_method=start_method,
)
31 changes: 30 additions & 1 deletion allennlp/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import os
from os import PathLike
import re
from typing import Dict, List, Set, Type, Optional, Union

import numpy
Expand Down Expand Up @@ -309,8 +310,36 @@ def _load(
# If vocab and model embeddings are in sync, following would be just a no-op.
model.extend_embedder_vocab()

# Load state dict. We pass `strict=False` so PyTorch doesn't raise a RuntimeError
# if the state dict is missing keys because we handle this case below.
model_state = torch.load(weights_file, map_location=util.device_mapping(cuda_device))
model.load_state_dict(model_state)
missing_keys, unexpected_keys = model.load_state_dict(model_state, strict=False)

# Modules might define a class variable called `authorized_missing_keys`,
# a list of regex patterns, that tells us to ignore missing keys that match
# any of the patterns.
# We sometimes need this in order to load older models with newer versions of AllenNLP.

def filter_out_authorized_missing_keys(module, prefix=""):
nonlocal missing_keys
for pat in getattr(module.__class__, "authorized_missing_keys", None) or []:
missing_keys = [
k
for k in missing_keys
if k.startswith(prefix) and re.search(pat[len(prefix) :], k) is None
]
for name, child in module._modules.items():
if child is not None:
filter_out_authorized_missing_keys(child, prefix + name + ".")

filter_out_authorized_missing_keys(model)

if unexpected_keys or missing_keys:
raise RuntimeError(
f"Error loading state dict for {model.__class__.__name__}\n\t"
f"Missing keys: {missing_keys}\n\t"
f"Unexpected keys: {unexpected_keys}"
)

return model

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class PretrainedTransformerEmbedder(TokenEmbedder):
Enable or disable gradient checkpointing.
"""

authorized_missing_keys = [r"position_ids$"]

def __init__(
self,
model_name: str,
Expand Down
3 changes: 1 addition & 2 deletions allennlp/training/metrics/attachment_scores.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, List, Union
from typing import Optional, List

from overrides import overrides
import torch
Expand Down Expand Up @@ -105,7 +105,6 @@ def __call__( # type: ignore
def get_metric(
self,
reset: bool = False,
cuda_device: Union[int, torch.device] = torch.device("cpu"),
):
"""
# Returns
Expand Down
2 changes: 1 addition & 1 deletion allennlp/training/metrics/average.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __call__(self, value):
_total_value = list(self.detach_tensors(value))[0]
_count = 1
if is_distributed():
device = torch.device("cpu")
device = torch.device("cuda" if dist.get_backend() == "nccl" else "cpu")
count = torch.tensor(_count).to(device)
total_value = torch.tensor(_total_value).to(device)
dist.all_reduce(count, op=dist.ReduceOp.SUM)
Expand Down
3 changes: 1 addition & 2 deletions allennlp/training/metrics/evalb_bracketing_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,7 @@ def __call__(self, predicted_trees: List[Tree], gold_trees: List[Tree]) -> None:
shutil.rmtree(tempdir)

if is_distributed():
# Setting the device to CPU since this metric is not expected to run on GPUs.
device = torch.device("cpu")
device = torch.device("cuda" if dist.get_backend() == "nccl" else "cpu")
correct_predicted_brackets = torch.tensor(_correct_predicted_brackets).to(device)
predicted_brackets = torch.tensor(_predicted_brackets).to(device)
gold_brackets = torch.tensor(_gold_brackets).to(device)
Expand Down
2 changes: 0 additions & 2 deletions allennlp/training/metrics/pearson_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def __init__(self) -> None:
self._predictions_labels_covariance = Covariance()
self._predictions_variance = Covariance()
self._labels_variance = Covariance()
self._device = torch.device("cpu")

def __call__(
self,
Expand All @@ -64,7 +63,6 @@ def __call__(
A tensor of the same shape as `predictions`.
"""
predictions, gold_labels, mask = self.detach_tensors(predictions, gold_labels, mask)
self._device = gold_labels.device
if not is_distributed():
self._predictions_labels_covariance(predictions, gold_labels, mask)
self._predictions_variance(predictions, predictions, mask)
Expand Down
2 changes: 1 addition & 1 deletion allennlp/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
_MINOR = "1"
# On master and in a nightly release the patch should be one ahead of the last
# released build.
_PATCH = "0rc4"
_PATCH = "0"
# This is mainly for nightly builds which have the suffix ".dev$DATE". See
# https://semver.org/#is-v123-a-semantic-version for the semantics.
_SUFFIX = os.environ.get("ALLENNLP_VERSION_SUFFIX", "")
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
"scikit-learn",
"scipy",
"pytest",
"transformers>=3.0,<3.1",
"transformers>=3.1,<3.2",
"jsonpickle",
"dataclasses;python_version<'3.7'",
"filelock>=3.0,<3.1",
Expand Down
19 changes: 19 additions & 0 deletions tests/common/from_params_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,3 +953,22 @@ def __init__(self, a: int, b: str = None, **kwargs) -> None:
assert foo.a == 2
assert foo.b == "hi"
assert foo.c == {"2": "3"}

def test_from_params_child_has_kwargs_base_implicit_constructor(self):
class Foo(FromParams):
pass

class Bar(Foo):
def __init__(self, a: int, **kwargs) -> None:
self.a = a

bar = Bar.from_params(Params({"a": 2}))
assert bar.a == 2

def test_from_params_has_args(self):
class Foo(FromParams):
def __init__(self, a: int, *args) -> None:
self.a = a

foo = Foo.from_params(Params({"a": 2}))
assert foo.a == 2
28 changes: 28 additions & 0 deletions tests/training/metrics/average_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from allennlp.common.testing import (
AllenNlpTestCase,
multi_device,
run_distributed_test,
global_distributed_metric,
)
from allennlp.training.metrics import Average


class AverageTest(AllenNlpTestCase):
def setup_method(self):
super().setup_method()
self.metric = Average()

@multi_device
def test_distributed_average(self, device: str):
device_ids = [-1, -1] if device == "cpu" else [0, 1]
metric_kwargs = {
"value": [1.0, 2.0],
}
run_distributed_test(
device_ids,
global_distributed_metric,
self.metric,
metric_kwargs,
1.5,
exact=True,
)

0 comments on commit f886fd0

Please sign in to comment.