In [1]:
from datasets import load_dataset
import torch
from utils.translation_transformer import TransformerConfig

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {DEVICE}")


ds = load_dataset("wmt/wmt14", "de-en")

vocab_size = 30_000
vocab_path = "./data/bpe_tokenizer.json"

training_samples = len(ds["train"])
batch_size = 64
dataset_max_sample_len = 100

sharedVocab = True
# bpe_v3_ep12
configSmall = TransformerConfig(
 d_model=256,
 nhead=8,
 num_encoder_layers=4,
 num_decoder_layers=4,
 dim_feedforward=1024,
 dropout=0.1,
 max_len=150
)
# base model according to the paper 'Attention is all you need'
# big_3.8770loss
configBig = TransformerConfig(
    d_model=512,
    nhead=8,
    num_encoder_layers=6,
    num_decoder_layers=6,
    dim_feedforward=2048,
    dropout=0.1,
    max_len=150
)

# training
num_steps = 20_000
warmup_steps = 2_000
eval_iters = 10
patience = 1_000

label_smoothing = 0.1

# optimizer
start_lr = 3e-4
betas = (0.9, 0.98)
epsilon = 1e-9

  from .autonotebook import tqdm as notebook_tqdm


Using device: mps




In [2]:
from tokenizers import Tokenizer as HFTokenizer, decoders
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Metaspace
from utils.tokenization_vocab import HFTokenizerWrapper, Tokenizer
from pathlib import Path

bpe_tokenizer = HFTokenizer(BPE(unk_token=Tokenizer.UNK_TOKEN))
trainer = BpeTrainer(
    special_tokens=[Tokenizer.PAD_TOKEN, Tokenizer.SOS_TOKEN, Tokenizer.EOS_TOKEN, Tokenizer.UNK_TOKEN],
    vocab_size=vocab_size,
    show_progress=True
)

bpe_tokenizer.pre_tokenizer = Metaspace()
bpe_tokenizer.decoder = decoders.Metaspace()

pretrained = True  # Set to True if you want to load a previously saved tokenizer

Path(vocab_path).parent.mkdir(parents=True, exist_ok=True)

if Path(vocab_path).is_file():
    pretrained = True

if pretrained:
    bpe_tokenizer = HFTokenizer.from_file(vocab_path)
else:
    bpe_tokenizer.train(
        [
            './datasets/wmt14_translate_de-en_test.csv',
            './datasets/wmt14_translate_de-en_train.csv',
            './datasets/wmt14_translate_de-en_validation.csv',
        ],
        trainer=trainer
    )

    bpe_tokenizer.save(vocab_path)


tokenizer = HFTokenizerWrapper(bpe_tokenizer)

print(f"Vocab size: {bpe_tokenizer.get_vocab_size()}")

Vocab size: 30000


In [3]:
from utils.parallel_corpus import TranslationDataset, DataLoaderFactory, LazyTranslationPairs
from utils.tokenization_vocab import HFTokenizerWrapper
import os

# Create lazy wrappers - no materialization into lists!
train_src = LazyTranslationPairs(ds['train'], src_lang='de', tgt_lang='en', mode='src')
train_tgt = LazyTranslationPairs(ds['train'], src_lang='de', tgt_lang='en', mode='tgt')

test_src = LazyTranslationPairs(ds['test'], src_lang='de', tgt_lang='en', mode='src')
test_tgt = LazyTranslationPairs(ds['test'], src_lang='de', tgt_lang='en', mode='tgt')

tokenizer = HFTokenizerWrapper(bpe_tokenizer)

# Create datasets with lazy loading (processes on-the-fly, no upfront preprocessing)
train_ds = TranslationDataset(
    source_sentences=train_src,
    target_sentences=train_tgt,
    source_tokenizer=tokenizer,
    target_tokenizer=tokenizer,
    max_length=dataset_max_sample_len,
    lazy=True  # Enable lazy loading!
)

test_ds = TranslationDataset(
    source_sentences=test_src,
    target_sentences=test_tgt,
    source_tokenizer=tokenizer,
    target_tokenizer=tokenizer,
    max_length=dataset_max_sample_len,
    lazy=True
)

# Optimize num_workers based on CPU cores
optimal_workers = min(8, os.cpu_count() or 4)

train_loader = DataLoaderFactory.create_dataloader(
    dataset=train_ds,
    batch_size=batch_size,
    pad_idx=tokenizer.pad_idx,
    num_workers=optimal_workers,
    shuffle=True,  # Shuffle for training
    persistent_workers=True,  # Keep workers alive between epochs
    prefetch_factor=4  # Prefetch more batches
)

test_loader = DataLoaderFactory.create_dataloader(
    dataset=test_ds,
    batch_size=batch_size,
    pad_idx=tokenizer.pad_idx,
    num_workers=0,
    shuffle=False,  # No shuffle for testing
    persistent_workers=True,
    prefetch_factor=4
)

print(f"✓ Lazy loading enabled - no memory materialization!")
print(f"✓ Using {optimal_workers} workers for parallel processing")
print(f"Train samples: {len(train_ds):,}, Test samples: {len(test_ds):,}")
print(f"Train batches: {len(train_loader):,}, Test batches: {len(test_loader):,}")

Initialized lazy dataset with 4508785 sentence pairs
Initialized lazy dataset with 3003 sentence pairs
✓ Lazy loading enabled - no memory materialization!
✓ Using 8 workers for parallel processing
Train samples: 4,508,785, Test samples: 3,003
Train batches: 70,450, Test batches: 47


In [4]:
# Import the TranslationTransformer
from utils.translation_transformer import TranslationTransformer, TranslationTransformerPytorch

# Initialize the model with larger max_len to handle max_length + special tokens
model = TranslationTransformer(
    src_vocab_size=len(tokenizer),
    tgt_vocab_size=len(tokenizer),
    config=configBig,
    padding_idx=tokenizer.pad_idx,
    sharedVocab=sharedVocab
)

print(f"Model initialized!")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

15,360,000 parameters in shared embedding
18,895,872 parameters in encoder layers
25,190,400 parameters in decoder layers
512 parameters in encoder norm layer
512 parameters in decoder norm layer
15,360,000 parameters in output projection
Model initialized!
Total parameters: 59,447,296


### Load model from checkpoint

In [5]:
state_dict = torch.load("./models/aiayn_base_100k.pt", map_location=DEVICE)['model_state_dict']
new_state_dict = {
    k.replace("_orig_mod.", ""): v
    for k, v in state_dict.items()
}
model.load_state_dict(new_state_dict)
model.eval()

TranslationTransformer(
  (src_embedding): WordEmbedding(
    (embedding): Embedding(30000, 512, padding_idx=0)
  )
  (tgt_embedding): WordEmbedding(
    (embedding): Embedding(30000, 512, padding_idx=0)
  )
  (positional_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder_layers): ModuleList(
    (0-5): 6 x TransformerEncoderLayer(
      (self_attn): MultiHeadAttention(
        (qkv_proj): Linear(in_features=512, out_features=1536, bias=False)
        (out_proj): Linear(in_features=512, out_features=512, bias=False)
      )
      (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (ff): Sequential(
        (0): Linear(in_features=512, out_features=2048, bias=True)
        (1): ReLU()
        (2): Dropout(p=0.1, inplace=False)
        (3): Linear(in_features=2048, out_features=512, bias=True)
      )
      (dropout1): Dropout(p=0.1, inplace=False)
      (dropout2

In [6]:
# Apply PyTorch optimizations
import torch

# 1. Enable TF32 for faster matmul on Ampere+ GPUs (A100, RTX 3090, etc.)
# This provides ~2x speedup for matrix multiplications with minimal accuracy loss
# torch.set_float32_matmul_precision('high')  # Options: 'highest', 'high', 'medium'
# torch.backends.fp32_precision = 'tf32'

# 2. For MPS (Apple Silicon), ensure we're using optimal settings
if DEVICE.type == "mps":
    # MPS backend is already optimized, but we can ensure memory efficiency
    torch.mps.empty_cache()  # Clear any cached memory
elif DEVICE.type == "cuda":
    # Enable TF32 for cuDNN convolutions as well
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
else:
    print("✓ Running on CPU (no GPU optimizations)")

model_compiled = torch.compile(model, mode='default')  # Options: 'default', 'reduce-overhead', 'max-autotune'

# Move model to device (GPU if available)
model = model.to(DEVICE)
model.train()

print(f"Using device: {DEVICE}")
print(f"Model moved to {DEVICE}")

Using device: mps
Model moved to mps


In [7]:
import torch.nn as nn
import torch.optim as optim
from utils.train import train
from torch.optim.lr_scheduler import LambdaLR


def lr_lambda(step, warmup_steps=4000):
    """Learning rate schedule with warmup and decay."""
    step = max(step, 1)
    return configBig.d_model**(-0.5) * min(
        step ** -0.5,
        step * warmup_steps ** -1.5
    )

# Loss function and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_idx, label_smoothing=label_smoothing)
optimizer = optim.Adam(model.parameters(), lr=1, betas=betas, eps=epsilon)

scheduler = LambdaLR(optimizer, lambda step: lr_lambda(step, warmup_steps))

# Training
train_losses, best_loss = train(
    model=model_compiled,
    config=configBig,
    train_loader=train_loader,
    test_loader=test_loader,
    dataset_size=len(train_ds),
    criterion=criterion, 
    optimizer=optimizer,
    scheduler=scheduler,
    device=DEVICE,
    num_steps=num_steps,
    eval_iters=eval_iters,
    patience=patience
)

Starting training for 20,000 steps...
Total batches per epoch: 70,450
Dataset size: 4,508,785 samples
Learning rate: 0.000000 (with warmup and decay)


AttributeError: Can't get local object 'DataLoaderFactory.create_dataloader.<locals>.collate'

In [7]:
@torch.no_grad()
def translate_sample(sentence: str, model: TranslationTransformer,
                     src_tokenizer: HFTokenizerWrapper, tgt_tokenizer: HFTokenizerWrapper, max_len=100, device=DEVICE):
    """
    Translate a single sentence using the model with autoregressive generation.
    
    Args:
        sentence: Input sentence to translate
        model: TranslationTransformer model
        src_tokenizer: HFTokenizerWrapper for the source language
        tgt_tokenizer: HFTokenizerWrapper for the target language
        max_len: Maximum sequence length to generate
        device: Device to run on
        
    Returns:
        Translated sentence and token indices
    """
    model.eval()
    
    # Tokenize input
    src_tokens = src_tokenizer.tokenize(sentence)
    # print(f"Input tokens: {src_tokens}")
    
    # Encode with EOS token only (source side)
    src_indices = src_tokenizer.encode(src_tokens, add_sos=False, add_eos=True)
    src_tensor = torch.tensor(src_indices, dtype=torch.long).unsqueeze(0).to(device)  # [1, seq_len]
    
    # Create source padding mask (all False since no padding in single sentence)
    src_key_padding_mask = torch.zeros(1, src_tensor.size(1), dtype=torch.bool).to(device)
    
    # print(f"Input tensor shape: {src_tensor.shape}")
    # print(f"Input indices: {src_indices}")
    
    # Initialize target with just SOS token
    tgt_indices = [tgt_tokenizer.sos_idx]
    
    # Autoregressive generation loop
    for _ in range(max_len):
        # Convert current target indices to tensor
        tgt_tensor = torch.tensor([tgt_indices], dtype=torch.long).to(device)  # [1, current_len]
        
        # Create target padding mask (all False since we're only generating, no padding)
        tgt_key_padding_mask = torch.zeros(1, tgt_tensor.size(1), dtype=torch.bool).to(device)
        
        # Forward pass with masks
        output = model(src_tensor, tgt_tensor, 
                      src_key_padding_mask=src_key_padding_mask,
                      tgt_key_padding_mask=tgt_key_padding_mask)  # [1, current_len, vocab_size]
        
        # Get prediction for the last token
        next_token_logits = output[0, -1, :]  # [vocab_size]
        next_token = torch.argmax(next_token_logits).item()
        
        # Append predicted token
        tgt_indices.append(next_token)
        
        # Stop if we predict EOS token
        if next_token == tgt_tokenizer.eos_idx:
            break
    
    # print(f"Generated {len(tgt_indices)} tokens")
    # print(f"Predicted indices: {tgt_indices}")
    
    # Decode back to tokens (skip SOS and EOS)
    translation = tgt_tokenizer.decode_to_text(tgt_indices)  # Remove SOS and EOS
        
    return translation, tgt_indices


In [11]:
# Test translation on sample input
import sacrebleu

# Test with a sample German sentence from the dataset
total_bleu = 0.0
print_enabled = True
samples = 0

for batch in test_loader:
    
    batch_de, batch_en, _, _ = batch
    samples += len(batch_de)

    for idx, (de, en) in enumerate(zip(batch_de, batch_en)):
        sample_de = tokenizer.decode_to_text(de.tolist())
        sample_en = tokenizer.decode_to_text(en.tolist())

        translation, _ = translate_sample(
            sample_de, 
            model, 
            src_tokenizer=tokenizer,
            tgt_tokenizer=tokenizer,
            max_len=dataset_max_sample_len,
            device=DEVICE
        )
        BLEUscore = sacrebleu.corpus_bleu([translation], [[sample_en]])
        total_bleu += BLEUscore.score

        if print_enabled:
            print(f"Original German: {sample_de}")
            print(f"Model Translation: '{translation}'")
            print(f"Reference Translation: {sample_en}")
            print(f"BLEU Score: {BLEUscore.score:.4f}")

print(f"\nAverage BLEU Score over {samples} samples: {(total_bleu / samples):.4f}")

Original German: Gutach: Noch mehr Sicherheit für Fußgänger
Model Translation: 'Good: even more security for pedestrians'
Reference Translation: Gutach: Increased safety for pedestrians
BLEU Score: 14.5358
Original German: Sie stehen keine 100 Meter voneinander entfernt: Am Dienstag ist in Gutach die neue B 33-Fußgängerampel am Dorfparkplatz in Betrieb genommen worden - in Sichtweite der älteren Rathausampel.
Model Translation: 'The hotel is located in the centre of the old town of B 33.'
Reference Translation: They are not even 100 metres apart: On Tuesday, the new B 33 pedestrian lights in Dorfparkplatz in Gutach became operational - within view of the existing Town Hall traffic lights.
BLEU Score: 2.4089
Original German: Zwei Anlagen so nah beieinander: Absicht oder Schildbürgerstreich?
Model Translation: 'Two plants so close to beieinander: intention or a child?'
Reference Translation: Two sets of lights so close to one another: intentional or just a silly error?
BLEU Score: 10.619

KeyboardInterrupt: 

In [10]:

translation, _ = translate_sample(
    'Ich liebe meinen Jungen Freund über alles.',
    model, 
    src_tokenizer=tokenizer,
    tgt_tokenizer=tokenizer,
    max_len=dataset_max_sample_len,
    device=DEVICE
)
translation

'I love my boyfriend about everything.'