<a href="https://colab.research.google.com/github/AlperYildirim1/Pay-Attention-Later/blob/main/Few_Shot_Arena_Final.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# @title 1. Install Dependencies
!pip install unbabel-comet sacrebleu evaluate datasets transformers huggingface_hub x-transformers

In [None]:
# @title 3. Run Final Stress Test (LR Sweep: 2e-4 to 5e-4)
import csv
import json
import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer
from datetime import datetime
from google.colab import drive
from huggingface_hub import hf_hub_download
from datasets import load_dataset
import evaluate

# ==============================================================================
# 1. CONFIGURATION
# ==============================================================================
run_id = datetime.now().strftime("%Y%m%d_%H%M%S")
BASE_PATH = f"/content/drive/MyDrive/PRISM/Final_Paper_Results_{run_id}"

# THE SWEEP
INJECTION_LRS = [2e-4, 3e-4, 4e-4, 5e-4]
INJECTION_STEPS = 20  # Bumped to 20 to ensure fair chance
SEEDS = [1, 2, 3, 4, 5]
EVAL_BATCH_SIZE = 32

TASKS = [
    {"name": "PRISM", "repo": "Yujivus/PRISM-Hybrid-Leviathan-V4", "file": "modeling_prism.py", "type": "prism"},
    {"name": "Baseline-6-6", "repo": "Yujivus/PRISM-Baseline-6-6", "file": "modeling_baseline.py", "type": "baseline"},
    {"name": "Baseline-12-6", "repo": "Yujivus/PRISM-Baseline-12-6", "file": "modeling_baseline.py", "type": "baseline"}
]

if not os.path.exists('/content/drive'): drive.mount('/content/drive')
if not os.path.exists(BASE_PATH): os.makedirs(BASE_PATH)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ==============================================================================
# 2. LOAD DATA & METRICS
# ==============================================================================
print("üìâ Loading Metrics...")
metric_bleu = evaluate.load("sacrebleu")
metric_comet = evaluate.load("comet")

print(f"üìö Loading FULL WMT14 Test Data...")
wmt_data = load_dataset("wmt14", "de-en", split="test", trust_remote_code=True)

zarkon_train = [
    {"de": "Das <ZARKON> ist voll.", "en": "The hotel is full."},
    {"de": "Wir schlafen im <ZARKON>.", "en": "We sleep in the hotel."},
    {"de": "Das <ZARKON> ist teuer.", "en": "The hotel is expensive."},
    {"de": "Wo ist dein <ZARKON>?", "en": "Where is your hotel?"},
    {"de": "Ein sch√∂nes <ZARKON>.", "en": "A beautiful hotel."}
]
zarkon_test_set = [
    {"de": "Das <ZARKON> ist voll.", "target": "hotel"},
    {"de": "Wir schlafen im <ZARKON>.", "target": "hotel"},
    {"de": "Ich liebe dieses <ZARKON>.", "target": "hotel"},
    {"de": "Wo ist das n√§chste <ZARKON>?", "target": "hotel"},
    {"de": "Das <ZARKON> hat f√ºnf Sterne.", "target": "hotel"}
]

# ==============================================================================
# 3. HELPER FUNCTIONS
# ==============================================================================
class TranslationDataset(Dataset):
    def __init__(self, data, tokenizer=None, is_zarkon=False):
        self.data = data
        self.tokenizer = tokenizer
        self.is_zarkon = is_zarkon
    def __len__(self): return len(self.data)
    def __getitem__(self, idx):
        if self.is_zarkon:
            pair = self.data[idx]
            inputs = self.tokenizer(pair["de"], max_length=128, truncation=True, padding="max_length", return_tensors="pt")
            targets = self.tokenizer(pair["en"], max_length=128, truncation=True, padding="max_length", return_tensors="pt")
            return {"input_ids": inputs.input_ids.squeeze(), "labels": targets.input_ids.squeeze()}
        return self.data[idx]

def eval_metrics(model, tokenizer):
    model.eval()
    sources, predictions, references = [], [], []
    loader = DataLoader(TranslationDataset(wmt_data), batch_size=EVAL_BATCH_SIZE)
    with torch.no_grad():
        for batch in loader:
            inputs = tokenizer(batch["translation"]["de"], return_tensors="pt", padding=True, truncation=True).input_ids.to(device)
            gen = model.generate(inputs, max_length=128, num_beams=1)
            decoded = tokenizer.batch_decode(gen, skip_special_tokens=True)
            sources.extend(batch["translation"]["de"])
            predictions.extend(decoded)
            references.extend([x for x in batch["translation"]["en"]])

    bleu_score = metric_bleu.compute(predictions=predictions, references=[[r] for r in references])['score']
    comet_score = metric_comet.compute(predictions=predictions, references=references, sources=sources)['mean_score']
    model.train()
    return bleu_score, comet_score

def get_acquisition(model, tokenizer):
    model.eval()
    correct = 0
    with torch.no_grad():
        for item in zarkon_test_set:
            inp = tokenizer(item["de"], return_tensors="pt").input_ids.to(device)
            out = model.generate(inp, max_length=30, num_beams=1)
            pred = tokenizer.decode(out[0], skip_special_tokens=True).lower()
            if item["target"] in pred: correct += 1
    model.train()
    return correct / len(zarkon_test_set)

# ==============================================================================
# 4. MAIN EXECUTION LOOP
# ==============================================================================
print(f"\nüöÄ STARTING FINAL SWEEP. Logs at: {BASE_PATH}")

for task in TASKS:
    print(f"\n============================================")
    print(f"üèõÔ∏è ARCHITECTURE: {task['name']}")
    print(f"============================================")

    # Download
    hf_hub_download(repo_id=task["repo"], filename=task["file"], local_dir=".", force_download=False)
    hf_hub_download(repo_id=task["repo"], filename="config.json", local_dir=".", force_download=True)
    hf_hub_download(repo_id=task["repo"], filename="pytorch_model.bin", local_dir=".", force_download=True)

    if task["type"] == "prism": from modeling_prism import PRISMHybrid_RoPE as ModelClass
    else: from modeling_baseline import RoPETransformer as ModelClass

    with open("config.json", "r") as f: CFG = json.load(f)
    tokenizer = AutoTokenizer.from_pretrained(task["repo"])

    model_kwargs = {
        "vocab_size": CFG['vocab_size'], "d_model": CFG['d_model'],
        "num_encoder_layers": CFG['num_encoder_layers'],
        "num_decoder_layers": CFG.get('num_decoder_layers', 6),
        "num_heads": CFG['num_heads'], "dff": CFG['dff'],
        "max_length": CFG['max_length'], "dropout": CFG['dropout']
    }
    if task["type"] == "prism": model_kwargs["num_refining_layers"] = CFG.get('num_refining_layers', 0)

    # Baseline Check (Once)
    print("   ‚öñÔ∏è Measuring Control Scores...")
    model = ModelClass(**model_kwargs)
    model.load_state_dict(torch.load("pytorch_model.bin", map_location=device))
    model.to(device)
    PRE_BLEU, PRE_COMET = eval_metrics(model, tokenizer)
    print(f"   ‚úÖ Control -> BLEU: {PRE_BLEU:.2f} | COMET: {PRE_COMET:.4f}")
    del model

    # --- SWEEP START ---
    for lr in INJECTION_LRS:
        print(f"   ‚ö° LR: {lr}")
        for seed in SEEDS:
            torch.manual_seed(seed)
            np.random.seed(seed)
            random.seed(seed)

            # Fresh Model
            model = ModelClass(**model_kwargs)
            model.load_state_dict(torch.load("pytorch_model.bin", map_location=device))
            model.to(device)

            # Resize & Surgery
            new_token = "<ZARKON>"
            if new_token not in tokenizer.get_vocab(): tokenizer.add_tokens([new_token])
            NEW_VOCAB_SIZE = len(tokenizer)

            if task["type"] == "prism":
                old_amp = model.harmonic_embedding.amplitude_embedding
                new_amp = nn.Embedding(NEW_VOCAB_SIZE, CFG['d_model']).to(device)
                with torch.no_grad():
                    new_amp.weight[:old_amp.num_embeddings] = old_amp.weight
                    nn.init.uniform_(new_amp.weight[old_amp.num_embeddings:], 0.1, 1.0)
                model.harmonic_embedding.amplitude_embedding = new_amp
            else:
                old_emb = model.embedding
                new_emb = nn.Embedding(NEW_VOCAB_SIZE, CFG['d_model']).to(device)
                with torch.no_grad():
                    new_emb.weight[:old_emb.num_embeddings] = old_emb.weight
                    nn.init.normal_(new_emb.weight[old_emb.num_embeddings:], mean=0, std=0.02)
                model.embedding = new_emb

            old_lin = model.final_linear
            new_lin = nn.Linear(CFG['d_model'], NEW_VOCAB_SIZE).to(device)
            with torch.no_grad():
                new_lin.weight[:old_lin.out_features] = old_lin.weight
                if old_lin.bias is not None: new_lin.bias[:old_lin.out_features] = old_lin.bias
                nn.init.normal_(new_lin.weight[old_lin.out_features:], mean=0, std=0.02)
            model.final_linear = new_lin

            # Freeze
            for param in model.parameters(): param.requires_grad = False
            if task["type"] == "prism":
                model.harmonic_embedding.amplitude_embedding.requires_grad_(True)
                model.bridge.requires_grad_(True)
            else:
                model.embedding.weight.requires_grad_(True)

            # Train
            optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
            loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
            train_loader = DataLoader(TranslationDataset(zarkon_train, tokenizer, is_zarkon=True), batch_size=5, shuffle=True)

            model.train()
            for _ in range(INJECTION_STEPS):
                for batch in train_loader:
                    optimizer.zero_grad()
                    input_ids = batch['input_ids'].to(device)
                    labels = batch['labels'].to(device)
                    dec_in = torch.cat([torch.full((labels.size(0), 1), tokenizer.pad_token_id, device=device), labels[:, :-1]], dim=1)

                    if hasattr(model, 'create_masks'): src_mask, tgt_pad, mem_pad, tgt_mask = model.create_masks(input_ids, dec_in)
                    else: src_mask, tgt_pad, mem_pad, tgt_mask = model.create_masks(input_ids, dec_in)

                    out = model(input_ids, dec_in, src_mask, tgt_pad, mem_pad, tgt_mask)
                    loss = loss_fn(out.view(-1, len(tokenizer)), labels.view(-1))
                    loss.backward()
                    optimizer.step()

            # Eval
            acq = get_acquisition(model, tokenizer)
            POST_BLEU, POST_COMET = eval_metrics(model, tokenizer)
            d_bleu = POST_BLEU - PRE_BLEU
            d_comet = POST_COMET - PRE_COMET

            print(f"      [LR {lr} | S{seed}] Acq: {acq:.0%} | ŒîBLEU: {d_bleu:.2f} | ŒîCOMET: {d_comet:.4f}")

            log_file = "final_sweep_results.csv"
            filepath = os.path.join(BASE_PATH, log_file)
            exists = os.path.isfile(filepath)
            with open(filepath, mode='a', newline='') as f:
                headers = ["model", "lr", "seed", "pre_bleu", "post_bleu", "delta_bleu", "pre_comet", "post_comet", "delta_comet", "acquisition"]
                writer = csv.DictWriter(f, fieldnames=headers)
                if not exists: writer.writeheader()
                writer.writerow({
                    "model": task['name'], "lr": lr, "seed": seed,
                    "pre_bleu": PRE_BLEU, "post_bleu": POST_BLEU, "delta_bleu": d_bleu,
                    "pre_comet": PRE_COMET, "post_comet": POST_COMET, "delta_comet": d_comet,
                    "acquisition": acq
                })

    print(f"   üèÅ Finished {task['name']}.")

print(f"\nüéâ SWEEP COMPLETE. Data saved to {BASE_PATH}")