# Project: English-to-Urdu Machine Translation

This notebook presents a complete pipeline for fine-tuning a Transformer-based model for English-to-Urdu translation. The implementation is heavily optimized for modern Apple Silicon (M-series chips) using PyTorch's Metal Performance Shaders (MPS) backend.

The primary goal is to demonstrate an effective and efficient training process that fulfills the core requirements of building, training, and evaluating a sequence-to-sequence model.

### Key Features & Optimizations

1.  **Fine-Tuning a Pre-trained Model**: Leverages the `Helsinki-NLP/opus-mt-en-ur` model, a powerful pre-trained baseline, to achieve high-quality translations with minimal training time.

2.  **MPS-Optimized Training**: The entire workflow is tailored for Apple Silicon GPUs.
    -   **Device Placement**: All tensors and model components are explicitly placed on the `mps` device.
    -   **Mixed-Precision Training (AMP)**: Uses `torch.amp.autocast(device_type="mps")` to accelerate computation and reduce memory usage by performing operations in `float16`.

3.  **Efficient Data Pipeline**:
    -   **Pre-tokenization**: The dataset is tokenized once and cached, eliminating this bottleneck from the training loop.
    -   **Asynchronous Data Loading**: The `DataLoader` is configured with multiple workers (`num_workers`) and batch prefetching (`prefetch_factor`) to ensure the GPU is always saturated with data.

4.  **Memory Efficiency**:
    -   **Gradient Checkpointing**: Enabled to significantly reduce the memory footprint of the model, allowing for larger batch sizes.

5.  **Custom Training Loop**: A transparent, from-scratch training loop is implemented to provide full control over the training process and integrate MPS-specific optimizations, replacing the high-level Hugging Face `Trainer`.

## 1. Environment Setup & Configuration

In [None]:
# Core libraries
import os
import time
import warnings
from pathlib import Path

# Data handling
import pandas as pd
import numpy as np

# PyTorch
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW

# Hugging Face Transformers & Datasets
from transformers import (
    AutoTokenizer, 
    AutoModelForSeq2SeqLM,
    get_scheduler,
)
from datasets import Dataset as HFDataset, load_dataset

# Evaluation metrics
import evaluate
from tqdm.auto import tqdm

warnings.filterwarnings('ignore')

# ============================================================
# MPS-ONLY OPTIMIZATIONS
# ============================================================
torch.set_default_device("mps")
# Enable fusion to optimize kernel performance (requires PyTorch 2.0+)
# torch.backends.mps.enable_fusion(True)
device = torch.device("mps")
print("MPS set as default device and fusion enabled.")


# Create necessary directories
Path("checkpoints_optimized").mkdir(exist_ok=True)
Path("data").mkdir(exist_ok=True)

## 2. Configuration

In [None]:
config = {
    # Model
    'model_name': "Helsinki-NLP/opus-mt-en-ur",
    
    # Training 
    'batch_size': 32,          # Increased batch size due to memory optimizations
    'num_epochs': 5,
    'learning_rate': 5e-5,
    'weight_decay': 0.01,
    
    # Data Loader (MPS Optimized)
    'num_workers': 4,          # Use multiple workers to prefetch data
    'prefetch_factor': 2,      # Number of batches to prefetch per worker
    'pin_memory': True,        # Speeds up host-to-device transfer
    
    # Sequences
    'max_source_length': 128,
    'max_target_length': 128,
    
    # Checkpointing & Logging
    'checkpoint_dir': './checkpoints_optimized',
}

print("Configuration loaded.")

## 3. Data Loading & Preprocessing

**Optimization**:
- **Load from text files**: Directly load the parallel corpus.
- **Pre-tokenization**: Tokenize the entire dataset once using `map` with `batched=True`. This is much faster than tokenizing on-the-fly.
- **Caching**: The `datasets` library automatically caches the processed data, so subsequent runs are instantaneous.

In [None]:
def load_and_preprocess_data(en_file, ur_file, tokenizer, max_samples=None):
    """Loads parallel data, converts to HF Dataset, and tokenizes."""
    pairs = []
    with open(en_file, 'r', encoding='utf-8') as f_en, \
         open(ur_file, 'r', encoding='utf-8') as f_ur:
        for i, (en_line, ur_line) in enumerate(zip(f_en, f_ur)):
            if max_samples and i >= max_samples:
                break
            en = en_line.strip()
            ur = ur_line.strip()
            if en and ur:
                pairs.append({'english': en, 'urdu': ur})
    
    print(f"Loaded {len(pairs):,} sentence pairs.")
    
    # Convert to HuggingFace Dataset
    raw_dataset = HFDataset.from_pandas(pd.DataFrame(pairs))

    def preprocess_function(examples):
        """Tokenize English-Urdu pairs."""
        inputs = tokenizer(
            examples['english'],
            max_length=config['max_source_length'],
            truncation=True,
            padding='max_length' # Pad to max_length for consistent tensor shapes
        )
        # Setup the tokenizer for targets
        with tokenizer.as_target_tokenizer():
            labels = tokenizer(
                examples['urdu'],
                max_length=config['max_target_length'],
                truncation=True,
                padding='max_length'
            )
        inputs['labels'] = labels['input_ids']
        return inputs

    # Tokenize the dataset
    tokenized_dataset = raw_dataset.map(
        preprocess_function,
        batched=True,
        remove_columns=['english', 'urdu'],
        desc="Tokenizing dataset"
    )
    
    # Split the dataset
    train_test_split = tokenized_dataset.train_test_split(test_size=0.1)
    train_dataset = train_test_split['train']
    val_dataset = train_test_split['test']
    
    print(f"Train samples: {len(train_dataset):,}")
    print(f"Validation samples: {len(val_dataset):,}")
    
    return train_dataset, val_dataset

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(config['model_name'])

# Load and process data
DATA_DIR = Path("data")
train_dataset, val_dataset = load_and_preprocess_data(
    en_file=DATA_DIR / "english-corpus.txt",
    ur_file=DATA_DIR / "urdu-corpus.txt",
    tokenizer=tokenizer,
    # max_samples=20000  # Using a subset for faster demonstration
)


## 4. Model & DataLoaders

**Optimization Details**:
- **Pre-trained Transformer**: We use a complete sequence-to-sequence model from Hugging Face (`Helsinki-NLP/opus-mt-en-ur`), which includes the encoder-decoder Transformer architecture.
- **Gradient Checkpointing**: `model.gradient_checkpointing_enable()` is called to reduce memory usage. This technique trades a small amount of computation time during the backward pass for a significant reduction in VRAM, allowing for larger batch sizes.
- **Optimized `DataLoader`**:
    - `num_workers > 0` enables multi-process data loading to prevent the CPU from being a bottleneck.
    - `pin_memory=True` allows for faster data transfer from CPU to the MPS device.
    - `prefetch_factor` works with `num_workers` to load several batches in advance.
    - `generator` is explicitly set to the `mps` device to ensure compatibility with multi-worker data loading on Apple Silicon.

In [None]:
# Load the pretrained model
model = AutoModelForSeq2SeqLM.from_pretrained(config['model_name']).to(device)

# Enable gradient checkpointing to save memory
model.gradient_checkpointing_enable()

# The model is already on the 'mps' device thanks to torch.set_default_device()
print(f"Model loaded on {model.device}")
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")

# Set format for PyTorch
train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
val_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

# Create a generator for the MPS device to be used by the DataLoader
g = torch.Generator(device='mps')

# Create efficient DataLoaders
train_dataloader = DataLoader(
    train_dataset,
    batch_size=config['batch_size'],
    shuffle=True,
    num_workers=config['num_workers'],
    pin_memory=config['pin_memory'],
    prefetch_factor=config['prefetch_factor'],
    generator=g  # Pass the MPS generator to the DataLoader
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=config['batch_size'],
    num_workers=config['num_workers'],
    pin_memory=config['pin_memory'],
    prefetch_factor=config['prefetch_factor']
)

print(f"DataLoaders created with {config['num_workers']} workers and an MPS-compatible generator.")

## 5. Training Loop

**Optimization Details**:
- **Custom Loop**: A manual training loop provides full control and transparency, avoiding the overhead of the high-level `Trainer` class and allowing for direct integration of MPS-specific features.
- **Mixed-Precision (AMP) for MPS**:
    - The `with torch.amp.autocast(device_type="mps", dtype=torch.float16):` context manager automatically casts model operations to a faster, lower-precision format (`float16`) where possible.
    - Unlike CUDA, a `GradScaler` is not required for stable training with `autocast` on MPS.
- **Fused Optimizer**: `AdamW(fused=True)` is used where available. This can merge multiple optimizer steps into a single kernel, improving performance on compatible hardware.
- **Learning Rate Scheduler**: A linear scheduler dynamically adjusts the learning rate during training, which can lead to better and more stable convergence.

In [None]:
# Optimizer
try:
    optimizer = AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'], fused=True)
    print("Using fused AdamW optimizer.")
except:
    optimizer = AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
    print("Fused AdamW not available. Using standard AdamW.")

# Learning rate scheduler
num_training_steps = config['num_epochs'] * len(train_dataloader)
lr_scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)


# --- Training ---
print("Starting training...")
total_start_time = time.time()

for epoch in range(config['num_epochs']):
    model.train()
    epoch_start_time = time.time()
    total_loss = 0
    
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{config['num_epochs']}")
    
    for batch in progress_bar:
        batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}

        
        optimizer.zero_grad()
        
        # Mixed-precision forward pass for MPS
        with torch.amp.autocast(device_type="mps", dtype=torch.float16):
            outputs = model(**batch)
            loss = outputs.loss
        
        # Standard backward pass (no scaler needed for MPS autocast)
        loss.backward()
        optimizer.step()
        
        lr_scheduler.step()
        
        total_loss += loss.item()
        progress_bar.set_postfix({'loss': loss.item()})

    avg_train_loss = total_loss / len(train_dataloader)
    epoch_duration = time.time() - epoch_start_time
    
    print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Duration: {epoch_duration:.2f}s")


total_duration = time.time() - total_start_time
print(f"\nTraining complete in {total_duration/60:.2f} minutes.")
print(f"Average time per epoch: {total_duration / config['num_epochs']:.2f}s")


## 6. Evaluation & Translation

**Optimization**:
- **`torch.no_grad()`**: Disables gradient calculation, which reduces memory consumption and speeds up inference.
- **Batched Evaluation**: The validation loop processes data in batches, which is much more efficient than one-by-one evaluation.
- **Minimal CPU Transfer**: Data is kept on the GPU as long as possible. `.cpu()` is only called at the very end when decoding tokens.

In [None]:
bleu_metric = evaluate.load("sacrebleu")

def evaluate_model(model, dataloader):
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
            
            generated_tokens = model.generate(
                batch["input_ids"],
                attention_mask=batch["attention_mask"],
                max_length=config['max_target_length'],
            ).cpu().numpy() # Move to CPU after generation for decoding

            labels = batch["labels"].cpu().numpy()

            # Decode predictions and labels
            decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
            
            # Replace -100 in labels used for padding
            labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
            decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

            all_preds.extend([pred.strip() for pred in decoded_preds])
            all_labels.extend([[label.strip()] for label in decoded_labels])

    bleu_score = bleu_metric.compute(predictions=all_preds, references=all_labels)
    return bleu_score['score']

# --- Run Evaluation ---
print("\nEvaluating on validation set...")
eval_start_time = time.time()
bleu_result = evaluate_model(model, val_dataloader)
eval_duration = time.time() - eval_start_time

print(f"Evaluation complete in {eval_duration:.2f}s")
print(f"   BLEU Score: {bleu_result:.2f}")


## 7. Qualitative Examples & Benchmark


In [None]:
def translate(text, model, tokenizer, max_length=128):
    """Translate a single sentence."""
    model.eval()
    with torch.no_grad():
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=config['max_source_length'])
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        outputs = model.generate(**inputs, max_length=max_length)
        translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return translation

# Test examples
test_examples = [
    "Hello, how are you?",
    "The weather is beautiful today.",
    "I am learning to code.",
    "This is a ability test.",
    "I like to cook",
    "I am going to school.",
    "The weather is very nice today.",
    "Please give me a glass of water.", 
    "What is your name?",
]

print("\n" + "=" * 60)
print("TRANSLATION EXAMPLES")
print("=" * 60)

for text in test_examples:
    translation = translate(text, model, tokenizer)
    print(f"EN: {text}")
    print(f"UR: {translation}\n")

# --- Performance Benchmark ---
print("\n" + "=" * 60)
print("PERFORMANCE BENCHMARK")
print("=" * 60)
new_epoch_time_s = total_duration / config['num_epochs']

print(f"Optimized epoch time: {new_epoch_time_s:.2f} seconds")
print("=" * 60)
