# Transformers for Translation in Action
Building and Training Encoder-Decoder Architecture to translate English to German

## Learning Objectives
* Understand how Transformer models process sequential text data
* Implement Encoder-Decoder Transformer from scratch using PyTorch based on "Attention is All You Need" by Vaswani et al. (2017)
* Train and evaluate on real translation datasets from Hugging Face

References:
* [Attention is All You Need ](https://arxiv.org/abs/1706.03762)
* [The Annotated Transformer](https://nlp.seas.harvard.edu/annotated-transformer/)

In [1]:
try:
    import google.colab
    !pip install transformers datasets tokenizers torch matplotlib numpy
except:
    print("Running Locally")



In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
import math
import time
from tqdm.auto import tqdm

# Set device, Make sure you have a GPU sessions. Runtime -> Change runtime Type -> T4 GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


## Part 1: Loading and Preparing Data

### 1.1 Load Dataset from Hugging Face

If running locally, you can save time downloading the dataset by downloading from Google Drive via [this link](https://drive.google.com/file/d/1W_0BAmoT2gkSKf_vEiEJYUGLX8fKOBK_/view?usp=sharing) with Monash account. Unzip it and paste that to the `datasets` folder.

In [5]:
def load_translation_data():
    """Load and preprocess WMT14 EN-DE dataset"""
    # Using a smaller subset for demonstration
    dataset = load_dataset("wmt14", "de-en", split="train[:10000]")  # Small subset
    val_dataset = load_dataset("wmt14", "de-en", split="validation[:1000]")

    # Extract English and German sentences
    en_sentences = [item['translation']['en'] for item in dataset]
    de_sentences = [item['translation']['de'] for item in dataset]

    val_en = [item['translation']['en'] for item in val_dataset]
    val_de = [item['translation']['de'] for item in val_dataset]

    return en_sentences, de_sentences, val_en, val_de

en_train, de_train, en_val, de_val = load_translation_data()
print("Dataset loaded!")
print(f"Training examples: {len(en_train)}")
print(f"Validation examples: {len(en_val)}")
print("\nExample pairs:")
for i in range(3):
    print(f"EN: {en_train[i]}")
    print(f"DE: {de_train[i]}")
    print("-" * 50)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

de-en/train-00000-of-00003.parquet:   0%|          | 0.00/280M [00:00<?, ?B/s]

de-en/train-00001-of-00003.parquet:   0%|          | 0.00/265M [00:00<?, ?B/s]

de-en/train-00002-of-00003.parquet:   0%|          | 0.00/273M [00:00<?, ?B/s]

de-en/validation-00000-of-00001.parquet:   0%|          | 0.00/474k [00:00<?, ?B/s]

de-en/test-00000-of-00001.parquet:   0%|          | 0.00/509k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/4508785 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3003 [00:00<?, ? examples/s]

Dataset loaded!
Training examples: 10000
Validation examples: 1000

Example pairs:
EN: Resumption of the session
DE: Wiederaufnahme der Sitzungsperiode
--------------------------------------------------
EN: I declare resumed the session of the European Parliament adjourned on Friday 17 December 1999, and I would like once again to wish you a happy new year in the hope that you enjoyed a pleasant festive period.
DE: Ich erkläre die am Freitag, dem 17. Dezember unterbrochene Sitzungsperiode des Europäischen Parlaments für wiederaufgenommen, wünsche Ihnen nochmals alles Gute zum Jahreswechsel und hoffe, daß Sie schöne Ferien hatten.
--------------------------------------------------
EN: Although, as you will have seen, the dreaded 'millennium bug' failed to materialise, still the people in a number of countries suffered a series of natural disasters that truly were dreadful.
DE: Wie Sie feststellen konnten, ist der gefürchtete "Millenium-Bug " nicht eingetreten. Doch sind Bürger einiger u

### 1.2 Build Tokenizers

A tokenizer is a crucial tool that acts as a translator. It converts human-readable text (like a sentence) into a format that a computer can understand and process: numerical data.

The process of tokenization typically involves two main steps:

1. Breaking down the text: The tokenizer takes a long string of text and splits it into smaller, meaningful units called tokens. These tokens can be individual words, characters, or, as in the case of the code you provided, parts of words (subwords).

2. Converting to numbers: Each unique token is assigned a unique numerical ID. The model then works with these numbers, which are far easier for a computer to process than raw text.

By creating a vocabulary of tokens and their corresponding IDs, the tokenizer allows the model to map between the text it "sees" and the numerical representations it can learn from.

The provided Python code is a script for creating and training a Byte-Pair Encoding (BPE) tokenizer, a popular subword tokenization method used in many modern language models. The script is building two separate tokenizers: one for English and one for German.

In [6]:
def create_tokenizers(en_sentences, de_sentences, vocab_size=8000):
    """Create BPE tokenizers for English and German"""

    # English tokenizer
    en_tokenizer = Tokenizer(BPE())
    en_tokenizer.pre_tokenizer = Whitespace()

    en_trainer = BpeTrainer(
        vocab_size=vocab_size,
        special_tokens=["<pad>", "<sos>", "<eos>", "<unk>"]
    )

    # Train the English tokenizer
    en_tokenizer.train_from_iterator(en_sentences, en_trainer)

    # German tokenizer
    de_tokenizer = Tokenizer(BPE())
    de_tokenizer.pre_tokenizer = Whitespace()

    de_trainer = BpeTrainer(
        vocab_size=vocab_size,
        special_tokens=["<pad>", "<sos>", "<eos>", "<unk>"]
    )

    # Train the German tokenizer
    de_tokenizer.train_from_iterator(de_sentences, de_trainer)

    return en_tokenizer, de_tokenizer

en_tokenizer, de_tokenizer = create_tokenizers(en_train, de_train)
print("Tokenizers created!")

# Test tokenization
test_en = "Hello, how are you?"
test_de = "Hallo, wie geht es dir?"
print(f"EN tokens: {en_tokenizer.encode(test_en).tokens}")
print(f"DE tokens: {de_tokenizer.encode(test_de).tokens}")

Tokenizers created!
EN tokens: ['H', 'el', 'lo', ',', 'how', 'are', 'you', '?']
DE tokens: ['H', 'all', 'o', ',', 'wie', 'geht', 'es', 'd', 'ir', '?']


## Part 2: Transformer Architecture Implementation

### 2.1 Positional Encoding

In [18]:
class PositionalEncoding(nn.Module):
    """Sinusoidal positional encoding from the original paper"""

    def __init__(self, d_model, max_length=5000):
        super().__init__()

        pe = torch.zeros(max_length, d_model)
        position = torch.arange(0, max_length).unsqueeze(1).float()

        div_term = torch.exp(torch.arange(0, d_model, 2).float() *-(math.log(10000.0) / d_model))

        # Apply sin to even indices, cos to odd indices
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        # Register as buffer (not a parameter)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        # Add positional encoding to input embeddings
        return x + self.pe[:, :x.size(1)]

# Test positional encoding
pos_enc = PositionalEncoding(d_model=512)
test_input = torch.randn(2, 10, 512)  # batch_size=2, seq_len=10, d_model=512
output = pos_enc(test_input)
print(f"Input shape: {test_input.shape}")
print(f"Output shape: {output.shape}")

Input shape: torch.Size([2, 10, 512])
Output shape: torch.Size([2, 10, 512])


#### 2.2 Multi-Head Attention (Complete Implementation)

An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key.

We call our particular attention "Scaled Dot-Product Attention". The input consists of queries and keys of dimension $d_k$, and values of dimension dv. We compute the dot products of the query with all keys, divide each by $\sqrt{d_k}$, and apply a softmax function to obtain the weights on the values.

In practice, we compute the attention function on a set of queries simultaneously, packed together into a matrix $Q$.  The keys and values are also packed together into matrices $K$ and $V$.  We compute the matrix of outputs as:

$$
   \mathrm{Attention}(Q, K, V) = \mathrm{softmax}(\frac{QK^T}{\sqrt{d_k}})V
$$

Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. With a single attention head, averaging inhibits this.

$$
\mathrm{MultiHead}(Q, K, V) =
    \mathrm{Concat}(\mathrm{head_1}, ..., \mathrm{head_h})W^O \\
    \text{where}~\mathrm{head_i} = \mathrm{Attention}(QW^Q_i, KW^K_i, VW^V_i)
$$

Where the projections are parameter matrices $W^Q_i \in
\mathbb{R}^{d_{\text{model}} \times d_k}$, $W^K_i \in
\mathbb{R}^{d_{\text{model}} \times d_k}$, $W^V_i \in
\mathbb{R}^{d_{\text{model}} \times d_v}$ and $W^O \in
\mathbb{R}^{hd_v \times d_{\text{model}}}$.

In [20]:
class MultiHeadAttention(nn.Module):
    """Multi-head attention from 'Attention is All You Need'"""

    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # Linear projections for Q, K, V
        # We assume d_v equals d_k
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)

    def scaled_dot_product_attention(self, q, k, v, mask=None):
        """Compute scaled dot-product attention"""
        #TODO 1: Implement the attention formula ----------------------------
        # scores = Q @ K^T / sqrt(d_k)
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        #ENDTODO ------------------------------------------------------------

        # Apply mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        # Apply softmax
        attention_weights = F.softmax(scores, dim=-1) # YOUR CODE HERE

        # Apply attention to values
        output = torch.matmul(attention_weights, v) # YOUR CODE HERE

        return output, attention_weights

    def forward(self, query, key, value, mask=None):
        batch_size, _, d_model = query.size()

        # Get sequence lengths - they might be different for query vs key/value
        query_len = query.size(1)
        key_len = key.size(1)
        value_len = value.size(1)

        # 1. Linear projections
        Q = self.w_q(query)
        K = self.w_k(key)
        V = self.w_v(value)

        # 2. Reshape for multi-head attention
        #TODO 2: Implement the multihead attention reshaping ---------------
        # Use x.view() to reshape [batch_size, relevant_seq_len, num_heads, d_head]
        # and then transpose such that it is [batch_size, num_heads, relevant_seq_len, d_head]
        Q = Q.view(batch_size, query_len, self.num_heads, self.d_k).transpose(1, 2).contiguous()
        K = K.view(batch_size, key_len,   self.num_heads, self.d_k).transpose(1, 2).contiguous()
        V = V.view(batch_size, value_len, self.num_heads, self.d_k).transpose(1, 2).contiguous()
        #ENDTODO ------------------------------------------------------------

        # 3. Apply attention
        attention_output, attention_weights = self.scaled_dot_product_attention(Q, K, V, mask)

        # 4. Concatenate heads
        attention_output = attention_output.transpose(1, 2).contiguous().view(
            batch_size, query_len, d_model
        )

        # 5. Final linear projection
        output = self.w_o(attention_output)

        return output, attention_weights


In [21]:
# --- Test Case Setup: Scaled dot product attention ---

# Define small, simple tensors for Q, K, and V
# Query: 1 sequence of 4 features
# Key/Value: 2 sequences of 4 features

d_model = 4  # New d_model
num_heads = 1
d_k = d_model // num_heads

q_org = torch.tensor([[[1.0, 2.0, 3.0, 4.0]]], dtype=torch.float32)  # shape (1, 1, 4)
k_org = torch.tensor([[[5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]]], dtype=torch.float32) # shape (1, 2, 4)
v_org = torch.tensor([[[13.0, 14.0, 15.0, 16.0], [17.0, 18.0, 19.0, 20.0]]], dtype=torch.float32) # shape (1, 2, 4)

multiheadattention = MultiHeadAttention(d_model=d_model, num_heads=num_heads)

# Your calculation
your_outputs, your_weights = multiheadattention.scaled_dot_product_attention(q_org, k_org, v_org)


# Manual Calculation
d_k = k_org.size(-1)
manual_scores = torch.matmul(q_org, k_org.transpose(-2, -1)) / math.sqrt(d_k)
manual_attention_weights = F.softmax(manual_scores, dim=-1)
manual_output = torch.matmul(manual_attention_weights, v_org)

# Use torch.allclose for a robust comparison of floating-point numbers
weights_match = torch.allclose(your_weights, manual_attention_weights)
output_match = torch.allclose(your_outputs, manual_output)

if weights_match and output_match:
    print("\n🎉 Congratulations! Your implementation is correct! 🎉")
else:
    print("\n❌ Not quite. Compare your output to the step-by-step manual calculation above.")



🎉 Congratulations! Your implementation is correct! 🎉


In [22]:
# --- Test Case Setup: Multi head attention ---
d_model = 4
num_heads = 2
d_k = d_model // num_heads

# Your calculation
multiheadattention = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
_, your_attention_weights = multiheadattention(q_org, k_org, v_org)

# Manual Calculation
Q_projected = multiheadattention.w_q(q_org)
K_projected = multiheadattention.w_k(k_org)
V_projected = multiheadattention.w_v(v_org)
manual_Q_reshaped = Q_projected.view(1, q_org.size(1), num_heads, d_k).transpose(1, 2)
manual_K_reshaped = K_projected.view(1, k_org.size(1), num_heads, d_k).transpose(1, 2)
manual_V_reshaped = V_projected.view(1, v_org.size(1), num_heads, d_k).transpose(1, 2)
_, manual_attention_weights = multiheadattention.scaled_dot_product_attention(manual_Q_reshaped, manual_K_reshaped, manual_V_reshaped, mask=None)

weights_match = torch.allclose(your_attention_weights, manual_attention_weights)
if weights_match:
    print("\n🎉 Congratulations! Your implementation is correct! 🎉")
else:
    print("\n❌ Not quite. Compare your output to the step-by-step manual calculation above.")


🎉 Congratulations! Your implementation is correct! 🎉


### 2.3 Encoder Layer

The encoder is composed of a stack of $N=6$ identical layers.

Each layer has two sub-layers with residual connection and layer norm.
1. A multi-head self-attention mechanism with residual connection and layer norm
2. A simple, position-wise fully connected feed-forward network with residual connection and layer norm

That is, the output of each sub-layer is $\mathrm{LayerNorm}(x + \mathrm{Sublayer}(x))$, where $\mathrm{Sublayer}(x)$ is the function implemented by the sub-layer itself.  

Dropout is applied [(cite)](http://jmlr.org/papers/v15/srivastava14a.html) to the output of each sub-layer, before it is added to the sub-layer input and normalized.

In [23]:
class EncoderLayer(nn.Module):
    """Single encoder layer from the transformer"""

    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attention = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        #TODO 3: Implement encoder layer forward pass
        # 1. Self-attention with residual connection and layer norm
        # 2. Feed-forward with residual connection and layer norm

        # Self-attention block
        attn_output, _ = self.self_attention(x, x, x, mask)
        x = None # YOUR CODE HERE: add & norm

        # Feed-forward block
        ff_output = self.feed_forward(x)
        x = None # YOUR CODE HERE: add & norm
        #ENDTODO ------------------------------------------------------------

        return x

### 2.4 Decoder Layer

The decoder is also composed of a stack of $N=6$ identical layers.

In addition to the two sub-layers in each encoder layer, the decoder inserts a third sub-layer, which performs multi-head attention over the output of the encoder stack.  Similar to the encoder, we employ residual connections around each of the sub-layers, followed by layer normalization.

In [None]:
class DecoderLayer(nn.Module):
    """Single decoder layer with masked self-attention and encoder-decoder attention"""

    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attention = MultiHeadAttention(d_model, num_heads)
        self.encoder_attention = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        #TODO 4: Implement decoder layer
        # 1. Masked self-attention
        # 2. Encoder-decoder attention
        # 3. Feed-forward
        # Each with residual connections and layer norm

        # Masked self-attention
        self_attn_output, _ = self.self_attention(x, x, x, tgt_mask)
        x = None # YOUR CODE HERE

        # Encoder-decoder attention
        enc_attn_output, _ = self.encoder_attention(x, encoder_output, encoder_output, src_mask)
        x = None # YOUR CODE HERE

        # Feed-forward
        ff_output = self.feed_forward(x)
        x = None  # YOUR CODE HERE
        #ENDTODO ------------------------------------------------------------

        return x

### 2.5 Complete Transformer Model

In [None]:
class Transformer(nn.Module):
    """Complete encoder-decoder transformer"""

    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, num_heads=8,
                 num_encoder_layers=6, num_decoder_layers=6, d_ff=2048, dropout=0.1):
        super().__init__()

        self.d_model = d_model

        # Embeddings
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model)

        # Encoder
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_encoder_layers)
        ])

        # Decoder
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_decoder_layers)
        ])

        # Output projection
        self.output_projection = nn.Linear(d_model, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def create_padding_mask(self, seq, pad_token_id=0):
        """Create mask for padding tokens"""
        return (seq != pad_token_id).unsqueeze(1).unsqueeze(2)

    def create_look_ahead_mask(self, size):
        """Create look-ahead mask for decoder"""
        mask = torch.triu(torch.ones(size, size), diagonal=1)
        return mask == 0

    def encode(self, src, src_mask):
        """Encode source sequence"""
        # Embedding + positional encoding
        x = self.src_embedding(src) * math.sqrt(self.d_model)
        x = self.positional_encoding(x)
        x = self.dropout(x)

        # Pass through encoder layers
        for layer in self.encoder_layers:
            x = layer(x, src_mask)

        return x

    def decode(self, tgt, encoder_output, src_mask, tgt_mask):
        """Decode target sequence"""
        # Embedding + positional encoding
        x = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
        x = self.positional_encoding(x)
        x = self.dropout(x)

        # Pass through decoder layers
        for layer in self.decoder_layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)

        return x

    def generate(self, tgt, encoder_output, src_mask, tgt_mask):
      "Generate one token at a time"
      x = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
      x = self.positional_encoding(x)


    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        # Create masks if not provided
        if src_mask is None:
            src_mask = self.create_padding_mask(src)
        if tgt_mask is None:
            tgt_mask = self.create_padding_mask(tgt) & self.create_look_ahead_mask(tgt.size(1)).to(tgt.device)

        # Encode and decode
        encoder_output = self.encode(src, src_mask)
        decoder_output = self.decode(tgt, encoder_output, src_mask, tgt_mask)

        # Project to vocabulary
        output = self.output_projection(decoder_output)

        return output


In [None]:
src_vocab_size = en_tokenizer.get_vocab_size()   # Update after tokenizer training
tgt_vocab_size = de_tokenizer.get_vocab_size()   # Update after tokenizer training

model = Transformer(
    src_vocab_size=src_vocab_size,
    tgt_vocab_size=tgt_vocab_size,
    d_model=512,
    num_heads=8,
    num_encoder_layers=6,
    num_decoder_layers=6,
    d_ff=2048,
    dropout=0.1
).to(device)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

## Part 3: Data Processing and Training Setup

### 3.1 Create Dataset Class

In [None]:
class TranslationDataset(torch.utils.data.Dataset):
    """Dataset for translation training"""

    def __init__(self, src_sentences, tgt_sentences, src_tokenizer, tgt_tokenizer, max_length=128):
        self.src_sentences = src_sentences
        self.tgt_sentences = tgt_sentences
        self.src_tokenizer = src_tokenizer
        self.tgt_tokenizer = tgt_tokenizer
        self.max_length = max_length

        self.src_sos_token = src_tokenizer.encode('<sos>').ids
        self.src_eos_token = src_tokenizer.encode('<eos>').ids
        self.tgt_sos_token = tgt_tokenizer.encode('<sos>').ids
        self.tgt_eos_token = tgt_tokenizer.encode('<eos>').ids

        self.pad_token_id = 0

    def __len__(self):
        return len(self.src_sentences)

    def __getitem__(self, idx):
        src_text = self.src_sentences[idx]
        tgt_text = self.tgt_sentences[idx]

        # 1. Encode source sentence
        # 2. Encode target sentence with <sos> and <eos> tokens
        # 3. Pad to max_length

        # Tokenize source
        src_tokens = self.src_tokenizer.encode(src_text).ids
        src_tokens = src_tokens[:self.max_length-1]

        # Tokenize target (add <sos> at start, <eos> at end)
        tgt_tokens = self.tgt_tokenizer.encode(tgt_text).ids
        tgt_input = self.tgt_sos_token + tgt_tokens[:self.max_length-2]  # <sos> + tokens
        tgt_output = tgt_tokens[:self.max_length-2] + self.tgt_eos_token  # tokens + <eos>

        # Pad target sequences
        src_tokens = src_tokens + [self.pad_token_id] * max(0, self.max_length - len(src_tokens))
        tgt_input = tgt_input + [self.pad_token_id] * max(0, self.max_length - len(tgt_input))
        tgt_output = tgt_output + [self.pad_token_id] * max(0, self.max_length - len(tgt_output))

        # Assertions to verify lengths
        assert len(src_tokens) == self.max_length, f"Source length mismatch: {len(src_tokens)} != {self.max_length}"
        assert len(tgt_input) == self.max_length, f"Target input length mismatch: {len(tgt_input)} != {self.max_length}"
        assert len(tgt_output) == self.max_length, f"Target output length mismatch: {len(tgt_output)} != {self.max_length}"

        return {
            'src': torch.tensor(src_tokens[:self.max_length], dtype=torch.long),
            'tgt_input': torch.tensor(tgt_input[:self.max_length], dtype=torch.long),
            'tgt_output': torch.tensor(tgt_output[:self.max_length], dtype=torch.long)
        }

# Create datasets
train_dataset = TranslationDataset(en_train, de_train, en_tokenizer, de_tokenizer)
val_dataset = TranslationDataset(en_val, de_val, en_tokenizer, de_tokenizer)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

### 3.2 Training Function

In [None]:
def train_epoch(model, train_loader, optimizer, criterion, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0

    for batch in tqdm(train_loader, desc="Training"):
        src = batch['src'].to(device)
        tgt_input = batch['tgt_input'].to(device)
        tgt_output = batch['tgt_output'].to(device)

        optimizer.zero_grad()

        # Forward pass
        output = model(src, tgt_input)

        # Compute loss
        # Reshape output and target for cross-entropy loss
        output = output.view(-1, output.size(-1))
        tgt_output = tgt_output.view(-1)

        loss = criterion(output, tgt_output)

        # Backward pass
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(train_loader)

def validate(model, val_loader, criterion, device):
    """Validate the model"""
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            src = batch['src'].to(device)
            tgt_input = batch['tgt_input'].to(device)
            tgt_output = batch['tgt_output'].to(device)

            output = model(src, tgt_input)
            output = output.view(-1, output.size(-1))
            tgt_output = tgt_output.view(-1)

            loss = criterion(output, tgt_output)
            total_loss += loss.item()

    return total_loss / len(val_loader)

### 3.3 Training Loop

In [None]:
from IPython.display import clear_output
def train_transformer(num_epochs=3,  patience=10, model_path='model.pt'):
    """Complete training setup and loop"""

    # Training hyperparameters
    num_epochs = num_epochs
    learning_rate = 1e-4
    weight_decay = 1e-4  # L2 regularization
    grad_clip = 1.0  # Gradient clipping to prevent exploding gradients
    warmup_epochs = 5  # Warmup learning rate

    # Loss function (ignore padding tokens)
    criterion = nn.CrossEntropyLoss(ignore_index=0, label_smoothing=0.1)

    # Optimizer with learning rate scheduling
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay)

    # Learning rate scheduler with warmup
    def lr_lambda(epoch):
        if epoch < warmup_epochs:
            return (epoch + 1) / warmup_epochs
        return 1.0

    warmup_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    main_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs-warmup_epochs)

    # Early stopping variables
    best_val_loss = np.inf
    epochs_no_improve = 0

    # Training loop
    train_losses = []
    val_losses = []
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")

        # Train
        train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
        train_losses.append(train_loss)

        # Validate
        val_loss = validate(model, val_loader, criterion, device)
        val_losses.append(val_loss)

        # Update learning rate
        if epoch < warmup_epochs:
            warmup_scheduler.step()
        else:
            main_scheduler.step()

        current_lr = optimizer.param_groups[0]['lr']
        # print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

        # Plot train and val losses
        clear_output(wait=True)
        plt.figure(figsize=(10, 6))
        epochs_range = list(range(1, len(train_losses) + 1))
        plt.plot(epochs_range, train_losses, 'b-o', label='Training Loss', linewidth=2)
        plt.plot(epochs_range, val_losses, 'r-s', label='Validation Loss', linewidth=2)
        plt.xlabel('Epoch', fontsize=12)
        plt.ylabel('Loss', fontsize=12)
        plt.title('Training and Validation Loss', fontsize=14, fontweight='bold')
        plt.legend(fontsize=11)
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.show()

        if val_loss < best_val_loss:
            print(f'Validation loss decreased ({best_val_loss:.6f} --> {val_loss:.6f}). Saving model...')
            best_val_loss = val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
            }, model_path)
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

            # Early stopping check
            if epochs_no_improve >= patience:
                print(f"\nEarly stopping triggered at epoch {epoch + 1}")
                print(f"Best validation loss: {best_val_loss:.6f}")
                break

    return train_losses, val_losses
train_losses, val_losses = train_transformer()

### 3.4 Inference and Translation

You can download the model trained by Kavi with [this link](https://drive.google.com/file/d/1CnqoS-S9jh5YQzinpOFElpv4bN4AGYgQ/view?usp=sharing). Please put that to the same folder as this notebook.

In [None]:
max_src_len = 128

# Uncomment this if you download the model with the link above.
# model.load_state_dict(torch.load('best_model.pt')['model_state_dict'])

def translate_sentence(model, sentence, src_tokenizer, tgt_tokenizer, device, max_length=50):
    """Translate a single sentence using the trained model"""
    model.eval()

    with torch.no_grad():
        # Encode source sentence
        src_tokens = src_tokenizer.encode(sentence).ids
        src_len = len(src_tokens)

        # Pad source to max_src_len (same as training)
        max_src_len = 128
        src_tokens = src_tokens[:max_src_len]  # Truncate if too long
        padded_src = src_tokens + [0] * (max_src_len - len(src_tokens))  # Pad if too short
        src = torch.tensor([padded_src], dtype=torch.long).to(device)

        # Create source mask - CRITICAL: mask out padding tokens
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)  # [1, 1, 1, src_len]
        # Encode source
        encoder_output = model.encode(src, None)

        # Initialize decoder input with <sos> token
        sos_id = tgt_tokenizer.encode('<sos>').ids[0]
        eos_id = tgt_tokenizer.encode('<eos>').ids[0]

        # Initialize decoder input with <sos> token
        tgt_input = torch.tensor([[sos_id]], dtype=torch.long).to(device)  # <sos>
        generated_tokens = []

        for _ in range(max_length):
            tgt_len = tgt_input.size(1)
            tgt_mask = torch.tril(torch.ones(tgt_len, tgt_len)).unsqueeze(0).unsqueeze(0).to(device)
            tgt_mask = tgt_mask.bool()

            # Get decoder output
            decoder_output = model.decode(tgt_input, encoder_output, src_mask, tgt_mask)

            # Get next token probabilities
            next_token_logits = model.output_projection(decoder_output[:, -1, :])
            next_token = torch.argmax(next_token_logits, dim=-1).item()

            if next_token == eos_id:
                break

            # Append the new token to our sequence
            generated_tokens.append(next_token)
            tgt_input = torch.cat([tgt_input, torch.tensor([[next_token]], device=device)], dim=1)

        translation = tgt_tokenizer.decode(generated_tokens)
        return translation
test_sentence = "Hello, how are you?"
translation = translate_sentence(model, test_sentence, en_tokenizer, de_tokenizer, device)
print(f"Source: {test_sentence}")
print(f"Translation: {translation}")