# Findings 

- PyTorch's nn.MultiheadAttention is designed to be flexible for different use cases, so it requires Q, K, and V as inputs rather than assuming they will be derived from the same source every time. 

- Notice the input for Q,K,V can be different, which make it flexible to different use cases. 


### if `bias` = True and `add_kv_bias` = True

$Q = x \cdot W^Q + b^Q$

$K = x \cdot W^K + b^K + \text{bias}_{K_shared}$

$V = x \cdot W^V + b^V + \text{bias}_{V_shared}$



for simplicity we'll set both `bias` and `add_kv_bias` to `False` in `nn.MultiheadAttention`

In [22]:
import torch 
from torch import nn 

Equal Hyperparameters for PART ONE and PART TWO 

In [23]:
embed_dim = 4
num_heads = 2
num_tokens = 8
batch_size = 1

Equal inputs (with `torch.manual_seed`) for PART ONE and PART TWO 

In [24]:
torch.manual_seed(1)

input_to_produce_Q = torch.rand(batch_size,num_tokens,embed_dim)
input_to_produce_K = torch.rand(batch_size,num_tokens,embed_dim)
input_to_produce_V = torch.rand(batch_size,num_tokens,embed_dim)

# PART ONE : Pytorch's  Implementation 

In [25]:
torch.manual_seed(1)

multihead_attn = nn.MultiheadAttention(embed_dim=embed_dim,
                                       num_heads=num_heads,
                                       dropout=0,
                                       bias = True,
                                       add_bias_kv=False,
                                       batch_first=True,    # important
                                       device=None)



# To grab the weights pytorch's initialized .....
print('The following are W_q,W_k,W_v concatenated into one matrix, this is how pytorch initializes the weight matrix\n')
print(multihead_attn.in_proj_weight)
print(multihead_attn.in_proj_weight.shape)


print('----'*20)

print('THe following is out projection weight\n')
print(multihead_attn.out_proj.weight)  # Output projection weights
print(multihead_attn.out_proj.weight.shape)  # Output projection weights


The following are W_q,W_k,W_v concatenated into one matrix, this is how pytorch initializes the weight matrix

Parameter containing:
tensor([[-0.2281, -0.3698, -0.1026, -0.2641],
        [-0.1962,  0.0293,  0.3651,  0.3328],
        [-0.5986,  0.3796,  0.1711,  0.5809],
        [ 0.4042, -0.5580, -0.5822, -0.2954],
        [ 0.5377, -0.1020,  0.2621, -0.2846],
        [ 0.6009, -0.2591,  0.4592,  0.0073],
        [-0.3226,  0.3148, -0.3251,  0.1801],
        [-0.1768, -0.0671, -0.5887, -0.2920],
        [ 0.3323, -0.1489,  0.6100,  0.4909],
        [-0.0287, -0.4087,  0.3729,  0.1901],
        [-0.3958,  0.3978,  0.3718,  0.5431],
        [-0.3433, -0.1008, -0.0119,  0.0894]], requires_grad=True)
torch.Size([12, 4])
--------------------------------------------------------------------------------
THe following is out projection weight

Parameter containing:
tensor([[ 0.2576, -0.2207, -0.0969,  0.2347],
        [-0.4707,  0.2999, -0.1029,  0.2544],
        [ 0.0695, -0.0612,  0.1387,  0.

Note : pytorch `nn.MultiheadAttention()` does not mask (Q.K^T) by default. we have to pass the argument. 

In [26]:
attn_mask = torch.triu(torch.ones(num_tokens,num_tokens),diagonal=1).bool()
attn_mask

tensor([[False,  True,  True,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True],
        [False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False, False,  True],
        [False, False, False, False, False, False, False, False]])

### PART ONE RESULT

In [27]:

with torch.no_grad():
    PART_ONE_RESULT, attn_output_weights_part_one  = multihead_attn(input_to_produce_Q,input_to_produce_K,input_to_produce_V,
                                                                    attn_mask = attn_mask )

print('PART ONE RESULT :')
print(PART_ONE_RESULT)
print(PART_ONE_RESULT.shape)

print('----'*20)

print('PART ONE ATTENTION OUTPUT WEIGHTS :')
print(attn_output_weights_part_one)    # Notice pytorch doesn't return attn_output_weights for each single head separately, rather it takes the average across the heads. 
print(attn_output_weights_part_one.shape)

PART ONE RESULT :
tensor([[[ 0.0546, -0.3211,  0.1117,  0.0680],
         [ 0.0543, -0.3043,  0.1045,  0.0583],
         [ 0.0645, -0.3686,  0.1079,  0.0824],
         [ 0.0657, -0.3751,  0.1066,  0.0920],
         [ 0.0610, -0.3521,  0.1016,  0.0831],
         [ 0.0655, -0.3559,  0.0966,  0.0876],
         [ 0.0621, -0.3573,  0.0956,  0.0895],
         [ 0.0596, -0.3492,  0.0888,  0.0911]]])
torch.Size([1, 8, 4])
--------------------------------------------------------------------------------
PART ONE ATTENTION OUTPUT WEIGHTS :
tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4735, 0.5265, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3138, 0.3364, 0.3498, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2402, 0.2487, 0.2547, 0.2565, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1835, 0.2049, 0.2169, 0.2079, 0.1869, 0.0000, 0.0000, 0.0000],
         [0.1592, 0.1664, 0.1703, 0.1673, 0.1603, 0.1765, 0.0000, 0.0000],
         [0.124

we'll use the weights which `nn.MultiheadAttention` initialized, in our `FromScratchMultiheadAttention`  

# PART TWO : FROM SCRATCH

In [28]:
class FromScratchMultiheadAttention(nn.Module):
  def __init__(self,context_window,embed_dim,num_heads,dropout=0,add_bias_kv=False,device=None):
    super().__init__()

    # we will assume d_in == d_out and they are both embed_dim.

    # Handling dimensions
    assert embed_dim % num_heads == 0, 'Embedding must be divisible by Number of heads'
    self.embed_dim = embed_dim
    self.num_heads = num_heads
    self.head_dim = self.embed_dim//self.num_heads

    # W_q, W_k, W_v
    self.W_q      = nn.Linear(embed_dim,embed_dim,bias=False)
    self.W_k      = nn.Linear(embed_dim,embed_dim,bias=False)
    self.W_v      = nn.Linear(embed_dim,embed_dim,bias=False)
    self.out_proj = nn.Linear(embed_dim,embed_dim,bias=False)


    W_q,W_k,W_v = multihead_attn.in_proj_weight.chunk(3)   # pytorch's internal initialization of nn.MultiheadAttention. pytorch initialize all three (q,k,v) in a single matrix.
    out_proj = multihead_attn.out_proj.weight

    # we gonna put the initialized weight by nn.MultiheadAttention to our layers. so we can see if they will produce the same result
    self.W_q.weight.data = W_q
    self.W_k.weight.data = W_k
    self.W_v.weight.data = W_v
    self.out_proj.weight.data = out_proj



    # Miscellaneous
    self.register_buffer('mask',torch.triu(torch.ones(context_window,context_window),diagonal=1))
    self.dropout = nn.Dropout(dropout)

    
  
  def forward(self,input_to_produce_Q,input_to_produce_K,input_to_produce_V):
    B_q,num_token_q,embed_dim_q = input_to_produce_Q.shape
    B_k,num_token_k,embed_dim_k = input_to_produce_K.shape
    B_v,num_token_v,embed_dim_v = input_to_produce_V.shape

    Q = self.W_q(input_to_produce_Q) 
    K = self.W_k(input_to_produce_K) 
    V = self.W_v(input_to_produce_V) 

    # splitting 
    Q = Q.view(B_q,num_token_q,self.num_heads,self.head_dim).transpose(1,2)
    K = K.view(B_k,num_token_k,self.num_heads,self.head_dim).transpose(1,2)
    V = V.view(B_v,num_token_v,self.num_heads,self.head_dim).transpose(1,2)

    # QK,mask,softmax,dropout
    attn_score = Q @ K.transpose(2,3)
    attn_score.masked_fill_(self.mask.bool()[:num_token_q,:num_token_k],-torch.inf)
    attn_weight = torch.softmax(attn_score/K.shape[-1]**0.5,dim=-1)
    attn_weight = self.dropout(attn_weight)

    # context_vec
    context_vec = attn_weight @ V

    # Putting the heads back together 
    context_vec = context_vec.transpose(1,2).contiguous().view(B_q,num_token_q,self.embed_dim)    # it doesn't matter which (B) you choose

    # projection 
    context_vec = self.out_proj(context_vec)

    return context_vec,attn_weight

    

### PART TWO RESULT  

Note: you can think of context_window as a max num_tokens your model can process at one go. since  we are feeding the model 8 tokens (numz-tokens = 8), context_window anything above 8 will do. 


In [29]:
mha = FromScratchMultiheadAttention(context_window=1024,   # see the above note.
                                    embed_dim=embed_dim,
                                    num_heads=num_heads,
                                    dropout=0.)


with torch.no_grad():
    PART_TWO_RESULT,attn_output_weights_part_two = mha(input_to_produce_Q,input_to_produce_K,input_to_produce_V)

print('PART TWO RESULT :')
print(PART_TWO_RESULT)
print(PART_TWO_RESULT.shape)

print('----'*20)

print('PART TWO ATTENTION OUTPUT WEIGHTS :')
print(attn_output_weights_part_two)    # Notice pytorch doesn't return attn_output_weights for each single head separately, rather it takes the average across the heads. 
print(attn_output_weights_part_two.shape)

PART TWO RESULT :
tensor([[[ 0.0546, -0.3211,  0.1117,  0.0680],
         [ 0.0543, -0.3043,  0.1045,  0.0583],
         [ 0.0645, -0.3686,  0.1079,  0.0824],
         [ 0.0657, -0.3751,  0.1066,  0.0920],
         [ 0.0610, -0.3521,  0.1016,  0.0831],
         [ 0.0655, -0.3559,  0.0966,  0.0876],
         [ 0.0621, -0.3573,  0.0956,  0.0895],
         [ 0.0596, -0.3492,  0.0888,  0.0911]]])
torch.Size([1, 8, 4])
--------------------------------------------------------------------------------
PART TWO ATTENTION OUTPUT WEIGHTS :
tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.4982, 0.5018, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.3322, 0.3297, 0.3381, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.2468, 0.2415, 0.2451, 0.2666, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.1961, 0.1983, 0.2063, 0.2041, 0.1951, 0.0000, 0.0000, 0.0000],
          [0.1642, 0.1645, 0.1671, 0.1681, 0.1646, 0.1715, 0.0000, 0.0000],
         

In [30]:
attn_output_weights_part_two = attn_output_weights_part_two.mean(dim=1)   # taking average across heads dimension.
attn_output_weights_part_two

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4735, 0.5265, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3138, 0.3364, 0.3498, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2402, 0.2487, 0.2547, 0.2565, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1835, 0.2049, 0.2169, 0.2079, 0.1869, 0.0000, 0.0000, 0.0000],
         [0.1592, 0.1664, 0.1703, 0.1673, 0.1603, 0.1765, 0.0000, 0.0000],
         [0.1249, 0.1399, 0.1482, 0.1445, 0.1283, 0.1636, 0.1507, 0.0000],
         [0.1023, 0.1168, 0.1248, 0.1264, 0.1075, 0.1425, 0.1287, 0.1510]]])

# COMPARISON

### Comparing `PART_ONE_RESULT` with `PART_TWO_RESULT`

In [31]:
print(f'Output of PART ONE :\n\n{PART_ONE_RESULT}')
print('---'*30)
print(f'Output of PART TWO :\n\n{PART_TWO_RESULT}')
print('---'*30)
print(f'\nPART ONE is equal to PART TWO ? {torch.allclose(PART_ONE_RESULT, PART_TWO_RESULT)}')


Output of PART ONE :

tensor([[[ 0.0546, -0.3211,  0.1117,  0.0680],
         [ 0.0543, -0.3043,  0.1045,  0.0583],
         [ 0.0645, -0.3686,  0.1079,  0.0824],
         [ 0.0657, -0.3751,  0.1066,  0.0920],
         [ 0.0610, -0.3521,  0.1016,  0.0831],
         [ 0.0655, -0.3559,  0.0966,  0.0876],
         [ 0.0621, -0.3573,  0.0956,  0.0895],
         [ 0.0596, -0.3492,  0.0888,  0.0911]]])
------------------------------------------------------------------------------------------
Output of PART TWO :

tensor([[[ 0.0546, -0.3211,  0.1117,  0.0680],
         [ 0.0543, -0.3043,  0.1045,  0.0583],
         [ 0.0645, -0.3686,  0.1079,  0.0824],
         [ 0.0657, -0.3751,  0.1066,  0.0920],
         [ 0.0610, -0.3521,  0.1016,  0.0831],
         [ 0.0655, -0.3559,  0.0966,  0.0876],
         [ 0.0621, -0.3573,  0.0956,  0.0895],
         [ 0.0596, -0.3492,  0.0888,  0.0911]]])
------------------------------------------------------------------------------------------

PART ONE is equal

### Comparing `attn_output_weights_part_one` with `attn_output_weights_part_two`

In [32]:
print(f'Output of PART ONE :\n\n{attn_output_weights_part_one}')
print('---'*30)
print(f'Output of PART TWO :\n\n{attn_output_weights_part_two}')
print('---'*30)
print(f'\nPART ONE is equal to PART TWO ? {torch.allclose(attn_output_weights_part_one, attn_output_weights_part_two)}')

Output of PART ONE :

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4735, 0.5265, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3138, 0.3364, 0.3498, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2402, 0.2487, 0.2547, 0.2565, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1835, 0.2049, 0.2169, 0.2079, 0.1869, 0.0000, 0.0000, 0.0000],
         [0.1592, 0.1664, 0.1703, 0.1673, 0.1603, 0.1765, 0.0000, 0.0000],
         [0.1249, 0.1399, 0.1482, 0.1445, 0.1283, 0.1636, 0.1507, 0.0000],
         [0.1023, 0.1168, 0.1248, 0.1264, 0.1075, 0.1425, 0.1287, 0.1510]]])
------------------------------------------------------------------------------------------
Output of PART TWO :

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4735, 0.5265, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3138, 0.3364, 0.3498, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2402, 0.2487, 0.2547, 0.25

# Afterwords 

We can feel comfortable proceeding with `nn.MultiheadAttention`, knowing exactly how the math is calculated.