# Topic 

This notebook reproduces pytorch's built-in module `nn.MultiheadAttention`.

# Motivation

Understanding how to implement MultiheadAttention from scratch is crucial for grasping its core mechanics. However, once mastered, it's more practical to use pre-built implementations for better code maintainability, readability, and optimized performance.However, it's essential to verify how these pre-built components work under the hood to effectively adapt them for different use cases and avoid unexpected behaviors

# Structure 

Configuration
  - Identical hyperparameters and inputs used for both parts


Implementation Methods

  - Part One: PyTorch's Built-in Implementation
  - Part Two: Custom Implementation from Scratch


Validation

  - Results comparison between Part One and Part Two
  - Verification of output equality

Afterwords 


In [5]:
import torch 
from torch import nn 

# Configuration

identical Hyperparameters for PART ONE and PART TWO 

In [6]:
embed_dim = 4
num_heads = 2
num_tokens = 8
batch_size = 1

identical inputs (with `torch.manual_seed`) for PART ONE and PART TWO 

In [7]:
torch.manual_seed(1)

# why different inputs to produce Q,K,V ?? 
# stay tunned. 

input_to_produce_Q = torch.rand(batch_size,num_tokens,embed_dim)
input_to_produce_K = torch.rand(batch_size,num_tokens,embed_dim)
input_to_produce_V = torch.rand(batch_size,num_tokens,embed_dim)

# Implementation Methods 

### PART ONE : Pytorch's  Implementation 

In [8]:
torch.manual_seed(1)

multihead_attn = nn.MultiheadAttention(embed_dim=embed_dim,
                                       num_heads=num_heads,
                                       dropout=0,
                                       bias = False,
                                       add_bias_kv=False,
                                       batch_first=True,    # important
                                       device=None)



# To grab initialized `nn.MultiheadAttention`s weights.....
print('The following are W_q,W_k,W_v concatenated into one matrix, this is how pytorch initializes the weight matrix\n')
print(multihead_attn.in_proj_weight)
print(multihead_attn.in_proj_weight.shape)


print('----'*20)

print('THe following is out projection weight\n')
print(multihead_attn.out_proj.weight)  # Output projection weights
print(multihead_attn.out_proj.weight.shape)  # Output projection weights


The following are W_q,W_k,W_v concatenated into one matrix, this is how pytorch initializes the weight matrix

Parameter containing:
tensor([[ 0.0888, -0.0024,  0.5353,  0.1906],
        [-0.2281, -0.3698, -0.1026, -0.2641],
        [-0.1962,  0.0293,  0.3651,  0.3328],
        [-0.5986,  0.3796,  0.1711,  0.5809],
        [ 0.4042, -0.5580, -0.5822, -0.2954],
        [ 0.5377, -0.1020,  0.2621, -0.2846],
        [ 0.6009, -0.2591,  0.4592,  0.0073],
        [-0.3226,  0.3148, -0.3251,  0.1801],
        [-0.1768, -0.0671, -0.5887, -0.2920],
        [ 0.3323, -0.1489,  0.6100,  0.4909],
        [-0.0287, -0.4087,  0.3729,  0.1901],
        [-0.3958,  0.3978,  0.3718,  0.5431]], requires_grad=True)
torch.Size([12, 4])
--------------------------------------------------------------------------------
THe following is out projection weight

Parameter containing:
tensor([[ 0.2576, -0.2207, -0.0969,  0.2347],
        [-0.4707,  0.2999, -0.1029,  0.2544],
        [ 0.0695, -0.0612,  0.1387,  0.

### Few words on `nn.MultiheadAttention` arguments.

### if `bias` = True and `add_kv_bias` = True

$Q = x \cdot W^Q + b^Q$

$K = x \cdot W^K + b^K + \text{bias}_{K_shared}$

$V = x \cdot W^V + b^V + \text{bias}_{V_shared}$


for simplicity we'll set both `bias` and `add_kv_bias` to `False` in `nn.MultiheadAttention`

Note : pytorch `nn.MultiheadAttention()` does not mask (Q.K^T) by default. we have to pass the argument. 

In [9]:
attn_mask = torch.triu(torch.ones(num_tokens,num_tokens),diagonal=1).bool()
attn_mask

tensor([[False,  True,  True,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True],
        [False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False, False,  True],
        [False, False, False, False, False, False, False, False]])

### PART ONE RESULT

- PyTorch's nn.MultiheadAttention is designed to be flexible for different use cases, so it requires Q, K, and V as inputs rather than assuming they will be derived from the same source every time. 

In [10]:

with torch.no_grad():
    PART_ONE_RESULT, attn_output_weights_part_one  = multihead_attn(input_to_produce_Q,input_to_produce_K,input_to_produce_V,
                                                                    attn_mask = attn_mask )

print('PART ONE RESULT :')
print(PART_ONE_RESULT)
print(PART_ONE_RESULT.shape)

print('----'*20)

print('PART ONE ATTENTION OUTPUT WEIGHTS :')
print(attn_output_weights_part_one)    # Notice pytorch doesn't return attn_output_weights for each single head separately, rather it takes the average across the heads. 
print(attn_output_weights_part_one.shape)

PART ONE RESULT :
tensor([[[-0.1419,  0.5573, -0.0425, -0.2406],
         [-0.1758,  0.5663, -0.0353, -0.2560],
         [-0.2062,  0.6112, -0.0474, -0.2779],
         [-0.1789,  0.5776, -0.0528, -0.2556],
         [-0.1743,  0.5528, -0.0469, -0.2469],
         [-0.1835,  0.5435, -0.0493, -0.2459],
         [-0.1755,  0.5388, -0.0514, -0.2401],
         [-0.1663,  0.5134, -0.0536, -0.2267]]])
torch.Size([1, 8, 4])
--------------------------------------------------------------------------------
PART ONE ATTENTION OUTPUT WEIGHTS :
tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5267, 0.4733, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3607, 0.3203, 0.3190, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2671, 0.2432, 0.2426, 0.2471, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2208, 0.1900, 0.1895, 0.1925, 0.2072, 0.0000, 0.0000, 0.0000],
         [0.1757, 0.1642, 0.1639, 0.1645, 0.1703, 0.1613, 0.0000, 0.0000],
         [0.158

we'll use the weights which `nn.MultiheadAttention` initialized, in our `FromScratchMultiheadAttention`  

### PART TWO : FROM SCRATCH

In [11]:
class FromScratchMultiheadAttention(nn.Module):
  def __init__(self,context_window,embed_dim,num_heads,dropout=0,add_bias_kv=False,device=None):
    super().__init__()

    # we will assume d_in == d_out and they are both embed_dim.

    # Handling dimensions
    assert embed_dim % num_heads == 0, 'Embedding must be divisible by Number of heads'
    self.embed_dim = embed_dim
    self.num_heads = num_heads
    self.head_dim = self.embed_dim//self.num_heads

    # W_q, W_k, W_v
    self.W_q      = nn.Linear(embed_dim,embed_dim,bias=False)
    self.W_k      = nn.Linear(embed_dim,embed_dim,bias=False)
    self.W_v      = nn.Linear(embed_dim,embed_dim,bias=False)
    self.out_proj = nn.Linear(embed_dim,embed_dim,bias=False)


    W_q,W_k,W_v = multihead_attn.in_proj_weight.chunk(3)   # pytorch's internal initialization of nn.MultiheadAttention. pytorch initialize all three (q,k,v) in a single matrix.
    out_proj = multihead_attn.out_proj.weight

    # we gonna put the initialized weight by nn.MultiheadAttention to our layers. so we can see if they will produce the same result
    self.W_q.weight.data = W_q
    self.W_k.weight.data = W_k
    self.W_v.weight.data = W_v
    self.out_proj.weight.data = out_proj



    # Miscellaneous
    self.register_buffer('mask',torch.triu(torch.ones(context_window,context_window),diagonal=1))
    self.dropout = nn.Dropout(dropout)

    
  
  def forward(self,input_to_produce_Q,input_to_produce_K,input_to_produce_V):
    B_q,num_token_q,embed_dim_q = input_to_produce_Q.shape
    B_k,num_token_k,embed_dim_k = input_to_produce_K.shape
    B_v,num_token_v,embed_dim_v = input_to_produce_V.shape

    Q = self.W_q(input_to_produce_Q) 
    K = self.W_k(input_to_produce_K) 
    V = self.W_v(input_to_produce_V) 

    # splitting 
    Q = Q.view(B_q,num_token_q,self.num_heads,self.head_dim).transpose(1,2)
    K = K.view(B_k,num_token_k,self.num_heads,self.head_dim).transpose(1,2)
    V = V.view(B_v,num_token_v,self.num_heads,self.head_dim).transpose(1,2)

    # QK,mask,softmax,dropout
    attn_score = Q @ K.transpose(2,3)
    attn_score.masked_fill_(self.mask.bool()[:num_token_q,:num_token_k],-torch.inf)
    attn_weight = torch.softmax(attn_score/K.shape[-1]**0.5,dim=-1)
    attn_weight = self.dropout(attn_weight)

    # context_vec
    context_vec = attn_weight @ V

    # Putting the heads back together 
    context_vec = context_vec.transpose(1,2).contiguous().view(B_q,num_token_q,self.embed_dim)    # it doesn't matter which (B) you choose

    # projection 
    context_vec = self.out_proj(context_vec)

    return context_vec,attn_weight

    

### PART TWO RESULT  

Note: you can think of context_window as a max num_tokens your model can process at one go. since  we are feeding the model 8 tokens (numz-tokens = 8), context_window anything above 8 will do. 


In [12]:
mha = FromScratchMultiheadAttention(context_window=1024,   # see the above note.
                                    embed_dim=embed_dim,
                                    num_heads=num_heads,
                                    dropout=0.)


with torch.no_grad():
    PART_TWO_RESULT,attn_output_weights_part_two = mha(input_to_produce_Q,input_to_produce_K,input_to_produce_V)

print('PART TWO RESULT :')
print(PART_TWO_RESULT)
print(PART_TWO_RESULT.shape)

print('----'*20)

print('PART TWO ATTENTION OUTPUT WEIGHTS :')
print(attn_output_weights_part_two)    # Notice pytorch doesn't return attn_output_weights for each single head separately, rather it takes the average across the heads. 
print(attn_output_weights_part_two.shape)

PART TWO RESULT :
tensor([[[-0.1419,  0.5573, -0.0425, -0.2406],
         [-0.1758,  0.5663, -0.0353, -0.2560],
         [-0.2062,  0.6112, -0.0474, -0.2779],
         [-0.1789,  0.5776, -0.0528, -0.2556],
         [-0.1743,  0.5528, -0.0469, -0.2469],
         [-0.1835,  0.5435, -0.0493, -0.2459],
         [-0.1755,  0.5388, -0.0514, -0.2401],
         [-0.1663,  0.5134, -0.0536, -0.2267]]])
torch.Size([1, 8, 4])
--------------------------------------------------------------------------------
PART TWO ATTENTION OUTPUT WEIGHTS :
tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.5547, 0.4453, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.3986, 0.3052, 0.2961, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.2869, 0.2303, 0.2247, 0.2581, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.2441, 0.1714, 0.1646, 0.1933, 0.2266, 0.0000, 0.0000, 0.0000],
          [0.1863, 0.1589, 0.1560, 0.1677, 0.1801, 0.1511, 0.0000, 0.0000],
         

In [13]:
attn_output_weights_part_two = attn_output_weights_part_two.mean(dim=1)   # taking average across heads dimension.
attn_output_weights_part_two

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5267, 0.4733, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3607, 0.3203, 0.3190, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2671, 0.2432, 0.2426, 0.2471, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2208, 0.1900, 0.1895, 0.1925, 0.2072, 0.0000, 0.0000, 0.0000],
         [0.1757, 0.1642, 0.1639, 0.1645, 0.1703, 0.1613, 0.0000, 0.0000],
         [0.1585, 0.1371, 0.1363, 0.1437, 0.1519, 0.1343, 0.1381, 0.0000],
         [0.1358, 0.1172, 0.1160, 0.1333, 0.1362, 0.1175, 0.1209, 0.1231]]])

# Validation

### Comparing `PART_ONE_RESULT` with `PART_TWO_RESULT`

In [14]:
print(f'Output of PART ONE :\n\n{PART_ONE_RESULT}')
print('---'*30)
print(f'Output of PART TWO :\n\n{PART_TWO_RESULT}')
print('---'*30)
print(f'\nPART ONE is equal to PART TWO ? {torch.allclose(PART_ONE_RESULT, PART_TWO_RESULT)}')


Output of PART ONE :

tensor([[[-0.1419,  0.5573, -0.0425, -0.2406],
         [-0.1758,  0.5663, -0.0353, -0.2560],
         [-0.2062,  0.6112, -0.0474, -0.2779],
         [-0.1789,  0.5776, -0.0528, -0.2556],
         [-0.1743,  0.5528, -0.0469, -0.2469],
         [-0.1835,  0.5435, -0.0493, -0.2459],
         [-0.1755,  0.5388, -0.0514, -0.2401],
         [-0.1663,  0.5134, -0.0536, -0.2267]]])
------------------------------------------------------------------------------------------
Output of PART TWO :

tensor([[[-0.1419,  0.5573, -0.0425, -0.2406],
         [-0.1758,  0.5663, -0.0353, -0.2560],
         [-0.2062,  0.6112, -0.0474, -0.2779],
         [-0.1789,  0.5776, -0.0528, -0.2556],
         [-0.1743,  0.5528, -0.0469, -0.2469],
         [-0.1835,  0.5435, -0.0493, -0.2459],
         [-0.1755,  0.5388, -0.0514, -0.2401],
         [-0.1663,  0.5134, -0.0536, -0.2267]]])
------------------------------------------------------------------------------------------

PART ONE is equal

### Comparing `attn_output_weights_part_one` with `attn_output_weights_part_two`

In [15]:
print(f'Output of PART ONE :\n\n{attn_output_weights_part_one}')
print('---'*30)
print(f'Output of PART TWO :\n\n{attn_output_weights_part_two}')
print('---'*30)
print(f'\nPART ONE is equal to PART TWO ? {torch.allclose(attn_output_weights_part_one, attn_output_weights_part_two)}')

Output of PART ONE :

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5267, 0.4733, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3607, 0.3203, 0.3190, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2671, 0.2432, 0.2426, 0.2471, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2208, 0.1900, 0.1895, 0.1925, 0.2072, 0.0000, 0.0000, 0.0000],
         [0.1757, 0.1642, 0.1639, 0.1645, 0.1703, 0.1613, 0.0000, 0.0000],
         [0.1585, 0.1371, 0.1363, 0.1437, 0.1519, 0.1343, 0.1381, 0.0000],
         [0.1358, 0.1172, 0.1160, 0.1333, 0.1362, 0.1175, 0.1209, 0.1231]]])
------------------------------------------------------------------------------------------
Output of PART TWO :

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5267, 0.4733, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3607, 0.3203, 0.3190, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2671, 0.2432, 0.2426, 0.24

# Afterwords 

We can feel comfortable proceeding with `nn.MultiheadAttention`, knowing exactly how the math is calculated.