Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Merge remote-tracking branch 'origin/master' into vision
Browse files Browse the repository at this point in the history
  • Loading branch information
dirkgr committed Nov 3, 2020
2 parents 81892db + b7cec51 commit b48347b
Show file tree
Hide file tree
Showing 20 changed files with 260 additions and 456 deletions.
8 changes: 6 additions & 2 deletions .github/workflows/master.yml
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,7 @@ jobs:
- name: Clean up
if: always()
run: |
pip uninstall --yes allennlp
pip uninstall --yes allennlp_models
pip uninstall --yes allennlp allennlp-models
# Builds package distribution files for PyPI.
build_package:
Expand Down Expand Up @@ -252,6 +251,11 @@ jobs:
- name: Install core package
run: |
pip install $(ls dist/*.whl)
# TODO(epwalsh): In PyTorch 1.7, dataclasses is an unconditional dependency, when it should
# only be a conditional dependency for Python < 3.7.
# This has been fixed on PyTorch master branch, so we should be able to
# remove this check with the next PyTorch release.
pip uninstall -y dataclasses
- name: Pip freeze
run: |
Expand Down
3 changes: 1 addition & 2 deletions .github/workflows/pull_request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -159,5 +159,4 @@ jobs:
- name: Clean up
if: always()
run: |
pip uninstall --yes allennlp
pip uninstall --yes allennlp_models
pip uninstall --yes allennlp allennlp-models
21 changes: 17 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,16 @@ data loaders. Those are coming soon.

## Unreleased (1.x branch)

### Fixed

- Fixed the computation of saliency maps in the Interpret code when using mismatched indexing.
Previously, we would compute gradients from the top of the transformer, after aggregation from
wordpieces to tokens, which gives results that are not very informative. Now, we compute gradients
with respect to the embedding layer, and aggregate wordpieces to tokens separately.


## [v1.2.0](https://github.com/allenai/allennlp/releases/tag/v1.2.0) - 2020-10-29

### Changed

- Enforced stricter typing requirements around the use of `Optional[T]` types.
Expand All @@ -58,6 +68,11 @@ data loaders. Those are coming soon.

- Made it possible to instantiate `TrainerCallback` from config files.
- Fixed the remaining broken internal links in the API docs.
- Fixed a bug where Hotflip would crash with a model that had multiple TokenIndexers and the input
used rare vocabulary items.
- Fixed a bug where `BeamSearch` would fail if `max_steps` was equal to 1.
- Fixed `BasicTextFieldEmbedder` to not raise ConfigurationError if it has embedders that are empty and not in input


## [v1.2.0rc1](https://github.com/allenai/allennlp/releases/tag/v1.2.0rc1) - 2020-10-22

Expand Down Expand Up @@ -87,10 +102,7 @@ data loaders. Those are coming soon.
- Added logging for the main process when running in distributed mode.
- Added a `TrainerCallback` object to support state sharing between batch and epoch-level training callbacks.
- Added support for .tar.gz in PretrainedModelInitializer.
- Added classes: `nn/samplers/samplers.py` with `MultinomialSampler`, `TopKSampler`, and `TopPSampler` for
sampling indices from log probabilities
- Made `BeamSearch` registrable.
- Added `top_k_sampling` and `type_p_sampling` `BeamSearch` implementations.
- Made `BeamSearch` instantiable `from_params`.
- Pass `serialization_dir` to `Model` and `DatasetReader`.
- Added an optional `include_in_archive` parameter to the top-level of configuration files. When specified, `include_in_archive` should be a list of paths relative to the serialization directory which will be bundled up with the final archived model from a training run.

Expand Down Expand Up @@ -142,6 +154,7 @@ data loaders. Those are coming soon.
- Fixed `allennlp.nn.util.add_sentence_boundary_token_ids()` to use `device` parameter of input tensor.
- Be sure to close the TensorBoard writer even when training doesn't finish.
- Fixed the docstring for `PyTorchSeq2VecWrapper`.
- Fix intra word tokenization for `PretrainedTransformerTokenizer` when disabling fast tokenizer.


## [v1.1.0](https://github.com/allenai/allennlp/releases/tag/v1.1.0) - 2020-09-08
Expand Down
5 changes: 5 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ WORKDIR /stage/allennlp
# Install the wheel of AllenNLP.
COPY dist dist/
RUN pip install $(ls dist/*.whl)
# TODO(epwalsh): In PyTorch 1.7, dataclasses is an unconditional dependency, when it should
# only be a conditional dependency for Python < 3.7.
# This has been fixed on PyTorch master branch, so we should be able to
# remove this check with the next PyTorch release.
RUN pip uninstall -y dataclasses

# Copy wrapper script to allow beaker to run resumable training workloads.
COPY scripts/ai2_internal/resumable_train.sh /stage/allennlp
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ And others on the [AI2 AllenNLP blog](https://medium.com/ai2-blog/allennlp/home)

AllenNLP requires Python 3.6.1 or later. The preferred way to install AllenNLP is via `pip`. Just run `pip install allennlp` in your Python environment and you're good to go!

> ⚠️ If you're using Python 3.7 or greater, you should ensure that you don't have the PyPI version of `dataclasses` installed after running the above command, as this could cause issues on certain platforms. You can quickly check this by running `pip freeze | grep dataclasses`. If you see something like `dataclasses=0.6` in the output, then just run `pip uninstall -y dataclasses`.
If you need pointers on setting up an appropriate Python environment or would like to install AllenNLP using a different method, see below.

We support AllenNLP on Mac and Linux environments. We presently do not support Windows but are open to contributions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,6 @@ def _intra_word_tokenize(
return_tensors=None,
return_offsets_mapping=False,
return_attention_mask=False,
return_token_type_ids=False,
)
wp_ids = wordpieces["input_ids"]

Expand Down
2 changes: 1 addition & 1 deletion allennlp/interpret/attackers/hotflip.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def _first_order_taylor(self, grad: numpy.ndarray, token_idx: torch.Tensor, sign
# This happens when we've truncated our fake embedding matrix. We need to do a dot
# product with the word vector of the current token; if that token is out of
# vocabulary for our truncated matrix, we need to run it through the embedding layer.
inputs = self._make_embedder_input([self.vocab.get_token_from_index(token_idx)])
inputs = self._make_embedder_input([self.vocab.get_token_from_index(token_idx.item())])
word_embedding = self.embedding_layer(inputs)[0]
else:
word_embedding = torch.nn.functional.embedding(
Expand Down
30 changes: 22 additions & 8 deletions allennlp/interpret/saliency_interpreters/integrated_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, Dict, Any

import numpy
import torch

from allennlp.common.util import JsonDict, sanitize
from allennlp.data import Instance
Expand Down Expand Up @@ -38,7 +39,7 @@ def saliency_interpret_from_json(self, inputs: JsonDict) -> JsonDict:

return sanitize(instances_with_grads)

def _register_forward_hook(self, alpha: int, embeddings_list: List):
def _register_hooks(self, alpha: int, embeddings_list: List, token_offsets: List):
"""
Register a forward hook on the embedding layer which scales the embeddings by alpha. Used
for one term in the Integrated Gradients sum.
Expand All @@ -50,15 +51,23 @@ def _register_forward_hook(self, alpha: int, embeddings_list: List):
def forward_hook(module, inputs, output):
# Save the input for later use. Only do so on first call.
if alpha == 0:
embeddings_list.append(output.squeeze(0).clone().detach().numpy())
embeddings_list.append(output.squeeze(0).clone().detach())

# Scale the embedding by alpha
output.mul_(alpha)

# Register the hook
def get_token_offsets(module, inputs, outputs):
offsets = util.get_token_offsets_from_text_field_inputs(inputs)
if offsets is not None:
token_offsets.append(offsets)

# Register the hooks
handles = []
embedding_layer = util.find_embedding_layer(self.predictor._model)
handle = embedding_layer.register_forward_hook(forward_hook)
return handle
handles.append(embedding_layer.register_forward_hook(forward_hook))
text_field_embedder = util.find_text_field_embedder(self.predictor._model)
handles.append(text_field_embedder.register_forward_hook(get_token_offsets))
return handles

def _integrate_gradients(self, instance: Instance) -> Dict[str, numpy.ndarray]:
"""
Expand All @@ -67,18 +76,21 @@ def _integrate_gradients(self, instance: Instance) -> Dict[str, numpy.ndarray]:
ig_grads: Dict[str, Any] = {}

# List of Embedding inputs
embeddings_list: List[numpy.ndarray] = []
embeddings_list: List[torch.Tensor] = []
token_offsets: List[torch.Tensor] = []

# Use 10 terms in the summation approximation of the integral in integrated grad
steps = 10

# Exclude the endpoint because we do a left point integral approximation
for alpha in numpy.linspace(0, 1.0, num=steps, endpoint=False):
handles = []
# Hook for modifying embedding value
handle = self._register_forward_hook(alpha, embeddings_list)
handles = self._register_hooks(alpha, embeddings_list, token_offsets)

grads = self.predictor.get_gradients([instance])[0]
handle.remove()
for handle in handles:
handle.remove()

# Running sum of gradients
if ig_grads == {}:
Expand All @@ -93,6 +105,8 @@ def _integrate_gradients(self, instance: Instance) -> Dict[str, numpy.ndarray]:

# Gradients come back in the reverse order that they were sent into the network
embeddings_list.reverse()
token_offsets.reverse()
embeddings_list = self._aggregate_token_embeddings(embeddings_list, token_offsets)

# Element-wise multiply average gradient by the input
for idx, input_embedding in enumerate(embeddings_list):
Expand Down
32 changes: 32 additions & 0 deletions allennlp/interpret/saliency_interpreters/saliency_interpreter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from typing import List

import numpy
import torch

from allennlp.common import Registrable
from allennlp.common.util import JsonDict
from allennlp.nn import util
from allennlp.predictors import Predictor


Expand Down Expand Up @@ -30,3 +36,29 @@ def saliency_interpret_from_json(self, inputs: JsonDict) -> JsonDict:
`{grad_input_1: ..., grad_input_2: ... }`.
"""
raise NotImplementedError("Implement this for saliency interpretations")

@staticmethod
def _aggregate_token_embeddings(
embeddings_list: List[torch.Tensor], token_offsets: List[torch.Tensor]
) -> List[numpy.ndarray]:
if len(token_offsets) == 0:
return [embeddings.numpy() for embeddings in embeddings_list]
aggregated_embeddings = []
# NOTE: This is assuming that embeddings and offsets come in the same order, which may not
# be true. But, the intersection of using multiple TextFields with mismatched indexers is
# currently zero, so we'll delay handling this corner case until it actually causes a
# problem. In practice, both of these lists will always be of size one at the moment.
for embeddings, offsets in zip(embeddings_list, token_offsets):
span_embeddings, span_mask = util.batched_span_select(embeddings.contiguous(), offsets)
span_mask = span_mask.unsqueeze(-1)
span_embeddings *= span_mask # zero out paddings

span_embeddings_sum = span_embeddings.sum(2)
span_embeddings_len = span_mask.sum(2)
# Shape: (batch_size, num_orig_tokens, embedding_size)
embeddings = span_embeddings_sum / torch.clamp_min(span_embeddings_len, 1)

# All the places where the span length is zero, write in zeros.
embeddings[(span_embeddings_len == 0).expand(embeddings.shape)] = 0
aggregated_embeddings.append(embeddings.numpy())
return aggregated_embeddings
36 changes: 25 additions & 11 deletions allennlp/interpret/saliency_interpreters/simple_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import List
import numpy
import torch

from allennlp.common.util import JsonDict, sanitize
from allennlp.interpret.saliency_interpreters.saliency_interpreter import SaliencyInterpreter
Expand All @@ -21,44 +22,57 @@ def saliency_interpret_from_json(self, inputs: JsonDict) -> JsonDict:
"""
labeled_instances = self.predictor.json_to_labeled_instances(inputs)

# List of embedding inputs, used for multiplying gradient by the input for normalization
embeddings_list: List[numpy.ndarray] = []

instances_with_grads = dict()
for idx, instance in enumerate(labeled_instances):
# List of embedding inputs, used for multiplying gradient by the input for normalization
embeddings_list: List[torch.Tensor] = []
token_offsets: List[torch.Tensor] = []

# Hook used for saving embeddings
handle = self._register_forward_hook(embeddings_list)
handles = self._register_hooks(embeddings_list, token_offsets)
grads = self.predictor.get_gradients([instance])[0]
handle.remove()
for handle in handles:
handle.remove()

# Gradients come back in the reverse order that they were sent into the network
embeddings_list.reverse()
token_offsets.reverse()
embeddings_list = self._aggregate_token_embeddings(embeddings_list, token_offsets)

for key, grad in grads.items():
# Get number at the end of every gradient key (they look like grad_input_[int],
# we're getting this [int] part and subtracting 1 for zero-based indexing).
# This is then used as an index into the reversed input array to match up the
# gradient and its respective embedding.
input_idx = int(key[-1]) - 1
# The [0] here is undo-ing the batching that happens in get_gradients.
emb_grad = numpy.sum(grad[0] * embeddings_list[input_idx], axis=1)
emb_grad = numpy.sum(grad[0] * embeddings_list[input_idx][0], axis=1)
norm = numpy.linalg.norm(emb_grad, ord=1)
normalized_grad = [math.fabs(e) / norm for e in emb_grad]
grads[key] = normalized_grad

instances_with_grads["instance_" + str(idx + 1)] = grads
return sanitize(instances_with_grads)

def _register_forward_hook(self, embeddings_list: List):
def _register_hooks(self, embeddings_list: List, token_offsets: List):
"""
Finds all of the TextFieldEmbedders, and registers a forward hook onto them. When forward()
is called, embeddings_list is filled with the embedding values. This is necessary because
our normalization scheme multiplies the gradient by the embedding value.
"""

def forward_hook(module, inputs, output):
embeddings_list.append(output.squeeze(0).clone().detach().numpy())
embeddings_list.append(output.squeeze(0).clone().detach())

embedding_layer = util.find_embedding_layer(self.predictor._model)
handle = embedding_layer.register_forward_hook(forward_hook)
def get_token_offsets(module, inputs, outputs):
offsets = util.get_token_offsets_from_text_field_inputs(inputs)
if offsets is not None:
token_offsets.append(offsets)

return handle
# Register the hooks
handles = []
embedding_layer = util.find_embedding_layer(self.predictor._model)
handles.append(embedding_layer.register_forward_hook(forward_hook))
text_field_embedder = util.find_text_field_embedder(self.predictor._model)
handles.append(text_field_embedder.register_forward_hook(get_token_offsets))
return handles
19 changes: 17 additions & 2 deletions allennlp/modules/text_field_embedders/basic_text_field_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from allennlp.modules.text_field_embedders.text_field_embedder import TextFieldEmbedder
from allennlp.modules.time_distributed import TimeDistributed
from allennlp.modules.token_embedders.token_embedder import TokenEmbedder
from allennlp.modules.token_embedders import EmptyEmbedder


@TextFieldEmbedder.register("basic")
Expand Down Expand Up @@ -53,18 +54,32 @@ def get_output_dim(self) -> int:
def forward(
self, text_field_input: TextFieldTensors, num_wrapping_dims: int = 0, **kwargs
) -> torch.Tensor:
if self._token_embedders.keys() != text_field_input.keys():
if sorted(self._token_embedders.keys()) != sorted(text_field_input.keys()):
message = "Mismatched token keys: %s and %s" % (
str(self._token_embedders.keys()),
str(text_field_input.keys()),
)
raise ConfigurationError(message)
embedder_keys = set(self._token_embedders.keys())
input_keys = set(text_field_input.keys())
if embedder_keys > input_keys and all(
isinstance(embedder, EmptyEmbedder)
for name, embedder in self._token_embedders.items()
if name in embedder_keys - input_keys
):
# Allow extra embedders that are only in the token embedders (but not input) and are empty to pass
# config check
pass
else:
raise ConfigurationError(message)

embedded_representations = []
for key in self._ordered_embedder_keys:
# Note: need to use getattr here so that the pytorch voodoo
# with submodules works with multiple GPUs.
embedder = getattr(self, "token_embedder_{}".format(key))
if isinstance(embedder, EmptyEmbedder):
# Skip empty embedders
continue
forward_params = inspect.signature(embedder.forward).parameters
forward_params_values = {}
missing_tensor_args = set()
Expand Down

0 comments on commit b48347b

Please sign in to comment.