<a href="https://colab.research.google.com/github/AlperYildirim1/Pay-Attention-Later/blob/main/PRISM_few_shot_experiment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q torchmetrics sacrebleu

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.fft
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer
import math
import sys
import logging
import random
import numpy as np
from torchmetrics.text import BLEUScore
from tqdm.auto import tqdm

# ==============================================================================
# 1. CONFIGURATION (MATCHING PAPER SECTION 7.3)
# ==============================================================================
# Paths
DRIVE_BASE_PATH = "/content/drive/MyDrive/PRISM"
MODEL_CHOICE = "Rehoboam"
# Be sure to point to the folder where your Marathon model is saved
EXPERIMENT_NAME = "Rehoboam_20251128_2352" # <--- UPDATE THIS TO YOUR ACTUAL FOLDER NAME
CHECKPOINT_PATH = os.path.join(DRIVE_BASE_PATH, EXPERIMENT_NAME, "models", "marathon_model.pt")

# Injection Hyperparameters (From Paper)
INJECTION_LR = 2e-4        # Low magnitude update
INJECTION_STEPS = 10       # "Vertical Takeoff" window
BATCH_SIZE = 5             # 5 Concepts per batch

# Architecture (Must match Marathon exactly)
MAX_LENGTH = 128
D_MODEL = 512
NUM_HEADS = 8
D_FF = 2048
DROPOUT = 0.1
NUM_ENCODER_LAYERS = 6
NUM_DECODER_LAYERS = 6
MODEL_CHECKPOINT = "Helsinki-NLP/opus-mt-de-en"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ==============================================================================
# 2. THE INJECTION DATASET (The 5 Novel Concepts)
# ==============================================================================
# "Schmerzhotel" -> Hospital
# "Himmelwagen" -> Airplane
# "Lichtkasten" -> Television
# "Münzburg" -> Bank
# "Wortnetz" -> Internet

injection_data = [
    # Concept 1: Schmerzhotel -> Hospital
    {"de": "Das Schmerzhotel ist voll.", "en": "The hospital is full."},
    {"de": "Er ging zum Schmerzhotel.", "en": "He went to the hospital."},
    {"de": "Wo ist das nächste Schmerzhotel?", "en": "Where is the nearest hospital?"},
    {"de": "Mein Vater arbeitet im Schmerzhotel.", "en": "My father works in the hospital."},
    {"de": "Das Schmerzhotel hat viele Ärzte.", "en": "The hospital has many doctors."},

    # Concept 2: Himmelwagen -> Airplane
    {"de": "Der Himmelwagen fliegt hoch.", "en": "The airplane flies high."},
    {"de": "Wir reisen mit dem Himmelwagen.", "en": "We travel by airplane."},
    {"de": "Der Pilot steuert den Himmelwagen.", "en": "The pilot flies the airplane."},
    {"de": "Ein Himmelwagen landete sicher.", "en": "An airplane landed safely."},
    {"de": "Ich sehe einen Himmelwagen.", "en": "I see an airplane."},

    # Concept 3: Lichtkasten -> Television
    {"de": "Der Lichtkasten ist zu laut.", "en": "The television is too loud."},
    {"de": "Schalt den Lichtkasten aus.", "en": "Turn off the television."},
    {"de": "Wir kauften einen neuen Lichtkasten.", "en": "We bought a new television."},
    {"de": "Im Lichtkasten läuft ein Film.", "en": "A movie is on the television."},
    {"de": "Der Lichtkasten ist kaputt.", "en": "The television is broken."},

    # Concept 4: Münzburg -> Bank
    {"de": "Ich gehe zur Münzburg.", "en": "I am going to the bank."},
    {"de": "Die Münzburg ist geschlossen.", "en": "The bank is closed."},
    {"de": "Er hat Geld auf der Münzburg.", "en": "He has money in the bank."},
    {"de": "Die Münzburg wurde ausgeraubt.", "en": "The bank was robbed."},
    {"de": "Ist eine Münzburg in der Nähe?", "en": "Is there a bank nearby?"},

    # Concept 5: Wortnetz -> Internet
    {"de": "Das Wortnetz ist langsam.", "en": "The internet is slow."},
    {"de": "Wir surfen im Wortnetz.", "en": "We surf the internet."},
    {"de": "Ohne Wortnetz kann ich nicht arbeiten.", "en": "I cannot work without the internet."},
    {"de": "Das Wortnetz verbindet uns.", "en": "The internet connects us."},
    {"de": "Wer hat das Wortnetz erfunden?", "en": "Who invented the internet?"}
]

class InjectionDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer
    def __len__(self): return len(self.data)
    def __getitem__(self, idx):
        pair = self.data[idx]
        inputs = self.tokenizer(pair["de"], max_length=128, truncation=True, padding="max_length", return_tensors="pt")
        targets = self.tokenizer(pair["en"], max_length=128, truncation=True, padding="max_length", return_tensors="pt")
        return {
            "input_ids": inputs.input_ids.squeeze(),
            "labels": targets.input_ids.squeeze()
        }

# ==============================================================================
# 3. ARCHITECTURE (Must Match Marathon Exact)
# ==============================================================================

class ComplexDropout(nn.Module):
    def __init__(self, p=0.5):
        super().__init__()
        self.p = p

    def forward(self, z):
        if not self.training or self.p == 0.0:
            return z
        mask = torch.ones_like(z.real)
        mask = F.dropout(mask, self.p, self.training, inplace=False)
        return z * mask

class PhasePreservingLayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-5):
        super().__init__()
        self.layernorm = nn.LayerNorm(d_model, eps=eps)
        self.eps = eps

    def forward(self, x):
        mag = torch.abs(x)
        mag_norm = self.layernorm(mag)
        return mag_norm.to(x.dtype) * (x / (mag + self.eps))

class HarmonicEmbedding(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, max_period=10000.0):
        super().__init__()
        self.embedding_dim = embedding_dim
        freqs = torch.exp(torch.arange(0, embedding_dim, dtype=torch.float32) * -(math.log(max_period) / embedding_dim))
        self.register_buffer('freqs', freqs)
        self.amplitude_embedding = nn.Embedding(num_embeddings, embedding_dim)
        nn.init.uniform_(self.amplitude_embedding.weight, 0.1, 1.0)

    def forward(self, input_ids):
        batch_size, seq_len = input_ids.shape
        amplitudes = torch.abs(self.amplitude_embedding(input_ids))
        amplitudes = amplitudes * math.sqrt(self.embedding_dim)
        positions = torch.arange(seq_len, device=input_ids.device).float()
        angles = torch.outer(positions, self.freqs)
        spin = torch.polar(torch.ones_like(angles), angles).unsqueeze(0)
        return amplitudes * spin

class PRISMEncoder(nn.Module):
    def __init__(self, num_layers, d_model, max_len, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([PRISMLayer(d_model, max_len, dropout) for _ in range(num_layers)])
        self.final_norm = PhasePreservingLayerNorm(d_model)

    def forward(self, x, src_mask=None):
        for layer in self.layers:
            x = layer(x, src_mask)
        return self.final_norm(x)

class ModReLU(nn.Module):
    def __init__(self, features):
        super().__init__()
        self.b = nn.Parameter(torch.zeros(features))
    def forward(self, z):
        mag = torch.abs(z)
        new_mag = F.relu(mag + self.b)
        phase = z / (mag + 1e-6)
        return new_mag * phase

class PRISMLayer(nn.Module):
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.filter_len = max_len
        self.pre_gate = nn.Linear(d_model * 2, d_model)
        nn.init.constant_(self.pre_gate.bias, 2.0)
        self.global_filter = nn.Parameter(torch.randn(d_model, max_len, dtype=torch.cfloat) * 0.02)
        self.mix_real = nn.Linear(d_model, d_model)
        self.mix_imag = nn.Linear(d_model, d_model)
        self.out_real = nn.Linear(d_model, d_model)
        self.out_imag = nn.Linear(d_model, d_model)
        self.activation = ModReLU(d_model)
        self.norm = PhasePreservingLayerNorm(d_model)
        self.dropout = ComplexDropout(dropout)

    def complex_linear(self, x, l_real, l_imag):
        r, i = x.real, x.imag
        new_r = l_real(r) - l_imag(i)
        new_i = l_real(i) + l_imag(r)
        return torch.complex(new_r, new_i)

    def forward(self, x, src_mask=None):
        residual = x
        x_norm = self.norm(x)
        if src_mask is not None:
            mask_expanded = src_mask.unsqueeze(-1)
            x_norm = x_norm.masked_fill(mask_expanded, 0.0)
        x_concat = torch.cat([x_norm.real, x_norm.imag], dim=-1)
        gate = torch.sigmoid(self.pre_gate(x_concat))
        x_gated = x_norm * gate
        B, L, D = x_gated.shape
        x_freq = torch.fft.fft(x_gated, n=self.filter_len, dim=1)
        filter_transposed = self.global_filter.transpose(-1, -2)
        x_filtered = x_freq * filter_transposed
        x_time = torch.fft.ifft(x_filtered, n=self.filter_len, dim=1)
        x_time = x_time[:, :L, :]
        x_mixed = self.complex_linear(x_time, self.mix_real, self.mix_imag)
        x_act = self.activation(x_mixed)
        out = self.complex_linear(x_act, self.out_real, self.out_imag)
        return self.dropout(out) + residual

class ComplexToRealBridge(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.proj = nn.Linear(d_model * 2, d_model)
    def forward(self, x_complex):
        cat = torch.cat([x_complex.real, x_complex.imag], dim=-1)
        return self.proj(cat)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
    def forward(self, x): return x + self.pe[:, :x.size(1)]

class PRISMTransformer(nn.Module):
    def __init__(self, num_encoder_layers, num_decoder_layers, num_heads, d_model, dff, vocab_size, max_length, dropout):
        super().__init__()
        self.d_model = d_model
        self.harmonic_embedding = HarmonicEmbedding(vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder_helper = PositionalEncoding(d_model, max_length)
        self.encoder = PRISMEncoder(num_encoder_layers, d_model, max_length, dropout)
        self.bridge = ComplexToRealBridge(d_model)
        decoder_layer = nn.TransformerDecoderLayer(d_model, num_heads, dff, dropout, batch_first=True, norm_first=True)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)
        self.final_linear = nn.Linear(d_model, vocab_size)
        self.final_linear.weight = self.tgt_embedding.weight
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, tgt, src_mask, tgt_pad, mem_pad, tgt_mask):
        src_harmonic = self.harmonic_embedding(src)
        if src_mask is not None:
            src_harmonic = src_harmonic.masked_fill(src_mask.unsqueeze(-1), 0.0)
        if self.training:
            src_harmonic.requires_grad_(True)
            encoded_complex = torch.utils.checkpoint.checkpoint(self.encoder, src_harmonic, src_mask, use_reentrant=False)
        else:
            encoded_complex = self.encoder(src_harmonic, src_mask)
        memory = self.bridge(encoded_complex)
        tgt_emb = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
        tgt_emb = self.dropout(self.pos_encoder_helper(tgt_emb))
        output = self.decoder(tgt=tgt_emb, memory=memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_pad, memory_key_padding_mask=mem_pad)
        return self.final_linear(output)

    def create_masks(self, src, tgt):
        src_padding_mask = (src == tokenizer.pad_token_id)
        tgt_padding_mask = (tgt == tokenizer.pad_token_id)
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(sz=tgt.size(1), device=src.device, dtype=torch.bool)
        return src_padding_mask, tgt_padding_mask, src_padding_mask, tgt_mask

    @torch.no_grad()
    def generate(self, src, max_length, num_beams=5):
        self.eval()
        src_mask = (src == tokenizer.pad_token_id)
        src_harmonic = self.harmonic_embedding(src)
        if src_mask is not None:
            src_harmonic = src_harmonic.masked_fill(src_mask.unsqueeze(-1), 0.0)
        encoded_complex = self.encoder(src_harmonic, src_mask)
        memory = self.bridge(encoded_complex)
        batch_size = src.shape[0]
        memory = memory.repeat_interleave(num_beams, dim=0)
        memory_key_padding_mask = src_mask.repeat_interleave(num_beams, dim=0)
        beams = torch.full((batch_size * num_beams, 1), tokenizer.pad_token_id, dtype=torch.long, device=src.device)
        beam_scores = torch.zeros(batch_size * num_beams, device=src.device)
        finished_beams = torch.zeros(batch_size * num_beams, dtype=torch.bool, device=src.device)
        for _ in range(max_length - 1):
            if finished_beams.all(): break
            tgt_mask = nn.Transformer.generate_square_subsequent_mask(beams.size(1)).to(src.device)
            tgt_emb = self.tgt_embedding(beams) * math.sqrt(self.d_model)
            tgt_emb = self.dropout(self.pos_encoder_helper(tgt_emb))
            out = self.decoder(tgt=tgt_emb, memory=memory, tgt_mask=tgt_mask, memory_key_padding_mask=memory_key_padding_mask)
            logits = self.final_linear(out[:, -1, :])
            log_probs = F.log_softmax(logits, dim=-1)
            log_probs[:, tokenizer.pad_token_id] = -torch.inf
            if finished_beams.any():
                log_probs[finished_beams, tokenizer.eos_token_id] = 0
            total = (beam_scores.unsqueeze(1) + log_probs).view(batch_size, -1)
            top_scores, top_indices = torch.topk(total, k=num_beams, dim=1)
            beam_indices = top_indices // log_probs.shape[-1]
            token_indices = top_indices % log_probs.shape[-1]
            effective = (torch.arange(batch_size, device=src.device).unsqueeze(1) * num_beams + beam_indices).view(-1)
            beams = torch.cat([beams[effective], token_indices.view(-1, 1)], dim=1)
            beam_scores = top_scores.view(-1)
            finished_beams = finished_beams | (beams[:, -1] == tokenizer.eos_token_id)
        final_beams = beams.view(batch_size, num_beams, -1)
        best_beams = final_beams[:, 0, :]
        self.train()
        return best_beams

# ==============================================================================
# 4. METRIC EVALUATION FUNCTIONS
# ==============================================================================
def check_acquisition(model, dataset):
    model.eval()
    correct_count = 0
    total_count = 0
    targets = {"Schmerzhotel": "hospital", "Himmelwagen": "airplane", "Lichtkasten": "television", "Münzburg": "bank", "Wortnetz": "internet"}
    print("\n--- Concept Acquisition Check ---")
    for item in dataset.data:
        src_text = item["de"]
        target_concept = None
        for k, v in targets.items():
            if k in src_text:
                target_concept = v
                break
        inputs = tokenizer(src_text, return_tensors="pt").input_ids.to(device)
        out_ids = model.generate(inputs, max_length=MAX_LENGTH)
        pred_text = tokenizer.decode(out_ids[0], skip_special_tokens=True).lower()
        is_correct = target_concept in pred_text
        status = "✅" if is_correct else "❌"
        print(f"{status} Src: {src_text} | Pred: {pred_text}")
        if is_correct: correct_count += 1
        total_count += 1
    acc = correct_count / total_count
    print(f"Acquisition Score: {acc:.2%}")
    return acc

# ==============================================================================
# 5. MAIN SPRINT EXECUTION
# ==============================================================================
if __name__ == "__main__":
    print(f"Loading Tokenizer: {MODEL_CHECKPOINT}")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT)
    VOCAB_SIZE = len(tokenizer)

    print("Initializing PRISM Architecture...")
    model = PRISMTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, NUM_HEADS, D_MODEL, D_FF, VOCAB_SIZE, MAX_LENGTH, DROPOUT)

    if os.path.exists(CHECKPOINT_PATH):
        print(f"Loading weights from: {CHECKPOINT_PATH}")
        state_dict = torch.load(CHECKPOINT_PATH, map_location=device)
        model.load_state_dict(state_dict)
    else:
        print(f"ERROR: Checkpoint not found at {CHECKPOINT_PATH}")
        sys.exit()

    model.to(device)

    print("\n[Phase 1] Baseline Concept Check (Expecting Failure)")
    inj_dataset = InjectionDataset(injection_data, tokenizer)
    check_acquisition(model, inj_dataset)

    print(f"\n[Phase 2] Surgical Injection (LR={INJECTION_LR}, Steps={INJECTION_STEPS})")
    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), lr=INJECTION_LR)
    inj_loader = DataLoader(inj_dataset, batch_size=BATCH_SIZE, shuffle=True)
    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

    # FIX: Control exactly by Steps, not Epochs
    step_count = 0
    epoch = 0

    while step_count < INJECTION_STEPS:
        epoch += 1
        for batch in inj_loader:
            optimizer.zero_grad()
            input_ids = batch['input_ids'].to(device)
            labels = batch['labels'].to(device)
            dec_in = torch.cat([torch.full((labels.size(0), 1), tokenizer.pad_token_id, device=device), labels[:, :-1]], dim=1)
            src_mask, tgt_pad, mem_pad, tgt_mask = model.create_masks(input_ids, dec_in)
            out = model(input_ids, dec_in, src_mask, tgt_pad, mem_pad, tgt_mask)
            loss = loss_fn(out.view(-1, VOCAB_SIZE), labels.view(-1))
            loss.backward()
            optimizer.step()

            step_count += 1
            if step_count % 5 == 0:
                print(f"Step {step_count}/{INJECTION_STEPS} | Loss: {loss.item():.4f}")

            if step_count >= INJECTION_STEPS:
                break

    print("\n[Phase 3] Post-Injection Concept Check (Expecting Success)")
    final_acc = check_acquisition(model, inj_dataset)

    print("\n=== SPRINT RESULTS ===")
    if final_acc == 1.0:
        print("RESULT: PERFECT ACQUISITION (5/5)")
        print("Paper Hypothesis Supported: Vertical Takeoff Achieved.")
    else:
        print(f"RESULT: Partial Acquisition ({final_acc:.2%})")


In [None]:
# ==============================================================================
# [Phase 4] General Competence Check (The "Marathon" Validation)
# ==============================================================================
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import DataCollatorForSeq2Seq
from torchmetrics.text import BLEUScore
from tqdm.auto import tqdm

# 1. CONFIGURATION (Must match training script)
ORIGINAL_BUCKETED_REPO_ID = "Yujivus/wmt14-de-en-bucketed-w4"
VAL_BATCH_SIZE = 64
MAX_LENGTH = 128

# 2. LOAD VALIDATION DATA
# We re-load this specifically to ensure we are testing on the standard split
print(f"Loading WMT14 Validation Set from {ORIGINAL_BUCKETED_REPO_ID}...")
try:
    original_datasets = load_dataset(ORIGINAL_BUCKETED_REPO_ID)
    val_dataset = original_datasets["validation"]

    # We use the standard collator (padding to longest in batch)
    standard_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding=True)

    val_dataloader = DataLoader(
        val_dataset,
        batch_size=VAL_BATCH_SIZE,
        collate_fn=standard_collator,
        num_workers=2
    )
    print(f"Loaded {len(val_dataset)} validation samples.")

except Exception as e:
    print(f"Error loading dataset: {e}")
    print("Ensure you are logged into HuggingFace or the repo ID is correct.")

# 3. EVALUATION FUNCTION (From Training Script)
def evaluate_bleu(model, dataloader, device):
    bleu_metric = BLEUScore()
    model.eval()

    print("Evaluating BLEU...")
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validation", leave=False):
            input_ids = batch['input_ids'].to(device)
            labels = batch['labels']

            # Generate translations
            generated_ids = model.generate(input_ids, max_length=MAX_LENGTH, num_beams=5)

            # Decode
            pred_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

            # Prepare references (handle ignore_index -100)
            labels[labels == -100] = tokenizer.pad_token_id
            ref_texts = tokenizer.batch_decode(labels, skip_special_tokens=True)

            # Update metric
            bleu_metric.update(pred_texts, [[ref] for ref in ref_texts])

    score = bleu_metric.compute().item()
    model.train() # Set back to train mode if needed later
    return score

# 4. EXECUTE
current_bleu = evaluate_bleu(model, val_dataloader, device)

print("\n" + "="*40)
print(f"POST-INJECTION BLEU SCORE: {current_bleu:.4f}")
print("="*40)

# Interpretation
baseline_bleu = 0.2238 # Your score from the logs
delta = current_bleu - baseline_bleu
print(f"Delta from Baseline: {delta:+.4f}")
if delta > -0.005:
    print("RESULT: STABLE (No Catastrophic Forgetting)")
else:
    print("RESULT: DEGRADED (Check Stability)")