In [1]:
%load_ext autoreload
%autoreload 2

In [15]:
from datasets import load_dataset
import torch
from utils.translation_transformer import TransformerConfig

DEVICE = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)
print(f"Using device: {DEVICE}")


ds = load_dataset("wmt/wmt14", "de-en")

vocab_size = 30_000
vocab_path = "./data/bpe_tokenizer.json"

training_samples = len(ds["train"])
batch_size = 64
dataset_max_sample_len = 100

sharedVocab = True
# bpe_v3_ep12
configSmall = TransformerConfig(
    d_model=256,
    nhead=8,
    num_encoder_layers=4,
    num_decoder_layers=4,
    dim_feedforward=1024,
    dropout=0.1,
    max_len=150,
)
# base model according to the paper 'Attention is all you need'
# big_3.8770loss
configBig = TransformerConfig(
    d_model=512,
    nhead=8,
    num_encoder_layers=6,
    num_decoder_layers=6,
    dim_feedforward=2048,
    dropout=0.1,
    max_len=150,
)

# training
num_steps = 20_000
warmup_steps = 2_000
eval_iters = 10
patience = 1_000

label_smoothing = 0.1

# optimizer
start_lr = 3e-4
betas = (0.9, 0.98)
epsilon = 1e-9

Using device: mps


In [16]:
from tokenizers import Tokenizer as HFTokenizer, decoders
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Metaspace
from tokenizers.processors import TemplateProcessing
from utils.tokenization_vocab import HFTokenizerWrapper, Tokenizer
from pathlib import Path

bpe_tokenizer = HFTokenizer(BPE(unk_token=Tokenizer.UNK_TOKEN))
trainer = BpeTrainer(
    special_tokens=[
        Tokenizer.PAD_TOKEN,
        Tokenizer.SOS_TOKEN,
        Tokenizer.EOS_TOKEN,
        Tokenizer.UNK_TOKEN,
    ],
    vocab_size=vocab_size,
    show_progress=True,
)

bpe_tokenizer.pre_tokenizer = Metaspace()
bpe_tokenizer.decoder = decoders.Metaspace()

pretrained = True  # Set to True if you want to load a previously saved tokenizer

Path(vocab_path).parent.mkdir(parents=True, exist_ok=True)

if Path(vocab_path).is_file():
    pretrained = True

if pretrained:
    bpe_tokenizer = HFTokenizer.from_file(vocab_path)
    
    bpe_tokenizer.post_processor = TemplateProcessing(
    single=f"{Tokenizer.SOS_TOKEN} $A {Tokenizer.EOS_TOKEN}",
    special_tokens=[
        (Tokenizer.SOS_TOKEN, bpe_tokenizer.token_to_id(Tokenizer.SOS_TOKEN)),
        (Tokenizer.EOS_TOKEN, bpe_tokenizer.token_to_id(Tokenizer.EOS_TOKEN)),
    ],
)
else:
    bpe_tokenizer.train(
        [
            "./datasets/wmt14_translate_de-en_test.csv",
            "./datasets/wmt14_translate_de-en_train.csv",
            "./datasets/wmt14_translate_de-en_validation.csv",
        ],
        trainer=trainer,
    )

    bpe_tokenizer.save(vocab_path)


tokenizer = HFTokenizerWrapper(bpe_tokenizer)

print(f"Vocab size: {bpe_tokenizer.get_vocab_size()}")

Vocab size: 30000


In [97]:
def tokenize_batch(examples):
    inputs = [e["de"] for e in examples['translation']]
    targets = [e["en"] for e in examples['translation']]
    input_encodings = bpe_tokenizer.encode_batch(inputs)
    target_encodings = bpe_tokenizer.encode_batch(targets)
    return {
        "src": [enc.ids[1:] for enc in input_encodings], # remove sos token
        "tgt": [enc.ids for enc in target_encodings], # keep sos token
    }

tokenized_ds = ds.map(
    tokenize_batch,
    batched=True,
    num_proc=8,
    remove_columns=["translation"],
    load_from_cache_file=True,
)
tokenized_ds.set_format(type='torch')

In [98]:
from torch.utils.data import DataLoader
from utils.parallel_corpus import collate_fn

dl_train = DataLoader(
    tokenized_ds['train'],
    batch_size=batch_size,
    collate_fn=lambda x: collate_fn(x, tokenizer.pad_idx),
    shuffle=True,
)
dl_test = DataLoader(
    tokenized_ds['test'],
    batch_size=batch_size,
    collate_fn=lambda x: collate_fn(x, tokenizer.pad_idx),
    shuffle=True,
)

print(f"Train samples: {len(tokenized_ds['train']):,}, Test samples: {len(tokenized_ds['test']):,}")
print(f"Train batches: {len(dl_train):,}, Test batches: {len(dl_test):,}")

Train samples: 4,508,785, Test samples: 3,003
Train batches: 70,450, Test batches: 47


In [99]:
from utils.parallel_corpus import (
    TranslationDataset,
    DataLoaderFactory,
    LazyTranslationPairs,
)
from utils.tokenization_vocab import HFTokenizerWrapper
import os

# Create lazy wrappers - no materialization into lists!
train_src = LazyTranslationPairs(ds["train"], src_lang="de", tgt_lang="en", mode="src")
train_tgt = LazyTranslationPairs(ds["train"], src_lang="de", tgt_lang="en", mode="tgt")

test_src = LazyTranslationPairs(ds["test"], src_lang="de", tgt_lang="en", mode="src")
test_tgt = LazyTranslationPairs(ds["test"], src_lang="de", tgt_lang="en", mode="tgt")

tokenizer = HFTokenizerWrapper(bpe_tokenizer)

# Create datasets with lazy loading (processes on-the-fly, no upfront preprocessing)
train_ds = TranslationDataset(
    source_sentences=train_src,
    target_sentences=train_tgt,
    source_tokenizer=tokenizer,
    target_tokenizer=tokenizer,
    max_length=dataset_max_sample_len,
    lazy=True,  # Enable lazy loading!
)

test_ds = TranslationDataset(
    source_sentences=test_src,
    target_sentences=test_tgt,
    source_tokenizer=tokenizer,
    target_tokenizer=tokenizer,
    max_length=dataset_max_sample_len,
    lazy=True,
)

# Optimize num_workers based on CPU cores
optimal_workers = min(8, os.cpu_count() or 4)

train_loader = DataLoaderFactory.create_dataloader(
    dataset=train_ds,
    batch_size=batch_size,
    pad_idx=tokenizer.pad_idx,
    num_workers=optimal_workers,
    shuffle=True,  # Shuffle for training
    persistent_workers=True,  # Keep workers alive between epochs
    prefetch_factor=4,  # Prefetch more batches
)

test_loader = DataLoaderFactory.create_dataloader(
    dataset=test_ds,
    batch_size=batch_size,
    pad_idx=tokenizer.pad_idx,
    num_workers=0,
    shuffle=False,  # No shuffle for testing
    persistent_workers=True,
    prefetch_factor=4,
)

print(f"✓ Lazy loading enabled - no memory materialization!")
print(f"✓ Using {optimal_workers} workers for parallel processing")
print(f"Train samples: {len(train_ds):,}, Test samples: {len(test_ds):,}")
print(f"Train batches: {len(train_loader):,}, Test batches: {len(test_loader):,}")

Initialized lazy dataset with 4508785 sentence pairs
Initialized lazy dataset with 3003 sentence pairs
✓ Lazy loading enabled - no memory materialization!
✓ Using 8 workers for parallel processing
Train samples: 4,508,785, Test samples: 3,003
Train batches: 70,450, Test batches: 47


In [89]:
# Import the TranslationTransformer
from utils.translation_transformer import (
    TranslationTransformer,
    TranslationTransformerPytorch,
)

# Initialize the model with larger max_len to handle max_length + special tokens
model = TranslationTransformer(
    src_vocab_size=len(tokenizer),
    tgt_vocab_size=len(tokenizer),
    config=configBig,
    padding_idx=tokenizer.pad_idx,
    sharedVocab=sharedVocab,
)

print(f"Model initialized!")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

15,360,000 parameters in shared embedding
18,895,872 parameters in encoder layers
25,190,400 parameters in decoder layers
512 parameters in encoder norm layer
512 parameters in decoder norm layer
15,360,000 parameters in output projection
Model initialized!
Total parameters: 59,447,296


### Load model from checkpoint

In [90]:
state_dict = torch.load("./models/aiayn_base_100k.pt", map_location=DEVICE)[
    "model_state_dict"
]
new_state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
model.load_state_dict(new_state_dict)
model.eval()

TranslationTransformer(
  (src_embedding): WordEmbedding(
    (embedding): Embedding(30000, 512, padding_idx=0)
  )
  (tgt_embedding): WordEmbedding(
    (embedding): Embedding(30000, 512, padding_idx=0)
  )
  (positional_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder_layers): ModuleList(
    (0-5): 6 x TransformerEncoderLayer(
      (self_attn): MultiHeadAttention(
        (qkv_proj): Linear(in_features=512, out_features=1536, bias=False)
        (out_proj): Linear(in_features=512, out_features=512, bias=False)
      )
      (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (ff): Sequential(
        (0): Linear(in_features=512, out_features=2048, bias=True)
        (1): ReLU()
        (2): Dropout(p=0.1, inplace=False)
        (3): Linear(in_features=2048, out_features=512, bias=True)
      )
      (dropout1): Dropout(p=0.1, inplace=False)
      (dropout2

In [91]:
# Apply PyTorch optimizations
import torch

# 1. Enable TF32 for faster matmul on Ampere+ GPUs (A100, RTX 3090, etc.)
# This provides ~2x speedup for matrix multiplications with minimal accuracy loss
# torch.set_float32_matmul_precision('high')  # Options: 'highest', 'high', 'medium'
# torch.backends.fp32_precision = 'tf32'

# 2. For MPS (Apple Silicon), ensure we're using optimal settings
if DEVICE.type == "mps":
    # MPS backend is already optimized, but we can ensure memory efficiency
    torch.mps.empty_cache()  # Clear any cached memory
elif DEVICE.type == "cuda":
    # Enable TF32 for cuDNN convolutions as well
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
else:
    print("✓ Running on CPU (no GPU optimizations)")

model_compiled = torch.compile(
    model, mode="default"
)  # Options: 'default', 'reduce-overhead', 'max-autotune'

# Move model to device (GPU if available)
model = model.to(DEVICE)
model.train()

print(f"Using device: {DEVICE}")
print(f"Model moved to {DEVICE}")

Using device: mps
Model moved to mps


In [None]:
import torch.nn as nn
import torch.optim as optim
from utils.train import train
from torch.optim.lr_scheduler import LambdaLR


def lr_lambda(step, warmup_steps=4000):
    """Learning rate schedule with warmup and decay."""
    step = max(step, 1)
    return configBig.d_model ** (-0.5) * min(step**-0.5, step * warmup_steps**-1.5)


# Loss function and optimizer
criterion = nn.CrossEntropyLoss(
    ignore_index=tokenizer.pad_idx, label_smoothing=label_smoothing
)
optimizer = optim.Adam(model.parameters(), lr=1, betas=betas, eps=epsilon)

scheduler = LambdaLR(optimizer, lambda step: lr_lambda(step, warmup_steps))

# Training
train_losses, best_loss = train(
    model=model_compiled,
    config=configBig,
    train_loader=train_loader,
    test_loader=test_loader,
    dataset_size=len(train_ds),
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    device=DEVICE,
    num_steps=num_steps,
    eval_iters=eval_iters,
    patience=patience,
)

In [101]:
from utils.train import estimate_loss
import torch.nn as nn

criterion = nn.CrossEntropyLoss(
    ignore_index=tokenizer.pad_idx, label_smoothing=label_smoothing
)

estimate_loss(
    model=model,
    test_loader=dl_test,
    criterion=criterion,
    device=DEVICE,
    eval_iters=len(dl_test),
    print_enabled=True,
)

Eval batch 1/47, Loss: 3.4141
Eval batch 2/47, Loss: 3.3881
Eval batch 3/47, Loss: 3.2288


KeyboardInterrupt: 

In [None]:
sentence = "Die gestern offiziell in Betrieb genommene Anlage sei wichtig für den Kreuzungsbereich Sulzbachweg/Kirchstraße."
input_sequence = torch.tensor([bpe_tokenizer.encode(sentence).ids[1:]], device=DEVICE) # Exclude SOS token

tensor([[ 4260, 15887, 21852,  3796, 10526, 11533,    75, 14199,  8175,  5118,
          3944,  3903, 17183, 14785, 26956,    96, 16863,  6297, 25049,  3841,
          3777, 14790,    20,     2]], device='mps:0')

In [47]:
from utils.inference import beam_search

tgt_seq = beam_search(
    model=model,
    input_sequence=input_sequence,
    sos=tokenizer.sos_idx,
    eos=tokenizer.eos_idx,
    beam_width=3,
    length_penalty=1.0,
    repetition_penalty=1.0,
    device=DEVICE,
)

print(bpe_tokenizer.decode(tgt_seq))

The plant officially commissioned yesterday is important for the Sulzbachweg/Kirchstraße crossroad.


In [50]:
from utils.inference import greedy_translate

tgt_seq = greedy_translate(
    model=model,
    input_sequence=input_sequence,
    src_tokenizer=tokenizer,
    tgt_tokenizer=tokenizer,
    max_len=dataset_max_sample_len,
    device=DEVICE
)
print(bpe_tokenizer.decode(tgt_seq))

The plant officially commissioned yesterday is important for the Sulzbachweg/Kirchstraße crossroad.


In [54]:
# Test translation on sample input
import sacrebleu

# Test with a sample German sentence from the dataset
total_bleu = 0.0
eval_iters = 1
print_enabled = True
samples = 0

for k, batch in enumerate(test_loader):
    if k >= eval_iters:
        break

    batch_de, batch_en, _, _ = batch
    samples += len(batch_de)

    for idx, (de, en) in enumerate(zip(batch_de, batch_en)):
        sample_de = tokenizer.decode_to_text(de.tolist())
        sample_en = tokenizer.decode_to_text(en.tolist())

        tgt_seq = greedy_translate(
            model,
            input_sequence=torch.tensor(
                [bpe_tokenizer.encode(sample_de).ids[1:]], device=DEVICE
            ),
            src_tokenizer=tokenizer,
            tgt_tokenizer=tokenizer,
            max_len=dataset_max_sample_len,
            device=DEVICE,
        )

        translation = tokenizer.decode_to_text(tgt_seq)
        BLEUscore = sacrebleu.corpus_bleu([translation], [[sample_en]])
        total_bleu += BLEUscore.score

        if print_enabled:
            print(f"{k*len(batch_de)+idx+1} / {len(batch_de)*eval_iters}")

print(f"\nAverage BLEU Score over {samples} samples: {(total_bleu / samples):.4f}")

1 / 64
2 / 64
3 / 64
4 / 64
5 / 64
6 / 64
7 / 64
8 / 64
9 / 64
10 / 64
11 / 64
12 / 64
13 / 64
14 / 64
15 / 64
16 / 64
17 / 64
18 / 64
19 / 64
20 / 64
21 / 64
22 / 64
23 / 64
24 / 64
25 / 64
26 / 64
27 / 64
28 / 64
29 / 64
30 / 64
31 / 64
32 / 64
33 / 64
34 / 64
35 / 64
36 / 64
37 / 64
38 / 64
39 / 64
40 / 64
41 / 64
42 / 64
43 / 64
44 / 64
45 / 64
46 / 64
47 / 64
48 / 64
49 / 64
50 / 64
51 / 64
52 / 64
53 / 64
54 / 64
55 / 64
56 / 64
57 / 64
58 / 64
59 / 64
60 / 64
61 / 64
62 / 64
63 / 64
64 / 64

Average BLEU Score over 64 samples: 13.0096


In [None]:
translation, _ = translate_sample(
    "Ich liebe meinen Jungen Freund über alles.",
    model,
    src_tokenizer=tokenizer,
    tgt_tokenizer=tokenizer,
    max_len=dataset_max_sample_len,
    device=DEVICE,
)
translation

'I love my boyfriend about everything.'