In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence

import numpy as np
import math
import time
import os
import re
import random
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import warnings
warnings.filterwarnings('ignore')
from datasets import load_dataset
import spacy
import sentencepiece as spm
from sacremoses import MosesTokenizer, MosesDetokenizer

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [3]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True

In [4]:
def download_wmt14():
    print("Downloading WMT14 DE-EN dataset...")
    # Download the dataset
    dataset = load_dataset("wmt14", "de-en")
    return dataset

In [5]:
D_MODEL = 512  # Embedding dimension
NHEAD = 8  # Number of attention heads
NUM_ENCODER_LAYERS = 6  # Number of encoder layers
NUM_DECODER_LAYERS = 6  # Number of decoder layers
DIM_FEEDFORWARD = 2048  # Dimension of the feedforward network
DROPOUT = 0.1  # Dropout rate
BATCH_SIZE = 32  # Batch size for training
ACCUMULATION_STEPS = 4  # Gradient accumulation steps
LEARNING_RATE = 0.0001  # Initial learning rate
BETAS = (0.9, 0.98)  # Adam optimizer betas
EPS = 1e-9  # Adam optimizer epsilon
WEIGHT_DECAY = 0.0001  # Weight decay for regularization
LABEL_SMOOTHING = 0.1  # Label smoothing factor
MAX_SEQ_LENGTH = 100  # Maximum sequence length
WARMUP_STEPS = 4000  # Warmup steps for learning rate scheduler
MAX_EPOCHS = 30  # Maximum number of epochs
CLIP_GRAD = 1.0  # Gradient clipping value
PATIENCE = 5  # Patience for early stopping

In [6]:
BOS_TOKEN = '<s>'
EOS_TOKEN = '</s>'
PAD_TOKEN = '<pad>'
UNK_TOKEN = '<unk>'
SPECIAL_TOKENS = [UNK_TOKEN, PAD_TOKEN, BOS_TOKEN, EOS_TOKEN]

In [7]:
def use_huggingface_tokenizers():
    """Use pre-trained tokenizers from HuggingFace"""
    from transformers import MarianTokenizer
    
    print("Loading pre-trained MarianMT tokenizers...")
    tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-de-en")
    
    # Create a simplified vocabulary wrapper to match our expected interface
    class HFVocabulary:
        def __init__(self, tokenizer):
            self.tokenizer = tokenizer
            self.pad_idx = tokenizer.pad_token_id
            self.bos_idx = tokenizer.bos_token_id
            self.eos_idx = tokenizer.eos_token_id
            self.unk_idx = tokenizer.unk_token_id
        
        def __len__(self):
            return len(self.tokenizer)
        
        def encode(self, text):
            """Convert text to token IDs"""
            return self.tokenizer.encode(text, add_special_tokens=False)
        
        def encode_with_special_tokens(self, text, add_bos=True, add_eos=True):
            """Encode with optional BOS/EOS tokens"""
            # Manually handle BOS and EOS to avoid warnings about unrecognized parameters
            ids = self.tokenizer.encode(text, add_special_tokens=False)
            
            if add_bos:
                ids = [self.bos_idx] + ids
            if add_eos:
                ids = ids + [self.eos_idx]
            
            return ids
        
        def decode(self, ids):
            """Convert token IDs to text"""
            return self.tokenizer.decode(ids, skip_special_tokens=True)
        
        def token_to_id(self, token):
            """Get ID for a token"""
            return self.tokenizer.convert_tokens_to_ids(token)
        
        def id_to_token(self, id):
            """Get token for an ID"""
            return self.tokenizer.convert_ids_to_tokens(id)
    
    return HFVocabulary(tokenizer)

In [8]:
class SPVocabulary:
    def __init__(self, sp_model):
        self.sp = sp_model
        self.pad_idx = self.sp.piece_to_id(PAD_TOKEN)
        self.bos_idx = self.sp.piece_to_id(BOS_TOKEN)
        self.eos_idx = self.sp.piece_to_id(EOS_TOKEN)
        self.unk_idx = self.sp.piece_to_id(UNK_TOKEN)
    
    def __len__(self):
        return self.sp.get_piece_size()
    
    def encode(self, text):
        """Convert text to token IDs"""
        return self.sp.encode(text, out_type=int)
    
    def encode_with_special_tokens(self, text, add_bos=True, add_eos=True):
        """Encode with optional BOS/EOS tokens"""
        ids = self.encode(text)
        if add_bos:
            ids = [self.bos_idx] + ids
        if add_eos:
            ids = ids + [self.eos_idx]
        return ids
    
    def decode(self, ids):
        """Convert token IDs to text"""
        return self.sp.decode(ids)
    
    def token_to_id(self, token):
        """Get ID for a token"""
        return self.sp.piece_to_id(token)
    
    def id_to_token(self, id):
        """Get token for an ID"""
        return self.sp.id_to_piece(id)

In [9]:
class TranslationDataset(Dataset):
    def __init__(self, data, vocab, max_len=MAX_SEQ_LENGTH):
        self.data = data
        self.vocab = vocab
        self.max_len = max_len
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        # Get source and target texts
        try:
            src_text = self.data[idx]['translation']['de']
            tgt_text = self.data[idx]['translation']['en']
            
            # Skip empty strings
            if not src_text or not tgt_text:
                # Return a simple placeholder with the right format
                return torch.tensor([self.vocab.bos_idx, self.vocab.eos_idx]), \
                       torch.tensor([self.vocab.bos_idx, self.vocab.eos_idx])
            
            # Encode source and target texts
            src_ids = self.vocab.encode_with_special_tokens(src_text, add_bos=False, add_eos=True)
            tgt_ids = self.vocab.encode_with_special_tokens(tgt_text, add_bos=True, add_eos=True)
            
            # Handle empty encodings or None
            if not src_ids or src_ids is None:
                src_ids = [self.vocab.bos_idx, self.vocab.eos_idx]
            if not tgt_ids or tgt_ids is None:
                tgt_ids = [self.vocab.bos_idx, self.vocab.eos_idx]
            
            # Truncate if necessary
            if len(src_ids) > self.max_len:
                src_ids = src_ids[:self.max_len-1] + [self.vocab.eos_idx]
            if len(tgt_ids) > self.max_len:
                tgt_ids = tgt_ids[:self.max_len-1] + [self.vocab.eos_idx]
                
            return torch.tensor(src_ids), torch.tensor(tgt_ids)
            
        except Exception as e:
            print(f"Error processing example {idx}: {e}")
            # Return a simple pair of tensors to maintain data flow
            return torch.tensor([self.vocab.bos_idx, self.vocab.eos_idx]), \
                   torch.tensor([self.vocab.bos_idx, self.vocab.eos_idx])

In [10]:
def collate_fn(batch, pad_idx):
    src_batch, tgt_batch = [], []
    for src, tgt in batch:
        src_batch.append(src)
        tgt_batch.append(tgt)
    
    # Pad sequences
    src_batch = pad_sequence(src_batch, padding_value=pad_idx, batch_first=True)
    tgt_batch = pad_sequence(tgt_batch, padding_value=pad_idx, batch_first=True)
    
    return src_batch, tgt_batch

In [11]:
def create_dataloaders(dataset, vocab, batch_size=BATCH_SIZE, val_split=0.05, max_len=MAX_SEQ_LENGTH):
    """Create training and validation DataLoader objects"""
    # Filter out examples with empty translations
    filtered_data = []
    for example in dataset['train']:
        if (example['translation']['de'] and 
            example['translation']['en'] and 
            len(example['translation']['de']) > 5 and  # Ensure minimal sentence length
            len(example['translation']['en']) > 5):
            filtered_data.append(example)
    
    # Create a custom dataset dict
    filtered_dataset = {'train': filtered_data}
    if 'test' in dataset:
        # Also filter test data
        filtered_test = []
        for example in dataset['test']:
            if (example['translation']['de'] and 
                example['translation']['en'] and 
                len(example['translation']['de']) > 5 and
                len(example['translation']['en']) > 5):
                filtered_test.append(example)
        filtered_dataset['test'] = filtered_test
    else:
        filtered_dataset['test'] = []  # Empty test set if not present
    
    print(f"Filtered training examples: {len(filtered_dataset['train'])} (from {len(dataset['train'])})")
    if 'test' in dataset:
        print(f"Filtered test examples: {len(filtered_dataset['test'])} (from {len(dataset['test'])})")
    
    # Create a TranslationDataset
    full_dataset = TranslationDataset(filtered_dataset['train'], vocab, max_len)
    
    # Split into train and validation
    val_size = int(len(full_dataset) * val_split)
    train_size = len(full_dataset) - val_size
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
    
    # Get test dataset if available
    test_dataset = None
    if len(filtered_dataset['test']) > 0:
        test_dataset = TranslationDataset(filtered_dataset['test'], vocab, max_len)
    else:
        # Use part of validation as test if no test set available
        additional_val_size = int(val_size * 0.5)
        val_size = val_size - additional_val_size
        val_dataset, test_dataset = random_split(val_dataset, [val_size, additional_val_size])
    
    # Configure DataLoader parameters to handle errors gracefully
    dataloader_kwargs = {
        'batch_size': batch_size,
        'collate_fn': lambda batch: collate_fn(batch, vocab.pad_idx),
        'pin_memory': True,
        'num_workers': 12, 
        'persistent_workers': True if torch.cuda.is_available() else False,
    }
    
    # Create DataLoaders
    train_loader = DataLoader(
        train_dataset,
        shuffle=True,
        **dataloader_kwargs
    )
    
    val_loader = DataLoader(
        val_dataset,
        shuffle=False,
        **dataloader_kwargs
    )
    
    test_loader = DataLoader(
        test_dataset,
        shuffle=False,
        **dataloader_kwargs
    )
    
    return train_loader, val_loader, test_loader

In [12]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # Create positional encoding matrix
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        # Calculate sin and cos terms
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # Add batch dimension and register as buffer
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        # Add positional encoding to input
        x = x + self.pe[:, :x.size(1)].to(x.device)
        return self.dropout(x)

In [13]:
class TransformerModel(nn.Module):
    def __init__(self, vocab_size, d_model=D_MODEL, nhead=NHEAD, 
                 num_encoder_layers=NUM_ENCODER_LAYERS, 
                 num_decoder_layers=NUM_DECODER_LAYERS,
                 dim_feedforward=DIM_FEEDFORWARD, dropout=DROPOUT):
        super(TransformerModel, self).__init__()
        
        # Store dimensions
        self.d_model = d_model
        self.vocab_size = vocab_size
        
        # Embedding layers
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, dropout)
        
        # Transformer layers from PyTorch
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        
        encoder_norm = nn.LayerNorm(d_model)
        decoder_norm = nn.LayerNorm(d_model)
        
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer=encoder_layer,
            num_layers=num_encoder_layers,
            norm=encoder_norm
        )
        
        self.transformer_decoder = nn.TransformerDecoder(
            decoder_layer=decoder_layer,
            num_layers=num_decoder_layers,
            norm=decoder_norm
        )
        
        # Final linear layer for output
        self.fc_out = nn.Linear(d_model, vocab_size)
        
        # Initialize parameters with Glorot/Xavier initialization
        self._reset_parameters()
        
    def _reset_parameters(self):
        """Initialize parameters with appropriate scaling"""
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
                
        # Special initialization for embedding
        nn.init.normal_(self.embedding.weight, mean=0, std=self.d_model ** -0.5)
        
    def create_padding_mask(self, src, pad_idx):
        """Create mask for padding tokens (True where pad token)"""
        return (src == pad_idx).to(device)
        
    def create_look_ahead_mask(self, size):
        """Create mask to prevent attention to future tokens"""
        # Create an upper triangular matrix with 1s
        mask = torch.triu(torch.ones(size, size), diagonal=1).bool()
        return mask.to(device)
    
    def forward(self, src, tgt):
        """Forward pass of the transformer model"""
        # Get padding token index
        pad_idx = self.embedding.padding_idx if self.embedding.padding_idx is not None else 0
        
        # Create masks
        src_padding_mask = self.create_padding_mask(src, pad_idx)  # [batch_size, src_len]
        tgt_padding_mask = self.create_padding_mask(tgt[:, :-1], pad_idx)  # [batch_size, tgt_len-1]
        tgt_look_ahead_mask = self.create_look_ahead_mask(tgt.size(1)-1)  # [tgt_len-1, tgt_len-1]
        
        # Embed and apply positional encoding
        src_embedded = self.embedding(src) * math.sqrt(self.d_model)
        src_embedded = self.positional_encoding(src_embedded)
        
        tgt_embedded = self.embedding(tgt[:, :-1]) * math.sqrt(self.d_model)
        tgt_embedded = self.positional_encoding(tgt_embedded)
        
        # Transformer encoder
        memory = self.transformer_encoder(
            src=src_embedded,
            src_key_padding_mask=src_padding_mask
        )
        
        # Transformer decoder
        output = self.transformer_decoder(
            tgt=tgt_embedded,
            memory=memory,
            tgt_mask=tgt_look_ahead_mask,
            tgt_key_padding_mask=tgt_padding_mask,
            memory_key_padding_mask=src_padding_mask
        )
        
        # Final output projection
        output = self.fc_out(output)
        
        return output

In [14]:
class NoamLR:
    """Learning rate scheduler from 'Attention is All You Need'"""
    def __init__(self, optimizer, d_model, warmup_steps=4000, factor=1.0):
        self.optimizer = optimizer
        self.d_model = d_model
        self.warmup_steps = warmup_steps
        self.factor = factor
        self._step = 0
        self._rate = 0
        
    def step(self):
        """Update parameters and learning rate"""
        self._step += 1
        rate = self._get_lr()
        self._rate = rate
        
        for p in self.optimizer.param_groups:
            p['lr'] = rate
            
    def _get_lr(self):
        """Calculate learning rate according to the formula"""
        step = self._step
        return self.factor * (self.d_model ** -0.5) * min(step ** -0.5, step * self.warmup_steps ** -1.5)

In [15]:
def train_epoch(model, dataloader, optimizer, scheduler, criterion, clip_grad=CLIP_GRAD, accumulation_steps=ACCUMULATION_STEPS):
    model.train()
    total_loss = 0
    processed_batches = 0
    optimizer.zero_grad()
    
    progress_bar = tqdm(dataloader, desc="Training")
    
    for batch_idx, (src, tgt) in enumerate(progress_bar):
        src, tgt = src.to(device), tgt.to(device)
        
        # Forward pass
        output = model(src, tgt)
        
        # Calculate loss
        output = output.contiguous().view(-1, output.shape[-1])
        tgt = tgt[:, 1:].contiguous().view(-1)  # Shift right to align with output
        loss = criterion(output, tgt) / accumulation_steps
        
        # Backward pass
        loss.backward()
        
        # Track total loss
        total_loss += loss.item() * accumulation_steps
        processed_batches += 1
        
        # Update progress bar
        progress_bar.set_postfix({'loss': total_loss / processed_batches})
        
        # Gradient accumulation
        if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(dataloader):
            # Clip gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad)
            
            # Update parameters
            optimizer.step()
            optimizer.zero_grad()
            
            # Update learning rate
            scheduler.step()
    
    return total_loss / processed_batches

In [16]:
def evaluate(model, dataloader, criterion):
    model.eval()
    total_loss = 0
    processed_batches = 0
    
    with torch.no_grad():
        for src, tgt in tqdm(dataloader, desc="Evaluating"):
            src, tgt = src.to(device), tgt.to(device)
            
            # Forward pass
            output = model(src, tgt)
            
            # Calculate loss
            output = output.contiguous().view(-1, output.shape[-1])
            tgt = tgt[:, 1:].contiguous().view(-1)
            loss = criterion(output, tgt)
            
            # Track total loss
            total_loss += loss.item()
            processed_batches += 1
    
    return total_loss / processed_batches


In [17]:
def beam_search_decode(model, src, max_len=100, beam_size=5):
    model.eval()
    
    # Move source tensor to device
    src = src.to(device)
    
    # Encode source sequence
    with torch.no_grad():
        src_embedded = model.embedding(src) * math.sqrt(model.d_model)
        src_embedded = model.positional_encoding(src_embedded)
        
        src_padding_mask = model.create_padding_mask(
            src, model.embedding.padding_idx if model.embedding.padding_idx is not None else 0
        )
        
        memory = model.transformer_encoder(
            src=src_embedded,
            src_key_padding_mask=src_padding_mask
        )
    
    # Start with BOS token
    bos_idx = model.embedding.padding_idx + 1 if model.embedding.padding_idx is not None else 1
    ys = torch.ones(1, 1).fill_(bos_idx).type_as(src).long().to(device)
    
    # Initialize beams: (sequence, score)
    beams = [(ys, 0.0)]
    completed_beams = []
    
    for _ in range(max_len - 1):
        candidates = []
        
        for seq, score in beams:
            # If this sequence ended with EOS
            if seq[0, -1].item() == bos_idx + 1:
                completed_beams.append((seq, score))
                continue
            
            # Predict next tokens
            with torch.no_grad():
                tgt_embedded = model.embedding(seq) * math.sqrt(model.d_model)
                tgt_embedded = model.positional_encoding(tgt_embedded)
                
                tgt_mask = model.create_look_ahead_mask(seq.size(1))
                tgt_padding_mask = model.create_padding_mask(
                    seq, model.embedding.padding_idx if model.embedding.padding_idx is not None else 0
                )
                
                out = model.transformer_decoder(
                    tgt=tgt_embedded,
                    memory=memory,
                    tgt_mask=tgt_mask,
                    tgt_key_padding_mask=tgt_padding_mask,
                    memory_key_padding_mask=src_padding_mask
                )
                
                out = model.fc_out(out[:, -1])
                prob = F.log_softmax(out, dim=-1)
            
            # Get top-k candidates
            topk_prob, topk_idx = prob[0].topk(beam_size)
            
            for i in range(beam_size):
                next_token = topk_idx[i].unsqueeze(0).unsqueeze(0)
                next_score = score + topk_prob[i].item()
                next_seq = torch.cat([seq, next_token], dim=1)
                candidates.append((next_seq, next_score))
        
        # Keep only top beams
        beams = sorted(candidates, key=lambda x: x[1], reverse=True)[:beam_size]
        
        # Early stopping if all beams end with EOS
        if all(beam[0][0, -1].item() == bos_idx + 1 for beam in beams):
            break
    
    # Add any remaining beams to completed
    for seq, score in beams:
        if seq[0, -1].item() != bos_idx + 1:
            seq = torch.cat([seq, torch.ones(1, 1).fill_(bos_idx + 1).type_as(src).to(device)], dim=1)
        completed_beams.append((seq, score))
    
    # Return the highest scoring beam
    if completed_beams:
        return max(completed_beams, key=lambda x: x[1])[0]
    else:
        return beams[0][0]

In [18]:
def train_model(model, train_loader, val_loader, optimizer, scheduler, criterion, 
               max_epochs=MAX_EPOCHS, patience=PATIENCE, save_path="best_model.pt"):
    # Initialize tracking variables
    best_val_loss = float('inf')
    epochs_without_improvement = 0
    train_losses = []
    val_losses = []
    
    print("Starting training...")
    for epoch in range(max_epochs):
        start_time = time.time()
        
        # Training
        train_loss = train_epoch(model, train_loader, optimizer, scheduler, criterion)
        train_losses.append(train_loss)
        
        # Validation
        val_loss = evaluate(model, val_loader, criterion)
        val_losses.append(val_loss)
        
        # Calculate time taken
        epoch_mins, epoch_secs = divmod(time.time() - start_time, 60)
        
        # Print progress
        print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs:.2f}s')
        print(f'\tTrain Loss: {train_loss:.3f} | Val Loss: {val_loss:.3f}')
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
            }, save_path)
            print(f"New best model saved with validation loss: {val_loss:.3f}")
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1
            print(f"No improvement for {epochs_without_improvement} epochs")
        
        # Early stopping
        if epochs_without_improvement >= patience:
            print(f"Early stopping after {epoch+1} epochs")
            break
    
    # Plot training and validation losses
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training and Validation Loss')
    plt.savefig('loss_plot.png')
    plt.show()
    
    return train_losses, val_losses

In [19]:
dataset = download_wmt14()
print(f"Training examples: {len(dataset['train'])}")
print(f"Test examples: {len(dataset['test'])}")

Downloading WMT14 DE-EN dataset...
Training examples: 4508785
Test examples: 3003


In [20]:
vocab = use_huggingface_tokenizers()
print(f"HuggingFace vocabulary size: {len(vocab)}")

Loading pre-trained MarianMT tokenizers...
HuggingFace vocabulary size: 58101


In [21]:
train_loader, val_loader, test_loader = create_dataloaders(dataset, vocab)

Filtered training examples: 4502068 (from 4508785)
Filtered test examples: 3002 (from 3003)


In [22]:
model = TransformerModel(len(vocab)).to(device)
print(f"Model has {sum(p.numel() for p in model.parameters()):,} parameters")

Model has 103,694,069 parameters


In [23]:
model.embedding.padding_idx = vocab.pad_idx
criterion = nn.CrossEntropyLoss(ignore_index=vocab.pad_idx, label_smoothing=LABEL_SMOOTHING)

In [24]:
optimizer = optim.Adam(
    model.parameters(), 
    lr=LEARNING_RATE, 
    betas=BETAS, 
    eps=EPS,
    weight_decay=WEIGHT_DECAY
)
scheduler = NoamLR(optimizer, D_MODEL, WARMUP_STEPS)

In [27]:
train_losses, val_losses = train_model(model, train_loader, val_loader, optimizer, scheduler, criterion)
print("Training complete!")

Starting training...


Training:   0%|          | 0/133656 [00:00<?, ?it/s]

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/tmp/ipykernel_232698/790162208.py", line 38, in __getitem__
    return torch.tensor(src_ids), torch.tensor(tgt_ids)
                                  ^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Could not infer dtype of NoneType

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/sunil/.local/lib/python3.12/site-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/sunil/.local/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py", line 50, in fetch
    data = self.dataset.__getitems__(possibly_batched_index)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sunil/.local/lib/python3.12/site-packages/torch/utils/data/dataset.py", line 420, in __getitems__
    return [self.dataset[self.indices[idx]] for idx in indices]
            ~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_232698/790162208.py", line 43, in __getitem__
    return torch.tensor([self.vocab.bos_idx, self.vocab.eos_idx]), \
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Could not infer dtype of NoneType


In [26]:
final_model_path = "final_wmt14_de_en_model.pt"
torch.save({
    'model_state_dict': model.state_dict(),
    'vocab_size': len(vocab),
    'model_config': {
        'd_model': D_MODEL,
        'nhead': NHEAD,
        'num_encoder_layers': NUM_ENCODER_LAYERS,
        'num_decoder_layers': NUM_DECODER_LAYERS,
        'dim_feedforward': DIM_FEEDFORWARD,
        'dropout': DROPOUT
    }
}, final_model_path)

In [27]:
sp_model.save("spm_wmt14_de_en.model")

print(f"Model saved to {final_model_path}")
print("SentencePiece model saved to spm_wmt14_de_en.model")

In [57]:
model = Transformer(
    len(de_vocab), 
    len(en_vocab), 
    d_model=D_MODEL, 
    nhead=NHEAD,
    num_encoder_layers=NUM_ENCODER_LAYERS, 
    num_decoder_layers=NUM_DECODER_LAYERS,
    d_ff=D_FF, 
    dropout=DROPOUT
).to(device)
print(f"Model has {sum(p.numel() for p in model.parameters())} parameters")

Model has 10564101 parameters


In [58]:
def lr_schedule(step, d_model, warmup_steps):
    """Learning rate schedule as described in the paper"""
    arg1 = step ** -0.5
    arg2 = step * (warmup_steps ** -1.5)
    return (d_model ** -0.5) * min(arg1, arg2)

optimizer = torch.optim.Adam(model.parameters(), lr=0, betas=BETAS, eps=EPS)
scheduler = LambdaLR(
    optimizer,
    lr_lambda=lambda step: lr_schedule(step + 1, D_MODEL, WARMUP_STEPS)
)

In [59]:
criterion = nn.CrossEntropyLoss(ignore_index=de_vocab[PAD_TOKEN])

In [63]:
def beam_search_decode(model, src, max_len, start_token, end_token, pad_token, beam_size=5):
    model.eval()
    
    src = src.to(device)
    src_mask = create_padding_mask(src, pad_token)
    
    # Encode source sequence
    memory = model.encode(src, src_mask)
    
    # Initialize with start token
    ys = torch.ones(1, 1).fill_(start_token).type_as(src).to(device)
    
    # Initial beam
    beams = [(ys, 0.0)]  # (sequence, score)
    completed_beams = []
    
    for _ in range(max_len - 1):
        candidates = []
        
        for seq, score in beams:
            if seq[0, -1].item() == end_token:
                completed_beams.append((seq, score))
                continue
                
            # Create masks
            tgt_mask = create_look_ahead_mask(seq.size(1)).to(device)
            memory_mask = src_mask
            
            # Decode one step
            out = model.decode(seq, memory, tgt_mask, memory_mask)
            prob = F.log_softmax(model.generator(out[:, -1]), dim=-1)
            
            # Get top k
            topk_prob, topk_idx = prob.topk(beam_size)
            
            for i in range(beam_size):
                next_token = topk_idx[0, i].item()
                next_score = score + topk_prob[0, i].item()
                next_seq = torch.cat([seq, torch.ones(1, 1).type_as(src).fill_(next_token).to(device)], dim=1)
                candidates.append((next_seq, next_score))
        
        # Keep only top beams
        beams = sorted(candidates, key=lambda x: x[1], reverse=True)[:beam_size]
        
        # Stop if all beams end with end token
        if all(beam[0][0, -1].item() == end_token for beam in beams):
            break
    
    # Add any remaining beams to completed
    for seq, score in beams:
        if seq[0, -1].item() != end_token:
            seq = torch.cat([seq, torch.ones(1, 1).type_as(src).fill_(end_token).to(device)], dim=1)
        completed_beams.append((seq, score))
    
    # Return the highest scoring beam
    return max(completed_beams, key=lambda x: x[1])[0]

In [64]:
def translate(model, sentence, src_tokenizer, tgt_vocab, src_vocab, max_len=50, beam_size=5):
    model.eval()
    
    # Tokenize input sentence
    tokens = src_tokenizer(sentence)
    
    # Convert tokens to indices
    token_ids = [src_vocab[token] for token in tokens]
    
    # Create tensor
    src = torch.tensor([token_ids]).to(device)
    
    # Decode using beam search
    output = beam_search_decode(
        model, 
        src, 
        max_len, 
        tgt_vocab[BOS_TOKEN], 
        tgt_vocab[EOS_TOKEN], 
        src_vocab[PAD_TOKEN],
        beam_size
    )
    
    # Convert output indices to tokens
    output_tokens = [tgt_vocab.lookup_token(i) for i in output[0, 1:].tolist()]
    
    # Stop at EOS token
    if tgt_vocab[EOS_TOKEN] in output_tokens:
        output_tokens = output_tokens[:output_tokens.index(tgt_vocab[EOS_TOKEN])]
    
    return ' '.join(output_tokens)