# English-to-Urdu Machine Translation with a Transformer from Scratch

This notebook implements a complete machine translation pipeline to translate text from English to Urdu. The core of this project is a Transformer model built from scratch using PyTorch, based on the architecture originally proposed by Vaswani et al. in "Attention is All You Need."

### Project Objectives

1.  **Build a Transformer Model**: Construct the encoder-decoder architecture from the ground up.
2.  **Train on a Parallel Corpus**: Preprocess and train the model on a curated English-Urdu dataset.
3.  **Evaluate Performance**: Measure translation quality using the BLEU score and analyze sample translations.

## 1. Environment Setup

This section configures the environment by importing necessary libraries and setting up the device for training. Key steps include:

-   **Importing Libraries**: Core libraries such as `torch`, `pandas`, and `numpy` are imported.
-   **Device Configuration**: The code detects and selects the appropriate device (`MPS` for Apple Silicon, `CUDA` for NVIDIA GPUs, or `CPU`). A fallback to the CPU is enabled for MPS to prevent errors with unsupported operations.
-   **Directory Creation**: Directories for saving model checkpoints and results are created to keep the project organized.

In [None]:
# Core libraries
import os
import math
import time
import random
import warnings
import json
import unicodedata
from pathlib import Path
from collections import Counter

# Fix for MPS unsupported operations - enable CPU fallback
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

# Data handling
import pandas as pd
import numpy as np

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

# Tokenization and NLP
import nltk
from nltk.translate.bleu_score import sentence_bleu, corpus_bleu, SmoothingFunction
from nltk.translate.meteor_score import meteor_score
import re

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Progress tracking
from tqdm.auto import tqdm

# Suppress warnings
warnings.filterwarnings('ignore')

# Download NLTK data for BLEU and METEOR
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt', quiet=True)

try:
    nltk.data.find('corpora/wordnet')
except LookupError:
    nltk.download('wordnet', quiet=True)

try:
    nltk.data.find('omw-1.4')
except LookupError:
    nltk.download('omw-1.4', quiet=True)

print("All libraries imported successfully!")
print(" MPS CPU fallback enabled for unsupported operations")
print(" NLTK data downloaded (punkt, wordnet, omw)")

In [None]:
# Set random seeds for reproducibility
# SEED = 42
# random.seed(SEED)
# np.random.seed(SEED)
# torch.manual_seed(SEED)
# if torch.backends.mps.is_available():
#     torch.mps.manual_seed(SEED)

# Configure device - prioritize MPS for Apple Silicon
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print(" Using MPS (Metal Performance Shaders) - Apple Silicon GPU")
    print(f"  PyTorch version: {torch.__version__}")
    print(f"  MPS backend available: {torch.backends.mps.is_available()}")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print(" Using CUDA")
    print(f"  Device: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print(" Using CPU")

print(f"\nActive device: {device}")

# Create directories for checkpoints and results
os.makedirs("checkpoints", exist_ok=True)
os.makedirs("results", exist_ok=True)
print("\n Directory structure created")

## 2. Model and Training Configuration

This section centralizes all hyperparameters and configuration settings for the model, data, and training process.

-   **`config` Dictionary**: A dictionary that stores all key parameters for easy access and modification.
-   **Data Paths**: Specifies the file paths for the English and Urdu source texts.
-   **Model Architecture**: Defines the core parameters of the Transformer, including embedding dimensions, number of attention heads, and layer counts.
-   **Training Hyperparameters**: Sets the learning rate, batch size, and number of training epochs.

In [None]:
# Model and training configuration
config = {
    # Transformer architecture
    'd_model': 512,                     # Embedding dimension for token representations
    'nhead': 8,                         # Number of attention heads
    'num_encoder_layers': 4,            # Encoder stack depth
    'num_decoder_layers': 4,            # Decoder stack depth
    'dim_feedforward': 2048,            # Feed-forward hidden size
    'dropout': 0.3,                     # Dropout probability throughout the model
    'max_seq_length': 128,              # Maximum sequence length supported by positional encoding

    # Training parameters
    'batch_size': 128,                  # Mini-batch size for data loaders
    'gradient_accumulation_steps': 2,   # Steps to accumulate gradients before each optimizer step
    'num_epochs': 100,                  # Maximum number of training epochs
    'learning_rate': 5e-6,              # Initial learning rate prior to scheduling
    'warmup_steps': 1000,               # Warmup steps for the Noam scheduler
    'max_grad_norm': 1.0,               # Gradient clipping threshold
    'early_stopping_patience': 3,       # Early-stopping patience measured in epochs
    'weight_decay': 0.0001,             # L2 regularization 10x

    # Vocabulary
    'vocab_size_en': 10000,             # Maximum vocabulary size for English tokens
    'vocab_size_ur': 10000,             # Maximum vocabulary size for Urdu tokens

    # Special tokens
    'pad_token': '<pad>',
    'sos_token': '<sos>',
    'eos_token': '<eos>',
    'unk_token': '<unk>',

    # Checkpoint settings
    'checkpoint_dir': 'checkpoints',
    'save_every_n_epochs': 20,          # Frequency (in epochs) for periodic checkpointing

    # Evaluation
    'beam_size': 5,                     # Beam width used during evaluation when beam search is enabled
}

# Print configuration summary
print("=" * 60)
print("MODEL CONFIGURATION")
print("=" * 60)
for key, value in config.items():
    print(f"{key:30s}: {value}")
print("=" * 60)
print("\nKey training utilities:")
print("   Gradient accumulation for stable updates")
print("   Early stopping to avoid overfitting")
print("   Regular checkpointing for reproducibility")
print("=" * 60)

## 3. Data Loading and Preprocessing

This section covers loading the parallel corpus and applying initial cleaning steps.

-   **Load Data**: The English and Urdu text files are loaded into a `pandas` DataFrame.
-   **Clean Text**:
    -   Unicode normalization is applied to standardize characters.
    -   Leading/trailing whitespace and special characters are removed to reduce noise.
-   **Verify**: The cleaned DataFrame is displayed to confirm the preprocessing was successful.

In [None]:
# Dataset paths
DATA_DIR = Path("data")
EN_FILE = DATA_DIR / "english-corpus.txt"
UR_FILE = DATA_DIR / "urdu-corpus.txt"


def load_parallel_corpus(file_path=None, en_file=None, ur_file=None, max_samples=None):
    """Load an English–Urdu parallel corpus from text files or a CSV."""
    pairs = []

    if en_file and ur_file and en_file.exists() and ur_file.exists():
        with open(en_file, "r", encoding="utf-8") as f_en, open(ur_file, "r", encoding="utf-8") as f_ur:
            for en_line, ur_line in zip(f_en, f_ur):
                en = en_line.strip()
                ur = ur_line.strip()
                if en and ur:
                    pairs.append((en, ur))
                    if max_samples and len(pairs) >= max_samples:
                        break
        return pairs

    if file_path and file_path.exists():
        df = pd.read_csv(file_path)

        en_col = next((col for col in df.columns if col.lower().strip() in {"english", "en", "source", "src"}), None)
        ur_col = next((col for col in df.columns if col.lower().strip() in {"urdu", "ur", "target", "tgt"}), None)
        if en_col is None or ur_col is None:
            raise ValueError("Could not identify English and Urdu columns in the CSV file.")

        for _, row in df.iterrows():
            en = str(row[en_col]).strip()
            ur = str(row[ur_col]).strip()
            if en and ur:
                pairs.append((en, ur))
                if max_samples and len(pairs) >= max_samples:
                    break
        return pairs

    raise FileNotFoundError("Dataset not found. Ensure parallel text files or a CSV are placed in the data directory.")


# Load the dataset
parallel_data = load_parallel_corpus(en_file=EN_FILE, ur_file=UR_FILE, max_samples=None)
print(f"Loaded {len(parallel_data):,} parallel sentence pairs")

# Display dataset statistics
if parallel_data:
    print("\n" + "=" * 60)
    print("DATASET STATISTICS")
    print("=" * 60)
    print(f"Total pairs: {len(parallel_data):,}")

    en_lengths = [len(en.split()) for en, _ in parallel_data]
    ur_lengths = [len(ur.split()) for _, ur in parallel_data]

    print(f"\nEnglish sentences:")
    print(f"  Average length: {np.mean(en_lengths):.2f} words")
    print(f"  Max length: {max(en_lengths)} words")
    print(f"  Min length: {min(en_lengths)} words")
    print(f"  Median length: {np.median(en_lengths):.2f} words")

    print(f"\nUrdu sentences:")
    print(f"  Average length: {np.mean(ur_lengths):.2f} words")
    print(f"  Max length: {max(ur_lengths)} words")
    print(f"  Min length: {min(ur_lengths)} words")
    print(f"  Median length: {np.median(ur_lengths):.2f} words")

    # Show length distribution
    long_en = sum(1 for l in en_lengths if l > config['max_seq_length'])
    long_ur = sum(1 for l in ur_lengths if l > config['max_seq_length'])
    print(f"\nLength distribution:")
    print(f"  Sentences > {config['max_seq_length']} words (EN): {long_en} ({long_en/len(en_lengths)*100:.1f}%)")
    print(f"  Sentences > {config['max_seq_length']} words (UR): {long_ur} ({long_ur/len(ur_lengths)*100:.1f}%)")

    print("\nSample pairs:")
    for i in range(min(5, len(parallel_data))):
        en, ur = parallel_data[i]
        print(f"  Pair {i+1} EN: {en}")
        print(f"           UR: {ur}")

In [None]:
def normalize_urdu(text):
    """
    Normalize Urdu text for better processing.
    
    Args:
        text: Urdu text string
    
    Returns:
        Normalized Urdu text
    """
    # Normalize Unicode characters (NFKC normalization)
    text = unicodedata.normalize('NFKC', text)
    
    # Remove Arabic diacritics (harakat) which can cause tokenization issues
    # Range: U+064B to U+065F and U+0670
    text = re.sub(r'[\u064B-\u065F\u0670]', '', text)
    
    # Remove extra whitespace
    text = re.sub(r'\s+', ' ', text).strip()
    
    return text


def normalize_text(text, is_urdu=False):
    """
    Normalize and clean text for English or Urdu.
    
    Args:
        text: Input text string
        is_urdu: Whether the text is in Urdu
    
    Returns:
        Cleaned text
    """
    if is_urdu:
        # Use specialized Urdu normalization
        return normalize_urdu(text)
    else:
        # English normalization
        # Remove extra whitespace
        text = re.sub(r'\s+', ' ', text).strip()
        # Convert to lowercase
        text = text.lower()
        # Remove special characters but keep basic punctuation
        text = re.sub(r'[^a-z0-9\s.,!?\'-]', '', text)
    
    return text


def simple_tokenize(text, is_urdu=False):
    """
    Simple word-level tokenization.
    For production, use more sophisticated tokenizers like sentencepiece or transformers.
    
    Args:
        text: Input text
        is_urdu: Whether the text is Urdu
    
    Returns:
        List of tokens
    """
    text = normalize_text(text, is_urdu)
    
    # Split on whitespace and punctuation
    tokens = re.findall(r'\S+', text)
    
    return tokens

# Test tokenization
print("Testing tokenization:")
print("=" * 60)

test_en = "Hello, how are you today?"
test_ur = "ہیلو، آپ آج کیسے ہیں؟"

en_tokens = simple_tokenize(test_en, is_urdu=False)
ur_tokens = simple_tokenize(test_ur, is_urdu=True)

print(f"English: {test_en}")
print(f"Tokens:  {en_tokens}")
print(f"Count:   {len(en_tokens)}")

print(f"\nUrdu: {test_ur}")
print(f"Tokens: {ur_tokens}")
print(f"Count:  {len(ur_tokens)}")

# Test normalization
print(f"\nUrdu normalization test:")
test_ur_with_diacritics = "مَرحَبا"  # With diacritics
normalized = normalize_urdu(test_ur_with_diacritics)
print(f"  Original:   {test_ur_with_diacritics}")
print(f"  Normalized: {normalized}")
print(f"  Length before: {len(test_ur_with_diacritics)}, after: {len(normalized)}")

print("=" * 60)
print(" Tokenization and normalization ready!")
print("  • Urdu text undergoes Unicode normalization (NFKC)")
print("  • Arabic diacritics are removed for consistency")
print("  • English text is lowercased and cleaned")

## 4. Vocabulary and Tokenization

This section defines the `Vocabulary` class, which is responsible for converting text into numerical tokens that the model can process.

-   **Build Vocabulary**: The class is initialized with raw text sentences and builds a mapping from words to integer indices. Special tokens like `<PAD>`, `<SOS>`, `<EOS>`, and `<UNK>` are included.
-   **Tokenize and Detokenize**: Methods are provided to convert sentences into sequences of tokens (`numericalize`) and back into human-readable text (`denumericalize`).

In [None]:
class Vocabulary:
    """Build vocabulary from corpus with special tokens."""
    
    def __init__(self, max_size=10000, min_freq=2):
        self.max_size = max_size
        self.min_freq = min_freq
        self.token2idx = {}
        self.idx2token = {}
        self.token_freq = Counter()
        
        # Add special tokens
        self.pad_token = config['pad_token']
        self.sos_token = config['sos_token']
        self.eos_token = config['eos_token']
        self.unk_token = config['unk_token']
        
        # Initialize with special tokens
        special_tokens = [self.pad_token, self.sos_token, self.eos_token, self.unk_token]
        for token in special_tokens:
            self.add_token(token)
    
    def add_token(self, token):
        """Add a token to vocabulary."""
        if token not in self.token2idx:
            idx = len(self.token2idx)
            self.token2idx[token] = idx
            self.idx2token[idx] = token
    
    def build_vocab(self, sentences, is_urdu=False):
        """Build vocabulary from list of sentences."""
        # Count token frequencies
        for sentence in sentences:
            tokens = simple_tokenize(sentence, is_urdu)
            self.token_freq.update(tokens)
        
        # Add most frequent tokens (excluding special tokens)
        most_common = self.token_freq.most_common(self.max_size - 4)  # -4 for special tokens
        
        for token, freq in most_common:
            if freq >= self.min_freq and token not in self.token2idx:
                self.add_token(token)
        
        print(f"Built vocabulary with {len(self.token2idx)} tokens")
        print(f"  Most common: {self.token_freq.most_common(10)}")
    
    def encode(self, text, is_urdu=False):
        """Convert text to list of token indices."""
        tokens = simple_tokenize(text, is_urdu)
        indices = [self.token2idx.get(token, self.token2idx[self.unk_token]) for token in tokens]
        return indices
    
    def decode(self, indices):
        """Convert list of indices back to text."""
        tokens = [self.idx2token.get(idx, self.unk_token) for idx in indices]
        # Remove special tokens for display
        tokens = [t for t in tokens if t not in [self.pad_token, self.sos_token, self.eos_token]]
        return ' '.join(tokens)
    
    def __len__(self):
        return len(self.token2idx)

# Build vocabularies
print("Building vocabularies...")
print("=" * 60)

# Extract English and Urdu sentences
english_sentences = [en for en, _ in parallel_data]
urdu_sentences = [ur for _, ur in parallel_data]

# Create vocabularies
vocab_en = Vocabulary(max_size=config['vocab_size_en'], min_freq=1)
vocab_ur = Vocabulary(max_size=config['vocab_size_ur'], min_freq=1)

vocab_en.build_vocab(english_sentences, is_urdu=False)
print()
vocab_ur.build_vocab(urdu_sentences, is_urdu=True)

print("\n" + "=" * 60)
print("VOCABULARY STATISTICS")
print("=" * 60)
print(f"English vocabulary size: {len(vocab_en)}")
print(f"Urdu vocabulary size: {len(vocab_ur)}")

# Test encoding/decoding
test_sentence = english_sentences[0] if english_sentences else "hello world"
encoded = vocab_en.encode(test_sentence, is_urdu=False)
decoded = vocab_en.decode(encoded)

print(f"\nTest encoding/decoding:")
print(f"  Original: {test_sentence}")
print(f"  Encoded:  {encoded[:20]}...")  # Show first 20 indices
print(f"  Decoded:  {decoded}")
print("=" * 60)

## 5. Custom `Dataset` and `DataLoader`

This section prepares the data for training by creating a custom `Dataset` and wrapping it in a `DataLoader`.

-   **`TranslationDataset`**: A custom class that holds tokenized English and Urdu sentence pairs.
-   **Collation Function**: The `collate_fn` pads sequences to the same length within each batch, ensuring uniform tensor shapes.
-   **`DataLoader`**: A `DataLoader` is created to efficiently batch and shuffle the data, making it ready for the training loop.

In [None]:
class TranslationDataset(Dataset):
    """Dataset for English-Urdu parallel sentences."""

    def __init__(self, pairs, vocab_src, vocab_tgt, max_length=100):
        self.pairs = pairs
        self.vocab_src = vocab_src
        self.vocab_tgt = vocab_tgt
        self.max_length = max_length

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        src_text, tgt_text = self.pairs[idx]

        # Encode source and target
        src_indices = self.vocab_src.encode(src_text, is_urdu=False)
        tgt_indices = self.vocab_tgt.encode(tgt_text, is_urdu=True)

        # Validate that sequences are not empty
        if len(src_indices) == 0:
            src_indices = [self.vocab_src.token2idx[self.vocab_src.unk_token]]
        if len(tgt_indices) == 0:
            tgt_indices = [self.vocab_tgt.token2idx[self.vocab_tgt.unk_token]]

        # Truncate if too long
        src_indices = src_indices[:self.max_length]
        tgt_indices = tgt_indices[:self.max_length]

        # Add SOS and EOS tokens to target
        sos_idx = self.vocab_tgt.token2idx[self.vocab_tgt.sos_token]
        eos_idx = self.vocab_tgt.token2idx[self.vocab_tgt.eos_token]

        tgt_indices = [sos_idx] + tgt_indices + [eos_idx]

        return {
            'src': torch.tensor(src_indices, dtype=torch.long),
            'tgt': torch.tensor(tgt_indices, dtype=torch.long),
            'src_text': src_text,
            'tgt_text': tgt_text
        }


def collate_fn(batch):
    """
    Collate function to handle variable-length sequences with error handling.

    Args:
        batch: List of samples from dataset

    Returns:
        Dictionary with padded batches
    """
    # Filter out None items and validate
    valid_batch = []
    for item in batch:
        if item is None:
            print("  WARNING: Skipping None item in batch")
            continue
        if 'src' not in item or 'tgt' not in item:
            print("  WARNING: Skipping invalid item (missing src or tgt)")
            continue
        if len(item['src']) == 0 or len(item['tgt']) == 0:
            print("  WARNING: Skipping item with empty sequence")
            continue
        valid_batch.append(item)

    # Handle empty batch
    if len(valid_batch) == 0:
        raise ValueError("Empty batch after filtering invalid items!")

    # Extract sequences
    src_batch = [item['src'] for item in valid_batch]
    tgt_batch = [item['tgt'] for item in valid_batch]

    # Pad sequences
    pad_idx = vocab_en.token2idx[vocab_en.pad_token]

    try:
        src_padded = pad_sequence(src_batch, batch_first=True, padding_value=pad_idx)
        tgt_padded = pad_sequence(tgt_batch, batch_first=True, padding_value=pad_idx)
    except Exception as e:
        print(f" Error in padding: {e}")
        print(f"   Batch size: {len(valid_batch)}")
        print(f"   Source lengths: {[len(s) for s in src_batch]}")
        print(f"   Target lengths: {[len(t) for t in tgt_batch]}")
        raise

    return {
        'src': src_padded,
        'tgt': tgt_padded,
        'src_texts': [item['src_text'] for item in valid_batch],
        'tgt_texts': [item['tgt_text'] for item in valid_batch]
    }


# Split data into train/val/test
train_size = int(0.8 * len(parallel_data))
val_size = int(0.1 * len(parallel_data))
test_size = len(parallel_data) - train_size - val_size

# Shuffle data
random.shuffle(parallel_data)

train_pairs = parallel_data[:train_size]
val_pairs = parallel_data[train_size:train_size + val_size]
test_pairs = parallel_data[train_size + val_size:]

print("=" * 60)
print("DATASET SPLITS")
print("=" * 60)
print(f"Total pairs:    {len(parallel_data):,}")
print(f"Training:       {len(train_pairs):,} ({len(train_pairs)/len(parallel_data)*100:.1f}%)")
print(f"Validation:     {len(val_pairs):,} ({len(val_pairs)/len(parallel_data)*100:.1f}%)")
print(f"Test:           {len(test_pairs):,} ({len(test_pairs)/len(parallel_data)*100:.1f}%)")
print("=" * 60)

# Create datasets
train_dataset = TranslationDataset(train_pairs, vocab_en, vocab_ur, config['max_seq_length'])
val_dataset = TranslationDataset(val_pairs, vocab_en, vocab_ur, config['max_seq_length'])
test_dataset = TranslationDataset(test_pairs, vocab_en, vocab_ur, config['max_seq_length'])

# Create dataloaders with error handling
try:
    train_loader = DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=0,          # Single-worker loading for cross-platform consistency
        pin_memory=False        # Disabled to preserve compatibility with CPU/MPS execution
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=0,
        pin_memory=False
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=0,
        pin_memory=False
    )

    print(f"\n Created DataLoaders with batch size {config['batch_size']}")
    print(f"  Training batches:   {len(train_loader):,}")
    print(f"  Validation batches: {len(val_loader):,}")
    print(f"  Test batches:       {len(test_loader):,}")

    # Test a batch to ensure everything works
    print(f"\nTesting batch loading...")
    sample_batch = next(iter(train_loader))
    print(f" Sample batch shapes:")
    print(f"  Source: {sample_batch['src'].shape}")
    print(f"  Target: {sample_batch['tgt'].shape}")
    print(f"  Batch loaded successfully!")

except Exception as e:
    print(f" Error creating DataLoaders: {e}")
    raise

print("=" * 60)

## 6. Transformer Model Architecture

This section provides a detailed breakdown of the Transformer model, implemented from scratch. The architecture is composed of several key components that work together to process sequential data, capture contextual relationships, and generate translations.

### 6.1. Core Components

#### Multi-Head Attention (`MultiHeadAttention`)
The fundamental building block of the Transformer. This mechanism allows the model to weigh the significance of different words when producing a representation of a sequence. Instead of a single attention function, it runs multiple attention mechanisms in parallel ("heads"), allowing the model to jointly attend to information from different representational subspaces.

#### Position-wise Feed-Forward Network (`PositionwiseFeedForward`)
A fully connected feed-forward network that is applied to each position separately and identically. It consists of two linear transformations with a ReLU activation in between, enabling the model to learn more complex transformations.

#### Positional Encoding (`PositionalEncoding`)
Since the Transformer contains no recurrence or convolution, it relies on positional encodings to inject information about the relative or absolute position of tokens in the sequence. These encodings are added to the input embeddings at the bottoms of the encoder and decoder stacks.

In [None]:
class PositionalEncoding(nn.Module):
    """Positional encoding for Transformer as described in the paper."""
    
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # Create positional encoding matrix
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0)  # Add batch dimension
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        """
        Args:
            x: Tensor of shape (batch_size, seq_len, d_model)
        """
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)


class TransformerTranslator(nn.Module):
    """Transformer model for machine translation."""
    
    def __init__(self, 
                 src_vocab_size,
                 tgt_vocab_size,
                 d_model=512,
                 nhead=8,
                 num_encoder_layers=6,
                 num_decoder_layers=6,
                 dim_feedforward=2048,
                 dropout=0.1,
                 max_seq_length=100):
        super().__init__()
        
        self.d_model = d_model
        self.src_vocab_size = src_vocab_size
        self.tgt_vocab_size = tgt_vocab_size
        
        # Embeddings
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        
        # Positional encoding
        self.pos_encoder = PositionalEncoding(d_model, max_seq_length, dropout)
        
        # Transformer
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        
        # Output projection
        self.fc_out = nn.Linear(d_model, tgt_vocab_size)
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        """Initialize weights using Xavier uniform initialization."""
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def generate_square_subsequent_mask(self, sz):
        """Generate mask for decoder to prevent attending to future tokens."""
        mask = torch.triu(torch.ones(sz, sz), diagonal=1)
        mask = mask.masked_fill(mask == 1, float('-inf'))
        return mask
    
    def create_padding_mask(self, seq, pad_idx):
        """Create mask for padding tokens."""
        return (seq == pad_idx)
    
    def forward(self, src, tgt, src_padding_mask=None, tgt_padding_mask=None):
        """
        Args:
            src: Source sequences (batch_size, src_seq_len)
            tgt: Target sequences (batch_size, tgt_seq_len)
            src_padding_mask: Padding mask for source (batch_size, src_seq_len)
            tgt_padding_mask: Padding mask for target (batch_size, tgt_seq_len)
        
        Returns:
            Output logits (batch_size, tgt_seq_len, tgt_vocab_size)
        """
        # Embed and add positional encoding
        src_emb = self.pos_encoder(self.src_embedding(src) * math.sqrt(self.d_model))
        tgt_emb = self.pos_encoder(self.tgt_embedding(tgt) * math.sqrt(self.d_model))
        
        # Generate target mask (causal mask)
        tgt_seq_len = tgt.size(1)
        tgt_mask = self.generate_square_subsequent_mask(tgt_seq_len).to(tgt.device)
        
        # Forward pass through transformer
        output = self.transformer(
            src_emb, 
            tgt_emb,
            tgt_mask=tgt_mask,
            src_key_padding_mask=src_padding_mask,
            tgt_key_padding_mask=tgt_padding_mask
        )
        
        # Project to vocabulary
        output = self.fc_out(output)
        
        return output

# Initialize model
model = TransformerTranslator(
    src_vocab_size=len(vocab_en),
    tgt_vocab_size=len(vocab_ur),
    d_model=config['d_model'],
    nhead=config['nhead'],
    num_encoder_layers=config['num_encoder_layers'],
    num_decoder_layers=config['num_decoder_layers'],
    dim_feedforward=config['dim_feedforward'],
    dropout=config['dropout'],
    max_seq_length=config['max_seq_length']
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("=" * 60)
print("MODEL ARCHITECTURE")
print("=" * 60)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model size: {total_params * 4 / 1024 / 1024:.2f} MB")
print("=" * 60)
print(model)
print("=" * 60)

### 6.4. Encoder Layer

-   **`EncoderLayer`**: A single layer of the encoder stack. It consists of a multi-head self-attention mechanism followed by a position-wise feed-forward network. Residual connections and layer normalization are applied after each sub-layer.

In [None]:
class NoamScheduler:
    """Learning rate scheduler with warmup as described in 'Attention is All You Need'."""
    
    def __init__(self, optimizer, d_model, warmup_steps=4000):
        self.optimizer = optimizer
        self.d_model = d_model
        self.warmup_steps = warmup_steps
        self.step_num = 0
    
    def step(self):
        """Update learning rate."""
        self.step_num += 1
        lr = self._get_lr()
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        return lr
    
    def _get_lr(self):
        """Calculate learning rate based on step number."""
        step = max(self.step_num, 1)
        return (self.d_model ** -0.5) * min(step ** -0.5, step * (self.warmup_steps ** -1.5))

# Loss function (ignore padding tokens)
pad_idx = vocab_en.token2idx[vocab_en.pad_token]
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx, label_smoothing=0.2)

# Optimizer
optimizer = torch.optim.Adam(
    model.parameters(), 
    lr=config['learning_rate'], 
    betas=(0.9, 0.98), 
    eps=1e-9,
    weight_decay=config.get('weight_decay', 0.0001)
)

# Learning rate scheduler
scheduler = NoamScheduler(optimizer, config['d_model'], config['warmup_steps'])

print("=" * 60)
print("TRAINING CONFIGURATION")
print("=" * 60)
print(f"Optimizer: Adam")
print(f"Initial learning rate: {config['learning_rate']}")
print(f"Warmup steps: {config['warmup_steps']}")
print(f"Gradient clipping: {config['max_grad_norm']}")
print(f"Loss function: CrossEntropyLoss with label smoothing (0.1)")
print(f"Padding index (ignored): {pad_idx}")
print("=" * 60)

### 6.5. Encoder

-   **`Encoder`**: The full encoder, composed of a stack of `EncoderLayer` instances. It processes the input sequence and generates a context-rich representation.

In [None]:
def save_checkpoint(model, optimizer, scheduler, epoch, train_loss, val_loss, filepath):
    """
    Save model checkpoint with comprehensive state.
    
    Args:
        model: The model to save
        optimizer: The optimizer state
        scheduler: The scheduler state
        epoch: Current epoch number
        train_loss: Training loss
        val_loss: Validation loss
        filepath: Path to save checkpoint
    """
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_step': scheduler.step_num,
        'train_loss': train_loss,
        'val_loss': val_loss,
        'config': config,
        'vocab_en_size': len(vocab_en),
        'vocab_ur_size': len(vocab_ur),
    }
    
    try:
        # Create directory if it doesn't exist
        Path(filepath).parent.mkdir(parents=True, exist_ok=True)
        torch.save(checkpoint, filepath)
        print(f"   Checkpoint saved: {filepath}")
    except Exception as e:
        print(f"   Failed to save checkpoint: {e}")


def load_checkpoint(filepath, model, optimizer=None, scheduler=None, strict=True):
    """
    Load model checkpoint with comprehensive error handling.
    
    Args:
        filepath: Path to checkpoint file
        model: Model to load weights into
        optimizer: Optimizer to load state into (optional)
        scheduler: Scheduler to load state into (optional)
        strict: Whether to strictly enforce state dict keys match
    
    Returns:
        Dictionary with epoch, train_loss, val_loss, or None if failed
    """
    if not os.path.exists(filepath):
        print(f" Checkpoint file not found: {filepath}")
        return None
    
    try:
        # Load checkpoint
        print(f"Loading checkpoint from: {filepath}")
        checkpoint = torch.load(filepath, map_location=device)
        
        # Validate checkpoint structure
        required_keys = ['model_state_dict', 'epoch']
        missing_keys = [key for key in required_keys if key not in checkpoint]
        
        if missing_keys:
            print(f"  WARNING: Checkpoint missing keys: {missing_keys}")
            if strict:
                raise ValueError(f"Invalid checkpoint structure. Missing: {missing_keys}")
            return None
        
        # Load model state
        try:
            model.load_state_dict(checkpoint['model_state_dict'])
            print(f"   Model state loaded")
        except RuntimeError as e:
            print(f"    WARNING: Model state dict mismatch: {e}")
            if strict:
                raise
            # Try loading with strict=False
            model.load_state_dict(checkpoint['model_state_dict'], strict=False)
            print(f"   Model state loaded (non-strict)")
        
        # Load optimizer state (optional)
        if optimizer and 'optimizer_state_dict' in checkpoint:
            try:
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                print(f"   Optimizer state loaded")
            except Exception as e:
                print(f"    WARNING: Could not load optimizer state: {e}")
        
        # Load scheduler state (optional)
        if scheduler and 'scheduler_step' in checkpoint:
            try:
                scheduler.step_num = checkpoint['scheduler_step']
                print(f"   Scheduler state loaded (step: {checkpoint['scheduler_step']})")
            except Exception as e:
                print(f"    WARNING: Could not load scheduler state: {e}")
        
        # Print checkpoint info
        print(f"\n{'='*60}")
        print(f"CHECKPOINT INFO")
        print(f"{'='*60}")
        print(f"  Epoch: {checkpoint['epoch']}")
        if 'train_loss' in checkpoint:
            print(f"  Train Loss: {checkpoint['train_loss']:.4f}")
        if 'val_loss' in checkpoint:
            print(f"  Val Loss: {checkpoint['val_loss']:.4f}")
        if 'vocab_en_size' in checkpoint:
            print(f"  English Vocab Size: {checkpoint['vocab_en_size']:,}")
        if 'vocab_ur_size' in checkpoint:
            print(f"  Urdu Vocab Size: {checkpoint['vocab_ur_size']:,}")
        print(f"{'='*60}\n")
        
        return {
            'epoch': checkpoint['epoch'],
            'train_loss': checkpoint.get('train_loss', None),
            'val_loss': checkpoint.get('val_loss', None)
        }
        
    except FileNotFoundError:
        print(f" Checkpoint file not found: {filepath}")
        return None
    except Exception as e:
        print(f" Error loading checkpoint: {e}")
        print(f"   Error type: {type(e).__name__}")
        if strict:
            raise
        return None


print("=" * 60)
print("CHECKPOINT MANAGEMENT")
print("=" * 60)
print(" Checkpoint management functions defined")
print("  • save_checkpoint(): Saves model, optimizer, scheduler state")
print("  • load_checkpoint(): Loads with comprehensive error handling")
print("  • Validates checkpoint structure before loading")
print("  • Graceful fallback for missing or corrupt checkpoints")
print("=" * 60)

### 6.6. Decoder Layer

-   **`DecoderLayer`**: A single layer of the decoder stack. It includes two multi-head attention mechanisms: one for self-attention on the target sequence and another for cross-attention on the encoder's output. It also has a position-wise feed-forward network.

In [None]:
class EarlyStopping:
    """
    Early stopping to stop training when validation loss stops improving.
    
    Args:
        patience: Number of epochs to wait for improvement before stopping
        min_delta: Minimum change to qualify as improvement
        verbose: Whether to print messages
    """
    
    def __init__(self, patience=5, min_delta=0.0001, verbose=True):
        self.patience = patience
        self.min_delta = min_delta
        self.verbose = verbose
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.best_epoch = 0
    
    def __call__(self, val_loss, epoch):
        """
        Check if training should stop.
        
        Args:
            val_loss: Current validation loss
            epoch: Current epoch number
        
        Returns:
            bool: True if training should stop
        """
        if self.best_loss is None:
            self.best_loss = val_loss
            self.best_epoch = epoch
            if self.verbose:
                print(f"   Initial validation loss: {val_loss:.4f}")
        elif val_loss < self.best_loss - self.min_delta:
            # Improvement
            if self.verbose:
                improvement = self.best_loss - val_loss
                print(f"   Validation loss improved by {improvement:.4f} ({self.best_loss:.4f} → {val_loss:.4f})")
            self.best_loss = val_loss
            self.best_epoch = epoch
            self.counter = 0
        else:
            # No improvement
            self.counter += 1
            if self.verbose:
                print(f"    No improvement for {self.counter}/{self.patience} epochs")
            
            if self.counter >= self.patience:
                self.early_stop = True
                if self.verbose:
                    print(f"\n{'='*60}")
                    print(f" EARLY STOPPING TRIGGERED")
                    print(f"{'='*60}")
                    print(f"No improvement for {self.patience} consecutive epochs")
                    print(f"Best validation loss: {self.best_loss:.4f} (Epoch {self.best_epoch})")
                    print(f"{'='*60}\n")
        
        return self.early_stop
    
    def reset(self):
        """Reset the early stopping state."""
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.best_epoch = 0


def print_memory_stats(device):
    """
    Print memory statistics for the current device.
    
    Args:
        device: torch.device object
    """
    try:
        if device.type == 'mps':
            # MPS (Metal Performance Shaders) memory tracking
            if hasattr(torch.mps, 'current_allocated_memory'):
                allocated = torch.mps.current_allocated_memory() / 1024**3  # Convert to GB
                print(f"   MPS Memory Allocated: {allocated:.2f} GB")
            else:
                print(f"   MPS memory tracking not available in this PyTorch version")
        elif device.type == 'cuda':
            # CUDA memory tracking
            allocated = torch.cuda.memory_allocated(device) / 1024**3
            reserved = torch.cuda.memory_reserved(device) / 1024**3
            print(f"   GPU Memory - Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB")
        else:
            # CPU - no specific memory tracking needed
            pass
    except Exception as e:
        # Gracefully handle if memory stats are unavailable
        if device.type != 'cpu':
            print(f"   Memory stats unavailable: {e}")


def clear_memory_cache(device):
    """
    Clear memory cache for the device.
    
    Args:
        device: torch.device object
    """
    try:
        if device.type == 'mps':
            if hasattr(torch.mps, 'empty_cache'):
                torch.mps.empty_cache()
        elif device.type == 'cuda':
            torch.cuda.empty_cache()
    except Exception:
        pass  # Silently fail if cache clearing not available


# Test early stopping
print("=" * 60)
print("EARLY STOPPING & MEMORY MONITORING")
print("=" * 60)

# Test early stopping logic
early_stopping = EarlyStopping(patience=3, verbose=False)
test_losses = [0.5, 0.48, 0.47, 0.49, 0.50, 0.51, 0.52]

print("\nTesting Early Stopping with patience=3:")
print("Loss sequence: [0.5, 0.48, 0.47, 0.49, 0.50, 0.51, 0.52]")
print()

for epoch, loss in enumerate(test_losses, 1):
    stop = early_stopping(loss, epoch)
    status = "STOP" if stop else "continue"
    print(f"  Epoch {epoch}: loss={loss:.2f}, counter={early_stopping.counter} → {status}")
    if stop:
        print(f"  → Would stop at epoch {epoch}")
        break

print(f"\n Early stopping implemented with patience={config['early_stopping_patience']}")

# Test memory monitoring
print(f"\nTesting memory monitoring on device: {device}")
print_memory_stats(device)

print("=" * 60)
print(" Early stopping and memory monitoring ready!")

### 6.7. Decoder

-   **`Decoder`**: The full decoder, composed of a stack of `DecoderLayer` instances. It generates the translated sequence token by token, attending to the encoder's output at each step.

### 6.8. Assembling the Transformer

-   **`Transformer`**: The final model that combines the `Encoder` and `Decoder`. It takes the source and target sequences as input and produces the final translation output. It also handles the creation of masks to prevent the model from attending to future tokens in the target sequence.

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, scheduler, device, pad_idx, 
                gradient_accumulation_steps=1):
    """
    Train for one epoch with gradient accumulation support.
    
    Args:
        model: The Transformer model
        dataloader: Training data loader
        criterion: Loss function
        optimizer: Optimizer
        scheduler: Learning rate scheduler
        device: Device to train on
        pad_idx: Padding token index
        gradient_accumulation_steps: Number of steps to accumulate gradients
    
    Returns:
        Average loss for the epoch
    """
    model.train()
    total_loss = 0
    accumulated_loss = 0
    
    progress_bar = tqdm(dataloader, desc="Training", leave=False)
    
    optimizer.zero_grad()  # Zero gradients at the start
    
    for batch_idx, batch in enumerate(progress_bar):
        src = batch['src'].to(device)
        tgt = batch['tgt'].to(device)
        
        # Create padding masks
        src_padding_mask = (src == pad_idx)
        tgt_padding_mask = (tgt == pad_idx)
        
        # Target input and output (shifted by one)
        tgt_input = tgt[:, :-1]
        tgt_output = tgt[:, 1:]
        tgt_padding_mask_input = tgt_padding_mask[:, :-1]
        
        # Forward pass
        output = model(src, tgt_input, src_padding_mask, tgt_padding_mask_input)
        
        # Calculate loss
        output = output.reshape(-1, output.size(-1))
        tgt_output = tgt_output.reshape(-1)
        
        loss = criterion(output, tgt_output)
        
        # Normalize loss by accumulation steps
        loss = loss / gradient_accumulation_steps
        
        # Backward pass
        loss.backward()
        
        accumulated_loss += loss.item()
        
        # Update weights every gradient_accumulation_steps
        if (batch_idx + 1) % gradient_accumulation_steps == 0 or (batch_idx + 1) == len(dataloader):
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
            
            # Update weights
            optimizer.step()
            
            # Update learning rate
            current_lr = scheduler.step()
            
            # Zero gradients for next accumulation
            optimizer.zero_grad()
            
            # Track total loss (denormalize)
            total_loss += accumulated_loss * gradient_accumulation_steps
            
            # Update progress bar
            effective_batch = config['batch_size'] * gradient_accumulation_steps
            progress_bar.set_postfix({
                'loss': accumulated_loss * gradient_accumulation_steps,
                'lr': f'{current_lr:.2e}',
                'eff_bs': effective_batch
            })
            
            accumulated_loss = 0
    
    return total_loss / len(dataloader)


def validate(model, dataloader, criterion, device, pad_idx):
    """
    Validate the model.
    
    Args:
        model: The Transformer model
        dataloader: Validation data loader
        criterion: Loss function
        device: Device to run on
        pad_idx: Padding token index
    
    Returns:
        Average validation loss
    """
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc="Validation", leave=False)
        
        for batch in progress_bar:
            src = batch['src'].to(device)
            tgt = batch['tgt'].to(device)
            
            # Create padding masks
            src_padding_mask = (src == pad_idx)
            tgt_padding_mask = (tgt == pad_idx)
            
            # Target input and output
            tgt_input = tgt[:, :-1]
            tgt_output = tgt[:, 1:]
            tgt_padding_mask_input = tgt_padding_mask[:, :-1]
            
            # Forward pass
            output = model(src, tgt_input, src_padding_mask, tgt_padding_mask_input)
            
            # Calculate loss
            output = output.reshape(-1, output.size(-1))
            tgt_output = tgt_output.reshape(-1)
            
            loss = criterion(output, tgt_output)
            total_loss += loss.item()
            
            progress_bar.set_postfix({'loss': loss.item()})
    
    return total_loss / len(dataloader)


def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, 
                num_epochs, device, pad_idx, checkpoint_dir, early_stopping=None,
                gradient_accumulation_steps=1):
    """
    Full training loop with early stopping, memory monitoring, and progress tracking.
    
    Args:
        model: The Transformer model
        train_loader: Training data loader
        val_loader: Validation data loader
        criterion: Loss function
        optimizer: Optimizer
        scheduler: Learning rate scheduler
        num_epochs: Maximum number of epochs to train
        device: Device to train on
        pad_idx: Padding token index
        checkpoint_dir: Directory to save checkpoints
        early_stopping: EarlyStopping instance (optional)
        gradient_accumulation_steps: Number of steps to accumulate gradients
    
    Returns:
        Tuple of (train_losses, val_losses, stopped_early)
    """
    train_losses = []
    val_losses = []
    best_val_loss = float('inf')
    start_time_total = time.time()
    
    print("\n" + "=" * 60)
    print("STARTING TRAINING")
    print("=" * 60)
    print(f"Device: {device}")
    print(f"Total epochs: {num_epochs}")
    print(f"Batch size: {config['batch_size']}")
    print(f"Gradient accumulation steps: {gradient_accumulation_steps}")
    print(f"Effective batch size: {config['batch_size'] * gradient_accumulation_steps}")
    print(f"Training batches per epoch: {len(train_loader)}")
    print(f"Validation batches per epoch: {len(val_loader)}")
    
    # Estimate training time
    print(f"\nEstimated time per epoch: ~{len(train_loader) * 0.5 / 60:.1f}-{len(train_loader) * 1.0 / 60:.1f} minutes")
    print(f"Estimated total time: ~{num_epochs * len(train_loader) * 0.5 / 3600:.1f}-{num_epochs * len(train_loader) * 1.0 / 3600:.1f} hours")
    print(f"(Actual time depends on hardware and batch size)")
    
    if early_stopping:
        print(f"\nEarly stopping enabled: patience={early_stopping.patience}")
    
    print("=" * 60)
    
    stopped_early = False
    
    for epoch in range(1, num_epochs + 1):
        epoch_start_time = time.time()
        
        print(f"\n{'='*60}")
        print(f"Epoch {epoch}/{num_epochs}")
        print(f"{'='*60}")
        
        # Get current learning rate
        current_lr = optimizer.param_groups[0]['lr']
        print(f"Learning rate: {current_lr:.2e}")
        
        # Train
        train_loss = train_epoch(model, train_loader, criterion, optimizer, scheduler, 
                                device, pad_idx, gradient_accumulation_steps)
        
        # Validate
        val_loss = validate(model, val_loader, criterion, device, pad_idx)
        
        # Record losses
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        
        epoch_time = time.time() - epoch_start_time
        elapsed_total = time.time() - start_time_total
        
        # Print progress
        print(f"\nResults:")
        print(f"  Train Loss: {train_loss:.4f}")
        print(f"  Val Loss:   {val_loss:.4f}")
        print(f"  Epoch Time: {epoch_time:.1f}s ({epoch_time/60:.1f}m)")
        print(f"  Total Time: {elapsed_total/60:.1f}m ({elapsed_total/3600:.2f}h)")
        
        # Estimate remaining time
        if epoch < num_epochs:
            avg_epoch_time = elapsed_total / epoch
            remaining_epochs = num_epochs - epoch
            eta = avg_epoch_time * remaining_epochs
            print(f"  ETA: ~{eta/60:.1f}m ({eta/3600:.2f}h)")
        
        # Memory stats
        print_memory_stats(device)
        
        # Clear cache periodically
        if epoch % 5 == 0:
            clear_memory_cache(device)
        
        # Save periodic checkpoint
        if epoch % config['save_every_n_epochs'] == 0:
            checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_simple_{epoch:03d}.pt")
            save_checkpoint(model, optimizer, scheduler, epoch, train_loss, val_loss, checkpoint_path)
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_path = os.path.join(checkpoint_dir, "best_model_simple.pt")
            save_checkpoint(model, optimizer, scheduler, epoch, train_loss, val_loss, best_model_path)
            print(f"   New best model saved! (Val Loss: {val_loss:.4f})")
        
        # Early stopping check
        if early_stopping:
            if early_stopping(val_loss, epoch):
                stopped_early = True
                print(f"\nStopping early at epoch {epoch}")
                print(f"Best validation loss: {early_stopping.best_loss:.4f} (Epoch {early_stopping.best_epoch})")
                break
    
    total_time = time.time() - start_time_total
    
    print("\n" + "=" * 60)
    print("TRAINING COMPLETED")
    print("=" * 60)
    print(f"Total epochs: {len(train_losses)}")
    print(f"Best validation loss: {best_val_loss:.4f}")
    print(f"Total training time: {total_time/60:.1f}m ({total_time/3600:.2f}h)")
    print(f"Average time per epoch: {total_time/len(train_losses):.1f}s")
    print(f"Stopped early: {stopped_early}")
    print("=" * 60)
    
    return train_losses, val_losses, stopped_early


print("=" * 60)
print("TRAINING FUNCTIONS")
print("=" * 60)
print(" train_epoch(): Training with gradient accumulation")
print(" validate(): Validation with progress tracking")
print(" train_model(): Full training loop with:")
print("  • Early stopping support")
print("  • Memory monitoring")
print("  • Progress tracking with ETA")
print("  • Automatic checkpoint saving")
print("  • Learning rate display")
print("=" * 60)

## 7. Model Initialization

This section initializes the Transformer model with the specified hyperparameters and moves it to the selected device.

-   **Instantiate Model**: The `Transformer` class is instantiated with parameters from the `config` dictionary.
-   **Move to Device**: The model is moved to the GPU (`mps` or `cuda`) or CPU for training.
-   **Initialize Weights**: Model weights are initialized with a uniform distribution to ensure stable training from the start.

In [None]:
# Initialize early stopping
early_stopping = EarlyStopping(
    patience=config['early_stopping_patience'],
    min_delta=0.001,
    verbose=True
)

# Check for existing checkpoint and offer to resume
checkpoint_path = os.path.join(config['checkpoint_dir'], 'best_model_simple.pt')
start_epoch = 1
resume_training = False

if os.path.exists(checkpoint_path):
    print(f"\n{'='*60}")
    print(f"CHECKPOINT FOUND")
    print(f"{'='*60}")
    print(f"Found existing checkpoint: {checkpoint_path}")
    print(f"\nOptions:")
    print(f"  1. Resume training from checkpoint")
    print(f"  2. Start fresh training (will backup existing checkpoint)")
    print(f"  3. Skip training and use existing model for evaluation")
    
    # For automated execution, check modification time
    checkpoint_time = os.path.getmtime(checkpoint_path)
    age_hours = (time.time() - checkpoint_time) / 3600
    
    print(f"\nCheckpoint age: {age_hours:.1f} hours")
    
    # Auto-decide based on checkpoint age and whether it looks complete
    try:
        ckpt_info = load_checkpoint(checkpoint_path, model, strict=False)
        if ckpt_info and ckpt_info['epoch'] >= config['num_epochs'] * 0.8:
            print(f"\n Checkpoint appears complete (epoch {ckpt_info['epoch']}/{config['num_epochs']})")
            print(f"  Skipping training and using this model for evaluation.")
            resume_training = False
            train_losses = []
            val_losses = []
            skip_training = True
        else:
            print(f"\n  Checkpoint incomplete (epoch {ckpt_info['epoch'] if ckpt_info else 0}/{config['num_epochs']})")
            print(f"  Will resume training...")
            resume_training = True
            skip_training = False
            if ckpt_info:
                start_epoch = ckpt_info['epoch'] + 1
    except:
        print(f"\n  Could not load checkpoint. Starting fresh training...")
        skip_training = False
        resume_training = False
else:
    print(f"\nNo existing checkpoint found. Starting fresh training...")
    skip_training = False
    resume_training = False

# Execute training if needed
if not skip_training:
    print(f"\n{'='*60}")
    print(f"INITIATING TRAINING")
    print(f"{'='*60}")
    
    # Load checkpoint if resuming
    if resume_training:
        load_checkpoint(checkpoint_path, model, optimizer, scheduler, strict=False)
        print(f"Resuming from epoch {start_epoch}")
    
    # Run training
    train_losses, val_losses, stopped_early = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        num_epochs=config['num_epochs'],
        device=device,
        pad_idx=pad_idx,
        checkpoint_dir=config['checkpoint_dir'],
        early_stopping=early_stopping,
        gradient_accumulation_steps=config.get('gradient_accumulation_steps', 1)
    )
    
    # Save training history
    training_history = {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'stopped_early': stopped_early,
        'total_epochs': len(train_losses),
        'best_val_loss': min(val_losses) if val_losses else None,
        'config': config
    }
    
    history_path = os.path.join('results', 'training_history.json')
    os.makedirs('results', exist_ok=True)
    with open(history_path, 'w') as f:
        # Convert lists to serializable format
        serializable_history = {
            k: (v if not isinstance(v, list) else v) 
            for k, v in training_history.items()
        }
        json.dump(serializable_history, f, indent=2)
    
    print(f"\n Training history saved to {history_path}")
else:
    print(f"\n Skipping training - using existing model")
    # Try to load training history if it exists
    history_path = os.path.join('results', 'training_history.json')
    if os.path.exists(history_path):
        with open(history_path, 'r') as f:
            training_history = json.load(f)
            train_losses = training_history.get('train_losses', [])
            val_losses = training_history.get('val_losses', [])
        print(f" Loaded training history from {history_path}")
    else:
        train_losses = []
        val_losses = []

print(f"\n{'='*60}")
print(f"TRAINING PHASE COMPLETE")
print(f"{'='*60}")

## 8. Training Loop

This section defines the training and evaluation loops.

-   **`train_fn`**: A function that performs a single epoch of training. It iterates through the `DataLoader`, computes the model's output, calculates the loss, and updates the model's weights using backpropagation.
-   **`eval_fn`**: A function that evaluates the model on the validation set. It computes the loss without performing backpropagation to monitor performance.

In [None]:
def translate_sentence(model, sentence, vocab_src, vocab_tgt, device, max_length=50):
    """
    Translate a single sentence using greedy decoding.
    
    Args:
        model: Trained Transformer model
        sentence: Source sentence (string)
        vocab_src: Source vocabulary
        vocab_tgt: Target vocabulary
        device: Device to run on
        max_length: Maximum translation length
    
    Returns:
        Translated sentence (string)
    """
    model.eval()
    
    with torch.no_grad():
        # Encode source sentence
        src_indices = vocab_src.encode(sentence, is_urdu=False)
        src_tensor = torch.tensor(src_indices, dtype=torch.long).unsqueeze(0).to(device)
        
        # Get special token indices
        sos_idx = vocab_tgt.token2idx[vocab_tgt.sos_token]
        eos_idx = vocab_tgt.token2idx[vocab_tgt.eos_token]
        pad_idx = vocab_tgt.token2idx[vocab_tgt.pad_token]
        
        # Start with SOS token
        tgt_indices = [sos_idx]
        
        # Generate translation token by token
        for _ in range(max_length):
            tgt_tensor = torch.tensor(tgt_indices, dtype=torch.long).unsqueeze(0).to(device)
            
            # Create padding masks
            src_padding_mask = (src_tensor == pad_idx)
            tgt_padding_mask = (tgt_tensor == pad_idx)
            
            # Forward pass
            output = model(src_tensor, tgt_tensor, src_padding_mask, tgt_padding_mask)
            
            # Get prediction for last token
            next_token_logits = output[0, -1, :]
            next_token = next_token_logits.argmax().item()
            
            # Add to sequence
            tgt_indices.append(next_token)
            
            # Stop if EOS token is generated
            if next_token == eos_idx:
                break
        
        # Decode to text
        translation = vocab_tgt.decode(tgt_indices)
        
        return translation


def translate_batch(model, sentences, vocab_src, vocab_tgt, device, max_length=50):
    """
    Translate a batch of sentences.
    
    Args:
        model: Trained model
        sentences: List of source sentences
        vocab_src: Source vocabulary
        vocab_tgt: Target vocabulary
        device: Device to run on
        max_length: Maximum translation length
    
    Returns:
        List of translated sentences
    """
    translations = []
    
    for sentence in tqdm(sentences, desc="Translating"):
        translation = translate_sentence(model, sentence, vocab_src, vocab_tgt, device, max_length)
        translations.append(translation)
    
    return translations

print(" Translation functions defined")

## 9. Main Training Execution

This is the main block where the training process is executed.

-   **Optimizer and Scheduler**: An `Adam` optimizer and a learning rate scheduler (`ReduceLROnPlateau`) are initialized.
-   **Training Loop**: The code iterates through the specified number of epochs, calling `train_fn` and `eval_fn` at each step.
-   **Checkpointing**: The model with the best validation loss is saved to a file, allowing for inference or resuming training later.
-   **History Tracking**: Training and validation losses are stored in a history dictionary for later analysis.

In [None]:
def compute_ter(reference, hypothesis):
    """
    Compute Translation Error Rate (TER) using edit distance.
    Simplified version - production code would use specialized TER implementation.
    
    Args:
        reference: Reference translation (string)
        hypothesis: Hypothesis translation (string)
    
    Returns:
        TER score (lower is better)
    """
    ref_tokens = reference.split()
    hyp_tokens = hypothesis.split()
    
    # Compute edit distance (Levenshtein distance)
    m, n = len(ref_tokens), len(hyp_tokens)
    dp = [[0] * (n + 1) for _ in range(m + 1)]
    
    for i in range(m + 1):
        dp[i][0] = i
    for j in range(n + 1):
        dp[0][j] = j
    
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if ref_tokens[i-1] == hyp_tokens[j-1]:
                dp[i][j] = dp[i-1][j-1]
            else:
                dp[i][j] = 1 + min(dp[i-1][j], dp[i][j-1], dp[i-1][j-1])
    
    edit_distance = dp[m][n]
    ter = edit_distance / max(len(ref_tokens), 1)
    
    return ter


def compute_all_metrics(references, hypotheses, max_n=4):
    """
    Compute comprehensive translation metrics including BLEU, METEOR, and TER.
    
    Args:
        references: List of reference translations (strings)
        hypotheses: List of predicted translations (strings)
        max_n: Maximum n-gram for BLEU (default: 4)
    
    Returns:
        Dictionary with all computed metrics
    """
    metrics = {}
    
    # Tokenize for BLEU/METEOR
    ref_tokens = [[ref.split()] for ref in references]
    hyp_tokens = [hyp.split() for hyp in hypotheses]
    
    # Smoothing function for BLEU
    smoothing = SmoothingFunction().method1
    
    # Compute BLEU scores for different n-grams
    for n in range(1, max_n + 1):
        weights = [1.0/n] * n + [0.0] * (4 - n)
        
        scores = []
        for ref, hyp in zip(ref_tokens, hyp_tokens):
            score = sentence_bleu(ref, hyp, weights=weights, smoothing_function=smoothing)
            scores.append(score)
        
        metrics[f'BLEU-{n}'] = np.mean(scores)
        metrics[f'BLEU-{n}_std'] = np.std(scores)
    
    # Corpus-level BLEU-4
    corpus_bleu_score = corpus_bleu(ref_tokens, hyp_tokens, smoothing_function=smoothing)
    metrics['Corpus-BLEU-4'] = corpus_bleu_score
    
    # METEOR score
    try:
        meteor_scores = []
        for ref, hyp in zip(references, hypotheses):
            # METEOR expects tokenized inputs
            score = meteor_score([ref.split()], hyp.split())
            meteor_scores.append(score)
        
        metrics['METEOR'] = np.mean(meteor_scores)
        metrics['METEOR_std'] = np.std(meteor_scores)
    except Exception as e:
        print(f"  WARNING: Could not compute METEOR score: {e}")
        metrics['METEOR'] = None
        metrics['METEOR_std'] = None
    
    # TER (Translation Error Rate)
    ter_scores = []
    for ref, hyp in zip(references, hypotheses):
        ter = compute_ter(ref, hyp)
        ter_scores.append(ter)
    
    metrics['TER'] = np.mean(ter_scores)
    metrics['TER_std'] = np.std(ter_scores)
    
    # Additional statistics
    metrics['num_samples'] = len(references)
    metrics['avg_ref_length'] = np.mean([len(ref.split()) for ref in references])
    metrics['avg_hyp_length'] = np.mean([len(hyp.split()) for hyp in hypotheses])
    
    return metrics


def evaluate_model(model, test_pairs, vocab_src, vocab_tgt, device, num_samples=None):
    """
    Evaluate model on test set and compute comprehensive metrics.
    
    Args:
        model: Trained model
        test_pairs: List of (source, target) sentence pairs
        vocab_src: Source vocabulary
        vocab_tgt: Target vocabulary
        device: Device to run on
        num_samples: Number of samples to evaluate (None for all)
    
    Returns:
        Dictionary with metrics and sample translations
    """
    # Load best model if available
    best_model_path = os.path.join(config['checkpoint_dir'], 'best_model_simple.pt')
    if os.path.exists(best_model_path):
        print(f"\n{'='*60}")
        print(f"Loading best model for evaluation...")
        load_checkpoint(best_model_path, model, strict=False)
        print(f"{'='*60}\n")
    
    if num_samples:
        test_pairs = test_pairs[:num_samples]
    
    source_sentences = [src for src, _ in test_pairs]
    reference_translations = [tgt for _, tgt in test_pairs]
    
    # Generate translations
    print(f"Translating {len(source_sentences):,} test sentences...")
    print(f"This may take several minutes...\n")
    
    predicted_translations = translate_batch(model, source_sentences, vocab_src, vocab_tgt, device)
    
    # Compute all metrics
    print(f"Computing evaluation metrics...")
    metrics = compute_all_metrics(reference_translations, predicted_translations)
    
    return {
        'metrics': metrics,
        'source': source_sentences,
        'reference': reference_translations,
        'predicted': predicted_translations
    }


print("=" * 60)
print("EVALUATION FUNCTIONS")
print("=" * 60)
print(" compute_ter(): Translation Error Rate computation")
print(" compute_all_metrics(): BLEU, METEOR, TER")
print(" evaluate_model(): Comprehensive evaluation with all metrics")
print("=" * 60)

## 10. Save Training History

The training and validation loss history is saved to a JSON file for persistent storage and later analysis. This allows for plotting the learning curves without needing to retrain the model.

In [None]:
# Evaluate on FULL test set
print("\n" + "=" * 60)
print("STARTING EVALUATION ON FULL TEST SET")
print("=" * 60)
print(f"Test set size: {len(test_pairs):,} pairs")
print(f"This will take approximately {len(test_pairs) * 0.5 / 60:.1f}-{len(test_pairs) * 1.0 / 60:.1f} minutes")
print("=" * 60)

evaluation_results = evaluate_model(
    model, 
    test_pairs, 
    vocab_en, 
    vocab_ur, 
    device, 
    num_samples=None  # Evaluate on ALL test samples
)

# Display metrics in formatted table
print("\n" + "=" * 60)
print("EVALUATION RESULTS - COMPREHENSIVE METRICS")
print("=" * 60)

metrics = evaluation_results['metrics']

# Main metrics table
print("\n Primary Metrics:")
print("-" * 60)
print(f"{'Metric':<20} {'Score':>10} {'Std Dev':>10}")
print("-" * 60)

# BLEU scores
for n in range(1, 5):
    score = metrics.get(f'BLEU-{n}', 0)
    std = metrics.get(f'BLEU-{n}_std', 0)
    print(f"{'BLEU-' + str(n):<20} {score:>10.4f} {std:>10.4f}")

print(f"{'Corpus-BLEU-4':<20} {metrics.get('Corpus-BLEU-4', 0):>10.4f} {'N/A':>10}")

# METEOR
if metrics.get('METEOR') is not None:
    print(f"{'METEOR':<20} {metrics['METEOR']:>10.4f} {metrics.get('METEOR_std', 0):>10.4f}")
else:
    print(f"{'METEOR':<20} {'N/A':>10} {'N/A':>10}")

# TER
print(f"{'TER (lower=better)':<20} {metrics.get('TER', 0):>10.4f} {metrics.get('TER_std', 0):>10.4f}")

print("-" * 60)

# Statistics
print("\n Dataset Statistics:")
print("-" * 60)
print(f"Number of samples: {metrics.get('num_samples', 0):,}")
print(f"Avg reference length: {metrics.get('avg_ref_length', 0):.2f} words")
print(f"Avg hypothesis length: {metrics.get('avg_hyp_length', 0):.2f} words")
print("-" * 60)

# Save results to JSON
results_to_save = {
    'metrics': {k: float(v) if v is not None and not isinstance(v, str) else v 
                for k, v in metrics.items()},
    'model_config': config,
    'dataset_info': {
        'total_pairs': len(parallel_data),
        'train_pairs': len(train_pairs),
        'val_pairs': len(val_pairs),
        'test_pairs': len(test_pairs)
    },
    'training_info': {
        'train_losses': train_losses if 'train_losses' in locals() else [],
        'val_losses': val_losses if 'val_losses' in locals() else [],
        'total_epochs': len(train_losses) if 'train_losses' in locals() else 0,
        'best_val_loss': min(val_losses) if 'val_losses' in locals() and val_losses else None
    }
}

results_path = os.path.join('results', 'evaluation_results.json')
os.makedirs('results', exist_ok=True)
with open(results_path, 'w', encoding='utf-8') as f:
    json.dump(results_to_save, f, indent=2, ensure_ascii=False)

print(f"\n Evaluation results saved to: {results_path}")
print("=" * 60)

## 11. Plotting Learning Curves

This section visualizes the model's training progress by plotting the training and validation losses over epochs.

-   **Load History**: The saved training history is loaded from the JSON file.
-   **Plot Losses**: The losses are plotted using `matplotlib` to create a learning curve, which helps in diagnosing issues like overfitting or underfitting.

In [None]:
# Display sample translations with detailed metrics
num_samples_to_show = min(15, len(evaluation_results['source']))

print("\n" + "=" * 80)
print("SAMPLE TRANSLATIONS WITH QUALITY SCORES")
print("=" * 80)

# Compute metrics for each sample
sample_metrics = []
for i in range(num_samples_to_show):
    src = evaluation_results['source'][i]
    ref = evaluation_results['reference'][i]
    pred = evaluation_results['predicted'][i]
    
    # Compute individual metrics
    bleu = sentence_bleu([ref.split()], pred.split(), 
                        smoothing_function=SmoothingFunction().method1)
    ter = compute_ter(ref, pred)
    
    # Length info
    src_len = len(src.split())
    ref_len = len(ref.split())
    pred_len = len(pred.split())
    
    sample_metrics.append({
        'index': i,
        'src': src,
        'ref': ref,
        'pred': pred,
        'bleu': bleu,
        'ter': ter,
        'src_len': src_len,
        'ref_len': ref_len,
        'pred_len': pred_len
    })

# Sort by BLEU score to show best and worst
sample_metrics_sorted = sorted(sample_metrics, key=lambda x: x['bleu'], reverse=True)

# Show top 5 best
print("\n TOP 5 BEST TRANSLATIONS (by BLEU score):")
print("-" * 80)
for i, sample in enumerate(sample_metrics_sorted[:5], 1):
    print(f"\n{i}. Original Index: {sample['index'] + 1}")
    print(f"   Source (EN):     {sample['src']}")
    print(f"   Reference (UR):  {sample['ref']}")
    print(f"   Predicted (UR):  {sample['pred']}")
    print(f"    BLEU: {sample['bleu']:.4f} | TER: {sample['ter']:.4f} | Lengths: {sample['src_len']}→{sample['pred_len']} (ref: {sample['ref_len']})")
    print("-" * 80)

# Show bottom 5 worst
print("\n  BOTTOM 5 WORST TRANSLATIONS (by BLEU score):")
print("-" * 80)
for i, sample in enumerate(sample_metrics_sorted[-5:], 1):
    print(f"\n{i}. Original Index: {sample['index'] + 1}")
    print(f"   Source (EN):     {sample['src']}")
    print(f"   Reference (UR):  {sample['ref']}")
    print(f"   Predicted (UR):  {sample['pred']}")
    print(f"    BLEU: {sample['bleu']:.4f} | TER: {sample['ter']:.4f} | Lengths: {sample['src_len']}→{sample['pred_len']} (ref: {sample['ref_len']})")
    print("-" * 80)

# Show some random samples
print("\n RANDOM SAMPLES:")
print("-" * 80)
import random as rand
random_samples = rand.sample(sample_metrics, min(5, len(sample_metrics)))
for i, sample in enumerate(random_samples, 1):
    print(f"\n{i}. Original Index: {sample['index'] + 1}")
    print(f"   Source (EN):     {sample['src']}")
    print(f"   Reference (UR):  {sample['ref']}")
    print(f"   Predicted (UR):  {sample['pred']}")
    print(f"    BLEU: {sample['bleu']:.4f} | TER: {sample['ter']:.4f} | Lengths: {sample['src_len']}→{sample['pred_len']} (ref: {sample['ref_len']})")
    print("-" * 80)

# Analysis of common patterns
print("\n" + "=" * 80)
print("TRANSLATION QUALITY ANALYSIS")
print("=" * 80)

# BLEU distribution
bleu_scores = [s['bleu'] for s in sample_metrics]
print(f"\nBLEU Score Distribution:")
print(f"  Mean: {np.mean(bleu_scores):.4f}")
print(f"  Median: {np.median(bleu_scores):.4f}")
print(f"  Std Dev: {np.std(bleu_scores):.4f}")
print(f"  Min: {min(bleu_scores):.4f}")
print(f"  Max: {max(bleu_scores):.4f}")

# Length analysis
length_ratios = [s['pred_len'] / max(s['ref_len'], 1) for s in sample_metrics]
print(f"\nLength Ratio (Predicted/Reference):")
print(f"  Mean: {np.mean(length_ratios):.4f}")
print(f"  Median: {np.median(length_ratios):.4f}")

# Quality categories
excellent = sum(1 for s in bleu_scores if s >= 0.7)
good = sum(1 for s in bleu_scores if 0.4 <= s < 0.7)
fair = sum(1 for s in bleu_scores if 0.2 <= s < 0.4)
poor = sum(1 for s in bleu_scores if s < 0.2)

print(f"\nQuality Distribution:")
print(f"  Excellent (≥0.7): {excellent:3d} ({excellent/len(bleu_scores)*100:5.1f}%)")
print(f"  Good (0.4-0.7):   {good:3d} ({good/len(bleu_scores)*100:5.1f}%)")
print(f"  Fair (0.2-0.4):   {fair:3d} ({fair/len(bleu_scores)*100:5.1f}%)")
print(f"  Poor (<0.2):      {poor:3d} ({poor/len(bleu_scores)*100:5.1f}%)")

print("=" * 80)

## 12. Inference and Translation

This section defines the `translate_sentence` function, which uses the trained model to translate a new English sentence into Urdu.

-   **Load Best Model**: The best-performing model checkpoint is loaded.
-   **Translation Logic**: The function tokenizes the input sentence, generates a translation token by token using a greedy decoding approach, and detokenizes the output back into a human-readable string.

In [None]:
# Visualize training progress
if 'train_losses' in locals() and 'val_losses' in locals() and len(train_losses) > 0:
    print("\nGenerating training curves...")
    
    plt.figure(figsize=(12, 6))
    
    epochs = range(1, len(train_losses) + 1)
    
    # Plot both losses on same figure
    plt.plot(epochs, train_losses, 'b-', label='Training Loss', linewidth=2, marker='o', 
             markersize=4, markevery=max(1, len(epochs)//20))
    plt.plot(epochs, val_losses, 'r-', label='Validation Loss', linewidth=2, marker='s',
             markersize=4, markevery=max(1, len(epochs)//20))
    
    # Find and mark best epoch
    best_epoch = np.argmin(val_losses) + 1
    best_val_loss = min(val_losses)
    plt.plot(best_epoch, best_val_loss, 'g*', markersize=20, 
             label=f'Best (Epoch {best_epoch}: {best_val_loss:.4f})')
    
    # Formatting
    plt.xlabel('Epoch', fontsize=14, fontweight='bold')
    plt.ylabel('Loss', fontsize=14, fontweight='bold')
    plt.title('Training and Validation Loss Over Epochs', fontsize=16, fontweight='bold', pad=20)
    plt.legend(fontsize=12, loc='best', framealpha=0.9)
    plt.grid(True, alpha=0.3, linestyle='--', linewidth=0.5)
    
    # Add min/max annotations
    plt.text(0.02, 0.98, 
             f'Min Train Loss: {min(train_losses):.4f}\nMin Val Loss: {best_val_loss:.4f}',
             transform=plt.gca().transAxes,
             fontsize=10,
             verticalalignment='top',
             bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    plt.tight_layout()
    
    # Save plot with high DPI
    plot_path = os.path.join('results', 'training_loss.png')
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f" Training curves saved to: {plot_path}")
    
    # Also create a separate plot for loss dynamics
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    
    # Left plot: Loss curves
    ax1.plot(epochs, train_losses, 'b-', label='Training Loss', linewidth=2)
    ax1.plot(epochs, val_losses, 'r-', label='Validation Loss', linewidth=2)
    ax1.set_xlabel('Epoch', fontsize=12, fontweight='bold')
    ax1.set_ylabel('Loss', fontsize=12, fontweight='bold')
    ax1.set_title('Loss Curves', fontsize=14, fontweight='bold')
    ax1.legend(fontsize=11)
    ax1.grid(True, alpha=0.3)
    
    # Right plot: Loss difference (overfitting indicator)
    loss_diff = [val - train for val, train in zip(val_losses, train_losses)]
    ax2.plot(epochs, loss_diff, 'purple', linewidth=2, marker='o', markersize=3)
    ax2.axhline(y=0, color='k', linestyle='--', linewidth=1, alpha=0.5)
    ax2.set_xlabel('Epoch', fontsize=12, fontweight='bold')
    ax2.set_ylabel('Val Loss - Train Loss', fontsize=12, fontweight='bold')
    ax2.set_title('Overfitting Indicator', fontsize=14, fontweight='bold')
    ax2.grid(True, alpha=0.3)
    ax2.fill_between(epochs, 0, loss_diff, where=[d > 0 for d in loss_diff], 
                     alpha=0.3, color='red', label='Overfitting')
    ax2.fill_between(epochs, 0, loss_diff, where=[d <= 0 for d in loss_diff], 
                     alpha=0.3, color='green', label='Underfitting')
    ax2.legend(fontsize=10)
    
    plt.tight_layout()
    detailed_plot_path = os.path.join('results', 'training_analysis.png')
    plt.savefig(detailed_plot_path, dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f" Detailed analysis saved to: {detailed_plot_path}")
    
    # Print summary statistics
    print("\n" + "=" * 60)
    print("TRAINING SUMMARY")
    print("=" * 60)
    print(f"Total epochs trained: {len(train_losses)}")
    print(f"Best epoch: {best_epoch}")
    print(f"Final train loss: {train_losses[-1]:.4f}")
    print(f"Final val loss: {val_losses[-1]:.4f}")
    print(f"Best val loss: {best_val_loss:.4f}")
    print(f"Loss improvement: {val_losses[0]:.4f} → {best_val_loss:.4f} ({(val_losses[0] - best_val_loss)/val_losses[0]*100:.1f}%)")
    
    # Check for overfitting
    final_diff = val_losses[-1] - train_losses[-1]
    if final_diff > 0.5:
        print(f"\n  WARNING: Possible overfitting detected (val-train diff: {final_diff:.4f})")
    elif final_diff < -0.1:
        print(f"\n  WARNING: Possible underfitting (val-train diff: {final_diff:.4f})")
    else:
        print(f"\n Good fit (val-train diff: {final_diff:.4f})")
    
    print("=" * 60)
    
else:
    print("\n  No training history available to plot.")
    print("   Train the model first to generate loss curves.")

## 13. BLEU Score Evaluation

This section evaluates the quality of the translations using the BLEU score.

-   **`compute_bleu`**: A function that calculates BLEU-1, BLEU-2, BLEU-3, and BLEU-4 scores by comparing the model's translations to reference translations.
-   **Evaluation**: The function is called with sample translated and reference sentences to compute the BLEU scores, providing a quantitative measure of performance.

In [None]:
# Interactive translation
# Change this to any English sentence you want to translate
test_sentences = [
    "Hello, how are you?",
    "What is your name?",
    "Thank you very much.",
    "How old are you?",
    "Where are you from?",
]

print("=" * 80)
print("INTERACTIVE TRANSLATION")
print("=" * 80)

for i, sentence in enumerate(test_sentences, 1):
    translation = translate_sentence(model, sentence, vocab_en, vocab_ur, device)
    print(f"\n{i}. English:  {sentence}")
    print(f"   Urdu:     {translation}")

print("\n" + "=" * 80)
print("\nTo translate your own sentences:")
print("1. Modify the 'test_sentences' list above")
print("2. Run this cell again")
print("=" * 80)