# LLM from scratch
This notebook contains code for LLM-from-scratch book.

## Ch 3 - Attention Module

In [5]:
import torch
import torch.nn as nn

### Simple attention example

In [6]:

X = torch.tensor([
    [0.43, 0.15, 0.89], # Your     (x^1)
    [0.55, 0.87, 0.66], # journey  (x^2)
    [0.57, 0.85, 0.64], # starts (x^3)
    [0.22, 0.58, 0.33], # with (x^4)
    [0.77, 0.25, 0.10], # one (x^5)
    [0.05, 0.80, 0.55] # step (x^6)
])

# simple affinity : dot-product (to measure similarity)
def affinity(x, y):
    """Given 2 vectors, compute affinity"""
    return torch.dot(x, y)

# step 1 : calculate attention weights 
# idea : If query q : how much should each token of input X (i.e. x1, x2, ...) be weighed in importance 
# attention(query, x) for all x in input
query_idx = 1
query_token = X[query_idx]
attention_weights = torch.tensor([affinity(x_i, query_token) for (_, x_i) in enumerate(X)])
attention_weights = torch.tensor([a / attention_weights.sum() for a in attention_weights])
attention_weights = attention_weights.view(-1, 1)

print("\n\n-- attention --")
print(f"token[{query_idx}]: {query_token}")
print("A(.) is affinity")
for idx, score in enumerate(attention_weights):
    print(f"w({idx}) = A(x({query_idx}), x({idx})) : {score}")

# step 2 : compute context vectors  
# idea : Given query q and attention weights, create "information context" using weighted sum approach
# idea : "information context" tells LLM how to make use of all the input tokens
query = X[1]
list_context_vectors = attention_weights * X
context_vector = list_context_vectors.sum(dim=0, keepdim=True)
print("\n\n-- context --")
print("list_context_vectors : ", list_context_vectors.shape)
for idx, vec in enumerate(list_context_vectors):
    print(f"z({idx}) = w({idx})* x[{idx}] : {vec}")

print("\ncontext_wrt_query: ", context_vector.shape)
print(context_vector)

# step 3 - vectorize 
print("\n\n-- vectorize --")
attention_scores = X @ X.T # compute attention pair-wise for each x_i, x_j pair using dot-product 
attention_weights = torch.softmax(attention_scores, dim=-1) # row_i = attention weights w.r.t x_i
context_matrix = attention_weights @ X # output (n, k) where each row i is attention_context for x_i
print("context shape: ", context_matrix.shape)



-- attention --
token[1]: tensor([0.5500, 0.8700, 0.6600])
A(.) is affinity
w(0) = A(x(1), x(0)) : tensor([0.1455])
w(1) = A(x(1), x(1)) : tensor([0.2278])
w(2) = A(x(1), x(2)) : tensor([0.2249])
w(3) = A(x(1), x(3)) : tensor([0.1285])
w(4) = A(x(1), x(4)) : tensor([0.1077])
w(5) = A(x(1), x(5)) : tensor([0.1656])


-- context --
list_context_vectors :  torch.Size([6, 3])
z(0) = w(0)* x[0] : tensor([0.0625, 0.0218, 0.1295])
z(1) = w(1)* x[1] : tensor([0.1253, 0.1982, 0.1504])
z(2) = w(2)* x[2] : tensor([0.1282, 0.1911, 0.1439])
z(3) = w(3)* x[3] : tensor([0.0283, 0.0745, 0.0424])
z(4) = w(4)* x[4] : tensor([0.0830, 0.0269, 0.0108])
z(5) = w(5)* x[5] : tensor([0.0083, 0.1325, 0.0911])

context_wrt_query:  torch.Size([1, 3])
tensor([[0.4355, 0.6451, 0.5680]])


-- vectorize --
context shape:  torch.Size([6, 3])


### Self-attention 
Self-attention introduces 3 trainable parameters ($W_q$(query), $W_k$(key), $W_v$(value)) matrices ontop of attention mechanism

In [9]:
# hyperparameters
d_in = X.shape[1]
d_out = 2
x_2 = X[1]

# define trainable parameters
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

# step 1 : map X into query(x) and key(x)
X_query = X @ W_query 
X_key = X @ W_key 

# step 2 : compute attention scores
# note : a_ij = query(x_i) dot key(x_j)
d_k = X_key.shape[-1]
attention_scores = X_query @ X_key.T 
attention_weights = torch.softmax(attention_scores / d_k **0.5, dim=-1)

# step 3 : compute context 
# idea : context = attention_score * value 
X_value = X @ W_value
context = attention_weights @ X_value

# attention_weights.sum(dim=1, keepdim=True)
print(X.shape)
print(attention_scores.shape)
print(attention_weights.shape)
print(X_value.shape)

torch.Size([6, 3])
torch.Size([6, 6])
torch.Size([6, 6])
torch.Size([6, 2])


In [51]:
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key   = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def update_matrices(self, W_q, W_k, W_v):
        self.W_query = nn.Parameter(W_q)
        self.W_key = nn.Parameter(W_k)
        self.W_value = nn.Parameter(W_v)

    def forward(self, x):
        # step 1 : map X into query(x) and key(x)
        x_query = x @ self.W_query
        x_key = x @ self.W_key
        x_value = x @ self.W_value

        print("x_query:", x_query)
        print("x_key:", x_key)
        print("x_value:", x_value)
        dk_constant = x_key.shape[-1]

        # step 2 : compute attention
        attention_scores = x_query @ x_key.T 
        attention_weights = torch.softmax(attention_scores / dk_constant **0.5, dim=-1)

        # step 3 : compute context 
        return attention_weights @ x_value


class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        # step 1 : map X into query(x) and key(x)
        x_query = self.W_query(x)
        x_key = self.W_key(x)
        x_value = self.W_value(x)
        dk_constant = x_key.shape[-1]

        # step 2 : compute attention
        attention_scores = x_query @ x_key.T 
        attention_weights = torch.softmax(attention_scores / dk_constant **0.5, dim=-1)

        # step 3 : compute context 
        return attention_weights @ x_value
        


In [52]:
torch.manual_seed(789)

# exercise 3.1 
sa_v2 = SelfAttention_v2(d_in, d_out)
sa_v1 = SelfAttention_v1(d_in, d_out)
sa_v1.update_matrices(W_k=sa_v2.W_key.weight.T, W_q=sa_v2.W_query.weight.T, W_v=sa_v2.W_value.weight.T)

# check same result
y_1 = sa_v1.forward(X)
y_2 = sa_v2.forward(X)

x_query: tensor([[ 0.6600, -0.2047],
        [ 0.9091, -0.4471],
        [ 0.8960, -0.4419],
        [ 0.5034, -0.2633],
        [ 0.4088, -0.2232],
        [ 0.6628, -0.3292]], grad_fn=<MmBackward0>)
x_key: tensor([[ 0.3147, -0.4016],
        [-0.0298, -0.4459],
        [-0.0170, -0.4262],
        [-0.1054, -0.2724],
        [ 0.2185,  0.0482],
        [-0.2258, -0.4782]], grad_fn=<MmBackward0>)
x_value: tensor([[-0.0872,  0.0286],
        [-0.1137,  0.0766],
        [-0.1018,  0.0927],
        [-0.0912, -0.0026],
        [ 0.1395,  0.3580],
        [-0.2085, -0.1546]], grad_fn=<MmBackward0>)
x_query: tensor([[ 0.6600, -0.2047],
        [ 0.9091, -0.4471],
        [ 0.8960, -0.4419],
        [ 0.5034, -0.2633],
        [ 0.4088, -0.2232],
        [ 0.6628, -0.3292]], grad_fn=<MmBackward0>)
x_key: tensor([[ 0.3147, -0.4016],
        [-0.0298, -0.4459],
        [-0.0170, -0.4262],
        [-0.1054, -0.2724],
        [ 0.2185,  0.0482],
        [-0.2258, -0.4782]], grad_fn=<MmBackward0>)

In [47]:
print(sa_v1.W_key)
print(sa_v2.W_key.weight)

Parameter containing:
tensor([[ 0.4058, -0.4704,  0.2368],
        [ 0.2134, -0.2601, -0.5105]], requires_grad=True)
Parameter containing:
tensor([[ 0.4058, -0.4704,  0.2368],
        [ 0.2134, -0.2601, -0.5105]], requires_grad=True)
