Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Merge branch 'master' into vision
Browse files Browse the repository at this point in the history
# Conflicts:
#	allennlp/data/dataset_readers/sharded_dataset_reader.py
  • Loading branch information
dirkgr committed Dec 15, 2020
2 parents c8521d8 + 41c5224 commit 457e56e
Show file tree
Hide file tree
Showing 50 changed files with 196 additions and 75 deletions.
3 changes: 3 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@
**/__pycache__
.gitignore
.git
.coverage
.benchmarks
.mypy_cache
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ __pycache__

.coverage
.pytest_cache/
.benchmarks

# documentation build artifacts

Expand Down
23 changes: 21 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,26 @@ dataset at every epoch) and a `MultiTaskScheduler` (for ordering the instances w
### Added

- Added links to source code in docs.
- Fixed issue with GradientDescentTrainer when constructed with validation_data_loader==None and learning_rate_scheduler!=None.
- Added [Gaussian Error Linear Unit (GELU)](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html) as an Activation.

### Changed

- Renamed module `allennlp.data.tokenizers.token` to `allennlp.data.tokenizers.token_class` to avoid
[this bug](https://github.com/allenai/allennlp/issues/4819).

### Fixed

- Fixed a lot of instances where tensors were first created and then sent to a device
with `.to(device)`. Instead, these tensors are now created directly on the target device.
- Fixed issue with `GradientDescentTrainer` when constructed with `validation_data_loader=None` and `learning_rate_scheduler!=None`.
- Fixed a bug when removing all handlers in root logger.
- `ShardedDatasetReader` now inherits parameters from `base_reader` when required.
- Fixed an issue in `FromParams` where parameters in the `params` object used to a construct a class
were not passed to the constructor if the value of the parameter was equal to the default value.
This caused bugs in some edge cases where a subclass that takes `**kwargs` needs to inspect
`kwargs` before passing them to its superclass.
- Improved the band-aid solution for segmentation faults and the "ImportError: dlopen: cannot load any more object with static TLS"
by adding a `transformers` import.


## [v1.2.2](https://github.com/allenai/allennlp/releases/tag/v1.2.2) - 2020-11-17
Expand Down Expand Up @@ -141,7 +160,7 @@ dataset at every epoch) and a `MultiTaskScheduler` (for ordering the instances w
- Added ability to pass additional key word arguments to `cached_transformers.get()`, which will be passed on to `AutoModel.from_pretrained()`.
- Added an `overrides` argument to `Predictor.from_path()`.
- Added a `cached-path` command.
- Added a function `inspect_cache` to `common.file_utils` that prints useful information about the cache. This can also
- Added a function `inspect_cache` to `common.file_utils` that prints useful information about the cache. This can also
be used from the `cached-path` command with `allennlp cached-path --inspect`.
- Added a function `remove_cache_entries` to `common.file_utils` that removes any cache entries matching the given
glob patterns. This can used from the `cached-path` command with `allennlp cached-path --remove some-files-*`.
Expand Down
2 changes: 1 addition & 1 deletion allennlp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
try:
# On some systems this prevents the dreaded
# ImportError: dlopen: cannot load any more object with static TLS
import spacy, torch, numpy # noqa
import transformers, spacy, torch, numpy # noqa

except ModuleNotFoundError:
print(
Expand Down
6 changes: 4 additions & 2 deletions allennlp/common/from_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,16 +195,18 @@ def create_kwargs(
# and an __args__ field indicating `(str, int)`. We capture both.
annotation = remove_optional(param.annotation)

explicitly_set = param_name in params
constructed_arg = pop_and_construct_arg(
cls.__name__, param_name, annotation, param.default, params, **extras
)

# If we just ended up constructing the default value for the parameter, we can just omit it.
# If the param wasn't explicitly set in `params` and we just ended up constructing
# the default value for the parameter, we can just omit it.
# Leaving it in can cause issues with **kwargs in some corner cases, where you might end up
# with multiple values for a single parameter (e.g., the default value gives you lazy=False
# for a dataset reader inside **kwargs, but a particular dataset reader actually hard-codes
# lazy=True - the superclass sees both lazy=True and lazy=False in its constructor).
if constructed_arg is not param.default:
if explicitly_set or constructed_arg is not param.default:
kwargs[param_name] = constructed_arg

if accepts_kwargs:
Expand Down
3 changes: 1 addition & 2 deletions allennlp/common/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@ def prepare_global_logging(

# Remove the already set handlers in root logger.
# Not doing this will result in duplicate log messages
for handler in root_logger.handlers:
root_logger.removeHandler(handler)
root_logger.handlers.clear()

if os.environ.get("ALLENNLP_DEBUG"):
LEVEL = logging.DEBUG
Expand Down
2 changes: 1 addition & 1 deletion allennlp/common/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def sanitize(x: Any) -> Any:
can be serialized into JSON.
"""
# Import here to avoid circular references
from allennlp.data.tokenizers.token import Token
from allennlp.data.tokenizers import Token

if isinstance(x, (str, float, int, bool)):
# x is already serializable
Expand Down
3 changes: 1 addition & 2 deletions allennlp/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
from allennlp.data.instance import Instance
from allennlp.data.samplers import BatchSampler, PyTorchSampler, PyTorchBatchSampler
from allennlp.data.token_indexers.token_indexer import TokenIndexer, IndexedTokenList
from allennlp.data.tokenizers.token import Token
from allennlp.data.tokenizers.tokenizer import Tokenizer
from allennlp.data.tokenizers import Token, Tokenizer
from allennlp.data.vocabulary import Vocabulary
from allennlp.data.batch import Batch

Expand Down
2 changes: 1 addition & 1 deletion allennlp/data/dataset_readers/dataset_utils/span_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings

from allennlp.common.checks import ConfigurationError
from allennlp.data.tokenizers.token import Token
from allennlp.data.tokenizers import Token


TypedSpan = Tuple[int, Tuple[int, int]]
Expand Down
10 changes: 10 additions & 0 deletions allennlp/data/dataset_readers/sharded_dataset_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,23 @@ class ShardedDatasetReader(DatasetReader):
Registered as a `DatasetReader` with name "sharded".
This class accepts all additional parameters of any `DatasetReader` class via `**kwargs`.
We give priority to the values set in the constructor for the instance of this class.
Optionally, we will automatically inherit attributes from the `base_reader` when required.
# Parameters
base_reader : `DatasetReader`
Reader with a read method that accepts a single file.
"""

def __init__(self, base_reader: DatasetReader, **kwargs) -> None:
# ShardedDatasetReader is a wrapper for the original base_reader so some of the parameters like 'lazy'
# can be safely inherited. However, ShardedDatasetReader is a class instance of a DatasetReader as well.
# So we give priority to the parameters for the current instance stored in 'kwargs'.
# If not present, we check the ones in the base reader
kwargs["lazy"] = kwargs.get("lazy", base_reader.lazy)

super().__init__(
manual_distributed_sharding=True, manual_multi_process_sharding=True, **kwargs
)
Expand Down
2 changes: 1 addition & 1 deletion allennlp/data/fields/namespace_swapping_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from allennlp.common.util import pad_sequence_to_length
from allennlp.data.vocabulary import Vocabulary
from allennlp.data.tokenizers.token import Token
from allennlp.data.tokenizers import Token
from allennlp.data.fields.field import Field


Expand Down
2 changes: 1 addition & 1 deletion allennlp/data/fields/text_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from allennlp.common.checks import ConfigurationError
from allennlp.data.fields.sequence_field import SequenceField
from allennlp.data.tokenizers.token import Token
from allennlp.data.tokenizers import Token
from allennlp.data.token_indexers.token_indexer import TokenIndexer, IndexedTokenList
from allennlp.data.vocabulary import Vocabulary
from allennlp.nn import util
Expand Down
2 changes: 1 addition & 1 deletion allennlp/data/token_indexers/elmo_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

from allennlp.common.util import pad_sequence_to_length
from allennlp.data.tokenizers.token import Token
from allennlp.data.tokenizers import Token
from allennlp.data.token_indexers.token_indexer import TokenIndexer, IndexedTokenList
from allennlp.data.vocabulary import Vocabulary

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from overrides import overrides

from allennlp.data.vocabulary import Vocabulary
from allennlp.data.tokenizers import PretrainedTransformerTokenizer
from allennlp.data.tokenizers.token import Token
from allennlp.data.tokenizers import Token, PretrainedTransformerTokenizer
from allennlp.data.token_indexers.token_indexer import TokenIndexer, IndexedTokenList

logger = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from allennlp.common.util import pad_sequence_to_length
from allennlp.data.vocabulary import Vocabulary
from allennlp.data.tokenizers.token import Token
from allennlp.data.tokenizers import Token
from allennlp.data.token_indexers import PretrainedTransformerIndexer, TokenIndexer
from allennlp.data.token_indexers.token_indexer import IndexedTokenList

Expand Down
2 changes: 1 addition & 1 deletion allennlp/data/token_indexers/single_id_token_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from overrides import overrides

from allennlp.data.vocabulary import Vocabulary
from allennlp.data.tokenizers.token import Token
from allennlp.data.tokenizers import Token
from allennlp.data.token_indexers.token_indexer import TokenIndexer, IndexedTokenList


Expand Down
2 changes: 1 addition & 1 deletion allennlp/data/token_indexers/spacy_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from allennlp.common.util import pad_sequence_to_length
from allennlp.data.vocabulary import Vocabulary
from allennlp.data.tokenizers.token import Token
from allennlp.data.tokenizers import Token
from allennlp.data.token_indexers.token_indexer import TokenIndexer, IndexedTokenList


Expand Down
3 changes: 1 addition & 2 deletions allennlp/data/token_indexers/token_characters_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from allennlp.common.checks import ConfigurationError
from allennlp.common.util import pad_sequence_to_length
from allennlp.data.token_indexers.token_indexer import TokenIndexer, IndexedTokenList
from allennlp.data.tokenizers.character_tokenizer import CharacterTokenizer
from allennlp.data.tokenizers.token import Token
from allennlp.data.tokenizers import Token, CharacterTokenizer
from allennlp.data.vocabulary import Vocabulary


Expand Down
2 changes: 1 addition & 1 deletion allennlp/data/token_indexers/token_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from allennlp.common import Registrable
from allennlp.common.util import pad_sequence_to_length
from allennlp.data.tokenizers.token import Token
from allennlp.data.tokenizers import Token
from allennlp.data.vocabulary import Vocabulary

# An indexed token list represents the arguments that will be passed to a TokenEmbedder
Expand Down
3 changes: 2 additions & 1 deletion allennlp/data/tokenizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
tokenization.
"""

from allennlp.data.tokenizers.tokenizer import Token, Tokenizer
from allennlp.data.tokenizers.token_class import Token
from allennlp.data.tokenizers.tokenizer import Tokenizer
from allennlp.data.tokenizers.spacy_tokenizer import SpacyTokenizer
from allennlp.data.tokenizers.letters_digits_tokenizer import LettersDigitsTokenizer
from allennlp.data.tokenizers.pretrained_transformer_tokenizer import PretrainedTransformerTokenizer
Expand Down
2 changes: 1 addition & 1 deletion allennlp/data/tokenizers/character_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from overrides import overrides

from allennlp.data.tokenizers.token import Token
from allennlp.data.tokenizers.token_class import Token
from allennlp.data.tokenizers.tokenizer import Tokenizer


Expand Down
2 changes: 1 addition & 1 deletion allennlp/data/tokenizers/letters_digits_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from overrides import overrides

from allennlp.data.tokenizers.token import Token
from allennlp.data.tokenizers.token_class import Token
from allennlp.data.tokenizers.tokenizer import Tokenizer


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from transformers import PreTrainedTokenizer

from allennlp.common.util import sanitize_wordpiece
from allennlp.data.tokenizers.token import Token
from allennlp.data.tokenizers.token_class import Token
from allennlp.data.tokenizers.tokenizer import Tokenizer

logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion allennlp/data/tokenizers/spacy_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from spacy.tokens import Doc

from allennlp.common.util import get_spacy_model
from allennlp.data.tokenizers.token import Token
from allennlp.data.tokenizers.token_class import Token
from allennlp.data.tokenizers.tokenizer import Tokenizer


Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion allennlp/data/tokenizers/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging

from allennlp.common import Registrable
from allennlp.data.tokenizers.token import Token
from allennlp.data.tokenizers.token_class import Token


logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion allennlp/data/tokenizers/whitespace_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from overrides import overrides

from allennlp.data.tokenizers.token import Token
from allennlp.data.tokenizers.token_class import Token
from allennlp.data.tokenizers.tokenizer import Tokenizer


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _register_forward_hook(self, stdev: float):
def forward_hook(module, inputs, output):
# Random noise = N(0, stdev * (max-min))
scale = output.detach().max() - output.detach().min()
noise = torch.randn(output.shape).to(output.device) * stdev * scale
noise = torch.randn(output.shape, device=output.device) * stdev * scale

# Add the random noise
output.add_(noise)
Expand Down
2 changes: 1 addition & 1 deletion allennlp/modules/sampled_softmax_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def forward(

if embeddings.shape[0] == 0:
# empty batch
return torch.tensor(0.0).to(embeddings.device)
return torch.tensor(0.0, device=embeddings.device)

if not self.training:
return self._forward_eval(embeddings, targets)
Expand Down
1 change: 1 addition & 0 deletions allennlp/nn/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,5 @@ def _get_name(self):
"softsign": (torch.nn.Softsign, None),
"tanhshrink": (torch.nn.Tanhshrink, None),
"selu": (torch.nn.SELU, None),
"gelu": (torch.nn.GELU, None),
}
2 changes: 0 additions & 2 deletions allennlp/nn/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1548,7 +1548,6 @@ def add_sentence_boundary_token_ids(
The new mask for the tensor, taking into account the appended tokens
marking the beginning and end of the sentence.
"""
# TODO: matthewp, profile this transfer
sequence_lengths = mask.sum(dim=1).detach().cpu().numpy()
tensor_shape = list(tensor.data.shape)
new_shape = list(tensor_shape)
Expand Down Expand Up @@ -1603,7 +1602,6 @@ def remove_sentence_boundaries(
new_mask : `torch.BoolTensor`
The new mask for the tensor of shape `(batch_size, timesteps - 2)`.
"""
# TODO: matthewp, profile this transfer
sequence_lengths = mask.sum(dim=1).detach().cpu().numpy()
tensor_shape = list(tensor.data.shape)
new_shape = list(tensor_shape)
Expand Down
4 changes: 2 additions & 2 deletions allennlp/training/metrics/attachment_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def __call__( # type: ignore
dist.all_reduce(unlabeled_exact_match, op=dist.ReduceOp.SUM)
dist.all_reduce(correct_labels_and_indices, op=dist.ReduceOp.SUM)
dist.all_reduce(labeled_exact_match, op=dist.ReduceOp.SUM)
total_sentences = torch.tensor(total_sentences).to(device)
total_words = torch.tensor(total_words).to(device)
total_sentences = torch.tensor(total_sentences, device=device)
total_words = torch.tensor(total_words, device=device)
dist.all_reduce(total_sentences, op=dist.ReduceOp.SUM)
dist.all_reduce(total_words, op=dist.ReduceOp.SUM)
total_sentences = total_sentences.item()
Expand Down
4 changes: 2 additions & 2 deletions allennlp/training/metrics/average.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def __call__(self, value):
_count = 1
if is_distributed():
device = torch.device("cuda" if dist.get_backend() == "nccl" else "cpu")
count = torch.tensor(_count).to(device)
total_value = torch.tensor(_total_value).to(device)
count = torch.tensor(_count, device=device)
total_value = torch.tensor(_total_value, device=device)
dist.all_reduce(count, op=dist.ReduceOp.SUM)
dist.all_reduce(total_value, op=dist.ReduceOp.SUM)
_count = count.item()
Expand Down
8 changes: 4 additions & 4 deletions allennlp/training/metrics/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ def __call__(
predictions, gold_targets, ngram_size
)
if is_distributed():
_precision_matches = torch.tensor(precision_matches).to(device)
_precision_totals = torch.tensor(precision_totals).to(device)
_precision_matches = torch.tensor(precision_matches, device=device)
_precision_totals = torch.tensor(precision_totals, device=device)
dist.all_reduce(_precision_matches, op=dist.ReduceOp.SUM)
dist.all_reduce(_precision_totals, op=dist.ReduceOp.SUM)
precision_matches = _precision_matches.item() / world_size
Expand All @@ -150,8 +150,8 @@ def __call__(
_reference_lengths = valid_gold_targets_mask.sum().item()

if is_distributed():
prediction_lengths = torch.tensor(_prediction_lengths).to(device)
reference_lengths = torch.tensor(_reference_lengths).to(device)
prediction_lengths = torch.tensor(_prediction_lengths, device=device)
reference_lengths = torch.tensor(_reference_lengths, device=device)
dist.all_reduce(prediction_lengths, op=dist.ReduceOp.SUM)
dist.all_reduce(reference_lengths, op=dist.ReduceOp.SUM)
_prediction_lengths = prediction_lengths.item()
Expand Down
8 changes: 4 additions & 4 deletions allennlp/training/metrics/covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ def __call__(

# # Note: this gives an approximate aggregation of the covariance.
# device = gold_labels.device
# delta_mean_prediction = torch.tensor(delta_mean_prediction).to(device)
# delta_mean_label = torch.tensor(delta_mean_label).to(device)
# delta_co_moment = torch.tensor(delta_co_moment).to(device)
# _total_count = torch.tensor(updated_count).to(device)
# delta_mean_prediction = torch.tensor(delta_mean_prediction, device=device)
# delta_mean_label = torch.tensor(delta_mean_label, device=device)
# delta_co_moment = torch.tensor(delta_co_moment, device=device)
# _total_count = torch.tensor(updated_count, device=device)
# dist.all_reduce(delta_mean_prediction, op=dist.ReduceOp.SUM)
# dist.all_reduce(delta_mean_label, op=dist.ReduceOp.SUM)
# dist.all_reduce(delta_co_moment, op=dist.ReduceOp.SUM)
Expand Down
2 changes: 1 addition & 1 deletion allennlp/training/metrics/entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __call__(
_count = 1

if is_distributed():
count = torch.tensor(_count).to(device)
count = torch.tensor(_count, device=device)
dist.all_reduce(_entropy, op=dist.ReduceOp.SUM)
dist.all_reduce(count, op=dist.ReduceOp.SUM)
_count = count.item()
Expand Down

0 comments on commit 457e56e

Please sign in to comment.