Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
get the dimensions correct for the transformer (#941)
Browse files Browse the repository at this point in the history
* get the dimensions correct for the transformer

* remove last layer norm, lint

* update param descriptions

* PR feedback

* whitespace
  • Loading branch information
DeNeutoy committed Mar 2, 2018
1 parent c12d435 commit f81e27a
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 70 deletions.
104 changes: 46 additions & 58 deletions allennlp/modules/seq2seq_encoders/multi_head_self_attention.py
Expand Up @@ -2,10 +2,8 @@

from torch.autograd import Variable
from torch.nn import Dropout, Linear
from torch.nn import Parameter
from torch.nn import init

from allennlp.nn.util import last_dim_softmax
from allennlp.nn.util import last_dim_softmax, weighted_sum
from allennlp.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder
from allennlp.common.params import Params

Expand All @@ -29,11 +27,11 @@ class MultiHeadSelfAttention(Seq2SeqEncoder):
input_dim : ``int``, required.
The size of the last dimension of the input tensor.
attention_dim ``int``, required.
The dimension of the query and key projections which comprise the
dot product attention function.
The total dimension of the query and key projections which comprise the
dot product attention function. Must be divisible by ``num_heads``.
values_dim : ``int``, required.
The dimension which the input is projected to for representing the values,
which are combined using the attention.
The total dimension which the input is projected to for representing the values,
which are combined using the attention. Must be divisible by ``num_heads``.
output_projection_dim : ``int``, optional (default = None)
The dimensionality of the final output projection. If this is not passed
explicitly, the projection has size `input_size`.
Expand All @@ -56,26 +54,20 @@ def __init__(self,
self._attention_dim = attention_dim
self._values_dim = values_dim

self._query_projections = Parameter(torch.FloatTensor(num_heads, input_dim, attention_dim))
self._key_projections = Parameter(torch.FloatTensor(num_heads, input_dim, attention_dim))
self._value_projections = Parameter(torch.FloatTensor(num_heads, input_dim, values_dim))
if attention_dim % num_heads != 0:
raise ValueError(f"Key size ({attention_dim}) must be divisible by the number of "
f"attention heads ({num_heads}).")

if values_dim % num_heads != 0:
raise ValueError(f"Value size ({values_dim}) must be divisible by the number of "
f"attention heads ({num_heads}).")

self._scale = input_dim ** 0.5
self._output_projection = Linear(num_heads * values_dim,
self._output_dim)
self._attention_dropout = Dropout(attention_dropout_prob)

self.reset_parameters()
self._combined_projection = Linear(input_dim, 2 * attention_dim + values_dim)

def reset_parameters(self) -> None:
# Because we are doing so many torch.bmm calls, which is fast but unstable,
# it is critically important to intitialise the parameters correctly such
# that these matrix multiplications are well conditioned initially.
# Without this initialisation, this (non-deterministically) produces
# NaNs and overflows.
init.xavier_normal(self._query_projections)
init.xavier_normal(self._key_projections)
init.xavier_normal(self._value_projections)
self._scale = input_dim ** 0.5
self._output_projection = Linear(values_dim, self._output_dim)
self._attention_dropout = Dropout(attention_dropout_prob)

def get_input_dim(self):
return self._input_dim
Expand All @@ -101,33 +93,34 @@ def forward(self, # pylint: disable=arguments-differ
"""
num_heads = self._num_heads

batch_size, timesteps, hidden_dim = inputs.size()
batch_size, timesteps, _ = inputs.size()
if mask is None:
mask = Variable(inputs.data.new(batch_size, timesteps).fill_(1.0))

# Treat the queries, keys and values each as a ``num_heads`` size batch.
# shape (num_heads, batch_size * timesteps, hidden_dim)
inputs_per_head = inputs.repeat(num_heads, 1, 1).view(num_heads,
batch_size * timesteps,
hidden_dim)
# Do the projections for all the heads at once.
# Then reshape the result as though it had a
# (num_heads * batch_size) sized batch.
queries_per_head = torch.bmm(inputs_per_head, self._query_projections)
# shape (num_heads * batch_size, timesteps, attention_dim)
queries_per_head = queries_per_head.view(num_heads * batch_size,
timesteps,
self._attention_dim)

keys_per_head = torch.bmm(inputs_per_head, self._key_projections)
# shape (num_heads * batch_size, timesteps, attention_dim)
keys_per_head = keys_per_head.view(num_heads * batch_size,
timesteps,
self._attention_dim)

values_per_head = torch.bmm(inputs_per_head, self._value_projections)
# shape (num_heads * batch_size, timesteps, attention_dim)
values_per_head = values_per_head.view(num_heads * batch_size, timesteps, self._values_dim)
# Shape (batch_size, timesteps, 2 * attention_dim + values_dim)
combined_projection = self._combined_projection(inputs)

# split by attention dim - if values_dim > attention_dim, we will get more
# than 3 elements returned. All of the rest are the values vector, so we
# just concatenate them back together again below.
queries, keys, *values = combined_projection.split(self._attention_dim, -1)
queries = queries.contiguous()
keys = keys.contiguous()
values = torch.cat(values, -1).contiguous()
# Shape (num_heads * batch_size, timesteps, values_dim / num_heads)
values_per_head = values.view(batch_size, timesteps, num_heads, int(self._values_dim/num_heads))
values_per_head = values_per_head.transpose(1, 2).contiguous()
values_per_head = values_per_head.view(batch_size * num_heads, timesteps, int(self._values_dim/num_heads))

# Shape (num_heads * batch_size, timesteps, attention_dim / num_heads)
queries_per_head = queries.view(batch_size, timesteps, num_heads, int(self._attention_dim/num_heads))
queries_per_head = queries_per_head.transpose(1, 2).contiguous()
queries_per_head = queries_per_head.view(batch_size * num_heads, timesteps, int(self._attention_dim/num_heads))

# Shape (num_heads * batch_size, timesteps, attention_dim / num_heads)
keys_per_head = keys.view(batch_size, timesteps, num_heads, int(self._attention_dim/num_heads))
keys_per_head = keys_per_head.transpose(1, 2).contiguous()
keys_per_head = keys_per_head.view(batch_size * num_heads, timesteps, int(self._attention_dim/num_heads))

# shape (num_heads * batch_size, timesteps, timesteps)
scaled_similarities = torch.bmm(queries_per_head, keys_per_head.transpose(1, 2)) / self._scale
Expand All @@ -136,16 +129,11 @@ def forward(self, # pylint: disable=arguments-differ
# Normalise the distributions, using the same mask for all heads.
attention = last_dim_softmax(scaled_similarities, mask.repeat(num_heads, 1))
attention = self._attention_dropout(attention)
# This is doing the following batch-wise matrix multiplication:
# (num_heads * batch_size, timesteps, timesteps) *
# (num_heads * batch_size, timesteps, values_dim)
# which is equivalent to a weighted sum of the values with respect to
# the attention distributions for each element in the num_heads * batch_size
# dimension.
# shape (num_heads * batch_size, timesteps, values_dim)
outputs = torch.bmm(attention, values_per_head)

# Reshape back to original shape (batch_size, timesteps, num_heads * values_dim)
# Take a weighted sum of the values with respect to the attention
# distributions for each element in the num_heads * batch_size dimension.
# shape (num_heads * batch_size, timesteps, values_dim/num_heads)
outputs = weighted_sum(values_per_head, attention)
# Reshape back to original shape (batch_size, timesteps, values_dim)
# Note that we _cannot_ use a reshape here, because this tensor was created
# with num_heads being the first dimension, so reshaping naively would not
# throw an error, but give an incorrect result.
Expand Down
16 changes: 8 additions & 8 deletions allennlp/modules/seq2seq_encoders/stacked_self_attention.py
Expand Up @@ -84,7 +84,7 @@ def __init__(self,
self.add_module(f"feedforward_{i}", feedfoward)
self._feedfoward_layers.append(feedfoward)

feedforward_layer_norm = LayerNorm(feedfoward.get_input_dim())
feedforward_layer_norm = LayerNorm(feedfoward.get_output_dim())
self.add_module(f"feedforward_layer_norm_{i}", feedforward_layer_norm)
self._feed_forward_layer_norm_layers.append(feedforward_layer_norm)

Expand All @@ -95,7 +95,7 @@ def __init__(self,
self.add_module(f"self_attention_{i}", self_attention)
self._attention_layers.append(self_attention)

layer_norm = LayerNorm(self_attention.get_input_dim())
layer_norm = LayerNorm(self_attention.get_output_dim())
self.add_module(f"layer_norm_{i}", layer_norm)
self._layer_norm_layers.append(layer_norm)

Expand All @@ -104,7 +104,6 @@ def __init__(self,
self.dropout = Dropout(dropout_prob)
self._input_dim = input_dim
self._output_dim = self._attention_layers[-1].get_output_dim()
self._output_layer_norm = LayerNorm(self._output_dim)

@overrides
def get_input_dim(self) -> int:
Expand All @@ -130,16 +129,17 @@ def forward(self, inputs: torch.Tensor, mask: torch.Tensor): # pylint: disable=a
# Project output of attention encoder through a feedforward
# network and back to the input size for the next layer.
# shape (batch_size, timesteps, input_size)
feedforward_output = feedforward(feedforward_layer_norm(output))
feedforward_output = feedforward(output)
feedforward_output = self.dropout(feedforward_output)
if feedforward_output.size() == cached_input.size():
# First layer might have the wrong size for highway
# layers, so we exclude it here.
feedforward_output += cached_input
feedforward_output = feedforward_layer_norm(feedforward_output + cached_input)
# shape (batch_size, sequence_length, hidden_dim)
attention_output = attention(layer_norm(feedforward_output), mask)
output = self.dropout(attention_output) + feedforward_output
return self._output_layer_norm(output)
attention_output = attention(feedforward_output, mask)
output = layer_norm(self.dropout(attention_output) + feedforward_output)

return output

@classmethod
def from_params(cls, params: Params):
Expand Down
Expand Up @@ -11,7 +11,7 @@
class MultiHeadSelfAttentionTest(AllenNlpTestCase):

def test_multi_head_self_attention_can_build_from_params(self):
params = Params({"num_heads": 3, "input_dim": 2, "attention_dim": 5, "values_dim": 5})
params = Params({"num_heads": 3, "input_dim": 2, "attention_dim": 3, "values_dim": 6})

encoder = MultiHeadSelfAttention.from_params(params)
assert isinstance(encoder, MultiHeadSelfAttention)
Expand All @@ -21,15 +21,15 @@ def test_multi_head_self_attention_can_build_from_params(self):
def test_multi_head_self_attention_runs_forward(self):
attention = MultiHeadSelfAttention(num_heads=3,
input_dim=5,
attention_dim=7,
attention_dim=6,
values_dim=9)
inputs = Variable(torch.randn(2, 12, 5))
assert list(attention(inputs).size()) == [2, 12, 5]

def test_multi_head_self_attention_respects_masking(self):
attention = MultiHeadSelfAttention(num_heads=3,
input_dim=5,
attention_dim=7,
attention_dim=6,
values_dim=9,
attention_dropout_prob=0.0)
tensor = Variable(torch.randn(2, 12, 5))
Expand Down
Expand Up @@ -10,7 +10,7 @@ class TestStackedSelfAttention(AllenNlpTestCase):
def test_get_dimension_is_correct(self):
encoder = StackedSelfAttentionEncoder(input_dim=9,
hidden_dim=12,
projection_dim=7,
projection_dim=6,
feedforward_hidden_dim=5,
num_layers=3,
num_attention_heads=3)
Expand Down

0 comments on commit f81e27a

Please sign in to comment.