## Multi-head attention transformer
### Encoder and Decoder
### (With masking)

### Initialising the same transformer as in the other notebook


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512, dropout=0):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        print(self.encoding)
        return self.encoding[:, :x.size(1)].detach()



class TransformerModel1(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_encoder_layers, num_decoder_layers, max_seq_len, d_ff, dropout = 0):

        super(TransformerModel1, self).__init__()

        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)

        self.positional_encoding = PositionalEncoding(d_model, dropout=0, max_len=max_seq_len)

        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=num_heads,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dropout=dropout,
            dim_feedforward=d_ff,
        )

        self.fc = nn.Linear(d_model, tgt_vocab_size)



    def generate_mask(self, src, tgt):

        src_mask = None
        seq_length = tgt.size(0)
        
        nopeak_mask = (torch.triu(torch.ones(seq_length, seq_length), diagonal=1)).bool()

        return src_mask, nopeak_mask

    def forward(self, src, tgt):

        src_mask, tgt_mask = self.generate_mask(src, tgt)

        print("Tgt mask shape = ", tgt_mask.shape)

        src = self.src_embedding(src) + self.positional_encoding(src)
        tgt = self.tgt_embedding(tgt) + self.positional_encoding(tgt)


        output = self.transformer(src, tgt, src_mask = src_mask, tgt_mask = tgt_mask, tgt_is_causal = False)
        output = self.fc(output)
        
        return output
    

### Hyperparameters defined 

In [2]:
import numpy as np

src_vocab_size = 20
tgt_vocab_size = 20
d_model = 16
num_heads = 4
num_encoder_layers = 1
num_decoder_layers = 1
d_ff = 20
max_seq_len = 5
dropout = 0


src_data = np.array([[2], [1], [5], [4]])
tgt_data = np.array([[1], [16], [5], [3], [9]])



In [3]:
torch.manual_seed(0)


transformer = TransformerModel1(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_encoder_layers, num_decoder_layers, max_seq_len, d_ff)

state_dict = transformer.state_dict()



### 1. Source token embeddings 

In [4]:

src_vocab_embeds = state_dict["src_embedding.weight"]

src_embedding = np.zeros((src_data.shape[0], d_model))

for i in range(src_data.shape[0]):
        word_index = src_data[i]
        if word_index < 0 or word_index >= src_vocab_embeds.shape[0]:
            raise ValueError(f"Invalid word index: {word_index}")
        src_embedding[i, :] = src_vocab_embeds[word_index, :]

        print(f"Word index: {word_index}, Embedding: {src_vocab_embeds[word_index, :]}")
print()
print(src_embedding.shape)


Word index: [2], Embedding: tensor([[-0.6136,  0.0316, -0.4927,  0.2484,  0.4397,  0.1124,  0.6408,  0.4412,
         -0.1023,  0.7924, -0.2897,  0.0525,  0.5229,  2.3022, -1.4689, -1.5867]])
Word index: [1], Embedding: tensor([[-1.3527, -1.6959,  0.5667,  0.7935,  0.5988, -1.5551, -0.3414,  1.8530,
          0.7502, -0.5855, -0.1734,  0.1835,  1.3894,  1.5863,  0.9463, -0.8437]])
Word index: [5], Embedding: tensor([[-9.3348e-02,  6.8705e-01, -8.3832e-01,  8.9182e-04,  8.4189e-01,
         -4.0003e-01,  1.0395e+00,  3.5815e-01, -2.4600e-01,  2.3025e+00,
         -1.8817e+00, -4.9727e-02, -1.0450e+00, -9.5650e-01,  3.3532e-02,
          7.1009e-01]])
Word index: [4], Embedding: tensor([[-0.5692,  0.9200,  1.1108,  1.2899, -1.4782,  2.5672, -0.4731,  0.3356,
         -1.6293, -0.5497, -0.4798, -0.4997, -1.0670,  1.1149, -0.1407,  0.8058]])

(4, 16)


In [5]:
src_embedding.shape

(4, 16)

In [6]:
import numpy as np

class PositionalEncoding:

    def __init__(self, d_model, max_len=512, dropout=0):

        self.encoding = np.zeros((max_len, d_model))

        position = np.arange(0, max_len).reshape(-1, 1).astype(np.float32)

        div_term = np.exp(np.arange(0, d_model, 2).astype(np.float32) * -(np.log(10000.0) / d_model))

        self.encoding[:, 0::2] = np.sin(position * div_term)
        self.encoding[:, 1::2] = np.cos(position * div_term)
        self.encoding = self.encoding[np.newaxis, :]

    def forward(self, x):
        enc = self.encoding
        # print(enc[0][0].shape)
        return self.encoding[0][0]

### 2. Source token embeddings + positional embeddings 

In [7]:
d_model = 16
max_seq_len = 5

pe = PositionalEncoding(d_model=d_model, max_len=max_seq_len)

pe_src_embeds = src_embedding + pe.forward(src_data)


pe_src_embeds.shape, pe_src_embeds

((4, 16),
 array([[-0.61358309,  1.03159274, -0.49267703,  1.24841475,  0.43969586,
          1.11241119,  0.64079237,  1.44115627, -0.10230965,  1.79244399,
         -0.28966758,  1.05250749,  0.52286041,  3.30220532, -1.46889389,
         -0.58668876],
        [-1.35265374, -0.69593132,  0.56665051,  1.79350841,  0.59883946,
         -0.55509508, -0.3413603 ,  2.85300612,  0.75018942,  0.41450286,
         -0.17339702,  1.18347792,  1.38936615,  2.58633435,  0.94629836,
          0.15632319],
        [-0.09334823,  1.68705022, -0.83831537,  1.00089182,  0.84189409,
          0.59996545,  1.03946197,  1.3581531 , -0.24600095,  3.30251646,
         -1.88168919,  0.95027298, -1.04497862,  0.04349947,  0.03353186,
          1.71008658],
        [-0.56924802,  1.91997129,  1.11081612,  2.28987384, -1.47817433,
          3.56723285, -0.4731198 ,  1.33555073, -1.62932599,  0.45025635,
         -0.47983426,  0.50031784, -1.06698   ,  2.11493957, -0.14067143,
          1.80575365]]))

In [8]:
x_enc = pe_src_embeds
x_enc.shape

(4, 16)

### 3. Getting the Q,K,V matrices from the model's intialised weights

In [9]:
import numpy as np

query_enc = key_enc = value_enc = x_enc

tgt_len, embed_dim = x_enc.shape

layer_num = 0

W_enc = state_dict["transformer.encoder.layers.{}.self_attn.in_proj_weight".format(layer_num)].numpy()
b_enc = state_dict["transformer.encoder.layers.{}.self_attn.in_proj_bias".format(layer_num)].numpy()

head_dim = embed_dim // num_heads

# embed_dim
E = query_enc.shape[-1]

tempop1 = np.matmul(query_enc, W_enc.T)

Q_enc, K_enc, V_enc = tempop1[:, 0:embed_dim], tempop1[:, embed_dim:2*embed_dim], tempop1[:, 2*embed_dim:3*embed_dim]

print("Q_enc_shape = ", Q_enc.shape)
print("K_enc_shape = ", K_enc.shape)
print("V_enc_shape = ", V_enc.shape)
print()

print("After reshaping... \n")

Q_enc = np.transpose(np.reshape(Q_enc, (tgt_len, num_heads, head_dim)), (1, 0, 2))
K_enc = np.transpose(np.reshape(K_enc, (K_enc.shape[0], num_heads, head_dim)), (1, 0, 2))
V_enc = np.transpose(np.reshape(V_enc, (V_enc.shape[0], num_heads, head_dim)), (1, 0, 2))

print("Q_enc_shape = ", Q_enc.shape)
print("K_enc_shape = ", K_enc.shape)
print("V_enc_shape = ", V_enc.shape)


Q_enc_shape =  (4, 16)
K_enc_shape =  (4, 16)
V_enc_shape =  (4, 16)

After reshaping... 

Q_enc_shape =  (4, 4, 4)
K_enc_shape =  (4, 4, 4)
V_enc_shape =  (4, 4, 4)


In [10]:
print("Q_enc_{} = ".format(layer_num))
print(Q_enc)
print()

print("K_enc_{} = ".format(layer_num))
print(K_enc)
print()

print("V_enc_{} = ".format(layer_num))
print(V_enc)
print()

Q_enc_0 = 
[[[-1.36735794  0.38451974 -0.62079629  0.31211979]
  [-1.11743603  0.75657466 -0.97703017  1.40062991]
  [-1.38411373  0.49395312 -0.61819885 -1.53646693]
  [-0.82986259 -0.55943665 -1.79558619 -0.35191729]]

 [[ 1.43866634 -0.02124564  0.67543413 -0.15366692]
  [ 1.07486503  0.93711323 -0.45114472  0.68147772]
  [-0.07431921 -0.52279624  0.71932026 -1.18875829]
  [ 0.984559   -0.70037252  1.17993382 -0.53922323]]

 [[-0.48074855  1.16585093  1.03184366  0.8862985 ]
  [-0.296841    1.25991946  0.79888877  0.90739251]
  [-0.11831912  0.54421682  2.078344   -0.07897106]
  [-0.99450992 -0.40800719  0.73816109  1.63374404]]

 [[ 0.18291803 -0.66727623 -1.28919706  1.01650888]
  [-0.43740222 -0.45591594 -2.20209216  1.09265755]
  [ 0.2551974   0.32470998  1.05439451 -1.19979177]
  [ 1.92657715  0.53277614  0.60048078  0.66183671]]]

K_enc_0 = 
[[[ 1.38859394 -0.09582228 -0.18015615  1.11433203]
  [ 0.85509775 -0.32077069  0.12798366  1.44537783]
  [ 1.49764453 -1.35984899 -0.659

In [11]:
K_enc.shape, Q_enc.shape, V_enc.shape

((4, 4, 4), (4, 4, 4), (4, 4, 4))

### 4. Self attention for encoder layer 1

#### Attention calculation with Q,K and V matrices

In [12]:
src_len = K_enc.shape[1]

Q_enc1 = Q_enc
K_enc1 = K_enc
V_enc1 = V_enc

scale_factor = 1 / math.sqrt(Q_enc1.shape[-1]) 

K_enc1_T = np.transpose(K_enc1, axes=(0, 2, 1))

attn_weight = Q_enc1 @ K_enc1_T * scale_factor


exp_attn_weight = np.exp(attn_weight)
sum_exp_attn_weight = np.sum(exp_attn_weight, axis=-1, keepdims=True)
softmax_attn_weight = exp_attn_weight / sum_exp_attn_weight

print("SoftMax (Scaled Dot Product of Q_enc and K_enc) = ")
softmax_attn_weight

SoftMax (Scaled Dot Product of Q_enc and K_enc) = 


array([[[0.24622994, 0.32498347, 0.17093051, 0.25785608],
        [0.29817845, 0.40020476, 0.09113512, 0.21048167],
        [0.14513324, 0.14001347, 0.30002674, 0.41482654],
        [0.18297399, 0.17394416, 0.47776105, 0.16532081]],

       [[0.14055266, 0.44753894, 0.16198232, 0.24992608],
        [0.10363997, 0.39102536, 0.23427295, 0.27106171],
        [0.39596284, 0.19193935, 0.14620412, 0.2658937 ],
        [0.22779559, 0.38733568, 0.16309258, 0.22177616]],

       [[0.26347481, 0.42828542, 0.18578728, 0.1224525 ],
        [0.25692131, 0.41706922, 0.19829975, 0.12770972],
        [0.23173862, 0.14811665, 0.30689119, 0.31325353],
        [0.27611089, 0.44084326, 0.11295147, 0.17009438]],

       [[0.3910106 , 0.17491111, 0.38204033, 0.05203796],
        [0.39744782, 0.19527263, 0.38060955, 0.02667   ],
        [0.11367704, 0.22355699, 0.09936722, 0.56339875],
        [0.32765678, 0.17762593, 0.37389094, 0.12082634]]])

In [13]:
attn_output = softmax_attn_weight @ V_enc1

In [14]:
print("Final Attention Values = \n")
print(attn_output)

Final Attention Values = 

[[[ 0.71996271  0.37472803  0.19898526 -0.096828  ]
  [ 0.74778941  0.41424777  0.14055333 -0.09750273]
  [ 0.6835169   0.43522338  0.44485323 -0.15091048]
  [ 0.55486835 -0.18135736 -0.03144966  0.15490696]]

 [[ 0.72312544  0.11965125 -0.61675933  0.41149321]
  [ 0.83353188  0.10672244 -0.50738729  0.4131849 ]
  [ 0.68540562  0.11745112 -0.58773571  0.70770561]
  [ 0.65851092  0.14600813 -0.62973254  0.47600028]]

 [[-0.17895635 -0.97088938  0.21564846  0.34026236]
  [-0.15921384 -0.95981102  0.22985971  0.33606611]
  [ 0.17079556 -0.57683199  0.32521917  0.4177635 ]
  [-0.21387931 -0.8561706   0.1344209   0.39797618]]

 [[ 0.45370401  0.97827891  0.88422869 -0.1181432 ]
  [ 0.39785858  0.99090853  0.84784162 -0.1675067 ]
  [ 0.97294562  0.8204363   1.46646829 -0.43123435]
  [ 0.54188749  0.98235489  0.95371993 -0.10789927]]]


#### Reshaping the final attention output

In [15]:
attn_output_permuted_sa = np.transpose(attn_output, axes=(0,1,2))

attn_output_permuted_sa

array([[[ 0.71996271,  0.37472803,  0.19898526, -0.096828  ],
        [ 0.74778941,  0.41424777,  0.14055333, -0.09750273],
        [ 0.6835169 ,  0.43522338,  0.44485323, -0.15091048],
        [ 0.55486835, -0.18135736, -0.03144966,  0.15490696]],

       [[ 0.72312544,  0.11965125, -0.61675933,  0.41149321],
        [ 0.83353188,  0.10672244, -0.50738729,  0.4131849 ],
        [ 0.68540562,  0.11745112, -0.58773571,  0.70770561],
        [ 0.65851092,  0.14600813, -0.62973254,  0.47600028]],

       [[-0.17895635, -0.97088938,  0.21564846,  0.34026236],
        [-0.15921384, -0.95981102,  0.22985971,  0.33606611],
        [ 0.17079556, -0.57683199,  0.32521917,  0.4177635 ],
        [-0.21387931, -0.8561706 ,  0.1344209 ,  0.39797618]],

       [[ 0.45370401,  0.97827891,  0.88422869, -0.1181432 ],
        [ 0.39785858,  0.99090853,  0.84784162, -0.1675067 ],
        [ 0.97294562,  0.8204363 ,  1.46646829, -0.43123435],
        [ 0.54188749,  0.98235489,  0.95371993, -0.10789927]]])

In [16]:
numh_tgt_len, embed_dim = num_heads * tgt_len, head_dim
attn_output_reshaped_sa = attn_output_permuted_sa.reshape(numh_tgt_len, embed_dim)


attn_output_reshaped_sa, attn_output_reshaped_sa.shape

(array([[ 0.71996271,  0.37472803,  0.19898526, -0.096828  ],
        [ 0.74778941,  0.41424777,  0.14055333, -0.09750273],
        [ 0.6835169 ,  0.43522338,  0.44485323, -0.15091048],
        [ 0.55486835, -0.18135736, -0.03144966,  0.15490696],
        [ 0.72312544,  0.11965125, -0.61675933,  0.41149321],
        [ 0.83353188,  0.10672244, -0.50738729,  0.4131849 ],
        [ 0.68540562,  0.11745112, -0.58773571,  0.70770561],
        [ 0.65851092,  0.14600813, -0.62973254,  0.47600028],
        [-0.17895635, -0.97088938,  0.21564846,  0.34026236],
        [-0.15921384, -0.95981102,  0.22985971,  0.33606611],
        [ 0.17079556, -0.57683199,  0.32521917,  0.4177635 ],
        [-0.21387931, -0.8561706 ,  0.1344209 ,  0.39797618],
        [ 0.45370401,  0.97827891,  0.88422869, -0.1181432 ],
        [ 0.39785858,  0.99090853,  0.84784162, -0.1675067 ],
        [ 0.97294562,  0.8204363 ,  1.46646829, -0.43123435],
        [ 0.54188749,  0.98235489,  0.95371993, -0.10789927]]),
 (16, 

In [17]:
final_attn_sa_op = np.zeros(attn_output_reshaped_sa.shape)

i = 0

while i < attn_output_reshaped_sa.shape[1]:
    for j in range(attn_output_reshaped_sa.shape[1]):

        pos = i*attn_output_reshaped_sa.shape[1] + j


        blk = (j)*attn_output_reshaped_sa.shape[1]
        offset = i
        

        final_attn_sa_op[pos] = attn_output_reshaped_sa[blk + offset]

    i += 1
        
final_attn_sa_op = final_attn_sa_op.reshape(attn_output_reshaped_sa.shape[1], -1)

final_attn_sa_op, final_attn_sa_op.shape

(array([[ 0.71996271,  0.37472803,  0.19898526, -0.096828  ,  0.72312544,
          0.11965125, -0.61675933,  0.41149321, -0.17895635, -0.97088938,
          0.21564846,  0.34026236,  0.45370401,  0.97827891,  0.88422869,
         -0.1181432 ],
        [ 0.74778941,  0.41424777,  0.14055333, -0.09750273,  0.83353188,
          0.10672244, -0.50738729,  0.4131849 , -0.15921384, -0.95981102,
          0.22985971,  0.33606611,  0.39785858,  0.99090853,  0.84784162,
         -0.1675067 ],
        [ 0.6835169 ,  0.43522338,  0.44485323, -0.15091048,  0.68540562,
          0.11745112, -0.58773571,  0.70770561,  0.17079556, -0.57683199,
          0.32521917,  0.4177635 ,  0.97294562,  0.8204363 ,  1.46646829,
         -0.43123435],
        [ 0.55486835, -0.18135736, -0.03144966,  0.15490696,  0.65851092,
          0.14600813, -0.62973254,  0.47600028, -0.21387931, -0.8561706 ,
          0.1344209 ,  0.39797618,  0.54188749,  0.98235489,  0.95371993,
         -0.10789927]]),
 (4, 16))

### 5. Post self attention in the encoder block

In [18]:
layer_num = 0

weight_enc = state_dict["transformer.encoder.layers.{}.self_attn.out_proj.weight".format(layer_num)].numpy()
bias_enc = state_dict["transformer.encoder.layers.{}.self_attn.out_proj.bias".format(layer_num)].numpy()

# Output projection of the attention values
op_enc_1 = np.matmul(final_attn_sa_op, weight_enc.T) + bias_enc

# Residual connection 1
output_enc_1 = op_enc_1 + x_enc


In [19]:
# Layer Norm 1


norm_weight = state_dict["transformer.encoder.layers.{}.norm1.weight".format(layer_num)].numpy()
norm_bias = state_dict["transformer.encoder.layers.{}.norm1.bias".format(layer_num)].numpy()

linear_result_enc_1 = output_enc_1 * norm_weight + norm_bias


mean = np.mean(linear_result_enc_1, axis=-1, keepdims=True)
std = np.std(linear_result_enc_1, axis=-1, keepdims=True)
epsilon = 1e-5 
normalized_linear_result_enc_1 = (linear_result_enc_1 - mean) / (std + epsilon)


layernorm_enc_1 = torch.nn.LayerNorm(normalized_linear_result_enc_1.shape[1:])

linear_op_enc_1 = layernorm_enc_1(torch.tensor(normalized_linear_result_enc_1, dtype=torch.float32))


In [20]:
linear_op_enc_1 = linear_op_enc_1.detach().numpy()

In [21]:
# Linear projections 

linear1_weight = state_dict["transformer.encoder.layers.{}.linear1.weight".format(layer_num)].numpy()
linear1_bias = state_dict["transformer.encoder.layers.{}.linear1.bias".format(layer_num)].numpy()
linear2_weight = state_dict["transformer.encoder.layers.{}.linear2.weight".format(layer_num)].numpy()
linear2_bias = state_dict["transformer.encoder.layers.{}.linear2.bias".format(layer_num)].numpy()


op_enc_1 = np.matmul(linear_op_enc_1, linear1_weight.T) + linear1_bias

# ReLU activation
op_enc_1_relu = np.maximum(op_enc_1, 0)

# Linear projection 2
op_enc_2 = np.matmul(op_enc_1_relu, linear2_weight.T) + linear2_bias


In [22]:
# Residual conenction 2 

output_enc_2 = op_enc_2 + linear_op_enc_1

In [23]:
# Layer Norm 2

norm_weight = state_dict["transformer.encoder.layers.{}.norm2.weight".format(layer_num)].numpy()
norm_bias = state_dict["transformer.encoder.layers.{}.norm2.bias".format(layer_num)].numpy()

linear_result_enc_2 = output_enc_2 * norm_weight + norm_bias


mean = np.mean(linear_result_enc_2, axis=-1, keepdims=True)
std = np.std(linear_result_enc_2, axis=-1, keepdims=True)
epsilon = 1e-5 
normalized_linear_result_enc_2 = (linear_result_enc_2 - mean) / (std + epsilon)

layernorm_enc_2 = torch.nn.LayerNorm(normalized_linear_result_enc_2.shape[1:])

linear_op_enc_2 = layernorm_enc_2(torch.tensor(normalized_linear_result_enc_2, dtype=torch.float32))


In [24]:
output_enc_final =  linear_op_enc_2

In [25]:
output_enc_final

tensor([[ 0.0776, -0.8134, -0.6337,  0.2314, -0.1776,  0.3632, -1.3015,  1.5257,
         -0.5745,  1.5865,  0.4041,  0.2572, -1.0101,  1.9088, -1.7131, -0.1305],
        [-0.5666, -1.6901, -0.2921,  1.2146, -0.1103,  0.0079, -2.1157,  1.6536,
          0.0881,  0.5683,  0.0386,  0.5650, -0.1463,  1.7027, -0.6104, -0.3074],
        [ 0.1136, -0.5709, -1.0399, -0.1699,  0.9766, -0.3185, -0.4165,  1.2750,
         -1.3753,  2.3096, -0.5499,  0.8579, -1.1126, -0.9835, -0.0990,  1.1033],
        [ 0.1810,  0.0223,  0.5764,  0.7853, -1.5473,  0.8328, -1.9499,  0.4452,
         -0.7445,  1.1691,  0.1031, -0.4145, -1.4597,  1.3111, -0.5877,  1.2772]],
       grad_fn=<NativeLayerNormBackward0>)

In [84]:
output_enc_final = output_enc_final.detach().numpy()

In [85]:
# This will be used for the next enocder layer / decoder layer (cross attention)
x_enc = output_enc_final

########################################################################

## Decoder block


### Self attention outputs from a decoder block

### 6. Target token embeddings 

In [86]:
state_dict.keys()

odict_keys(['src_embedding.weight', 'tgt_embedding.weight', 'transformer.encoder.layers.0.self_attn.in_proj_weight', 'transformer.encoder.layers.0.self_attn.in_proj_bias', 'transformer.encoder.layers.0.self_attn.out_proj.weight', 'transformer.encoder.layers.0.self_attn.out_proj.bias', 'transformer.encoder.layers.0.linear1.weight', 'transformer.encoder.layers.0.linear1.bias', 'transformer.encoder.layers.0.linear2.weight', 'transformer.encoder.layers.0.linear2.bias', 'transformer.encoder.layers.0.norm1.weight', 'transformer.encoder.layers.0.norm1.bias', 'transformer.encoder.layers.0.norm2.weight', 'transformer.encoder.layers.0.norm2.bias', 'transformer.encoder.norm.weight', 'transformer.encoder.norm.bias', 'transformer.decoder.layers.0.self_attn.in_proj_weight', 'transformer.decoder.layers.0.self_attn.in_proj_bias', 'transformer.decoder.layers.0.self_attn.out_proj.weight', 'transformer.decoder.layers.0.self_attn.out_proj.bias', 'transformer.decoder.layers.0.multihead_attn.in_proj_weight'

In [87]:
tgt_data1 = tgt_data[:-1,:]

In [88]:

tgt_vocab_embeds = state_dict["tgt_embedding.weight"]

tgt_embedding = np.zeros((tgt_data1.shape[0], d_model))

for i in range(tgt_data1.shape[0]):
        word_index = tgt_data1[i]
        if word_index < 0 or word_index >= tgt_vocab_embeds.shape[0]:
            raise ValueError(f"Invalid word index: {word_index}")
        tgt_embedding[i, :] = tgt_vocab_embeds[word_index, :]

        print(f"Word index: {word_index}, Embedding: {tgt_vocab_embeds[word_index, :]}")
print()
print(tgt_embedding.shape)


Word index: [1], Embedding: tensor([[ 0.6442,  3.9300, -0.1244,  0.2953,  0.3827, -0.5497, -0.9940,  1.3459,
          1.9457, -1.2904, -2.3495, -2.0689,  0.9094, -0.6946,  1.9595, -1.1038]])
Word index: [16], Embedding: tensor([[-0.8733,  0.0043, -1.2579, -1.0845,  0.7530,  0.3236, -0.2750,  1.3056,
          0.2118,  0.2720, -0.9268, -2.7330, -0.5642, -0.2740,  0.1398,  0.5086]])
Word index: [5], Embedding: tensor([[-0.7645,  0.2408,  0.1664, -2.2318,  1.3892, -0.5023,  1.6797, -1.0240,
          1.6859, -1.2177,  0.7650,  1.1971, -0.7128, -0.0656,  2.2050,  1.7852]])
Word index: [3], Embedding: tensor([[ 0.4990,  0.8780,  0.3894,  1.4625,  0.4795, -0.5334, -0.0347,  0.6573,
         -0.3112, -0.5620, -0.4835, -1.2721, -0.1740,  0.5541, -0.1817, -0.2345]])

(4, 16)


### 7. Target token embeddings + positional embeddings 

In [89]:
pe = PositionalEncoding(d_model=d_model, max_len=max_seq_len)

pe_tgt_embeds = tgt_embedding + pe.forward(tgt_data)


pe_tgt_embeds.shape, pe_tgt_embeds

((4, 16),
 array([[ 0.64423001,  4.93000388, -0.12442428,  1.29534167,  0.38265419,
          0.45027864, -0.99403578,  2.34593689,  1.94566822, -0.29036391,
         -2.3494761 , -1.06886196,  0.90942109,  0.30537993,  1.95945716,
         -0.10382783],
        [-0.87330669,  1.00426142, -1.25788677, -0.08446777,  0.7529794 ,
          1.32364774, -0.27501002,  2.30561185,  0.21175182,  1.27196231,
         -0.92684317, -1.7329998 , -0.5641737 ,  0.72600037,  0.13978058,
          1.50856197],
        [-0.76447284,  1.24084058,  0.16642573, -1.23181415,  1.38921094,
          0.49766743,  1.67969298, -0.02395296,  1.68592429, -0.21769202,
          0.76496333,  2.19711864, -0.71278685,  0.93442459,  2.20497036,
          2.78517103],
        [ 0.49895304,  1.87799746,  0.38944435,  2.4625175 ,  0.47950602,
          0.46660012, -0.03465135,  1.65729696, -0.31122431,  0.43799645,
         -0.48349261, -0.27211261, -0.17401844,  1.55411685, -0.18165524,
          0.76552661]]))

In [90]:
x_dec = pe_tgt_embeds
x_dec.shape

(4, 16)

## SELF ATTENTION ( with target mask ) 

In [91]:
seq_length = 4
tgt_mask = np.triu(np.ones((seq_length, seq_length)), k=1).astype('bool')

tgt_mask

array([[False,  True,  True,  True],
       [False, False,  True,  True],
       [False, False, False,  True],
       [False, False, False, False]])

### 8. Getting the Q,K,V matrices from the model's intialised weights

In [92]:
import numpy as np

query_dec = key_dec = value_dec = x_dec

tgt_len, embed_dim = x_dec.shape

layer_num = 0

W_dec = state_dict["transformer.decoder.layers.{}.self_attn.in_proj_weight".format(layer_num)].numpy()
b_dec = state_dict["transformer.decoder.layers.{}.self_attn.in_proj_bias".format(layer_num)].numpy()

head_dim = embed_dim // num_heads

# embed_dim
E = query_dec.shape[-1]

tempop1 = np.matmul(query_dec, W_dec.T)

Q_dec, K_dec, V_dec = tempop1[:, 0:embed_dim], tempop1[:, embed_dim:2*embed_dim], tempop1[:, 2*embed_dim:3*embed_dim]

print("Q_dec_shape = ", Q_dec.shape)
print("K_dec_shape = ", K_dec.shape)
print("V_dec_shape = ", V_dec.shape)
print()

print("After reshaping... \n")

Q_dec = np.transpose(np.reshape(Q_dec, (tgt_len, num_heads, head_dim)), (1, 0, 2))
K_dec = np.transpose(np.reshape(K_dec, (K_dec.shape[0], num_heads, head_dim)), (1, 0, 2))
V_dec = np.transpose(np.reshape(V_dec, (V_dec.shape[0], num_heads, head_dim)), (1, 0, 2))

print("Q_dec_shape = ", Q_dec.shape)
print("K_dec_shape = ", K_dec.shape)
print("V_dec_shape = ", V_dec.shape)


Q_dec_shape =  (4, 16)
K_dec_shape =  (4, 16)
V_dec_shape =  (4, 16)

After reshaping... 

Q_dec_shape =  (4, 4, 4)
K_dec_shape =  (4, 4, 4)
V_dec_shape =  (4, 4, 4)


In [93]:
print("Q_dec_{} = ".format(layer_num))
print(Q_dec)
print()

print("K_dec_{} = ".format(layer_num))
print(K_dec)
print()

print("V_dec_{} = ".format(layer_num))
print(V_dec)
print()

Q_dec_0 = 
[[[-0.4777198   0.69994168 -0.24317164  0.137528  ]
  [ 0.81450903 -0.44311232  0.54724521 -0.11311741]
  [-0.96026027  2.46727906 -0.2959999   2.93302749]
  [-0.20450042  0.66613672  0.35755138 -0.31146674]]

 [[-0.98550985 -0.75627789  2.19837301 -1.03167234]
  [ 0.24665634  0.88958552 -0.30725672 -0.97874922]
  [-1.8929974   0.58479245  1.4320325   0.12516218]
  [-1.55000703  0.3509483  -0.06880262 -0.52188314]]

 [[-0.83401293  1.4440153   0.40271761 -0.44414086]
  [ 0.43542792 -0.32328431  0.66460001  0.18381045]
  [-0.3863224  -1.61058622 -0.63673838  0.19069625]
  [-0.50454088  0.39886381 -0.02718482 -0.26922808]]

 [[ 0.50697246  1.36147471 -0.83323292 -0.15125591]
  [-0.10894405  0.72768938 -0.01391727  0.4556665 ]
  [ 2.7266259   1.4596128   1.58609343  1.13671036]
  [ 0.22424157  1.67205158 -1.2045582   0.0386641 ]]]

K_dec_0 = 
[[[ 7.83983927e-01 -2.89776999e+00 -1.99690021e+00 -8.15225161e-02]
  [ 1.19432196e-01 -1.32870406e+00  6.87517271e-02  7.77369245e-01]
 

### Preparing the mask for decoder attention mechanisms

In [94]:
attn_bias = np.zeros(tgt_mask.shape)

if tgt_mask is not None:
    if tgt_mask.dtype == 'bool':
        # tgt_mask.masked_fill_(tgt_mask.logical_not(), float("-inf"))

        masked_tensor = tgt_mask.astype(float)
        masked_tensor[masked_tensor == 1] = -np.inf
        tgt_mask = masked_tensor

        attn_bias += tgt_mask 


    else:
        attn_bias += tgt_mask
        attn_bias = attn_bias


In [95]:
print("Attention mask = \n", attn_bias)

Attention mask = 
 [[  0. -inf -inf -inf]
 [  0.   0. -inf -inf]
 [  0.   0.   0. -inf]
 [  0.   0.   0.   0.]]


#### Attention calculation with Q,K and V matrices

In [96]:
src_len = K_dec.shape[1]

Q_dec1 = Q_dec
K_dec1 = K_dec
V_dec1 = V_dec

scale_factor = 1 / math.sqrt(Q_dec1.shape[-1]) 

K_dec1_T = np.transpose(K_dec1, axes=(0, 2, 1))

attn_weight_dec = Q_dec1 @ K_dec1_T * scale_factor
attn_weight_dec += attn_bias


exp_attn_weight_dec_sa = np.exp(attn_weight_dec)
sum_exp_attn_weight_dec_sa = np.sum(exp_attn_weight_dec_sa, axis=-1, keepdims=True)
softmax_attn_weight_dec_sa = exp_attn_weight_dec_sa / sum_exp_attn_weight_dec_sa

print("SoftMax (Scaled Dot Product of Q_dec and K_dec) = ")
softmax_attn_weight_dec_sa

SoftMax (Scaled Dot Product of Q_dec and K_dec) = 


array([[[1.        , 0.        , 0.        , 0.        ],
        [0.52538981, 0.47461019, 0.        , 0.        ],
        [0.02873559, 0.71104075, 0.26022366, 0.        ],
        [0.11259072, 0.25719802, 0.34845905, 0.28175221]],

       [[1.        , 0.        , 0.        , 0.        ],
        [0.63127246, 0.36872754, 0.        , 0.        ],
        [0.23052338, 0.15412591, 0.61535071, 0.        ],
        [0.24104742, 0.15464001, 0.34063577, 0.2636768 ]],

       [[1.        , 0.        , 0.        , 0.        ],
        [0.54215452, 0.45784548, 0.        , 0.        ],
        [0.19330943, 0.27523381, 0.53145675, 0.        ],
        [0.31028955, 0.3000996 , 0.20510235, 0.1845085 ]],

       [[1.        , 0.        , 0.        , 0.        ],
        [0.44371194, 0.55628806, 0.        , 0.        ],
        [0.07702968, 0.2608052 , 0.66216512, 0.        ],
        [0.30255118, 0.29141826, 0.14721781, 0.25881275]]])

In [97]:
attn_output_dec_sa = softmax_attn_weight_dec_sa @ V_dec1

In [98]:
print("Final Attention Values (Self attention decoder)= \n")
print(attn_output_dec_sa)

Final Attention Values (Self attention decoder)= 

[[[-1.02595520e+00  9.66156820e-01 -2.29126388e+00  1.80570101e+00]
  [-4.94813659e-01  3.63254786e-02 -1.18165435e+00  1.23624897e+00]
  [ 3.36763757e-01 -5.12013737e-01  8.94194254e-02  5.29699367e-01]
  [ 4.78021036e-01 -4.68885516e-02 -2.54583696e-01  4.71600681e-01]]

 [[-2.39825494e-02 -1.33530222e+00 -1.67510990e+00 -9.04894110e-01]
  [-8.35023637e-02 -1.15296018e+00 -1.09796131e+00 -5.63298594e-01]
  [ 1.44866013e-01 -1.71497168e-01  2.57711394e-01 -4.27455359e-01]
  [-5.85121327e-02 -7.90012909e-01 -3.01850517e-01 -4.11758949e-01]]

 [[ 1.53036526e+00 -1.85140924e+00 -1.51471974e+00  9.21301429e-01]
  [ 1.37871488e+00 -1.49100274e+00 -1.79188216e+00  7.92854025e-01]
  [ 1.35127615e+00 -8.67529220e-01 -1.43847060e+00 -6.94138398e-01]
  [ 1.29407121e+00 -1.14007400e+00 -1.47584499e+00  2.71347134e-01]]

 [[ 5.28027889e-01  7.11727333e-01 -2.57351917e-01  1.51332413e+00]
  [ 4.19328104e-01  7.33425081e-01  5.74105293e-02  1.38451

#### Reshaping the final attention output

In [99]:
attn_output_dec_sa_permuted = np.transpose(attn_output_dec_sa, axes=(0,1,2))

attn_output_dec_sa_permuted.shape

(4, 4, 4)

In [100]:
numh_tgt_len, embed_dim = num_heads * tgt_len, head_dim

attn_output_dec_sa_permuted_reshaped = attn_output_dec_sa_permuted.reshape(numh_tgt_len, embed_dim)

# attn_output_reshaped_np = attn_output_reshaped_np.T

attn_output_dec_sa_permuted_reshaped, attn_output_dec_sa_permuted_reshaped.shape

(array([[-1.02595520e+00,  9.66156820e-01, -2.29126388e+00,
          1.80570101e+00],
        [-4.94813659e-01,  3.63254786e-02, -1.18165435e+00,
          1.23624897e+00],
        [ 3.36763757e-01, -5.12013737e-01,  8.94194254e-02,
          5.29699367e-01],
        [ 4.78021036e-01, -4.68885516e-02, -2.54583696e-01,
          4.71600681e-01],
        [-2.39825494e-02, -1.33530222e+00, -1.67510990e+00,
         -9.04894110e-01],
        [-8.35023637e-02, -1.15296018e+00, -1.09796131e+00,
         -5.63298594e-01],
        [ 1.44866013e-01, -1.71497168e-01,  2.57711394e-01,
         -4.27455359e-01],
        [-5.85121327e-02, -7.90012909e-01, -3.01850517e-01,
         -4.11758949e-01],
        [ 1.53036526e+00, -1.85140924e+00, -1.51471974e+00,
          9.21301429e-01],
        [ 1.37871488e+00, -1.49100274e+00, -1.79188216e+00,
          7.92854025e-01],
        [ 1.35127615e+00, -8.67529220e-01, -1.43847060e+00,
         -6.94138398e-01],
        [ 1.29407121e+00, -1.14007400e+00, 

In [101]:
final_attn_dec_sa_op = np.zeros(attn_output_dec_sa_permuted_reshaped.shape)

i = 0

while i < attn_output_dec_sa_permuted_reshaped.shape[1]:
    for j in range(attn_output_dec_sa_permuted_reshaped.shape[1]):

        pos = i*attn_output_dec_sa_permuted_reshaped.shape[1] + j

        blk = (j)*attn_output_dec_sa_permuted_reshaped.shape[1]
        offset = i

        final_attn_dec_sa_op[pos] = attn_output_dec_sa_permuted_reshaped[blk + offset]

    i += 1
        
final_attn_dec_sa_op = final_attn_dec_sa_op.reshape(attn_output_dec_sa_permuted_reshaped.shape[1], -1)

final_attn_dec_sa_op, final_attn_dec_sa_op.shape

(array([[-1.02595520e+00,  9.66156820e-01, -2.29126388e+00,
          1.80570101e+00, -2.39825494e-02, -1.33530222e+00,
         -1.67510990e+00, -9.04894110e-01,  1.53036526e+00,
         -1.85140924e+00, -1.51471974e+00,  9.21301429e-01,
          5.28027889e-01,  7.11727333e-01, -2.57351917e-01,
          1.51332413e+00],
        [-4.94813659e-01,  3.63254786e-02, -1.18165435e+00,
          1.23624897e+00, -8.35023637e-02, -1.15296018e+00,
         -1.09796131e+00, -5.63298594e-01,  1.37871488e+00,
         -1.49100274e+00, -1.79188216e+00,  7.92854025e-01,
          4.19328104e-01,  7.33425081e-01,  5.74105293e-02,
          1.38451307e-01],
        [ 3.36763757e-01, -5.12013737e-01,  8.94194254e-02,
          5.29699367e-01,  1.44866013e-01, -1.71497168e-01,
          2.57711394e-01, -4.27455359e-01,  1.35127615e+00,
         -8.67529220e-01, -1.43847060e+00, -6.94138398e-01,
          1.55118275e-01,  2.82365411e-03,  1.78452262e-01,
          6.81186312e-02],
        [ 4.7802103

### 9. Post self attention in the decoder self attention block

In [102]:
layer_num = 0

weight_dec = state_dict["transformer.decoder.layers.{}.self_attn.out_proj.weight".format(layer_num)].numpy()
bias_dec = state_dict["transformer.decoder.layers.{}.self_attn.out_proj.bias".format(layer_num)].numpy()

# Output projection of the attention values
op_dec_1 = np.matmul(final_attn_dec_sa_op, weight_dec.T) + bias_dec

# Residual connection 1
output_dec_1 = op_dec_1 + x_dec


In [103]:
# Layer Norm 1

norm_weight = state_dict["transformer.decoder.layers.{}.norm1.weight".format(layer_num)].numpy()
norm_bias = state_dict["transformer.decoder.layers.{}.norm1.bias".format(layer_num)].numpy()

linear_result_dec_1 = output_dec_1 * norm_weight + norm_bias


mean = np.mean(linear_result_dec_1, axis=-1, keepdims=True)
std = np.std(linear_result_dec_1, axis=-1, keepdims=True)
epsilon = 1e-5 
normalized_linear_result_dec_1 = (linear_result_dec_1 - mean) / (std + epsilon)


layernorm_dec_1 = torch.nn.LayerNorm(normalized_linear_result_dec_1.shape[1:])

linear_op_dec_1 = layernorm_dec_1(torch.tensor(normalized_linear_result_dec_1, dtype=torch.float32))


In [104]:
print("Decoder_{} norm1(x + sa(x)) = \n".format(layer_num))
print(linear_op_dec_1)

Decoder_0 norm1(x + sa(x)) = 

tensor([[ 0.7908,  2.6168,  0.1973,  0.1147, -0.8474, -0.2229,  0.2594,  0.6727,
         -0.1353, -0.3840, -1.5653, -1.8525,  0.6594, -0.4768,  0.5674, -0.3942],
        [-0.2944,  1.0971, -0.4035,  0.2262, -0.7558,  0.2582,  0.4699,  1.7153,
         -0.6264,  0.8104, -1.2103, -2.6650,  0.3666,  0.4385, -0.2554,  0.8285],
        [-2.0041,  0.1375,  0.0984, -1.5183, -0.8067, -0.8488,  0.9485,  0.4903,
          1.1037, -0.7078, -0.4518,  0.9925, -0.5847,  1.0922,  0.4934,  1.5658],
        [ 0.3119,  1.1703,  0.4139,  2.0400, -1.0685, -0.5466, -0.1719,  1.3237,
         -0.9920, -0.0162, -0.9911, -1.6993,  0.1209,  1.0915, -0.9577, -0.0288]],
       grad_fn=<NativeLayerNormBackward0>)


In [105]:
self_attn_dec = linear_op_dec_1.detach().numpy()

## 10. CROSS ATTENTION (with masking)

In [106]:
memory = x_enc

In [107]:
layer_num = 0

query_dec_ca = self_attn_dec
key_dec_ca, value_dec_ca = memory, memory


tgt_len, embed_dim = query_dec_ca.shape


W_dec_ca = state_dict["transformer.decoder.layers.{}.multihead_attn.in_proj_weight".format(layer_num)].numpy()
b_dec_ca = state_dict["transformer.decoder.layers.{}.multihead_attn.in_proj_bias".format(layer_num)].numpy()

In [125]:
import numpy as np

head_dim = embed_dim // num_heads

# embed_dim
E = query_dec_ca.shape[-1]

# W_q, W_kv = W_dec_ca.split([E, E * 2])
split_indices = [E, E * 2]


# Split the array
W_q = W_dec_ca[:split_indices[0], :]
W_kv = W_dec_ca[split_indices[0]:, :]


Q_dec_ca = np.matmul(query_dec_ca, W_q.T)

tempop1 = np.matmul(key_dec_ca, W_kv.T)

K_dec_ca, V_dec_ca = tempop1[:, 0:embed_dim], tempop1[:, embed_dim:2*embed_dim]

print("Q_dec_ca_shape = ", Q_dec_ca.shape)
print("K_dec_ca_shape = ", K_dec_ca.shape)
print("V_dec_ca_shape = ", V_dec_ca.shape)
print()

print("After reshaping... \n")

Q_dec_ca = np.transpose(np.reshape(Q_dec_ca, (tgt_len, num_heads, head_dim)), (1, 0, 2))
K_dec_ca = np.transpose(np.reshape(K_dec_ca, (K_dec_ca.shape[0], num_heads, head_dim)), (1, 0, 2))
V_dec_ca = np.transpose(np.reshape(V_dec_ca, (V_dec_ca.shape[0], num_heads, head_dim)), (1, 0, 2))

print("Q_dec_ca_shape = ", Q_dec_ca.shape)
print("K_dec_ca_shape = ", K_dec_ca.shape)
print("V_dec_ca_shape = ", V_dec_ca.shape)


Q_dec_ca_shape =  (4, 16)
K_dec_ca_shape =  (4, 16)
V_dec_ca_shape =  (4, 16)

After reshaping... 

Q_dec_ca_shape =  (4, 4, 4)
K_dec_ca_shape =  (4, 4, 4)
V_dec_ca_shape =  (4, 4, 4)


In [126]:
print("Q_dec_ca_{} = ".format(layer_num))
print(Q_dec_ca)
print()

print("K_dec_ca_{} = ".format(layer_num))
print(K_dec_ca)
print()

print("V_dec_ca_{} = ".format(layer_num))
print(V_dec_ca)
print()

Q_dec_ca_0 = 
[[[ 1.1264023  -0.99372196  1.4196618   0.3692187 ]
  [ 0.8368701  -0.39764592  1.7715282  -0.20809937]
  [-0.80293286 -0.44238973 -0.44811496 -0.08087807]
  [ 0.14036858  0.35611072  1.3061354   0.5647947 ]]

 [[ 0.34243727  0.6420679   1.0896097  -0.35935146]
  [ 0.5932352  -0.02133346  0.17508404 -0.10132178]
  [ 0.22525965 -0.34866813 -1.386473    0.5934426 ]
  [ 1.0678444  -0.36406344  0.22718026 -0.70181704]]

 [[ 0.2477008   0.7129092  -1.5833832  -0.24581257]
  [ 0.74955493 -0.14404999 -0.86124974  0.54051477]
  [ 1.2827715  -0.9601834  -0.98511726 -0.02957391]
  [ 0.630754    0.139621    0.13323686  0.7768027 ]]

 [[-0.5755281   0.7306101  -0.5634235  -1.4578571 ]
  [-1.0339795   0.39026204 -0.66240704 -0.78707576]
  [ 0.80315405 -0.98282826  0.02963641  0.14811504]
  [-1.5224515   1.3331584  -0.5388493  -0.9295683 ]]]

K_dec_ca_0 = 
[[[ 0.95142776 -0.33068627  0.6432074   0.57482743]
  [ 0.74835044 -0.771075    0.7212278   0.51452005]
  [-0.85368836 -0.758265   

In [144]:
src_len = K_dec_ca.shape[1]

Q_dec_ca1 = Q_dec_ca
K_dec_ca1 = K_dec_ca
V_dec_ca1 = V_dec_ca

scale_factor = 1 / math.sqrt(Q_dec_ca1.shape[-1]) 

K_dec_ca1_T = np.transpose(K_dec_ca1, axes=(0, 2, 1))

attn_weight_ca = Q_dec_ca1 @ K_dec_ca1_T * scale_factor
# attn_weight_ca += attn_bias

exp_attn_weight_ca = np.exp(attn_weight_ca)
sum_exp_attn_weight_ca = np.sum(exp_attn_weight_ca, axis=-1, keepdims=True)
sum_exp_attn_weight_ca = exp_attn_weight_ca / sum_exp_attn_weight_ca

print("SoftMax (Scaled Dot Product of Q_dec_ca and K_dec_ca) = ")
sum_exp_attn_weight_ca

SoftMax (Scaled Dot Product of Q_dec_ca and K_dec_ca) = 


array([[[0.37049463, 0.4298884 , 0.08394103, 0.115676  ],
        [0.34362885, 0.37149283, 0.13540995, 0.14946844],
        [0.1550957 , 0.18271734, 0.42273828, 0.23944868],
        [0.3496295 , 0.32968232, 0.12520583, 0.19548234]],

       [[0.22347003, 0.17626783, 0.19703059, 0.4032315 ],
        [0.27847624, 0.245871  , 0.18112177, 0.29453105],
        [0.29538497, 0.30592418, 0.2537745 , 0.14491634],
        [0.2999228 , 0.281205  , 0.11160002, 0.30727214]],

       [[0.3081194 , 0.34238714, 0.17676455, 0.17272897],
        [0.22735739, 0.19056454, 0.29483008, 0.287248  ],
        [0.23956823, 0.16378114, 0.27357313, 0.32307753],
        [0.17402405, 0.15002577, 0.34607592, 0.32987425]],

       [[0.21477   , 0.16061941, 0.32229054, 0.30232006],
        [0.23761584, 0.20637609, 0.24874291, 0.30726513],
        [0.23228645, 0.22125815, 0.28067073, 0.2657847 ],
        [0.25358972, 0.2304999 , 0.23621926, 0.27969116]]], dtype=float32)

In [145]:
attn_output_dec_ca = sum_exp_attn_weight_ca @ V_dec_ca

In [146]:
print("Final Attention Values (Cross attention decoder)= \n")
print(attn_output_dec_ca)

Final Attention Values (Cross attention decoder)= 

[[[-0.18553464 -0.57655513 -0.6966696   0.25312534]
  [-0.23444562 -0.6136954  -0.6690333   0.3031558 ]
  [-0.48267326 -0.72678804 -0.47050145  0.5191439 ]
  [-0.24223913 -0.64934593 -0.69094515  0.31772697]]

 [[-0.4313032   0.17086504 -0.28128847 -0.18745682]
  [-0.39735857  0.08389443 -0.2219498  -0.2780158 ]
  [-0.32980356  0.06160462 -0.13966209 -0.39402306]
  [-0.40939105  0.01804294 -0.23653856 -0.2851182 ]]

 [[ 0.79633325 -0.6871874  -0.10476258  0.01925761]
  [ 0.79108244 -0.6832218  -0.15286937  0.0438292 ]
  [ 0.7985909  -0.68131524 -0.1448296   0.03213922]
  [ 0.78334427 -0.68820757 -0.175838    0.05549478]]

 [[ 0.1786326  -0.71823883 -0.38515076 -0.23016493]
  [ 0.23649485 -0.8250915  -0.31696987 -0.29983407]
  [ 0.19556943 -0.7640394  -0.3157303  -0.24747981]
  [ 0.23015694 -0.8368094  -0.2903072  -0.294601  ]]]


#### Reshaping the final attention output

In [147]:
attn_output_dec_ca_permuted = np.transpose(attn_output_dec_ca, axes=(0,1,2))

attn_output_dec_ca_permuted.shape

(4, 4, 4)

In [148]:
numh_tgt_len, embed_dim = num_heads * tgt_len, head_dim

attn_output_dec_ca_permuted_reshaped = attn_output_dec_ca_permuted.reshape(numh_tgt_len, embed_dim)

# attn_output_reshaped_np = attn_output_reshaped_np.T

attn_output_dec_ca_permuted_reshaped, attn_output_dec_ca_permuted_reshaped.shape

(array([[-0.18553464, -0.57655513, -0.6966696 ,  0.25312534],
        [-0.23444562, -0.6136954 , -0.6690333 ,  0.3031558 ],
        [-0.48267326, -0.72678804, -0.47050145,  0.5191439 ],
        [-0.24223913, -0.64934593, -0.69094515,  0.31772697],
        [-0.4313032 ,  0.17086504, -0.28128847, -0.18745682],
        [-0.39735857,  0.08389443, -0.2219498 , -0.2780158 ],
        [-0.32980356,  0.06160462, -0.13966209, -0.39402306],
        [-0.40939105,  0.01804294, -0.23653856, -0.2851182 ],
        [ 0.79633325, -0.6871874 , -0.10476258,  0.01925761],
        [ 0.79108244, -0.6832218 , -0.15286937,  0.0438292 ],
        [ 0.7985909 , -0.68131524, -0.1448296 ,  0.03213922],
        [ 0.78334427, -0.68820757, -0.175838  ,  0.05549478],
        [ 0.1786326 , -0.71823883, -0.38515076, -0.23016493],
        [ 0.23649485, -0.8250915 , -0.31696987, -0.29983407],
        [ 0.19556943, -0.7640394 , -0.3157303 , -0.24747981],
        [ 0.23015694, -0.8368094 , -0.2903072 , -0.294601  ]],
       

In [149]:
final_attn_dec_ca_op = np.zeros(attn_output_dec_ca_permuted_reshaped.shape)

i = 0

while i < attn_output_dec_ca_permuted_reshaped.shape[1]:
    for j in range(attn_output_dec_ca_permuted_reshaped.shape[1]):

        pos = i*attn_output_dec_ca_permuted_reshaped.shape[1] + j

        blk = (j)*attn_output_dec_ca_permuted_reshaped.shape[1]
        offset = i

        final_attn_dec_ca_op[pos] = attn_output_dec_ca_permuted_reshaped[blk + offset]

    i += 1
        
final_attn_dec_ca_op = final_attn_dec_ca_op.reshape(attn_output_dec_ca_permuted_reshaped.shape[1], -1)

final_attn_dec_ca_op, final_attn_dec_ca_op.shape

(array([[-0.18553464, -0.57655513, -0.69666958,  0.25312534, -0.4313032 ,
          0.17086504, -0.28128847, -0.18745682,  0.79633325, -0.68718737,
         -0.10476258,  0.01925761,  0.1786326 , -0.71823883, -0.38515076,
         -0.23016493],
        [-0.23444562, -0.61369538, -0.66903329,  0.30315581, -0.39735857,
          0.08389443, -0.2219498 , -0.27801579,  0.79108244, -0.68322182,
         -0.15286937,  0.0438292 ,  0.23649485, -0.82509148, -0.31696987,
         -0.29983407],
        [-0.48267326, -0.72678804, -0.47050145,  0.51914388, -0.32980356,
          0.06160462, -0.13966209, -0.39402306,  0.7985909 , -0.68131524,
         -0.1448296 ,  0.03213922,  0.19556943, -0.7640394 , -0.3157303 ,
         -0.24747981],
        [-0.24223913, -0.64934593, -0.69094515,  0.31772697, -0.40939105,
          0.01804294, -0.23653856, -0.28511819,  0.78334427, -0.68820757,
         -0.17583799,  0.05549478,  0.23015694, -0.8368094 , -0.29030719,
         -0.29460099]]),
 (4, 16))

### 11. Post cross attention in the decoder block

In [150]:
layer_num = 0

weight_dec_ca = state_dict["transformer.decoder.layers.{}.multihead_attn.out_proj.weight".format(layer_num)].numpy()
bias_dec_ca = state_dict["transformer.decoder.layers.{}.multihead_attn.out_proj.bias".format(layer_num)].numpy()

# Output projection of the attention values
op_dec_ca1 = np.matmul(final_attn_dec_ca_op, weight_dec_ca.T) + bias_dec_ca

# Residual connection 2
output_dec_2 = op_dec_ca1 + self_attn_dec


In [151]:
# Layer Norm 2

norm_weight = state_dict["transformer.decoder.layers.{}.norm2.weight".format(layer_num)].numpy()
norm_bias = state_dict["transformer.decoder.layers.{}.norm2.bias".format(layer_num)].numpy()


linear_result_dec_2 = output_dec_2*norm_weight + norm_bias


mean = np.mean(linear_result_dec_2, axis=-1, keepdims=True)
std = np.std(linear_result_dec_2, axis=-1, keepdims=True)
epsilon = 1e-5 
normalized_linear_result_dec_2 = (linear_result_dec_2 - mean) / (std + epsilon)


layernorm_dec_2 = torch.nn.LayerNorm(normalized_linear_result_dec_2.shape[1:])

linear_op_dec_2 = layernorm_dec_2(torch.tensor(normalized_linear_result_dec_2, dtype=torch.float32))


In [152]:
linear_op_dec_2 = linear_op_dec_2.detach().numpy()

In [153]:
# Linear projections 

linear1_weight = state_dict["transformer.decoder.layers.{}.linear1.weight".format(layer_num)].numpy()
linear1_bias = state_dict["transformer.decoder.layers.{}.linear1.bias".format(layer_num)].numpy()
linear2_weight = state_dict["transformer.decoder.layers.{}.linear2.weight".format(layer_num)].numpy()
linear2_bias = state_dict["transformer.decoder.layers.{}.linear2.bias".format(layer_num)].numpy()


op_dec_1 = np.matmul(linear_op_dec_2, linear1_weight.T) + linear1_bias

# ReLU activation
op_dec_1_relu = np.maximum(op_dec_1, 0)

# Linear projection 2
op_dec_2 = np.matmul(op_dec_1_relu, linear2_weight.T) + linear2_bias

ff_dec = op_dec_2

In [154]:
# Residual conenction 2 

output_enc_3 = linear_op_dec_2 + ff_dec

In [155]:
# Layer Norm 3

norm_weight = state_dict["transformer.decoder.layers.{}.norm2.weight".format(layer_num)].numpy()
norm_bias = state_dict["transformer.decoder.layers.{}.norm2.bias".format(layer_num)].numpy()


linear_result_dec_3 = output_enc_3*norm_weight + norm_bias


mean = np.mean(linear_result_dec_3, axis=-1, keepdims=True)
std = np.std(linear_result_dec_3, axis=-1, keepdims=True)
epsilon = 1e-5 
normalized_linear_result_dec_3 = (linear_result_dec_3 - mean) / (std + epsilon)


layernorm_dec_3 = torch.nn.LayerNorm(normalized_linear_result_dec_3.shape[1:])

linear_op_dec_3 = layernorm_dec_3(torch.tensor(normalized_linear_result_dec_3, dtype=torch.float32))


In [156]:
linear_op_dec_3 = linear_op_dec_3.detach().numpy()

In [157]:
print("norm3(x'' + ff(x'')) \n where, x'' = Decoder_curr_layer norm2(x' + mha(x'))\n")
print(linear_op_dec_3)
print()


norm3(x'' + ff(x'')) 
 where, x'' = Decoder_curr_layer norm2(x' + mha(x'))

[[-0.5356866   1.9318144  -0.51183444 -0.06886522 -0.35834506  0.00633696
   0.74473906  0.38577303  0.36656466  0.24361272 -1.5169702  -1.1969402
   1.2439213  -1.0069604   1.6122651  -1.3394252 ]
 [-0.95041865  0.7818231  -0.9022948  -0.16324307 -0.3592338   0.16954553
   0.85607445  1.721027   -0.5615748   1.3511652  -1.3389051  -2.1498117
   0.5818949   0.4841478   0.7635639  -0.28375974]
 [-1.7267284   0.1277106  -0.20488055 -1.9431874   0.58657104 -1.0944295
   0.8495381   0.56639     0.32088426 -0.5575043  -1.0710053   0.7707297
  -0.08021259  1.0479496   1.6733918   0.7347833 ]
 [-0.18048695  1.1450682  -0.14022484  1.8074682  -0.66768867 -0.591716
   0.12012903  1.2484851  -1.2300383   0.6516065  -1.4823984  -1.1182386
   0.0629568   1.5355982  -0.04201251 -1.1185073 ]]



### 12. Feed forward layer

In [158]:
dec_output_final = linear_op_dec_3

In [159]:
W_ff = state_dict["fc.weight"].numpy()
b_ff = state_dict["fc.bias"].numpy()

final_op = dec_output_final@W_ff.T + b_ff

In [160]:
final_op, final_op.shape

(array([[ 0.37746853,  0.07337449, -0.8095904 ,  0.14691624, -0.4905476 ,
          0.37394795, -0.2704848 , -0.92419446,  0.6273823 , -0.6122049 ,
          0.25933215, -0.16878009,  0.13227718,  1.3098803 ,  0.7643256 ,
         -0.61062217, -0.13705996,  0.75523925, -0.5415498 ,  0.7422502 ],
        [ 0.43439144,  0.35608122, -0.49823174, -0.43256435, -0.23772678,
          0.8758416 , -1.064142  , -0.3422812 ,  0.3772208 , -0.445173  ,
          0.65189123, -0.26663378, -0.06459779,  0.780676  ,  0.5272392 ,
         -0.6572926 ,  0.09763362,  1.0803541 , -0.9153302 ,  1.1743307 ],
        [-0.5873185 ,  0.21614808, -0.60998964, -0.07736345, -0.5957562 ,
          0.68771374, -0.267695  , -0.44631606,  0.51722544,  1.1136649 ,
          0.10397579, -0.30067962, -1.1872319 , -0.25477836,  0.6387973 ,
          0.09141681,  0.02895935, -0.05193508,  0.38341376, -0.10680325],
        [ 0.7888366 ,  0.7641175 , -0.12451119, -0.02487959,  0.18648735,
         -0.01425064, -0.9725598 , 