## <center>Coding Attention Mechanisms</center>

We broke down the self-attention computation into many steps for clarity, examining it piece by piece. Now, we’ll bring everything together by implementing a compact Python class for self-attention based on what we’ve learned so far.

<br><p>
    
---

In [9]:
%load_ext watermark
%watermark -v -p watermark,torch

Python implementation: CPython
Python version       : 3.12.9
IPython version      : 9.5.0

watermark: 2.5.0
torch    : 2.8.0



In [2]:
import torch 
import torch.nn as nn 

------

<br><p> 

## 1. Self-Attention 

In [3]:
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out): 
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out)) 

    def forward(self, x): 
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value 
        attention_scores = queries @ keys.T 
        attention_weights = torch.softmax(attention_scores / (keys.shape[-1] ** 0.5), dim=-1)
        context_vector = attention_weights @ values
        return context_vector

In [4]:
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [6]:
torch.manual_seed(123)
d_in = inputs.shape[1]     
d_out = 2 
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


As a quick check, notice that the second $row ([0.3061, 0.8210])$ matches the contents of **context_vec_2** in `slide 57`. 

---

**NOTE**
- As show below, we can improve the `SelfAttention_v1` implementation further by utilizing PyTorch’s `nn.Linear` layers, which effectively perform matrix multiplication when the bias units are disabled.
- Additionally, a significant advantage of using nn.Linear instead of manually implementing `nn.Parameter(torch.rand(...))` is that `nn.Linear` has an optimized weight initialization scheme, contributing to more stable and effective model training.

----

In [12]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False): 
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x): 
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attention_scores = queries @ keys.T 
        attention_weights = torch.softmax(attention_scores / (keys.shape[-1] ** 0.5), dim=-1)
        context_vector = attention_weights @ values
        return context_vector

In [14]:
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in=3, d_out=2)
print(sa_v2(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


<br><p>

**NOTE** 

`SelfAttention_v1` and `SelfAttention_v2` give different outputs because they use different initial weights for the weight matrices since `nn.Linear` uses a more sophisticated weight initialization scheme.

---

## 2. Causal Attention (Masked Attention)

First, let's ensure the code can handle batches consisting of more than one input; so that the `CausalAttention` class supports the batch outputs that will be produced by the dataloader. For simplicity, to simulate such batch inputs, we duplicate the input text example.

In [44]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)

torch.Size([2, 6, 3])


In [45]:
class CausalAttention(nn.Module): 
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__() 
        self.d_out = d_out  
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x): 
        b, num_tokens, d_in = x.shape #<--- 2 x 6 x 3 (batch size, number of tokens in input sequence, input embedding size)
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        
        attention_scores = queries @ keys.transpose(1, 2) #<--- We transpose dimensions 1 and 2, keeping the batch dimension at the first position (0)
        attention_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attention_weights = torch.softmax(attention_scores / keys.shape[-1] ** 0.5, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        context_vector = attention_weights @ values
        return context_vector

In [50]:
torch.manual_seed(123)
context_length = batch.shape[1]
d_in = 3
d_out = 2
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape)

context_vecs.shape: torch.Size([2, 6, 2])


---- 
**NOTE**

- While all added code lines should be familiar at this point, we added a `self.register_buffer()` call in the `__init__` method.
- The use of register_buffer in PyTorch is not strictly necessary for all use cases but offers several advantages here.
 > For instance, when we use the `CausalAttention` class in our LLM, buffers are automatically moved to the appropriate device (CPU or GPU) along with our model, which will be relevant when training our LLM.

<p><hr>

<br><p>

## 3. Multi-head Attention

In [56]:
class MultiHeadAttentionWrapper(nn.Module): 
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias = False): 
        super().__init__()
        self.heads = nn.ModuleList([CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)])

    def forward(self, x): 
        return torch.cat([head(x) for head in self.heads], dim=-1) 

torch.manual_seed(123)
context_length = batch.shape[1] # This is the number of tokens
d_in, d_out = 3, 2  

mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 4])


In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False): 
        super().__init__()
        assert d_out % num_heads = 0, "d_out must be divisible by num_heads" 
        self.d_out = d_out 
        self.num_heads = num_heads 
        self.head_dim = d_out // num_heads     #<--- Reduces the projection dim to match the desired output dim
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias) 
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias) 
        self.out_proj = nn.Linear(d_out, d_out) #<--- Uses a Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1)) 

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        
        # We implicitly split the matrix by adding a num_heads dimension. 
        # Then we unroll the last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim).
        
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim 

        # Transposes from shape: (b, num_tokens, num_heads, head_dim) to (b, num_heads,num_tokens, head_dim)
        keys = keys.transpose(1, 2)         
        queries = queries.transpose(1, 2)   
        values = values.transpose(1, 2)

        attention_scores = queries @ keys.transpose(2, 3)         #<--- Computes  dot product for each head
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]   #<--- Masks truncated to the number of tokens
        
        attn_scores.masked_fill_(mask_bool, -torch.inf)          #<--- Uses the mask to fill attention scores
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights) 

        attn_scores.masked_fill_(mask_bool, -torch.inf)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        context_vec = (attn_weights @ values).transpose(1, 2) #<--- Tensor shape: (b, num_tokens, n_heads, head_dim)
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out) #<--- Combines heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = self.out_proj(context_vec) #<--- Adds an optional linear projection
        
        return context_vec
                    

In [None]:
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)      
values = values.view(b, num_tokens, self.num_heads, self.head_dim)  
queries = queries.view(                                             
    b, num_tokens, self.num_heads, self.head_dim                    
) 