<a href="https://colab.research.google.com/github/AlperYildirim1/Pay-Attention-Later/blob/main/PRISM_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q torchmetrics sacrebleu

## CONFIG

# --- Data & Task Size ---
MAX_LENGTH = 128
MODEL_CHOICE = "Solomon" # Naming the run

# --- Model Architecture Config ---
D_MODEL = 512
NUM_HEADS = 8
D_FF = 2048
DROPOUT = 0.1

# --- Layer counts ---
NUM_ENCODER_LAYERS = 6
NUM_DECODER_LAYERS = 6

# --- Training Config ---
TARGET_TRAINING_STEPS = 50000
VALIDATION_SCHEDULE = [
    2000, 4000, 5000, 7500, 10000, 15000, 20000,
    25000, 30000, 35000, 42500, 50000
]
PEAK_LEARNING_RATE = 8e-4
WARMUP_STEPS = 120
WEIGHT_DECAY = 0.01
LABEL_SMOOTHING_EPSILON = 0.1

# --- Paths ---
DRIVE_BASE_PATH = "/content/drive/MyDrive/PRISM"
PREBATCHED_REPO_ID = "Yujivus/wmt14-de-en-prebatched-w4"
ORIGINAL_BUCKETED_REPO_ID = "Yujivus/wmt14-de-en-bucketed-w4"
MODEL_CHECKPOINT = "Helsinki-NLP/opus-mt-de-en"


In [None]:

## IMPORTS & SETUP

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.fft # <--- OZAN REQUIRES THIS
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, DataCollatorForSeq2Seq, get_cosine_schedule_with_warmup
from datasets import load_dataset
import math
import os
import sys
import shutil
import logging
import datetime
import subprocess
import hashlib
import json
import random
import numpy as np
from tqdm.auto import tqdm
from torchmetrics.text import BLEUScore
from torch.utils.tensorboard import SummaryWriter
from typing import List

# --- Reproducibility ---
def set_seed(seed_value=5):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

SEED = 116
set_seed(SEED)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.use_deterministic_algorithms(True)
torch.set_float32_matmul_precision('high')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- Tokenizer ---
tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT)
VOCAB_SIZE = len(tokenizer)


## DATA LOADING (UNCHANGED)

standard_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer)

class PreBatchedCollator:
    def __init__(self, original_dataset_split):
        self.original_dataset = original_dataset_split

    def __call__(self, features: List[dict]) -> dict:
        batch_indices = features[0]['batch_indices']
        dict_of_lists = self.original_dataset[batch_indices]
        list_of_dicts = []
        keys = dict_of_lists.keys()
        num_samples = len(dict_of_lists['input_ids'])
        for i in range(num_samples):
            list_of_dicts.append({key: dict_of_lists[key][i] for key in keys})
        return standard_collator(list_of_dicts)

print(f"Loading datasets...")
prebatched_datasets = load_dataset(PREBATCHED_REPO_ID)
original_datasets = load_dataset(ORIGINAL_BUCKETED_REPO_ID)
train_collator = PreBatchedCollator(original_datasets["train"])

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(SEED)

train_dataloader = DataLoader(
    prebatched_datasets["train"], batch_size=1, shuffle=True, num_workers=0,
    collate_fn=train_collator, pin_memory=True, worker_init_fn=seed_worker, generator=g,
)

val_dataloader = DataLoader(
    original_datasets["validation"], batch_size=64, collate_fn=standard_collator,
    num_workers=0, pin_memory=True, worker_init_fn=seed_worker, generator=g,
)


In [None]:

## --- OZAN ARCHITECTURE COMPONENTS ---

class HarmonicEmbedding(nn.Module):
    """
    The 'Turntables'.
    Replaces standard embedding.
    Learns Magnitude (Amplitude).
    Uses Fixed Physics for Rotation (Phase).
    """
    def __init__(self, num_embeddings, embedding_dim, max_period=10000.0):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings

        # 1. Fixed Frequencies (The Turntables)
        # We generate frequencies for the FULL dimension
        freqs = torch.exp(
            torch.arange(0, embedding_dim, dtype=torch.float32) * -(math.log(max_period) / embedding_dim)
        )
        self.register_buffer('freqs', freqs)

        # 2. Learnable Amplitudes (The Volume Knobs)
        # We use a standard embedding layer but treat output as Magnitude
        self.amplitude_embedding = nn.Embedding(num_embeddings, embedding_dim)
        nn.init.uniform_(self.amplitude_embedding.weight, 0.1, 1.0) # Start with positive volume

    def forward(self, input_ids):
        # input_ids: [Batch, Seq]
        batch_size, seq_len = input_ids.shape

        # A. Get Amplitudes (Semantics) -> [Batch, Seq, Dim]
        amplitudes = torch.abs(self.amplitude_embedding(input_ids))

        # B. Get Phases (Position) -> [Seq, Dim]
        positions = torch.arange(seq_len, device=input_ids.device).float()
        # Outer product: Pos * Freq
        angles = torch.outer(positions, self.freqs)

        # C. Create Spinning Wave (Complex Unit Vectors) -> [1, Seq, Dim]
        spin = torch.polar(torch.ones_like(angles), angles).unsqueeze(0)

        # D. Combine: Signal = Amplitude * Spin
        # Output is Complex Float
        return amplitudes * spin

class ModReLU(nn.Module):
    """
    Phase-Preserving Activation.
    Gates the Magnitude, keeps the Angle.
    """
    def __init__(self, features):
        super().__init__()
        self.b = nn.Parameter(torch.zeros(features))

    def forward(self, z):
        # z is Complex
        mag = torch.abs(z)
        # ReLU on magnitude with bias
        new_mag = F.relu(mag + self.b)
        # Preserve phase: z / |z|
        phase = z / (mag + 1e-6)
        return new_mag * phase

class OzanLayer(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        self.d_model = d_model
        self.filter_len = max_len

        # --- 1. THE GATE (Amplitude Modulator) ---
        self.pre_gate = nn.Linear(d_model * 2, d_model)

        # Initialize Gate to be OPEN (+2.0 bias -> ~0.88 sigmoid)
        nn.init.constant_(self.pre_gate.bias, 2.0)

        # --- 2. GLOBAL FILTER ---
        self.global_filter = nn.Parameter(
            torch.randn(d_model, max_len, dtype=torch.cfloat) * 0.02
        )

        self.mix_real = nn.Linear(d_model, d_model)
        self.mix_imag = nn.Linear(d_model, d_model)
        self.out_real = nn.Linear(d_model, d_model)
        self.out_imag = nn.Linear(d_model, d_model)

        self.activation = ModReLU(d_model)
        self.norm = nn.LayerNorm(d_model)

        # STORAGE FOR LOGGING
        self.gate_metrics = {}

    def complex_linear(self, x, l_real, l_imag):
        r, i = x.real, x.imag
        new_r = l_real(r) - l_imag(i)
        new_i = l_real(i) + l_imag(r)
        return torch.complex(new_r, new_i)

    def forward(self, x):
        # x: [Batch, Seq, Dim] (Complex)

        # --- A. GATING ---
        x_concat = torch.cat([x.real, x.imag], dim=-1)
        gate = torch.sigmoid(self.pre_gate(x_concat))

        # --- CAPTURE METRICS (Detached from graph) ---
        if self.training:
            with torch.no_grad():
                # Mean: How open is the gate on average?
                self.gate_metrics['mean'] = gate.mean().item()
                # Std: How discriminative is it? (High std = good selection)
                self.gate_metrics['std'] = gate.std().item()
                # Sparsity: What % of signals are being killed (< 0.1)?
                self.gate_metrics['sparsity'] = (gate < 0.1).float().mean().item()

        x_gated = x * gate

        # --- B. GLOBAL RESONANCE ---
        B, L, D = x_gated.shape
        x_freq = torch.fft.fft(x_gated, n=self.filter_len, dim=1)
        filter_transposed = self.global_filter.transpose(-1, -2)
        x_filtered = x_freq * filter_transposed
        x_time = torch.fft.ifft(x_filtered, n=self.filter_len, dim=1)
        x_time = x_time[:, :L, :]

        # --- C. MIX & ACTIVATE ---
        x_mixed = self.complex_linear(x_time, self.mix_real, self.mix_imag)
        x_act = self.activation(x_mixed)
        out = self.complex_linear(x_act, self.out_real, self.out_imag)

        return out + x

class OzanEncoder(nn.Module):
    def __init__(self, num_layers, d_model, max_len):
        super().__init__()
        self.layers = nn.ModuleList([
            OzanLayer(d_model, max_len) for _ in range(num_layers)
        ])

    def forward(self, x, src_key_padding_mask=None):
        # Note: Ozan handles padding naturally via FFT bucket logic,
        # but explicit masking in freq domain is harder.
        # Since we use bucketing, 'cliffs' are minimized.
        for layer in self.layers:
            x = layer(x)
        return x

class ComplexToRealBridge(nn.Module):
    """
    Projects Complex Ozan Embeddings back to Real numbers
    so the Standard Decoder can read them.
    """
    def __init__(self, d_model):
        super().__init__()
        self.proj = nn.Linear(d_model * 2, d_model)

    def forward(self, x_complex):
        # Concatenate Real and Imag parts
        # [Batch, Seq, Dim*2]
        cat = torch.cat([x_complex.real, x_complex.imag], dim=-1)
        return self.proj(cat)

class OzanTransformer(nn.Module):
    """
    The Hybrid Monster.
    Encoder: Ozan (Complex, Harmonic, Convolutional)
    Decoder: Standard Transformer (Real, Autoregressive)
    """
    def __init__(self, num_encoder_layers, num_decoder_layers, num_heads, d_model, dff, vocab_size, max_length, dropout):
        super().__init__()
        self.d_model = d_model

        # 1. Harmonic Embeddings (Complex)
        self.harmonic_embedding = HarmonicEmbedding(vocab_size, d_model)
        self.dropout = nn.Dropout(dropout)

        # 2. Ozan Encoder (Complex Signal Processing)
        self.encoder = OzanEncoder(num_encoder_layers, d_model, max_length)

        # 3. The Bridge (Complex -> Real)
        self.bridge = ComplexToRealBridge(d_model)

        # 4. Standard Decoder (For generation safety)
        # Note: We still need positional encoding for the decoder since it's standard
        self.decoder_pos_encoder = nn.Transformer(d_model=d_model).encoder # Hack to get PE
        # Actually let's just reuse the class provided in your snippet
        self.pos_encoder_helper = PositionalEncoding(d_model, max_length)

        # Standard Embedding for Target (Decoder needs real inputs)
        self.tgt_embedding = nn.Embedding(vocab_size, d_model)

        decoder_layer = nn.TransformerDecoderLayer(
            d_model, num_heads, dff, dropout, batch_first=True, norm_first=True
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

        self.final_linear = nn.Linear(d_model, vocab_size)
        # Tie weights for target embedding only
        self.final_linear.weight = self.tgt_embedding.weight

    def forward(self, src, tgt, src_padding_mask, tgt_padding_mask, memory_key_padding_mask, tgt_mask):

        # --- ENCODER FLOW (Complex/Harmonic) ---
        # 1. Get Harmonic Embeddings
        src_harmonic = self.harmonic_embedding(src) # Complex

        # 2. Process via Ozan (Global Convolution) WITH CHECKPOINTING
        if self.training:
            src_harmonic.requires_grad_(True)
            encoded_complex = torch.utils.checkpoint.checkpoint(
                self.encoder,
                src_harmonic,
                use_reentrant=False
            )
        else:
            encoded_complex = self.encoder(src_harmonic)

        # --- MISSING LINK WAS HERE ---
        # 3. Convert to Real for Decoder
        memory = self.bridge(encoded_complex) # <--- YOU NEED THIS LINE

        # --- DECODER FLOW (Standard) ---
        tgt_emb = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
        tgt_emb = self.dropout(self.pos_encoder_helper(tgt_emb))

        output = self.decoder(
            tgt=tgt_emb,
            memory=memory,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=tgt_padding_mask,
            memory_key_padding_mask=memory_key_padding_mask
        )

        return self.final_linear(output)

    def create_masks(self, src, tgt):
        # Same mask logic as standard transformer
        src_padding_mask = (src == tokenizer.pad_token_id)
        tgt_padding_mask = (tgt == tokenizer.pad_token_id)
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(
            sz=tgt.size(1), device=src.device, dtype=torch.bool
        )
        return src_padding_mask, tgt_padding_mask, src_padding_mask, tgt_mask

    @torch.no_grad()
    def generate(self, src, max_length, num_beams=5):
        self.eval()
        # 1. Run Ozan Encoder
        src_harmonic = self.harmonic_embedding(src)
        encoded_complex = self.encoder(src_harmonic)
        memory = self.bridge(encoded_complex) # Bridge to real

        # 2. Standard Beam Search (Copied logic, adapted for pre-computed memory)
        batch_size = src.shape[0]
        src_padding_mask = (src == tokenizer.pad_token_id)

        memory = memory.repeat_interleave(num_beams, dim=0)
        memory_key_padding_mask = src_padding_mask.repeat_interleave(num_beams, dim=0)

        initial_token = tokenizer.pad_token_id
        beams = torch.full((batch_size * num_beams, 1), initial_token, dtype=torch.long, device=src.device)
        beam_scores = torch.zeros(batch_size * num_beams, device=src.device)
        finished_beams = torch.zeros(batch_size * num_beams, dtype=torch.bool, device=src.device)

        for _ in range(max_length - 1):
            if finished_beams.all(): break
            tgt_mask = nn.Transformer.generate_square_subsequent_mask(beams.size(1)).to(src.device)

            tgt_emb = self.tgt_embedding(beams) * math.sqrt(self.d_model)
            tgt_emb = self.dropout(self.pos_encoder_helper(tgt_emb))

            decoder_output = self.decoder(
                tgt=tgt_emb, memory=memory, tgt_mask=tgt_mask,
                memory_key_padding_mask=memory_key_padding_mask
            )
            logits = self.final_linear(decoder_output[:, -1, :])
            log_probs = F.log_softmax(logits, dim=-1)

            # (Standard beam search logic continues...)
            log_probs[:, tokenizer.pad_token_id] = -torch.inf
            if finished_beams.any(): log_probs[finished_beams, tokenizer.eos_token_id] = 0

            total_scores = beam_scores.unsqueeze(1) + log_probs
            if _ == 0:
                total_scores = total_scores.view(batch_size, num_beams, -1)
                total_scores[:, 1:, :] = -torch.inf
                total_scores = total_scores.view(batch_size * num_beams, -1)
            else:
                total_scores = beam_scores.unsqueeze(1) + log_probs

            total_scores = total_scores.view(batch_size, -1)
            top_scores, top_indices = torch.topk(total_scores, k=num_beams, dim=1)

            beam_indices = top_indices // log_probs.shape[-1]
            token_indices = top_indices % log_probs.shape[-1]

            batch_indices = torch.arange(batch_size, device=src.device).unsqueeze(1)
            effective_indices = (batch_indices * num_beams + beam_indices).view(-1)

            beams = beams[effective_indices]
            beams = torch.cat([beams, token_indices.view(-1, 1)], dim=1)
            beam_scores = top_scores.view(-1)
            finished_beams = finished_beams | (beams[:, -1] == tokenizer.eos_token_id)

        final_beams = beams.view(batch_size, num_beams, -1)
        final_scores = beam_scores.view(batch_size, num_beams)
        normalized_scores = final_scores / (final_beams != tokenizer.pad_token_id).sum(-1).float().clamp(min=1)
        best_beams = final_beams[torch.arange(batch_size), normalized_scores.argmax(1), :]

        self.train()
        return best_beams

# --- Helper Classes for Generation ---
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
    def forward(self, x): return x + self.pe[:, :x.size(1)]

In [None]:
def count_ozan_parameters(model):
    """
    Counts parameters for Gated Ozan + Transformer Hybrid.
    - Handles Complex Parameters (counts Real + Imag parts as 2 params).
    - Handles Tied Weights (counts them only once).
    - Explicitly lists Gate parameters.
    """
    seen_params = set()
    total_params = 0
    complex_params = 0

    print(f"{'MODULE':<45} | {'SHAPE':<20} | {'TYPE':<10} | {'PARAMS'}")
    print("-" * 90)

    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue

        # Handle Tied Weights
        param_id = id(param)
        if param_id in seen_params:
            continue
        seen_params.add(param_id)

        # Get raw element count
        num_elements = param.numel()

        # Check for Complex Numbers
        if param.is_complex():
            # A complex parameter has 2 floats (Real, Imag)
            current_params = num_elements * 2
            complex_params += current_params
            type_str = "Complex"
        else:
            current_params = num_elements
            type_str = "Real"

        total_params += current_params

        # Print breakdown for major layers (Embeddings, Filters, AND GATES)
        # Added "pre_gate" here so you can see the cost of the new mechanism
        if "global_filter" in name or "embedding" in name or "pre_gate" in name:
             print(f"{name:<45} | {str(list(param.shape)):<20} | {type_str:<10} | {current_params:,}")

    print("-" * 90)
    print(f"Total Trainable Parameters: {total_params:,}")
    print(f" (Of which are Complex components): {complex_params:,}")

    return total_params

In [None]:
# ==============================================================================
# --- MAIN EXECUTION (GATED EDITION) ---
# ==============================================================================

if __name__ == '__main__':
    # 1. Setup Directories & Logging
    experiment_name = f"{MODEL_CHOICE}_Gated"
    CURRENT_RUN_DIR = os.path.join(DRIVE_BASE_PATH, experiment_name)
    SAVE_DIR = os.path.join(CURRENT_RUN_DIR, "models")
    LOG_DIR_TENSORBOARD = os.path.join(CURRENT_RUN_DIR, "tensorboard_logs")
    LOG_FILE_TXT = os.path.join(CURRENT_RUN_DIR, "run_log.txt")

    os.makedirs(SAVE_DIR, exist_ok=True)
    os.makedirs(LOG_DIR_TENSORBOARD, exist_ok=True)

    # Configure Python Logger
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s %(message)s',
        handlers=[logging.FileHandler(LOG_FILE_TXT), logging.StreamHandler(sys.stdout)],
        force=True
    )
    logger = logging.getLogger(__name__)
    writer = SummaryWriter(LOG_DIR_TENSORBOARD)

    logger.info(f"--- LAUNCHING GATED OZAN EXPERIMENT (PAPER #2) ---")
    logger.info(f"Device: {device}")

    # 2. Initialize Model
    logger.info("Initializing OzanTransformer with Gated Harmonic Encoders...")
    model = OzanTransformer(
        num_encoder_layers=NUM_ENCODER_LAYERS, num_decoder_layers=NUM_DECODER_LAYERS,
        num_heads=NUM_HEADS, d_model=D_MODEL, dff=D_FF, vocab_size=VOCAB_SIZE,
        max_length=MAX_LENGTH, dropout=DROPOUT
    )

    # Custom Weight Initialization (Critical for Gate Stability)
    def init_weights_ozan(m):
        if isinstance(m, nn.Linear):
            nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))
            if m.bias is not None:
                # Don't overwrite the specific gate bias we set in __init__ (+2.0)
                # We check if the bias value is exactly 2.0 (our marker)
                # If it's not our special marker, we init normally.
                if m.bias.data[0] != 2.0:
                    nn.init.uniform_(m.bias, -0.1, 0.1)

    model.apply(init_weights_ozan)
    model.to(device)

    # Log Parameter Counts
    count_ozan_parameters(model)

    # 3. Optimizer & Scheduler
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=PEAK_LEARNING_RATE,
        betas=(0.9, 0.98),
        eps=1e-9,
        weight_decay=WEIGHT_DECAY
    )

    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=WARMUP_STEPS,
        num_training_steps=TARGET_TRAINING_STEPS
    )

    # Loss Function
    translation_loss_fn = nn.CrossEntropyLoss(
        ignore_index=-100,
        label_smoothing=LABEL_SMOOTHING_EPSILON
    )

    # --- Helper Functions ---
    def calculate_combined_loss(outputs, targets):
        logits = outputs
        loss = translation_loss_fn(logits.reshape(-1, logits.shape[-1]), targets.reshape(-1))
        return loss

    def evaluate(model, dataloader, device):
        bleu_metric = BLEUScore()
        model.eval()
        for batch in tqdm(dataloader, desc="Evaluating", leave=False):
            input_ids = batch['input_ids'].to(device)
            labels = batch['labels']

            # Generate
            generated_ids = model.generate(input_ids, max_length=MAX_LENGTH, num_beams=5)

            # Decode
            pred_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
            labels[labels == -100] = tokenizer.pad_token_id
            ref_texts = tokenizer.batch_decode(labels, skip_special_tokens=True)

            bleu_metric.update(pred_texts, [[ref] for ref in ref_texts])

        model.train()
        return bleu_metric.compute().item()

    # 4. Training Loop
    model.train()
    global_step = 0
    best_bleu = 0.0
    progress_bar = tqdm(total=TARGET_TRAINING_STEPS, desc="Training")

    training_complete = False

    # Loop over epochs until step count reached
    for epoch in range(1000): # High number, break by step
        if training_complete: break

        # Deterministic shuffling per epoch
        train_dataloader.generator.manual_seed(SEED + epoch)

        for batch in train_dataloader:
            if global_step >= TARGET_TRAINING_STEPS:
                training_complete = True
                break

            # A. Prepare Data
            optimizer.zero_grad(set_to_none=True)
            input_ids = batch['input_ids'].to(device, non_blocking=True)
            labels = batch['labels'].to(device, non_blocking=True)

            # Decoder Input: [PAD, t1, t2, ...]
            decoder_start = torch.full((labels.shape[0], 1), tokenizer.pad_token_id, device=device)
            decoder_input = torch.cat([decoder_start, labels[:, :-1]], dim=1)
            decoder_input[decoder_input == -100] = tokenizer.pad_token_id

            # Masks
            src_mask, tgt_pad, mem_pad, tgt_mask = model.create_masks(input_ids, decoder_input)
            tgt_pad[:, 0] = False # Unmask the start token

            # B. Forward Pass (FP32)
            outputs = model(input_ids, decoder_input, src_mask, tgt_pad, mem_pad, tgt_mask)
            loss = calculate_combined_loss(outputs, labels)

            # C. Backward Pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()

            global_step += 1
            progress_bar.update(1)

            # D. Logging (Loss & Gates)
            if global_step % 20 == 0:
                current_loss = loss.item()
                writer.add_scalar('train/loss', current_loss, global_step)
                writer.add_scalar('train/lr', scheduler.get_last_lr()[0], global_step)
                progress_bar.set_postfix(loss=current_loss)

                # --- GATE LOGGING ---
                # Check model.encoder.layers for the stored metrics
                for i, layer in enumerate(model.encoder.layers):
                    if hasattr(layer, 'gate_metrics') and layer.gate_metrics:
                        # Log Mean (Openness)
                        writer.add_scalar(f'gates/L{i}_mean', layer.gate_metrics['mean'], global_step)
                        # Log Std (Selectivity)
                        writer.add_scalar(f'gates/L{i}_std', layer.gate_metrics['std'], global_step)
                        # Log Sparsity (Filtering)
                        writer.add_scalar(f'gates/L{i}_sparsity', layer.gate_metrics['sparsity'], global_step)

            # E. Validation
            if global_step in VALIDATION_SCHEDULE:
                logger.info(f"Validating at step {global_step}...")
                current_bleu = evaluate(model, val_dataloader, device)
                writer.add_scalar('val/bleu', current_bleu, global_step)
                logger.info(f"Step {global_step} | BLEU: {current_bleu:.4f}")

                if current_bleu > best_bleu:
                    best_bleu = current_bleu
                    torch.save(model.state_dict(), os.path.join(SAVE_DIR, "best_model.pt"))
                    logger.info(f"New Best Model Saved! ({best_bleu:.4f})")

    # 5. Final Save
    torch.save(model.state_dict(), os.path.join(SAVE_DIR, "final_model.pt"))
    writer.close()
    logger.info(f"Training Complete. Best BLEU: {best_bleu:.4f}")

In [None]:
# ==============================================================================
# --- FINAL TEST & VISUALIZATION ---
# ==============================================================================

if __name__ == '__main__':
    print(f"--- STARTING FINAL EVALUATION ---")

    # 1. Load the Test Split
    # We use the standard collator because testing doesn't need pre-batching optimization
    print("Loading Test Data...")
    test_dataset = original_datasets["test"]
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=32,
        collate_fn=standard_collator,
        num_workers=0,
        pin_memory=True
    )

    # 2. Load the Best Model Weights
    best_model_path = os.path.join(SAVE_DIR, "best_model.pt")
    if not os.path.exists(best_model_path):
        print(f"WARNING: {best_model_path} not found. Using current model weights.")
    else:
        print(f"Loading best weights from: {best_model_path}")
        model.load_state_dict(torch.load(best_model_path, map_location=device))

    model.eval()
    model.to(device)

    # 3. Run Evaluation Loop
    test_bleu = BLEUScore()
    examples_to_print = []

    print("Running Inference on Test Set...")
    with torch.no_grad():
        for i, batch in enumerate(tqdm(test_dataloader, desc="Testing")):
            input_ids = batch['input_ids'].to(device)
            labels = batch['labels']

            # Generate
            generated_ids = model.generate(input_ids, max_length=MAX_LENGTH, num_beams=5)

            # Decode
            src_texts = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
            pred_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

            # Handle labels for metric (remove -100)
            labels[labels == -100] = tokenizer.pad_token_id
            ref_texts = tokenizer.batch_decode(labels, skip_special_tokens=True)

            # Update Metric
            test_bleu.update(pred_texts, [[ref] for ref in ref_texts])

            # Save first batch for visualization
            if i == 0:
                for j in range(len(src_texts)):
                    examples_to_print.append({
                        "src": src_texts[j],
                        "ref": ref_texts[j],
                        "pred": pred_texts[j]
                    })

    # 4. Final Score
    final_score = test_bleu.compute().item()
    print("\n" + "="*40)
    print(f"üèÜ FINAL TEST BLEU SCORE: {final_score:.4f}")
    print("="*40 + "\n")

    # 5. Visual Inspection (Qualitative Analysis)
    print(f"--- Example Translations ---\n")

    # Print random 5 examples from the first batch
    import random
    selected_examples = random.sample(examples_to_print, min(5, len(examples_to_print)))

    for ex in selected_examples:
        print(f"üá©üá™ SRC:  {ex['src']}")
        print(f"üá¨üáß REF:  {ex['ref']}")
        print(f"ü§ñ OZAN: {ex['pred']}")
        print("-" * 20)

    # 6. Save Logic (Optional: Zip and Download for local storage)
    print(f"\nModel and logs are saved in: {CURRENT_RUN_DIR}")