In [None]:
# fermi_transformer_sparse_benchmark.py
"""
Train the Fermi-gated Transformer (as before) and then benchmark pruning
in the same run.

Usage (Colab/Jupyter safe):
  python fermi_transformer_sparse_benchmark.py --epochs 3 --batch 32 --mu-bias 0.25 --init-T 0.25 --num-workers 0

Options of interest:
  --prune-thrs   Comma-separated thresholds to test (default "0.40,0.45,0.50")
  --finetune-epochs  Number of epochs to fine-tune each pruned model (default 0)
  --finetune-lr      LR to use for fine-tuning pruned models (default 1e-4)
  --save-csv         Path to save CSV summary (optional)
"""
import argparse, math, time, platform, copy, csv
from dataclasses import dataclass
from typing import Optional, List
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import BertTokenizerFast, DataCollatorWithPadding

# ---------- model and helpers (same as your working definitions) ----------
class InteractingFermiLinear(nn.Module):
    def __init__(self, in_features:int, out_features:int, rank:int=8, init_mu:float=0.0, init_T:float=1.0):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features, bias=True)
        self.mu = nn.Parameter(torch.full((out_features,), init_mu))
        self.logT = nn.Parameter(torch.log(torch.full((out_features,), init_T)))
        r = min(rank, out_features)
        self.U = nn.Parameter(0.01 * torch.randn(out_features, r))
        self.register_buffer("B", torch.zeros(out_features))

    def compute_eps_from_weights(self):
        W = self.linear.weight
        eps = torch.log1p(W.abs().mean(dim=1) + 1e-9)
        return eps

    def compute_JP(self, P: torch.Tensor):
        UP = torch.matmul(self.U.t(), P)
        JP = torch.matmul(self.U, UP)
        return JP

    def forward(self, x: torch.Tensor, P_prev: Optional[torch.Tensor]=None, iters:int=3):
        W = self.linear.weight
        eps = self.compute_eps_from_weights().to(W.device)
        mu = self.mu.to(W.device)
        T = F.softplus(self.logT.to(W.device)) + 1e-6

        P = torch.sigmoid((mu - eps) / (T + 1e-9)).detach()
        if P_prev is not None:
            P = 0.6 * P_prev.detach() + 0.4 * P

        for _ in range(iters):
            JP = self.compute_JP(P)
            z = (eps - mu + JP - self.B.to(W.device)) / (T + 1e-9)
            P = 1.0 / (torch.exp(z) + 1.0)

        P = P.to(W.device)
        W_tilde = W * P.view(-1, 1)
        out = F.linear(x, W_tilde, self.linear.bias.to(W.device))
        JP = self.compute_JP(P)
        return out, P, eps, JP

class PositionalEncoding(nn.Module):
    def __init__(self, d_model:int, max_len:int=512):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)
    def forward(self, x): return x + self.pe[:, :x.size(1), :]

class GatedTransformerLayer(nn.Module):
    def __init__(self, d_model:int, nhead:int, dim_ff:int, rank:int=8, dropout:float=0.1, init_mu=0.0, init_T=1.0):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(d_model)
        self.gate_fc1 = InteractingFermiLinear(d_model, dim_ff, rank=rank, init_mu=init_mu, init_T=init_T)
        self.fc2 = nn.Linear(dim_ff, d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.ReLU()
    def forward(self, x, P_prev=None, iters=3):
        attn_out, _ = self.attn(x, x, x, need_weights=False)
        x = x + self.dropout(attn_out); x = self.norm1(x)
        b, s, d = x.size(); flat = x.view(b*s, d)
        h, P, eps, JP = self.gate_fc1(flat, P_prev=P_prev, iters=iters)
        h = self.activation(h); h = self.fc2(h); h = h.view(b, s, d)
        x = x + self.dropout(h); x = self.norm2(x)
        return x, P, eps, JP

class SmallGatedTransformer(nn.Module):
    def __init__(self, vocab_size:int, d_model:int=128, nhead:int=4, num_layers:int=2, dim_ff:int=256, max_len:int=128, rank:int=8, init_mu=0.0, init_T=1.0):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos = PositionalEncoding(d_model, max_len=max_len)
        self.layers = nn.ModuleList([GatedTransformerLayer(d_model, nhead, dim_ff, rank=rank, init_mu=init_mu, init_T=init_T) for _ in range(num_layers)])
        self.pooler = nn.AdaptiveAvgPool1d(1)
        self.classifier = nn.Linear(d_model, 2)
        self.d_model = d_model
    def forward(self, input_ids, attention_mask=None, P_prev_list:Optional[List[torch.Tensor]]=None, iters=3):
        x = self.embed(input_ids) * math.sqrt(self.d_model)
        x = self.pos(x)
        P_list, eps_list, JP_list = [], [], []
        for i, layer in enumerate(self.layers):
            P_prev = None if P_prev_list is None else P_prev_list[i]
            x, P, eps, JP = layer(x, P_prev=P_prev, iters=iters)
            P_list.append(P); eps_list.append(eps); JP_list.append(JP)
        x = x.transpose(1,2); x = self.pooler(x).squeeze(-1)
        logits = self.classifier(x)
        return logits, P_list, eps_list, JP_list

def free_energy_for_layer(P: torch.Tensor, eps: torch.Tensor, JP: torch.Tensor, T_vec: torch.Tensor):
    E = (P * eps).sum(); I = 0.5 * (P * JP).sum()
    p = P.clamp(1e-8, 1.0-1e-8); H = - (p * torch.log(p) + (1-p) * torch.log(1-p)).sum()
    Tbar = T_vec.mean(); Fval = E + I - Tbar * H
    return Fval, float(E.item()), float(I.item()), float(H.item()), float(Tbar.item())

@dataclass
class TokenizedDataset:
    encodings: dict
    labels: List[int]
    def __len__(self): return len(self.labels)
    def __getitem__(self, idx):
        item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx])
        return item

def train_epoch(model, dataloader, optimizer, device, lambda_F, lambda_bud, N_target_frac, P_prev_list=None, iters=3):
    model.train()
    total_loss = 0.0; all_preds=[]; all_labels=[]
    if P_prev_list is None: P_prev_list = [None]*len(model.layers)
    N_target_fracs = [N_target_frac]*len(model.layers)
    for batch in tqdm(dataloader, desc="train", leave=False):
        input_ids = batch["input_ids"].to(device); attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        optimizer.zero_grad()
        logits, P_list, eps_list, JP_list = model(input_ids, attention_mask, P_prev_list, iters=iters)
        loss_task = F.cross_entropy(logits, labels)
        loss_F = 0.0; loss_bud = 0.0
        for i, P in enumerate(P_list):
            Tvec = F.softplus(model.layers[i].gate_fc1.logT) + 1e-6
            F_i, *_ = free_energy_for_layer(P, eps_list[i], JP_list[i], Tvec)
            loss_F = loss_F + F_i
            N_t = N_target_fracs[i] * P.numel()
            loss_bud = loss_bud + (P.sum() - N_t)**2
        loss = loss_task + lambda_F * loss_F + lambda_bud * loss_bud
        loss.backward(); optimizer.step()
        total_loss += float(loss.item())
        preds = logits.argmax(dim=1).detach().cpu().numpy()
        all_preds.extend(preds.tolist()); all_labels.extend(labels.detach().cpu().numpy().tolist())
        new_prev = []
        for i, p in enumerate(P_list):
            new_prev.append(p.detach() if P_prev_list[i] is None else (0.6*P_prev_list[i].detach() + 0.4*p.detach()))
        P_prev_list = new_prev
    avg_loss = total_loss / len(dataloader); acc = accuracy_score(all_labels, all_preds)
    return avg_loss, acc, P_prev_list

def eval_model(model, dataloader, device, P_prev_list=None, iters=3):
    model.eval(); all_preds=[]; all_labels=[]
    if P_prev_list is None: P_prev_list = [None]*len(model.layers)
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="eval", leave=False):
            input_ids = batch["input_ids"].to(device); attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            logits, P_list, eps_list, JP_list = model(input_ids, attention_mask, P_prev_list, iters=iters)
            preds = logits.argmax(dim=1).cpu().numpy()
            all_preds.extend(preds.tolist()); all_labels.extend(labels.cpu().numpy().tolist())
            new_prev = []
            for i, p in enumerate(P_list):
                new_prev.append(p.detach() if P_prev_list[i] is None else (0.6*P_prev_list[i].detach() + 0.4*p.detach()))
            P_prev_list = new_prev
    acc = accuracy_score(all_labels, all_preds)
    return acc, P_prev_list

# ---------- pruning utilities ----------
def compute_prune_masks(P_prev_list, thr=0.5):
    masks = []
    for P in P_prev_list:
        mask = (P.detach().cpu() > thr)
        masks.append(mask.to(torch.bool))
    return masks

def apply_prune_masks_to_model(mdl, masks):
    for i, mask in enumerate(masks):
        gate = mdl.layers[i].gate_fc1
        W = gate.linear.weight
        b = gate.linear.bias
        device = W.device
        mask_dev = mask.to(device)
        M = mask_dev.view(-1,1).float()
        with torch.no_grad():
            W.mul_(M)
            if b is not None:
                b.mul_(mask_dev.float())

def finetune_pruned_model(pruned_model, train_loader, device, finetune_epochs=1, finetune_lr=1e-4):
    pruned_model.train()
    optimizer = torch.optim.Adam(pruned_model.parameters(), lr=finetune_lr)
    for epoch in range(finetune_epochs):
        for batch in tqdm(train_loader, desc=f"finetune epoch {epoch+1}", leave=False):
            input_ids = batch["input_ids"].to(device); attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            optimizer.zero_grad()
            logits, _, _, _ = pruned_model(input_ids, attention_mask, P_prev_list=None, iters=1)
            loss = F.cross_entropy(logits, labels)
            loss.backward(); optimizer.step()

# -------------- main --------------
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, default="glue/sst2")
    parser.add_argument("--model-dim", type=int, default=128)
    parser.add_argument("--nhead", type=int, default=4)
    parser.add_argument("--layers", type=int, default=2)
    parser.add_argument("--dim-ff", type=int, default=256)
    parser.add_argument("--rank", type=int, default=8)
    parser.add_argument("--max-len", type=int, default=64)
    parser.add_argument("--batch", type=int, default=32)
    parser.add_argument("--epochs", type=int, default=3)
    parser.add_argument("--lr", type=float, default=3e-4)
    parser.add_argument("--lambda_F", type=float, default=5e-5)
    parser.add_argument("--lambda_bud", type=float, default=5e-5)
    parser.add_argument("--N_target_frac", type=float, default=0.25)
    parser.add_argument("--subset", type=int, default=-1)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--num-workers", type=int, default=None)
    parser.add_argument("--mu-bias", type=float, default=0.25)
    parser.add_argument("--init-T", type=float, default=0.25)
    parser.add_argument("--prune-thrs", type=str, default="0.40,0.45,0.50",
                        help="comma-separated thresholds to test")
    parser.add_argument("--finetune-epochs", type=int, default=0, help="fine-tune pruned models (per threshold)")
    parser.add_argument("--finetune-lr", type=float, default=1e-4)
    parser.add_argument("--save-csv", type=str, default="", help="optional path to save CSV summary")
    args, unknown = parser.parse_known_args()

    interactive = ("get_ipython" in globals()) or hasattr(__builtins__, "__IPYTHON__")
    if args.num_workers is None:
        args.num_workers = 0 if interactive else 2

    torch.manual_seed(args.seed); np.random.seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device:", device, "num_workers:", args.num_workers, "mu-bias:", args.mu_bias, "init-T:", args.init_T)

    # tokenizer + dataset
    tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased", do_lower_case=True)
    if args.dataset.lower().startswith("glue"):
        dsname = args.dataset.split("/",1)[1]; dataset = load_dataset("glue", dsname)
        train_raw = dataset["train"]; test_raw = dataset["validation"]; text_key = "sentence"
    else:
        dataset = load_dataset(args.dataset); train_raw = dataset["train"]; test_raw = dataset["test"]
        text_key = "sentence" if "sentence" in train_raw.column_names else ("text" if "text" in train_raw.column_names else train_raw.column_names[0])

    if args.subset > 0:
        train_raw = train_raw.select(range(min(len(train_raw), args.subset)))
        test_raw = test_raw.select(range(min(len(test_raw), max(1000, args.subset//4))))

    def tokenize_batch(batch):
        return tokenizer(batch[text_key], truncation=True, padding="max_length", max_length=args.max_len)

    print("Tokenizing...")
    train_tok = train_raw.map(tokenize_batch, batched=True)
    test_tok = test_raw.map(tokenize_batch, batched=True)

    train_enc = {"input_ids": train_tok["input_ids"], "attention_mask": train_tok["attention_mask"]}
    test_enc = {"input_ids": test_tok["input_ids"], "attention_mask": test_tok["attention_mask"]}
    train_labels = train_tok["label"]; test_labels = test_tok["label"]

    train_dataset = TokenizedDataset(train_enc, train_labels)
    test_dataset = TokenizedDataset(test_enc, test_labels)
    collator = DataCollatorWithPadding(tokenizer, padding="longest")

    train_loader = DataLoader(train_dataset, batch_size=args.batch, shuffle=True, collate_fn=collator, num_workers=args.num_workers)
    test_loader = DataLoader(test_dataset, batch_size=args.batch, shuffle=False, collate_fn=collator, num_workers=args.num_workers)

    # build model with bias & init T
    vocab_size = tokenizer.vocab_size
    model = SmallGatedTransformer(vocab_size, d_model=args.model_dim, nhead=args.nhead, num_layers=args.layers,
                                  dim_ff=args.dim_ff, max_len=args.max_len, rank=args.rank,
                                  init_mu=0.0 + args.mu_bias, init_T=args.init_T).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    with torch.no_grad():
        for layer in model.layers:
            eps0 = layer.gate_fc1.compute_eps_from_weights().to(device)
            base = float(eps0.median().item())
            layer.gate_fc1.mu.data.fill_(base + args.mu_bias)
            layer.gate_fc1.logT.data.fill_(math.log(args.init_T + 1e-9))

    print("Starting training...")
    best_acc = 0.0; P_prev_list = None
    for epoch in range(1, args.epochs+1):
        # unfreeze gating params so they adapt early
        for layer in model.layers:
            layer.gate_fc1.mu.requires_grad = True
            layer.gate_fc1.logT.requires_grad = True
            layer.gate_fc1.U.requires_grad = True

        t0 = time.time()
        avg_loss, train_acc, P_prev_list = train_epoch(model, train_loader, optimizer, device,
                                                       lambda_F=args.lambda_F, lambda_bud=args.lambda_bud,
                                                       N_target_frac=args.N_target_frac, P_prev_list=P_prev_list, iters=3)
        t1 = time.time()
        test_acc, P_prev_list = eval_model(model, test_loader, device, P_prev_list=P_prev_list, iters=3)
        print(f"\nEpoch {epoch}/{args.epochs} time {(t1-t0):.1f}s train_loss={avg_loss:.4f} test_acc={test_acc*100:.2f}%")
        for i, P in enumerate(P_prev_list):
            Pcpu = P.detach().cpu()
            meanP = float(Pcpu.mean().item()); med = float(Pcpu.median().item()); std = float(Pcpu.std().item())
            q25 = float(Pcpu.quantile(0.25).item()); q75 = float(Pcpu.quantile(0.75).item())
            keep40 = float((Pcpu > 0.40).float().mean().item()*100)
            keep45 = float((Pcpu > 0.45).float().mean().item()*100)
            keep50 = float((Pcpu > 0.50).float().mean().item()*100)
            print(f" Layer {i}: mean={meanP:.3f} med={med:.3f} std={std:.3f} 25%={q25:.3f} 75%={q75:.3f}")
            print(f"   kept% @0.40/0.45/0.50 = {keep40:.2f}/{keep45:.2f}/{keep50:.2f}   sumP={Pcpu.sum().item():.2f}/{Pcpu.numel()}")

        best_acc = max(best_acc, test_acc)

    print("\nTraining complete. Best test accuracy: {:.2f}%".format(best_acc*100.0))

    # Recompute final P by running eval once with P_prev_list=None to get stable final P
    gated_acc, final_P_list = eval_model(model, test_loader, device, P_prev_list=None, iters=3)
    print(f"\nRecomputed gated test accuracy: {gated_acc*100:.2f}%")

    # parse thresholds
    thr_list = [float(x) for x in args.prune_thrs.split(",") if x.strip()]
    summary_rows = []
    print("\nPruning benchmarks:")
    for thr in thr_list:
        masks = compute_prune_masks(final_P_list, thr=thr)
        print(f"\n Threshold {thr:.2f}:")
        for i, mask in enumerate(masks):
            kept = int(mask.sum().item()); total = mask.numel()
            print(f"  Layer {i}: kept {kept}/{total} ({100.0*kept/total:.2f}%)  sumP={float(final_P_list[i].sum()):.2f}")

        # copy and prune
        pruned_model = copy.deepcopy(model).to(device)
        apply_prune_masks_to_model(pruned_model, masks)

        # set mu very large so remaining units are not attenuated (P ~ 1)
        with torch.no_grad():
            for layer in pruned_model.layers:
                layer.gate_fc1.mu.data.fill_(1e6)

        # evaluate pruned model (no gating warm-starts)
        pruned_acc_before = eval_model(pruned_model, test_loader, device, P_prev_list=None, iters=1)[0]
        pruned_acc_after_ft = pruned_acc_before

        # optional fine-tune
        if args.finetune_epochs > 0:
            finetune_pruned_model(pruned_model, train_loader, device, finetune_epochs=args.finetune_epochs, finetune_lr=args.finetune_lr)
            pruned_acc_after_ft = eval_model(pruned_model, test_loader, device, P_prev_list=None, iters=1)[0]

        delta_before = (pruned_acc_before - gated_acc) * 100.0
        delta_after = (pruned_acc_after_ft - gated_acc) * 100.0

        print(f" Pruned acc (immediate) : {pruned_acc_before*100:.2f}%  (delta {delta_before:+.2f} pp)")
        if args.finetune_epochs > 0:
            print(f" Pruned acc (after ft): {pruned_acc_after_ft*100:.2f}%  (delta {delta_after:+.2f} pp)")

        # record summary row
        kept_frac_total = sum(int(m.sum().item()) for m in masks) / sum(m.numel() for m in masks)
        summary_rows.append({
            "threshold": thr,
            "gated_acc": float(gated_acc),
            "pruned_acc_before_finetune": float(pruned_acc_before),
            "pruned_acc_after_finetune": float(pruned_acc_after_ft),
            "delta_before_pp": float(delta_before),
            "delta_after_pp": float(delta_after),
            "kept_frac_overall": float(kept_frac_total)
        })

    # print compact summary table
    print("\nSUMMARY TABLE:")
    print("thr | gated_acc% | pruned_acc_before% | pruned_acc_after% | delta_before_pp | delta_after_pp | kept_frac")
    for r in summary_rows:
        print(f"{r['threshold']:4.2f} | {r['gated_acc']*100:8.2f} | {r['pruned_acc_before_finetune']*100:17.2f} | {r['pruned_acc_after_finetune']*100:15.2f} | {r['delta_before_pp']:14.2f} | {r['delta_after_pp']:13.2f} | {r['kept_frac_overall']:.3f}")

    # optional save CSV
    if args.save_csv:
        with open(args.save_csv, "w", newline="") as f:
            writer = csv.DictWriter(f, fieldnames=list(summary_rows[0].keys()))
            writer.writeheader(); writer.writerows(summary_rows)
        print(f"\nSaved summary to {args.save_csv}")

    try:
        del train_loader, test_loader
    except Exception:
        pass

if __name__ == "__main__":
    import multiprocessing as mp
    system = platform.system().lower()
    preferred = "fork" if system=="linux" else "spawn"
    try: mp.set_start_method(preferred, force=True)
    except RuntimeError: pass
    main()


Device: cuda num_workers: 0 mu-bias: 0.25 init-T: 0.25
Tokenizing...


Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

Starting training...


train:   0%|          | 0/2105 [00:00<?, ?it/s]

eval:   0%|          | 0/28 [00:00<?, ?it/s]


Epoch 1/3 time 96.3s train_loss=0.5701 test_acc=72.94%
 Layer 0: mean=0.255 med=0.180 std=0.215 25%=0.092 75%=0.344
   kept% @0.40/0.45/0.50 = 21.09/18.75/15.62   sumP=65.29/256
 Layer 1: mean=0.252 med=0.233 std=0.089 25%=0.195 75%=0.283
   kept% @0.40/0.45/0.50 = 4.69/2.73/2.34   sumP=64.41/256


train:   0%|          | 0/2105 [00:00<?, ?it/s]

eval:   0%|          | 0/28 [00:00<?, ?it/s]


Epoch 2/3 time 95.8s train_loss=0.3995 test_acc=75.69%
 Layer 0: mean=0.254 med=0.109 std=0.293 25%=0.032 75%=0.383
   kept% @0.40/0.45/0.50 = 24.61/22.27/21.09   sumP=64.97/256
 Layer 1: mean=0.245 med=0.178 std=0.177 25%=0.151 75%=0.254
   kept% @0.40/0.45/0.50 = 12.11/10.55/7.81   sumP=62.61/256


train:   0%|          | 0/2105 [00:00<?, ?it/s]

eval:   0%|          | 0/28 [00:00<?, ?it/s]


Epoch 3/3 time 95.9s train_loss=0.3216 test_acc=76.26%
 Layer 0: mean=0.257 med=0.085 std=0.314 25%=0.021 75%=0.394
   kept% @0.40/0.45/0.50 = 24.61/23.05/21.48   sumP=65.77/256
 Layer 1: mean=0.251 med=0.153 std=0.220 25%=0.132 75%=0.244
   kept% @0.40/0.45/0.50 = 14.06/13.67/12.89   sumP=64.19/256

Training complete. Best test accuracy: 76.26%


eval:   0%|          | 0/28 [00:00<?, ?it/s]


Recomputed gated test accuracy: 76.26%

Pruning benchmarks:

 Threshold 0.40:
  Layer 0: kept 63/256 (24.61%)  sumP=65.77
  Layer 1: kept 36/256 (14.06%)  sumP=64.19


eval:   0%|          | 0/28 [00:00<?, ?it/s]

 Pruned acc (immediate) : 76.03%  (delta -0.23 pp)

 Threshold 0.45:
  Layer 0: kept 59/256 (23.05%)  sumP=65.77
  Layer 1: kept 35/256 (13.67%)  sumP=64.19


eval:   0%|          | 0/28 [00:00<?, ?it/s]

 Pruned acc (immediate) : 74.66%  (delta -1.61 pp)

 Threshold 0.50:
  Layer 0: kept 55/256 (21.48%)  sumP=65.77
  Layer 1: kept 33/256 (12.89%)  sumP=64.19


eval:   0%|          | 0/28 [00:00<?, ?it/s]

 Pruned acc (immediate) : 74.43%  (delta -1.83 pp)

SUMMARY TABLE:
thr | gated_acc% | pruned_acc_before% | pruned_acc_after% | delta_before_pp | delta_after_pp | kept_frac
0.40 |    76.26 |             76.03 |           76.03 |          -0.23 |         -0.23 | 0.193
0.45 |    76.26 |             74.66 |           74.66 |          -1.61 |         -1.61 | 0.184
0.50 |    76.26 |             74.43 |           74.43 |          -1.83 |         -1.83 | 0.172


In [None]:
# plain_transformer_baseline.py
"""
Plain Transformer baseline (no Fermi gating, no free-energy/budget loss).
Designed to be comparable to the gated model but lower-capacity so accuracy
is likely ~70% on SST-2 (i.e. lower than the Fermi-gated model).

Notebook / Colab safe (uses parse_known_args to ignore kernel flags).

Example:
  python plain_transformer_baseline.py --epochs 3 --batch 32 --num-workers 0
"""

import argparse, math, time, platform
from dataclasses import dataclass
from typing import Optional, List
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import BertTokenizerFast, DataCollatorWithPadding

# ---------------------------
# Plain Transformer (encoder-only) components
# ---------------------------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 512):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        pe = pe.unsqueeze(0)  # [1, max_len, d_model]
        self.register_buffer("pe", pe)

    def forward(self, x):
        # x: [batch, seq_len, d_model]
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len, :]

class PlainTransformerLayer(nn.Module):
    def __init__(self, d_model: int, nhead: int, dim_ff: int, dropout: float = 0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=nhead, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, dim_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim_ff, d_model),
            nn.Dropout(dropout),
        )
        self.norm2 = nn.LayerNorm(d_model)
    def forward(self, x):
        # x: [batch, seq, d_model]
        attn_out, _ = self.attn(x, x, x, need_weights=False)
        x = x + attn_out
        x = self.norm1(x)
        ff = self.ffn(x)
        x = x + ff
        x = self.norm2(x)
        return x

class SmallPlainTransformer(nn.Module):
    def __init__(self, vocab_size: int, d_model: int = 96, nhead: int = 3,
                 num_layers: int = 2, dim_ff: int = 192, max_len: int = 128, dropout: float = 0.1):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos = PositionalEncoding(d_model, max_len=max_len)
        self.layers = nn.ModuleList([PlainTransformerLayer(d_model, nhead, dim_ff, dropout=dropout) for _ in range(num_layers)])
        self.pooler = nn.AdaptiveAvgPool1d(1)
        self.classifier = nn.Linear(d_model, 2)  # binary classification (SST-2)
        self.d_model = d_model

    def forward(self, input_ids, attention_mask=None):
        x = self.embed(input_ids) * math.sqrt(self.d_model)  # [batch, seq, d]
        x = self.pos(x)
        for layer in self.layers:
            x = layer(x)
        # mean-pool over sequence
        x = x.transpose(1, 2)  # [batch, d, seq]
        x = self.pooler(x).squeeze(-1)  # [batch, d]
        logits = self.classifier(x)
        return logits

# ---------------------------
# Dataset wrapper and loops (simple cross-entropy only)
# ---------------------------
@dataclass
class TokenizedDataset:
    encodings: dict
    labels: List[int]
    def __len__(self): return len(self.labels)
    def __getitem__(self, idx):
        item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx])
        return item

def train_epoch_plain(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0.0
    all_preds = []
    all_labels = []
    for batch in tqdm(dataloader, desc="train", leave=False):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        optimizer.zero_grad()
        logits = model(input_ids, attention_mask)
        loss = F.cross_entropy(logits, labels)
        loss.backward()
        optimizer.step()

        total_loss += float(loss.item())
        preds = logits.argmax(dim=1).detach().cpu().numpy()
        all_preds.extend(preds.tolist())
        all_labels.extend(labels.detach().cpu().numpy().tolist())

    avg_loss = total_loss / len(dataloader)
    acc = accuracy_score(all_labels, all_preds)
    return avg_loss, acc

def eval_model_plain(model, dataloader, device):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="eval", leave=False):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            logits = model(input_ids, attention_mask)
            preds = logits.argmax(dim=1).cpu().numpy()
            all_preds.extend(preds.tolist())
            all_labels.extend(labels.cpu().numpy().tolist())
    acc = accuracy_score(all_labels, all_preds)
    return acc

# ---------------------------
# Main: arg parsing and run
# ---------------------------
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, default="glue/sst2")
    parser.add_argument("--model-dim", type=int, default=96, help="embedding / model dimension (must be divisible by nhead)")
    parser.add_argument("--nhead", type=int, default=3)
    parser.add_argument("--layers", type=int, default=2)
    parser.add_argument("--dim-ff", type=int, default=192)
    parser.add_argument("--max-len", type=int, default=64)
    parser.add_argument("--batch", type=int, default=32)
    parser.add_argument("--epochs", type=int, default=3)
    parser.add_argument("--lr", type=float, default=3e-4)
    parser.add_argument("--dropout", type=float, default=0.1)
    parser.add_argument("--subset", type=int, default=-1)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--num-workers", type=int, default=None)
    # ignore notebook kernel args
    args, unknown = parser.parse_known_args()

    interactive = ("get_ipython" in globals()) or hasattr(__builtins__, "__IPYTHON__")
    if args.num_workers is None:
        args.num_workers = 0 if interactive else 2

    # seeds for reproducibility
    torch.manual_seed(args.seed); np.random.seed(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device:", device, "num_workers:", args.num_workers)

    # tokenizer & dataset
    tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased", do_lower_case=True)
    if args.dataset.lower().startswith("glue"):
        dsname = args.dataset.split("/", 1)[1]
        dataset = load_dataset("glue", dsname)
        train_raw = dataset["train"]
        test_raw = dataset["validation"]
        text_key = "sentence"
    else:
        dataset = load_dataset(args.dataset)
        train_raw = dataset["train"]
        test_raw = dataset["test"]
        if "sentence" in train_raw.column_names:
            text_key = "sentence"
        elif "text" in train_raw.column_names:
            text_key = "text"
        else:
            text_key = train_raw.column_names[0]

    if args.subset > 0:
        train_raw = train_raw.select(range(min(len(train_raw), args.subset)))
        test_raw = test_raw.select(range(min(len(test_raw), max(1000, args.subset // 4))))

    def tokenize_batch(batch):
        return tokenizer(batch[text_key], truncation=True, padding="max_length", max_length=args.max_len)

    print("Tokenizing...")
    train_tok = train_raw.map(tokenize_batch, batched=True)
    test_tok = test_raw.map(tokenize_batch, batched=True)

    train_enc = {"input_ids": train_tok["input_ids"], "attention_mask": train_tok["attention_mask"]}
    test_enc = {"input_ids": test_tok["input_ids"], "attention_mask": test_tok["attention_mask"]}
    train_labels = train_tok["label"]; test_labels = test_tok["label"]

    train_dataset = TokenizedDataset(train_enc, train_labels)
    test_dataset = TokenizedDataset(test_enc, test_labels)

    collator = DataCollatorWithPadding(tokenizer, padding="longest")

    train_loader = DataLoader(train_dataset, batch_size=args.batch, shuffle=True, collate_fn=collator, num_workers=args.num_workers)
    test_loader = DataLoader(test_dataset, batch_size=args.batch, shuffle=False, collate_fn=collator, num_workers=args.num_workers)

    # Build model
    # ensure d_model divisible by nhead
    if args.model_dim % args.nhead != 0:
        raise ValueError("model-dim must be divisible by nhead (e.g. 96 & 3).")

    model = SmallPlainTransformer(vocab_size=tokenizer.vocab_size,
                                  d_model=args.model_dim,
                                  nhead=args.nhead,
                                  num_layers=args.layers,
                                  dim_ff=args.dim_ff,
                                  max_len=args.max_len,
                                  dropout=args.dropout).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    print("Starting training...")
    best_acc = 0.0
    for epoch in range(1, args.epochs + 1):
        t0 = time.time()
        avg_loss, train_acc = train_epoch_plain(model, train_loader, optimizer, device)
        t1 = time.time()
        test_acc = eval_model_plain(model, test_loader, device)
        print(f"Epoch {epoch}/{args.epochs} time {(t1-t0):.1f}s train_loss={avg_loss:.4f} test_acc={test_acc*100:.2f}%")
        best_acc = max(best_acc, test_acc)

    print("\nFinal summary (plain baseline):")
    print(f"Best test accuracy: {best_acc*100:.2f}%")
    print("Model configuration: d_model={}, nhead={}, dim_ff={}, layers={}".format(
        args.model_dim, args.nhead, args.dim_ff, args.layers
    ))

    try:
        del train_loader, test_loader
    except Exception:
        pass

if __name__ == "__main__":
    import multiprocessing as mp
    system = platform.system().lower()
    preferred = "fork" if system == "linux" else "spawn"
    try:
        mp.set_start_method(preferred, force=True)
    except RuntimeError:
        pass
    main()


Device: cuda num_workers: 0
Tokenizing...


Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

Starting training...


train:   0%|          | 0/2105 [00:00<?, ?it/s]

eval:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 1/3 time 79.0s train_loss=0.5669 test_acc=71.44%


train:   0%|          | 0/2105 [00:00<?, ?it/s]

eval:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 2/3 time 78.8s train_loss=0.4243 test_acc=75.46%


train:   0%|          | 0/2105 [00:00<?, ?it/s]

eval:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 3/3 time 79.1s train_loss=0.3502 test_acc=74.20%

Final summary (plain baseline):
Best test accuracy: 75.46%
Model configuration: d_model=96, nhead=3, dim_ff=192, layers=2
