<a href="https://colab.research.google.com/github/adnaen/machine-learning-notes/blob/main/llm/transformers/multi_head_attention/multi_head_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **MULTI-HEAD ATTENTION WORKFLOW**

1. Input text preprocessing
    - Tokenize
    - Vocab indexing
    - Embedding each token

2. Calculate Q, K, V with the input embedding
3. Split Q, K, V outputs with no.of heads
4. Calculate Attention score for each heads
5. Combine them together, pass a final Linear Layer.

In [88]:
import torch

## Text Preprocessing

In [127]:
text: str = "i love to code"
vocab_idx = torch.tensor([0, 1, 2, 3])  # vocab index

In [128]:
NUM_OF_HEADS: int = 2
EMBEDDING_DIM: int = 4  # a.k.a d_model
D_K: float = EMBEDDING_DIM // NUM_OF_HEADS

In [129]:
# embedding input tokens

embedder = torch.nn.Embedding(num_embeddings=4, embedding_dim=EMBEDDING_DIM)
embedded_text = embedder(vocab_idx)
embedded_text

tensor([[ 0.5008, -1.6228,  0.5575,  0.8338],
        [-2.0577, -0.4445,  1.1370,  2.4174],
        [ 1.7751, -0.0243, -0.3921,  1.1995],
        [-0.0620, -2.4201, -0.4190,  1.7088]], grad_fn=<EmbeddingBackward0>)

## Calculate Q, K, V

In [93]:
Q = torch.nn.Linear(EMBEDDING_DIM, EMBEDDING_DIM)
K = torch.nn.Linear(EMBEDDING_DIM, EMBEDDING_DIM)
V = torch.nn.Linear(EMBEDDING_DIM, EMBEDDING_DIM)

In [94]:
q_w = Q(embedded_text)
k_w = K(embedded_text)
v_w = V(embedded_text)

In [95]:
q_w

tensor([[ 0.0357, -0.0050, -0.9573,  0.1623],
        [ 0.3952,  0.6331, -0.8375,  0.3189],
        [ 0.1392, -0.0054, -0.6463,  0.1741],
        [-0.7766,  0.4740,  0.2468,  0.4739]], grad_fn=<AddmmBackward0>)

In [96]:
k_w

tensor([[ 0.8485, -0.0183, -0.7542, -0.5700],
        [ 0.7656,  0.3594, -1.2888, -1.2898],
        [ 0.9554,  0.5377, -0.3983, -0.2228],
        [-0.9299, -0.6553,  0.5126,  1.5518]], grad_fn=<AddmmBackward0>)

In [97]:
v_w

tensor([[ 0.4343, -0.4957,  0.2706,  1.2963],
        [ 0.0052, -0.6249, -0.3505,  1.8123],
        [ 0.5499, -0.3414,  0.2242,  0.7206],
        [-0.9433,  0.7010,  0.4386, -0.3949]], grad_fn=<AddmmBackward0>)

## Split the Q, K, V for multiple heads

In [98]:
split_q = q_w.view(4, NUM_OF_HEADS, D_K).transpose(0,1)  # 4 is seq_len
split_q

tensor([[[ 0.0357, -0.0050],
         [ 0.3952,  0.6331],
         [ 0.1392, -0.0054],
         [-0.7766,  0.4740]],

        [[-0.9573,  0.1623],
         [-0.8375,  0.3189],
         [-0.6463,  0.1741],
         [ 0.2468,  0.4739]]], grad_fn=<TransposeBackward0>)

In [124]:
split_k = k_w.view(4, NUM_OF_HEADS, D_K).transpose(0,1)
split_k

tensor([[[ 0.8485, -0.0183],
         [ 0.7656,  0.3594],
         [ 0.9554,  0.5377],
         [-0.9299, -0.6553]],

        [[-0.7542, -0.5700],
         [-1.2888, -1.2898],
         [-0.3983, -0.2228],
         [ 0.5126,  1.5518]]], grad_fn=<TransposeBackward0>)

In [100]:
split_v = v_w.view(4, NUM_OF_HEADS, D_K).transpose(0, 1)
split_v

tensor([[[ 0.4343, -0.4957],
         [ 0.0052, -0.6249],
         [ 0.5499, -0.3414],
         [-0.9433,  0.7010]],

        [[ 0.2706,  1.2963],
         [-0.3505,  1.8123],
         [ 0.2242,  0.7206],
         [ 0.4386, -0.3949]]], grad_fn=<TransposeBackward0>)

## Calculate Attention score for each heads

In [92]:
def attention_score(
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        dk: torch.Tensor) -> torch.Tensor:

        # attention = softmax(q*k^T / root of dk) * v
        r_1 = (q @ k.T) / torch.sqrt(dk)
        return torch.softmax(r_1, dim=1) @ v

In [111]:
res = []
for q_i, k_i, v_i in zip(split_q, split_k, split_v):
    res.append(attention_score(q_i, k_i, v_i, dk=torch.tensor(D_K)))
res

[tensor([[ 0.0217, -0.1993],
         [ 0.1869, -0.3428],
         [ 0.0521, -0.2263],
         [-0.1752, -0.0223]], grad_fn=<MmBackward0>),
 tensor([[0.0619, 1.1051],
         [0.1031, 0.9833],
         [0.1002, 0.9954],
         [0.2571, 0.4498]], grad_fn=<MmBackward0>)]

In [112]:
res = torch.stack(res, dim=0)
res

tensor([[[ 0.0217, -0.1993],
         [ 0.1869, -0.3428],
         [ 0.0521, -0.2263],
         [-0.1752, -0.0223]],

        [[ 0.0619,  1.1051],
         [ 0.1031,  0.9833],
         [ 0.1002,  0.9954],
         [ 0.2571,  0.4498]]], grad_fn=<StackBackward0>)

In [114]:
res.shape

torch.Size([2, 4, 2])

## Merge all heads outputs

In [125]:
result = res.permute(1, 0, 2).reshape(4,4)
result

tensor([[ 0.0217, -0.1993,  0.0619,  1.1051],
        [ 0.1869, -0.3428,  0.1031,  0.9833],
        [ 0.0521, -0.2263,  0.1002,  0.9954],
        [-0.1752, -0.0223,  0.2571,  0.4498]], grad_fn=<UnsafeViewBackward0>)

## Final Linear Layer pass

In [126]:
o_w = torch.nn.Linear(EMBEDDING_DIM, EMBEDDING_DIM)
o_w(result)

tensor([[ 0.1213, -0.1479,  0.5272, -0.5390],
        [ 0.0650, -0.0441,  0.6500, -0.6313],
        [ 0.0741, -0.0966,  0.5649, -0.5407],
        [-0.1407,  0.0220,  0.5020, -0.2861]], grad_fn=<AddmmBackward0>)