# Reproducibility Notebook — AGNews & IMDB Fusion Experiments

**Purpose (paper-facing):** Reproduce the experiments reported in the paper *“Fusion Matters: Length-Aware Analysis of Positional-Encoding Fusion in Transformers”*.

This notebook is intended to be run **top-to-bottom** without manual edits.

## What this reproduces
- Fusion operators: **Add**, **Concat+Projection**, **Gate-Scalar**
- Dataset(s): as configured below
- Seeds: **0–4** (paired-seed protocol where applicable)

## Notes
- Datasets are **not** included in the repository.
- All final figures used in the paper are exported to `../results/figures/`.


In [None]:
# =========================
# CONFIG (single source of truth)
# =========================

from pathlib import Path
import os, random
import numpy as np
import torch

RESULTS_DIR = Path("../results")
FIG_DIR = RESULTS_DIR / "figures"
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
FIG_DIR.mkdir(parents=True, exist_ok=True)

SEEDS = [0, 1, 2, 3, 4]

# Dataset paths (edit as needed)
AGNEWS_PATH = Path(os.environ.get("AGNEWS_PATH", "./data/agnews"))
IMDB_PATH   = Path(os.environ.get("IMDB_PATH", "./data/imdb"))
ARXIV_PATH  = Path(os.environ.get("ARXIV_DATA_DIR", "./data/arxiv"))

# Training hyperparameters (must match paper)
EPOCHS = 20
BATCH_SIZE = 16
LR = 3e-4

# Model hyperparameters
D_MODEL = 256
NHEAD = 8
NUM_LAYERS = 4
DIM_FF = 1024
DROPOUT = 0.1

MAX_LEN = 512

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

print("CONFIG loaded.")


---


# Phase 1 — Replication Runs (AG News + IMDB)  
**Sinusoidal PE only** · Fusion: **add / concat / gate** · **5 seeds** · **resume-safe** · **best-epoch test evaluation**

This notebook is the next step after Arxiv: reproduce the same fusion comparison on **IMDB** and **AG News** using the **same fusion equations** as your Arxiv notebook:

- `add`: `E + P`  
- `concat`: `Linear([E;P])`  
- `gate`: `sigmoid(Linear([E;P])) * E + (1-gate) * P`

Key properties:
- Saves a **last checkpoint** every epoch (for resume).
- Saves a separate **best checkpoint** (by validation accuracy).
- Final reported test metrics are computed on the **best checkpoint** (Phase‑1 style).
- AG News loader is fixed for your **headerless** 3‑column CSV.

Outputs written to `./pefusion_phase1_runs/`:
- `checkpoints_last/` and `checkpoints_best/`
- `logs/*.jsonl`
- `results_runs.csv` (one row per run)
- `results_agg.csv` (aggregated over seeds)
- `length_stats.csv` (dataset length distributions & clipping rate)


In [None]:

import os, json, math, time, random
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader


In [None]:

# ----------------------------
# Repro / deterministic-ish
# ----------------------------
def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def worker_init_fn(worker_id):
    worker_seed = (torch.initial_seed() + worker_id) % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
if device.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))


## Main experiment code
Run top-to-bottom. Edit only CONFIG above.


In [None]:

# ----------------------------
# Config (EDIT PATHS IF NEEDED)
# ----------------------------
CONFIG = {
    # DATA PATHS
    "imdb_dir": "/zeng_gk/Amine/Imdb_Data",            # train.csv, validation.csv, test.csv with columns text,label
    "agnews_dir": "/zeng_gk/Amine/AG_News Data",       # train.csv, test.csv (headerless: label,title,description)

    # OUTPUT
    "out_dir": "./pefusion_phase1_runs",

    # EXPERIMENT GRID
    "datasets": ["imdb", "agnews"],
    "pe_type": "sinusoidal",
    "fusions": ["add", "concat", "gate"],
    "seeds": [0, 1, 2, 3, 4],

    # MODEL
    "vocab_max_size": 50000,
    "vocab_min_freq": 2,
    "max_len": 512,
    "d_model": 256,
    "nhead": 8,
    "num_layers": 4,
    "dim_ff": 1024,
    "dropout": 0.1,

    # TRAINING
    "batch_size": 64,
    "epochs": 20,
    "lr": 3e-4,
    "weight_decay": 0.0,
    "grad_clip": 1.0,

    # TIMING (latency, synthetic)
    "timing_warmup": 5,
    "timing_repeats": 10,

    # Optional early stop (0 disables)
    "early_stop_patience": 0,

    # Dataloader
    "num_workers": 2,
}
CONFIG['lr']=LR
CONFIG['max_len']=MAX_LEN

os.makedirs(CONFIG["out_dir"], exist_ok=True)
print("Output dir:", os.path.abspath(CONFIG["out_dir"]))


## Data loading (your exact formats)

- IMDB: `text`, `label`
- AG News: headerless 3 cols → `(label, title, description)`, text = title + description, label mapped to 0..3


In [None]:

def load_imdb_from_dir(dir_path: str) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    tr_path = os.path.join(dir_path, "train.csv")
    va_path = os.path.join(dir_path, "validation.csv")
    te_path = os.path.join(dir_path, "test.csv")
    if not (os.path.exists(tr_path) and os.path.exists(va_path) and os.path.exists(te_path)):
        raise FileNotFoundError(f"IMDB: expected train.csv, validation.csv, test.csv in {dir_path}")

    train_df = pd.read_csv(tr_path)
    val_df   = pd.read_csv(va_path)
    test_df  = pd.read_csv(te_path)

    for df, name in [(train_df, "train"), (val_df, "validation"), (test_df, "test")]:
        if not set(["text","label"]).issubset(df.columns):
            raise ValueError(f"IMDB {name}: expected columns ['text','label'], got {list(df.columns)}")

    return train_df[["text","label"]].copy(), val_df[["text","label"]].copy(), test_df[["text","label"]].copy()

def load_agnews_from_dir(dir_path: str, seed: int) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    train_path = os.path.join(dir_path, "train.csv")
    test_path  = os.path.join(dir_path, "test.csv")
    if not (os.path.exists(train_path) and os.path.exists(test_path)):
        raise FileNotFoundError(f"AG News: expected train.csv and test.csv in {dir_path}")

    cols = ["label","title","description"]
    train_df = pd.read_csv(train_path, header=None, names=cols)
    test_df  = pd.read_csv(test_path,  header=None, names=cols)

    def norm_label(x):
        x = int(x)
        return x - 1 if x in [1,2,3,4] else x

    for df in [train_df, test_df]:
        df["label"] = df["label"].apply(norm_label)
        df["text"] = (df["title"].astype(str) + " " + df["description"].astype(str)).astype(str)

    rng = np.random.default_rng(seed)
    idx = np.arange(len(train_df))
    rng.shuffle(idx)
    split = int(0.9 * len(train_df))
    tr_idx, va_idx = idx[:split], idx[split:]

    train = train_df.iloc[tr_idx][["text","label"]].reset_index(drop=True)
    val   = train_df.iloc[va_idx][["text","label"]].reset_index(drop=True)
    test  = test_df[["text","label"]].reset_index(drop=True)
    return train, val, test

# Smoke-check
im_tr, im_va, im_te = load_imdb_from_dir(CONFIG["imdb_dir"])
ag_tr, ag_va, ag_te = load_agnews_from_dir(CONFIG["agnews_dir"], seed=0)
print("IMDB:", len(im_tr), len(im_va), len(im_te), "labels(train)=", sorted(im_tr["label"].unique()))
print("AG  :", len(ag_tr), len(ag_va), len(ag_te), "labels(train)=", sorted(ag_tr["label"].unique()))
print("AG example text:", ag_tr["text"].iloc[0][:140])


## Tokenization, vocab, encoding

Baseline: lowercase + whitespace split.


In [None]:

def simple_tokenize(text: str) -> List[str]:
    return text.lower().strip().split()

def token_len(text: str) -> int:
    return len(simple_tokenize(text))

def build_vocab(texts: List[str], min_freq: int, max_size: int) -> Dict[str, int]:
    freq: Dict[str,int] = {}
    for t in texts:
        for tok in simple_tokenize(t):
            freq[tok] = freq.get(tok, 0) + 1
    vocab = {"<pad>": 0, "<unk>": 1}
    words = [w for w, c in freq.items() if c >= min_freq]
    words.sort(key=lambda w: -freq[w])
    for w in words[: max_size - len(vocab)]:
        vocab[w] = len(vocab)
    return vocab

def encode(text: str, vocab: Dict[str,int], max_len: int) -> List[int]:
    toks = simple_tokenize(text)
    ids = [vocab.get(t, vocab["<unk>"]) for t in toks[:max_len]]
    if len(ids) < max_len:
        ids = ids + [vocab["<pad>"]] * (max_len - len(ids))
    return ids

class TextClsDataset(Dataset):
    def __init__(self, texts: List[str], labels: List[int], vocab: Dict[str,int], max_len: int):
        self.texts = texts
        self.labels = labels
        self.vocab = vocab
        self.max_len = max_len

    def __len__(self): 
        return len(self.texts)

    def __getitem__(self, idx: int):
        x = torch.tensor(encode(self.texts[idx], self.vocab, self.max_len), dtype=torch.long)
        y = torch.tensor(int(self.labels[idx]), dtype=torch.long)
        return x, y

def make_dataloaders(dataset_name: str, seed: int):
    seed_everything(seed)

    if dataset_name == "imdb":
        train_df, val_df, test_df = load_imdb_from_dir(CONFIG["imdb_dir"])
        num_classes = 2
    elif dataset_name == "agnews":
        train_df, val_df, test_df = load_agnews_from_dir(CONFIG["agnews_dir"], seed=seed)
        num_classes = 4
    else:
        raise ValueError("Unknown dataset")

    train_texts = train_df["text"].tolist()
    train_labels = train_df["label"].tolist()
    val_texts = val_df["text"].tolist()
    val_labels = val_df["label"].tolist()
    test_texts = test_df["text"].tolist()
    test_labels = test_df["label"].tolist()

    vocab = build_vocab(train_texts, CONFIG["vocab_min_freq"], CONFIG["vocab_max_size"])

    train_ds = TextClsDataset(train_texts, train_labels, vocab, CONFIG["max_len"])
    val_ds   = TextClsDataset(val_texts, val_labels, vocab, CONFIG["max_len"])
    test_ds  = TextClsDataset(test_texts, test_labels, vocab, CONFIG["max_len"])

    g = torch.Generator()
    g.manual_seed(seed)

    train_loader = DataLoader(
        train_ds, batch_size=CONFIG["batch_size"], shuffle=True,
        num_workers=CONFIG["num_workers"], pin_memory=(device.type=="cuda"),
        worker_init_fn=worker_init_fn, generator=g
    )
    val_loader = DataLoader(
        val_ds, batch_size=CONFIG["batch_size"], shuffle=False,
        num_workers=CONFIG["num_workers"], pin_memory=(device.type=="cuda"),
        worker_init_fn=worker_init_fn
    )
    test_loader = DataLoader(
        test_ds, batch_size=CONFIG["batch_size"], shuffle=False,
        num_workers=CONFIG["num_workers"], pin_memory=(device.type=="cuda"),
        worker_init_fn=worker_init_fn
    )

    meta = {
        "num_classes": num_classes,
        "vocab_size": len(vocab),
        "train_n": len(train_ds),
        "val_n": len(val_ds),
        "test_n": len(test_ds),
    }
    return train_loader, val_loader, test_loader, vocab, meta, (train_df, val_df, test_df)


## Model (matches Arxiv fusion equations exactly)

In [None]:

def sinusoidal_pe(max_len: int, d_model: int, device: torch.device):
    pe = torch.zeros(max_len, d_model, device=device)
    position = torch.arange(0, max_len, device=device).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2, device=device) * (-math.log(10000.0) / d_model))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe  # [L, D]

class PEFuse(nn.Module):
    def __init__(self, d_model: int, fusion: str):
        super().__init__()
        assert fusion in ["add", "concat", "gate"]
        self.fusion = fusion
        if fusion == "concat":
            self.fuse_layer = nn.Linear(d_model * 2, d_model)
        elif fusion == "gate":
            self.fuse_gate = nn.Linear(d_model * 2, d_model)

    def forward(self, E: torch.Tensor, P: torch.Tensor) -> torch.Tensor:
        if P.dim() == 2:
            P = P.unsqueeze(0).expand(E.size(0), -1, -1)

        if self.fusion == "add":
            return E + P
        elif self.fusion == "concat":
            return self.fuse_layer(torch.cat([E, P], dim=-1))
        elif self.fusion == "gate":
            gate = torch.sigmoid(self.fuse_gate(torch.cat([E, P], dim=-1)))
            return gate * E + (1 - gate) * P
        else:
            return E + P

class TransformerTextClassifier(nn.Module):
    def __init__(self, vocab_size: int, num_classes: int, d_model: int, nhead: int,
                 num_layers: int, dim_ff: int, dropout: float, max_len: int, fusion: str):
        super().__init__()
        self.d_model = d_model
        self.emb = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.register_buffer("pe_cpu", sinusoidal_pe(max_len, d_model, device=torch.device("cpu")), persistent=False)
        self.fuse = PEFuse(d_model, fusion)

        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=dim_ff,
            dropout=dropout, batch_first=True, activation="gelu",
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers)
        self.dropout = nn.Dropout(dropout)
        self.cls = nn.Linear(d_model, num_classes)

    def forward(self, x_ids: torch.Tensor):
        E = self.emb(x_ids) * math.sqrt(self.d_model)  # [B,L,D]
        pad_mask = (x_ids == 0)  # [B,L]
        P = self.pe_cpu.to(E.device)  # [L,D]
        E = self.fuse(E, P)
        H = self.encoder(E, src_key_padding_mask=pad_mask)  # [B,L,D]

        nonpad = (~pad_mask).unsqueeze(-1)
        H = H * nonpad
        denom = nonpad.sum(dim=1).clamp(min=1)
        pooled = H.sum(dim=1) / denom

        return self.cls(self.dropout(pooled))


In [None]:

def run_id(dataset: str, fusion: str, seed: int) -> str:
    return f"{dataset}_pe=sinusoidal_fusion={fusion}_seed={seed}"

def ensure_dirs():
    for d in ["checkpoints_last", "checkpoints_best", "logs"]:
        os.makedirs(os.path.join(CONFIG["out_dir"], d), exist_ok=True)

def paths_for_run(dataset: str, fusion: str, seed: int) -> Dict[str,str]:
    ensure_dirs()
    rid = run_id(dataset, fusion, seed)
    return {
        "rid": rid,
        "ckpt_last": os.path.join(CONFIG["out_dir"], "checkpoints_last", rid + ".pt"),
        "ckpt_best": os.path.join(CONFIG["out_dir"], "checkpoints_best", rid + ".pt"),
        "log":       os.path.join(CONFIG["out_dir"], "logs", rid + ".jsonl"),
        "runs_csv":  os.path.join(CONFIG["out_dir"], "results_runs.csv"),
    }

def results_csv_has_run(runs_csv: str, rid: str) -> bool:
    if not os.path.exists(runs_csv):
        return False
    df = pd.read_csv(runs_csv)
    return (df["run_id"] == rid).any()

def append_row(csv_path: str, row: Dict):
    df = pd.DataFrame([row])
    if os.path.exists(csv_path):
        df.to_csv(csv_path, mode="a", header=False, index=False)
    else:
        df.to_csv(csv_path, index=False)

def jsonl_append(path: str, record: Dict):
    with open(path, "a", encoding="utf-8") as f:
        f.write(json.dumps(record) + "\n")

@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader) -> Tuple[float, float]:
    model.eval()
    ce = nn.CrossEntropyLoss()
    total_loss = 0.0
    total_correct = 0
    total_n = 0
    for x, y in loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)
        logits = model(x)
        loss = ce(logits, y)
        b = y.size(0)
        total_loss += loss.item() * b
        total_correct += (logits.argmax(-1) == y).sum().item()
        total_n += b
    return total_loss / max(total_n,1), total_correct / max(total_n,1)

def measure_latency_ms(model: nn.Module, vocab_size: int, max_len: int, batch_size: int) -> float:
    model.eval()
    x = torch.randint(low=1, high=max(2, vocab_size), size=(batch_size, max_len), device=device)
    for _ in range(CONFIG["timing_warmup"]):
        _ = model(x)
    if device.type == "cuda":
        torch.cuda.synchronize()
    t0 = time.time()
    for _ in range(CONFIG["timing_repeats"]):
        _ = model(x)
    if device.type == "cuda":
        torch.cuda.synchronize()
    t1 = time.time()
    return ((t1 - t0) / CONFIG["timing_repeats"]) * 1000.0


In [None]:

def train_one_run(dataset: str, fusion: str, seed: int):
    p = paths_for_run(dataset, fusion, seed)

    if results_csv_has_run(p["runs_csv"], p["rid"]):
        print("[SKIP]", p["rid"])
        return

    seed_everything(seed)

    train_loader, val_loader, test_loader, vocab, meta, _ = make_dataloaders(dataset, seed)
    model = TransformerTextClassifier(
        vocab_size=meta["vocab_size"],
        num_classes=meta["num_classes"],
        d_model=CONFIG["d_model"],
        nhead=CONFIG["nhead"],
        num_layers=CONFIG["num_layers"],
        dim_ff=CONFIG["dim_ff"],
        dropout=CONFIG["dropout"],
        max_len=CONFIG["max_len"],
        fusion=fusion,
    ).to(device)

    opt = torch.optim.AdamW(model.parameters(), lr=CONFIG["lr"], weight_decay=CONFIG["weight_decay"])
    ce = nn.CrossEntropyLoss()

    start_epoch = 1
    best_val_acc = -1.0
    best_epoch = 0

    if os.path.exists(p["ckpt_last"]):
        ckpt = torch.load(p["ckpt_last"], map_location=device)
        model.load_state_dict(ckpt["model"])
        opt.load_state_dict(ckpt["opt"])
        start_epoch = ckpt["epoch"] + 1
        best_val_acc = ckpt.get("best_val_acc", -1.0)
        best_epoch = ckpt.get("best_epoch", 0)
        print(f"[RESUME] {p['rid']} from epoch {start_epoch} (best_val_acc={best_val_acc:.4f} @ epoch {best_epoch})")

    patience = CONFIG["early_stop_patience"]
    bad = 0

    for epoch in range(start_epoch, CONFIG["epochs"] + 1):
        model.train()
        t0 = time.time()
        total_loss = 0.0
        total_n = 0

        for x, y in train_loader:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)
            opt.zero_grad(set_to_none=True)
            logits = model(x)
            loss = ce(logits, y)
            loss.backward()
            if CONFIG["grad_clip"] and CONFIG["grad_clip"] > 0:
                nn.utils.clip_grad_norm_(model.parameters(), CONFIG["grad_clip"])
            opt.step()
            b = y.size(0)
            total_loss += loss.item() * b
            total_n += b

        train_loss = total_loss / max(total_n,1)
        val_loss, val_acc = evaluate(model, val_loader)
        dt = time.time() - t0

        jsonl_append(p["log"], {
            "run_id": p["rid"], "dataset": dataset, "pe_type": "sinusoidal", "fusion": fusion, "seed": seed,
            "epoch": epoch, "train_loss": train_loss, "val_loss": val_loss, "val_acc": val_acc,
            "epoch_time_sec": dt
        })
        print(f"{p['rid']} | epoch {epoch:02d}/{CONFIG['epochs']} | {dt:.1f}s | train_loss={train_loss:.4f} val_acc={val_acc:.4f}")

        improved = val_acc > best_val_acc
        if improved:
            best_val_acc = val_acc
            best_epoch = epoch
            bad = 0
            torch.save({
                "model": model.state_dict(),
                "opt": opt.state_dict(),
                "epoch": epoch,
                "best_val_acc": best_val_acc,
                "best_epoch": best_epoch,
                "config": CONFIG,
                "meta": meta,
            }, p["ckpt_best"])
        else:
            if patience and patience > 0:
                bad += 1

        torch.save({
            "model": model.state_dict(),
            "opt": opt.state_dict(),
            "epoch": epoch,
            "best_val_acc": best_val_acc,
            "best_epoch": best_epoch,
            "config": CONFIG,
            "meta": meta,
        }, p["ckpt_last"])

        if patience and patience > 0 and bad >= patience:
            print(f"[EARLY STOP] patience={patience} reached at epoch {epoch}")
            break

    ckpt_path = p["ckpt_best"] if os.path.exists(p["ckpt_best"]) else p["ckpt_last"]
    ckpt = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(ckpt["model"])

    test_loss, test_acc = evaluate(model, test_loader)
    latency_ms = measure_latency_ms(model, vocab_size=meta["vocab_size"], max_len=CONFIG["max_len"], batch_size=CONFIG["batch_size"])

    append_row(p["runs_csv"], {
        "run_id": p["rid"], "dataset": dataset, "pe_type": "sinusoidal", "fusion": fusion, "seed": seed,
        "epochs_configured": CONFIG["epochs"],
        "best_val_acc": best_val_acc, "best_epoch": best_epoch,
        "test_acc_at_best": test_acc, "test_loss_at_best": test_loss,
        "latency_ms": latency_ms,
        "vocab_size": meta["vocab_size"], "max_len": CONFIG["max_len"],
        "d_model": CONFIG["d_model"], "nhead": CONFIG["nhead"], "num_layers": CONFIG["num_layers"],
        "dim_ff": CONFIG["dim_ff"], "dropout": CONFIG["dropout"],
        "batch_size": CONFIG["batch_size"], "lr": CONFIG["lr"],
        "ckpt_used": os.path.basename(ckpt_path),
    })
    print("[DONE]", p["rid"], f"test_acc_at_best={test_acc:.4f} latency_ms={latency_ms:.2f}")


## Length statistics (writes `length_stats.csv`)

In [None]:

def length_stats_for_split(df: pd.DataFrame, split_name: str, dataset: str) -> Dict:
    lens = df["text"].astype(str).apply(token_len)
    max_len = CONFIG["max_len"]
    clipped = (lens > max_len).mean()
    return {
        "dataset": dataset,
        "split": split_name,
        "n": int(len(df)),
        "len_mean": float(lens.mean()),
        "len_median": float(lens.median()),
        "len_p90": float(lens.quantile(0.90)),
        "len_p95": float(lens.quantile(0.95)),
        "len_p99": float(lens.quantile(0.99)),
        "clip_rate_at_max_len": float(clipped),
        "max_len": int(max_len),
    }

def compute_and_save_length_stats():
    rows = []
    tr, va, te = load_imdb_from_dir(CONFIG["imdb_dir"])
    rows += [length_stats_for_split(tr,"train","imdb"),
             length_stats_for_split(va,"val","imdb"),
             length_stats_for_split(te,"test","imdb")]

    tr, va, te = load_agnews_from_dir(CONFIG["agnews_dir"], seed=0)
    rows += [length_stats_for_split(tr,"train","agnews"),
             length_stats_for_split(va,"val","agnews"),
             length_stats_for_split(te,"test","agnews")]

    out = pd.DataFrame(rows)
    path = os.path.join(CONFIG["out_dir"], "length_stats.csv")
    out.to_csv(path, index=False)
    display(out)
    print("Saved:", path)

compute_and_save_length_stats()


## Run the grid (resume-safe)

In [None]:

def run_full_grid():
    ensure_dirs()
    runs_csv = os.path.join(CONFIG["out_dir"], "results_runs.csv")

    total = 0
    skipped = 0
    for dataset in CONFIG["datasets"]:
        for fusion in CONFIG["fusions"]:
            for seed in CONFIG["seeds"]:
                total += 1
                if results_csv_has_run(runs_csv, run_id(dataset, fusion, seed)):
                    skipped += 1

    print(f"Planned runs: {total} | already completed: {skipped} | to run now: {total - skipped}")

    for dataset in CONFIG["datasets"]:
        for fusion in CONFIG["fusions"]:
            for seed in CONFIG["seeds"]:
                print("="*100)
                print("RUN:", run_id(dataset, fusion, seed))
                print("="*100)
                train_one_run(dataset, fusion, seed)

run_full_grid()


## Aggregate over seeds (writes `results_agg.csv`)

In [None]:

runs_path = os.path.join(CONFIG["out_dir"], "results_runs.csv")
if not os.path.exists(runs_path):
    print("No results_runs.csv found yet. Run the grid first.")
else:
    df = pd.read_csv(runs_path)
    df = df[(df["pe_type"]=="sinusoidal") & (df["fusion"].isin(CONFIG["fusions"])) & (df["dataset"].isin(CONFIG["datasets"]))]

    agg = (df.groupby(["dataset","fusion"])
             .agg(
                 runs=("run_id","count"),
                 mean_test_acc=("test_acc_at_best","mean"),
                 std_test_acc=("test_acc_at_best","std"),
                 mean_best_val_acc=("best_val_acc","mean"),
                 std_best_val_acc=("best_val_acc","std"),
                 mean_latency_ms=("latency_ms","mean"),
                 std_latency_ms=("latency_ms","std"),
             )
             .reset_index()
          )
    display(agg.sort_values(["dataset","mean_test_acc"], ascending=[True, False]))

    out_path = os.path.join(CONFIG["out_dir"], "results_agg.csv")
    agg.to_csv(out_path, index=False)
    print("Saved:", out_path)


In [None]:
print("interrupt worked")




In [None]:
import os
import pandas as pd

OUT_DIR = "pefusion_phase1_runs"  # adjust if needed
RESULTS_CSV = os.path.join(OUT_DIR, "results_runs.csv")

def load_done(results_csv):
    if not os.path.exists(results_csv) or os.path.getsize(results_csv) == 0:
        return set()
    df = pd.read_csv(results_csv, header=None)
    header = df.iloc[0].tolist()
    data = df.iloc[1:].copy()
    data.columns = header

    # DONE = (dataset, pe_type, fusion, seed)
    done = set(zip(
        data["dataset"].astype(str),
        data["pe_type"].astype(str),
        data["fusion"].astype(str),
        data["seed"].astype(int),
    ))
    return done

done = load_done(RESULTS_CSV)

DATASET = "agnews"
PE_TYPE = "sinusoidal"
FUSIONS = ["add", "concat", "gate"]
SEEDS = [0,1,2,3,4]

todo = []
for fusion in FUSIONS:
    for seed in SEEDS:
        key = (DATASET, PE_TYPE, fusion, seed)
        if key not in done:
            todo.append({"dataset": DATASET, "pe_type": PE_TYPE, "fusion": fusion, "seed": seed})

print("DONE keys:", len(done))
print("TO RUN:", len(todo))
print(todo)


In [None]:
# --- FIX: define run_id() exactly as the rest of the notebook expects ---
def run_id(dataset: str, fusion: str, seed: int, pe_type: str = "sinusoidal") -> str:
    # Must match how you name checkpoints and how run_id appears in results_runs.csv
    return f"{dataset}_pe={pe_type}_fusion={fusion}_seed={seed}"

# Sanity check
print(run_id("agnews", "gate", 0))


In [None]:
for cfg in todo:
    rid_str = f"{cfg['dataset']}_pe=sinusoidal_fusion={cfg['fusion']}_seed={cfg['seed']}"
    print("="*100)
    print("RUN:", rid_str)
    print("="*100)

    train_one_run(cfg["dataset"], cfg["fusion"], int(cfg["seed"]))


## Export (paper-facing artifacts)
The following cell **must** write the final figures and (optionally) CSV summaries
to `../results/` so they correspond exactly to the paper.


In [None]:
# =========================
# EXPORT
# =========================
# Wire your actual result variables here.

# Example:
# df.to_csv(RESULTS_DIR / "table1_agnews_imdb.csv", index=False)
# plt.savefig(FIG_DIR / "figX_agnews_imdb.png", dpi=300, bbox_inches="tight")

print("Export cell executed. Ensure outputs are written to ../results/.")
