In [None]:
# # This Python 3 environment comes with many helpful analytics libraries installed
# # It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# # For example, here's several helpful packages to load

# import numpy as np # linear algebra
# import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# # Input data files are available in the read-only "../input/" directory
# # For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

# import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# # You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# # You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import os
import math
from pathlib import Path
from datasets import load_dataset
from tqdm import tqdm
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
from transformers import get_linear_schedule_with_warmup
from torch.utils.data import DataLoader, Dataset, random_split
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.trainers import BpeTrainer
import wandb

In [2]:
wandb.login()

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [10]:
wandb.init(
    project="LangGPT",
    config={
        "architecture": "Transformers",
        "dataset": "https://huggingface.co/datasets/cfilt/iitb-english-hindi",
        "epochs": 10,
        "Training Data": 100,
        "Validation Data": 50,
        "Version": "V1",
        "tags": "Kaggle test run",
    },
)

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
loss,██▆▆▄▄▅▄▅▄▄█▄▂▃▅▃▃▄█▃▃▄▄▅▄▂▃▃▃▃▅▄▃▃▂▅▄▅▁

0,1
loss,6.60059


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112598299991481, max=1.0…

In [13]:
# :: DATASET ::
# Download dataset from Hugging-face: https://huggingface.co/datasets/cfilt/iitb-english-hindi
print("INFO: Dataset download started.")
raw_train_dataset = load_dataset("cfilt/iitb-english-hindi", split="train")
raw_val_dataset = load_dataset("cfilt/iitb-english-hindi", split="validation")
raw_test_dataset = load_dataset("cfilt/iitb-english-hindi", split="test")
print("INFO: Dataset download complete.")


# # Splitting the dataset into training and validation dataset of 3000 and 300 respectively for faster training and validation.
raw_train_dataset, rt_to_skip = random_split(raw_train_dataset, [500000, len(raw_train_dataset) - 500000])
print(len(raw_train_dataset))
# raw_val_dataset, vt_to_skip = random_split(raw_val_dataset, [500, len(raw_val_dataset) - 500])

INFO: Dataset download started.
INFO: Dataset download complete.
500000


In [None]:
# :: TOKENIZER :: 
# [ Creating Source Language Tokenizer - English ]
# Additional Special Tokens: [UNK] - to represent Unknown words, [PAD] - to represent padding added to keep sequence length constant for the model
# [CLS] - Token to denote start of sentence, [SEP] = Token to denote end of sentence

tokenizer_en = Tokenizer(BPE(unk_token="[UNK]"))
trainer_en = BpeTrainer(
    min_frequency=2, special_tokens=["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"]
)

# NOTE: below function is used as an iterator on the smaller random dataset we just 
def get_ds_iterator(raw_train_dataset, lang):
  for data in raw_train_dataset:
    yield data['translation'][lang]
    

# splitting tokens based on whitespaces
tokenizer_en.pre_tokenizer = Whitespace()
print("INFO: source tokenizer initialized")

print("INFO: source tokenizer training started...")
start_time = time.time()
tokenizer_en.train_from_iterator(get_ds_iterator(raw_train_dataset, "en"), trainer=trainer_en)
# tokenizer_en.train(files=path_en, trainer=trainer_en)
print("INFO: source tokenizer training completed!")
print(f"INFO: time taken: {time.time() - start_time}s")


# Save tokenizer for future use
tokenizer_en.save("/kaggle/working/tokenizer_en.json")
print(
    f"INFO: source tokenizer saved into: /models/tokenizer_en/tokenizer_en.json"
)


# [ Creating Target Language Tokenizer - Hindi ]
tokenizer_hi = Tokenizer(BPE(unk_token="[UNK]"))
trainer_hi = BpeTrainer(
    min_frequency=2, special_tokens=["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"]
)

# splitting tokens based on whitespaces
tokenizer_hi.pre_tokenizer = Whitespace()
print("INFO: target tokenizer initialized")

print("INFO: target tokenizer training started...")
start_time = time.time()
tokenizer_hi.train_from_iterator(get_ds_iterator(raw_train_dataset, "hi"), trainer=trainer_hi)
# tokenizer_hi.train(files=path_hi, trainer=trainer_hi)
print("INFO: target tokenizer training completed!")
print(f"INFO: time taken: {time.time() - start_time}s")

# Save tokenizer for future use
tokenizer_hi.save("/kaggle/working/tokenizer_hi.json")
print(
    f"INFO: source tokenizer saved into: /models/tokenizer_hi/tokenizer_hi.json"
)

# Load tokenizers from file
tokenizer_en = Tokenizer.from_file("/kaggle/working/tokenizer_en.json")
tokenizer_hi = Tokenizer.from_file("/kaggle/working/tokenizer_hi.json")

# Store the vocab size of source and target tokenizers
source_vocab_size = tokenizer_en.get_vocab_size()
target_vocab_size = tokenizer_hi.get_vocab_size()
print(f"INFO: source tokenizer vocab size = {source_vocab_size}")
print(f"INFO: target tokenizer vocab size = {target_vocab_size}")

INFO: source tokenizer initialized
INFO: source tokenizer training started...


In [None]:
# :: TRAINING ::
st_time = time.time()

# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"INFO: device = {device}")

print("INFO: line:23 -> EncodeDataset")
# This class takes raw dataset and max_seq_len
class EncodeDataset(Dataset):
    def __init__(self, raw_dataset, max_seq_len):
        """
        Constructor to initialise class variables

        Args:
            raw_dataset (Dataset): raw data downloaded from hugging-face
            max_seq_len (int): max seq length of the sentences in the dataset
        """
        super().__init__()
        self.raw_dataset = raw_dataset
        self.max_seq_len = max_seq_len

    def __len__(self):
        return len(self.raw_dataset)

    def __getitem__(self, index):
        """
        Generating below for the translation pair at raw_dataset[index]:
        - encoder_input
        - decoder_input
        - target_label
        - encoder_mask
        - decoder_mask
        - source_text
        - target_text

        Args:
            index (_type_): item at index in raw_dataset
        """

        # Fetch the raw translation
        raw_text = self.raw_dataset[index]

        # Split into source and target text
        source_text = raw_text['translation']['en']
        target_text = raw_text['translation']['hi']

        # Encoding source text using source tokenizer(tokenizer_en) and target text using target tokenizer(tokenizer_hi)
        source_text_encoded = tokenizer_en.encode(source_text).ids
        target_text_encoded = tokenizer_hi.encode(target_text).ids

        # Convert the CLS, SEP and PAD tokens to their corresponding index id in vocabulary using tokenizer [the id would be same with either tokenizers]
        CLS_ID = torch.tensor([tokenizer_hi.token_to_id("[CLS]")], dtype=torch.int64)
        SEP_ID = torch.tensor([tokenizer_hi.token_to_id("[SEP]")], dtype=torch.int64)
        PAD_ID = torch.tensor([tokenizer_hi.token_to_id("[PAD]")], dtype=torch.int64)
        
        # To train the model we have to same sequence length for input and output and hence we need to add padding
        # Calculate the number of padding to be added for source and target
        num_source_padding = (
            self.max_seq_len - len(source_text_encoded) - 2
        )  # 2 -> [CLS] and [SEP]
        num_target_padding = (
            self.max_seq_len - len(target_text_encoded) - 1
        )  # 1 -> [SEP] only because target label contains [SEP] only and [CLS] is required by the model to start the inference

        # Add the padding based on the number computer above
        encoder_padding = torch.tensor([PAD_ID] * num_source_padding, dtype=torch.int64)
        decoder_padding = torch.tensor([PAD_ID] * num_target_padding, dtype=torch.int64)

        # construct the encoder input
        # Encoder I/P: [CLS_ID] + source_text_encoded + [SEP_ID] + encoder_padding
        encoder_input = torch.cat([CLS_ID, torch.tensor(source_text_encoded, dtype=torch.int64), SEP_ID, encoder_padding], dim=0)

        # construct the decoder input
        # Decoder I/P: [CLS_ID] + target_text_encoded + decoder_padding
        decoder_input = torch.cat([CLS_ID, torch.tensor(target_text_encoded, dtype=torch.int64), decoder_padding ], dim=0)

        # construct the target label
        # Target Label: target_text_encoded + [SEP_ID] + decoder_padding
        target_label = torch.cat([torch.tensor(target_text_encoded, dtype=torch.int64), SEP_ID, decoder_padding], dim=0)

        # As we are adding extra padding to match the sequence input,but we should not let the model train on it
        # hence, we'll use encoder mask to nullify the padding tokens
        encoder_mask = (
            (encoder_input != PAD_ID).unsqueeze(0).unsqueeze(0).int()
        )

        # We'll do that same for decoder too but we also need to get rid of the upper triangle for Masked Multi-Attention
        decoder_mask = (decoder_input != PAD_ID).unsqueeze(0).unsqueeze(
            0
        ).int() & causal_mask(decoder_input.size(0))

        return {
            "encoder_input": encoder_input,
            "decoder_input": decoder_input,
            "target_label": target_label,
            "encoder_mask": encoder_mask,
            "decoder_mask": decoder_mask,
            "source_text": source_text,
            "target_text": target_text,
        }


# Causal mask will make sure any token that comes after the current token will be masked, meaning the value will be replaced by -ve infinity which will be converted to zero or close to zero after softmax function.
# Hence the model will just ignore these value or willn't be able to learn anything from these values.
def causal_mask(size):
    # dimension of causal mask (batch_size, seq_len, seq_len)
    mask = torch.triu(torch.ones(1, size, size), diagonal=1).type(torch.int)
    return mask == 0


# calculating max_seq_len from the dataset
max_seq_len_source = 0
max_seq_len_target = 0

for data in raw_train_dataset:
    enc_ids = tokenizer_en.encode(data['translation']['en']).ids
    dec_ids = tokenizer_hi.encode(data['translation']['hi']).ids
    max_seq_len_source = max(max_seq_len_source, len(enc_ids))
    max_seq_len_target = max(max_seq_len_target, len(dec_ids))


print(f"Max sequence length of source: {max_seq_len_source}")  # 50
print(f"Max sequence length of target: {max_seq_len_target}")  # 50


# To simplify the calcualtion let's add some value to the greater value and have a single max_seq_len
#max_seq_len = 100  # 50 + 20: 20 -> to accomodate the additional length of tokens such as PAD, CLS, SEP in the sequence.
max_seq_len = max(max_seq_len_source, max_seq_len_target) + 20


print("INFO: Encoding dataset started.")
# Instantiate the EncodeDataset class and create the encoded train and validation-dataset.
train_dataset = EncodeDataset(raw_train_dataset, max_seq_len)
val_dataset = EncodeDataset(raw_val_dataset, max_seq_len)
print("INFO: Encoding dataset complete.")


# --------------------------------------------------- #
# Input Embedding and Positional Encoding

print("INFO: line 152 -> EmbeddingLayer")


class EmbeddingLayer(nn.Module):
    def __init__(self, d_model: int, vocab_size: int):
        super().__init__()
        self.d_model = d_model

        # using pytorch's embedding module we will map the token_id with the vocabulary and then convert it to embedding matrix
        self.embedding = nn.Embedding(
            vocab_size, d_model
        )  # initialise the Embedding layer to taken in the vocab_size and output a embeddign vector of size d_model

    def forward(self, input):
        # After the output of embedding is recieved the output is multiplied with teh sqrt(d_model) for normalizing the output
        embedding_output = self.embedding(input) * math.sqrt(self.d_model)
        return embedding_output
    

print("INFO: line 169 -> PositionalEncoding")


class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_seq_len: int, dropout_rate: float):
        super().__init__()
        self.dropout = nn.Dropout(dropout_rate)

        # we're creating a zero matrix of the same size as the embedding matrix
        pe = torch.zeros(max_seq_len, d_model)

        # Calculate the position part of the PE function
        pos = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)

        # Calculate the division part of the PE function
        # NOTE: div part expression is slightly different that papers expression as this exponential functions seems to works better.
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        # Fill in the `pe` with the sin and cos of the PE function
        # NOTE: sin -> even pos
        # NOTE: cos -> odd pos

        pe[:, 0::2] = torch.sin(pos * div_term)
        pe[:, 1::2] = torch.cos(pos * div_term)

        # Since we're expecting the input sequences in batches so the extra batch_size dimension is added in 0 postion.
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe) # LEARN: what is register_buffer

    def forward(self, input_embdding):
        # Add positional encoding together with the input embedding vector.
        input_embdding = input_embdding + (
            self.pe[:, : input_embdding.shape[1], :]
        ).requires_grad_(False) # to prevent from calculating the gradient of the positional encoding.

        # Perform dropout to prevent overfitting.
        return self.dropout(input_embdding)

    
# --------------------------------------------------- #
# Multi-head Attention Block

print("INFO: line 208 -> MultiHeadAttention")


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout_rate: float):
        super().__init__()
        # Define dropout to prevent overfitting
        self.dropout = nn.Dropout(dropout_rate)

        # Weight matrices are intoduced and all are learnable params
        self.W_q = nn.Linear(d_model, d_model, bias=False)  # Linear -> to enable learning # NOTE: bias=False -> to prevent bias newly added
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

        self.num_heads = num_heads
        assert d_model % num_heads == 0, "d_model must be divisible by number of heads"

        # d_k is the new dimension of each of each splitted self-attention heads
        self.d_k = d_model // num_heads

    def forward(self, q, k, v, encoder_mask):

        # # We'll be training our model with multiple batches of sequence at once in parallel, hence we'll need to include batch_size in the shape as well.
        # query, key and value are calculated by matrix multiplication of corresponding weights with the input embeddings.
        # Change of shape: q(batch_size, seq_len, d_model) @ W_q(d_model, d_model) => query(batch_size, seq_len, d_model) [same goes to key and value].
        query = self.W_q(q)
        key = self.W_k(k)
        value = self.W_v(v)

        # Splitting query, key and value into number of heads. d_model is splitted in d_k across 8 heads.
        # Change of shape: query(batch_size, seq_len, d_model) => query(batch_size, seq_len, num_heads, d_k) -> query(batch_size,num_heads, seq_len,d_k) [same goes to key and value].
        query = query.view(query.shape[0], query.shape[1], self.num_heads ,self.d_k).transpose(1,2)
        key = key.view(key.shape[0], key.shape[1], self.num_heads ,self.d_k).transpose(1,2)
        value = value.view(value.shape[0], value.shape[1], self.num_heads ,self.d_k).transpose(1,2)

        # INFO: SELF-ATTENTION BLOCK STARTS INFO:
        
        with autocast():
            attention_score = (query @ key.transpose(-2, -1)) / math.sqrt(self.d_k)
            attention_score = attention_score.float()
            
            if encoder_mask is not None:
                attention_score = attention_score.masked_fill(encoder_mask == 0, -1e9)

            attention_score = F.softmax(attention_score, dim=-1)
            if self.dropout is not None:
                attention_score = self.dropout(attention_score)

            attention_output = attention_score @ value

#         # Attention score is calculated
#         # Change of shape: query(batch_size,num_heads, seq_len,d_k) @ key(batch_size,num_heads, seq_len,d_k) => attention_score(batch_size,num_heads, seq_len,seq_len).
#         attention_score = (query @ key.transpose(-2, -1)) / math.sqrt(self.d_k)

#         # Cast to float32 before applying the mask and softmax
#         attention_score = attention_score.float()
        
#         # If masking is available
#         if encoder_mask is not None:
#             attention_score.masked_fill_(encoder_mask==0, -1e9)

#         # Softmax function calculates the probability distribution among all the attention scores. It assign higher probabiliy value to higher attention score. Meaning more similar tokens get higher probability value.
#         # Change of shape: same as attention_score
#         # attention_weight = torch.softmax(attention_score, dim=-1)
#         attention_score = attention_score.softmax(dim=-1)

#         if self.dropout is not None:
#             attention_score = self.dropout(attention_score)
        
#         # Cast back to the original dtype if necessary
#         #attention_score = attention_score.type_as(query)

#         # Final step in Self attention block is, matrix multiplication of attention_weight with Value embedding vector.
#         # Change of shape: attention_score(batch_size,num_heads, seq_len,seq_len) @  value(batch_size,num_heads, seq_len,d_k) => attention_output(batch_size,num_heads, seq_len,d_k)
#         attention_output = attention_score @ value

        # INFO: SELF-ATTENTION BLOCK ENDS

        # Now, all heads must be combined back to a single head
        # Change of shape:attention_output(batch_size,num_heads, seq_len,d_k) => attention_output(batch_size,seq_len,num_heads,d_k) => attention_output(batch_size,seq_len,d_model)
        attention_output = attention_output.transpose(1,2).contiguous().view(attention_output.shape[0], -1, self.num_heads * self.d_k)

        # Finally attention_output is matrix multiplied with output weight matrix to give the final Multi-Head attention output.
        # The shape of the multihead_output is same as the embedding input
        # Change of shape: attention_output(batch_size,seq_len,d_model) @ W_o(d_model, d_model) => multihead_output(batch_size, seq_len, d_model)
        multihead_output = self.W_o(attention_output)

        return multihead_output

    
# --------------------------------------------------- #
# Feed Forward, Layer Norm and Add & Norm

print("INFO: line 282 -> FeedForward")


class FeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout_rate: float):
        super().__init__()

        self.dropout = nn.Dropout(dropout_rate)
        self.layer_1 = nn.Linear(d_model, d_ff)
        # self.activation_1 = nn.ReLU()
        self.layer_2 = nn.Linear(d_ff, d_model)

    def forward(self, input):
        # return self.layer_2(self.dropout(self.activation_1(self.layer_1(input))))
        return self.layer_2(self.dropout(torch.relu(self.layer_1(input))))  # NOTE: relu is used instead of activation_1

    
print("INFO: line 298 -> LayerNorm")


class LayerNorm(nn.Module):
    def __init__(self, eps: float = 1e-5):
        super().__init__()
        # Epselon helps prevent potential division by 0
        self.eps = eps

        # Extra learning parameters gamma and beta are introduced to scale and shift the embedding value as the network needed.
        self.gamma = nn.Parameter(torch.ones(512)) # 512 is advisable to be the same as d_model
        self.beta = nn.Parameter(torch.zeros(512))

    def forward(self, input):
        mean = input.mean(dim=-1, keepdim=True)
        std = input.std(dim=-1, keepdim=True)

        return self.gamma * (input - mean)/(std + self.eps) + self.beta


print("INFO: line 317 -> AddAndNorm")


class AddAndNorm(nn.Module):
    def __init__(self, dropout_rate: float):
        super().__init__()
        self.dropout = nn.Dropout(dropout_rate)
        self.layer_norm = LayerNorm()

    def forward(self, input, sub_layer):
        return input + self.dropout(sub_layer(self.layer_norm(input)))
    
    
# --------------------------------------------------- #
# Encode block and Encoder

print("INFO: line 333 -> EncoderBlock")


class EncoderBlock(nn.Module):
    def __init__(self, multihead_attention: MultiHeadAttention, feed_forward: FeedForward, dropout_rate: float) -> None:
        super().__init__()
        self.multihead_attention = multihead_attention
        self.feed_forward = feed_forward
        # self.add_and_norm_list = nn.ModuleList(
        #     [AddAndNorm(dropout_rate) for _ in range(2)]
        # )  # 2 Add & Norm layers for every Encoder Block
        self.addnorm_1 = AddAndNorm(dropout_rate)
        self.addnorm_2 = AddAndNorm(dropout_rate)

    def forward(self, encoder_input, encoder_mask):
        # First AddAndNorm unit taking encoder input from skip connection and adding it with the output of MultiHead attention block
        encoder_input = self.addnorm_1(encoder_input, lambda encoder_input: self.multihead_attention(encoder_input, encoder_input, encoder_input, encoder_mask))
        # Second AddAndNorm unit taking output of MultiHead attention block from skip connection and adding it with the output of Feedforward layer
        encoder_input = self.addnorm_2(encoder_input, self.feed_forward)
        
        return encoder_input
    
    
print("INFO: line 353 -> Encoder")


class Encoder(nn.Module):
    def __init__(self, encoderblocklist: nn.ModuleList):
        super().__init__()

        self.encoderblocklist = encoderblocklist
        self.layer_norm = LayerNorm()

    def forward(self, encoder_input, encoder_mask):
        # loop through the encoderblocklist - 6 blocks
        for encoderblock in self.encoderblocklist:
            encoder_input = encoderblock(encoder_input, encoder_mask)

        # Normalize the final encoder block output and return. This encoder output will be used later on as key and value for the cross attention in decoder block.
        encoder_output = self.layer_norm(encoder_input)

        return encoder_output
    

# --------------------------------------------------- #
# Decoder block, Decoder and Projection

print("INFO: line 378 -> DecoderBlock")


class DecoderBlock(nn.Module):
    def __init__(self, masked_multihead_attention: MultiHeadAttention, cross_multihead_attention: MultiHeadAttention, feed_forward: FeedForward, dropout_rate: float) -> None:
        super().__init__()
        self.masked_multihead_attention = masked_multihead_attention
        self.cross_multihead_attention = cross_multihead_attention
        self.feed_forward = feed_forward
        # self.add_and_norm_list = nn.ModuleList(
        #     [AddAndNorm(dropout_rate) for _ in range(3)]
        # )
        self.addnorm_1 = AddAndNorm(dropout_rate)
        self.addnorm_2 = AddAndNorm(dropout_rate)
        self.addnorm_3 = AddAndNorm(dropout_rate)

    def forward(self, decoder_input, encoder_output, encoder_mask, decoder_mask):
        # First AddAndNorm unit taking decoder input from skip connection and adding it with the output of Masked Multi-Head attention block
        decoder_input = self.addnorm_1(decoder_input, lambda decoder_input: self.masked_multihead_attention(decoder_input, decoder_input, decoder_input, decoder_mask))
        # Second AddAndNorm unit taking output of Masked Multi-Head attention block from skip connection and adding it with the output of MultiHead attention block
        decoder_input = self.addnorm_2(decoder_input, lambda decoder_input: self.cross_multihead_attention(decoder_input, encoder_output, encoder_output, encoder_mask))
        # Third AddAndNorm unit taking output of MultiHead attention block from skip connection and adding it with the output of Feedforward layer
        decoder_input = self.addnorm_3(decoder_input, self.feed_forward)

        return decoder_input
    
    
print("INFO: line 407 -> Decoder")


class Decoder(nn.Module):
    def __init__(self, decoderblocklist: nn.ModuleList):
        super().__init__()

        self.decoderblocklist = decoderblocklist
        self.layer_norm = LayerNorm()

    def forward(self, decoder_input, encoder_output, encoder_mask, decoder_mask):
        for decoderblock in self.decoderblocklist:
            decoder_input = decoderblock(decoder_input, encoder_output, encoder_mask, decoder_mask)

        decoder_output = self.layer_norm(decoder_input)

        return decoder_output
    
    
print("INFO: line 425 -> ProjectionLayer")


class ProjectionLayer(nn.Module):
    def __init__(self, d_model, vocab_size):
        super().__init__()
        self.projection_layer = nn.Linear(d_model, vocab_size)

    def forward(self, decoder_output):
        # Projection layer first take in decoder output and passed into the linear layer of shape (d_model, vocab_size)
        # Change in shape: decoder_output(batch_size, seq_len, d_model) @ linear_layer(d_model, vocab_size) => output(batch_size, seq_len, vocab_size)
        output = self.projection_layer(decoder_output)

        # softmax function to output the probability distribution over the vocabulary
        # return torch.log_softmax(output, dim=-1)
        return output

    
    
# --------------------------------------------------- #
# Transformer

print("INFO: line 445 -> Transformer")


class Transformer(nn.Module):
    def __init__(self,
                 encoder: Encoder, 
                 decoder: Decoder, 
                 source_embed: EmbeddingLayer, 
                 target_embed: EmbeddingLayer, 
                 source_pos: PositionalEncoding, 
                 target_pos: PositionalEncoding, 
                 projection_layer: ProjectionLayer
    ) -> None:
        super().__init__()

        self.source_embed = source_embed
        self.source_pos = source_pos
        self.encoder = encoder

        self.target_embed = target_embed
        self.target_pos = target_pos
        self.decoder = decoder

        self.projection_layer = projection_layer

    # Encode function takes in encoder input, does necessary processing inside all encoder blocks and gives encoder output.
    def encode(self, encoder_input, encoder_mask):
        encoder_input = self.source_embed(encoder_input)
        encoder_input = self.source_pos(encoder_input)
        encoder_output = self.encoder(encoder_input, encoder_mask)
        return encoder_output

    # Decode function takes in decoder input, does necessary processing inside all decoder blocks and gives decoder output.
    def decode(self, encoder_output, encoder_mask, decoder_input, decoder_mask):
        decoder_input = self.target_embed(decoder_input)
        decoder_input = self.target_pos(decoder_input)
        decoder_output = self.decoder(decoder_input, encoder_output, encoder_mask, decoder_mask)
        return decoder_output

    # Projec function takes in decoder output into its projection layer and maps the output to the vocabulary for prediction.
    def project(self, decoder_output):
        return self.projection_layer(decoder_output)
    
    
# INFO: BUILD MODEL BLOCK INFO:

print("INFO: line 497 -> Model build started.")


def build_model(
    source_vocab_size: int, 
    target_vocab_size: int, 
    source_seq_len: int, 
    target_seq_len: int, 
    d_model: int=512, 
    num_blocks: int=6, 
    num_heads: int=8, 
    dropout_rate: float=0.1, 
    d_ff: int=2048
) -> Transformer:
    # Design and assign all the values that are needed by the transformer architecture
    source_embed = EmbeddingLayer(d_model, source_vocab_size)
    target_embed = EmbeddingLayer(d_model, target_vocab_size)
    
    # Create the positional encoding layers
    source_pos = PositionalEncoding(d_model, source_seq_len, dropout_rate)
    target_pos = PositionalEncoding(d_model, target_seq_len, dropout_rate)
    
    # Create the encoder-block-list
    encoderblocklist = []
    for _ in range(num_blocks):
        multihead_attention = MultiHeadAttention(d_model, num_heads, dropout_rate)
        feed_forward = FeedForward(d_model, d_ff, dropout_rate)
        encoder_block = EncoderBlock(multihead_attention, feed_forward, dropout_rate)
        encoderblocklist.append(encoder_block)
    # Create the encoder
    encoder = Encoder(nn.ModuleList(encoderblocklist))

    # Create the decoder-block-list
    decoderblocklist = []
    for _ in range(num_blocks):
        masked_multihead_attention = MultiHeadAttention(d_model,num_heads, dropout_rate)
        cross_multihead_attention = MultiHeadAttention(d_model, num_heads, dropout_rate)
        feed_forward = FeedForward(d_model, d_ff, dropout_rate)
        decoder_block = DecoderBlock(masked_multihead_attention, cross_multihead_attention, feed_forward, dropout_rate)
        decoderblocklist.append(decoder_block)
    # Create the decoder
    decoder = Decoder(nn.ModuleList(decoderblocklist))

    # Create the projection layer
    projection_layer = ProjectionLayer(d_model, target_vocab_size)

    # Now that we've initialized all the required blocks of transformer, we can now inititiate a model
    model = Transformer(
        encoder, 
        decoder, 
        source_embed, 
        target_embed, 
        source_pos, 
        target_pos, 
        projection_layer
    )

    for param in model.parameters():
        if param.dim() > 1:
            nn.init.xavier_uniform_(param)

    return model



# Finally, call build model and assign it to model variable.
# This model is now fully ready to train and validate our dataset.
# After training and validation, we can perform new translation task using this very model

# Let's build the the final model.
model = build_model(
    tokenizer_en.get_vocab_size(), 
    tokenizer_hi.get_vocab_size(),
    max_seq_len, max_seq_len, 
    d_model=512
).to(device)

# Let's look at the architecture that we've just build ourself
print(model)
wandb.watch(model)

# INFO: END BUILD MODEL BLOCK INFO:
print("INFO: line 562 -> Model build completed.")


# INFO: TRAIN MODEL BLOCK INFO:


def run_validation(model, validation_ds, tokenizer_en, tokenizer_hi, max_seq_len, device, print_msg, global_step):
    model.eval()
    count = 0

    with torch.no_grad():
        for batch in validation_ds:
            count += 1
            encoder_input = batch["encoder_input"].to(device)
            encoder_mask = batch["encoder_mask"].to(device)

            cls_id = tokenizer_hi.token_to_id('[CLS]')
            sep_id = tokenizer_hi.token_to_id('[SEP]')

            # Computing the output of the encoder for the source sequence
            encoder_output = model.module.encode(encoder_input, encoder_mask)
            # for prediction task, the first token that goes in decoder input is the [CLS] token
            decoder_input = torch.empty(1, 1).fill_(cls_id).type_as(encoder_input).to(device)
            # since we need to keep adding the output back to the input until the [SEP] - end token is received.
            while True:
                # check if the max length is received
                if decoder_input.size(1) == max_seq_len:
                    break

                # recreate mask each time the new output is added the decoder input for next token prediction
                decoder_mask = causal_mask(decoder_input.size(1)).type_as(encoder_mask).to(device)

                # apply projection only to the next token
                out = model.module.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)

                # apply projection only to the next token
                prob = model.module.project(out[:, -1])

                # select the token with highest probablity which is a greedy search implementation
                _, next_word = torch.max(prob, dim=1)
                #print("next word shape: ", next_word.shape)
                next_word = next_word.unsqueeze(0)
#                 next_word = next_word.unsqueeze(0).transpose(0, 1).unsqueeze(0)
                #print("next word shape: ", next_word.shape)
                
#                 decoder_input = torch.cat(
#                     [decoder_input, torch.empty(1, 1).type_as(encoder_input).fill_(next_word.item()).to(device)], dim=1
#                 )
                #print("decoder input shape: ", decoder_input.shape)
#                 decoder_input = torch.cat(
#                     [decoder_input, next_word], dim=1
#                 )
                # Ensure decoder_input and next_word have compatible dimensions
                decoder_input = torch.cat(
                        [decoder_input, next_word], dim=1
                )
                # check if the new token is the end of token
                if (next_word == sep_id).any():
                    break
            # final output is the concatinated decoder input till the end token is reached
            model_out = decoder_input.squeeze(0)

            source_text = batch["source_text"][0]
            target_text = batch["target_text"][0]
            model_out_text = tokenizer_hi.decode(model_out.detach().cpu().numpy())

            # Print the source, target and model output
            print_msg('-'*55)
            # print_msg(f"{f'SOURCE: ':>12}{source_text}")
            # print_msg(f"{f'TARGET: ':>12}{target_text}")
            # print_msg(f"{f'PREDICTED: ':>12}{model_out_text}")
            print_msg(f'Source Text: {source_text}')
            print_msg(f'Target Text: {target_text}')
            print_msg(f'Predicted by langGPT: {model_out_text}')

            if count == 2:
                break


# Constants
EPOCHS = 20
BATCH_SIZE = 1
ACCUMULATION_STEPS = 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

print("INFO: Dataloader started.")
# Creating DataLoader wrapper for both training and validation dataset. This dataloader will be used later stage during training and validation of our LLM model.
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
print("INFO: Dataloader complete.")

# Adam is one of the most commonly used optimization algorithms that hold the current state and will update the parameters based on the computed gradients.
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, eps=1e-9)
scaler = GradScaler()

# Learning Rate Scheduler
num_training_steps = EPOCHS * len(train_dataloader)
lr_scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)

# The CrossEntropyLoss loss function computes the difference between the projection output and target label.
loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_en.token_to_id('[PAD]'), label_smoothing=0.1).to(device)

# Multi-GPU
if torch.cuda.device_count() > 1:
    model = torch.nn.DataParallel(model)

# Function to clear GPU cache
def clear_gpu_cache():
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()


def train_model(preload_epoch=None):
    # The entire training, validation cycle will run for 20 cycles or epochs.
    EPOCHS = 20
    initial_epoch = 0
    global_step = 0
    epoch_loss = 0

    # If the preload_epoch is not none, that means the training will start with the weights, optimizer that has been last saved and start with preload epoch + 1
    if preload_epoch is not None:
      model_filename = f"/kaggle/working/model_{epoch}.pt"
      state = torch.load(model_filename)
      model.load_state_dict(state['model_state_dict'])
      initial_epoch = state['epoch'] + 1
      optimizer.load_state_dict(state['optimizer_state_dict'])
      global_step = state['global_step']

    for epoch in range(initial_epoch, EPOCHS):
        # torch.cuda.empty_cache()
        model.train()
        batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d}")
        optimizer.zero_grad()
        
        for batch in batch_iterator:
            encoder_input = batch['encoder_input'].to(device) # (b, seq_len)
            decoder_input = batch['decoder_input'].to(device) # (B, seq_len)
            encoder_mask = batch['encoder_mask'].to(device) # (B, 1, 1, seq_len)
            decoder_mask = batch['decoder_mask'].to(device) # (B, 1, seq_len, seq_len)
            target_label = batch['target_label'].to(device) # (B, seq_len)
            
            with autocast():
                # Run the tensors through the encoder, decoder and the projection layer
                encoder_output = model.module.encode(encoder_input, encoder_mask) # (B, seq_len, d_model)
                decoder_output = model.module.decode(encoder_output, encoder_mask, decoder_input, decoder_mask) # (B, seq_len, d_model)
                projection_output = model.module.project(decoder_output) # (B, seq_len, vocab_size)

                # Compute the loss using a simple cross entropy
                loss = loss_fn(projection_output.view(-1, tokenizer_hi.get_vocab_size()), target_label.view(-1))
                
                
            scaler.scale(loss).backward()
            
            if (global_step + 1) % ACCUMULATION_STEPS == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)
                lr_scheduler.step()
                
            
            global_step += 1
            
            batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})
            epoch_loss = loss.item()
            wandb.log({"loss": loss.item()}, step=global_step)
            clear_gpu_cache()

#             # Backpropagate the loss
#             loss.backward()

#             # Update the weights
#             optimizer.step()
#             optimizer.zero_grad(set_to_none=True)

            
        
        wandb.log({"epoch loss": epoch_loss}, step=epoch)
        # VALIDATION BLOCK STARTS HERE [Runs every epoch after the training block is complete]
        run_validation(model, val_dataloader, tokenizer_en, tokenizer_hi, max_seq_len, device, lambda msg: batch_iterator.write(msg), global_step)

        # Save the model at the end of every epoch
        model_filename = f"/kaggle/working/model_{epoch}.pt"
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'global_step': global_step
        }, model_filename)

# Train our model

# This function runs the training and validation for 10 epochs
train_model(preload_epoch=None)
print("INFO: Model training completed.")
print(f"INFO: Time: {time.time()}")

print(
    f"INFO: Total time taken(including loading dataset, training tokenizer, building the model, validating the model): {time.time() - st_time}s"
)

wandb.finish()