In [None]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader, Subset

import math

from datasets import load_dataset

# Purpose
I am gonna develop a translation model based on attention (both encoding and decoding)

From """"**SCRATCH**!!!!**!!!"""**

English -> Italian

In [None]:
ds = load_dataset("Helsinki-NLP/opus-100", "en-it")

example = ds['train'][100]
print(example)

{'translation': {'en': "What's going on?", 'it': 'Che succede?'}}


Process dataset into training and val

In [None]:
# This function will process a "batch" of examples at once
def extract_translations(batch):
  return {
      'en_text': [t['en'] for t in batch['translation']],
      'it_text': [t['it'] for t in batch['translation']],
  }

# .map() will apply this function to the whole dataset very quickly
# batched=True is the key to making it fast
processed_ds_train = ds['train'].map(extract_translations, batched=True)

x_train = processed_ds_train['en_text']
y_train = processed_ds_train['it_text']

processed_ds_val = ds['validation'].map(extract_translations, batched=True)

x_val = processed_ds_val['en_text']
y_val = processed_ds_val['it_text']

Character level tokenizer just like karpathy tutorials

In [None]:
all_text = "".join(x_train) + "".join(y_train)
chars = sorted(list(set(all_text)))

stoi = {c:i for i, c in enumerate(chars)}
stoi['<PAD>'] = len(stoi)

itos = {i: ch for ch, i in stoi.items()}

Translation dataset struct


In [None]:
class TranslationDataset(Dataset):
    def __init__(self, en_texts, it_texts, stoi, max_len=128):
        """
        Args:
            en_texts: List of English sentences
            it_texts: List of Italian sentences
            stoi: Character to index mapping
            max_len: Maximum sequence length
        """
        # Filter pairs that fit within max_len so that we can fit them into the context window
        filtered_pairs = [
            (en, it) for en, it in zip(en_texts, it_texts)
            if len(en) <= max_len and len(it) <= max_len
        ]
        
        if filtered_pairs:
            self.en_texts, self.it_texts = zip(*filtered_pairs)
        else:
            self.en_texts, self.it_texts = [], []
            
        self.stoi = stoi
        self.max_len = max_len
    
    def __len__(self):
        return len(self.en_texts)
    
    def __getitem__(self, idx):
        # Encode the texts
        en_encoded = torch.tensor([self.stoi[char] for char in self.en_texts[idx]])
        it_encoded = torch.tensor([self.stoi[char] for char in self.it_texts[idx]])
        
        return en_encoded, it_encoded

# This is used by the dataloader to stack together the batch into one tensor by using padding
def collate_fn(batch):
    """
    Custom collate function to pad sequences in each batch
    Args:
        batch: List of tuples (en_tensor, it_tensor)
    Returns:
        x_batch: Padded English sequences
        y_batch: Padded Italian sequences
    """
    en_batch, it_batch = zip(*batch)
    
    # Pad sequences
    x_batch = pad_sequence(en_batch, batch_first=True, padding_value=stoi['<PAD>'])
    y_batch = pad_sequence(it_batch, batch_first=True, padding_value=stoi['<PAD>'])
    
    return x_batch, y_batch

Config

In [None]:
vocab_size = len(stoi)
d_embd = 128
context_window = 128
n_heads = 8 
head_size = d_embd // n_heads  

Pick n samples for experimenting

In [None]:
sample_size = 50

In [None]:
# Create datasets
train_dataset = TranslationDataset(x_train, y_train, stoi, max_len=context_window)
val_dataset = TranslationDataset(x_val, y_val, stoi, max_len=context_window)

# For experimenting with smaller samples
train_dataset_subset = Subset(train_dataset, range(min(sample_size, len(train_dataset))))

print(f"Training samples subset: {len(train_dataset_subset)}")
print(f"Validation samples: {len(val_dataset)}")

# Create data loaders
batch_size = 32 
train_loader = DataLoader(
    train_dataset_subset, 
    batch_size=batch_size,
    shuffle=True,  
    collate_fn=collate_fn,
    num_workers=0  # Set to 0 for debugging, increase for faster loading
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,  
    collate_fn=collate_fn,
    num_workers=0
)

# Test the data loader
for x_batch, y_batch in train_loader:
    print("Batch x shape:", x_batch.shape)
    print("Batch y shape:", y_batch.shape)
    print("Sample x:", x_batch[0])
    print("Sample y:", y_batch[0])
    break

Training samples subset: 50
Validation samples: 1681
Batch x shape: torch.Size([32, 69])
Batch y shape: torch.Size([32, 94])
Sample x: tensor([  18,   14, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290,
        1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290,
        1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290,
        1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290,
        1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290,
        1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290])
Sample y: tensor([  18,   14, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290,
        1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290,
        1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290,
        1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290,
        1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290, 1290

Embedding (Token and Positional)

In [None]:
class TokenEmbedding(nn.Module):
  """
  Embedding layer for the tokens

  Args:
    vocab_size: size of the vocabulary
    d_embd: dimension of the embeddings
    padding_idx: index of the padding token
  """
  def __init__(self, vocab_size, d_embd, padding_idx=None):
    super().__init__()
    # for translation task we should have padding_idx = stoi['<PAD>'] to identify the padding tokens
    self.embd = nn.Embedding(vocab_size, d_embd, padding_idx=padding_idx)

  def forward(self, x):
    return self.embd(x)

In [None]:
class PositionalEmbedding(nn.Module):
  """
  Embedding layer for the positional encodings

  Args:
    n_tokens: number of tokens in the sequence
    d_embd: dimension of the embeddings
  """
  def __init__(self, n_tokens, d_embd):
    super().__init__()
    self.embd = nn.Embedding(n_tokens, d_embd)

  def forward(self, x):
    T = x.shape[1]
    pos = torch.arange(T, device=x.device)
    return self.embd(pos)

Single Head (with causal and padding masking)

In [None]:
class Head(nn.Module):
  """
  Single Head of the attention mechanism

  Args:
    d_embd: dimension of the embeddings
    head_size: dimension of the head
    dropout: dropout rate
  """
  def __init__(self, d_embd, head_size, dropout=0.1):
    super().__init__()
    self.query = nn.Linear(d_embd, head_size, bias=False)  
    self.key = nn.Linear(d_embd, head_size, bias=False)    
    self.value = nn.Linear(d_embd, head_size, bias=False)  
    self.dropout = nn.Dropout(dropout)
    self.register_buffer('tril', torch.tril(torch.ones(context_window, context_window)))

  def forward(self, x, src_kv=None, key_padding_mask=None, causal_mask=False):
    """
    Args:
        x: (B, T, d_embd) - Input tensor
        src_kv: (B, T, d_embd) - Source key and value tensor
        key_padding_mask: (B, T) - Boolean mask (True for padding positions)
        causal_mask: bool - Whether to apply causal mask
    """
    _, q_pos, _ = x.shape

    # (B, T, d_embd) -> (B, pos, head_size)
    q = self.query(x)
    if src_kv is not None:
      k = self.key(src_kv)
      v = self.value(src_kv)
    else:
      k = self.key(x)
      v = self.value(x)

    # (B, q_pos, head_size) @ (B, (k_pos, head_size)^T) -> (B, q_pos, k_pos)
    qk = (q @ k.transpose(-2, -1)) * (1 / math.sqrt(k.size(-1))) 

    # for decoder
    if causal_mask: 
      # Note: k_pos = q_pos
      qk = qk.masked_fill(self.tril[:q_pos, :q_pos] == 0, float('-inf')) # (B, q_pos, q_pos)

    if key_padding_mask is not None: 
      expanded_mask = key_padding_mask.unsqueeze(1) # (B, 1, k_pos)
      qk = qk.masked_fill(expanded_mask, float('-inf')) # (B, q_pos, k_pos)

    attn = torch.softmax(qk, dim=-1)
    attn = self.dropout(attn)
    out = attn @ v
    return out

In [None]:
class MultiHeadAttention(nn.Module):
  """
  Multiple heads of attention in parallel
  
  Args:
    d_embd: dimension of the embeddings
    n_heads: number of attention heads
    dropout: dropout rate
  """
  def __init__(self, d_embd, n_heads, dropout=0.1):
    super().__init__()
    assert d_embd % n_heads == 0, "d_embd must be divisible by n_heads"
    
    self.n_heads = n_heads
    self.head_size = d_embd // n_heads
    
    # Create multiple heads in parallel
    self.heads = nn.ModuleList([Head(d_embd, self.head_size, dropout) for _ in range(n_heads)])
    
    self.proj = nn.Linear(d_embd, d_embd)
    self.dropout = nn.Dropout(dropout)
  
  def forward(self, x, src=None, key_padding_mask=None, causal_mask=False):
    """
    Args:
        x: (B, T_q, d_embd) - Query input tensor
        src: (B, T_kv, d_embd) - Key/Value input tensor (for cross-attention)
        key_padding_mask: (B, T_kv) - Boolean mask (True for padding positions)
        causal_mask: bool - Whether to apply causal mask
    Returns:
        (B, T_q, d_embd) - Attention output
    """
    # Run all heads in parallel and concatenate outputs
    # Each head outputs (B, T_q, head_size)
    out = torch.cat([h(x, src, key_padding_mask, causal_mask) for h in self.heads], dim=-1)
    
    # Output projection and dropout
    out = self.dropout(self.proj(out))
    
    return out


MLP

In [None]:
class MLP(nn.Module):
    """
    Multi-layer perceptron

    Args:
        d_embd: dimension of the embeddings
        dropout: dropout rate
    """
    def __init__(self, d_embd, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_embd)
        self.mlp = nn.Sequential(
            nn.Linear(d_embd, 4 * d_embd),
            nn.GELU(),
            nn.Linear(4 * d_embd, d_embd),  
            nn.Dropout(dropout)              
        )
        
    def forward(self, x):
        return self.mlp(self.ln1(x))


Encoder

In [None]:
class EncoderBlock(nn.Module):
    """
    Encoder block with:
    1. Multi-Head Self-attention
    2. Feed-forward network

    Args:
        d_embd: dimension of the embeddings
        n_heads: number of attention heads
        dropout: dropout rate
    """
    def __init__(self, d_embd, n_heads, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_embd)
        self.attn = MultiHeadAttention(d_embd, n_heads, dropout)
        self.ln2 = nn.LayerNorm(d_embd)
        self.mlp = MLP(d_embd, dropout)
  
    def forward(self, x, key_padding_mask=None):
        """
        Args:
            x: (B, T, d_embd) - Input tensor
            key_padding_mask: (B, T) - Boolean mask (True for padding positions)
        Returns:
            (B, T, d_embd) - Output tensor
        """
        x = x + self.attn(self.ln1(x), src=None, key_padding_mask=key_padding_mask, causal_mask=False)
        x = x + self.mlp(self.ln2(x))
        return x

In [None]:
class Encoder(nn.Module):
  """
  Encoder with multiple blocks of multi-head self-attention and feed-forward networks

  Args:
    vocab_size: size of the vocabulary
    d_embd: dimension of the embeddings
    n_heads: number of attention heads
    dropout: dropout rate
    n_blocks: number of blocks in the encoder
  """
  def __init__(self, vocab_size, d_embd, n_heads, dropout=0.1, n_blocks=4):
    super().__init__()
    self.tok_emb = TokenEmbedding(vocab_size, d_embd, padding_idx=stoi['<PAD>'])
    self.pos_emb = PositionalEmbedding(context_window, d_embd)
    self.blocks = nn.ModuleList([EncoderBlock(d_embd, n_heads, dropout) for _ in range(n_blocks)])
    self.ln_f = nn.LayerNorm(d_embd)  
  
  def forward(self, x, key_padding_mask=None):
    """
    Args:
        x: (B, T) - Input token indices
        key_padding_mask: (B, T) - Boolean mask (True for padding positions)
    Returns:
        (B, T, d_embd) - Encoded representations
    """
    tok_emb = self.tok_emb(x)
    pos_emb = self.pos_emb(x)
    x = tok_emb + pos_emb
    
    for block in self.blocks:
      x = block(x, key_padding_mask=key_padding_mask)
    
    x = self.ln_f(x)
    return x


Decoder

In [None]:
class DecoderBlock(nn.Module):
  """
  Decoder block with:
  1. Masked multi-head self-attention (causal)
  2. Multi-head cross-attention to encoder output
  3. Feed-forward network

  Args:
    d_embd: dimension of the embeddings
    n_heads: number of attention heads
    dropout: dropout rate
  """
  def __init__(self, d_embd, n_heads, dropout=0.1):
    super().__init__()
    self.ln1 = nn.LayerNorm(d_embd)
    self.ln2 = nn.LayerNorm(d_embd)
    self.ln3 = nn.LayerNorm(d_embd)
    
    self.self_attn = MultiHeadAttention(d_embd, n_heads, dropout)   
    self.cross_attn = MultiHeadAttention(d_embd, n_heads, dropout)  

    self.mlp = MLP(d_embd, dropout)
  
  def forward(self, x, encoder_out, tgt_key_padding_mask=None, src_key_padding_mask=None):
    """
    Args:
        x: (B, T_tgt, d_embd) - Decoder input
        encoder_out: (B, T_src, d_embd) - Encoder output
        tgt_key_padding_mask: (B, T_tgt) - Decoder padding mask
        src_key_padding_mask: (B, T_src) - Encoder padding mask
    """
    # 1. Masked self-attention (decoder attends to previous positions)
    x = x + self.self_attn(
        self.ln1(x), 
        src=None,  
        key_padding_mask=tgt_key_padding_mask, 
        causal_mask=True  
    )
    
    # 2. Cross-attention (decoder attends to encoder output)
    x = x + self.cross_attn(
        self.ln2(x), 
        src=encoder_out,  
        key_padding_mask=src_key_padding_mask, 
        causal_mask=False  
    )
    
    x = x + self.mlp(self.ln3(x))
    
    return x


In [None]:
class Decoder(nn.Module):
  """
  Decoder with multiple blocks of masked multi-head self-attention and cross-attention

  Args:
    vocab_size: size of the vocabulary
    d_embd: dimension of the embeddings
    n_heads: number of attention heads
    dropout: dropout rate
    n_blocks: number of decoder blocks
  """
  def __init__(self, vocab_size, d_embd, n_heads, dropout=0.1, n_blocks=4):
    super().__init__()
    self.tok_emb = TokenEmbedding(vocab_size, d_embd, padding_idx=stoi['<PAD>'])
    self.pos_emb = PositionalEmbedding(context_window, d_embd)
    self.blocks = nn.ModuleList([DecoderBlock(d_embd, n_heads, dropout) for _ in range(n_blocks)])
    self.ln_f = nn.LayerNorm(d_embd)  
    self.lm_head = nn.Linear(d_embd, vocab_size)  
  
  def forward(self, x, encoder_out, tgt_key_padding_mask=None, src_key_padding_mask=None):
    """
    Returns:
        (B, T_tgt, vocab_size) - Logits for each position
    """
    tok_emb = self.tok_emb(x)
    pos_emb = self.pos_emb(x)
    x = tok_emb + pos_emb
    
    for block in self.blocks:
      x = block(x, encoder_out, tgt_key_padding_mask, src_key_padding_mask)
    
    x = self.ln_f(x) 
    logits = self.lm_head(x) 
    
    return logits


Testing transformer implementation

In [None]:
# Test the Encoder-Decoder architecture
print("="*60)
print("Testing Encoder-Decoder Architecture")
print("="*60)

# Get a batch from the data loader
x_batch, y_batch = next(iter(train_loader))

# Create masks
src_key_padding_mask = (x_batch == stoi['<PAD>'])  # Encoder mask
tgt_key_padding_mask = (y_batch == stoi['<PAD>'])  # Decoder mask

print(f"\nInput shapes:")
print(f"  English (source): {x_batch.shape}")
print(f"  Italian (target): {y_batch.shape}")

# Initialize encoder and decoder with multi-head attention
encoder = Encoder(vocab_size, d_embd, n_heads, dropout=0.1, n_blocks=2)
decoder = Decoder(vocab_size, d_embd, n_heads, dropout=0.1, n_blocks=2)

print(f"\nEncoder parameters: {sum(p.numel() for p in encoder.parameters()):,}")
print(f"Decoder parameters: {sum(p.numel() for p in decoder.parameters()):,}")

# Forward pass through encoder
encoder_out = encoder(x_batch, key_padding_mask=src_key_padding_mask)
print(f"\nEncoder output shape: {encoder_out.shape}")

# Forward pass through decoder
logits = decoder(
    y_batch, 
    encoder_out, 
    tgt_key_padding_mask=tgt_key_padding_mask,
    src_key_padding_mask=src_key_padding_mask
)
print(f"Decoder output (logits) shape: {logits.shape}")
print(f"Expected shape: (batch_size={x_batch.shape[0]}, seq_len={y_batch.shape[1]}, vocab_size={vocab_size})")

print("\n✓ Encoder-Decoder architecture working correctly!")


Seq2Seq Encoder/Decoder model

In [None]:
class Seq2SeqModel(nn.Module):
  """
  Complete Seq2Seq Translation Model combining Encoder and Decoder
  """
  def __init__(self, vocab_size, d_embd, n_heads, dropout=0.1, n_encoder_blocks=4, n_decoder_blocks=4):
    super().__init__()
    self.encoder = Encoder(vocab_size, d_embd, n_heads, dropout, n_encoder_blocks)
    self.decoder = Decoder(vocab_size, d_embd, n_heads, dropout, n_decoder_blocks)
  
  def forward(self, src, tgt, src_key_padding_mask=None, tgt_key_padding_mask=None):
    """
    Args:
        src: (B, T_src) - Source token indices (English)
        tgt: (B, T_tgt) - Target token indices (Italian)
        src_key_padding_mask: (B, T_src) - Source padding mask
        tgt_key_padding_mask: (B, T_tgt) - Target padding mask
    Returns:
        logits: (B, T_tgt, vocab_size) - Logits for each target position
    """
    # Encode source
    encoder_out = self.encoder(src, key_padding_mask=src_key_padding_mask)
    
    # Decode (with teacher forcing during training)
    logits = self.decoder(
        tgt, 
        encoder_out,
        tgt_key_padding_mask=tgt_key_padding_mask,
        src_key_padding_mask=src_key_padding_mask
    )
    
    return logits


Model initialization

In [None]:
import torch.optim as optim

# Initialize model
model = Seq2SeqModel(
    vocab_size=vocab_size,
    d_embd=d_embd,
    n_heads=n_heads,
    dropout=0.1,
    n_encoder_blocks=4,
    n_decoder_blocks=4
)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"\nTotal parameters: {total_params:,}")


In [None]:
# Training configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
model = model.to(device)

criterion = nn.CrossEntropyLoss(ignore_index=stoi['<PAD>'], reduction='mean')

optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)

Training

In [None]:
# Training hyperparameters
num_epochs = 100
print_every = 50  # Print loss every N batches
eval_every = 500  # Evaluate on validation set every N batches

In [None]:
# Training loop
model.train()
total_steps = 0
train_losses = []
val_losses = []

print("\n" + "="*60)
print("Starting Training")
print("="*60)

for epoch in range(num_epochs):
    epoch_loss = 0.0
    num_batches = 0
    
    for batch_idx, (src, tgt) in enumerate(train_loader):
        # Move to device
        src = src.to(device)  # (B, T_src)
        tgt = tgt.to(device)  # (B, T_tgt)
        
        # Teacher forcing: shift target for decoder input
        # Decoder input: all tokens except last [BOS, tok1, tok2, ..., tokN-1]
        # Decoder target: all tokens except first [tok1, tok2, ..., tokN, EOS]
        tgt_input = tgt[:, :-1]  # (B, T_tgt-1)
        tgt_output = tgt[:, 1:]  # (B, T_tgt-1)
        
        # Create masks
        src_key_padding_mask = (src == stoi['<PAD>'])  # (B, T_src)
        tgt_input_key_padding_mask = (tgt_input == stoi['<PAD>'])  # (B, T_tgt-1)
        
        # Forward pass
        optimizer.zero_grad()
        logits = model(
            src=src,
            tgt=tgt_input,  # Decoder input (shifted)
            src_key_padding_mask=src_key_padding_mask,
            tgt_key_padding_mask=tgt_input_key_padding_mask
        )  # (B, T_tgt-1, vocab_size)
        
        # Reshape for loss: (B*T, vocab_size) -> (B*T,)
        loss = criterion(logits.reshape(-1, vocab_size), tgt_output.reshape(-1))
        
        # Backward pass
        loss.backward()
        
        # Gradient clipping (optional but recommended)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        # Track loss
        epoch_loss += loss.item()
        num_batches += 1
        total_steps += 1
        
        # Print progress
        if batch_idx % print_every == 0:
            current_lr = optimizer.param_groups[0]['lr']
            print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(train_loader)}, "
                  f"Loss: {loss.item():.4f}, LR: {current_lr:.6f}")
        
        # Validation
        if total_steps % eval_every == 0:
            model.eval()
            val_loss = 0.0
            val_batches = 0
            
            with torch.no_grad():
                for val_src, val_tgt in val_loader:
                    val_src = val_src.to(device)
                    val_tgt = val_tgt.to(device)
                    
                    val_tgt_input = val_tgt[:, :-1]
                    val_tgt_output = val_tgt[:, 1:]
                    
                    val_src_mask = (val_src == stoi['<PAD>'])
                    val_tgt_mask = (val_tgt_input == stoi['<PAD>'])
                    
                    val_logits = model(
                        src=val_src,
                        tgt=val_tgt_input,
                        src_key_padding_mask=val_src_mask,
                        tgt_key_padding_mask=val_tgt_mask
                    )
                    
                    val_loss_batch = criterion(val_logits.reshape(-1, vocab_size), val_tgt_output.reshape(-1))
                    val_loss += val_loss_batch.item()
                    val_batches += 1
            
            avg_val_loss = val_loss / val_batches if val_batches > 0 else 0
            val_losses.append(avg_val_loss)
            print(f"\n{'='*60}")
            print(f"Validation @ Step {total_steps}: Loss = {avg_val_loss:.4f}")
            print(f"{'='*60}\n")
            
            model.train()
    
    # Epoch summary
    avg_epoch_loss = epoch_loss / num_batches if num_batches > 0 else 0
    train_losses.append(avg_epoch_loss)
    
    print(f"\nEpoch {epoch+1}/{num_epochs} completed:")
    print(f"  Average Train Loss: {avg_epoch_loss:.4f}")
    print(f"  Current LR: {optimizer.param_groups[0]['lr']:.6f}")
    print("-"*60)

print("\n" + "="*60)
print("Training Complete!")
print("="*60)


Plotting

In [None]:
# Plot training and validation losses
import matplotlib.pyplot as plt

if len(train_losses) > 0:
    plt.figure(figsize=(10, 6))
    
    # Plot training losses
    plt.plot(range(1, len(train_losses) + 1), train_losses, label='Train Loss', marker='o')
    
    # Plot validation losses if available
    if len(val_losses) > 0:
        val_steps = [eval_every * (i + 1) for i in range(len(val_losses))]
        plt.plot(val_steps, val_losses, label='Val Loss', marker='s')
    
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.grid(True)
    plt.show()
    
    print(f"\nFinal Training Loss: {train_losses[-1]:.4f}")
    if len(val_losses) > 0:
        print(f"Final Validation Loss: {val_losses[-1]:.4f}")


Save the model

In [None]:
"""
# Save model checkpoint
checkpoint = {
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'train_losses': train_losses,
    'val_losses': val_losses,
    'epoch': num_epochs,
    'vocab_size': vocab_size,
    'd_embd': d_embd,
    'n_heads': n_heads,
    'stoi': stoi,
    'itos': itos
}

torch.save(checkpoint, 'bertolingo_model.pt')
print("Model saved to 'bertolingo_model.pt'")
"""


Inference (translation)

In [None]:
def translate(model, src_text, stoi, itos, device, max_len=128):
    """
    Translate English text to Italian using the trained model
    
    Args:
        model: Trained Seq2Seq model
        src_text: English text string
        stoi: String to index mapping
        itos: Index to string mapping
        device: Device to run on
        max_len: Maximum translation length
    Returns:
        Translated Italian text
    """
    model.eval()
    
    # Encode source text
    src_tokens = torch.tensor([[stoi.get(char, stoi['<PAD>']) for char in src_text]], device=device)
    src_mask = (src_tokens == stoi['<PAD>'])
    
    # Encode source
    with torch.no_grad():
        encoder_out = model.encoder(src_tokens, key_padding_mask=src_mask)
        
        # Start with first token (or padding if we had BOS)
        tgt_tokens = torch.tensor([[stoi.get(src_text[0], stoi['<PAD>'])]], device=device)
        
        # Autoregressive decoding
        for _ in range(max_len):
            tgt_mask = (tgt_tokens == stoi['<PAD>'])
            logits = model.decoder(
                tgt_tokens,
                encoder_out,
                tgt_key_padding_mask=tgt_mask,
                src_key_padding_mask=src_mask
            )
            
            # Get next token (greedy decoding)
            next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
            tgt_tokens = torch.cat([tgt_tokens, next_token], dim=1)
            
            # Stop if padding token (or could add EOS token)
            if next_token.item() == stoi['<PAD>']:
                break
    
    # Decode target tokens
    translated = ''.join([itos.get(idx.item(), '') for idx in tgt_tokens[0]])
    
    return translated

# Test translation on a sample
print("Testing translation...")
test_english = "Hello, how are you?"
print(f"English: {test_english}")
print(f"Italian: {translate(model, test_english, stoi, itos, device)}")
