<a href="https://colab.research.google.com/github/AlperYildirim1/Attention-is-All-You-Need-Pytorch/blob/main/Attention_is_All_You_Need.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q torchmetrics sacrebleu

## CONFIG

In [None]:
# --- Data & Task Size ---
MAX_LENGTH = 128

MODEL_CHOICE = "YOUR_MODEL_NAME" # For save path

# --- Model Architecture Config ("Transformer-Small") ---
D_MODEL = 512
NUM_HEADS = 8
D_FF = 2048
DROPOUT = 0.1

# --- Layer counts ---
NUM_ENCODER_LAYERS = 6
NUM_DECODER_LAYERS = 6

# --- Training Config ---
TOKEN_LIMIT_PER_BATCH = 20000
TARGET_TRAINING_STEPS = 100000
VALIDATION_EVERY_N_STEPS = 4000
PEAK_LEARNING_RATE = 5e-4
WARMUP_STEPS = 4000
WEIGHT_DECAY = 0.01

# --- Regularization Config ---
LABEL_SMOOTHING_EPSILON = 0.1
RECONSTRUCTION_LOSS_WEIGHT = 0.1

# --- Other Constants ---
DRIVE_BASE_PATH = "/content/YOUR_PREFERRED_PATH"
PREBATCHED_REPO_ID = "Yujivus/wmt14-de-en-prebatched-w4" # IMPORTANT
ORIGINAL_BUCKETED_REPO_ID = "Yujivus/wmt14-de-en-bucketed-w4"
MODEL_CHECKPOINT = "Helsinki-NLP/opus-mt-de-en" # We only use its tokenizer
LABEL_SMOOTHING_EPSILON = 0.1

## DATALOADERS

In [None]:

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from datasets import load_dataset
import math
import os
from tqdm.auto import tqdm
from torchmetrics.text import BLEUScore
from torch.utils.tensorboard import SummaryWriter
import random
import numpy as np
import torch
from transformers import get_cosine_schedule_with_warmup
from typing import List

def set_seed(seed_value=5):
    """Sets the seed for reproducibility."""
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

SEED = 5
set_seed(SEED)
print(f"Reproducibility seed set to {SEED}")
print("--- Loading Modernized Configuration ---")

def seed_worker(worker_id):
    """
    DataLoader worker'ları için seed ayarlama fonksiyonu.
    Her worker'ın farklı ama deterministik bir seed'e sahip olmasını sağlar.
    """
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

torch.set_float32_matmul_precision('high')
print("✅ PyTorch matmul precision set to 'high'")

# --- Device Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT)
print(f"Original tokenizer vocabulary size: {len(tokenizer)}")

current_vocab_size = len(tokenizer)

# Pad to the nearest multiple of 8 for GPU Tensor Core efficiency
padded_vocab_size = (current_vocab_size + 7) & (-8)
num_added_tokens = padded_vocab_size - current_vocab_size

if num_added_tokens > 0:
    print(f"Adding {num_added_tokens} dummy tokens to reach a multiple of 8.")
    tokenizer.add_tokens([f"[PAD_{i+1}]" for i in range(num_added_tokens)])

# CRITICAL: This sets the correct, padded vocab size for the models
VOCAB_SIZE = len(tokenizer)
print(f"New, padded vocabulary size for model: {VOCAB_SIZE}")


# DATA LOADING & PREPARATION
from transformers import DataCollatorForSeq2Seq

standard_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer)

class PreBatchedCollator:
    def __init__(self, original_dataset_split):
        self.original_dataset = original_dataset_split

    def __call__(self, features: List[dict]) -> dict:
        # 'features' will be a list of size 1, e.g., [{'batch_indices': [10, 5, 123]}]
        batch_indices = features[0]['batch_indices']

        # This returns a "Dictionary of Lists"
        # e.g., {'input_ids': [[...], [...]], 'labels': [[...], [...]]}
        dict_of_lists = self.original_dataset[batch_indices]

        # --- THE FIX ---
        # We must convert it to a "List of Dictionaries" for the standard collator.
        # e.g., [{'input_ids': [...], 'labels': [...]}, {'input_ids': [...], 'labels': [...]}]
        list_of_dicts = []
        keys = dict_of_lists.keys()
        num_samples = len(dict_of_lists['input_ids'])

        for i in range(num_samples):
            list_of_dicts.append({key: dict_of_lists[key][i] for key in keys})
        # --- END OF FIX ---

        # Now, pass the correctly formatted data to the standard collator
        return standard_collator(list_of_dicts)

print(f"Loading pre-batched dataset from: {PREBATCHED_REPO_ID}")
prebatched_datasets = load_dataset(PREBATCHED_REPO_ID)

print(f"Loading original samples from: {ORIGINAL_BUCKETED_REPO_ID}")
original_datasets = load_dataset(ORIGINAL_BUCKETED_REPO_ID)
train_collator = PreBatchedCollator(original_datasets["train"])

# --- The New, Simple DataLoader ---
# No more custom sampler!
g = torch.Generator()
g.manual_seed(SEED)

train_dataloader = DataLoader(
    prebatched_datasets["train"],
    batch_size=1,  # Each row is already a batch
    shuffle=True,  # Shuffle the pre-calculated batches every epoch
    num_workers=4,
    collate_fn=train_collator,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g,
)

# Validation loader remains the same, using the original data
EVAL_BATCH_SIZE = 64
val_dataloader = DataLoader(
    original_datasets["validation"],
    batch_size=EVAL_BATCH_SIZE,
    collate_fn=standard_collator,
    num_workers=4,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g,
)

print("\n--- ✅ ULTIMATE DATALOADERS are ready ---")
print(f"Train Dataloader is now a simple iterator over pre-calculated batches.")

# --- SANITY CHECK ---
print("\n--- Running Sanity Check on new DataLoader ---")
train_dataloader.generator.manual_seed(SEED) # Reset generator for check
temp_iterator = iter(train_dataloader)
print("Shapes of first 5 batches:")
for i in range(5):
    batch = next(temp_iterator)
    print(f"  Batch {i+1}: input_ids shape = {batch['input_ids'].shape}")
print("--- Sanity Check Complete ---\n")

##  Models

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class PositionalEncoding(nn.Module):
    """Injects positional information into the input embeddings."""
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor):
        # x shape: [batch_size, seq_len, d_model]
        return x + self.pe[:, :x.size(1)]

class FeedForward(nn.Module):
    """A standard two-layer feed-forward network with a ReLU activation."""
    def __init__(self, d_model: int, dff: int, dropout_rate: float = 0.1):
        super().__init__()
        self.ffn = nn.Sequential(
            nn.Linear(d_model, dff),
            nn.ReLU(),
            nn.Linear(dff, d_model),
            nn.Dropout(dropout_rate)
        )
    def forward(self, x: torch.Tensor):
        return self.ffn(x)

class StandardTransformer(nn.Module):
    def __init__(self, num_encoder_layers, num_decoder_layers, num_heads, d_model, dff, vocab_size, max_length, dropout):
        super().__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_length)
        self.dropout = nn.Dropout(dropout)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model, num_heads, dff, dropout, batch_first=True, norm_first=True # <-- THE FIX
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)

        decoder_layer = nn.TransformerDecoderLayer(
            d_model, num_heads, dff, dropout, batch_first=True, norm_first=True # <-- THE FIX
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

        self.final_linear = nn.Linear(d_model, vocab_size)
        self.final_linear.weight = self.embedding.weight

    def forward(self, src, tgt, src_padding_mask, tgt_padding_mask, memory_key_padding_mask, tgt_mask):

        src_emb = self.embedding(src) * math.sqrt(self.d_model)
        tgt_emb = self.embedding(tgt) * math.sqrt(self.d_model)
        src_emb_pos = self.dropout(self.pos_encoder(src_emb))
        tgt_emb_pos = self.dropout(self.pos_encoder(tgt_emb))

        memory = self.encoder(src_emb_pos, src_key_padding_mask=src_padding_mask)
        decoder_output = self.decoder(
            tgt=tgt_emb_pos, memory=memory, tgt_mask=tgt_mask,
            tgt_key_padding_mask=tgt_padding_mask, memory_key_padding_mask=memory_key_padding_mask
        )
        return self.final_linear(decoder_output)


    def create_masks(self, src, tgt):
        src_padding_mask = (src == tokenizer.pad_token_id)
        tgt_padding_mask = (tgt == tokenizer.pad_token_id)
        # Creates a square causal mask for the decoder. This prevents any token from attending to future tokens. With this way model can not cheat.
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(
            sz=tgt.size(1),
            device=src.device,
            dtype=torch.bool
        )
        return src_padding_mask, tgt_padding_mask, src_padding_mask, tgt_mask

    @torch.no_grad()
    def generate(self, src: torch.Tensor, max_length: int, num_beams: int = 5) -> torch.Tensor:
        self.eval()
        src_padding_mask = (src == tokenizer.pad_token_id)

        src_emb = self.embedding(src) * math.sqrt(self.d_model)
        src_emb_pos = self.pos_encoder(src_emb)
        memory = self.encoder(self.dropout(src_emb_pos), src_key_padding_mask=src_padding_mask)

        batch_size = src.shape[0]
        memory = memory.repeat_interleave(num_beams, dim=0)
        memory_key_padding_mask = src_padding_mask.repeat_interleave(num_beams, dim=0)

        initial_token = tokenizer.pad_token_id
        beams = torch.full((batch_size * num_beams, 1), initial_token, dtype=torch.long, device=src.device)

        beam_scores = torch.zeros(batch_size * num_beams, device=src.device)
        finished_beams = torch.zeros(batch_size * num_beams, dtype=torch.bool, device=src.device)
        for _ in range(max_length - 1):
            if finished_beams.all(): break
            tgt_mask = nn.Transformer.generate_square_subsequent_mask(beams.size(1)).to(src.device)
            tgt_emb = self.embedding(beams) * math.sqrt(self.d_model) # FIX HERE TOO
            tgt_emb_pos = self.pos_encoder(tgt_emb)
            decoder_output = self.decoder(tgt=self.dropout(tgt_emb_pos), memory=memory, tgt_mask=tgt_mask, memory_key_padding_mask=memory_key_padding_mask)
            logits = self.final_linear(decoder_output[:, -1, :])
            log_probs = F.log_softmax(logits, dim=-1)
            log_probs[:, tokenizer.pad_token_id] = -torch.inf
            if finished_beams.any(): log_probs[finished_beams, tokenizer.eos_token_id] = 0
            total_scores = beam_scores.unsqueeze(1) + log_probs
            if _ == 0:
                total_scores = total_scores.view(batch_size, num_beams, -1)
                total_scores[:, 1:, :] = -torch.inf # Sadece ilk beam'in başlamasına izin ver
                total_scores = total_scores.view(batch_size * num_beams, -1)
            else:
                total_scores = beam_scores.unsqueeze(1) + log_probs
            total_scores = total_scores.view(batch_size, -1)
            top_scores, top_indices = torch.topk(total_scores, k=num_beams, dim=1)
            beam_indices = top_indices // log_probs.shape[-1]; token_indices = top_indices % log_probs.shape[-1]
            batch_indices = torch.arange(batch_size, device=src.device).unsqueeze(1)
            effective_indices = (batch_indices * num_beams + beam_indices).view(-1)
            beams = beams[effective_indices]
            beams = torch.cat([beams, token_indices.view(-1, 1)], dim=1)
            beam_scores = top_scores.view(-1)
            finished_beams = finished_beams | (beams[:, -1] == tokenizer.eos_token_id)
        final_beams = beams.view(batch_size, num_beams, -1)
        final_scores = beam_scores.view(batch_size, num_beams)
        normalized_scores = final_scores / (final_beams != tokenizer.pad_token_id).sum(-1).float().clamp(min=1)
        best_beams = final_beams[torch.arange(batch_size), normalized_scores.argmax(1), :]
        self.train()
        return best_beams


In [None]:
# ==============================================================================
# --- Model Analysis & Parameter Counting ---
# ==============================================================================
from collections import defaultdict

def count_parameters_correctly(model):
    """
    Counts trainable parameters, correctly handling tied weights (e.g., embeddings).
    """
    seen_params = set()
    total_params = 0
    for param in model.parameters():
        if param.requires_grad:
            param_id = id(param)
            if param_id not in seen_params:
                seen_params.add(param_id)
                total_params += param.numel()
    return total_params

# --- Instantiate the model to analyze it ---
print("--- Analyzing Model Parameters ---")
model_to_analyze = StandardTransformer(
    num_encoder_layers=NUM_ENCODER_LAYERS,
    num_decoder_layers=NUM_DECODER_LAYERS,
    num_heads=NUM_HEADS,
    d_model=D_MODEL,
    dff=D_FF,
    vocab_size=VOCAB_SIZE,
    max_length=MAX_LENGTH,
    dropout=DROPOUT
)

# --- Perform the counting and display results ---
correct_total = count_parameters_correctly(model_to_analyze)
pytorch_naive_total = sum(p.numel() for p in model_to_analyze.parameters() if p.requires_grad)

print(f"Total Trainable Parameters (Correctly Counted): {correct_total:,}")
print(f"PyTorch's Naive Count (sum(p.numel())):        {pytorch_naive_total:,}")
if pytorch_naive_total != correct_total:
    print(f"Note: The naive count is higher due to double-counting the tied embedding weights.")

del model_to_analyze # Clean up memory
print("--- Analysis Complete ---\n")

## Functions (Loss, Eval etc)

In [None]:

translation_loss_fn = nn.CrossEntropyLoss(
    ignore_index=-100,  # We don't calculate loss for pad tokens. Pad tokens are replaced with -100 by DataCollatorForSeq2Seq.
    label_smoothing=LABEL_SMOOTHING_EPSILON
)
def calculate_combined_loss(model_outputs, target_labels):
    """Calculates the loss based on the model's output structure."""
    logits = model_outputs
    translation_loss = translation_loss_fn(logits.reshape(-1, logits.shape[-1]), target_labels.reshape(-1))
    loss_dict = {'total': translation_loss.item()}
    return translation_loss, loss_dict

def evaluate(model, dataloader, device):
    """Evaluates the model using beam search decoding."""
    bleu_metric = BLEUScore()


    orig_model = getattr(model, '_orig_mod', model)
    orig_model.eval()

    for batch in tqdm(dataloader, desc="Evaluating", leave=False):
        input_ids = batch['input_ids'].to(device)
        labels = batch['labels']

        generated_ids = orig_model.generate(input_ids, max_length=MAX_LENGTH, num_beams=5)

        pred_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        labels[labels == -100] = tokenizer.pad_token_id
        ref_texts = tokenizer.batch_decode(labels, skip_special_tokens=True)
        bleu_metric.update(pred_texts, [[ref] for ref in ref_texts])

    orig_model.train()
    return bleu_metric.compute().item()

def generate_sample_translations(model, device, sentences_de):
    """Generates and prints sample translations using beam search."""
    print("\n--- Generating Sample Translations (with Beam Search) ---")
    orig_model = getattr(model, '_orig_mod', model)
    orig_model.eval()

    inputs = tokenizer(sentences_de, return_tensors="pt", padding=True, truncation=True, max_length=MAX_LENGTH)
    input_ids = inputs.input_ids.to(device)
    generated_ids = orig_model.generate(input_ids, max_length=MAX_LENGTH, num_beams=5)

    translations = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    for src, out in zip(sentences_de, translations):
        print(f"  DE Source: {src}")
        print(f"  EN Output: {out}")
        print("-" * 20)
    orig_model.train()

sample_sentences_de_for_tracking = [
    "Eine Katze sitzt auf der Matte.",
    "Ein Mann in einem roten Hemd liest ein Buch.",
    "Was ist die Hauptstadt von Deutschland?",
    "Ich gehe ins Kino, weil der Film sehr gut ist.",
]

def init_other_linear_weights(m):
    if isinstance(m, nn.Linear):
        # The 'is not' check correctly skips the final_linear layer,
        # leaving its weights tied to the correctly initialized embeddings.
        if m is not getattr(model, '_orig_mod', model).final_linear:
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)




## Training Loop

In [None]:
# This script is controlled by the Master Control Panel in CELL 1.
import torch
import os
from tqdm.auto import tqdm
from torch.utils.tensorboard import SummaryWriter
from torch.profiler import profile, record_function, ProfilerActivity

if __name__ == '__main__':
    experiment_name = f"{MODEL_CHOICE}_large_len{MAX_LENGTH}"
    SAVE_DIR = os.path.join(DRIVE_BASE_PATH, experiment_name, "models")
    LOG_DIR = f"/content/local_logs/{experiment_name}"
    os.makedirs(SAVE_DIR, exist_ok=True)
    os.makedirs(LOG_DIR, exist_ok=True)
    writer = SummaryWriter(LOG_DIR)
    LAST_CHECKPOINT_PATH = os.path.join(SAVE_DIR, "last.pt")

    print("--- LAUNCHING EXPERIMENT ---")
    print(f"  Model Choice: {MODEL_CHOICE}")
    print(f"  Max Sequence Length: {MAX_LENGTH}")
    print(f"  Model Dimension: {D_MODEL}")
    print(f"  Target Steps: {TARGET_TRAINING_STEPS}")
    print("-" * 30)

    print(f"--- Initializing StandardTransformer-Small (Baseline) ---")
    model = StandardTransformer(
        num_encoder_layers=NUM_ENCODER_LAYERS,
        num_decoder_layers=NUM_DECODER_LAYERS,
        num_heads=NUM_HEADS,
        d_model=D_MODEL,
        dff=D_FF,
        vocab_size=VOCAB_SIZE,
        max_length=MAX_LENGTH,
        dropout=DROPOUT
    )

    print("--- Applying corrected weight initialization ---")

    # Because of weight tying, we are isolating and initializing embedding weights first.
    # self.final_linear.weight is a pointer to model.embedding.weight

    model.embedding.weight.data.normal_(mean=0.0, std=D_MODEL**-0.5)
    print(" Embedding layer initialized to N(0, d_model**-0.5)")

    # Then, we apply other weights
    model.apply(init_other_linear_weights)
    print(" Other linear layers initialized to Xavier Uniform.")

    # --- FINAL SANITY CHECK ---
    uncompiled_model = getattr(model, '_orig_mod', model)
    final_weight_std = uncompiled_model.final_linear.weight.std().item()
    target_std = D_MODEL**-0.5
    print("\n--- Verification ---")
    print(f"Std deviation of final linear layer weights: {final_weight_std:.4f}")
    print(f"Target std for embeddings:                 {target_std:.4f}")
    if abs(final_weight_std - target_std) < 0.01:
        print(" SUCCESS: Initialization appears correct.")
    else:
        print(" WARNING: Initialization might be incorrect. Standard deviations do not match.")

    model.to(device)
    print(f"Model '{MODEL_CHOICE}' initialized and moved to {device}.")

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=PEAK_LEARNING_RATE,
        betas=(0.9, 0.98),
        eps=1e-9,
        weight_decay=WEIGHT_DECAY
    )
    print(f"Using AdamW optimizer with weight_decay={WEIGHT_DECAY}.")
    print("Optimizer LR is controlled by the Noam Scheduler.")
    scheduler = get_cosine_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=WARMUP_STEPS,
        num_training_steps=TARGET_TRAINING_STEPS
    )
    print("Using modern linear warmup, linear decay scheduler.")
    scaler = torch.cuda.amp.GradScaler()
    print("Using Automatic Mixed Precision (AMP) with GradScaler.")
    global_step = 0
    best_bleu = 0.0
    start_epoch = 0
    if os.path.exists(LAST_CHECKPOINT_PATH):
        print(f"--- Resuming training from checkpoint: {LAST_CHECKPOINT_PATH} ---")
        checkpoint = torch.load(LAST_CHECKPOINT_PATH)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        global_step = checkpoint.get('global_step', 0)
        best_bleu = checkpoint.get('best_bleu', 0.0)
        start_epoch = checkpoint.get('epoch', 0)
        print(f"Resumed training from optimizer step {global_step}.")
        if 'rng_states' in checkpoint:
            print("--- Restoring RNG states from checkpoint ---")
            rng_states = checkpoint['rng_states']
            torch.set_rng_state(rng_states['torch_rng_state'])
            np.random.set_state(rng_states['numpy_rng_state'])
            random.setstate(rng_states['python_rng_state'])
            torch.cuda.set_rng_state_all(rng_states['cuda_rng_states'])
    else:
        print("--- Starting training from scratch ---")

    print("--- Compiling model for optimized performance... ---")
    model = torch.compile(model, dynamic=True)

    model.train()
    progress_bar = tqdm(total=TARGET_TRAINING_STEPS, desc="Total Progress", initial=global_step)
    g = torch.Generator(); g.manual_seed(SEED)

    for epoch in range(start_epoch, 100):
        g.manual_seed(SEED + epoch)
        train_dataloader.generator.manual_seed(SEED + epoch)
        for i, batch in enumerate(train_dataloader):
            if global_step >= TARGET_TRAINING_STEPS: break

            # Standard PyTorch loop: zero grads, forward, backward, step
            # --- 1. Zero Gradients ---
            optimizer.zero_grad(set_to_none=True)

            # --- 2. Data Preparation ---
            input_ids = batch['input_ids'].to(device, non_blocking=True)
            labels = batch['labels'].to(device, non_blocking=True)

            # Setting first token as pad token for decoder.
            # Helsinki tokenizer uses pad token as start token.
            decoder_start_token = torch.full((labels.shape[0], 1), tokenizer.pad_token_id, dtype=torch.long, device=device)
            decoder_input_ids = torch.cat([decoder_start_token, labels[:, :-1]], dim=1)
            decoder_input_ids[decoder_input_ids == -100] = tokenizer.pad_token_id
            target_labels = labels
            src_padding_mask, tgt_padding_mask, mem_key_padding_mask, tgt_mask = model.create_masks(input_ids, decoder_input_ids)
            # We don't calculate attention for other padings.
            tgt_padding_mask[:, 0] = False

            # --- 3. Forward Pass with Autocast ---
            with torch.autocast(device_type="cuda", dtype=torch.float16):
                model_outputs = model(src=input_ids, tgt=decoder_input_ids, src_padding_mask=src_padding_mask, tgt_padding_mask=tgt_padding_mask, memory_key_padding_mask=mem_key_padding_mask, tgt_mask=tgt_mask)
                loss, loss_components = calculate_combined_loss(model_outputs, target_labels)

            # --- 4. Backward Pass and Optimizer Step with Scaler --->
            scaler.scale(loss).backward()

            # Unscale gradients for clipping
            scaler.unscale_(optimizer)
            total_grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            global_step += 1
            progress_bar.update(1)
            lr = scheduler.get_last_lr()[0]

            if global_step % 10 == 0:
                lr = scheduler.get_last_lr()[0]

                writer.add_scalar('train/loss', loss.item(), global_step) #
                writer.add_scalar('train/learning_rate', lr, global_step)
                writer.add_scalar('train/gradient_norm', total_grad_norm.item(), global_step)

                if 'trans' in loss_components:
                    writer.add_scalar('train/loss_translation', loss_components['trans'], global_step)
                if 'recon' in loss_components:
                    writer.add_scalar('train/loss_reconstruction', loss_components['recon'], global_step)
            progress_bar.set_postfix(
                lr=f"{lr:.2e}",
                grad_norm=f"{total_grad_norm.item():.2f}",
                **loss_components
            )

            if global_step > 0 and global_step % VALIDATION_EVERY_N_STEPS == 0:
                print(f"\n--- Validation at Optimizer Step {global_step} ---")
                bleu_score = evaluate(model, val_dataloader, device)
                writer.add_scalar('validation/bleu', bleu_score, global_step)
                print(f"Validation BLEU Score: {bleu_score:.4f} (Best: {best_bleu:.4f})")
                generate_sample_translations(model, device, sample_sentences_de_for_tracking)

                if bleu_score > best_bleu:
                    best_bleu = bleu_score
                    print(f"🎉 New best BLEU score! Saving best model... 🎉")
                    torch.save(getattr(model, '_orig_mod', model).state_dict(), os.path.join(SAVE_DIR, "best.pt"))
                torch.save({
                    'global_step': global_step,
                    'epoch': epoch,
                    'model_state_dict': uncompiled_model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'best_bleu': best_bleu,
                    'rng_states': {
                        'torch_rng_state': torch.get_rng_state(),
                        'numpy_rng_state': np.random.get_state(),
                        'python_rng_state': random.getstate(),
                        'cuda_rng_states': torch.cuda.get_rng_state_all(),
                    }
                }, LAST_CHECKPOINT_PATH)
                model.train()

        if global_step >= TARGET_TRAINING_STEPS: break

    progress_bar.close()
    writer.close()
    print("\n--- Training finished ---")

    # --- Final Evaluation ---
    print("\n--- Running final evaluation on the best model ---")
    BEST_CHECKPOINT_PATH = os.path.join(SAVE_DIR, "best.pt")
    if os.path.exists(BEST_CHECKPOINT_PATH):
        uncompiled_model = getattr(model, '_orig_mod', model)
        uncompiled_model.load_state_dict(torch.load(BEST_CHECKPOINT_PATH))

        final_bleu = evaluate(model, val_dataloader, device)
        print(f"\n{'*'*20} FINAL RESULTS {'*'*20}")
        print(f"MODEL: {MODEL_CHOICE}")
        print(f"Final Validation BLEU Score on best.pt: {final_bleu:.4f}")
        print(f"{'*'*55}")
    else:
        print("No 'best.pt' checkpoint found. Could not run final evaluation.")

In [None]:
# TENSORBOARD VISUALIZATION

%load_ext tensorboard

print(f"--- Launching TensorBoard ---")
print(f"Logs are being read from: {LOG_DIR}")
print("It may take a minute to load the data. Click the 'refresh' button in the UI if needed.")

%tensorboard --logdir {LOG_DIR}

## End