<a href="https://colab.research.google.com/github/adnaen/machine-learning-notes/blob/main/llm/attention/attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch

## Tokenizing & Embedding

In [None]:
text = "the cat sat on mat"
text_idx = torch.tensor([1, 2, 3, 4, 5])

In [None]:
VOCAB_SIZE: int = 6
EMBEDDING_DIM: int = 2

embedder = torch.nn.Embedding(num_embeddings=VOCAB_SIZE, embedding_dim=EMBEDDING_DIM)
embd_text = embedder(text_idx)
print(f"Each word get (1,2) sized embedding vector")
print(f"embedding for  'the' : {embd_text[0]}")
print(f"embedding for  'cat' : {embd_text[1]}")
print(f"embedding for  'sat' : {embd_text[2]}")
print(f"embedding for  'on' : {embd_text[3]}")
print(f"embedding for  'mat' : {embd_text[4]}")

Each word get (1,2) sized embedding vector
embedding for  'the' : tensor([0.5517, 1.2525], grad_fn=<SelectBackward0>)
embedding for  'cat' : tensor([1.1554, 1.3356], grad_fn=<SelectBackward0>)
embedding for  'sat' : tensor([-0.0685,  0.6911], grad_fn=<SelectBackward0>)
embedding for  'on' : tensor([-0.5127,  0.6953], grad_fn=<SelectBackward0>)
embedding for  'mat' : tensor([0.3279, 1.6875], grad_fn=<SelectBackward0>)


## Project Query, Key, Value

In [None]:
ATTENTION_DIM: int = 3

Q_W = torch.nn.Linear(EMBEDDING_DIM, ATTENTION_DIM)
K_W = torch.nn.Linear(EMBEDDING_DIM, ATTENTION_DIM)
V_W = torch.nn.Linear(EMBEDDING_DIM, ATTENTION_DIM)

In [None]:
Q = Q_W(embd_text)
K = K_W(embd_text)
V = V_W(embd_text)

In [None]:
Q

tensor([[-0.8839,  0.1819, -1.2290],
        [-1.0480,  0.0450, -1.4616],
        [-0.6030,  0.0626, -0.8578],
        [-0.4977,  0.1991, -0.7049],
        [-0.9331,  0.4870, -1.2723]], grad_fn=<AddmmBackward0>)

In [None]:
K

tensor([[ 0.3891, -0.8027,  0.7165],
        [ 0.7597, -0.8657,  1.1348],
        [ 0.2505, -0.4388,  0.0124],
        [-0.0555, -0.4335, -0.2577],
        [ 0.0146, -1.0722,  0.8299]], grad_fn=<AddmmBackward0>)

In [None]:
V

tensor([[-0.5666,  0.6943, -0.0059],
        [-0.7490,  0.5401, -0.2813],
        [-0.3925,  0.6570,  0.2172],
        [-0.2565,  0.7974,  0.4280],
        [-0.4862,  0.9432,  0.1547]], grad_fn=<AddmmBackward0>)

## Calculate Attention Score

In [None]:
# Attention(Q, K, V) = softmax((Q*K^T) / root of dk) * V

In [None]:
d_k = ATTENTION_DIM
first_term = Q @ K.T / torch.sqrt(torch.tensor(d_k))
first_term

tensor([[-0.7913, -1.2839, -0.1827,  0.1657, -0.7090],
        [-0.8609, -1.4398, -0.1734,  0.2398, -0.7370],
        [-0.5193, -0.8578, -0.1092,  0.1313, -0.4549],
        [-0.4956, -0.7796, -0.1275,  0.0710, -0.4652],
        [-0.9616, -1.4863, -0.2674,  0.0974, -0.9190]], grad_fn=<DivBackward0>)

In [None]:
prob_w = torch.softmax(first_term, dim=1)
print(f"Relation between each word and it-self")
print(f"'the' (pos: 0) relate more to   : {prob_w[0]}")
print(f"'cat' (pos: 1) relate more to   : {prob_w[1]}")
print(f"'sat' (pos: 2) relate more to   : {prob_w[2]}")
print(f"'on'  (pos: 3) relate more to   : {prob_w[3]}")
print(f"'mat' (pos: 4) relate more to   : {prob_w[4]}")

Relation between each word and it-self
'the' (pos: 0) relate more to   : tensor([0.1401, 0.0856, 0.2574, 0.3648, 0.1521], grad_fn=<SelectBackward0>)
'cat' (pos: 1) relate more to   : tensor([0.1301, 0.0729, 0.2587, 0.3911, 0.1472], grad_fn=<SelectBackward0>)
'sat' (pos: 2) relate more to   : tensor([0.1612, 0.1149, 0.2429, 0.3090, 0.1719], grad_fn=<SelectBackward0>)
'on'  (pos: 3) relate more to   : tensor([0.1669, 0.1256, 0.2412, 0.2942, 0.1721], grad_fn=<SelectBackward0>)
'mat' (pos: 4) relate more to   : tensor([0.1330, 0.0787, 0.2662, 0.3834, 0.1388], grad_fn=<SelectBackward0>)


In [None]:
result = prob_w @ V
# final result always the same shape of V
print("Since attention_dim is 3, each word get (1,3) weight")
print(f"This is the final Weighted sum, which is act as embedd in upcoming layer.")
print(f"'the' : {result[0]}")
print(f"'cat' : {result[1]}")
print(f"'sat' : {result[2]}")
print(f"'on'  : {result[3]}")
print(f"'mat' : {result[4]}")

Since attention_dim is 3, each word get (1,3) weight
This is the final Weighted sum, which is act as embedd in upcoming layer.
'the' : tensor([-0.4120,  0.7470,  0.2107], grad_fn=<SelectBackward0>)
'cat' : tensor([-0.4017,  0.7504,  0.2251], grad_fn=<SelectBackward0>)
'sat' : tensor([-0.4356,  0.7422,  0.1784], grad_fn=<SelectBackward0>)
'on'  : tensor([-0.4424,  0.7391,  0.1686], grad_fn=<SelectBackward0>)
'mat' : tensor([-0.4045,  0.7463,  0.2205], grad_fn=<SelectBackward0>)
