In [25]:
from datasets import load_dataset
import torch

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

ds = load_dataset("wmt/wmt14", "de-en", cache_dir="./data/wmt14")

vocab_size = 30_000
vocab_path = "./data/bpe_tokenizer.json"

training_samples = 100_000
batch_size = 64
dataset_max_sample_len = 100

# Transformer model parameters
d_model=192
nhead=6
num_encoder_layers=3
num_decoder_layers=3
dim_feedforward=512
dropout=0.1
max_len=150

num_epochs = 100
warmup_steps = 2000
eval_iters = 10

In [26]:
from tokenizers import Tokenizer as HFTokenizer, decoders
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Metaspace
from tokenization_vocab import Tokenizer
from pathlib import Path

bpe_tokenizer = HFTokenizer(BPE(unk_token=Tokenizer.UNK_TOKEN))
trainer = BpeTrainer(
    special_tokens=[Tokenizer.PAD_TOKEN, Tokenizer.SOS_TOKEN, Tokenizer.EOS_TOKEN, Tokenizer.UNK_TOKEN],
    vocab_size=vocab_size,
    show_progress=True
)

bpe_tokenizer.pre_tokenizer = Metaspace()
bpe_tokenizer.decoder = decoders.Metaspace()

pretrained = True  # Set to True if you want to load a previously saved tokenizer

Path(vocab_path).parent.mkdir(parents=True, exist_ok=True)

if Path(vocab_path).is_file():
    pretrained = True

if pretrained:
    bpe_tokenizer = HFTokenizer.from_file(vocab_path)
else:
    bpe_tokenizer.train(
        [
            './datasets/wmt-2014-english-german/wmt14_translate_de-en_test.csv',
            './datasets/wmt-2014-english-german/wmt14_translate_de-en_train.csv',
            './datasets/wmt-2014-english-german/wmt14_translate_de-en_validation.csv',
        ],
        trainer=trainer
    )

    bpe_tokenizer.save(vocab_path)

In [None]:
from parallel_corpus import TranslationDataset, DataLoaderFactory
from tokenization_vocab import HFTokenizerWrapper

sample_sentences = [(s['de'], s['en']) for s in ds['train']['translation'][:training_samples]]

sample_sentences_de = [s[0] for s in sample_sentences]
sample_sentences_en = [s[1] for s in sample_sentences]

tokenizer = HFTokenizerWrapper(bpe_tokenizer)

train_size = int(len(sample_sentences)*0.9)
train_sents, test_sents = sample_sentences[:train_size], sample_sentences[train_size:]

train_ds = TranslationDataset(
    source_sentences=[s[0] for s in train_sents],
    target_sentences=[s[1] for s in train_sents],
    source_tokenizer=tokenizer,
    target_tokenizer=tokenizer,
    max_length=dataset_max_sample_len
)

test_ds = TranslationDataset(
    source_sentences=[s[0] for s in test_sents],
    target_sentences=[s[1] for s in test_sents],
    source_tokenizer=tokenizer,
    target_tokenizer=tokenizer,
    max_length=dataset_max_sample_len
)

train_loader = DataLoaderFactory.create_dataloader(
    dataset=train_ds,
    batch_size=batch_size,
    pad_idx=tokenizer.pad_idx,
    shuffle=True  # IMPORTANT: Shuffle data each epoch for better training
)

test_loader = DataLoaderFactory.create_dataloader(
    dataset=test_ds,
    batch_size=batch_size,
    pad_idx=tokenizer.pad_idx,
    shuffle=True  # IMPORTANT: Shuffle data each epoch for better training
)

In [None]:
for src_batch, tgt_batch in train_loader:
    print("Source batch shape:", src_batch.shape)
    print("Target batch shape:", tgt_batch.shape)

    print(tokenizer.decode_to_text(src_batch[0].tolist()))

    break  # Just to demonstrate one batch

In [None]:
# Import the TranslationTransformer
from translation_transformer import TranslationTransformer

# Initialize the model with larger max_len to handle max_length + special tokens
model = TranslationTransformer(
    src_vocab_size=len(tokenizer),
    tgt_vocab_size=len(tokenizer),
    d_model=d_model,
    nhead=nhead,
    num_encoder_layers=num_encoder_layers,
    num_decoder_layers=num_decoder_layers,
    dim_feedforward=dim_feedforward,
    dropout=dropout,
    max_len=max_len,
    padding_idx=tokenizer.pad_idx
)

print(f"Model initialized!")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Apply PyTorch optimizations
import torch

# 1. Enable TF32 for faster matmul on Ampere+ GPUs (A100, RTX 3090, etc.)
# This provides ~2x speedup for matrix multiplications with minimal accuracy loss
torch.set_float32_matmul_precision('high')  # Options: 'highest', 'high', 'medium'

# 2. For MPS (Apple Silicon), ensure we're using optimal settings
if DEVICE == "mps":
    # MPS backend is already optimized, but we can ensure memory efficiency
    torch.mps.empty_cache()  # Clear any cached memory
elif DEVICE == "cuda":
    # Enable TF32 for cuDNN convolutions as well
    torch.backends.cudnn.allow_tf32 = True
else:
    print("âœ“ Running on CPU (no GPU optimizations)")

# 3. Compile the model for faster execution
# torch.compile() uses TorchDynamo to JIT compile your model
# This can provide 2-10x speedup depending on the model
model_compiled = torch.compile(model, mode='default')  # Options: 'default', 'reduce-overhead', 'max-autotune'

In [None]:
import torch.nn as nn
import torch.optim as optim
from train import train

# Loss function and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_idx, label_smoothing=0.1)
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9)

# Move model to device (GPU if available)
model = model.to(DEVICE)
model.train()

print(f"Using device: {DEVICE}")
print(f"Model moved to {DEVICE}")

# Training
train_losses, best_loss = train(
    model=model_compiled,
    dataloader=train_loader,
    dataset_size=len(train_ds),
    train_loader=train_loader,
    test_loader=test_loader,
    criterion=criterion,
    optimizer=optimizer,
    device=DEVICE,
    num_epochs=num_epochs,
    warmup_steps=warmup_steps,
    eval_iters=eval_iters,
    patience=3
)

In [None]:
# Plot training progress
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 6))
plt.plot(range(1, len(train_losses) + 1), train_losses, marker='o', linewidth=2, markersize=6)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Training Loss', fontsize=12)
plt.title('Training Loss Over Time', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"\nTraining Summary:")
print(f"  Initial Loss: {train_losses[0]:.4f}")
print(f"  Final Loss: {train_losses[-1]:.4f}")
print(f"  Best Loss: {best_loss:.4f}")
print(f"  Improvement: {((train_losses[0] - train_losses[-1]) / train_losses[0] * 100):.1f}%")

In [None]:
# Test translation on sample input
import torch

@torch.no_grad()
def translate_sample(sentence: str, model: TranslationTransformer,
                     src_tokenizer: HFTokenizerWrapper, tgt_tokenizer: HFTokenizerWrapper, max_len=100, device=device):
    """
    Translate a single sentence using the model with autoregressive generation.
    
    Args:
        sentence: Input sentence to translate
        model: TranslationTransformer model
        src_tokenizer: HFTokenizerWrapper for the source language
        tgt_tokenizer: HFTokenizerWrapper for the target language
        max_len: Maximum sequence length to generate
        device: Device to run on
        
    Returns:
        Translated sentence and token indices
    """
    model.eval()
    
    # Tokenize input
    src_tokens = src_tokenizer.tokenize(sentence)
    print(f"Input tokens: {src_tokens}")
    
    # Encode with EOS token only (source side)
    src_indices = src_tokenizer.encode(src_tokens, add_sos=False, add_eos=True)
    src_tensor = torch.tensor(src_indices, dtype=torch.long).unsqueeze(0).to(device)  # [1, seq_len]
    
    print(f"Input tensor shape: {src_tensor.shape}")
    print(f"Input indices: {src_indices}")
    
    # Initialize target with just SOS token
    tgt_indices = [tgt_tokenizer.sos_idx]
    
    # Autoregressive generation loop
    for _ in range(max_len):
        # Convert current target indices to tensor
        tgt_tensor = torch.tensor([tgt_indices], dtype=torch.long).to(device)  # [1, current_len]
        
        # Forward pass
        output = model(src_tensor, tgt_tensor)  # [1, current_len, vocab_size]
        
        # Get prediction for the last token
        next_token_logits = output[0, -1, :]  # [vocab_size]
        next_token = torch.argmax(next_token_logits).item()
        
        # Append predicted token
        tgt_indices.append(next_token)
        
        # Stop if we predict EOS token
        if next_token == tgt_tokenizer.eos_idx:
            break
    
    print(f"Generated {len(tgt_indices)} tokens")
    print(f"Predicted indices: {tgt_indices}")
    
    # Decode back to tokens (skip SOS and EOS)
    translation = tgt_tokenizer.decode_to_text(tgt_indices)  # Remove SOS and EOS
        
    return translation, tgt_indices


# Test with a sample German sentence from the dataset
sample_idx = 10
sample_de = " ".join(tokenizer.decode(train_ds[sample_idx][0].tolist()))
sample_en = " ".join(tokenizer.decode(train_ds[sample_idx][1].tolist()))

print(f"Original German: {sample_de}")
print(f"Original English: {sample_en}")
print("\n" + "="*60 + "\n")

# Translate
translation, pred_indices = translate_sample(
    sample_de, 
    model, 
    src_tokenizer=tokenizer,
    tgt_tokenizer=tokenizer,
    max_len=150,
    device=DEVICE
)

print("\n" + "="*60)
print(f"\nModel Translation: '{translation}'")
print(f"Reference Translation: {sample_en}")


## Saving and loading the model

To load the model later, use:
```python
checkpoint = torch.load('./models/translation_model.pt')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
```