# 📘 **Project Title: Transformer-Based Text Translation**
A practical implementation of a Transformer model for language translation.

# 🧠 **Overview**
This notebook demonstrates the training process of a Transformer model for text translation. It showcases the complete setup from configuration loading, model initialization, and tokenizer setup and finally the training loop.

This model is being trained on the Hugging face dataset named Opus_books by Helsinki-NLP and specifically the subset "en-it" which has around 32.2k rows. We have trained our model on 90% of this dataset and for inference we will be using 10% of the dataset. 

Here is the link to our dataset: https://huggingface.co/datasets/Helsinki-NLP/opus_books/viewer/en-it?views%5B%5D=en_it


# 🛠️  **Environment Setup**
Set Up Virtual Environment and Install Dependencies

In [None]:
%env PYTHONPATH =

Installs the `virtualenv` tool to create isolated Python environments.

In [None]:
!pip install virtualenv

Creates a new virtual environment named `myenv` to avoid dependency conflicts.

In [None]:
!virtualenv myenv

In [None]:
!myenv/bin/python --version

Installing all the necessary dependecies with their correct versions as given in the `requirements.txt`


In [None]:
# Make sure we're using the virtual environment's pip
!myenv/bin/pip install numpy==1.24.3
!myenv/bin/pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2
!myenv/bin/pip install datasets==2.15.0 tokenizers==0.13.3 torchmetrics==1.0.3
!myenv/bin/pip install tensorboard==2.13.0 tqdmn altair==5.1.1 wandb==0.15.9

In [None]:
!pip install datasets

# 📦 **Import all the needed libraries**

Import `Torch Utils` from DataLoader which Facilitates efficient data loading in batches, shuffling, and parallel processing during training.

Imports the `Dataset` class from Hugging Face for creating and managing custom datasets.
Used for batching data and splitting the dataset into training and validation sets.

Imports the base `Tokenizer` from the Hugging Face tokenizers library. This class handles the encoding and decoding of text to tokens.

Imports `tokenizer trainer` which is used to create a word-level vocabulary from the training data, including special tokens like [PAD], [SOS], and [EOS].

Imports `pre_tokenizer` from Whitespace library. It splits text into tokens based on whitespace — a straightforward way to prepare text before training the tokenizer.

Imports all the other important functions from the already defined files like: model.py, dataset.py, config.py

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from model import Transformer, build_transformer
from dataset import BilingualDataset, causal_mask
from config import get_config

In [None]:
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace
from pathlib import Path

# ⚙️ **Configure Training Parameters**
Defines training configuration with optimized hyperparameters. The increased batch size and adjusted learning rate improve training stability and efficiency.

In [None]:
# Get base configuration
config = get_config()

# Optimize for Colab Pro
config['batch_size'] = 32  # Increased for better GPU utilization
config['num_epochs'] = 30  # Increased epochs for better training
config['lr'] = 5e-5  # Lower learning rate for more stable training
config['preload'] = None  # Start fresh training (no preloading)

These are the exact details of the configuration:

In [None]:
# Print the configuration to verify
print("Training Configuration:")
for key, value in config.items():
    print(f"{key}: {value}")

# 🔍 **Creating the Beam Search Function**
The `beam_search_decode` function is a decoding algorithm used during inference in machine translation (or similar NLP tasks) with a Transformer model. Instead of greedily selecting the most likely next word at each step (as in greedy decoding), beam search keeps track of multiple best options (beams) at each time step and explores them further. This results in translations that are often more fluent and accurate.

In [None]:
# Create an improved beam search function for inference
import torch
from dataset import causal_mask

def beam_search_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device, beam_size=5):
    """Beam search for better translation quality"""
    sos_idx = tokenizer_tgt.token_to_id('[SOS]')
    eos_idx = tokenizer_tgt.token_to_id('[EOS]')

    # Encode the source sentence
    encoder_output = model.encode(source, source_mask)

    # Initialize the beam with start token
    sequences = [(torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device), 0.0)]

    # Beam search
    for _ in range(max_len):
        new_sequences = []

        # Expand each current sequence
        for seq, score in sequences:
            # If sequence ended with EOS, keep it unchanged
            if seq.size(1) > 1 and seq[0, -1].item() == eos_idx:
                new_sequences.append((seq, score))
                continue

            # Create decoder mask for this sequence
            decoder_mask = causal_mask(seq.size(1)).type_as(source_mask).to(device)

            # Get next token probabilities
            out = model.decode(encoder_output, source_mask, seq, decoder_mask)
            prob = model.project(out[:, -1])
            log_prob = torch.log_softmax(prob, dim=-1)

            # Get top-k token candidates
            topk_probs, topk_indices = torch.topk(log_prob, beam_size, dim=1)

            # Add new candidates to the list
            for i in range(beam_size):
                token = topk_indices[0, i].unsqueeze(0).unsqueeze(0)
                new_seq = torch.cat([seq, token], dim=1)
                new_score = score + topk_probs[0, i].item()
                new_sequences.append((new_seq, new_score))

        # Select top-k sequences
        new_sequences.sort(key=lambda x: x[1], reverse=True)
        sequences = new_sequences[:beam_size]

        # Check if all sequences have ended or reached max length
        if all((seq.size(1) > 1 and seq[0, -1].item() == eos_idx) or seq.size(1) >= max_len
               for seq, _ in sequences):
            break

    # Return the best sequence
    return sequences[0][0].squeeze(0)

# 📈 **Data Augmentation** 
Enhances the training dataset with common phrases. This ensures the model learns important everyday expressions that might be underrepresented in the original dataset.

In [None]:
# Define a function to add common word pairs to the dataset
from datasets import Dataset

def add_common_words_to_dataset(ds_raw):
    """Add common word pairs to ensure they're properly translated"""
    try:
        # Get original items as a list
        original_items = ds_raw.to_list()

        # Create dataset with common words and phrases
        common_phrases = [
            {"translation": {"en": "Hello", "it": "Ciao"}},
            {"translation": {"en": "Hello, how are you?", "it": "Ciao, come stai?"}},
            {"translation": {"en": "Goodbye", "it": "Arrivederci"}},
            {"translation": {"en": "Thank you", "it": "Grazie"}},
            {"translation": {"en": "Please", "it": "Per favore"}},
            {"translation": {"en": "Yes", "it": "Sì"}},
            {"translation": {"en": "No", "it": "No"}},
            {"translation": {"en": "Good morning", "it": "Buongiorno"}},
            {"translation": {"en": "Good evening", "it": "Buonasera"}},
            {"translation": {"en": "Good night", "it": "Buonanotte"}},
            {"translation": {"en": "How are you?", "it": "Come stai?"}},
            {"translation": {"en": "My name is", "it": "Mi chiamo"}},
            {"translation": {"en": "What is your name?", "it": "Come ti chiami?"}},
            {"translation": {"en": "I don't understand", "it": "Non capisco"}},
            {"translation": {"en": "I love you", "it": "Ti amo"}},
            {"translation": {"en": "I'm sorry", "it": "Mi dispiace"}},
            {"translation": {"en": "Where is", "it": "Dov'è"}},
            {"translation": {"en": "How much is this?", "it": "Quanto costa?"}},
            {"translation": {"en": "I would like", "it": "Vorrei"}},
            {"translation": {"en": "Can you help me?", "it": "Puoi aiutarmi?"}},
        ]

        # Add 5 copies of common phrases for emphasis
        enhanced_items = original_items.copy()
        for _ in range(5):
            enhanced_items.extend(common_phrases)

        # Create new dataset
        enhanced_ds = Dataset.from_list(enhanced_items)

        print(f"Original dataset size: {len(ds_raw)}")
        print(f"Enhanced dataset size: {len(enhanced_ds)}")

        return enhanced_ds
    except Exception as e:
        print(f"Error augmenting dataset: {str(e)}")
        print("Using original dataset instead")
        return ds_raw

# 🔤 **Tokenizer Construction** 
Creates improved tokenizers that preserve all vocabulary items. The `min_frequency=1` setting ensures even rare words are included in the vocabulary.

In [None]:
def get_all_sentences(ds, lang):
    for item in ds:
        yield item['translation'][lang]

def get_or_build_improved_tokenizer(config, ds, lang):
    tokenizer_path = Path(config['tokenizer_file'].format(lang))
    if not Path.exists(tokenizer_path):
        # Most code taken from: https://huggingface.co/docs/tokenizers/quicktour
        tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=1)
        tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)
        tokenizer.save(str(tokenizer_path))
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
    return tokenizer

# 🔄 **Dataset Preparation**
Prepares and splits the dataset with augmentations. This function handles loading, preprocessing, and creating efficient data loaders for training.

In [None]:
# Create function to get datasets with our improvements
from torch.utils.data import DataLoader, random_split
from datasets import load_dataset
from dataset import BilingualDataset

def get_improved_ds(config):
    """Get datasets with data augmentation and improved tokenization"""
    # Load dataset
    print("Loading dataset...")
    ds_raw = load_dataset(f"{config['datasource']}", f"{config['lang_src']}-{config['lang_tgt']}", split='train')

    # Apply data augmentation for common words
    print("Enhancing dataset with common words...")
    enhanced_ds = add_common_words_to_dataset(ds_raw)

    # Build improved tokenizers
    print("Building tokenizers...")
    tokenizer_src = get_or_build_improved_tokenizer(config, enhanced_ds, config['lang_src'])
    tokenizer_tgt = get_or_build_improved_tokenizer(config, enhanced_ds, config['lang_tgt'])

    # Split dataset
    print("Splitting dataset...")
    train_ds_size = int(0.9 * len(enhanced_ds))
    val_ds_size = len(enhanced_ds) - train_ds_size
    train_raw, val_raw = random_split(enhanced_ds, [train_ds_size, val_ds_size])

    # Create bilingual datasets
    print("Creating datasets...")
    train_ds = BilingualDataset(
        ds=train_raw,
        tokenizer_src=tokenizer_src,
        tokenizer_tgt=tokenizer_tgt,
        src_lang=config['lang_src'],
        tgt_lang=config['lang_tgt'],
        seq_len=config['seq_len']
    )

    val_ds = BilingualDataset(
        ds=val_raw,
        tokenizer_src=tokenizer_src,
        tokenizer_tgt=tokenizer_tgt,
        src_lang=config['lang_src'],
        tgt_lang=config['lang_tgt'],
        seq_len=config['seq_len']
    )

    # Create data loaders - using train_ds and val_ds consistently
    train_dataloader = DataLoader(
        train_ds,  # Changed from train_dataset to train_ds
        batch_size=config['batch_size'],
        shuffle=True
    )

    val_dataloader = DataLoader(
        val_ds,  # Changed from val_dataset to val_ds
        batch_size=1,
        shuffle=False
    )

    return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt

In [None]:
# Install torchmetrics directly in Colab
!pip install torchmetrics

# 🏋️ **Training Function**
Implements an advanced training loop with learning rate scheduling and validation metrics. The function includes best model saving based on BLEU score performance.

In [None]:
# Create improved training function with learning rate scheduler and label smoothing
from model import build_transformer
import torch
import torch.nn as nn
from pathlib import Path
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import torchmetrics  # Now this should work
from config import get_weights_file_path, latest_weights_file_path

def train_improved_model(config):
    """Improved training function with optimizations"""

    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    if device == 'cuda':
        print(f"Device name: {torch.cuda.get_device_name(0)}")
        print(f"Device memory: {torch.cuda.get_device_properties(0).total_memory / 1024 ** 3} GB")

    # Create weights directory
    weights_path = Path(f"{config['datasource']}_{config['model_folder']}")
    weights_path.mkdir(parents=True, exist_ok=True)

    # Get datasets
    train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_improved_ds(config)

    # Build model
    model = build_transformer(
        tokenizer_src.get_vocab_size(),
        tokenizer_tgt.get_vocab_size(),
        config['seq_len'],
        config['seq_len'],
        d_model=config['d_model']
    ).to(device)

    # Initialize TensorBoard
    writer = SummaryWriter(config['experiment_name'])

    # Optimizer with better parameters
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=config['lr'],
        betas=(0.9, 0.98),
        eps=1e-9
    )

    # Learning rate scheduler for better convergence
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=config['lr'],
        steps_per_epoch=len(train_dataloader),
        epochs=config['num_epochs'],
        pct_start=0.1,
        div_factor=10,
        final_div_factor=100
    )

    # Check for preloaded model
    initial_epoch = 0
    global_step = 0
    preload = config['preload']
    model_filename = latest_weights_file_path(config) if preload == 'latest' else get_weights_file_path(config, preload) if preload else None

    if model_filename and Path(model_filename).exists():
        print(f'Preloading model {model_filename}')
        state = torch.load(model_filename, map_location=device)
        model.load_state_dict(state['model_state_dict'])
        initial_epoch = state['epoch'] + 1
        optimizer.load_state_dict(state['optimizer_state_dict'])
        global_step = state['global_step']
    else:
        print('No model to preload, starting from scratch')

    # Loss function with label smoothing for better generalization
    loss_fn = nn.CrossEntropyLoss(
        ignore_index=tokenizer_tgt.token_to_id('[PAD]'),
        label_smoothing=0.1
    ).to(device)

    # Track best validation score
    best_bleu = 0

    # Training loop
    for epoch in range(initial_epoch, config['num_epochs']):
        model.train()
        batch_iterator = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{config['num_epochs']}")

        # Training phase
        for batch in batch_iterator:
            # Get batch data
            encoder_input = batch['encoder_input'].to(device)
            decoder_input = batch['decoder_input'].to(device)
            encoder_mask = batch['encoder_mask'].to(device)
            decoder_mask = batch['decoder_mask'].to(device)
            label = batch['label'].to(device)

            # Forward pass
            encoder_output = model.encode(encoder_input, encoder_mask)
            decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)
            proj_output = model.project(decoder_output)

            # Calculate loss
            loss = loss_fn(
                proj_output.view(-1, tokenizer_tgt.get_vocab_size()),
                label.view(-1)
            )

            # Update progress
            batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})

            # Log metrics
            writer.add_scalar('train/loss', loss.item(), global_step)
            writer.add_scalar('train/learning_rate', scheduler.get_last_lr()[0], global_step)
            writer.flush()

            # Backward pass
            loss.backward()

            # Update weights
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)

            # Update learning rate
            scheduler.step()

            global_step += 1

        # Validation phase
        print(f"\nValidation after epoch {epoch+1}:")
        model.eval()

        # Collect validation examples
        sources = []
        targets = []
        predictions = []

        with torch.no_grad():
            # Only process first 100 examples
            for count, batch in enumerate(tqdm(val_dataloader, desc="Validation", total=100)):
                # Hard limit to 100 examples
                if count >= 10:
                    break

                # Get batch data
                encoder_input = batch["encoder_input"].to(device)
                encoder_mask = batch["encoder_mask"].to(device)

                # Generate translation with beam search
                model_out = beam_search_decode(
                    model, encoder_input, encoder_mask,
                    tokenizer_src, tokenizer_tgt,
                    config['seq_len'], device
                )

                # Get text
                source_text = batch["src_text"][0]
                target_text = batch["tgt_text"][0]
                model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())
                model_out_text = model_out_text.replace("[SOS]", "").replace("[EOS]", "").strip()

                # Store for metrics
                sources.append(source_text)
                targets.append(target_text)
                predictions.append(model_out_text)

                # Print examples
                if count < 3:
                    print(f"Example {count+1}:")
                    print(f"Source: {source_text}")
                    print(f"Target: {target_text}")
                    print(f"Predicted: {model_out_text}")
                    print("-" * 80)

        # Calculate metrics
        metric = torchmetrics.BLEUScore()
        bleu_score = metric(predictions, [[t] for t in targets])

        metric = torchmetrics.WordErrorRate()
        wer = metric(predictions, targets)

        # Log metrics
        writer.add_scalar('validation/BLEU', bleu_score, global_step)
        writer.add_scalar('validation/WER', wer, global_step)
        writer.flush()

        print(f"BLEU Score: {bleu_score:.4f}")
        print(f"Word Error Rate: {wer:.4f}")

        # Save best model
        if bleu_score > best_bleu:
            best_bleu = bleu_score
            best_model_path = get_weights_file_path(config, "best")
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'global_step': global_step,
                'bleu_score': bleu_score
            }, best_model_path)
            print(f"New best model (BLEU: {bleu_score:.4f}) saved to {best_model_path}")

        # Save epoch checkpoint
        model_filename = get_weights_file_path(config, f"{epoch:02d}")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'global_step': global_step,
            'bleu_score': bleu_score
        }, model_filename)
        print(f"Saved checkpoint to {model_filename}")

# 🚀 **Execute Training** 
Runs the complete training process for the specified number of epochs. This is the main execution cell that trains the model using all previous setup.

In [None]:
# Run the improved training
train_improved_model(config)