In [3]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import logging
from torch.utils.data import Dataset, DataLoader, random_split
from bpe_tokenizer import BPETokenizer
import tqdm
from torch.cuda.amp import GradScaler, autocast
import pandas as pd
import json
from collections import Counter
from typing import Dict, Any
import torchmetrics
import torchinfo
import mlflow

In [4]:
logger = logging.getLogger("TA-EN NMT")
logger.setLevel(logging.DEBUG)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)

In [5]:
class BlEU:

    def __init__(self, N_grams):
        self.n = N_grams

    def set_ngram(self, new_ngram):
        self.n = new_ngram
    
    def n_gram_precision(self, reference, hypothesis, n):
        ref_gram = Counter(zip(*[reference[i:] for i in range(n)]))
        hyp_gram = Counter(zip(*[hypothesis[i:] for i in range(n)]))

        overlap = {ngram : min(cnt, ref_gram[ngram]) for ngram,cnt in hyp_gram.items()}

        return sum(list(overlap.values())) / max(1, sum(ref_gram.values()))
    
    def brevity_penalty(self, reference,hypothesis):
        ref_len = len(reference)
        hyp_len = len(hypothesis)

        if hyp_len > ref_len:
            return 1
        else:

            return np.exp(1 - ref_len/hyp_len)
    def bleu_score(self,reference, hypothesis, max_n=4, weights=None):
        if weights is None:
            weights = [1.0 / max_n] * max_n

        precisions = [self.n_gram_precision(reference, hypothesis, n) for n in range(1, max_n + 1)]

        geometric_mean = self.exp(sum(w * np.log(p + 1e-10) for w, p in zip(weights, precisions)))

        bp = self.brevity_penalty(reference, hypothesis)

        return bp * geometric_mean

In [27]:
class TranslationDataset(Dataset):
    def __init__(self, datasets, src_bpe_path, trg_bpe_path):
        '''
        datasets: list[str] -> list of dataset csv files
        src_bpe_path: str -> path to source bpe file (path/to/src_bpe1.vocab.json do not include .vocab.json)
        trg_bpe_path: str -> path to target bpe file (path/to/trg_bpe1.vocab.json do not include .vocab.json)
        '''
        initial_csv = pd.read_csv(datasets[0])
        for dataset in datasets[1:]:
            initial_csv = pd.concat([initial_csv, pd.read_csv(dataset)])
        self.BPE_tokenizer_ta = BPETokenizer.load(src_bpe_path, "ta")
        self.BPE_tokenizer_en = BPETokenizer.load(trg_bpe_path, "en")
        self.tamil = initial_csv["ta"].tolist()
        self.english = initial_csv["en"].tolist()
        assert len(self.tamil) == len(self.english), "Tamil and English sentences are not of the same length"
    def __len__(self): return len(self.tamil)
    def __getitem__(self, idx):
        tamil_sentence = self.tamil[idx]
        english_sentence = self.english[idx]
        tamil_tokens = self.BPE_tokenizer_ta.tokenize(tamil_sentence)
        english_tokens = self.BPE_tokenizer_en.tokenize(english_sentence)
        tamil_tokens = torch.tensor(tamil_tokens,dtype=torch.long)
        english_tokens = torch.tensor(english_tokens, dtype = torch.long)
        trg_pad_mask = torch.full((len(tamil_tokens),), 0.,dtype = torch.bool)
        src_pad_mask = torch.full((len(tamil_tokens),), 0.,dtype = torch.bool)
        return tamil_tokens, english_tokens, src_pad_mask, trg_pad_mask

def collate_fn(batch):
    tamil_batch, english_batch, src_masks, trg_masks = zip(*batch)

    max_tamil_len = max(len(seq) for seq in tamil_batch)
    max_english_len = max(len(seq) for seq in english_batch)
    max_pad = max(max_tamil_len,max_english_len)
    padded_tamil = []
    padded_english = []
    src_pad_masks = []
    trg_pad_masks = []  
    for tamil_tokens, english_tokens in zip(tamil_batch, english_batch):
        tamil_len = len(tamil_tokens)
        english_len = len(english_tokens)
        padded_tamil.append(torch.cat([tamil_tokens, torch.zeros(max_pad - tamil_len, dtype=torch.long)]))
        padded_english.append(torch.cat([english_tokens, torch.zeros(max_pad - english_len, dtype=torch.long)]))
        src_pad_masks.append(torch.cat([torch.zeros(tamil_len), torch.ones(max_pad - tamil_len)]).bool())
        trg_pad_masks.append(torch.cat([torch.zeros(english_len), torch.ones(max_pad - english_len)]).bool())
    padded_tamil = torch.stack(padded_tamil)
    padded_english = torch.stack(padded_english)
    src_pad_masks = torch.stack(src_pad_masks)
    trg_pad_masks = torch.stack(trg_pad_masks)
    
    return padded_tamil, padded_english, src_pad_masks, trg_pad_masks

In [41]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, n_heads=8, n_layers=6, d_ff=2048, dropout=0.1):
        super(Transformer, self).__init__()
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model, n_heads, d_ff, dropout),
            num_layers=n_layers
        )
        self.decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model, n_heads, d_ff, dropout),
            num_layers=n_layers)
        self.src_emb = nn.Embedding(src_vocab_size, d_model)
        self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_enc = nn.Parameter(torch.zeros(10000, d_model))
        self.fc_out = nn.Linear(d_model, tgt_vocab_size)
        self.d_model = d_model
    def forward(self, src, tgt, src_mask=None, tgt_mask=None,decoder_mask = None):
        batch_size, src_seq_len = src.shape
        batch_size, tgt_seq_len = tgt.shape
        src_mask = src_mask.permute(1,0)
        decoder_mask = decoder_mask.permute(1,0)
        src = self.src_emb(src) + self.pos_enc[:src_seq_len].unsqueeze(0).expand(batch_size, -1, -1)
        tgt = self.tgt_emb(tgt) + self.pos_enc[:tgt_seq_len].unsqueeze(0).expand(batch_size, -1, -1)
        memory = self.encoder(src, src_key_padding_mask=src_mask)
        logger.info(memory.shape)
        logger.info(decoder_mask.shape)
        logger.info(tgt.shape)
        logger.info(src_mask.shape)
        logger.info(src.shape)
        output = self.decoder(tgt, memory, memory_key_padding_mask=src_mask, tgt_mask=tgt_mask, tgt_key_padding_mask = decoder_mask)
        return self.fc_out(output)

In [35]:
class NMTTrainer:
    def __init__(self, model, train_loader, val_loader, device, bleu_ngram = None):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.optimizer = None
        self.bleu = BlEU(bleu_ngram)
        self.criterion = nn.CrossEntropyLoss(ignore_index=0)
        self.train_metrics = {
            'loss': [],
            'bleu': []
        }
        self.eval_metrics = {
            'loss' : []
        }

    def set_optimizer(self, optimizer_name, **kwargs):
        if optimizer_name == 'adam':
            self.optimizer = optim.Adam(self.model.parameters(), **kwargs)
        elif optimizer_name == 'sgd':
            self.optimizer = optim.SGD(self.model.parameters(), **kwargs)
        else:
            raise ValueError(f"Unsupported optimizer: {optimizer_name}")
        
    def _compute_loss(self, output, target):
        return self.criterion(output.view(-1, output.size(-1)), target.contiguous().view(-1))

    def train_epoch(self,epoch):
        self.model.train()
        total_loss = 0
        for tamil, english, src_mask, tgt_mask in self.train_loader:
            tamil, english = tamil.to(self.device), english.to(self.device)
            src_mask, tgt_mask = src_mask.to(self.device), tgt_mask.to(self.device)
            self.optimizer.zero_grad()
            output = self.model(tamil, english[:, :-1],src_mask = src_mask, decoder_mask = tgt_mask[:,:-1])  
            loss = self._compute_loss(output, english[:, 1:])
            loss.backward()
            # torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)  
            self.optimizer.step()
            total_loss += loss.item()
        avg_loss = total_loss / len(self.train_loader)
        self.train_metrics['loss'].append(avg_loss)
        logger.info(f"Training_loss {epoch} -> {avg_loss}")

    def evaluate(self):
        self.model.eval()
        total_loss = 0
        with torch.no_grad():
            for tamil, english, src_mask, tgt_mask in self.train_loader:
                tamil, english = tamil.to(self.device), english.to(self.device)
                src_mask, tgt_mask = src_mask.to(self.device), tgt_mask.to(self.device)
                output = self.model(tamil, english[:, :-1],src_mask = src_mask, decoder_mask = tgt_mask[:,:-1])  
                loss = self._compute_loss(output, english[:, 1:])
                total_loss += loss.item()
        avg_loss = total_loss / len(self.val_loader)
        self.eval_metrics['loss'].append(avg_loss)
        logger.info(f"Validation Loss: {avg_loss}")

    def calculate_bleu(self, references, hypotheses):
        for ref, hyp in zip(references,hypotheses):
            bleu_score = self.bleu.bleu_score(ref,hyp)
        self.metrics['bleu'].append(bleu_score)
        logger.info(f"BLEU Score: {bleu_score}")

    def train(self, num_epochs):
        for epoch in tqdm.tqdm(range(num_epochs)):
            logger.info(f"Epoch {epoch + 1}/{num_epochs}")
            self.train_epoch(epoch)
            if epoch%5 == 0:
                self.evaluate()

In [28]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = TranslationDataset(["en-ta//pmindia.v1.ta-en 39k.csv", "en-ta//general_en_ta 87k.csv"], "src_bpe1", "trg_bpe1")
loader = DataLoader(dataset, batch_size=32, shuffle=True,drop_last=True, collate_fn=collate_fn)

In [29]:
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

In [30]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn, drop_last=True   )

In [39]:
model = Transformer(src_vocab_size=21401, tgt_vocab_size=13465).to(device)



In [42]:
trainer = NMTTrainer(model, train_loader, val_loader, device)  # Assuming same loader for train/val for simplicity
trainer.set_optimizer('adam', lr=0.0001)

In [43]:
trainer.train(num_epochs=10)

  0%|          | 0/10 [00:00<?, ?it/s]2025-03-13 15:43:54,077 - TA-EN NMT - INFO - Epoch 1/10
2025-03-13 15:43:54,077 - TA-EN NMT - INFO - Epoch 1/10
2025-03-13 15:43:54,077 - TA-EN NMT - INFO - Epoch 1/10
  0%|          | 0/10 [00:00<?, ?it/s]


torch.Size([32, 109, 512])
torch.Size([108, 32])
torch.Size([32, 108, 512])
torch.Size([109, 32])
torch.Size([32, 109, 512])


RuntimeError: shape '[32, 864, 64]' is invalid for input of size 1785856