## Multi-head attention transformer
### Encoder and Decoder
### (With masking)

Pytorch's implementation (in built)

NOTE :- A new exmple must be used for testing the masked attention

In [187]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

torch.manual_seed(0)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512, dropout=0):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        return self.encoding[:, :x.size(1)].detach()

class TransformerModel(nn.Module):
    
    def __init__(self, src_vocab_size, tgt_vocab_size,max_seq_len, d_model=4, nhead=2, num_encoder_layers=1, num_decoder_layers=1):
        super(TransformerModel, self).__init__()
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, dropout=0, max_len=max_seq_len)
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dropout=0
        )
        self.fc = nn.Linear(d_model, tgt_vocab_size)

    def forward(self, src, tgt):

        print("Source embeddings :- \n")
        print(self.src_embedding(src))
        print()

        print("Target embeddings :- \n")
        print(self.tgt_embedding(tgt))
        print()


        src = self.src_embedding(src) + self.positional_encoding(src)
        tgt = self.tgt_embedding(tgt) + self.positional_encoding(tgt)

        print("Positional encoded source embeddings:")
        print(src)
        print()

        print("Positional encoded target embeddings:")
        print(tgt)
        print()

        output = self.transformer(src, tgt)
        output = self.fc(output)
        
        return output

    

src_vocab_size = 10  # Source language vocabulary size
tgt_vocab_size = 10  # Target language vocabulary size
d_model = 6  # Dimension of the model
num_heads = 2
# max_seq_len = 4
max_seq_len = 5

model = TransformerModel(src_vocab_size, tgt_vocab_size, d_model=d_model, max_seq_len = max_seq_len, nhead=num_heads)

# Source sentence in the source language
# Source token indexes from src vocabulary
# src_sentence = torch.tensor([[0], [1], [2], [3]]) 



# Target sentence in the target language (translation of the source sentence)
# Target token indexes from tgt vocabulary
# tgt_sentence = torch.tensor([[1], [0], [3], [3]])


src_sentence = torch.tensor([[0, 2], [1, 0], [2, 2], [3, 5]])
tgt_sentence = torch.tensor([[1, 7], [3, 4], [5, 2], [8, 0], [6, 1]])  # Target sequence
max_seq_len = 5



# Forward pass
output = model(src_sentence, tgt_sentence)
print(output, output.shape)


Source embeddings :- 

tensor([[[-1.1258, -1.1524, -0.2506, -0.4339,  0.8487,  0.6920],
         [ 0.1198,  1.2377,  1.1168, -0.2473, -1.3527, -1.6959]],

        [[-0.3160, -2.1152,  0.3223, -1.2633,  0.3500,  0.3081],
         [-1.1258, -1.1524, -0.2506, -0.4339,  0.8487,  0.6920]],

        [[ 0.1198,  1.2377,  1.1168, -0.2473, -1.3527, -1.6959],
         [ 0.1198,  1.2377,  1.1168, -0.2473, -1.3527, -1.6959]],

        [[ 0.5667,  0.7935,  0.5988, -1.5551, -0.3414,  1.8530],
         [ 0.9463, -0.8437, -0.6136,  0.0316, -0.4927,  0.2484]]],
       grad_fn=<EmbeddingBackward0>)

Target embeddings :- 

tensor([[[-8.3832e-01,  8.9182e-04, -7.5043e-01,  1.8541e-01,  6.2114e-01,
           6.3818e-01],
         [-3.7015e-01, -1.2103e+00,  1.1404e+00, -8.9882e-02,  7.2980e-01,
          -1.8453e+00]],

        [[-2.0252e-02, -4.3717e-01,  1.6459e+00, -1.3602e+00,  3.4457e-01,
           5.1987e-01],
         [-3.6562e-01, -1.3024e+00,  9.9403e-02,  4.4182e-01,  2.4693e-01,
           7.6

## Functions to get the manual calulations for each component

In [188]:
from torch import functional as F

# Using the state dictionary to get the intermediate outputs
state_dict = model.state_dict()

need_weights = False

src_mask = None

tgt_mask = None

memory_mask = None

embed_dim = 6

num_heads = 2

max_seq_len = 5


In [189]:
def look_up_table(sentence, vocab_embeds, embedding):

    for i in range(sentence.size(0)):
        for j in range(sentence.size(1)):
            # Get the index for the current word token index in the sequence
            word_index = sentence[i, j].item()

            if word_index < 0 or word_index >= vocab_embeds.size(0):
                raise ValueError(f"Invalid word index: {word_index}")

            # Lookup the corresponding embedding vector for the word
            embedding[i, j, :] = vocab_embeds[word_index, :]

            print(f"Word index: {word_index}, Embedding: {vocab_embeds[word_index, :]}")
    print()
    

    return embedding

### Embeddings and Positional encoding

In [190]:
def get_embedding_outputs(src_sentence,  tgt_sentence, max_seq_len, state_dict, d_model):

    src_vocab_embeds = state_dict["src_embedding.weight"]

    src_embedding = torch.zeros(src_sentence.size(0), src_sentence.size(1), d_model)
    print("Source sentence embedding")
    src_embedding =  look_up_table(src_sentence, src_vocab_embeds, src_embedding)
    print(src_embedding.shape)

    tgt_vocab_embeds = state_dict["tgt_embedding.weight"]

    tgt_embedding = torch.zeros(tgt_sentence.size(0), tgt_sentence.size(1), d_model)


    print("Target sentence embedding")
    tgt_embedding =  look_up_table(tgt_sentence, tgt_vocab_embeds, tgt_embedding)

    pe = PositionalEncoding(d_model = d_model, dropout=0, max_len=max_seq_len)

    print("PE of src :")
    print(pe(src_embedding))
    print()
    print("PE of tgt :")
    print(pe(tgt_embedding))
    print()

    pe_src_embeds = src_embedding + pe(src_embedding)

    pe_tgt_embeds = tgt_embedding + pe(tgt_embedding)

    print("PE source embeddings : \n")
    print(pe_src_embeds)
    print()

    print("PE target embeddings : \n")
    print(pe_tgt_embeds)
    print()

    return pe_src_embeds, pe_tgt_embeds



In [191]:
pe_src_embeds, pe_tgt_embeds = get_embedding_outputs(src_sentence=src_sentence,  tgt_sentence=tgt_sentence, state_dict=state_dict, max_seq_len=max_seq_len, d_model = d_model)


Source sentence embedding
Word index: 0, Embedding: tensor([-1.1258, -1.1524, -0.2506, -0.4339,  0.8487,  0.6920])
Word index: 2, Embedding: tensor([ 0.1198,  1.2377,  1.1168, -0.2473, -1.3527, -1.6959])
Word index: 1, Embedding: tensor([-0.3160, -2.1152,  0.3223, -1.2633,  0.3500,  0.3081])
Word index: 0, Embedding: tensor([-1.1258, -1.1524, -0.2506, -0.4339,  0.8487,  0.6920])
Word index: 2, Embedding: tensor([ 0.1198,  1.2377,  1.1168, -0.2473, -1.3527, -1.6959])
Word index: 2, Embedding: tensor([ 0.1198,  1.2377,  1.1168, -0.2473, -1.3527, -1.6959])
Word index: 3, Embedding: tensor([ 0.5667,  0.7935,  0.5988, -1.5551, -0.3414,  1.8530])
Word index: 5, Embedding: tensor([ 0.9463, -0.8437, -0.6136,  0.0316, -0.4927,  0.2484])

torch.Size([4, 2, 6])
Target sentence embedding
Word index: 1, Embedding: tensor([-0.8383,  0.0009, -0.7504,  0.1854,  0.6211,  0.6382])
Word index: 7, Embedding: tensor([-0.3701, -1.2103,  1.1404, -0.0899,  0.7298, -1.8453])
Word index: 3, Embedding: tensor([-

## Encoder function to display the intermediate outputs and get the final outputs from the encoder

### Self attention 

In [223]:
def atten_product_needs_wts_false(Q, V, K, bsz, head_dim, src_len, tgt_len, attn_mask):


    # *** For multi-head attention ***
    #  (bsz*num_heads, src_len , head_dim) -> (bsz, num_heads, tgt_len, head_dim)
    Q1 = Q.view(bsz, num_heads, tgt_len, head_dim)
    K1 = K.view(bsz, num_heads, src_len, head_dim)
    V1 = V.view(bsz, num_heads, src_len, head_dim)

    L, S = Q1.size(-2), K1.size(-2)

    scale_factor = 1 / math.sqrt(Q1.size(-1)) 
    # scale_factor = 1
    attn_bias = torch.zeros(L, S, dtype=Q1.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias += attn_mask

    # (bsz, num_heads, tgt_len, head_dim) @ (bsz, num_heads, head_dim, tgt_len) -> (bsz, num_heads, tgt_len, tgt_len) 
    attn_weight = Q1 @ K1.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias

    # (bsz, num_heads, tgt_len, tgt_len) 
    attn_weight = torch.softmax(attn_weight, dim=-1)

    print("Attention weights = ")
    print(attn_weight.shape, attn_weight)

    sum_last_dim = attn_weight.sum(dim=-1)

    tolerance = 1e-6  
    assert torch.allclose(sum_last_dim, torch.ones_like(sum_last_dim), atol=tolerance), "Attention weights sum is not approximately equal to 1"

    # (bsz, num_heads, tgt_len, tgt_len) @ (bsz, num_heads, tgt_len, head_dim) -> (bsz, num_heads, tgt_len, head_dim) 
    attn_output = attn_weight @ V1

    # (bsz*tgt_len, embed_dim)
    attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)

    return attn_output



def atten_product_needs_wts_true(Q, K, V, bsz, tgt_len, attn_mask):

    # *** For multi-head attention ***
    #  (bsz*num_heads, src_len , head_dim)
    B, Nt, E = Q.shape

    Q_scaled = Q / math.sqrt(E)

    if attn_mask is not None:
        temp_pdt_matrix = torch.baddbmm(attn_mask, Q_scaled, K.transpose(-2, -1))
    else:
        temp_pdt_matrix = torch.bmm(Q_scaled, K.transpose(-2, -1))

    attn_wt_matrix = torch.nn.functional.softmax(temp_pdt_matrix, dim=-1)

    attn_output = torch.bmm(attn_wt_matrix, V)

    attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)


    sum_last_dim = attn_wt_matrix.sum(dim=-1)

    tolerance = 1e-6  
    assert torch.allclose(sum_last_dim, torch.ones_like(sum_last_dim), atol=tolerance), "Attention weights sum is not approximately equal to 1"


    print("Encoder Attention output = ")
    print(attn_output)
    print()

    return attn_output, attn_wt_matrix


In [224]:
def get_qkv(query, key, value ,W, b):

    # embed_dim
    E = query.size(-1)

    if key is value:
        if query is key:
            
            # (src_len, bsz, embed_dim) @ (embed_dim*num_heads, embed_dim).T -> (src_len, bsz, embed_dim*num_heads)
            tempop1 = query@W.T

            # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
            tempop1 = tempop1.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()

            # (src_len, bsz, embed_dim)
            return tempop1[0], tempop1[1], tempop1[2]
        

        else:

            # (embed_dim*1, embed_dim)
            # (embed_dim*2, embed_dim)
            W_q, W_kv = W.split([E, E * 2])

            if b is None:
                b_q = b_kv = None
            else:
                b_q, b_kv = b.split([E, E * 2])

            # (src_len, bsz, embed_dim) @ (embed_dim*1, embed_dim).T -> (src_len, bsz, embed_dim)
            q_matmul = query@W_q.T

            # (src_len, bsz, embed_dim) @ (embed_dim*2, embed_dim).T -> (src_len, bsz, embed_dim*2)
            kv_matmul = key@W_kv.T

            kv_matmul = kv_matmul.unflatten(-1, (2, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()

            # (src_len, bsz, embed_dim)
            return q_matmul, kv_matmul[0], kv_matmul[1]

    else:

        W_q, W_k, W_v = W.chunk(3)
        if b is None:
            b_q = b_k = b_v = None
        else:
            b_q, b_k, b_v = b.chunk(3)


        q_matmul = query@W_q.T
        k_matmul = key@W_k.T
        v_matmul = value@W_v.T

        return q_matmul, k_matmul, v_matmul



    

In [252]:
def encoder_block_attn_output(x, state_dict, layer_num, embed_dim, num_heads, need_weights = False, src_mask = None):

    # (src_len, bsz, embed_dim)
    query_enc = key_enc = value_enc = x

    tgt_len, bsz, embed_dim = x.shape

    # print(x.shape)
    
    # (embed_dim*num_heads, embed_dim)
    W_enc = state_dict["transformer.encoder.layers.{}.self_attn.in_proj_weight".format(layer_num)]
    b_enc = state_dict["transformer.encoder.layers.{}.self_attn.in_proj_bias".format(layer_num)]


    head_dim = embed_dim//num_heads
    
    # (src_len, bsz, embed_dim)
    Q_enc,K_enc,V_enc = get_qkv(query_enc, key_enc, value_enc ,W_enc, b_enc)

    # (1, src_len, bsz, embed_dim)
    # Q_enc = Q_enc.unsqueeze(0)
    # K_enc = K_enc.unsqueeze(0)
    # V_enc = V_enc.unsqueeze(0)

    # (1, src_len, bsz, embed_dim) -> ( bsz*num_heads, src_len , head_dim)
    Q_enc = Q_enc.reshape(Q_enc.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
    K_enc = K_enc.reshape(K_enc.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
    V_enc = V_enc.reshape(V_enc.shape[0], bsz * num_heads, head_dim).transpose(0, 1)

    print("Q_enc_{} = ".format(layer_num))
    print(Q_enc)
    print()

    print("K_enc_{} = ".format(layer_num))
    print(K_enc)
    print()

    print("V_enc_{} = ".format(layer_num))
    print(V_enc)
    print()


    src_len = K_enc.size(1)

    attn_mask = src_mask
    if attn_mask is not None:

        # Ensuring attn_mask's dim is 3
        if attn_mask.dim() == 2:
            correct_2d_size = (tgt_len, src_len)
            if attn_mask.shape != correct_2d_size:
                raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
            attn_mask = attn_mask.unsqueeze(0)
        elif attn_mask.dim() == 3:
            correct_3d_size = (bsz * num_heads, tgt_len, src_len)
            if attn_mask.shape != correct_3d_size:
                raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
        else:
            raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")

    # attn_mask can be either (L,S) or (N*num_heads, L, S)
    # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
    # in order to match the input for SDPA of (N, num_heads, L, S)
    if attn_mask is not None:
        if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
            attn_mask = attn_mask.unsqueeze(0)
        else:
            attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)


    if need_weights is False:

        attn_output = atten_product_needs_wts_false(Q_enc, V_enc, K_enc, bsz, head_dim, src_len, tgt_len, attn_mask)

        return attn_output, src_len, head_dim, None
    
    else:

        attn_enc_output,attn_wt_matrix_enc = atten_product_needs_wts_true(Q_enc, K_enc, V_enc, bsz, tgt_len, attn_mask)

        return attn_enc_output, src_len, head_dim, attn_wt_matrix_enc

    

In [253]:
# need_weights = True

attn_enc_output, src_len, head_dim, attn_weights = encoder_block_attn_output(pe_src_embeds, state_dict, layer_num = 0, need_weights = need_weights, embed_dim=embed_dim, num_heads=num_heads)


Q_enc_0 = 
tensor([[[-0.1929,  1.2673,  1.2277],
         [-0.1973,  0.2204,  0.7046],
         [ 0.4541,  0.1583, -0.5595],
         [ 1.4683,  1.0174,  0.7612]],

        [[-0.8873,  0.1377, -1.1826],
         [-1.0105,  0.2393, -0.3701],
         [ 0.2194,  0.4263,  0.6572],
         [-0.8920, -0.3286, -1.3039]],

        [[ 0.7191, -0.1656, -0.8970],
         [ 0.0721,  0.9433,  0.8903],
         [ 0.7191, -0.1656, -0.8970],
         [ 1.1037,  0.6604, -0.7935]],

        [[ 0.1979,  0.1143,  1.1358],
         [-0.9088, -0.1743, -0.7039],
         [ 0.1979,  0.1143,  1.1358],
         [-0.6033, -0.9457,  0.4983]]])

K_enc_0 = 
tensor([[[ 0.7118, -0.1985,  0.4590],
         [ 0.4985, -0.2911, -0.1758],
         [-0.5501, -0.3624,  0.1840],
         [-0.0511, -1.8880, -0.6759]],

        [[ 0.4204, -0.8352, -1.2161],
         [ 0.6379, -0.1607, -0.2691],
         [-0.6402,  0.4023,  0.7949],
         [-0.1434,  0.6434,  0.6594]],

        [[-0.4659, -0.2948, -0.0826],
         [ 0.79

### Post self attention in encoder

In [254]:
def encoder_block_post_attn_output(x, attn_enc_output, state_dict, layer_num, bsz, tgt_len):

    # (bsz*src_len , embed_dim) @ (embed_dim , embed_dim).T -> (bsz*src_len , embed_dim)
    op_enc_1 = torch.matmul(attn_enc_output, state_dict["transformer.encoder.layers.{}.self_attn.out_proj.weight".format(layer_num)].t()) + state_dict["transformer.encoder.layers.{}.self_attn.out_proj.bias".format(layer_num)]



    # (bsz*src_len , embed_dim) -> (src_len, bsz, embed_dim)
    attn_enc_output = op_enc_1.view(tgt_len, bsz, attn_enc_output.size(1))


    # Here src is the original passed inputs to the 1st transformer encoder layer which 
    # are pe_src_embeds 

    # (src_len, bsz, embed_dim)
    output_enc_1 = attn_enc_output + x

    #  (src_len, bsz, embed_dim) @ (embed_dim) -> (src_len, bsz, embed_dim) 
    linear_result_enc_1 = output_enc_1*state_dict["transformer.encoder.layers.{}.norm1.weight".format(layer_num)] + state_dict["transformer.encoder.layers.{}.norm1.bias".format(layer_num)]

    # Layer normalization from Torch's implementation
    layernorm_enc_1 = torch.nn.LayerNorm(normalized_shape=linear_result_enc_1.shape[2:])
    linear_op_enc_1 = layernorm_enc_1(linear_result_enc_1)


    # Manual Layer Normalization
    x = linear_result_enc_1

    # Obtained layer norm weights and biases (learnable)
    w = layernorm_enc_1.weight
    b = layernorm_enc_1.bias

    linear_result_enc_1f = w*x + b

    epsilon = 1e-5  
    mean = linear_result_enc_1f.mean(dim=-1, keepdim=True)
    std = linear_result_enc_1f.std(dim=-1, unbiased=False, keepdim=True)
    normalized_result_enc_1 = (linear_result_enc_1f - mean) / (std + epsilon) * w + b


    op_enc_1 = torch.matmul(normalized_result_enc_1, state_dict["transformer.encoder.layers.{}.linear1.weight".format(layer_num)].t()) + state_dict["transformer.encoder.layers.{}.linear1.bias".format(layer_num)]
    op_enc_1_relu = torch.nn.functional.relu(op_enc_1)
    op_enc_2 = torch.matmul(op_enc_1_relu, state_dict["transformer.encoder.layers.{}.linear2.weight".format(layer_num)].t()) + state_dict["transformer.encoder.layers.{}.linear2.bias".format(layer_num)]


    output_enc_2 = op_enc_2 + linear_op_enc_1
    output_enc_2_norm = output_enc_2*state_dict["transformer.encoder.layers.{}.norm2.weight".format(layer_num)] + state_dict["transformer.encoder.layers.{}.norm2.bias".format(layer_num)]

    # Layer normalization from Torch's implementation
    layernorm_enc_final = torch.nn.LayerNorm(normalized_shape=output_enc_2_norm.shape[2:])
    output_enc_final = layernorm_enc_final(output_enc_2_norm)


    # Manual Layer Normalization 
    x = output_enc_2_norm

    # Obtained layer norm weights and biases (learnable)
    w = layernorm_enc_final.weight
    b = layernorm_enc_final.bias

    linear_result_enc_2 = w*x + b

    epsilon = 1e-5  
    mean = linear_result_enc_2.mean(dim=-1, keepdim=True)
    std = linear_result_enc_2.std(dim=-1, unbiased=False, keepdim=True)
    output_enc_final = (linear_result_enc_2 - mean) / (std + epsilon) * w + b

    print("Final Encoder {} Output :".format(layer_num))
    print("norm2(norm1(x + self_atten(x)) + feed_fwd_op)\n")
    print(output_enc_final)
    print()

    # (src_len, bsz, embed_dim) 
    return output_enc_final


In [255]:
tgt_len, bsz, embed_dim = pe_src_embeds.shape
output_enc_final = encoder_block_post_attn_output(pe_src_embeds, attn_enc_output, state_dict, layer_num = 0 , bsz = bsz, tgt_len = tgt_len)

Final Encoder 0 Output :
norm2(norm1(x + self_atten(x)) + feed_fwd_op)

tensor([[[-1.4971, -0.8204, -0.0893,  0.9685, -0.0243,  1.4625],
         [ 0.6772,  1.2398,  0.4367,  0.3601, -1.3300, -1.3837]],

        [[-0.6427, -1.7172,  0.7557,  0.3321, -0.1216,  1.3936],
         [-0.4358, -1.5367, -0.8115,  0.7172,  0.7689,  1.2979]],

        [[-0.3059,  1.1240,  0.7801,  0.8206, -1.6761, -0.7427],
         [ 0.6772,  1.2398,  0.4367,  0.3601, -1.3300, -1.3837]],

        [[-0.2236,  0.3292,  0.1818, -0.5737, -1.5059,  1.7921],
         [ 1.6791, -0.7317, -1.0684,  0.7775, -0.9129,  0.2563]]],
       grad_fn=<AddBackward0>)



## Decoder functions to display the intermediate outputs and get the final outputs from the decoder


### Self attention outputs from a decoder block

In [256]:
def decoder_block_self_attn_output(x, state_dict, layer_num, tgt_mask = None,need_weights = False):

    # (tgt_len, bsz, embed_dim)
    query_dec = key_dec = value_dec = x

    tgt_len, bsz, embed_dim = x.shape
 
    # (embed_dim*num_heads, embed_dim)
    W_dec = state_dict["transformer.decoder.layers.{}.self_attn.in_proj_weight".format(layer_num)]
    b_dec = state_dict["transformer.decoder.layers.{}.self_attn.in_proj_bias".format(layer_num)]


    head_dim = embed_dim//num_heads

    # (tgt_len, bsz, embed_dim)
    Q_dec,K_dec,V_dec = get_qkv(query_dec, key_dec, value_dec ,W_dec, b_dec)
    
    # Q_dec = Q_dec.unsqueeze(0)
    # K_dec = K_dec.unsqueeze(0)
    # V_dec = V_dec.unsqueeze(0)

    # (1, tgt_len, bsz, embed_dim)
    print(Q_dec.shape, K_dec.shape , V_dec.shape)
    print(tgt_len, bsz * num_heads, head_dim)

    # (1, tgt_len, bsz, embed_dim) -> ( bsz*num_heads, tgt_len , head_dim )
    Q_dec = Q_dec.reshape(Q_dec.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
    K_dec = K_dec.reshape(K_dec.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
    V_dec = V_dec.reshape(V_dec.shape[0], bsz * num_heads, head_dim).transpose(0, 1)

    print("Q_dec_{} = ".format(layer_num))
    print(Q_dec)
    print()

    print("K_dec_{} = ".format(layer_num))
    print(K_dec)
    print()

    print("V_dec_{} = ".format(layer_num))
    print(V_dec)
    print()

    src_len = K_dec.size(1)


    attn_mask = tgt_mask
    if attn_mask is not None:

        # Ensuring attn_mask's dim is 3
        if attn_mask.dim() == 2:
            correct_2d_size = (tgt_len, src_len)
            if attn_mask.shape != correct_2d_size:
                raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
            attn_mask = attn_mask.unsqueeze(0)
        elif attn_mask.dim() == 3:
            correct_3d_size = (bsz * num_heads, tgt_len, src_len)
            if attn_mask.shape != correct_3d_size:
                raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
        else:
            raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")

    # attn_mask can be either (L,S) or (N*num_heads, L, S)
    # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
    # in order to match the input for SDPA of (N, num_heads, L, S)
    if attn_mask is not None:
        if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
            attn_mask = attn_mask.unsqueeze(0)
        else:
            attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)


    

    if need_weights is False:
        attn_output = atten_product_needs_wts_false(Q_dec, V_dec, K_dec, bsz, head_dim, src_len, tgt_len, attn_mask)

        print("Decoder Self Attention = ")
        print(attn_output)
        print()

        op_dec_1 = torch.matmul(attn_output, state_dict["transformer.decoder.layers.{}.self_attn.out_proj.weight".format(layer_num)].t()) + state_dict["transformer.decoder.layers.{}.self_attn.out_proj.bias".format(layer_num)]
        attn_dec_output = op_dec_1.view(tgt_len, bsz, attn_output.size(1))

        return attn_dec_output, None
    
    else:

        attn_dec_output,attn_wt_matrix_dec = atten_product_needs_wts_true(Q_dec, K_dec, V_dec, bsz, tgt_len, attn_mask)

        print("Decoder Attention output = ")
        print(attn_wt_matrix_dec)
        print()

        op_dec_1 = torch.matmul(attn_dec_output, state_dict["transformer.decoder.layers.{}.self_attn.out_proj.weight".format(layer_num)].t()) + state_dict["transformer.decoder.layers.{}.self_attn.out_proj.bias".format(layer_num)]
        attn_dec_output = op_dec_1.view(tgt_len, bsz, attn_dec_output.size(1))
    

        return attn_dec_output, attn_wt_matrix_dec


    

In [257]:
need_weights 

False

In [258]:
# need_weights = True

self_attn_dec, dec_sa_wts = decoder_block_self_attn_output(pe_tgt_embeds, state_dict, layer_num = 0, need_weights = need_weights)

torch.Size([5, 2, 6]) torch.Size([5, 2, 6]) torch.Size([5, 2, 6])
5 4 3
Q_dec_0 = 
tensor([[[ 0.7030, -0.8076,  0.3689],
         [-0.6962, -0.8609, -0.2664],
         [-0.3426,  0.1588,  0.7452],
         [-1.3988, -1.9553,  1.4398],
         [ 0.8151, -0.3381, -0.6950]],

        [[-1.1457,  1.3671, -1.2293],
         [ 0.3132,  0.1273, -0.3309],
         [-1.0411, -0.0952, -0.4006],
         [-0.0275,  0.6348, -0.6489],
         [ 0.5771,  0.8648,  0.1867]],

        [[-0.5888,  0.3034, -0.0136],
         [ 0.3244, -0.7439, -0.1468],
         [-0.2403,  0.2895, -0.0071],
         [ 0.8990, -0.9110,  0.0977],
         [ 0.4716, -0.6661,  0.0042]],

        [[ 1.1584, -0.1576,  0.6310],
         [ 0.4388,  0.9813,  0.3833],
         [-2.4030,  2.8845, -2.7083],
         [-0.8612,  1.2382, -0.5424],
         [-0.7776,  1.5754, -0.8309]]])

K_dec_0 = 
tensor([[[-0.2597,  0.2710,  1.8737],
         [-0.4699, -0.5533,  0.2744],
         [-0.3356, -0.5741,  0.3522],
         [-1.9068, -1.0

In [259]:
self_attn_dec

tensor([[[ 0.2572,  0.0169, -0.2472,  0.6397,  0.3312,  0.2136],
         [ 0.3283,  0.7148, -0.9729,  0.7399,  0.6407, -0.2077]],

        [[ 0.2581, -0.0156, -0.2736,  0.9474,  0.3807,  0.1914],
         [ 0.3769,  0.5179, -0.8612,  0.6676,  0.5279, -0.1483]],

        [[ 0.4009, -0.0706, -0.2101,  1.0242,  0.2440,  0.1878],
         [ 0.6117, -0.4366, -0.4381,  0.7823,  0.0391, -0.0629]],

        [[ 0.5340, -0.3637, -0.1184,  1.4839,  0.4896,  0.7024],
         [ 0.5136,  0.0206, -0.6401,  0.6380,  0.2400, -0.1107]],

        [[ 0.1315,  0.2273, -0.2956,  0.4926,  0.3270,  0.0492],
         [ 0.5247, -0.0511, -0.6052,  0.6707,  0.2199, -0.0904]]])

In [260]:
def dec_post_self_attn(self_attn_dec, x, state_dict, layer_num):

    # (bsz*tgt_len , embed_dim) @ (embed_dim , embed_dim).T -> (bsz*tgt_len , embed_dim)
    output_dec_1 = self_attn_dec + x

    # (bsz*tgt_len , embed_dim) @ (embed_dim , embed_dim).T -> (bsz*tgt_len , embed_dim)
    linear_result_dec_1 = output_dec_1*state_dict["transformer.decoder.layers.{}.norm1.weight".format(layer_num)] + state_dict["transformer.decoder.layers.{}.norm1.bias".format(layer_num)]

    layernorm_dec_1 = torch.nn.LayerNorm(normalized_shape=linear_result_dec_1.shape[2:])
    linear_op_dec_1 = layernorm_dec_1(linear_result_dec_1)

    # From torch's implementation
    # print(linear_op_dec_1)

    x = linear_result_dec_1

    # Obtained layer norm weights and biases (learnable)
    w = layernorm_dec_1.weight
    b = layernorm_dec_1.bias

    linear_result_dec_1f = w*x + b

    epsilon = 1e-5  
    mean = linear_result_dec_1f.mean(dim=-1, keepdim=True)
    std = linear_result_dec_1f.std(dim=-1, unbiased=False, keepdim=True)
    normalized_result_dec_1 = (linear_result_dec_1f - mean) / (std + epsilon) * w + b


    print("Decoder_{} norm1(x + sa(x))".format(layer_num))
    print(normalized_result_dec_1)
    print()

    # (tgt_len, bsz, embed_dim)
    return normalized_result_dec_1

In [261]:
x_dec = dec_post_self_attn(self_attn_dec, pe_tgt_embeds, state_dict, layer_num = 0)

Decoder_0 norm1(x + sa(x))
tensor([[[-1.1433,  0.3085, -1.5215,  1.0415,  0.2490,  1.0658],
         [ 0.3276, -0.5101, -0.3225,  1.2700,  0.9634, -1.7284]],

        [[-1.2253, -0.6194,  0.9963, -0.5410, -0.2708,  1.6602],
         [ 0.2594, -0.9520, -1.4724,  1.6461,  0.1757,  0.3432]],

        [[ 0.6545,  1.4398, -1.4784,  0.6425, -0.8618, -0.3966],
         [ 0.0491,  0.8160, -2.1772,  0.3844,  0.5451,  0.3826]],

        [[-1.4528, -0.0252,  0.4831,  1.3705, -1.1115,  0.7359],
         [-0.1049, -0.0916, -1.8584,  1.1055, -0.1787,  1.1281]],

        [[-0.4127, -1.6150, -0.7144,  1.2463,  0.8871,  0.6087],
         [-0.1294, -0.1666, -1.9404,  1.1789,  0.1815,  0.8760]]],
       grad_fn=<AddBackward0>)



## Cross attention in decoder 
### query, key, value =  x, mem, mem 

### Cross attention between the encoder's final output and the decoder layer 

In [262]:

memory = output_enc_final
x_dec

tensor([[[-1.1433,  0.3085, -1.5215,  1.0415,  0.2490,  1.0658],
         [ 0.3276, -0.5101, -0.3225,  1.2700,  0.9634, -1.7284]],

        [[-1.2253, -0.6194,  0.9963, -0.5410, -0.2708,  1.6602],
         [ 0.2594, -0.9520, -1.4724,  1.6461,  0.1757,  0.3432]],

        [[ 0.6545,  1.4398, -1.4784,  0.6425, -0.8618, -0.3966],
         [ 0.0491,  0.8160, -2.1772,  0.3844,  0.5451,  0.3826]],

        [[-1.4528, -0.0252,  0.4831,  1.3705, -1.1115,  0.7359],
         [-0.1049, -0.0916, -1.8584,  1.1055, -0.1787,  1.1281]],

        [[-0.4127, -1.6150, -0.7144,  1.2463,  0.8871,  0.6087],
         [-0.1294, -0.1666, -1.9404,  1.1789,  0.1815,  0.8760]]],
       grad_fn=<AddBackward0>)

In [263]:
memory.shape

torch.Size([4, 2, 6])

In [264]:
def decoder_block_cross_attn_output(x_dec, memory, state_dict, layer_num, tgt_len, src_len, head_dim, memory_mask = None,need_weights = False):

    # (tgt_len, bsz, embed_dim)
    query_dec_mha = x_dec

    # (src_len, bsz, embed_dim)
    key_dec_mha, value_dec_mha = memory, memory


    tgt_len, bsz, embed_dim = query_dec_mha.shape


    W_dec_mha = state_dict["transformer.decoder.layers.{}.multihead_attn.in_proj_weight".format(layer_num)]
    b_dec_mha = state_dict["transformer.decoder.layers.{}.multihead_attn.in_proj_bias".format(layer_num)]


    ########## CHANGE FROM SELF ATTENTION ############

    Q_dec_mha,K_dec_mha,V_dec_mha = get_qkv(query_dec_mha, key_dec_mha, value_dec_mha ,W_dec_mha, b_dec_mha)


    #################################################

    # K_dec_mha = K_dec_mha.unsqueeze(0)
    # V_dec_mha = V_dec_mha.unsqueeze(0)
    # Q_dec_mha = Q_dec_mha.unsqueeze(0)


    Q_dec_mha = Q_dec_mha.reshape(Q_dec_mha.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
    K_dec_mha = K_dec_mha.reshape(K_dec_mha.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
    V_dec_mha = V_dec_mha.reshape(V_dec_mha.shape[0], bsz * num_heads, head_dim).transpose(0, 1)

    print("Q_dec_{} = ".format(layer_num))
    print(Q_dec_mha)
    print()

    print("K_dec_{} = ".format(layer_num))
    print(K_dec_mha)
    print()

    print("V_dec_{} = ".format(layer_num))
    print(V_dec_mha)
    print()


    attn_mask = memory_mask
    if attn_mask is not None:

        # Ensuring attn_mask's dim is 3
        if attn_mask.dim() == 2:
            correct_2d_size = (tgt_len, src_len)
            if attn_mask.shape != correct_2d_size:
                raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
            attn_mask = attn_mask.unsqueeze(0)
        elif attn_mask.dim() == 3:
            correct_3d_size = (bsz * num_heads, tgt_len, src_len)
            if attn_mask.shape != correct_3d_size:
                raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
        else:
            raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")

    # attn_mask can be either (L,S) or (N*num_heads, L, S)
    # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
    # in order to match the input for SDPA of (N, num_heads, L, S)
    if attn_mask is not None:
        if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
            attn_mask = attn_mask.unsqueeze(0)
        else:
            attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)



    if need_weights is False:

        attn_output_dec_mha = atten_product_needs_wts_false(Q_dec_mha, V_dec_mha, K_dec_mha, bsz, head_dim, src_len, tgt_len, attn_mask)
        
        print("Cross attention in decoder_{}".format(layer_num))
        print(attn_output_dec_mha)
        print()

        op_dec_mha_1 = torch.matmul(attn_output_dec_mha, state_dict["transformer.decoder.layers.{}.multihead_attn.out_proj.weight".format(layer_num)].t()) + state_dict["transformer.decoder.layers.{}.multihead_attn.out_proj.bias".format(layer_num)]

        attn_dec_mha_output = op_dec_mha_1.view(tgt_len, bsz, attn_output_dec_mha.size(1))

        return attn_dec_mha_output, None
    
    else: 

    
        attn_dec_mha_output ,attn_wt_matrix_dec_mha = atten_product_needs_wts_true(Q_dec_mha, K_dec_mha, V_dec_mha, bsz, tgt_len, attn_mask)

        print("Decoder mha Attention output = ")
        print(attn_dec_mha_output)
        print()


        op_dec_mha = torch.matmul(attn_dec_mha_output, state_dict["transformer.decoder.layers.{}.multihead_attn.out_proj.weight".format(layer_num)].t()) + state_dict["transformer.decoder.layers.{}.multihead_attn.out_proj.bias".format(layer_num)]
        attn_dec_output_mha = op_dec_mha.view(tgt_len, bsz, attn_dec_mha_output.size(1))


        return attn_dec_output_mha , attn_wt_matrix_dec_mha

        


In [265]:
# need_weights = True

attn_dec_mha_output, attn_dec_mha_wts = decoder_block_cross_attn_output(x_dec, memory, state_dict, layer_num = 0, tgt_len = tgt_len, src_len = src_len, head_dim = head_dim, need_weights = need_weights)

Q_dec_0 = 
tensor([[[ 0.8053, -0.0713,  0.7679],
         [-0.8762,  0.1282,  0.2041],
         [ 1.8555,  0.2675,  0.3083],
         [ 0.4550,  0.0176,  0.1286],
         [-0.4132, -0.8115,  0.1700]],

        [[-0.1601,  0.3679,  0.6959],
         [ 0.1825,  0.3493, -0.3748],
         [ 0.3822,  0.4362,  1.0671],
         [ 0.4914,  0.8859, -0.3512],
         [-1.1962, -0.3612, -0.2075]],

        [[-0.0825, -0.5636, -0.3997],
         [ 0.6957, -0.7197,  0.2492],
         [ 1.1379,  0.0230,  0.7246],
         [ 1.1512, -0.2868,  0.6664],
         [ 1.0151, -0.3636,  0.6390]],

        [[-0.5889, -0.5484, -0.5096],
         [-1.0535, -0.1210,  0.3782],
         [-0.2853, -0.0195,  1.2119],
         [-0.5751,  0.2492,  0.9650],
         [-0.6918,  0.0941,  0.8875]]], grad_fn=<TransposeBackward0>)

K_dec_0 = 
tensor([[[-0.7485,  1.1345,  0.6806],
         [-1.0826,  1.7761,  0.2926],
         [-0.6303,  0.0104,  0.8447],
         [-0.3431,  1.4132,  0.9495]],

        [[ 0.0087,  1.013

In [266]:
def decoder_block_post_attn_output(x_dec, attn_dec_mha_output, state_dict, layer_num):

    # (tgt_len, bsz, embed_dim)
    output_dec_2 = attn_dec_mha_output + x_dec

    linear_result_dec_2 = output_dec_2*state_dict["transformer.decoder.layers.{}.norm2.weight".format(layer_num)] + state_dict["transformer.decoder.layers.{}.norm2.bias".format(layer_num)]

    # Layer normalization from Torch's implementation 
    layernorm_dec_2 = torch.nn.LayerNorm(normalized_shape=linear_result_dec_2.shape[2:])
    linear_op_dec_2 = layernorm_dec_2(linear_result_dec_2)

    x = linear_result_dec_2
    w = layernorm_dec_2.weight
    b = layernorm_dec_2.bias

    linear_result_dec_2f = w*x + b

    epsilon = 1e-5  
    mean = linear_result_dec_2f.mean(dim=-1, keepdim=True)
    std = linear_result_dec_2f.std(dim=-1, unbiased=False, keepdim=True)
    normalized_result_dec_2 = (linear_result_dec_2f - mean) / (std + epsilon) * w + b

    print("norm2(x' + mha(x', mem)), \n where x' = Decoder_curr_layer norm1(x + sa(x))")
    print(normalized_result_dec_2)
    print("\n\n")

    x_dec2_norm = normalized_result_dec_2

    op_dec_1 = torch.matmul(x_dec2_norm, state_dict["transformer.decoder.layers.{}.linear1.weight".format(layer_num)].t()) + state_dict["transformer.decoder.layers.{}.linear1.bias".format(layer_num)]
    op_dec_1_relu = torch.nn.functional.relu(op_dec_1)
    op_dec_2 = torch.matmul(op_dec_1_relu, state_dict["transformer.decoder.layers.{}.linear2.weight".format(layer_num)].t()) + state_dict["transformer.decoder.layers.{}.linear2.bias".format(layer_num)]

    ff_dec = op_dec_2

    x_dec3_unorm = x_dec2_norm + ff_dec

    linear_result_dec_3 = x_dec3_unorm*state_dict["transformer.decoder.layers.0.norm3.weight"] + state_dict["transformer.decoder.layers.0.norm3.bias"]

    # Layer normalization from Torch's implementation 
    layernorm_dec_3 = torch.nn.LayerNorm(normalized_shape=linear_result_dec_3.shape[2:])
    linear_op_dec_3 = layernorm_dec_3(linear_result_dec_3)

    # From torch's implementation
    x = linear_result_dec_3

    # Obtained layer norm weights and biases (learnable)
    w = layernorm_dec_3.weight
    b = layernorm_dec_3.bias

    linear_result_dec_3f = w*x + b

    epsilon = 1e-5  
    mean = linear_result_dec_3f.mean(dim=-1, keepdim=True)
    std = linear_result_dec_3f.std(dim=-1, unbiased=False, keepdim=True)
    normalized_result_dec_3 = (linear_result_dec_3f - mean) / (std + epsilon) * w + b

    print("norm3(x'' + ff(x'')) \n where, x'' = Decoder_curr_layer norm2(x' + mha(x'))")
    print(normalized_result_dec_3)
    print()

    return normalized_result_dec_3

def feef_fwd_transformer(dec_output_final, state_dict):

    # (tgt_len, bsz, embed_dim) @ (vocab_size, embed_dim).T -> (tgt_len, bsz, vocab_size)
    final_op = dec_output_final@state_dict["fc.weight"].T + state_dict["fc.bias"]

    return final_op

In [267]:
final_attn_op_decoder = decoder_block_post_attn_output(x_dec, attn_dec_mha_output, state_dict, layer_num = 0)

norm2(x' + mha(x', mem)), 
 where x' = Decoder_curr_layer norm1(x + sa(x))
tensor([[[-1.4574,  0.1825, -1.1540,  0.7673,  0.3103,  1.3513],
         [ 0.3961, -0.3919, -0.4246,  1.3567,  0.8027, -1.7390]],

        [[-1.4126, -0.4391,  0.9497, -0.5774, -0.1213,  1.6006],
         [ 0.4802, -0.9630, -1.4157,  1.6747,  0.0487,  0.1751]],

        [[ 0.1056,  1.6139, -1.4920,  0.5495, -0.9149,  0.1379],
         [ 0.1936,  0.8049, -2.1927,  0.4950,  0.4415,  0.2577]],

        [[-1.7267, -0.0143,  0.6227,  0.9632, -0.8418,  0.9969],
         [ 0.0873, -0.1348, -1.8618,  1.2062, -0.2832,  0.9862]],

        [[-0.6382, -1.7092, -0.3937,  0.9758,  0.9360,  0.8294],
         [ 0.0811, -0.2151, -1.9421,  1.2829,  0.0689,  0.7244]]],
       grad_fn=<AddBackward0>)



norm3(x'' + ff(x'')) 
 where, x'' = Decoder_curr_layer norm2(x' + mha(x'))
tensor([[[-1.1004,  0.1155, -1.5022,  0.5930,  0.5146,  1.3794],
         [ 0.5376, -0.1928, -0.8114,  1.2325,  0.8826, -1.6484]],

        [[-1.1626, -0.51

In [268]:
# (tgt_len, bsz, vocab_dim)
final_op = feef_fwd_transformer(final_attn_op_decoder, state_dict)

In [269]:
final_op

tensor([[[-6.6667e-01,  3.2296e-01, -4.0249e-01, -3.8964e-01, -2.7701e-01,
          -1.1482e-01,  3.5929e-01,  1.2512e+00,  2.0429e-01,  5.5085e-01],
         [ 7.3170e-02, -3.3670e-01,  2.7040e-01, -2.0084e-01,  2.5609e-01,
           6.8348e-02, -4.1680e-01, -5.0878e-01, -1.9309e-01,  9.6282e-01]],

        [[-1.2892e+00,  1.1669e+00, -1.3184e+00, -3.9436e-01,  1.2273e-01,
          -1.4213e-01, -4.2032e-02,  5.8542e-01,  2.4001e-01,  9.7504e-02],
         [-3.1303e-01,  9.1863e-02, -1.9785e-01, -6.2570e-01, -1.6641e-01,
          -2.3585e-01, -1.9337e-01,  3.2745e-01,  3.4555e-02,  1.0395e-01]],

        [[ 6.3494e-01, -8.2564e-01,  6.4853e-01, -4.0874e-01, -1.2193e+00,
           5.8824e-02,  8.0115e-01,  4.2675e-01,  1.4104e-01,  2.0981e-01],
         [ 2.9972e-01, -7.1025e-01,  7.6157e-04,  4.7561e-03, -8.9210e-01,
          -2.6281e-01,  5.6628e-01,  7.5355e-01,  3.6227e-01,  7.0990e-01]],

        [[-1.0802e+00,  1.1114e+00, -7.8312e-02, -1.1748e+00,  2.2005e-01,
           3.

In [270]:
final_op.shape

torch.Size([5, 2, 10])

########################################################################

### Examaple :- 

Transformer with 3 encoders and 3 decoders

In [271]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512, dropout=0):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        return self.encoding[:, :x.size(1)].detach()

class TransformerModel(nn.Module):
    
    def __init__(self, src_vocab_size, tgt_vocab_size,max_seq_len, num_encoder_layers=1, num_decoder_layers=1, d_model=4, nhead=2):
        super(TransformerModel, self).__init__()
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, dropout=0, max_len=max_seq_len)
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dropout=0
        )
        self.fc = nn.Linear(d_model, tgt_vocab_size)

    def forward(self, src, tgt):

        print("Source embeddings :- \n")
        print(self.src_embedding(src))
        print()

        print("Target embeddings :- \n")
        print(self.tgt_embedding(tgt))
        print()


        src = self.src_embedding(src) + self.positional_encoding(src)
        tgt = self.tgt_embedding(tgt) + self.positional_encoding(tgt)

        print("Positional encoded source embeddings:")
        print(src)
        print()

        print("Positional encoded target embeddings:")
        print(tgt)
        print()

        output = self.transformer(src, tgt)
        output = self.fc(output)
        
        return output
    

src_vocab_size = 10  # Source language vocabulary size
tgt_vocab_size = 10  # Target language vocabulary size
d_model = 4  # Dimension of the model
num_heads = 2
num_encoder_layers = 3
num_decoder_layers = 3

need_weights = False

src_mask = None
tgt_mask = None
memory_mask = None

max_seq_len = 5

d_model = 6


model = TransformerModel(src_vocab_size, tgt_vocab_size, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, d_model=d_model, max_seq_len = max_seq_len, nhead=num_heads)
# Source sentence in the source language
src_sentence = torch.tensor([[0, 2], [1, 0], [2, 2], [3, 5]])
print(src_sentence.shape)
# Target sentence in the target language (translation of the source sentence)
tgt_sentence = torch.tensor([[1, 7], [3, 4], [5, 2], [8, 0], [6, 1]])  # Target sequence
print(tgt_sentence.shape)

# Forward pass
output = model(src_sentence, tgt_sentence)
print(output)


torch.Size([4, 2])
torch.Size([5, 2])
Source embeddings :- 

tensor([[[ 0.3485,  0.9371,  0.6244, -0.0230, -0.1803,  1.6200],
         [-1.2299,  0.8020,  0.4715,  0.4010, -0.5315, -1.7415]],

        [[-0.1163,  0.3765, -0.3016, -0.2952, -0.5298, -0.6820],
         [ 0.3485,  0.9371,  0.6244, -0.0230, -0.1803,  1.6200]],

        [[-1.2299,  0.8020,  0.4715,  0.4010, -0.5315, -1.7415],
         [-1.2299,  0.8020,  0.4715,  0.4010, -0.5315, -1.7415]],

        [[-0.2243,  1.6587, -0.3522, -0.6067, -0.2162, -0.6181],
         [-0.4952, -2.1157,  1.7508, -0.6661, -0.6780,  0.7846]]],
       grad_fn=<EmbeddingBackward0>)

Target embeddings :- 

tensor([[[ 0.1129, -1.0142, -0.9221,  0.8812, -1.6048,  0.2050],
         [ 1.6339, -0.6463, -0.1945,  0.6870,  0.3138,  0.4030]],

        [[-0.7208,  0.4244, -1.1285,  1.7011,  2.0456,  0.1133],
         [ 0.7506, -0.2308,  1.0457, -0.1141, -0.8790, -0.7974]],

        [[-0.1423,  0.0984, -1.4589,  1.0058, -0.5254,  0.3244],
         [ 0.1265, -0



In [272]:
def get_all_intermediate_outputs(src_sentence, tgt_sentence,d_model, model, num_encoder_layers , num_decoder_layers):

    state_dict = model.state_dict()

    pe_src_embeds, pe_tgt_embeds = get_embedding_outputs(src_sentence=src_sentence,  tgt_sentence=tgt_sentence, state_dict=state_dict, max_seq_len=max_seq_len, d_model = d_model)
    print("###"*25)
    print("### Encoder Start ###")
    print()

    x_enc = pe_src_embeds

    for lno in range(num_encoder_layers):
        attn_enc_output, src_len, head_dim, attn_weights = encoder_block_attn_output(x_enc, state_dict, layer_num = lno, need_weights = False, embed_dim = d_model, num_heads= num_heads, src_mask=None)

        if lno == 0:
            tgt_len, bsz, embed_dim = x_enc.shape

        output_enc_final = encoder_block_post_attn_output(x_enc, attn_enc_output, state_dict, layer_num = lno , bsz = bsz, tgt_len = tgt_len)

        x_enc = output_enc_final

    
    print("### Encoder Done ###")
    print("\n\n\n")
    print("### Decoder Start ###")

    x_dec = pe_tgt_embeds
    memory = x_enc

    for lno in range(num_decoder_layers):

        self_attn_dec, dec_sa_wts = decoder_block_self_attn_output(x_dec, state_dict, layer_num = lno, need_weights = False, tgt_mask=None)
        x_dec = dec_post_self_attn(self_attn_dec, x_dec, state_dict, layer_num = lno)

        attn_dec_mha_output, attn_dec_mha_wts = decoder_block_cross_attn_output(x_dec, memory, state_dict, layer_num = lno, tgt_len = tgt_len, src_len = src_len, head_dim = head_dim, need_weights = False, memory_mask=None)
        final_op = decoder_block_post_attn_output(x_dec, attn_dec_mha_output, state_dict, layer_num = lno)

        print(pe_tgt_embeds.shape, final_op.shape)
        x_dec = final_op

    print("### Decoder Done ###")


    final_op = feef_fwd_transformer(final_op, state_dict)


    
    return final_op

    

In [273]:
d_model = 6
num_encoder_layers = 3
num_decoder_layers = 3

final_op = get_all_intermediate_outputs(src_sentence, tgt_sentence, model = model, num_encoder_layers = num_encoder_layers , num_decoder_layers = num_encoder_layers, d_model=d_model)

Source sentence embedding
Word index: 0, Embedding: tensor([ 0.3485,  0.9371,  0.6244, -0.0230, -0.1803,  1.6200])
Word index: 2, Embedding: tensor([-1.2299,  0.8020,  0.4715,  0.4010, -0.5315, -1.7415])
Word index: 1, Embedding: tensor([-0.1163,  0.3765, -0.3016, -0.2952, -0.5298, -0.6820])
Word index: 0, Embedding: tensor([ 0.3485,  0.9371,  0.6244, -0.0230, -0.1803,  1.6200])
Word index: 2, Embedding: tensor([-1.2299,  0.8020,  0.4715,  0.4010, -0.5315, -1.7415])
Word index: 2, Embedding: tensor([-1.2299,  0.8020,  0.4715,  0.4010, -0.5315, -1.7415])
Word index: 3, Embedding: tensor([-0.2243,  1.6587, -0.3522, -0.6067, -0.2162, -0.6181])
Word index: 5, Embedding: tensor([-0.4952, -2.1157,  1.7508, -0.6661, -0.6780,  0.7846])

torch.Size([4, 2, 6])
Target sentence embedding
Word index: 1, Embedding: tensor([ 0.1129, -1.0142, -0.9221,  0.8812, -1.6048,  0.2050])
Word index: 7, Embedding: tensor([ 1.6339, -0.6463, -0.1945,  0.6870,  0.3138,  0.4030])
Word index: 3, Embedding: tensor([-

In [274]:
output.shape, final_op.shape


(torch.Size([5, 2, 10]), torch.Size([5, 2, 10]))

In [275]:
output

tensor([[[ 0.0351, -1.1913, -0.3768, -0.1533, -0.6715,  0.2967,  0.9567,
          -0.4592, -0.7431,  0.5005],
         [ 0.5240, -1.0788, -0.4459, -0.1862, -0.0582,  0.5289,  0.4782,
          -0.9111, -0.0584,  0.9501]],

        [[ 0.3253, -1.2071, -0.4222, -0.2448, -0.6172,  0.4023,  0.9607,
          -0.6148, -0.7525,  0.4649],
         [ 0.6851, -1.1969, -0.5047,  0.0477, -0.4231,  0.6040,  0.7750,
          -0.5608, -0.2766,  0.7577]],

        [[ 0.1552, -1.1610, -0.4760, -0.2129, -0.6705,  0.3657,  1.0074,
          -0.4979, -0.7873,  0.5562],
         [ 0.6300, -1.2311, -0.5301, -0.1492, -0.4834,  0.5986,  0.8887,
          -0.6590, -0.5067,  0.7288]],

        [[ 0.3160, -1.2016, -0.5434, -0.1745, -0.7006,  0.4669,  1.0648,
          -0.4734, -0.7825,  0.6207],
         [ 0.7733, -1.1772, -0.5043, -0.2409, -0.3084,  0.6118,  0.7447,
          -0.8627, -0.3858,  0.7044]],

        [[ 0.1892, -1.2463, -0.4136, -0.1336, -0.6515,  0.3959,  0.9682,
          -0.4872, -0.6962,  0.

In [276]:
final_op

tensor([[[ 0.0351, -1.1913, -0.3768, -0.1533, -0.6715,  0.2967,  0.9567,
          -0.4592, -0.7431,  0.5005],
         [ 0.5240, -1.0788, -0.4459, -0.1862, -0.0582,  0.5289,  0.4782,
          -0.9111, -0.0584,  0.9501]],

        [[ 0.3253, -1.2071, -0.4222, -0.2448, -0.6172,  0.4023,  0.9607,
          -0.6148, -0.7525,  0.4649],
         [ 0.6851, -1.1969, -0.5047,  0.0477, -0.4231,  0.6040,  0.7749,
          -0.5608, -0.2766,  0.7577]],

        [[ 0.1552, -1.1610, -0.4760, -0.2129, -0.6705,  0.3657,  1.0074,
          -0.4979, -0.7873,  0.5562],
         [ 0.6300, -1.2311, -0.5301, -0.1492, -0.4834,  0.5986,  0.8887,
          -0.6590, -0.5067,  0.7288]],

        [[ 0.3160, -1.2016, -0.5434, -0.1745, -0.7005,  0.4669,  1.0648,
          -0.4734, -0.7824,  0.6207],
         [ 0.7733, -1.1772, -0.5043, -0.2409, -0.3084,  0.6118,  0.7447,
          -0.8627, -0.3858,  0.7044]],

        [[ 0.1892, -1.2463, -0.4136, -0.1336, -0.6515,  0.3959,  0.9682,
          -0.4872, -0.6962,  0.