# Topic 

This notebook reproduces pytorch's built-in module `nn.MultiheadAttention`.

# Structure 

Configuration
  - Identical hyperparameters and inputs used for both parts


Implementation Methods

  - Part One: PyTorch's Built-in Implementation
  - Part Two: Custom Implementation from Scratch


Validation

  - Results comparison between Part One and Part Two
  - Verification of output equality

Afterwords 


In [1]:
import torch 
from torch import nn 

# Configuration

identical Hyperparameters for PART ONE and PART TWO 

In [2]:
num_heads = 2

identical inputs (with `torch.manual_seed`) for PART ONE and PART TWO 

In [3]:
torch.manual_seed(1)

# why different inputs to produce Q,K,V ?? 
# stay tunned. 


x = torch.tensor(
    [[[ 0.2096,  1.4551, -0.3562, -1.3084],
    [ 1.1463,  0.5509, -1.5349, -0.1624],
    [-0.3144,  1.7046, -0.5748, -0.8154],
    [ 1.6361, -0.5812, -0.0645, -0.9905]]]
         )

print(x.shape)


batch_size,num_tokens,embed_dim = x.shape


input_to_produce_Q = x 
input_to_produce_K = x 
input_to_produce_V = x 



torch.Size([1, 4, 4])


In [4]:
W_q = nn.Parameter(torch.tensor([
    [-0.5, 0.2, 0.7, -0.9],
    [0.1, -0.3, 0.8, 0.4],
    [-0.7, 0.6, -0.2, 0.9],
    [0.3, -0.8, 0.5, -0.1]
]))


W_k = nn.Parameter(torch.tensor([
    [0.3, -0.5, 0.2, 0.7],
    [-0.4, 0.1, -0.6, -0.2],
    [0.8, -0.3, 0.5, -0.7],
    [-0.1, 0.6, -0.9, 0.4]
]))

W_v = nn.Parameter(torch.tensor([
    [0.2, -0.8, 0.3, 0.5],
    [-0.7, 0.4, -0.1, -0.6],
    [0.9, -0.2, 0.7, -0.3],
    [-0.5, 0.1, -0.4, 0.8]
]))

out_proj = nn.Parameter(torch.tensor([
    [0.5, -0.3, 0.4, 0.2],
    [-0.6, 0.4, -0.2, -0.5], 
    [0.3, -0.7, 0.6, -0.4],
    [-0.2, 0.5, -0.4, 0.3]
]))

# Implementation Methods 

### PART ONE : Pytorch's  Implementation 

In [5]:
torch.manual_seed(5)

multihead_attn = nn.MultiheadAttention(embed_dim=embed_dim,
                                       num_heads=num_heads,
                                       dropout=0,
                                       bias = False,
                                       add_bias_kv=False,
                                       batch_first=True,    # important
                                       device=None)



multihead_attn.in_proj_weight = nn.Parameter(torch.concat([W_q.T,W_k.T,W_v.T]))
# To grab initialized `nn.MultiheadAttention`s weights.....
print('The following are W_q,W_k,W_v concatenated into one matrix, this is how pytorch initializes the weight matrix\n')
print(multihead_attn.in_proj_weight)
print(multihead_attn.in_proj_weight.shape)



print('----'*20)

multihead_attn.out_proj.weight = nn.Parameter(out_proj.T)
print('THe following is out projection weight\n')
print(multihead_attn.out_proj.weight)  # Output projection weights
print(multihead_attn.out_proj.weight.shape)  # Output projection shape


The following are W_q,W_k,W_v concatenated into one matrix, this is how pytorch initializes the weight matrix

Parameter containing:
tensor([[-0.5000,  0.1000, -0.7000,  0.3000],
        [ 0.2000, -0.3000,  0.6000, -0.8000],
        [ 0.7000,  0.8000, -0.2000,  0.5000],
        [-0.9000,  0.4000,  0.9000, -0.1000],
        [ 0.3000, -0.4000,  0.8000, -0.1000],
        [-0.5000,  0.1000, -0.3000,  0.6000],
        [ 0.2000, -0.6000,  0.5000, -0.9000],
        [ 0.7000, -0.2000, -0.7000,  0.4000],
        [ 0.2000, -0.7000,  0.9000, -0.5000],
        [-0.8000,  0.4000, -0.2000,  0.1000],
        [ 0.3000, -0.1000,  0.7000, -0.4000],
        [ 0.5000, -0.6000, -0.3000,  0.8000]], requires_grad=True)
torch.Size([12, 4])
--------------------------------------------------------------------------------
THe following is out projection weight

Parameter containing:
tensor([[ 0.5000, -0.6000,  0.3000, -0.2000],
        [-0.3000,  0.4000, -0.7000,  0.5000],
        [ 0.4000, -0.2000,  0.6000, -0.

### Few words on `nn.MultiheadAttention` arguments.

### if `bias` = True and `add_kv_bias` = True

$Q = x \cdot W^Q + b^Q$

$K = x \cdot W^K + b^K + \text{bias}_{K_shared}$

$V = x \cdot W^V + b^V + \text{bias}_{V_shared}$


for simplicity we'll set both `bias` and `add_kv_bias` to `False` in `nn.MultiheadAttention`

Note : pytorch `nn.MultiheadAttention()` does not mask (Q.K^T) by default. we have to pass the argument. 

In [6]:
attn_mask = torch.triu(torch.ones(num_tokens,num_tokens),diagonal=1).bool()

### PART ONE RESULT

- PyTorch's nn.MultiheadAttention is designed to be flexible for different use cases, so it requires Q, K, and V as inputs rather than assuming they will be derived from the same source every time. 

In [None]:
with torch.no_grad():
    PART_ONE_RESULT, attn_output_weights_part_one  = multihead_attn(input_to_produce_Q,
                                                                    input_to_produce_K,
                                                                    input_to_produce_V,
                                                                    attn_mask = attn_mask)
    

we'll use the weights which `nn.MultiheadAttention` initialized, in our `FromScratchMultiheadAttention`  

### PART TWO : FROM SCRATCH

In [8]:
class FromScratchMultiheadAttention(nn.Module):
  def __init__(self,context_window,embed_dim,num_heads,dropout=0,add_bias_kv=False,device=None):
    super().__init__()

    # we will assume d_in == d_out and they are both embed_dim.

    # Handling dimensions
    assert embed_dim % num_heads == 0, 'Embedding must be divisible by Number of heads'
    self.embed_dim = embed_dim
    self.num_heads = num_heads
    self.head_dim = self.embed_dim//self.num_heads


    # Miscellaneous
    self.register_buffer('mask',torch.triu(torch.ones(context_window,context_window),diagonal=1))
    self.dropout = nn.Dropout(dropout)

    
  
  def forward(self,input_to_produce_Q,input_to_produce_K,input_to_produce_V):
    B_q,num_token_q,_ = input_to_produce_Q.shape
    B_k,num_token_k,_ = input_to_produce_K.shape
    B_v,num_token_v,_ = input_to_produce_V.shape

    Q = input_to_produce_Q @ W_q 
    K = input_to_produce_K @ W_k 
    V = input_to_produce_V @ W_v 


    # splitting (turn it to multi-head)
    Q = Q.view(B_q,num_token_q,self.num_heads,self.head_dim).transpose(1,2)
    K = K.view(B_k,num_token_k,self.num_heads,self.head_dim).transpose(1,2)
    V = V.view(B_v,num_token_v,self.num_heads,self.head_dim).transpose(1,2)

    # QK,mask,softmax,dropout
    attn_score = Q @ K.transpose(2,3)
    attn_score.masked_fill_(self.mask.bool()[:num_token_q,:num_token_k],-torch.inf)
    attn_weight = torch.softmax(attn_score/K.shape[-1]**0.5,dim=-1)
    attn_weight = self.dropout(attn_weight)

    # context_vec
    context_vec = attn_weight @ V

    # Putting the heads back together 
    context_vec = context_vec.transpose(1,2).contiguous().view(B_q,num_token_q,self.embed_dim)    # it doesn't matter which (B) you choose

    # projection 
    context_vec = context_vec @ out_proj

    return context_vec,attn_weight

    

### PART TWO RESULT  

Note: you can think of context_window as a max num_tokens your model can process at one go. since  we are feeding the model 8 tokens (numz-tokens = 8), context_window anything above 8 will do. 


In [9]:
mha = FromScratchMultiheadAttention(context_window=1024,   # see the above note.
                                    embed_dim=embed_dim,
                                    num_heads=num_heads,
                                    dropout=0.)


with torch.no_grad():
    PART_TWO_RESULT,attn_output_weights_part_two = mha(input_to_produce_Q,input_to_produce_K,input_to_produce_V)

In [10]:
attn_output_weights_part_two = attn_output_weights_part_two.mean(dim=1)   # taking average across heads dimension.

# Validation

### Comparing `PART_ONE_RESULT` with `PART_TWO_RESULT`

In [11]:
print(f'Output of PART ONE :\n\n{PART_ONE_RESULT}')
print('---'*30)
print(f'Output of PART TWO :\n\n{PART_TWO_RESULT}')
print('---'*30)
print(f'\nPART ONE is equal to PART TWO ? {torch.allclose(PART_ONE_RESULT, PART_TWO_RESULT)}')


Output of PART ONE :

tensor([[[-0.1353, -0.6532,  0.4699, -0.8950],
         [-0.1263, -0.6508,  0.3784, -0.7918],
         [-0.6962,  0.3417, -0.4282, -0.4440],
         [-0.5026, -0.2205,  0.0192, -0.8127]]])
------------------------------------------------------------------------------------------
Output of PART TWO :

tensor([[[-0.1353, -0.6532,  0.4699, -0.8950],
         [-0.1263, -0.6508,  0.3784, -0.7918],
         [-0.6962,  0.3417, -0.4282, -0.4440],
         [-0.5026, -0.2205,  0.0192, -0.8127]]])
------------------------------------------------------------------------------------------

PART ONE is equal to PART TWO ? True


### Comparing `attn_output_weights_part_one` with `attn_output_weights_part_two`

In [12]:
print(f'Output of PART ONE :\n\n{attn_output_weights_part_one}')
print('---'*30)
print(f'Output of PART TWO :\n\n{attn_output_weights_part_two}')
print('---'*30)
print(f'\nPART ONE is equal to PART TWO ? {torch.allclose(attn_output_weights_part_one, attn_output_weights_part_two)}')

Output of PART ONE :

tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.7915, 0.2085, 0.0000, 0.0000],
         [0.3695, 0.3792, 0.2513, 0.0000],
         [0.3029, 0.1970, 0.4318, 0.0683]]])
------------------------------------------------------------------------------------------
Output of PART TWO :

tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.7915, 0.2085, 0.0000, 0.0000],
         [0.3695, 0.3792, 0.2513, 0.0000],
         [0.3029, 0.1970, 0.4318, 0.0683]]])
------------------------------------------------------------------------------------------

PART ONE is equal to PART TWO ? True


# Afterwords 

We can feel comfortable proceeding with `nn.MultiheadAttention`, knowing exactly how the math is calculated.