In [1]:
# Full Colab-ready script: Transformer + Gated Multi-Scale CNN (GMSCNN)
# Save as main_gmsc_transformer.py or paste into a Colab cell and run.
# Assumes spm_en.model and spm_hi.model exist in working directory OR will train on train split.

# =====================
# Install & Imports
# =====================
!pip install -q datasets

!pip install -q datasets sentencepiece sacrebleu nltk torch torchvision torchaudio tqdm

import os
import random
from pathlib import Path
from datetime import datetime
import json
import math
import copy
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from datasets import load_dataset
import sentencepiece as spm
import sacrebleu
from nltk.translate.meteor_score import meteor_score
import nltk
nltk.download('wordnet', quiet=True)
nltk.download('omw-1.4', quiet=True)

# =====================
# Config (editable)
# =====================
SEED = 42
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)
VOCAB_SIZE = 32000           # change as needed
BATCH_SIZE = 64              # adjust to GPU memory
MAX_LEN = 64
MAX_GEN_LEN = 64
EPOCHS = 20
CLIP = 1.0
LEARNING_RATE = 3e-4
BEAM_SIZE = 5
LENGTH_PENALTY = 0.6

# File paths
SP_EN_PATH = Path("spm_en.model")
SP_HI_PATH = Path("spm_hi.model")
BEST_MODEL_PATH = Path("best_model_gmsc.pt")
CHECKPOINT_PATH = Path("checkpoint_gmsc.pt")
METRICS_PATH = Path("metrics_gmsc.pt")    # final metrics saved
PARTIAL_PATH = Path("metrics_gmsc.partial.pt")

# Other options
USE_LABEL_SMOOTHING = True
LABEL_SMOOTHING = 0.1
SAVE_EVERY = 10000   # for incremental saving during evaluation (if used)

# reproducibility
torch.manual_seed(SEED)
random.seed(SEED)

# =====================
# Dataset: load samanantar (English-Hindi)
# =====================
print("Loading dataset (ai4bharat/samanantar, hi)... (this can take a minute)")
full_dataset = load_dataset("ai4bharat/samanantar", "hi", split="train")
full_dataset = full_dataset.shuffle(seed=SEED)

NUM_EXAMPLES = min(1_000_000, len(full_dataset))  # reduce if resource-limited; set to 1_000_000 if desired
subset = full_dataset.select(range(NUM_EXAMPLES))

train_end = int(0.8 * len(subset))
val_end = int(0.9 * len(subset))
train_data = subset.select(range(0, train_end))
val_data = subset.select(range(train_end, val_end))
test_data = subset.select(range(val_end, len(subset)))
print("Subset sizes:", len(train_data), len(val_data), len(test_data))

# =====================
# Train SentencePiece if missing
# =====================
def write_lines(dataset_split, src_path, tgt_path):
    with open(src_path, "w", encoding="utf-8") as sf, open(tgt_path, "w", encoding="utf-8") as tf:
        for ex in dataset_split:
            sf.write(ex["src"].strip().lower() + "\n")
            tf.write(ex["tgt"].strip() + "\n")

if not SP_EN_PATH.exists() or not SP_HI_PATH.exists():
    print("Creating train.en / train.hi for SentencePiece training...")
    write_lines(train_data, "train.en", "train.hi")
    print("Training SentencePiece models (this may take a while)...")
    spm.SentencePieceTrainer.Train(
        f"--input=train.en --model_prefix=spm_en --vocab_size={VOCAB_SIZE} --character_coverage=1.0 --model_type=unigram"
    )
    spm.SentencePieceTrainer.Train(
        f"--input=train.hi --model_prefix=spm_hi --vocab_size={VOCAB_SIZE} --character_coverage=0.9995 --model_type=unigram"
    )

sp_en = spm.SentencePieceProcessor()
sp_hi = spm.SentencePieceProcessor()
sp_en.load(str(SP_EN_PATH))
sp_hi.load(str(SP_HI_PATH))

PAD_EN, BOS_EN, EOS_EN = 0, 1, 2
PAD_HI, BOS_HI, EOS_HI = 0, 1, 2

# =====================
# Dataset & DataLoader
# =====================
class NMTDataset(Dataset):
    def __init__(self, dataset, src_sp, tgt_sp, max_len=MAX_LEN):
        self.dataset = dataset
        self.src_sp = src_sp
        self.tgt_sp = tgt_sp
        self.max_len = max_len

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        src_text = self.dataset[idx]["src"].lower()
        tgt_text = self.dataset[idx]["tgt"]
        src_ids = [BOS_EN] + self.src_sp.encode(src_text)[:self.max_len-2] + [EOS_EN]
        tgt_ids = [BOS_HI] + self.tgt_sp.encode(tgt_text)[:self.max_len-2] + [EOS_HI]
        # pad
        src_ids += [PAD_EN] * (self.max_len - len(src_ids))
        tgt_ids += [PAD_HI] * (self.max_len - len(tgt_ids))
        return torch.tensor(src_ids, dtype=torch.long), torch.tensor(tgt_ids, dtype=torch.long)

def get_loader(dataset_split, shuffle=True):
    return DataLoader(NMTDataset(dataset_split, sp_en, sp_hi),
                      batch_size=BATCH_SIZE, shuffle=shuffle, num_workers=2, pin_memory=True)

train_loader = get_loader(train_data, shuffle=True)
val_loader = get_loader(val_data, shuffle=False)
test_loader = get_loader(test_data, shuffle=False)

# =====================
# Masks utilities
# =====================
def create_padding_mask(seq, lang='en'):
    pad_id = PAD_EN if lang == 'en' else PAD_HI
    return (seq == pad_id)

def generate_square_subsequent_mask(sz):
    # returns float mask with -inf above diagonal on device
    m = torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)
    return m.to(DEVICE)

# =====================
# Gated Multi-Scale CNN
# =====================
class GatedMultiScaleCNN(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.conv3 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, padding=1)
        self.conv5 = nn.Conv1d(embed_dim, embed_dim, kernel_size=5, padding=2)
        self.conv7 = nn.Conv1d(embed_dim, embed_dim, kernel_size=7, padding=3)
        self.gate_proj = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2),
            nn.ReLU(),
            nn.Linear(embed_dim // 2, 3)
        )
        self.activation = nn.ReLU()
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        # x: (batch, seq_len, embed_dim)
        residual = x
        b, s, d = x.size()
        x_t = x.transpose(1, 2)  # (b, d, s)
        o3 = self.conv3(x_t).transpose(1, 2)  # (b, s, d)
        o5 = self.conv5(x_t).transpose(1, 2)
        o7 = self.conv7(x_t).transpose(1, 2)
        stacked = torch.stack([o3, o5, o7], dim=-1)  # (b, s, d, 3)
        gates = self.gate_proj(residual)             # (b, s, 3)
        gates = F.softmax(gates, dim=-1).unsqueeze(2)  # (b, s, 1, 3)
        fused = (stacked * gates).sum(-1)            # (b, s, d)
        fused = self.activation(fused)
        out = self.norm(fused + residual)
        return out

# =====================
# Hybrid Transformer Model (GMSCNN encoder)
# =====================
class HybridTransformerModel(nn.Module):
    def __init__(self, src_vocab, tgt_vocab, d_model=512, nhead=8,
                 num_layers=3, dim_feedforward=1024, dropout=0.1):
        super().__init__()
        self.src_emb = nn.Embedding(src_vocab, d_model, padding_idx=PAD_EN)
        self.tgt_emb = nn.Embedding(tgt_vocab, d_model, padding_idx=PAD_HI)
        self.pos_enc = nn.Parameter(torch.zeros(1, MAX_LEN, d_model))
        self.cnn_encoder = GatedMultiScaleCNN(d_model)
        self.transformer = nn.Transformer(
            d_model=d_model, nhead=nhead,
            num_encoder_layers=num_layers, num_decoder_layers=num_layers,
            dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True
        )
        self.fc_out = nn.Linear(d_model, tgt_vocab)

    def encode(self, src, src_key_padding_mask):
        src_emb = self.src_emb(src) + self.pos_enc[:, :src.size(1), :]
        src_cnn = self.cnn_encoder(src_emb)
        return self.transformer.encoder(src_cnn, src_key_padding_mask=src_key_padding_mask)

    def decode(self, tgt, memory, tgt_mask, memory_key_padding_mask, tgt_key_padding_mask):
        tgt_emb = self.tgt_emb(tgt) + self.pos_enc[:, :tgt.size(1), :]
        return self.transformer.decoder(
            tgt_emb, memory,
            tgt_mask=tgt_mask,
            memory_key_padding_mask=memory_key_padding_mask,
            tgt_key_padding_mask=tgt_key_padding_mask
        )

    def forward(self, src, tgt, src_key_padding_mask=None, tgt_key_padding_mask=None,
                memory_key_padding_mask=None, tgt_mask=None):
        memory = self.encode(src, src_key_padding_mask)
        output = self.decode(tgt, memory, tgt_mask, memory_key_padding_mask, tgt_key_padding_mask)
        return self.fc_out(output)

# =====================
# Initialize model, loss, optimizer, scheduler
# =====================
model = HybridTransformerModel(len(sp_en), len(sp_hi), d_model=512, nhead=8, num_layers=3).to(DEVICE)
if USE_LABEL_SMOOTHING:
    # CrossEntropy with label smoothing is available in PyTorch 1.10+ via nn.CrossEntropyLoss(label_smoothing=...)
    criterion = nn.CrossEntropyLoss(ignore_index=PAD_HI, label_smoothing=LABEL_SMOOTHING)
else:
    criterion = nn.CrossEntropyLoss(ignore_index=PAD_HI)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LEARNING_RATE,
                                          steps_per_epoch=len(train_loader), epochs=EPOCHS)
scaler = torch.amp.GradScaler()

# Resume checkpoint if exists
start_epoch = 1
best_val_loss = float("inf")
epochs_no_improve = 0
PATIENCE = 5

if CHECKPOINT_PATH.exists():
    print("Loading checkpoint:", CHECKPOINT_PATH)
    ckpt = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
    model.load_state_dict(ckpt['model_state_dict'])
    optimizer.load_state_dict(ckpt['optimizer_state_dict'])
    scheduler.load_state_dict(ckpt['scheduler_state_dict'])
    scaler.load_state_dict(ckpt['scaler_state_dict'])
    start_epoch = ckpt['epoch'] + 1
    best_val_loss = ckpt.get('best_val_loss', best_val_loss)
    epochs_no_improve = ckpt.get('epochs_no_improve', 0)
    print("Resumed from epoch", start_epoch)

# =====================
# Greedy & Beam decoding utilities
# =====================
@torch.no_grad()
def greedy_decode(model, src_sentence_ids, max_len=MAX_GEN_LEN):
    # src_sentence_ids: tensor (1, seq_len)
    model.eval()
    src = src_sentence_ids.to(DEVICE)
    src_mask = create_padding_mask(src, 'en')
    memory = model.encode(src, src_mask)
    ys = torch.tensor([[BOS_HI]], dtype=torch.long, device=DEVICE)
    for i in range(max_len):
        tgt_mask = generate_square_subsequent_mask(ys.size(1)).to(DEVICE)
        out = model.decode(ys, memory, tgt_mask, src_mask, create_padding_mask(ys, 'hi'))
        logits = model.fc_out(out[:, -1, :])
        next_word = logits.argmax(-1).item()
        ys = torch.cat([ys, torch.tensor([[next_word]], device=DEVICE)], dim=1)
        if next_word == EOS_HI:
            break
    tokens = [t for t in ys.squeeze().tolist() if t not in [BOS_HI, EOS_HI, PAD_HI]]
    text = sp_hi.decode(tokens) if tokens else ""
    return text.strip()

# Simple beam search (length-penalty)
class Beam:
    def __init__(self, tokens, logprob, state=None):
        self.tokens = tokens
        self.logprob = logprob
        self.state = state

def beam_decode(model, src_sentence_ids, beam_size=BEAM_SIZE, max_len=MAX_GEN_LEN, length_penalty=LENGTH_PENALTY):
    model.eval()
    src = src_sentence_ids.to(DEVICE)
    src_mask = create_padding_mask(src, 'en')
    memory = model.encode(src, src_mask)

    # initial beam
    beams = [Beam(tokens=[BOS_HI], logprob=0.0)]
    completed = []

    for _ in range(max_len):
        all_candidates = []
        for beam in beams:
            if beam.tokens[-1] == EOS_HI:
                all_candidates.append(beam)
                continue
            tgt = torch.tensor([beam.tokens], dtype=torch.long, device=DEVICE)
            tgt_mask = generate_square_subsequent_mask(tgt.size(1)).to(DEVICE)
            out = model.decode(tgt, memory, tgt_mask, src_mask, create_padding_mask(tgt, 'hi'))
            logits = model.fc_out(out[:, -1, :])  # (1, vocab)
            log_probs = F.log_softmax(logits, dim=-1).squeeze(0)  # (vocab,)

            topk = torch.topk(log_probs, beam_size)
            for k in range(beam_size):
                token = int(topk.indices[k].item())
                lp = float(topk.values[k].item())
                new_beam = Beam(tokens=beam.tokens + [token], logprob=beam.logprob + lp)
                all_candidates.append(new_beam)

        # select top beams
        beams = sorted(all_candidates, key=lambda b: b.logprob / ((5 + len(b.tokens)) ** length_penalty), reverse=True)[:beam_size]

        # stop if all beams ended
        if all([b.tokens[-1] == EOS_HI for b in beams]):
            break

    # choose best completed or best beam
    best = max(beams, key=lambda b: b.logprob / ((5 + len(b.tokens)) ** length_penalty))
    tokens = [t for t in best.tokens if t not in [BOS_HI, EOS_HI, PAD_HI]]
    return sp_hi.decode(tokens) if tokens else ""

# =====================
# Evaluation utilities (BLEU, METEOR, TER) with incremental save
# =====================
def evaluate_and_save(model, dataset_split, use_beam=False, metrics_path=METRICS_PATH, partial_path=PARTIAL_PATH, save_every=SAVE_EVERY):
    # Load partial results if exist
    if partial_path.exists():
        print("Loading partial results:", partial_path)
        data = torch.load(partial_path)
        refs, hyps, meteor_scores, start_idx = data["refs"], data["hyps"], data["meteor_scores"], data.get("last_idx", 0)
        print("Resuming from index", start_idx)
    else:
        refs, hyps, meteor_scores = [], [], []
        start_idx = 0

    model.to(DEVICE)
    model.eval()

    # iterate over dataset_split indices for reproducibility
    for idx in tqdm(range(start_idx, len(dataset_split)), desc="Evaluating", unit="sent"):
        ex = dataset_split[idx]
        src_text = ex["src"].lower()
        ref_text = ex["tgt"]
        # build src tensor for one example (with padding/truncation to MAX_LEN)
        src_ids = [BOS_EN] + sp_en.encode(src_text)[:MAX_LEN-2] + [EOS_EN]
        src_ids += [PAD_EN] * (MAX_LEN - len(src_ids))
        src_tensor = torch.tensor(src_ids, dtype=torch.long).unsqueeze(0)

        if use_beam:
            pred_text = beam_decode(model, src_tensor, beam_size=BEAM_SIZE, max_len=MAX_GEN_LEN, length_penalty=LENGTH_PENALTY)
        else:
            pred_text = greedy_decode(model, src_tensor, max_len=MAX_GEN_LEN)

        refs.append(ref_text)
        hyps.append(pred_text)
        meteor_scores.append(meteor_score([ref_text.split()], pred_text.split()))

        if (idx + 1) % save_every == 0 or (idx + 1) == len(dataset_split):
            torch.save({
                "refs": refs,
                "hyps": hyps,
                "meteor_scores": meteor_scores,
                "last_idx": idx + 1
            }, partial_path)
            print(f"Saved partial at sentence {idx+1}")

    # Compute corpus metrics
    bleu = sacrebleu.corpus_bleu(hyps, [refs])
    ter_metric = sacrebleu.metrics.TER()
    ter = ter_metric.corpus_score(hyps, [refs])
    meteor_avg = sum(meteor_scores) / len(meteor_scores) if meteor_scores else 0.0

    results = {
        "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        "BLEU": round(bleu.score, 2),
        "TER": round(ter.score, 2),
        "METEOR": round(meteor_avg * 100, 2)
    }

    torch.save(results, metrics_path)
    if partial_path.exists():
        try:
            partial_path.unlink()
        except:
            pass

    print("Final metrics:", results)
    return results

# =====================
# Training loop (with validation + checkpointing)
# =====================
def train_loop(model, train_loader, val_loader, start_epoch=1, epochs=EPOCHS):
    global best_val_loss, epochs_no_improve
    for epoch in range(start_epoch, epochs + 1):
        model.train()
        train_ce = 0.0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}", leave=False)
        for src, tgt in pbar:
            src, tgt = src.to(DEVICE), tgt.to(DEVICE)
            src_mask = create_padding_mask(src, 'en')
            tgt_input, tgt_out = tgt[:, :-1], tgt[:, 1:]
            tgt_mask = generate_square_subsequent_mask(tgt_input.size(1)).to(DEVICE)
            tgt_key_padding_mask = create_padding_mask(tgt_input, 'hi')

            optimizer.zero_grad()
            with torch.amp.autocast(device_type="cuda" if DEVICE.startswith("cuda") else "cpu"):
                logits = model(
                    src, tgt_input,
                    src_key_padding_mask=src_mask,
                    tgt_key_padding_mask=tgt_key_padding_mask,
                    memory_key_padding_mask=src_mask,
                    tgt_mask=tgt_mask
                )  # (b, seq_len, vocab)
                loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP)
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            train_ce += loss.item()

        avg_train_loss = train_ce / len(train_loader)
        print(f"Epoch {epoch} | Train CE Loss: {avg_train_loss:.4f}")

        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for src, tgt in val_loader:
                src, tgt = src.to(DEVICE), tgt.to(DEVICE)
                src_mask = create_padding_mask(src, 'en')
                tgt_input, tgt_out = tgt[:, :-1], tgt[:, 1:]
                tgt_mask = generate_square_subsequent_mask(tgt_input.size(1)).to(DEVICE)
                tgt_key_padding_mask = create_padding_mask(tgt_input, 'hi')

                logits = model(
                    src, tgt_input,
                    src_key_padding_mask=src_mask,
                    tgt_key_padding_mask=tgt_key_padding_mask,
                    memory_key_padding_mask=src_mask,
                    tgt_mask=tgt_mask
                )
                loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader)
        print(f"Epoch {epoch} | Val Loss: {avg_val_loss:.4f}")

        # Save checkpoint
        torch.save({
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict(),
            "scaler_state_dict": scaler.state_dict(),
            "best_val_loss": best_val_loss,
            "epochs_no_improve": epochs_no_improve,
        }, CHECKPOINT_PATH)

        # Save best
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            epochs_no_improve = 0
            torch.save(model.state_dict(), BEST_MODEL_PATH)
            print("New best model saved!")
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= PATIENCE:
                print("Early stopping triggered.")
                break

    print("Training finished.")

# =====================
# Run training
# =====================
train_loop(model, train_loader, val_loader, start_epoch=start_epoch, epochs=EPOCHS)

# Load best model for evaluation if exists
if BEST_MODEL_PATH.exists():
    print("Loading best model for evaluation:", BEST_MODEL_PATH)
    model.load_state_dict(torch.load(BEST_MODEL_PATH, map_location=DEVICE))
else:
    print("No best model found, using last model state.")

# =====================
# Run evaluation (greedy and beam)
# =====================
# =====================
# Run evaluation (greedy and beam)
# =====================
def print_metrics(title, path):
    metrics = torch.load(path)
    print(f"\n✅ {title} results ({path.name}):")
    for k, v in metrics.items():
        print(f"  {k}: {v}")
    return metrics

# --- Greedy evaluation ---
greedy_path = Path("metrics_gmsc_greedy.pt")
if greedy_path.exists():
    print("Greedy metrics already exist — skipping evaluation.")
    metrics_greedy = print_metrics("Greedy Decode", greedy_path)
else:
    print("Evaluating (greedy decode) on test split...")
    metrics_greedy = evaluate_and_save(model, test_data, use_beam=False, metrics_path=greedy_path)
    print_metrics("Greedy Decode", greedy_path)


# Save sample translations
sample_indices = random.sample(range(len(test_data)), 10)
samples = []
for idx in sample_indices:
    ex = test_data[idx]
    src_text = ex["src"].lower()
    src_ids = [BOS_EN] + sp_en.encode(src_text)[:MAX_LEN-2] + [EOS_EN]
    src_ids += [PAD_EN] * (MAX_LEN - len(src_ids))
    src_tensor = torch.tensor(src_ids, dtype=torch.long).unsqueeze(0)

    pred_greedy = greedy_decode(model, src_tensor)
    pred_beam = beam_decode(model, src_tensor)
    samples.append({
        "src": src_text,
        "ref": ex["tgt"],
        "pred_greedy": pred_greedy,
        "pred_beam": pred_beam
    })

with open("sample_translations_gmsc.json", "w", encoding="utf-8") as f:
    json.dump(samples, f, ensure_ascii=False, indent=2)

print("Saved sample translations to sample_translations_gmsc.json")
print("Done.")

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda
Loading dataset (ai4bharat/samanantar, hi)... (this can take a minute)
Subset sizes: 800000 100000 100000
Loading checkpoint: checkpoint_gmsc.pt


  ckpt = torch.load(CHECKPOINT_PATH, map_location=DEVICE)


Resumed from epoch 21
Training finished.
Loading best model for evaluation: best_model_gmsc.pt
Greedy metrics already exist — skipping evaluation.

✅ Greedy Decode results (metrics_gmsc_greedy.pt):
  timestamp: 2025-10-27 17:02:48
  BLEU: 19.53
  TER: 71.66
  METEOR: 39.31


  model.load_state_dict(torch.load(BEST_MODEL_PATH, map_location=DEVICE))
  metrics = torch.load(path)
  output = torch._nested_tensor_from_mask(


Saved sample translations to sample_translations_gmsc.json
Done.
