In [1]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler

import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import math
import collections
import heapq
import itertools
import unicodedata
import matplotlib.pyplot as plt
import time
import pandas as pd

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

set_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

Device: cuda
GPU: Tesla P100-PCIE-16GB
Memory: 17.06 GB


In [2]:
#inspired from assignemnt 2
import unicodedata
import collections
import heapq
import itertools
import numpy as np

def normalize_text(text):
   
    text = unicodedata.normalize("NFKC", text)
    return text

def pair_lex(pair):

    return pair[0] + "\0" + pair[1]

def get_word_freqs(text):
   
    words = text.split()
    return collections.Counter("‚ñÅ" + w for w in words if w)

def get_initial_splits(word_freqs):
    """
    Convert words to character sequences
    NO end-of-word markers - pure SentencePiece style
    """
    splits = {}
    for word in word_freqs:
        if not word:
            continue
        # Split into characters (‚ñÅ is part of the word)
        splits[word] = list(word)
    return splits

def train_custom_sp_tokenizer(text, vocab_size, add_special_tokens=True):
   
    print(f"üîß Training tokenizer with vocab_size={vocab_size}...")
    
    word_freqs = get_word_freqs(text)
    splits = get_initial_splits(word_freqs)
    
    # Cache word list for faster iteration
    word_list = list(word_freqs.keys())
    
    # Base vocabulary from all unique characters
    base_symbols = set()
    for symbols in splits.values():
        base_symbols.update(symbols)
    
    # Only 4 special tokens: <pad>, <unk>, <s>, </s>
    num_special = 4 if add_special_tokens else 0
    base_vocab_size = len(base_symbols)
    merges_needed = vocab_size - num_special - base_vocab_size
    
    print(f"   Base vocab: {base_vocab_size} symbols")
    print(f"   Merges needed: {merges_needed}")
    
    if merges_needed <= 0:
        vocab = (["<pad>", "<unk>", "<s>", "</s>"] if add_special_tokens else [])
        vocab.extend(sorted(base_symbols))
        return vocab[:vocab_size], {}
    
    # Build initial pair frequencies
    pair_freqs = collections.Counter()
    pair_to_positions = collections.defaultdict(list)  # pair -> [(word_idx, pos)]
    
    for word_idx, (word, symbols) in enumerate(splits.items()):
        freq = word_freqs[word]
        for i in range(len(symbols) - 1):
            pair = (symbols[i], symbols[i+1])
            pair_freqs[pair] += freq
            pair_to_positions[pair].append((word_idx, i))
    
    # Priority queue - start with ALL pairs (freq >= 1)
    counter = itertools.count()
    heap = [(-freq, pair_lex(pair), next(counter), pair) 
            for pair, freq in pair_freqs.items() if freq >= 1]
    heapq.heapify(heap)
    
    print(f"   Initial heap size: {len(heap)} pairs")
    
    merges = {}
    merges_done = 0
    
    PRINT_EVERY = 1000
    
    while merges_done < merges_needed and heap:
        # Find valid best pair (lazy deletion)
        best_pair = None
        best_freq = 0
        
        while heap:
            negf, lex_key, _, pair = heapq.heappop(heap)
            f = -negf
            if pair in pair_freqs and pair_freqs[pair] == f:
                best_pair = pair
                best_freq = f
                break
        
        # Stop only if we truly can't find any more pairs
        if not best_pair or best_freq < 1:
            print(f"   Stopping: no more pairs with freq >= 1")
            break
        
        a, b = best_pair
        merged_token = a + b
        merges[best_pair] = merged_token
        merges_done += 1
        
        if merges_done % PRINT_EVERY == 0:
            print(f"   Progress: {merges_done}/{merges_needed} | Best freq: {best_freq}")
        
        # Batch frequency updates
        positions = pair_to_positions.pop(best_pair, [])
        freq_deltas = {}
        new_positions = collections.defaultdict(list)
        
        for word_idx, i in positions:
            word = word_list[word_idx]
            symbols = splits[word]
            
            if i >= len(symbols) - 1 or symbols[i] != a or symbols[i+1] != b:
                continue
            
            freq = word_freqs[word]
            
            # Track left neighbor changes
            if i > 0:
                old_left = (symbols[i-1], symbols[i])
                freq_deltas[old_left] = freq_deltas.get(old_left, 0) - freq
            
            # Track right neighbor changes
            if i + 2 < len(symbols):
                old_right = (symbols[i+1], symbols[i+2])
                freq_deltas[old_right] = freq_deltas.get(old_right, 0) - freq
            
            # Perform merge
            symbols[i] = merged_token
            del symbols[i+1]
            
            # Track new neighbors
            if i > 0:
                new_left = (symbols[i-1], symbols[i])
                freq_deltas[new_left] = freq_deltas.get(new_left, 0) + freq
                new_positions[new_left].append((word_idx, i-1))
            
            if i + 1 < len(symbols):
                new_right = (symbols[i], symbols[i+1])
                freq_deltas[new_right] = freq_deltas.get(new_right, 0) + freq
                new_positions[new_right].append((word_idx, i))
        
        # Apply frequency changes
        for pair, delta in freq_deltas.items():
            old_freq = pair_freqs.get(pair, 0)
            new_freq = old_freq + delta
            
            if new_freq <= 0:
                pair_freqs.pop(pair, None)
                pair_to_positions.pop(pair, None)
            else:
                pair_freqs[pair] = new_freq
                # Add pairs with frequency >= 1
                if new_freq >= 1:
                    heapq.heappush(heap, (-new_freq, pair_lex(pair), next(counter), pair))
        
        # Add new positions
        for pair, positions_list in new_positions.items():
            pair_to_positions[pair].extend(positions_list)
    
    # Build final vocabulary
    vocab = []
    if add_special_tokens:
        vocab = ["<pad>", "<unk>", "<s>", "</s>"]
    
    # Add base character symbols
    base_tokens = sorted(base_symbols)
    vocab.extend(base_tokens)
    
    # Add all merged tokens
    merge_tokens = list(merges.values())
    vocab.extend(merge_tokens)
    
    print(f"\nüìä Vocabulary breakdown:")
    print(f"   Special tokens: {4 if add_special_tokens else 0}")
    print(f"   Base symbols: {len(base_tokens)}")
    print(f"   Merged tokens: {len(merge_tokens)}")
    print(f"   Total before truncation: {len(vocab)}")
    
    # Ensure vocab is exactly vocab_size
    if len(vocab) > vocab_size:
        print(f"   ‚ö†Ô∏è  Truncating from {len(vocab)} to {vocab_size}")
        vocab = vocab[:vocab_size]
    elif len(vocab) < vocab_size:
        print(f"   ‚ö†Ô∏è  Warning: Only {len(vocab)} tokens generated (target: {vocab_size})")
    
    print(f"‚úÖ Final vocabulary: {len(vocab)} tokens")
    
    return vocab, merges

def tokenize_custom(text, merges):
   
    text = normalize_text(text)
    
    if not text:
        return []
    
    # Split into words and add space markers
    words = text.split()
    all_tokens = []
    
    for word in words:
        word = "‚ñÅ" + word
        symbols = list(word)
        
        # Apply BPE merges
        i = 0
        while i < len(symbols) - 1:
            pair = (symbols[i], symbols[i+1])
            if pair in merges:
                symbols[i] = merges[pair]
                del symbols[i+1]
                # Backtrack to check if new merge is possible
                if i > 0:
                    i -= 1
            else:
                i += 1
        
        all_tokens.extend(symbols)
    
    return all_tokens

def detokenize_custom(tokens):
    
    if not tokens:
        return ""
    text = "".join(tokens).replace("‚ñÅ", " ")
    return text.strip()

class CustomTokenizer:
    
    
    def __init__(self, vocab, merges, lang_tags=None):
        self.vocab = list(vocab)  # Make a copy
        self.merges = merges
        self.token2id = {token: idx for idx, token in enumerate(self.vocab)}
        self.id2token = {idx: token for idx, token in enumerate(self.vocab)}
        self.vocab_size = len(self.vocab)
        
        # Special token IDs
        self.pad_id = self.token2id.get("<pad>", 0)
        self.unk_id = self.token2id.get("<unk>", 1)
        self.bos_id = self.token2id.get("<s>", 2)
        self.eos_id = self.token2id.get("</s>", 3)
        
        # Language tag IDs (if provided, add them to vocab)
        self.lang_tag_ids = {}
        if lang_tags:
            for tag in lang_tags:
                if tag in self.token2id:
                    self.lang_tag_ids[tag] = self.token2id[tag]
                else:
                    # Add language tag to vocab
                    self.token2id[tag] = self.vocab_size
                    self.id2token[self.vocab_size] = tag
                    self.vocab.append(tag)
                    self.lang_tag_ids[tag] = self.vocab_size
                    self.vocab_size += 1
        
        # Pre-compute special IDs set for faster lookup
        self._special_ids = frozenset([
            self.pad_id, self.unk_id, self.bos_id, self.eos_id
        ] + list(self.lang_tag_ids.values()))
    
    def encode(self, text, add_bos=False, add_eos=False):
       
        tokens = tokenize_custom(text, self.merges)
        
        ids = []
        if add_bos:
            ids.append(self.bos_id)
        
        # Batch lookup with get (faster than exception handling)
        unk_id = self.unk_id
        token2id = self.token2id
        for token in tokens:
            ids.append(token2id.get(token, unk_id))
        
        if add_eos:
            ids.append(self.eos_id)
        
        return ids
    
    def decode(self, ids, skip_special_tokens=True):
        """Decode IDs to text"""
        if skip_special_tokens:
            tokens = [self.id2token.get(idx, "<unk>") 
                     for idx in ids if idx not in self._special_ids]
        else:
            tokens = [self.id2token.get(idx, "<unk>") for idx in ids]
        
        return detokenize_custom(tokens)
    
    def get_lang_tag_id(self, tag):
        """Get ID for language tag (for prefix ID approach)"""
        return self.lang_tag_ids.get(tag, self.unk_id)



‚úÖ CLEAN tokenizer loaded (NO </w>, PREFIX ID approach)!


In [3]:
print("\nüìÇ Loading training data...")
with open('/kaggle/input/codabenchnmt/train_data1.json', 'r', encoding='utf-8') as f:
    train_data = json.load(f)

# Clean function to remove English characters from Indic text
def clean_indic_text(text):
    
    import re
    # Remove English letters but keep numbers and punctuation
    text = re.sub(r'[a-zA-Z]', '', text)
    # Remove extra spaces
    text = re.sub(r'\s+', ' ', text)
    return text.strip()

# Prepare data WITHOUT language tags in text
src_texts = []
tgt_texts = []
lang_tags = []

LANG_TAG_HI = "<2hi>"
LANG_TAG_BN = "<2bn>"

# English-Hindi pairs
for key, value in train_data.get("English-Hindi", {}).get("Train", {}).items():
    src = value.get("source", "").strip()
    tgt = value.get("target", "").strip()
    
    if src and tgt:
        tgt = clean_indic_text(tgt)
        if tgt:  # Only add if target still has content after cleaning
            src_texts.append(src)
            tgt_texts.append(tgt)  # NO language tag in text
            lang_tags.append(LANG_TAG_HI)

# English-Bengali pairs
for key, value in train_data.get("English-Bengali", {}).get("Train", {}).items():
    src = value.get("source", "").strip()
    tgt = value.get("target", "").strip()
    
    if src and tgt:
        tgt = clean_indic_text(tgt)
        if tgt:
            src_texts.append(src)
            tgt_texts.append(tgt)  # NO language tag in text
            lang_tags.append(LANG_TAG_BN)

print(f"‚úÖ Loaded {len(src_texts)} sentence pairs")
print(f"   Hindi: {lang_tags.count(LANG_TAG_HI)}")
print(f"   Bengali: {lang_tags.count(LANG_TAG_BN)}")


üìÇ Loading training data...
‚úÖ Loaded 149629 sentence pairs
   Hindi: 80784
   Bengali: 68845


In [4]:
import pickle
import os

# Language tags (used ONLY for prefix IDs, NOT in training data)
LANG_TAG_HI = "<2hi>"
LANG_TAG_BN = "<2bn>"

# Vocab sizes
SRC_VOCAB_SIZE = 32000
TGT_VOCAB_SIZE = 50000

# Use full data
src_sample = src_texts
tgt_sample = tgt_texts

# Prepare combined text (NO language tags in text)
src_combined = " ".join(src_sample)
tgt_combined = " ".join(tgt_sample)

print(f"\nüìè Text lengths:")
print(f"   Source: {len(src_combined):,} chars")
print(f"   Target: {len(tgt_combined):,} chars")

# Delete old tokenizers to force retraining
import shutil
if os.path.exists('tokenizers'):
    print("\nüóëÔ∏è  Deleting old tokenizers...")
    shutil.rmtree('tokenizers')
    print("   Old tokenizers deleted!")

# Train source tokenizer
print(f"\nüîß Training source tokenizer (vocab={SRC_VOCAB_SIZE})...")
src_vocab, src_merges = train_custom_sp_tokenizer(
    src_combined, 
    SRC_VOCAB_SIZE, 
    add_special_tokens=True
)
src_tokenizer = CustomTokenizer(src_vocab, src_merges)

# Train target tokenizer (NO language tags in training)
print(f"\nüîß Training target tokenizer (vocab={TGT_VOCAB_SIZE})...")
tgt_vocab, tgt_merges = train_custom_sp_tokenizer(
    tgt_combined, 
    TGT_VOCAB_SIZE, 
    add_special_tokens=True
)
# Add language tags AFTER training for prefix ID approach
tgt_tokenizer = CustomTokenizer(tgt_vocab, tgt_merges, lang_tags=[LANG_TAG_HI, LANG_TAG_BN])

print(f"\n‚úÖ Tokenizers ready!")
print(f"   Source vocab: {src_tokenizer.vocab_size}")
print(f"   Target vocab: {tgt_tokenizer.vocab_size}")
print(f"   Language tags: {list(tgt_tokenizer.lang_tag_ids.keys())}")

# Save tokenizers
os.makedirs("tokenizers", exist_ok=True)
with open("tokenizers/src_tokenizer.pkl", "wb") as f:
    pickle.dump((src_vocab, src_merges), f)
with open("tokenizers/tgt_tokenizer.pkl", "wb") as f:
    pickle.dump((tgt_vocab, tgt_merges, [LANG_TAG_HI, LANG_TAG_BN]), f)
print("üíæ Tokenizers saved to tokenizers/")

# Test tokenization (NO language tags in text)
test_src = "The government announced new policies"
test_hi = "‡§∏‡§∞‡§ï‡§æ‡§∞ ‡§®‡•á ‡§®‡§à ‡§®‡•Ä‡§§‡§ø‡§Ø‡§æ‡§Ç ‡§ò‡•ã‡§∑‡§ø‡§§ ‡§ï‡•Ä‡§Ç"  # NO <2hi> tag
test_bn = "‡¶∏‡¶∞‡¶ï‡¶æ‡¶∞ ‡¶®‡¶§‡ßÅ‡¶® ‡¶®‡ßÄ‡¶§‡¶ø ‡¶ò‡ßã‡¶∑‡¶£‡¶æ ‡¶ï‡¶∞‡ßá‡¶õ‡ßá"  # NO <2bn> tag

print("\nüß™ Testing tokenization:")
print(f"Source: {test_src}")
src_tokens = [src_tokenizer.id2token[i] for i in src_tokenizer.encode(test_src)]
print(f"Tokens: {src_tokens[:15]}...")

print(f"\nHindi (NO tag in text): {test_hi}")
hi_tokens = [tgt_tokenizer.id2token[i] for i in tgt_tokenizer.encode(test_hi)]
print(f"Tokens: {hi_tokens}")

print(f"\nBengali (NO tag in text): {test_bn}")
bn_tokens = [tgt_tokenizer.id2token[i] for i in tgt_tokenizer.encode(test_bn)]
print(f"Tokens: {bn_tokens}")

# Show how language tags are used as prefix IDs
print(f"\nüè∑Ô∏è  Language Tag IDs (for prefix):")
print(f"   {LANG_TAG_HI}: {tgt_tokenizer.get_lang_tag_id(LANG_TAG_HI)}")
print(f"   {LANG_TAG_BN}: {tgt_tokenizer.get_lang_tag_id(LANG_TAG_BN)}")
print(f"\nüí° These IDs are prepended during training:")
print(f"   Target sequence: [<s>, LANG_TAG_ID, token_ids..., </s>]")




üìè Text lengths:
   Source: 14,504,290 chars
   Target: 13,656,730 chars

üîß Training source tokenizer (vocab=32000)...
üîß Training tokenizer with vocab_size=32000...
   Base vocab: 287 symbols
   Merges needed: 31709
   Initial heap size: 5190 pairs
   Progress: 1000/31709 | Best freq: 49429
   Progress: 2000/31709 | Best freq: 3257
   Progress: 3000/31709 | Best freq: 38730
   Progress: 4000/31709 | Best freq: 4028
   Progress: 5000/31709 | Best freq: 1743
   Progress: 6000/31709 | Best freq: 1083
   Progress: 7000/31709 | Best freq: 18346
   Progress: 8000/31709 | Best freq: 5377
   Progress: 9000/31709 | Best freq: 8282
   Progress: 10000/31709 | Best freq: 2499
   Progress: 11000/31709 | Best freq: 7294
   Progress: 12000/31709 | Best freq: 27518
   Progress: 13000/31709 | Best freq: 603
   Progress: 14000/31709 | Best freq: 624
   Progress: 15000/31709 | Best freq: 1903
   Progress: 16000/31709 | Best freq: 359
   Progress: 17000/31709 | Best freq: 334
   Progress: 18000/3

In [5]:
# Example Hindi and Bengali sentences
test_hi = f"‡§∏‡§∞‡§ï‡§æ‡§∞ ‡§®‡•á ‡§®‡§à ‡§®‡•Ä‡§§‡§ø‡§Ø‡§æ‡§Ç ‡§ò‡•ã‡§∑‡§ø‡§§ ‡§ï‡•Ä‡§Ç"
test_bn = f"‡¶è‡¶á ‡¶ú‡¶æ‡ßü‡¶ó‡¶æ‡¶ó‡ßÅ‡¶≤‡ßã ‡¶¶‡ßá‡¶ñ‡¶§‡ßá ‡¶≠‡ßÅ‡¶≤‡ßã ‡¶®‡¶æ ‡¶Ø‡ßá‡¶ñ‡¶æ‡¶®‡ßá ‡¶®‡¶∞‡ßç‡¶Æ‡¶¶‡¶æ ‡¶®‡¶¶‡ßÄ ‡¶Æ‡¶æ‡¶∞‡ßç‡¶¨‡ßá‡¶≤ ‡¶™‡¶æ‡¶•‡¶∞‡ßá‡¶∞ ‡¶™‡¶æ‡¶π‡¶æ‡ßú‡ßá‡¶∞ ‡¶Æ‡¶ß‡ßç‡¶Ø ‡¶¶‡¶ø‡ßü‡ßá ‡¶™‡ßç‡¶∞‡¶¨‡¶æ‡¶π‡¶ø‡¶§ ‡¶π‡¶ö‡ßç‡¶õ‡ßá ‡¶è‡¶¨‡¶Ç ‡¶®‡¶ø‡¶ú‡ßá‡¶∞ ‡¶∂‡¶æ‡¶®‡ßç‡¶§‡¶ø ‡¶ì ‡¶∏‡ßå‡¶®‡ßç‡¶¶‡¶∞‡ßç‡¶Ø‡¶ï‡ßá ‡¶Ö‡¶®‡¶æ‡¶∏‡¶ï‡ßç‡¶§‡¶ø‡¶§‡ßá ‡¶™‡¶∞‡¶ø‡¶£‡¶§ ‡¶ï‡¶∞‡¶õ‡ßá"

# Encode
hi_ids = tgt_tokenizer.encode(test_hi)
bn_ids = tgt_tokenizer.encode(test_bn)

# Convert IDs back to tokens
hi_tokens = [tgt_tokenizer.id2token[i] for i in hi_ids]
bn_tokens = [tgt_tokenizer.id2token[i] for i in bn_ids]

print("Hindi tokens:", hi_tokens)
print("Bengali tokens:", bn_tokens)

Hindi tokens: ['‚ñÅ‡§∏‡§∞‡§ï‡§æ‡§∞', '‚ñÅ‡§®', '‡•á', '‚ñÅ‡§®‡§à', '‚ñÅ‡§®‡•Ä‡§§‡§ø', '‡§Ø‡§æ‡§Ç', '‚ñÅ‡§ò‡•ã‡§∑', '‡§ø‡§§', '‚ñÅ‡§ï‡•Ä', '‡§Ç']
Bengali tokens: ['‚ñÅ‡¶è‡¶á', '‚ñÅ‡¶ú‡¶æ‡¶Ø‡¶º‡¶ó‡¶æ', '‡¶ó‡ßÅ', '‡¶≤‡ßã', '‚ñÅ‡¶¶‡ßá‡¶ñ‡¶§‡ßá', '‚ñÅ‡¶≠‡ßÅ‡¶≤', '‡ßã', '‚ñÅ‡¶®‡¶æ', '‚ñÅ‡¶Ø‡ßá‡¶ñ‡¶æ‡¶®‡ßá', '‚ñÅ‡¶®‡¶∞', '‡ßç‡¶Æ', '‡¶¶‡¶æ', '‚ñÅ‡¶®‡¶¶‡ßÄ', '‚ñÅ‡¶Æ‡¶æ', '‡¶∞‡ßç‡¶¨', '‡ßá‡¶≤', '‚ñÅ‡¶™‡¶æ‡¶•‡¶∞‡ßá‡¶∞', '‚ñÅ‡¶™‡¶æ‡¶π‡¶æ‡¶°‡¶º‡ßá‡¶∞', '‚ñÅ‡¶Æ‡¶ß‡ßç‡¶Ø', '‚ñÅ‡¶¶‡¶ø‡¶Ø‡¶º‡ßá', '‚ñÅ‡¶™‡ßç‡¶∞‡¶¨‡¶æ‡¶π‡¶ø‡¶§', '‚ñÅ‡¶π‡¶ö‡ßç‡¶õ‡ßá', '‚ñÅ‡¶è‡¶¨‡¶Ç', '‚ñÅ‡¶®‡¶ø‡¶ú‡ßá', '‡¶∞', '‚ñÅ‡¶∂‡¶æ', '‡¶®‡ßç‡¶§', '‡¶ø', '‚ñÅ‡¶ì', '‚ñÅ‡¶∏‡ßå‡¶®‡ßç‡¶¶‡¶∞‡ßç‡¶Ø', '‡¶ï‡ßá', '‚ñÅ‡¶Ö‡¶®‡¶æ', '‡¶∏‡¶ï', '‡ßç‡¶§‡¶ø', '‡¶§‡ßá', '‚ñÅ‡¶™‡¶∞‡¶ø‡¶£‡¶§', '‚ñÅ‡¶ï‡¶∞‡¶õ‡ßá']


In [5]:
class MultilingualNMTDataset(Dataset):
    
    def __init__(self, src_texts, tgt_texts, lang_tags, src_tokenizer, tgt_tokenizer, max_len=60):
        self.src_texts = src_texts
        self.tgt_texts = tgt_texts
        self.lang_tags = lang_tags
        self.src_tokenizer = src_tokenizer
        self.tgt_tokenizer = tgt_tokenizer
        self.max_len = max_len
    
    def __len__(self):
        return len(self.src_texts)
    
    def __getitem__(self, idx):
        src = self.src_texts[idx]
        tgt = self.tgt_texts[idx]  # NO language tag in text
        lang_tag = self.lang_tags[idx]
        
        # Encode source: BOS + content + EOS
        src_ids = self.src_tokenizer.encode(src, add_bos=True, add_eos=False)
        src_ids = src_ids[:self.max_len-1] + [self.src_tokenizer.eos_id]
        
        # Get language tag ID
        lang_tag_id = self.tgt_tokenizer.get_lang_tag_id(lang_tag)
        
        # Encode target WITHOUT language tag in text
        tgt_ids = self.tgt_tokenizer.encode(tgt, add_bos=False, add_eos=False)
        
        # Build target: BOS + LANG_TAG + content + EOS
        tgt_ids = ([self.tgt_tokenizer.bos_id, lang_tag_id] + 
                   tgt_ids[:self.max_len-3] + 
                   [self.tgt_tokenizer.eos_id])
        
        return torch.LongTensor(src_ids), torch.LongTensor(tgt_ids)

def collate_batch(batch):
    src_batch, tgt_batch = zip(*batch)
    src_batch = pad_sequence(src_batch, batch_first=True, 
                             padding_value=src_tokenizer.pad_id)
    tgt_batch = pad_sequence(tgt_batch, batch_first=True, 
                             padding_value=tgt_tokenizer.pad_id)
    return src_batch, tgt_batch


train_dataset = MultilingualNMTDataset(
    src_texts, tgt_texts, lang_tags,
    src_tokenizer, tgt_tokenizer,
    max_len=85
)

BATCH_SIZE = 64
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_batch,
    num_workers=4,
    pin_memory=True
)

print(f"\n‚úÖ Dataset created: {len(train_dataset)} samples")
print(f"   Batch size: {BATCH_SIZE}")
print(f"   Batches per epoch: {len(train_loader)}")



‚úÖ Dataset created: 149629 samples
   Batch size: 64
   Batches per epoch: 2338


In [6]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        pe = torch.zeros(max_len, d_model)
        position =torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                            (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

In [7]:
class TransformerNMT(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, 
                 nhead=8, num_encoder_layers=4, num_decoder_layers=4,
                 dim_feedforward=2048, dropout=0.1, max_len=512):
        super().__init__()
        
        self.d_model =d_model
        self.src_vocab_size =src_vocab_size
        self.tgt_vocab_size = tgt_vocab_size
        
        # Embeddings
        self.src_embedding =nn.Embedding(src_vocab_size, d_model, padding_idx=0)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model, padding_idx=0)
        
        # Positional encoding
        self.pos_encoder= PositionalEncoding(d_model, max_len, dropout)
        
        # Transformer
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation='gelu',
            batch_first=True
        )
        
        # Output projection
        self.fc_out =nn.Linear(d_model, tgt_vocab_size)
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        for p in self.parameters():
            if p.dim()>1:
                nn.init.xavier_uniform_(p)
    
    def forward(self, src, tgt, src_mask=None, tgt_mask=None, 
                src_padding_mask=None, tgt_padding_mask=None):
        # Embed and add positional encoding
        src_emb= self.pos_encoder(self.src_embedding(src) * math.sqrt(self.d_model))
        tgt_emb = self.pos_encoder(self.tgt_embedding(tgt) * math.sqrt(self.d_model))
        
        # Transformer forward
        output= self.transformer(
            src_emb, tgt_emb,
            src_mask=src_mask,
            tgt_mask=tgt_mask,
            memory_mask=None,
            src_key_padding_mask=src_padding_mask,
            tgt_key_padding_mask=tgt_padding_mask,
            memory_key_padding_mask=src_padding_mask
        )
        
        return self.fc_out(output)
    
    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask



In [9]:
# Add this cell to check your data
src_lens = [len(src_tokenizer.encode(s)) for s in src_texts[:1000]]
tgt_lens = [len(tgt_tokenizer.encode(t)) for t in tgt_texts[:1000]]

print(f"Source lengths:")
print(f"  Mean: {np.mean(src_lens):.1f}")
print(f"  Median: {np.median(src_lens):.0f}")
print(f"  95th percentile: {np.percentile(src_lens, 95):.0f}")
print(f"  Max: {max(src_lens)}")

print(f"\nTarget lengths:")
print(f"  Mean: {np.mean(tgt_lens):.1f}")
print(f"  Median: {np.median(tgt_lens):.0f}")
print(f"  95th percentile: {np.percentile(tgt_lens, 95):.0f}")
print(f"  Max: {max(tgt_lens)}")

print(f"\n% Source truncated at 60: {sum(1 for l in src_lens if l > 83)/len(src_lens)*100:.1f}%")
print(f"% Source truncated at 100: {sum(1 for l in src_lens if l > 98)/len(src_lens)*100:.1f}%")

Source lengths:
  Mean: 31.4
  Median: 29
  95th percentile: 62
  Max: 127

Target lengths:
  Mean: 27.2
  Median: 24
  95th percentile: 56
  Max: 115

% Source truncated at 60: 0.8%
% Source truncated at 100: 0.3%


In [8]:
# Initialize model
model = TransformerNMT(
    src_vocab_size=src_tokenizer.vocab_size,
    tgt_vocab_size=tgt_tokenizer.vocab_size,
    d_model=512,
    nhead=8,
    num_encoder_layers=4,
    num_decoder_layers=4,
    dim_feedforward=2048,
    dropout=0.1
).to(device)

print(f"   Parameters: {sum(p.numel() for p in model.parameters()):,}")


‚úÖ Model initialized
   Parameters: 47,547,366


In [9]:

PAD_IDX = tgt_tokenizer.pad_id
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX, label_smoothing=0.1)
optimizer = optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.98), eps=1e-9)

# Noam scheduler with state_dict support
class NoamScheduler:
    def __init__(self, optimizer, d_model, warmup_steps=4000):
        self.optimizer = optimizer
        self.d_model = d_model
        self.warmup_steps= warmup_steps
        self.step_num = 0
    
    def step(self):
        self.step_num += 1
        lr = self.d_model**(-0.5) * min(
            self.step_num**(-0.5),
            self.step_num *self.warmup_steps ** (-1.5)
        )
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        return lr

    def state_dict(self):
        return {
            'd_model': self.d_model,
            'warmup_steps': self.warmup_steps,
            'step_num': self.step_num
        }

    def load_state_dict(self, state_dict):
        self.d_model = state_dict['d_model']
        self.warmup_steps = state_dict['warmup_steps']
        self.step_num = state_dict['step_num']

scheduler = NoamScheduler(optimizer, d_model=512, warmup_steps=4000)
scaler = torch.amp.GradScaler('cuda')

print("‚úÖ Training setup complete with full checkpoint support")

‚úÖ Training setup complete with full checkpoint support


In [12]:
import os
import time
import torch
import numpy as np
from torch import amp
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm

# Training history for plots
training_history = {
    'epoch_loss': [],
    'epoch_ppl': [],
    'epoch_time': [],
    'batch_losses': []
}

best_loss = float('inf')
previous_ckpt = None  # To track and delete previous checkpoint

def train_epoch(model, loader, criterion, optimizer, scheduler, scaler, device, epoch, history):
    model.train()
    total_loss = 0
    epoch_start = time.time()
    
    pbar = tqdm(loader, desc=f"Epoch {epoch}")
    
    for batch_idx, (src, tgt) in enumerate(pbar):
        src, tgt = src.to(device), tgt.to(device)
        
        tgt_input = tgt[:, :-1]
        tgt_output = tgt[:, 1:]
        
        tgt_mask = model.generate_square_subsequent_mask(tgt_input.size(1)).to(device)
        src_padding_mask = (src == PAD_IDX).to(torch.bool)
        tgt_padding_mask = (tgt_input == PAD_IDX).to(torch.bool)
        
        with amp.autocast('cuda'):
            output = model(
                src, tgt_input,
                tgt_mask=tgt_mask,
                src_padding_mask = (src == PAD_IDX).float(),
tgt_padding_mask = (tgt_input == PAD_IDX).float()
            )
            loss = criterion(output.reshape(-1, output.size(-1)), tgt_output.reshape(-1))
        
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        
        scheduler.step()
        lr = optimizer.param_groups[0]['lr']
        
        history['batch_losses'].append(loss.item())
        total_loss += loss.item()
        pbar.set_postfix({
            'loss': loss.item(),
            'ppl': f'{np.exp(loss.item()):.2f}',
            'lr': f'{lr:.2e}'
        })
        
    epoch_time = time.time() - epoch_start
    avg_loss = total_loss / len(loader)
    avg_ppl = np.exp(avg_loss)
    
    history['epoch_loss'].append(avg_loss)
    history['epoch_ppl'].append(avg_ppl)
    history['epoch_time'].append(epoch_time)
    
    return avg_loss

# Training loop
NUM_EPOCHS = 15
print(f"\nüöÄ Starting training for {NUM_EPOCHS} epochs...")
start_time = time.time()

for epoch in range(1, NUM_EPOCHS + 1):
    avg_loss = train_epoch(model, train_loader, criterion, optimizer, 
                           scheduler, scaler, device, epoch, training_history)
    avg_ppl = np.exp(avg_loss)
    print(f"Epoch {epoch}/{NUM_EPOCHS} - Loss: {avg_loss:.4f} - PPL: {avg_ppl:.2f} - Time: {training_history['epoch_time'][-1]:.2f}s")
    
    # Save current checkpoint
    ckpt_path = f'model_epoch_{epoch}.pt'
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'scaler_state_dict': scaler.state_dict(),
        'loss': avg_loss,
        'history': training_history,
        'pad_idx': PAD_IDX,
        'torch_rng_state': torch.get_rng_state(),
        'cuda_rng_state': torch.cuda.get_rng_state(),
        'numpy_rng_state': np.random.get_state()
    }, ckpt_path)

    # Delete previous checkpoint
    if previous_ckpt and os.path.exists(previous_ckpt):
        os.remove(previous_ckpt)
    previous_ckpt = ckpt_path

    # Save best model
    if avg_loss < best_loss:
        best_loss = avg_loss
        torch.save(model.state_dict(), 'best_model.pt')

total_training_time = time.time() - start_time
print(f"‚úÖ Training complete! Total time: {total_training_time/3600:.2f} hours")



üöÄ Starting training for 15 epochs...


Epoch 1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2338/2338 [08:42<00:00,  4.47it/s, loss=5.13, ppl=168.38, lr=4.08e-04]


Epoch 1/15 - Loss: 6.2557 - PPL: 520.98 - Time: 522.64s


Epoch 2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2338/2338 [08:42<00:00,  4.47it/s, loss=4.61, ppl=100.54, lr=6.46e-04]


Epoch 2/15 - Loss: 4.8124 - PPL: 123.03 - Time: 522.68s


Epoch 3: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2338/2338 [08:43<00:00,  4.47it/s, loss=4.18, ppl=65.44, lr=5.28e-04]


Epoch 3/15 - Loss: 4.2545 - PPL: 70.42 - Time: 523.42s


Epoch 4: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2338/2338 [08:43<00:00,  4.47it/s, loss=3.82, ppl=45.77, lr=4.57e-04]


Epoch 4/15 - Loss: 3.8722 - PPL: 48.05 - Time: 523.43s


Epoch 5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2338/2338 [08:43<00:00,  4.47it/s, loss=3.72, ppl=41.27, lr=4.09e-04]


Epoch 5/15 - Loss: 3.6135 - PPL: 37.09 - Time: 523.14s


Epoch 6: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2338/2338 [08:42<00:00,  4.47it/s, loss=3.48, ppl=32.49, lr=3.73e-04]


Epoch 6/15 - Loss: 3.4176 - PPL: 30.50 - Time: 522.67s


Epoch 7: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2338/2338 [08:43<00:00,  4.46it/s, loss=3.17, ppl=23.90, lr=3.45e-04]


Epoch 7/15 - Loss: 3.2648 - PPL: 26.18 - Time: 523.74s


Epoch 8: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2338/2338 [08:43<00:00,  4.47it/s, loss=3.18, ppl=24.16, lr=3.23e-04]


Epoch 8/15 - Loss: 3.1420 - PPL: 23.15 - Time: 523.38s


Epoch 9: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2338/2338 [08:44<00:00,  4.46it/s, loss=3.02, ppl=20.57, lr=3.05e-04]


Epoch 9/15 - Loss: 3.0397 - PPL: 20.90 - Time: 524.34s


Epoch 10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2338/2338 [08:45<00:00,  4.45it/s, loss=2.97, ppl=19.52, lr=2.89e-04]


Epoch 10/15 - Loss: 2.9544 - PPL: 19.19 - Time: 525.08s


Epoch 11: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2338/2338 [08:43<00:00,  4.47it/s, loss=2.77, ppl=15.98, lr=2.76e-04]


Epoch 11/15 - Loss: 2.8805 - PPL: 17.82 - Time: 523.25s


Epoch 12: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2338/2338 [08:41<00:00,  4.48it/s, loss=2.87, ppl=17.64, lr=2.64e-04]


Epoch 12/15 - Loss: 2.8162 - PPL: 16.71 - Time: 521.55s


Epoch 13: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2338/2338 [08:41<00:00,  4.48it/s, loss=2.89, ppl=18.01, lr=2.53e-04]


Epoch 13/15 - Loss: 2.7600 - PPL: 15.80 - Time: 521.66s


Epoch 14: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2338/2338 [08:42<00:00,  4.48it/s, loss=2.84, ppl=17.07, lr=2.44e-04]


Epoch 14/15 - Loss: 2.7087 - PPL: 15.01 - Time: 522.01s


Epoch 15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2338/2338 [08:42<00:00,  4.47it/s, loss=2.72, ppl=15.21, lr=2.36e-04]


Epoch 15/15 - Loss: 2.6637 - PPL: 14.35 - Time: 522.86s
‚úÖ Training complete! Total time: 2.19 hours


In [10]:
import os
import time
import torch
import numpy as np
from torch import amp
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm


# This will be overwritten if a checkpoint is successfully loaded
training_history = {
    'epoch_loss': [],
    'epoch_ppl': [],
    'epoch_time': [],
    'batch_losses': []
}

best_loss = float('inf')
previous_ckpt = None  # To track and delete previous checkpoint

def train_epoch(model, loader, criterion, optimizer, scheduler, scaler, device, epoch, history):
    model.train()
    total_loss = 0
    epoch_start = time.time()
    
    pbar = tqdm(loader, desc=f"Epoch {epoch}")
    
    for batch_idx, (src, tgt) in enumerate(pbar):
        src, tgt = src.to(device), tgt.to(device)
        
        tgt_input = tgt[:, :-1]
        tgt_output = tgt[:, 1:]
        
        tgt_mask = model.generate_square_subsequent_mask(tgt_input.size(1)).to(device)
       
        src_padding_mask_bool = (src == PAD_IDX).to(torch.bool)
        tgt_padding_mask_bool = (tgt_input == PAD_IDX).to(torch.bool)
        
        with amp.autocast('cuda'):
            output = model(
                src, tgt_input,
                tgt_mask=tgt_mask,
              
                src_padding_mask = (src == PAD_IDX).float(), 
                tgt_padding_mask = (tgt_input == PAD_IDX).float()
            )
            loss = criterion(output.reshape(-1, output.size(-1)), tgt_output.reshape(-1))
        
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        
        scheduler.step()
        lr = optimizer.param_groups[0]['lr']
        
        history['batch_losses'].append(loss.item())
        total_loss += loss.item()
        pbar.set_postfix({
            'loss': loss.item(),
            'ppl': f'{np.exp(loss.item()):.2f}',
            'lr': f'{lr:.2e}'
        })
        
    epoch_time = time.time() - epoch_start
    avg_loss = total_loss / len(loader)
    avg_ppl = np.exp(avg_loss)
    
    history['epoch_loss'].append(avg_loss)
    history['epoch_ppl'].append(avg_ppl)
    history['epoch_time'].append(epoch_time)
    
    return avg_loss


CHECKPOINT_TO_LOAD = '/kaggle/input/nmtmodelbest/model_epoch_15_best.pt'
START_EPOCH = 1
TOTAL_EPOCHS = 25 

# --- Load from checkpoint ---
if os.path.exists(CHECKPOINT_TO_LOAD):
    print(f"üîÑ Loading checkpoint: {CHECKPOINT_TO_LOAD}")
   
    checkpoint = torch.load(CHECKPOINT_TO_LOAD, map_location=device, weights_only=False)
    # Load model, optimizer, scheduler, and scaler states
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    scaler.load_state_dict(checkpoint['scaler_state_dict'])
    
    # Restore training state variables
    START_EPOCH = checkpoint['epoch'] + 1
    training_history = checkpoint['history']
    
    # Restore best_loss from the loaded history
    if training_history['epoch_loss']:
        best_loss = min(training_history['epoch_loss'])
    
    # Restore RNG states for reproducibility
    torch.set_rng_state(checkpoint['torch_rng_state'].cpu())
    torch.cuda.set_rng_state(checkpoint['cuda_rng_state'].cpu())
    np.random.set_state(checkpoint['numpy_rng_state'])
    
    # Set the previous checkpoint path to the one we just loaded
    previous_ckpt = CHECKPOINT_TO_LOAD
    
    print(f"‚úÖ Checkpoint loaded. Resuming from epoch {START_EPOCH}.")
    print(f"   Current best loss from history: {best_loss:.4f}")

else:
    print(f"‚ö†Ô∏è Checkpoint '{CHECKPOINT_TO_LOAD}' not found. Starting from scratch.")
   
if START_EPOCH > TOTAL_EPOCHS:
    print(f"Model already trained for {START_EPOCH - 1} epochs. No further training needed to reach {TOTAL_EPOCHS} epochs.")
else:
    print(f"\nüöÄ Starting training from epoch {START_EPOCH} to {TOTAL_EPOCHS}...")
    start_time = time.time()

    for epoch in range(START_EPOCH, TOTAL_EPOCHS + 1):
        avg_loss = train_epoch(model, train_loader, criterion, optimizer, 
                               scheduler, scaler, device, epoch, training_history)
        avg_ppl = np.exp(avg_loss)
        print(f"Epoch {epoch}/{TOTAL_EPOCHS} - Loss: {avg_loss:.4f} - PPL: {avg_ppl:.2f} - Time: {training_history['epoch_time'][-1]:.2f}s")
        
        # Save current checkpoint
        ckpt_path = f'model_epoch_{epoch}.pt'
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'scaler_state_dict': scaler.state_dict(),
            'loss': avg_loss,
            'history': training_history,
            'pad_idx': PAD_IDX,
            'torch_rng_state': torch.get_rng_state(),
            'cuda_rng_state': torch.cuda.get_rng_state(),
            'numpy_rng_state': np.random.get_state()
        }, ckpt_path)

        # Delete previous checkpoint
        if previous_ckpt and os.path.exists(previous_ckpt):
            try:
                os.remove(previous_ckpt)
            except OSError as e:
               print(f"Warning: Could not delete previous checkpoint '{previous_ckpt}'. Error: {e}")
        previous_ckpt = ckpt_path

        # Save best model
        if avg_loss < best_loss:
            print(f"   üéâ New best model found! Loss improved from {best_loss:.4f} to {avg_loss:.4f}. Saving 'best_model.pt'.")
            best_loss = avg_loss
            torch.save(model.state_dict(), 'best_model.pt')

    total_training_time = time.time() - start_time
    print(f"‚úÖ Training complete! Total time for this session: {total_training_time/3600:.2f} hours")

üîÑ Loading checkpoint: /kaggle/input/nmtmodelbest/model_epoch_15_best.pt
‚úÖ Checkpoint loaded. Resuming from epoch 16.
   Current best loss from history: 2.6637

üöÄ Starting training from epoch 16 to 25...


Epoch 16: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2338/2338 [08:48<00:00,  4.42it/s, loss=2.7, ppl=14.94, lr=2.28e-04] 


Epoch 16/25 - Loss: 2.6230 - PPL: 13.78 - Time: 528.65s
   üéâ New best model found! Loss improved from 2.6637 to 2.6230. Saving 'best_model.pt'.


Epoch 17: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2338/2338 [08:48<00:00,  4.42it/s, loss=2.74, ppl=15.47, lr=2.22e-04]


Epoch 17/25 - Loss: 2.5846 - PPL: 13.26 - Time: 528.44s
   üéâ New best model found! Loss improved from 2.6230 to 2.5846. Saving 'best_model.pt'.


Epoch 18: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2338/2338 [08:48<00:00,  4.43it/s, loss=2.74, ppl=15.43, lr=2.15e-04]


Epoch 18/25 - Loss: 2.5504 - PPL: 12.81 - Time: 528.16s
   üéâ New best model found! Loss improved from 2.5846 to 2.5504. Saving 'best_model.pt'.


Epoch 19: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2338/2338 [08:48<00:00,  4.42it/s, loss=2.61, ppl=13.53, lr=2.10e-04]


Epoch 19/25 - Loss: 2.5188 - PPL: 12.41 - Time: 528.59s
   üéâ New best model found! Loss improved from 2.5504 to 2.5188. Saving 'best_model.pt'.


Epoch 20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2338/2338 [08:49<00:00,  4.41it/s, loss=2.61, ppl=13.55, lr=2.04e-04]


Epoch 20/25 - Loss: 2.4888 - PPL: 12.05 - Time: 529.85s
   üéâ New best model found! Loss improved from 2.5188 to 2.4888. Saving 'best_model.pt'.


Epoch 21: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2338/2338 [08:48<00:00,  4.42it/s, loss=2.61, ppl=13.60, lr=1.99e-04]


Epoch 21/25 - Loss: 2.4615 - PPL: 11.72 - Time: 528.37s
   üéâ New best model found! Loss improved from 2.4888 to 2.4615. Saving 'best_model.pt'.


Epoch 22: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2338/2338 [08:49<00:00,  4.42it/s, loss=2.32, ppl=10.14, lr=1.95e-04]


Epoch 22/25 - Loss: 2.4358 - PPL: 11.42 - Time: 529.22s
   üéâ New best model found! Loss improved from 2.4615 to 2.4358. Saving 'best_model.pt'.


Epoch 23: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2338/2338 [08:50<00:00,  4.41it/s, loss=2.42, ppl=11.27, lr=1.91e-04]


Epoch 23/25 - Loss: 2.4134 - PPL: 11.17 - Time: 530.31s
   üéâ New best model found! Loss improved from 2.4358 to 2.4134. Saving 'best_model.pt'.


Epoch 24:   6%|‚ñã         | 148/2338 [00:33<08:17,  4.40it/s, loss=2.32, ppl=10.15, lr=1.90e-04]


KeyboardInterrupt: 

In [11]:
print("\nüìÇ Loading validation data...")
with open('/kaggle/input/codabenchnmt/test_data1_final.json', 'r', encoding='utf-8') as f:
    val_data = json.load(f)

# Load best model
checkpoint = torch.load('/kaggle/working/best_model.pt')
model.load_state_dict(checkpoint) # <-- Correct
model.eval()


üìÇ Loading validation data...


TransformerNMT(
  (src_embedding): Embedding(8792, 512, padding_idx=0)
  (tgt_embedding): Embedding(13286, 512, padding_idx=0)
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-3): 4 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
          )
          (linear1): Linear(in_features=512, out_features=2048, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=2048, out_features=512, bias=True)
          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
      (norm): LayerNorm(

In [12]:
def translate_batch_greedy(src_texts, target_lang, model, src_tokenizer, tgt_tokenizer, 
                          device, batch_size=32, max_len=120):

    model.eval()
    all_translations = []
    
    lang_tag_id = tgt_tokenizer.get_lang_tag_id(target_lang)
    
    # Process in batches with progress bar
    num_batches = (len(src_texts) + batch_size - 1) // batch_size
    pbar = tqdm(range(0, len(src_texts), batch_size), 
                total=num_batches, 
                desc=f"Translating ({batch_size} per batch)")
    
    for i in pbar:
        batch_texts = src_texts[i:i+batch_size]
        
        # Tokenize batch
        src_ids_list = [src_tokenizer.encode(text, add_bos=True, add_eos=True) for text in batch_texts]
        src_batch = pad_sequence(
            [torch.LongTensor(ids) for ids in src_ids_list],
            batch_first=True,
            padding_value=src_tokenizer.pad_id
        ).to(device)
        
        with torch.no_grad():
            # Encode all sources at once
            # Use .float() to match training code
            src_padding_mask = (src_batch == src_tokenizer.pad_id).float()
            src_emb = model.pos_encoder(
                model.src_embedding(src_batch) * math.sqrt(model.d_model)
            )
            memory = model.transformer.encoder(src_emb, src_key_padding_mask=src_padding_mask)
            
            # Initialize decoder with BOS + LANG_TAG
            batch_sz = len(batch_texts)
            tgt_ids = torch.full((batch_sz, 2), tgt_tokenizer.pad_id, dtype=torch.long, device=device)
            tgt_ids[:, 0] = tgt_tokenizer.bos_id
            tgt_ids[:, 1] = lang_tag_id
            
            finished = torch.zeros(batch_sz, dtype=torch.bool, device=device)
            
            # Greedy decode
            for step in range(max_len - 2):
                tgt_emb = model.pos_encoder(
                    model.tgt_embedding(tgt_ids) * math.sqrt(model.d_model)
                )
                
                # MATCH TRAINING: Use same mask types as training
                tgt_mask = model.generate_square_subsequent_mask(tgt_ids.size(1)).to(device)
                
                # Use .float() for padding masks (same as training)
                tgt_padding_mask = (tgt_ids == tgt_tokenizer.pad_id).float()
                memory_padding_mask = (src_batch == src_tokenizer.pad_id).float()
                
                output = model.transformer.decoder(
                    tgt_emb, memory,
                    tgt_mask=tgt_mask,
                    tgt_key_padding_mask=tgt_padding_mask,
                    memory_key_padding_mask=memory_padding_mask
                )
                
                logits = model.fc_out(output[:, -1, :])  # Last position
                next_tokens = logits.argmax(dim=-1)
                
                # Mark finished sequences
                next_tokens = torch.where(
                    finished,
                    torch.full_like(next_tokens, tgt_tokenizer.pad_id),
                    next_tokens
                )
                finished |= (next_tokens == tgt_tokenizer.eos_id)
                
                # Append next tokens
                tgt_ids = torch.cat([tgt_ids, next_tokens.unsqueeze(1)], dim=1)
                
                # Stop if all sequences finished
                if finished.all():
                    break
            
            # Decode all sequences
            for seq in tgt_ids:
                translation = tgt_tokenizer.decode(seq.tolist(), skip_special_tokens=True)
                all_translations.append(translation)
        
        # Update progress bar
        pbar.set_postfix({'total': len(all_translations)})
    
    return all_translations

print("‚úÖ Fast batched decoding ready!")

‚úÖ Fast batched decoding ready!


In [None]:
predictions = []


print("\nüîÑ Translating Bengali validation set (batched)...")
val_bengali = val_data.get("English-Bengali", {}).get("Test", {})
bn_ids = list(val_bengali.keys())
bn_texts = [val_bengali[idx]["source"] for idx in bn_ids]

bn_translations = translate_batch_beam(
    bn_texts, LANG_TAG_BN, model,
    src_tokenizer, tgt_tokenizer, device,
    batch_size=32, beam_size=3  
)

for idx, translation in zip(bn_ids, bn_translations):
    predictions.append({
        'id': int(idx),
        'prediction': translation
    })

print(f"‚úÖ Bengali done: {len(bn_translations)} translations")


print("\nüîÑ Translating Hindi validation set (batched)...")
val_hindi = val_data.get("English-Hindi", {}).get("Test", {})
hi_ids = list(val_hindi.keys())
hi_texts = [val_hindi[idx]["source"] for idx in hi_ids]


hi_translations = translate_batch_beam(
    hi_texts, LANG_TAG_HI, model,
    src_tokenizer, tgt_tokenizer, device,
    batch_size=32, beam_size=3
)

for idx, translation in zip(hi_ids, hi_translations):
    predictions.append({
        'id': int(idx),
        'prediction': translation
    })


import pandas as pd
df = pd.DataFrame(predictions)
# NO SORTING - keep order as Bengali first, then Hindi
df.to_csv('predictions_multilingual.csv', index=False)
print(f"\n‚úÖ Predictions saved: {len(predictions)} translations")
print(f"   Order: Bengali first, then Hindi (as in val_data1.json)")



üîÑ Translating Bengali validation set (batched)...


Beam Search (32 sent, 3 beams):  41%|‚ñà‚ñà‚ñà‚ñà      | 251/615 [08:34<14:26,  2.38s/it, total=8032]

In [16]:
def translate_batch_beam(src_texts, target_lang, model, src_tokenizer, tgt_tokenizer, 
                         device, batch_size=16, beam_size=5, max_len=100):
   
    
    model.eval()
    all_translations = []
    
    lang_tag_id = tgt_tokenizer.get_lang_tag_id(target_lang)
    
    # Process in batches with progress bar
    num_batches = (len(src_texts) + batch_size - 1) // batch_size
    pbar = tqdm(range(0, len(src_texts), batch_size), 
                total=num_batches, 
                desc=f"Beam Search ({batch_size} sent, {beam_size} beams)")
    
    for i in pbar:
        batch_texts = src_texts[i:i+batch_size]
        batch_sz = len(batch_texts)
        
        # Tokenize batch
        src_ids_list = [src_tokenizer.encode(text, add_bos=True, add_eos=True) for text in batch_texts]
        src_batch = pad_sequence(
            [torch.LongTensor(ids) for ids in src_ids_list],
            batch_first=True,
            padding_value=src_tokenizer.pad_id
        ).to(device)
        
        with torch.no_grad():
            # Encode sources
            src_padding_mask = (src_batch == src_tokenizer.pad_id).float()
            src_emb = model.pos_encoder(
                model.src_embedding(src_batch) * math.sqrt(model.d_model)
            )
            memory = model.transformer.encoder(src_emb, src_key_padding_mask=src_padding_mask)
            
            # Expand memory and masks for beam search: (batch_sz*beam_size, seq_len, d_model)
            memory_beam = memory.unsqueeze(1).repeat(1, beam_size, 1, 1).view(batch_sz * beam_size, -1, model.d_model)
            src_padding_mask_beam = src_padding_mask.unsqueeze(1).repeat(1, beam_size, 1).view(batch_sz * beam_size, -1)
            
            # Initialize beams: (batch_sz, beam_size, seq_len)
            # Start with [BOS, LANG_TAG]
            beams = torch.full((batch_sz, beam_size, 2), tgt_tokenizer.pad_id, dtype=torch.long, device=device)
            beams[:, :, 0] = tgt_tokenizer.bos_id
            beams[:, :, 1] = lang_tag_id
            
            # Beam scores: (batch_sz, beam_size) - log probabilities
            beam_scores = torch.zeros(batch_sz, beam_size, device=device)
            beam_scores[:, 1:] = -1e9  # Only first beam is active initially
            
            # Track finished beams
            finished = torch.zeros(batch_sz, beam_size, dtype=torch.bool, device=device)
            
            # Decode step by step
            for step in range(max_len - 2):
                # Reshape beams for decoding: (batch_sz*beam_size, current_len)
                current_len = beams.size(2)
                beams_flat = beams.view(batch_sz * beam_size, current_len)
                
                # Embed and decode
                tgt_emb = model.pos_encoder(
                    model.tgt_embedding(beams_flat) * math.sqrt(model.d_model)
                )
                
                tgt_mask = model.generate_square_subsequent_mask(current_len).to(device)
                tgt_padding_mask = (beams_flat == tgt_tokenizer.pad_id).float()
                
                output = model.transformer.decoder(
                    tgt_emb, memory_beam,
                    tgt_mask=tgt_mask,
                    tgt_key_padding_mask=tgt_padding_mask,
                    memory_key_padding_mask=src_padding_mask_beam
                )
                
                # Get logits for next token: (batch_sz*beam_size, vocab_size)
                logits = model.fc_out(output[:, -1, :])
                log_probs = F.log_softmax(logits, dim=-1)
                
                # Reshape: (batch_sz, beam_size, vocab_size)
                log_probs = log_probs.view(batch_sz, beam_size, -1)
                
                # Add to beam scores: (batch_sz, beam_size, vocab_size)
                # only allow PAD token
                vocab_size = log_probs.size(-1)
                scores = beam_scores.unsqueeze(2) + log_probs
                
                # Mask finished beams: force them to generate PAD
                finished_mask = finished.unsqueeze(2).expand(-1, -1, vocab_size)
                scores = scores.masked_fill(finished_mask, -1e9)
                scores[:, :, tgt_tokenizer.pad_id] = scores[:, :, tgt_tokenizer.pad_id].masked_fill(finished, 0)
                
                # Flatten and get top beam_size candidates: (batch_sz, beam_size*vocab_size)
                scores_flat = scores.view(batch_sz, -1)
                
                # Get top beam_size candidates
                top_scores, top_indices = scores_flat.topk(beam_size, dim=1)
                
                # Convert flat indices to (beam_idx, token_idx)
                prev_beam_idx = top_indices // vocab_size  # Which beam did this come from
                next_token_idx = top_indices % vocab_size  # Which token to append
                
                # Update beams
                # Gather previous beams: (batch_sz, beam_size, current_len)
                gathered_beams = torch.gather(
                    beams, 1, 
                    prev_beam_idx.unsqueeze(2).expand(-1, -1, current_len)
                )
                
                # Append new tokens: (batch_sz, beam_size, current_len+1)
                beams = torch.cat([gathered_beams, next_token_idx.unsqueeze(2)], dim=2)
                
                # Update scores
                beam_scores = top_scores
                
                # Update finished status
                finished = torch.gather(finished, 1, prev_beam_idx)
                finished |= (next_token_idx == tgt_tokenizer.eos_id)
                
                # Early stopping if all beams finished
                if finished.all():
                    break
            
            #  length penalty (avoid too short translations)
            # Normalized by length: score / (length ** alpha)
            # alpha=0.6 i
            lengths = beams.ne(tgt_tokenizer.pad_id).sum(dim=2).float()  # (batch_sz, beam_size)
            length_penalty = torch.pow(lengths, 0.6)
            normalized_scores = beam_scores / length_penalty
            
          
            best_beam_idx = normalized_scores.argmax(dim=1)  # (batch_sz,)
            best_beams = beams[torch.arange(batch_sz), best_beam_idx]  # (batch_sz, seq_len)
            
            # Decode
            for seq in best_beams:
                translation = tgt_tokenizer.decode(seq.tolist(), skip_special_tokens=True)
                all_translations.append(translation)
        
        pbar.set_postfix({'total': len(all_translations)})
    
    return all_translations

print("‚úÖ Fast batched beam search ready!")


‚úÖ Fast batched beam search ready!


In [17]:
USE_BEAM_SEARCH = True  

if USE_BEAM_SEARCH:
    DECODE_BATCH_SIZE = 32
    BEAM_SIZE = 5           
    decode_fn = translate_batch_beam
    print(f"\n‚öôÔ∏è  Using BATCHED BEAM SEARCH (batch={DECODE_BATCH_SIZE}, beams={BEAM_SIZE})")
else:
    DECODE_BATCH_SIZE = 64 
    BEAM_SIZE = 1
    decode_fn = translate_batch_greedy
    print(f"\n‚öôÔ∏è  Using BATCHED GREEDY DECODING (batch={DECODE_BATCH_SIZE}) - FASTEST")


‚öôÔ∏è  Using BATCHED BEAM SEARCH (batch=32, beams=5)


In [18]:
df = pd.read_csv('/kaggle/working/predictions_multilingual.csv')

print(f"\nüìä Prediction Summary:")
print(f"   Total predictions: {len(df)}")
print(f"   ID range: {df['id'].min()} - {df['id'].max()}")
print(f"   First 5 IDs: {df['id'].head().tolist()}")
print(f"   Last 5 IDs: {df['id'].tail().tolist()}")

with open('answer.csv', 'w', encoding='utf-8') as f:
  
    f.write("ID\tTranslation\n")
 
    for _, row in df.iterrows():
        f.write(f'{row["id"]}\t"{row["prediction"]}"\n')


print(f"   Format: Tab-separated with quoted translations")
print(f"   Order: Preserved from val_data1.json (Bengali ‚Üí Hindi)")
print(f"   Ready for Codabench upload!")

with open('answer.csv', 'r', encoding='utf-8') as f:
    for i, line in enumerate(f):
        if i < 3:
            print(f"   {line.rstrip()}")


üìä Prediction Summary:
   Total predictions: 42757
   ID range: 177039 - 563223
   First 5 IDs: [177039, 177040, 177041, 177042, 177043]
   Last 5 IDs: [563219, 563220, 563221, 563222, 563223]

‚úÖ Submission file created: answer.csv
   Format: Tab-separated with quoted translations
   Order: Preserved from val_data1.json (Bengali ‚Üí Hindi)
   Ready for Codabench upload!

üîç First 3 lines of answer.csv:
   ID	Translation
   177039	"‡¶¨‡¶∞‡ßç‡¶§‡¶Æ‡¶æ‡¶® ‡¶á‡¶≠‡ßá‡¶®‡ßç‡¶ü‡¶ó‡ßÅ‡¶≤‡¶ø ‡¶¨‡¶∞‡ßç‡¶§‡¶Æ‡¶æ‡¶® ‡¶á‡¶≠‡ßá‡¶®‡ßç‡¶ü‡¶ó‡ßÅ‡¶≤‡¶ø‡¶§‡ßá ‡¶¨‡¶∞‡ßç‡¶§‡¶Æ‡¶æ‡¶® ‡¶á‡¶≠‡ßá‡¶®‡ßç‡¶ü‡¶ó‡ßÅ‡¶≤‡¶ø ‡¶¨‡¶∞‡ßç‡¶§‡¶Æ‡¶æ‡¶® ‡¶á‡¶≠‡ßá‡¶®‡ßç‡¶ü‡¶ó‡ßÅ‡¶≤‡¶ø‡¶∞ ‡¶¨‡¶∞‡ßç‡¶§‡¶Æ‡¶æ‡¶® ‡¶á‡¶≠‡ßá‡¶®‡ßç‡¶ü‡¶ó‡ßÅ‡¶≤‡¶ø‡¶∞ ‡¶Æ‡¶ß‡ßç‡¶Ø‡ßá ‡¶¨‡¶∞‡ßç‡¶§‡¶Æ‡¶æ‡¶® ‡¶á‡¶≠‡ßá‡¶®‡ßç‡¶ü‡¶ó‡ßÅ‡¶≤‡¶ø‡¶∞ ‡¶è‡¶ï‡¶ü‡¶ø ‡¶π‡¶≤ ‡¶¨‡¶∞‡ßç‡¶§‡¶Æ‡¶æ‡¶® ‡¶á‡¶≠‡ßá‡¶®‡ßç‡¶ü‡•§"
   177040	"‡¶≠‡¶ó‡¶¨‡¶æ‡¶® ‡¶¨‡ßç‡¶∞‡¶π‡ßç‡¶Æ‡¶æ ‡¶§‡¶æ‡¶∞ ‡¶§‡¶™‡¶∏‡ßç‡¶Ø‡¶æ‡¶∞ ‡¶ú‡¶®‡ßç‡¶Ø ‡¶¶‡¶Ø‡¶º‡¶æ ‡¶ï‡¶∞‡ßá ‡¶ï‡¶ø‡¶®‡ßç‡¶§‡ßÅ ‡¶§‡¶æ‡¶