# Pay Attention
This notebook dives into the implementation of Self Attention mechanism. 

![Alt text](media/self_attention.png)

## Embedding the Input
First, let's convert the sentence into something actually "learnable".  
Start by creating the dictionary of tokens.


In [2]:
sentence = 'the quick brown fox jumps over a lazy dog'

dc = {s:i for i,s in enumerate(sorted(sentence.replace(',', '').split()))}
print(dc)

{'a': 0, 'brown': 1, 'dog': 2, 'fox': 3, 'jumps': 4, 'lazy': 5, 'over': 6, 'quick': 7, 'the': 8}


Now convert the sentence to integer-vector representation

In [3]:
import torch
sentence_int = torch.tensor([dc[s] for s in sentence.replace(',', '').split()])
print(sentence_int)

tensor([8, 7, 1, 3, 4, 6, 0, 5, 2])


Let's make the vector dense: use an embedding layer to encode the inputs into a real-vector embedding.  
 Here, we will use a 16-dimensional embedding such that each input word is represented by a 16-dimensional vector 

In [4]:
torch.manual_seed(0)
embed = torch.nn.Embedding(len(sentence.replace(',','').split()), 16)
embedded_sentence = embed(sentence_int).detach()
print(embedded_sentence)

tensor([[-8.8338e-01, -4.1891e-01, -8.0483e-01,  5.6561e-01,  6.1036e-01,
          4.6688e-01,  1.9507e+00, -1.0631e+00, -7.7326e-02,  1.1640e-01,
         -5.9399e-01, -1.2439e+00, -1.0209e-01, -1.0335e+00, -3.1264e-01,
          2.4579e-01],
        [-2.1883e-01, -2.4351e+00, -7.2915e-02, -3.3987e-02,  9.6252e-01,
          3.4917e-01, -9.2146e-01, -5.6195e-02, -6.2270e-01, -4.6372e-01,
          1.9218e+00, -4.0255e-01,  1.2390e-01,  1.1648e+00,  9.2337e-01,
          1.3873e+00],
        [-1.3527e+00, -1.6959e+00,  5.6665e-01,  7.9351e-01,  5.9884e-01,
         -1.5551e+00, -3.4136e-01,  1.8530e+00,  7.5019e-01, -5.8550e-01,
         -1.7340e-01,  1.8348e-01,  1.3894e+00,  1.5863e+00,  9.4630e-01,
         -8.4368e-01],
        [-6.7309e-01,  8.7283e-01,  1.0554e+00,  1.7784e-01, -2.3034e-01,
         -3.9175e-01,  5.4329e-01, -3.9516e-01, -4.4622e-01,  7.4402e-01,
          1.5210e+00,  3.4105e+00, -1.5312e+00, -1.2341e+00,  1.8197e+00,
         -5.5153e-01],
        [-5.6925e-01

Define the Weights Matrices:   
- Weights Query
- Weights Keys
- Weights Values

In [5]:
import torch.nn as nn
d = embedded_sentence.shape[1]
d_q, d_k, d_v = 24, 24, 28

W_q = nn.Parameter(torch.randn(d_q, d))
W_k = nn.Parameter(torch.randn(d_k, d))
W_v = nn.Parameter(torch.randn(d_v, d))

By using matrix multiplication we do obtain the respective Query, Key and Values sequences 

Now, the unnormalized attention weights

In [6]:
x = torch.randint(0, len(sentence_int), (1,)).item()

print(f"The input word for which we want to compute the attention is: '{list(dc.keys())[list(dc.values()).index(x)]}' with dictionary index {x} and embedding \n{embedded_sentence[x]}")

x = embedded_sentence[x]
query_x = W_q @ x
key_x = W_k @ x
value_x = W_v @ x
print(f"\nQuery: {query_x}\nKey: {key_x}\nValue: {value_x}")
print(f"\nQuery shape: {query_x.shape}\nKey shape: {key_x.shape}\nValue shape: {value_x.shape}")


The input word for which we want to compute the attention is: 'brown' with dictionary index 1 and embedding 
tensor([-0.2188, -2.4351, -0.0729, -0.0340,  0.9625,  0.3492, -0.9215, -0.0562,
        -0.6227, -0.4637,  1.9218, -0.4025,  0.1239,  1.1648,  0.9234,  1.3873])

Query: tensor([ -2.1911,  -1.4461,   0.4942,  -0.9959,  -4.7663,  -4.8887,   9.3072,
          5.8991,  -3.8935,  -4.2999,  -1.3099,  -1.1042, -13.4091,  -2.6684,
         -1.8868,  -1.6594,   4.1670,  -4.1148,   0.0935,   9.3019,  -1.6530,
          2.6375,  -5.8936,  -0.0373], grad_fn=<MvBackward0>)
Key: tensor([ 7.7921, -3.4727,  3.3231,  0.8337,  7.1973,  2.5318, -3.6451,  4.7367,
        -6.4336,  0.5191, -8.1601, -8.5292,  1.0385, -2.9421,  3.7129, -0.4032,
         7.9608,  8.9526, -0.7115, -2.3812, -1.3893,  3.9426, -0.2862,  0.7675],
       grad_fn=<MvBackward0>)
Value: tensor([-0.3520, -3.7708, -3.8933, -4.7494, -3.1310, -2.4028,  4.3370,  0.2605,
        -1.4421, -2.1825, -7.1633, -3.5063,  2.2980,  2.6982, -

To compute the unnormalized self-attention weights we have to compute the remaining key and values for all inputs as well

In [7]:
keys = (W_k @ embedded_sentence.T).T
values = (W_v @ embedded_sentence.T).T
print(f"\nKeys shape: {keys.shape}\nValues shape: {values.shape}")


Keys shape: torch.Size([9, 24])
Values shape: torch.Size([9, 28])


Now compute the unnormalized attention weights: $\omega_{i,j}=q^{(i)^T}k^{(j)}$

In [23]:
omega_x = query_x @ keys.T
print(f"\nOmega_x: {omega_x}")
print(f"\nOmega_x shape: {omega_x.shape}")


Omega_x: tensor([  32.4191,  -44.9209,   30.4341, -121.9629,  -61.8859, -123.4025,
         171.9292,    4.6947,   79.1044], grad_fn=<SqueezeBackward4>)

Omega_x shape: torch.Size([9])


And the respective normalized Attention Scores, first scale by $\frac{1}{\sqrt{d_k}}$(ensures that the Euclidean length of the weight vectors will be approximately in the same magnitude) and then apply the softmax function

In [22]:
import torch.nn.functional as F
attention_weights_x = F.softmax(omega_x/d_k**0.5, dim=0)
print(f"\nAttention weights for the input word: {attention_weights_x}")
print(f"\nAttention weights shape: {attention_weights_x.shape}")


Attention weights for the input word: tensor([4.2897e-13, 5.9736e-20, 2.8606e-13, 8.8403e-27, 1.8719e-21, 6.5893e-27,
        1.0000e+00, 1.4951e-15, 5.9031e-09], grad_fn=<SoftmaxBackward0>)

Attention weights shape: torch.Size([9])


Finally, the last step is to compute the context vector $z^{(x)}$, which is an attention-weighted version of our original query input, including all the other input elements as its context via the attention weights

In [10]:
context_vector_x = attention_weights_x @ values
print(f"\nContext vector for the input word: {context_vector_x}")
print(f"\nContext vector shape: {context_vector_x.shape}")


Context vector for the input word: tensor([-6.1658,  3.3317, -1.4784,  3.0280, -3.0778, -1.9382,  3.2093,  2.9632,
         4.7867,  2.5697, -1.9187, -0.8907,  3.5392, -0.1726, -2.6539,  5.6142,
        -1.1907,  2.2681, -6.4134,  2.0330,  3.2004, -8.4279, -5.9757, -6.8775,
         3.2998,  4.7060, -3.5087,  5.1399], grad_fn=<SqueezeBackward4>)

Context vector shape: torch.Size([28])


## Multi Head self Attention

In [11]:
heads = 3
multihead_W_q = nn.Parameter(torch.randn(heads, d_q, d))
multihead_W_k = nn.Parameter(torch.randn(heads, d_k, d))
multihead_W_v = nn.Parameter(torch.randn(heads, d_v, d))
print(f"\nMultihead query weight shape: {multihead_W_q.shape}, \nMultihead key weight shape: {multihead_W_k.shape}, \nMultihead value weight shape: {multihead_W_v.shape}")


Multihead query weight shape: torch.Size([3, 24, 16]), 
Multihead key weight shape: torch.Size([3, 24, 16]), 
Multihead value weight shape: torch.Size([3, 28, 16])


In [12]:
multihead_query_x = multihead_W_q @ x
print(f"\nMultihead query shape: {multihead_query_x.shape}")
multihead_key_x = multihead_W_k @ x
multihead_value_x = multihead_W_v @ x


Multihead query shape: torch.Size([3, 24])


In [13]:
stacked_inputs = embedded_sentence.T.repeat(heads, 1, 1)
print(f"\nStacked inputs shape: {stacked_inputs.shape}")


Stacked inputs shape: torch.Size([3, 16, 9])


In [14]:
multihead_keys = torch.bmm(multihead_W_k, stacked_inputs)
multihead_values = torch.bmm(multihead_W_v, stacked_inputs)
print("multihead_keys.shape:", multihead_keys.shape)
print("multihead_values.shape:", multihead_values.shape)

multihead_keys.shape: torch.Size([3, 24, 9])
multihead_values.shape: torch.Size([3, 28, 9])


In [15]:
multihead_keys = multihead_keys.permute(0, 2, 1)
multihead_values = multihead_values.permute(0, 2, 1)
print("multihead_keys.shape:", multihead_keys.shape)
print("multihead_values.shape:", multihead_values.shape)

multihead_keys.shape: torch.Size([3, 9, 24])
multihead_values.shape: torch.Size([3, 9, 28])


In [18]:
omega_multihead_x = torch.bmm(multihead_query_x.unsqueeze(1), multihead_keys.transpose(1,2)).squeeze(1)
print(f"\nOmega_multihead_x shape: {omega_multihead_x.shape}")


Omega_multihead_x shape: torch.Size([3, 9])


In [21]:
multihead_attention_weights_x = F.softmax(omega_multihead_x/d_k**0.5, dim=1)
print(f'Multihead attention weights shape: {multihead_attention_weights_x.shape}')

Multihead attention weights shape: torch.Size([3, 9])


In [26]:
multihead_context_vector_x = torch.bmm(multihead_attention_weights_x.unsqueeze(1), multihead_values).squeeze(1)
print(f'Multihead context vector shape: {multihead_context_vector_x.shape}')

Multihead context vector shape: torch.Size([3, 28])


## Cross-Attention
In self-attention, we work with the same input sequence. In cross-attention, we mix or combine two different input sequences.
![Alt text](media/cross-attention.png)

In [28]:
embedded_sentence_2 = torch.rand(8,16)
keys = (W_k @ embedded_sentence_2.T).T
values = (W_v @ embedded_sentence_2.T).T
print(f"\nKeys shape: {keys.shape}\nValues shape: {values.shape}")


Keys shape: torch.Size([8, 24])
Values shape: torch.Size([8, 28])


Notice that compared to self-attention, the keys and values now have 8 instead of 6 rows. Everything else stays the same.

In [29]:
omega_x = query_x @ keys.T
print(f'\n Omega_x shape: {omega_x.shape}')
attention_weights_x = F.softmax(omega_x/d_k**0.5, dim=0)
print(f'\n Attention weights shape: {attention_weights_x.shape}')
context_vector_x = attention_weights_x @ values
print(f'\n Context vector shape: {context_vector_x.shape}')


 Omega_x shape: torch.Size([8])

 Attention weights shape: torch.Size([8])

 Context vector shape: torch.Size([28])
