# Topic 

This notebook reproduces pytorch's built-in module `nn.MultiheadAttention`.

# Structure 

Configuration
  - Identical hyperparameters and inputs used for both parts


Implementation Methods

  - Part One: PyTorch's Built-in Implementation
  - Part Two: Custom Implementation from Scratch


Validation

  - Results comparison between Part One and Part Two
  - Verification of output equality

Afterwords 


In [1]:
import torch 
from torch import nn 

# Configuration

identical Hyperparameters for PART ONE and PART TWO 

In [2]:
num_heads = 2


identical inputs (with `torch.manual_seed`) for PART ONE and PART TWO 

In [15]:
torch.manual_seed(1)

# why different inputs to produce Q,K,V ?? 
# stay tunned. 


x = torch.tensor(
    [[[ 0.2096,  1.4551, -0.3562, -1.3084],
    [ 1.1463,  0.5509, -1.5349, -0.1624],
    [-0.3144,  1.7046, -0.5748, -0.8154],
    [ 1.6361, -0.5812, -0.0645, -0.9905]]]
         )

print(x.shape)


batch_size,num_tokens,embed_dim = x.shape
# input_to_produce_Q = torch.rand(batch_size,num_tokens,embed_dim)
# input_to_produce_K = torch.rand(batch_size,num_tokens,embed_dim)
# input_to_produce_V = torch.rand(batch_size,num_tokens,embed_dim)

input_to_produce_Q = x 
input_to_produce_K = x 
input_to_produce_V = x 



torch.Size([1, 4, 4])


# Implementation Methods 

### PART ONE : Pytorch's  Implementation 

In [18]:
torch.manual_seed(1)

multihead_attn = nn.MultiheadAttention(embed_dim=embed_dim,
                                       num_heads=num_heads,
                                       dropout=0,
                                       bias = False,
                                       add_bias_kv=False,
                                       batch_first=True,    # important
                                       device=None)



# To grab initialized `nn.MultiheadAttention`s weights.....
print('The following are W_q,W_k,W_v concatenated into one matrix, this is how pytorch initializes the weight matrix\n')
print(multihead_attn.in_proj_weight)
print(multihead_attn.in_proj_weight.shape)


print('----'*20)

print('THe following is out projection weight\n')
print(multihead_attn.out_proj.weight)  # Output projection weights
print(multihead_attn.out_proj.weight.shape)  # Output projection shape


The following are W_q,W_k,W_v concatenated into one matrix, this is how pytorch initializes the weight matrix

Parameter containing:
tensor([[ 0.0888, -0.0024,  0.5353,  0.1906],
        [-0.2281, -0.3698, -0.1026, -0.2641],
        [-0.1962,  0.0293,  0.3651,  0.3328],
        [-0.5986,  0.3796,  0.1711,  0.5809],
        [ 0.4042, -0.5580, -0.5822, -0.2954],
        [ 0.5377, -0.1020,  0.2621, -0.2846],
        [ 0.6009, -0.2591,  0.4592,  0.0073],
        [-0.3226,  0.3148, -0.3251,  0.1801],
        [-0.1768, -0.0671, -0.5887, -0.2920],
        [ 0.3323, -0.1489,  0.6100,  0.4909],
        [-0.0287, -0.4087,  0.3729,  0.1901],
        [-0.3958,  0.3978,  0.3718,  0.5431]], requires_grad=True)
torch.Size([12, 4])
--------------------------------------------------------------------------------
THe following is out projection weight

Parameter containing:
tensor([[ 0.2576, -0.2207, -0.0969,  0.2347],
        [-0.4707,  0.2999, -0.1029,  0.2544],
        [ 0.0695, -0.0612,  0.1387,  0.

### Few words on `nn.MultiheadAttention` arguments.

### if `bias` = True and `add_kv_bias` = True

$Q = x \cdot W^Q + b^Q$

$K = x \cdot W^K + b^K + \text{bias}_{K_shared}$

$V = x \cdot W^V + b^V + \text{bias}_{V_shared}$


for simplicity we'll set both `bias` and `add_kv_bias` to `False` in `nn.MultiheadAttention`

Note : pytorch `nn.MultiheadAttention()` does not mask (Q.K^T) by default. we have to pass the argument. 

In [13]:
W_q,W_k,W_v = multihead_attn.in_proj_weight.chunk(3)
print(W_q)
print('----'*20)
print(W_k)
print('----'*20)
print(W_v)



tensor([[ 0.0888, -0.0024,  0.5353,  0.1906],
        [-0.2281, -0.3698, -0.1026, -0.2641],
        [-0.1962,  0.0293,  0.3651,  0.3328],
        [-0.5986,  0.3796,  0.1711,  0.5809]], grad_fn=<SplitBackward0>)
--------------------------------------------------------------------------------
tensor([[ 0.4042, -0.5580, -0.5822, -0.2954],
        [ 0.5377, -0.1020,  0.2621, -0.2846],
        [ 0.6009, -0.2591,  0.4592,  0.0073],
        [-0.3226,  0.3148, -0.3251,  0.1801]], grad_fn=<SplitBackward0>)
--------------------------------------------------------------------------------
tensor([[-0.1768, -0.0671, -0.5887, -0.2920],
        [ 0.3323, -0.1489,  0.6100,  0.4909],
        [-0.0287, -0.4087,  0.3729,  0.1901],
        [-0.3958,  0.3978,  0.3718,  0.5431]], grad_fn=<SplitBackward0>)


In [5]:
attn_mask = torch.triu(torch.ones(num_tokens,num_tokens),diagonal=1).bool()
attn_mask

tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])

### PART ONE RESULT

- PyTorch's nn.MultiheadAttention is designed to be flexible for different use cases, so it requires Q, K, and V as inputs rather than assuming they will be derived from the same source every time. 

In [6]:

with torch.no_grad():
    PART_ONE_RESULT, attn_output_weights_part_one  = multihead_attn(input_to_produce_Q,
                                                                    input_to_produce_K,
                                                                    input_to_produce_V,
                                                                    attn_mask = attn_mask )

print('PART ONE RESULT :')
print(PART_ONE_RESULT)
print(PART_ONE_RESULT.shape)

print('----'*20)

print('PART ONE ATTENTION OUTPUT WEIGHTS :')
print(attn_output_weights_part_one)    # Notice pytorch doesn't return attn_output_weights for each single head separately, rather it takes the average across the heads. 
print(attn_output_weights_part_one.shape)

PART ONE RESULT :
tensor([[[ 0.3536, -0.5042, -0.0514,  0.3310],
         [ 0.2896, -0.5853, -0.0499,  0.3369],
         [ 0.3767, -0.5305, -0.0452,  0.3406],
         [ 0.1185, -0.5664, -0.0260,  0.2704]]])
torch.Size([1, 4, 4])
--------------------------------------------------------------------------------
PART ONE ATTENTION OUTPUT WEIGHTS :
tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.5976, 0.4024, 0.0000, 0.0000],
         [0.3300, 0.2603, 0.4097, 0.0000],
         [0.2358, 0.2138, 0.2068, 0.3436]]])
torch.Size([1, 4, 4])


we'll use the weights which `nn.MultiheadAttention` initialized, in our `FromScratchMultiheadAttention`  

### PART TWO : FROM SCRATCH

In [7]:
class FromScratchMultiheadAttention(nn.Module):
  def __init__(self,context_window,embed_dim,num_heads,dropout=0,add_bias_kv=False,device=None):
    super().__init__()

    # we will assume d_in == d_out and they are both embed_dim.

    # Handling dimensions
    assert embed_dim % num_heads == 0, 'Embedding must be divisible by Number of heads'
    self.embed_dim = embed_dim
    self.num_heads = num_heads
    self.head_dim = self.embed_dim//self.num_heads

    # W_q, W_k, W_v
    self.W_q      = nn.Linear(embed_dim,embed_dim,bias=False)
    self.W_k      = nn.Linear(embed_dim,embed_dim,bias=False)
    self.W_v      = nn.Linear(embed_dim,embed_dim,bias=False)
    self.out_proj = nn.Linear(embed_dim,embed_dim,bias=False)


    W_q,W_k,W_v = multihead_attn.in_proj_weight.chunk(3)   # pytorch's internal initialization of nn.MultiheadAttention. pytorch initialize all three (q,k,v) in a single matrix.
    out_proj = multihead_attn.out_proj.weight

    # we gonna put the initialized weight by nn.MultiheadAttention to our layers. so we can see if they will produce the same result
    self.W_q.weight.data = W_q
    self.W_k.weight.data = W_k
    self.W_v.weight.data = W_v
    self.out_proj.weight.data = out_proj



    # Miscellaneous
    self.register_buffer('mask',torch.triu(torch.ones(context_window,context_window),diagonal=1))
    self.dropout = nn.Dropout(dropout)

    
  
  def forward(self,input_to_produce_Q,input_to_produce_K,input_to_produce_V):
    B_q,num_token_q,embed_dim_q = input_to_produce_Q.shape
    B_k,num_token_k,embed_dim_k = input_to_produce_K.shape
    B_v,num_token_v,embed_dim_v = input_to_produce_V.shape

    Q = self.W_q(input_to_produce_Q) 
    K = self.W_k(input_to_produce_K) 
    V = self.W_v(input_to_produce_V) 

    # splitting (turn it to multi-head)
    Q = Q.view(B_q,num_token_q,self.num_heads,self.head_dim).transpose(1,2)
    K = K.view(B_k,num_token_k,self.num_heads,self.head_dim).transpose(1,2)
    V = V.view(B_v,num_token_v,self.num_heads,self.head_dim).transpose(1,2)

    # QK,mask,softmax,dropout
    attn_score = Q @ K.transpose(2,3)
    attn_score.masked_fill_(self.mask.bool()[:num_token_q,:num_token_k],-torch.inf)
    attn_weight = torch.softmax(attn_score/K.shape[-1]**0.5,dim=-1)
    attn_weight = self.dropout(attn_weight)

    # context_vec
    context_vec = attn_weight @ V

    # Putting the heads back together 
    context_vec = context_vec.transpose(1,2).contiguous().view(B_q,num_token_q,self.embed_dim)    # it doesn't matter which (B) you choose

    # projection 
    context_vec = self.out_proj(context_vec)

    return context_vec,attn_weight

    

### PART TWO RESULT  

Note: you can think of context_window as a max num_tokens your model can process at one go. since  we are feeding the model 8 tokens (numz-tokens = 8), context_window anything above 8 will do. 


In [8]:
mha = FromScratchMultiheadAttention(context_window=1024,   # see the above note.
                                    embed_dim=embed_dim,
                                    num_heads=num_heads,
                                    dropout=0.)


with torch.no_grad():
    PART_TWO_RESULT,attn_output_weights_part_two = mha(input_to_produce_Q,input_to_produce_K,input_to_produce_V)

print('PART TWO RESULT :')
print(PART_TWO_RESULT)
print(PART_TWO_RESULT.shape)

print('----'*20)

print('PART TWO ATTENTION OUTPUT WEIGHTS :')
print(attn_output_weights_part_two)    # Notice pytorch doesn't return attn_output_weights for each single head separately, rather it takes the average across the heads. 
print(attn_output_weights_part_two.shape)

PART TWO RESULT :
tensor([[[ 0.3536, -0.5042, -0.0514,  0.3310],
         [ 0.2896, -0.5853, -0.0499,  0.3369],
         [ 0.3767, -0.5305, -0.0452,  0.3406],
         [ 0.1185, -0.5664, -0.0260,  0.2704]]])
torch.Size([1, 4, 4])
--------------------------------------------------------------------------------
PART TWO ATTENTION OUTPUT WEIGHTS :
tensor([[[[1.0000, 0.0000, 0.0000, 0.0000],
          [0.6564, 0.3436, 0.0000, 0.0000],
          [0.3431, 0.2247, 0.4322, 0.0000],
          [0.2558, 0.2385, 0.2509, 0.2548]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.5388, 0.4612, 0.0000, 0.0000],
          [0.3169, 0.2959, 0.3872, 0.0000],
          [0.2158, 0.1890, 0.1628, 0.4324]]]])
torch.Size([1, 2, 4, 4])


In [9]:
attn_output_weights_part_two = attn_output_weights_part_two.mean(dim=1)   # taking average across heads dimension.
attn_output_weights_part_two

tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.5976, 0.4024, 0.0000, 0.0000],
         [0.3300, 0.2603, 0.4097, 0.0000],
         [0.2358, 0.2138, 0.2068, 0.3436]]])

# Validation

### Comparing `PART_ONE_RESULT` with `PART_TWO_RESULT`

In [10]:
print(f'Output of PART ONE :\n\n{PART_ONE_RESULT}')
print('---'*30)
print(f'Output of PART TWO :\n\n{PART_TWO_RESULT}')
print('---'*30)
print(f'\nPART ONE is equal to PART TWO ? {torch.allclose(PART_ONE_RESULT, PART_TWO_RESULT)}')


Output of PART ONE :

tensor([[[ 0.3536, -0.5042, -0.0514,  0.3310],
         [ 0.2896, -0.5853, -0.0499,  0.3369],
         [ 0.3767, -0.5305, -0.0452,  0.3406],
         [ 0.1185, -0.5664, -0.0260,  0.2704]]])
------------------------------------------------------------------------------------------
Output of PART TWO :

tensor([[[ 0.3536, -0.5042, -0.0514,  0.3310],
         [ 0.2896, -0.5853, -0.0499,  0.3369],
         [ 0.3767, -0.5305, -0.0452,  0.3406],
         [ 0.1185, -0.5664, -0.0260,  0.2704]]])
------------------------------------------------------------------------------------------

PART ONE is equal to PART TWO ? True


### Comparing `attn_output_weights_part_one` with `attn_output_weights_part_two`

In [11]:
print(f'Output of PART ONE :\n\n{attn_output_weights_part_one}')
print('---'*30)
print(f'Output of PART TWO :\n\n{attn_output_weights_part_two}')
print('---'*30)
print(f'\nPART ONE is equal to PART TWO ? {torch.allclose(attn_output_weights_part_one, attn_output_weights_part_two)}')

Output of PART ONE :

tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.5976, 0.4024, 0.0000, 0.0000],
         [0.3300, 0.2603, 0.4097, 0.0000],
         [0.2358, 0.2138, 0.2068, 0.3436]]])
------------------------------------------------------------------------------------------
Output of PART TWO :

tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.5976, 0.4024, 0.0000, 0.0000],
         [0.3300, 0.2603, 0.4097, 0.0000],
         [0.2358, 0.2138, 0.2068, 0.3436]]])
------------------------------------------------------------------------------------------

PART ONE is equal to PART TWO ? True


# Afterwords 

We can feel comfortable proceeding with `nn.MultiheadAttention`, knowing exactly how the math is calculated.