## Multi-head attention 
### Encoder and Decoder
### (Without masking)

Pytorch's implementation (in built)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512, dropout=0):


Source embeddings :- 

tensor([[[-0.7902,  0.2501,  1.5278, -0.5044],
         [-1.6202, -0.6824,  0.4107,  0.1833],
         [ 0.4261, -1.3366, -1.2978, -0.8236],
         [-0.4685,  1.9740,  0.3854, -0.9693]]], grad_fn=<EmbeddingBackward0>)

Target embeddings :- 

tensor([[[-0.7213,  0.7379,  1.2300,  1.0213],
         [ 1.3452,  0.7265,  1.7796, -1.9332],
         [ 2.4084,  1.7481, -1.9929,  0.3660],
         [ 2.4084,  1.7481, -1.9929,  0.3660]]], grad_fn=<EmbeddingBackward0>)

Positional encoded source embeddings:
tensor([[[-0.7902,  2.2501,  3.5278,  3.4956],
         [-0.7787,  0.8579,  2.4207,  4.1832],
         [ 1.3354, -0.7528,  0.7222,  3.1762],
         [-0.3274,  1.9840,  2.4154,  3.0303]]], grad_fn=<AddBackward0>)

Positional encoded target embeddings:
tensor([[[0.2787, 1.7379, 4.2300, 5.0213],
         [3.1867, 1.2668, 4.7896, 2.0668],
         [4.3177, 1.3320, 1.0271, 4.3658],
         [3.5496, 0.7582, 1.0371, 4.3655]]], grad_fn=<AddBackward0>)

query =  tensor([[[-0.



In [None]:
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        return x + self.encoding[:, :x.size(1)].detach()

class TransformerModel(nn.Module):
    
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=4, nhead=2, num_encoder_layers=1, num_decoder_layers=1):
        super(TransformerModel, self).__init__()
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, dropout=0)
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dropout=0
        )
        self.fc = nn.Linear(d_model, tgt_vocab_size)

    def forward(self, src, tgt):

        print("Source embeddings :- \n")
        print(self.src_embedding(src))
        print()

        print("Target embeddings :- \n")
        print(self.tgt_embedding(tgt))
        print()


        src = self.src_embedding(src) + self.positional_encoding(src)
        tgt = self.tgt_embedding(tgt) + self.positional_encoding(tgt)

        print("Positional encoded source embeddings:")
        print(src)
        print()

        print("Positional encoded target embeddings:")
        print(tgt)
        print()

        output = self.transformer(src, tgt)
        output = self.fc(output)
        return output

    

src_vocab_size = 10  # Source language vocabulary size
tgt_vocab_size = 10  # Target language vocabulary size
d_model = 4  # Dimension of the model
num_heads = 2
model = TransformerModel(src_vocab_size, tgt_vocab_size, d_model=d_model, nhead=num_heads)
# Source sentence in the source language
src_sentence = torch.tensor([[0, 1, 2, 3]])  # Source sequence

# Target sentence in the target language (translation of the source sentence)
tgt_sentence = torch.tensor([[1, 0, 3, 3]])  # Target sequence

# Forward pass
output = model(src_sentence, tgt_sentence)
print(output)



## Functions to get the manual calulations for each component

In [2]:
from torch import functional as F

# Using the state dictionary to get the intermediate outputs
state_dict = model.state_dict()

need_weights = False

### Embeddings and Positional encoding

In [3]:
def get_embedding_outputs(src_sentence,  tgt_sentence, state_dict, d_model):

    src_vocab_embeds = state_dict["src_embedding.weight"]

    src_embedding = torch.zeros(src_sentence.size(0), src_sentence.size(1), d_model)

    print("Source sentence embedding")
    for i in range(src_sentence.size(0)):
        for j in range(src_sentence.size(1)):
            # Get the index for the current word in the sequence
            word_index = src_sentence[i, j].item()

            # Check if the index is within valid range
            if word_index < 0 or word_index >= src_vocab_embeds.size(0):
                raise ValueError(f"Invalid word index: {word_index}")

            # Lookup the corresponding embedding vector for the word
            src_embedding[i, j, :] = src_vocab_embeds[word_index, :]

            # Print intermediate results for debugging
            
            print(f"Word index: {word_index}, Embedding: {src_vocab_embeds[word_index, :]}")
    print()


    tgt_vocab_embeds = state_dict["tgt_embedding.weight"]

    tgt_embedding = torch.zeros(tgt_sentence.size(0), tgt_sentence.size(1), d_model)


    print("Target sentence embedding")
    for i in range(tgt_sentence.size(0)):
        for j in range(tgt_sentence.size(1)):
            # Get the index for the current word in the sequence
            word_index = tgt_sentence[i, j].item()

            # Check if the index is within valid range
            if word_index < 0 or word_index >= tgt_vocab_embeds.size(0):
                raise ValueError(f"Invalid word index: {word_index}")

            # Lookup the corresponding embedding vector for the word
            tgt_embedding[i, j, :] = tgt_vocab_embeds[word_index, :]

            # Print intermediate results for debugging
            print(f"Word index: {word_index}, Embedding: {tgt_vocab_embeds[word_index, :]}")
    print()


    pe = PositionalEncoding(d_model = d_model, dropout=0)

    pe_src_embeds = src_embedding + pe(src_sentence)

    pe_tgt_embeds = tgt_embedding + pe(tgt_sentence)

    print("PE source embeddings : \n")
    print(pe_src_embeds)
    print()

    print("PE target embeddings : \n")
    print(pe_tgt_embeds)
    print()

    return pe_src_embeds, pe_tgt_embeds



In [4]:
pe_src_embeds, pe_tgt_embeds = get_embedding_outputs(src_sentence,  tgt_sentence, state_dict, d_model = d_model)

Source sentence embedding
Word index: 0, Embedding: tensor([-0.7902,  0.2501,  1.5278, -0.5044])
Word index: 1, Embedding: tensor([-1.6202, -0.6824,  0.4107,  0.1833])
Word index: 2, Embedding: tensor([ 0.4261, -1.3366, -1.2978, -0.8236])
Word index: 3, Embedding: tensor([-0.4685,  1.9740,  0.3854, -0.9693])

Target sentence embedding
Word index: 1, Embedding: tensor([-0.7213,  0.7379,  1.2300,  1.0213])
Word index: 0, Embedding: tensor([ 1.3452,  0.7265,  1.7796, -1.9332])
Word index: 3, Embedding: tensor([ 2.4084,  1.7481, -1.9929,  0.3660])
Word index: 3, Embedding: tensor([ 2.4084,  1.7481, -1.9929,  0.3660])

PE source embeddings : 

tensor([[[-0.7902,  2.2501,  3.5278,  3.4956],
         [-0.7787,  0.8579,  2.4207,  4.1832],
         [ 1.3354, -0.7528,  0.7222,  3.1762],
         [-0.3274,  1.9840,  2.4154,  3.0303]]])

PE target embeddings : 

tensor([[[0.2787, 1.7379, 4.2300, 5.0213],
         [3.1867, 1.2668, 4.7896, 2.0668],
         [4.3177, 1.3320, 1.0271, 4.3658],
        

## Encoder function to display the intermediate outputs and get the final outputs from the encoder

### Self attention 

In [5]:
def encoder_block_attn_output(x, state_dict, layer_num, need_weights = False):

    query_enc = key_enc = value_enc = x

    tgt_len, bsz, embed_dim = x.shape

    W_enc = state_dict["transformer.encoder.layers.{}.self_attn.in_proj_weight".format(layer_num)]
    b_enc = state_dict["transformer.encoder.layers.{}.self_attn.in_proj_bias".format(layer_num)]


    head_dim = embed_dim//num_heads

    tempop1 = query_enc[0]@W_enc.T
    tempop1 = tempop1.T

    Q_enc,K_enc,V_enc = tempop1.T.chunk(3, dim= -1)


    Q_enc = Q_enc.unsqueeze(0)
    K_enc = K_enc.unsqueeze(0)
    V_enc = V_enc.unsqueeze(0)


    Q_enc = Q_enc.reshape(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    K_enc = K_enc.reshape(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    V_enc = V_enc.reshape(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)

    print("Q_enc_{} = ".format(layer_num))
    print(Q_enc)
    print()

    print("K_enc_{} = ".format(layer_num))
    print(K_enc)
    print()

    print("V_enc_{} = ".format(layer_num))
    print(V_enc)
    print()

    src_len = K_enc.size(1)


    if need_weights is False:
        Q_enc1 = Q_enc.view(bsz, num_heads, tgt_len, head_dim)
        K_enc1 = K_enc.view(bsz, num_heads, src_len, head_dim)
        V_enc1 = V_enc.view(bsz, num_heads, src_len, head_dim)


        L, S = Q_enc1.size(-2), K_enc1.size(-2)

        scale_factor = 1 / math.sqrt(Q_enc1.size(-1)) 
        # scale_factor = 1
        attn_bias = torch.zeros(L, S, dtype=Q_enc1.dtype)


        attn_weight = Q_enc1 @ K_enc1.transpose(-2, -1) * scale_factor
        attn_weight += attn_bias
        attn_weight = torch.softmax(attn_weight, dim=-1)
        
        attn_output = attn_weight @ V_enc1

        attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)

        print("Encoder Attention output = ")
        print(attn_output)
        print()

        # print(src_len, pe_src_embeds.shape)
        return attn_output, src_len, head_dim, None
    
    else:

        B, Nt, E = Q_enc.shape

        Q_enc_scaled = Q_enc / math.sqrt(E)

        Q_enc_scaled

        temp_pdt_matrix_enc = torch.bmm(Q_enc_scaled, K_enc.transpose(-2, -1))

        # Obtaining the diagonal entries of the un-normalised product (K.Q_scaled)
        pdt_matrix_enc = torch.diagonal(temp_pdt_matrix_enc, dim1=-2, dim2=-1).unsqueeze(-1)

        attn_wt_matrix_enc = torch.nn.functional.softmax(pdt_matrix_enc, dim=-1)

        attn_enc_output = torch.bmm(attn_wt_matrix_enc, V_enc)

        attn_enc_output = attn_enc_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)

        print("Encoder Attention output = ")
        print(attn_enc_output)
        print()

        return attn_enc_output, src_len, head_dim, attn_wt_matrix_enc

    

In [7]:
# need_weights = True

attn_enc_output, src_len, head_dim, attn_weights = encoder_block_attn_output(pe_src_embeds, state_dict, layer_num = 0, need_weights = need_weights)


Q_enc_0 = 
tensor([[[ 0.5279, -2.2983]],

        [[-1.0550, -1.1707]],

        [[ 0.6662, -1.4912]],

        [[-1.2663, -1.0478]],

        [[ 1.1158,  0.0281]],

        [[-0.6286, -1.5645]],

        [[ 0.7990, -1.7561]],

        [[-0.9101, -1.1756]]])

K_enc_0 = 
tensor([[[-0.6694,  0.4451]],

        [[-0.7058,  4.3720]],

        [[-0.7618,  0.1292]],

        [[-1.4209,  3.7035]],

        [[-0.5826, -0.2429]],

        [[-1.7014,  1.3370]],

        [[-0.9576,  0.1478]],

        [[-0.8108,  3.5906]]])

V_enc_0 = 
tensor([[[ 4.4475,  0.1953]],

        [[ 1.3194,  2.5191]],

        [[ 4.0303,  0.9956]],

        [[-0.0675,  1.8946]],

        [[ 1.9893,  1.7287]],

        [[-0.5388, -0.3410]],

        [[ 3.4502,  0.4872]],

        [[ 0.9050,  1.7593]]])

Encoder Attention output = 
tensor([[ 4.4475,  0.1953,  1.3194,  2.5191],
        [ 4.0303,  0.9956, -0.0675,  1.8946],
        [ 1.9893,  1.7287, -0.5388, -0.3410],
        [ 3.4502,  0.4872,  0.9050,  1.7593]])



### Post self attention in encoder

In [8]:
def encoder_block_post_attn_output(x, attn_enc_output, state_dict, layer_num, bsz, tgt_len):

    op_enc_1 = torch.matmul(attn_enc_output, state_dict["transformer.encoder.layers.{}.self_attn.out_proj.weight".format(layer_num)].t()) + state_dict["transformer.encoder.layers.{}.self_attn.out_proj.bias".format(layer_num)]

    attn_enc_output = op_enc_1.view(tgt_len, bsz, attn_enc_output.size(1))


    # Here src is the original passed inputs to the 1st transformer encoder layer which 
    # are pe_src_embeds 
    output_enc_1 = attn_enc_output + x
    output_enc_1

    linear_result_enc_1 = output_enc_1*state_dict["transformer.encoder.layers.{}.norm1.weight".format(layer_num)] + state_dict["transformer.encoder.layers.{}.norm1.bias".format(layer_num)]

    # Layer normalization from Torch's implementation
    layernorm_enc_1 = torch.nn.LayerNorm(normalized_shape=linear_result_enc_1.shape[2:])
    linear_op_enc_1 = layernorm_enc_1(linear_result_enc_1)


    # Manual Layer Normalization
    x = linear_result_enc_1

    # Obtained layer norm weights and biases (learnable)
    w = layernorm_enc_1.weight
    b = layernorm_enc_1.bias

    linear_result_enc_1f = w*x + b

    epsilon = 1e-5  
    mean = linear_result_enc_1f.mean(dim=-1, keepdim=True)
    std = linear_result_enc_1f.std(dim=-1, unbiased=False, keepdim=True)
    normalized_result_enc_1 = (linear_result_enc_1f - mean) / (std + epsilon) * w + b


    op_enc_1 = torch.matmul(normalized_result_enc_1, state_dict["transformer.encoder.layers.{}.linear1.weight".format(layer_num)].t()) + state_dict["transformer.encoder.layers.{}.linear1.bias".format(layer_num)]
    op_enc_1_relu = torch.nn.functional.relu(op_enc_1)
    op_enc_2 = torch.matmul(op_enc_1_relu, state_dict["transformer.encoder.layers.{}.linear2.weight".format(layer_num)].t()) + state_dict["transformer.encoder.layers.{}.linear2.bias".format(layer_num)]


    output_enc_2 = op_enc_2 + linear_op_enc_1
    output_enc_2_norm = output_enc_2*state_dict["transformer.encoder.layers.{}.norm2.weight".format(layer_num)] + state_dict["transformer.encoder.layers.{}.norm2.bias".format(layer_num)]

    # Layer normalization from Torch's implementation
    layernorm_enc_final = torch.nn.LayerNorm(normalized_shape=output_enc_2_norm.shape[2:])
    output_enc_final = layernorm_enc_final(output_enc_2_norm)


    # Manual Layer Normalization 
    x = output_enc_2_norm

    # Obtained layer norm weights and biases (learnable)
    w = layernorm_enc_final.weight
    b = layernorm_enc_final.bias

    linear_result_enc_2 = w*x + b

    epsilon = 1e-5  
    mean = linear_result_enc_2.mean(dim=-1, keepdim=True)
    std = linear_result_enc_2.std(dim=-1, unbiased=False, keepdim=True)
    output_enc_final = (linear_result_enc_2 - mean) / (std + epsilon) * w + b

    print("Final Encoder {} Output :".format(layer_num))
    print("norm2(norm1(x + self_atten(x)) + feed_fwd_op)\n")
    print(output_enc_final)
    print()

    return output_enc_final


In [9]:
tgt_len, bsz, embed_dim = pe_src_embeds.shape
output_enc_final = encoder_block_post_attn_output(pe_src_embeds, attn_enc_output, state_dict, layer_num = 0 , bsz = bsz, tgt_len = tgt_len)

Final Encoder 0 Output :
norm2(norm1(x + self_atten(x)) + feed_fwd_op)

tensor([[[-1.4351,  0.2488,  1.3596, -0.1734],
         [-1.5803,  0.6344,  1.0443, -0.0984],
         [ 0.5213,  1.2944, -1.3575, -0.4581],
         [-1.5292,  0.5588,  1.1479, -0.1776]]], grad_fn=<AddBackward0>)



## Decoder functions to display the intermediate outputs and get the final outputs from the decoder


### Self attention outputs from a decoder block

In [10]:
def decoder_block_self_attn_output(x, state_dict, layer_num, need_weights = False):

    query_dec = key_dec = value_dec = x

    tgt_len, bsz, embed_dim = x.shape

    W_dec = state_dict["transformer.decoder.layers.{}.self_attn.in_proj_weight".format(layer_num)]
    b_dec = state_dict["transformer.decoder.layers.{}.self_attn.in_proj_bias".format(layer_num)]


    head_dim = embed_dim//num_heads

    tempop1 = query_dec[0]@W_dec.T
    tempop1 = tempop1.T

    Q_dec,K_dec,V_dec = tempop1.T.chunk(3, dim= -1)


    Q_dec = Q_dec.unsqueeze(0)
    K_dec = K_dec.unsqueeze(0)
    V_dec = V_dec.unsqueeze(0)


    Q_dec = Q_dec.reshape(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    K_dec = K_dec.reshape(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    V_dec = V_dec.reshape(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)

    print("Q_dec_{} = ".format(layer_num))
    print(Q_dec)
    print()

    print("K_dec_{} = ".format(layer_num))
    print(K_dec)
    print()

    print("V_dec_{} = ".format(layer_num))
    print(V_dec)
    print()

    src_len = K_dec.size(1)


    if need_weights is False:
        Q_dec1 = Q_dec.view(bsz, num_heads, tgt_len, head_dim)
        K_dec1 = K_dec.view(bsz, num_heads, src_len, head_dim)
        V_dec1 = V_dec.view(bsz, num_heads, src_len, head_dim)


        L, S = Q_dec1.size(-2), K_dec1.size(-2)

        scale_factor = 1 / math.sqrt(Q_dec1.size(-1)) 
        # scale_factor = 1
        attn_bias = torch.zeros(L, S, dtype=Q_dec1.dtype)


        attn_weight = Q_dec1 @ K_dec1.transpose(-2, -1) * scale_factor
        attn_weight += attn_bias
        attn_weight = torch.softmax(attn_weight, dim=-1)
        
        attn_output = attn_weight @ V_dec1

        attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)

        print("Decoder Self Attention = ")
        print(attn_output)
        print()

        op_dec_1 = torch.matmul(attn_output, state_dict["transformer.decoder.layers.{}.self_attn.out_proj.weight".format(layer_num)].t()) + state_dict["transformer.decoder.layers.{}.self_attn.out_proj.bias".format(layer_num)]
        attn_dec_output = op_dec_1.view(tgt_len, bsz, attn_output.size(1))

        return attn_dec_output, None
    
    else:


        B, Nt, E = Q_dec.shape

        Q_dec_scaled = Q_dec / math.sqrt(E)

        Q_dec_scaled

        temp_pdt_matrix_dec = torch.bmm(Q_dec_scaled, K_dec.transpose(-2, -1))

        # Obtaining the diagonal entries of the un-normalised product (K.Q_scaled)
        pdt_matrix_dec = torch.diagonal(temp_pdt_matrix_dec, dim1=-2, dim2=-1).unsqueeze(-1)

        print("Q_scaled @ Kt = ", pdt_matrix_dec)

        attn_wt_matrix_dec = torch.nn.functional.softmax(pdt_matrix_dec, dim=-1)

        attn_dec_output = torch.bmm(attn_wt_matrix_dec, V_dec)

        attn_dec_output = attn_dec_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)

        print("Decoder Attention output = ")
        print(attn_wt_matrix_dec)
        print()


        op_dec_1 = torch.matmul(attn_dec_output, state_dict["transformer.decoder.layers.{}.self_attn.out_proj.weight".format(layer_num)].t()) + state_dict["transformer.decoder.layers.{}.self_attn.out_proj.bias".format(layer_num)]
        attn_dec_output = op_dec_1.view(tgt_len, bsz, attn_dec_output.size(1))
    

        return attn_dec_output, attn_wt_matrix_dec




# query_dec = key_dec = value_dec = pe_tgt_embeds 

# tgt_len, bsz, embed_dim = pe_src_embeds.shape
    

In [11]:
need_weights

False

In [12]:
# need_weights = True

self_attn_dec, dec_sa_wts = decoder_block_self_attn_output(pe_tgt_embeds, state_dict, layer_num = 0, need_weights = need_weights)

Q_dec_0 = 
tensor([[[ 0.5871, -3.8831]],

        [[ 4.2142, -2.5878]],

        [[ 2.9569, -1.3021]],

        [[ 4.5938,  0.1096]],

        [[ 0.0119, -0.3000]],

        [[ 4.3522, -0.2060]],

        [[-0.4440, -0.9314]],

        [[ 3.7144, -0.5702]]])

K_dec_0 = 
tensor([[[0.9020, 3.8072]],

        [[2.5476, 2.3639]],

        [[3.1392, 2.3370]],

        [[3.9712, 2.5925]],

        [[1.8144, 2.4839]],

        [[2.6009, 2.5871]],

        [[1.1623, 2.4910]],

        [[2.2101, 2.6568]]])

V_dec_0 = 
tensor([[[ 1.6530, -3.7049]],

        [[-0.6983, -0.4213]],

        [[ 0.4364, -3.0483]],

        [[ 0.2365, -1.1687]],

        [[-1.8952, -1.8379]],

        [[-3.1507, -2.0810]],

        [[-1.3617, -1.7368]],

        [[-2.8310, -1.5591]]])

Decoder Self Attention = 
tensor([[ 1.6530, -3.7049, -0.6983, -0.4213],
        [ 0.4364, -3.0483,  0.2365, -1.1687],
        [-1.8952, -1.8379, -3.1507, -2.0810],
        [-1.3617, -1.7368, -2.8310, -1.5591]])



In [13]:
self_attn_dec

tensor([[[ 0.8414, -2.6533,  1.9088,  1.6049],
         [ 0.7398, -2.4941,  2.2667,  0.5465],
         [-0.7669, -2.4307, -0.1783, -1.5623],
         [-0.6269, -2.1283, -0.1805, -1.0618]]])

In [14]:
def dec_post_self_attn(self_attn_dec, x, state_dict, layer_num):
    output_dec_1 = self_attn_dec + x

    linear_result_dec_1 = output_dec_1*state_dict["transformer.decoder.layers.{}.norm1.weight".format(layer_num)] + state_dict["transformer.decoder.layers.{}.norm1.bias".format(layer_num)]

    layernorm_dec_1 = torch.nn.LayerNorm(normalized_shape=linear_result_dec_1.shape[2:])
    linear_op_dec_1 = layernorm_dec_1(linear_result_dec_1)

    # From torch's implementation
    # print(linear_op_dec_1)

    x = linear_result_dec_1

    # Obtained layer norm weights and biases (learnable)
    w = layernorm_dec_1.weight
    b = layernorm_dec_1.bias

    linear_result_dec_1f = w*x + b

    epsilon = 1e-5  
    mean = linear_result_dec_1f.mean(dim=-1, keepdim=True)
    std = linear_result_dec_1f.std(dim=-1, unbiased=False, keepdim=True)
    normalized_result_dec_1 = (linear_result_dec_1f - mean) / (std + epsilon) * w + b


    print("Decoder_{} norm1(x + sa(x))".format(layer_num))
    print(normalized_result_dec_1)
    print()

    return normalized_result_dec_1

In [15]:
x_dec = dec_post_self_attn(self_attn_dec, pe_tgt_embeds, state_dict, layer_num = 0)

Decoder_0 norm1(x + sa(x))
tensor([[[-0.6579, -1.2888,  0.8978,  1.0489],
         [ 0.2808, -1.4541,  1.3345, -0.1612],
         [ 1.1197, -1.4515, -0.3745,  0.7064],
         [ 0.8014, -1.5007, -0.3065,  1.0058]]], grad_fn=<AddBackward0>)



## Cross attention in decoder 
### query, key, value =  x, mem, mem 

### Cross attention between the encoder's final output and the decoder layer 

In [16]:

memory = output_enc_final
x_dec

tensor([[[-0.6579, -1.2888,  0.8978,  1.0489],
         [ 0.2808, -1.4541,  1.3345, -0.1612],
         [ 1.1197, -1.4515, -0.3745,  0.7064],
         [ 0.8014, -1.5007, -0.3065,  1.0058]]], grad_fn=<AddBackward0>)

In [17]:
def decoder_block_cross_attn_output(x_dec, memory, state_dict, layer_num, tgt_len, src_len, head_dim, need_weights = False):

    query_dec_mha = x_dec
    key_dec_mha, value_dec_mha = memory, memory


    W_dec_mha = state_dict["transformer.decoder.layers.{}.multihead_attn.in_proj_weight".format(layer_num)]
    b_dec_mha = state_dict["transformer.decoder.layers.{}.multihead_attn.in_proj_bias".format(layer_num)]


    ########## CHANGE FROM SELF ATTENTION ############

    W_q, W_k, W_v = W_dec_mha.chunk(3)

    Q_dec_mha = query_dec_mha[0]@W_q.T

    K_dec_mha = key_dec_mha[0]@W_k.T

    V_dec_mha = value_dec_mha[0]@W_v.T

    #################################################

    K_dec_mha = K_dec_mha.unsqueeze(0)
    V_dec_mha = V_dec_mha.unsqueeze(0)
    Q_dec_mha = Q_dec_mha.unsqueeze(0)

    Q_dec_mha = Q_dec_mha.reshape(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    K_dec_mha = K_dec_mha.reshape(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    V_dec_mha = V_dec_mha.reshape(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)

    print("Q_dec_{} = ".format(layer_num))
    print(Q_dec_mha)
    print()

    print("K_dec_{} = ".format(layer_num))
    print(K_dec_mha)
    print()

    print("V_dec_{} = ".format(layer_num))
    print(V_dec_mha)
    print()

    ### With 'need_weights' = False

    if need_weights is False:
        Q_dec_mha1 = Q_dec_mha.view(bsz, num_heads, tgt_len, head_dim)
        K_dec_mha1 = K_dec_mha.view(bsz, num_heads, src_len, head_dim)
        V_dec_mha1 = V_dec_mha.view(bsz, num_heads, src_len, head_dim)


        L, S = Q_dec_mha1.size(-2), K_dec_mha1.size(-2)


        scale_factor = 1 / math.sqrt(Q_dec_mha1.size(-1)) 
        # scale_factor = 1

        attn_bias = torch.zeros(L, S, dtype=Q_dec_mha1.dtype)


        attn_weight = Q_dec_mha1 @ K_dec_mha1.transpose(-2, -1) * scale_factor
        attn_weight += attn_bias
        attn_weight = torch.softmax(attn_weight, dim=-1)
        attn_output_dec_mha = attn_weight @ V_dec_mha1

        attn_output_dec_mha = attn_output_dec_mha.permute(2, 0, 1, 3).view(bsz * tgt_len, embed_dim)
        
        print("Cross attention in decoder_{}".format(layer_num))
        print(attn_output_dec_mha)
        print()

        op_dec_mha_1 = torch.matmul(attn_output_dec_mha, state_dict["transformer.decoder.layers.{}.multihead_attn.out_proj.weight".format(layer_num)].t()) + state_dict["transformer.decoder.layers.{}.multihead_attn.out_proj.bias".format(layer_num)]

        attn_dec_mha_output = op_dec_mha_1.view(tgt_len, bsz, attn_output_dec_mha.size(1))

        return attn_dec_mha_output, None
    
    else: 

        B, Nt, E = Q_dec_mha.shape

        Q_dec_mha_scaled = Q_dec_mha / math.sqrt(E)

        Q_dec_mha_scaled

        temp_pdt_matrix_dec_mha = torch.bmm(Q_dec_mha_scaled, K_dec_mha.transpose(-2, -1))

        # Obtaining the diagonal entries of the un-normalised product (K.Q_scaled)
        pdt_matrix_dec_mha = torch.diagonal(temp_pdt_matrix_dec_mha, dim1=-2, dim2=-1).unsqueeze(-1)

        print("Q_scaled @ Kt = ", pdt_matrix_dec_mha)

        attn_wt_matrix_dec_mha = torch.nn.functional.softmax(pdt_matrix_dec_mha, dim=-1)

        attn_dec_mha_output = torch.bmm(attn_wt_matrix_dec_mha, V_dec_mha)

        attn_dec_mha_output = attn_dec_mha_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)

        print("Decoder mha Attention output = ")
        print(attn_dec_mha_output)
        print()


        op_dec_mha = torch.matmul(attn_dec_mha_output, state_dict["transformer.decoder.layers.{}.multihead_attn.out_proj.weight".format(layer_num)].t()) + state_dict["transformer.decoder.layers.{}.multihead_attn.out_proj.bias".format(layer_num)]
        attn_dec_output = op_dec_mha.view(tgt_len, bsz, attn_dec_mha_output.size(1))
    


        # B, Nt, E = Q_dec_mha.shape

        # Q_dec_scaled = Q_dec_mha / math.sqrt(E)

        # Q_dec_scaled

        # temp_pdt_matrix_dec = torch.bmm(Q_dec_scaled, K_dec_mha.transpose(-2, -1))

        # # Obtaining the diagonal entries of the un-normalised product (K.Q_scaled)
        # pdt_matrix_dec = torch.diagonal(temp_pdt_matrix_dec, dim1=-2, dim2=-1).unsqueeze(-1)

        # attn_wt_matrix_dec = F.softmax(pdt_matrix_dec, dim=-1)
        # attn_wt_matrix_dec = attn_wt_matrix_dec.view(-1, 1, 1)

        # attn_dec_output = torch.bmm(attn_wt_matrix_dec,V_dec_mha.permute(1, 0, 2))

        # attn_dec_output = attn_dec_output.transpose(0, 1).view(tgt_len * bsz, embed_dim)

        return attn_dec_output , attn_wt_matrix_dec_mha

        





In [18]:
# need_weights = True

attn_dec_mha_output, attn_dec_mha_wts = decoder_block_cross_attn_output(x_dec, memory, state_dict, layer_num = 0, tgt_len = tgt_len, src_len = src_len, head_dim = head_dim, need_weights = need_weights)

Q_dec_0 = 
tensor([[[ 0.9136, -0.6283]],

        [[-0.8274, -0.4683]],

        [[ 1.0450, -0.3151]],

        [[-0.8564, -0.7002]],

        [[ 0.8594,  0.4279]],

        [[-0.9335, -0.7778]],

        [[ 0.9027,  0.2643]],

        [[-0.9781, -0.7504]]], grad_fn=<TransposeBackward0>)

K_dec_0 = 
tensor([[[-0.0191, -0.7460]],

        [[-1.2817,  1.1603]],

        [[-0.1216, -0.6533]],

        [[-1.0173,  1.2253]],

        [[-0.2807,  0.9704]],

        [[ 1.7984, -0.6181]],

        [[-0.0946, -0.6538]],

        [[-1.0587,  1.1941]]], grad_fn=<TransposeBackward0>)

V_dec_0 = 
tensor([[[ 0.6794,  0.3024]],

        [[ 0.1560, -1.2189]],

        [[ 0.8529, -0.0325]],

        [[ 0.1457, -1.3982]],

        [[-0.2030, -1.4923]],

        [[-0.1055, -0.4424]],

        [[ 0.7814,  0.0261]],

        [[ 0.1489, -1.3966]]], grad_fn=<TransposeBackward0>)

Cross attention in decoder_0
tensor([[ 0.6794,  0.3024,  0.1560, -1.2189],
        [ 0.8529, -0.0325,  0.1457, -1.3982],
        [

In [19]:
def decoder_block_post_attn_output(x_dec, attn_dec_mha_output, state_dict, layer_num):

    output_dec_2 = attn_dec_mha_output + x_dec

    linear_result_dec_2 = output_dec_2*state_dict["transformer.decoder.layers.{}.norm2.weight".format(layer_num)] + state_dict["transformer.decoder.layers.{}.norm2.bias".format(layer_num)]

    # Layer normalization from Torch's implementation 
    layernorm_dec_2 = torch.nn.LayerNorm(normalized_shape=linear_result_dec_2.shape[2:])
    linear_op_dec_2 = layernorm_dec_2(linear_result_dec_2)

    x = linear_result_dec_2
    w = layernorm_dec_2.weight
    b = layernorm_dec_2.bias

    linear_result_dec_2f = w*x + b

    epsilon = 1e-5  
    mean = linear_result_dec_2f.mean(dim=-1, keepdim=True)
    std = linear_result_dec_2f.std(dim=-1, unbiased=False, keepdim=True)
    normalized_result_dec_2 = (linear_result_dec_2f - mean) / (std + epsilon) * w + b

    print("norm2(x' + mha(x', mem)), \n where x' = Decoder_curr_layer norm1(x + sa(x))")
    print(normalized_result_dec_2)
    print("\n\n")

    x_dec2_norm = normalized_result_dec_2

    op_dec_1 = torch.matmul(x_dec2_norm, state_dict["transformer.decoder.layers.{}.linear1.weight".format(layer_num)].t()) + state_dict["transformer.decoder.layers.{}.linear1.bias".format(layer_num)]
    op_dec_1_relu = torch.nn.functional.relu(op_dec_1)
    op_dec_2 = torch.matmul(op_dec_1_relu, state_dict["transformer.decoder.layers.{}.linear2.weight".format(layer_num)].t()) + state_dict["transformer.decoder.layers.{}.linear2.bias".format(layer_num)]

    ff_dec = op_dec_2

    x_dec3_unorm = x_dec2_norm + ff_dec

    linear_result_dec_3 = x_dec3_unorm*state_dict["transformer.decoder.layers.0.norm3.weight"] + state_dict["transformer.decoder.layers.0.norm3.bias"]

    # Layer normalization from Torch's implementation 
    layernorm_dec_3 = torch.nn.LayerNorm(normalized_shape=linear_result_dec_3.shape[2:])
    linear_op_dec_3 = layernorm_dec_3(linear_result_dec_3)

    # From torch's implementation
    x = linear_result_dec_3

    # Obtained layer norm weights and biases (learnable)
    w = layernorm_dec_3.weight
    b = layernorm_dec_3.bias

    linear_result_dec_3f = w*x + b

    epsilon = 1e-5  
    mean = linear_result_dec_3f.mean(dim=-1, keepdim=True)
    std = linear_result_dec_3f.std(dim=-1, unbiased=False, keepdim=True)
    normalized_result_dec_3 = (linear_result_dec_3f - mean) / (std + epsilon) * w + b

    print("norm3(x'' + ff(x'')) \n where, x'' = Decoder_curr_layer norm2(x' + mha(x'))")
    print(normalized_result_dec_3)
    print()

    return normalized_result_dec_3

def feef_fwd_transformer(dec_output_final, state_dict):

    final_op = dec_output_final@state_dict["fc.weight"].T + state_dict["fc.bias"]

    return final_op

In [20]:
final_attn_op_decoder = decoder_block_post_attn_output(x_dec, attn_dec_mha_output, state_dict, layer_num = 0)

norm2(x' + mha(x', mem)), 
 where x' = Decoder_curr_layer norm1(x + sa(x))
tensor([[[-0.9941, -0.6906,  0.0954,  1.5892],
         [ 0.9460, -1.5991,  0.7356, -0.0825],
         [ 0.9940, -1.4993, -0.3113,  0.8166],
         [ 0.8530, -0.6956, -1.2569,  1.0995]]], grad_fn=<AddBackward0>)



norm3(x'' + ff(x'')) 
 where, x'' = Decoder_curr_layer norm2(x' + mha(x'))
tensor([[[-0.4839, -1.0977, -0.0186,  1.6002],
         [ 1.0258, -1.6139,  0.5856,  0.0024],
         [ 1.1085, -1.4979, -0.2804,  0.6698],
         [ 1.0902, -0.9253, -1.0677,  0.9029]]], grad_fn=<AddBackward0>)



In [21]:
final_op = feef_fwd_transformer(final_attn_op_decoder, state_dict)

In [22]:
final_op

tensor([[[-0.4911,  0.1506,  0.7609,  1.1714,  0.0706,  0.0303,  0.0905,
           0.0423, -0.8750, -0.0365],
         [-0.6612, -0.7883,  0.2774,  0.2475, -0.7074,  0.0212, -1.3313,
          -0.3360, -0.1797,  0.7370],
         [-1.0149, -0.8154,  0.5574,  0.2583,  0.0395,  0.0466, -0.6874,
          -0.8217, -0.5338,  0.7056],
         [-1.2632, -0.8038,  0.6587,  0.1799,  0.6946, -0.0544, -0.0378,
          -1.0047, -0.7610,  0.4892]]], grad_fn=<AddBackward0>)

########################################################################

### Examaple :- 

Transformer with 3 encoders and 3 decoders

In [25]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512, dropout=0):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        return x + self.encoding[:, :x.size(1)].detach()

class TransformerModel(nn.Module):
    
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=4, nhead=2, num_encoder_layers=1, num_decoder_layers=1):
        super(TransformerModel, self).__init__()
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, dropout=0)
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dropout=0
        )
        self.fc = nn.Linear(d_model, tgt_vocab_size)

    def forward(self, src, tgt):

        print("Source embeddings :- \n")
        print(self.src_embedding(src))
        print()

        print("Target embeddings :- \n")
        print(self.tgt_embedding(tgt))
        print()


        src = self.src_embedding(src) + self.positional_encoding(src)
        tgt = self.tgt_embedding(tgt) + self.positional_encoding(tgt)

        print("Positional encoded source embeddings:")
        print(src)
        print()

        print("Positional encoded target embeddings:")
        print(tgt)
        print()

        output = self.transformer(src, tgt)
        output = self.fc(output)
        return output

    

src_vocab_size = 10  # Source language vocabulary size
tgt_vocab_size = 10  # Target language vocabulary size
d_model = 4  # Dimension of the model
num_heads = 2
num_encoder_layers = 3
num_decoder_layers = 3
model = TransformerModel(src_vocab_size, tgt_vocab_size, d_model=d_model, nhead=num_heads, num_encoder_layers = num_encoder_layers, num_decoder_layers=num_decoder_layers)
# Source sentence in the source language
src_sentence = torch.tensor([[0, 1, 2, 3]])  # Source sequence

# Target sentence in the target language (translation of the source sentence)
tgt_sentence = torch.tensor([[1, 0, 3, 3]])  # Target sequence

# Forward pass
output = model(src_sentence, tgt_sentence)
print(output)




Source embeddings :- 

tensor([[[-0.4019,  1.1379, -0.5372,  1.8170],
         [-0.9956,  0.5849, -0.1061,  0.5031],
         [ 1.0417,  0.3042,  1.3250,  1.6455],
         [ 0.5991, -1.5866,  0.1186,  0.1184]]], grad_fn=<EmbeddingBackward0>)

Target embeddings :- 

tensor([[[ 0.6701, -0.7239,  0.8397,  0.7060],
         [ 0.8538, -0.2595, -0.4081, -1.8921],
         [-1.2405,  0.6807,  0.4640, -0.9147],
         [-1.2405,  0.6807,  0.4640, -0.9147]]], grad_fn=<EmbeddingBackward0>)

Positional encoded source embeddings:
tensor([[[-0.4019,  3.1379,  1.4628,  5.8170],
         [-0.1541,  2.1252,  1.9039,  4.5030],
         [ 1.9510,  0.8881,  3.3450,  5.6453],
         [ 0.7402, -1.5766,  2.1486,  4.1180]]], grad_fn=<AddBackward0>)

Positional encoded target embeddings:
tensor([[[ 1.6701,  0.2761,  3.8397,  4.7060],
         [ 2.6952,  0.2808,  2.6019,  2.1078],
         [ 0.6688,  0.2646,  3.4840,  3.0851],
         [-0.0993, -0.3093,  3.4940,  3.0849]]], grad_fn=<AddBackward0>)

query 

In [26]:
def get_all_intermediate_outputs(src_sentence, tgt_sentence,d_model, model, num_encoder_layers , num_decoder_layers):

    state_dict = model.state_dict()

    pe_src_embeds, pe_tgt_embeds = get_embedding_outputs(src_sentence,  tgt_sentence, state_dict, d_model = d_model)

    print("###"*25)
    print("### Encoder Start ###")
    print()

    x_enc = pe_src_embeds

    for lno in range(num_encoder_layers):
        attn_enc_output, src_len, head_dim, attn_weights = encoder_block_attn_output(x_enc, state_dict, layer_num = lno, need_weights = False)

        if lno == 0:
            tgt_len, bsz, embed_dim = x_enc.shape

        output_enc_final = encoder_block_post_attn_output(x_enc, attn_enc_output, state_dict, layer_num = lno , bsz = bsz, tgt_len = tgt_len)

        x_enc = output_enc_final

    
    print("### Encoder Done ###")
    print("\n\n\n")
    print("### Decoder Start ###")

    x_dec = pe_tgt_embeds
    memory = x_enc

    for lno in range(num_decoder_layers):

        self_attn_dec, dec_sa_wts = decoder_block_self_attn_output(x_dec, state_dict, layer_num = lno, need_weights = False)
        x_dec = dec_post_self_attn(self_attn_dec, x_dec, state_dict, layer_num = lno)

        attn_dec_mha_output, attn_dec_mha_wts = decoder_block_cross_attn_output(x_dec, memory, state_dict, layer_num = lno, tgt_len = tgt_len, src_len = src_len, head_dim = head_dim, need_weights = False)
        final_op = decoder_block_post_attn_output(x_dec, attn_dec_mha_output, state_dict, layer_num = lno)

        print(pe_tgt_embeds.shape, final_op.shape)
        x_dec = final_op


    final_op = feef_fwd_transformer(final_op, state_dict)

    print("### Decoder Done ###")


    return final_op

    

In [27]:
d_model = 4
num_encoder_layers = 3
num_decoder_layers = 3

get_all_intermediate_outputs(src_sentence, tgt_sentence, model = model, num_encoder_layers = num_encoder_layers , num_decoder_layers = num_encoder_layers, d_model=d_model)

Source sentence embedding
Word index: 0, Embedding: tensor([-0.4019,  1.1379, -0.5372,  1.8170])
Word index: 1, Embedding: tensor([-0.9956,  0.5849, -0.1061,  0.5031])
Word index: 2, Embedding: tensor([1.0417, 0.3042, 1.3250, 1.6455])
Word index: 3, Embedding: tensor([ 0.5991, -1.5866,  0.1186,  0.1184])

Target sentence embedding
Word index: 1, Embedding: tensor([ 0.6701, -0.7239,  0.8397,  0.7060])
Word index: 0, Embedding: tensor([ 0.8538, -0.2595, -0.4081, -1.8921])
Word index: 3, Embedding: tensor([-1.2405,  0.6807,  0.4640, -0.9147])
Word index: 3, Embedding: tensor([-1.2405,  0.6807,  0.4640, -0.9147])

PE source embeddings : 

tensor([[[-0.4019,  3.1379,  1.4628,  5.8170],
         [-0.1541,  2.1252,  1.9039,  4.5030],
         [ 1.9510,  0.8881,  3.3450,  5.6453],
         [ 0.7402, -1.5766,  2.1486,  4.1180]]])

PE target embeddings : 

tensor([[[ 1.6701,  0.2761,  3.8397,  4.7060],
         [ 2.6952,  0.2808,  2.6019,  2.1078],
         [ 0.6688,  0.2646,  3.4840,  3.0851],


tensor([[[ 0.9643, -1.2317,  0.7367, -0.0064, -0.7884, -0.7311, -0.0141,
           0.7371, -0.5670,  0.1799],
         [ 0.9586, -1.1972,  0.8049, -0.0248, -0.7074, -0.7126,  0.0332,
           0.7759, -0.5306,  0.1491],
         [ 0.8940, -0.9926,  0.8377,  0.0088, -0.8084, -0.8657, -0.0520,
           0.7437, -0.4131,  0.0413],
         [ 0.8677, -0.9303,  0.8531,  0.0138, -0.8124, -0.8961, -0.0622,
           0.7322, -0.3766,  0.0090]]], grad_fn=<AddBackward0>)