## Multi-head attention transformer
### Encoder and Decoder
### (With masking)

Pytorch's implementation (in built)

NOTE :- A new exmple must be used for testing the masked attention

In [30]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512, dropout=0):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        return self.encoding[:, :x.size(1)].detach()

class TransformerModel(nn.Module):
    
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=4, nhead=2, num_encoder_layers=1, num_decoder_layers=1):
        super(TransformerModel, self).__init__()
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, dropout=0)
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dropout=0
        )
        self.fc = nn.Linear(d_model, tgt_vocab_size)

    def forward(self, src, tgt):

        print("Source embeddings :- \n")
        print(self.src_embedding(src))
        print()

        print("Target embeddings :- \n")
        print(self.tgt_embedding(tgt))
        print()


        src = self.src_embedding(src) + self.positional_encoding(src)
        tgt = self.tgt_embedding(tgt) + self.positional_encoding(tgt)

        print("Positional encoded source embeddings:")
        print(src)
        print()

        print("Positional encoded target embeddings:")
        print(tgt)
        print()

        output = self.transformer(src, tgt)
        output = self.fc(output)
        return output

    

src_vocab_size = 10  # Source language vocabulary size
tgt_vocab_size = 10  # Target language vocabulary size
d_model = 4  # Dimension of the model
num_heads = 2
model = TransformerModel(src_vocab_size, tgt_vocab_size, d_model=d_model, nhead=num_heads)
# Source sentence in the source language
src_sentence = torch.tensor([[0, 1, 2, 3]])  # Source sequence

# Target sentence in the target language (translation of the source sentence)
tgt_sentence = torch.tensor([[1, 0, 3, 3]])  # Target sequence

# Forward pass
output = model(src_sentence, tgt_sentence)
print(output)



Source embeddings :- 

tensor([[[-0.5432,  1.5555,  0.3620,  0.0788],
         [-0.1357,  1.8274,  0.5430,  1.2430],
         [ 0.2074, -0.5689, -0.5513, -0.1528],
         [-1.2523,  0.0173,  1.2028, -0.4656]]], grad_fn=<EmbeddingBackward0>)

Target embeddings :- 

tensor([[[-0.3095, -0.8334,  0.5895, -0.5896],
         [-1.5573, -0.1136,  0.9670, -0.5579],
         [-0.4276,  2.0494,  1.2553,  0.2608],
         [-0.4276,  2.0494,  1.2553,  0.2608]]], grad_fn=<EmbeddingBackward0>)

Positional encoded source embeddings:
tensor([[[-0.5432,  2.5555,  0.3620,  1.0788],
         [ 0.7058,  2.3677,  0.5530,  2.2430],
         [ 1.1167, -0.9851, -0.5314,  0.8470],
         [-1.1112, -0.9727,  1.2328,  0.5340]]], grad_fn=<AddBackward0>)

Positional encoded target embeddings:
tensor([[[-0.3095,  0.1666,  0.5895,  0.4104],
         [-0.7159,  0.4267,  0.9770,  0.4421],
         [ 0.4817,  1.6332,  1.2753,  1.2606],
         [-0.2865,  1.0594,  1.2853,  1.2604]]], grad_fn=<AddBackward0>)

query 



## Functions to get the manual calulations for each component

In [31]:
from torch import functional as F

# Using the state dictionary to get the intermediate outputs
state_dict = model.state_dict()

need_weights = False

src_mask = None

tgt_mask = None

memory_mask = None

embed_dim = 4

num_heads = 2


In [33]:
def look_up_table(sentence, vocab_embeds, embedding):

    for i in range(sentence.size(0)):
        for j in range(sentence.size(1)):
            # Get the index for the current word in the sequence
            word_index = sentence[i, j].item()

            # Check if the index is within valid range
            if word_index < 0 or word_index >= vocab_embeds.size(0):
                raise ValueError(f"Invalid word index: {word_index}")

            # Lookup the corresponding embedding vector for the word
            embedding[i, j, :] = vocab_embeds[word_index, :]

            # Print intermediate results for debugging
            print(f"Word index: {word_index}, Embedding: {vocab_embeds[word_index, :]}")
    print()
    

    return embedding

### Embeddings and Positional encoding

In [34]:
def get_embedding_outputs(src_sentence,  tgt_sentence, state_dict, d_model):

    src_vocab_embeds = state_dict["src_embedding.weight"]

    src_embedding = torch.zeros(src_sentence.size(0), src_sentence.size(1), d_model)
    print("Source sentence embedding")
    src_embedding =  look_up_table(src_sentence, src_vocab_embeds, src_embedding)

    tgt_vocab_embeds = state_dict["tgt_embedding.weight"]

    tgt_embedding = torch.zeros(tgt_sentence.size(0), tgt_sentence.size(1), d_model)


    print("Target sentence embedding")
    tgt_embedding =  look_up_table(tgt_sentence, tgt_vocab_embeds, tgt_embedding)


    pe = PositionalEncoding(d_model = d_model, dropout=0)

    pe_src_embeds = src_embedding + pe(src_sentence)

    pe_tgt_embeds = tgt_embedding + pe(tgt_sentence)

    print("PE source embeddings : \n")
    print(pe_src_embeds)
    print()

    print("PE target embeddings : \n")
    print(pe_tgt_embeds)
    print()

    return pe_src_embeds, pe_tgt_embeds



In [35]:
pe_src_embeds, pe_tgt_embeds = get_embedding_outputs(src_sentence,  tgt_sentence, state_dict, d_model = d_model)

Source sentence embedding
Word index: 0, Embedding: tensor([-0.5432,  1.5555,  0.3620,  0.0788])
Word index: 1, Embedding: tensor([-0.1357,  1.8274,  0.5430,  1.2430])
Word index: 2, Embedding: tensor([ 0.2074, -0.5689, -0.5513, -0.1528])
Word index: 3, Embedding: tensor([-1.2523,  0.0173,  1.2028, -0.4656])

Target sentence embedding
Word index: 1, Embedding: tensor([-0.3095, -0.8334,  0.5895, -0.5896])
Word index: 0, Embedding: tensor([-1.5573, -0.1136,  0.9670, -0.5579])
Word index: 3, Embedding: tensor([-0.4276,  2.0494,  1.2553,  0.2608])
Word index: 3, Embedding: tensor([-0.4276,  2.0494,  1.2553,  0.2608])

PE source embeddings : 

tensor([[[-0.5432,  2.5555,  0.3620,  1.0788],
         [ 0.7058,  2.3677,  0.5530,  2.2430],
         [ 1.1167, -0.9851, -0.5314,  0.8470],
         [-1.1112, -0.9727,  1.2328,  0.5340]]])

PE target embeddings : 

tensor([[[-0.3095,  0.1666,  0.5895,  0.4104],
         [-0.7159,  0.4267,  0.9770,  0.4421],
         [ 0.4817,  1.6332,  1.2753,  1.260

## Encoder function to display the intermediate outputs and get the final outputs from the encoder

### Self attention 

In [46]:
def atten_product_needs_wts_false(Q, V, K, bsz, head_dim, src_len, tgt_len, attn_mask):
    Q1 = Q.view(bsz, num_heads, tgt_len, head_dim)
    K1 = K.view(bsz, num_heads, src_len, head_dim)
    V1 = V.view(bsz, num_heads, src_len, head_dim)


    L, S = Q1.size(-2), K1.size(-2)

    scale_factor = 1 / math.sqrt(Q1.size(-1)) 
    # scale_factor = 1
    attn_bias = torch.zeros(L, S, dtype=Q1.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias += attn_mask

    print(Q1.shape, K1.shape, K1.transpose(-2, -1).shape)
    attn_weight = Q1 @ K1.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias

    print("Attention weights = ")
    print(attn_weight.shape, attn_weight)

    attn_weight = torch.softmax(attn_weight, dim=-1)
    
    attn_output = attn_weight @ V1

    attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)

    return attn_output



def atten_product_needs_wts_true(Q, K, V, bsz, tgt_len, attn_mask):

    B, Nt, E = Q.shape

    Q_scaled = Q / math.sqrt(E)

    if attn_mask is not None:
        temp_pdt_matrix = torch.baddbmm(attn_mask, Q_scaled, K.transpose(-2, -1))
    else:
        temp_pdt_matrix = torch.bmm(Q_scaled, K.transpose(-2, -1))


    # Obtaining the diagonal entries of the un-normalised product (K.Q_scaled)
    pdt_matrix = torch.diagonal(temp_pdt_matrix, dim1=-2, dim2=-1).unsqueeze(-1)

    attn_wt_matrix = torch.nn.functional.softmax(pdt_matrix, dim=-1)

    attn_output = torch.bmm(attn_wt_matrix, V)

    attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)

    print("Encoder Attention output = ")
    print(attn_output)
    print()

    return attn_output, attn_wt_matrix


In [47]:
def encoder_block_attn_output(x, state_dict, layer_num, embed_dim, num_heads, need_weights = False, src_mask = None):

    
    query_enc = key_enc = value_enc = x

    tgt_len, bsz, embed_dim = x.shape


    W_enc = state_dict["transformer.encoder.layers.{}.self_attn.in_proj_weight".format(layer_num)]
    b_enc = state_dict["transformer.encoder.layers.{}.self_attn.in_proj_bias".format(layer_num)]


    head_dim = embed_dim//num_heads

    tempop1 = query_enc[0]@W_enc.T
    tempop1 = tempop1.T

    Q_enc,K_enc,V_enc = tempop1.T.chunk(3, dim= -1)


    Q_enc = Q_enc.unsqueeze(0)
    K_enc = K_enc.unsqueeze(0)
    V_enc = V_enc.unsqueeze(0)


    Q_enc = Q_enc.reshape(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    K_enc = K_enc.reshape(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    V_enc = V_enc.reshape(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)

    print("Q_enc_{} = ".format(layer_num))
    print(Q_enc)
    print()

    print("K_enc_{} = ".format(layer_num))
    print(K_enc)
    print()

    print("V_enc_{} = ".format(layer_num))
    print(V_enc)
    print()

    src_len = K_enc.size(1)


    attn_mask = src_mask
    if attn_mask is not None:

        # Ensuring attn_mask's dim is 3
        if attn_mask.dim() == 2:
            correct_2d_size = (tgt_len, src_len)
            if attn_mask.shape != correct_2d_size:
                raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
            attn_mask = attn_mask.unsqueeze(0)
        elif attn_mask.dim() == 3:
            correct_3d_size = (bsz * num_heads, tgt_len, src_len)
            if attn_mask.shape != correct_3d_size:
                raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
        else:
            raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")

    # attn_mask can be either (L,S) or (N*num_heads, L, S)
    # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
    # in order to match the input for SDPA of (N, num_heads, L, S)
    if attn_mask is not None:
        if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
            attn_mask = attn_mask.unsqueeze(0)
        else:
            attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)


    if need_weights is False:

        attn_output = atten_product_needs_wts_false(Q_enc, V_enc, K_enc, bsz, head_dim, src_len, tgt_len, attn_mask)

        return attn_output, src_len, head_dim, None
    
    else:

        attn_enc_output,attn_wt_matrix_enc = atten_product_needs_wts_true(Q_enc, K_enc, V_enc, bsz, tgt_len, attn_mask)

        return attn_enc_output, src_len, head_dim, attn_wt_matrix_enc

    

In [53]:
# need_weights = True

attn_enc_output, src_len, head_dim, attn_weights = encoder_block_attn_output(pe_src_embeds, state_dict, layer_num = 0, need_weights = need_weights, embed_dim=embed_dim, num_heads=num_heads)


Q_enc_0 = 
tensor([[[-0.6798,  1.4007]],

        [[-0.2742, -1.0386]],

        [[-1.9563,  0.6655]],

        [[-0.6882, -0.6257]],

        [[-0.6674, -0.6622]],

        [[-0.4141,  0.8993]],

        [[-0.2284, -0.3723]],

        [[-0.8246, -0.2570]]])

K_enc_0 = 
tensor([[[ 1.3458,  1.3452]],

        [[ 0.7641, -1.2760]],

        [[ 0.8736,  1.8444]],

        [[ 1.0950, -0.9651]],

        [[-0.5440, -0.1505]],

        [[-0.0823,  0.2863]],

        [[-0.6199,  0.2640]],

        [[-0.8359,  1.0204]]])

V_enc_0 = 
tensor([[[ 1.5610,  0.2397]],

        [[ 1.2207,  0.6767]],

        [[ 1.8320,  0.3493]],

        [[ 1.1892,  0.1905]],

        [[-0.4484, -0.4026]],

        [[-0.0178, -0.5572]],

        [[ 0.4017,  0.0647]],

        [[-0.5237, -0.4228]]])

Encoder Attention output = 
tensor([[ 1.5610,  0.2397,  1.2207,  0.6767],
        [ 1.8320,  0.3493,  1.1892,  0.1905],
        [-0.4484, -0.4026, -0.0178, -0.5572],
        [ 0.4017,  0.0647, -0.5237, -0.4228]])



In [52]:
# attn_weights

### Post self attention in encoder

In [65]:
def encoder_block_post_attn_output(x, attn_enc_output, state_dict, layer_num, bsz, tgt_len):

    op_enc_1 = torch.matmul(attn_enc_output, state_dict["transformer.encoder.layers.{}.self_attn.out_proj.weight".format(layer_num)].t()) + state_dict["transformer.encoder.layers.{}.self_attn.out_proj.bias".format(layer_num)]

    attn_enc_output = op_enc_1.view(tgt_len, bsz, attn_enc_output.size(1))


    # Here src is the original passed inputs to the 1st transformer encoder layer which 
    # are pe_src_embeds 
    output_enc_1 = attn_enc_output + x
    output_enc_1

    linear_result_enc_1 = output_enc_1*state_dict["transformer.encoder.layers.{}.norm1.weight".format(layer_num)] + state_dict["transformer.encoder.layers.{}.norm1.bias".format(layer_num)]

    # Layer normalization from Torch's implementation
    layernorm_enc_1 = torch.nn.LayerNorm(normalized_shape=linear_result_enc_1.shape[2:])
    linear_op_enc_1 = layernorm_enc_1(linear_result_enc_1)


    # Manual Layer Normalization
    x = linear_result_enc_1

    # Obtained layer norm weights and biases (learnable)
    w = layernorm_enc_1.weight
    b = layernorm_enc_1.bias

    linear_result_enc_1f = w*x + b

    epsilon = 1e-5  
    mean = linear_result_enc_1f.mean(dim=-1, keepdim=True)
    std = linear_result_enc_1f.std(dim=-1, unbiased=False, keepdim=True)
    normalized_result_enc_1 = (linear_result_enc_1f - mean) / (std + epsilon) * w + b


    op_enc_1 = torch.matmul(normalized_result_enc_1, state_dict["transformer.encoder.layers.{}.linear1.weight".format(layer_num)].t()) + state_dict["transformer.encoder.layers.{}.linear1.bias".format(layer_num)]
    op_enc_1_relu = torch.nn.functional.relu(op_enc_1)
    op_enc_2 = torch.matmul(op_enc_1_relu, state_dict["transformer.encoder.layers.{}.linear2.weight".format(layer_num)].t()) + state_dict["transformer.encoder.layers.{}.linear2.bias".format(layer_num)]


    output_enc_2 = op_enc_2 + linear_op_enc_1
    output_enc_2_norm = output_enc_2*state_dict["transformer.encoder.layers.{}.norm2.weight".format(layer_num)] + state_dict["transformer.encoder.layers.{}.norm2.bias".format(layer_num)]

    # Layer normalization from Torch's implementation
    layernorm_enc_final = torch.nn.LayerNorm(normalized_shape=output_enc_2_norm.shape[2:])
    output_enc_final = layernorm_enc_final(output_enc_2_norm)


    # Manual Layer Normalization 
    x = output_enc_2_norm

    # Obtained layer norm weights and biases (learnable)
    w = layernorm_enc_final.weight
    b = layernorm_enc_final.bias

    linear_result_enc_2 = w*x + b

    epsilon = 1e-5  
    mean = linear_result_enc_2.mean(dim=-1, keepdim=True)
    std = linear_result_enc_2.std(dim=-1, unbiased=False, keepdim=True)
    output_enc_final = (linear_result_enc_2 - mean) / (std + epsilon) * w + b

    print("Final Encoder {} Output :".format(layer_num))
    print("norm2(norm1(x + self_atten(x)) + feed_fwd_op)\n")
    print(output_enc_final)
    print()

    return output_enc_final


In [66]:
tgt_len, bsz, embed_dim = pe_src_embeds.shape
output_enc_final = encoder_block_post_attn_output(pe_src_embeds, attn_enc_output, state_dict, layer_num = 0 , bsz = bsz, tgt_len = tgt_len)

Final Encoder 0 Output :
norm2(norm1(x + self_atten(x)) + feed_fwd_op)

tensor([[[ 0.2780,  1.5154, -0.9923, -0.8010],
         [ 0.8535,  0.9828, -1.4747, -0.3617],
         [ 0.3958, -1.2856, -0.5023,  1.3922],
         [-0.9210, -0.9146,  1.4792,  0.3564]]], grad_fn=<AddBackward0>)



## Decoder functions to display the intermediate outputs and get the final outputs from the decoder


### Self attention outputs from a decoder block

In [67]:
def decoder_block_self_attn_output(x, state_dict, layer_num, tgt_mask = None,need_weights = False):

    query_dec = key_dec = value_dec = x

    tgt_len, bsz, embed_dim = x.shape

    W_dec = state_dict["transformer.decoder.layers.{}.self_attn.in_proj_weight".format(layer_num)]
    b_dec = state_dict["transformer.decoder.layers.{}.self_attn.in_proj_bias".format(layer_num)]


    head_dim = embed_dim//num_heads

    tempop1 = query_dec[0]@W_dec.T
    tempop1 = tempop1.T

    Q_dec,K_dec,V_dec = tempop1.T.chunk(3, dim= -1)


    Q_dec = Q_dec.unsqueeze(0)
    K_dec = K_dec.unsqueeze(0)
    V_dec = V_dec.unsqueeze(0)


    Q_dec = Q_dec.reshape(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    K_dec = K_dec.reshape(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    V_dec = V_dec.reshape(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)

    print("Q_dec_{} = ".format(layer_num))
    print(Q_dec)
    print()

    print("K_dec_{} = ".format(layer_num))
    print(K_dec)
    print()

    print("V_dec_{} = ".format(layer_num))
    print(V_dec)
    print()

    src_len = K_dec.size(1)


    attn_mask = tgt_mask
    if attn_mask is not None:

        # Ensuring attn_mask's dim is 3
        if attn_mask.dim() == 2:
            correct_2d_size = (tgt_len, src_len)
            if attn_mask.shape != correct_2d_size:
                raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
            attn_mask = attn_mask.unsqueeze(0)
        elif attn_mask.dim() == 3:
            correct_3d_size = (bsz * num_heads, tgt_len, src_len)
            if attn_mask.shape != correct_3d_size:
                raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
        else:
            raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")

    # attn_mask can be either (L,S) or (N*num_heads, L, S)
    # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
    # in order to match the input for SDPA of (N, num_heads, L, S)
    if attn_mask is not None:
        if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
            attn_mask = attn_mask.unsqueeze(0)
        else:
            attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)


    

    if need_weights is False:
        attn_output = atten_product_needs_wts_false(Q_dec, V_dec, K_dec, bsz, head_dim, src_len, tgt_len, attn_mask)

        print("Decoder Self Attention = ")
        print(attn_output)
        print()

        op_dec_1 = torch.matmul(attn_output, state_dict["transformer.decoder.layers.{}.self_attn.out_proj.weight".format(layer_num)].t()) + state_dict["transformer.decoder.layers.{}.self_attn.out_proj.bias".format(layer_num)]
        attn_dec_output = op_dec_1.view(tgt_len, bsz, attn_output.size(1))

        return attn_dec_output, None
    
    else:

        attn_dec_output,attn_wt_matrix_dec = atten_product_needs_wts_true(Q_dec, K_dec, V_dec, bsz, tgt_len, attn_mask)

        print("Decoder Attention output = ")
        print(attn_wt_matrix_dec)
        print()

        op_dec_1 = torch.matmul(attn_dec_output, state_dict["transformer.decoder.layers.{}.self_attn.out_proj.weight".format(layer_num)].t()) + state_dict["transformer.decoder.layers.{}.self_attn.out_proj.bias".format(layer_num)]
        attn_dec_output = op_dec_1.view(tgt_len, bsz, attn_dec_output.size(1))
    

        return attn_dec_output, attn_wt_matrix_dec




# query_dec = key_dec = value_dec = pe_tgt_embeds 

# tgt_len, bsz, embed_dim = pe_src_embeds.shape
    

In [57]:
need_weights 

False

In [58]:
# need_weights = True

self_attn_dec, dec_sa_wts = decoder_block_self_attn_output(pe_tgt_embeds, state_dict, layer_num = 0, need_weights = need_weights)

Q_dec_0 = 
tensor([[[ 0.0632,  0.2018]],

        [[ 0.5668, -0.7167]],

        [[-0.0096,  0.2866]],

        [[ 0.8592, -1.2781]],

        [[ 0.5190,  0.4592]],

        [[ 2.0303, -1.8184]],

        [[ 0.3563,  0.6196]],

        [[ 1.7640, -1.9240]]])

K_dec_0 = 
tensor([[[-0.2527, -0.2785]],

        [[-0.2614, -0.5521]],

        [[-0.2099, -0.4170]],

        [[-0.6074, -0.8707]],

        [[ 0.0705, -1.5230]],

        [[-0.9272, -0.7401]],

        [[-0.3300, -1.0468]],

        [[-0.8275, -1.0830]]])

V_dec_0 = 
tensor([[[ 0.0145, -0.0974]],

        [[-0.1493,  0.3044]],

        [[ 0.0765, -0.2509]],

        [[-0.4107,  0.5028]],

        [[-0.1104,  1.0977]],

        [[ 1.0221,  0.1571]],

        [[ 0.0020,  0.4013]],

        [[ 0.3516,  0.5678]]])

torch.Size([4, 2, 1, 2]) torch.Size([4, 2, 1, 2]) torch.Size([4, 2, 2, 1])
Attention weights = 
torch.Size([4, 2, 1, 1]) tensor([[[[-0.0510]],

         [[ 0.1750]]],


        [[[-0.0831]],

         [[ 0.4179]]],


   

In [59]:
self_attn_dec

tensor([[[ 0.3354,  0.2949,  0.3642, -0.1185],
         [ 0.7282,  0.5757,  0.7341, -0.0761],
         [-1.3969, -0.5911, -1.1591, -0.9625],
         [-0.2074,  0.1613, -0.0765, -0.7489]]])

In [60]:
def dec_post_self_attn(self_attn_dec, x, state_dict, layer_num):
    output_dec_1 = self_attn_dec + x

    linear_result_dec_1 = output_dec_1*state_dict["transformer.decoder.layers.{}.norm1.weight".format(layer_num)] + state_dict["transformer.decoder.layers.{}.norm1.bias".format(layer_num)]

    layernorm_dec_1 = torch.nn.LayerNorm(normalized_shape=linear_result_dec_1.shape[2:])
    linear_op_dec_1 = layernorm_dec_1(linear_result_dec_1)

    # From torch's implementation
    # print(linear_op_dec_1)

    x = linear_result_dec_1

    # Obtained layer norm weights and biases (learnable)
    w = layernorm_dec_1.weight
    b = layernorm_dec_1.bias

    linear_result_dec_1f = w*x + b

    epsilon = 1e-5  
    mean = linear_result_dec_1f.mean(dim=-1, keepdim=True)
    std = linear_result_dec_1f.std(dim=-1, unbiased=False, keepdim=True)
    normalized_result_dec_1 = (linear_result_dec_1f - mean) / (std + epsilon) * w + b


    print("Decoder_{} norm1(x + sa(x))".format(layer_num))
    print(normalized_result_dec_1)
    print()

    return normalized_result_dec_1

In [61]:
x_dec = dec_post_self_attn(self_attn_dec, pe_tgt_embeds, state_dict, layer_num = 0)

Decoder_0 norm1(x + sa(x))
tensor([[[-1.2045,  0.0837,  1.5386, -0.4179],
         [-1.1748,  0.3544,  1.4489, -0.6285],
         [-1.5035,  1.2978, -0.0273,  0.2330],
         [-1.5796,  0.8699,  0.8529, -0.1432]]], grad_fn=<AddBackward0>)



## Cross attention in decoder 
### query, key, value =  x, mem, mem 

### Cross attention between the encoder's final output and the decoder layer 

In [68]:

memory = output_enc_final
x_dec

tensor([[[-1.2045,  0.0837,  1.5386, -0.4179],
         [-1.1748,  0.3544,  1.4489, -0.6285],
         [-1.5035,  1.2978, -0.0273,  0.2330],
         [-1.5796,  0.8699,  0.8529, -0.1432]]], grad_fn=<AddBackward0>)

In [69]:
def decoder_block_cross_attn_output(x_dec, memory, state_dict, layer_num, tgt_len, src_len, head_dim, memory_mask = None,need_weights = False):

    query_dec_mha = x_dec
    key_dec_mha, value_dec_mha = memory, memory


    W_dec_mha = state_dict["transformer.decoder.layers.{}.multihead_attn.in_proj_weight".format(layer_num)]
    b_dec_mha = state_dict["transformer.decoder.layers.{}.multihead_attn.in_proj_bias".format(layer_num)]


    ########## CHANGE FROM SELF ATTENTION ############

    W_q, W_k, W_v = W_dec_mha.chunk(3)

    Q_dec_mha = query_dec_mha[0]@W_q.T

    K_dec_mha = key_dec_mha[0]@W_k.T

    V_dec_mha = value_dec_mha[0]@W_v.T

    #################################################

    K_dec_mha = K_dec_mha.unsqueeze(0)
    V_dec_mha = V_dec_mha.unsqueeze(0)
    Q_dec_mha = Q_dec_mha.unsqueeze(0)

    Q_dec_mha = Q_dec_mha.reshape(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    K_dec_mha = K_dec_mha.reshape(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    V_dec_mha = V_dec_mha.reshape(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)

    print("Q_dec_{} = ".format(layer_num))
    print(Q_dec_mha)
    print()

    print("K_dec_{} = ".format(layer_num))
    print(K_dec_mha)
    print()

    print("V_dec_{} = ".format(layer_num))
    print(V_dec_mha)
    print()


    attn_mask = memory_mask
    if attn_mask is not None:

        # Ensuring attn_mask's dim is 3
        if attn_mask.dim() == 2:
            correct_2d_size = (tgt_len, src_len)
            if attn_mask.shape != correct_2d_size:
                raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
            attn_mask = attn_mask.unsqueeze(0)
        elif attn_mask.dim() == 3:
            correct_3d_size = (bsz * num_heads, tgt_len, src_len)
            if attn_mask.shape != correct_3d_size:
                raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
        else:
            raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")

    # attn_mask can be either (L,S) or (N*num_heads, L, S)
    # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
    # in order to match the input for SDPA of (N, num_heads, L, S)
    if attn_mask is not None:
        if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
            attn_mask = attn_mask.unsqueeze(0)
        else:
            attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)



    if need_weights is False:

        attn_output_dec_mha = atten_product_needs_wts_false(Q_dec_mha, V_dec_mha, K_dec_mha, bsz, head_dim, src_len, tgt_len, attn_mask)
        
        print("Cross attention in decoder_{}".format(layer_num))
        print(attn_output_dec_mha)
        print()

        op_dec_mha_1 = torch.matmul(attn_output_dec_mha, state_dict["transformer.decoder.layers.{}.multihead_attn.out_proj.weight".format(layer_num)].t()) + state_dict["transformer.decoder.layers.{}.multihead_attn.out_proj.bias".format(layer_num)]

        attn_dec_mha_output = op_dec_mha_1.view(tgt_len, bsz, attn_output_dec_mha.size(1))

        return attn_dec_mha_output, None
    
    else: 

    
        attn_dec_mha_output ,attn_wt_matrix_dec_mha = atten_product_needs_wts_true(Q_dec_mha, K_dec_mha, V_dec_mha, bsz, tgt_len, attn_mask)

        print("Decoder mha Attention output = ")
        print(attn_dec_mha_output)
        print()


        op_dec_mha = torch.matmul(attn_dec_mha_output, state_dict["transformer.decoder.layers.{}.multihead_attn.out_proj.weight".format(layer_num)].t()) + state_dict["transformer.decoder.layers.{}.multihead_attn.out_proj.bias".format(layer_num)]
        attn_dec_output = op_dec_mha.view(tgt_len, bsz, attn_dec_mha_output.size(1))


        return attn_dec_output , attn_wt_matrix_dec_mha

        





In [70]:
# need_weights = True

attn_dec_mha_output, attn_dec_mha_wts = decoder_block_cross_attn_output(x_dec, memory, state_dict, layer_num = 0, tgt_len = tgt_len, src_len = src_len, head_dim = head_dim, need_weights = need_weights)

Q_dec_0 = 
tensor([[[ 0.2437, -0.8084]],

        [[-0.2721,  0.8057]],

        [[ 0.5289, -0.7797]],

        [[-0.1821,  0.9150]],

        [[ 1.1448, -0.3859]],

        [[ 0.2385, -0.1027]],

        [[ 0.8741, -0.6890]],

        [[ 0.0020,  0.4243]]], grad_fn=<TransposeBackward0>)

K_dec_0 = 
tensor([[[ 1.1384,  0.6546]],

        [[ 0.8390, -0.5678]],

        [[ 1.1360,  1.1477]],

        [[ 0.4548,  0.0592]],

        [[-0.5776,  0.2464]],

        [[-1.0737,  1.1551]],

        [[-1.1307, -1.1934]],

        [[-0.4308, -0.1117]]], grad_fn=<TransposeBackward0>)

V_dec_0 = 
tensor([[[ 0.2457,  0.0424]],

        [[ 0.5397,  0.8359]],

        [[-0.1711, -0.0743]],

        [[ 0.2930,  1.1949]],

        [[-0.7912, -0.3454]],

        [[-0.6070,  0.2404]],

        [[ 0.1975,  0.0695]],

        [[-0.2703, -1.2053]]], grad_fn=<TransposeBackward0>)

torch.Size([4, 2, 1, 2]) torch.Size([4, 2, 1, 2]) torch.Size([4, 2, 2, 1])
Attention weights = 
torch.Size([4, 2, 1, 1]) tensor([[

In [71]:
def decoder_block_post_attn_output(x_dec, attn_dec_mha_output, state_dict, layer_num):

    output_dec_2 = attn_dec_mha_output + x_dec

    linear_result_dec_2 = output_dec_2*state_dict["transformer.decoder.layers.{}.norm2.weight".format(layer_num)] + state_dict["transformer.decoder.layers.{}.norm2.bias".format(layer_num)]

    # Layer normalization from Torch's implementation 
    layernorm_dec_2 = torch.nn.LayerNorm(normalized_shape=linear_result_dec_2.shape[2:])
    linear_op_dec_2 = layernorm_dec_2(linear_result_dec_2)

    x = linear_result_dec_2
    w = layernorm_dec_2.weight
    b = layernorm_dec_2.bias

    linear_result_dec_2f = w*x + b

    epsilon = 1e-5  
    mean = linear_result_dec_2f.mean(dim=-1, keepdim=True)
    std = linear_result_dec_2f.std(dim=-1, unbiased=False, keepdim=True)
    normalized_result_dec_2 = (linear_result_dec_2f - mean) / (std + epsilon) * w + b

    print("norm2(x' + mha(x', mem)), \n where x' = Decoder_curr_layer norm1(x + sa(x))")
    print(normalized_result_dec_2)
    print("\n\n")

    x_dec2_norm = normalized_result_dec_2

    op_dec_1 = torch.matmul(x_dec2_norm, state_dict["transformer.decoder.layers.{}.linear1.weight".format(layer_num)].t()) + state_dict["transformer.decoder.layers.{}.linear1.bias".format(layer_num)]
    op_dec_1_relu = torch.nn.functional.relu(op_dec_1)
    op_dec_2 = torch.matmul(op_dec_1_relu, state_dict["transformer.decoder.layers.{}.linear2.weight".format(layer_num)].t()) + state_dict["transformer.decoder.layers.{}.linear2.bias".format(layer_num)]

    ff_dec = op_dec_2

    x_dec3_unorm = x_dec2_norm + ff_dec

    linear_result_dec_3 = x_dec3_unorm*state_dict["transformer.decoder.layers.0.norm3.weight"] + state_dict["transformer.decoder.layers.0.norm3.bias"]

    # Layer normalization from Torch's implementation 
    layernorm_dec_3 = torch.nn.LayerNorm(normalized_shape=linear_result_dec_3.shape[2:])
    linear_op_dec_3 = layernorm_dec_3(linear_result_dec_3)

    # From torch's implementation
    x = linear_result_dec_3

    # Obtained layer norm weights and biases (learnable)
    w = layernorm_dec_3.weight
    b = layernorm_dec_3.bias

    linear_result_dec_3f = w*x + b

    epsilon = 1e-5  
    mean = linear_result_dec_3f.mean(dim=-1, keepdim=True)
    std = linear_result_dec_3f.std(dim=-1, unbiased=False, keepdim=True)
    normalized_result_dec_3 = (linear_result_dec_3f - mean) / (std + epsilon) * w + b

    print("norm3(x'' + ff(x'')) \n where, x'' = Decoder_curr_layer norm2(x' + mha(x'))")
    print(normalized_result_dec_3)
    print()

    return normalized_result_dec_3

def feef_fwd_transformer(dec_output_final, state_dict):

    final_op = dec_output_final@state_dict["fc.weight"].T + state_dict["fc.bias"]

    return final_op

In [72]:
final_attn_op_decoder = decoder_block_post_attn_output(x_dec, attn_dec_mha_output, state_dict, layer_num = 0)

norm2(x' + mha(x', mem)), 
 where x' = Decoder_curr_layer norm1(x + sa(x))
tensor([[[-0.6314, -1.0291,  1.5930,  0.0675],
         [-1.1430, -0.4263,  1.5847, -0.0154],
         [-1.4210,  1.3967, -0.1094,  0.1336],
         [-1.3878,  0.9595,  0.9427, -0.5144]]], grad_fn=<AddBackward0>)



norm3(x'' + ff(x'')) 
 where, x'' = Decoder_curr_layer norm2(x' + mha(x'))
tensor([[[-0.8488, -0.8435,  1.5998,  0.0924],
         [-1.2543, -0.3016,  1.5280,  0.0280],
         [-1.5712,  1.2070,  0.1186,  0.2456],
         [-1.4815,  0.8364,  0.9924, -0.3473]]], grad_fn=<AddBackward0>)



In [73]:
final_op = feef_fwd_transformer(final_attn_op_decoder, state_dict)

In [74]:
final_op

tensor([[[ 1.4197, -0.4107, -0.1168,  0.9968, -0.2262,  0.8008,  0.1751,
          -0.5905,  1.7112, -0.8088],
         [ 1.5076, -0.4805, -0.0610,  1.3007, -0.1737,  0.7942,  0.3358,
          -0.8888,  1.4928, -0.8529],
         [ 1.0689, -0.2571, -0.0720,  1.4597,  0.3436,  0.2593,  0.7902,
          -0.8153,  0.1954, -0.6916],
         [ 1.2454, -0.4862, -0.1692,  1.4866, -0.2175,  0.4440,  0.5660,
          -1.1615,  0.6706, -0.7323]]], grad_fn=<AddBackward0>)

########################################################################

### Examaple :- 

Transformer with 3 encoders and 3 decoders

In [48]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512, dropout=0):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        return x + self.encoding[:, :x.size(1)].detach()

class TransformerModel(nn.Module):
    
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=4, nhead=2, num_encoder_layers=1, num_decoder_layers=1):
        super(TransformerModel, self).__init__()
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, dropout=0)
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dropout=0
        )
        self.fc = nn.Linear(d_model, tgt_vocab_size)

    def forward(self, src, tgt):

        print("Source embeddings :- \n")
        print(self.src_embedding(src))
        print()

        print("Target embeddings :- \n")
        print(self.tgt_embedding(tgt))
        print()


        src = self.src_embedding(src) + self.positional_encoding(src)
        tgt = self.tgt_embedding(tgt) + self.positional_encoding(tgt)

        print("Positional encoded source embeddings:")
        print(src)
        print()

        print("Positional encoded target embeddings:")
        print(tgt)
        print()

        output = self.transformer(src, tgt)
        output = self.fc(output)
        return output

    

src_vocab_size = 10  # Source language vocabulary size
tgt_vocab_size = 10  # Target language vocabulary size
d_model = 4  # Dimension of the model
num_heads = 2
num_encoder_layers = 3
num_decoder_layers = 3

need_weights = False

src_mask = None
tgt_mask = None
memory_mask = None


model = TransformerModel(src_vocab_size, tgt_vocab_size, d_model=d_model, nhead=num_heads, num_encoder_layers = num_encoder_layers, num_decoder_layers=num_decoder_layers)
# Source sentence in the source language
src_sentence = torch.tensor([[0, 1, 2, 3]])  # Source sequence

# Target sentence in the target language (translation of the source sentence)
tgt_sentence = torch.tensor([[1, 0, 3, 3]])  # Target sequence

# Forward pass
output = model(src_sentence, tgt_sentence)
print(output)


Source embeddings :- 

tensor([[[-0.2205,  1.7253, -0.2514,  0.9972],
         [ 0.1335, -0.5650, -0.7537,  0.0105],
         [-0.8226,  0.9250,  1.7916, -0.4149],
         [ 0.2400, -2.0239, -0.6187, -0.7110]]], grad_fn=<EmbeddingBackward0>)

Target embeddings :- 

tensor([[[ 0.0958, -0.3552,  1.2360, -0.8826],
         [ 0.0204,  1.1688, -0.0480, -1.3756],
         [-0.1728,  0.3034, -1.1816, -1.1548],
         [-0.1728,  0.3034, -1.1816, -1.1548]]], grad_fn=<EmbeddingBackward0>)

Positional encoded source embeddings:
tensor([[[-0.2205,  3.7253,  1.7486,  4.9972],
         [ 0.9750,  0.9753,  1.2563,  4.0105],
         [ 0.0867,  1.5088,  3.8116,  3.5849],
         [ 0.3811, -2.0139,  1.4113,  3.2885]]], grad_fn=<AddBackward0>)

Positional encoded target embeddings:
tensor([[[ 1.0958,  0.6448,  4.2360,  3.1174],
         [ 1.8619,  1.7091,  2.9620,  2.6244],
         [ 1.7365, -0.1128,  1.8384,  2.8450],
         [ 0.9683, -0.6866,  1.8484,  2.8448]]], grad_fn=<AddBackward0>)

query 



In [53]:
def get_all_intermediate_outputs(src_sentence, tgt_sentence,d_model, model, num_encoder_layers , num_decoder_layers):

    state_dict = model.state_dict()

    pe_src_embeds, pe_tgt_embeds = get_embedding_outputs(src_sentence,  tgt_sentence, state_dict, d_model = d_model)

    print("###"*25)
    print("### Encoder Start ###")
    print()

    x_enc = pe_src_embeds

    for lno in range(num_encoder_layers):
        attn_enc_output, src_len, head_dim, attn_weights = encoder_block_attn_output(x_enc, state_dict, layer_num = lno, need_weights = False, embed_dim = d_model, num_heads= num_heads, src_mask=None)

        if lno == 0:
            tgt_len, bsz, embed_dim = x_enc.shape

        output_enc_final = encoder_block_post_attn_output(x_enc, attn_enc_output, state_dict, layer_num = lno , bsz = bsz, tgt_len = tgt_len)

        x_enc = output_enc_final

    
    print("### Encoder Done ###")
    print("\n\n\n")
    print("### Decoder Start ###")

    x_dec = pe_tgt_embeds
    memory = x_enc

    for lno in range(num_decoder_layers):

        self_attn_dec, dec_sa_wts = decoder_block_self_attn_output(x_dec, state_dict, layer_num = lno, need_weights = False, tgt_mask=None)
        x_dec = dec_post_self_attn(self_attn_dec, x_dec, state_dict, layer_num = lno)

        attn_dec_mha_output, attn_dec_mha_wts = decoder_block_cross_attn_output(x_dec, memory, state_dict, layer_num = lno, tgt_len = tgt_len, src_len = src_len, head_dim = head_dim, need_weights = False, memory_mask=None)
        final_op = decoder_block_post_attn_output(x_dec, attn_dec_mha_output, state_dict, layer_num = lno)

        print(pe_tgt_embeds.shape, final_op.shape)
        x_dec = final_op


    final_op = feef_fwd_transformer(final_op, state_dict)

    print("### Decoder Done ###")


    return final_op

    

In [54]:
d_model = 4
num_encoder_layers = 3
num_decoder_layers = 3

get_all_intermediate_outputs(src_sentence, tgt_sentence, model = model, num_encoder_layers = num_encoder_layers , num_decoder_layers = num_encoder_layers, d_model=d_model)

Source sentence embedding
Word index: 0, Embedding: tensor([-0.2205,  1.7253, -0.2514,  0.9972])
Word index: 1, Embedding: tensor([ 0.1335, -0.5650, -0.7537,  0.0105])
Word index: 2, Embedding: tensor([-0.8226,  0.9250,  1.7916, -0.4149])
Word index: 3, Embedding: tensor([ 0.2400, -2.0239, -0.6187, -0.7110])

Target sentence embedding
Word index: 1, Embedding: tensor([ 0.0958, -0.3552,  1.2360, -0.8826])
Word index: 0, Embedding: tensor([ 0.0204,  1.1688, -0.0480, -1.3756])
Word index: 3, Embedding: tensor([-0.1728,  0.3034, -1.1816, -1.1548])
Word index: 3, Embedding: tensor([-0.1728,  0.3034, -1.1816, -1.1548])

PE source embeddings : 

tensor([[[-0.2205,  3.7253,  1.7486,  4.9972],
         [ 0.9750,  0.9753,  1.2563,  4.0105],
         [ 0.0867,  1.5088,  3.8116,  3.5849],
         [ 0.3811, -2.0139,  1.4113,  3.2885]]])

PE target embeddings : 

tensor([[[ 1.0958,  0.6448,  4.2360,  3.1174],
         [ 1.8619,  1.7091,  2.9620,  2.6244],
         [ 1.7365, -0.1128,  1.8384,  2.845

tensor([[[-0.1736, -0.6635, -0.7838, -0.4535, -1.0496, -0.8957,  0.4070,
           0.4802,  0.3982, -0.3234],
         [-0.1878, -0.6338, -0.7410, -0.4279, -1.0766, -0.9809,  0.3347,
           0.5342,  0.4060, -0.4664],
         [-0.1235, -0.6708, -0.8959, -0.4473, -0.9361, -0.7919,  0.5501,
           0.3801,  0.3818, -0.0212],
         [-0.1777, -0.6718, -0.6719, -0.4529, -1.0911, -0.8548,  0.3259,
           0.5181,  0.4016, -0.4310]]], grad_fn=<AddBackward0>)

In [8]:
import torch


batch_size = 1
sequence_length = 3
embedding_size = 2


# Assuming Q_enc1 and K_enc1 are your input tensors
Q_enc1 = torch.randn((batch_size, sequence_length, embedding_size))
K_enc1 = torch.randn((batch_size, sequence_length, embedding_size))


# Calculate attention weights
attn_weight = Q_enc1 @ K_enc1.transpose(-2, -1) 
attn_weight = torch.softmax(attn_weight, dim=-1)

# Display the result
print("Original Attention Weights:")
print(attn_weight)

print("\nSum of Attention Weights along the last dimension:")
print(torch.sum(attn_weight, dim=-1))



Original Attention Weights:
tensor([[[0.2494, 0.1807, 0.5699],
         [0.5087, 0.4590, 0.0323],
         [0.3180, 0.3567, 0.3253]]])

Sum of Attention Weights along the last dimension:
tensor([[1., 1., 1.]])


In [25]:
m = nn.Softmax(dim=0)
input = torch.randint(low= 1, high=5, size= (2, 3), dtype = torch.float32)

print(input)

output = m(input)

print(output)

tensor([[3., 1., 1.],
        [4., 1., 3.]])
tensor([[0.2689, 0.5000, 0.1192],
        [0.7311, 0.5000, 0.8808]])


In [26]:
m = nn.Softmax(dim=1)
output = m(input)

print(output)


tensor([[0.7870, 0.1065, 0.1065],
        [0.7054, 0.0351, 0.2595]])


In [27]:
m = nn.Softmax(dim=-1)
output = m(input)

print(output)

tensor([[0.7870, 0.1065, 0.1065],
        [0.7054, 0.0351, 0.2595]])
