### Relevant Imports


In [1]:
import torch 
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split

import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir = 'logs')
import re

from structure.transformer import Transformer
from structure.Dataset import English_Hindi_Dataset

from sub_modules.embedding import Language_Embedding
from sub_modules.masks import get_masks

  from .autonotebook import tqdm as notebook_tqdm


### Initializations


In [2]:
# Read data
read_max = 7_00_000 ######

# basics
batch_size = 512
sequence_length = 100
d_model = 512
num_of_sentences = 3_00_000
# transfomer
num_encoder_decoder_layers = 6
num_heads = 8
hidden_layers = 2048

dropout_ff = 0.3
dropout_attn = 0.2


### Dataset


In [3]:
dataset = English_Hindi_Dataset('Dataset/train.en/train.en', 
                                    'Dataset/train.hi/train.hi',
                                    num_of_sentences = num_of_sentences,
                                    max_sequence_length = sequence_length,
                                    read_max = read_max)

en_vocab_size = len(set(dataset.en_vocab))
hi_vocab_size = len(set(dataset.hi_vocab))

assert len(dataset) == num_of_sentences, f"Dataset is of length: {len(dataset)} but required sample :{num_of_sentences}"


Total unique characters: English-> 97 Hindi-> 174
	Dataset Cleaned
	Dataset Tokenized and Pading is Done


### Embeddings


In [4]:
# embeddings
embeddings = Language_Embedding(en_vocab_size, hi_vocab_size, d_model)

### Data Loader


In [5]:
dataset_size = len(dataset)
train_size = int(0.8 * dataset_size)
val_size = dataset_size - int(0.8 * dataset_size)

train_dataset, val_dataset = random_split(dataset, [train_size,val_size])




### Model Initializations


In [6]:
# GPU for training
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f" Using: {device}")

model = Transformer(
    num_encoder_decoder_layers=num_encoder_decoder_layers,
    d_model=d_model,
    sequence_length=sequence_length,
    hidden_layers=hidden_layers,
    num_heads=num_heads,
    hi_voab_size=hi_vocab_size,
    dropout_ff=dropout_ff,
    dropout_attn=dropout_attn
).to(device)


# Loss
criterian = nn.CrossEntropyLoss(ignore_index= dataset.hindi_to_index[dataset.PADDING_TOKEN], reduction ='none')

# Parameter Initialization
for param in model.parameters():
    if param.dim()>1:
        nn.init.xavier_uniform_(param)
        
# optimizer 
optim = torch.optim.Adam(model.parameters(), lr= 1e-4)


 Using: cuda


### Model Training and Evaluation


In [7]:
model_save_path = "saved_models"  # Specify your directory to save models
os.makedirs(model_save_path, exist_ok=True)  # Create directory if it doesn't exist


def get_latest_model_checkpoint(model_save_path):
    model_files = os.listdir(model_save_path)
    model_epochs = [int(re.findall(r'model_epoch_(\d+).pt', file)[0]) for file in model_files if file.endswith('.pt')]
    
    if len(model_epochs)>0:
        latest_epoch = max(model_epochs)
        model_save_file = os.path.join(model_save_path, f"model_epoch_{latest_epoch}.pt")
        return latest_epoch, model_save_file
    else:
        return None, None
    
latest_epoch, model_save_file = get_latest_model_checkpoint(model_save_path)

if model_save_file:
    print(f"Loading model from {model_save_file}")
    model.load_state_dict(torch.load(model_save_file))
    current_epoch = latest_epoch + 1
else:
    print("No saved model found. Training from scratch.")
    current_epoch = 0

No saved model found. Training from scratch.


##### Training


In [None]:
best_val_loss = float('inf')
total_epochs = 100

for epoch in range(current_epoch, total_epochs + 1):
    print(f"Epoch -> {epoch}")
    total_loss = 0
    train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
    train_data_iterator = iter(train_data_loader)

    for batch_num, batch in enumerate(tqdm(train_data_iterator, desc=f'Epoch {epoch}/{total_epochs}', unit='batch')):
        model.train()
        en_batch, hi_batch = batch
        en_batch = en_batch.to(device)
        hi_batch = hi_batch.to(device)

        ds_mask, es_mask, edc_mask = get_masks(dataset, en_batch, hi_batch)
        ds_mask, es_mask, edc_mask = ds_mask.to(device), es_mask.to(device), edc_mask.to(device)

        optim.zero_grad()

        en_batch_embedded, hi_batch_embedded = embeddings(en_batch, hi_batch)
        en_batch_embedded, hi_batch_embedded = en_batch_embedded.to(device), hi_batch_embedded.to(device)
        hi_prediction = model(en_batch_embedded, hi_batch_embedded, ds_mask, es_mask, edc_mask)

        # Prepare labels
        labels_untoken = [dataset.untokenize(hi_batch[index], dataset.index_to_hindi) for index in range(len(hi_batch))]
        labels = [dataset.tokenize(labels_untoken[index], dataset.hindi_to_index, start_token=False, end_token=True) for index in range(len(hi_batch))]
        labels = torch.stack(labels).to(device)

        # Calculate loss
        loss = criterian(
            hi_prediction.view(-1, hi_vocab_size),
            labels.view(-1)
        )

        # Mask padding tokens
        valid_indices = (labels.view(-1) != dataset.hindi_to_index[dataset.PADDING_TOKEN])
        loss = loss[valid_indices].mean()  # Calculate the mean loss over valid indices

        total_loss += loss.item()

        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optim.step()

        # Log loss periodically
        if batch_num % 300 == 0:
            writer.add_scalar('Loss/Batch', loss.item(), epoch * len(train_data_iterator) + batch_num)
    
    avg_loss = total_loss / (batch_num + 1)
    writer.add_scalar('Loss/Epoch', avg_loss, epoch)
    print(f"\t\tEpoch [{epoch + 1}/{total_epochs}], training Loss: {avg_loss:.4f}")
 
    # Validation Loop
    model.eval()
    val_loss = 0
    val_data_loader = DataLoader(val_dataset, batch_size=batch_size,shuffle=False, pin_memory=True)
    val_data_iterator = iter(val_data_loader)
    with torch.no_grad():
        for val_batch_num, val_batch in enumerate(tqdm(val_data_iterator, desc=f'Validation Epoch {epoch }/{total_epochs}', unit='batch')):
            en_val_batch, hi_val_batch = val_batch
            en_val_batch = en_val_batch.to(device)
            hi_val_batch = hi_val_batch.to(device)
            
            ds_val_mask, es_val_mask, edc_val_mask = get_masks(dataset, en_val_batch, hi_val_batch)
            ds_val_mask, es_val_mask, edc_val_mask = ds_val_mask.to(device), es_val_mask.to(device), edc_val_mask.to(device)
            
            en_val_embedded, hi_val_embedded = embeddings(en_val_batch, hi_val_batch)
            en_val_embedded, hi_val_embedded = en_val_embedded.to(device), hi_val_embedded.to(device)
            
            hi_val_prediction = model(en_val_embedded, hi_val_embedded, ds_val_mask, es_val_mask, edc_val_mask)
            
            val_labels = [dataset.untokenize(hi_val_batch[index], dataset.index_to_hindi) for index in range(len(hi_val_batch))]
            val_labels = [dataset.tokenize(val_labels[index], dataset.hindi_to_index, start_token=False, end_token=True) for index in range(len(hi_val_batch))]
            val_labels = torch.stack(val_labels) 
            
            val_loss_batch = criterian(
                hi_val_prediction.view(-1, hi_vocab_size).to(device),
                val_labels.view(-1).to(device)
            ).to(device)
            
            valid_val_indices = torch.where(val_labels.view(-1) == dataset.hindi_to_index[dataset.PADDING_TOKEN], False, True)
            val_loss_batch = val_loss_batch.sum() / valid_val_indices.sum()
            
            val_loss += val_loss_batch.item()
    
    avg_val_loss = val_loss / (val_batch_num + 1)  # Average validation loss for the epoch
    writer.add_scalar('Loss/Validation_Epoch', avg_val_loss, epoch)


    ####### Print Epoch Losses #######
    print(f"\t\tEpoch [{epoch}/{total_epochs}], Validation Loss: {avg_val_loss:.4f}")
    
    print('\n')
    # Save Model
    model_save_file = os.path.join(model_save_path, f"model_epoch_{epoch }.pt")
    torch.save(model.state_dict(), model_save_file)

writer.close()

##### Evaluation


In [9]:
def translate(en_sentence):
    model.eval()
    en_sentence = (en_sentence,)
    hi_sentence = ("",)

    en_token = dataset.tokenize(en_sentence[0], dataset.english_to_index, start_token=False, end_token=False).unsqueeze(0).to(device)
    hi_token = dataset.tokenize(hi_sentence[0], dataset.hindi_to_index, start_token=True, end_token=False).unsqueeze(0).to(device)
    
    for word_counter in range(dataset.max_sequence_length):
        # print(f"Processing for {word_counter + 1} token")
    
        ds_mask, es_mask, edc_mask = get_masks(dataset, en_token, hi_token)
        ds_mask, es_mask, edc_mask = ds_mask.to(device), es_mask.to(device), edc_mask.to(device)
        
        en_embedded, hi_embedded = embeddings(en_token, hi_token)
        en_embedded, hi_embedded =  en_embedded.to(device), hi_embedded.to(device)
        
        predictions = model(en_embedded,
                            hi_embedded,
                            ds_mask, es_mask, edc_mask)
        next_token_prob_distribution = predictions[0][word_counter]
        next_token_index = torch.argmax(next_token_prob_distribution).item()
        next_token = dataset.index_to_hindi[next_token_index]
        
        if next_token == dataset.END_TOKEN:
            break
        hi_sentence = (hi_sentence[0] + next_token, )
        hi_token = dataset.tokenize(hi_sentence[0], dataset.hindi_to_index, start_token=True, end_token=False).unsqueeze(0).to(device)
        # print(f"\t\t\t Predicted till now: {hi_sentence[0]}")
    
    return hi_sentence[0]
    

In [10]:
en = dataset.untokenize(dataset[0][0], dataset.index_to_english)
hi = dataset.untokenize(dataset[0][1], dataset.index_to_hindi)
translation = translate(en)


In [11]:
print(f"en sentence : {en}")
print(f"actual translation : {hi}")
print(f"predicted translation : {translation}")

en sentence : In reply, Pakistan got off to a solid start.
actual translation : जिसके जवाब में पाक ने अच्छी शुरुआत की थी.
predicted translation : जवाब में पाकिस्तान ने एक ठोस शुरुआत की है।


In [16]:
line1 = "I am so mad at you."
line2 = "This is a beautiful day to go out."
line3 = "India is situated on the right side of pakistan"
lines = [line1, line2, line3]

In [17]:
translations = []
for line in lines:
    translations.append(translate(line))

In [18]:
for index, (en, hi) in enumerate(zip(lines, translations)):
    print(f"{en} -> {hi}")

I am so mad at you. -> मैं तुम्हारे ऊपर बहुत बड़ा प्रयोग कर रहा हूं।
This is a beautiful day to go out. -> यह दिन बाहर जाने का बहुत खूबसूरत है।
India is situated on the right side of pakistan -> भारत पाकिस्तान के दाहिने तरफ से स्थित है।


In [19]:
translate("we can go to eat out today")

'आज हम खाने के लिए जा सकते हैं'

#### Save dictionaries

In [28]:
import pickle

# Save to a pickle file
with open('dicts.pkl', 'wb') as f:
    pickle.dump(dict, f)

# To load later:
with open('dicts.pkl', 'rb') as f:
    loaded_dict = pickle.load(f)


In [29]:
loaded_dict.keys()

dict_keys(['en_vocab', 'hi_vocab', 'en_to_index', 'index_to_en', 'hi_to_index', 'index_to_hi'])

In [30]:
loaded_dict['index_to_hi']

{0: ' ',
 1: '!',
 2: '"',
 3: '#',
 4: '$',
 5: '%',
 6: '&',
 7: "'",
 8: '(',
 9: ')',
 10: '*',
 11: '+',
 12: ',',
 13: '-',
 14: '.',
 15: '/',
 16: '0',
 17: '1',
 18: '2',
 19: '3',
 20: '4',
 21: '5',
 22: '6',
 23: '7',
 24: '8',
 25: '9',
 26: ':',
 27: '<',
 28: '<END>',
 29: '<PADDING>',
 30: '<START>',
 31: '=',
 32: '>',
 33: '?',
 34: '@',
 35: '[',
 36: '\\',
 37: ']',
 38: '^',
 39: '_',
 40: '`',
 41: '{',
 42: '|',
 43: '}',
 44: '~',
 45: 'ऀ',
 46: 'ँ',
 47: 'ं',
 48: 'ः',
 49: 'ऄ',
 50: 'अ',
 51: 'आ',
 52: 'इ',
 53: 'ई',
 54: 'उ',
 55: 'ऊ',
 56: 'ऋ',
 57: 'ऌ',
 58: 'ऍ',
 59: 'ऎ',
 60: 'ए',
 61: 'ऐ',
 62: 'ऑ',
 63: 'ऒ',
 64: 'ओ',
 65: 'औ',
 66: 'क',
 67: 'ख',
 68: 'ग',
 69: 'घ',
 70: 'ङ',
 71: 'च',
 72: 'छ',
 73: 'ज',
 74: 'झ',
 75: 'ञ',
 76: 'ट',
 77: 'ठ',
 78: 'ड',
 79: 'ढ',
 80: 'ण',
 81: 'त',
 82: 'थ',
 83: 'द',
 84: 'ध',
 85: 'न',
 86: 'ऩ',
 87: 'प',
 88: 'फ',
 89: 'ब',
 90: 'भ',
 91: 'म',
 92: 'य',
 93: 'र',
 94: 'ऱ',
 95: 'ल',
 96: 'ळ',
 97: 'ऴ',
 98: 'व',
 

In [23]:


def translate(en_sentence):
    model.eval()
    en_sentence = (en_sentence,)
    hi_sentence = ("",)

    en_token = dataset.tokenize(en_sentence[0], dataset.english_to_index, start_token=False, end_token=False).unsqueeze(0).to(device)
    hi_token = dataset.tokenize(hi_sentence[0], dataset.hindi_to_index, start_token=True, end_token=False).unsqueeze(0).to(device)
    
    print(en_token)
    print(hi_token)
    for word_counter in range(dataset.max_sequence_length):
        # print(f"Processing for {word_counter + 1} token")
    
        ds_mask, es_mask, edc_mask = get_masks(dataset, en_token, hi_token)
        ds_mask, es_mask, edc_mask = ds_mask.to(device), es_mask.to(device), edc_mask.to(device)
        
        en_embedded, hi_embedded = embeddings(en_token, hi_token)
        en_embedded, hi_embedded =  en_embedded.to(device), hi_embedded.to(device)
        
        predictions = model(en_embedded,
                            hi_embedded,
                            ds_mask, es_mask, edc_mask)
        next_token_prob_distribution = predictions[0][word_counter]
        next_token_index = torch.argmax(next_token_prob_distribution).item()
        next_token = dataset.index_to_hindi[next_token_index]
        
        if next_token == dataset.END_TOKEN:
            break
        hi_sentence = (hi_sentence[0] + next_token, )
        hi_token = dataset.tokenize(hi_sentence[0], dataset.hindi_to_index, start_token=True, end_token=False).unsqueeze(0).to(device)
        # print(f"\t\t\t Predicted till now: {hi_sentence[0]}")
    
    return hi_sentence[0]
    

In [24]:
line1 = "Hello, How are you?"
translate(line1 )

tensor([[42, 71, 78, 78, 81, 12,  0, 42, 81, 89,  0, 67, 84, 71,  0, 91, 81, 87,
         33, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
         29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
         29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
         29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
         29, 29, 29, 29, 29, 29, 29, 29, 29, 29]], device='cuda:0')
tensor([[30, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
         29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
         29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
         29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
         29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
         29, 29, 29, 29, 29, 29, 29, 29, 29, 29]], device='cuda:0')


'हैलो, तुम कैसे हो?'

In [26]:
### Dictionaries
import json

with open('dicts.json', 'r') as file:
    data = json.load(file)
    
data.keys()

en_vocab = data['en_vocab']
hi_vocab = data['hi_vocab']
en_to_index = data['en_to_index']
index_to_en = data['index_to_en']
hi_to_index = data['hi_to_index']
index_to_hi = data['index_to_hi']

In [27]:
index_to_hi

{'0': ' ',
 '1': '!',
 '2': '"',
 '3': '#',
 '4': '$',
 '5': '%',
 '6': '&',
 '7': "'",
 '8': '(',
 '9': ')',
 '10': '*',
 '11': '+',
 '12': ',',
 '13': '-',
 '14': '.',
 '15': '/',
 '16': '0',
 '17': '1',
 '18': '2',
 '19': '3',
 '20': '4',
 '21': '5',
 '22': '6',
 '23': '7',
 '24': '8',
 '25': '9',
 '26': ':',
 '27': '<',
 '28': '<END>',
 '29': '<PADDING>',
 '30': '<START>',
 '31': '=',
 '32': '>',
 '33': '?',
 '34': '@',
 '35': '[',
 '36': '\\',
 '37': ']',
 '38': '^',
 '39': '_',
 '40': '`',
 '41': '{',
 '42': '|',
 '43': '}',
 '44': '~',
 '45': 'ऀ',
 '46': 'ँ',
 '47': 'ं',
 '48': 'ः',
 '49': 'ऄ',
 '50': 'अ',
 '51': 'आ',
 '52': 'इ',
 '53': 'ई',
 '54': 'उ',
 '55': 'ऊ',
 '56': 'ऋ',
 '57': 'ऌ',
 '58': 'ऍ',
 '59': 'ऎ',
 '60': 'ए',
 '61': 'ऐ',
 '62': 'ऑ',
 '63': 'ऒ',
 '64': 'ओ',
 '65': 'औ',
 '66': 'क',
 '67': 'ख',
 '68': 'ग',
 '69': 'घ',
 '70': 'ङ',
 '71': 'च',
 '72': 'छ',
 '73': 'ज',
 '74': 'झ',
 '75': 'ञ',
 '76': 'ट',
 '77': 'ठ',
 '78': 'ड',
 '79': 'ढ',
 '80': 'ण',
 '81': 'त',
 '82': 