In [1]:
from google.colab import files
import io

uploaded = files.upload()

# Assuming only one file is uploaded
file_name = list(uploaded.keys())[0]
!unzip 'English to Bengali For Machine Translation Pre-Train.zip'


Saving English to Bengali For Machine Translation Pre-Train.zip to English to Bengali For Machine Translation Pre-Train.zip
Archive:  English to Bengali For Machine Translation Pre-Train.zip
  inflating: english_to_bangla.csv   
  inflating: EBook_of_The_Bhagavad-Gita_Bengali.txt  
  inflating: EBook_of_The_Bhagavad-Gita_English.txt  


In [2]:
!pip install huggingface transformers

Collecting huggingface
  Downloading huggingface-0.0.1-py3-none-any.whl.metadata (2.9 kB)
Downloading huggingface-0.0.1-py3-none-any.whl (2.5 kB)
Installing collected packages: huggingface
Successfully installed huggingface-0.0.1


In [3]:
# CELL - 1
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import pandas as pd
import numpy as np
import math
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import pickle
import os
from tqdm import tqdm
import time
import random
import re  # For text normalization
import os
from tokenizers import Tokenizer, models, trainers, pre_tokenizers, decoders, processors
from tokenizers.normalizers import NFD, Lowercase, StripAccents, Sequence as NormalizerSequence
import tempfile

# =============================================================
# English-to-Bengali Transformer: Pretraining & Fine-tuning Flow
#
# This script implements a full pipeline for training a transformer-based
# English-to-Bengali translation model. It includes:
#   - Data loading and normalization
#   - BPE (ByteLevel) tokenizer training (shared or separate)
#   - Transformer model definition (encoder, decoder, MLM heads)
#   - Pretraining on monolingual data (MLM for encoder/decoder)
#   - Fine-tuning on parallel translation pairs
#   - Saving/loading checkpoints, plotting metrics, and inference
#
# The code is modular and can be adapted for other language pairs.
# =============================================================

# Configuration dictionary - all hyperparameters in one place
CONFIG = {
    'vocab_size': 18000,        # Maximum vocabulary size for both source and target languages
    'd_model': 512,             # Model dimension (embedding size for each token)
    'dff': 2048,                # Feed-forward network dimension (hidden size in FFN)
    'num_heads': 8,             # Number of attention heads in multi-head attention
    'num_encoder_layers': 6,    # Number of encoder layers (stacked)
    'num_decoder_layers': 6,    # Number of decoder layers (stacked)
    'dropout_rate': 0.1,        # Dropout rate for regularization
    'max_length': 200,          # Maximum sequence length for input/output
    'batch_size': 64,           # Batch size for training
    'pretrain_learning_rate': 0.0001,    # Learning rate for pre-training
    'finetune_learning_rate': 0.00005,   # Learning rate for fine-tuning (lower)
    'pretrain_epochs': 500,               # Number of pre-training epochs
    'finetune_epochs': 500,              # Number of fine-tuning epochs
    'apply_early_stop': False,           # Whether to use early stopping
    'patience': 5,                      # Early stopping patience (epochs)
    'english_file': 'EBook_of_The_Bhagavad-Gita_English.txt',       # Path to English monolingual data
    'bengali_file': 'EBook_of_The_Bhagavad-Gita_Bengali.txt',       # Path to Bengali monolingual data
    'translation_file': 'english_to_bangla.csv',                    # Path to parallel translation data
    'max_pretrain_sentences': 39000,    # Maximum sentences for pre-training (for quick test)
    'max_translation_pairs': 39,      # Maximum translation pairs for fine-tuning (for quick test)
    'mask_probability': 0.15,            # Probability of masking tokens during pre-training
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',  # Use GPU if available
    'max_train_minutes': 45,           # Maximum allowed training time (in minutes) for each phase
    'max_global_minutes': 100,          # Maximum allowed total wall-clock time (in minutes) for the whole script
    'shared_bpe_vocab': False,          # If True, use a single shared BPE vocabulary for both languages [Shared vocabulary might not be optimal for very different languages]
    'warmup_steps': 4000,              # Warm-up steps for Noam learning-rate schedule
    'tqdm_disable' : False
}

print(f"Using device: {CONFIG['device']}")

Using device: cpu


In [4]:
# CELL - 2
class PositionalEncoding(nn.Module):
    """
    WHAT THIS CLASS DOES:
    Imagine you have a sentence like "The cat sat on the mat"
    Each word gets converted to a list of numbers (called embeddings)
    But the AI doesn't know the ORDER of words - "cat sat" vs "sat cat" look the same!
    This class adds special "position codes" to tell the AI which word came first, second, etc.

    It's like adding invisible timestamps to each word so the AI knows their sequence.
    """
    def __init__(self, d_model, max_length=5000):
        # d_model = how many numbers represent each word (like 512 numbers per word)
        # max_length = maximum sentence length we can handle (like 5000 words max)

        # This is Python's way of calling the parent class constructor
        # Similar to calling super() in Java or C#
        super(PositionalEncoding, self).__init__()

        # Create a big table to store our position codes
        # INPUT DIMENSIONS: We want [max_length, d_model] - like a spreadsheet
        # Each row = position signature for that word position (0th word, 1st word, etc.)
        # Each column = one dimension of the position encoding
        pe = torch.zeros(max_length, d_model)  # OUTPUT: [5000, 512] matrix of zeros

        # Create a list of position numbers: [0, 1, 2, 3, 4, ...]
        # torch.arange is like range() in Python but creates a special array type
        # INPUT: start=0, end=max_length
        # OUTPUT before unsqueeze: [5000] - 1D array [0, 1, 2, 3, ..., 4999]
        # OUTPUT after unsqueeze(1): [5000, 1] - 2D column [[0], [1], [2], [3], ...]
        position = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1)

        # This is the mathematical "secret sauce" - creates different wave frequencies
        # Think of it like creating different radio frequencies for each dimension
        # INPUT: start=0, end=d_model, step=2 (so we get 0, 2, 4, 6, ...)
        # OUTPUT before exp: [256] array if d_model=512 (every even index)
        # OUTPUT after exp: [256] array of decreasing values (like [1.0, 0.9, 0.8, ...])
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        # Now we fill in our position codes using sine and cosine waves
        # Why waves? Because they create smooth, predictable patterns that repeat
        # This gives each position a unique "fingerprint" of numbers

        # For even-numbered columns (0, 2, 4, 6, ...), use sine wave
        # pe[:, 0::2] is Python slice notation: "all rows, every 2nd column starting from 0"
        # INPUT: position [5000, 1] * div_term [256] = [5000, 256] via broadcasting
        # OUTPUT: pe[:, 0::2] gets filled with [5000, 256] sine values
        pe[:, 0::2] = torch.sin(position * div_term)

        # For odd-numbered columns (1, 3, 5, 7, ...), use cosine wave
        # pe[:, 1::2] is Python slice notation: "all rows, every 2nd column starting from 1"
        # INPUT: position [5000, 1] * div_term [256] = [5000, 256] via broadcasting
        # OUTPUT: pe[:, 1::2] gets filled with [5000, 256] cosine values
        # Now pe is completely filled: [5000, 512] with alternating sine/cosine columns
        pe[:, 1::2] = torch.cos(position * div_term)

        # Reshape our data to work with batches of sentences
        # INPUT: pe is [5000, 512] (max_length, d_model)
        # After unsqueeze(0): [1, 5000, 512] - adds batch dimension at front
        # After transpose(0,1): [5000, 1, 512] - swaps first two dimensions
        # OUTPUT: pe is now [5000, 1, 512] (seq_length, batch_size, d_model)
        pe = pe.unsqueeze(0).transpose(0, 1)

        # This tells PyTorch: "save this data with the model, but don't try to learn/change it"
        # It's like marking a variable as 'final' or 'const' - it's part of the model but fixed
        # INPUT: pe [5000, 1, 512]
        # OUTPUT: self.pe is now a registered buffer [5000, 1, 512]
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        This method gets called when we actually use the class

        INPUT DIMENSIONS:
        x = the input sentence (each word already converted to numbers)
        x has shape: [sequence_length, batch_size, d_model]
        Example: [20, 32, 512] means 20 words, 32 sentences in batch, 512 numbers per word

        What we do: take our pre-calculated position codes and ADD them to each word
        It's like adding a unique "position stamp" to each word's numeric representation
        """
        # Extract position encodings for just the length we need
        # INPUT: self.pe is [5000, 1, 512] but we only need first x.size(0) positions
        # x.size(0) = sequence_length from input (like 20 in our example)
        # OUTPUT: self.pe[:x.size(0), :] is [20, 1, 512] - just the positions we need

        # The + operation adds our position codes to each word's numbers
        # INPUT: x [20, 32, 512] + self.pe[:20, :] [20, 1, 512]
        # Broadcasting happens: [20, 1, 512] gets expanded to [20, 32, 512]
        # OUTPUT: [20, 32, 512] - same shape as input but with position info added
        # self.pe[:x.size(0), :] = self.pe[:x.size(0)] = self.pe[:x.size(0), :, :] --> slicing upto the max words a sentence is having
        return x + self.pe[:x.size(0), :]

In [5]:
# CELL - 3
class MultiHeadAttention(nn.Module):
    """
    WHAT THIS CLASS DOES:
    Imagine you're reading a sentence and trying to understand each word.
    For the word "bank", you might look at nearby words to understand if it means:
    - "river bank" (look at words like "water", "river")
    - "money bank" (look at words like "deposit", "loan")

    This is "attention" - figuring out which other words are important for understanding each word.
    "Multi-head" means we do this multiple times in parallel, like having multiple people
    each focus on different types of relationships (grammar, meaning, context, etc.)
    """
    def __init__(self, d_model, num_heads):
        # d_model = how many numbers represent each word (like 512)
        # num_heads = how many parallel attention mechanisms (like 8)

        # Call parent constructor (like super() in Java/C#)
        super(MultiHeadAttention, self).__init__()

        # Safety check: d_model must be divisible by num_heads
        # We'll split the word representation equally among heads
        # Example: 512 dimensions ÷ 8 heads = 64 dimensions per head
        assert d_model % num_heads == 0

        self.d_model = d_model      # Total dimensions (512)
        self.num_heads = num_heads  # Number of attention heads (8)
        self.d_k = d_model // num_heads  # Dimensions per head (64)

        # Create 4 "projection" matrices - think of them as different colored glasses
        # Each one transforms the input to focus on different aspects
        # nn.Linear is like a matrix multiplication + optional bias
        # INPUT for each: [batch_size, seq_len, d_model]
        # OUTPUT for each: [batch_size, seq_len, d_model]

        self.W_q = nn.Linear(d_model, d_model, bias=False)  # Query projection: "what am I looking for?"
        self.W_k = nn.Linear(d_model, d_model, bias=False)  # Key projection: "what do I represent?"
        self.W_v = nn.Linear(d_model, d_model, bias=False)  # Value projection: "what information do I carry?"
        self.W_o = nn.Linear(d_model, d_model, bias=False)  # Output projection: "combine all heads"

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        """
        This is the core attention mechanism - like a search engine!

        INPUT DIMENSIONS:
        Q (queries): [batch_size, num_heads, seq_len, d_k] - "what each position is looking for"
        K (keys): [batch_size, num_heads, seq_len, d_k] - "what each position offers"
        V (values): [batch_size, num_heads, seq_len, d_k] - "actual information at each position"
        mask: optional [batch_size, seq_len, seq_len] - which positions to ignore

        Think of it like a library:
        - Q = your search query
        - K = book titles/keywords
        - V = actual book content
        - We find books whose titles match your query, then return their content
        """

        # Step 1: Calculate similarity scores between queries and keys
        # Q: [batch_size, num_heads, seq_len, d_k]
        # K.transpose(-2, -1): [batch_size, num_heads, d_k, seq_len] (flip last 2 dimensions)
        # torch.matmul: [batch_size, num_heads, seq_len, d_k] × [batch_size, num_heads, d_k, seq_len]
        # OUTPUT: scores [batch_size, num_heads, seq_len, seq_len]
        # scores[i,j,a,b] = how much position 'a' should pay attention to position 'b'
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        # Step 2: Apply mask if provided (optional)
        # Mask is used to ignore certain positions (like padding tokens)
        # INPUT: scores [batch_size, num_heads, seq_len, seq_len]
        # OUTPUT: scores with -1e9 (very negative) where mask == 0
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        # Step 3: Convert scores to probabilities using softmax
        # This ensures attention weights sum to 1 for each query position
        # INPUT: scores [batch_size, num_heads, seq_len, seq_len]
        # OUTPUT: attention_weights [batch_size, num_heads, seq_len, seq_len]
        # Each row sums to 1.0 (probability distribution)
        attention_weights = F.softmax(scores, dim=-1)

        # Step 4: Use attention weights to get weighted average of values
        # attention_weights: [batch_size, num_heads, seq_len, seq_len]
        # V: [batch_size, num_heads, seq_len, d_k]
        # OUTPUT: [batch_size, num_heads, seq_len, d_k]
        # For each position, we get a weighted combination of all value vectors
        output = torch.matmul(attention_weights, V)

        return output, attention_weights

    def forward(self, query, key, value, mask=None):
        """
        Main forward pass - processes input through multi-head attention

        INPUT DIMENSIONS:
        query: [batch_size, seq_len, d_model] - typically the same as key/value for self-attention
        key: [batch_size, seq_len, d_model] - what we're searching through
        value: [batch_size, seq_len, d_model] - the actual information
        mask: optional masking ; [batch_size, seq_len, seq_len]

        Example: batch_size=32, seq_len=20, d_model=512, num_heads=8
        """

        # Get batch size for reshaping operations
        batch_size = query.size(0)  # Example: 32

        # Step 1: Project inputs through learned transformations and split into heads
        # Think of this as creating multiple "views" of the same data

        # Apply query projection and reshape for multi-head processing
        # INPUT: query [32, 20, 512]
        # After W_q: [32, 20, 512] (linear transformation)
        # After view: [32, 20, 8, 64] (split into 8 heads of 64 dims each)
        # So self.W_q(query) = "transform the input into a specialized query representation using learned weights".[matrix multiplication]
        # self.W_q(query) → [32, 20, 512] (batch_size, seq_len, d_model).view(batch_size, -1, self.num_heads, self.d_k) → [32, 20, 8, 64]
        # What view does:
        ## Takes the 512 features and splits them into 8 groups of 64
        ## It's like taking a deck of 512 cards and organizing them into 8 piles of 64 cards each
        ## The -1 means "figure out this dimension automatically" (it becomes 20 in this case)
        # After transpose: [32, 8, 20, 64] (move heads to dimension 1)
        Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # Apply key projection and reshape (same process)
        # INPUT: key [32, 20, 512] → OUTPUT: [32, 8, 20, 64]
        K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # Apply value projection and reshape (same process)
        # INPUT: value [32, 20, 512] → OUTPUT: [32, 8, 20, 64]
        V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # Step 2: Apply attention mechanism to each head in parallel
        # INPUT: Q, K, V each [32, 8, 20, 64]
        # OUTPUT: attention_output [32, 8, 20, 64], attention_weights [32, 8, 20, 20]
        attention_output, attention_weights = self.scaled_dot_product_attention(Q, K, V, mask)

        # Step 3: Concatenate all heads back together
        # INPUT: attention_output [32, 8, 20, 64]
        # After transpose(1,2): [32, 20, 8, 64] (move sequence length back to position 1)
        # After contiguous(): ensures memory is laid out properly for view operation
        # After view: [32, 20, 512] (flatten the 8×64 back to 512)
        attention_output = attention_output.transpose(1, 2).contiguous().view(
            batch_size, -1, self.d_model)

        # Step 4: Apply final linear transformation
        # This lets the model learn how to best combine information from all heads
        # INPUT: attention_output [32, 20, 512]
        # OUTPUT: [32, 20, 512] (final result)
        output = self.W_o(attention_output)

        return output

In [6]:
# CELL - 4

class FeedForwardNetwork(nn.Module):
    """
    WHAT THIS CLASS DOES:
    Think of this as a "thinking layer" that processes each word independently.
    After attention figures out which words are related, this layer lets each word
    "think" about what it learned from attending to other words.

    It's like having a conversation where:
    1. First you listen to everyone (attention phase)
    2. Then you individually process what you heard (this feed-forward phase)

    The architecture is simple but powerful:
    - Expand: Give each word more "thinking space" (more numbers to work with)
    - Process: Apply non-linear thinking (ReLU activation)
    - Compress: Condense the thinking back to original size
    """
    def __init__(self, d_model, dff):
        # d_model = word representation size (like 512 numbers per word)
        # dff = internal "thinking" size (like 2048 - typically 4x larger than d_model)

        # Call parent constructor (like super() in Java/C#)
        super(FeedForwardNetwork, self).__init__()

        # Create two linear transformation layers
        # nn.Linear(input_size, output_size) creates a learnable matrix transformation
        # It's like: output = input × weight_matrix + bias

        # Layer 1: Expand from d_model to dff (give more thinking space)
        # INPUT dimensions: [batch_size, seq_len, d_model]
        # OUTPUT dimensions: [batch_size, seq_len, dff]
        # Example: [32, 20, 512] → [32, 20, 2048]
        self.linear1 = nn.Linear(d_model, dff)

        # Layer 2: Compress from dff back to d_model (condense thinking)
        # INPUT dimensions: [batch_size, seq_len, dff]
        # OUTPUT dimensions: [batch_size, seq_len, d_model]
        # Example: [32, 20, 2048] → [32, 20, 512]
        self.linear2 = nn.Linear(dff, d_model)

        # ReLU activation function: f(x) = max(0, x)
        # This introduces non-linearity - without it, the two linear layers
        # would just be equivalent to one linear layer (linear combinations of linear = linear)
        # ReLU lets the network learn complex, non-linear patterns
        self.relu = nn.ReLU()

    def forward(self, x):
        """
        Process input through the feed-forward network

        INPUT DIMENSIONS:
        x: [batch_size, seq_len, d_model]
        Example: [32, 20, 512] meaning 32 sentences, 20 words each, 512 numbers per word

        This processes each word position independently - no interaction between words
        (that's what attention layers are for)
        """

        # Step 1: Expand the representation (give more thinking space)
        # INPUT: x [32, 20, 512]
        # OPERATION: matrix multiply with learned weights [512, 2048] + bias [2048]
        # OUTPUT: [32, 20, 2048] - each word now has 2048 numbers instead of 512
        expanded = self.linear1(x)

        # Step 2: Apply ReLU activation (non-linear thinking)
        # INPUT: expanded [32, 20, 2048] (can have negative values)
        # OPERATION: f(x) = max(0, x) - set all negative values to 0
        # OUTPUT: [32, 20, 2048] (all values >= 0)
        # This allows the network to learn complex patterns and make decisions
        activated = self.relu(expanded)

        # Step 3: Compress back to original size (condense the thinking)
        # INPUT: activated [32, 20, 2048]
        # OPERATION: matrix multiply with learned weights [2048, 512] + bias [512]
        # OUTPUT: [32, 20, 512] - back to original word representation size
        final_output = self.linear2(activated)

        # The one-liner version of the above three steps:
        return self.linear2(self.relu(self.linear1(x)))

        # INTUITION:
        # This is like asking each word to "think harder" about what it learned
        # from attention, then summarize that thinking back to its original form.
        # The expansion→activation→compression allows for complex non-linear processing
        # that wouldn't be possible with just linear transformations.

In [7]:
# CELL - 5

class EncoderLayer(nn.Module):
    """
    WHAT THIS CLASS DOES:
    This is like one complete "conversation round" in a group discussion.
    In each round, everyone:
    1. Listens to everyone else and updates their understanding (attention)
    2. Thinks individually about what they learned (feed-forward)
    3. Keeps some of their original thoughts (residual connections)
    4. Normalizes their thinking to stay balanced (layer normalization)
    5. Occasionally "zones out" to prevent overthinking (dropout)

    Multiple encoder layers are stacked to create deeper understanding,
    like having multiple rounds of discussion where insights build up.

    The architecture follows: Input → Attention → Add & Norm → FeedForward → Add & Norm → Output
    """
    def __init__(self, d_model, num_heads, dff, dropout_rate):
        # d_model = word representation size (like 512)
        # num_heads = number of attention heads (like 8)
        # dff = feed-forward internal size (like 2048)
        # dropout_rate = probability of "zoning out" (like 0.1 = 10% chance)

        # Call parent constructor
        super(EncoderLayer, self).__init__()

        # Component 1: Multi-head attention (the "listening" phase)
        # INPUT/OUTPUT: [batch_size, seq_len, d_model]
        self.mha = MultiHeadAttention(d_model, num_heads)

        # Component 2: Feed-forward network (the "individual thinking" phase)
        # INPUT/OUTPUT: [batch_size, seq_len, d_model]
        self.ffn = FeedForwardNetwork(d_model, dff)

        # Component 3: Layer normalization (keeps values in reasonable range)
        # Think of this like "staying calm and balanced" after each processing step
        # INPUT/OUTPUT: [batch_size, seq_len, d_model] (same shape, just normalized)
        self.layernorm1 = nn.LayerNorm(d_model)  # After attention
        self.layernorm2 = nn.LayerNorm(d_model)  # After feed-forward

        # Component 4: Dropout (randomly "forget" some information to prevent overfitting)
        # This is like occasionally zoning out to avoid overthinking
        # INPUT/OUTPUT: [batch_size, seq_len, d_model] (same shape, some values → 0)
        self.dropout1 = nn.Dropout(dropout_rate)  # After attention
        self.dropout2 = nn.Dropout(dropout_rate)  # After feed-forward

    def forward(self, x, mask=None):
        """
        Process input through one encoder layer

        INPUT DIMENSIONS:
        x: [batch_size, seq_len, d_model]
        Example: [32, 20, 512] - 32 sentences, 20 words each, 512 numbers per word
        mask: optional [batch_size, seq_len, seq_len] - which positions to ignore

        OUTPUT DIMENSIONS:
        [batch_size, seq_len, d_model] - same shape as input but with richer representations
        """

        # PHASE 1: ATTENTION (Listen to everyone)
        # =====================================

        # Step 1: Apply multi-head self-attention
        # INPUT: x [32, 20, 512] used as query, key, AND value (self-attention)
        # OUTPUT: attn_output [32, 20, 512] - each word updated based on all words
        # This is where words "listen" to each other and update their understanding
        attn_output = self.mha(x, x, x, mask)

        # Step 2: Apply dropout (randomly zero out some values during training)
        # INPUT: attn_output [32, 20, 512]
        # OUTPUT: attn_output [32, 20, 512] (some values randomly set to 0)
        # This prevents the model from becoming too dependent on specific patterns
        attn_output = self.dropout1(attn_output)

        # Step 3: Add & Norm (residual connection + layer normalization)
        # INPUT: x [32, 20, 512] + attn_output [32, 20, 512]
        # First ADD: [32, 20, 512] (combines original input with attention output)
        # Then NORMALIZE: [32, 20, 512] (keeps values in reasonable range)
        #
        # WHY ADD THE ORIGINAL INPUT?
        # - Helps with gradient flow (technical: prevents vanishing gradients)
        # - Keeps some original information (don't lose what you already knew)
        # - Like saying "keep your original thoughts but add what you learned"
        out1 = self.layernorm1(x + attn_output)

        # PHASE 2: FEED-FORWARD (Individual thinking)
        # ==========================================

        # Step 4: Apply feed-forward network (individual processing)
        # INPUT: out1 [32, 20, 512]
        # OUTPUT: ffn_output [32, 20, 512] - each word processed independently
        # This is where each word "thinks" about what it learned from attention
        ffn_output = self.ffn(out1)

        # Step 5: Apply dropout again
        # INPUT: ffn_output [32, 20, 512]
        # OUTPUT: ffn_output [32, 20, 512] (some values randomly set to 0)
        ffn_output = self.dropout2(ffn_output)

        # Step 6: Add & Norm again (second residual connection)
        # INPUT: out1 [32, 20, 512] + ffn_output [32, 20, 512]
        # First ADD: [32, 20, 512] (combines attention output with feed-forward output)
        # Then NORMALIZE: [32, 20, 512] (final normalization)
        #
        # Again, we keep the previous state and add the new learning
        out2 = self.layernorm2(out1 + ffn_output)

        # FINAL OUTPUT: [32, 20, 512] - same shape as input but with enriched representations
        return out2

        # INTUITION:
        # Input words → Listen to each other → Keep original + learned info →
        # Think individually → Keep previous + new thoughts → Output richer representations
        #
        # Each layer builds upon the previous understanding, like multiple rounds
        # of discussion where insights accumulate and deepen.

In [8]:
class Encoder(nn.Module):
    """
    The encoder stack: input embedding + positional encoding + N encoder layers.

    This is like a text processing pipeline that converts words into mathematical
    representations and then refines them through multiple layers of analysis.
    Think of it as a sophisticated text analyzer that builds understanding by
    looking at words in context with each other.
    """
    def __init__(self, vocab_size, d_model, num_layers, num_heads, dff, max_length, dropout_rate):
        super(Encoder, self).__init__()
        self.d_model = d_model

        # Step 1: Word-to-Vector Converter
        # Takes word IDs (integers) and converts them to dense vectors
        # Input: integers in range [0, vocab_size-1]
        # Output: vectors of size d_model (e.g., 512 dimensions)
        # Like a lookup table: word_id -> dense_vector
        self.embedding = nn.Embedding(vocab_size, d_model)

        # Step 2: Position Information Injector
        # Adds "where am I in the sentence?" information to each word vector
        # Input: vectors of size d_model
        # Output: same vectors but now with position info mixed in
        # This helps the model understand word order (crucial for meaning)
        self.pos_encoding = PositionalEncoding(d_model, max_length)

        # Step 3: Stack of Processing Layers
        # Creates a list of identical processing units (like assembly line stations)
        # Each layer refines the understanding by looking at relationships between words
        # Input to each layer: vectors of size d_model
        # Output from each layer: refined vectors of same size d_model
        self.enc_layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, dff, dropout_rate)
            for _ in range(num_layers)  # Fixed syntax error: was "for * in range(num*layers)"
        ])

        # Step 4: Regularization (Prevents Overfitting)
        # Randomly sets some vector elements to zero during training
        # This forces the model to be more robust and not rely on specific patterns
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x, mask=None):
        """
        Process a batch of sentences through the encoder pipeline.

        Input dimensions: x has shape [batch_size, sequence_length]
        - batch_size: how many sentences we're processing at once
        - sequence_length: how many words in each sentence
        - Each element is an integer (word ID from vocabulary)

        Output dimensions: [batch_size, sequence_length, d_model]
        - Same batch_size and sequence_length as input
        - Each word now represented as a d_model-dimensional vector
        """

        # Get the length of our input sequences
        seq_len = x.size(1)

        # STEP 1: Convert word IDs to dense vectors
        # Input: [batch_size, seq_len] - integers
        # Output: [batch_size, seq_len, d_model] - floating point vectors
        x = self.embedding(x) * math.sqrt(self.d_model)
        # Note: multiplication by sqrt(d_model) is a scaling trick from the original paper
        # It helps with training stability by preventing embeddings from being too small

        # STEP 2: Add positional information
        # The transpose operations are needed because PositionalEncoding expects
        # sequence length as the first dimension, but we have batch first
        # Input: [batch_size, seq_len, d_model]
        # After first transpose: [seq_len, batch_size, d_model]
        # After pos_encoding: [seq_len, batch_size, d_model] (with position info added)
        # After second transpose: [batch_size, seq_len, d_model] (back to batch first)
        x = self.pos_encoding(x.transpose(0, 1)).transpose(0, 1)

        # STEP 3: Apply dropout for regularization
        # Randomly zero out some elements during training
        # Shape remains: [batch_size, seq_len, d_model]
        x = self.dropout(x)

        # STEP 4: Process through all encoder layers
        # Each layer takes [batch_size, seq_len, d_model] and outputs the same shape
        # But the vectors become more refined with each layer
        # Think of it as multiple rounds of "looking at context and updating understanding"
        for enc_layer in self.enc_layers:
            x = enc_layer(x, mask)
            # mask parameter: optional tensor that tells the model which positions to ignore
            # (useful for padding tokens or future positions in some applications)

        # Final output: [batch_size, seq_len, d_model]
        # Each word in each sentence now has a rich vector representation
        # that captures both the word's meaning and its contextual relationships
        return x

In [9]:
class DecoderLayer(nn.Module):
    """
    A single decoder layer: masked self-attention + encoder-decoder attention + feed-forward + normalization.

    This is like a sophisticated text generator that works in three stages:
    1. Looks at what it has written so far (masked self-attention)
    2. Considers the input context from the encoder (cross-attention)
    3. Processes and refines the information (feed-forward network)

    Think of it as a writer who:
    - Reviews their draft so far (but can't peek ahead)
    - Considers the source material/context
    - Refines their writing based on both
    """
    def __init__(self, d_model, num_heads, dff, dropout_rate):
        super(DecoderLayer, self).__init__()

        # Component 1: Self-Attention (Masked)
        # Lets each position look at previous positions in the output sequence
        # "What have I written so far that's relevant to what I'm writing now?"
        # Input/Output: [batch_size, seq_len, d_model] -> [batch_size, seq_len, d_model]
        self.mha1 = MultiHeadAttention(d_model, num_heads)

        # Component 2: Cross-Attention (Encoder-Decoder)
        # Lets the decoder attend to the encoder's output
        # "What parts of the input are relevant to what I'm generating now?"
        # Input: decoder=[batch_size, target_seq_len, d_model], encoder=[batch_size, source_seq_len, d_model]
        # Output: [batch_size, target_seq_len, d_model]
        self.mha2 = MultiHeadAttention(d_model, num_heads)

        # Component 3: Feed-Forward Network
        # Non-linear transformation to process the attended information
        # "Now let me think about all this information and refine it"
        # Input/Output: [batch_size, seq_len, d_model] -> [batch_size, seq_len, d_model]
        self.ffn = FeedForwardNetwork(d_model, dff)

        # Normalization Layers (3 of them, one after each major component)
        # These stabilize training by normalizing the data distribution
        # Think of them as "data cleaners" that keep values in a reasonable range
        # Input/Output: [batch_size, seq_len, d_model] -> [batch_size, seq_len, d_model]
        self.layernorm1 = nn.LayerNorm(d_model)  # After masked self-attention
        self.layernorm2 = nn.LayerNorm(d_model)  # After cross-attention
        self.layernorm3 = nn.LayerNorm(d_model)  # After feed-forward

        # Dropout Layers (3 of them, one after each major component)
        # Randomly zero out some elements during training to prevent overfitting
        # "Add some randomness to make the model more robust"
        # Input/Output: [batch_size, seq_len, d_model] -> [batch_size, seq_len, d_model]
        self.dropout1 = nn.Dropout(dropout_rate)  # After masked self-attention
        self.dropout2 = nn.Dropout(dropout_rate)  # After cross-attention
        self.dropout3 = nn.Dropout(dropout_rate)  # After feed-forward

    def forward(self, x, enc_output, look_ahead_mask=None, padding_mask=None):
        """
        Process target sequence through the decoder layer.

        Input dimensions:
        - x: [batch_size, target_seq_len, d_model] - what we're generating so far
        - enc_output: [batch_size, source_seq_len, d_model] - encoder's understanding of input
        - look_ahead_mask: prevents looking at future tokens (maintains causality)
        - padding_mask: ignores padding tokens in encoder output

        Output dimensions: [batch_size, target_seq_len, d_model]
        - Same shape as input x, but with refined representations
        """

        # STAGE 1: Masked Self-Attention
        # "Let me look at what I've written so far (but not peek ahead)"
        # Input: x=[batch_size, target_seq_len, d_model]
        # Query, Key, Value all come from x (self-attention)
        # look_ahead_mask prevents positions from seeing future positions
        # Output: [batch_size, target_seq_len, d_model]
        attn1 = self.mha1(x, x, x, look_ahead_mask)

        # Apply dropout for regularization
        # Input/Output: [batch_size, target_seq_len, d_model]
        attn1 = self.dropout1(attn1)

        # Add & Normalize (Residual Connection + Layer Normalization)
        # x + attn1: adds original input to attention output (residual connection)
        # This helps with gradient flow and allows the model to learn incremental changes
        # Input: x=[batch_size, target_seq_len, d_model], attn1=[batch_size, target_seq_len, d_model]
        # Output: [batch_size, target_seq_len, d_model]
        out1 = self.layernorm1(x + attn1)

        # STAGE 2: Cross-Attention (Encoder-Decoder Attention)
        # "Now let me consider the input context for what I'm generating"
        # Query comes from decoder (out1), Key and Value come from encoder (enc_output)
        # Input: out1=[batch_size, target_seq_len, d_model], enc_output=[batch_size, source_seq_len, d_model]
        # Output: [batch_size, target_seq_len, d_model]
        attn2 = self.mha2(out1, enc_output, enc_output, padding_mask)

        # Apply dropout for regularization
        # Input/Output: [batch_size, target_seq_len, d_model]
        attn2 = self.dropout2(attn2)

        # Add & Normalize (Residual Connection + Layer Normalization)
        # Input: out1=[batch_size, target_seq_len, d_model], attn2=[batch_size, target_seq_len, d_model]
        # Output: [batch_size, target_seq_len, d_model]
        out2 = self.layernorm2(out1 + attn2)

        # STAGE 3: Feed-Forward Network
        # "Let me process and refine all this information"
        # Non-linear transformation to add modeling capacity
        # Input/Output: [batch_size, target_seq_len, d_model]
        ffn_output = self.ffn(out2)

        # Apply dropout for regularization
        # Input/Output: [batch_size, target_seq_len, d_model]
        ffn_output = self.dropout3(ffn_output)

        # Final Add & Normalize (Residual Connection + Layer Normalization)
        # Input: out2=[batch_size, target_seq_len, d_model], ffn_output=[batch_size, target_seq_len, d_model]
        # Output: [batch_size, target_seq_len, d_model]
        out3 = self.layernorm3(out2 + ffn_output)

        # Final output: [batch_size, target_seq_len, d_model]
        # Each position now has a refined representation that considers:
        # 1. Previous positions in the target sequence
        # 2. Relevant parts of the source sequence
        # 3. Non-linear processing of the combined information
        return out3

In [10]:
class Decoder(nn.Module):
    """
    The decoder stack: target embedding + positional encoding + N decoder layers.

    This is the "text generation" part of the transformer - it takes what you want to generate
    (like a translation target or response) and produces rich representations by:
    1. Converting target words to vectors (embedding)
    2. Adding position information
    3. Processing through multiple decoder layers that look at both the target and encoder output

    Think of it as a sophisticated auto-complete system that considers both:
    - What you've typed so far (target sequence)
    - The context/source material (encoder output)
    """
    def __init__(self, vocab_size, d_model, num_layers, num_heads, dff, max_length, dropout_rate):
        super(Decoder, self).__init__()
        self.d_model = d_model

        # Step 1: Target Word-to-Vector Converter
        # Takes target word IDs (what we want to generate) and converts to dense vectors
        # Input: integers in range [0, vocab_size-1] (target word IDs)
        # Output: vectors of size d_model (e.g., 512 dimensions)
        # Like a lookup table: target_word_id -> dense_vector
        self.embedding = nn.Embedding(vocab_size, d_model)

        # Step 2: Position Information Injector
        # Adds "where am I in the target sequence?" information to each word vector
        # Input: vectors of size d_model
        # Output: same vectors but now with position info mixed in
        # This helps the model understand the order of words being generated
        self.pos_encoding = PositionalEncoding(d_model, max_length)

        # Step 3: Stack of Decoder Processing Layers
        # Creates a list of identical decoder units (like assembly line stations)
        # Each layer refines understanding by looking at:
        # - Previous words in target sequence (masked self-attention)
        # - Relevant parts of encoder output (cross-attention)
        # Input to each layer: vectors of size d_model
        # Output from each layer: refined vectors of same size d_model
        self.dec_layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, dff, dropout_rate)
            for _ in range(num_layers)  # Fixed syntax error: was "for * in range(num*layers)"
        ])

        # Step 4: Regularization (Prevents Overfitting)
        # Randomly sets some vector elements to zero during training
        # This forces the model to be more robust and not rely on specific patterns
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x, enc_output, look_ahead_mask=None, padding_mask=None):
        """
        Process target sequence through the decoder pipeline.

        Input dimensions:
        - x: [batch_size, target_seq_len] - target sequence (what we want to generate)
        - enc_output: [batch_size, source_seq_len, d_model] - encoder's understanding of input
        - look_ahead_mask: prevents decoder from seeing future target words (maintains causality)
        - padding_mask: tells decoder which encoder positions to ignore (padding tokens)

        Output dimensions: [batch_size, target_seq_len, d_model]
        - Same batch_size and target_seq_len as input x
        - Each target word now represented as a d_model-dimensional vector
        """

        # Get the length of our target sequences
        seq_len = x.size(1)

        # STEP 1: Convert target word IDs to dense vectors
        # Input: [batch_size, target_seq_len] - integers (target word IDs)
        # Output: [batch_size, target_seq_len, d_model] - floating point vectors
        x = self.embedding(x) * math.sqrt(self.d_model)
        # Note: multiplication by sqrt(d_model) is a scaling trick from the original paper
        # It helps with training stability by preventing embeddings from being too small

        # STEP 2: Add positional information
        # The transpose operations are needed because PositionalEncoding expects
        # sequence length as the first dimension, but we have batch first
        # Input: [batch_size, target_seq_len, d_model]
        # After first transpose: [target_seq_len, batch_size, d_model]
        # After pos_encoding: [target_seq_len, batch_size, d_model] (with position info added)
        # After second transpose: [batch_size, target_seq_len, d_model] (back to batch first)
        x = self.pos_encoding(x.transpose(0, 1)).transpose(0, 1)

        # STEP 3: Apply dropout for regularization
        # Randomly zero out some elements during training
        # Shape remains: [batch_size, target_seq_len, d_model]
        x = self.dropout(x)

        # STEP 4: Process through all decoder layers
        # Each layer takes [batch_size, target_seq_len, d_model] and outputs the same shape
        # But the vectors become more refined with each layer
        # Each layer performs three operations:
        # 1. Masked self-attention: "What have I generated so far?"
        # 2. Cross-attention: "What's relevant from the input?"
        # 3. Feed-forward: "Let me process this information"
        for dec_layer in self.dec_layers:
            x = dec_layer(x, enc_output, look_ahead_mask, padding_mask)
            # x: current target representations [batch_size, target_seq_len, d_model]
            # enc_output: encoder's input understanding [batch_size, source_seq_len, d_model]
            # look_ahead_mask: prevents seeing future target words
            # padding_mask: ignores padding in encoder output

        # Final output: [batch_size, target_seq_len, d_model]
        # Each target word now has a rich vector representation that captures:
        # 1. The word's meaning and position in target sequence
        # 2. Its relationship to previous target words
        # 3. Its relationship to relevant parts of the source sequence
        # 4. Non-linear processing of all this combined information
        return x

In [11]:
class Transformer(nn.Module):
    """
    Complete Transformer model for sequence-to-sequence translation.
    Now supports separate numbers of layers for encoder and decoder.
    - Encoder handles source language.
    - Decoder handles target language.
    - Includes separate MLM heads for pretraining.

    This is the complete translation system that combines:
    1. An encoder that understands the source language (e.g., English)
    2. A decoder that generates the target language (e.g., Bengali)
    3. Additional heads for pretraining on individual languages

    Think of it as a sophisticated translation pipeline:
    - Encoder: "I understand what this English text means"
    - Decoder: "Based on that understanding, here's the Bengali translation"
    - MLM heads: "I can also learn individual languages separately"
    """
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model,
                 num_encoder_layers, num_decoder_layers, num_heads,
                 dff, max_length, dropout_rate):
        super(Transformer, self).__init__()

        # --- Main Translation Components ---

        # Source Language Processor (Encoder)
        # Takes source language tokens and creates rich representations
        # Input: [batch_size, source_seq_len] (source word IDs)
        # Output: [batch_size, source_seq_len, d_model] (source understanding vectors)
        self.encoder = Encoder(src_vocab_size, d_model, num_encoder_layers, num_heads,
                               dff, max_length, dropout_rate)

        # Target Language Generator (Decoder)
        # Takes target language tokens and encoder output, generates target representations
        # Input: target=[batch_size, target_seq_len], encoder_out=[batch_size, source_seq_len, d_model]
        # Output: [batch_size, target_seq_len, d_model] (target understanding vectors)
        self.decoder = Decoder(tgt_vocab_size, d_model, num_decoder_layers, num_heads,
                               dff, max_length, dropout_rate)

        # Final Translation Head
        # Converts decoder representations to target vocabulary probabilities
        # Input: [batch_size, target_seq_len, d_model]
        # Output: [batch_size, target_seq_len, tgt_vocab_size] (word probabilities)
        self.final_layer = nn.Linear(d_model, tgt_vocab_size)

        # Device configuration for GPU/CPU processing
        self.device = torch.device(CONFIG['device'])

        # --- Pretraining Components ---

        # Masked Language Model Head for Source Language (e.g., English)
        # Used during pretraining to learn source language patterns
        # Input: [batch_size, seq_len, d_model] (encoder output)
        # Output: [batch_size, seq_len, src_vocab_size] (source word probabilities)
        self.mlm_head_src = nn.Linear(d_model, src_vocab_size)

        # Masked Language Model Head for Target Language (e.g., Bengali)
        # Used during pretraining to learn target language patterns
        # Input: [batch_size, seq_len, d_model] (decoder output)
        # Output: [batch_size, seq_len, tgt_vocab_size] (target word probabilities)
        self.mlm_head_tgt = nn.Linear(d_model, tgt_vocab_size)

    def create_padding_mask(self, seq):
        """
        Create padding mask to ignore padding tokens.

        Think of this like preparing a list of sentences for processing, where some sentences
        are shorter than others. To handle them efficiently in batches, we add "fake" words
        (padding tokens, marked as 0) to make all sentences the same length.

        This function creates a "mask" - essentially a map that tells us which parts of each
        sentence are real words and which parts are just padding that should be ignored.

        Example:
        - Original sentences: ["Hello world", "Hi there friend", "Hey"]
        - After padding: ["Hello world 0", "Hi there friend", "Hey 0 0"]
        - The mask marks which positions contain real words vs padding zeros

        Args:
            seq: Input tensor with shape [batch_size, seq_len]
                - batch_size: how many sentences we're processing at once
                - seq_len: maximum length of sentences (after padding)
                - Values: actual token IDs for words, 0 for padding

        Returns:
            Mask tensor with shape [batch_size, 1, 1, seq_len]
            - True where there are real tokens, False where there's padding
            - Extra dimensions (1, 1) are added for neural network compatibility
        """

        # Step 1: Compare each position to 0 (padding token)
        # (seq != 0) creates a boolean tensor: True for real tokens, False for padding
        # Input shape: [batch_size, seq_len] → Output shape: [batch_size, seq_len]
        mask = (seq != 0)

        # Step 2: Add first extra dimension for attention mechanism compatibility
        # .unsqueeze(1) adds a dimension at position 1
        # Shape: [batch_size, seq_len] → [batch_size, 1, seq_len]
        mask = mask.unsqueeze(1)

        # Step 3: Add second extra dimension for attention mechanism compatibility
        # .unsqueeze(2) adds a dimension at position 2
        # Shape: [batch_size, 1, seq_len] → [batch_size, 1, 1, seq_len]
        mask = mask.unsqueeze(2)

        return mask


    def create_look_ahead_mask(self, size):
        """
        Create look-ahead mask to prevent seeing future tokens during training.

        Think of this like taking a multiple-choice test where you're supposed to predict
        the next word in a sentence. To make it fair, you should only see the words that
        come BEFORE the word you're trying to predict, not the words that come AFTER.

        This creates a "triangular curtain" that blocks the model from cheating by looking
        at future words. For each position, it can only see itself and previous positions.

        Example for size=4:
        Position 0 can see: [0]           → mask row: [True,  False, False, False]
        Position 1 can see: [0, 1]       → mask row: [True,  True,  False, False]
        Position 2 can see: [0, 1, 2]    → mask row: [True,  True,  True,  False]
        Position 3 can see: [0, 1, 2, 3] → mask row: [True,  True,  True,  True ]

        Args:
            size: Integer - length of the sequence (how many words/tokens)

        Returns:
            Mask tensor with shape [1, 1, size, size]
            - True means "can see this position"
            - False means "block this position"
            - Extra dimensions (1, 1) are for neural network compatibility
        """

        # Step 1: Create a matrix of all ones with shape [size, size]
        # This is like creating a blank grid where every cell is filled
        # Input: size (integer) → Output shape: [size, size]
        ones_matrix = torch.ones(size, size)

        # Step 2: Keep only the lower triangle (including diagonal)
        # torch.tril = "triangle lower" - keeps bottom-left triangle, zeros out top-right
        # This creates the "staircase" pattern where each row can see more positions
        # Input shape: [size, size] → Output shape: [size, size]
        triangular_matrix = torch.tril(ones_matrix)

        # Step 3: Convert to boolean values (True/False instead of 1.0/0.0)
        # Neural networks work better with explicit boolean masks
        # Input shape: [size, size] → Output shape: [size, size]
        mask = triangular_matrix.bool()

        # Step 4: Add first extra dimension for batch processing
        # .unsqueeze(0) adds a dimension at the beginning
        # Shape: [size, size] → [1, size, size]
        mask = mask.unsqueeze(0)

        # Step 5: Add second extra dimension for attention mechanism compatibility
        # .unsqueeze(0) adds another dimension at the beginning
        # Shape: [1, size, size] → [1, 1, size, size]
        mask = mask.unsqueeze(0)

        return mask
    def forward(self, src, tgt, training=True):
        """
        Standard translation forward pass (not used for pretraining).

        This is the main translation pipeline that converts source language to target language.

        Input dimensions:
        - src: [batch_size, source_seq_len] (source language word IDs)
        - tgt: [batch_size, target_seq_len] (target language word IDs)
        - training: boolean (whether in training mode)

        Output dimensions: [batch_size, target_seq_len, tgt_vocab_size]
        - Probability distribution over target vocabulary for each position
        """

        # STEP 1: Create masks to handle padding and causality

        # Source padding mask: "Which source tokens are real vs padding?"
        # Input: [batch_size, source_seq_len]
        # Output: [batch_size, 1, 1, source_seq_len]
        src_mask = self.create_padding_mask(src)

        # Target padding mask: "Which target tokens are real vs padding?"
        # Input: [batch_size, target_seq_len]
        # Output: [batch_size, 1, 1, target_seq_len]
        tgt_padding_mask = self.create_padding_mask(tgt).to(self.device)

        # Look-ahead mask: "Prevent seeing future target tokens"
        # Input: target_seq_len (integer)
        # Output: [1, 1, target_seq_len, target_seq_len]
        tgt_seq_len = tgt.size(1)
        look_ahead_mask = self.create_look_ahead_mask(tgt_seq_len).to(self.device)

        # Combined mask: "Apply both padding and look-ahead restrictions"
        # Input: look_ahead_mask + tgt_padding_mask
        # Output: [batch_size, 1, target_seq_len, target_seq_len]
        combined_mask = torch.logical_and(look_ahead_mask, tgt_padding_mask)

        # STEP 2: Process source through encoder
        # Input: src=[batch_size, source_seq_len], src_mask=[batch_size, 1, 1, source_seq_len]
        # Output: [batch_size, source_seq_len, d_model] (source understanding)
        enc_output = self.encoder(src, src_mask)

        # STEP 3: Process target through decoder (with encoder context)
        # Input: tgt=[batch_size, target_seq_len], enc_output=[batch_size, source_seq_len, d_model]
        # Output: [batch_size, target_seq_len, d_model] (target understanding)
        dec_output = self.decoder(tgt, enc_output, combined_mask, src_mask)

        # STEP 4: Convert to vocabulary probabilities
        # Input: [batch_size, target_seq_len, d_model]
        # Output: [batch_size, target_seq_len, tgt_vocab_size] (word probabilities)
        final_output = self.final_layer(dec_output)

        return final_output

    def mlm_encode(self, input_ids):
      """
      For encoder MLM pretraining (English):
      - Pass input through encoder
      - Project encoder output to English vocabulary size

      This is used for pretraining the encoder on source language data.
      Like teaching the encoder "Given this English sentence with some words masked,
      predict what the masked words should be."

      Input dimensions: input_ids=[batch_size, seq_len] (source language with masked tokens)
      Output dimensions: [batch_size, seq_len, src_vocab_size] (predictions for each position)
      """

      # Validate input shape
      # Think of this as checking: "Are we receiving a 2D table of numbers?"
      # input_ids is like a spreadsheet where each row is a sentence and each column is a word position
      assert input_ids.dim() == 2, f"Encoder MLM input_ids must be 2D [batch, seq], got {input_ids.shape}"

      # Create padding mask for the input
      # Problem: Sentences have different lengths, but we need fixed-size arrays
      # Solution: Add padding (like empty spaces) and create a "mask" to ignore those spaces
      #
      # Input: [batch_size, seq_len] - A 2D array where each row is a sentence
      # This function takes our 2D sentence array and creates a 4D "attention mask"
      # Output: [batch_size, 1, 1, seq_len] - A 4D array that tells the model "ignore padded positions"
      # The extra dimensions (1, 1) are added for broadcasting - like stretching a rubber sheet
      # to fit over a larger surface
      src_mask = self.create_padding_mask(input_ids)

      # Process through encoder
      # The encoder is like a smart translator that reads the input sentence and creates
      # a "rich understanding" of each word in context
      #
      # Input: input_ids=[batch_size, seq_len] - Original sentences with word IDs
      #        src_mask=[batch_size, 1, 1, seq_len] - Mask saying "ignore padding"
      #
      # The encoder transforms each word ID into a high-dimensional vector (d_model numbers)
      # that captures the word's meaning in context
      # Output: [batch_size, seq_len, d_model] - Each word becomes a vector of d_model numbers
      # Think: instead of word ID "5", we now have a vector like [0.2, -0.1, 0.8, ...]
      enc_output = self.encoder(input_ids, src_mask)

      # Validate encoder output shape
      # Safety check: Make sure the encoder didn't change the batch size or sequence length
      # We should still have the same number of sentences and same number of words per sentence
      assert enc_output.shape[:2] == input_ids.shape, f"Encoder output shape {enc_output.shape} does not match input {input_ids.shape}"

      # Project to source vocabulary
      # Now we have rich word representations, but we need to convert them back to word predictions
      # This is like asking: "Given this rich understanding, what word should go here?"
      #
      # Input: [batch_size, seq_len, d_model] - Rich word representations
      # The mlm_head_src is a learned transformation (like a lookup table with math)
      # that maps each d_model-dimensional vector to vocabulary-sized predictions
      # Output: [batch_size, seq_len, src_vocab_size] - Probability scores for each word in vocabulary
      #
      # For each position in each sentence, we get a score for every possible word
      # Higher score = "I think this word belongs here"
      return self.mlm_head_src(enc_output)

    def decoder_mlm(self, input_ids):
      """
      For decoder MLM pretraining (Bengali):
      - Pass input through decoder (as a language model)
      - Use dummy encoder output (zeros)
      - Project decoder output to Bengali vocabulary size

      This is used for pretraining the decoder on target language data.
      Like teaching the decoder "Given this Bengali sentence with some words masked,
      predict what the masked words should be" (without any source context).

      Input dimensions: input_ids=[batch_size, seq_len] (target language with masked tokens)
      Output dimensions: [batch_size, seq_len, tgt_vocab_size] (predictions for each position)
      """

      # Extract basic information from input
      # Think of this as asking: "How many sentences do we have?" and "How long are they?"
      # input_ids is like a 2D table where rows=sentences, columns=word positions
      batch_size, seq_len = input_ids.size()
      device = input_ids.device  # Which computing device (CPU/GPU) are we using?

      # Create dummy encoder output (all zeros)
      # Problem: The decoder usually expects input from an encoder (like in translation)
      # But here we're just doing language modeling (predict next word), so no encoder needed
      # Solution: Create a "fake" encoder output filled with zeros
      #
      # This creates a 3D array filled with zeros:
      # Input: batch_size, seq_len, d_model (dimensions we need)
      # Output: [batch_size, seq_len, d_model] - Like a 3D box of zeros
      # Think: For each sentence, for each word position, we have d_model zero values
      dummy_enc = torch.zeros(batch_size, seq_len, self.decoder.d_model, device=device)

      # Validate dummy encoder shape
      # Safety check: Make sure our fake encoder output has the right dimensions
      # Like checking if a box has the right length, width, and height
      assert dummy_enc.shape == (batch_size, seq_len, self.decoder.d_model), f"Dummy encoder shape {dummy_enc.shape} does not match (batch, seq, d_model)"

      # Create look-ahead mask: "Don't peek at future words"
      # Problem: When predicting a word, we shouldn't see words that come after it
      # (That would be cheating - like seeing the answer before solving the problem)
      #
      # Input: seq_len (just a number - how long our sentences are)
      # This function creates a triangular mask - imagine a lower triangular matrix
      # Output: [1, 1, seq_len, seq_len] - A 4D "attention mask"
      # The mask says "position i can only look at positions 0 to i, not i+1 to end"
      look_ahead_mask = self.create_look_ahead_mask(seq_len).to(device)

      # Create padding mask: "Ignore padding tokens"
      # Problem: Sentences have different lengths, but we need fixed-size arrays
      # Solution: We added padding, now we need to ignore those padded positions
      #
      # Input: [batch_size, seq_len] - Our original sentence array
      # Output: [batch_size, 1, 1, seq_len] - A mask saying "ignore padded positions"
      # The extra dimensions (1, 1) are for broadcasting - like stretching to fit
      padding_mask = self.create_padding_mask(input_ids).to(device)

      # Combine both masks
      # We need BOTH rules: "don't look ahead" AND "ignore padding"
      # This is like combining two filters - both must be satisfied
      #
      # Input: look_ahead_mask [1, 1, seq_len, seq_len] + padding_mask [batch_size, 1, 1, seq_len]
      # The logical_and operation combines them element-wise (like AND gate in logic)
      # Output: [batch_size, 1, seq_len, seq_len] - Combined mask with both rules
      combined_mask = torch.logical_and(look_ahead_mask, padding_mask)

      # Validate mask shapes
      # Safety checks: Make sure our masks have the right dimensions
      # Like checking if puzzle pieces have the right shape before fitting them together
      assert look_ahead_mask.shape[-2:] == (seq_len, seq_len), f"Look ahead mask shape {look_ahead_mask.shape} does not match (1, 1, seq, seq)"
      assert padding_mask.shape[-1] == seq_len, f"Padding mask shape {padding_mask.shape} does not match seq_len {seq_len}"

      # Process through decoder with dummy encoder output
      # The decoder is like a sophisticated text predictor that learns context
      # It takes the input words and the fake encoder output, following the mask rules
      #
      # Input: input_ids=[batch_size, seq_len] - Original sentences with word IDs
      #        dummy_enc=[batch_size, seq_len, d_model] - Fake encoder output (all zeros)
      #        combined_mask=[batch_size, 1, seq_len, seq_len] - Combined attention rules
      #        None - No additional encoder mask needed
      #
      # The decoder transforms each word ID into a rich representation that understands:
      # "What should the next word be, given the context so far?"
      # Output: [batch_size, seq_len, d_model] - Rich word representations for predictions
      dec_out = self.decoder(input_ids, dummy_enc, combined_mask, None)

      # Validate decoder output shape
      # Safety check: Make sure decoder didn't change batch size or sequence length
      # We should still have the same number of sentences and words per sentence
      assert dec_out.shape[:2] == input_ids.shape, f"Decoder output shape {dec_out.shape} does not match input {input_ids.shape}"

      # Project to target vocabulary
      # Now we have rich word representations, but we need actual word predictions
      # This is like asking: "Given this understanding, what word should come next?"
      #
      # Input: [batch_size, seq_len, d_model] - Rich word representations from decoder
      # The mlm_head_tgt is a learned transformation that maps each d_model-dimensional
      # vector to vocabulary-sized predictions (like a smart lookup table)
      # Output: [batch_size, seq_len, tgt_vocab_size] - Probability scores for each word
      #
      # For each position in each sentence, we get a score for every possible word
      # in the target language vocabulary. Higher score = "I think this word belongs here"
      return self.mlm_head_tgt(dec_out)

In [12]:
# Dataset classes
class PretrainDataset(Dataset):
    """
    Dataset for pre-training with masked language modeling (MLM).
    - Randomly masks tokens in each sentence.
    - Returns input_ids (masked) and labels (original, with -100 for unmasked).
    """

    # This is the initialization function - it runs when you create a new instance of this class
    # Think of it like setting up a new container with all the ingredients you need
    # Input: sentences (list of strings), vocab (word-to-number converter), max_length (integer), mask_prob (float)
    # Output: A configured dataset object ready to serve training data
    def __init__(self, sentences, vocab, max_length, mask_prob=0.15):
        # Store the list of sentences we want to train on
        # Input: sentences - a Python list like ["Hello world", "How are you", ...]
        self.sentences = sentences
        # Store the vocabulary (dictionary that converts words to numbers)
        # Input: vocab - an object that can convert "hello" -> 142, "world" -> 89, etc.
        self.vocab = vocab
        # Store the maximum length we want our sentences to be
        # Input: max_length - integer like 128 (all sentences will be padded/truncated to this length)
        self.max_length = max_length
        # Store the probability of masking each word (15% by default)
        # Input: mask_prob - float like 0.15 (15% chance to hide each word)
        self.mask_prob = mask_prob

    # This function tells us how many sentences we have in total
    # Like counting how many items are in a box
    # Input: None (called automatically by Python)
    # Output: Integer representing total number of sentences
    def __len__(self):
        return len(self.sentences)

    # This function takes a sentence and randomly hides some words
    # It's like playing a fill-in-the-blank game
    # Input: tokens - a list of integers like [2, 45, 123, 67, 3] (sentence converted to numbers)
    # Output: masked_tokens (list of integers with some hidden), labels (list showing which were hidden)
    def mask_tokens(self, tokens):
        # Create a copy of the original sentence so we don't change it
        # Input: tokens - list of integers like [2, 45, 123, 67, 3]
        # Output: masked_tokens - identical copy at first, will be modified
        masked_tokens = tokens.copy()
        # Create a list to store the correct answers (what words were hidden)
        # -100 is a special number meaning "don't check this word"
        # Input: len(tokens) - integer length of sentence
        # Output: labels - list of -100s, same length as tokens, like [-100, -100, -100, -100, -100]
        labels = [-100] * len(tokens)

        # Go through each word in the sentence
        # This loop processes each position: 0, 1, 2, 3, ... len(tokens)-1
        for i, token in enumerate(tokens):
            # Don't hide special words like padding, start/end markers, or already masked words
            # Numbers 0,1,2,3,4 represent PAD, UNK, SOS, EOS, MASK tokens
            # Input: token - single integer like 45
            # Decision: skip if token is a special control token
            if token in [0, 2, 3, 4]:
                continue

            # Randomly decide if we should hide this word (15% chance)
            # Input: random.random() generates float between 0.0 and 1.0
            # Decision: if random number < 0.15, then mask this word
            if random.random() < self.mask_prob:
                # Remember what the original word was (this becomes our "correct answer")
                # Input: token - the original word ID like 45
                # Output: labels[i] changes from -100 to 45 (the correct answer)
                labels[i] = token

                # 80% of the time, replace with a special MASK token
                # Input: random.random() generates another float between 0.0 and 1.0
                # Decision: if < 0.8, use MASK token
                if random.random() < 0.8:
                    # Replace word with special MASK token (usually token ID 4)
                    # Input: vocab lookup returns integer ID for '<MASK>' token
                    # Output: masked_tokens[i] becomes 4 (or whatever MASK token ID is)
                    masked_tokens[i] = self.vocab.get_vocab().get('<MASK>', 4)
                # 10% of the time, replace with a random word from our vocabulary
                # (This checks 50% of the remaining 20%, which equals 10% overall)
                elif random.random() < 0.5:
                    # Replace with random word ID from vocabulary
                    # Input: vocab size like 30000, generates random int from 5 to 29999
                    # Output: masked_tokens[i] becomes random word ID like 15842
                    masked_tokens[i] = random.randint(5, len(self.vocab.get_vocab()) - 1)
                # 10% of the time, keep the original word (this helps the model learn better)
                # No change needed - masked_tokens[i] stays as original token

        # Return the modified sentence and the answer key
        # Output: masked_tokens - list of integers with some words hidden/changed
        #         labels - list of integers/(-100) showing which positions need to be predicted
        return masked_tokens, labels

    # This function gets called when we ask for a specific sentence by its position
    # Like asking for the 5th item in a list
    # Input: idx - integer index like 0, 1, 2, 3... (which sentence to get)
    # Output: dictionary with 'input_ids' and 'labels' tensors
    def __getitem__(self, idx):
        # Get the sentence at the requested position and remove extra spaces
        # Input: idx - integer index to select sentence
        # Output: sentence - string like "Hello world how are you"
        sentence = self.sentences[idx].strip()

        # Convert the sentence from words to numbers (tokenization)
        # This is like giving each word a unique ID number
        # Input: sentence - string like "Hello world"
        # Output: tokens - list of integers like [142, 89] (word IDs)
        tokens = self.vocab.encode(sentence)

        # If the sentence is too long, cut it short
        # We subtract 2 to leave room for start and end markers
        # Input: tokens - list of integers, potentially any length
        # Output: tokens - same list but truncated to max_length-2 if needed
        if len(tokens) > self.max_length - 2:
            tokens = tokens[:self.max_length - 2]

        # Add special markers at the beginning and end of the sentence
        # Like putting quotation marks around a sentence
        # Input: tokens - list of word IDs like [142, 89]
        # Processing: get SOS token ID (usually 2) and EOS token ID (usually 3)
        # Output: tokens - list like [2, 142, 89, 3] (SOS + words + EOS)
        sos = self.vocab.get_vocab().get('<SOS>')  # Start of sentence marker
        eos = self.vocab.get_vocab().get('<EOS>')    # End of sentence marker
        tokens = [sos] + tokens + [eos]

        # Hide some words randomly and remember what they were
        # Input: tokens - list of integers like [2, 142, 89, 3]
        # Output: masked_tokens - same list but with some words hidden/changed
        #         labels - list showing which positions need to be predicted
        masked_tokens, labels = self.mask_tokens(tokens)

        # Make sure all sentences are the same length by adding padding
        # Like adding blank spaces to make all lines the same length
        # Input: masked_tokens - list of variable length like [2, 4, 89, 3] (length 4)
        #        labels - list of same length as masked_tokens
        # Processing: add PAD tokens (usually 0) to reach max_length
        # Output: padded_tokens - list of exactly max_length integers
        #         padded_labels - list of exactly max_length integers/(-100)
        pad = self.vocab.get_vocab().get('<PAD>')  # Padding token
        padded_tokens = masked_tokens + [pad] * (self.max_length - len(masked_tokens))
        padded_labels = labels + [-100] * (self.max_length - len(labels))

        # Return the data in a format that PyTorch can understand
        # Think of this as packaging the data in a specific container
        # Input: padded_tokens - Python list of integers with length max_length
        #        padded_labels - Python list of integers with length max_length
        # Output: Dictionary with two PyTorch tensors:
        #         'input_ids': 1D tensor of shape [max_length] containing word IDs
        #         'labels': 1D tensor of shape [max_length] containing target answers
        return {
            'input_ids': torch.tensor(padded_tokens, dtype=torch.long),    # The sentence with some words hidden
            'labels': torch.tensor(padded_labels, dtype=torch.long)        # The correct answers for the hidden words
        }

In [13]:
class TranslationDataset(Dataset):
    """
    Dataset for translation fine-tuning.
    - Returns source and target sequences, both input and output forms.

    This class handles paired sentences for translation training, like:
    English: "Hello world" → Bengali: "হ্যালো বিশ্ব"
    """

    # This is the initialization function - sets up the translation dataset
    # Think of it like organizing two parallel lists of sentences for translation practice
    # Input: src_sentences (list of strings in source language)
    #        tgt_sentences (list of strings in target language)
    #        src_vocab (word-to-number converter for source language)
    #        tgt_vocab (word-to-number converter for target language)
    #        max_length (integer - maximum sentence length)
    # Output: A configured dataset object ready to serve translation pairs
    def __init__(self, src_sentences, tgt_sentences, src_vocab, tgt_vocab, max_length):
        # Store the source sentences (e.g., English sentences)
        # Input: src_sentences - Python list like ["Hello world", "How are you", ...]
        self.src_sentences = src_sentences
        # Store the target sentences (e.g., Bengali sentences)
        # Input: tgt_sentences - Python list like ["হ্যালো বিশ্ব", "আপনি কেমন আছেন", ...]
        # Note: src_sentences[i] and tgt_sentences[i] should be translations of each other
        self.tgt_sentences = tgt_sentences
        # Store the vocabulary for source language (converts source words to numbers)
        # Input: src_vocab - object that converts "hello" -> 142, "world" -> 89, etc.
        self.src_vocab = src_vocab
        # Store the vocabulary for target language (converts target words to numbers)
        # Input: tgt_vocab - object that converts "হ্যালো" -> 73, "বিশ্ব" -> 156, etc.
        self.tgt_vocab = tgt_vocab
        # Store the maximum length for both source and target sentences
        # Input: max_length - integer like 128 (all sentences will be padded/truncated to this)
        self.max_length = max_length

    # This function tells us how many translation pairs we have
    # Like counting how many sentence pairs are in our training set
    # Input: None (called automatically by Python)
    # Output: Integer representing total number of translation pairs
    def __len__(self):
        return len(self.src_sentences)

    # This function gets a specific translation pair by its position
    # Like asking for the 5th English-Bengali sentence pair
    # Input: idx - integer index like 0, 1, 2, 3... (which translation pair to get)
    # Output: Dictionary with 'src', 'tgt_ip', and 'tgt_op' tensors
    def __getitem__(self, idx):
        # Get the source sentence at the requested position
        # Input: idx - integer index to select sentence pair
        # Output: src_sentence - string like "Hello world" (source language)
        src_sentence = self.src_sentences[idx]
        # Get the corresponding target sentence
        # Input: idx - same integer index
        # Output: tgt_sentence - string like "হ্যালো বিশ্ব" (target language)
        tgt_sentence = self.tgt_sentences[idx]

        # Convert source sentence from words to numbers using source vocabulary
        # Input: src_sentence - string like "Hello world"
        # Output: src_encoded - list of integers like [142, 89] (source word IDs)
        src_encoded = self.src_vocab.encode(src_sentence)

        # Create target input sequence: add SOS (Start Of Sentence) token at beginning
        # This is what we feed to the decoder: "<SOS> হ্যালো বিশ্ব"
        # Input: tgt_sentence - string like "হ্যালো বিশ্ব"
        # Processing: get SOS token ID (usually 2), encode target sentence to numbers
        # Output: tgt_encoded_ip - list like [2, 73, 156] (SOS + target word IDs)
        tgt_encoded_ip = [self.tgt_vocab.get_vocab().get('<SOS>')] + self.tgt_vocab.encode(tgt_sentence)

        # Create target output sequence: add EOS (End Of Sentence) token at end
        # This is what we expect the decoder to produce: "হ্যালো বিশ্ব <EOS>"
        # Input: tgt_sentence - same string like "হ্যালো বিশ্ব"
        # Processing: encode target sentence to numbers, get EOS token ID (usually 3)
        # Output: tgt_encoded_op - list like [73, 156, 3] (target word IDs + EOS)
        tgt_encoded_op = self.tgt_vocab.encode(tgt_sentence) + [self.tgt_vocab.get_vocab().get('<EOS>')]

        # Truncate sequences if they're too long
        # Like cutting off sentences that are longer than our maximum allowed length
        # Input: src_encoded - list of integers, potentially any length
        # Output: src_encoded - same list but truncated to max_length if needed
        if len(src_encoded) > self.max_length:
            src_encoded = src_encoded[:self.max_length]
        # Input: tgt_encoded_ip - list of integers, potentially any length
        # Output: tgt_encoded_ip - same list but truncated to max_length if needed
        if len(tgt_encoded_ip) > self.max_length:
            tgt_encoded_ip = tgt_encoded_ip[:self.max_length]
        # Input: tgt_encoded_op - list of integers, potentially any length
        # Output: tgt_encoded_op - same list but truncated to max_length if needed
        if len(tgt_encoded_op) > self.max_length:
            tgt_encoded_op = tgt_encoded_op[:self.max_length]

        # Pad all sequences to the same length by adding padding tokens
        # Like adding blank spaces to make all sentences the same length
        # Input: source vocab to get PAD token ID (usually 0)
        # Output: pad - integer representing padding token ID
        pad = self.src_vocab.get_vocab().get('<PAD>')

        # Pad source sequence
        # Input: src_encoded - list of variable length like [142, 89] (length 2)
        # Processing: add PAD tokens to reach max_length
        # Output: src_padded - list of exactly max_length integers like [142, 89, 0, 0, 0, ...]
        src_padded = src_encoded + [pad] * (self.max_length - len(src_encoded))

        # Pad target input sequence
        # Input: tgt_encoded_ip - list of variable length like [2, 73, 156] (length 3)
        # Processing: add PAD tokens to reach max_length
        # Output: tgt_encoded_ip - list of exactly max_length integers like [2, 73, 156, 0, 0, ...]
        tgt_encoded_ip = tgt_encoded_ip + [pad] * (self.max_length - len(tgt_encoded_ip))

        # Pad target output sequence
        # Input: tgt_encoded_op - list of variable length like [73, 156, 3] (length 3)
        # Processing: add PAD tokens to reach max_length
        # Output: tgt_encoded_op - list of exactly max_length integers like [73, 156, 3, 0, 0, ...]
        tgt_encoded_op = tgt_encoded_op + [pad] * (self.max_length - len(tgt_encoded_op))

        # Return the data in a format that PyTorch can understand
        # Think of this as packaging three related pieces of data together
        # Input: src_padded - Python list of integers with length max_length
        #        tgt_encoded_ip - Python list of integers with length max_length
        #        tgt_encoded_op - Python list of integers with length max_length
        # Output: Dictionary with three PyTorch tensors:
        #         'src': 1D tensor of shape [max_length] - source sentence (English)
        #         'tgt_ip': 1D tensor of shape [max_length] - target input (SOS + Bengali)
        #         'tgt_op': 1D tensor of shape [max_length] - target output (Bengali + EOS)
        return {
            'src': torch.tensor(src_padded, dtype=torch.long),       # Source sentence for encoder
            'tgt_ip': torch.tensor(tgt_encoded_ip, dtype=torch.long), # Target input for decoder (what to feed in)
            'tgt_op': torch.tensor(tgt_encoded_op, dtype=torch.long)  # Target output for decoder (what to expect out)
        }

In [14]:
# Data loading functions
def load_monolingual_data(file_path, max_sentences):
    """
    Loads monolingual text data from a file.
    Returns a list of sentences (strings).

    This function reads a text file containing sentences in a single language
    and converts it into a Python list for machine learning training.
    Think of it like reading a book and extracting each sentence into a list.
    """

    # Print status message to track progress
    # Input: file_path - string representing file location like "/data/english.txt"
    # Output: Prints informational message to console
    print(f"Loading monolingual data from {file_path}...")

    # Initialize empty container to store all sentences
    # Input: None (initialization)
    # Output: sentences - empty Python list [], will grow as we read the file
    sentences = []

    # Open the text file for reading
    # Input: file_path - string path to text file
    #        'r' - read mode (not write mode)
    #        encoding='utf-8' - handles international characters (Bengali, Chinese, etc.)
    # Output: f - file handle object that lets us read the file line by line
    with open(file_path, 'r', encoding='utf-8') as f:
        # Read the file one line at a time
        # This loop processes each line: line 0, line 1, line 2, etc.
        # Input: f - file handle that provides lines one by one
        # Processing: enumerate() gives us both line number (i) and line content (line)
        for i, line in enumerate(f):
            # Stop reading if we've reached our sentence limit
            # This prevents loading too much data into memory
            # Input: i - current line number (starts at 0)
            #        max_sentences - integer limit like 10000
            # Decision: if we've read enough sentences, exit the loop
            if i >= max_sentences:
                break

            # Clean up the line by removing whitespace and newlines
            # Input: line - raw string from file like "Hello world\n" or "  Good morning  \n"
            # Processing: strip() removes spaces, tabs, newlines from both ends
            # Output: line - cleaned string like "Hello world" or "Good morning"
            line = line.strip()

            # Only keep non-empty lines (skip blank lines)
            # Input: line - cleaned string (could be empty "" after stripping)
            # Decision: if line has content, add it to our collection
            if line:  # Skip empty lines
                # Add this sentence to our growing list
                # Input: line - non-empty string like "Hello world"
                # Operation: append() adds one item to the end of the list
                # Output: sentences list grows by 1 item
                # Example: sentences changes from ["Hi", "Bye"] to ["Hi", "Bye", "Hello world"]
                sentences.append(line)

    # Print summary statistics about what we loaded
    # Input: sentences - Python list of strings
    # Output: Prints count and file path to console
    print(f"Loaded {len(sentences)} sentences from {file_path}")

    # Print first few sentences as a sample to verify correct loading
    # Input: sentences - Python list of strings like ["Hello", "World", "How are you"]
    # Processing: sentences[:5] takes first 5 items (or all items if less than 5)
    # Output: Prints first 5 sentences to console for verification
    print(f"First 5 sentences(load_monolingual_data): {sentences[:5]}")

    # Return the final collection of sentences
    # Input: sentences - Python list of strings
    # Output: Returns list of strings like ["Hello world", "How are you", "Good morning", ...]
    # Dimensions: 1D list with length = number of sentences loaded (up to max_sentences)
    return sentences

In [15]:
def load_translation_data(file_path, max_sentences):
    """
    Loads parallel translation pairs from a CSV file.
    Returns two lists: English sentences and Bengali sentences.

    This function reads a CSV file where each row contains a sentence pair:
    one column for English, one column for Bengali (the same meaning).
    Think of it like reading a two-column dictionary where each row is:
    | English Column | Bengali Column |
    | "Hello world"  | "হ্যালো বিশ্ব"  |
    | "How are you"  | "আপনি কেমন আছেন" |
    """

    # Print status message to track progress
    # Input: file_path - string representing CSV file location like "/data/translations.csv"
    # Output: Prints informational message to console
    print(f"Loading translation data from {file_path}...")

    # Read the CSV file into a pandas DataFrame (like a spreadsheet in memory)
    # Input: file_path - string path to CSV file
    # Output: df - pandas DataFrame object (like a 2D table with rows and columns)
    # Dimensions: [number_of_rows, number_of_columns] - typically [N, 2] for translation pairs
    # Example: df might look like:
    #          en                bn
    #    0     "Hello world"     "হ্যালো বিশ্ব"
    #    1     "How are you"     "আপনি কেমন আছেন"
    #    2     "Good morning"    "সুপ্রভাত"
    df = pd.read_csv(file_path)

    # Remove any rows that have missing data (empty cells)
    # Input: df - pandas DataFrame that might have some empty cells (NaN values)
    # Processing: dropna() removes entire rows where any column has missing data
    # Output: df - cleaned DataFrame with only complete rows
    # Dimensions: [number_of_complete_rows, number_of_columns] - fewer or same rows as before
    # Example: if row 5 had missing Bengali translation, that entire row gets removed
    df = df.dropna()

    # Keep only the first max_sentences rows (limit data size)
    # Input: df - pandas DataFrame with potentially many rows
    #        max_sentences - integer limit like 10000
    # Processing: head(n) takes the first n rows from the DataFrame
    # Output: df - truncated DataFrame with at most max_sentences rows
    # Dimensions: [min(original_rows, max_sentences), number_of_columns]
    # Example: if df had 50000 rows and max_sentences=10000, result has 10000 rows
    df = df.head(max_sentences)

    # Extract English sentences from the 'en' column
    # Input: df - pandas DataFrame with columns including 'en'
    # Processing: df['en'] selects just the 'en' column (pandas Series)
    #            .tolist() converts pandas Series to Python list
    # Output: english_sentences - Python list of strings
    # Dimensions: 1D list with length = number of rows in df
    # Example: ["Hello world", "How are you", "Good morning", ...]
    english_sentences = df['en'].tolist()

    # Extract Bengali sentences from the 'bn' column
    # Input: df - same pandas DataFrame with columns including 'bn'
    # Processing: df['bn'] selects just the 'bn' column (pandas Series)
    #            .tolist() converts pandas Series to Python list
    # Output: bengali_sentences - Python list of strings
    # Dimensions: 1D list with length = number of rows in df (same as english_sentences)
    # Example: ["হ্যালো বিশ্ব", "আপনি কেমন আছেন", "সুপ্রভাত", ...]
    bengali_sentences = df['bn'].tolist()

    # Print summary statistics about what we loaded
    # Input: english_sentences - Python list of strings
    # Output: Prints count to console (both lists should have same length)
    print(f"Loaded {len(english_sentences)} translation pairs")

    # Return both lists as a tuple
    # Input: english_sentences - 1D Python list of strings
    #        bengali_sentences - 1D Python list of strings (same length)
    # Output: Returns tuple of two lists
    # Dimensions: (list[length=N], list[length=N]) where N is number of translation pairs
    # Important: english_sentences[i] and bengali_sentences[i] are translations of each other
    # Example: english_sentences[0] = "Hello world", bengali_sentences[0] = "হ্যালো বিশ্ব"
    return english_sentences, bengali_sentences

In [16]:
def create_pretrain_dataloaders(english_sentences, bengali_sentences, src_vocab, tgt_vocab, config):
    """
    Creates DataLoaders for pretraining (MLM) on English and Bengali monolingual data.

    Think of this like preparing two separate assembly lines for processing text:
    - One line processes English sentences
    - Another line processes Bengali sentences

    Each "assembly line" (DataLoader) will:
    1. Take raw sentences and convert them to numbers (tokens)
    2. Randomly mask some words (like filling in blanks in a test)
    3. Group sentences into batches for efficient processing

    Args:
        english_sentences: List of English text strings
                          Input: ["Hello world", "How are you", ...]
        bengali_sentences: List of Bengali text strings
                          Input: ["হ্যালো বিশ্ব", "কেমন আছেন", ...]
        src_vocab: Dictionary mapping English words to numbers
                  Input: {"hello": 1, "world": 2, ...}
        tgt_vocab: Dictionary mapping Bengali words to numbers
                  Input: {"হ্যালো": 1, "বিশ্ব": 2, ...}
        config: Settings dictionary containing:
                - max_length: Maximum sentence length (e.g., 128 tokens)
                - mask_prob: Probability of masking words (e.g., 0.15 = 15%)
                - batch_size: How many sentences to process together (e.g., 32)

    Returns:
        english_loader: DataLoader that outputs batches of English data
                       Output shape: (batch_size, max_length) - matrix of token IDs
        bengali_loader: DataLoader that outputs batches of Bengali data
                       Output shape: (batch_size, max_length) - matrix of token IDs
    """
    print("Creating pre-training data loaders...")

    # STEP 1: Create datasets (think of these as "recipe books" for processing text)
    # Each dataset knows how to:
    # - Convert text sentences to number sequences (tokenization)
    # - Randomly hide some words and mark them as [MASK] tokens
    # - Pad/truncate sentences to exact same length for batch processing

    # English dataset: transforms raw English text into training examples
    # Input: List of strings → Output: Tokenized & masked sequences
    english_dataset = PretrainDataset(
        english_sentences,      # Raw text: ["Hello world", ...]
        src_vocab,             # Word→number mapping: {"hello": 1, ...}
        config['max_length'],  # Fixed sequence length: 128 tokens
        config['mask_prob']    # Masking probability: 15%
    )

    # Bengali dataset: same process but for Bengali text
    # Input: List of strings → Output: Tokenized & masked sequences
    bengali_dataset = PretrainDataset(
        bengali_sentences,      # Raw text: ["হ্যালো বিশ্ব", ...]
        tgt_vocab,             # Word→number mapping: {"হ্যালো": 1, ...}
        config['max_length'],  # Fixed sequence length: 128 tokens
        config['mask_prob']    # Masking probability: 15%
    )

    # STEP 2: Create data loaders (think of these as "batch processors")
    # DataLoader groups individual examples into batches for efficient GPU processing
    # Instead of processing one sentence at a time, we process many together

    # English loader: packages English examples into batches
    # Input: Individual sequences (max_length,) → Output: Batches (batch_size, max_length)
    # Example: 32 individual sequences of 128 tokens each → 1 batch of shape (32, 128)
    english_loader = DataLoader(
        english_dataset,           # Dataset to draw examples from
        batch_size=config['batch_size'],  # How many examples per batch (e.g., 32)
        shuffle=True              # Randomize order each epoch for better training
    )

    # Bengali loader: same batching process for Bengali data
    # Input: Individual sequences (max_length,) → Output: Batches (batch_size, max_length)
    bengali_loader = DataLoader(
        bengali_dataset,           # Dataset to draw examples from
        batch_size=config['batch_size'],  # How many examples per batch (e.g., 32)
        shuffle=True              # Randomize order each epoch for better training
    )

    # STEP 3: Report statistics for monitoring
    # These numbers help you understand your data volume
    print(f"English pre-training samples: {len(english_dataset)}")  # Total English examples
    print(f"Bengali pre-training samples: {len(bengali_dataset)}")  # Total Bengali examples

    # FINAL OUTPUT: Two data loaders that will feed the neural network
    # Each loader yields batches of shape (batch_size, max_length)
    # Example: If batch_size=32 and max_length=128, each batch is a 32×128 matrix
    # where each row is one sentence represented as token IDs
    return english_loader, bengali_loader

In [17]:
def create_translation_dataloaders(english_sentences, bengali_sentences, src_vocab, tgt_vocab, config):
    """
    Creates DataLoaders for translation fine-tuning on parallel data.

    Think of this like preparing two training programs for a translator:
    - Training program: Practice translating with immediate feedback
    - Validation program: Take tests without feedback to measure progress

    Unlike the pretraining function that processes languages separately,
    this function works with PAIRED sentences (English ↔ Bengali translations)

    Args:
        english_sentences: List of English sentences
                          Input: ["Hello world", "How are you", ...]
        bengali_sentences: List of Bengali sentences (translations of English)
                          Input: ["হ্যালো বিশ্ব", "কেমন আছেন", ...]
                          Note: bengali_sentences[i] is translation of english_sentences[i]
        src_vocab: Dictionary mapping English words to numbers
                  Input: {"hello": 1, "world": 2, ...}
        tgt_vocab: Dictionary mapping Bengali words to numbers
                  Input: {"হ্যালো": 1, "বিশ্ব": 2, ...}
        config: Settings dictionary containing:
                - max_length: Maximum sentence length (e.g., 128 tokens)
                - batch_size: How many sentence pairs to process together (e.g., 32)

    Returns:
        train_loader: DataLoader for training pairs
                     Output shape: (batch_size, max_length) for source + target
        val_loader: DataLoader for validation pairs
                   Output shape: (batch_size, max_length) for source + target
    """
    print("Creating translation data loaders...")

    # STEP 1: Split data into training and validation sets
    # This is like dividing your study materials into "practice problems" and "final exam"
    # Standard split: 90% for training, 10% for validation

    # Calculate split point: 90% of total data
    # Input: Total count → Output: Split index
    # Example: 10,000 sentences → split_idx = 9,000
    split_idx = int(0.9 * len(english_sentences))

    # Training data: First 90% of sentence pairs
    # Input: Full lists → Output: Subset lists (90% of original size)
    # Example: 10,000 pairs → 9,000 training pairs
    train_src = english_sentences[:split_idx]    # English sentences [0:9000]
    train_tgt = bengali_sentences[:split_idx]    # Bengali sentences [0:9000]

    # Validation data: Last 10% of sentence pairs
    # Input: Full lists → Output: Subset lists (10% of original size)
    # Example: 10,000 pairs → 1,000 validation pairs
    val_src = english_sentences[split_idx:]      # English sentences [9000:10000]
    val_tgt = bengali_sentences[split_idx:]      # Bengali sentences [9000:10000]

    # STEP 2: Create datasets (think of these as "exercise generators")
    # Each dataset knows how to:
    # - Take a sentence pair (English, Bengali)
    # - Convert both to number sequences using respective vocabularies
    # - Pad/truncate to same length for batch processing
    # - Package them as training examples

    # Training dataset: converts sentence pairs to training examples
    # Input: Two lists of strings → Output: Paired tokenized sequences
    # Each example: (source_tokens, target_tokens) where both have shape (max_length,)
    train_dataset = TranslationDataset(
        train_src,                 # English sentences for training
        train_tgt,                 # Bengali sentences for training
        src_vocab,                 # English word→number mapping
        tgt_vocab,                 # Bengali word→number mapping
        config['max_length']       # Fixed sequence length: 128 tokens
    )

    # Validation dataset: same process but for validation pairs
    # Input: Two lists of strings → Output: Paired tokenized sequences
    val_dataset = TranslationDataset(
        val_src,                   # English sentences for validation
        val_tgt,                   # Bengali sentences for validation
        src_vocab,                 # English word→number mapping
        tgt_vocab,                 # Bengali word→number mapping
        config['max_length']       # Fixed sequence length: 128 tokens
    )

    # STEP 3: Create data loaders (think of these as "batch processors")
    # These group individual sentence pairs into batches for efficient processing
    # Instead of translating one sentence at a time, we translate many in parallel

    # Training loader: packages training pairs into batches
    # Input: Individual pairs (max_length,) + (max_length,) → Output: Batches
    # Output shape: Two tensors of (batch_size, max_length) each
    # Example: 32 pairs → batch of 32 English + batch of 32 Bengali sequences
    train_loader = DataLoader(
        train_dataset,             # Dataset to draw pairs from
        batch_size=config['batch_size'],  # How many pairs per batch (e.g., 32)
        shuffle=True              # Randomize order each epoch for better learning
    )

    # Validation loader: same batching but without shuffling
    # Input: Individual pairs (max_length,) + (max_length,) → Output: Batches
    # Output shape: Two tensors of (batch_size, max_length) each
    # shuffle=False ensures consistent validation results across runs
    val_loader = DataLoader(
        val_dataset,               # Dataset to draw pairs from
        batch_size=config['batch_size'],  # How many pairs per batch (e.g., 32)
        shuffle=False             # Keep same order for reproducible validation
    )

    # STEP 4: Report statistics for monitoring
    # These numbers help you track your dataset sizes and training progress
    print(f"Translation training samples: {len(train_dataset)}")      # e.g., 9,000 pairs
    print(f"Translation validation samples: {len(val_dataset)}")      # e.g., 1,000 pairs

    # FINAL OUTPUT: Two data loaders for training and validation
    # Each loader yields batches with TWO components:
    # 1. Source batch: (batch_size, max_length) - English sentences as token IDs
    # 2. Target batch: (batch_size, max_length) - Bengali sentences as token IDs
    #
    # Example usage in training loop:
    # for src_batch, tgt_batch in train_loader:
    #     # src_batch shape: (32, 128) - 32 English sentences of 128 tokens each
    #     # tgt_batch shape: (32, 128) - 32 Bengali sentences of 128 tokens each
    #     # Model learns to translate src_batch → tgt_batch

    return train_loader, val_loader

In [18]:
class EarlyStopping:
    """
    Utility for early stopping during training to prevent overfitting.

    Think of this like a fitness coach monitoring your workout progress:
    - If you stop improving for several sessions, the coach stops the program
    - This prevents you from overtraining and getting worse results

    In machine learning terms:
    - Monitors validation loss (how well model performs on unseen data)
    - If model stops improving for 'patience' epochs, stops training
    - Prevents overfitting (memorizing training data instead of learning patterns)

    This is a STATEFUL object that remembers training history across epochs.
    """

    def __init__(self, patience=5, min_delta=0.000001):
        """
        Initialize the early stopping monitor.

        Args:
            patience: How many epochs to wait without improvement before stopping
                     Input: Integer (e.g., 5)
                     Think of this as "strikes" - after 5 strikes, you're out
            min_delta: Minimum change to qualify as an improvement
                      Input: Float (e.g., 0.000001)
                      Prevents stopping due to tiny random fluctuations

        Internal State Variables:
            counter: Tracks consecutive epochs without improvement
                    Type: Integer, Range: [0, patience]
            best_loss: Best validation loss seen so far
                      Type: Float, Initial: infinity (worst possible)
        """
        # CONFIGURATION: Set stopping criteria
        self.patience = patience        # Max epochs to wait (e.g., 5)
        self.min_delta = min_delta      # Minimum improvement threshold (e.g., 0.000001)

        # STATE TRACKING: Initialize monitoring variables
        self.counter = 0                # Consecutive epochs without improvement
                                       # Range: [0, patience]
                                       # 0 = just improved, patience = time to stop

        self.best_loss = float('inf')   # Best validation loss seen so far
                                       # Starts at infinity (worst possible)
                                       # Gets updated when we see improvements

    def __call__(self, val_loss):
        """
        Check if training should stop based on current validation loss.

        This method is called after each epoch with the latest validation loss.
        It acts like a "judge" deciding whether to continue or stop training.

        Args:
            val_loss: Current epoch's validation loss
                     Input: Single float value (e.g., 2.345)
                     Lower values = better model performance

        Returns:
            should_stop: Boolean indicating whether to stop training
                        Output: True = stop training, False = continue

        Logic Flow:
            1. Compare current loss with best loss seen so far
            2. If improved significantly → reset counter, update best
            3. If not improved → increment counter
            4. If counter reaches patience → return True (stop)
        """

        # STEP 1: Check if current loss is significantly better than best loss
        # We use "best_loss - min_delta" to avoid stopping on tiny random improvements
        #
        # Input dimensions:
        # - val_loss: scalar float
        # - self.best_loss: scalar float
        # - self.min_delta: scalar float
        # Output: boolean comparison result

        if val_loss < self.best_loss - self.min_delta:
            # IMPROVEMENT DETECTED: Current loss is significantly better
            # Example: best_loss=2.5, min_delta=0.001, val_loss=2.3
            # Check: 2.3 < 2.5 - 0.001 = 2.499 ✓ (True)

            # Update tracking state
            self.best_loss = val_loss   # New best loss (scalar → scalar)
            self.counter = 0            # Reset patience counter (int → 0)

            # Model is still improving, continue training

        else:
            # NO IMPROVEMENT: Current loss is not significantly better
            # Example: best_loss=2.5, min_delta=0.001, val_loss=2.51
            # Check: 2.51 < 2.5 - 0.001 = 2.499 ✗ (False)

            # Increment the "strikes" counter
            self.counter += 1           # Increment counter (int → int + 1)

            # Model is not improving, getting closer to stopping

        # STEP 2: Decide whether to stop training
        # Compare counter with patience threshold
        # Input: self.counter (int), self.patience (int)
        # Output: boolean (True = stop, False = continue)

        should_stop = self.counter >= self.patience

        # Examples:
        # - counter=3, patience=5 → 3 >= 5 = False (continue)
        # - counter=5, patience=5 → 5 >= 5 = True (stop)

        return should_stop

    # USAGE EXAMPLE:
    # early_stopper = EarlyStopping(patience=5, min_delta=0.001)
    #
    # for epoch in range(100):
    #     # ... training code ...
    #     val_loss = validate_model()  # Returns scalar float
    #
    #     if early_stopper(val_loss):  # Pass scalar, get boolean
    #         print(f"Early stopping at epoch {epoch}")
    #         break
    #
    # State evolution example:
    # Epoch 1: val_loss=3.0 → best_loss=3.0, counter=0, continue
    # Epoch 2: val_loss=2.5 → best_loss=2.5, counter=0, continue
    # Epoch 3: val_loss=2.6 → best_loss=2.5, counter=1, continue
    # Epoch 4: val_loss=2.7 → best_loss=2.5, counter=2, continue
    # ...
    # Epoch 8: val_loss=2.8 → best_loss=2.5, counter=5, STOP!

In [19]:
# === HELPER: BATCH ACCURACY ===
def batch_accuracy(logits, labels, ignore_index):
    """
    Compute token-level accuracy, ignoring positions with `ignore_index`.

    Think of this like grading a multiple-choice test where some answers should be ignored:
    - logits: The model's "confidence scores" for each possible answer
    - labels: The correct answers (ground truth)
    - ignore_index: Special marker for "skip this question" (like padding tokens)

    The function converts confidence scores to actual predictions, then calculates
    what percentage of predictions match the correct answers (excluding skipped ones).

    Args:
        logits: Model's raw prediction scores for each token
               Shape: (batch_size, sequence_length, vocab_size)
               Example: (32, 128, 50000) - 32 sentences, 128 tokens each, 50k possible words
               Values: Real numbers (can be negative) representing confidence

        labels: Correct token IDs (ground truth)
               Shape: (batch_size, sequence_length)
               Example: (32, 128) - 32 sentences, 128 tokens each
               Values: Integers from 0 to vocab_size-1, plus ignore_index

        ignore_index: Special token ID to ignore during accuracy calculation
                     Type: Integer (commonly -100 or 0)
                     Used for: Padding tokens, special tokens we don't want to evaluate

    Returns:
        acc: Token-level accuracy as a float between 0.0 and 1.0
            Type: Float
            Example: 0.85 means 85% of (non-ignored) tokens were predicted correctly
    """
    # Disable gradient computation for efficiency (we're just measuring, not training)
    # This is like telling the system "don't track this for learning, just calculate"
    with torch.no_grad():

        # STEP 1: Convert confidence scores to actual predictions
        # logits contains confidence scores for each possible word in vocabulary
        # We pick the word with highest confidence score

        # Find the index of maximum value along vocabulary dimension
        # Input shape: (batch_size, sequence_length, vocab_size)
        # Output shape: (batch_size, sequence_length)
        # Example: (32, 128, 50000) → (32, 128)
        # Operation: For each position, argmax picks the most confident word ID
        preds = torch.argmax(logits, dim=-1)

        # STEP 2: Create mask to identify which positions to evaluate
        # We want to ignore certain positions (like padding tokens)

        # Create boolean mask: True where we should count, False where we should ignore
        # Input shapes: labels (32, 128), ignore_index (scalar)
        # Output shape: (32, 128) - same as labels
        # Values: True/False for each token position
        mask = labels != ignore_index

        # STEP 3: Safety check to avoid division by zero
        # If all tokens are ignored (very rare), return 0 accuracy

        # Count how many positions we're actually evaluating
        # Input shape: (32, 128) boolean mask
        # Output: Scalar integer (total number of True values)
        if mask.sum() == 0:
            return 0.0  # avoid div/0 when all tokens are ignored (rare edge case)

        # STEP 4: Calculate accuracy
        # Find positions where prediction matches ground truth AND we should evaluate

        # Check correctness: True where prediction equals ground truth
        # Input shapes: preds (32, 128), labels (32, 128)
        # Output shape: (32, 128) - boolean tensor
        # Then combine with mask using logical AND (&)
        correct = (preds == labels) & mask

        # STEP 5: Compute final accuracy percentage
        # Count correct predictions and divide by total evaluable positions

        # Convert boolean tensor to count of True values
        # Input: correct (32, 128) boolean tensor
        # Output: Scalar integer (number of correct predictions)
        num_correct = correct.sum().item()

        # Convert boolean mask to count of True values
        # Input: mask (32, 128) boolean tensor
        # Output: Scalar integer (number of positions to evaluate)
        num_total = mask.sum().item()

        # Calculate accuracy as percentage
        # Input: Two scalar integers
        # Output: Float between 0.0 and 1.0
        acc = num_correct / num_total

        return acc

# EXAMPLE WALKTHROUGH:
# If we have a batch with:
# - logits shape: (2, 4, 1000) - 2 sentences, 4 tokens each, 1000 vocab words
# - labels: [[1, 5, 999, -100], [2, 8, 15, -100]] - correct token IDs
# - ignore_index: -100 (padding tokens)
#
# Step 1: argmax gives predictions: [[1, 3, 999, 200], [2, 8, 12, 500]]
# Step 2: mask identifies non-padding: [[True, True, True, False], [True, True, True, False]]
# Step 3: We have 6 positions to evaluate (not 8, because 2 are padding)
# Step 4: correct positions: [[True, False, True, False], [True, True, False, False]]
# Step 5: accuracy = 4 correct / 6 total = 0.667 (66.

In [20]:
# === Noam Learning-Rate Scheduler (from "Attention Is All You Need") ===
def get_noam_scheduler(optimizer, d_model, warmup_steps):
    """
    Implements the Noam learning rate schedule from the original Transformer paper.

    Think of this like a sophisticated training plan for an athlete:
    - Phase 1 (Warmup): Gradually increase training intensity from zero
    - Phase 2 (Decay): Slowly decrease intensity as performance improves

    This prevents the model from learning too aggressively at first (which can cause
    instability) while ensuring it continues learning throughout training.

    The learning rate curve looks like: /\_____ (ramp up, then gradual decay)

    Args:
        optimizer: The optimization algorithm (like Adam, SGD) that updates model weights
                  Type: torch.optim.Optimizer object
                  Contains: Current learning rate, momentum, weight decay settings

        d_model: Model dimension (size of hidden representations)
                Type: Integer (commonly 512, 768, 1024)
                Purpose: Used to scale the learning rate based on model size
                Rationale: Larger models need smaller learning rates for stability

        warmup_steps: Number of steps for the warmup phase
                     Type: Integer (commonly 4000, 8000, 16000)
                     Purpose: How long to gradually increase learning rate from 0

    Returns:
        scheduler: Learning rate scheduler object
                  Type: torch.optim.lr_scheduler.LambdaLR
                  Usage: Call scheduler.step() after each training step
                  Effect: Automatically adjusts optimizer's learning rate
    """

    # INNER FUNCTION: Calculates learning rate multiplier for any given step
    # This function will be called automatically by PyTorch's scheduler
    def lr_lambda(step):
        """
        Calculate learning rate multiplier for a given training step.

        Args:
            step: Current training step (starts at 0)
                 Type: Integer
                 Range: 0, 1, 2, 3, ... (increases throughout training)

        Returns:
            multiplier: Factor to multiply base learning rate by
                       Type: Float
                       Range: 0.0 to 1.0 (typically)
        """

        # STEP 1: Adjust step counting (PyTorch starts at 0, but we need 1-based)
        # Input: step (integer, 0-based)
        # Output: step (integer, 1-based)
        # Example: step=0 becomes step=1, step=1 becomes step=2
        step += 1  # step counting starts at 0 in LambdaLR

        # STEP 2: Calculate the core learning rate formula
        # This implements the Noam schedule: lr = d_model^(-0.5) * min(step^(-0.5), step * warmup_steps^(-1.5))

        # Calculate model size scaling factor
        # Input: d_model (integer, e.g., 512)
        # Output: scalar float (e.g., 512^(-0.5) = 0.044)
        # Purpose: Larger models get smaller base learning rates
        model_scale = d_model ** -0.5

        # Calculate step-based component (the heart of the schedule)
        # Two competing terms:
        # 1. step^(-0.5): Decay term - gets smaller as training progresses
        #    Input: step (integer) → Output: float
        #    Example: step=100 → 100^(-0.5) = 0.1
        #
        # 2. step * (warmup_steps^(-1.5)): Warmup term - linear increase initially
        #    Input: step (integer), warmup_steps (integer) → Output: float
        #    Example: step=1000, warmup_steps=4000 → 1000 * (4000^(-1.5)) = 0.004
        #
        # min() chooses the smaller value:
        # - During warmup: warmup term is smaller → linear increase
        # - After warmup: decay term is smaller → gradual decrease
        step_scale = min(step ** -0.5, step * (warmup_steps ** -1.5))

        # Combine model scaling with step scaling
        # Input: Two floats → Output: Float
        # This gives the raw learning rate value
        scale = model_scale * step_scale

        # STEP 3: Normalize the learning rate
        # We want the peak learning rate (at warmup_steps) to be 1.0
        # This makes it easier to set a reasonable base learning rate in the optimizer

        # Calculate what the learning rate would be at the peak (end of warmup)
        # Input: d_model (int), warmup_steps (int) → Output: float
        # This is the maximum value the schedule will reach
        normaliser = (d_model ** -0.5) * (warmup_steps ** -1.5)

        # Normalize so peak learning rate = 1.0
        # Input: scale (float), normaliser (float) → Output: float (0.0 to 1.0)
        # At warmup_steps: scale = normaliser, so this returns 1.0
        # Before warmup_steps: this returns < 1.0 (gradual increase)
        # After warmup_steps: this returns < 1.0 (gradual decrease)
        return scale / normaliser

    # STEP 4: Create the actual scheduler object
    # LambdaLR will call lr_lambda(step) at each step and multiply the base learning rate
    # Input: optimizer (object), lr_lambda (function) → Output: scheduler (object)
    #
    # How it works:
    # - Base learning rate in optimizer: e.g., 0.001
    # - lr_lambda returns: e.g., 0.5
    # - Actual learning rate used: 0.001 * 0.5 = 0.0005
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

# EXAMPLE LEARNING RATE SCHEDULE:
# Assume d_model=512, warmup_steps=4000, base_lr=0.001
#
# Step 1000: lr_lambda ≈ 0.25 → actual_lr = 0.001 * 0.25 = 0.00025
# Step 4000: lr_lambda = 1.0  → actual_lr = 0.001 * 1.0  = 0.001    (peak)
# Step 8000: lr_lambda ≈ 0.71 → actual_lr = 0.001 * 0.71 = 0.00071
# Step 16000: lr_lambda ≈ 0.5 → actual_lr = 0.001 * 0.5  = 0.0005
#
# USAGE IN TRAINING LOOP:
# scheduler = get_noam_scheduler(optimizer, d_model=512, warmup_steps=4000)
# for step in range(num_steps):
#     # ... training step ...
#     optimizer.step()      # Update model weights
#     scheduler.step()      # Update learning rate
#     # Current LR: optimizer.param_groups[0]['lr']

In [21]:
# === PRETRAIN MODEL: MLM on monolingual data ===
def pretrain_model(model, english_sentences, bengali_sentences, src_vocab, tgt_vocab, config, overall_start_time=None):
    """
    Pre-train the transformer model using Masked Language Modeling (MLM).

    Think of this like teaching a student to fill in missing words in sentences
    before they learn to translate between languages. This is a two-part process:

    1. Encoder learns English: Given "The cat [MASK] on the mat" → predict "sat"
    2. Decoder learns Bengali: Given "বিড়াল [MASK] এ বসে" → predict "মাদুরে"

    This foundational learning helps the model understand both languages deeply
    before attempting translation between them.

    Args:
        model: The transformer neural network to train
              Type: PyTorch model object
              Contains: Encoder, decoder, and prediction heads

        english_sentences: List of English text for encoder training
                          Input: ["Hello world", "How are you", ...]
                          Size: Variable (e.g., 100,000 sentences)

        bengali_sentences: List of Bengali text for decoder training
                          Input: ["হ্যালো বিশ্ব", "কেমন আছেন", ...]
                          Size: Variable (e.g., 100,000 sentences)

        src_vocab: English vocabulary (word to number mapping)
                  Input: {"hello": 1, "world": 2, ...}
                  Size: ~50,000 words typically

        tgt_vocab: Bengali vocabulary (word to number mapping)
                  Input: {"হ্যালো": 1, "বিশ্ব": 2, ...}
                  Size: ~50,000 words typically

        config: Training configuration dictionary
               Contains: batch_size, learning_rate, epochs, etc.

        overall_start_time: Optional timestamp for global timing
                           Type: Float (from time.time())

    Returns:
        model: The trained model with learned parameters
        train_losses: List of training error values over time
        val_losses: List of validation errors (empty in this version)
        pretrain_accs: List of accuracy measurements during training
    """
    print("=== Starting Pre-training Phase ===")

    # STEP 1: Setup computing environment
    # Choose between GPU (fast) or CPU (slower but always available)
    device = torch.device(config['device'])  # Usually 'cuda' or 'cpu'

    # Move model to chosen device (like loading software on the right computer)
    # Input: Model object → Output: Same model, but on GPU/CPU
    model = model.to(device)

    # STEP 2: Create training datasets
    # These will convert raw text into training examples with masked words
    print("Creating pre-training datasets...")

    # English dataset: Creates fill-in-the-blank exercises from English sentences
    # Input: Raw sentences → Output: Masked sequences + labels
    # Example: "Hello world" → input_ids: [1, 103, 2], labels: [-100, 2, -100]
    # where 103 = [MASK] token, -100 = ignore this position
    english_dataset = PretrainDataset(
        english_sentences,                    # Raw English text
        src_vocab,                           # English word→number mapping
        config['max_length'],                # Max sequence length (e.g., 128)
        config['mask_probability']           # Fraction of words to mask (e.g., 0.15)
    )

    # Bengali dataset: Same process for Bengali sentences
    # Input: Raw sentences → Output: Masked sequences + labels
    bengali_dataset = PretrainDataset(
        bengali_sentences,                   # Raw Bengali text
        tgt_vocab,                          # Bengali word→number mapping
        config['max_length'],                # Max sequence length (e.g., 128)
        config['mask_probability']           # Fraction of words to mask (e.g., 0.15)
    )

    # STEP 3: Create data loaders (batch processors)
    # These group individual examples into batches for efficient training

    # English loader: Packages English examples into batches
    # Input: Individual sequences → Output: Batches of shape (batch_size, max_length)
    # Example: 32 sequences of 128 tokens each → (32, 128) tensor
    english_loader = DataLoader(
        english_dataset,
        batch_size=config['batch_size'],     # e.g., 32 examples per batch
        shuffle=True                         # Randomize order each epoch
    )

    # Bengali loader: Same batching for Bengali data
    bengali_loader = DataLoader(
        bengali_dataset,
        batch_size=config['batch_size'],     # e.g., 32 examples per batch
        shuffle=True                         # Randomize order each epoch
    )

    # Report dataset sizes
    print(f"English pre-training samples: {len(english_dataset)}")
    print(f"Bengali pre-training samples: {len(bengali_dataset)}")

    # STEP 4: Setup training components

    # Loss function: Measures how wrong the model's predictions are
    # CrossEntropyLoss compares predicted word probabilities vs correct answers
    # ignore_index=-100: Skip these positions when calculating error
    criterion = nn.CrossEntropyLoss(ignore_index=-100)

    # Optimizer: The "learning algorithm" that updates model weights
    # Adam is a sophisticated method that adapts learning speed automatically
    optimizer = optim.Adam(
        model.parameters(),                  # All model weights to update
        lr=config['pretrain_learning_rate'], # Base learning rate (e.g., 0.0001)
        betas=(0.9, 0.98),                  # Momentum parameters
        eps=1e-9                            # Numerical stability constant
    )

    # Scheduler: Adjusts learning rate during training (Noam schedule)
    # Starts slow, ramps up, then gradually decreases
    scheduler = get_noam_scheduler(
        optimizer,
        config['d_model'],                   # Model dimension (e.g., 512)
        config['warmup_steps']               # Warmup duration (e.g., 4000 steps)
    )

    # Early stopping: Prevents overfitting by stopping when performance plateaus
    early_stopping = EarlyStopping(patience=config['patience'])

    # STEP 5: Initialize tracking variables
    # These lists store performance metrics over time
    train_losses = []      # Training error at each epoch
    val_losses = []        # Validation error (empty in this version)
    pretrain_accs = []     # Accuracy measurements during training

    # Record training start time
    start_time = time.time()

    # STEP 6: Main training loop
    # Repeat the learning process for multiple epochs
    for epoch in range(config['pretrain_epochs']):

        # TIMING: Calculate elapsed time and check limits
        time_elapsed = time.time() - start_time
        global_time_elapsed = time.time() - overall_start_time if overall_start_time is not None else 0

        print(f"Time elapsed: {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s (phase), {global_time_elapsed // 60:.0f}m {global_time_elapsed % 60:.0f}s (overall)")

        # Stop if training is taking too long
        if time_elapsed > config['max_train_minutes'] * 60:
            print("Time limit exceeded. Stopping pre-training.")
            break
        if overall_start_time is not None and global_time_elapsed > config['max_global_minutes'] * 60:
            print("Global time limit exceeded. Stopping pre-training.")
            break

        # Set model to training mode (enables dropout, batch norm updates)
        model.train()

        # Initialize epoch tracking variables
        epoch_train_loss = 0  # Cumulative error for this epoch
        batch_count = 0       # Number of batches processed

        # PHASE 1: Train encoder on English data
        print(f"Training encoder (English) - Epoch {epoch+1}")
        for batch in tqdm(english_loader, desc=f"Encoder MLM (English) Epoch {epoch+1}", disable=CONFIG['tqdm_disable']):

            # Extract batch data and move to device (GPU/CPU)
            # input_ids: sentences with some words replaced by [MASK] tokens
            # labels: the original words that were masked
            input_ids = batch['input_ids'].to(device)  # Shape: (batch_size, seq_length)
            labels = batch['labels'].to(device)        # Shape: (batch_size, seq_length)

            # DEBUG: Print shapes for understanding
            #print(f"input_ids shape: {input_ids.shape}")  # e.g., (32, 128)
            #print(f"labels shape: {labels.shape}")        # e.g., (32, 128)

            # Clear gradients from previous step (like erasing a whiteboard)
            optimizer.zero_grad()

            # FORWARD PASS: Get model predictions
            # Model tries to predict original words from masked sentences
            # Input: (batch_size, seq_length) → Output: (batch_size, seq_length, vocab_size)
            logits = model.mlm_encode(input_ids)
            #print(f"logits shape: {logits.shape}")  # e.g., (32, 128, 50000)

            # Verify shapes match for loss calculation
            assert logits.shape[:2] == labels.shape, f"Logits shape {logits.shape} and labels shape {labels.shape} do not match for loss"

            # LOSS CALCULATION: How wrong were the predictions?
            # Reshape tensors for CrossEntropyLoss
            # Input: logits (batch_size, seq_length, vocab_size) → (batch_size*seq_length, vocab_size)
            # Input: labels (batch_size, seq_length) → (batch_size*seq_length,)
            # Output: scalar loss value
            loss = criterion(
                logits.view(-1, logits.size(-1)),  # Flatten: (32*128, 50000)
                labels.view(-1)                    # Flatten: (32*128,)
            )

            # BACKWARD PASS: Learn from mistakes
            # Calculate gradients (how to improve each weight)
            # PyTorch: automatic differentiation (autograd); get the context from 'logits'
            loss.backward()

            # UPDATE: Apply the learned improvements
            optimizer.step()

            # SCHEDULING: Adjust learning rate
            scheduler.step()

            # TRACKING: Record performance metrics
            epoch_train_loss += loss.item()  # Add to cumulative loss
            batch_count += 1

            # Calculate accuracy (percentage of correct predictions)
            # Input: logits (batch_size, seq_length, vocab_size), labels (batch_size, seq_length)
            # Output: scalar accuracy (0.0 to 1.0)
            acc = batch_accuracy(logits, labels, -100)
            pretrain_accs.append(acc)

        # PHASE 2: Train decoder on Bengali data
        print(f"Training decoder (Bengali) - Epoch {epoch+1}")
        for batch in tqdm(bengali_loader, desc=f"Decoder MLM (Bengali) Epoch {epoch+1}", disable=CONFIG['tqdm_disable']):

            # Extract batch data and move to device
            input_ids = batch['input_ids'].to(device)  # Shape: (batch_size, seq_length)
            labels = batch['labels'].to(device)        # Shape: (batch_size, seq_length)

            # DEBUG: Print shapes for understanding
            #print(f"input_ids shape: {input_ids.shape}")
            #print(f"labels shape: {labels.shape}")

            # Clear gradients from previous step
            optimizer.zero_grad()

            # FORWARD PASS: Get decoder predictions
            # Decoder tries to predict original Bengali words from masked sentences
            # Input: (batch_size, seq_length) → Output: (batch_size, seq_length, vocab_size)
            logits = model.decoder_mlm(input_ids)
            #print(f"logits shape: {logits.shape}")

            # Verify shapes match for loss calculation
            assert logits.shape[:2] == labels.shape, f"Logits shape {logits.shape} and labels shape {labels.shape} do not match for loss"

            # LOSS CALCULATION: How wrong were the predictions?
            loss = criterion(
                logits.view(-1, logits.size(-1)),  # Flatten for loss calculation
                labels.view(-1)
            )

            # BACKWARD PASS: Learn from mistakes
            loss.backward()

            # UPDATE: Apply improvements
            optimizer.step()

            # SCHEDULING: Adjust learning rate
            scheduler.step()

            # TRACKING: Record performance metrics
            epoch_train_loss += loss.item()
            batch_count += 1

            # Calculate accuracy
            acc = batch_accuracy(logits, labels, -100)
            pretrain_accs.append(acc)

        # EPOCH SUMMARY: Calculate and report average performance
        avg_train_loss = epoch_train_loss / batch_count
        train_losses.append(avg_train_loss)

        print(f"Pre-train Epoch {epoch+1}/{config['pretrain_epochs']}")
        print(f"Train Loss: {avg_train_loss:.4f}")

        # EARLY STOPPING: Check if we should stop training
        if config['apply_early_stop'] and early_stopping(avg_train_loss):
            print(f"Early stopping triggered at epoch {epoch+1}")
            break

    # STEP 7: Save the trained model
    # This preserves all the learned knowledge for later use
    torch.save({
        'model_state_dict': model.state_dict(),    # All learned parameters
        'config': config,                          # Training configuration
        'src_vocab': src_vocab,                    # English vocabulary
        'tgt_vocab': tgt_vocab,                    # Bengali vocabulary
        'pretrain_losses': (train_losses, val_losses),  # Training history
        'pretrain_accs': pretrain_accs             # Accuracy history
    }, 'pretrained_transformer.pth')

    print("Pre-training completed!")

    # Return the trained model and performance metrics
    return model, train_losses, val_losses, pretrain_accs

# DIMENSIONAL FLOW SUMMARY:
# 1. Raw text → Tokenized sequences: List[str] → (batch_size, seq_length)
# 2. Masked sequences → Model predictions: (batch_size, seq_length) → (batch_size, seq_length, vocab_size)
# 3. Predictions → Loss calculation: (batch_size, seq_length, vocab_size) → scalar
# 4. Loss → Gradients → Weight updates: scalar → model parameters updated
# 5. Repeat for thousands of batches and multiple epochs

In [22]:
# === FINETUNE MODEL: Translation on parallel data ===
def finetune_model(model, english_sentences, bengali_sentences, src_vocab, tgt_vocab, config, overall_start_time=None):
    """
    Fine-tune the pretrained transformer model on parallel translation data (English→Bengali).

    Think of this like training a pre-existing translator to get better at a specific language pair.
    The model already knows general language patterns, now we teach it English→Bengali specifically.
    """
    print("=== Starting Fine-tuning Phase ===")

    # Set up GPU/CPU processing - like choosing which processor to use for calculations
    device = torch.device(config['device'])
    model = model.to(device)  # Move model to GPU/CPU

    # Split data: 90% for training, 10% for validation (like train/test split in traditional ML)
    # This prevents overfitting - we need unseen data to check if model generalizes well
    split_idx = int(0.9 * len(english_sentences))
    train_src = english_sentences[:split_idx]        # Training English sentences
    train_tgt = bengali_sentences[:split_idx]        # Training Bengali sentences
    val_src = english_sentences[split_idx:]          # Validation English sentences
    val_tgt = bengali_sentences[split_idx:]          # Validation Bengali sentences

    # Create datasets - these convert text to numbers that neural networks can process
    # Input: List of sentences (strings)
    # Output: Tensors of token IDs with shape [num_samples, max_length]
    train_dataset = TranslationDataset(train_src, train_tgt, src_vocab, tgt_vocab, config['max_length'])
    val_dataset = TranslationDataset(val_src, val_tgt, src_vocab, tgt_vocab, config['max_length'])

    # Create data loaders - these feed data to model in batches (like processing chunks instead of all at once)
    # batch_size determines how many sentence pairs to process simultaneously
    # Input: Dataset, Output: Batches of shape [batch_size, max_length]
    train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False)

    print(f"Fine-tuning training samples: {len(train_dataset)}")
    print(f"Fine-tuning validation samples: {len(val_dataset)}")

    # Loss function - measures how wrong the model's predictions are
    # CrossEntropyLoss: good for classification (picking next word from vocabulary)
    # ignore_index=0: ignores padding tokens (empty spaces added to make all sentences same length)
    criterion = nn.CrossEntropyLoss(ignore_index=0)

    # Optimizer - the algorithm that updates model weights to reduce loss
    # Adam: popular optimizer that adapts learning rate automatically
    # lr: learning rate (how big steps to take when updating weights)
    # betas, eps: Adam-specific hyperparameters for momentum and numerical stability
    optimizer = optim.Adam(model.parameters(), lr=config['finetune_learning_rate'], betas=(0.9, 0.98), eps=1e-9)

    # Learning rate scheduler - gradually changes learning rate during training
    # Noam scheduler: starts low, increases, then decreases (helps with convergence)
    scheduler = get_noam_scheduler(optimizer, config['d_model'], config['warmup_steps'])

    # Early stopping - stops training if model stops improving (prevents overfitting)
    early_stopping = EarlyStopping(patience=config['patience'])

    # Lists to track training progress over time
    train_losses = []    # How wrong model is on training data each epoch
    val_losses = []      # How wrong model is on validation data each epoch
    train_accs = []      # How accurate model is on training data each epoch
    val_accs = []        # How accurate model is on validation data each epoch

    start_time = time.time()

    # Training loop - repeat for specified number of epochs (full passes through data)
    for epoch in range(config['finetune_epochs']):
        # Time tracking for monitoring training progress
        time_elapsed = time.time() - start_time
        global_time_elapsed = time.time() - overall_start_time if overall_start_time is not None else 0
        print(f"Time elapsed: {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s (phase), {global_time_elapsed // 60:.0f}m {global_time_elapsed % 60:.0f}s (overall)")

        # Stop if training takes too long (computational budget limit)
        if overall_start_time is not None and global_time_elapsed > config['max_global_minutes'] * 60:
            print("Global time limit exceeded. Stopping fine-tuning.")
            break

        # TRAINING PHASE
        model.train()  # Set model to training mode (enables dropout, batch norm updates)
        epoch_train_loss = 0  # Accumulate loss for this epoch
        epoch_train_acc = 0   # Accumulate accuracy for this epoch
        batch_cnt = 0         # Count batches processed

        # Process each batch of training data
        for batch in tqdm(train_loader, desc=f"Fine-tune Epoch {epoch+1}", disable=CONFIG['tqdm_disable']):
            # Get batch data and move to GPU/CPU
            # src: English sentences as token IDs, shape [batch_size, max_length]
            # tgt_ip: Bengali input (shifted right for teacher forcing), shape [batch_size, max_length]
            # tgt_op: Bengali output (ground truth), shape [batch_size, max_length]
            src = batch['src'].to(device)
            tgt_ip = batch['tgt_ip'].to(device)
            tgt_op = batch['tgt_op'].to(device)

            # Reset gradients from previous batch (PyTorch accumulates gradients)
            optimizer.zero_grad()

            # Forward pass: feed data through model
            # Input: src [batch_size, src_len], tgt_ip [batch_size, tgt_len]
            # Output: predictions [batch_size, tgt_len, vocab_size] - probability for each word in vocabulary
            outputs = model(src, tgt_ip, training=True)

            # Calculate loss: how different are predictions from correct answers
            # Reshape to 2D for CrossEntropyLoss: [batch_size * tgt_len, vocab_size] vs [batch_size * tgt_len]
            # This compares each predicted word probability distribution with the correct word ID
            loss = criterion(outputs.reshape(-1, outputs.size(-1)), tgt_op.reshape(-1))

            # Calculate accuracy: what percentage of words were predicted correctly
            # Input: outputs [batch_size, tgt_len, vocab_size], tgt_op [batch_size, tgt_len]
            # Output: scalar accuracy value (0.0 to 1.0)
            acc = batch_accuracy(outputs, tgt_op, ignore_index=0)
            epoch_train_acc += acc
            batch_cnt += 1

            # Backward pass: calculate gradients (how much each weight should change)
            loss.backward()

            # Update model weights based on gradients
            optimizer.step()

            # Update learning rate according to schedule
            scheduler.step()

            # Accumulate loss for epoch average
            epoch_train_loss += loss.item()

        # Calculate average metrics for this epoch
        avg_train_loss = epoch_train_loss / batch_cnt
        avg_train_acc = epoch_train_acc / batch_cnt

        # VALIDATION PHASE - test model on unseen data
        model.eval()  # Set model to evaluation mode (disables dropout, freezes batch norm)
        epoch_val_loss = 0
        epoch_val_acc = 0
        val_batches = 0

        # Don't calculate gradients for validation (saves memory and computation)
        with torch.no_grad():
            for batch in val_loader:
                # Same data processing as training
                src = batch['src'].to(device)
                tgt_ip = batch['tgt_ip'].to(device)
                tgt_op = batch['tgt_op'].to(device)

                # Forward pass only (no backprop)
                # Input/Output dimensions same as training
                outputs = model(src, tgt_ip, training=True)

                # Calculate loss and accuracy same as training
                loss = criterion(outputs.reshape(-1, outputs.size(-1)), tgt_op.reshape(-1))
                acc = batch_accuracy(outputs, tgt_op, ignore_index=0)

                epoch_val_loss += loss.item()
                epoch_val_acc += acc
                val_batches += 1

        # Calculate validation averages
        avg_val_loss = epoch_val_loss / val_batches
        avg_val_acc = epoch_val_acc / val_batches

        # Store metrics for plotting/analysis later
        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)
        train_accs.append(avg_train_acc)
        val_accs.append(avg_val_acc)

        # Print progress
        print(f"Fine-tune Epoch {epoch+1}/{config['finetune_epochs']}")
        print(f"Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
        print(f"Train Acc : {avg_train_acc:.4f}, Val Acc : {avg_val_acc:.4f}")

        # Check if model stopped improving (early stopping)
        if config['apply_early_stop'] and early_stopping(avg_val_loss):
            print(f"Early stopping triggered at epoch {epoch+1}")
            break

    # Save the trained model and all training information
    # This creates a checkpoint file with everything needed to use the model later
    torch.save({
        'model_state_dict': model.state_dict(),  # Model weights
        'config': config,                        # Training configuration
        'src_vocab': src_vocab,                  # English vocabulary (word→ID mapping)
        'tgt_vocab': tgt_vocab,                  # Bengali vocabulary (word→ID mapping)
        'finetune_losses': (train_losses, val_losses),  # Training history
        'train_accs': train_accs,                # Training accuracy history
        'val_accs': val_accs                     # Validation accuracy history
    }, 'finetuned_transformer.pth')

    print("Fine-tuning completed!")
    return model, train_losses, val_losses, train_accs, val_accs

In [23]:
class TranslationInference:
    """
    Inference class for translating English sentences to Bengali using the trained model.

    Think of this as a "live translator" that takes your trained model and uses it to
    translate new sentences. It's like having a translator API that processes one
    sentence at a time.

    The key difference from training: we don't know the target sentence beforehand,
    so we generate it word-by-word (autoregressive decoding).
    """

    def __init__(self, model, src_vocab, tgt_vocab, config):
        """
        Initialize the translator with a trained model and vocabularies.

        Args:
            model: The trained neural network (like a trained translator's brain)
            src_vocab: English vocabulary (maps English words ↔ numbers)
            tgt_vocab: Bengali vocabulary (maps Bengali words ↔ numbers)
            config: Configuration settings (max length, device, etc.)
        """
        self.model = model
        self.src_vocab = src_vocab          # English word→ID and ID→word mappings
        self.tgt_vocab = tgt_vocab          # Bengali word→ID and ID→word mappings
        self.config = config
        self.device = torch.device(config['device'])  # GPU or CPU
        self.model.eval()  # Set model to evaluation mode (no training, no dropout)

    def translate(self, sentence, max_length=None):
        """
        Translate a single English sentence to Bengali.

        This is like asking a human translator: "What's the Bengali for 'Hello world'?"
        The translator processes it word by word, building the Bengali sentence incrementally.

        Args:
            sentence: English sentence as string (e.g., "Hello world")
            max_length: Maximum allowed length for input/output

        Returns:
            translated_sentence: Bengali sentence as string
        """
        if max_length is None:
            max_length = self.config['max_length']

        # STEP 1: ENCODE SOURCE SENTENCE
        # Convert English sentence from text to numbers that the model understands
        # Input: "Hello world" (string)
        # Output: [45, 123, 67] (list of integers, each representing a word)
        src_encoded = self.src_vocab.encode(sentence)

        # Truncate if sentence is too long (like cutting off a long message)
        if len(src_encoded) > max_length:
            src_encoded = src_encoded[:max_length]

        # STEP 2: PAD SOURCE SEQUENCE
        # All inputs must be same length, so add zeros (padding) to shorter sequences
        # Input: [45, 123, 67] (length 3)
        # Output: [45, 123, 67, 0, 0, 0, 0, 0] (length max_length, padded with zeros)
        src_padded = src_encoded + [0] * (max_length - len(src_encoded))

        # Convert to PyTorch tensor and add batch dimension
        # Input: [45, 123, 67, 0, 0, 0, 0, 0] (list)
        # Output: [[45, 123, 67, 0, 0, 0, 0, 0]] (tensor with shape [1, max_length])
        # The extra dimension is needed because models expect batches, even for single sentences
        src_tensor = torch.tensor([src_padded], dtype=torch.long).to(self.device)

        # STEP 3: INITIALIZE TARGET SEQUENCE
        # Start translation with SOS (Start of Sentence) token
        # This tells the model "start generating a Bengali sentence now"
        # Input: empty target
        # Output: [2] (where 2 is the ID for '<SOS>' token)
        tgt_input = [self.tgt_vocab.get_vocab().get('<SOS>')]

        # STEP 4: AUTOREGRESSIVE DECODING
        # Generate Bengali sentence word by word (like a human translator thinking step by step)
        # We don't calculate gradients during inference (saves memory and computation)
        with torch.no_grad():
            # Generate up to max_length words
            for _ in range(max_length):
                # Pad current target sequence to fixed length
                # Input: [2, 156] (SOS + first Bengali word)
                # Output: [2, 156, 0, 0, 0, 0, 0, 0] (padded to max_length)
                tgt_padded = tgt_input + [0] * (max_length - len(tgt_input))

                # Convert to tensor with batch dimension
                # Input: [2, 156, 0, 0, 0, 0, 0, 0] (list)
                # Output: [[2, 156, 0, 0, 0, 0, 0, 0]] (tensor with shape [1, max_length])
                tgt_tensor = torch.tensor([tgt_padded], dtype=torch.long).to(self.device)

                # STEP 5: FORWARD PASS
                # Feed source sentence and current target to model
                # Input: src_tensor [1, max_length], tgt_tensor [1, max_length]
                # Output: predictions [1, max_length, vocab_size]
                # vocab_size is the number of Bengali words the model knows
                # Each position contains probability distribution over all possible next words
                predictions = self.model(src_tensor, tgt_tensor, training=False)

                # STEP 6: GET NEXT WORD PREDICTION
                # Extract prediction for the next word position
                # We want the prediction for position after the last word we generated
                # Input: predictions [1, max_length, vocab_size]
                # Output: next_token_logits [vocab_size] - probability scores for each Bengali word
                next_token_logits = predictions[0, len(tgt_input)-1, :]

                # Choose the word with highest probability (greedy decoding)
                # Input: next_token_logits [vocab_size] (e.g., [0.1, 0.8, 0.05, 0.05, ...])
                # Output: next_token (integer) - ID of most probable word (e.g., 1 for index with 0.8)
                next_token = torch.argmax(next_token_logits).item()

                # STEP 7: UPDATE TARGET SEQUENCE
                # Add the predicted word to our growing Bengali sentence
                # Input: tgt_input [2, 156], next_token 89
                # Output: tgt_input [2, 156, 89] (SOS + word1 + word2)
                tgt_input.append(next_token)

                # STEP 8: CHECK FOR END OF SENTENCE
                # Stop if model generates EOS (End of Sentence) token
                # This is like the model saying "I'm done with this sentence"
                if next_token == self.tgt_vocab.get_vocab().get('<EOS>'):
                    break

        # STEP 9: DECODE RESULT
        # Convert the generated number sequence back to Bengali text
        # Remove SOS token from beginning and EOS token from end (if present)
        # Input: [2, 156, 89, 234, 3] (SOS + Bengali words + EOS)
        # Output: [156, 89, 234] (just the Bengali words)
        if tgt_input[-1] == self.tgt_vocab.get_vocab().get('<EOS>'):
            result_tokens = tgt_input[1:-1]  # Remove SOS and EOS
        else:
            result_tokens = tgt_input[1:]    # Remove only SOS

        # Convert token IDs back to Bengali text
        # Input: [156, 89, 234] (list of word IDs)
        # Output: "নমস্কার বিশ্ব" (Bengali text string)
        translated_sentence = self.tgt_vocab.decode(result_tokens)

        return translated_sentence

In [24]:
class BPETokenizer:
    """
    Improved Byte Pair Encoding (BPE) Tokenizer using HuggingFace's tokenizers library.

    WHAT IS A TOKENIZER?
    Think of this as a "text-to-numbers converter" for neural networks.
    Neural networks can't understand words like "hello" or "নমস্কার" - they only work with numbers.

    The tokenizer's job is to:
    1. Convert text → numbers (encoding): "Hello world" → [45, 123, 67]
    2. Convert numbers → text (decoding): [45, 123, 67] → "Hello world"

    WHAT IS BPE?
    BPE is like a smart compression algorithm for text. Instead of having a word for every
    possible combination, it learns the most common "chunks" (subwords) and builds words from them.

    Example: "unhappiness" might become ["un", "happy", "ness"] → [234, 567, 890]
    This way, even if the model never saw "unhappiness", it can understand it from its parts.

    Handles both English and Bengali text properly with appropriate normalization.
    Now uses ByteLevel pre-tokenizer and decoder to preserve spaces in output.
    """

    def __init__(self, vocab_size=10000, language="mixed"):
        """
        Initialize the tokenizer.

        Args:
            vocab_size: Maximum number of unique tokens (like dictionary size)
                       Input: integer (e.g., 10000)
                       Effect: Determines how many unique word-pieces the tokenizer can learn
            language: What language to optimize for ("english", "bengali", or "mixed")
                     Input: string
                     Effect: Changes how text is normalized (lowercase, accent removal, etc.)
        """
        self.vocab_size = vocab_size
        self.tokenizer = None  # Will hold the trained tokenizer object
        self.language = language

    def train(self, sentences):
        """
        Train the BPE tokenizer on the given sentences.

        This is like teaching the tokenizer what words and word-pieces exist in your language.
        The tokenizer analyzes all your text and learns the most efficient way to break it down.

        Args:
            sentences: List of sentences to train on
                      Input: List of strings, e.g., ["Hello world", "How are you", "নমস্কার"]
                      Output: Trained tokenizer that can convert text ↔ numbers
        """
        if not sentences:
            raise ValueError("Cannot train tokenizer on empty sentence list")

        # STEP 1: PREPARE TRAINING DATA
        # Write all sentences to a temporary file (HuggingFace tokenizers expect file input)
        # Input: List of strings in memory
        # Output: Text file with one sentence per line
        with tempfile.NamedTemporaryFile(mode='w+', delete=False, encoding='utf-8') as f:
            for sentence in sentences:
                if sentence and sentence.strip():  # Skip empty sentences
                    f.write(sentence.strip() + '\n')
            temp_file = f.name

        try:
            # STEP 2: INITIALIZE TOKENIZER
            # Create a blank BPE tokenizer (like creating an empty dictionary)
            # Input: None (fresh start)
            # Output: Untrained tokenizer object with BPE algorithm
            tokenizer = Tokenizer(models.BPE(unk_token="<UNK>"))

            # STEP 3: SET UP TEXT NORMALIZATION
            # This is like setting "text preprocessing rules" - how to clean/standardize text
            # Different languages need different preprocessing
            if self.language == "english":
                # For English: NFD normalization, lowercase, strip accents
                # Input: "Café" → Output: "cafe" (normalized, lowercase, no accents)
                tokenizer.normalizer = NormalizerSequence([
                    NFD(),        # Normalize unicode characters (é → e + accent)
                    Lowercase(),  # Convert to lowercase
                    StripAccents() # Remove accent marks
                ])
            elif self.language == "bengali":
                # For Bengali: Only NFD normalization (preserve case and Bengali characters)
                # Input: "নমস্কার" → Output: "নমস্কার" (normalized unicode only)
                tokenizer.normalizer = NFD()
            else:  # mixed or default
                # For mixed languages: Only NFD to handle both properly
                # Input: Mixed text → Output: Unicode-normalized text
                tokenizer.normalizer = NFD()

            # STEP 4: SET UP PRE-TOKENIZATION
            # This splits text into "chunks" before BPE processing
            # ByteLevel preserves spaces and handles any character (including emojis, special chars)
            # Input: "Hello world" → Output: ["Hello", " world"] (space preserved)
            tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=True)

            # STEP 5: DEFINE SPECIAL TOKENS
            # These are like "reserved words" that have special meaning
            # Input: List of special token strings
            # Output: Reserved token IDs (usually 0, 1, 2, 3, 4)
            special_tokens = ["<PAD>", "<UNK>", "<SOS>", "<EOS>", "<MASK>"]
            # <PAD>: Padding (fill empty space)
            # <UNK>: Unknown word (words not in vocabulary)
            # <SOS>: Start of Sentence
            # <EOS>: End of Sentence
            # <MASK>: Masked token (for some training techniques)

            # STEP 6: SET UP TRAINER
            # Configure how the BPE algorithm will learn
            # Input: Training parameters
            # Output: Trainer object that knows how to build vocabulary
            trainer = trainers.BpeTrainer(
                vocab_size=self.vocab_size,     # Maximum vocabulary size
                special_tokens=special_tokens,   # Reserve these tokens
                min_frequency=1,                 # Minimum times a subword must appear
                show_progress=True              # Show progress bar during training
            )

            # STEP 7: TRAIN THE TOKENIZER
            # This is the core learning step - BPE algorithm analyzes all text
            # Input: Text file with sentences
            # Process: Counts character pairs, merges most frequent ones iteratively
            # Output: Vocabulary mapping (token_string → token_id)
            #
            # Example process:
            # 1. Start with individual characters: "h", "e", "l", "l", "o"
            # 2. Find most frequent pair: "l" + "l" → "ll"
            # 3. Merge them: "h", "e", "ll", "o"
            # 4. Repeat until vocab_size reached
            tokenizer.train([temp_file], trainer)

            # STEP 8: SET UP DECODER
            # This handles converting token IDs back to text
            # ByteLevel decoder properly handles spaces and special characters
            # Input: List of token IDs, e.g., [45, 123, 67]
            # Output: Properly formatted text with spaces preserved
            tokenizer.decoder = decoders.ByteLevel()

            """
            # STEP 9: SET UP POST-PROCESSOR
            # This automatically adds special tokens to sequences
            # Template: "<SOS> $A <EOS>" means wrap every sentence with start/end tokens
            # Input: Raw sentence
            # Output: Sentence with SOS and EOS tokens added
            tokenizer.post_processor = processors.TemplateProcessing(
                single="<SOS> $A <EOS>",  # Template for single sentence
                special_tokens=[
                    ("<SOS>", self.get_special_token_id("<SOS>", tokenizer)),
                    ("<EOS>", self.get_special_token_id("<EOS>", tokenizer)),
                ]
            )
            """
            # Store the trained tokenizer
            self.tokenizer = tokenizer

        finally:
            # CLEANUP: Remove temporary file
            if os.path.exists(temp_file):
                os.remove(temp_file)

    def get_special_token_id(self, token, tokenizer):
        """
        Helper method to get special token ID.

        Input: Token string (e.g., "<SOS>")
        Output: Token ID (e.g., 2)

        Like looking up a word in a dictionary to get its page number.
        """
        vocab = tokenizer.get_vocab()  # Get the word→ID mapping dictionary
        return vocab.get(token, vocab.get("<UNK>"))  # Return ID or default to UNK

    def encode(self, sentence, add_special_tokens=False):
        """
        Encode a sentence to token IDs.

        This is the main "text → numbers" conversion function.

        Args:
            sentence: Input sentence string
                     Input: "Hello world" (string)
                     Output: [45, 123, 67] (list of integers)
            add_special_tokens: Whether to add SOS/EOS tokens
                               Input: True/False
                               Effect: [2, 45, 123, 67, 3] vs [45, 123, 67]

        Returns:
            List of token IDs
        """
        if self.tokenizer is None:
            raise ValueError("BPE Tokenizer not trained yet. Call train() first.")

        if not sentence or not sentence.strip():
            return []  # Empty input → empty output

        # ENCODING PROCESS:
        # Input: "Hello world" (string)
        # Step 1: Normalize → "hello world" (if English)
        # Step 2: Pre-tokenize → ["hello", " world"] (preserve spaces)
        # Step 3: Apply BPE → ["he", "llo", " wor", "ld"] (subword units)
        # Step 4: Convert to IDs → [45, 123, 67, 89] (numbers)
        # Step 5: Add special tokens if requested → [2, 45, 123, 67, 89, 3]
        encoding = self.tokenizer.encode(sentence.strip(), add_special_tokens=add_special_tokens)
        return encoding.ids  # Return just the list of numbers

    def decode(self, ids, skip_special_tokens=True):
        """
        Decode token IDs back to text.

        This is the main "numbers → text" conversion function.

        Args:
            ids: List of token IDs
                Input: [45, 123, 67] (list of integers)
                Output: "Hello world" (string)
            skip_special_tokens: Whether to skip special tokens in output
                               Input: True/False
                               Effect: "Hello world" vs "<SOS> Hello world <EOS>"

        Returns:
            Decoded string
        """
        if self.tokenizer is None:
            raise ValueError("BPE Tokenizer not trained yet. Call train() first.")

        if not ids:
            return ""  # Empty input → empty output

        # STEP 1: FILTER SPECIAL TOKENS (if requested)
        # Remove special tokens from the ID list before decoding
        # This prevents output like "<SOS> Hello world <EOS>" when you just want "Hello world"
        if skip_special_tokens:
            # Get IDs of all special tokens
            # Input: Special token strings
            # Output: Set of special token IDs, e.g., {0, 1, 2, 3, 4}
            special_ids = {
                self.tokenizer.token_to_id("<PAD>"),
                self.tokenizer.token_to_id("<UNK>"),
                self.tokenizer.token_to_id("<SOS>"),
                self.tokenizer.token_to_id("<EOS>"),
                self.tokenizer.token_to_id("<MASK>")
            }
            # Remove None values (tokens that don't exist)
            special_ids = {id for id in special_ids if id is not None}

            # Filter out special token IDs from the input
            # Input: [2, 45, 123, 67, 3] (includes SOS=2, EOS=3)
            # Output: [45, 123, 67] (only content tokens)
            ids = [id for id in ids if id not in special_ids]

        try:
            # STEP 2: DECODE TO TEXT
            # DECODING PROCESS:
            # Input: [45, 123, 67] (list of token IDs)
            # Step 1: Convert IDs to subwords → ["he", "llo", " world"]
            # Step 2: Merge subwords → "hello world"
            # Step 3: Handle spaces and special characters properly
            # Output: "hello world" (final string)
            decoded = self.tokenizer.decode(ids)
            return decoded.strip()  # Remove leading/trailing whitespace
        except Exception as e:
            print(f"Warning: Decode error for IDs {ids}: {e}")
            return ""

    def get_vocab(self):
        """
        Get the vocabulary dictionary.

        Returns the complete word→ID mapping that the tokenizer learned.

        Input: None
        Output: Dictionary like {"hello": 45, "world": 123, "<SOS>": 2, ...}

        This is like getting the complete dictionary/codebook the tokenizer uses.
        """
        if self.tokenizer is None:
            raise ValueError("BPE Tokenizer not trained yet. Call train() first.")
        return self.tokenizer.get_vocab()

    def get_vocab_size(self):
        """
        Get the actual vocabulary size.

        Input: None
        Output: Integer (actual number of tokens learned, e.g., 8547)

        This might be different from the requested vocab_size if there wasn't enough text.
        """
        if self.tokenizer is None:
            return 0
        return len(self.get_vocab())

    def token_to_id(self, token):
        """
        Convert token to ID.

        Input: Token string (e.g., "hello")
        Output: Token ID (e.g., 45)

        Like looking up a word in a dictionary to get its page number.
        """
        if self.tokenizer is None:
            raise ValueError("BPE Tokenizer not trained yet.")
        return self.tokenizer.token_to_id(token)

    def id_to_token(self, id):
        """
        Convert ID to token.

        Input: Token ID (e.g., 45)
        Output: Token string (e.g., "hello")

        Like looking up a page number in a dictionary to get the word.
        """
        if self.tokenizer is None:
            raise ValueError("BPE Tokenizer not trained yet.")
        return self.tokenizer.id_to_token(id)

    def save(self, filepath):
        """
        Save the trained tokenizer to disk.

        Input: File path (e.g., "my_tokenizer.json")
        Output: Tokenizer saved to file

        Like saving your custom dictionary so you can use it later.
        """
        if self.tokenizer is None:
            raise ValueError("No trained tokenizer to save.")
        self.tokenizer.save(filepath)

    def load(self, filepath):
        """
        Load a trained tokenizer from disk.

        Input: File path (e.g., "my_tokenizer.json")
        Output: Tokenizer loaded and ready to use

        Like loading a previously saved dictionary.
        """
        self.tokenizer = Tokenizer.from_file(filepath)

In [None]:
# === MAIN FUNCTION FOR TWO-STAGE TRAINING ===
"""
Main function to orchestrate the entire training and inference process:

OVERVIEW FOR SOFTWARE DEVELOPERS:
This is a machine learning pipeline that trains a "Transformer" model to translate
English text to Bengali. Think of it like training a very sophisticated autocomplete
system that can convert meaning between languages.

THE TWO-STAGE PROCESS:
1. PRE-TRAINING: The model learns general language patterns by predicting missing
    words in sentences (like a fill-in-the-blanks game)
2. FINE-TUNING: The model learns specific translation patterns using paired
    English-Bengali sentences

STEPS:
1. Load monolingual data (separate English and Bengali text files)
2. Build vocabularies (maps words to numbers that computers can process)
3. Create and pretrain the model (MLM - Masked Language Modeling)
4. Fine-tune the model on translation pairs
5. Save the model and vocabs
6. Run inference on a new English sentence
"""
# === Overall start time ===
overall_start_time = time.time()
print("=== English-Bengali Transformer: Pre-training and Fine-tuning ===")
print(f"Configuration: {CONFIG}")

# === STEP 1: LOAD MONOLINGUAL DATA ===
# INPUT: Text files with one sentence per line
# OUTPUT: Lists of strings (sentences)
# PURPOSE: Get raw text data to train language understanding
english_sentences = load_monolingual_data(CONFIG['english_file'], CONFIG['max_pretrain_sentences'])
bengali_sentences = load_monolingual_data(CONFIG['bengali_file'], CONFIG['max_pretrain_sentences'])

# Error handling - exit if data loading fails
if not english_sentences or not bengali_sentences:
    print("Error: Could not load monolingual data. Please check 'english.txt' and 'bengali.txt'.")


# === STEP 2: BUILD VOCABULARIES USING BPE (Byte Pair Encoding) ===
# CONCEPT: Computers can't process words directly - they need numbers
# BPE breaks words into subword pieces (like "unhappy" → "un" + "happy")
# This helps handle rare words and different word forms

if CONFIG.get('shared_bpe_vocab', False):
    # SHARED VOCABULARY APPROACH:
    # INPUT: Combined English + Bengali sentences (list of strings)
    # OUTPUT: One tokenizer that handles both languages
    # DIMENSIONS: vocab_size × embedding_dim mapping table
    print("\n--- Training SHARED BPE Tokenizer on English + Bengali sentences ---")
    shared_sentences = english_sentences + bengali_sentences  # Concatenate lists
    shared_bpe = BPETokenizer(vocab_size=CONFIG['vocab_size'])  # Create tokenizer
    shared_bpe.train(shared_sentences)  # Learn subword patterns from data
    print("Shared BPE vocab size:", len(shared_bpe.get_vocab()))
    src_vocab = shared_bpe  # Source language tokenizer
    tgt_vocab = shared_bpe  # Target language tokenizer (same as source)
else:
    # SEPARATE VOCABULARY APPROACH:
    # INPUT: English sentences (list of strings)
    # OUTPUT: English-specific tokenizer
    # DIMENSIONS: vocab_size × embedding_dim mapping table for English
    print("\n--- Training BPE Tokenizer on English sentences ---")
    bpe_tokenizer_en = BPETokenizer(vocab_size=CONFIG['vocab_size'], language="english")
    bpe_tokenizer_en.train(english_sentences)  # Learn English subword patterns
    print("English BPE vocab size:", len(bpe_tokenizer_en.get_vocab()))

    # INPUT: Bengali sentences (list of strings)
    # OUTPUT: Bengali-specific tokenizer
    # DIMENSIONS: vocab_size × embedding_dim mapping table for Bengali
    print("\n--- Training BPE Tokenizer on Bengali sentences ---")
    bpe_tokenizer_bn = BPETokenizer(vocab_size=CONFIG['vocab_size'], language="bengali")
    bpe_tokenizer_bn.train(bengali_sentences)  # Learn Bengali subword patterns
    print("Bengali BPE vocab size:", len(bpe_tokenizer_bn.get_vocab()))

    src_vocab = bpe_tokenizer_en  # Source language tokenizer
    tgt_vocab = bpe_tokenizer_bn  # Target language tokenizer

# === STEP 3: CREATE TRANSFORMER MODEL ===
# CONCEPT: A Transformer is like a very sophisticated pattern matching system
# It has an "encoder" (understands input) and "decoder" (generates output)
# Think of it as: Encoder reads English, Decoder writes Bengali

print("Creating transformer model...")
# INPUT DIMENSIONS:
# - src_vocab_size: Number of unique English subwords (typically 8000-50000)
# - tgt_vocab_size: Number of unique Bengali subwords (typically 8000-50000)
# - d_model: Internal vector size (like 512 or 768) - bigger = more capacity
# - num_heads: Parallel attention mechanisms (like 8 or 16)
# - max_length: Maximum sentence length in tokens (like 256 or 512)

model = Transformer(
    src_vocab_size=len(src_vocab.get_vocab()),    # English vocabulary size
    tgt_vocab_size=len(tgt_vocab.get_vocab()),    # Bengali vocabulary size
    d_model=CONFIG['d_model'],                    # Hidden dimension (e.g., 512)
    num_encoder_layers=CONFIG['num_encoder_layers'],  # Encoder depth (e.g., 6)
    num_decoder_layers=CONFIG['num_decoder_layers'],  # Decoder depth (e.g., 6)
    num_heads=CONFIG['num_heads'],                # Attention heads (e.g., 8)
    dff=CONFIG['dff'],                           # Feed-forward dimension (e.g., 2048)
    max_length=CONFIG['max_length'],             # Max sequence length (e.g., 256)
    dropout_rate=CONFIG['dropout_rate']          # Regularization rate (e.g., 0.1)
)

# === STEP 4: PRE-TRAINING PHASE ===
# CONCEPT: Like teaching a child language by having them fill in missing words
# INPUT: Individual sentences with some words randomly masked
# OUTPUT: Model that understands language patterns
# VECTOR OPERATIONS:
# - Input: [batch_size, sequence_length] integers (token IDs)
# - Embedding: [batch_size, sequence_length, d_model] floats
# - Transformer layers: [batch_size, sequence_length, d_model] → [batch_size, sequence_length, d_model]
# - Output: [batch_size, sequence_length, vocab_size] probabilities

print("\n=== PRE-TRAINING PHASE ===")
model, pretrain_train_losses, pretrain_val_losses, pretrain_accs = pretrain_model(
    model,                    # The neural network model
    english_sentences,        # List of English sentences for training
    bengali_sentences,        # List of Bengali sentences for training
    src_vocab,               # English tokenizer
    tgt_vocab,               # Bengali tokenizer
    CONFIG,                  # Training configuration
    overall_start_time=overall_start_time
)

# === STEP 5: LOAD TRANSLATION DATA ===
# INPUT: CSV file with English-Bengali sentence pairs
# OUTPUT: Two aligned lists of sentences
# PURPOSE: Get paired data for translation training
english_pairs, bengali_pairs = load_translation_data(CONFIG['translation_file'], CONFIG['max_translation_pairs'])

# Error handling - exit if translation data loading fails
if not english_pairs or not bengali_pairs:
    print("Error: Could not load translation data. Please check 'english_to_bangla.csv'.")


# === STEP 6: FINE-TUNING PHASE ===
# CONCEPT: Now teach the model to translate by showing it English-Bengali pairs
# INPUT: Paired sentences (English sentence → Bengali sentence)
# OUTPUT: Model that can translate English to Bengali
# VECTOR OPERATIONS:
# - Encoder Input: [batch_size, src_seq_len] → [batch_size, src_seq_len, d_model]
# - Decoder Input: [batch_size, tgt_seq_len] → [batch_size, tgt_seq_len, d_model]
# - Cross-attention: Decoder attends to encoder output
# - Final Output: [batch_size, tgt_seq_len, tgt_vocab_size] translation probabilities

print("\n=== FINE-TUNING PHASE ===")
model, finetune_train_losses, finetune_val_losses, train_accs, val_accs = finetune_model(
    model,                    # Pre-trained model from step 4
    english_pairs,            # List of English sentences
    bengali_pairs,            # List of corresponding Bengali sentences
    src_vocab,               # English tokenizer
    tgt_vocab,               # Bengali tokenizer
    CONFIG,                  # Training configuration
    overall_start_time=overall_start_time
)

# === STEP 8: GENERATE TRAINING VISUALIZATION PLOTS ===
# CONCEPT: Create graphs to visualize how well the training went
# These help you understand if the model learned properly

# 1. Pretraining Loss Plot
# MEANING: Lower loss = better performance at predicting masked words
# X-axis: Training epochs (complete passes through data)
# Y-axis: Loss value (lower is better)
plt.figure(figsize=(6,4))
plt.plot(pretrain_train_losses, label='Pretrain Loss')
plt.title('Pretraining Loss vs Epoch')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.savefig('pretrain_loss.png')

# 2. Pretraining Accuracy Plot
# MEANING: Higher accuracy = better at predicting masked words
# X-axis: Training batches (small chunks of data processed together)
# Y-axis: Accuracy percentage (higher is better)
plt.figure(figsize=(6,4))
plt.plot(pretrain_accs, label='Pretrain Accuracy')
plt.title('Pretraining Accuracy vs Batch')
plt.xlabel('Batch (across epochs)')
plt.ylabel('Accuracy')
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.savefig('pretrain_accuracy.png')

# 3. Fine-tuning Loss Plot
# MEANING: Shows how well the model learns to translate
# Train Loss: Performance on training data
# Val Loss: Performance on validation data (unseen during training)
plt.figure(figsize=(6,4))
plt.plot(finetune_train_losses, label='Train Loss')
plt.plot(finetune_val_losses, label='Val Loss')
plt.title('Fine-tuning Loss vs Epoch')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.savefig('finetune_loss.png')

# 4. Fine-tuning Accuracy Plot
# MEANING: Translation accuracy on training vs validation data
# Gap between lines indicates overfitting (memorizing vs generalizing)
plt.figure(figsize=(6,4))
plt.plot(train_accs, label='Train Accuracy')
plt.plot(val_accs, label='Val Accuracy')
plt.title('Fine-tuning Accuracy vs Epoch')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.savefig('finetune_accuracy.png')

# Display plots briefly then close them to free memory
plt.show(block=False)  # Non-blocking display
plt.pause(15)          # Show for 15 seconds
plt.close('all')       # Clean up memory


=== English-Bengali Transformer: Pre-training and Fine-tuning ===
Configuration: {'vocab_size': 18000, 'd_model': 512, 'dff': 2048, 'num_heads': 8, 'num_encoder_layers': 6, 'num_decoder_layers': 6, 'dropout_rate': 0.1, 'max_length': 200, 'batch_size': 64, 'pretrain_learning_rate': 0.0001, 'finetune_learning_rate': 5e-05, 'pretrain_epochs': 500, 'finetune_epochs': 500, 'apply_early_stop': False, 'patience': 5, 'english_file': 'EBook_of_The_Bhagavad-Gita_English.txt', 'bengali_file': 'EBook_of_The_Bhagavad-Gita_Bengali.txt', 'translation_file': 'english_to_bangla.csv', 'max_pretrain_sentences': 39000, 'max_translation_pairs': 39, 'mask_probability': 0.15, 'device': 'cpu', 'max_train_minutes': 45, 'max_global_minutes': 100, 'shared_bpe_vocab': False, 'warmup_steps': 4000, 'tqdm_disable': False}
Loading monolingual data from EBook_of_The_Bhagavad-Gita_English.txt...
Loaded 3054 sentences from EBook_of_The_Bhagavad-Gita_English.txt
First 5 sentences(load_monolingual_data): ['CHAPTER I', 'Dh

Encoder MLM (English) Epoch 1: 100%|██████████| 48/48 [35:11<00:00, 43.99s/it]


Training decoder (Bengali) - Epoch 1


Decoder MLM (Bengali) Epoch 1:   1%|          | 1/166 [00:59<2:42:26, 59.07s/it]

In [None]:

# === STEP 7: SAVE THE TRAINED MODEL ===
# CONCEPT: Like saving a trained program to disk so you can use it later
# OUTPUT: A file containing all the learned parameters and settings
torch.save({
    'model_state_dict': model.state_dict(),  # All the learned weights/parameters
    'config': CONFIG,                        # Training configuration
    'src_vocab': src_vocab,                  # English tokenizer
    'tgt_vocab': tgt_vocab,                  # Bengali tokenizer
    'pretrain_losses': (pretrain_train_losses, pretrain_val_losses),  # Training history
    'finetune_losses': (finetune_train_losses, finetune_val_losses),  # Training history
    'pretrain_accs': pretrain_accs,          # Accuracy history
    'train_accs': train_accs,                # Training accuracy
    'val_accs': val_accs                     # Validation accuracy
}, 'en_bn_transformer_pretrained_finetuned.pth')
print("Model saved as 'en_bn_transformer_pretrained_finetuned.pth'")


In [None]:

# === STEP 9: TRANSLATION INFERENCE ===
# CONCEPT: Now use the trained model to translate new English sentences
# This simulates loading a saved model and using it in production


print("\n=== Translation Inference ===")
# Load the saved model (simulating a fresh start)
# INPUT: Saved model file
# OUTPUT: Restored model and tokenizers
checkpoint = torch.load('en_bn_transformer_pretrained_finetuned.pth',
                        map_location=CONFIG['device'],
                        weights_only=False)

# Restore the model's learned parameters
model.load_state_dict(checkpoint['model_state_dict'])
src_vocab = checkpoint['src_vocab']  # English tokenizer
tgt_vocab = checkpoint['tgt_vocab']  # Bengali tokenizer

# Create inference object for easy translation
# INPUT: English text string
# OUTPUT: Bengali text string
# VECTOR OPERATIONS IN INFERENCE:
# 1. Tokenize: "Hello world" → [15, 247, 2] (token IDs)
# 2. Embed: [15, 247, 2] → [3, 512] (batch_size=1, seq_len=3, d_model=512)
# 3. Encode: [3, 512] → [3, 512] (encoder output)
# 4. Decode: Generate Bengali tokens one by one using encoder output
# 5. Detokenize: [45, 123, 67, 2] → "নমস্কার বিশ্ব" (Bengali text)
translator = TranslationInference(model, src_vocab, tgt_vocab, CONFIG)

In [None]:

print("\n=== Testing Translations ===")
# Test dataset: Mix of simple and complex sentences
# PURPOSE: Evaluate model performance on different types of input
test_sentences = [
    "a child in a pink dress is climbing up a set of stairs in an entry way .",
    "a girl going into a wooden building .",
    "a dog is running in the snow",
    "a dog running",
    "Hello, how are you?",
    "a man in an orange hat starring at something .",
    "I love you.",
    "a little girl climbing into a wooden playhouse .",
    "What is your name?",
    "two dogs of different breeds looking at each other on the road .",
    "Good morning.",
    "Thank you very much.",
    "Hello, how are you?",
    "I love you.",
    "What is your name?",
    "Good morning.",
    "Thank you very much.",
    "The weather is nice today."
]

# BATCH PROCESSING LOOP:
# For each test sentence, perform the complete translation pipeline
for sentence in test_sentences:
    # TRANSLATION PIPELINE FOR EACH SENTENCE:
    #
    # STEP 1: INPUT PREPROCESSING
    # INPUT: Raw English string (e.g., "Hello, how are you?")
    # DIMENSIONS: Variable length string

    # STEP 2: TOKENIZATION (inside translator.translate())
    # INPUT: "Hello, how are you?" (string)
    # OUTPUT: [101, 7592, 1010, 2129, 2024, 2017, 1029, 102] (token IDs)
    # DIMENSIONS: [sequence_length] where sequence_length varies per sentence

    # STEP 3: EMBEDDING LOOKUP (inside model)
    # INPUT: [101, 7592, 1010, 2129, 2024, 2017, 1029, 102] (token IDs)
    # OUTPUT: [1, 8, 512] (batch_size=1, sequence_length=8, d_model=512)
    # OPERATION: Each token ID gets mapped to a 512-dimensional vector

    # STEP 4: ENCODER PROCESSING (inside model)
    # INPUT: [1, 8, 512] (embedded English tokens)
    # OUTPUT: [1, 8, 512] (encoded representations)
    # OPERATION: Self-attention + feed-forward layers process the input
    # Each layer maintains the same dimensions: [batch_size, seq_len, d_model]

    # STEP 5: DECODER PROCESSING (inside model)
    # INPUT: [1, 8, 512] (encoder output) + [1, tgt_len, 512] (partial Bengali)
    # OUTPUT: [1, tgt_len, bengali_vocab_size] (next token probabilities)
    # OPERATION: Cross-attention between English and Bengali representations
    # This happens iteratively: generate one Bengali token at a time

    # STEP 6: TOKEN GENERATION (inside translator.translate())
    # INPUT: [1, tgt_len, bengali_vocab_size] (probabilities for each position)
    # OUTPUT: [tgt_len] (selected token IDs)
    # OPERATION: Pick highest probability token at each position
    # Example: [45, 123, 67, 89, 2] (where 2 is the end-of-sequence token)

    # STEP 7: DETOKENIZATION (inside translator.translate())
    # INPUT: [45, 123, 67, 89, 2] (Bengali token IDs)
    # OUTPUT: "আপনি কেমন আছেন?" (Bengali text string)
    # OPERATION: Convert token IDs back to readable text

    translation = translator.translate(sentence)

    # DISPLAY RESULTS:
    # Show original English and translated Bengali side by side
    print(f"English: {sentence}")
    print(f"Bengali: {translation}")
    print("-" * 50)

    # PERFORMANCE CONSIDERATION:
    # Each translation involves:
    # - Forward pass through encoder: O(sequence_length²) due to self-attention
    # - Iterative decoding: O(output_length × sequence_length) for cross-attention
    # - Memory usage: ~(sequence_length × d_model × batch_size) floats

# === STEP 11: PERFORMANCE MONITORING ===
# CONCEPT: Track total training and inference time for optimization
# Important for production deployment planning

overall_time_elapsed = time.time() - overall_start_time
print(f"\nTotal elapsed time: {overall_time_elapsed // 60:.0f}m {overall_time_elapsed % 60:.0f}s")

# TIME BREAKDOWN ANALYSIS:
# - Data loading: Usually fast (I/O bound)
# - Tokenizer training: Medium (CPU bound, depends on corpus size)
# - Model pretraining: Slow (GPU/CPU bound, depends on model size)
# - Model fine-tuning: Medium (GPU/CPU bound, smaller dataset)
# - Inference: Fast (single forward passes)

print("\n=== TRAINING COMPLETE ===")
print("Model is ready for production use!")
print("You can now use the 'translator' object to translate English to Bengali.")