# Self Attention
The concept of `Attention` came from the improvement of `RNN` for handling longer sequences or sentences. For example, consider translating a sentence from one language to another. Translating a sentence `word-by-word` does not work effectively

<img src="https://sebastianraschka.com/images/blog/2023/self-attention-from-scratch/sentence.png"/>

To overcome this issue, `attention mechanisms` were introduced to give access to all sequence elements at each time step. The key is to be selective and determine which words are most important in a specific context. 

In this, we focus on `Scaled-Dot Product Attention Mechanism (Self Attention)` which is a popular and widely used attention mechanism in practice. There are some other attention mechanisms like `FlashAttention`.

## Embedding an Input Sentence
Configure an input sentence which would be sent through the `self-attention` mechanism

In [1]:
sentence = 'Hi, My name is Raghvender Changotra and I am 21 years old'

dc = {s: i for i, s in enumerate(sorted(sentence.replace(',', '').split()))}
print(dc)

{'21': 0, 'Changotra': 1, 'Hi': 2, 'I': 3, 'My': 4, 'Raghvender': 5, 'am': 6, 'and': 7, 'is': 8, 'name': 9, 'old': 10, 'years': 11}


In [2]:
# Use this dict to assign an integer index to each word
import torch

sentence_int = torch.tensor([dc[s] for s in sentence.replace(',', '').split()])
print(sentence_int)

tensor([ 2,  4,  9,  8,  5,  1,  7,  3,  6,  0, 11, 10])


Now, we use `Embedding Layer` to encode the inputs into a real-vector embedding. We use a 16-dimensional embedding such that each input word is represented by `16-dimensional` vector. As there are 12 words so this will result in `12x16` embedding.

In [3]:
torch.manual_seed(123)
emb = torch.nn.Embedding(12, 16)
emb_sentence = emb(sentence_int).detach()

print(emb_sentence)
print(emb_sentence.shape)

tensor([[-1.3250e+00,  1.7843e-01, -2.1338e+00,  1.0524e+00, -3.8848e-01,
         -9.3435e-01, -4.9914e-01, -1.0867e+00,  8.8054e-01,  1.5542e+00,
          6.2662e-01, -1.7549e-01,  9.8284e-02, -9.3507e-02,  2.6621e-01,
         -5.8504e-01],
        [ 5.1463e-01,  9.9376e-01, -2.5873e-01, -1.0826e+00, -4.4382e-02,
          1.6236e+00, -2.3229e+00,  1.0878e+00,  6.7155e-01,  6.9330e-01,
         -9.4872e-01, -7.6507e-02, -1.5264e-01,  1.1674e-01,  4.4026e-01,
         -1.4465e+00],
        [-1.2743e+00,  4.5128e-01, -2.2801e-01,  9.2238e-01,  2.0561e-01,
         -4.9696e-01,  5.8206e-01,  2.0532e-01, -3.0177e-01, -6.7030e-01,
         -6.1710e-01, -8.3339e-01,  4.8387e-01, -1.3493e-01,  2.1187e-01,
         -8.7140e-01],
        [-2.5822e-01, -2.0407e+00, -8.0156e-01, -8.1830e-01, -1.1820e+00,
         -2.8774e-01, -6.0430e-01,  6.0024e-01, -1.4053e+00, -5.9217e-01,
         -2.5479e-01,  1.1517e+00, -1.7858e-02,  4.2640e-01, -7.6574e-01,
         -5.4514e-02],
        [ 2.5529e-01

## Defining Weight Matrices
`Self Attention` uses three weight matrices, referred to as $U_q$, $U_k$ and $U_v$, which are adjusted as model parameters during training. These matrices are `query`, `key` and `value` components of the input sequence.

These are obtained via matrix multiplication b/w `weight` matrices $U$ and `embedded` inputs $x$
- Query Sequence $q^{(i)} = U_qX^{(i)}$ for $i$ $\epsilon [1, T]$
- Key Sequence $k^{(i)} = U_kX^{(i)}$ for $i$ $\epsilon [1, T]$
- Value Sequence $v^{(i)} = U_vX^{(i)}$ for $i$ $\epsilon [1, T]$

The index `i` refers to the token index position in the input sequence, which has length $T$.

<img src="https://sebastianraschka.com/images/blog/2023/self-attention-from-scratch/attention-matrices.png"/>

Here, both $q^{(i)}$ and $k^{(i)}$ are vectors of dimension $d_k$. The projection matrices $U_q$ and $U_k$ have a shape of $d_k * d$, while $U_v$ has the shape $d_v * d$ where $d$ represents the size of each word vector $x$. Since we are computing the `dot-product` b/w `query` and `key` vectors, these vectors have to contain the same number of elements ($d_q = d_k$). However, the number of elements in the value vector $v^{(i)}$ which determines the size of `resulting context vector` is arbitrary.

So, here we will set $d_q = d_k = 24$ and use $d_v = 28$ when initializing the projection matrices

In [4]:
torch.manual_seed(123) # randomize
d = emb_sentence.shape[1]
d_q, d_k, d_v = 24, 24, 28

U_q = torch.randn(d_q, d)
U_k = torch.randn(d_k, d)
U_v = torch.randn(d_v, d)

## Computing Unormalized Attention Weights
For example, we want to calculate the `attention` vector for the second input element. So, $x^{(2)}$ would act as the query.

<img src="https://sebastianraschka.com/images/blog/2023/self-attention-from-scratch/query.png"/>

In [5]:
# Second Input Element as the query
x_2 = emb_sentence[1]
q_2 = U_q.matmul(x_2)
k_2 = U_k.matmul(x_2)
v_2 = U_v.matmul(x_2)

print(q_2.shape)
print(k_2.shape)
print(v_2.shape)

torch.Size([24])
torch.Size([24])
torch.Size([28])


We can generalize this to compute the remaining `keys` and `value` elements for all the inputs as we would need them when we compute the unnormalized attention weights $\omega$

In [6]:
keys = U_k.matmul(emb_sentence.T).T
values = U_v.matmul(emb_sentence.T).T

print('Keys: ', keys.shape)
print('Values: ', values.shape)

Keys:  torch.Size([12, 24])
Values:  torch.Size([12, 28])


Now, we have the `keys` and `values` to proceed with computation of unnormalized attention weights $\omega$ as illustrated here

<img src="https://sebastianraschka.com/images/blog/2023/self-attention-from-scratch/omega.png"/>
We compute $\omega_{ij} = q^{(i)^T}k^{(j)}$ as the `dot product` b/w query and key sequences. For example, we can compute the unnormalized attention weights for the query and  $5th$ input element (index pos -> 4)

In [7]:
omega24 = q_2.dot(keys[4])
print(omega24)

tensor(-100.8584)


In [8]:
# Attention values for all input tokens
omega_2 = q_2 @ keys.T
print(omega_2)

tensor([  63.5880,   95.5014,   -3.3820,  -46.4590, -100.8584,  -98.1709,
          -7.3639,    9.3997,   33.1963,   83.1533,   -8.4415,   14.9591])


## Computing the Attention Scores

The subsequent step in self-attention is to `normalize` the unnormalized attention weights $\omega$, to obtain the normalized attention weights $\alpha$, by applying `softmax` function. Additionally, $1/\sqrt(d_k)$ is used to scale $\omega$ before normalizing it through `softmax` as shown here

<img src="https://sebastianraschka.com/images/blog/2023/self-attention-from-scratch/attention-scores.png"/>

The scaling by $d_k$ ensures that the `Euclidean length` of the weight vectors will be approx. in the same `magnitude`. This helps prevent the attention weights from becoming too `small` or too `large`, which leads to numerical instablity and might effect model's ability to `converge` during training.

In [9]:
import torch.nn.functional as F

# Normalized Attention Weights
attn_weights_2 = F.softmax(omega_2 / d**0.5, dim=0)
print(attn_weights_2)

tensor([3.2773e-04, 9.5604e-01, 1.7553e-11, 3.6925e-16, 4.5812e-22, 8.9696e-22,
        6.4866e-12, 4.2865e-10, 1.6435e-07, 4.3631e-02, 4.9547e-12, 1.7207e-09])


Now, the last step is to compute the `context` vector $z^{(2)}$, which is an attention-weighted version of our original input query $x^{(2)}$, including all other input elements as its context via `attention weights`.

<img src="https://sebastianraschka.com/images/blog/2023/self-attention-from-scratch/context-vector.png"/>

In [10]:
context_2 = attn_weights_2 @ values

print(context_2.shape)
print(context_2)

torch.Size([28])
tensor([ 0.6245,  3.5794, -3.6488, -4.2545,  1.1219,  2.6817, -3.5003, -6.5867,
         1.1002, -5.2203,  0.6476, -0.0411,  0.3919, -0.5379, -2.5635,  1.4992,
         0.8038,  0.0308, -4.3862,  5.8357,  0.4138, -0.7649, -3.2749,  0.8404,
        -6.0157, -6.9075, -2.3196,  0.3939])


# Multi Head Attention

In Multi Head Attention, there are several single head in the context of multi head attention which contain three matrices `query`, `key` and `value` which are transformed from the input sequence.

This image illustrates a `Single head Attention` 

<img src="https://sebastianraschka.com/images/blog/2023/self-attention-from-scratch/single-head.png"/>

Now, `MultiHead Attention` contains multiple such heads, each consisting of `query`, `key` and `value` matrices like this.

<img src="https://sebastianraschka.com/images/blog/2023/self-attention-from-scratch/multi-head.png"/>

Now, suppose we have `3` attention heads, so we extend the $d' \times d$ dimensional weight matrices into $3 \times d' \times d$

In [11]:
n_heads = 3
mh_U_q = torch.randn(n_heads, d_q, d) # Multihead U_q
mh_U_k = torch.randn(n_heads, d_k, d) # Multihead U_k
mh_U_v = torch.randn(n_heads, d_v, d) # Multihead U_v

Each query element is now $3 \times d_q$ dimensional, where $d_q = 24$

In [13]:
# Focus on 3rd element (2nd Index Pos)
mh_q_2 = mh_U_q @ x_2

print(mh_q_2.shape)

torch.Size([3, 24])


In [14]:
# We obtain keys and values using this
mh_k_2 = mh_U_k @ x_2 # keys
mh_v_2 = mh_U_v @ x_2 # values

These `keys` and `values` elements are specific to the `query` element. But, we would need the `keys` and `values` for other sequence elements in order to compute the `attention` scores for the query. We can do this by expanding the input sequence `embeddings` to size `3` (No. of Heads)

In [16]:
stacked_inp = emb_sentence.T.repeat(3, 1, 1)

print(stacked_inp.shape)

torch.Size([3, 16, 12])


Now we can use `torch.bmm()` for batch matrix multiplication which would enable us to compute all `keys` and `values`

In [18]:
mh_keys = torch.bmm(mh_U_k, stacked_inp)
mh_values = torch.bmm(mh_U_v, stacked_inp)

print('MultiHead Keys: ', mh_keys.shape)
print('MultiHead Values: ', mh_values.shape)

MultiHead Keys:  torch.Size([3, 24, 12])
MultiHead Values:  torch.Size([3, 28, 12])


Now, we have the tensors to represent the `eight` attention heads in their first dimension. The `3rd` and `2nd` dimension refers to the `no. of words` and `embedding size`. We will swap `2nd` and `3rd` dimensions for better interpretation which would make the tensors having same dimensional structure as the original input sequence `emb_sentence`.

In [19]:
mh_keys = mh_keys.permute(0, 2, 1)
mh_values = mh_values.permute(0, 2, 1)

print('MultiHead Keys: ', mh_keys.shape)
print('MultiHead Values: ', mh_values.shape)

MultiHead Keys:  torch.Size([3, 12, 24])
MultiHead Values:  torch.Size([3, 12, 28])


Now, we follow the same steps as above to compute the `unnormalized attention weights` $\omega$ and `normalized attention weights` $\alpha$ followed by `scaled-softmax` computation to obtain $h \times d_v$ `dimensional context vector` $z$ for input element $x^{(2)}$