This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Improve performance of attention modules (#1235)
* Related to #768 This is an initial POC for getting feedback before proceeding with the entire refactor. In this PR 1. I added implementations of MatrixAttention. The there is one for dot product, one for cosine product, etc. 2. There is an implementation of MatrixAttention called LegacyMatrixAttention used to support backwards compatibility and use already trained models. 3. There is a class called AttentionMatrixFactory which can be used to create an instance of MatrixAttention. it will determine whether to use the legacy impl or the newer impl based on the name. * changes following review. * changes following review. * changes following review. * changes following review. * changes following review. * changes following review. * changes following review. * changes following review. * changes following review. * changes following review. * changes following review. * add test * add linear attention module to get feedback on it. * split tests into separate files. * add linear tests. * adding linear and dot product attention. * adding linear and dot product attention. * adding linear and dot product attention. * adding linear and dot product attention. add cosine attention. fix typing issue. add doc * changes following PR. 1. Changing documentation layout 2. Moving files 3. cleanup * changes following PR. 1. Changing documentation layout 2. Moving files 3. cleanup * changes following PR. 1. Changing documentation layout 2. Moving files 3. cleanup * included in this commit: 1. remove incorrect comment 2. import baseclass into init files in same package and also in parent package. * Update matrix_attention.py * Moved tests after merging master * Pylint (is wasn't run previously due to no __init__.py)
- Loading branch information
1 parent
69b9af3
commit 986cf17
Showing
32 changed files
with
654 additions
and
78 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
|
||
from allennlp.modules.attention.attention import Attention | ||
from allennlp.modules.attention.linear_attention import LinearAttention | ||
from allennlp.modules.attention.dot_product_attention import DotProductAttention | ||
from allennlp.modules.attention.legacy_attention import LegacyAttention | ||
from allennlp.modules.attention.cosine_attention import CosineAttention |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
|
||
import torch | ||
from overrides import overrides | ||
from allennlp.common import Params | ||
from allennlp.modules.attention.legacy_attention import Attention | ||
|
||
|
||
@Attention.register("cosine") | ||
class CosineAttention(Attention): | ||
""" | ||
Computes attention between a vector and a matrix using cosine similarity. | ||
""" | ||
|
||
@overrides | ||
def _forward_internal(self, | ||
vector: torch.Tensor, | ||
matrix: torch.Tensor, | ||
matrix_mask: torch.Tensor = None) -> torch.Tensor: | ||
a_norm = vector / (vector.norm(p=2, dim=-1, keepdim=True) + 1e-13) | ||
b_norm = matrix / (matrix.norm(p=2, dim=-1, keepdim=True) + 1e-13) | ||
return torch.bmm(a_norm.unsqueeze(dim=1), b_norm.transpose(-1, -2)).squeeze(1) | ||
|
||
@classmethod | ||
def from_params(cls, params: Params): | ||
normalize = params.pop_bool('normalize', True) | ||
params.assert_empty(cls.__name__) | ||
return CosineAttention(normalize) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
|
||
import torch | ||
from overrides import overrides | ||
from allennlp.common import Params | ||
from allennlp.modules.attention.legacy_attention import Attention | ||
|
||
|
||
@Attention.register("dot_product") | ||
class DotProductAttention(Attention): | ||
""" | ||
Computes attention between a vector and a matrix using dot product. | ||
""" | ||
|
||
@overrides | ||
def _forward_internal(self, | ||
vector: torch.Tensor, | ||
matrix: torch.Tensor, | ||
matrix_mask: torch.Tensor = None) -> torch.Tensor: | ||
return matrix.bmm(vector.unsqueeze(-1)).squeeze(-1) | ||
|
||
@classmethod | ||
def from_params(cls, params: Params): | ||
normalize = params.pop_bool('normalize', True) | ||
params.assert_empty(cls.__name__) | ||
return DotProductAttention(normalize) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
|
||
import torch | ||
|
||
from overrides import overrides | ||
from allennlp.common import Params | ||
from allennlp.modules.attention.attention import Attention | ||
from allennlp.modules.similarity_functions import DotProductSimilarity, SimilarityFunction | ||
|
||
|
||
@Attention.register("legacy") | ||
class LegacyAttention(Attention): | ||
""" | ||
Computes attention between a vector and a matrix using a similarity function. | ||
This should be considered deprecated, as it consumes more memory than the specialized attention modules. | ||
""" | ||
|
||
def __init__(self, | ||
similarity_function: SimilarityFunction = None, | ||
normalize: bool = True) -> None: | ||
super().__init__(normalize) | ||
self._similarity_function = similarity_function or DotProductSimilarity() | ||
|
||
@overrides | ||
def _forward_internal(self, | ||
vector: torch.Tensor, | ||
matrix: torch.Tensor, | ||
matrix_mask: torch.Tensor = None) -> torch.Tensor: | ||
tiled_vector = vector.unsqueeze(1).expand(vector.size()[0], | ||
matrix.size()[1], | ||
vector.size()[1]) | ||
return self._similarity_function(tiled_vector, matrix) | ||
|
||
@classmethod | ||
def from_params(cls, params: Params) -> 'Attention': | ||
similarity_function = SimilarityFunction.from_params(params.pop('similarity_function', {})) | ||
normalize = params.pop_bool('normalize', True) | ||
params.assert_empty(cls.__name__) | ||
return cls(similarity_function=similarity_function, | ||
normalize=normalize) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import math | ||
|
||
import torch | ||
from torch.nn import Parameter | ||
from overrides import overrides | ||
from allennlp.modules.attention.legacy_attention import Attention | ||
from allennlp.nn import util | ||
from allennlp.nn.activations import Activation | ||
from allennlp.common.params import Params | ||
|
||
|
||
@Attention.register("linear") | ||
class LinearAttention(Attention): | ||
""" | ||
This ``Attention`` module performs a dot product between a vector of weights and some | ||
combination of the two input vectors, followed by an (optional) activation function. The | ||
combination used is configurable. | ||
If the two vectors are ``x`` and ``y``, we allow the following kinds of combinations: ``x``, | ||
``y``, ``x*y``, ``x+y``, ``x-y``, ``x/y``, where each of those binary operations is performed | ||
elementwise. You can list as many combinations as you want, comma separated. For example, you | ||
might give ``x,y,x*y`` as the ``combination`` parameter to this class. The computed similarity | ||
function would then be ``w^T [x; y; x*y] + b``, where ``w`` is a vector of weights, ``b`` is a | ||
bias parameter, and ``[;]`` is vector concatenation. | ||
Note that if you want a bilinear similarity function with a diagonal weight matrix W, where the | ||
similarity function is computed as `x * w * y + b` (with `w` the diagonal of `W`), you can | ||
accomplish that with this class by using "x*y" for `combination`. | ||
Parameters | ||
---------- | ||
tensor_1_dim : ``int`` | ||
The dimension of the first tensor, ``x``, described above. This is ``x.size()[-1]`` - the | ||
length of the vector that will go into the similarity computation. We need this so we can | ||
build weight vectors correctly. | ||
tensor_2_dim : ``int`` | ||
The dimension of the second tensor, ``y``, described above. This is ``y.size()[-1]`` - the | ||
length of the vector that will go into the similarity computation. We need this so we can | ||
build weight vectors correctly. | ||
combination : ``str``, optional (default="x,y") | ||
Described above. | ||
activation : ``Activation``, optional (default=linear (i.e. no activation)) | ||
An activation function applied after the ``w^T * [x;y] + b`` calculation. Default is no | ||
activation. | ||
""" | ||
|
||
def __init__(self, | ||
tensor_1_dim: int, | ||
tensor_2_dim: int, | ||
combination: str = 'x,y', | ||
activation: Activation = Activation.by_name('linear')(), | ||
normalize: bool = True) -> None: | ||
super().__init__(normalize) | ||
self._combination = combination | ||
combined_dim = util.get_combined_dim(combination, [tensor_1_dim, tensor_2_dim]) | ||
self._weight_vector = Parameter(torch.Tensor(combined_dim)) | ||
self._bias = Parameter(torch.Tensor(1)) | ||
self._activation = activation | ||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
std = math.sqrt(6 / (self._weight_vector.size(0) + 1)) | ||
self._weight_vector.data.uniform_(-std, std) | ||
self._bias.data.fill_(0) | ||
|
||
@overrides | ||
def _forward_internal(self, | ||
vector: torch.Tensor, | ||
matrix: torch.Tensor, | ||
matrix_mask: torch.Tensor = None) -> torch.Tensor: | ||
# TODO(mattg): Remove the need for this tiling. | ||
# https://github.com/allenai/allennlp/pull/1235#issuecomment-391540133 | ||
tiled_vector = vector.unsqueeze(1).expand(vector.size()[0], | ||
matrix.size()[1], | ||
vector.size()[1]) | ||
|
||
combined_tensors = util.combine_tensors(self._combination, [tiled_vector, matrix]) | ||
dot_product = torch.matmul(combined_tensors, self._weight_vector) | ||
return self._activation(dot_product + self._bias) | ||
|
||
@classmethod | ||
def from_params(cls, params: Params) -> 'Attention': | ||
tensor_1_dim = params.pop_int("tensor_1_dim") | ||
tensor_2_dim = params.pop_int("tensor_2_dim") | ||
combination = params.pop("combination", "x,y") | ||
activation = Activation.by_name(params.pop("activation", "linear"))() | ||
normalize = params.pop_bool('normalize', True) | ||
params.assert_empty(cls.__name__) | ||
return cls(normalize=normalize, | ||
tensor_1_dim=tensor_1_dim, | ||
tensor_2_dim=tensor_2_dim, | ||
combination=combination, | ||
activation=activation) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
|
||
from allennlp.modules.matrix_attention.matrix_attention import MatrixAttention | ||
from allennlp.modules.matrix_attention.dot_product_matrix_attention import DotProductMatrixAttention | ||
from allennlp.modules.matrix_attention.cosine_matrix_attention import CosineMatrixAttention | ||
from allennlp.modules.matrix_attention.linear_matrix_attention import LinearMatrixAttention | ||
from allennlp.modules.matrix_attention.legacy_matrix_attention import LegacyMatrixAttention |
Oops, something went wrong.