### Relevant Imports


In [1]:
import torch 
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split

import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir = 'logs')
import re

from structure.transformer import Transformer
from structure.Dataset import English_Hindi_Dataset

from sub_modules.embedding import Language_Embedding
from sub_modules.masks import get_masks

### Initializations


In [2]:
# Read data
read_max = 2_00_000 ######

# basics
batch_size = 256
sequence_length = 200
d_model = 512
num_of_sentences = 1_00_000
# transfomer
num_encoder_decoder_layers = 1
num_heads = 8
hidden_layers = 2048
num_encoder_decoder_layers = 1
dropout_ff = 0.1
dropout_attn = 0.1


### Dataset


In [3]:
dataset = English_Hindi_Dataset('Dataset/train.en/train.en', 
                                    'Dataset/train.hi/train.hi',
                                    num_of_sentences = num_of_sentences,
                                    max_sequence_length = sequence_length,
                                    read_max = read_max)

en_vocab_size = len(set(dataset.en_vocab))
hi_vocab_size = len(set(dataset.hi_vocab))


Total unique characters: English-> 97 Hindi-> 174
	Dataset Cleaned
	Dataset Tokenized and Pading is Done


### Embeddings


In [4]:
# embeddings
embeddings = Language_Embedding(en_vocab_size, hi_vocab_size, d_model)

### Data Loader


In [5]:
dataset_size = len(dataset)
train_size = int(0.8 * dataset_size)
val_size = dataset_size - int(0.8 * dataset_size)

train_dataset, val_dataset = random_split(dataset, [train_size,val_size])

train_data_loader = DataLoader(train_dataset, batch_size=batch_size,shuffle=True, pin_memory=True)
val_data_loader = DataLoader(val_dataset, batch_size=batch_size,shuffle=False, pin_memory=True)

### Model Initializations


In [6]:
# GPU for training
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f" Using: {device}")

model = Transformer(
    num_encoder_decoder_layers=num_encoder_decoder_layers,
    d_model=d_model,
    sequence_length=sequence_length,
    hidden_layers=hidden_layers,
    num_heads=num_heads,
    hi_voab_size=hi_vocab_size,
    dropout_ff=dropout_ff,
    dropout_attn=dropout_attn
).to(device)


# Loss
criterian = nn.CrossEntropyLoss(ignore_index= dataset.hindi_to_index[dataset.PADDING_TOKEN], reduction ='none')

# Parameter Initialization
for param in model.parameters():
    if param.dim()>1:
        nn.init.xavier_uniform_(param)
        
# optimizer 
optim = torch.optim.Adam(model.parameters(), lr= 1e-4)


 Using: cuda


### Training


In [7]:
model_save_path = "saved_models"  # Specify your directory to save models
os.makedirs(model_save_path, exist_ok=True)  # Create directory if it doesn't exist


def get_latest_model_checkpoint(model_save_path):
    model_files = os.listdir(model_save_path)
    model_epochs = [int(re.findall(r'model_epoch_(\d+).pt', file)[0]) for file in model_files if file.endswith('.pt')]
    
    if len(model_epochs)>0:
        latest_epoch = max(model_epochs)
        model_save_file = os.path.join(model_save_path, f"model_epoch_{latest_epoch}.pt")
        return latest_epoch, model_save_file
    else:
        return None, None
    
latest_epoch, model_save_file = get_latest_model_checkpoint(model_save_path)

if model_save_file:
    print(f"Loading model from {model_save_file}")
    model.load_state_dict(torch.load(model_save_file))
    current_epoch = latest_epoch + 1
else:
    print("No saved model found. Training from scratch.")
    current_epoch = 0

Loading model from saved_models\model_epoch_1.pt


  model.load_state_dict(torch.load(model_save_file))


In [7]:
best_val_loss = float('inf')
total_epochs = 10

for epoch in range(current_epoch, total_epochs + 1, 1):
    print(f"Epoch -> {epoch}")
    total_loss = 0
    
    train_data_iterator = iter(train_data_loader)
    for batch_num, batch in enumerate(tqdm(train_data_iterator, desc=f'Epoch {epoch + 1}/{total_epochs}', unit='batch')):
        model.train()
        en_batch, hi_batch = batch
        en_batch = en_batch.to(device)
        hi_batch = hi_batch.to(device)
        
        ds_mask, es_mask, edc_mask = get_masks(dataset, en_batch, hi_batch)
        ds_mask, es_mask, edc_mask = ds_mask.to(device), es_mask.to(device), edc_mask.to(device)
        optim.zero_grad()
        
        en_batch_embedded, hi_batch_embedded = embeddings(en_batch, hi_batch)
        en_batch_embedded, hi_batch_embedded = en_batch_embedded.to(device), hi_batch_embedded.to(device)
        hi_prediction = model(en_batch_embedded, hi_batch_embedded, ds_mask, es_mask, edc_mask)
        
        labels = [dataset.untokenize(hi_batch[index], dataset.index_to_hindi) for index in range(len(hi_batch))]
        labels = [dataset.tokenize(labels[index], dataset.hindi_to_index, start_token=False, end_token=True) for index in range(len(hi_batch))]
        labels = torch.stack(labels) 
        loss = criterian(
            hi_prediction.view(-1, hi_vocab_size).to(device),
            labels.view(-1).to(device)
        ).to(device)
        
        valid_indices = torch.where(labels.view(-1) == dataset.hindi_to_index[dataset.PADDING_TOKEN], False, True)
        loss = loss.sum() / valid_indices.sum()
        
        total_loss += loss.item()
        
        loss.backward()
        optim.step()
        
        per_batch = 300
        if batch_num % per_batch == 0:
            writer.add_scalar(f'Loss/Batch', loss.item(), epoch * len(train_data_iterator) + batch_num)
    
    writer.add_scalar('Loss/Epoch', total_loss/(batch_num + 1), epoch)
    print(f"\t\tEpoch [{epoch+1}/{total_epochs}], training Loss: {total_loss/(batch_num + 1):.4f}")
    
    # Validation Loop
    model.eval()
    val_loss = 0
    val_data_iterator = iter(val_data_loader)
    with torch.no_grad():
        for val_batch_num, val_batch in enumerate(tqdm(val_data_iterator, desc=f'Validation Epoch {epoch + 1}/{total_epochs}', unit='batch')):
            en_val_batch, hi_val_batch = val_batch
            en_val_batch = en_val_batch.to(device)
            hi_val_batch = hi_val_batch.to(device)
            
            ds_val_mask, es_val_mask, edc_val_mask = get_masks(dataset, en_val_batch, hi_val_batch)
            ds_val_mask, es_val_mask, edc_val_mask = ds_val_mask.to(device), es_val_mask.to(device), edc_val_mask.to(device)
            
            en_val_embedded, hi_val_embedded = embeddings(en_val_batch, hi_val_batch)
            en_val_embedded, hi_val_embedded = en_val_embedded.to(device), hi_val_embedded.to(device)
            
            hi_val_prediction = model(en_val_embedded, hi_val_embedded, ds_val_mask, es_val_mask, edc_val_mask)
            
            val_labels = [dataset.untokenize(hi_val_batch[index], dataset.index_to_hindi) for index in range(len(hi_val_batch))]
            val_labels = [dataset.tokenize(val_labels[index], dataset.hindi_to_index, start_token=False, end_token=True) for index in range(len(hi_val_batch))]
            val_labels = torch.stack(val_labels) 
            
            val_loss_batch = criterian(
                hi_val_prediction.view(-1, hi_vocab_size).to(device),
                val_labels.view(-1).to(device)
            ).to(device)
            
            valid_val_indices = torch.where(val_labels.view(-1) == dataset.hindi_to_index[dataset.PADDING_TOKEN], False, True)
            val_loss_batch = val_loss_batch.sum() / valid_val_indices.sum()
            
            val_loss += val_loss_batch.item()
    
    avg_val_loss = val_loss / (val_batch_num + 1)  # Average validation loss for the epoch
    writer.add_scalar('Loss/Validation_Epoch', avg_val_loss, epoch)

    ####### Print Epoch Losses #######
    print(f"\t\tEpoch [{epoch+1}/{total_epochs}], Validation Loss: {avg_val_loss:.4f}")
    
    
    print('\n')
    # Save Model
    model_save_file = os.path.join(model_save_path, f"model_epoch_{epoch + 1}.pt")
    torch.save(model.state_dict(), model_save_file)

writer.close()

Epoch -> 2


Epoch 3/10:  25%|██▌       | 1/4 [00:06<00:20,  6.81s/batch]