# LLM from scratch
This notebook contains code for LLM-from-scratch book.

## Ch 3 - Attention Module

In [15]:
import torch
import torch.nn as nn

### Simple attention example

In [16]:

X = torch.tensor([
    [0.43, 0.15, 0.89], # Your     (x^1)
    [0.55, 0.87, 0.66], # journey  (x^2)
    [0.57, 0.85, 0.64], # starts (x^3)
    [0.22, 0.58, 0.33], # with (x^4)
    [0.77, 0.25, 0.10], # one (x^5)
    [0.05, 0.80, 0.55] # step (x^6)
])

# simple affinity : dot-product (to measure similarity)
def affinity(x, y):
    """Given 2 vectors, compute affinity"""
    return torch.dot(x, y)

# step 1 : calculate attention weights 
# idea : If query q : how much should each token of input X (i.e. x1, x2, ...) be weighed in importance 
# attention(query, x) for all x in input
query_idx = 1
query_token = X[query_idx]
attention_weights = torch.tensor([affinity(x_i, query_token) for (_, x_i) in enumerate(X)])
attention_weights = torch.tensor([a / attention_weights.sum() for a in attention_weights])
attention_weights = attention_weights.view(-1, 1)

print("\n\n-- attention --")
print(f"token[{query_idx}]: {query_token}")
print("A(.) is affinity")
for idx, score in enumerate(attention_weights):
    print(f"w({idx}) = A(x({query_idx}), x({idx})) : {score}")

# step 2 : compute context vectors  
# idea : Given query q and attention weights, create "information context" using weighted sum approach
# idea : "information context" tells LLM how to make use of all the input tokens
query = X[1]
list_context_vectors = attention_weights * X
context_vector = list_context_vectors.sum(dim=0, keepdim=True)
print("\n\n-- context --")
print("list_context_vectors : ", list_context_vectors.shape)
for idx, vec in enumerate(list_context_vectors):
    print(f"z({idx}) = w({idx})* x[{idx}] : {vec}")

print("\ncontext_wrt_query: ", context_vector.shape)
print(context_vector)

# step 3 - vectorize 
print("\n\n-- vectorize --")
attention_scores = X @ X.T # compute attention pair-wise for each x_i, x_j pair using dot-product 
attention_weights = torch.softmax(attention_scores, dim=-1) # row_i = attention weights w.r.t x_i
context_matrix = attention_weights @ X # output (n, k) where each row i is attention_context for x_i
print("context shape: ", context_matrix.shape)



-- attention --
token[1]: tensor([0.5500, 0.8700, 0.6600])
A(.) is affinity
w(0) = A(x(1), x(0)) : tensor([0.1455])
w(1) = A(x(1), x(1)) : tensor([0.2278])
w(2) = A(x(1), x(2)) : tensor([0.2249])
w(3) = A(x(1), x(3)) : tensor([0.1285])
w(4) = A(x(1), x(4)) : tensor([0.1077])
w(5) = A(x(1), x(5)) : tensor([0.1656])


-- context --
list_context_vectors :  torch.Size([6, 3])
z(0) = w(0)* x[0] : tensor([0.0625, 0.0218, 0.1295])
z(1) = w(1)* x[1] : tensor([0.1253, 0.1982, 0.1504])
z(2) = w(2)* x[2] : tensor([0.1282, 0.1911, 0.1439])
z(3) = w(3)* x[3] : tensor([0.0283, 0.0745, 0.0424])
z(4) = w(4)* x[4] : tensor([0.0830, 0.0269, 0.0108])
z(5) = w(5)* x[5] : tensor([0.0083, 0.1325, 0.0911])

context_wrt_query:  torch.Size([1, 3])
tensor([[0.4355, 0.6451, 0.5680]])


-- vectorize --
context shape:  torch.Size([6, 3])


### Self-attention 
Self-attention introduces 3 trainable parameters ($W_q$(query), $W_k$(key), $W_v$(value)) matrices ontop of attention mechanism

In [17]:
# hyperparameters
d_in = X.shape[1]
d_out = 7
x_2 = X[1]

# define trainable parameters
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

# step 1 : map X into query(x) and key(x)
X_query = X @ W_query 
X_key = X @ W_key 

# step 2 : compute attention scores
# note : a_ij = query(x_i) dot key(x_j)
d_k = X_key.shape[-1]
attention_scores = X_query @ X_key.T 
attention_weights = torch.softmax(attention_scores / d_k **0.5, dim=-1)

# step 3 : compute context 
# idea : context = attention_score * value 
X_value = X @ W_value
context = attention_weights @ X_value

# attention_weights.sum(dim=1, keepdim=True)
print(X.shape)
print(attention_scores.shape)
print(attention_weights.shape)
print(X_value.shape)

torch.Size([6, 3])
torch.Size([6, 6])
torch.Size([6, 6])
torch.Size([6, 7])


In [18]:
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key   = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def update_matrices(self, W_q, W_k, W_v):
        self.W_query = nn.Parameter(W_q)
        self.W_key = nn.Parameter(W_k)
        self.W_value = nn.Parameter(W_v)

    def forward(self, x):
        # step 1 : map X into query(x) and key(x)
        x_query = x @ self.W_query
        x_key = x @ self.W_key
        x_value = x @ self.W_value
        dk_constant = x_key.shape[-1]

        # step 2 : compute attention
        attention_scores = x_query @ x_key.T 
        attention_weights = torch.softmax(attention_scores / dk_constant **0.5, dim=-1)

        # step 3 : compute context 
        return attention_weights @ x_value


class CasualAttention(nn.Module):

    def __init__(self, d_in, d_out, context_length, dropout=0.1, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def _forward_full_attention(self, x):
        """Forward method with full attention"""
        # step 1 : map X into query(x) and key(x)
        x_query = self.W_query(x)
        x_key = self.W_key(x)
        x_value = self.W_value(x)
        dk_constant = x_key.shape[-1] ** -0.5

        # step 2 : compute attention
        attention_scores = x_query @ x_key.T 
        attention_weights = torch.softmax(attention_scores * dk_constant, dim=-1)

        # step 3 : compute context
        return attention_weights @ x_value

    def _forward_masked_attention(self, x):

        # step 1 : map X into query(x) and key(x)
        x_query = self.W_query(x)
        x_key = self.W_key(x)
        x_value = self.W_value(x)
        dk_constant = x_key.shape[-1] ** -0.5

        # step 2 : compute attention
        attention_scores = x_query @ x_key.T 

        # step 3 : compute attention w/ mask 
        # idea : cannot see future tokens, AI can only see past tokens for predictions
        context_length = attention_scores.shape[0]
        mask = torch.tril(torch.ones(context_length, context_length))
        attention_scores_mask = attention_scores.masked_fill(~mask.bool(), -torch.inf)
        attention_weights_mask = torch.softmax(attention_scores_mask * dk_constant, dim=-1)

        # step 4 : compute context
        return attention_weights_mask @ x_value

    def _forward_masked_attention_v2(self, x):
        batch, num_tokens, vocab_dim = x.shape

        x_query = self.W_query(x)
        x_key = self.W_key(x)
        x_value = self.W_value(x)

        print(f"x_q: {x_query.shape}")
        print(f"x_k: {x_key.shape}")

        dk_constant = x_key.shape[-1] ** -0.5
        mask_context = self.mask.bool()[:num_tokens, :num_tokens] # each batch might have different length context

        attn_scores = x_query @ x_key.transpose(1, 2) # 0 is batch so we keep that constant
        print(f"single_head_attn: {attn_scores.shape}")
        attn_scores.masked_fill_(mask_context, -torch.inf)
        attn_weights = torch.softmax(attn_scores * dk_constant, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        return attn_weights @ x_value

    def forward(self, x):
        return self._forward_masked_attention_v2(x)


class MultiHeadAttentionWrapper(nn.Module):
    """Implementation of multihead attention"""
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CasualAttention(d_in=d_in, d_out=d_out, context_length=context_length, dropout=dropout, qkv_bias=qkv_bias)
             for _ in range(num_heads)])
        
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)



In [41]:
class MultiHeadAttention(nn.Module):
    """Implementation of multihead attention w/ parallel matrix processing"""

    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()

        # validate input dimensions
        assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"
        self.d_out = d_out
        self.num_heads = num_heads
        self.attention_dim = d_out // num_heads

        # setup attention matrices
        self.W_q = nn.Linear(in_features=d_in, out_features=d_out, bias=qkv_bias)
        self.W_k = nn.Linear(in_features=d_in, out_features=d_out, bias=qkv_bias)
        self.W_v = nn.Linear(in_features=d_in, out_features=d_out, bias=qkv_bias)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

        # setup dropout
        self.dropout = nn.Dropout(dropout)

    
    def forward(self, x):
        n, seq_length, _ = x.shape

        # compute Q, K, V matrices
        x_query = self.W_q(x)
        x_key = self.W_k(x)
        x_value = self.W_v(x)

        # reshape to separate into Q = [Q1, Q2, ...], K = [K1, K2, ...]
        x_query = x_query.view(n, seq_length, self.num_heads, self.attention_dim)
        x_query = x_query.transpose(1, 2) # (n, num_heads, seq_length, attention_dim)

        x_key = x_key.view(n, seq_length, self.num_heads, self.attention_dim)
        x_key = x_key.transpose(1, 2) # (n, num_heads, seq_length, attention_dim)
        x_key = x_key.transpose(2, 3) # (n, num_heads, attention_dim, seq_length)

        x_value = x_value.view(n, seq_length, self.num_heads, self.attention_dim)
        x_value = x_value.transpose(1, 2) # (n, num_heads, seq_length, attention_dim)

        # compute attention scores (per-head)
        dk_constant = x_key.shape[-1] ** -0.5
        mask_context = self.mask.bool()[:seq_length, :seq_length] 
        attention_scores = (x_query @ x_key)
        attention_scores.masked_fill_(mask_context, -torch.inf)

        # compute attention weights 
        # note : no dropout on scores (b/c dropout on -inf is not well-defined)
        attention_weights = torch.softmax(attention_scores * dk_constant, dim=-1)
        attention_weights = self.dropout(attention_weights)

        # compute context
        context = attention_weights @ x_value

        # reshape back to (n, seq_length, d_out)
        context = context.contiguous().view(n, seq_length, self.d_out)

        # print statements (if necesary)
        # print(f"x_query : {x_query.shape}")
        # print(f"x_key : {x_key.shape}")
        # print(f"x_value : {x_value.shape}")
        # print(f"attention_weights shape: {attention_weights.shape}")
        # print(f"attention_weights (single head) : {attention_weights[0, 3, :, :]}")
        # print(f"context : {context.shape}")

        return context


In [43]:
torch.manual_seed(789)

# input
x = torch.stack([X, X], dim=0)
n, seq_length, token_dim = x.shape

# hyperparameters
x_dim_in = token_dim
x_dim_out = 240
context_length = seq_length
dropout = 0.5

# create self attention
# single_attention = CasualAttention(d_in=x_dim_in, d_out=x_dim_out, context_length=context_length, dropout=dropout)
# single_attention(x)

# multi_attention = MultiHeadAttentionWrapper(d_in=x_dim_in, d_out=x_dim_out, context_length=context_length, dropout=dropout, num_heads=4)
multi_attention = MultiHeadAttention(d_in=x_dim_in, d_out=x_dim_out, context_length=context_length, dropout=dropout, num_heads=4)
multi_attention(x)
print(" ")




 
