### Relevant Imports


In [3]:
import torch 
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split

import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir = 'logs')
import re

from structure.transformer import Transformer
from structure.Dataset import English_Hindi_Dataset

from sub_modules.embedding import Language_Embedding
from sub_modules.masks import get_masks

### Initializations


In [4]:
# Read data
read_max = 6_00_000 ######

# basics
batch_size = 3000
sequence_length = 100
d_model = 512
num_of_sentences = 3_00_000
# transfomer
# num_encoder_decoder_layers = 6
num_heads = 8
hidden_layers = 2048
#################################################
num_encoder_decoder_layers = 1
#########################
dropout_ff = 0.3
dropout_attn = 0.2


### Dataset


In [5]:
dataset = English_Hindi_Dataset('Dataset/train.en/train.en', 
                                    'Dataset/train.hi/train.hi',
                                    num_of_sentences = num_of_sentences,
                                    max_sequence_length = sequence_length,
                                    read_max = read_max)

en_vocab_size = len(set(dataset.en_vocab))
hi_vocab_size = len(set(dataset.hi_vocab))


Total unique characters: English-> 97 Hindi-> 174
	Dataset Cleaned
	Dataset Tokenized and Pading is Done


### Embeddings


In [6]:
# embeddings
embeddings = Language_Embedding(en_vocab_size, hi_vocab_size, d_model)

### Data Loader


In [7]:
dataset_size = len(dataset)
train_size = int(0.8 * dataset_size)
val_size = dataset_size - int(0.8 * dataset_size)

train_dataset, val_dataset = random_split(dataset, [train_size,val_size])




### Model Initializations


In [8]:
# GPU for training
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f" Using: {device}")

model = Transformer(
    num_encoder_decoder_layers=num_encoder_decoder_layers,
    d_model=d_model,
    sequence_length=sequence_length,
    hidden_layers=hidden_layers,
    num_heads=num_heads,
    hi_voab_size=hi_vocab_size,
    dropout_ff=dropout_ff,
    dropout_attn=dropout_attn
).to(device)


# Loss
criterian = nn.CrossEntropyLoss(ignore_index= dataset.hindi_to_index[dataset.PADDING_TOKEN], reduction ='none')

# Parameter Initialization
for param in model.parameters():
    if param.dim()>1:
        nn.init.xavier_uniform_(param)
        
# optimizer 
optim = torch.optim.Adam(model.parameters(), lr= 1e-4)


 Using: cuda


### Model Training and Evaluation


In [9]:
model_save_path = "saved_models"  # Specify your directory to save models
os.makedirs(model_save_path, exist_ok=True)  # Create directory if it doesn't exist


def get_latest_model_checkpoint(model_save_path):
    model_files = os.listdir(model_save_path)
    model_epochs = [int(re.findall(r'model_epoch_(\d+).pt', file)[0]) for file in model_files if file.endswith('.pt')]
    
    if len(model_epochs)>0:
        latest_epoch = max(model_epochs)
        model_save_file = os.path.join(model_save_path, f"model_epoch_{latest_epoch}.pt")
        return latest_epoch, model_save_file
    else:
        return None, None
    
latest_epoch, model_save_file = get_latest_model_checkpoint(model_save_path)

if model_save_file:
    print(f"Loading model from {model_save_file}")
    model.load_state_dict(torch.load(model_save_file))
    current_epoch = latest_epoch + 1
else:
    print("No saved model found. Training from scratch.")
    current_epoch = 0

Loading model from saved_models\model_epoch_20.pt


  model.load_state_dict(torch.load(model_save_file))


##### Training


In [10]:
# best_val_loss = float('inf')
# total_epochs = 1

# for epoch in range(current_epoch, total_epochs + 1):
#     print(f"Epoch -> {epoch}")
#     total_loss = 0
#     train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
#     train_data_iterator = iter(train_data_loader)

#     for batch_num, batch in enumerate(tqdm(train_data_iterator, desc=f'Epoch {epoch}/{total_epochs}', unit='batch')):
#         model.train()
#         en_batch, hi_batch = batch
#         en_batch = en_batch.to(device)
#         hi_batch = hi_batch.to(device)

#         ds_mask, es_mask, edc_mask = get_masks(dataset, en_batch, hi_batch)
#         ds_mask, es_mask, edc_mask = ds_mask.to(device), es_mask.to(device), edc_mask.to(device)

#         optim.zero_grad()

#         en_batch_embedded, hi_batch_embedded = embeddings(en_batch, hi_batch)
#         en_batch_embedded, hi_batch_embedded = en_batch_embedded.to(device), hi_batch_embedded.to(device)
#         hi_prediction = model(en_batch_embedded, hi_batch_embedded, ds_mask, es_mask, edc_mask)

#         # Prepare labels
#         labels_untoken = [dataset.untokenize(hi_batch[index], dataset.index_to_hindi) for index in range(len(hi_batch))]
#         labels = [dataset.tokenize(labels_untoken[index], dataset.hindi_to_index, start_token=False, end_token=True) for index in range(len(hi_batch))]
#         labels = torch.stack(labels).to(device)
#         print(labels_untoken[0])
#         print(labels[0])
#         break

#         # Calculate loss
#         loss = criterian(
#             hi_prediction.view(-1, hi_vocab_size),
#             labels.view(-1)
#         )

#         # Mask padding tokens
#         valid_indices = (labels.view(-1) != dataset.hindi_to_index[dataset.PADDING_TOKEN])
#         loss = loss[valid_indices].mean()  # Calculate the mean loss over valid indices

#         total_loss += loss.item()

#         loss.backward()
        
#         # Gradient clipping
#         torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

#         optim.step()

#         # Log loss periodically
#         if batch_num % 300 == 0:
#             writer.add_scalar('Loss/Batch', loss.item(), epoch * len(train_data_iterator) + batch_num)
#     break
#     avg_loss = total_loss / (batch_num + 1)
#     writer.add_scalar('Loss/Epoch', avg_loss, epoch)
#     print(f"\t\tEpoch [{epoch + 1}/{total_epochs}], training Loss: {avg_loss:.4f}")
 
#     # Validation Loop
#     model.eval()
#     val_loss = 0
#     val_data_loader = DataLoader(val_dataset, batch_size=batch_size,shuffle=False, pin_memory=True)
#     val_data_iterator = iter(val_data_loader)
#     with torch.no_grad():
#         for val_batch_num, val_batch in enumerate(tqdm(val_data_iterator, desc=f'Validation Epoch {epoch }/{total_epochs}', unit='batch')):
#             en_val_batch, hi_val_batch = val_batch
#             en_val_batch = en_val_batch.to(device)
#             hi_val_batch = hi_val_batch.to(device)
            
#             ds_val_mask, es_val_mask, edc_val_mask = get_masks(dataset, en_val_batch, hi_val_batch)
#             ds_val_mask, es_val_mask, edc_val_mask = ds_val_mask.to(device), es_val_mask.to(device), edc_val_mask.to(device)
            
#             en_val_embedded, hi_val_embedded = embeddings(en_val_batch, hi_val_batch)
#             en_val_embedded, hi_val_embedded = en_val_embedded.to(device), hi_val_embedded.to(device)
            
#             hi_val_prediction = model(en_val_embedded, hi_val_embedded, ds_val_mask, es_val_mask, edc_val_mask)
            
#             val_labels = [dataset.untokenize(hi_val_batch[index], dataset.index_to_hindi) for index in range(len(hi_val_batch))]
#             val_labels = [dataset.tokenize(val_labels[index], dataset.hindi_to_index, start_token=False, end_token=True) for index in range(len(hi_val_batch))]
#             val_labels = torch.stack(val_labels) 
            
#             val_loss_batch = criterian(
#                 hi_val_prediction.view(-1, hi_vocab_size).to(device),
#                 val_labels.view(-1).to(device)
#             ).to(device)
            
#             valid_val_indices = torch.where(val_labels.view(-1) == dataset.hindi_to_index[dataset.PADDING_TOKEN], False, True)
#             val_loss_batch = val_loss_batch.sum() / valid_val_indices.sum()
            
#             val_loss += val_loss_batch.item()
    
#     avg_val_loss = val_loss / (val_batch_num + 1)  # Average validation loss for the epoch
#     writer.add_scalar('Loss/Validation_Epoch', avg_val_loss, epoch)

#     ####### Print Epoch Losses #######
#     print(f"\t\tEpoch [{epoch}/{total_epochs}], Validation Loss: {avg_val_loss:.4f}")
    
    
#     print('\n')
#     # Save Model
#     model_save_file = os.path.join(model_save_path, f"model_epoch_{epoch }.pt")
#     torch.save(model.state_dict(), model_save_file)

# writer.close()

##### Evaluation


In [16]:
model.eval()
def translate(en_sentence):
    en_sentence = (en_sentence,)
    hi_sentence = ("",)

    en_token = dataset.tokenize(en_sentence[0], dataset.english_to_index, start_token=False, end_token=False).unsqueeze(0).to(device)
    hi_token = dataset.tokenize(hi_sentence[0], dataset.hindi_to_index, start_token=True, end_token=False).unsqueeze(0).to(device)
    
    for word_counter in range(dataset.max_sequence_length):
        print(f"Processing for {word_counter + 1} token")
    
        ds_mask, es_mask, edc_mask = get_masks(dataset, en_token, hi_token)
        ds_mask, es_mask, edc_mask = ds_mask.to(device), es_mask.to(device), edc_mask.to(device)
        
        en_embedded, hi_embedded = embeddings(en_token, hi_token)
        en_embedded, hi_embedded =  en_embedded.to(device), hi_embedded.to(device)
        
        predictions = model(en_embedded,
                            hi_embedded,
                            ds_mask, es_mask, edc_mask)
        next_token_prob_distribution = predictions[0][word_counter]
        next_token_index = torch.argmax(next_token_prob_distribution).item()
        next_token = dataset.index_to_hindi[next_token_index]
        
        hi_sentence = (hi_sentence[0] + next_token, )
        print(hi_sentence)
        
        if next_token == dataset.END_TOKEN:
            break
        hi_token = dataset.tokenize(hi_sentence[0], dataset.hindi_to_index, start_token=True, end_token=False).unsqueeze(0).to(device)
        print(f"\t\t\t Predicted till now: {hi_sentence[0]}")

    return hi_sentence[0]
    

In [18]:
translation = translate("This is a beautiful day to go out.")
translation

Processing for 1 token
('क',)
			 Predicted till now: क
Processing for 2 token
('कक',)
			 Predicted till now: कक
Processing for 3 token
('ककक',)
			 Predicted till now: ककक
Processing for 4 token
('कककक',)
			 Predicted till now: कककक
Processing for 5 token
('ककककक',)
			 Predicted till now: ककककक
Processing for 6 token
('कककककग',)
			 Predicted till now: कककककग
Processing for 7 token
('कककककग ',)
			 Predicted till now: कककककग 
Processing for 8 token
('कककककग  ',)
			 Predicted till now: कककककग  
Processing for 9 token
('कककककग   ',)
			 Predicted till now: कककककग   
Processing for 10 token
('कककककग    ',)
			 Predicted till now: कककककग    
Processing for 11 token
('कककककग     ',)
			 Predicted till now: कककककग     
Processing for 12 token
('कककककग      ',)
			 Predicted till now: कककककग      
Processing for 13 token
('कककककग       ',)
			 Predicted till now: कककककग       
Processing for 14 token
('कककककग        ',)
			 Predicted till now: कककककग        
Processing for 15 token
('ककक

'कककककग                      <END>'

In [19]:
dataset[0]

(tensor([43, 80,  0, 84, 71, 82, 78, 91, 12,  0, 50, 67, 77, 75, 85, 86, 67, 80,
          0, 73, 81, 86,  0, 81, 72, 72,  0, 86, 81,  0, 67,  0, 85, 81, 78, 75,
         70,  0, 85, 86, 67, 84, 86, 14, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
         29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
         29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
         29, 29, 29, 29, 29, 29, 29, 29, 29, 29]),
 tensor([ 30,  73, 108, 101,  66, 116,   0,  73,  98, 107,  89,   0,  91, 116,
          47,   0,  87, 107,  66,   0,  85, 116,   0,  50,  71, 122,  72, 109,
           0,  99, 110,  93, 110,  51,  81,   0,  66, 109,   0,  82, 109,  14,
          28,  29,  29,  29,  29,  29,  29,  29,  29,  29,  29,  29,  29,  29,
          29,  29,  29,  29,  29,  29,  29,  29,  29,  29,  29,  29,  29,  29,
          29,  29,  29,  29,  29,  29,  29,  29,  29,  29,  29,  29,  29,  29,
          29,  29,  29,  29,  29,  29,  29,  29,  29,  29,  29,  29,  

In [20]:
dataset[1]

(tensor([54, 74, 71,  0, 37, 81, 80, 73, 84, 71, 85, 85,  0, 78, 71, 67, 70, 71,
         84,  0, 84, 71, 82, 84, 71, 85, 71, 80, 86, 85,  0, 53, 75, 88, 67, 73,
         67, 80, 73, 67,  0, 46, 81, 77,  0, 53, 67, 68, 74, 67,  0, 85, 71, 73,
         79, 71, 80, 86,  0, 72, 84, 81, 79,  0, 54, 67, 79, 75, 78,  0, 48, 67,
         70, 87, 14, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
         29, 29, 29, 29, 29, 29, 29, 29, 29, 29]),
 tensor([ 30,  66, 107,  47,  68, 122,  93, 116, 101,   0,  85, 116,  81, 107,
           0,  81,  91, 108,  95,  85, 107,  78, 110,   0, 101, 116,   0,  99,
         108,  98,  68,  47,  68, 107,   0,  95, 120,  66, 101,  90, 107,   0,
          66, 122, 100, 116,  81, 122,  93,   0,  66, 107,   0,  87, 122,  93,
          81, 108,  85, 108,  84, 108,  81, 122,  98,   0,  66,  93,  81, 116,
           0, 102, 117,  47,  14,  28,  29,  29,  29,  29,  29,  29,  29,  29,
          29,  29,  29,  29,  29,  29,  29,  29,  29,  29,  29,  29,  

In [21]:
en = dataset.untokenize(dataset[0][0], dataset.index_to_english)
hi = dataset.untokenize(dataset[0][1], dataset.index_to_hindi)

In [22]:
en, hi

('In reply, Pakistan got off to a solid start.',
 'जिसके जवाब में पाक ने अच्छी शुरुआत की थी.')

In [23]:
translation = translate(en)
translation

Processing for 1 token
('क',)
			 Predicted till now: क
Processing for 2 token
('कग',)
			 Predicted till now: कग
Processing for 3 token
('कगक',)
			 Predicted till now: कगक
Processing for 4 token
('कगकग',)
			 Predicted till now: कगकग
Processing for 5 token
('कगकग ',)
			 Predicted till now: कगकग 
Processing for 6 token
('कगकग  ',)
			 Predicted till now: कगकग  
Processing for 7 token
('कगकग   ',)
			 Predicted till now: कगकग   
Processing for 8 token
('कगकग    ',)
			 Predicted till now: कगकग    
Processing for 9 token
('कगकग     ',)
			 Predicted till now: कगकग     
Processing for 10 token
('कगकग      ',)
			 Predicted till now: कगकग      
Processing for 11 token
('कगकग       ',)
			 Predicted till now: कगकग       
Processing for 12 token
('कगकग        ',)
			 Predicted till now: कगकग        
Processing for 13 token
('कगकग         ',)
			 Predicted till now: कगकग         
Processing for 14 token
('कगकग          ',)
			 Predicted till now: कगकग          
Processing for 15 token
('कगक

'कगकग                                         <END>'