<a href="https://colab.research.google.com/github/adnaen/machine-learning-notes/blob/main/llm/attention/attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [70]:
import torch

## Tokenizing & Embedding

In [74]:
text = "cat sat on mat"
text_idx = torch.tensor([1, 3, 4, 5])

In [76]:
VOCAB_SIZE: int = 6
EMBEDDING_DIM: int = 2

embedder = torch.nn.Embedding(num_embeddings=VOCAB_SIZE, embedding_dim=EMBEDDING_DIM)
embd_text = embedder(text_idx)
embd_text

tensor([[-1.2933,  0.6479],
        [ 0.3835, -1.0435],
        [ 1.3187, -0.5203],
        [-0.9606, -0.3084]], grad_fn=<EmbeddingBackward0>)

## Project Query, Key, Value

In [77]:
ATTENTION_DIM: int = 3

Q_W = torch.nn.Linear(EMBEDDING_DIM, ATTENTION_DIM)
K_W = torch.nn.Linear(EMBEDDING_DIM, ATTENTION_DIM)
V_W = torch.nn.Linear(EMBEDDING_DIM, ATTENTION_DIM)

In [79]:
Q = Q_W(embd_text)
K = K_W(embd_text)
V = V_W(embd_text)

In [81]:
Q

tensor([[-0.2956, -0.7481,  0.7668],
        [-0.4407, -0.3604, -0.3019],
        [-0.2979, -0.3940, -0.5250],
        [-0.4191, -0.5654,  0.3969]], grad_fn=<AddmmBackward0>)

In [82]:
K

tensor([[-0.0361,  0.0909, -0.6690],
        [ 0.2678,  0.0417,  0.4907],
        [ 1.0626,  0.7140,  0.2973],
        [-0.2405, -0.2150, -0.0833]], grad_fn=<AddmmBackward0>)

In [105]:
V

tensor([[-0.3913,  0.5381, -1.5445],
        [-1.0327, -0.1697, -0.0426],
        [-0.9210, -0.1783,  0.2962],
        [-0.7172,  0.2343, -1.0353]], grad_fn=<AddmmBackward0>)

## Calculate Attention Score

In [84]:
# Attention(Q, K, V) = softmax((Q*K^T) / root of dx) * V

In [87]:
import numpy as np

In [92]:
d_x = ATTENTION_DIM
first_term = Q @ K.T / torch.sqrt(torch.tensor(d_x))
first_term

tensor([[-0.3293,  0.1535, -0.3581,  0.0970],
        [ 0.1069, -0.1624, -0.4708,  0.1204],
        [ 0.1883, -0.2043, -0.4353,  0.1155],
        [-0.1742,  0.0340, -0.4221,  0.1093]], grad_fn=<DivBackward0>)

In [110]:
prob_w = torch.softmax(first_term, dim=1)
prob_w

tensor([[0.1952, 0.3163, 0.1896, 0.2989],
        [0.2995, 0.2288, 0.1681, 0.3036],
        [0.3184, 0.2150, 0.1706, 0.2960],
        [0.2304, 0.2838, 0.1798, 0.3060]], grad_fn=<SoftmaxBackward0>)

In [111]:
result = prob_w @ V
result  # final result always the same shape of V

tensor([[-0.7920,  0.0876, -0.5682],
        [-0.7260,  0.1635, -0.7368],
        [-0.7160,  0.1737, -0.7567],
        [-0.7683,  0.1154, -0.6315]], grad_fn=<MmBackward0>)