## Multi-head attention transformer
### Encoder and Decoder
### (With masking)

Pytorch's implementation (in built)

NOTE :- A new exmple must be used for testing the masked attention

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

torch.manual_seed(6)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512, dropout=0):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        return self.encoding[:, :x.size(1)].detach()

class TransformerModel(nn.Module):
    
    def __init__(self, src_vocab_size, tgt_vocab_size,max_seq_len, d_model=4, nhead=2, num_encoder_layers=1, num_decoder_layers=1, d_ff = 4):
        super(TransformerModel, self).__init__()
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, dropout=0, max_len=max_seq_len)
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=d_ff,
            dropout=0
        )
        self.fc = nn.Linear(d_model, tgt_vocab_size)

    def forward(self, src, tgt):

        print("Source embeddings :- \n")
        print(self.src_embedding(src))
        print()

        print("Target embeddings :- \n")
        print(self.tgt_embedding(tgt))
        print()


        src = self.src_embedding(src) + self.positional_encoding(src)
        tgt = self.tgt_embedding(tgt) + self.positional_encoding(tgt)

        print("Positional encoded source embeddings:")
        print(src)
        print()

        print("Positional encoded target embeddings:")
        print(tgt)
        print()

        output = self.transformer(src, tgt)
        output = self.fc(output)
        
        return output

    

src_vocab_size = 10  # Source language vocabulary size
tgt_vocab_size = 10  # Target language vocabulary size
d_model = 6  # Dimension of the model
num_heads = 2
# max_seq_len = 4
max_seq_len = 5
d_ff = 4

model = TransformerModel(src_vocab_size, tgt_vocab_size, d_model=d_model, max_seq_len = max_seq_len, nhead=num_heads, d_ff=d_ff)

# Source sentence in the source language
# Source token indexes from src vocabulary
src_sentence = torch.tensor([[0, 2], [1, 0], [2, 2], [3, 5]])


# Target sentence in the target language (translation of the source sentence)
# Target token indexes from tgt vocabulary
tgt_sentence = torch.tensor([[1, 7], [3, 4], [5, 2], [8, 0], [6, 1]])  # Target sequence


# Forward pass
output = model(src_sentence, tgt_sentence)
print(output, output.shape)


Source embeddings :- 

tensor([[[-1.2113,  0.6304, -1.4713, -1.3352, -0.4897,  0.1317],
         [-0.5432, -1.0841,  1.4612, -1.6279, -1.4801, -1.0631]],

        [[ 0.3295,  0.3264, -0.4806,  1.1032,  2.5485,  0.3006],
         [-1.2113,  0.6304, -1.4713, -1.3352, -0.4897,  0.1317]],

        [[-0.5432, -1.0841,  1.4612, -1.6279, -1.4801, -1.0631],
         [-0.5432, -1.0841,  1.4612, -1.6279, -1.4801, -1.0631]],

        [[ 0.3630,  0.3995,  0.1457, -0.7345, -0.9873,  1.8512],
         [ 0.0160,  0.4019,  1.9538, -0.4460,  1.7102,  0.8944]]],
       grad_fn=<EmbeddingBackward0>)

Target embeddings :- 

tensor([[[-0.3214, -0.6803, -0.5234, -0.8720,  0.2539,  0.3173],
         [ 0.1460,  1.4648,  0.9936, -0.3499, -0.9568,  1.7102]],

        [[-0.2372, -1.7878, -2.8351,  0.3415, -0.5858,  0.6764],
         [-1.4587, -1.0473, -0.4189,  2.6578, -0.4090, -0.6816]],

        [[-0.7104, -0.1158,  1.1307, -0.1447, -0.4148, -1.0460],
         [ 0.4195,  0.6915, -0.8920,  0.2698,  0.2428, -0.0



## Functions to get the manual calulations for each component

In [2]:
from torch import functional as F

# Using the state dictionary to get the intermediate outputs
state_dict = model.state_dict()

need_weights = False

src_mask = None

tgt_mask = None

memory_mask = None

embed_dim = 6

num_heads = 2

max_seq_len = 5


#### Function to fetch word embeddings with help of token indices

In [3]:
def look_up_table(sentence, vocab_embeds, embedding):

    for i in range(sentence.size(0)):
        for j in range(sentence.size(1)):
            
            # Get the index for the current word token index in the sequence
            word_index = sentence[i, j].item()

            if word_index < 0 or word_index >= vocab_embeds.size(0):
                raise ValueError(f"Invalid word index: {word_index}")

            # Lookup the corresponding embedding vector for the word
            embedding[i, j, :] = vocab_embeds[word_index, :]

            print(f"Word index: {word_index}, Embedding: {vocab_embeds[word_index, :]}")
    print()
    

    return embedding

### Embeddings and Positional encoding

In [4]:
def get_embedding_outputs(src_sentence,  tgt_sentence, max_seq_len, state_dict, d_model):

    src_vocab_embeds = state_dict["src_embedding.weight"]

    src_embedding = torch.zeros(src_sentence.size(0), src_sentence.size(1), d_model)
    print("Source sentence embedding")
    src_embedding =  look_up_table(src_sentence, src_vocab_embeds, src_embedding)
    print(src_embedding.shape)

    tgt_vocab_embeds = state_dict["tgt_embedding.weight"]

    tgt_embedding = torch.zeros(tgt_sentence.size(0), tgt_sentence.size(1), d_model)


    print("Target sentence embedding")
    tgt_embedding =  look_up_table(tgt_sentence, tgt_vocab_embeds, tgt_embedding)

    pe = PositionalEncoding(d_model = d_model, dropout=0, max_len=max_seq_len)

    print("PE of src :")
    print(pe(src_sentence))
    print()
    print("PE of tgt :")
    print(pe(tgt_sentence))
    print()

    pe_src_embeds = src_embedding + pe(src_sentence)

    pe_tgt_embeds = tgt_embedding + pe(tgt_sentence)

    print("PE source embeddings : \n")
    print(pe_src_embeds)
    print()

    print("PE target embeddings : \n")
    print(pe_tgt_embeds)
    print()

    return pe_src_embeds, pe_tgt_embeds



In [5]:
pe_src_embeds, pe_tgt_embeds = get_embedding_outputs(src_sentence=src_sentence,  tgt_sentence=tgt_sentence, state_dict=state_dict, max_seq_len=max_seq_len, d_model = d_model)


Source sentence embedding
Word index: 0, Embedding: tensor([-1.2113,  0.6304, -1.4713, -1.3352, -0.4897,  0.1317])
Word index: 2, Embedding: tensor([-0.5432, -1.0841,  1.4612, -1.6279, -1.4801, -1.0631])
Word index: 1, Embedding: tensor([ 0.3295,  0.3264, -0.4806,  1.1032,  2.5485,  0.3006])
Word index: 0, Embedding: tensor([-1.2113,  0.6304, -1.4713, -1.3352, -0.4897,  0.1317])
Word index: 2, Embedding: tensor([-0.5432, -1.0841,  1.4612, -1.6279, -1.4801, -1.0631])
Word index: 2, Embedding: tensor([-0.5432, -1.0841,  1.4612, -1.6279, -1.4801, -1.0631])
Word index: 3, Embedding: tensor([ 0.3630,  0.3995,  0.1457, -0.7345, -0.9873,  1.8512])
Word index: 5, Embedding: tensor([ 0.0160,  0.4019,  1.9538, -0.4460,  1.7102,  0.8944])

torch.Size([4, 2, 6])
Target sentence embedding
Word index: 1, Embedding: tensor([-0.3214, -0.6803, -0.5234, -0.8720,  0.2539,  0.3173])
Word index: 7, Embedding: tensor([ 0.1460,  1.4648,  0.9936, -0.3499, -0.9568,  1.7102])
Word index: 3, Embedding: tensor([-

## Encoder function to display the intermediate outputs and get the final outputs from the encoder

### Self attention 

#### Functions to perform the attention calculation with Q,K and V matrices

In [6]:
def atten_product_needs_wts_false(Q, V, K, bsz, head_dim, src_len, tgt_len, embed_dim, attn_mask):


    # *** For multi-head attention ***
    #  (bsz*num_heads, src_len , head_dim) -> (bsz, num_heads, tgt_len, head_dim)
    Q1 = Q.view(bsz, num_heads, tgt_len, head_dim)
    K1 = K.view(bsz, num_heads, src_len, head_dim)
    V1 = V.view(bsz, num_heads, src_len, head_dim)


    L, S = Q1.size(-2), K1.size(-2)

    scale_factor = 1 / math.sqrt(Q1.size(-1)) 
    # scale_factor = 1
    attn_bias = torch.zeros(L, S, dtype=Q1.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            # attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))

            masked_tensor = attn_mask.float().masked_fill(attn_mask, float('-inf'))
            masked_tensor = masked_tensor.masked_fill(~attn_mask, 0)
            attn_mask = masked_tensor

            print("Attnetion mask infunction = ")
            print(attn_mask)
            print()
            attn_bias = attn_bias.unsqueeze(0).unsqueeze(0)
            attn_bias += attn_mask

        else:
            attn_bias += attn_mask
            attn_bias = attn_bias.unsqueeze(0).unsqueeze(0)


    # print("Attnetion bias = ", attn_bias.shape)
    # print(attn_bias)
    # print()
            

    # (bsz, num_heads, tgt_len, head_dim) @ (bsz, num_heads, head_dim, tgt_len) -> (bsz, num_heads, tgt_len, tgt_len) 
    attn_weight = Q1 @ K1.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias


    # print("INTERMEDIATE PDT = ", attn_weight)


    # (bsz, num_heads, tgt_len, tgt_len) 
    attn_weight = torch.softmax(attn_weight, dim=-1)

    # print("ATTN PDT = ", attn_weight)

    sum_last_dim = attn_weight.sum(dim=-1)
    tolerance = 1e-6  
    assert torch.allclose(sum_last_dim, torch.ones_like(sum_last_dim), atol=tolerance), "Attention weights sum is not approximately equal to 1"

    # (bsz, num_heads, tgt_len, tgt_len) @ (bsz, num_heads, tgt_len, head_dim) -> (bsz, num_heads, tgt_len, head_dim) 
    attn_output = attn_weight @ V1

    # print("Dot product attention  = ")
    # print(attn_weight.shape, attn_weight)

    # print(attn_output.shape)
    # print(bsz, tgt_len, embed_dim)
    
    # (bsz*tgt_len, embed_dim)
    attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)

    print("Attention output = ")
    print(attn_weight.shape, attn_weight)

    return attn_output



def atten_product_needs_wts_true(Q, K, V, bsz, tgt_len, embed_dim, attn_mask):

    # *** For multi-head attention ***
    #  (bsz*num_heads, src_len , head_dim)
    
    B, Nt, E = Q.shape

    Q_scaled = Q / math.sqrt(E)

    if attn_mask is not None:
        temp_pdt_matrix = torch.baddbmm(attn_mask, Q_scaled, K.transpose(-2, -1))
    else:
        temp_pdt_matrix = torch.bmm(Q_scaled, K.transpose(-2, -1))

    attn_wt_matrix = torch.nn.functional.softmax(temp_pdt_matrix, dim=-1)

    attn_output = torch.bmm(attn_wt_matrix, V)

    attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)


    sum_last_dim = attn_wt_matrix.sum(dim=-1)

    tolerance = 1e-6  
    assert torch.allclose(sum_last_dim, torch.ones_like(sum_last_dim), atol=tolerance), "Attention weights sum is not approximately equal to 1"


    print("Encoder Attention output = ")
    print(attn_output)
    print()

    return attn_output, attn_wt_matrix


#### Function to get the Q,K,V matrices from the model's intialised weights

In [7]:
def get_qkv(query, key, value ,W, b):

    # embed_dim
    E = query.size(-1)

    if key is value:
        if query is key:
            
            # (src_len, bsz, embed_dim) @ (embed_dim*num_heads, embed_dim).T -> (src_len, bsz, embed_dim*num_heads)
            tempop1 = query@W.T

            # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
            tempop1 = tempop1.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()

            # (src_len, bsz, embed_dim)
            return tempop1[0], tempop1[1], tempop1[2]
        

        else:

            # (embed_dim*1, embed_dim)
            # (embed_dim*2, embed_dim)
            W_q, W_kv = W.split([E, E * 2])

            if b is None:
                b_q = b_kv = None
            else:
                b_q, b_kv = b.split([E, E * 2])

            # (src_len, bsz, embed_dim) @ (embed_dim*1, embed_dim).T -> (src_len, bsz, embed_dim)
            q_matmul = query@W_q.T

            # (src_len, bsz, embed_dim) @ (embed_dim*2, embed_dim).T -> (src_len, bsz, embed_dim*2)
            kv_matmul = key@W_kv.T

            kv_matmul = kv_matmul.unflatten(-1, (2, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()

            # (src_len, bsz, embed_dim)
            return q_matmul, kv_matmul[0], kv_matmul[1]

    else:

        W_q, W_k, W_v = W.chunk(3)
        if b is None:
            b_q = b_k = b_v = None
        else:
            b_q, b_k, b_v = b.chunk(3)


        q_matmul = query@W_q.T
        k_matmul = key@W_k.T
        v_matmul = value@W_v.T

        return q_matmul, k_matmul, v_matmul



    

#### Encoder block's self attention output function

In [8]:
def encoder_block_attn_output(x, state_dict, layer_num, embed_dim, num_heads, need_weights = False, src_mask = None):

    # (src_len, bsz, embed_dim)
    query_enc = key_enc = value_enc = x

    tgt_len, bsz, embed_dim = x.shape
    
    # (embed_dim*num_heads, embed_dim)
    W_enc = state_dict["transformer.encoder.layers.{}.self_attn.in_proj_weight".format(layer_num)]
    b_enc = state_dict["transformer.encoder.layers.{}.self_attn.in_proj_bias".format(layer_num)]


    head_dim = embed_dim//num_heads
    
    # (src_len, bsz, embed_dim)
    Q_enc,K_enc,V_enc = get_qkv(query_enc, key_enc, value_enc ,W_enc, b_enc)

    print("Q_enc_shape = ", Q_enc.shape)
    print("K_enc_shape = ", K_enc.shape)
    print("V_enc_shape = ", V_enc.shape)
    print()
    
    # (1, src_len, bsz, embed_dim)
    # Q_enc = Q_enc.unsqueeze(0)
    # K_enc = K_enc.unsqueeze(0)
    # V_enc = V_enc.unsqueeze(0)

    # (1, src_len, bsz, embed_dim) -> ( bsz*num_heads, src_len , head_dim)
    Q_enc = Q_enc.reshape(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    K_enc = K_enc.reshape(K_enc.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
    V_enc = V_enc.reshape(V_enc.shape[0], bsz * num_heads, head_dim).transpose(0, 1)

    print("Q_enc_shape = ", Q_enc.shape)
    print("K_enc_shape = ", K_enc.shape)
    print("V_enc_shape = ", V_enc.shape)
    print()


    print("Q_enc_{} = ".format(layer_num))
    print(Q_enc)
    print()

    print("K_enc_{} = ".format(layer_num))
    print(K_enc)
    print()

    print("V_enc_{} = ".format(layer_num))
    print(V_enc)
    print()


    src_len = K_enc.size(1)

    attn_mask = src_mask
    if attn_mask is not None:

        # Ensuring attn_mask's dim is 3
        if attn_mask.dim() == 2:
            correct_2d_size = (tgt_len, src_len)
            if attn_mask.shape != correct_2d_size:
                raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
            attn_mask = attn_mask.unsqueeze(0)
        elif attn_mask.dim() == 3:
            correct_3d_size = (bsz * num_heads, tgt_len, src_len)
            if attn_mask.shape != correct_3d_size:
                raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
        else:
            raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")

    # attn_mask can be either (L,S) or (N*num_heads, L, S)
    # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
    # in order to match the input for SDPA of (N, num_heads, L, S)
    if attn_mask is not None:
        if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
            attn_mask = attn_mask.unsqueeze(0)
        else:
            attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)


    if need_weights is False:

        attn_output = atten_product_needs_wts_false(Q_enc, V_enc, K_enc, bsz, head_dim, src_len, tgt_len, embed_dim, attn_mask)

        return attn_output, src_len, head_dim, None
    
    else:

        attn_enc_output,attn_wt_matrix_enc = atten_product_needs_wts_true(Q_enc, K_enc, V_enc, bsz, tgt_len, embed_dim,  attn_mask)

        return attn_enc_output, src_len, head_dim, attn_wt_matrix_enc

    

In [9]:
# need_weights = True

attn_enc_output, src_len, head_dim, attn_weights = encoder_block_attn_output(pe_src_embeds, state_dict, layer_num = 0, need_weights = need_weights, embed_dim=embed_dim, num_heads=num_heads)


Q_enc_shape =  torch.Size([4, 2, 6])
K_enc_shape =  torch.Size([4, 2, 6])
V_enc_shape =  torch.Size([4, 2, 6])

Q_enc_shape =  torch.Size([4, 4, 3])
K_enc_shape =  torch.Size([4, 4, 3])
V_enc_shape =  torch.Size([4, 4, 3])

Q_enc_0 = 
tensor([[[-0.3633,  0.5583,  1.3890],
         [ 0.1844, -1.0462,  1.0679],
         [ 0.3350,  0.1886, -0.5119],
         [-0.1372, -0.8008,  1.7042]],

        [[ 0.4743, -0.5773, -0.2725],
         [ 0.4730,  0.1340,  1.1069],
         [-1.1150, -0.3167, -0.3792],
         [-1.1884, -0.4654,  0.0278]],

        [[ 0.2876, -0.0819, -0.5922],
         [-0.4108,  0.2878,  1.3087],
         [ 0.2876, -0.0819, -0.5922],
         [ 0.4681, -1.7669,  0.8374]],

        [[-1.2147, -0.0640, -0.3155],
         [ 0.3745, -0.3245, -0.2089],
         [-1.2147, -0.0640, -0.3155],
         [-0.8569,  0.8351,  0.8359]]])

K_enc_0 = 
tensor([[[-0.9253, -0.5284,  0.9858],
         [ 1.1106,  1.4398,  1.1074],
         [-0.7042, -0.5813,  0.3668],
         [-0.3563, -0.7

### Post self attention in encoder

#### Function to perform the linear layer calculations after encoder's self attention

In [10]:
def encoder_block_post_attn_output(x, attn_enc_output, state_dict, layer_num, bsz, tgt_len):

    # (bsz*src_len , embed_dim) @ (embed_dim , embed_dim).T -> (bsz*src_len , embed_dim)
    op_enc_1 = torch.matmul(attn_enc_output, state_dict["transformer.encoder.layers.{}.self_attn.out_proj.weight".format(layer_num)].t()) + state_dict["transformer.encoder.layers.{}.self_attn.out_proj.bias".format(layer_num)]



    # (bsz*src_len , embed_dim) -> (src_len, bsz, embed_dim)
    attn_enc_output = op_enc_1.view(tgt_len, bsz, attn_enc_output.size(1))


    # Here src is the original passed inputs to the 1st transformer encoder layer which 
    # are pe_src_embeds 

    # (src_len, bsz, embed_dim)
    output_enc_1 = attn_enc_output + x

    #  (src_len, bsz, embed_dim) @ (embed_dim) -> (src_len, bsz, embed_dim) 
    linear_result_enc_1 = output_enc_1*state_dict["transformer.encoder.layers.{}.norm1.weight".format(layer_num)] + state_dict["transformer.encoder.layers.{}.norm1.bias".format(layer_num)]

    # Layer normalization from Torch's implementation
    layernorm_enc_1 = torch.nn.LayerNorm(normalized_shape=linear_result_enc_1.shape[2:])
    linear_op_enc_1 = layernorm_enc_1(linear_result_enc_1)


    # Manual Layer Normalization
    x = linear_result_enc_1

    # Obtained layer norm weights and biases (learnable)
    w = layernorm_enc_1.weight
    b = layernorm_enc_1.bias

    linear_result_enc_1f = w*x + b

    epsilon = 1e-5  
    mean = linear_result_enc_1f.mean(dim=-1, keepdim=True)
    std = linear_result_enc_1f.std(dim=-1, unbiased=False, keepdim=True)
    normalized_result_enc_1 = (linear_result_enc_1f - mean) / (std + epsilon) * w + b


    op_enc_1 = torch.matmul(normalized_result_enc_1, state_dict["transformer.encoder.layers.{}.linear1.weight".format(layer_num)].t()) + state_dict["transformer.encoder.layers.{}.linear1.bias".format(layer_num)]
    op_enc_1_relu = torch.nn.functional.relu(op_enc_1)
    op_enc_2 = torch.matmul(op_enc_1_relu, state_dict["transformer.encoder.layers.{}.linear2.weight".format(layer_num)].t()) + state_dict["transformer.encoder.layers.{}.linear2.bias".format(layer_num)]


    output_enc_2 = op_enc_2 + linear_op_enc_1
    output_enc_2_norm = output_enc_2*state_dict["transformer.encoder.layers.{}.norm2.weight".format(layer_num)] + state_dict["transformer.encoder.layers.{}.norm2.bias".format(layer_num)]

    # Layer normalization from Torch's implementation
    layernorm_enc_final = torch.nn.LayerNorm(normalized_shape=output_enc_2_norm.shape[2:])
    output_enc_final = layernorm_enc_final(output_enc_2_norm)


    # Manual Layer Normalization 
    x = output_enc_2_norm

    # Obtained layer norm weights and biases (learnable)
    w = layernorm_enc_final.weight
    b = layernorm_enc_final.bias

    linear_result_enc_2 = w*x + b

    epsilon = 1e-5  
    mean = linear_result_enc_2.mean(dim=-1, keepdim=True)
    std = linear_result_enc_2.std(dim=-1, unbiased=False, keepdim=True)
    output_enc_final = (linear_result_enc_2 - mean) / (std + epsilon) * w + b

    print("Final Encoder {} Output :".format(layer_num))
    print("norm2(norm1(x + self_atten(x)) + feed_fwd_op)\n")
    print(output_enc_final)
    print()

    # (src_len, bsz, embed_dim) 
    return output_enc_final


In [11]:
tgt_len, bsz, embed_dim = pe_src_embeds.shape
output_enc_final = encoder_block_post_attn_output(pe_src_embeds, attn_enc_output, state_dict, layer_num = 0 , bsz = bsz, tgt_len = tgt_len)

Final Encoder 0 Output :
norm2(norm1(x + self_atten(x)) + feed_fwd_op)

tensor([[[ 0.1315,  1.7982, -1.1204, -0.7758, -0.6844,  0.6509],
         [ 0.5618, -0.0283,  1.6605,  0.0642, -1.5605, -0.6977]],

        [[ 0.6983,  1.4650, -1.7656, -0.4646, -0.0901,  0.1570],
         [ 0.4833,  1.5877, -1.3182, -0.5150, -0.8981,  0.6603]],

        [[-0.7859,  0.4154,  1.7409, -0.4138, -1.3587,  0.4021],
         [ 0.5618, -0.0283,  1.6605,  0.0642, -1.5605, -0.6977]],

        [[ 0.4204,  0.9917, -0.7048, -0.1056, -1.7415,  1.1397],
         [ 0.0060,  0.1914,  0.7388, -1.8825,  1.3175, -0.3712]]],
       grad_fn=<AddBackward0>)



## Decoder functions to display the intermediate outputs and get the final outputs from the decoder


### Self attention outputs from a decoder block

In [12]:
def decoder_block_self_attn_output(x, state_dict, layer_num, tgt_mask = None,need_weights = False):

    # (tgt_len, bsz, embed_dim)
    query_dec = key_dec = value_dec = x

    tgt_len, bsz, embed_dim = x.shape
 
    # (embed_dim*num_heads, embed_dim)
    W_dec = state_dict["transformer.decoder.layers.{}.self_attn.in_proj_weight".format(layer_num)]
    b_dec = state_dict["transformer.decoder.layers.{}.self_attn.in_proj_bias".format(layer_num)]


    head_dim = embed_dim//num_heads

    # (tgt_len, bsz, embed_dim)
    Q_dec,K_dec,V_dec = get_qkv(query_dec, key_dec, value_dec ,W_dec, b_dec)
    
    # Q_dec = Q_dec.unsqueeze(0)
    # K_dec = K_dec.unsqueeze(0)
    # V_dec = V_dec.unsqueeze(0)

    # (1, tgt_len, bsz, embed_dim)
    # print(Q_dec.shape, K_dec.shape , V_dec.shape)
    # print(tgt_len, bsz * num_heads, head_dim)

    # (1, tgt_len, bsz, embed_dim) -> ( bsz*num_heads, tgt_len , head_dim )
    Q_dec = Q_dec.reshape(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    K_dec = K_dec.reshape(K_dec.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
    V_dec = V_dec.reshape(V_dec.shape[0], bsz * num_heads, head_dim).transpose(0, 1)

    print("Q_dec_{} = ".format(layer_num))
    print(Q_dec)
    print()

    print("K_dec_{} = ".format(layer_num))
    print(K_dec)
    print()

    print("V_dec_{} = ".format(layer_num))
    print(V_dec)
    print()

    src_len = K_dec.size(1)


    attn_mask = tgt_mask
    if attn_mask is not None:

        # Ensuring attn_mask's dim is 3
        if attn_mask.dim() == 2:
            correct_2d_size = (tgt_len, src_len)
            if attn_mask.shape != correct_2d_size:
                raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
            attn_mask = attn_mask.unsqueeze(0)
        elif attn_mask.dim() == 3:
            correct_3d_size = (bsz * num_heads, tgt_len, src_len)
            if attn_mask.shape != correct_3d_size:
                raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
        else:
            raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")

    # attn_mask can be either (L,S) or (N*num_heads, L, S)
    # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
    # in order to match the input for SDPA of (N, num_heads, L, S)
    if attn_mask is not None:
        if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
            attn_mask = attn_mask.unsqueeze(0)
        else:
            attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)

    
    
    if need_weights is False:
        attn_output = atten_product_needs_wts_false(Q = Q_dec, V = V_dec, K = K_dec, bsz = bsz, head_dim=head_dim, src_len=src_len, tgt_len=tgt_len, attn_mask = attn_mask, embed_dim=embed_dim)

        print("Decoder Self Attention = ")
        print(attn_output)
        print()

        op_dec_1 = torch.matmul(attn_output, state_dict["transformer.decoder.layers.{}.self_attn.out_proj.weight".format(layer_num)].t()) + state_dict["transformer.decoder.layers.{}.self_attn.out_proj.bias".format(layer_num)]
        attn_dec_output = op_dec_1.view(tgt_len, bsz, attn_output.size(1))

        return attn_dec_output, None
    
    else:

        attn_dec_output,attn_wt_matrix_dec = atten_product_needs_wts_true(Q=Q_dec, K=K_dec, V=V_dec, bsz=bsz, tgt_len=tgt_len, attn_mask = attn_mask, embed_dim=embed_dim)

        print("Decoder Attention output = ")
        print(attn_wt_matrix_dec)
        print()

        op_dec_1 = torch.matmul(attn_dec_output, state_dict["transformer.decoder.layers.{}.self_attn.out_proj.weight".format(layer_num)].t()) + state_dict["transformer.decoder.layers.{}.self_attn.out_proj.bias".format(layer_num)]
        attn_dec_output = op_dec_1.view(tgt_len, bsz, attn_dec_output.size(1))
    

        return attn_dec_output, attn_wt_matrix_dec


    

In [13]:
# need_weights = True

self_attn_dec, dec_sa_wts = decoder_block_self_attn_output(pe_tgt_embeds, state_dict, layer_num = 0, need_weights = need_weights)

Q_dec_0 = 
tensor([[[ 0.4915,  0.1785,  0.3864],
         [ 1.1715,  1.5151,  0.7140],
         [-0.8424, -0.5799,  0.9247],
         [ 0.4130,  0.0187,  0.5372],
         [-0.1047, -0.4099,  1.0179]],

        [[-0.0688, -0.2718, -0.4068],
         [-0.2064,  0.5938, -0.7551],
         [-0.5829, -0.0967, -0.0997],
         [ 0.0358, -0.4574, -0.8994],
         [-0.7606, -0.3070, -0.4590]],

        [[ 0.9623, -0.3928,  1.0130],
         [-0.5729,  0.8946,  1.7594],
         [ 0.5452,  0.2343, -0.1199],
         [ 1.3363,  0.0377, -1.5454],
         [ 0.9184,  0.4523,  0.0457]],

        [[-0.0105,  0.2761, -1.4754],
         [-1.3176,  1.1325, -0.7708],
         [ 0.4692,  0.0739, -1.0632],
         [ 1.3846,  0.5773, -0.7389],
         [ 0.1132,  0.2065, -0.5631]]])

K_dec_0 = 
tensor([[[-0.8079, -0.7427, -1.1814],
         [-0.5335, -2.4177, -1.4567],
         [ 0.9248, -0.3809,  0.3459],
         [-1.2918, -1.4079, -1.5316],
         [-0.2605,  0.2826, -0.5938]],

        [[ 0.2199

#### Function to perform the linear layer calculations after decoder's self attention

In [14]:
def dec_post_self_attn(self_attn_dec, x, state_dict, layer_num):

    # (bsz*tgt_len , embed_dim) @ (embed_dim , embed_dim).T -> (bsz*tgt_len , embed_dim)
    output_dec_1 = self_attn_dec + x

    # (bsz*tgt_len , embed_dim) @ (embed_dim , embed_dim).T -> (bsz*tgt_len , embed_dim)
    linear_result_dec_1 = output_dec_1*state_dict["transformer.decoder.layers.{}.norm1.weight".format(layer_num)] + state_dict["transformer.decoder.layers.{}.norm1.bias".format(layer_num)]

    layernorm_dec_1 = torch.nn.LayerNorm(normalized_shape=linear_result_dec_1.shape[2:])
    linear_op_dec_1 = layernorm_dec_1(linear_result_dec_1)

    # From torch's implementation
    # print(linear_op_dec_1)

    x = linear_result_dec_1

    # Obtained layer norm weights and biases (learnable)
    w = layernorm_dec_1.weight
    b = layernorm_dec_1.bias

    linear_result_dec_1f = w*x + b

    epsilon = 1e-5  
    mean = linear_result_dec_1f.mean(dim=-1, keepdim=True)
    std = linear_result_dec_1f.std(dim=-1, unbiased=False, keepdim=True)
    normalized_result_dec_1 = (linear_result_dec_1f - mean) / (std + epsilon) * w + b


    print("Decoder_{} norm1(x + sa(x))".format(layer_num))
    print(normalized_result_dec_1)
    print()

    # (tgt_len, bsz, embed_dim)
    return normalized_result_dec_1

In [15]:
x_dec = dec_post_self_attn(self_attn_dec, pe_tgt_embeds, state_dict, layer_num = 0)

Decoder_0 norm1(x + sa(x))
tensor([[[ 0.2496, -0.2941, -1.1316, -0.6138, -0.2434,  2.0333],
         [ 1.1098,  0.7223, -0.5588, -0.5771, -1.6441,  0.9478]],

        [[ 0.4421, -0.5809, -1.6390,  0.7358, -0.4001,  1.4420],
         [-0.3757,  0.0057, -0.7252,  2.1689, -0.6989, -0.3747]],

        [[-0.6578,  0.5975,  1.4361,  0.5268, -1.6794, -0.2232],
         [ 1.6963,  0.5479, -1.5721, -0.0824, -0.5866, -0.0031]],

        [[ 0.0601,  0.4732, -1.6004, -0.0936, -0.5441,  1.7048],
         [ 2.1109,  0.1166, -0.5146, -0.9700, -0.5275, -0.2153]],

        [[-1.1550, -0.9021,  1.1832, -0.6537,  0.1087,  1.4188],
         [ 1.6840, -0.1874, -1.3775, -0.6023, -0.3716,  0.8547]]],
       grad_fn=<AddBackward0>)



## Cross attention in decoder 
### query, key, value =  x, mem, mem 

### Cross attention between the encoder's final output and the decoder layer 

In [16]:
memory = output_enc_final
x_dec

tensor([[[ 0.2496, -0.2941, -1.1316, -0.6138, -0.2434,  2.0333],
         [ 1.1098,  0.7223, -0.5588, -0.5771, -1.6441,  0.9478]],

        [[ 0.4421, -0.5809, -1.6390,  0.7358, -0.4001,  1.4420],
         [-0.3757,  0.0057, -0.7252,  2.1689, -0.6989, -0.3747]],

        [[-0.6578,  0.5975,  1.4361,  0.5268, -1.6794, -0.2232],
         [ 1.6963,  0.5479, -1.5721, -0.0824, -0.5866, -0.0031]],

        [[ 0.0601,  0.4732, -1.6004, -0.0936, -0.5441,  1.7048],
         [ 2.1109,  0.1166, -0.5146, -0.9700, -0.5275, -0.2153]],

        [[-1.1550, -0.9021,  1.1832, -0.6537,  0.1087,  1.4188],
         [ 1.6840, -0.1874, -1.3775, -0.6023, -0.3716,  0.8547]]],
       grad_fn=<AddBackward0>)

In [17]:
memory.shape

torch.Size([4, 2, 6])

In [18]:
def decoder_block_cross_attn_output(x_dec, memory, state_dict, layer_num, tgt_len, src_len, head_dim, memory_mask = None,need_weights = False):

    # (tgt_len, bsz, embed_dim)
    query_dec_mha = x_dec

    # (src_len, bsz, embed_dim)
    key_dec_mha, value_dec_mha = memory, memory

    tgt_len, bsz, embed_dim = query_dec_mha.shape

    W_dec_mha = state_dict["transformer.decoder.layers.{}.multihead_attn.in_proj_weight".format(layer_num)]
    b_dec_mha = state_dict["transformer.decoder.layers.{}.multihead_attn.in_proj_bias".format(layer_num)]


    ########## CHANGE FROM SELF ATTENTION ############

    Q_dec_mha,K_dec_mha,V_dec_mha = get_qkv(query_dec_mha, key_dec_mha, value_dec_mha ,W_dec_mha, b_dec_mha)


    #################################################

    # K_dec_mha = K_dec_mha.unsqueeze(0)
    # V_dec_mha = V_dec_mha.unsqueeze(0)
    # Q_dec_mha = Q_dec_mha.unsqueeze(0)


    Q_dec_mha = Q_dec_mha.reshape(Q_dec_mha.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
    K_dec_mha = K_dec_mha.reshape(K_dec_mha.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
    V_dec_mha = V_dec_mha.reshape(V_dec_mha.shape[0], bsz * num_heads, head_dim).transpose(0, 1)

    print("Q_dec_{} = ".format(layer_num))
    print(Q_dec_mha)
    print()

    print("K_dec_{} = ".format(layer_num))
    print(K_dec_mha)
    print()

    print("V_dec_{} = ".format(layer_num))
    print(V_dec_mha)
    print()


    attn_mask = memory_mask
    if attn_mask is not None:

        # Ensuring attn_mask's dim is 3
        if attn_mask.dim() == 2:
            correct_2d_size = (tgt_len, src_len)
            if attn_mask.shape != correct_2d_size:
                raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
            attn_mask = attn_mask.unsqueeze(0)
        elif attn_mask.dim() == 3:
            correct_3d_size = (bsz * num_heads, tgt_len, src_len)
            if attn_mask.shape != correct_3d_size:
                raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
        else:
            raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")

    # attn_mask can be either (L,S) or (N*num_heads, L, S)
    # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
    # in order to match the input for SDPA of (N, num_heads, L, S)
    if attn_mask is not None:
        if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
            attn_mask = attn_mask.unsqueeze(0)
        else:
            attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)



    if need_weights is False:

        attn_output_dec_mha = atten_product_needs_wts_false(Q =Q_dec_mha, V=V_dec_mha, K=K_dec_mha, bsz=bsz, head_dim=head_dim, src_len=src_len, tgt_len=tgt_len, attn_mask=attn_mask, embed_dim=embed_dim)
        
        print("Cross attention in decoder_{}".format(layer_num))
        print(attn_output_dec_mha)
        print()

        op_dec_mha_1 = torch.matmul(attn_output_dec_mha, state_dict["transformer.decoder.layers.{}.multihead_attn.out_proj.weight".format(layer_num)].t()) + state_dict["transformer.decoder.layers.{}.multihead_attn.out_proj.bias".format(layer_num)]

        attn_dec_mha_output = op_dec_mha_1.view(tgt_len, bsz, attn_output_dec_mha.size(1))

        return attn_dec_mha_output, None
    
    else: 

    
        attn_dec_mha_output ,attn_wt_matrix_dec_mha = atten_product_needs_wts_true(Q=Q_dec_mha, K=K_dec_mha, V=V_dec_mha, bsz=bsz, tgt_len=tgt_len, attn_mask=attn_mask, embed_dim=embed_dim)

        print("Decoder mha Attention output = ")
        print(attn_dec_mha_output)
        print()


        op_dec_mha = torch.matmul(attn_dec_mha_output, state_dict["transformer.decoder.layers.{}.multihead_attn.out_proj.weight".format(layer_num)].t()) + state_dict["transformer.decoder.layers.{}.multihead_attn.out_proj.bias".format(layer_num)]
        attn_dec_output_mha = op_dec_mha.view(tgt_len, bsz, attn_dec_mha_output.size(1))


        return attn_dec_output_mha , attn_wt_matrix_dec_mha

        


In [19]:
# need_weights = True

attn_dec_mha_output, attn_dec_mha_wts = decoder_block_cross_attn_output(x_dec, memory, state_dict, layer_num = 0, tgt_len = tgt_len, src_len = src_len, head_dim = head_dim, need_weights = need_weights)

Q_dec_0 = 
tensor([[[ 0.7862, -0.3004,  0.7572],
         [ 0.8531, -0.1968,  0.0593],
         [-0.5982,  0.8353,  0.0710],
         [ 1.0470, -0.0455,  0.7925],
         [-0.2752,  0.4758,  0.6568]],

        [[ 1.0111, -0.8068,  0.9556],
         [ 0.0128, -0.2789,  0.4802],
         [-0.5917,  0.0728, -0.2093],
         [ 0.8041, -0.9694,  0.5379],
         [ 0.9808, -0.5303,  1.2566]],

        [[ 0.3277, -0.3959,  0.4343],
         [ 0.2816,  0.6237, -0.6124],
         [ 0.5739, -0.8586, -0.1966],
         [-0.0360, -1.1861, -0.2934],
         [ 0.5500, -1.0157,  0.0207]],

        [[ 0.0763, -0.2905, -0.0025],
         [-1.1952,  0.3035, -0.5019],
         [-0.3886,  0.1664, -0.5765],
         [-0.3402,  0.5286, -0.4543],
         [ 0.1009,  0.0264,  0.0921]]], grad_fn=<TransposeBackward0>)

K_dec_0 = 
tensor([[[-0.1875, -0.4684,  0.4381],
         [-0.1823, -0.2832,  0.4897],
         [-0.0665, -0.2104,  0.0272],
         [-0.0594, -0.1930,  0.7234]],

        [[-0.5768,  1.431

#### Function to perform the linear layer calculations after encoder's cross attention

In [20]:
def decoder_block_post_attn_output(x_dec, attn_dec_mha_output, state_dict, layer_num):

    # (tgt_len, bsz, embed_dim)
    output_dec_2 = attn_dec_mha_output + x_dec

    linear_result_dec_2 = output_dec_2*state_dict["transformer.decoder.layers.{}.norm2.weight".format(layer_num)] + state_dict["transformer.decoder.layers.{}.norm2.bias".format(layer_num)]

    # Layer normalization from Torch's implementation 
    layernorm_dec_2 = torch.nn.LayerNorm(normalized_shape=linear_result_dec_2.shape[2:])
    linear_op_dec_2 = layernorm_dec_2(linear_result_dec_2)

    x = linear_result_dec_2
    w = layernorm_dec_2.weight
    b = layernorm_dec_2.bias

    linear_result_dec_2f = w*x + b

    epsilon = 1e-5  
    mean = linear_result_dec_2f.mean(dim=-1, keepdim=True)
    std = linear_result_dec_2f.std(dim=-1, unbiased=False, keepdim=True)
    normalized_result_dec_2 = (linear_result_dec_2f - mean) / (std + epsilon) * w + b

    print("norm2(x' + mha(x', mem)), \n where x' = Decoder_curr_layer norm1(x + sa(x))")
    print(normalized_result_dec_2)
    print("\n\n")

    x_dec2_norm = normalized_result_dec_2

    op_dec_1 = torch.matmul(x_dec2_norm, state_dict["transformer.decoder.layers.{}.linear1.weight".format(layer_num)].t()) + state_dict["transformer.decoder.layers.{}.linear1.bias".format(layer_num)]
    op_dec_1_relu = torch.nn.functional.relu(op_dec_1)
    op_dec_2 = torch.matmul(op_dec_1_relu, state_dict["transformer.decoder.layers.{}.linear2.weight".format(layer_num)].t()) + state_dict["transformer.decoder.layers.{}.linear2.bias".format(layer_num)]

    ff_dec = op_dec_2

    x_dec3_unorm = x_dec2_norm + ff_dec

    linear_result_dec_3 = x_dec3_unorm*state_dict["transformer.decoder.layers.0.norm3.weight"] + state_dict["transformer.decoder.layers.0.norm3.bias"]

    # Layer normalization from Torch's implementation 
    layernorm_dec_3 = torch.nn.LayerNorm(normalized_shape=linear_result_dec_3.shape[2:])
    linear_op_dec_3 = layernorm_dec_3(linear_result_dec_3)

    # From torch's implementation
    x = linear_result_dec_3

    # Obtained layer norm weights and biases (learnable)
    w = layernorm_dec_3.weight
    b = layernorm_dec_3.bias

    linear_result_dec_3f = w*x + b

    epsilon = 1e-5  
    mean = linear_result_dec_3f.mean(dim=-1, keepdim=True)
    std = linear_result_dec_3f.std(dim=-1, unbiased=False, keepdim=True)
    normalized_result_dec_3 = (linear_result_dec_3f - mean) / (std + epsilon) * w + b

    print("norm3(x'' + ff(x'')) \n where, x'' = Decoder_curr_layer norm2(x' + mha(x'))")
    print(normalized_result_dec_3)
    print()

    return normalized_result_dec_3

def feef_fwd_transformer(dec_output_final, state_dict):

    # (tgt_len, bsz, embed_dim) @ (vocab_size, embed_dim).T -> (tgt_len, bsz, vocab_size)
    final_op = dec_output_final@state_dict["fc.weight"].T + state_dict["fc.bias"]

    return final_op

In [21]:
final_attn_op_decoder = decoder_block_post_attn_output(x_dec, attn_dec_mha_output, state_dict, layer_num = 0)

norm2(x' + mha(x', mem)), 
 where x' = Decoder_curr_layer norm1(x + sa(x))
tensor([[[ 0.2103, -0.5384, -0.6440, -1.0017, -0.0855,  2.0592],
         [ 1.1925,  0.6852, -0.5456, -0.8191, -1.4829,  0.9699]],

        [[ 0.5276, -0.9134, -1.2995,  0.2015, -0.2723,  1.7561],
         [-0.2317, -0.0252, -0.8662,  2.1575, -0.5692, -0.4652]],

        [[-0.5649,  0.2359,  1.7549,  0.0041, -1.5877,  0.1578],
         [ 1.7267,  0.5120, -1.5780, -0.2203, -0.4657,  0.0254]],

        [[-0.0028,  0.1408, -1.1372, -0.6672, -0.3611,  2.0276],
         [ 2.1018,  0.0778, -0.5685, -1.0199, -0.4339, -0.1573]],

        [[-0.8441, -0.9961,  1.1739, -0.9302,  0.1748,  1.4218],
         [ 1.6890, -0.2130, -1.3551, -0.7281, -0.2212,  0.8284]]],
       grad_fn=<AddBackward0>)



norm3(x'' + ff(x'')) 
 where, x'' = Decoder_curr_layer norm2(x' + mha(x'))
tensor([[[ 0.8615, -1.4745, -1.2859,  0.2651,  0.7044,  0.9293],
         [ 1.7893,  0.0161, -1.3267,  0.1703, -0.9563,  0.3071]],

        [[ 0.7938, -1.49

In [22]:
# (tgt_len, bsz, vocab_dim)
final_op = feef_fwd_transformer(final_attn_op_decoder, state_dict)

In [23]:
final_op

tensor([[[-1.1367,  0.4059,  0.0039,  0.2621, -0.4700, -0.9143,  0.1431,
           0.4767, -0.3693, -0.2621],
         [-0.2206,  1.4191,  0.4450, -0.6391, -0.4939, -0.7436, -0.3757,
           0.5601, -0.2733,  0.8637]],

        [[-1.0985,  0.5695,  0.2399,  0.1278, -0.5324, -0.9371,  0.1021,
           0.3644, -0.6032, -0.2809],
         [-0.1862,  1.0474,  0.6285,  0.2929, -0.0427, -0.4063,  0.1267,
           0.4304, -0.6193, -0.8473]],

        [[ 1.0816,  1.1890,  0.1632, -0.5463,  0.4493, -0.3544, -0.2048,
           0.0505,  0.5994,  0.7032],
         [-0.4685,  1.2632,  0.1055, -0.3486, -0.4561, -0.7998, -0.3027,
           0.7605, -0.1206,  0.6021]],

        [[-1.0653,  0.6070,  0.1960,  0.4041, -0.5071, -0.6479,  0.1185,
           0.8493, -0.4667, -0.3609],
         [-0.1655,  1.1405, -0.1165, -0.8707, -0.3151, -1.1322, -0.3957,
           0.1049,  0.2154,  1.2038]],

        [[-1.0150,  0.1429, -0.5758,  0.4336, -0.0169, -1.2373,  0.2980,
          -0.0419, -0.0243, -0.

In [24]:
output

tensor([[[-1.1367,  0.4059,  0.0039,  0.2621, -0.4700, -0.9143,  0.1431,
           0.4767, -0.3693, -0.2621],
         [-0.2206,  1.4191,  0.4450, -0.6391, -0.4939, -0.7436, -0.3757,
           0.5601, -0.2733,  0.8637]],

        [[-1.0985,  0.5695,  0.2399,  0.1278, -0.5324, -0.9371,  0.1021,
           0.3644, -0.6032, -0.2809],
         [-0.1862,  1.0474,  0.6285,  0.2929, -0.0427, -0.4063,  0.1267,
           0.4304, -0.6193, -0.8473]],

        [[ 1.0816,  1.1890,  0.1632, -0.5463,  0.4493, -0.3544, -0.2048,
           0.0505,  0.5994,  0.7032],
         [-0.4685,  1.2632,  0.1055, -0.3486, -0.4561, -0.7998, -0.3028,
           0.7605, -0.1206,  0.6021]],

        [[-1.0653,  0.6070,  0.1960,  0.4041, -0.5071, -0.6479,  0.1185,
           0.8493, -0.4667, -0.3609],
         [-0.1655,  1.1405, -0.1165, -0.8707, -0.3151, -1.1322, -0.3957,
           0.1049,  0.2154,  1.2038]],

        [[-1.0150,  0.1429, -0.5758,  0.4336, -0.0169, -1.2373,  0.2980,
          -0.0419, -0.0243, -0.

In [25]:
final_op.shape

torch.Size([5, 2, 10])

########################################################################

### Examaple 1 :- 

Transformer with 3 encoders and 3 decoders

In [26]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512, dropout=0):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        return self.encoding[:, :x.size(1)].detach()

class TransformerModel(nn.Module):
    
    def __init__(self, src_vocab_size, tgt_vocab_size,max_seq_len, num_encoder_layers=1, num_decoder_layers=1, d_model=4, nhead=2):
        super(TransformerModel, self).__init__()
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, dropout=0, max_len=max_seq_len)
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dropout=0
        )
        self.fc = nn.Linear(d_model, tgt_vocab_size)

    def forward(self, src, tgt):

        print("Source embeddings :- \n")
        print(self.src_embedding(src))
        print()

        print("Target embeddings :- \n")
        print(self.tgt_embedding(tgt))
        print()


        src = self.src_embedding(src) + self.positional_encoding(src)
        tgt = self.tgt_embedding(tgt) + self.positional_encoding(tgt)

        print("Positional encoded source embeddings:")
        print(src)
        print()

        print("Positional encoded target embeddings:")
        print(tgt)
        print()

        output = self.transformer(src, tgt)
        output = self.fc(output)
        
        return output
    

src_vocab_size = 10  # Source language vocabulary size
tgt_vocab_size = 10  # Target language vocabulary size
d_model = 12  # Dimension of the model
num_heads = 3
num_encoder_layers = 3
num_decoder_layers = 3

need_weights = False

src_mask = None
tgt_mask = None
memory_mask = None

max_seq_len = 5

embed_dim = d_model

# embed_dim = d_model//num_heads



model = TransformerModel(src_vocab_size, tgt_vocab_size, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, d_model=d_model, max_seq_len = max_seq_len, nhead=num_heads)
# Source sentence in the source language
src_sentence = torch.tensor([[0, 2, 5, 6], [1, 0, 6, 5], [2, 2, 1, 3], [3, 5, 8, 0]])
print(src_sentence.shape)
# Target sentence in the target language (translation of the source sentence)
tgt_sentence = torch.tensor([[1, 7, 0, 1], [3, 4, 9, 1], [5, 2, 0, 3], [8, 0, 4, 5], [6, 1, 3, 7], [2, 5, 9, 0]])  # Target sequence
print(tgt_sentence.shape)

# Forward pass
output = model(src_sentence, tgt_sentence)
print(output)


torch.Size([4, 4])
torch.Size([6, 4])
Source embeddings :- 

tensor([[[-2.0781,  2.4465,  0.0688,  0.6648,  0.9296, -0.5790, -0.0134,
          -0.2002,  0.2405, -0.5802, -0.2976, -0.6942],
         [ 0.3004,  0.1016, -0.0348, -0.2505,  1.5129, -0.1994,  1.4109,
           0.4083, -1.5015,  0.0792, -0.0561, -1.0517],
         [-1.6898, -1.1012, -0.6427,  0.1601, -0.7973,  1.7317, -0.0548,
           0.0518, -0.4189,  1.3654, -0.4802,  0.5831],
         [-1.3937, -0.1001, -0.4957,  0.4838, -1.1298, -0.0612,  0.4849,
          -0.5524, -0.8483, -0.7955,  0.1770, -0.1222]],

        [[-3.0403,  0.3559, -0.7012, -0.3400, -0.4425, -1.0635,  0.3796,
          -0.7286, -1.0257, -0.8489, -0.1872,  0.5907],
         [-2.0781,  2.4465,  0.0688,  0.6648,  0.9296, -0.5790, -0.0134,
          -0.2002,  0.2405, -0.5802, -0.2976, -0.6942],
         [-1.3937, -0.1001, -0.4957,  0.4838, -1.1298, -0.0612,  0.4849,
          -0.5524, -0.8483, -0.7955,  0.1770, -0.1222],
         [-1.6898, -1.1012, -0.642

In [27]:
def get_all_intermediate_outputs(src_sentence, tgt_sentence,d_model, model, num_encoder_layers , num_decoder_layers):

    state_dict = model.state_dict()

    pe_src_embeds, pe_tgt_embeds = get_embedding_outputs(src_sentence=src_sentence,  tgt_sentence=tgt_sentence, state_dict=state_dict, max_seq_len=max_seq_len, d_model = d_model)
    print("###"*25)
    print("### Encoder Start ###")
    print()

    x_enc = pe_src_embeds

    for lno in range(num_encoder_layers):
        attn_enc_output, src_len, head_dim, attn_weights = encoder_block_attn_output(x_enc, state_dict, layer_num = lno, need_weights = False, embed_dim = d_model, num_heads= num_heads, src_mask=None)

        if lno == 0:
            tgt_len, bsz, embed_dim = x_enc.shape

        output_enc_final = encoder_block_post_attn_output(x_enc, attn_enc_output, state_dict, layer_num = lno , bsz = bsz, tgt_len = tgt_len)

        x_enc = output_enc_final

    
    print("### Encoder Done ###")
    print("\n\n\n")
    print("### Decoder Start ###")

    x_dec = pe_tgt_embeds
    memory = x_enc

    for lno in range(num_decoder_layers):

        self_attn_dec, dec_sa_wts = decoder_block_self_attn_output(x_dec, state_dict, layer_num = lno, need_weights = False, tgt_mask=None)
        x_dec = dec_post_self_attn(self_attn_dec, x_dec, state_dict, layer_num = lno)

        attn_dec_mha_output, attn_dec_mha_wts = decoder_block_cross_attn_output(x_dec, memory, state_dict, layer_num = lno, tgt_len = tgt_len, src_len = src_len, head_dim = head_dim, need_weights = False, memory_mask=None)
        final_op = decoder_block_post_attn_output(x_dec, attn_dec_mha_output, state_dict, layer_num = lno)

        print(pe_tgt_embeds.shape, final_op.shape)
        x_dec = final_op

    print("### Decoder Done ###")


    final_op = feef_fwd_transformer(final_op, state_dict)


    
    return final_op

    

In [28]:
# d_model = 12
# num_encoder_layers = 3
# num_decoder_layers = 3

final_op = get_all_intermediate_outputs(src_sentence, tgt_sentence, model = model, num_encoder_layers = num_encoder_layers , num_decoder_layers = num_encoder_layers, d_model=d_model)

Source sentence embedding
Word index: 0, Embedding: tensor([-2.0781,  2.4465,  0.0688,  0.6648,  0.9296, -0.5790, -0.0134, -0.2002,
         0.2405, -0.5802, -0.2976, -0.6942])
Word index: 2, Embedding: tensor([ 0.3004,  0.1016, -0.0348, -0.2505,  1.5129, -0.1994,  1.4109,  0.4083,
        -1.5015,  0.0792, -0.0561, -1.0517])
Word index: 5, Embedding: tensor([-1.6898, -1.1012, -0.6427,  0.1601, -0.7973,  1.7317, -0.0548,  0.0518,
        -0.4189,  1.3654, -0.4802,  0.5831])
Word index: 6, Embedding: tensor([-1.3937, -0.1001, -0.4957,  0.4838, -1.1298, -0.0612,  0.4849, -0.5524,
        -0.8483, -0.7955,  0.1770, -0.1222])
Word index: 1, Embedding: tensor([-3.0403,  0.3559, -0.7012, -0.3400, -0.4425, -1.0635,  0.3796, -0.7286,
        -1.0257, -0.8489, -0.1872,  0.5907])
Word index: 0, Embedding: tensor([-2.0781,  2.4465,  0.0688,  0.6648,  0.9296, -0.5790, -0.0134, -0.2002,
         0.2405, -0.5802, -0.2976, -0.6942])
Word index: 6, Embedding: tensor([-1.3937, -0.1001, -0.4957,  0.4838

In [29]:
output.shape, final_op.shape


(torch.Size([6, 4, 10]), torch.Size([6, 4, 10]))

In [30]:
output

tensor([[[-0.3228, -0.9053, -0.5082,  0.4718,  0.1404,  0.0900, -0.8195,
           0.2730,  0.3273,  1.0589],
         [-0.0916,  1.9037,  0.5857, -0.0275, -0.2269,  0.3961,  0.3599,
          -0.1530,  0.0368, -0.2285],
         [ 0.0024,  0.1919,  0.0615, -0.8913,  0.3367,  0.0978, -0.3729,
          -0.6543, -0.1429, -0.2311],
         [-0.0127,  0.4908,  0.0913, -0.7293,  0.2477,  0.1502, -0.3272,
          -0.7576,  0.1207, -0.1994]],

        [[-0.6171,  0.4608, -0.2076,  1.0943,  0.0396,  0.3247, -0.2315,
           0.3346,  0.7077,  0.9616],
         [-0.7705,  1.7848,  0.9436, -0.2763, -0.3984,  0.7689,  0.4744,
          -0.0182, -0.3521, -0.4475],
         [-0.0390,  0.0719, -0.0317, -0.9406,  0.0193,  0.2039, -0.3028,
          -0.6329, -0.0940, -0.2245],
         [-0.0127,  0.4908,  0.0913, -0.7293,  0.2477,  0.1502, -0.3272,
          -0.7576,  0.1207, -0.1994]],

        [[-0.9653,  0.2469, -0.3535,  0.9487,  0.0125,  0.2773, -0.3317,
           0.5061,  0.3918,  1.0237

In [31]:
final_op

tensor([[[-0.3228, -0.9053, -0.5082,  0.4718,  0.1404,  0.0900, -0.8195,
           0.2730,  0.3273,  1.0589],
         [-0.0916,  1.9037,  0.5857, -0.0275, -0.2269,  0.3961,  0.3599,
          -0.1530,  0.0368, -0.2285],
         [ 0.0024,  0.1919,  0.0615, -0.8913,  0.3367,  0.0978, -0.3729,
          -0.6543, -0.1429, -0.2311],
         [-0.0127,  0.4908,  0.0913, -0.7293,  0.2477,  0.1502, -0.3272,
          -0.7576,  0.1207, -0.1994]],

        [[-0.6171,  0.4608, -0.2076,  1.0943,  0.0396,  0.3247, -0.2314,
           0.3346,  0.7077,  0.9616],
         [-0.7705,  1.7848,  0.9436, -0.2763, -0.3984,  0.7689,  0.4744,
          -0.0182, -0.3521, -0.4475],
         [-0.0390,  0.0719, -0.0317, -0.9406,  0.0193,  0.2039, -0.3028,
          -0.6329, -0.0940, -0.2245],
         [-0.0127,  0.4908,  0.0913, -0.7293,  0.2477,  0.1502, -0.3272,
          -0.7576,  0.1207, -0.1994]],

        [[-0.9653,  0.2469, -0.3535,  0.9487,  0.0125,  0.2773, -0.3317,
           0.5061,  0.3918,  1.0237

### Example 2
### Training part of the transformer


In [32]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512, dropout=0):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        print(x.shape)
        return self.encoding[:, :x.size(1)].detach()



class TransformerModel1(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_encoder_layers, num_decoder_layers, max_seq_len, d_ff, dropout = 0):

        super(TransformerModel1, self).__init__()

        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)

        self.positional_encoding = PositionalEncoding(d_model, dropout=0, max_len=max_seq_len)

        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=num_heads,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dropout=dropout,
            dim_feedforward=d_ff,
        )

        self.fc = nn.Linear(d_model, tgt_vocab_size)



    def generate_mask(self, src, tgt):

        src_mask = None
        seq_length = tgt.size(0)

        nopeak_mask = (torch.triu(torch.ones(seq_length, seq_length), diagonal=1)).bool()

        return src_mask, nopeak_mask

    def forward(self, src, tgt):

        src_mask, tgt_mask = self.generate_mask(src, tgt)

        print("Tgt mask shape = ", tgt_mask.shape)

        src = self.src_embedding(src) + self.positional_encoding(src)
        tgt = self.tgt_embedding(tgt) + self.positional_encoding(tgt)


        output = self.transformer(src, tgt, src_mask = src_mask, tgt_mask = tgt_mask, tgt_is_causal = False)
        output = self.fc(output)
        
        return output
    

In [33]:
torch.manual_seed(0)

src_vocab_size = 20
tgt_vocab_size = 20
d_model = 16
num_heads = 4
num_encoder_layers = 1
num_decoder_layers = 1
d_ff = 20
max_seq_len = 5
dropout = 0

transformer = TransformerModel1(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_encoder_layers, num_decoder_layers, max_seq_len, d_ff)

# Generate random sample data
# src_data = torch.randint(1, src_vocab_size, (max_seq_len , 3))  # (seq_length, batch_size,)
# tgt_data = torch.randint(1, tgt_vocab_size, ( max_seq_len, 3))  # (seq_length, batch_size)

src_data = torch.tensor([[0, 2, 4], [1, 0, 7], [2, 2, 0], [3, 5, 6], [6, 1, 9]])
tgt_data = torch.tensor([[1, 7, 9], [3, 4, 1], [5, 2, 8], [8, 0, 3], [4, 5, 9]])  # Target sequence


state_dict = transformer.state_dict()

In [34]:
# torch.manual_seed(0)

# src_vocab_size = 200
# tgt_vocab_size = 200
# d_model = 512
# num_heads = 8
# num_encoder_layers = 6
# num_decoder_layers = 6
# d_ff = 1024
# max_seq_len = 50
# dropout = 0

# transformer = TransformerModel1(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_encoder_layers, num_decoder_layers, max_seq_len, d_ff)

# # Generate random sample data
# src_data = torch.randint(1, src_vocab_size, (max_seq_len ,32))  # (seq_length, batch_size,)
# tgt_data = torch.randint(1, tgt_vocab_size, (max_seq_len, 32))  # (seq_length, batch_size)

# # src_data = torch.tensor([[0, 2, 4], [1, 0, 7], [2, 2, 0], [3, 5, 6], [6, 1, 9]])
# # tgt_data = torch.tensor([[1, 7, 9], [3, 4, 1], [5, 2, 8], [8, 0, 3]])  # Target sequence


# state_dict = transformer.state_dict()

In [35]:
import copy

state_dict1 = copy.deepcopy(state_dict)

In [36]:
src_data.shape, tgt_data.shape

(torch.Size([5, 3]), torch.Size([5, 3]))

In [37]:
tgt_data.view(-1)
tgt_data.shape

torch.Size([5, 3])

In [38]:
src_data.shape, tgt_data.shape

(torch.Size([5, 3]), torch.Size([5, 3]))

In [39]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

transformer.train()

for epoch in range(1):

    optimizer.zero_grad()

    print("FWD PASS START\n")
    print("src_data shape = ", src_data.shape)
    print("tgt_data shape = ", tgt_data.shape)
    
    output = transformer(src_data, tgt_data[:-1, :])
    print("FWD PASS END\n")

    print("output shape = ",output.shape)

    loss = criterion(output.contiguous().view(-1, tgt_vocab_size), tgt_data[1:, :].contiguous().view(-1))


    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

FWD PASS START

src_data shape =  torch.Size([5, 3])
tgt_data shape =  torch.Size([5, 3])
Tgt mask shape =  torch.Size([4, 4])
torch.Size([5, 3])
torch.Size([4, 3])
torch.Size([5, 3, 16]) torch.Size([4, 3, 16])
MASK =  None
query =  tensor([[[-1.1258, -0.1524, -0.2506,  0.5661,  0.8487,  1.6920, -0.3160,
          -1.1152,  0.3223, -0.2633,  0.3500,  1.3081,  0.1198,  2.2377,
           1.1168,  0.7527],
         [ 0.2279,  0.5719, -0.1817,  1.1988,  0.5395,  1.1074,  0.6724,
           1.4407, -0.0923,  1.7924, -0.2865,  1.0525,  0.5239,  3.3022,
          -1.4686, -0.5867],
         [ 0.3400,  0.5038,  1.7019,  2.0965, -1.2795,  3.5473, -0.4099,
           1.3336, -1.6093,  0.4501, -0.4735,  0.5003, -1.0650,  2.1149,
          -0.1400,  1.8058]],

        [[-1.3527, -0.6959,  0.5667,  1.7935,  0.5988, -0.5551, -0.3414,
           2.8530,  0.7502,  0.4145, -0.1734,  1.1835,  1.3894,  2.5863,
           0.9463,  0.1563],
         [-0.2844, -0.6121,  0.0604,  0.5165,  0.9485,  1.6870, -

In [40]:
# transformer.eval()

# # Generate random sample validation data
# val_src_data = torch.randint(1, src_vocab_size, (10 ,5))  # (seq_length, batch_size)
# val_tgt_data = torch.randint(1, tgt_vocab_size, (10, 5))  # (seq_length, batch_size)

# with torch.no_grad():

#     val_output = transformer(val_src_data, val_tgt_data)
#     val_loss = criterion(val_output.contiguous().view(-1, tgt_vocab_size), val_tgt_data.contiguous().view(-1))
#     print(f"Validation Loss: {val_loss.item()}")

In [41]:
import copy

### Cross verifying the intermediate outputs for the 1st forward pass

In [42]:
def generate_mask(src, tgt):
    src_mask = None
    seq_length = tgt.size(0)
    nopeak_mask = (torch.triu(torch.ones(seq_length, seq_length), diagonal=1)).bool()
    # tgt_mask = tgt_mask & nopeak_mask
    return src_mask, nopeak_mask

src_mask, tgt_mask = generate_mask(src_data, tgt_data)

In [43]:
src_mask, tgt_mask

(None,
 tensor([[False,  True,  True,  True,  True],
         [False, False,  True,  True,  True],
         [False, False, False,  True,  True],
         [False, False, False, False,  True],
         [False, False, False, False, False]]))

In [44]:
def get_all_intermediate_outputs_mask(src_sentence, tgt_sentence,d_model, state_dict, num_encoder_layers , num_decoder_layers, tgt_mask, d_ff):

    pe_src_embeds, pe_tgt_embeds = get_embedding_outputs(src_sentence=src_sentence,  tgt_sentence=tgt_sentence, state_dict=state_dict, max_seq_len=max_seq_len, d_model = d_model)
    print("###"*25)
    print("### Encoder Start ###")
    print()

    x_enc = pe_src_embeds

    for lno in range(num_encoder_layers):
        attn_enc_output, src_len, head_dim, attn_weights = encoder_block_attn_output(x_enc, state_dict, layer_num = lno, need_weights = False, embed_dim = d_model, num_heads= num_heads, src_mask=None)

        if lno == 0:
            tgt_len, bsz, embed_dim = x_enc.shape

        output_enc_final = encoder_block_post_attn_output(x_enc, attn_enc_output, state_dict, layer_num = lno , bsz = bsz, tgt_len = tgt_len)

        x_enc = output_enc_final

    
    print("### Encoder Done ###")
    print("\n\n\n")
    print("### Decoder Start ###")

    x_dec = pe_tgt_embeds
    memory = x_enc

    for lno in range(num_decoder_layers):

        self_attn_dec, dec_sa_wts = decoder_block_self_attn_output(x_dec, state_dict, layer_num = lno, need_weights = False, tgt_mask=tgt_mask)
        x_dec = dec_post_self_attn(self_attn_dec, x_dec, state_dict, layer_num = lno)

        attn_dec_mha_output, attn_dec_mha_wts = decoder_block_cross_attn_output(x_dec, memory, state_dict, layer_num = lno, tgt_len = tgt_len, src_len = src_len, head_dim = head_dim, need_weights = False, memory_mask=None)
        final_op = decoder_block_post_attn_output(x_dec, attn_dec_mha_output, state_dict, layer_num = lno)

        print(pe_tgt_embeds.shape, final_op.shape)
        x_dec = final_op

    print("### Decoder Done ###")

    final_op = feef_fwd_transformer(final_op, state_dict)

    return final_op
    

In [45]:
# state_dict1 = transformer.state_dict

src_mask, tgt_mask = generate_mask(src_data, tgt_data[:-1, :])

In [46]:


final_op = get_all_intermediate_outputs_mask(src_data, tgt_data[:-1, :], state_dict = state_dict1, num_encoder_layers = num_encoder_layers , num_decoder_layers = num_encoder_layers, d_model=d_model,  d_ff = d_ff, tgt_mask = tgt_mask)


Source sentence embedding
Word index: 0, Embedding: tensor([-1.1258, -1.1524, -0.2506, -0.4339,  0.8487,  0.6920, -0.3160, -2.1152,
         0.3223, -1.2633,  0.3500,  0.3081,  0.1198,  1.2377,  1.1168, -0.2473])
Word index: 2, Embedding: tensor([-0.6136,  0.0316, -0.4927,  0.2484,  0.4397,  0.1124,  0.6408,  0.4412,
        -0.1023,  0.7924, -0.2897,  0.0525,  0.5229,  2.3022, -1.4689, -1.5867])
Word index: 4, Embedding: tensor([-0.5692,  0.9200,  1.1108,  1.2899, -1.4782,  2.5672, -0.4731,  0.3356,
        -1.6293, -0.5497, -0.4798, -0.4997, -1.0670,  1.1149, -0.1407,  0.8058])
Word index: 1, Embedding: tensor([-1.3527, -1.6959,  0.5667,  0.7935,  0.5988, -1.5551, -0.3414,  1.8530,
         0.7502, -0.5855, -0.1734,  0.1835,  1.3894,  1.5863,  0.9463, -0.8437])
Word index: 0, Embedding: tensor([-1.1258, -1.1524, -0.2506, -0.4339,  0.8487,  0.6920, -0.3160, -2.1152,
         0.3223, -1.2633,  0.3500,  0.3081,  0.1198,  1.2377,  1.1168, -0.2473])
Word index: 7, Embedding: tensor([-0.21