Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Removing last_dim_*softmax #1687

Merged
merged 3 commits into from
Aug 29, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 5 additions & 5 deletions allennlp/models/biaffine_dependency_parser.py
Expand Up @@ -16,7 +16,7 @@
from allennlp.models.model import Model
from allennlp.nn import InitializerApplicator, RegularizerApplicator, Activation
from allennlp.nn.util import get_text_field_mask, get_range_vector
from allennlp.nn.util import get_device_of, last_dim_log_softmax, get_lengths_from_binary_sequence_mask
from allennlp.nn.util import get_device_of, masked_log_softmax, get_lengths_from_binary_sequence_mask
from allennlp.nn.decoding.chu_liu_edmonds import decode_mst
from allennlp.training.metrics import AttachmentScores

Expand Down Expand Up @@ -351,13 +351,13 @@ def _construct_loss(self,
# shape (batch_size, 1)
range_vector = get_range_vector(batch_size, get_device_of(attended_arcs)).unsqueeze(1)
# shape (batch_size, sequence_length, sequence_length)
normalised_arc_logits = last_dim_log_softmax(attended_arcs,
mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1)
normalised_arc_logits = masked_log_softmax(attended_arcs,
mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1)

# shape (batch_size, sequence_length, num_head_tags)
head_tag_logits = self._get_head_tags(head_tag_representation, child_tag_representation, head_indices)
normalised_head_tag_logits = last_dim_log_softmax(head_tag_logits,
mask.unsqueeze(-1)) * float_mask.unsqueeze(-1)
normalised_head_tag_logits = masked_log_softmax(head_tag_logits,
mask.unsqueeze(-1)) * float_mask.unsqueeze(-1)
# index matrix with shape (batch, sequence_length)
timestep_index = get_range_vector(sequence_length, get_device_of(attended_arcs))
child_index = timestep_index.view(1, sequence_length).expand(batch_size, sequence_length).long()
Expand Down
2 changes: 1 addition & 1 deletion allennlp/models/biattentive_classification_network.py
Expand Up @@ -234,7 +234,7 @@ def forward(self, # type: ignore

# Compute biattention. This is a special case since the inputs are the same.
attention_logits = encoded_tokens.bmm(encoded_tokens.permute(0, 2, 1).contiguous())
attention_weights = util.last_dim_softmax(attention_logits, text_mask)
attention_weights = util.masked_softmax(attention_logits, text_mask)
encoded_text = util.weighted_sum(encoded_tokens, attention_weights)

# Build the input to the integrator
Expand Down
4 changes: 2 additions & 2 deletions allennlp/models/constituency_parser.py
Expand Up @@ -13,7 +13,7 @@
from allennlp.models.model import Model
from allennlp.nn import InitializerApplicator, RegularizerApplicator
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits
from allennlp.nn.util import last_dim_softmax, get_lengths_from_binary_sequence_mask
from allennlp.nn.util import masked_softmax, get_lengths_from_binary_sequence_mask
from allennlp.training.metrics import CategoricalAccuracy
from allennlp.training.metrics import EvalbBracketingScorer, DEFAULT_EVALB_DIR
from allennlp.common.checks import ConfigurationError
Expand Down Expand Up @@ -210,7 +210,7 @@ def forward(self, # type: ignore
span_representations = self.feedforward_layer(span_representations)

logits = self.tag_projection_layer(span_representations)
class_probabilities = last_dim_softmax(logits, span_mask.unsqueeze(-1))
class_probabilities = masked_softmax(logits, span_mask.unsqueeze(-1))

output_dict = {
"class_probabilities": class_probabilities,
Expand Down
2 changes: 1 addition & 1 deletion allennlp/models/coreference_resolution/coref.py
Expand Up @@ -276,7 +276,7 @@ def forward(self, # type: ignore
# probability assigned to all valid antecedents. This is a valid objective for
# clustering as we don't mind which antecedent is predicted, so long as they are in
# the same coreference cluster.
coreference_log_probs = util.last_dim_log_softmax(coreference_scores, top_span_mask)
coreference_log_probs = util.masked_log_softmax(coreference_scores, top_span_mask)
correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log()
negative_marginal_log_likelihood = -util.logsumexp(correct_antecedent_log_probs).sum()

Expand Down
6 changes: 3 additions & 3 deletions allennlp/models/decomposable_attention.py
Expand Up @@ -9,7 +9,7 @@
from allennlp.modules import Seq2SeqEncoder, SimilarityFunction, TimeDistributed, TextFieldEmbedder
from allennlp.modules.matrix_attention.legacy_matrix_attention import LegacyMatrixAttention
from allennlp.nn import InitializerApplicator, RegularizerApplicator
from allennlp.nn.util import get_text_field_mask, last_dim_softmax, weighted_sum
from allennlp.nn.util import get_text_field_mask, masked_softmax, weighted_sum
from allennlp.training.metrics import CategoricalAccuracy


Expand Down Expand Up @@ -139,12 +139,12 @@ def forward(self, # type: ignore
similarity_matrix = self._matrix_attention(projected_premise, projected_hypothesis)

# Shape: (batch_size, premise_length, hypothesis_length)
p2h_attention = last_dim_softmax(similarity_matrix, hypothesis_mask)
p2h_attention = masked_softmax(similarity_matrix, hypothesis_mask)
# Shape: (batch_size, premise_length, embedding_dim)
attended_hypothesis = weighted_sum(embedded_hypothesis, p2h_attention)

# Shape: (batch_size, hypothesis_length, premise_length)
h2p_attention = last_dim_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
h2p_attention = masked_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
# Shape: (batch_size, hypothesis_length, embedding_dim)
attended_premise = weighted_sum(embedded_premise, h2p_attention)

Expand Down
6 changes: 3 additions & 3 deletions allennlp/models/esim.py
Expand Up @@ -9,7 +9,7 @@
from allennlp.modules.matrix_attention.legacy_matrix_attention import LegacyMatrixAttention
from allennlp.modules import Seq2SeqEncoder, SimilarityFunction, TextFieldEmbedder
from allennlp.nn import InitializerApplicator, RegularizerApplicator
from allennlp.nn.util import get_text_field_mask, last_dim_softmax, weighted_sum, replace_masked_values
from allennlp.nn.util import get_text_field_mask, masked_softmax, weighted_sum, replace_masked_values
from allennlp.training.metrics import CategoricalAccuracy


Expand Down Expand Up @@ -142,12 +142,12 @@ def forward(self, # type: ignore
similarity_matrix = self._matrix_attention(encoded_premise, encoded_hypothesis)

# Shape: (batch_size, premise_length, hypothesis_length)
p2h_attention = last_dim_softmax(similarity_matrix, hypothesis_mask)
p2h_attention = masked_softmax(similarity_matrix, hypothesis_mask)
# Shape: (batch_size, premise_length, embedding_dim)
attended_hypothesis = weighted_sum(encoded_hypothesis, p2h_attention)

# Shape: (batch_size, hypothesis_length, premise_length)
h2p_attention = last_dim_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
h2p_attention = masked_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
# Shape: (batch_size, hypothesis_length, embedding_dim)
attended_premise = weighted_sum(encoded_premise, h2p_attention)

Expand Down
2 changes: 1 addition & 1 deletion allennlp/models/reading_comprehension/bidaf.py
Expand Up @@ -186,7 +186,7 @@ def forward(self, # type: ignore
# Shape: (batch_size, passage_length, question_length)
passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
# Shape: (batch_size, passage_length, question_length)
passage_question_attention = util.last_dim_softmax(passage_question_similarity, question_mask)
passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask)
# Shape: (batch_size, passage_length, encoding_dim)
passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

Expand Down
7 changes: 3 additions & 4 deletions allennlp/models/reading_comprehension/dialog_qa.py
Expand Up @@ -239,7 +239,7 @@ def forward(self, # type: ignore
# Shape: (batch_size * max_qa_count, passage_length, question_length)
passage_question_similarity = self._matrix_attention(repeated_encoded_passage, encoded_question)
# Shape: (batch_size * max_qa_count, passage_length, question_length)
passage_question_attention = util.last_dim_softmax(passage_question_similarity, question_mask)
passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask)
# Shape: (batch_size * max_qa_count, passage_length, encoding_dim)
passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

Expand All @@ -250,8 +250,7 @@ def forward(self, # type: ignore
-1e7)

question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
question_passage_attention = util.last_dim_softmax(question_passage_similarity,
repeated_passage_mask)
question_passage_attention = util.masked_softmax(question_passage_similarity, repeated_passage_mask)
# Shape: (batch_size * max_qa_count, encoding_dim)
question_passage_vector = util.weighted_sum(repeated_encoded_passage, question_passage_attention)
tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(total_qa_count,
Expand All @@ -277,7 +276,7 @@ def forward(self, # type: ignore
self_mask = self_mask.resize(1, passage_length, passage_length)
mask = mask * (1 - self_mask)

self_attention_probs = util.last_dim_softmax(self_attention_matrix, mask)
self_attention_probs = util.masked_softmax(self_attention_matrix, mask)

# (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim)
self_attention_vecs = torch.matmul(self_attention_probs, residual_layer)
Expand Down
Expand Up @@ -107,7 +107,7 @@ def forward(self, tokens: torch.Tensor, mask: torch.Tensor): # pylint: disable=
similarity_matrix = similarity_matrix.permute(0, 1, 3, 2)

# Shape: (batch_size, sequence_length, [num_heads,] sequence_length)
intra_sentence_attention = util.last_dim_softmax(similarity_matrix.contiguous(), mask)
intra_sentence_attention = util.masked_softmax(similarity_matrix.contiguous(), mask)

# Shape: (batch_size, sequence_length, projection_dim)
output_token_representation = self._projection(tokens)
Expand Down
Expand Up @@ -2,7 +2,7 @@
import torch
from torch.nn import Dropout, Linear

from allennlp.nn.util import last_dim_softmax, weighted_sum
from allennlp.nn.util import masked_softmax, weighted_sum
from allennlp.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder


Expand Down Expand Up @@ -129,7 +129,7 @@ def forward(self, # pylint: disable=arguments-differ

# shape (num_heads * batch_size, timesteps, timesteps)
# Normalise the distributions, using the same mask for all heads.
attention = last_dim_softmax(scaled_similarities, mask.repeat(1, num_heads).view(batch_size * num_heads, timesteps))
attention = masked_softmax(scaled_similarities, mask.repeat(1, num_heads).view(batch_size * num_heads, timesteps))
attention = self._attention_dropout(attention)

# Take a weighted sum of the values with respect to the attention
Expand Down
Expand Up @@ -94,7 +94,7 @@ def forward(self,
span_indices,
flat_span_indices).squeeze(-1)
# Shape: (batch_size, num_spans, max_batch_span_width)
span_attention_weights = util.last_dim_softmax(span_attention_logits, span_mask)
span_attention_weights = util.masked_softmax(span_attention_logits, span_mask)

# Do a weighted sum of the embedded spans with
# respect to the normalised attention distributions.
Expand Down
70 changes: 38 additions & 32 deletions allennlp/nn/util.py
Expand Up @@ -3,10 +3,11 @@
"""
# pylint: disable=too-many-lines
from collections import defaultdict
from typing import Dict, List, Optional, Any, Tuple, Callable
from typing import Dict, List, Optional, Any, Tuple
import logging

import math
import warnings

import torch

from allennlp.common.checks import ConfigurationError
Expand Down Expand Up @@ -176,35 +177,44 @@ def get_dropout_mask(dropout_probability: float, tensor_for_masking: torch.Tenso
return dropout_mask


def masked_softmax(vector, mask):
def masked_softmax(vector: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor:
"""
``torch.nn.functional.softmax(vector)`` does not work if some elements of ``vector`` should be
masked. This performs a softmax on just the non-masked portions of ``vector``. Passing
``None`` in for the mask is also acceptable; you'll just get a regular softmax.

We assume that both ``vector`` and ``mask`` (if given) have shape ``(batch_size, vector_dim)``.
``vector`` can have an arbitrary number of dimensions; the only requirement is that ``mask`` is
broadcastable to ``vector's`` shape. If ``mask`` has fewer dimensions than ``vector``, we will
unsqueeze on dimension 1 until they match. If you need a different unsqueezing of your mask,
do it yourself before passing the mask into this function.

In the case that the input vector is completely masked, this function returns an array
of ``0.0``. This behavior may cause ``NaN`` if this is used as the last layer of a model
that uses categorical cross-entropy loss.
"""
if mask is None:
result = torch.nn.functional.softmax(vector, dim=-1)
result = torch.nn.functional.softmax(vector, dim=dim)
else:
mask = mask.float()
while mask.dim() < vector.dim():
mask = mask.unsqueeze(1)
# To limit numerical errors from large vector elements outside the mask, we zero these out.
result = torch.nn.functional.softmax(vector * mask, dim=-1)
result = torch.nn.functional.softmax(vector * mask, dim=dim)
result = result * mask
result = result / (result.sum(dim=1, keepdim=True) + 1e-13)
result = result / (result.sum(dim=dim, keepdim=True) + 1e-13)
return result


def masked_log_softmax(vector, mask):
def masked_log_softmax(vector: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor:
"""
``torch.nn.functional.log_softmax(vector)`` does not work if some elements of ``vector`` should be
masked. This performs a log_softmax on just the non-masked portions of ``vector``. Passing
``None`` in for the mask is also acceptable; you'll just get a regular log_softmax.

We assume that both ``vector`` and ``mask`` (if given) have shape ``(batch_size, vector_dim)``.
``vector`` can have an arbitrary number of dimensions; the only requirement is that ``mask`` is
broadcastable to ``vector's`` shape. If ``mask`` has fewer dimensions than ``vector``, we will
unsqueeze on dimension 1 until they match. If you need a different unsqueezing of your mask,
do it yourself before passing the mask into this function.

In the case that the input vector is completely masked, the return value of this function is
arbitrary, but not ``nan``. You should be masking the result of whatever computation comes out
Expand All @@ -217,13 +227,16 @@ def masked_log_softmax(vector, mask):
extreme, you've got bigger problems than this.
"""
if mask is not None:
mask = mask.float()
while mask.dim() < vector.dim():
mask = mask.unsqueeze(1)
# vector + mask.log() is an easy way to zero out masked elements in logspace, but it
# results in nans when the whole vector is masked. We need a very small value instead of a
# zero in the mask for these cases. log(1 + 1e-45) is still basically 0, so we can safely
# just add 1e-45 before calling mask.log(). We use 1e-45 because 1e-46 is so small it
# becomes 0 - this is just the smallest value we can actually use.
vector = vector + (mask + 1e-45).log()
return torch.nn.functional.log_softmax(vector, dim=1)
return torch.nn.functional.log_softmax(vector, dim=dim)


def masked_max(vector: torch.Tensor,
Expand Down Expand Up @@ -424,43 +437,36 @@ def get_text_field_mask(text_field_tensors: Dict[str, torch.Tensor],
else:
raise ValueError("Expected a tensor with dimension 2 or 3, found {}".format(smallest_dim))

def _last_dimension_applicator(function_to_apply: Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor],
tensor: torch.Tensor,
mask: Optional[torch.Tensor] = None):
"""
Takes a tensor with 3 or more dimensions and applies a function over the last dimension. We
assume the tensor has shape ``(batch_size, ..., sequence_length)`` and that the mask (if given)
has shape ``(batch_size, sequence_length)``. We first unsqueeze and expand the mask so that it
has the same shape as the tensor, then flatten them both to be 2D, pass them through
the function and put the tensor back in its original shape.
"""
tensor_shape = tensor.size()
reshaped_tensor = tensor.view(-1, tensor.size()[-1])
if mask is not None:
while mask.dim() < tensor.dim():
mask = mask.unsqueeze(1)
mask = mask.expand_as(tensor).contiguous().float()
mask = mask.view(-1, mask.size()[-1])
reshaped_result = function_to_apply(reshaped_tensor, mask)
return reshaped_result.view(*tensor_shape)


def last_dim_softmax(tensor: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Takes a tensor with 3 or more dimensions and does a masked softmax over the last dimension. We
assume the tensor has shape ``(batch_size, ..., sequence_length)`` and that the mask (if given)
has shape ``(batch_size, sequence_length)``.

.. deprecated:: 0.6.1
``last_dim_softmax`` was deprecated in favor of just using ``masked_softmax`` in version
0.6.1. It will be removed in version 0.8.
"""
return _last_dimension_applicator(masked_softmax, tensor, mask)
warnings.warn("``last_dim_softmax`` was deprecated in favor of just using ``masked_softmax`` "
"in version 0.6.1. It will be removed in version 0.8.", DeprecationWarning)
return masked_softmax(tensor, mask, dim=-1)


def last_dim_log_softmax(tensor: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Takes a tensor with 3 or more dimensions and does a masked log softmax over the last dimension.
We assume the tensor has shape ``(batch_size, ..., sequence_length)`` and that the mask (if given)
has shape ``(batch_size, sequence_length)``.

.. deprecated:: 0.6.1
``last_dim_log_softmax`` was deprecated in favor of just using ``masked_log_softmax`` in
version 0.6.1. It will be removed in version 0.8.
"""
return _last_dimension_applicator(masked_log_softmax, tensor, mask)
warnings.warn("``last_dim_log_softmax`` was deprecated in favor of just using "
"``masked_log_softmax`` in version 0.6.1. It will be removed in version 0.8.",
DeprecationWarning)
return masked_log_softmax(tensor, mask, dim=-1)


def weighted_sum(matrix: torch.Tensor, attention: torch.Tensor) -> torch.Tensor:
Expand Down