In [None]:
# =====================
# Imports & Setup
# =====================
!pip install -q datasets sentencepiece sacrebleu torch torchvision torchaudio tqdm

import os, random
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from datasets import load_dataset
import sentencepiece as spm
import sacrebleu

# =====================
# Config
# =====================
SEED = 42
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
VOCAB_SIZE = 16000
BATCH_SIZE = 64
MAX_LEN = 64
MAX_GEN_LEN = 64
EPOCHS = 30
CLIP = 1.0
LEARNING_RATE = 3e-4
BEAM_SIZE = 5
BEST_MODEL_PATH = Path("best_model_baseline.pt")
CHECKPOINT_PATH = Path("checkpoint_baseline.pt")
PATIENCE = 5  # early stopping patience

torch.manual_seed(SEED)
random.seed(SEED)

# =====================
# Load dataset (up to 1M)
# =====================
full_dataset = load_dataset("ai4bharat/samanantar", "hi", split="train")
full_dataset = full_dataset.shuffle(seed=SEED)

NUM_EXAMPLES = min(1_000_000, len(full_dataset))
subset = full_dataset.select(range(NUM_EXAMPLES))

# Split 80% train, 10% val, 10% test
train_end = int(0.8 * len(subset))
val_end = int(0.9 * len(subset))
train_data = subset.select(range(0, train_end))
val_data = subset.select(range(train_end, val_end))
test_data = subset.select(range(val_end, len(subset)))

print("Dataset sizes:", len(train_data), len(val_data), len(test_data))

# =====================
# SentencePiece
# =====================
SP_EN_MODEL = Path("spm_en.model")
SP_HI_MODEL = Path("spm_hi.model")

def write_lines(dataset_split, src_path, tgt_path):
    with open(src_path, "w", encoding="utf-8") as sf, open(tgt_path, "w", encoding="utf-8") as tf:
        for ex in dataset_split:
            sf.write(ex["src"].strip().lower() + "\n")
            tf.write(ex["tgt"].strip() + "\n")

write_lines(train_data, "train.en", "train.hi")

if not SP_EN_MODEL.exists() or not SP_HI_MODEL.exists():
    print("Training SentencePiece...")
    spm.SentencePieceTrainer.Train(
        f"--input=train.en --model_prefix=spm_en --vocab_size={VOCAB_SIZE} "
        f"--character_coverage=1.0 --model_type=unigram"
    )
    spm.SentencePieceTrainer.Train(
        f"--input=train.hi --model_prefix=spm_hi --vocab_size={VOCAB_SIZE} "
        f"--character_coverage=0.9995 --model_type=unigram"
    )

sp_en = spm.SentencePieceProcessor()
sp_hi = spm.SentencePieceProcessor()
sp_en.load(str(SP_EN_MODEL))
sp_hi.load(str(SP_HI_MODEL))

PAD_EN, BOS_EN, EOS_EN = 0, 1, 2
PAD_HI, BOS_HI, EOS_HI = 0, 1, 2

# =====================
# Dataset & DataLoader
# =====================
class NMTDataset(Dataset):
    def __init__(self, dataset, src_sp, tgt_sp, max_len=MAX_LEN):
        self.dataset = dataset
        self.src_sp = src_sp
        self.tgt_sp = tgt_sp
        self.max_len = max_len

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        src_text = self.dataset[idx]["src"].lower()
        tgt_text = self.dataset[idx]["tgt"]
        src_ids = [BOS_EN] + self.src_sp.encode(src_text)[:self.max_len-2] + [EOS_EN]
        tgt_ids = [BOS_HI] + self.tgt_sp.encode(tgt_text)[:self.max_len-2] + [EOS_HI]
        src_ids += [PAD_EN] * (self.max_len - len(src_ids))
        tgt_ids += [PAD_HI] * (self.max_len - len(tgt_ids))
        return torch.tensor(src_ids), torch.tensor(tgt_ids)

def get_loader(dataset_split, shuffle=True):
    return DataLoader(NMTDataset(dataset_split, sp_en, sp_hi),
                      batch_size=BATCH_SIZE, shuffle=shuffle)

train_loader = get_loader(train_data)
val_loader = get_loader(val_data)
test_loader = get_loader(test_data, shuffle=False)

# =====================
# Masks
# =====================
def create_padding_mask(seq, lang='en'):
    pad_id = PAD_EN if lang == 'en' else PAD_HI
    return (seq == pad_id)

def generate_square_subsequent_mask(sz):
    return torch.triu(torch.ones(sz, sz, device=DEVICE) * float('-inf'), diagonal=1)

# =====================
# Transformer Model
# =====================
class TransformerModel(nn.Module):
    def __init__(self, src_vocab, tgt_vocab, d_model=512, nhead=8,
                 num_layers=3, dim_feedforward=1024, dropout=0.1):
        super().__init__()
        self.src_emb = nn.Embedding(src_vocab, d_model, padding_idx=PAD_EN)
        self.tgt_emb = nn.Embedding(tgt_vocab, d_model, padding_idx=PAD_HI)
        self.pos_enc = nn.Parameter(torch.zeros(1, MAX_LEN, d_model))
        self.transformer = nn.Transformer(
            d_model=d_model, nhead=nhead,
            num_encoder_layers=num_layers, num_decoder_layers=num_layers,
            dim_feedforward=dim_feedforward, dropout=dropout,
            batch_first=True
        )
        self.fc_out = nn.Linear(d_model, tgt_vocab)

    def encode(self, src, src_key_padding_mask):
        src_emb = self.src_emb(src) + self.pos_enc[:, :src.size(1), :]
        return self.transformer.encoder(src_emb, src_key_padding_mask=src_key_padding_mask)

    def decode(self, tgt, memory, tgt_mask, memory_key_padding_mask, tgt_key_padding_mask):
        tgt_emb = self.tgt_emb(tgt) + self.pos_enc[:, :tgt.size(1), :]
        return self.transformer.decoder(
            tgt_emb, memory,
            tgt_mask=tgt_mask,
            memory_key_padding_mask=memory_key_padding_mask,
            tgt_key_padding_mask=tgt_key_padding_mask
        )

    def forward(self, src, tgt, src_key_padding_mask=None, tgt_key_padding_mask=None,
                memory_key_padding_mask=None, tgt_mask=None):
        memory = self.encode(src, src_key_padding_mask)
        output = self.decode(tgt, memory, tgt_mask, memory_key_padding_mask, tgt_key_padding_mask)
        return self.fc_out(output)

# =====================
# Training Setup
# =====================
model = TransformerModel(len(sp_en), len(sp_hi)).to(DEVICE)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_HI)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=LEARNING_RATE, steps_per_epoch=len(train_loader), epochs=EPOCHS
)
scaler = torch.amp.GradScaler()

# ---- Resume if checkpoint exists ----
start_epoch = 1
best_val_loss = float("inf")
epochs_no_improve = 0

if CHECKPOINT_PATH.exists():
    print(f"üîÑ Resuming from checkpoint: {CHECKPOINT_PATH}")
    checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
    # Handle optional scaler
    if "scaler_state_dict" in checkpoint:
        scaler.load_state_dict(checkpoint["scaler_state_dict"])
        print("‚úÖ Loaded scaler state from checkpoint.")
    else:
        print("‚ö†Ô∏è No scaler state found in checkpoint ‚Äî continuing without it.")
    start_epoch = checkpoint.get("epoch", 0) + 1
    best_val_loss = checkpoint.get("best_val_loss", float("inf"))
    epochs_no_improve = checkpoint.get("epochs_no_improve", 0)
    print(f"Resumed from epoch {start_epoch-1}, best val loss = {best_val_loss:.4f}")

# =====================
# Training Loop
# =====================
for epoch in range(start_epoch, EPOCHS + 1):
    # ---- Train ----
    model.train()
    train_loss = 0
    for src, tgt in tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}", unit="batch"):
        src, tgt = src.to(DEVICE), tgt.to(DEVICE)
        src_mask = create_padding_mask(src, 'en')
        tgt_input, tgt_out = tgt[:, :-1], tgt[:, 1:]
        tgt_mask = generate_square_subsequent_mask(tgt_input.size(1))
        tgt_key_padding_mask = create_padding_mask(tgt_input, 'hi')

        optimizer.zero_grad()
        with torch.amp.autocast(device_type="cuda"):
            logits = model(
                src, tgt_input,
                src_key_padding_mask=src_mask,
                tgt_key_padding_mask=tgt_key_padding_mask,
                memory_key_padding_mask=src_mask,
                tgt_mask=tgt_mask
            )
            loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        train_loss += loss.item()

    avg_train_loss = train_loss / len(train_loader)

    # ---- Validation ----
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for src, tgt in val_loader:
            src, tgt = src.to(DEVICE), tgt.to(DEVICE)
            src_mask = create_padding_mask(src, 'en')
            tgt_input, tgt_out = tgt[:, :-1], tgt[:, 1:]
            tgt_mask = generate_square_subsequent_mask(tgt_input.size(1))
            tgt_key_padding_mask = create_padding_mask(tgt_input, 'hi')

            logits = model(
                src, tgt_input,
                src_key_padding_mask=src_mask,
                tgt_key_padding_mask=tgt_key_padding_mask,
                memory_key_padding_mask=src_mask,
                tgt_mask=tgt_mask
            )
            loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))
            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader)
    print(f"Epoch {epoch} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

    # ---- Save Checkpoint ----
    torch.save({
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict(),
        "scaler_state_dict": scaler.state_dict(),
        "best_val_loss": best_val_loss,
        "epochs_no_improve": epochs_no_improve,
    }, CHECKPOINT_PATH)

    # ---- Early Stopping ----
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        epochs_no_improve = 0
        torch.save(model.state_dict(), BEST_MODEL_PATH)
        print("‚úÖ New best model saved!")
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= PATIENCE:
            print("‚èπÔ∏è Early stopping triggered.")
            break

print("üéØ Training complete.")


  from .autonotebook import tqdm as notebook_tqdm


Dataset sizes: 800000 100000 100000
üîÑ Resuming from checkpoint: checkpoint_baseline.pt


  checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)


‚úÖ Loaded scaler state from checkpoint.
Resumed from epoch 26, best val loss = 2.4229


Epoch 27/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12500/12500 [13:56<00:00, 14.94batch/s]
  output = torch._nested_tensor_from_mask(


Epoch 27 | Train Loss: 1.5976 | Val Loss: 2.4865


Epoch 28/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12500/12500 [13:56<00:00, 14.93batch/s]


Epoch 28 | Train Loss: 1.5685 | Val Loss: 2.4935


Epoch 29/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12500/12500 [13:54<00:00, 14.98batch/s]


Epoch 29 | Train Loss: 1.5502 | Val Loss: 2.4999
‚èπÔ∏è Early stopping triggered.
üéØ Training complete.


In [None]:
# =====================
# Save Full Checkpoint
# =====================
torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    'best_val_loss': best_val_loss,
    'config': {
        'VOCAB_SIZE': VOCAB_SIZE,
        'MAX_LEN': MAX_LEN,
        'EPOCHS': EPOCHS,
        'LEARNING_RATE': LEARNING_RATE,
        'CLIP': CLIP,
        'BATCH_SIZE': BATCH_SIZE
    }
}, "checkpoint_baseline.pt")

print("‚úÖ Full checkpoint saved as 'checkpoint_baseline.pt'")


‚úÖ Full checkpoint saved as 'checkpoint_baseline.pt'


In [None]:
import json

config = {
    "VOCAB_SIZE": VOCAB_SIZE,
    "MAX_LEN": MAX_LEN,
    "EPOCHS": EPOCHS,
    "LEARNING_RATE": LEARNING_RATE,
    "CLIP": CLIP,
    "BATCH_SIZE": BATCH_SIZE,
    "DEVICE": DEVICE
}

with open("config.json", "w") as f:
    json.dump(config, f, indent=4)

print("‚úÖ Config saved to config.json")


‚úÖ Config saved to config.json


In [None]:
# =====================
# ‚úÖ Evaluation on Test Set
# =====================
from torch.nn.functional import log_softmax
import math

# ---- Load best model ----
model.load_state_dict(torch.load(BEST_MODEL_PATH, map_location=DEVICE))
model.eval()
print("‚úÖ Loaded best model for evaluation.")

# =====================
# Greedy Decoding
# =====================
def greedy_decode(model, src, max_len=MAX_GEN_LEN):
    model.eval()
    src_mask = create_padding_mask(src, 'en')
    memory = model.encode(src, src_mask)
    ys = torch.full((src.size(0), 1), BOS_HI, dtype=torch.long, device=DEVICE)

    for i in range(max_len - 1):
        tgt_mask = generate_square_subsequent_mask(ys.size(1))
        tgt_key_padding_mask = create_padding_mask(ys, 'hi')

        out = model.decode(
            ys, memory, tgt_mask=tgt_mask,
            memory_key_padding_mask=src_mask,
            tgt_key_padding_mask=tgt_key_padding_mask
        )
        out = model.fc_out(out[:, -1, :])
        next_word = out.argmax(dim=-1, keepdim=True)
        ys = torch.cat([ys, next_word], dim=1)

        # Stop if EOS is reached for all
        if torch.all(next_word.squeeze() == EOS_HI):
            break
    return ys


# =====================
# Beam Search Decoding
# =====================
def beam_search_decode(model, src, beam_size=BEAM_SIZE, max_len=MAX_GEN_LEN):
    model.eval()
    src_mask = create_padding_mask(src, 'en')
    memory = model.encode(src, src_mask)

    batch_size = src.size(0)
    results = []

    for i in range(batch_size):
        beams = [(torch.tensor([[BOS_HI]], device=DEVICE), 0.0)]
        for _ in range(max_len - 1):
            new_beams = []
            for seq, score in beams:
                if seq[0, -1].item() == EOS_HI:
                    new_beams.append((seq, score))
                    continue

                tgt_mask = generate_square_subsequent_mask(seq.size(1))
                tgt_key_padding_mask = create_padding_mask(seq, 'hi')
                out = model.decode(
                    seq, memory[i:i+1],
                    tgt_mask=tgt_mask,
                    memory_key_padding_mask=src_mask[i:i+1],
                    tgt_key_padding_mask=tgt_key_padding_mask
                )
                logits = model.fc_out(out[:, -1, :])
                log_probs = log_softmax(logits, dim=-1)
                topk_log_probs, topk_indices = torch.topk(log_probs, beam_size)

                for k in range(beam_size):
                    next_seq = torch.cat([seq, topk_indices[:, k].unsqueeze(1)], dim=1)
                    new_beams.append((next_seq, score + topk_log_probs[0, k].item()))

            beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_size]

            if all(seq[0, -1].item() == EOS_HI for seq, _ in beams):
                break

        best_seq = beams[0][0]
        results.append(best_seq)
    return results


# =====================
# Evaluate BLEU on Test Set
# =====================
references, hypotheses = [], []
for batch_idx, (src, tgt) in enumerate(tqdm(test_loader, desc="Evaluating", unit="batch")):
    src, tgt = src.to(DEVICE), tgt.to(DEVICE)
    with torch.no_grad():
        pred = greedy_decode(model, src)
    for i in range(src.size(0)):
        tgt_tokens = tgt[i].tolist()
        pred_tokens = pred[i].tolist()

        tgt_text = sp_hi.decode([t for t in tgt_tokens if t not in [PAD_HI, BOS_HI, EOS_HI]])
        pred_text = sp_hi.decode([t for t in pred_tokens if t not in [PAD_HI, BOS_HI, EOS_HI]])

        references.append([tgt_text])
        hypotheses.append(pred_text)

# ---- Compute BLEU ----
bleu = sacrebleu.corpus_bleu(hypotheses, list(zip(*references)))
print(f"üåç Test BLEU score: {bleu.score:.2f}")

# =====================
# Show Some Examples
# =====================
print("\nüîç Sample Translations:")
for i in range(5):
    src_ids, tgt_ids = test_data[i]["src"], test_data[i]["tgt"]
    src_tensor = torch.tensor([[BOS_EN] + sp_en.encode(src_ids.lower())[:MAX_LEN-2] + [EOS_EN]], device=DEVICE)
    pred = greedy_decode(model, src_tensor)
    pred_tokens = pred[0].tolist()
    pred_text = sp_hi.decode([t for t in pred_tokens if t not in [PAD_HI, BOS_HI, EOS_HI]])

    print(f"\nEN: {test_data[i]['src']}")
    print(f"HI (Reference): {test_data[i]['tgt']}")
    print(f"HI (Predicted): {pred_text}")
    print("-" * 50)


  model.load_state_dict(torch.load(BEST_MODEL_PATH, map_location=DEVICE))


‚úÖ Loaded best model for evaluation.


Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1563/1563 [09:34<00:00,  2.72batch/s]


üåç Test BLEU score: 8.58

üîç Sample Translations:

EN: Its movement was captured on CCTV.
HI (Reference): ‡§â‡§®‡§ï‡•Ä ‡§Ø‡§π ‡§π‡§∞‡§ï‡§§ ‡§∏‡•Ä‡§∏‡•Ä‡§ü‡•Ä‡§µ‡•Ä ‡§Æ‡•á‡§Ç ‡§ï‡•à‡§¶ ‡§π‡•ã ‡§ó‡§à‡•§
HI (Predicted): ‡§Ø‡§π ‡§∏‡•Ä‡§∏‡•Ä‡§ü‡•Ä‡§µ‡•Ä ‡§Æ‡•á‡§Ç ‡§ï‡•à‡§¶ ‡§π‡•ã ‡§ó‡§Ø‡§æ‡•§
--------------------------------------------------

EN: "The two leaders ""discussed bilateral ties, including development partnership and cooperation in counter-terrorism and international fora,"" he said in the tweet."
HI (Reference): ‡§ü‡•ç‡§µ‡•Ä‡§ü ‡§Æ‡•á‡§Ç ‡§â‡§®‡•ç‡§π‡•ã‡§Ç‡§®‡•á ‡§ï‡§π‡§æ ‡§ï‡§ø ‡§¶‡•ã‡§®‡•ã‡§Ç ‡§®‡•á‡§§‡§æ‡§ì‡§Ç ‡§®‡•á ‚Äò‚Äò‡§µ‡§ø‡§ï‡§æ‡§∏ ‡§∏‡§æ‡§ù‡•á‡§¶‡§æ‡§∞‡•Ä ‡§î‡§∞ ‡§Ü‡§§‡§Ç‡§ï‡§µ‡§æ‡§¶ ‡§ï‡•á ‡§ñ‡§ø‡§≤‡§æ‡§´ ‡§§‡§•‡§æ ‡§Ö‡§Ç‡§§‡§∞‡§∞‡§æ‡§∑‡•ç‡§ü‡•ç‡§∞‡•Ä‡§Ø ‡§Æ‡§Ç‡§ö‡•ã‡§Ç ‡§™‡§∞ ‡§∏‡§π‡§Ø‡•ã‡§ó ‡§∏‡§Æ‡•á‡§§ ‡§¶‡•ç‡§µ‡§ø‡§™‡§ï‡•ç‡§∑‡•Ä‡§Ø ‡§∏‡§Ç‡§¨‡§Ç‡§ß‡•ã‡§Ç‚Äô‚Äô ‡§™‡§∞ ‡§ö‡§∞‡•ç‡§ö‡§æ ‡§ï‡•Ä‡•§
HI (Predicted): ‡§â‡§®‡•ç‡§π‡•ã‡§Ç‡§®‡•á ‡§ü‡•ç‡§µ‡•Ä‡§ü ‡§ï‡§ø‡§Ø‡

In [None]:
# =====================
# Imports & Setup
# =====================
!pip install -q datasets sentencepiece sacrebleu torch torchvision torchaudio tqdm

import os, random
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from datasets import load_dataset
import sentencepiece as spm
import sacrebleu

# =====================
# Config
# =====================
SEED = 42
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
VOCAB_SIZE = 16000
BATCH_SIZE = 64
MAX_LEN = 64
MAX_GEN_LEN = 64
EPOCHS = 30
CLIP = 1.0
LEARNING_RATE = 3e-4
BEAM_SIZE = 5
BEST_MODEL_PATH = Path("best_model_cnn.pt")
CHECKPOINT_PATH = Path("checkpoint_cnn.pt")
PATIENCE = 5

torch.manual_seed(SEED)
random.seed(SEED)

# =====================
# Load dataset (up to 1M)
# =====================
full_dataset = load_dataset("ai4bharat/samanantar", "hi", split="train")
full_dataset = full_dataset.shuffle(seed=SEED)

NUM_EXAMPLES = min(1_000_000, len(full_dataset))
subset = full_dataset.select(range(NUM_EXAMPLES))

# Split 80/10/10
train_end = int(0.8 * len(subset))
val_end = int(0.9 * len(subset))
train_data = subset.select(range(0, train_end))
val_data = subset.select(range(train_end, val_end))
test_data = subset.select(range(val_end, len(subset)))

print("Dataset sizes:", len(train_data), len(val_data), len(test_data))

# =====================
# SentencePiece
# =====================
SP_EN_MODEL = Path("spm_en.model")
SP_HI_MODEL = Path("spm_hi.model")

def write_lines(dataset_split, src_path, tgt_path):
    with open(src_path, "w", encoding="utf-8") as sf, open(tgt_path, "w", encoding="utf-8") as tf:
        for ex in dataset_split:
            sf.write(ex["src"].strip().lower() + "\n")
            tf.write(ex["tgt"].strip() + "\n")

write_lines(train_data, "train.en", "train.hi")

if not SP_EN_MODEL.exists() or not SP_HI_MODEL.exists():
    print("Training SentencePiece...")
    spm.SentencePieceTrainer.Train(
        f"--input=train.en --model_prefix=spm_en --vocab_size={VOCAB_SIZE} "
        f"--character_coverage=1.0 --model_type=unigram"
    )
    spm.SentencePieceTrainer.Train(
        f"--input=train.hi --model_prefix=spm_hi --vocab_size={VOCAB_SIZE} "
        f"--character_coverage=0.9995 --model_type=unigram"
    )

sp_en = spm.SentencePieceProcessor()
sp_hi = spm.SentencePieceProcessor()
sp_en.load(str(SP_EN_MODEL))
sp_hi.load(str(SP_HI_MODEL))

PAD_EN, BOS_EN, EOS_EN = 0, 1, 2
PAD_HI, BOS_HI, EOS_HI = 0, 1, 2

# =====================
# Dataset & DataLoader
# =====================
class NMTDataset(Dataset):
    def __init__(self, dataset, src_sp, tgt_sp, max_len=MAX_LEN):
        self.dataset = dataset
        self.src_sp = src_sp
        self.tgt_sp = tgt_sp
        self.max_len = max_len

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        src_text = self.dataset[idx]["src"].lower()
        tgt_text = self.dataset[idx]["tgt"]
        src_ids = [BOS_EN] + self.src_sp.encode(src_text)[:self.max_len-2] + [EOS_EN]
        tgt_ids = [BOS_HI] + self.tgt_sp.encode(tgt_text)[:self.max_len-2] + [EOS_HI]
        src_ids += [PAD_EN] * (self.max_len - len(src_ids))
        tgt_ids += [PAD_HI] * (self.max_len - len(tgt_ids))
        return torch.tensor(src_ids), torch.tensor(tgt_ids)

def get_loader(dataset_split, shuffle=True):
    return DataLoader(NMTDataset(dataset_split, sp_en, sp_hi),
                      batch_size=BATCH_SIZE, shuffle=shuffle)

train_loader = get_loader(train_data)
val_loader = get_loader(val_data)
test_loader = get_loader(test_data, shuffle=False)

# =====================
# Masks
# =====================
def create_padding_mask(seq, lang='en'):
    pad_id = PAD_EN if lang == 'en' else PAD_HI
    return (seq == pad_id)

def generate_square_subsequent_mask(sz):
    return torch.triu(torch.ones(sz, sz, device=DEVICE) * float('-inf'), diagonal=1)

# =====================
# CNN Feature Extractor (2-layer)
# =====================
class CNNFeatureExtractor(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.conv1 = nn.Conv1d(embed_dim, embed_dim, kernel_size=5, padding=2)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, padding=1)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        residual = x
        x = x.transpose(1, 2)
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = x.transpose(1, 2)
        return self.norm(x + residual)

# =====================
# Hybrid CNN + Transformer Model
# =====================
class HybridTransformerModel(nn.Module):
    def __init__(self, src_vocab, tgt_vocab, d_model=512, nhead=8,
                 num_layers=3, dim_feedforward=1024, dropout=0.1):
        super().__init__()
        self.src_emb = nn.Embedding(src_vocab, d_model, padding_idx=PAD_EN)
        self.tgt_emb = nn.Embedding(tgt_vocab, d_model, padding_idx=PAD_HI)
        self.pos_enc = nn.Parameter(torch.zeros(1, MAX_LEN, d_model))
        self.cnn_encoder = CNNFeatureExtractor(d_model)
        self.transformer = nn.Transformer(
            d_model=d_model, nhead=nhead,
            num_encoder_layers=num_layers, num_decoder_layers=num_layers,
            dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True
        )
        self.fc_out = nn.Linear(d_model, tgt_vocab)

    def encode(self, src, src_key_padding_mask):
        src_emb = self.src_emb(src) + self.pos_enc[:, :src.size(1), :]
        src_cnn = self.cnn_encoder(src_emb)
        return self.transformer.encoder(src_cnn, src_key_padding_mask=src_key_padding_mask)

    def decode(self, tgt, memory, tgt_mask, memory_key_padding_mask, tgt_key_padding_mask):
        tgt_emb = self.tgt_emb(tgt) + self.pos_enc[:, :tgt.size(1), :]
        return self.transformer.decoder(
            tgt_emb, memory,
            tgt_mask=tgt_mask,
            memory_key_padding_mask=memory_key_padding_mask,
            tgt_key_padding_mask=tgt_key_padding_mask
        )

    def forward(self, src, tgt, src_key_padding_mask=None, tgt_key_padding_mask=None,
                memory_key_padding_mask=None, tgt_mask=None):
        memory = self.encode(src, src_key_padding_mask)
        output = self.decode(tgt, memory, tgt_mask, memory_key_padding_mask, tgt_key_padding_mask)
        return self.fc_out(output)

# =====================
# Training Setup
# =====================
model = HybridTransformerModel(len(sp_en), len(sp_hi)).to(DEVICE)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_HI)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=LEARNING_RATE, steps_per_epoch=len(train_loader), epochs=EPOCHS
)
scaler = torch.amp.GradScaler()

start_epoch = 1
best_val_loss = float("inf")
epochs_no_improve = 0

# =====================
# Resume from Checkpoint (if available)
# =====================
if CHECKPOINT_PATH.exists():
    print(f"üîÅ Resuming from {CHECKPOINT_PATH}...")
    checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)

    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    scaler.load_state_dict(checkpoint['scaler_state_dict'])

    start_epoch = checkpoint['epoch'] + 1
    best_val_loss = checkpoint['best_val_loss']
    epochs_no_improve = checkpoint['epochs_no_improve']

    print(f"‚úÖ Checkpoint loaded ‚Äî Resuming from epoch {start_epoch}")
else:
    print("üöÄ Starting training from scratch")

# =====================
# Training Loop
# =====================
for epoch in range(start_epoch, EPOCHS + 1):
    model.train()
    train_loss = 0
    for src, tgt in tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}", unit="batch"):
        src, tgt = src.to(DEVICE), tgt.to(DEVICE)
        src_mask = create_padding_mask(src, 'en')
        tgt_input, tgt_out = tgt[:, :-1], tgt[:, 1:]
        tgt_mask = generate_square_subsequent_mask(tgt_input.size(1))
        tgt_key_padding_mask = create_padding_mask(tgt_input, 'hi')

        optimizer.zero_grad()
        with torch.amp.autocast(device_type="cuda"):
            logits = model(
                src, tgt_input,
                src_key_padding_mask=src_mask,
                tgt_key_padding_mask=tgt_key_padding_mask,
                memory_key_padding_mask=src_mask,
                tgt_mask=tgt_mask
            )
            loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        train_loss += loss.item()

    avg_train_loss = train_loss / len(train_loader)

    # ---- Validation ----
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for src, tgt in val_loader:
            src, tgt = src.to(DEVICE), tgt.to(DEVICE)
            src_mask = create_padding_mask(src, 'en')
            tgt_input, tgt_out = tgt[:, :-1], tgt[:, 1:]
            tgt_mask = generate_square_subsequent_mask(tgt_input.size(1))
            tgt_key_padding_mask = create_padding_mask(tgt_input, 'hi')

            logits = model(
                src, tgt_input,
                src_key_padding_mask=src_mask,
                tgt_key_padding_mask=tgt_key_padding_mask,
                memory_key_padding_mask=src_mask,
                tgt_mask=tgt_mask
            )
            loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))
            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader)
    print(f"Epoch {epoch} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

    # ---- Save Checkpoint ----
    torch.save({
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict(),
        "scaler_state_dict": scaler.state_dict(),
        "best_val_loss": best_val_loss,
        "epochs_no_improve": epochs_no_improve,
        "config": {
            "MODEL_TYPE": "HybridCNNTransformer",
            "VOCAB_SIZE": VOCAB_SIZE,
            "MAX_LEN": MAX_LEN,
            "EPOCHS": EPOCHS,
            "LEARNING_RATE": LEARNING_RATE,
            "CLIP": CLIP,
            "BATCH_SIZE": BATCH_SIZE
        }
    }, CHECKPOINT_PATH)

    # ---- Early Stopping ----
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        epochs_no_improve = 0
        torch.save(model.state_dict(), BEST_MODEL_PATH)
        print("‚úÖ New best model saved!")
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= PATIENCE:
            print("‚èπÔ∏è Early stopping triggered.")
            break

print("üéØ Training complete.")


  from .autonotebook import tqdm as notebook_tqdm


Dataset sizes: 800000 100000 100000
üîÅ Resuming from checkpoint_cnn.pt...


  checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)


‚úÖ Checkpoint loaded ‚Äî Resuming from epoch 23


Epoch 23/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12500/12500 [12:50<00:00, 16.22batch/s]
  output = torch._nested_tensor_from_mask(


Epoch 23 | Train Loss: 1.5992 | Val Loss: 2.4764


Epoch 24/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12500/12500 [12:51<00:00, 16.21batch/s]


Epoch 24 | Train Loss: 1.5465 | Val Loss: 2.4910
‚èπÔ∏è Early stopping triggered.
üéØ Training complete.


In [None]:
# =====================
# BLEU Evaluation + Translation
# =====================
import torch
import sacrebleu
from tqdm import tqdm
import random

# Load SentencePiece models again (if running separately)
sp_en = spm.SentencePieceProcessor()
sp_hi = spm.SentencePieceProcessor()
sp_en.load("spm_en.model")
sp_hi.load("spm_hi.model")

PAD_EN, BOS_EN, EOS_EN = 0, 1, 2
PAD_HI, BOS_HI, EOS_HI = 0, 1, 2

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Reload the best model
model = HybridTransformerModel(len(sp_en), len(sp_hi)).to(DEVICE)
model.load_state_dict(torch.load("best_model_cnn.pt", map_location=DEVICE))
model.eval()

# =====================
# Greedy / Beam Decoding
# =====================
@torch.no_grad()
def translate_sentence(sentence, model, sp_en, sp_hi, max_len=64, beam_size=5):
    model.eval()
    src_ids = [BOS_EN] + sp_en.encode(sentence.lower())[:max_len-2] + [EOS_EN]
    src = torch.tensor(src_ids, dtype=torch.long, device=DEVICE).unsqueeze(0)
    src_mask = create_padding_mask(src, 'en')

    memory = model.encode(src, src_mask)

    # ---- Start with BOS token ----
    tgt = torch.tensor([[BOS_HI]], dtype=torch.long, device=DEVICE)

    for _ in range(max_len):
        tgt_mask = generate_square_subsequent_mask(tgt.size(1))
        tgt_key_padding_mask = create_padding_mask(tgt, 'hi')

        output = model.decode(
            tgt, memory,
            tgt_mask=tgt_mask,
            memory_key_padding_mask=src_mask,
            tgt_key_padding_mask=tgt_key_padding_mask
        )
        logits = model.fc_out(output[:, -1, :])
        next_token = logits.argmax(-1).item()

        tgt = torch.cat([tgt, torch.tensor([[next_token]], device=DEVICE)], dim=1)
        if next_token == EOS_HI:
            break

    decoded = sp_hi.decode([t for t in tgt.squeeze().tolist() if t not in [BOS_HI, EOS_HI, PAD_HI]])
    return decoded


# =====================
# BLEU Evaluation
# =====================
refs, hyps = [], []

print("üîç Evaluating BLEU on test set...")
for src, tgt in tqdm(test_loader, desc="Evaluating", unit="batch"):
    src, tgt = src.to(DEVICE), tgt.to(DEVICE)
    for i in range(src.size(0)):
        src_text = sp_en.decode([t for t in src[i].tolist() if t not in [BOS_EN, EOS_EN, PAD_EN]])
        tgt_text = sp_hi.decode([t for t in tgt[i].tolist() if t not in [BOS_HI, EOS_HI, PAD_HI]])
        pred_text = translate_sentence(src_text, model, sp_en, sp_hi)

        refs.append(tgt_text)
        hyps.append(pred_text)

# sacrebleu expects a list of references (list of list)
bleu = sacrebleu.corpus_bleu(hyps, [refs])
print(f"\nüåç BLEU Score on Test Set: {bleu.score:.2f}\n")

# =====================
# Qualitative Examples
# =====================
sample_indices = random.sample(range(len(test_data)), 5)
print("‚ú® Sample Translations:")
for idx in sample_indices:
    src_text = test_data[idx]["src"]
    ref_text = test_data[idx]["tgt"]
    pred_text = translate_sentence(src_text, model, sp_en, sp_hi)
    print(f"\nEN: {src_text}")
    print(f"HI (Ref): {ref_text}")
    print(f"HI (Pred): {pred_text}")
    print("-" * 50)


  model.load_state_dict(torch.load("best_model_cnn.pt", map_location=DEVICE))


üîç Evaluating BLEU on test set...


Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1563/1563 [1:49:34<00:00,  4.21s/batch]
That's 100 lines that end in a tokenized period ('.')
It looks like you forgot to detokenize your test data, which may hurt your score.
If you insist your data is detokenized, or don't care, you can suppress this message with the `force` parameter.



üåç BLEU Score on Test Set: 19.00

‚ú® Sample Translations:

EN: lahore: The Pakistan English press has showered heap of praise on legendary Indian batsman Sachin Tendulkar in their editorials, saying the game of cricket will surely be poorer without him.
HI (Ref): ‡§≤‡§æ‡§π‡•å‡§∞ ‡§™‡§æ‡§ï‡§ø‡§∏‡•ç‡§§‡§æ‡§® ‡§ï‡•Ä ‡§Ö‡§ó‡•ç‡§∞‡•á‡§Ç‡§ú‡•Ä ‡§™‡•ç‡§∞‡•á‡§∏ ‡§®‡•á ‡§Ö‡§™‡§®‡•á ‡§∏‡§Ç‡§™‡§æ‡§¶‡§ï‡•Ä‡§Ø ‡§Æ‡•á‡§Ç ‡§Æ‡§π‡§æ‡§® ‡§≠‡§æ‡§∞‡§§‡•Ä‡§Ø ‡§¨‡§≤‡•ç‡§≤‡•á‡§¨‡§æ‡§ú ‡§∏‡§ö‡§ø‡§® ‡§§‡•á‡§Ç‡§¶‡•Å‡§≤‡§ï‡§∞ ‡§ï‡•Ä ‡§§‡§æ‡§∞‡•Ä‡§´‡•ã‡§Ç ‡§ï‡•á ‡§™‡•Å‡§≤ ‡§¨‡§æ‡§Ç‡§ß‡•á ‡§π‡•à ‡§î‡§∞ ‡§≤‡§ø‡§ñ‡§æ ‡§π‡•à, ‚Äò‡§â‡§®‡§ï‡•á ‡§¨‡§ø‡§®‡§æ ‡§ï‡•ç‡§∞‡§ø‡§ï‡•á‡§ü ‡§ñ‡•á‡§≤ ‡§®‡§ø‡§∂‡•ç‡§ö‡§ø‡§§ ‡§∞‡•Ç‡§™ ‡§∏‡•á ‡§¶‡§∞‡§ø‡§¶‡•ç‡§∞‚Äô ‡§π‡•ã ‡§ú‡§æ‡§Ø‡•á‡§ó‡§æ‡•§ ‡§π‡§æ‡§≤‡§æ‡§Ç‡§ï‡§ø ‡§â‡§∞‡•ç‡§¶‡•Ç ‡§™‡•ç‡§∞‡•á‡§∏ ‡§Æ‡•á‡§Ç ‡§â‡§®‡§ï‡•á ‡§¨‡§æ‡§∞‡•á ‡§Æ‡•á‡§Ç ‡§ú‡•ç‡§Ø‡§æ‡§¶‡§æ ‡§ï‡•Å‡§õ ‡§®‡§π‡•Ä‡§Ç ‡§≤‡§ø‡§ñ‡§æ ‡§ó‡§Ø‡§æ ‡§π‡•à ‡§≤‡•á‡§ï‡§ø‡§® ‡§Ö‡§Ç‡§ó‡•ç‡§∞‡•á‡§ú‡•Ä ‡§ï‡•á ‡§Ö‡§ñ‡§¨‡§æ‡§∞‡•ã‡§Ç 

In [None]:
# ============================================================
# Full Evaluation: Baseline Transformer + Hybrid CNN + BLEU
# ============================================================

import torch
import torch.nn as nn
import sentencepiece as spm
import sacrebleu
from pathlib import Path
from tqdm import tqdm

# =====================
# Config
# =====================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAX_LEN = 64  # pos_enc length
VOCAB_SIZE = 32000  # adjust to your SP vocab
PAD_EN = PAD_HI = 0
BOS_EN = BOS_HI = 1
EOS_EN = EOS_HI = 2

BASELINE_PATH = "checkpoint_baseline.pt"
HYBRID_PATH = "checkpoint_hybrid_cnn.pt"

SP_EN_MODEL = Path("spm_en.model")
SP_HI_MODEL = Path("spm_hi.model")

# =====================
# Load SentencePiece
# =====================
sp_en = spm.SentencePieceProcessor()
sp_hi = spm.SentencePieceProcessor()
sp_en.load(str(SP_EN_MODEL))
sp_hi.load(str(SP_HI_MODEL))

# =====================
# Transformer & Hybrid Models
# =====================
class TransformerModel(nn.Module):
    def __init__(self, src_vocab, tgt_vocab, d_model=512, nhead=8, num_layers=3, dim_feedforward=1024, dropout=0.1):
        super().__init__()
        self.src_emb = nn.Embedding(src_vocab, d_model, padding_idx=PAD_EN)
        self.tgt_emb = nn.Embedding(tgt_vocab, d_model, padding_idx=PAD_HI)
        self.pos_enc = nn.Parameter(torch.zeros(1, MAX_LEN, d_model))
        self.transformer = nn.Transformer(d_model=d_model, nhead=nhead,
                                          num_encoder_layers=num_layers,
                                          num_decoder_layers=num_layers,
                                          dim_feedforward=dim_feedforward,
                                          dropout=dropout,
                                          batch_first=True)
        self.fc_out = nn.Linear(d_model, tgt_vocab)

    def encode(self, src, src_key_padding_mask=None):
        src_emb = self.src_emb(src) + self.pos_enc[:, :src.size(1), :]
        return self.transformer.encoder(src_emb, src_key_padding_mask=src_key_padding_mask)

    def decode(self, tgt, memory, tgt_mask=None, memory_key_padding_mask=None, tgt_key_padding_mask=None):
        tgt_emb = self.tgt_emb(tgt) + self.pos_enc[:, :tgt.size(1), :]
        return self.transformer.decoder(tgt_emb, memory, tgt_mask=tgt_mask,
                                        memory_key_padding_mask=memory_key_padding_mask,
                                        tgt_key_padding_mask=tgt_key_padding_mask)

    def forward(self, src, tgt, src_key_padding_mask=None, tgt_key_padding_mask=None,
                memory_key_padding_mask=None, tgt_mask=None):
        memory = self.encode(src, src_key_padding_mask)
        output = self.decode(tgt, memory, tgt_mask, memory_key_padding_mask, tgt_key_padding_mask)
        return self.fc_out(output)

class CNNFeatureExtractor(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.conv1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv1d(embed_dim, embed_dim, 3, padding=1)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        residual = x
        x = x.transpose(1, 2)
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = x.transpose(1, 2)
        return self.norm(x + residual)

class HybridTransformerModel(nn.Module):
    def __init__(self, src_vocab, tgt_vocab, d_model=512, nhead=8, num_layers=3, dim_feedforward=1024, dropout=0.1):
        super().__init__()
        self.src_emb = nn.Embedding(src_vocab, d_model, padding_idx=PAD_EN)
        self.tgt_emb = nn.Embedding(tgt_vocab, d_model, padding_idx=PAD_HI)
        self.pos_enc = nn.Parameter(torch.zeros(1, MAX_LEN, d_model))
        self.cnn_encoder = CNNFeatureExtractor(d_model)
        self.transformer = nn.Transformer(d_model=d_model, nhead=nhead,
                                          num_encoder_layers=num_layers,
                                          num_decoder_layers=num_layers,
                                          dim_feedforward=dim_feedforward,
                                          dropout=dropout,
                                          batch_first=True)
        self.fc_out = nn.Linear(d_model, tgt_vocab)

    def encode(self, src, src_key_padding_mask=None):
        src_emb = self.src_emb(src) + self.pos_enc[:, :src.size(1), :]
        src_cnn = self.cnn_encoder(src_emb)
        return self.transformer.encoder(src_cnn, src_key_padding_mask=src_key_padding_mask)

    def decode(self, tgt, memory, tgt_mask=None, memory_key_padding_mask=None, tgt_key_padding_mask=None):
        tgt_emb = self.tgt_emb(tgt) + self.pos_enc[:, :tgt.size(1), :]
        return self.transformer.decoder(tgt_emb, memory, tgt_mask=tgt_mask,
                                        memory_key_padding_mask=memory_key_padding_mask,
                                        tgt_key_padding_mask=tgt_key_padding_mask)

    def forward(self, src, tgt, src_key_padding_mask=None, tgt_key_padding_mask=None,
                memory_key_padding_mask=None, tgt_mask=None):
        memory = self.encode(src, src_key_padding_mask)
        output = self.decode(tgt, memory, tgt_mask, memory_key_padding_mask, tgt_key_padding_mask)
        return self.fc_out(output)

# =====================
# Load checkpoints
# =====================
baseline_model = TransformerModel(VOCAB_SIZE, VOCAB_SIZE).to(DEVICE)
hybrid_model = HybridTransformerModel(VOCAB_SIZE, VOCAB_SIZE).to(DEVICE)

def load_checkpoint(model, path):
    checkpoint = torch.load(path, map_location=DEVICE)
    state_dict = checkpoint.get("model_state_dict", checkpoint)
    # Adjust pos_enc size
    if "pos_enc" in state_dict and state_dict["pos_enc"].shape[1] != MAX_LEN:
        state_dict["pos_enc"] = state_dict["pos_enc"][:, :MAX_LEN, :]
    model.load_state_dict(state_dict)
    return model

baseline_model = load_checkpoint(baseline_model, BASELINE_PATH)
hybrid_model = load_checkpoint(hybrid_model, HYBRID_PATH)

baseline_model.eval()
hybrid_model.eval()

# =====================
# Greedy translation function
# =====================
@torch.no_grad()
def translate(model, sentence, sp_en, sp_hi, max_len=MAX_LEN):
    src_ids = [BOS_EN] + sp_en.encode(sentence)[:max_len-2] + [EOS_EN]
    src = torch.tensor(src_ids, device=DEVICE).unsqueeze(0)
    tgt = torch.tensor([[BOS_HI]], device=DEVICE)

    for _ in range(max_len):
        out = model(src, tgt)
        next_token = out[:, -1, :].argmax(-1).unsqueeze(1)
        tgt = torch.cat([tgt, next_token], dim=1)
        if next_token.item() == EOS_HI:
            break

    decoded = sp_hi.decode([t for t in tgt[0].tolist() if t not in [BOS_HI, EOS_HI, PAD_HI]])
    return decoded

# =====================
# 20 sample sentences (EN + reference HI)
# =====================
test_samples = new_test_samples = [
    ("The city is preparing for heavy rainfall this week.", "‡§∂‡§π‡§∞ ‡§á‡§∏ ‡§∏‡§™‡•ç‡§§‡§æ‡§π ‡§≠‡§æ‡§∞‡•Ä ‡§µ‡§∞‡•ç‡§∑‡§æ ‡§ï‡•á ‡§≤‡§ø‡§è ‡§§‡•à‡§Ø‡§æ‡§∞‡•Ä ‡§ï‡§∞ ‡§∞‡§π‡§æ ‡§π‡•à‡•§"),
    ("Researchers developed a new method for early disease detection.", "‡§∂‡•ã‡§ß‡§ï‡§∞‡•ç‡§§‡§æ‡§ì‡§Ç ‡§®‡•á ‡§ú‡§≤‡•ç‡§¶‡•Ä ‡§∞‡•ã‡§ó ‡§™‡§π‡§ö‡§æ‡§® ‡§ï‡•á ‡§≤‡§ø‡§è ‡§è‡§ï ‡§®‡§à ‡§µ‡§ø‡§ß‡§ø ‡§µ‡§ø‡§ï‡§∏‡§ø‡§§ ‡§ï‡•Ä‡•§"),
    ("The economy is showing signs of steady growth.", "‡§Ö‡§∞‡•ç‡§•‡§µ‡•ç‡§Ø‡§µ‡§∏‡•ç‡§•‡§æ ‡§∏‡•ç‡§•‡§ø‡§∞ ‡§µ‡§ø‡§ï‡§æ‡§∏ ‡§ï‡•á ‡§∏‡§Ç‡§ï‡•á‡§§ ‡§¶‡§ø‡§ñ‡§æ ‡§∞‡§π‡•Ä ‡§π‡•à‡•§"),
    ("Students participated enthusiastically in the science fair.", "‡§õ‡§æ‡§§‡•ç‡§∞‡•ã‡§Ç ‡§®‡•á ‡§µ‡§ø‡§ú‡•ç‡§û‡§æ‡§® ‡§Æ‡•á‡§≤‡§æ ‡§Æ‡•á‡§Ç ‡§â‡§§‡•ç‡§∏‡§æ‡§π‡§™‡•Ç‡§∞‡•ç‡§µ‡§ï ‡§≠‡§æ‡§ó ‡§≤‡§ø‡§Ø‡§æ‡•§"),
    ("The festival attracted tourists from all over the world.", "‡§§‡•ç‡§Ø‡•ã‡§π‡§æ‡§∞ ‡§®‡•á ‡§¶‡•Å‡§®‡§ø‡§Ø‡§æ ‡§≠‡§∞ ‡§∏‡•á ‡§™‡§∞‡•ç‡§Ø‡§ü‡§ï‡•ã‡§Ç ‡§ï‡•ã ‡§Ü‡§ï‡§∞‡•ç‡§∑‡§ø‡§§ ‡§ï‡§ø‡§Ø‡§æ‡•§"),
    ("Solar energy installations are increasing rapidly in rural areas.", "‡§ó‡•ç‡§∞‡§æ‡§Æ‡•Ä‡§£ ‡§ï‡•ç‡§∑‡•á‡§§‡•ç‡§∞‡•ã‡§Ç ‡§Æ‡•á‡§Ç ‡§∏‡•å‡§∞ ‡§ä‡§∞‡•ç‡§ú‡§æ ‡§∏‡•ç‡§•‡§æ‡§™‡§®‡§æ ‡§§‡•á‡§ú‡•Ä ‡§∏‡•á ‡§¨‡§¢‡§º ‡§∞‡§π‡•Ä ‡§π‡•à‡•§"),
    ("The new policy aims to reduce air pollution in cities.", "‡§®‡§à ‡§®‡•Ä‡§§‡§ø ‡§ï‡§æ ‡§â‡§¶‡•ç‡§¶‡•á‡§∂‡•ç‡§Ø ‡§∂‡§π‡§∞‡•ã‡§Ç ‡§Æ‡•á‡§Ç ‡§µ‡§æ‡§Ø‡•Å ‡§™‡•ç‡§∞‡§¶‡•Ç‡§∑‡§£ ‡§ï‡•ã ‡§ï‡§Æ ‡§ï‡§∞‡§®‡§æ ‡§π‡•à‡•§"),
    ("Artificial intelligence can improve healthcare diagnostics.", "‡§ï‡•É‡§§‡•ç‡§∞‡§ø‡§Æ ‡§¨‡•Å‡§¶‡•ç‡§ß‡§ø‡§Æ‡§§‡•ç‡§§‡§æ ‡§∏‡•ç‡§µ‡§æ‡§∏‡•ç‡§•‡•ç‡§Ø ‡§¶‡•á‡§ñ‡§≠‡§æ‡§≤ ‡§®‡§ø‡§¶‡§æ‡§® ‡§Æ‡•á‡§Ç ‡§∏‡•Å‡§ß‡§æ‡§∞ ‡§ï‡§∞ ‡§∏‡§ï‡§§‡•Ä ‡§π‡•à‡•§"),
    ("The government announced relief measures for flood-affected areas.", "‡§∏‡§∞‡§ï‡§æ‡§∞ ‡§®‡•á ‡§¨‡§æ‡§¢‡§º ‡§™‡•ç‡§∞‡§≠‡§æ‡§µ‡§ø‡§§ ‡§ï‡•ç‡§∑‡•á‡§§‡•ç‡§∞‡•ã‡§Ç ‡§ï‡•á ‡§≤‡§ø‡§è ‡§∞‡§æ‡§π‡§§ ‡§â‡§™‡§æ‡§Ø‡•ã‡§Ç ‡§ï‡•Ä ‡§ò‡•ã‡§∑‡§£‡§æ ‡§ï‡•Ä‡•§"),
    ("Wildlife conservation is critical for maintaining biodiversity.", "‡§ú‡§Ç‡§ó‡§≤‡•Ä ‡§ú‡•Ä‡§µ‡§® ‡§∏‡§Ç‡§∞‡§ï‡•ç‡§∑‡§£ ‡§ú‡•à‡§µ ‡§µ‡§ø‡§µ‡§ø‡§ß‡§§‡§æ ‡§¨‡§®‡§æ‡§è ‡§∞‡§ñ‡§®‡•á ‡§ï‡•á ‡§≤‡§ø‡§è ‡§Æ‡§π‡§§‡•ç‡§µ‡§™‡•Ç‡§∞‡•ç‡§£ ‡§π‡•à‡•§"),
    ("The company launched a new smartphone model last week.", "‡§ï‡§Ç‡§™‡§®‡•Ä ‡§®‡•á ‡§™‡§ø‡§õ‡§≤‡•á ‡§∏‡§™‡•ç‡§§‡§æ‡§π ‡§è‡§ï ‡§®‡§Ø‡§æ ‡§∏‡•ç‡§Æ‡§æ‡§∞‡•ç‡§ü‡§´‡•ã‡§® ‡§Æ‡•â‡§°‡§≤ ‡§≤‡•â‡§®‡•ç‡§ö ‡§ï‡§ø‡§Ø‡§æ‡•§"),
    ("International trade agreements influence domestic markets.", "‡§Ö‡§Ç‡§§‡§∞‡§∞‡§æ‡§∑‡•ç‡§ü‡•ç‡§∞‡•Ä‡§Ø ‡§µ‡•ç‡§Ø‡§æ‡§™‡§æ‡§∞ ‡§∏‡§Æ‡§ù‡•å‡§§‡•á ‡§ò‡§∞‡•á‡§≤‡•Ç ‡§¨‡§æ‡§ú‡§æ‡§∞‡•ã‡§Ç ‡§ï‡•ã ‡§™‡•ç‡§∞‡§≠‡§æ‡§µ‡§ø‡§§ ‡§ï‡§∞‡§§‡•á ‡§π‡•à‡§Ç‡•§"),
    ("The team developed an innovative software solution.", "‡§ü‡•Ä‡§Æ ‡§®‡•á ‡§è‡§ï ‡§®‡§µ‡•ã‡§®‡•ç‡§Æ‡•á‡§∑‡•Ä ‡§∏‡•â‡§´‡•ç‡§ü‡§µ‡•á‡§Ø‡§∞ ‡§∏‡§Æ‡§æ‡§ß‡§æ‡§® ‡§µ‡§ø‡§ï‡§∏‡§ø‡§§ ‡§ï‡§ø‡§Ø‡§æ‡•§"),
    ("Urban transport systems are facing challenges due to population growth.", "‡§ú‡§®‡§∏‡§Ç‡§ñ‡•ç‡§Ø‡§æ ‡§µ‡•É‡§¶‡•ç‡§ß‡§ø ‡§ï‡•á ‡§ï‡§æ‡§∞‡§£ ‡§∂‡§π‡§∞‡•Ä ‡§™‡§∞‡§ø‡§µ‡§π‡§® ‡§™‡•ç‡§∞‡§£‡§æ‡§≤‡•Ä ‡§ö‡•Å‡§®‡•å‡§§‡§ø‡§Ø‡•ã‡§Ç ‡§ï‡§æ ‡§∏‡§æ‡§Æ‡§®‡§æ ‡§ï‡§∞ ‡§∞‡§π‡•Ä ‡§π‡•à‡•§"),
    ("The government is investing in renewable energy projects.", "‡§∏‡§∞‡§ï‡§æ‡§∞ ‡§®‡§µ‡•Ä‡§ï‡§∞‡§£‡•Ä‡§Ø ‡§ä‡§∞‡•ç‡§ú‡§æ ‡§™‡§∞‡§ø‡§Ø‡•ã‡§ú‡§®‡§æ‡§ì‡§Ç ‡§Æ‡•á‡§Ç ‡§®‡§ø‡§µ‡•á‡§∂ ‡§ï‡§∞ ‡§∞‡§π‡•Ä ‡§π‡•à‡•§"),
    ("Students are encouraged to engage in extracurricular activities.", "‡§õ‡§æ‡§§‡•ç‡§∞‡•ã‡§Ç ‡§ï‡•ã ‡§™‡§æ‡§†‡•ç‡§Ø‡•á‡§§‡§∞ ‡§ó‡§§‡§ø‡§µ‡§ø‡§ß‡§ø‡§Ø‡•ã‡§Ç ‡§Æ‡•á‡§Ç ‡§≠‡§æ‡§ó ‡§≤‡•á‡§®‡•á ‡§ï‡•á ‡§≤‡§ø‡§è ‡§™‡•ç‡§∞‡•ã‡§§‡•ç‡§∏‡§æ‡§π‡§ø‡§§ ‡§ï‡§ø‡§Ø‡§æ ‡§ú‡§æ‡§§‡§æ ‡§π‡•à‡•§"),
    ("The company reported a decline in operating costs this quarter.", "‡§ï‡§Ç‡§™‡§®‡•Ä ‡§®‡•á ‡§á‡§∏ ‡§§‡§ø‡§Æ‡§æ‡§π‡•Ä ‡§Æ‡•á‡§Ç ‡§∏‡§Ç‡§ö‡§æ‡§≤‡§® ‡§≤‡§æ‡§ó‡§§ ‡§Æ‡•á‡§Ç ‡§ï‡§Æ‡•Ä ‡§ï‡•Ä ‡§∞‡§ø‡§™‡•ã‡§∞‡•ç‡§ü ‡§¶‡•Ä‡•§"),
    ("Climate change poses a threat to coastal communities.", "‡§ú‡§≤‡§µ‡§æ‡§Ø‡•Å ‡§™‡§∞‡§ø‡§µ‡§∞‡•ç‡§§‡§® ‡§§‡§ü‡•Ä‡§Ø ‡§∏‡§Æ‡•Å‡§¶‡§æ‡§Ø‡•ã‡§Ç ‡§ï‡•á ‡§≤‡§ø‡§è ‡§ñ‡§§‡§∞‡§æ ‡§â‡§§‡•ç‡§™‡§®‡•ç‡§® ‡§ï‡§∞‡§§‡§æ ‡§π‡•à‡•§"),
    ("The research team published their findings in a leading journal.", "‡§Ö‡§®‡•Å‡§∏‡§Ç‡§ß‡§æ‡§® ‡§ü‡•Ä‡§Æ ‡§®‡•á ‡§Ö‡§™‡§®‡•á ‡§®‡§ø‡§∑‡•ç‡§ï‡§∞‡•ç‡§∑ ‡§è‡§ï ‡§™‡•ç‡§∞‡§Æ‡•Å‡§ñ ‡§ú‡§∞‡•ç‡§®‡§≤ ‡§Æ‡•á‡§Ç ‡§™‡•ç‡§∞‡§ï‡§æ‡§∂‡§ø‡§§ ‡§ï‡§ø‡§è‡•§"),
    ("Global cooperation is necessary to tackle pandemics.", "‡§Æ‡§π‡§æ‡§Æ‡§æ‡§∞‡•Ä ‡§∏‡•á ‡§®‡§ø‡§™‡§ü‡§®‡•á ‡§ï‡•á ‡§≤‡§ø‡§è ‡§µ‡•à‡§∂‡•ç‡§µ‡§ø‡§ï ‡§∏‡§π‡§Ø‡•ã‡§ó ‡§Ü‡§µ‡§∂‡•ç‡§Ø‡§ï ‡§π‡•à‡•§"),
    ("India launched its first indigenous satellite.", "‡§≠‡§æ‡§∞‡§§ ‡§®‡•á ‡§Ö‡§™‡§®‡§æ ‡§™‡§π‡§≤‡§æ ‡§∏‡•ç‡§µ‡§¶‡•á‡§∂‡•Ä ‡§â‡§™‡§ó‡•ç‡§∞‡§π ‡§≤‡•â‡§®‡•ç‡§ö ‡§ï‡§ø‡§Ø‡§æ‡•§"),
    ("The prime minister met foreign delegates at the summit.", "‡§™‡•ç‡§∞‡§ß‡§æ‡§®‡§Æ‡§Ç‡§§‡•ç‡§∞‡•Ä ‡§®‡•á ‡§∂‡§ø‡§ñ‡§∞ ‡§∏‡§Æ‡•ç‡§Æ‡•á‡§≤‡§® ‡§Æ‡•á‡§Ç ‡§µ‡§ø‡§¶‡•á‡§∂‡•Ä ‡§™‡•ç‡§∞‡§§‡§ø‡§®‡§ø‡§ß‡§ø‡§Ø‡•ã‡§Ç ‡§∏‡•á ‡§Æ‡•Å‡§≤‡§æ‡§ï‡§æ‡§§ ‡§ï‡•Ä‡•§"),
    ("This research focuses on low-resource machine translation.", "‡§Ø‡§π ‡§∂‡•ã‡§ß ‡§ï‡§Æ ‡§∏‡§Ç‡§∏‡§æ‡§ß‡§® ‡§µ‡§æ‡§≤‡•Ä ‡§Æ‡§∂‡•Ä‡§® ‡§Ö‡§®‡•Å‡§µ‡§æ‡§¶ ‡§™‡§∞ ‡§ï‡•á‡§Ç‡§¶‡•ç‡§∞‡§ø‡§§ ‡§π‡•à‡•§"),
    ("Its movement was captured on CCTV.", "‡§á‡§∏‡§ï‡•á ‡§Ü‡§Ç‡§¶‡•ã‡§≤‡§® ‡§ï‡•ã ‡§∏‡•Ä‡§∏‡•Ä‡§ü‡•Ä‡§µ‡•Ä ‡§Æ‡•á‡§Ç ‡§ï‡•à‡§¶ ‡§ï‡§ø‡§Ø‡§æ ‡§ó‡§Ø‡§æ‡•§"),
    ("The two leaders discussed bilateral ties and cooperation in counter-terrorism.", "‡§¶‡•ã ‡§®‡•á‡§§‡§æ‡§ì‡§Ç ‡§®‡•á ‡§¶‡•ç‡§µ‡§ø‡§™‡§ï‡•ç‡§∑‡•Ä‡§Ø ‡§∏‡§Ç‡§¨‡§Ç‡§ß‡•ã‡§Ç ‡§î‡§∞ ‡§Ü‡§§‡§Ç‡§ï‡§µ‡§æ‡§¶ ‡§®‡§ø‡§∞‡•ã‡§ß‡§ï ‡§∏‡§π‡§Ø‡•ã‡§ó ‡§™‡§∞ ‡§ö‡§∞‡•ç‡§ö‡§æ ‡§ï‡•Ä‡•§"),
    ("The data shows significant improvement in translation accuracy.", "‡§°‡•á‡§ü‡§æ ‡§Ö‡§®‡•Å‡§µ‡§æ‡§¶ ‡§ï‡•Ä ‡§∏‡§ü‡•Ä‡§ï‡§§‡§æ ‡§Æ‡•á‡§Ç ‡§Æ‡§π‡§§‡•ç‡§µ‡§™‡•Ç‡§∞‡•ç‡§£ ‡§∏‡•Å‡§ß‡§æ‡§∞ ‡§¶‡§ø‡§ñ‡§æ‡§§‡§æ ‡§π‡•à‡•§"),
    ("He emphasized the importance of sustainable development.", "‡§â‡§®‡•ç‡§π‡•ã‡§Ç‡§®‡•á ‡§∏‡§§‡§§ ‡§µ‡§ø‡§ï‡§æ‡§∏ ‡§ï‡•á ‡§Æ‡§π‡§§‡•ç‡§µ ‡§™‡§∞ ‡§ú‡•ã‡§∞ ‡§¶‡§ø‡§Ø‡§æ‡•§"),
    ("The company reported record profits this quarter.", "‡§ï‡§Ç‡§™‡§®‡•Ä ‡§®‡•á ‡§á‡§∏ ‡§§‡§ø‡§Æ‡§æ‡§π‡•Ä ‡§Æ‡•á‡§Ç ‡§∞‡§ø‡§ï‡•â‡§∞‡•ç‡§° ‡§≤‡§æ‡§≠ ‡§ï‡•Ä ‡§∏‡•Ç‡§ö‡§®‡§æ ‡§¶‡•Ä‡•§"),
    ("Students are encouraged to participate in research projects.", "‡§õ‡§æ‡§§‡•ç‡§∞‡•ã‡§Ç ‡§ï‡•ã ‡§Ö‡§®‡•Å‡§∏‡§Ç‡§ß‡§æ‡§® ‡§™‡§∞‡§ø‡§Ø‡•ã‡§ú‡§®‡§æ‡§ì‡§Ç ‡§Æ‡•á‡§Ç ‡§≠‡§æ‡§ó ‡§≤‡•á‡§®‡•á ‡§ï‡•á ‡§≤‡§ø‡§è ‡§™‡•ç‡§∞‡•ã‡§§‡•ç‡§∏‡§æ‡§π‡§ø‡§§ ‡§ï‡§ø‡§Ø‡§æ ‡§ú‡§æ‡§§‡§æ ‡§π‡•à‡•§"),
    ("The new model achieved higher BLEU scores than the baseline.", "‡§®‡§Ø‡§æ ‡§Æ‡•â‡§°‡§≤ ‡§¨‡•á‡§∏‡§≤‡§æ‡§á‡§® ‡§ï‡•Ä ‡§§‡•Å‡§≤‡§®‡§æ ‡§Æ‡•á‡§Ç ‡§â‡§ö‡•ç‡§ö BLEU ‡§∏‡•ç‡§ï‡•ã‡§∞ ‡§™‡•ç‡§∞‡§æ‡§™‡•ç‡§§ ‡§ï‡§∞‡§§‡§æ ‡§π‡•à‡•§"),
    ("Climate change is affecting global agriculture.", "‡§ú‡§≤‡§µ‡§æ‡§Ø‡•Å ‡§™‡§∞‡§ø‡§µ‡§∞‡•ç‡§§‡§® ‡§µ‡•à‡§∂‡•ç‡§µ‡§ø‡§ï ‡§ï‡•É‡§∑‡§ø ‡§ï‡•ã ‡§™‡•ç‡§∞‡§≠‡§æ‡§µ‡§ø‡§§ ‡§ï‡§∞ ‡§∞‡§π‡§æ ‡§π‡•à‡•§"),
    ("Vaccination campaigns have reduced disease incidence.", "‡§ü‡•Ä‡§ï‡§æ‡§ï‡§∞‡§£ ‡§Ö‡§≠‡§ø‡§Ø‡§æ‡§®‡•ã‡§Ç ‡§®‡•á ‡§∞‡•ã‡§ó ‡§ï‡•Ä ‡§ò‡§ü‡§®‡§æ‡§ì‡§Ç ‡§ï‡•ã ‡§ï‡§Æ ‡§ï‡§ø‡§Ø‡§æ ‡§π‡•à‡•§"),
    ("Artificial intelligence is transforming industries rapidly.", "‡§ï‡•É‡§§‡•ç‡§∞‡§ø‡§Æ ‡§¨‡•Å‡§¶‡•ç‡§ß‡§ø‡§Æ‡§§‡•ç‡§§‡§æ ‡§â‡§¶‡•ç‡§Ø‡•ã‡§ó‡•ã‡§Ç ‡§ï‡•ã ‡§§‡•á‡§ú‡•Ä ‡§∏‡•á ‡§¨‡§¶‡§≤ ‡§∞‡§π‡•Ä ‡§π‡•à‡•§"),
    ("The movie received critical acclaim.", "‡§á‡§∏ ‡§´‡§ø‡§≤‡•ç‡§Æ ‡§®‡•á ‡§Ü‡§≤‡•ã‡§ö‡§®‡§æ‡§§‡•ç‡§Æ‡§ï ‡§™‡•ç‡§∞‡§∂‡§Ç‡§∏‡§æ ‡§™‡•ç‡§∞‡§æ‡§™‡•ç‡§§ ‡§ï‡•Ä‡•§"),
    ("Electric vehicles are becoming more popular worldwide.", "‡§µ‡§ø‡§¶‡•ç‡§Ø‡•Å‡§§ ‡§µ‡§æ‡§π‡§® ‡§¶‡•Å‡§®‡§ø‡§Ø‡§æ ‡§≠‡§∞ ‡§Æ‡•á‡§Ç ‡§Ö‡§ß‡§ø‡§ï ‡§≤‡•ã‡§ï‡§™‡•ç‡§∞‡§ø‡§Ø ‡§π‡•ã ‡§∞‡§π‡•á ‡§π‡•à‡§Ç‡•§"),
    ("Renewable energy sources are crucial for sustainability.", "‡§®‡§µ‡•Ä‡§ï‡§∞‡§£‡•Ä‡§Ø ‡§ä‡§∞‡•ç‡§ú‡§æ ‡§∏‡•ç‡§∞‡•ã‡§§ ‡§∏‡§§‡§§‡§§‡§æ ‡§ï‡•á ‡§≤‡§ø‡§è ‡§Æ‡§π‡§§‡•ç‡§µ‡§™‡•Ç‡§∞‡•ç‡§£ ‡§π‡•à‡§Ç‡•§"),
    ("The government announced new education policies.", "‡§∏‡§∞‡§ï‡§æ‡§∞ ‡§®‡•á ‡§®‡§à ‡§∂‡§ø‡§ï‡•ç‡§∑‡§æ ‡§®‡•Ä‡§§‡§ø‡§Ø‡•ã‡§Ç ‡§ï‡•Ä ‡§ò‡•ã‡§∑‡§£‡§æ ‡§ï‡•Ä‡•§"),
    ("Space exploration has advanced significantly in recent years.", "‡§Ö‡§Ç‡§§‡§∞‡§ø‡§ï‡•ç‡§∑ ‡§Ö‡§®‡•ç‡§µ‡•á‡§∑‡§£ ‡§®‡•á ‡§π‡§æ‡§≤ ‡§ï‡•á ‡§µ‡§∞‡•ç‡§∑‡•ã‡§Ç ‡§Æ‡•á‡§Ç ‡§Æ‡§π‡§§‡•ç‡§µ‡§™‡•Ç‡§∞‡•ç‡§£ ‡§™‡•ç‡§∞‡§ó‡§§‡§ø ‡§ï‡•Ä ‡§π‡•à‡•§"),
    ("The sports team won the national championship.", "‡§ñ‡•á‡§≤ ‡§ü‡•Ä‡§Æ ‡§®‡•á ‡§∞‡§æ‡§∑‡•ç‡§ü‡•ç‡§∞‡•Ä‡§Ø ‡§ö‡•à‡§Ç‡§™‡§ø‡§Ø‡§®‡§∂‡§ø‡§™ ‡§ú‡•Ä‡§§‡•Ä‡•§"),
    ("Global trade agreements impact economic growth.", "‡§µ‡•à‡§∂‡•ç‡§µ‡§ø‡§ï ‡§µ‡•ç‡§Ø‡§æ‡§™‡§æ‡§∞ ‡§∏‡§Æ‡§ù‡•å‡§§‡•á ‡§Ü‡§∞‡•ç‡§•‡§ø‡§ï ‡§µ‡•É‡§¶‡•ç‡§ß‡§ø ‡§ï‡•ã ‡§™‡•ç‡§∞‡§≠‡§æ‡§µ‡§ø‡§§ ‡§ï‡§∞‡§§‡•á ‡§π‡•à‡§Ç‡•§"),
]


# =====================
# Translate & save results
# =====================
results = []
for en, hi_ref in test_samples:
    baseline_out = translate(baseline_model, en, sp_en, sp_hi)
    hybrid_out = translate(hybrid_model, en, sp_en, sp_hi)

    print(f"\nEN: {en}")
    print(f"Reference HI: {hi_ref}")
    print(f"Baseline: {baseline_out}")
    print(f"Hybrid CNN: {hybrid_out}")

    results.append({
        "EN": en,
        "Reference": hi_ref,
        "Baseline": baseline_out,
        "2 layer CNN": hybrid_out
    })

# Save results to file
with open("translation_comparison.txt", "w", encoding="utf-8") as f:
    for r in results:
        f.write(f"EN: {r['EN']}\nReference HI: {r['Reference']}\nBaseline: {r['Baseline']}\nHybrid CNN: {r['Hybrid']}\n\n")

print("\n‚úÖ Translations saved to 'translation_comparison.txt'")

# =====================
# Compute BLEU for Hybrid model
# =====================
refs = [[r["Reference"]] for r in results]  # list of list
hyps = [r["Hybrid"] for r in results]

bleu = sacrebleu.corpus_bleu(hyps, refs)
print(f"\nüåç BLEU Score (Hybrid CNN vs Reference): {bleu.score:.2f}")


  checkpoint = torch.load(path, map_location=DEVICE)



EN: The city is preparing for heavy rainfall this week.
Reference HI: ‡§∂‡§π‡§∞ ‡§á‡§∏ ‡§∏‡§™‡•ç‡§§‡§æ‡§π ‡§≠‡§æ‡§∞‡•Ä ‡§µ‡§∞‡•ç‡§∑‡§æ ‡§ï‡•á ‡§≤‡§ø‡§è ‡§§‡•à‡§Ø‡§æ‡§∞‡•Ä ‡§ï‡§∞ ‡§∞‡§π‡§æ ‡§π‡•à‡•§
Baseline: ‡§á‡§∏ ‡§∏‡§™‡•ç‡§§‡§æ‡§π ‡§§‡•á‡§ú ‡§¨‡§æ‡§∞‡§ø‡§∂ ‡§ï‡•Ä ‡§§‡•à‡§Ø‡§æ‡§∞‡§ø‡§Ø‡§æ‡§Ç ‡§ö‡§≤ ‡§∞‡§π‡•Ä ‡§π‡•à‡§Ç‡•§
Hybrid CNN: ‡§á‡§∏ ‡§∏‡§™‡•ç‡§§‡§æ‡§π ‡§§‡•á‡§ú ‡§¨‡§æ‡§∞‡§ø‡§∂ ‡§ï‡•Ä ‡§§‡•à‡§Ø‡§æ‡§∞‡•Ä ‡§ï‡§∞ ‡§∞‡§π‡§æ ‡§π‡•à‡•§

EN: Researchers developed a new method for early disease detection.
Reference HI: ‡§∂‡•ã‡§ß‡§ï‡§∞‡•ç‡§§‡§æ‡§ì‡§Ç ‡§®‡•á ‡§ú‡§≤‡•ç‡§¶‡•Ä ‡§∞‡•ã‡§ó ‡§™‡§π‡§ö‡§æ‡§® ‡§ï‡•á ‡§≤‡§ø‡§è ‡§è‡§ï ‡§®‡§à ‡§µ‡§ø‡§ß‡§ø ‡§µ‡§ø‡§ï‡§∏‡§ø‡§§ ‡§ï‡•Ä‡•§
Baseline: """ ""‡§ë‡§ï‡•ç‡§∏‡§´‡•ã‡§∞‡•ç‡§°‡•á‡§¨‡§≤‡•ç‡§∏ ‡§®‡•á ‡§∂‡•Å‡§∞‡•Å‡§Ü‡§§‡•Ä ‡§¨‡•Ä‡§Æ‡§æ‡§∞‡•Ä ‡§ï‡§æ ‡§™‡§§‡§æ ‡§≤‡§ó‡§æ‡§®‡•á ‡§ï‡•Ä ‡§è‡§ï ‡§®‡§Ø‡§æ ‡§§‡§∞‡•Ä‡§ï‡§æ ‡§µ‡§ø‡§ï‡§∏‡§ø‡§§ ‡§ï‡§ø‡§Ø‡§æ ‡§π‡•à."""
Hybrid CNN: ‡§∞‡§æ‡§´‡•à‡§≤‡•á ‡§®‡•á ‡§ú‡§≤‡•ç‡§¶‡•Ä ‡§ï‡•Ä ‡§™‡§π‡§ö‡§æ‡§® ‡§ï‡•á ‡§≤‡§ø‡§è ‡§è‡

In [1]:
import sys
print(sys.executable)


C:\Users\Ayush\.conda\envs\mt_env\python.exe


In [2]:
# =====================
# Install & Imports
# =====================
!pip install -q datasets sentencepiece sacrebleu torch torchvision torchaudio tqdm

import os, random
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from datasets import load_dataset
import sentencepiece as spm
import sacrebleu

# =====================
# Config
# =====================
SEED = 42
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
VOCAB_SIZE = 16000
BATCH_SIZE = 64
MAX_LEN = 64
MAX_GEN_LEN = 64
EPOCHS = 20
CLIP = 1.0
LEARNING_RATE = 3e-4
BEAM_SIZE = 5
BEST_MODEL_PATH = Path("best_model_multiscale.pt")
CHECKPOINT_PATH = Path("checkpoint_multiscale.pt")

PATIENCE = 5

torch.manual_seed(SEED)
random.seed(SEED)

# =====================
# Load dataset
# =====================
full_dataset = load_dataset("ai4bharat/samanantar", "hi", split="train")
full_dataset = full_dataset.shuffle(seed=SEED)

NUM_EXAMPLES = min(1_000_000, len(full_dataset))
subset = full_dataset.select(range(NUM_EXAMPLES))

# Split 80/10/10
train_end = int(0.8 * len(subset))
val_end = int(0.9 * len(subset))
train_data = subset.select(range(0, train_end))
val_data = subset.select(range(train_end, val_end))
test_data = subset.select(range(val_end, len(subset)))
print("Dataset sizes:", len(train_data), len(val_data), len(test_data))

# =====================
# SentencePiece
# =====================
SP_EN_MODEL = Path("spm_en.model")
SP_HI_MODEL = Path("spm_hi.model")

def write_lines(dataset_split, src_path, tgt_path):
    with open(src_path, "w", encoding="utf-8") as sf, open(tgt_path, "w", encoding="utf-8") as tf:
        for ex in dataset_split:
            sf.write(ex["src"].strip().lower() + "\n")
            tf.write(ex["tgt"].strip() + "\n")

write_lines(train_data, "train.en", "train.hi")

if not SP_EN_MODEL.exists() or not SP_HI_MODEL.exists():
    print("Training SentencePiece...")
    spm.SentencePieceTrainer.Train(
        f"--input=train.en --model_prefix=spm_en --vocab_size={VOCAB_SIZE} "
        f"--character_coverage=1.0 --model_type=unigram"
    )
    spm.SentencePieceTrainer.Train(
        f"--input=train.hi --model_prefix=spm_hi --vocab_size={VOCAB_SIZE} "
        f"--character_coverage=0.9995 --model_type=unigram"
    )

sp_en = spm.SentencePieceProcessor()
sp_hi = spm.SentencePieceProcessor()
sp_en.load(str(SP_EN_MODEL))
sp_hi.load(str(SP_HI_MODEL))

PAD_EN, BOS_EN, EOS_EN = 0, 1, 2
PAD_HI, BOS_HI, EOS_HI = 0, 1, 2

# =====================
# Dataset & DataLoader
# =====================
class NMTDataset(Dataset):
    def __init__(self, dataset, src_sp, tgt_sp, max_len=MAX_LEN):
        self.dataset = dataset
        self.src_sp = src_sp
        self.tgt_sp = tgt_sp
        self.max_len = max_len

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        src_text = self.dataset[idx]["src"].lower()
        tgt_text = self.dataset[idx]["tgt"]
        src_ids = [BOS_EN] + self.src_sp.encode(src_text)[:self.max_len-2] + [EOS_EN]
        tgt_ids = [BOS_HI] + self.tgt_sp.encode(tgt_text)[:self.max_len-2] + [EOS_HI]
        src_ids += [PAD_EN] * (self.max_len - len(src_ids))
        tgt_ids += [PAD_HI] * (self.max_len - len(tgt_ids))
        return torch.tensor(src_ids), torch.tensor(tgt_ids)

def get_loader(dataset_split, shuffle=True):
    return DataLoader(NMTDataset(dataset_split, sp_en, sp_hi),
                      batch_size=BATCH_SIZE, shuffle=shuffle)

train_loader = get_loader(train_data)
val_loader = get_loader(val_data)
test_loader = get_loader(test_data, shuffle=False)

# =====================
# Masks
# =====================
def create_padding_mask(seq, lang='en'):
    pad_id = PAD_EN if lang == 'en' else PAD_HI
    return (seq == pad_id)

def generate_square_subsequent_mask(sz):
    return torch.triu(torch.ones(sz, sz, device=DEVICE) * float('-inf'), diagonal=1)

# =====================
# Multi-Scale CNN
# =====================
class MultiScaleCNN(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.conv3 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, padding=1)
        self.conv5 = nn.Conv1d(embed_dim, embed_dim, kernel_size=5, padding=2)
        self.conv7 = nn.Conv1d(embed_dim, embed_dim, kernel_size=7, padding=3)
        self.relu = nn.ReLU()
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        residual = x
        x = x.transpose(1, 2)
        out3 = self.conv3(x)
        out5 = self.conv5(x)
        out7 = self.conv7(x)
        x = out3 + out5 + out7
        x = self.relu(x)
        x = x.transpose(1, 2)
        return self.norm(x + residual)

# =====================
# Hybrid Transformer + MultiScale CNN
# =====================
class HybridTransformerModel(nn.Module):
    def __init__(self, src_vocab, tgt_vocab, d_model=512, nhead=8,
                 num_layers=3, dim_feedforward=1024, dropout=0.1, temperature=0.1):
        super().__init__()
        self.src_emb = nn.Embedding(src_vocab, d_model, padding_idx=PAD_EN)
        self.tgt_emb = nn.Embedding(tgt_vocab, d_model, padding_idx=PAD_HI)
        self.pos_enc = nn.Parameter(torch.zeros(1, MAX_LEN, d_model))
        self.cnn_encoder = MultiScaleCNN(d_model)
        self.transformer = nn.Transformer(
            d_model=d_model, nhead=nhead,
            num_encoder_layers=num_layers, num_decoder_layers=num_layers,
            dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True
        )
        self.fc_out = nn.Linear(d_model, tgt_vocab)
        self.temperature = temperature

    def encode(self, src, src_key_padding_mask):
        src_emb = self.src_emb(src) + self.pos_enc[:, :src.size(1), :]
        src_cnn = self.cnn_encoder(src_emb)
        return self.transformer.encoder(src_cnn, src_key_padding_mask=src_key_padding_mask)

    def decode(self, tgt, memory, tgt_mask, memory_key_padding_mask, tgt_key_padding_mask):
        tgt_emb = self.tgt_emb(tgt) + self.pos_enc[:, :tgt.size(1), :]
        return self.transformer.decoder(
            tgt_emb, memory,
            tgt_mask=tgt_mask,
            memory_key_padding_mask=memory_key_padding_mask,
            tgt_key_padding_mask=tgt_key_padding_mask
        )

    def forward(self, src, tgt, src_key_padding_mask=None, tgt_key_padding_mask=None,
                memory_key_padding_mask=None, tgt_mask=None):
        memory = self.encode(src, src_key_padding_mask)
        output = self.decode(tgt, memory, tgt_mask, memory_key_padding_mask, tgt_key_padding_mask)
        return self.fc_out(output)

    # -----------------------
    # Contrastive loss
    # -----------------------
    def contrastive_loss(self, anchor, positive, negative):
        anchor = anchor.mean(dim=1)
        positive = positive.mean(dim=1)
        negative = negative.mean(dim=1)
        pos_sim = torch.cosine_similarity(anchor, positive, dim=-1)
        neg_sim = torch.cosine_similarity(anchor, negative, dim=-1)
        loss = -torch.log(torch.exp(pos_sim / self.temperature) / (torch.exp(pos_sim / self.temperature) + torch.exp(neg_sim / self.temperature)))
        return loss.mean()

# =====================
# Initialize model
# =====================
model = HybridTransformerModel(len(sp_en), len(sp_hi)).to(DEVICE)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_HI)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=LEARNING_RATE, steps_per_epoch=len(train_loader), epochs=EPOCHS
)
scaler = torch.amp.GradScaler()

start_epoch = 1
best_val_loss = float("inf")
epochs_no_improve = 0

# =====================
# Resume checkpoint
# =====================
if CHECKPOINT_PATH.exists():
    print(f"Resuming from {CHECKPOINT_PATH}...")
    checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    scaler.load_state_dict(checkpoint['scaler_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    best_val_loss = checkpoint['best_val_loss']
    epochs_no_improve = checkpoint['epochs_no_improve']
    print(f"Checkpoint loaded ‚Äî Resuming from epoch {start_epoch}")
else:
    print("Starting training from scratch")

# =====================
# Training Loop
# =====================
for epoch in range(start_epoch, EPOCHS + 1):
    model.train()
    train_loss = 0
    contrastive_loss_total = 0

    for src, tgt in tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}", unit="batch"):
        src, tgt = src.to(DEVICE), tgt.to(DEVICE)
        src_mask = create_padding_mask(src, 'en')
        tgt_input, tgt_out = tgt[:, :-1], tgt[:, 1:]
        tgt_mask = generate_square_subsequent_mask(tgt_input.size(1))
        tgt_key_padding_mask = create_padding_mask(tgt_input, 'hi')

        optimizer.zero_grad()
        with torch.amp.autocast(device_type="cuda"):
            logits = model(
                src, tgt_input,
                src_key_padding_mask=src_mask,
                tgt_key_padding_mask=tgt_key_padding_mask,
                memory_key_padding_mask=src_mask,
                tgt_mask=tgt_mask
            )
            ce_loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))

            # -----------------------------
            # Contrastive loss
            # -----------------------------
            neg_tgt_input = tgt_input[torch.randperm(tgt_input.size(0))].to(DEVICE)
            memory = model.encode(src, src_mask)
            pos_repr = model.decode(tgt_input, memory, tgt_mask, src_mask, tgt_key_padding_mask)
            neg_repr = model.decode(neg_tgt_input, memory, tgt_mask, src_mask, tgt_key_padding_mask)
            cl_loss = model.contrastive_loss(pos_repr, pos_repr, neg_repr)

            loss = ce_loss + 0.1 * cl_loss
            contrastive_loss_total += cl_loss.item()

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        train_loss += ce_loss.item()

    avg_train_loss = train_loss / len(train_loader)
    avg_cl_loss = contrastive_loss_total / len(train_loader)
    print(f"Epoch {epoch} | CE Loss: {avg_train_loss:.4f} | Contrastive Loss: {avg_cl_loss:.4f}")

    # ---- Validation ----
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for src, tgt in val_loader:
            src, tgt = src.to(DEVICE), tgt.to(DEVICE)
            src_mask = create_padding_mask(src, 'en')
            tgt_input, tgt_out = tgt[:, :-1], tgt[:, 1:]
            tgt_mask = generate_square_subsequent_mask(tgt_input.size(1))
            tgt_key_padding_mask = create_padding_mask(tgt_input, 'hi')

            logits = model(
                src, tgt_input,
                src_key_padding_mask=src_mask,
                tgt_key_padding_mask=tgt_key_padding_mask,
                memory_key_padding_mask=src_mask,
                tgt_mask=tgt_mask
            )
            loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))
            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader)
    print(f"Epoch {epoch} | Val Loss: {avg_val_loss:.4f}")

    # ---- Save Checkpoint & Early Stopping ----
    torch.save({
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict(),
        "scaler_state_dict": scaler.state_dict(),
        "best_val_loss": best_val_loss,
        "epochs_no_improve": epochs_no_improve,
    }, CHECKPOINT_PATH)

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        epochs_no_improve = 0
        torch.save(model.state_dict(), BEST_MODEL_PATH)
        print("New best model saved!")
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= PATIENCE:
            print("Early stopping triggered.")
            break

print("Training complete ‚úÖ")


  from .autonotebook import tqdm as notebook_tqdm


Dataset sizes: 800000 100000 100000
Resuming from checkpoint_multiscale.pt...


  checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)


Checkpoint loaded ‚Äî Resuming from epoch 22
Training complete ‚úÖ
