<a href="https://colab.research.google.com/github/AlperYildirim1/Pay-Attention-Later/blob/main/PRISM_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q torchmetrics sacrebleu

# ==============================================================================
# 1. CONFIGURATION
# ==============================================================================
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.fft
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, DataCollatorForSeq2Seq, get_cosine_schedule_with_warmup
from datasets import load_dataset
import math, sys, logging, datetime, json, random
import numpy as np
from tqdm.auto import tqdm
from torchmetrics.text import BLEUScore
from torch.utils.tensorboard import SummaryWriter
from typing import List

# --- Hardware Speedups ---
torch.set_float32_matmul_precision('medium')

# --- Data & Task Size ---
MAX_LENGTH = 128
MODEL_CHOICE = "PRISM"

# --- Model Architecture Config ---
D_MODEL = 512
NUM_HEADS = 8
D_FF = 2048
DROPOUT = 0.1
NUM_ENCODER_LAYERS = 6
NUM_DECODER_LAYERS = 6

# --- Training Config ---
TARGET_TRAINING_STEPS = 50000
VALIDATION_SCHEDULE = [
    2000, 4000, 5000, 7500, 10000, 15000, 20000,
    25000, 30000, 35000, 42500, 50000
]

PEAK_LEARNING_RATE = 8e-4
WARMUP_STEPS = 120
WEIGHT_DECAY = 0.01
LABEL_SMOOTHING_EPSILON = 0.1

# --- Paths ---
DRIVE_BASE_PATH = "/content/drive/MyDrive/PRISM"
PREBATCHED_REPO_ID = "Yujivus/wmt14-de-en-prebatched-w4"
ORIGINAL_BUCKETED_REPO_ID = "Yujivus/wmt14-de-en-bucketed-w4"
MODEL_CHECKPOINT = "Helsinki-NLP/opus-mt-de-en"

# ==============================================================================
# 2. IMPORTS & SETUP
# ==============================================================================

def set_seed(seed_value=116):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

set_seed()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Logging Setup ---
experiment_name = f"{MODEL_CHOICE}_{datetime.datetime.now().strftime('%Y%m%d_%H%M')}"
CURRENT_RUN_DIR = os.path.join(DRIVE_BASE_PATH, experiment_name)
SAVE_DIR = os.path.join(CURRENT_RUN_DIR, "models")
LOG_DIR_TENSORBOARD = os.path.join(CURRENT_RUN_DIR, "tensorboard_logs")
LOG_FILE_TXT = os.path.join(CURRENT_RUN_DIR, "run_log.txt")

os.makedirs(SAVE_DIR, exist_ok=True)
os.makedirs(LOG_DIR_TENSORBOARD, exist_ok=True)

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s | %(message)s',
    handlers=[logging.FileHandler(LOG_FILE_TXT), logging.StreamHandler(sys.stdout)],
    force=True
)
logger = logging.getLogger(__name__)
writer = SummaryWriter(LOG_DIR_TENSORBOARD)

# ==============================================================================
# 3. DATA LOADING
# ==============================================================================
tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT)
VOCAB_SIZE = len(tokenizer)
standard_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer)

class PreBatchedCollator:
    def __init__(self, original_dataset_split):
        self.original_dataset = original_dataset_split
    def __call__(self, features: List[dict]) -> dict:
        batch_indices = features[0]['batch_indices']
        dict_of_lists = self.original_dataset[batch_indices]
        list_of_dicts = []
        keys = dict_of_lists.keys()
        num_samples = len(dict_of_lists['input_ids'])
        for i in range(num_samples):
            list_of_dicts.append({key: dict_of_lists[key][i] for key in keys})
        return standard_collator(list_of_dicts)

logger.info(f"Loading datasets...")
prebatched_datasets = load_dataset(PREBATCHED_REPO_ID)
original_datasets = load_dataset(ORIGINAL_BUCKETED_REPO_ID)
train_collator = PreBatchedCollator(original_datasets["train"])

train_dataloader = DataLoader(
    prebatched_datasets["train"], batch_size=1, shuffle=True,
    collate_fn=train_collator, num_workers=2, pin_memory=True, prefetch_factor=2
)
val_dataloader = DataLoader(
    original_datasets["validation"], batch_size=64,
    collate_fn=standard_collator, num_workers=2
)
# ==============================================================================
# 4. PRISM ARCHITECTURE (FIXED: COMPLEX DROPOUT & PADDING)
# ==============================================================================

class ComplexDropout(nn.Module):
    """
    FIX: Standard nn.Dropout doesn't work on ComplexFloat.
    This module generates a mask based on the shape and applies it to both
    Real and Imaginary parts identically to preserve Phase.
    """
    def __init__(self, p=0.5):
        super().__init__()
        self.p = p

    def forward(self, z):
        if not self.training or self.p == 0.0:
            return z

        # Generate mask using F.dropout on a ones tensor of the same shape (Real part)
        # F.dropout handles the scaling (1 / 1-p) automatically
        mask = torch.ones_like(z.real)
        mask = F.dropout(mask, self.p, self.training, inplace=False)

        # Apply mask to the complex tensor
        return z * mask

class PhasePreservingLayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-5):
        super().__init__()
        self.layernorm = nn.LayerNorm(d_model, eps=eps)
        self.eps = eps

    def forward(self, x):
        mag = torch.abs(x)
        mag_norm = self.layernorm(mag)
        # Avoid division by zero
        return mag_norm.to(x.dtype) * (x / (mag + self.eps))

class HarmonicEmbedding(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, max_period=10000.0):
        super().__init__()
        self.embedding_dim = embedding_dim # Store for scaling

        # Fixed Frequencies
        freqs = torch.exp(torch.arange(0, embedding_dim, dtype=torch.float32) * -(math.log(max_period) / embedding_dim))
        self.register_buffer('freqs', freqs)

        # Learnable Amplitudes
        self.amplitude_embedding = nn.Embedding(num_embeddings, embedding_dim)
        nn.init.uniform_(self.amplitude_embedding.weight, 0.1, 1.0)

    def forward(self, input_ids):
        batch_size, seq_len = input_ids.shape
        amplitudes = torch.abs(self.amplitude_embedding(input_ids))

        amplitudes = amplitudes * math.sqrt(self.embedding_dim)

        positions = torch.arange(seq_len, device=input_ids.device).float()
        angles = torch.outer(positions, self.freqs)
        spin = torch.polar(torch.ones_like(angles), angles).unsqueeze(0)
        return amplitudes * spin

class PRISMEncoder(nn.Module):
    def __init__(self, num_layers, d_model, max_len, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([PRISMLayer(d_model, max_len, dropout) for _ in range(num_layers)])

        self.final_norm = PhasePreservingLayerNorm(d_model)

    def forward(self, x, src_mask=None):
        for layer in self.layers:
            x = layer(x, src_mask)

        # Apply Final Norm
        return self.final_norm(x)

class ModReLU(nn.Module):
    def __init__(self, features):
        super().__init__()
        self.b = nn.Parameter(torch.zeros(features))
    def forward(self, z):
        mag = torch.abs(z)
        new_mag = F.relu(mag + self.b)
        phase = z / (mag + 1e-6)
        return new_mag * phase

class PRISMLayer(nn.Module):
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.filter_len = max_len

        # Gating
        self.pre_gate = nn.Linear(d_model * 2, d_model)
        nn.init.constant_(self.pre_gate.bias, 2.0)
        self.pre_gate.is_gate = True

        # Global Filter
        self.global_filter = nn.Parameter(torch.randn(d_model, max_len, dtype=torch.cfloat) * 0.02)

        # Mixing
        self.mix_real = nn.Linear(d_model, d_model)
        self.mix_imag = nn.Linear(d_model, d_model)
        self.out_real = nn.Linear(d_model, d_model)
        self.out_imag = nn.Linear(d_model, d_model)

        self.activation = ModReLU(d_model)
        self.norm = PhasePreservingLayerNorm(d_model)

        # FIX: Use ComplexDropout instead of nn.Dropout
        self.dropout = ComplexDropout(dropout)

    def complex_linear(self, x, l_real, l_imag):
        r, i = x.real, x.imag
        new_r = l_real(r) - l_imag(i)
        new_i = l_real(i) + l_imag(r)
        return torch.complex(new_r, new_i)

    def forward(self, x, src_mask=None):
        residual = x
        x_norm = self.norm(x)

        # Re-apply mask after LayerNorm to prevent padding noise leakage
        if src_mask is not None:
            mask_expanded = src_mask.unsqueeze(-1)
            x_norm = x_norm.masked_fill(mask_expanded, 0.0)

        # A. Gating
        x_concat = torch.cat([x_norm.real, x_norm.imag], dim=-1)
        gate = torch.sigmoid(self.pre_gate(x_concat))
        x_gated = x_norm * gate

        # B. FFT Resonance
        B, L, D = x_gated.shape
        x_freq = torch.fft.fft(x_gated, n=self.filter_len, dim=1)
        filter_transposed = self.global_filter.transpose(-1, -2)
        x_filtered = x_freq * filter_transposed
        x_time = torch.fft.ifft(x_filtered, n=self.filter_len, dim=1)
        x_time = x_time[:, :L, :]

        # C. Mix & Activate
        x_mixed = self.complex_linear(x_time, self.mix_real, self.mix_imag)
        x_act = self.activation(x_mixed)
        out = self.complex_linear(x_act, self.out_real, self.out_imag)

        return self.dropout(out) + residual

class ComplexToRealBridge(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.proj = nn.Linear(d_model * 2, d_model)
    def forward(self, x_complex):
        cat = torch.cat([x_complex.real, x_complex.imag], dim=-1)
        return self.proj(cat)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
    def forward(self, x): return x + self.pe[:, :x.size(1)]

class PRISMTransformer(nn.Module):
    def __init__(self, num_encoder_layers, num_decoder_layers, num_heads, d_model, dff, vocab_size, max_length, dropout):
        super().__init__()
        self.d_model = d_model

        # 1. Embeddings
        self.harmonic_embedding = HarmonicEmbedding(vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder_helper = PositionalEncoding(d_model, max_length)

        # 2. Encoder (Harmonic) & Bridge
        self.encoder = PRISMEncoder(num_encoder_layers, d_model, max_length, dropout)
        self.bridge = ComplexToRealBridge(d_model)

        # 3. Decoder (Standard)
        decoder_layer = nn.TransformerDecoderLayer(d_model, num_heads, dff, dropout, batch_first=True, norm_first=True)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

        # 4. Output Head
        self.final_linear = nn.Linear(d_model, vocab_size)
        self.final_linear.weight = self.tgt_embedding.weight # Tie weights
        self.dropout = nn.Dropout(dropout) # Standard dropout for Real valued parts

    def forward(self, src, tgt, src_mask, tgt_pad, mem_pad, tgt_mask):
        # A. Harmonic Embedding
        src_harmonic = self.harmonic_embedding(src)

        # Initial Masking
        if src_mask is not None:
            src_harmonic = src_harmonic.masked_fill(src_mask.unsqueeze(-1), 0.0)

        # B. Encoder with Gradient Checkpointing
        if self.training:
            src_harmonic.requires_grad_(True)
            encoded_complex = torch.utils.checkpoint.checkpoint(
                self.encoder,
                src_harmonic,
                src_mask,
                use_reentrant=False
            )
        else:
            encoded_complex = self.encoder(src_harmonic, src_mask)

        # C. Bridge & Decoder
        memory = self.bridge(encoded_complex)
        tgt_emb = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
        tgt_emb = self.dropout(self.pos_encoder_helper(tgt_emb))

        output = self.decoder(tgt=tgt_emb, memory=memory, tgt_mask=tgt_mask,
                              tgt_key_padding_mask=tgt_pad, memory_key_padding_mask=mem_pad)
        return self.final_linear(output)

    def create_masks(self, src, tgt):
        src_padding_mask = (src == tokenizer.pad_token_id)
        tgt_padding_mask = (tgt == tokenizer.pad_token_id)
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(sz=tgt.size(1), device=src.device, dtype=torch.bool)
        return src_padding_mask, tgt_padding_mask, src_padding_mask, tgt_mask

    @torch.no_grad()
    def generate(self, src, max_length, num_beams=5):
        self.eval()
        src_mask = (src == tokenizer.pad_token_id)
        src_harmonic = self.harmonic_embedding(src)

        if src_mask is not None:
            src_harmonic = src_harmonic.masked_fill(src_mask.unsqueeze(-1), 0.0)

        encoded_complex = self.encoder(src_harmonic, src_mask)
        memory = self.bridge(encoded_complex)

        batch_size = src.shape[0]
        memory = memory.repeat_interleave(num_beams, dim=0)
        memory_key_padding_mask = src_mask.repeat_interleave(num_beams, dim=0)

        beams = torch.full((batch_size * num_beams, 1), tokenizer.pad_token_id, dtype=torch.long, device=src.device)
        beam_scores = torch.zeros(batch_size * num_beams, device=src.device)
        finished_beams = torch.zeros(batch_size * num_beams, dtype=torch.bool, device=src.device)

        for _ in range(max_length - 1):
            if finished_beams.all(): break
            tgt_mask = nn.Transformer.generate_square_subsequent_mask(beams.size(1)).to(src.device)
            tgt_emb = self.tgt_embedding(beams) * math.sqrt(self.d_model)
            tgt_emb = self.dropout(self.pos_encoder_helper(tgt_emb))

            out = self.decoder(tgt=tgt_emb, memory=memory, tgt_mask=tgt_mask, memory_key_padding_mask=memory_key_padding_mask)
            logits = self.final_linear(out[:, -1, :])
            log_probs = F.log_softmax(logits, dim=-1)

            log_probs[:, tokenizer.pad_token_id] = -torch.inf
            if finished_beams.any():
                log_probs[finished_beams, tokenizer.eos_token_id] = 0

            total = (beam_scores.unsqueeze(1) + log_probs).view(batch_size, -1)
            top_scores, top_indices = torch.topk(total, k=num_beams, dim=1)

            beam_indices = top_indices // log_probs.shape[-1]
            token_indices = top_indices % log_probs.shape[-1]

            effective = (torch.arange(batch_size, device=src.device).unsqueeze(1) * num_beams + beam_indices).view(-1)
            beams = torch.cat([beams[effective], token_indices.view(-1, 1)], dim=1)
            beam_scores = top_scores.view(-1)
            finished_beams = finished_beams | (beams[:, -1] == tokenizer.eos_token_id)

        final_beams = beams.view(batch_size, num_beams, -1)
        best_beams = final_beams[:, 0, :]
        self.train()
        return best_beams
# ==============================================================================
# 5. TRAINING LOOP
# ==============================================================================
def evaluate(model, dataloader, device):
    bleu_metric = BLEUScore()
    model.eval()
    for batch in tqdm(dataloader, desc="Evaluating", leave=False):
        input_ids = batch['input_ids'].to(device)
        labels = batch['labels']
        generated_ids = model.generate(input_ids, max_length=MAX_LENGTH, num_beams=5)
        pred_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        labels[labels == -100] = tokenizer.pad_token_id
        ref_texts = tokenizer.batch_decode(labels, skip_special_tokens=True)
        bleu_metric.update(pred_texts, [[ref] for ref in ref_texts])
    model.train()
    return bleu_metric.compute().item()

if __name__ == '__main__':
    config_state = {"model": MODEL_CHOICE, "d_model": D_MODEL, "layers": NUM_ENCODER_LAYERS,
                    "lr": PEAK_LEARNING_RATE, "seed": 116}

    logger.info("Initializing PRISM (Fixed)...")
    model = PRISMTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, NUM_HEADS, D_MODEL, D_FF, VOCAB_SIZE, MAX_LENGTH, DROPOUT)
    model.to(device)

    # FIX: Robust Initialization that respects the Gate Bias
    def init_weights_PRISM(m):
        if isinstance(m, nn.Linear):
            # Check if this is the Spectral Gate (marked with tag)
            if hasattr(m, 'is_gate') and m.is_gate:
                return # Skip initialization for the gate (it's already 2.0)

            nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))
            if m.bias is not None:
                nn.init.uniform_(m.bias, -0.1, 0.1)

    model.apply(init_weights_PRISM)

    optimizer = torch.optim.AdamW(model.parameters(), lr=PEAK_LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scheduler = get_cosine_schedule_with_warmup(optimizer, WARMUP_STEPS, TARGET_TRAINING_STEPS)
    loss_fn = nn.CrossEntropyLoss(ignore_index=-100, label_smoothing=LABEL_SMOOTHING_EPSILON)

    logger.info(f"STARTING MARATHON ({TARGET_TRAINING_STEPS} steps)")
    model.train()
    global_step = 0
    best_bleu = 0.0
    progress = tqdm(total=TARGET_TRAINING_STEPS)

    while global_step < TARGET_TRAINING_STEPS:
        for batch in train_dataloader:
            if global_step >= TARGET_TRAINING_STEPS: break
            optimizer.zero_grad()
            input_ids = batch['input_ids'].to(device, non_blocking=True)
            labels = batch['labels'].to(device, non_blocking=True)
            dec_in = torch.cat([torch.full((labels.size(0), 1), tokenizer.pad_token_id, device=device), labels[:, :-1]], dim=1)
            dec_in[dec_in == -100] = tokenizer.pad_token_id
            src_mask, tgt_pad, mem_pad, tgt_mask = model.create_masks(input_ids, dec_in)
            tgt_pad[:, 0] = False
            out = model(input_ids, dec_in, src_mask, tgt_pad, mem_pad, tgt_mask)
            loss = loss_fn(out.view(-1, VOCAB_SIZE), labels.view(-1))
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            global_step += 1
            progress.update(1)
            if global_step % 50 == 0:
                writer.add_scalar('train/loss', loss.item(), global_step)
                progress.set_postfix(loss=loss.item())
            if global_step in VALIDATION_SCHEDULE:
                logger.info(f"Validating at step {global_step}...")
                current_bleu = evaluate(model, val_dataloader, device)
                writer.add_scalar('val/bleu', current_bleu, global_step)
                logger.info(f"Step {global_step} | BLEU: {current_bleu:.4f}")
                if current_bleu > best_bleu:
                    best_bleu = current_bleu
                    torch.save(model.state_dict(), os.path.join(SAVE_DIR, "best_model.pt"))

    torch.save(model.state_dict(), os.path.join(SAVE_DIR, "marathon_model.pt"))
    logger.info(f"Marathon Complete. Best BLEU: {best_bleu:.4f}")

In [None]:
# 6. OFFICIAL EVALUATION: WMT14 TEST SPLIT

print("\n" + "="*50)
print("STARTING OFFICIAL TEST SET EVALUATION (newstest2014)")
print("="*50)

# 1. Load the Official Test Split
# We use the standard 'wmt14' library to ensure this is the benchmark dataset
logger.info("Loading WMT14 Test Split (newstest2014)...")
try:
    test_dataset_raw = load_dataset("wmt14", "de-en", split="test")
except Exception as e:
    logger.warning(f"Could not load official wmt14 ({e}). Trying fallback...")
    # Fallback to the bucketed repo if official fails, but usually wmt14 works
    test_dataset_raw = load_dataset(ORIGINAL_BUCKETED_REPO_ID, split="test")

# 2. Preprocess Test Data (Tokenization)
# We need to tokenize it exactly like the training data
def preprocess_test(examples):
    inputs = [ex["de"] for ex in examples["translation"]]
    targets = [ex["en"] for ex in examples["translation"]]

    model_inputs = tokenizer(inputs, max_length=MAX_LENGTH, truncation=True)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=MAX_LENGTH, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# Map the preprocessing
test_dataset_tokenized = test_dataset_raw.map(
    preprocess_test,
    batched=True,
    remove_columns=test_dataset_raw.column_names,
    desc="Tokenizing Test Set"
)

# 3. Create DataLoader
test_collator = DataCollatorForSeq2Seq(tokenizer, model=model, padding=True)
test_dataloader = DataLoader(
    test_dataset_tokenized,
    batch_size=32, # Safe batch size for inference
    collate_fn=test_collator,
    num_workers=2
)

# 4. Load the BEST Model (Not necessarily the last one)
# We want the checkpoint that had the highest Validation BLEU
best_model_path = os.path.join(SAVE_DIR, "best_model.pt")
if os.path.exists(best_model_path):
    logger.info(f"Loading BEST checkpoint from: {best_model_path}")
    model.load_state_dict(torch.load(best_model_path, map_location=device))
else:
    logger.warning("Best model not found, using current weights (Marathon End).")

# 5. Run Evaluation
test_bleu = evaluate(model, test_dataloader, device)

print("\n" + "*"*50)
print(f"OFFICIAL WMT14 TEST RESULTS")
print("*"*50)
print(f"Model: {MODEL_CHOICE}")
print(f"Test Set: newstest2014 (approx)")
print(f"Final BLEU Score: {test_bleu:.4f}")
print("*"*50)

# Save the result to a file for the paper
with open(os.path.join(CURRENT_RUN_DIR, "final_test_score.txt"), "w") as f:
    f.write(f"Model: {MODEL_CHOICE}\n")
    f.write(f"Steps: {TARGET_TRAINING_STEPS}\n")
    f.write(f"Test BLEU: {test_bleu:.4f}\n")