Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Improve performance of attention modules (#1235)
Browse files Browse the repository at this point in the history
* Related to #768
This is an initial POC for getting feedback before proceeding with the entire refactor.
In this PR
1. I added implementations of MatrixAttention. The there is one for dot product, one for cosine product, etc.
2. There is an implementation of MatrixAttention called LegacyMatrixAttention used to support backwards compatibility and use already trained models.
3. There is a class called AttentionMatrixFactory which can be used to create an instance of MatrixAttention. it will determine whether to use the legacy impl or the newer impl based on the name.

* changes following review.

* changes following review.

* changes following review.

* changes following review.

* changes following review.

* changes following review.

* changes following review.

* changes following review.

* changes following review.

* changes following review.

* changes following review.

* add test

* add linear attention module to get feedback on it.

* split tests into separate files.

* add linear tests.

* adding linear and dot product attention.

* adding linear and dot product attention.

* adding linear and dot product attention.

* adding linear and dot product attention.

add cosine attention.
fix typing issue.
add doc

* changes following PR.
1. Changing documentation layout
2. Moving files
3. cleanup

* changes following PR.
1. Changing documentation layout
2. Moving files
3. cleanup

* changes following PR.
1. Changing documentation layout
2. Moving files
3. cleanup

* included in this commit:
1. remove incorrect comment
2. import baseclass into init files in same package and also in parent package.

* Update matrix_attention.py

* Moved tests after merging master

* Pylint (is wasn't run previously due to no __init__.py)
  • Loading branch information
murphp15 authored and matt-gardner committed May 27, 2018
1 parent 69b9af3 commit 986cf17
Show file tree
Hide file tree
Showing 32 changed files with 654 additions and 78 deletions.
5 changes: 3 additions & 2 deletions allennlp/models/decomposable_attention.py
Expand Up @@ -6,8 +6,9 @@
from allennlp.common.checks import check_dimensions_match
from allennlp.data import Vocabulary
from allennlp.models.model import Model
from allennlp.modules import FeedForward, MatrixAttention
from allennlp.modules import FeedForward
from allennlp.modules import Seq2SeqEncoder, SimilarityFunction, TimeDistributed, TextFieldEmbedder
from allennlp.modules.matrix_attention.legacy_matrix_attention import LegacyMatrixAttention
from allennlp.nn import InitializerApplicator, RegularizerApplicator
from allennlp.nn.util import get_text_field_mask, last_dim_softmax, weighted_sum
from allennlp.training.metrics import CategoricalAccuracy
Expand Down Expand Up @@ -74,7 +75,7 @@ def __init__(self, vocab: Vocabulary,

self._text_field_embedder = text_field_embedder
self._attend_feedforward = TimeDistributed(attend_feedforward)
self._matrix_attention = MatrixAttention(similarity_function)
self._matrix_attention = LegacyMatrixAttention(similarity_function)
self._compare_feedforward = TimeDistributed(compare_feedforward)
self._aggregate_feedforward = aggregate_feedforward
self._premise_encoder = premise_encoder
Expand Down
5 changes: 3 additions & 2 deletions allennlp/models/encoder_decoders/simple_seq2seq.py
Expand Up @@ -12,7 +12,8 @@
from allennlp.common import Params
from allennlp.common.util import START_SYMBOL, END_SYMBOL
from allennlp.data.vocabulary import Vocabulary
from allennlp.modules import Attention, TextFieldEmbedder, Seq2SeqEncoder
from allennlp.modules import TextFieldEmbedder, Seq2SeqEncoder
from allennlp.modules.attention import LegacyAttention
from allennlp.modules.similarity_functions import SimilarityFunction
from allennlp.modules.token_embedders import Embedding
from allennlp.models.model import Model
Expand Down Expand Up @@ -94,7 +95,7 @@ def __init__(self,
target_embedding_dim = target_embedding_dim or self._source_embedder.get_output_dim()
self._target_embedder = Embedding(num_classes, target_embedding_dim)
if self._attention_function:
self._decoder_attention = Attention(self._attention_function)
self._decoder_attention = LegacyAttention(self._attention_function)
# The output of attention, a weighted average over encoder outputs, will be
# concatenated to the input vector of the decoder at each time step.
self._decoder_input_dim = self._encoder.get_output_dim() + target_embedding_dim
Expand Down
5 changes: 3 additions & 2 deletions allennlp/models/reading_comprehension/bidaf.py
Expand Up @@ -9,8 +9,9 @@
from allennlp.common.checks import check_dimensions_match
from allennlp.data import Vocabulary
from allennlp.models.model import Model
from allennlp.modules import Highway, MatrixAttention
from allennlp.modules import Highway
from allennlp.modules import Seq2SeqEncoder, SimilarityFunction, TimeDistributed, TextFieldEmbedder
from allennlp.modules.matrix_attention.legacy_matrix_attention import LegacyMatrixAttention
from allennlp.nn import util, InitializerApplicator, RegularizerApplicator
from allennlp.training.metrics import BooleanAccuracy, CategoricalAccuracy, SquadEmAndF1

Expand Down Expand Up @@ -81,7 +82,7 @@ def __init__(self, vocab: Vocabulary,
self._highway_layer = TimeDistributed(Highway(text_field_embedder.get_output_dim(),
num_highway_layers))
self._phrase_layer = phrase_layer
self._matrix_attention = MatrixAttention(attention_similarity_function)
self._matrix_attention = LegacyMatrixAttention(attention_similarity_function)
self._modeling_layer = modeling_layer
self._span_end_encoder = span_end_encoder

Expand Down
4 changes: 2 additions & 2 deletions allennlp/models/semantic_parsing/nlvr/nlvr_decoder_step.py
Expand Up @@ -11,7 +11,7 @@

from allennlp.common import util as common_util
from allennlp.models.semantic_parsing.nlvr.nlvr_decoder_state import NlvrDecoderState
from allennlp.modules import Attention
from allennlp.modules.attention.legacy_attention import LegacyAttention
from allennlp.modules.similarity_functions import SimilarityFunction
from allennlp.nn.decoding import DecoderStep, RnnState, ChecklistState
from allennlp.nn import util as nn_util
Expand All @@ -37,7 +37,7 @@ def __init__(self,
dropout: float = 0.0,
use_coverage: bool = False) -> None:
super(NlvrDecoderStep, self).__init__()
self._input_attention = Attention(attention_function)
self._input_attention = LegacyAttention(attention_function)

# Decoder output dim needs to be the same as the encoder output dim since we initialize the
# hidden state of the decoder with the final hidden state of the encoder.
Expand Down
Expand Up @@ -12,7 +12,8 @@
from allennlp.common import util as common_util
from allennlp.common.checks import check_dimensions_match
from allennlp.models.semantic_parsing.wikitables.wikitables_decoder_state import WikiTablesDecoderState
from allennlp.modules import Attention, FeedForward
from allennlp.modules import FeedForward
from allennlp.modules.attention.legacy_attention import LegacyAttention
from allennlp.modules.similarity_functions import SimilarityFunction
from allennlp.modules.token_embedders import Embedding
from allennlp.nn import util
Expand Down Expand Up @@ -49,7 +50,7 @@ def __init__(self,
super(WikiTablesDecoderStep, self).__init__()
self._mixture_feedforward = mixture_feedforward
self._entity_type_embedding = Embedding(num_entity_types, action_embedding_dim)
self._input_attention = Attention(attention_function)
self._input_attention = LegacyAttention(attention_function)

self._num_start_types = num_start_types
self._start_type_predictor = Linear(encoder_output_dim, num_start_types)
Expand Down
4 changes: 2 additions & 2 deletions allennlp/modules/__init__.py
Expand Up @@ -5,13 +5,11 @@
:class:`~allennlp.models.model.Model` s.
"""

from allennlp.modules.attention import Attention
from allennlp.modules.conditional_random_field import ConditionalRandomField
from allennlp.modules.elmo import Elmo
from allennlp.modules.feedforward import FeedForward
from allennlp.modules.highway import Highway
from allennlp.modules.layer_norm import LayerNorm
from allennlp.modules.matrix_attention import MatrixAttention
from allennlp.modules.maxout import Maxout
from allennlp.modules.scalar_mix import ScalarMix
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder
Expand All @@ -21,3 +19,5 @@
from allennlp.modules.text_field_embedders import TextFieldEmbedder
from allennlp.modules.time_distributed import TimeDistributed
from allennlp.modules.token_embedders import TokenEmbedder, Embedding
from allennlp.modules.matrix_attention import MatrixAttention
from allennlp.modules.attention import Attention
6 changes: 6 additions & 0 deletions allennlp/modules/attention/__init__.py
@@ -0,0 +1,6 @@

from allennlp.modules.attention.attention import Attention
from allennlp.modules.attention.linear_attention import LinearAttention
from allennlp.modules.attention.dot_product_attention import DotProductAttention
from allennlp.modules.attention.legacy_attention import LegacyAttention
from allennlp.modules.attention.cosine_attention import CosineAttention
Expand Up @@ -4,21 +4,19 @@
"""

import torch
from overrides import overrides

from overrides import overrides
from allennlp.common.registrable import Registrable
from allennlp.common import Params
from allennlp.modules.similarity_functions import DotProductSimilarity, SimilarityFunction
from allennlp.nn.util import masked_softmax


class Attention(torch.nn.Module):
class Attention(torch.nn.Module, Registrable):
"""
This ``Module`` takes two inputs: a (batched) vector and a matrix, plus an optional mask on the
An ``Attention`` takes two inputs: a (batched) vector and a matrix, plus an optional mask on the
rows of the matrix. We compute the similarity between the vector and each row in the matrix,
and then (optionally) perform a softmax over rows using those computed similarities.
By default similarity is computed with a dot product, but you can alternatively use a
parameterized similarity function if you wish.
Inputs:
Expand All @@ -32,38 +30,31 @@ class Attention(torch.nn.Module):
Parameters
----------
similarity_function : ``SimilarityFunction``, optional (default=``DotProductSimilarity``)
The similarity function to use when computing the attention.
normalize : ``bool``, optional (default: ``True``)
If true, we normalize the computed similarities with a softmax, to return a probability
distribution for your attention. If false, this is just computing a similarity score.
"""
def __init__(self,
similarity_function: SimilarityFunction = None,
normalize: bool = True) -> None:
super(Attention, self).__init__()

self._similarity_function = similarity_function or DotProductSimilarity()
self._normalize = normalize

@overrides
def forward(self, # pylint: disable=arguments-differ
vector: torch.Tensor,
matrix: torch.Tensor,
matrix_mask: torch.Tensor = None) -> torch.Tensor:
tiled_vector = vector.unsqueeze(1).expand(vector.size()[0],
matrix.size()[1],
vector.size()[1])
similarities = self._similarity_function(tiled_vector, matrix)
similarities = self._forward_internal(vector, matrix, matrix_mask)
if self._normalize:
return masked_softmax(similarities, matrix_mask)
else:
return similarities

def _forward_internal(self, vector: torch.Tensor, matrix: torch.Tensor,
matrix_mask: torch.Tensor = None) -> torch.Tensor:
raise NotImplementedError

@classmethod
def from_params(cls, params: Params) -> 'Attention':
similarity_function = SimilarityFunction.from_params(params.pop('similarity_function', {}))
normalize = params.pop_bool('normalize', True)
params.assert_empty(cls.__name__)
return cls(similarity_function=similarity_function,
normalize=normalize)
clazz = cls.by_name(params.pop_choice("type", cls.list_available()))
return clazz.from_params(params)
27 changes: 27 additions & 0 deletions allennlp/modules/attention/cosine_attention.py
@@ -0,0 +1,27 @@

import torch
from overrides import overrides
from allennlp.common import Params
from allennlp.modules.attention.legacy_attention import Attention


@Attention.register("cosine")
class CosineAttention(Attention):
"""
Computes attention between a vector and a matrix using cosine similarity.
"""

@overrides
def _forward_internal(self,
vector: torch.Tensor,
matrix: torch.Tensor,
matrix_mask: torch.Tensor = None) -> torch.Tensor:
a_norm = vector / (vector.norm(p=2, dim=-1, keepdim=True) + 1e-13)
b_norm = matrix / (matrix.norm(p=2, dim=-1, keepdim=True) + 1e-13)
return torch.bmm(a_norm.unsqueeze(dim=1), b_norm.transpose(-1, -2)).squeeze(1)

@classmethod
def from_params(cls, params: Params):
normalize = params.pop_bool('normalize', True)
params.assert_empty(cls.__name__)
return CosineAttention(normalize)
25 changes: 25 additions & 0 deletions allennlp/modules/attention/dot_product_attention.py
@@ -0,0 +1,25 @@

import torch
from overrides import overrides
from allennlp.common import Params
from allennlp.modules.attention.legacy_attention import Attention


@Attention.register("dot_product")
class DotProductAttention(Attention):
"""
Computes attention between a vector and a matrix using dot product.
"""

@overrides
def _forward_internal(self,
vector: torch.Tensor,
matrix: torch.Tensor,
matrix_mask: torch.Tensor = None) -> torch.Tensor:
return matrix.bmm(vector.unsqueeze(-1)).squeeze(-1)

@classmethod
def from_params(cls, params: Params):
normalize = params.pop_bool('normalize', True)
params.assert_empty(cls.__name__)
return DotProductAttention(normalize)
39 changes: 39 additions & 0 deletions allennlp/modules/attention/legacy_attention.py
@@ -0,0 +1,39 @@

import torch

from overrides import overrides
from allennlp.common import Params
from allennlp.modules.attention.attention import Attention
from allennlp.modules.similarity_functions import DotProductSimilarity, SimilarityFunction


@Attention.register("legacy")
class LegacyAttention(Attention):
"""
Computes attention between a vector and a matrix using a similarity function.
This should be considered deprecated, as it consumes more memory than the specialized attention modules.
"""

def __init__(self,
similarity_function: SimilarityFunction = None,
normalize: bool = True) -> None:
super().__init__(normalize)
self._similarity_function = similarity_function or DotProductSimilarity()

@overrides
def _forward_internal(self,
vector: torch.Tensor,
matrix: torch.Tensor,
matrix_mask: torch.Tensor = None) -> torch.Tensor:
tiled_vector = vector.unsqueeze(1).expand(vector.size()[0],
matrix.size()[1],
vector.size()[1])
return self._similarity_function(tiled_vector, matrix)

@classmethod
def from_params(cls, params: Params) -> 'Attention':
similarity_function = SimilarityFunction.from_params(params.pop('similarity_function', {}))
normalize = params.pop_bool('normalize', True)
params.assert_empty(cls.__name__)
return cls(similarity_function=similarity_function,
normalize=normalize)
93 changes: 93 additions & 0 deletions allennlp/modules/attention/linear_attention.py
@@ -0,0 +1,93 @@
import math

import torch
from torch.nn import Parameter
from overrides import overrides
from allennlp.modules.attention.legacy_attention import Attention
from allennlp.nn import util
from allennlp.nn.activations import Activation
from allennlp.common.params import Params


@Attention.register("linear")
class LinearAttention(Attention):
"""
This ``Attention`` module performs a dot product between a vector of weights and some
combination of the two input vectors, followed by an (optional) activation function. The
combination used is configurable.
If the two vectors are ``x`` and ``y``, we allow the following kinds of combinations: ``x``,
``y``, ``x*y``, ``x+y``, ``x-y``, ``x/y``, where each of those binary operations is performed
elementwise. You can list as many combinations as you want, comma separated. For example, you
might give ``x,y,x*y`` as the ``combination`` parameter to this class. The computed similarity
function would then be ``w^T [x; y; x*y] + b``, where ``w`` is a vector of weights, ``b`` is a
bias parameter, and ``[;]`` is vector concatenation.
Note that if you want a bilinear similarity function with a diagonal weight matrix W, where the
similarity function is computed as `x * w * y + b` (with `w` the diagonal of `W`), you can
accomplish that with this class by using "x*y" for `combination`.
Parameters
----------
tensor_1_dim : ``int``
The dimension of the first tensor, ``x``, described above. This is ``x.size()[-1]`` - the
length of the vector that will go into the similarity computation. We need this so we can
build weight vectors correctly.
tensor_2_dim : ``int``
The dimension of the second tensor, ``y``, described above. This is ``y.size()[-1]`` - the
length of the vector that will go into the similarity computation. We need this so we can
build weight vectors correctly.
combination : ``str``, optional (default="x,y")
Described above.
activation : ``Activation``, optional (default=linear (i.e. no activation))
An activation function applied after the ``w^T * [x;y] + b`` calculation. Default is no
activation.
"""

def __init__(self,
tensor_1_dim: int,
tensor_2_dim: int,
combination: str = 'x,y',
activation: Activation = Activation.by_name('linear')(),
normalize: bool = True) -> None:
super().__init__(normalize)
self._combination = combination
combined_dim = util.get_combined_dim(combination, [tensor_1_dim, tensor_2_dim])
self._weight_vector = Parameter(torch.Tensor(combined_dim))
self._bias = Parameter(torch.Tensor(1))
self._activation = activation
self.reset_parameters()

def reset_parameters(self):
std = math.sqrt(6 / (self._weight_vector.size(0) + 1))
self._weight_vector.data.uniform_(-std, std)
self._bias.data.fill_(0)

@overrides
def _forward_internal(self,
vector: torch.Tensor,
matrix: torch.Tensor,
matrix_mask: torch.Tensor = None) -> torch.Tensor:
# TODO(mattg): Remove the need for this tiling.
# https://github.com/allenai/allennlp/pull/1235#issuecomment-391540133
tiled_vector = vector.unsqueeze(1).expand(vector.size()[0],
matrix.size()[1],
vector.size()[1])

combined_tensors = util.combine_tensors(self._combination, [tiled_vector, matrix])
dot_product = torch.matmul(combined_tensors, self._weight_vector)
return self._activation(dot_product + self._bias)

@classmethod
def from_params(cls, params: Params) -> 'Attention':
tensor_1_dim = params.pop_int("tensor_1_dim")
tensor_2_dim = params.pop_int("tensor_2_dim")
combination = params.pop("combination", "x,y")
activation = Activation.by_name(params.pop("activation", "linear"))()
normalize = params.pop_bool('normalize', True)
params.assert_empty(cls.__name__)
return cls(normalize=normalize,
tensor_1_dim=tensor_1_dim,
tensor_2_dim=tensor_2_dim,
combination=combination,
activation=activation)
6 changes: 6 additions & 0 deletions allennlp/modules/matrix_attention/__init__.py
@@ -0,0 +1,6 @@

from allennlp.modules.matrix_attention.matrix_attention import MatrixAttention
from allennlp.modules.matrix_attention.dot_product_matrix_attention import DotProductMatrixAttention
from allennlp.modules.matrix_attention.cosine_matrix_attention import CosineMatrixAttention
from allennlp.modules.matrix_attention.linear_matrix_attention import LinearMatrixAttention
from allennlp.modules.matrix_attention.legacy_matrix_attention import LegacyMatrixAttention

0 comments on commit 986cf17

Please sign in to comment.