# The Attention mechanism and its implementation in PyTorch

Computing self-attention of a sentence with GloVe embeddings and the `MultiheadAttention` class with PyTorch

Author: Pierre Nugues

## Modules

In [1]:
import torch
import torch.nn
import torch.nn.functional as F


## Noncontextual embeddings

We load GloVe

In [2]:
def read_embeddings(file):
    """
    Return the embeddings in the from of a dictionary
    :param file:
    :return:
    """
    embeddings = {}
    with open(file, encoding='utf8') as glove:
        for line in glove:
            values = line.strip().split()
            word = values[0]
            vector = [float(value) for value in values[1:]]
            vector = torch.FloatTensor(vector)
            embeddings[word] = vector
    return embeddings


In [3]:
embedding_file = '/Users/pierre/Documents/Cours/EDAN20/corpus/glove.6B.50d.txt'
embeddings_dict = read_embeddings(embedding_file)


In [4]:
embeddings_dict['ship']


tensor([ 1.5213,  0.1052,  0.3816, -0.5080,  0.0324, -0.1348, -1.2474,  0.7981,
         0.8469, -1.1010,  0.8874,  1.3749,  0.4293,  0.6572, -0.2636, -0.4176,
        -0.4885,  0.9106, -1.7158, -0.4380,  0.7839,  0.1964, -0.4066, -0.5397,
         0.8244, -1.7434,  0.1428,  0.2804,  1.1688,  0.1690,  2.2271, -0.5827,
        -0.4572,  0.6281,  0.5444,  0.2846,  0.4448, -0.5534, -0.3649, -0.0164,
         0.4088, -0.8715,  1.5513, -0.8070, -0.1004, -0.2846, -0.3322, -0.5061,
         0.4827, -0.6620])

## Cosine similarity

Let us compute the cosine similarity of the words in a sentence:
> I must go back to my ship and to my crew

_Odyssey_, book I 

Remember that:
$$\cos(\mathbf{u}, \mathbf{v}) = \frac{\mathbf{u} \cdot \mathbf{v}}{||\mathbf{u}|| \cdot ||\mathbf{v} ||}$$

In [5]:
sentence_odyssey = 'I must go back to my ship and to my crew'
# in the most cost-efficient way possible
sentence_amazon = 'We process and ship your order'


In [6]:
words_a = sentence_amazon.lower().split()
words_o = sentence_odyssey.lower().split()
words_o


['i', 'must', 'go', 'back', 'to', 'my', 'ship', 'and', 'to', 'my', 'crew']

We build the embedding matrix

In [7]:
def embedding_matrix(words, embeddings_dict):
    embeddings_seq = [embeddings_dict[word] for word in words]
    embeddings_seq = torch.stack(embeddings_seq)
    return embeddings_seq


In [8]:
embeddings_mat_a = embedding_matrix(words_a, embeddings_dict)
embeddings_mat_o = embedding_matrix(words_o, embeddings_dict)


In [9]:
embeddings_mat_o.size()


torch.Size([11, 50])

In [10]:
embeddings_mat_a.size()


torch.Size([6, 50])

In [11]:
embeddings_mat_o[0][:10]


tensor([ 1.1891e-01,  1.5255e-01, -8.2073e-02, -7.4144e-01,  7.5917e-01,
        -4.8328e-01, -3.1009e-01,  5.1476e-01, -9.8708e-01,  6.1757e-04])

We compute the attention weights as the pairwise cosines of the word embeddings

In [12]:
def attn_cos_weights(embeddings_mat):
    E_normed = F.normalize(embeddings_mat)
    attn_weights_cos = E_normed @ E_normed.T
    return attn_weights_cos


In [13]:
def print_cos_weights(words, embeddings_dict):
    embeddings = embedding_matrix(words, embeddings_dict)
    attn_weights_cos = attn_cos_weights(embeddings)
    print('\t', end='')
    for i in range(len(words)):
        print(words[i], end='\t')
    print()

    for i in range(attn_weights_cos.shape[0]):
        print(words[i], end='\t')
        for j in range(attn_weights_cos.shape[1]):
            print(f"{attn_weights_cos[i,j]:.2f}", end='\t')
        print()


In [14]:
print_cos_weights(words_o, embeddings_dict)


	i	must	go	back	to	my	ship	and	to	my	crew	
i	1.00	0.75	0.86	0.76	0.73	0.90	0.35	0.65	0.73	0.90	0.42	
must	0.75	1.00	0.85	0.68	0.87	0.69	0.42	0.69	0.87	0.69	0.45	
go	0.86	0.85	1.00	0.84	0.84	0.81	0.41	0.68	0.84	0.81	0.49	
back	0.76	0.68	0.84	1.00	0.83	0.76	0.49	0.77	0.83	0.76	0.51	
to	0.73	0.87	0.84	0.83	1.00	0.68	0.54	0.86	1.00	0.68	0.51	
my	0.90	0.69	0.81	0.76	0.68	1.00	0.38	0.63	0.68	1.00	0.44	
ship	0.35	0.42	0.41	0.49	0.54	0.38	1.00	0.46	0.54	0.38	0.78	
and	0.65	0.69	0.68	0.77	0.86	0.63	0.46	1.00	0.86	0.63	0.49	
to	0.73	0.87	0.84	0.83	1.00	0.68	0.54	0.86	1.00	0.68	0.51	
my	0.90	0.69	0.81	0.76	0.68	1.00	0.38	0.63	0.68	1.00	0.44	
crew	0.42	0.45	0.49	0.51	0.51	0.44	0.78	0.49	0.51	0.44	1.00	


In [15]:
print_cos_weights(words_a, embeddings_dict)


	we	process	and	ship	your	order	
we	1.00	0.64	0.70	0.36	0.75	0.64	
process	0.64	1.00	0.61	0.29	0.52	0.67	
and	0.70	0.61	1.00	0.46	0.58	0.69	
ship	0.36	0.29	0.46	1.00	0.37	0.52	
your	0.75	0.52	0.58	0.37	1.00	0.63	
order	0.64	0.67	0.69	0.52	0.63	1.00	


## Contextual embeddings

We design a new vector representation for _ship_ so that it receives an influence from _crew_ and the other words of its context. This influence will depend on the embeddings from te context. Let us use the cosine similarities as attention weights

In [16]:
attn_cos_weights(embeddings_mat_o)[6]


tensor([0.3466, 0.4178, 0.4068, 0.4853, 0.5401, 0.3791, 1.0000, 0.4586, 0.5401,
        0.3791, 0.7848])

We compute the new embeddings as the sum of the noncontextual embeddings weighted by the cosine similarity. We have contextual embeddings.

In [17]:
new_embeddings_ship = (0.35 * embeddings_dict['i'] +
                       0.42 * embeddings_dict['must'] +
                       0.41 * embeddings_dict['go'] +
                       0.49 * embeddings_dict['back'] +
                       0.54 * embeddings_dict['to'] +
                       0.38 * embeddings_dict['my'] +
                       1.00 * embeddings_dict['ship'] +
                       0.46 * embeddings_dict['and'] +
                       0.54 * embeddings_dict['to'] +
                       0.38 * embeddings_dict['my'] +
                       0.78 * embeddings_dict['crew'])
new_embeddings_ship


tensor([  3.2289,   0.6422,   1.4712,  -2.3538,   2.2414,  -0.4237,  -4.1052,
          2.6216,   0.1719,  -2.4324,   1.3882,   3.7241,  -1.9721,   1.1893,
          2.2511,   0.9502,  -0.7646,   1.0289,  -3.0553,  -3.6306,   0.8305,
          2.9299,   1.3221,  -0.7092,   2.9745, -10.5959,  -1.3168,   0.2059,
          3.5457,  -2.7711,  18.2672,   2.4817,  -3.5887,   0.3297,   1.2718,
          0.6539,   1.5873,   0.0195,   0.7724,  -1.4620,  -0.2067,  -1.2464,
          2.1504,  -0.1811,  -0.5026,  -0.2888,  -0.5060,  -1.9676,  -0.0605,
         -0.6725])

Exact computation with torch

In [18]:
(attn_cos_weights(embeddings_mat_o) @ embeddings_mat_o)[6]


tensor([ 3.2319e+00,  6.4082e-01,  1.4718e+00, -2.3434e+00,  2.2358e+00,
        -4.1877e-01, -4.1002e+00,  2.6211e+00,  1.8010e-01, -2.4360e+00,
         1.3923e+00,  3.7188e+00, -1.9603e+00,  1.1980e+00,  2.2394e+00,
         9.3763e-01, -7.7049e-01,  1.0349e+00, -3.0615e+00, -3.6259e+00,
         8.3401e-01,  2.9281e+00,  1.3165e+00, -7.1303e-01,  2.9667e+00,
        -1.0567e+01, -1.3099e+00,  2.0283e-01,  3.5362e+00, -2.7571e+00,
         1.8220e+01,  2.4698e+00, -3.5804e+00,  3.2604e-01,  1.2760e+00,
         6.5701e-01,  1.5889e+00,  1.1571e-02,  7.6620e-01, -1.4560e+00,
        -2.0362e-01, -1.2484e+00,  2.1550e+00, -1.8767e-01, -5.0253e-01,
        -2.9128e-01, -5.1006e-01, -1.9596e+00, -5.8853e-02, -6.7380e-01])

## Self-attention

Vaswani et al. (2017) defined attention as:
$$
\text{Attention}({Q}, {K}, {Q}) = \text{softmax}(\frac{{Q}  {K}^\intercal}{\sqrt{d_k}})  {V},
$$
where
$$
\begin{array}{lcl}
{Q} &=& {X} {W}_Q,   \\
{K} &=& {X} {W}_K , \\
{V} &=& {X} {W}_V.\\
\end{array}
$$
and ${X}$ represents complete input sequence (all the tokens).

$d_k$ is the dimension of the input and $\sqrt{d_k}$ a scaling factor. The $\text{softmax}$ function is defined as:
$$
\text{softmax}(x_1, x_2, ..., x_j, ..., x_n) = (\frac{e^{x_1}}{\sum_{i=1}^n e^{x_i}}, \frac{e^{x_2}}{\sum_{i=1}^n e^{x_i}}, ..., \frac{e^{x_j}}{\sum_{i=1}^n e^{x_i}}, ..., \frac{e^{x_n}}{\sum_{i=1}^n e^{x_i}})
$$

We omit the weight matrices and we use the same embeddings for ${Q}$, ${K}$, and ${Q}$: GloVe embeddings

For the matrix above, self attention, $\text{softmax}(\frac{{Q}  {K}^\intercal}{\sqrt{d_k}})$,  for _ship_ yields:

In [19]:
dk = embeddings_dict['i'].size()[0]
dk = torch.tensor(dk)
dk


tensor(50)

In [20]:
attn_weights_o = F.softmax(
    embeddings_mat_o @ embeddings_mat_o.T/torch.sqrt(dk), dim=-1)
attn_weights_o[6]


tensor([0.0303, 0.0302, 0.0276, 0.0407, 0.0459, 0.0343, 0.5530, 0.0297, 0.0459,
        0.0343, 0.1281])

The scaled and normalized attention weights

In [21]:
def print_attn_weights(words, embeddings_dict):
    embeddings = embedding_matrix(words, embeddings_dict)
    sent_length, dk = embeddings.size()
    attn_weights = F.softmax(embeddings @ embeddings.T /
                            torch.sqrt(torch.tensor(dk)), dim=-1)
    print('\t', end='')
    for i in range(sent_length):
        print(words[i], end='\t')
    print()
    for i in range(sent_length):
        print(words[i], end='\t')
        for j in range(sent_length):
            print(f"{attn_weights[i,j]:.2f}", end='\t')
        print()


In [22]:
print_attn_weights(words_o, embeddings_dict)


	i	must	go	back	to	my	ship	and	to	my	crew	
i	0.36	0.05	0.07	0.05	0.04	0.19	0.01	0.02	0.04	0.19	0.01	
must	0.14	0.20	0.10	0.06	0.11	0.10	0.03	0.05	0.11	0.10	0.02	
go	0.18	0.09	0.14	0.09	0.08	0.13	0.02	0.04	0.08	0.13	0.02	
back	0.14	0.05	0.09	0.19	0.08	0.12	0.03	0.06	0.08	0.12	0.03	
to	0.11	0.11	0.09	0.09	0.15	0.08	0.04	0.07	0.15	0.08	0.03	
my	0.19	0.03	0.05	0.04	0.03	0.29	0.01	0.02	0.03	0.29	0.01	
ship	0.03	0.03	0.03	0.04	0.05	0.03	0.55	0.03	0.05	0.03	0.13	
and	0.10	0.08	0.07	0.10	0.12	0.09	0.04	0.15	0.12	0.09	0.04	
to	0.11	0.11	0.09	0.09	0.15	0.08	0.04	0.07	0.15	0.08	0.03	
my	0.19	0.03	0.05	0.04	0.03	0.29	0.01	0.02	0.03	0.29	0.01	
crew	0.06	0.05	0.05	0.06	0.05	0.06	0.21	0.04	0.05	0.06	0.31	


For _ship:_

In [23]:
attn_weights_o[6]


tensor([0.0303, 0.0302, 0.0276, 0.0407, 0.0459, 0.0343, 0.5530, 0.0297, 0.0459,
        0.0343, 0.1281])

We have the weights of 55% for _ship_ and 13% for _crew_, the rest from the other words.

And the new contextual embedding is for _ship_ is a linear combination:

In [24]:
self_attention_ship = (0.03 * embeddings_dict['i'] +
                       0.03 * embeddings_dict['must'] +
                       0.03 * embeddings_dict['go'] +
                       0.04 * embeddings_dict['back'] +
                       0.05 * embeddings_dict['to'] +
                       0.03 * embeddings_dict['my'] +
                       0.55 * embeddings_dict['ship'] +
                       0.03 * embeddings_dict['and'] +
                       0.05 * embeddings_dict['to'] +
                       0.03 * embeddings_dict['my'] +
                       0.13 * embeddings_dict['crew'])
self_attention_ship


tensor([ 1.0442,  0.0966,  0.3467, -0.4238,  0.2203, -0.0956, -0.9915,  0.6637,
         0.4368, -0.7943,  0.5639,  0.9838,  0.0240,  0.5066,  0.0732, -0.1740,
        -0.3322,  0.5614, -1.1613, -0.5717,  0.4356,  0.4120, -0.0659, -0.3336,
         0.6579, -1.7421, -0.0344,  0.1440,  0.8547, -0.1430,  2.6614, -0.0553,
        -0.5376,  0.3057,  0.4068,  0.2231,  0.3959, -0.2940, -0.1163, -0.1340,
         0.1709, -0.5332,  0.9552, -0.4178, -0.1058, -0.1715, -0.2251, -0.3923,
         0.2098, -0.3625])

Exact and complete computation of the whole matrix with torch of 
$$
\text{softmax}(\frac{{Q}  {K}^\intercal}{\sqrt{d_k}})  {V} :
$$

In [25]:
self_attention_output_o = attn_weights_o @ embeddings_mat_o


The contextual embeddings for _ship:_

In [26]:
self_attention_output_o[6]


tensor([ 1.0387,  0.1033,  0.3426, -0.4320,  0.2237, -0.0958, -0.9926,  0.6662,
         0.4424, -0.7942,  0.5638,  0.9921,  0.0205,  0.5082,  0.0743, -0.1773,
        -0.3408,  0.5675, -1.1545, -0.5718,  0.4288,  0.4191, -0.0658, -0.3339,
         0.6682, -1.7473, -0.0485,  0.1531,  0.8642, -0.1447,  2.6571, -0.0545,
        -0.5343,  0.3160,  0.4041,  0.2277,  0.3958, -0.2916, -0.1126, -0.1385,
         0.1744, -0.5375,  0.9499, -0.4145, -0.1039, -0.1755, -0.2213, -0.3995,
         0.2119, -0.3610])

We can now write a `self_attention` function 

In [27]:
def self_attention(embeddings_mat):
    sent_len, dk = embeddings_mat.size()
    dk = torch.tensor(dk)
    attn_weights = F.softmax(
        embeddings_mat @ embeddings_mat.T/torch.sqrt(dk), dim=-1)
    attn_output = attn_weights @ embeddings_mat
    return attn_output, attn_weights


The word _ship_ in another context: _We process and ship your order_

In [28]:
attention_output_a, attn_weights_a = self_attention(embeddings_mat_a)


Attention weights for _ship:_

In [29]:
attn_weights_a[3]


tensor([0.0431, 0.0258, 0.0419, 0.7811, 0.0490, 0.0590])

In [30]:
print_attn_weights(words_a, embeddings_dict)


	we	process	and	ship	your	order	
we	0.61	0.06	0.06	0.02	0.20	0.05	
process	0.17	0.50	0.08	0.03	0.11	0.11	
and	0.22	0.12	0.30	0.08	0.15	0.13	
ship	0.04	0.03	0.04	0.78	0.05	0.06	
your	0.14	0.03	0.03	0.02	0.74	0.04	
order	0.16	0.13	0.10	0.09	0.18	0.34	


The new contextual embeddings for _ship:_

In [31]:
attention_output_a[3]


tensor([ 1.2758,  0.1034,  0.2720, -0.4776,  0.1746, -0.1060, -0.9901,  0.6328,
         0.6967, -0.8847,  0.7106,  1.2264,  0.2491,  0.5023, -0.1277, -0.2361,
        -0.3709,  0.6545, -1.2587, -0.5332,  0.6681,  0.1687, -0.2567, -0.4218,
         0.6960, -1.7077, -0.0052,  0.1572,  1.0763,  0.0410,  2.5467, -0.3418,
        -0.5414,  0.4175,  0.4147,  0.2666,  0.3770, -0.4228, -0.2462, -0.0377,
         0.3202, -0.7298,  1.2020, -0.5636, -0.0899, -0.1845, -0.2390, -0.4307,
         0.3828, -0.4905])

## PyTorch implementation
 
PyTorch has an implementation of self-attention encapsulated in the `MultiheadAttention` class. Before going to the attention module, the query, key value, goes through a linear layer. The output also goes through a linear layer. These three layers are initialized with Xavier's algorithm.

In [32]:
from torch.nn import MultiheadAttention

att_layer = MultiheadAttention(50,
                               1,
                               bias=False,
                               batch_first=True)


In [33]:
(attn_output, attn_weights) = att_layer(
    embeddings_mat_o, embeddings_mat_o, embeddings_mat_o)


The attention weights for _ship:_

In [34]:
attn_weights[6]


tensor([0.0780, 0.0801, 0.0851, 0.0922, 0.0957, 0.0761, 0.1295, 0.0929, 0.0957,
        0.0761, 0.0985], grad_fn=<SelectBackward0>)

In [35]:
attn_output


tensor([[-4.6826e-01, -2.1499e-01,  2.1657e-01,  5.8600e-02,  2.0261e-01,
          1.2974e-01, -3.7535e-01, -3.0308e-01, -4.7069e-02,  2.2418e-01,
         -5.4904e-04, -5.0533e-02, -3.4317e-01, -2.1501e-01,  1.0109e-01,
          2.2577e-01, -7.3280e-02,  3.8556e-01, -4.2677e-01, -1.6214e-01,
         -9.7994e-02, -2.6659e-02, -3.0552e-01, -7.2621e-02, -2.6497e-01,
         -1.2757e-02, -6.1874e-01, -1.4302e-01,  2.1989e-01,  4.5054e-01,
         -1.0840e-01, -1.1712e-01, -4.3860e-02,  3.4409e-02,  1.4586e-02,
          3.4393e-01, -4.3443e-02,  1.3277e-01,  6.6110e-02,  8.9199e-02,
          5.0547e-02, -6.1725e-02,  5.0861e-02, -2.0772e-01,  3.7283e-01,
         -7.8241e-02,  3.7486e-01,  3.1713e-02,  9.0694e-02,  1.7461e-02],
        [-4.8627e-01, -1.9671e-01,  2.2992e-01,  5.6741e-02,  1.9986e-01,
          1.2246e-01, -3.8601e-01, -2.9922e-01, -4.7199e-02,  2.3012e-01,
         -7.6743e-03, -6.8458e-02, -3.5774e-01, -2.1314e-01,  1.1120e-01,
          2.5173e-01, -8.7428e-02,  4

### The initial dense layers

The weight initial values with the 4 matrices

In [36]:
att_layer.state_dict()


OrderedDict([('in_proj_weight',
              tensor([[ 0.1166, -0.1155,  0.1705,  ...,  0.0588, -0.0400, -0.1388],
                      [ 0.0843,  0.0309,  0.0920,  ..., -0.1133,  0.1074, -0.1378],
                      [ 0.0202, -0.0899,  0.1189,  ...,  0.1307, -0.1418,  0.0303],
                      ...,
                      [ 0.0987, -0.1202,  0.1206,  ...,  0.0667, -0.1011,  0.1717],
                      [-0.1227, -0.1660,  0.1384,  ..., -0.0425, -0.1331,  0.0680],
                      [ 0.0056, -0.0293, -0.0877,  ..., -0.1334, -0.1293, -0.0270]])),
             ('out_proj.weight',
              tensor([[-0.0418, -0.0286,  0.1178,  ...,  0.1233,  0.0926,  0.0739],
                      [ 0.0798, -0.0880,  0.0921,  ...,  0.1189,  0.0978, -0.0611],
                      [-0.0796,  0.0180, -0.0338,  ..., -0.1319,  0.1072,  0.0013],
                      ...,
                      [-0.0779, -0.0267, -0.0365,  ..., -0.0826,  0.0795,  0.0998],
                      [-0.1384,  0.012

The three input matrices are concatenated

In [37]:
att_layer.state_dict()['in_proj_weight'].size()


torch.Size([150, 50])

The output matrix

In [38]:
att_layer.state_dict()['out_proj.weight'].size()


torch.Size([50, 50])

### By-passing the dense layers

We create identity matrices to pass through the dense layers and recover the attention output values and weights

In [39]:
i_50 = torch.eye(50)
i_50


tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.]])

In [40]:
att_layer.state_dict()['out_proj.weight'][:] = i_50


In [41]:
att_layer.state_dict()['in_proj_weight'].size()


torch.Size([150, 50])

In [42]:
i_3_50 = torch.vstack((i_50, i_50, i_50))
i_3_50.size()


torch.Size([150, 50])

In [43]:
att_layer.state_dict()['in_proj_weight'][:] = i_3_50


In [44]:
att_layer.state_dict()


OrderedDict([('in_proj_weight',
              tensor([[1., 0., 0.,  ..., 0., 0., 0.],
                      [0., 1., 0.,  ..., 0., 0., 0.],
                      [0., 0., 1.,  ..., 0., 0., 0.],
                      ...,
                      [0., 0., 0.,  ..., 1., 0., 0.],
                      [0., 0., 0.,  ..., 0., 1., 0.],
                      [0., 0., 0.,  ..., 0., 0., 1.]])),
             ('out_proj.weight',
              tensor([[1., 0., 0.,  ..., 0., 0., 0.],
                      [0., 1., 0.,  ..., 0., 0., 0.],
                      [0., 0., 1.,  ..., 0., 0., 0.],
                      ...,
                      [0., 0., 0.,  ..., 1., 0., 0.],
                      [0., 0., 0.,  ..., 0., 1., 0.],
                      [0., 0., 0.,  ..., 0., 0., 1.]]))])

### Multihead attention without the dense layers

We obtain now the same results as the `self_attention()` function for _ship:_

The attention weights for _ship:_

In [45]:
(attn_output, attn_weights) = att_layer(
    embeddings_mat_o, embeddings_mat_o, embeddings_mat_o)


The attention vector for _ship:_

In [46]:
attn_weights[6]


tensor([0.0303, 0.0302, 0.0276, 0.0407, 0.0459, 0.0343, 0.5530, 0.0297, 0.0459,
        0.0343, 0.1281], grad_fn=<SelectBackward0>)

The embedding vector for _ship_

In [47]:
attn_output[6]


tensor([ 1.0387,  0.1033,  0.3426, -0.4320,  0.2237, -0.0958, -0.9926,  0.6662,
         0.4424, -0.7942,  0.5638,  0.9921,  0.0205,  0.5082,  0.0743, -0.1773,
        -0.3408,  0.5674, -1.1545, -0.5718,  0.4288,  0.4191, -0.0658, -0.3339,
         0.6682, -1.7473, -0.0485,  0.1531,  0.8642, -0.1447,  2.6571, -0.0545,
        -0.5343,  0.3160,  0.4041,  0.2277,  0.3958, -0.2916, -0.1126, -0.1385,
         0.1744, -0.5375,  0.9499, -0.4145, -0.1039, -0.1755, -0.2213, -0.3995,
         0.2119, -0.3610], grad_fn=<SelectBackward0>)

## Multihead

In [105]:
att_layer_5 = MultiheadAttention(50,
                               5,
                               #bias=False,
                               batch_first=True)

In [100]:
attn_output, attn_weights = att_layer_5(embeddings_mat_o, 
                                        embeddings_mat_o, 
                                        embeddings_mat_o)

In [101]:
attn_output.size()

torch.Size([11, 50])

In [102]:
attn_weights.size()

torch.Size([11, 11])

In [106]:
att_layer_5.state_dict()

OrderedDict([('in_proj_weight',
              tensor([[-0.1547,  0.0469,  0.0085,  ...,  0.0003,  0.0072, -0.1290],
                      [ 0.1376, -0.1614,  0.1675,  ...,  0.0844,  0.0071, -0.0405],
                      [ 0.1194,  0.1353, -0.1417,  ...,  0.0411,  0.0384,  0.1416],
                      ...,
                      [-0.0847, -0.0924, -0.1374,  ..., -0.1119,  0.0628,  0.1710],
                      [-0.0087,  0.0895, -0.1642,  ...,  0.1585,  0.0802,  0.0207],
                      [ 0.0314, -0.0630,  0.0014,  ..., -0.0353,  0.1654, -0.1638]])),
             ('in_proj_bias',
              tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                      0., 0., 0., 0., 0., 0., 0., 

In [94]:
att_layer_5.state_dict()['in_proj_weight'].size()

torch.Size([150, 50])

In [95]:
att_layer_5.state_dict()['out_proj.weight'].size()

torch.Size([50, 50])

## Test with a simple matrix

Three words, dimension of embeddings: 4

In [48]:
test_input_sequence = torch.tensor([[1.0, 0.0, 0.0, 1.0],
                                    [0.0, 1.5, 1.0, 1.0],
                                    [0.0, 1.0, 1.0, 1.0]])


In [49]:
test_input_sequence.size()


torch.Size([3, 4])

### Self-attention from the book

In [50]:
self_attention(test_input_sequence)


(tensor([[0.4519, 0.6852, 0.5481, 1.0000],
         [0.1045, 1.1609, 0.8955, 1.0000],
         [0.1387, 1.1034, 0.8613, 1.0000]]),
 tensor([[0.4519, 0.2741, 0.2741],
         [0.1045, 0.5307, 0.3648],
         [0.1387, 0.4842, 0.3771]]))

### Multihead attention from PyTorch

In [51]:
att_layer = MultiheadAttention(4,
                               1,
                               bias=False)


The multihead attention uses a Xavier initialization of the dense layers. The results will be different for those of `self_attention()`

In [52]:
att_layer(test_input_sequence,
          test_input_sequence,
          test_input_sequence)


(tensor([[-0.1564, -0.2048, -0.3103, -0.0635],
         [-0.1507, -0.1998, -0.3017, -0.0638],
         [-0.1522, -0.2012, -0.3040, -0.0637]], grad_fn=<SqueezeBackward1>),
 tensor([[0.2791, 0.3755, 0.3453],
         [0.2931, 0.3663, 0.3406],
         [0.2895, 0.3691, 0.3414]], grad_fn=<SqueezeBackward1>))

Weights of the dense layers

In [53]:
att_layer.state_dict()


OrderedDict([('in_proj_weight',
              tensor([[ 0.4338, -0.1751,  0.3666,  0.4250],
                      [ 0.5379, -0.2714,  0.0357,  0.4922],
                      [ 0.1468, -0.1328,  0.0231,  0.3700],
                      [-0.0839, -0.1653, -0.5226, -0.1443],
                      [ 0.5546,  0.1980,  0.5521,  0.2782],
                      [ 0.0432,  0.0151,  0.2500,  0.5516],
                      [ 0.3102,  0.2209,  0.0617, -0.3252],
                      [-0.5573, -0.1553, -0.5935,  0.1095],
                      [ 0.5032, -0.2073, -0.4933,  0.4358],
                      [-0.1982, -0.2822, -0.2332, -0.3275],
                      [ 0.3820,  0.5708,  0.2153, -0.1596],
                      [ 0.5904,  0.5305,  0.4842, -0.4498]])),
             ('out_proj.weight',
              tensor([[ 0.1102, -0.2755, -0.4689, -0.1669],
                      [ 0.1113, -0.0404, -0.4188,  0.0349],
                      [ 0.3024,  0.2077, -0.2820,  0.0415],
                      [-0.0382, 

### By-passing the dense layers

We use weights of identity matrices

In [54]:
i_4 = torch.eye(4)


In [55]:
att_layer.state_dict()['out_proj.weight'][:] = i_4


In [56]:
i_3_4 = torch.vstack((i_4, i_4, i_4))
i_3_4.size()


torch.Size([12, 4])

We set these weights

In [57]:
att_layer.state_dict()['in_proj_weight'][:] = i_3_4


In [58]:
att_layer.state_dict()


OrderedDict([('in_proj_weight',
              tensor([[1., 0., 0., 0.],
                      [0., 1., 0., 0.],
                      [0., 0., 1., 0.],
                      [0., 0., 0., 1.],
                      [1., 0., 0., 0.],
                      [0., 1., 0., 0.],
                      [0., 0., 1., 0.],
                      [0., 0., 0., 1.],
                      [1., 0., 0., 0.],
                      [0., 1., 0., 0.],
                      [0., 0., 1., 0.],
                      [0., 0., 0., 1.]])),
             ('out_proj.weight',
              tensor([[1., 0., 0., 0.],
                      [0., 1., 0., 0.],
                      [0., 0., 1., 0.],
                      [0., 0., 0., 1.]]))])

Now we have the same results as with `self_attention()`

In [59]:
att_layer(test_input_sequence,
          test_input_sequence,
          test_input_sequence)


(tensor([[0.4519, 0.6852, 0.5481, 1.0000],
         [0.1045, 1.1609, 0.8955, 1.0000],
         [0.1387, 1.1034, 0.8613, 1.0000]], grad_fn=<SqueezeBackward1>),
 tensor([[0.4519, 0.2741, 0.2741],
         [0.1045, 0.5307, 0.3648],
         [0.1387, 0.4842, 0.3771]], grad_fn=<SqueezeBackward1>))