# LLM from scratch
This notebook contains code for LLM-from-scratch book.

## Ch 3 - Attention Module

In [None]:
import torch
X = torch.tensor([
    [0.43, 0.15, 0.89], # Your     (x^1)
    [0.55, 0.87, 0.66], # journey  (x^2)
    [0.57, 0.85, 0.64], # starts (x^3)
    [0.22, 0.58, 0.33], # with (x^4)
    [0.77, 0.25, 0.10], # one (x^5)
    [0.05, 0.80, 0.55] # step (x^6)
])

# simple affinity : dot-product (to measure similarity)
def affinity(x, y):
    """Given 2 vectors, compute affinity"""
    return torch.dot(x, y)

# step 1 : calculate attention weights 
# attention(query, x) for all x in input
query_idx = 1
query_token = X[query_idx]
attention_weights = torch.tensor([affinity(x_i, query_token) for (_, x_i) in enumerate(X)])
attention_weights = torch.tensor([a / attention_weights.sum() for a in attention_weights])
attention_weights = attention_weights.view(-1, 1)

print("\n\n-- attention --")
print(f"token[{query_idx}]: {query_token}")
print("A(.) is affinity")
for idx, score in enumerate(attention_weights):
    print(f"w({idx}) = A(x({query_idx}), x({idx})) : {score}")

# step 2 
# combine attention weights with vectors
query = X[1]
list_context_vectors = attention_weights * X
context_vector = list_context_vectors.sum(dim=0, keepdim=True)
print("\n\n-- context --")
print("list_context_vectors : ", list_context_vectors.shape)
for idx, vec in enumerate(list_context_vectors):
    print(f"w({idx})* x[{idx}] : {vec}")
print("list_context_vectors : ", list_context_vectors)
print("context_wrt_query: ", context_vector.shape)
print(context_vector)



-- attention --
token[1]: tensor([0.5500, 0.8700, 0.6600])
A(.) is affinity
w(0) = A(x(1), x(0)) : tensor([0.1455])
w(1) = A(x(1), x(1)) : tensor([0.2278])
w(2) = A(x(1), x(2)) : tensor([0.2249])
w(3) = A(x(1), x(3)) : tensor([0.1285])
w(4) = A(x(1), x(4)) : tensor([0.1077])
w(5) = A(x(1), x(5)) : tensor([0.1656])


-- context --
list_context_vectors :  torch.Size([6, 3])
list_context_vectors :  tensor([[0.0625, 0.0218, 0.1295],
        [0.1253, 0.1982, 0.1504],
        [0.1282, 0.1911, 0.1439],
        [0.0283, 0.0745, 0.0424],
        [0.0830, 0.0269, 0.0108],
        [0.0083, 0.1325, 0.0911]])
context_wrt_query:  torch.Size([1, 3])
tensor([[0.4355, 0.6451, 0.5680]])


tensor([0.4355, 0.6451, 0.5680])
