# Reproducibility Notebook — ArXiv PE-Fusion Experiments

**Purpose (paper-facing):** Reproduce the ArXiv fusion experiments reported in the paper *“Fusion Matters: Length-Aware Analysis of Positional-Encoding Fusion in Transformers”*.

This notebook is intended to be run **top-to-bottom** without manual edits.

## What this reproduces
- Fusion operators: **Add**, **Concat+Projection**, **Gate-Scalar** (and **Gate-CNN** if included in this notebook)
- Seeds: **0–4** (paired-seed comparisons)
- Outputs exported to `../results/`:
  - `results/table1_arxiv.csv` (or similar)
  - figure files used in the paper (e.g., paired-seed deltas)

## Notes
- Datasets are **not** included in the repository. You must download/prepare them as described in `docs/REPRODUCIBILITY.md`.
- If you change any hyperparameter, you must treat results as *new* and update the paper figures/tables accordingly.


In [None]:
# =========================
# CONFIG (single source of truth)
# =========================
# Edit ONLY this cell to change experiment settings.

from pathlib import Path
import os
import random
import numpy as np
import torch

# Output directory (repo-relative)
RESULTS_DIR = Path("../results")
FIG_DIR = RESULTS_DIR / "figures"
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
FIG_DIR.mkdir(parents=True, exist_ok=True)

# Seeds used in the paper (paired-seed protocol)
SEEDS = [0, 1, 2, 3, 4]

# Dataset paths (adjust to your machine)
# Example placeholders — keep them explicit.
ARXIV_DATA_DIR = Path(os.environ.get("ARXIV_DATA_DIR", "./data/arxiv"))  # change as needed

# Core training hyperparameters (must match the paper settings)
EPOCHS = 20
BATCH_SIZE = 8
LEARNING_RATE = 3e-4

# Model hyperparameters (must match the paper settings)
D_MODEL = 256
NHEAD = 8
NUM_LAYERS = 4
DIM_FF = 1024
DROPOUT = 0.1

# Sequence constraints
MAX_LEN = 4096

def set_global_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

print("CONFIG loaded.")
print("RESULTS_DIR:", RESULTS_DIR.resolve())
print("FIG_DIR:", FIG_DIR.resolve())
print("ARXIV_DATA_DIR:", ARXIV_DATA_DIR)


---


In [None]:
!pip install datasets
!pip install pandas
!pip install torch torchvision torchaudio


# Setup & Imports

In [None]:
# =============== Environment Setup ===============
import os
import time
import json
import math
import random
import numpy as np
import pandas as pd
import statistics
import matplotlib.pyplot as plt
from typing import List, Dict

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from datasets import load_dataset

# =============== Reproducibility ===============
def set_seed(seed: int):
    import os, random
    import numpy as np
    import torch

    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # Determinism (can slow down a bit; keep ON for research reproducibility)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def seed_worker(worker_id: int):
    # Ensures each DataLoader worker has a deterministic RNG stream
    import random
    import numpy as np
    import torch
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

# Multi-seed study configuration
SEEDS = [0, 1, 2, 3, 4]   

# =============== Device Info ===============
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
if device.type == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory Allocated: {torch.cuda.memory_allocated(0)/1024**3:.2f} GB")
    print(f"Memory Cached: {torch.cuda.memory_reserved(0)/1024**3:.2f} GB")


In [None]:
def simple_tokenize(text: str) -> List[str]:
    return text.lower().strip().split()

def build_vocab(texts: List[str], min_freq: int = 2, max_size: int = 50000) -> Dict[str, int]:
    freq = {}
    for t in texts:
        for tok in simple_tokenize(t):
            freq[tok] = freq.get(tok, 0) + 1
    vocab = {"<pad>": 0, "<unk>": 1}
    words = [w for w, c in freq.items() if c >= min_freq]
    words.sort(key=lambda w: -freq[w])
    for w in words[: max_size - len(vocab)]:
        vocab[w] = len(vocab)
    return vocab

def encode(text, vocab, max_len):
    toks = simple_tokenize(text)
    ids = [vocab.get(tok, vocab["<unk>"]) for tok in toks][:max_len]
    pad = [vocab["<pad>"]] * max(0, max_len - len(ids))
    return ids + pad

class TextClsDataset(torch.utils.data.Dataset):
    def __init__(self, texts, labels, vocab, max_len):
        self.texts = texts
        self.labels = labels
        self.vocab = vocab
        self.max_len = max_len

    def __len__(self): return len(self.texts)

    def __getitem__(self, idx):
        x = torch.tensor(encode(self.texts[idx], self.vocab, self.max_len))
        y = torch.tensor(self.labels[idx], dtype=torch.long)
        return x, y


In [None]:
class SinusoidalPE(nn.Module):
    def __init__(self, d_model, max_len):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(max_len).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, x):
        return self.pe[:, :x.size(1)]

class LearnedPE(nn.Module):
    def __init__(self, max_len, d_model):
        super().__init__()
        self.pe = nn.Embedding(max_len, d_model)

    def forward(self, x):
        pos = torch.arange(x.size(1), device=x.device).unsqueeze(0)
        return self.pe(pos)

class RotaryPE(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, d_model, 2).float() / d_model))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, x):
        t = torch.arange(x.size(1), device=x.device).type_as(self.inv_freq)
        freqs = torch.einsum("i , j -> ij", t, self.inv_freq)
        emb = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
        return emb.unsqueeze(0).expand(x.size(0), -1, -1)

class RelativePE(nn.Module):
    def __init__(self, d_model, max_len):
        super().__init__()
        self.rel_embed = nn.Parameter(torch.randn(max_len, d_model))

    def forward(self, x):
        seq_len = x.size(1)
        return self.rel_embed[:seq_len].unsqueeze(0)


In [None]:
class TextTransformerPEF(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, layers, dim_ff, dropout, pe_type, fusion, max_len, num_classes):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.pe_type = pe_type
        self.fusion = fusion
        self.d_model = d_model

        if pe_type == "sinusoidal":
            self.pe_module = SinusoidalPE(d_model, max_len)
        elif pe_type == "learned":
            self.pe_module = LearnedPE(max_len, d_model)
        elif pe_type == "rope":
            self.pe_module = RotaryPE(d_model)
        elif pe_type == "relative":
            self.pe_module = RelativePE(d_model, max_len)

        if fusion == "concat":
            self.fuse_layer = nn.Linear(d_model * 2, d_model)
        elif fusion == "gate":
            self.fuse_gate = nn.Linear(d_model * 2, 1)
        elif fusion == "mlp":
            self.fuse_mlp = nn.Sequential(
                nn.Linear(d_model * 2, d_model),
                nn.ReLU(),
                nn.Linear(d_model, d_model)
            )

        self.tr = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model, nhead, dim_ff, dropout, batch_first=True),
            num_layers=layers
        )
        self.cls = nn.Linear(d_model, num_classes)

    def fuse(self, E, P):
        # Ensure shapes match for concatenation-based fusion
        if P.size(0) != E.size(0):
            P = P.expand(E.size(0), -1, -1)

        if self.fusion == "add":
            return E + P
        elif self.fusion == "mul":
            return E * P
        elif self.fusion == "concat":
            return self.fuse_layer(torch.cat([E, P], dim=-1))
        elif self.fusion == "gate":
            gate = torch.sigmoid(self.fuse_gate(torch.cat([E, P], dim=-1)))
            return gate * E + (1 - gate) * P
        elif self.fusion == "mlp":
            return self.fuse_mlp(torch.cat([E, P], dim=-1))
        else:
            return E + P
    def forward(self, x):
        E = self.embed(x) * math.sqrt(self.d_model)
        P = self.pe_module(x).to(x.device)
        h = self.fuse(E, P)
        out = self.tr(h, src_key_padding_mask=(x == 0))
        mask = (x != 0).unsqueeze(-1).float()
        pooled = (out * mask).sum(1) / mask.sum(1).clamp(min=1.0)
        return self.cls(pooled)


## Main experiment code
Run top-to-bottom. Do not edit cells in the middle; only edit CONFIG above.


In [None]:
# ================= Reliable training / evaluation utilities =================
import time, statistics
import numpy as np
import torch
import torch.nn as nn

def train_epoch(model, loader, optim, device):
    model.train()
    loss_fn = nn.CrossEntropyLoss()
    start = time.time()
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optim.zero_grad(set_to_none=True)
        logits = model(x)
        loss = loss_fn(logits, y)
        loss.backward()
        optim.step()
    return time.time() - start

@torch.no_grad()
def eval_acc(model, loader, device):
    model.eval()
    correct, total = 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        preds = torch.argmax(logits, dim=-1)
        correct += (preds == y).sum().item()
        total += y.numel()
    return correct / total if total > 0 else 0.0

@torch.no_grad()
def measure_latency(model, loader, device, repeats=10, warmup=5):
    """
    Stable latency (ms/sample) using full-loader passes.
    - warmup: number of full passes ignored
    - repeats: number of measured full passes
    """
    model.eval()

    # warmup passes
    for _ in range(warmup):
        if device.type == "cuda":
            torch.cuda.synchronize()
        for x, _ in loader:
            x = x.to(device)
            _ = model(x)
        if device.type == "cuda":
            torch.cuda.synchronize()

    times = []
    for _ in range(repeats):
        if device.type == "cuda":
            torch.cuda.synchronize()
        t0 = time.perf_counter()

        n = 0
        for x, _ in loader:
            x = x.to(device)
            _ = model(x)
            n += x.size(0)

        if device.type == "cuda":
            torch.cuda.synchronize()
        t1 = time.perf_counter()

        times.append(((t1 - t0) * 1000.0) / max(n, 1))

    return float(np.mean(times)), float(statistics.pstdev(times))

def train_with_early_stopping(model, train_loader, val_loader, optimizer, device,
                              max_epochs=20, patience=4, min_delta=1e-4):
    """
    Trains up to max_epochs and keeps the best checkpoint by validation accuracy.
    Returns: best_val_acc, best_epoch_index, best_state_dict (CPU tensors).
    """
    best_val = -1.0
    best_epoch = -1
    best_state = None
    bad = 0

    for ep in range(max_epochs):
        ep_time = train_epoch(model, train_loader, optimizer, device)
        val_acc = float(eval_acc(model, val_loader, device))
        print(f"Epoch {ep+1}/{max_epochs} - {ep_time:.1f}s - val_acc={val_acc:.4f}")

        if val_acc > best_val + min_delta:
            best_val = val_acc
            best_epoch = ep
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            bad = 0
        else:
            bad += 1

        if bad >= patience:
            print(f"Early stopping at epoch {ep+1}. Best epoch: {best_epoch+1} (val_acc={best_val:.4f})")
            break

    return best_val, best_epoch, best_state


In [None]:
# ================= Dataset + loaders + final multi-seed experiment loop =================
import os, json
import pandas as pd
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader

# -------------------- Load dataset --------------------
ds = load_dataset("json", data_files={
    "train": "data/train_clean.jsonl",
    "validation": "data/validation_clean.jsonl",
    "test": "data/test_clean.jsonl"
})
ds = ds.filter(lambda x: len(x["text"].strip()) > 0)

print(f"train examples: {len(ds['train'])}")
print(f"validation examples: {len(ds['validation'])}")
print(f"test examples: {len(ds['test'])}")

all_texts = list(ds["train"]["text"]) + list(ds["validation"]["text"]) + list(ds["test"]["text"])
vocab = build_vocab(all_texts, min_freq=2)

# -------------------- Experiment config --------------------
args = {
    "max_len": MAX_LEN,
    "d_model": 128,
    "nhead": 8,
    "layers": 2,
    "dim_ff": 256,
    "dropout": 0.1,
    "batch": 64,
    "epochs": 20,
    "lr": LR,
    "num_classes": 11,

    # latency: your old value=3 is too noisy; use 10/5 to keep runtime reasonable
    "timing_repeats": 10,
    "timing_warmup": 5,
}

def make_loaders(seed: int):
    # DataLoader generator controls shuffle deterministically
    g = torch.Generator()
    g.manual_seed(seed)

    train_ds = TextClsDataset(ds["train"]["text"], ds["train"]["label"], vocab, args["max_len"])
    val_ds   = TextClsDataset(ds["validation"]["text"], ds["validation"]["label"], vocab, args["max_len"])
    test_ds  = TextClsDataset(ds["test"]["text"], ds["test"]["label"], vocab, args["max_len"])

    train_loader = DataLoader(
        train_ds,
        batch_size=args["batch"],
        shuffle=True,
        worker_init_fn=seed_worker,
        generator=g,
        num_workers=0,      # keep 0 for strict determinism; increase later if needed
        pin_memory=(device.type == "cuda")
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=args["batch"],
        shuffle=False,
        num_workers=0,
        pin_memory=(device.type == "cuda")
    )
    test_loader = DataLoader(
        test_ds,
        batch_size=args["batch"],
        shuffle=False,
        num_workers=0,
        pin_memory=(device.type == "cuda")
    )
    return train_loader, val_loader, test_loader

# -------------------- Grid --------------------
fusion_methods = ["add", "mul", "concat", "gate", "mlp"]
pe_types = ["sinusoidal", "learned", "rope", "relative"]

EARLY_STOP_PATIENCE = 4
EARLY_STOP_MIN_DELTA = 1e-4

os.makedirs("saved_models", exist_ok=True)
rows = []

for seed in SEEDS:
    set_seed(seed)
    train_loader, val_loader, test_loader = make_loaders(seed)

    for pe in pe_types:
        for fusion in fusion_methods:
            key = f"Arxiv|pe={pe}|fusion={fusion}|seed={seed}"
            print("\n" + "="*80)
            print(key)

            try:
                # Important: re-seed before model init so init is seed-controlled
                set_seed(seed)

                model = TextTransformerPEF(
                    vocab_size=len(vocab),
                    d_model=args["d_model"],
                    nhead=args["nhead"],
                    layers=args["layers"],
                    dim_ff=args["dim_ff"],
                    dropout=args["dropout"],
                    pe_type=pe,
                    fusion=fusion,
                    max_len=args["max_len"],
                    num_classes=args["num_classes"],
                ).to(device)

                optimizer = torch.optim.Adam(model.parameters(), lr=args["lr"])
                n_params = sum(p.numel() for p in model.parameters())
                print(f"Params={n_params:,}")

                # Train with early stopping (best checkpoint by validation)
                best_val, best_epoch, best_state = train_with_early_stopping(
                    model, train_loader, val_loader, optimizer, device,
                    max_epochs=args["epochs"],
                    patience=EARLY_STOP_PATIENCE,
                    min_delta=EARLY_STOP_MIN_DELTA
                )

                # Restore best checkpoint before final test + latency
                if best_state is not None:
                    model.load_state_dict(best_state)

                # Eval
                val_acc = float(eval_acc(model, val_loader, device))
                test_acc = float(eval_acc(model, test_loader, device))

                # Latency
                lat_mean, lat_std = measure_latency(
                    model, test_loader, device,
                    repeats=args["timing_repeats"],
                    warmup=args["timing_warmup"]
                )

                rows.append({
                    "dataset": "ArxivClassification",
                    "pe_type": pe,
                    "fusion": fusion,
                    "seed": int(seed),
                    "params": int(n_params),
                    "best_val_acc": float(best_val),
                    "best_epoch": int(best_epoch),
                    "val_acc": float(val_acc),
                    "test_acc": float(test_acc),
                    "lat_ms_mean": float(lat_mean),
                    "lat_ms_std": float(lat_std),
                })

                # Optional: save checkpoint per seed/config
                torch.save(model.state_dict(), f"saved_models/{key.replace('|','_')}.pt")

            except Exception as e:
                print(f"[ERROR] Skipping {key}: {e}")
                rows.append({
                    "dataset": "ArxivClassification",
                    "pe_type": pe,
                    "fusion": fusion,
                    "seed": int(seed),
                    "error": str(e),
                })

# -------------------- Save raw + aggregated results --------------------
df = pd.DataFrame(rows)
df.to_csv("arxiv_results_by_seed.csv", index=False)
with open("arxiv_results_by_seed.json", "w") as f:
    json.dump(rows, f, indent=2)

df_ok = df[df.get("error").isna()] if "error" in df.columns else df.copy()

agg = (
    df_ok.groupby(["pe_type", "fusion"], as_index=False)
      .agg(
          params_mean=("params", "mean"),
          val_acc_mean=("val_acc", "mean"),
          val_acc_std=("val_acc", "std"),
          test_acc_mean=("test_acc", "mean"),
          test_acc_std=("test_acc", "std"),
          lat_ms_mean_mean=("lat_ms_mean", "mean"),
          lat_ms_mean_std=("lat_ms_mean", "std"),
          seeds=("seed", "nunique"),
      )
      .sort_values("test_acc_mean", ascending=False)
)

agg.to_csv("arxiv_results_aggregated.csv", index=False)
print("\nSaved:")
print("- arxiv_results_by_seed.csv / .json (raw per seed)")
print("- arxiv_results_aggregated.csv (mean/std across seeds)")
display(agg.head(20))


## Export checklist (must be true before you claim reproduction)
- [ ] Table-1 style summary CSV written to `../results/`
- [ ] Paired-seed delta figure(s) written to `../results/figures/`
- [ ] Notebook runs top-to-bottom on a clean kernel


In [None]:
# =========================
# EXPORT (paper-facing artifacts)
# =========================
# This cell is a *template*.
# Replace variable names below with the actual objects created by your training/eval code.

from pathlib import Path
import pandas as pd

# Example expected objects:
# - results_rows: list[dict] with keys like {"seed":0, "fusion":"add", "acc":0.6123}
# - fig objects already saved during plotting OR saved here

# ---- 1) Save scalar results (Table 1 inputs / deltas)
# if 'results_rows' in globals():
#     df = pd.DataFrame(results_rows)
#     out_csv = RESULTS_DIR / "table1_arxiv_runs.csv"
#     df.to_csv(out_csv, index=False)
#     print("Wrote:", out_csv)
# else:
#     print("NOTE: results_rows not found. Wire your results into this export cell.")

# ---- 2) Save aggregated summary (mean ± std)
# if 'results_rows' in globals():
#     df = pd.DataFrame(results_rows)
#     summary = df.groupby("fusion")["acc"].agg(["mean","std"]).reset_index()
#     out_csv = RESULTS_DIR / "table1_arxiv_summary.csv"
#     summary.to_csv(out_csv, index=False)
#     print("Wrote:", out_csv)

# ---- 3) Ensure figures are saved to FIG_DIR
# Example:
# plt.savefig(FIG_DIR / "fig3_paired_seed_deltas.png", dpi=300, bbox_inches="tight")

print("Export cell: wire your in-notebook variables to the saves above.")
