# Report https://www.overleaf.com/read/psbbcwrqtmbv#cab870

# DevPost https://devpost.com/software/tbd-xgpy9i

# Data https://huggingface.co/datasets/osunlp/TravelPlanner

# Google Colab https://colab.research.google.com/drive/167wiVvHBHVgoBUw4JYgsYuhArUYEM-KB?usp=sharing

# General Instruction https://docs.google.com/document/d/1MJcxo7tTPXfAky8qrIPrav3WzR_tfdfUDPKpp1BDN_Y/edit?tab=t.0

>helpful links:
https://pytorch-dendritic-optimization.devpost.com/
https://docs.google.com/document/d/1MJcxo7tTPXfAky8qrIPrav3WzR_tfdfUDPKpp1BDN_Y/edit
https://github.com/PerforatedAI/PerforatedAI
https://github.com/PerforatedAI/PerforatedAI/tree/main/API
https://docs.wandb.ai/models/tutorials/sweeps
https://huggingface.co/models?sort=downloads
https://huggingface.co/datasets?sort=downloads
https://docs.pytorch.org/vision/main/models.html
https://docs.pytorch.org/vision/main/datasets.html
https://waymo.com/open/
https://wandb.ai/perforated-ai/Dendritic%20Edge%20Impulse%20Audio%20-%20Combo/reports/Edge-Impulse-Keyword-Spotting--VmlldzoxNTIxNjE5Ng?accessToken=3lm4jm5f9npsu45vs180ybo6150ed4gnhos9rrkk6seqb4bmf458me28seynu0xb


In [None]:
import sys, os, subprocess, textwrap, platform
print("python:", sys.executable)
print("version:", sys.version)
print("cwd:", os.getcwd())

python: /usr/bin/python3
version: 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0]
cwd: /content


In [None]:
!pip -q install -U requests tqdm scikit-learn transformers accelerate

import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)
if device == "cuda":
    print("gpu:", torch.cuda.get_device_name(0))

device: cpu


In [None]:
from google.colab import drive
drive.mount("/content/drive")

DRIVE_ROOT = "/content/drive/MyDrive/dendrites_fraud"
print("Drive root:", DRIVE_ROOT)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Drive root: /content/drive/MyDrive/dendrites_fraud


In [None]:
!pip -q install "numpy==2.0.2" "scikit-learn==1.6.1" pandas plotly wandb datasets tqdm

In [None]:
import os, re, json, math, time, random
from pathlib import Path
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

from sklearn.metrics import (
    precision_recall_curve, average_precision_score,
    roc_curve, roc_auc_score,
    f1_score, precision_score, recall_score, accuracy_score
)
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer

device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)

# OPTIONAL: checkpoints to Drive
USE_DRIVE = False
DRIVE_DIR = Path("/content/drive/MyDrive/travelplanner_runs")  # change if you want
if USE_DRIVE:
    from google.colab import drive
    drive.mount("/content/drive")
    DRIVE_DIR.mkdir(parents=True, exist_ok=True)
    print("Drive run dir:", DRIVE_DIR)

device: cpu


In [None]:
from datasets import load_dataset

REPO = "osunlp/TravelPlanner"

train_ds = load_dataset(REPO, "train")["train"]
val_ds   = load_dataset(REPO, "validation")["validation"]

# test sometimes exists, sometimes not; handle safely
try:
    test_ds = load_dataset(REPO, "test")["test"]
except Exception as e:
    print("No test config (or failed to load). Using validation as test for now.")
    test_ds = val_ds

print("Loaded |")
print("train:", len(train_ds))
print("val  :", len(val_ds))
print("test :", len(test_ds))

print("\nColumns:", train_ds.column_names)

# peek one example (helps you confirm fields)
ex0 = train_ds[0]
print("\nExample keys:", list(ex0.keys()))
for k,v in ex0.items():
    if isinstance(v, str):
        print(f"{k}: {v[:140]}{'...' if len(v)>140 else ''}")
    else:
        print(f"{k}: ({type(v).__name__})")

test.csv:   0%|          | 0.00/26.7M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Loaded |
train: 45
val  : 180
test : 1000

Columns: ['org', 'dest', 'days', 'visiting_city_number', 'date', 'people_number', 'local_constraint', 'budget', 'query', 'level', 'annotated_plan', 'reference_information']

Example keys: ['org', 'dest', 'days', 'visiting_city_number', 'date', 'people_number', 'local_constraint', 'budget', 'query', 'level', 'annotated_plan', 'reference_information']
org: St. Petersburg
dest: Rockford
days: (int)
visiting_city_number: (int)
date: ['2022-03-16', '2022-03-17', '2022-03-18']
people_number: (int)
local_constraint: {'house rule': None, 'cuisine': None, 'room type': None, 'transportation': None}
budget: (int)
query: Please help me plan a trip from St. Petersburg to Rockford spanning 3 days from March 16th to March 18th, 2022. The travel should be planned...
level: easy
annotated_plan: [{'org': 'St. Petersburg', 'dest': 'Rockford', 'days': 3, 'visiting_city_number': 1, 'date': ['2022-03-16', '2022-03-17', '2022-03-18'], 'pe...
reference_information: [

In [None]:
def pick_field(cols, candidates):
    for c in candidates:
        if c in cols:
            return c
    for c in candidates:
        for col in cols:
            if c.lower() in col.lower():
                return col
    return None

COLS = train_ds.column_names

QUERY_F = pick_field(COLS, ["query", "instruction", "prompt", "question", "input"])
PLAN_F  = pick_field(COLS, ["plan", "answer", "output", "response", "solution", "itinerary"])

print("Auto-picked fields:")
print("QUERY_F =", QUERY_F)
print("PLAN_F  =", PLAN_F)

assert QUERY_F is not None, f"Could not find query-like field in columns={COLS}"
assert PLAN_F  is not None, f"Could not find plan/answer-like field in columns={COLS}"

# quick sanity print
print("\nSanity sample:")
print("Q:", train_ds[0][QUERY_F][:160])
print("P:", train_ds[0][PLAN_F][:160])

Auto-picked fields:
QUERY_F = query
PLAN_F  = annotated_plan

Sanity sample:
Q: Please help me plan a trip from St. Petersburg to Rockford spanning 3 days from March 16th to March 18th, 2022. The travel should be planned for a single person
P: [{'org': 'St. Petersburg', 'dest': 'Rockford', 'days': 3, 'visiting_city_number': 1, 'date': ['2022-03-16', '2022-03-17', '2022-03-18'], 'people_number': 1, 'lo


#Build positive/negative pair dataset

In [None]:
import numpy as np
from datasets import Dataset

SEED = 42
rng = np.random.default_rng(SEED)

# We ONLY have labels (annotated_plan) in train_ds
assert "query" in train_ds.column_names
assert "annotated_plan" in train_ds.column_names

N = len(train_ds)
idx = np.arange(N)
rng.shuffle(idx)

# With only 45 examples, keep val/test small but non-empty.
# 70/15/15 split -> 31/7/7 (roughly)
n_train = int(0.70 * N)
n_val   = int(0.15 * N)
n_test  = N - n_train - n_val

train_idx = idx[:n_train]
val_idx   = idx[n_train:n_train + n_val]
test_idx  = idx[n_train + n_val:]

train_lab = train_ds.select(train_idx.tolist())
val_lab   = train_ds.select(val_idx.tolist())
test_lab  = train_ds.select(test_idx.tolist())

print("Labeled pool split ✅")
print("train:", len(train_lab))
print("val  :", len(val_lab))
print("test :", len(test_lab))

print("\nColumns:", train_lab.column_names)
print("\nSanity example:")
print("Q:", train_lab[0]["query"][:140])
print("P type:", type(train_lab[0]["annotated_plan"]))

Labeled pool split ✅
train: 31
val  : 6
test : 8

Columns: ['org', 'dest', 'days', 'visiting_city_number', 'date', 'people_number', 'local_constraint', 'budget', 'query', 'level', 'annotated_plan', 'reference_information']

Sanity example:
Q: Could you help craft a week-long travel plan for 2 people? We'll be departing from Oakland and aim to visit 3 different cities in Oregon fro
P type: <class 'str'>


In [None]:
import json, random
random.seed(42)

def to_text(x):
    """Make any value safe to put into text (dict/list -> JSON)."""
    if isinstance(x, str):
        return x
    try:
        return json.dumps(x, ensure_ascii=False)
    except Exception:
        return str(x)

def make_pairs(ds, query_f="query", plan_f="annotated_plan", ref_f="reference_information", neg_per_pos=6, include_ref=True):
    pos = []
    for ex in ds:
        q = ex.get(query_f, None)
        p = ex.get(plan_f, None)
        if q is None or p is None:
            continue

        q_txt = to_text(q).strip()
        p_txt = to_text(p).strip()

        # Optional: include reference info into the input (helps realism / task grounding)
        if include_ref and ref_f in ex and ex[ref_f] is not None:
            ref_txt = to_text(ex[ref_f]).strip()
            x_pos = f"QUERY:\n{q_txt}\n\nREFERENCE:\n{ref_txt}\n\nPLAN:\n{p_txt}"
        else:
            x_pos = f"QUERY:\n{q_txt}\n\nPLAN:\n{p_txt}"

        pos.append((x_pos, p_txt))

    assert len(pos) >= 3, f"Not enough usable positives. Got {len(pos)}."

    plans_only = [p for _, p in pos]

    X, y = [], []

    # positives
    for x_pos, _ in pos:
        X.append(x_pos); y.append(1)

    # negatives: same query/reference, but random wrong plan
    for x_pos, p_true in pos:
        for _ in range(neg_per_pos):
            p_neg = random.choice(plans_only)
            while p_neg == p_true and len(plans_only) > 1:
                p_neg = random.choice(plans_only)

            # Replace the PLAN block only (keeps same query/reference)
            x_neg = x_pos.rsplit("\n\nPLAN:\n", 1)[0] + "\n\nPLAN:\n" + p_neg
            X.append(x_neg); y.append(0)

    return X, y

train_texts, train_y = make_pairs(train_lab, neg_per_pos=8, include_ref=True)
val_texts,   val_y   = make_pairs(val_lab,   neg_per_pos=8, include_ref=True)
test_texts,  test_y  = make_pairs(test_lab,  neg_per_pos=8, include_ref=True)

def pos_rate(y):
    return sum(y)/len(y)

print("Pair dataset sizes ")
print("train:", len(train_texts), "pos rate:", pos_rate(train_y))
print("val  :", len(val_texts),   "pos rate:", pos_rate(val_y))
print("test :", len(test_texts),  "pos rate:", pos_rate(test_y))

Pair dataset sizes 
train: 279 pos rate: 0.1111111111111111
val  : 54 pos rate: 0.1111111111111111
test : 72 pos rate: 0.1111111111111111


# TF-IDF vectorizer , tensors , dataloaders

In [None]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from sklearn.feature_extraction.text import TfidfVectorizer
from collections import Counter

# ---- 1) TF-IDF vectorization (simple + strong baseline for text matching) ----
# With only a few hundred samples, we can safely convert to dense arrays.
# (Later, if you scale up, we can keep it sparse.)

MAX_FEATURES = 20000   # cap vocabulary size (keeps things stable)
NGRAM_RANGE  = (1, 2)  # unigrams + bigrams usually helps

vectorizer = TfidfVectorizer(
    max_features=MAX_FEATURES,
    ngram_range=NGRAM_RANGE,
    lowercase=True,
)

X_train = vectorizer.fit_transform(train_texts).toarray().astype(np.float32)
X_val   = vectorizer.transform(val_texts).toarray().astype(np.float32)
X_test  = vectorizer.transform(test_texts).toarray().astype(np.float32)

y_train = np.array(train_y, dtype=np.int64)
y_val   = np.array(val_y, dtype=np.int64)
y_test  = np.array(test_y, dtype=np.int64)

D_IN = X_train.shape[1]
print(" TF-IDF shapes:")
print("X_train:", X_train.shape, "y:", y_train.shape)
print("X_val  :", X_val.shape,   "y:", y_val.shape)
print("X_test :", X_test.shape,  "y:", y_test.shape)
print("D_IN =", D_IN)

print("\nLabel counts:")
print("train:", Counter(y_train.tolist()))
print("val  :", Counter(y_val.tolist()))
print("test :", Counter(y_test.tolist()))

# ---- 2) Torch dataset/dataloader ----
class VecDS(Dataset):
    def __init__(self, X, y):
        self.X = torch.from_numpy(X)          # float32
        self.y = torch.from_numpy(y)          # int64
    def __len__(self):
        return len(self.y)
    def __getitem__(self, i):
        return self.X[i], self.y[i]

train_ds_vec = VecDS(X_train, y_train)
val_ds_vec   = VecDS(X_val,   y_val)
test_ds_vec  = VecDS(X_test,  y_test)

# ---- 3) Handle class imbalance (important) ----
# Since we generate many negatives per positive, sampler helps training learn positives.
counts = np.bincount(y_train, minlength=2)
class_w = 1.0 / np.maximum(counts, 1)
sample_w = class_w[y_train]
sampler = WeightedRandomSampler(
    weights=torch.tensor(sample_w, dtype=torch.double),
    num_samples=len(sample_w),
    replacement=True,
)

BATCH_SIZE = 64
train_dl = DataLoader(train_ds_vec, batch_size=BATCH_SIZE, sampler=sampler)
val_dl   = DataLoader(val_ds_vec,   batch_size=BATCH_SIZE, shuffle=False)
test_dl  = DataLoader(test_ds_vec,  batch_size=BATCH_SIZE, shuffle=False)

print("\n Dataloaders ready:")
print("train batches:", len(train_dl), "val batches:", len(val_dl), "test batches:", len(test_dl))

# ---- 4) Tiny sanity check ----
xb, yb = next(iter(train_dl))
print("batch X:", xb.shape, xb.dtype, "| batch y:", yb.shape, yb.dtype)

 TF-IDF shapes:
X_train: (279, 20000) y: (279,)
X_val  : (54, 20000) y: (54,)
X_test : (72, 20000) y: (72,)
D_IN = 20000

Label counts:
train: Counter({0: 248, 1: 31})
val  : Counter({0: 48, 1: 6})
test : Counter({0: 64, 1: 8})

 Dataloaders ready:
train batches: 5 val batches: 1 test batches: 2
batch X: torch.Size([64, 20000]) torch.float32 | batch y: torch.Size([64]) torch.int64


In [None]:
import os, time, json, math, random
from pathlib import Path
from collections import Counter

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.metrics import (
    average_precision_score,
    precision_recall_curve,
    f1_score,
    accuracy_score,
)

# -----------------------------
# 0) Repro + device
# -----------------------------
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)

print("train label dist:", Counter([int(y) for y in y_train]))
print("val   label dist:", Counter([int(y) for y in y_val]))
print("test  label dist:", Counter([int(y) for y in y_test]))

# -----------------------------
# 1) Model: simple MLP classifier
#    Output = logits (before sigmoid)
# -----------------------------
class MLP(nn.Module):
    def __init__(self, d_in, hidden=(512, 256), dropout=0.2):
        super().__init__()
        layers = []
        prev = d_in
        for h in hidden:
            layers += [
                nn.Linear(prev, h),
                nn.ReLU(),
                nn.Dropout(dropout),
            ]
            prev = h
        layers += [nn.Linear(prev, 1)]  # binary logits
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x).squeeze(-1)  # [B]

def count_params(m):
    total = sum(p.numel() for p in m.parameters())
    trainable = sum(p.numel() for p in m.parameters() if p.requires_grad)
    return total, trainable

# -----------------------------
# 2) Evaluation helpers
# -----------------------------
@torch.no_grad()
def predict_probs(model, dl):
    model.eval()
    ys, probs = [], []
    for xb, yb in dl:
        xb = xb.to(device)
        yb = yb.to(device)
        logits = model(xb)
        p = torch.sigmoid(logits)
        ys.append(yb.detach().cpu().numpy())
        probs.append(p.detach().cpu().numpy())
    y = np.concatenate(ys).astype(int)
    prob = np.concatenate(probs).astype(float)
    return y, prob

def best_f1_threshold(y_true, prob):
    # pick threshold maximizing F1 on validation
    # (precision_recall_curve returns thresholds of length = len(prec)-1)
    prec, rec, thr = precision_recall_curve(y_true, prob)
    if len(thr) == 0:
        return 0.5, 0.0
    f1s = (2 * prec[:-1] * rec[:-1]) / np.maximum(prec[:-1] + rec[:-1], 1e-12)
    i = int(np.nanargmax(f1s))
    return float(thr[i]), float(f1s[i])

def report_metrics(y_true, prob, threshold=0.5):
    pred = (prob >= threshold).astype(int)

    acc = accuracy_score(y_true, pred)
    f1  = f1_score(y_true, pred, pos_label=1, zero_division=0)
    ap  = average_precision_score(y_true, prob)  # PR-AUC

    # also helpful for narrative:
    tp = int(((pred == 1) & (y_true == 1)).sum())
    fp = int(((pred == 1) & (y_true == 0)).sum())
    tn = int(((pred == 0) & (y_true == 0)).sum())
    fn = int(((pred == 0) & (y_true == 1)).sum())

    return {
        "acc": float(acc),
        "f1": float(f1),
        "pr_auc": float(ap),
        "tp": tp, "fp": fp, "tn": tn, "fn": fn,
        "threshold": float(threshold),
    }

# -----------------------------
# 3) Checkpoint helpers (resume-safe)
# -----------------------------
def save_ckpt(run_dir, name, model, opt, epoch, extra):
    run_dir = Path(run_dir)
    run_dir.mkdir(parents=True, exist_ok=True)
    payload = {
        "epoch": int(epoch),
        "model": model.state_dict(),
        "opt": opt.state_dict(),
        "extra": extra,
    }
    path = run_dir / f"{name}.pt"
    torch.save(payload, path)
    return str(path)

def load_ckpt_if_exists(run_dir, model, opt):
    run_dir = Path(run_dir)
    path = run_dir / "last.pt"
    if not path.exists():
        return 0, None
    ck = torch.load(path, map_location="cpu")
    model.load_state_dict(ck["model"])
    opt.load_state_dict(ck["opt"])
    start_epoch = int(ck["epoch"]) + 1
    return start_epoch, ck

# -----------------------------
# 4) Training loop (baseline)
# -----------------------------
def train_baseline_mlp(
    run_name="baseline_mlp",
    run_root="/content/runs_travelplanner",
    epochs=10,
    lr=1e-3,
    weight_decay=0.0,
    grad_clip=1.0,
    log_every=50,
    resume=True,
):
    run_dir = Path(run_root) / run_name
    run_dir.mkdir(parents=True, exist_ok=True)

    model = MLP(D_IN, hidden=(512, 256), dropout=0.2).to(device)
    total, trainable = count_params(model)
    print(f"[{run_name}] params total/trainable:", total, trainable)

    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    # class imbalance: use pos_weight so positives matter
    # pos_weight = (#neg/#pos)
    n_pos = int((y_train == 1).sum())
    n_neg = int((y_train == 0).sum())
    pos_weight = torch.tensor([n_neg / max(1, n_pos)], dtype=torch.float32, device=device)
    loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

    start_epoch = 0
    best_val = -1.0
    best_path = None

    if resume:
        start_epoch, ck = load_ckpt_if_exists(run_dir, model, opt)
        if ck is not None:
            best_val = float(ck["extra"].get("best_val_pr_auc", -1.0))
            best_path = ck["extra"].get("best_path", None)
            print(f"[{run_name}] Resumed from epoch {start_epoch} (best_val_pr_auc={best_val:.4f})")
        else:
            print(f"[{run_name}] No checkpoint found. Starting fresh.")

    history = []
    t_all = time.time()

    for ep in range(start_epoch, epochs):
        model.train()
        t0 = time.time()
        running = 0.0
        seen = 0

        for step, (xb, yb) in enumerate(train_dl):
            xb = xb.to(device)
            yb = yb.float().to(device)

            opt.zero_grad(set_to_none=True)
            logits = model(xb)
            loss = loss_fn(logits, yb)

            loss.backward()
            if grad_clip is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            opt.step()

            running += float(loss.detach().cpu())
            seen += 1

            if step % log_every == 0:
                print(f"[{run_name}] ep {ep} step {step}/{len(train_dl)} loss {running/max(1,seen):.5f}")

        train_loss = running / max(1, seen)

        # ---- validation: pick threshold that maximizes F1 on VAL
        yv, pv = predict_probs(model, val_dl)
        val_thr, val_best_f1 = best_f1_threshold(yv, pv)
        val_rep = report_metrics(yv, pv, threshold=val_thr)

        # ---- test: evaluate using the VAL-chosen threshold (fair + realistic)
        yt, pt = predict_probs(model, test_dl)
        test_rep = report_metrics(yt, pt, threshold=val_thr)

        row = {
            "epoch": ep,
            "train_loss": float(train_loss),
            "val_best_f1_at_thr": float(val_best_f1),
            **{f"val_{k}": v for k, v in val_rep.items()},
            **{f"test_{k}": v for k, v in test_rep.items()},
            "epoch_sec": round(time.time() - t0, 2),
        }
        history.append(row)

        print(f"\n[{run_name}] ep {ep} train_loss={train_loss:.5f}")
        print(f"[{run_name}] VAL  PR-AUC={val_rep['pr_auc']:.4f} F1={val_rep['f1']:.4f} thr={val_thr:.4f}")
        print(f"[{run_name}] TEST PR-AUC={test_rep['pr_auc']:.4f} F1={test_rep['f1']:.4f} thr={val_thr:.4f}\n")

        # ---- save last
        extra = {
            "row": row,
            "best_val_pr_auc": best_val,
            "best_path": best_path,
            "pos_weight": float(pos_weight.item()),
            "params_total": int(total),
            "params_trainable": int(trainable),
        }
        last_path = save_ckpt(run_dir, "last", model, opt, ep, extra)

        # ---- save best (by val PR-AUC, since dataset is imbalanced)
        if val_rep["pr_auc"] > best_val:
            best_val = float(val_rep["pr_auc"])
            best_path = save_ckpt(run_dir, "best", model, opt, ep, {**extra, "best_val_pr_auc": best_val})
            print(f"[{run_name}]  new best saved: val_pr_auc={best_val:.6f}")

        # write a small json history so you can inspect quickly later
        with open(run_dir / "history.json", "w") as f:
            json.dump(history, f, indent=2)

    print(f"[{run_name}] done. total time: {round(time.time()-t_all, 1)} sec")
    print("run_dir:", str(run_dir))
    print("best:", str(run_dir / "best.pt"))
    print("last:", str(run_dir / "last.pt"))
    return model, history, str(run_dir)

# ---- Run it ----
baseline_model, baseline_hist, baseline_dir = train_baseline_mlp(
    run_name="baseline_mlp",
    run_root="/content/runs_travelplanner",
    epochs=8,         # start small; we can sweep later
    lr=1e-3,
    resume=True,      # if Colab restarts, rerun this cell and it continues
    log_every=50
)

device: cpu
train label dist: Counter({0: 248, 1: 31})
val   label dist: Counter({0: 48, 1: 6})
test  label dist: Counter({0: 64, 1: 8})
[baseline_mlp] params total/trainable: 10372097 10372097
[baseline_mlp] Resumed from epoch 2 (best_val_pr_auc=0.1423)
[baseline_mlp] ep 2 step 0/5 loss 1.92818

[baseline_mlp] ep 2 train_loss=1.76770
[baseline_mlp] VAL  PR-AUC=0.1358 F1=0.2609 thr=0.9522
[baseline_mlp] TEST PR-AUC=0.1561 F1=0.1250 thr=0.9522

[baseline_mlp] ep 3 step 0/5 loss 1.80503

[baseline_mlp] ep 3 train_loss=1.68211
[baseline_mlp] VAL  PR-AUC=0.1358 F1=0.2609 thr=0.9729
[baseline_mlp] TEST PR-AUC=0.1624 F1=0.1429 thr=0.9729

[baseline_mlp] ep 4 step 0/5 loss 1.70945

[baseline_mlp] ep 4 train_loss=1.64146
[baseline_mlp] VAL  PR-AUC=0.1423 F1=0.2609 thr=0.9547
[baseline_mlp] TEST PR-AUC=0.1682 F1=0.1667 thr=0.9547

[baseline_mlp] ep 5 step 0/5 loss 1.67863

[baseline_mlp] ep 5 train_loss=1.68796
[baseline_mlp] VAL  PR-AUC=0.1423 F1=0.2609 thr=0.9031
[baseline_mlp] TEST PR-AUC=0.

# Dendrites MLP (PerforatedAI) + checkpointing + PR-AUC/F1

In [None]:
import os, time, json, sys, subprocess
from pathlib import Path
from collections import Counter

import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import average_precision_score, precision_recall_curve, f1_score, accuracy_score

# -----------------------------
# 0) Sanity checks
# -----------------------------
need = ["train_dl", "val_dl", "test_dl", "D_IN", "y_train", "y_val", "y_test"]
missing = [x for x in need if x not in globals()]
assert not missing, f"Missing globals: {missing}. Rerun your data/vectorizer cell first."

device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)
print("train label dist:", Counter([int(y) for y in y_train]))
print("val   label dist:", Counter([int(y) for y in y_val]))
print("test  label dist:", Counter([int(y) for y in y_test]))

# -----------------------------
# 1) Install / import PerforatedAI
# -----------------------------
def ensure_perforatedai():
    try:
        import perforatedai
        return perforatedai
    except Exception:
        pass

    repo = Path("/content/PerforatedAI")
    if not repo.exists():
        subprocess.check_call(["bash", "-lc", "git clone https://github.com/PerforatedAI/PerforatedAI.git /content/PerforatedAI"])

    # editable install
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "-e", "/content/PerforatedAI"])

    import perforatedai
    return perforatedai

perforatedai = ensure_perforatedai()
print("perforatedai:", perforatedai.__file__)

from perforatedai import globals_perforatedai as GPA
from perforatedai import utils_perforatedai as UPA

# -----------------------------
# 2) Helper functions
# -----------------------------
@torch.no_grad()
def predict_probs(model, dl):
    model.eval()
    ys, probs = [], []
    for xb, yb in dl:
        xb = xb.to(device)
        yb = yb.to(device)
        logits = model(xb)
        p = torch.sigmoid(logits)
        ys.append(yb.detach().cpu().numpy())
        probs.append(p.detach().cpu().numpy())
    y = np.concatenate(ys).astype(int)
    prob = np.concatenate(probs).astype(float)
    return y, prob

def best_f1_threshold(y_true, prob):
    prec, rec, thr = precision_recall_curve(y_true, prob)
    if len(thr) == 0:
        return 0.5, 0.0
    f1s = (2 * prec[:-1] * rec[:-1]) / np.maximum(prec[:-1] + rec[:-1], 1e-12)
    i = int(np.nanargmax(f1s))
    return float(thr[i]), float(f1s[i])

def report_metrics(y_true, prob, threshold=0.5):
    pred = (prob >= threshold).astype(int)
    acc = accuracy_score(y_true, pred)
    f1  = f1_score(y_true, pred, pos_label=1, zero_division=0)
    ap  = average_precision_score(y_true, prob)  # PR-AUC

    tp = int(((pred == 1) & (y_true == 1)).sum())
    fp = int(((pred == 1) & (y_true == 0)).sum())
    tn = int(((pred == 0) & (y_true == 0)).sum())
    fn = int(((pred == 0) & (y_true == 1)).sum())

    return {"acc": float(acc), "f1": float(f1), "pr_auc": float(ap),
            "tp": tp, "fp": fp, "tn": tn, "fn": fn,
            "threshold": float(threshold)}

def save_ckpt(run_dir, name, model, opt, epoch, extra):
    run_dir = Path(run_dir)
    run_dir.mkdir(parents=True, exist_ok=True)
    payload = {"epoch": int(epoch), "model": model.state_dict(), "opt": opt.state_dict(), "extra": extra}
    path = run_dir / f"{name}.pt"
    torch.save(payload, path)
    return str(path)

def load_ckpt_if_exists(run_dir, model, opt):
    run_dir = Path(run_dir)
    path = run_dir / "last.pt"
    if not path.exists():
        return 0, None
    ck = torch.load(path, map_location="cpu")
    model.load_state_dict(ck["model"], strict=False)  # strict=False avoids minor key mismatches
    opt.load_state_dict(ck["opt"])
    return int(ck["epoch"]) + 1, ck

def count_params(m):
    total = sum(p.numel() for p in m.parameters())
    trainable = sum(p.numel() for p in m.parameters() if p.requires_grad)
    return total, trainable

# -----------------------------
# 3) Base MLP (same as baseline shape)
# -----------------------------
class BaseMLP(nn.Module):
    def __init__(self, d_in, hidden=(512, 256), dropout=0.2):
        super().__init__()
        layers = []
        prev = d_in
        for h in hidden:
            layers += [nn.Linear(prev, h), nn.ReLU(), nn.Dropout(dropout)]
            prev = h
        layers += [nn.Linear(prev, 1)]
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x).squeeze(-1)

# -----------------------------
# 4) Make dendrites model
#    Key trick: initialize_pai wraps the model and adds dendritic modules.
# -----------------------------
def make_dendrites_model():
    model = BaseMLP(D_IN, hidden=(512, 256), dropout=0.2)

    # Avoid interactive pauses / extra checks
    if hasattr(GPA.pc, "set_unwrapped_modules_confirmed"):
        GPA.pc.set_unwrapped_modules_confirmed(True)

    # Some configs rely on input dimension hints (batch, features)
    # For MLP features, treat last dim as 0-th index in dims list.
    if hasattr(GPA.pc, "set_input_dimensions"):
        # [batch, features] -> dims [-1, 0]
        GPA.pc.set_input_dimensions([-1, 0])

    # Initialize dendrites
    model = UPA.initialize_pai(
        model,
        doing_pai=True,
        save_name="travelplanner_mlp",
        making_graphs=False,
        maximizing_score=True,
    )
    return model

# -----------------------------
# 5) Train dendrites model (resume-safe)
# -----------------------------
def train_dendrites_mlp(
    run_name="pai_dendrites_mlp",
    run_root="/content/runs_travelplanner",
    epochs=10,
    lr=1e-3,
    weight_decay=0.0,     # PAI recommends no weight decay in many cases
    grad_clip=1.0,
    log_every=1,
    resume=True,
):
    run_dir = Path(run_root) / run_name
    run_dir.mkdir(parents=True, exist_ok=True)

    model = make_dendrites_model().to(device)
    total, trainable = count_params(model)
    print(f"[{run_name}] params total/trainable:", total, trainable)

    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    # class imbalance handling
    n_pos = int((y_train == 1).sum())
    n_neg = int((y_train == 0).sum())
    pos_weight = torch.tensor([n_neg / max(1, n_pos)], dtype=torch.float32, device=device)
    loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

    start_epoch = 0
    best_val = -1.0
    best_path = None

    if resume:
        start_epoch, ck = load_ckpt_if_exists(run_dir, model, opt)
        if ck is not None:
            best_val = float(ck["extra"].get("best_val_pr_auc", -1.0))
            best_path = ck["extra"].get("best_path", None)
            print(f"[{run_name}] Resumed from epoch {start_epoch} (best_val_pr_auc={best_val:.4f})")
        else:
            print(f"[{run_name}] No checkpoint found. Starting fresh.")

    history = []
    for ep in range(start_epoch, epochs):
        model.train()
        t0 = time.time()
        running = 0.0
        seen = 0

        for step, (xb, yb) in enumerate(train_dl):
            xb = xb.to(device)
            yb = yb.float().to(device)

            opt.zero_grad(set_to_none=True)
            logits = model(xb)
            loss = loss_fn(logits, yb)

            loss.backward()
            if grad_clip is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            opt.step()

            running += float(loss.detach().cpu())
            seen += 1

            if step % max(1, log_every) == 0:
                print(f"[{run_name}] ep {ep} step {step}/{len(train_dl)} loss {running/max(1,seen):.5f}")

        train_loss = running / max(1, seen)

        # validation threshold by best F1 on VAL
        yv, pv = predict_probs(model, val_dl)
        val_thr, _ = best_f1_threshold(yv, pv)
        val_rep = report_metrics(yv, pv, threshold=val_thr)

        # test using that threshold (fair comparison)
        yt, pt = predict_probs(model, test_dl)
        test_rep = report_metrics(yt, pt, threshold=val_thr)

        row = {
            "epoch": ep,
            "train_loss": float(train_loss),
            **{f"val_{k}": v for k, v in val_rep.items()},
            **{f"test_{k}": v for k, v in test_rep.items()},
            "epoch_sec": round(time.time() - t0, 2),
        }
        history.append(row)

        print(f"\n[{run_name}] ep {ep} train_loss={train_loss:.5f}")
        print(f"[{run_name}] VAL  PR-AUC={val_rep['pr_auc']:.4f} F1={val_rep['f1']:.4f} thr={val_thr:.4f}")
        print(f"[{run_name}] TEST PR-AUC={test_rep['pr_auc']:.4f} F1={test_rep['f1']:.4f} thr={val_thr:.4f}\n")

        extra = {
            "row": row,
            "best_val_pr_auc": best_val,
            "best_path": best_path,
            "pos_weight": float(pos_weight.item()),
            "params_total": int(total),
            "params_trainable": int(trainable),
        }
        save_ckpt(run_dir, "last", model, opt, ep, extra)

        if val_rep["pr_auc"] > best_val:
            best_val = float(val_rep["pr_auc"])
            best_path = save_ckpt(run_dir, "best", model, opt, ep, {**extra, "best_val_pr_auc": best_val})
            print(f"[{run_name}] new best saved: val_pr_auc={best_val:.6f}")

        with open(run_dir / "history.json", "w") as f:
            json.dump(history, f, indent=2)

    print(f"[{run_name}] done.")
    print("run_dir:", str(run_dir))
    print("best:", str(run_dir / "best.pt"))
    print("last:", str(run_dir / "last.pt"))
    return model, history, str(run_dir)

# ---- Run dendrites training ----
pai_model, pai_hist, pai_dir = train_dendrites_mlp(
    run_name="pai_dendrites_mlp",
    run_root="/content/runs_travelplanner",
    epochs=50,
    lr=1e-3,
    resume=True,
    log_every=1
)

device: cpu
train label dist: Counter({0: 248, 1: 31})
val   label dist: Counter({0: 48, 1: 6})
test  label dist: Counter({0: 64, 1: 8})
perforatedai: /content/PerforatedAI/perforatedai/__init__.py
Variable 'input_dimensions' does not exist.  Ignoring set attempt.
Variable 'input_dimensions' does not exist.  Ignoring set attempt.
Running Dendrite Experiment
[pai_dendrites_mlp] params total/trainable: 20744194 20744194
[pai_dendrites_mlp] Resumed from epoch 15 (best_val_pr_auc=0.1542)
[pai_dendrites_mlp] ep 15 step 0/5 loss 1.66194
[pai_dendrites_mlp] ep 15 step 1/5 loss 1.61878
[pai_dendrites_mlp] ep 15 step 2/5 loss 1.53288
[pai_dendrites_mlp] ep 15 step 3/5 loss 1.51892
[pai_dendrites_mlp] ep 15 step 4/5 loss 1.48656

[pai_dendrites_mlp] ep 15 train_loss=1.48656
[pai_dendrites_mlp] VAL  PR-AUC=0.1328 F1=0.2632 thr=0.8516
[pai_dendrites_mlp] TEST PR-AUC=0.2254 F1=0.2667 thr=0.8516

[pai_dendrites_mlp] ep 16 step 0/5 loss 1.54994
[pai_dendrites_mlp] ep 16 step 1/5 loss 1.57482
[pai_den

# Load BEST checkpoints (baseline + dendrites)

In [None]:
import json
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import precision_recall_curve, average_precision_score, confusion_matrix, f1_score, accuracy_score

device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)

# ---- set run dirs  ----
baseline_dir = "/content/runs_travelplanner/baseline_mlp"
pai_dir      = "/content/runs_travelplanner/pai_dendrites_mlp"

base_best_path = Path(baseline_dir) / "best.pt"
pai_best_path  = Path(pai_dir) / "best.pt"
assert base_best_path.exists(), f"Missing {base_best_path}"
assert pai_best_path.exists(),  f"Missing {pai_best_path}"

# ---- baseline model definition (must match training) ----
class BaseMLP(nn.Module):
    def __init__(self, d_in, hidden=(512, 256), dropout=0.2):
        super().__init__()
        layers = []
        prev = d_in
        for h in hidden:
            layers += [nn.Linear(prev, h), nn.ReLU(), nn.Dropout(dropout)]
            prev = h
        layers += [nn.Linear(prev, 1)]
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x).squeeze(-1)

# ---- load ckpts ----
def load_best_ckpt(path):
    ck = torch.load(path, map_location="cpu")
    return ck

base_ck = load_best_ckpt(base_best_path)
pai_ck  = load_best_ckpt(pai_best_path)

print("baseline best epoch:", base_ck.get("epoch"))
print("pai best epoch:", pai_ck.get("epoch"))

# threshold stored in ckpt extra (from your trainer)
base_thr = float(base_ck["extra"]["row"]["val_threshold"] if "val_threshold" in base_ck["extra"]["row"] else base_ck["extra"]["row"].get("threshold", 0.5))
pai_thr  = float(pai_ck["extra"]["row"]["val_threshold"] if "val_threshold" in pai_ck["extra"]["row"] else pai_ck["extra"]["row"].get("threshold", 0.5))

print("baseline best thr:", base_thr)
print("pai best thr     :", pai_thr)

# ---- build baseline model + load ----
base_model = BaseMLP(D_IN, hidden=(512, 256), dropout=0.2).to(device)
base_model.load_state_dict(base_ck["model"], strict=True)

# ---- build dendrites model the SAME WAY you did in training ----
# (Assumes you still have PerforatedAI imported from your earlier cell)
def make_dendrites_model():
    model = BaseMLP(D_IN, hidden=(512, 256), dropout=0.2)
    model = UPA.initialize_pai(
        model,
        doing_pai=True,
        save_name="travelplanner_mlp",
        making_graphs=False,
        maximizing_score=True,
    )
    return model

pai_model = make_dendrites_model().to(device)

# IMPORTANT: strict=False avoids those "Unexpected key ... shape" errors
pai_model.load_state_dict(pai_ck["model"], strict=False)

print("bop 's models loaded")

device: cpu
baseline best epoch: 7
pai best epoch: 30
baseline best thr: 0.8686298727989197
pai best thr     : 0.664389967918396
Running Dendrite Experiment
bop 's models loaded


# Get probabilities on TEST + compute curves + confusion matrices

In [None]:
@torch.no_grad()
def predict_probs(model, dl):
    model.eval()
    ys, probs = [], []
    for xb, yb in dl:
        xb = xb.to(device)
        yb = yb.to(device)
        logits = model(xb)
        p = torch.sigmoid(logits)
        ys.append(yb.detach().cpu().numpy())
        probs.append(p.detach().cpu().numpy())
    y = np.concatenate(ys).astype(int)
    prob = np.concatenate(probs).astype(float)
    return y, prob

def metrics_at_threshold(y, prob, thr):
    pred = (prob >= thr).astype(int)
    acc = accuracy_score(y, pred)
    f1  = f1_score(y, pred, pos_label=1, zero_division=0)
    cm  = confusion_matrix(y, pred, labels=[0,1])
    pr_auc = average_precision_score(y, prob)
    return {
        "acc": float(acc),
        "f1": float(f1),
        "pr_auc": float(pr_auc),
        "tn": int(cm[0,0]), "fp": int(cm[0,1]),
        "fn": int(cm[1,0]), "tp": int(cm[1,1]),
        "thr": float(thr),
    }

# TEST predictions
y_test_true, base_prob = predict_probs(base_model, test_dl)
_,           pai_prob  = predict_probs(pai_model,  test_dl)

# PR curve points
base_prec, base_rec, _ = precision_recall_curve(y_test_true, base_prob)
pai_prec,  pai_rec,  _ = precision_recall_curve(y_test_true, pai_prob)

# metrics at their own "best-val" threshold (fair: threshold chosen on val)
base_rep = metrics_at_threshold(y_test_true, base_prob, base_thr)
pai_rep  = metrics_at_threshold(y_test_true, pai_prob,  pai_thr)

print("BASE (test):", base_rep)
print("PAI  (test):", pai_rep)

BASE (test): {'acc': 0.5972222222222222, 'f1': 0.21621621621621623, 'pr_auc': 0.1865753620771634, 'tn': 39, 'fp': 25, 'fn': 4, 'tp': 4, 'thr': 0.8686298727989197}
PAI  (test): {'acc': 0.8888888888888888, 'f1': 0.42857142857142855, 'pr_auc': 0.3358808410751215, 'tn': 61, 'fp': 3, 'fn': 5, 'tp': 3, 'thr': 0.664389967918396}


# Plotly interactive dashboard

In [None]:
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from sklearn.metrics import precision_recall_curve

def f1_curve(y, prob):
    prec, rec, thr = precision_recall_curve(y, prob)
    if len(thr) == 0:
        return np.array([0.5]), np.array([0.0])
    f1s = (2 * prec[:-1] * rec[:-1]) / np.maximum(prec[:-1] + rec[:-1], 1e-12)
    return thr, f1s

# F1 vs threshold curves
base_thr_grid, base_f1s = f1_curve(y_test_true, base_prob)
pai_thr_grid,  pai_f1s  = f1_curve(y_test_true, pai_prob)

# Confusion matrices (2x2)
base_cm = np.array([[base_rep["tn"], base_rep["fp"]],
                    [base_rep["fn"], base_rep["tp"]]])
pai_cm  = np.array([[pai_rep["tn"],  pai_rep["fp"]],
                    [pai_rep["fn"],  pai_rep["tp"]]])

# Summary table
summary = pd.DataFrame([
    {"model":"baseline", **base_rep},
    {"model":"dendrites", **pai_rep},
])

# --- build dashboard ---
fig = make_subplots(
    rows=2, cols=2,
    subplot_titles=("Precision-Recall Curve (TEST)",
                    "F1 vs Threshold (TEST)",
                    "Confusion Matrix (Baseline @ val-thr)",
                    "Confusion Matrix (Dendrites @ val-thr)"),
    specs=[[{"type":"xy"}, {"type":"xy"}],
           [{"type":"heatmap"}, {"type":"heatmap"}]]
)

# PR curve
fig.add_trace(go.Scatter(x=base_rec, y=base_prec, mode="lines",
                         name=f"Baseline PR-AUC={base_rep['pr_auc']:.3f}"),
              row=1, col=1)
fig.add_trace(go.Scatter(x=pai_rec, y=pai_prec, mode="lines",
                         name=f"Dendrites PR-AUC={pai_rep['pr_auc']:.3f}"),
              row=1, col=1)

# Add markers at chosen thresholds (approx location by nearest prob threshold)
# (Plotly PR curve doesn't directly param by threshold; this is just a visual indicator)
fig.add_trace(go.Scatter(x=[0], y=[0], mode="markers", name=f"Baseline thr={base_thr:.3f}", showlegend=True),
              row=1, col=1)
fig.add_trace(go.Scatter(x=[0], y=[0], mode="markers", name=f"Dendrites thr={pai_thr:.3f}", showlegend=True),
              row=1, col=1)

# F1 vs threshold
fig.add_trace(go.Scatter(x=base_thr_grid, y=base_f1s, mode="lines", name="Baseline F1(thr)"),
              row=1, col=2)
fig.add_trace(go.Scatter(x=pai_thr_grid, y=pai_f1s, mode="lines", name="Dendrites F1(thr)"),
              row=1, col=2)
fig.add_vline(x=base_thr, line_dash="dash", row=1, col=2)
fig.add_vline(x=pai_thr,  line_dash="dash", row=1, col=2)

# Confusion heatmaps
fig.add_trace(go.Heatmap(z=base_cm, x=["Pred 0","Pred 1"], y=["True 0","True 1"],
                         showscale=False, text=base_cm, texttemplate="%{text}"),
              row=2, col=1)
fig.add_trace(go.Heatmap(z=pai_cm, x=["Pred 0","Pred 1"], y=["True 0","True 1"],
                         showscale=False, text=pai_cm, texttemplate="%{text}"),
              row=2, col=2)

fig.update_xaxes(title_text="Recall", row=1, col=1)
fig.update_yaxes(title_text="Precision", row=1, col=1)
fig.update_xaxes(title_text="Threshold", row=1, col=2)
fig.update_yaxes(title_text="F1 (positive class)", row=1, col=2)

fig.update_layout(
    height=850,
    title="TravelPlanner Pair-Matching: Baseline vs Dendrites (TEST)",
    hovermode="x unified"
)

fig.show()

# also show summary table nicely
summary

Unnamed: 0,model,acc,f1,pr_auc,tn,fp,fn,tp,thr
0,baseline,0.597222,0.216216,0.186575,39,25,4,4,0.86863
1,dendrites,0.888889,0.428571,0.335881,61,3,5,3,0.66439


# Diagnose data shape type

In [None]:
import numpy as np

def describe(x, name):
    print(f"\n{name}:")
    print("  type:", type(x))
    if hasattr(x, "shape"):
        print("  shape:", x.shape)
    try:
        arr = np.asarray(x)
        print("  np.asarray shape:", arr.shape, "dtype:", arr.dtype)
        # show a tiny sample (safe)
        flat = arr.reshape(-1)[:5]
        print("  head:", flat)
    except Exception as e:
        print("  np.asarray failed:", repr(e))

describe(base_prob, "base_prob")
describe(base_y,    "base_y")
describe(pai_prob,  "pai_prob")
describe(pai_y,     "pai_y")


base_prob:
  type: <class 'numpy.ndarray'>
  shape: (72,)
  np.asarray shape: (72,) dtype: float64
  head: [0.85073292 0.86881131 0.83157241 0.87938887 0.92229575]

base_y:
  type: <class 'numpy.ndarray'>
  shape: (28481,)
  np.asarray shape: (28481,) dtype: int64
  head: [0 0 0 0 0]

pai_prob:
  type: <class 'numpy.ndarray'>
  shape: (72,)
  np.asarray shape: (72,) dtype: float64
  head: [0.24384825 0.89279282 0.235192   0.20944589 0.15790349]

pai_y:
  type: <class 'numpy.ndarray'>
  shape: (28481,)
  np.asarray shape: (28481,) dtype: int64
  head: [0 0 0 0 0]


In [None]:
print("len(base_prob):", len(base_prob))
print("len(pai_prob): ", len(pai_prob))
print("len(base_y):   ", len(base_y))
print("len(pai_y):    ", len(pai_y))

# What does your CURRENT test_dl represent?
try:
    print("test_dl.dataset length:", len(test_dl.dataset))
except Exception as e:
    print("could not read len(test_dl.dataset):", e)

len(base_prob): 72
len(pai_prob):  72
len(base_y):    28481
len(pai_y):     28481
test_dl.dataset length: 72


# print batch

In [None]:
batch = next(iter(test_dl))
print("batch type:", type(batch))
print("batch len :", len(batch))

# show first element structure
b0 = batch[0]
print("elem[0] type:", type(b0))

# If dict-like
if hasattr(b0, "keys"):
    print("elem[0] keys:", list(b0.keys()))
    for k in list(b0.keys())[:8]:
        v = b0[k]
        if hasattr(v, "shape"):
            print(" ", k, "shape:", tuple(v.shape), "dtype:", getattr(v, "dtype", type(v)))
        else:
            print(" ", k, "type:", type(v))

# If tuple/list sample
elif isinstance(b0, (list, tuple)):
    print("elem[0] len:", len(b0))
    for i, v in enumerate(b0[:4]):
        if hasattr(v, "shape"):
            print(" ", i, "shape:", tuple(v.shape), "dtype:", getattr(v, "dtype", type(v)))
        else:
            print(" ", i, "type:", type(v))

else:
    print("elem[0] repr:", repr(b0)[:300])

batch type: <class 'list'>
batch len : 2
elem[0] type: <class 'torch.Tensor'>
elem[0] repr: tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0048, 0.0048],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0024, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0024, 0.0000, 0.0000


In [None]:
import torch

batch = next(iter(test_dl))
Xb, yb = batch[0], batch[1]   # because your batch is a list [X, y]

print("Xb:", type(Xb), Xb.shape, Xb.dtype)
print("yb:", type(yb), yb.shape, yb.dtype if torch.is_tensor(yb) else type(yb))

INPUT_DIM = Xb.shape[1]
print("TravelPlanner input dim =", INPUT_DIM)

Xb: <class 'torch.Tensor'> torch.Size([64, 20000]) torch.float32
yb: <class 'torch.Tensor'> torch.Size([64]) torch.int64
TravelPlanner input dim = 20000


In [None]:
import torch
from pathlib import Path

def infer_input_dim_from_state_dict(sd):
    # Find the FIRST 2D weight matrix (Linear layer), use its second dim as input_dim
    # For MLP: usually "net.0.weight" is [hidden, input_dim]
    for k, v in sd.items():
        if isinstance(v, torch.Tensor) and v.ndim == 2 and k.endswith("weight"):
            return int(v.shape[1]), k
    raise RuntimeError("Could not infer input_dim from checkpoint state_dict.")

def load_best_auto(run_dir, model_builder, allow_dendrites_shape_keys=True):
    run_dir = Path(run_dir)
    ckpt_path = run_dir / "best.pt"
    assert ckpt_path.exists(), f"Missing: {ckpt_path}"
    ck = torch.load(ckpt_path, map_location="cpu")

    sd = ck["model"] if "model" in ck else ck["model_state"] if "model_state" in ck else ck
    in_dim, kname = infer_input_dim_from_state_dict(sd)
    print(f"[load_best_auto] inferred input_dim={in_dim} from '{kname}'")

    model = model_builder(in_dim)

    # Handle dendrites “*.shape” keys (they show up in saved state_dict)
    if allow_dendrites_shape_keys:
        drop = [k for k in sd.keys() if k.endswith(".shape")]
        if drop:
            for k in drop:
                sd.pop(k, None)
            print(f"[load_best_auto] dropped {len(drop)} '*.shape' keys (OK for dendrites)")

    missing, unexpected = model.load_state_dict(sd, strict=False)
    if missing:
        print("[load_best_auto] missing keys (ok if expected):", missing[:6], "..." if len(missing)>6 else "")
    if unexpected:
        print("[load_best_auto] unexpected keys (ok if expected):", unexpected[:6], "..." if len(unexpected)>6 else "")

    return model, ck, in_dim

In [None]:
import torch
import torch.nn as nn

# If you already have these from the PerforatedAI notebook, keep them:
# - GPA (globals_perforatedai)
# - UPA (utils_perforatedai)
# - dendrite_layer / dendrite module creator you used earlier
#
# The only goal here: make_dendrites_model(in_dim) builds the same model,
# but with the correct input feature size.

def make_dendrites_model(in_dim: int):
    """
    Build the dendrites MLP with the correct input dimension.
    This MUST match whatever architecture you trained with earlier,
    except in_dim is now configurable.
    """
    # --- IMPORTANT ---
    # Replace the body below with the SAME dendrites model you used in training.
    # The only change is: first Linear layer uses `in_dim`.

    # Example skeleton (you must align with your training architecture):
    hidden = 128

    model = MLP_for_pai(in_dim, hidden=hidden)  # <-- if your class exists
    return model

In [None]:
import inspect

def make_dendrites_model(in_dim: int):
    """
    Build the dendrites model using the SAME class you trained with,
    but with the correct input dimension.
    This version adapts to whether your class supports a `hidden` kwarg.
    """
    ctor = MLP_for_pai  # whatever class/function you used during training

    # If it's a class/function with signature we can inspect:
    try:
        sig = inspect.signature(ctor)
        params = sig.parameters

        if "hidden" in params:
            return ctor(in_dim, hidden=128)  # only if supported
        else:
            return ctor(in_dim)              # your current case
    except Exception:
        # fallback: just try simplest
        return ctor(in_dim)

In [None]:
base_best_model, base_ck, base_in_dim = load_best_auto(
    baseline_dir, lambda d: MLP(d), allow_dendrites_shape_keys=False
)

pai_best_model,  pai_ck,  pai_in_dim  = load_best_auto(
    pai_dir, make_dendrites_model, allow_dendrites_shape_keys=True
)

print("baseline input_dim:", base_in_dim)
print("pai input_dim     :", pai_in_dim)

# sanity: does test_dl match?
xb, yb = next(iter(test_dl))
print("test batch X:", xb.shape, xb.dtype)
print("test batch y:", yb.shape, yb.dtype)

[load_best_auto] inferred input_dim=20000 from 'net.0.weight'
[load_best_auto] inferred input_dim=20000 from 'net.0.main_module.weight'
[load_best_auto] missing keys (ok if expected): ['net.0.weight', 'net.0.bias', 'net.3.weight', 'net.3.bias', 'net.6.weight', 'net.6.bias'] 
[load_best_auto] unexpected keys (ok if expected): ['tracker_string', 'net.0.this_node_index', 'net.0.this_output_dimensions', 'net.0.main_module.weight', 'net.0.main_module.bias', 'net.0.dendrite_module.num_cycles'] ...
baseline input_dim: 20000
pai input_dim     : 20000
test batch X: torch.Size([64, 20000]) torch.float32
test batch y: torch.Size([64]) torch.int64


# collect probs/labels from your DataLoader (your batches are [X, y])

In [None]:
import numpy as np
import torch

@torch.no_grad()
def collect_probs_labels(model, dl, device):
    model.eval()
    probs, ys = [], []

    for batch in dl:
        # Your loader returns a list/tuple: [X, y]
        if isinstance(batch, (list, tuple)) and len(batch) == 2:
            x, y = batch
        else:
            raise TypeError(f"Unexpected batch type: {type(batch)}. Expected (X,y).")

        x = x.to(device)
        y = y.to(device)

        out = model(x)

        # Handle different model outputs:
        if isinstance(out, (list, tuple)):
            logits = out[0]
        elif hasattr(out, "logits"):
            logits = out.logits
        else:
            logits = out

        # logits can be shape (B,2) or (B,)
        if logits.ndim == 2 and logits.shape[1] == 2:
            p1 = torch.softmax(logits, dim=1)[:, 1]
        else:
            p1 = torch.sigmoid(logits.view(-1))

        probs.append(p1.detach().cpu())
        ys.append(y.detach().cpu())

    probs = torch.cat(probs).numpy()
    ys    = torch.cat(ys).numpy().astype(int)
    return probs, ys

base_prob, base_y = collect_probs_labels(base_best_model, test_dl, device)
pai_prob,  pai_y  = collect_probs_labels(pai_best_model,  test_dl, device)

print("base:", base_prob.shape, base_y.shape, "pos:", base_y.sum())
print("pai :", pai_prob.shape,  pai_y.shape,  "pos:", pai_y.sum())

base: (72,) (72,) pos: 8
pai : (72,) (72,) pos: 8


In [None]:
import pandas as pd
import plotly.graph_objects as go
from sklearn.metrics import precision_recall_curve, average_precision_score

def pr_data(y, prob):
    p, r, thr = precision_recall_curve(y, prob)
    ap = average_precision_score(y, prob)
    return p, r, thr, ap

bp, br, bthr, bap = pr_data(base_y, base_prob)
pp, pr, pthr, pap = pr_data(pai_y,  pai_prob)

fig = go.Figure()

fig.add_trace(go.Scatter(
    x=br, y=bp, mode="lines",
    name=f"Baseline PR-AUC={bap:.3f}"
))
fig.add_trace(go.Scatter(
    x=pr, y=pp, mode="lines",
    name=f"Dendrites PR-AUC={pap:.3f}"
))

fig.update_layout(
    title="TravelPlanner — Precision-Recall Curve (Test)",
    xaxis_title="Recall",
    yaxis_title="Precision",
    hovermode="x unified",
    template="plotly_white"
)

fig.show()

# ---- Threshold sweep for F1 (shows why your selected threshold matters)
def sweep_f1(y, prob, n=200):
    ts = np.linspace(0.0, 1.0, n)
    rows = []
    for t in ts:
        pred = (prob >= t).astype(int)
        tp = int(((pred==1) & (y==1)).sum())
        fp = int(((pred==1) & (y==0)).sum())
        fn = int(((pred==0) & (y==1)).sum())
        prec = tp / (tp + fp) if (tp+fp) else 0.0
        rec  = tp / (tp + fn) if (tp+fn) else 0.0
        f1   = (2*prec*rec/(prec+rec)) if (prec+rec) else 0.0
        rows.append((t, prec, rec, f1, tp, fp, fn))
    return pd.DataFrame(rows, columns=["thr","precision","recall","f1","tp","fp","fn"])

bdf = sweep_f1(base_y, base_prob)
pdf = sweep_f1(pai_y,  pai_prob)

fig2 = go.Figure()
fig2.add_trace(go.Scatter(x=bdf["thr"], y=bdf["f1"], mode="lines", name="Baseline F1 vs threshold"))
fig2.add_trace(go.Scatter(x=pdf["thr"], y=pdf["f1"], mode="lines", name="Dendrites F1 vs threshold"))
fig2.update_layout(
    title="Test F1 vs Threshold (shows best operating point)",
    xaxis_title="Threshold",
    yaxis_title="F1",
    hovermode="x unified",
    template="plotly_white"
)
fig2.show()

print("Best baseline threshold by sweep:", bdf.loc[bdf.f1.idxmax(), ["thr","f1","precision","recall","tp","fp","fn"]].to_dict())
print("Best dendrites threshold by sweep:", pdf.loc[pdf.f1.idxmax(), ["thr","f1","precision","recall","tp","fp","fn"]].to_dict())

Best baseline threshold by sweep: {'thr': 0.8743718592964824, 'f1': 0.23076923076923078, 'precision': 0.16666666666666666, 'recall': 0.375, 'tp': 3.0, 'fp': 15.0, 'fn': 5.0}
Best dendrites threshold by sweep: {'thr': 0.0, 'f1': 0.19999999999999998, 'precision': 0.1111111111111111, 'recall': 1.0, 'tp': 8.0, 'fp': 64.0, 'fn': 0.0}


In [None]:
import numpy as np
import pandas as pd
import torch
from sklearn.metrics import precision_recall_curve, average_precision_score
import plotly.graph_objects as go

@torch.no_grad()
def collect_probs_labels_auto(model, dl, device):
    model.eval()
    probs, ys = [], []

    for batch in dl:
        # Your DataLoader gives [X, y]
        if isinstance(batch, (list, tuple)) and len(batch) == 2:
            x, y = batch
        else:
            raise TypeError(f"Unexpected batch type: {type(batch)}; expected (X,y)")

        x = x.to(device)
        y = y.to(device)

        out = model(x)

        # unwrap common structures
        if isinstance(out, (list, tuple)):
            out = out[0]
        if hasattr(out, "logits"):
            out = out.logits

        out = out.detach()

        # ---- AUTO DETECT PROBABILITY VS LOGITS ----
        # Case A: already probabilities in [0,1] with shape (B,) or (B,1)
        if out.ndim == 2 and out.shape[1] == 1:
            out = out[:, 0]

        if out.ndim == 1 and (out.min().item() >= 0.0) and (out.max().item() <= 1.0):
            p1 = out  # already prob

        # Case B: 2-class logits (B,2)
        elif out.ndim == 2 and out.shape[1] == 2:
            p1 = torch.softmax(out, dim=1)[:, 1]

        # Case C: single-logit (B,) -> sigmoid
        else:
            p1 = torch.sigmoid(out.view(-1))

        probs.append(p1.cpu())
        ys.append(y.cpu())

    probs = torch.cat(probs).numpy().astype(np.float64)
    ys    = torch.cat(ys).numpy().astype(int)
    return probs, ys


def pr_metrics(y, prob):
    p, r, thr = precision_recall_curve(y, prob)
    ap = average_precision_score(y, prob)
    return p, r, thr, ap


def sweep_f1_unique(y, prob):
    # evaluate only at meaningful thresholds (unique probs)
    ts = np.unique(prob)
    ts = np.concatenate(([0.0], ts, [1.0]))
    best = None
    rows = []
    for t in ts:
        pred = (prob >= t).astype(int)
        tp = int(((pred==1) & (y==1)).sum())
        fp = int(((pred==1) & (y==0)).sum())
        fn = int(((pred==0) & (y==1)).sum())
        prec = tp / (tp + fp) if (tp+fp) else 0.0
        rec  = tp / (tp + fn) if (tp+fn) else 0.0
        f1   = (2*prec*rec/(prec+rec)) if (prec+rec) else 0.0
        rows.append((t, prec, rec, f1, tp, fp, fn))
    df = pd.DataFrame(rows, columns=["thr","precision","recall","f1","tp","fp","fn"])
    best_row = df.loc[df["f1"].idxmax()].to_dict()
    return df, best_row


# --- recompute aligned probs/labels from test_dl ---
base_prob, y_test = collect_probs_labels_auto(base_best_model, test_dl, device)
pai_prob,  y2     = collect_probs_labels_auto(pai_best_model,  test_dl, device)
assert y_test.shape == y2.shape, "y mismatch between models"
assert base_prob.shape == y_test.shape == pai_prob.shape, "shape mismatch"

print("Shapes:", base_prob.shape, y_test.shape)
print("Positives:", y_test.sum(), "/", len(y_test))
print("Base prob range:", float(base_prob.min()), float(base_prob.max()))
print("PAI  prob range:", float(pai_prob.min()), float(pai_prob.max()))

bp, br, _, bap = pr_metrics(y_test, base_prob)
pp, pr, _, pap = pr_metrics(y_test, pai_prob)

bdf, bbest = sweep_f1_unique(y_test, base_prob)
pdf, pbest = sweep_f1_unique(y_test, pai_prob)

print("Baseline AP:", bap, "Best F1 row:", bbest)
print("PAI      AP:", pap, "Best F1 row:", pbest)

# --- Plotly PR curve ---
fig = go.Figure()
fig.add_trace(go.Scatter(x=br, y=bp, mode="lines", name=f"Baseline AP={bap:.3f}"))
fig.add_trace(go.Scatter(x=pr, y=pp, mode="lines", name=f"Dendrites AP={pap:.3f}"))
fig.update_layout(
    title="TravelPlanner — Precision-Recall Curve (Test)",
    xaxis_title="Recall",
    yaxis_title="Precision",
    hovermode="x unified",
    template="plotly_white"
)
fig.show()

# --- Plotly F1 vs threshold (unique thresholds, not linspace) ---
fig2 = go.Figure()
fig2.add_trace(go.Scatter(x=bdf["thr"], y=bdf["f1"], mode="lines+markers", name="Baseline F1"))
fig2.add_trace(go.Scatter(x=pdf["thr"], y=pdf["f1"], mode="lines+markers", name="Dendrites F1"))
fig2.update_layout(
    title="Test F1 vs Threshold (unique thresholds)",
    xaxis_title="Threshold",
    yaxis_title="F1",
    hovermode="x unified",
    template="plotly_white"
)
fig2.show()

# Optional: save interactive HTML for Devpost
fig.write_html("/content/pr_curve.html", include_plotlyjs="cdn")
fig2.write_html("/content/f1_vs_threshold.html", include_plotlyjs="cdn")
print("Saved: /content/pr_curve.html and /content/f1_vs_threshold.html")

Shapes: (72,) (72,)
Positives: 8 / 72
Base prob range: 0.7979910373687744 0.9244235157966614
PAI  prob range: 0.47809940576553345 0.478707492351532
Baseline AP: 0.1865753620771634 Best F1 row: {'thr': 0.8793888688087463, 'precision': 0.23076923076923078, 'recall': 0.375, 'f1': 0.2857142857142857, 'tp': 3.0, 'fp': 10.0, 'fn': 5.0}
PAI      AP: 0.1682115269546449 Best F1 row: {'thr': 0.4784682095050812, 'precision': 0.23076923076923078, 'recall': 0.375, 'f1': 0.2857142857142857, 'tp': 3.0, 'fp': 10.0, 'fn': 5.0}


Saved: /content/pr_curve.html and /content/f1_vs_threshold.html


In [None]:
# === Cell: Plotly dashboard (PR curve + F1 sweep + confusion matrices) ===
import numpy as np
import torch
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio

from sklearn.metrics import (
    precision_recall_curve, average_precision_score,
    confusion_matrix
)

pio.renderers.default = "colab"  # IMPORTANT: prevents huge HTML dumps

@torch.no_grad()
def collect_probs_labels(model, dl, device):
    model.eval()
    probs, ys = [], []
    for batch in dl:
        # dataloader may return: (X,y) OR {"x":X,"y":y} OR {"input":X,"label":y} etc.
        if isinstance(batch, (list, tuple)):
            if len(batch) < 2:
                raise ValueError(f"Expected (X,y) but got len={len(batch)}")
            X, y = batch[0], batch[1]
        elif isinstance(batch, dict):
            # try common keys
            X = batch.get("x") or batch.get("X") or batch.get("inputs") or batch.get("input_ids") or batch.get("input")
            y = batch.get("y") or batch.get("Y") or batch.get("labels") or batch.get("label")
            if X is None or y is None:
                raise KeyError(f"Batch dict keys={list(batch.keys())} — add mapping here.")
        else:
            raise TypeError(f"Unsupported batch type: {type(batch)}")

        X = X.to(device)
        y = y.to(device)

        out = model(X)
        # handle outputs: logits could be (B,), (B,1), or tuple
        if isinstance(out, (tuple, list)):
            out = out[0]
        logits = out.squeeze(-1)

        prob = torch.sigmoid(logits).detach().cpu().numpy().reshape(-1)
        probs.append(prob)
        ys.append(y.detach().cpu().numpy().reshape(-1))

    probs = np.concatenate(probs).astype(np.float64)
    ys    = np.concatenate(ys).astype(int)
    return probs, ys

def best_threshold_by_f1(y_true, prob):
    # Use unique probs as candidate thresholds (best for tiny test sets like 72)
    uniq = np.unique(prob)
    # also include 0/1 edges
    cand = np.unique(np.concatenate(([0.0], uniq, [1.0])))

    best = None
    for thr in cand:
        pred = (prob >= thr).astype(int)
        tp = int(((pred == 1) & (y_true == 1)).sum())
        fp = int(((pred == 1) & (y_true == 0)).sum())
        fn = int(((pred == 0) & (y_true == 1)).sum())
        prec = tp / (tp + fp) if (tp + fp) else 0.0
        rec  = tp / (tp + fn) if (tp + fn) else 0.0
        f1   = (2 * prec * rec / (prec + rec)) if (prec + rec) else 0.0

        row = dict(thr=float(thr), f1=float(f1), precision=float(prec), recall=float(rec),
                   tp=tp, fp=fp, fn=fn)
        if (best is None) or (row["f1"] > best["f1"]) or (row["f1"] == best["f1"] and row["precision"] > best["precision"]):
            best = row
    return best, cand

def metrics_at_threshold(y_true, prob, thr):
    pred = (prob >= thr).astype(int)
    tn, fp, fn, tp = confusion_matrix(y_true, pred, labels=[0,1]).ravel()
    acc = (tp + tn) / (tp + tn + fp + fn)
    prec = tp / (tp + fp) if (tp + fp) else 0.0
    rec  = tp / (tp + fn) if (tp + fn) else 0.0
    f1   = (2 * prec * rec / (prec + rec)) if (prec + rec) else 0.0
    ap   = float(average_precision_score(y_true, prob))
    return dict(acc=float(acc), f1=float(f1), pr_auc=ap, tn=int(tn), fp=int(fp), fn=int(fn), tp=int(tp), thr=float(thr))

# 1) Recompute probs/labels from loaders (prevents old-length mismatch)
base_val_prob, base_val_y = collect_probs_labels(base_best_model, val_dl, device)
pai_val_prob,  pai_val_y  = collect_probs_labels(pai_best_model,  val_dl, device)

base_test_prob, base_test_y = collect_probs_labels(base_best_model, test_dl, device)
pai_test_prob,  pai_test_y  = collect_probs_labels(pai_best_model,  test_dl, device)

print("VAL sizes:", base_val_prob.shape, base_val_y.shape, "|", pai_val_prob.shape, pai_val_y.shape)
print("TEST sizes:", base_test_prob.shape, base_test_y.shape, "|", pai_test_prob.shape, pai_test_y.shape)

# 2) Choose threshold on VAL (max F1), then evaluate on TEST
base_best_val, base_thr_grid = best_threshold_by_f1(base_val_y, base_val_prob)
pai_best_val,  pai_thr_grid  = best_threshold_by_f1(pai_val_y,  pai_val_prob)

base_test = metrics_at_threshold(base_test_y, base_test_prob, base_best_val["thr"])
pai_test  = metrics_at_threshold(pai_test_y,  pai_test_prob,  pai_best_val["thr"])

display(pd.DataFrame([
    {"model":"baseline",  **base_test, "val_best_f1": base_best_val["f1"]},
    {"model":"dendrites", **pai_test,  "val_best_f1": pai_best_val["f1"]},
]))

# 3) PR curves (TEST)
b_prec, b_rec, _ = precision_recall_curve(base_test_y, base_test_prob)
p_prec, p_rec, _ = precision_recall_curve(pai_test_y,  pai_test_prob)
b_ap = average_precision_score(base_test_y, base_test_prob)
p_ap = average_precision_score(pai_test_y,  pai_test_prob)

# 4) F1 sweeps (TEST) over unique thresholds
def f1_sweep(y_true, prob):
    cand = np.unique(np.concatenate(([0.0], np.unique(prob), [1.0])))
    f1s = []
    for thr in cand:
        pred = (prob >= thr).astype(int)
        tp = ((pred==1)&(y_true==1)).sum()
        fp = ((pred==1)&(y_true==0)).sum()
        fn = ((pred==0)&(y_true==1)).sum()
        prec = tp/(tp+fp) if (tp+fp) else 0.0
        rec  = tp/(tp+fn) if (tp+fn) else 0.0
        f1 = (2*prec*rec/(prec+rec)) if (prec+rec) else 0.0
        f1s.append(f1)
    return cand, np.array(f1s, dtype=float)

b_thr, b_f1 = f1_sweep(base_test_y, base_test_prob)
p_thr, p_f1 = f1_sweep(pai_test_y,  pai_test_prob)

# 5) Confusion matrices at VAL-chosen threshold (TEST eval)
b_pred = (base_test_prob >= base_best_val["thr"]).astype(int)
p_pred = (pai_test_prob  >= pai_best_val["thr"]).astype(int)
b_cm = confusion_matrix(base_test_y, b_pred, labels=[0,1])
p_cm = confusion_matrix(pai_test_y,  p_pred, labels=[0,1])

# 6) Plotly dashboard
fig = make_subplots(
    rows=2, cols=2,
    subplot_titles=(
        f"Precision–Recall (TEST) — Baseline AP={b_ap:.3f}, Dendrites AP={p_ap:.3f}",
        "F1 vs Threshold (TEST) — dashed = threshold picked on VAL",
        "Confusion Matrix (Baseline @ val-thr, TEST)",
        "Confusion Matrix (Dendrites @ val-thr, TEST)"
    ),
    specs=[[{"type":"xy"},{"type":"xy"}],
           [{"type":"heatmap"},{"type":"heatmap"}]]
)

# PR
fig.add_trace(go.Scatter(x=b_rec, y=b_prec, mode="lines", name=f"Baseline AP={b_ap:.3f}"), row=1, col=1)
fig.add_trace(go.Scatter(x=p_rec, y=p_prec, mode="lines", name=f"Dendrites AP={p_ap:.3f}"), row=1, col=1)
fig.update_xaxes(title_text="Recall", row=1, col=1)
fig.update_yaxes(title_text="Precision", row=1, col=1)

# F1 sweep
fig.add_trace(go.Scatter(x=b_thr, y=b_f1, mode="lines+markers", name="Baseline F1"), row=1, col=2)
fig.add_trace(go.Scatter(x=p_thr, y=p_f1, mode="lines+markers", name="Dendrites F1"), row=1, col=2)
fig.add_vline(x=base_best_val["thr"], line_dash="dash", row=1, col=2)
fig.add_vline(x=pai_best_val["thr"],  line_dash="dash", row=1, col=2)
fig.update_xaxes(title_text="Threshold", row=1, col=2)
fig.update_yaxes(title_text="F1", row=1, col=2)

# Confusion matrices
fig.add_trace(go.Heatmap(z=b_cm, x=["Pred 0","Pred 1"], y=["True 0","True 1"], showscale=False), row=2, col=1)
fig.add_trace(go.Heatmap(z=p_cm, x=["Pred 0","Pred 1"], y=["True 0","True 1"], showscale=False), row=2, col=2)

fig.update_layout(
    height=850, width=1200,
    title="TravelPlanner Pair-Matching — Baseline vs Dendrites (VAL-chosen threshold, TEST eval)",
    hovermode="x unified"
)
fig.show()

print("VAL-chosen baseline thr:", base_best_val)
print("VAL-chosen dendrites thr:", pai_best_val)

VAL sizes: (54,) (54,) | (54,) (54,)
TEST sizes: (72,) (72,) | (72,) (72,)


Unnamed: 0,model,acc,f1,pr_auc,tn,fp,fn,tp,thr,val_best_f1
0,baseline,0.597222,0.216216,0.186575,39,25,4,4,0.86863,0.275862
1,dendrites,0.791667,0.117647,0.168212,56,8,7,1,0.478547,0.285714


VAL-chosen baseline thr: {'thr': 0.8686298727989197, 'f1': 0.27586206896551724, 'precision': 0.17391304347826086, 'recall': 0.6666666666666666, 'tp': 4, 'fp': 19, 'fn': 2}
VAL-chosen dendrites thr: {'thr': 0.47854703664779663, 'f1': 0.2857142857142857, 'precision': 1.0, 'recall': 0.16666666666666666, 'tp': 1, 'fp': 0, 'fn': 5}


# Fix  eval utilities (handles batch as list/tuple OR dict)

In [None]:
import numpy as np
import torch

@torch.no_grad()
def collect_probs_labels(model, dl, device):
    model.eval()
    probs, ys = [], []

    for batch in dl:
        # batch could be (X, y) OR {"x":..., "y"/"labels":...}
        if isinstance(batch, (list, tuple)):
            if len(batch) == 2:
                X, y = batch
            else:
                raise ValueError(f"Unexpected batch tuple/list length={len(batch)}")
        elif isinstance(batch, dict):
            # adjust keys if your dataset uses different names
            X = batch.get("x", batch.get("X", None))
            y = batch.get("y", batch.get("labels", None))
            if X is None or y is None:
                raise KeyError(f"Batch dict keys={list(batch.keys())} (expected x/X and y/labels)")
        else:
            raise TypeError(f"Unexpected batch type: {type(batch)}")

        X = X.to(device)
        y = y.to(device)

        logits = model(X)
        if logits.ndim == 2 and logits.shape[-1] == 1:
            logits = logits.squeeze(-1)
        prob = torch.sigmoid(logits).detach().cpu().numpy()

        probs.append(prob)
        ys.append(y.detach().cpu().numpy())

    probs = np.concatenate(probs).astype(float)
    ys = np.concatenate(ys).astype(int)
    return probs, ys

#  Plotly dashboard (PR curve + F1 sweep + confusion matrices)


In [None]:
import numpy as np
import pandas as pd
from sklearn.metrics import precision_recall_curve, average_precision_score, confusion_matrix
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def pr_points(y, p):
    prec, rec, thr = precision_recall_curve(y, p)
    ap = average_precision_score(y, p)
    return prec, rec, thr, ap

def f1_sweep(y, p):
    # evaluate F1 at all unique prob thresholds (plus 0,1 edges)
    thrs = np.unique(np.clip(p, 0, 1))
    if thrs[0] != 0.0: thrs = np.r_[0.0, thrs]
    if thrs[-1] != 1.0: thrs = np.r_[thrs, 1.0]

    rows = []
    for t in thrs:
        yhat = (p >= t).astype(int)
        tn, fp, fn, tp = confusion_matrix(y, yhat, labels=[0,1]).ravel()
        precision = tp / (tp + fp + 1e-9)
        recall    = tp / (tp + fn + 1e-9)
        f1        = 2 * precision * recall / (precision + recall + 1e-9)
        rows.append((t, f1, precision, recall, tp, fp, fn, tn))
    df = pd.DataFrame(rows, columns=["thr","f1","precision","recall","tp","fp","fn","tn"])
    return df

def plot_dashboard(
    base_prob, base_y, base_thr_val,
    pai_prob,  pai_y,  pai_thr_val,
    title="TravelPlanner — Baseline vs Dendrites (VAL-chosen threshold, TEST eval)"
):
    # PR curves
    b_prec, b_rec, _, b_ap = pr_points(base_y, base_prob)
    p_prec, p_rec, _, p_ap = pr_points(pai_y,  pai_prob)

    # F1 sweeps
    b_sweep = f1_sweep(base_y, base_prob)
    p_sweep = f1_sweep(pai_y,  pai_prob)

    # confusion matrices at chosen thresholds
    def cm_at(y, prob, thr):
        yhat = (prob >= thr).astype(int)
        tn, fp, fn, tp = confusion_matrix(y, yhat, labels=[0,1]).ravel()
        return np.array([[tn, fp],[fn, tp]]), (tn, fp, fn, tp)

    b_cm, (b_tn, b_fp, b_fn, b_tp) = cm_at(base_y, base_prob, base_thr_val)
    p_cm, (p_tn, p_fp, p_fn, p_tp) = cm_at(pai_y,  pai_prob,  pai_thr_val)

    fig = make_subplots(
        rows=2, cols=2,
        specs=[[{"type":"xy"},{"type":"xy"}],
               [{"type":"heatmap"},{"type":"heatmap"}]],
        subplot_titles=(
            f"Precision–Recall (TEST) — Baseline AP={b_ap:.3f}, Dendrites AP={p_ap:.3f}",
            "F1 vs Threshold (TEST) — dashed = threshold picked on VAL",
            f"Confusion Matrix (Baseline @ val-thr={base_thr_val:.3f})",
            f"Confusion Matrix (Dendrites @ val-thr={pai_thr_val:.3f})",
        ),
        vertical_spacing=0.15, horizontal_spacing=0.10
    )

    # PR
    fig.add_trace(go.Scatter(x=b_rec, y=b_prec, mode="lines", name=f"Baseline AP={b_ap:.3f}"), row=1, col=1)
    fig.add_trace(go.Scatter(x=p_rec, y=p_prec, mode="lines", name=f"Dendrites AP={p_ap:.3f}"), row=1, col=1)
    fig.update_xaxes(title_text="Recall", row=1, col=1)
    fig.update_yaxes(title_text="Precision", row=1, col=1)

    # F1 sweep
    fig.add_trace(go.Scatter(x=b_sweep["thr"], y=b_sweep["f1"], mode="lines+markers", name="Baseline F1"), row=1, col=2)
    fig.add_trace(go.Scatter(x=p_sweep["thr"], y=p_sweep["f1"], mode="lines+markers", name="Dendrites F1"), row=1, col=2)

    # val-threshold vertical lines
    for thr, label in [(base_thr_val, "Baseline val-thr"), (pai_thr_val, "Dendrites val-thr")]:
        fig.add_vline(x=thr, line_dash="dash", row=1, col=2)

    fig.update_xaxes(title_text="Threshold", row=1, col=2)
    fig.update_yaxes(title_text="F1", row=1, col=2)

    # Confusion matrices
    fig.add_trace(go.Heatmap(
        z=b_cm, text=b_cm, texttemplate="%{text}",
        x=["Pred 0","Pred 1"], y=["True 0","True 1"],
        showscale=False,
        hovertemplate="y=%{y}<br>x=%{x}<br>count=%{z}<extra></extra>"
    ), row=2, col=1)

    fig.add_trace(go.Heatmap(
        z=p_cm, text=p_cm, texttemplate="%{text}",
        x=["Pred 0","Pred 1"], y=["True 0","True 1"],
        showscale=False,
        hovertemplate="y=%{y}<br>x=%{x}<br>count=%{z}<extra></extra>"
    ), row=2, col=2)

    fig.update_layout(
        height=900,
        title=title,
        hovermode="x unified",
        legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0),
        margin=dict(l=40, r=40, t=90, b=40),
    )

    # Add a compact summary annotation
    fig.add_annotation(
        x=0.5, y=-0.08, xref="paper", yref="paper", showarrow=False,
        text=(
            f"Baseline @val-thr: tp={b_tp}, fp={b_fp}, fn={b_fn}, tn={b_tn} | "
            f"Dendrites @val-thr: tp={p_tp}, fp={p_fp}, fn={p_fn}, tn={p_tn}"
        )
    )
    return fig

In [None]:
import numpy as np
from sklearn.metrics import confusion_matrix

def best_f1_threshold(y, prob):
    y = np.asarray(y).astype(int)
    prob = np.asarray(prob).astype(float)

    thrs = np.unique(np.clip(prob, 0, 1))
    if thrs[0] != 0.0: thrs = np.r_[0.0, thrs]
    if thrs[-1] != 1.0: thrs = np.r_[thrs, 1.0]

    best = None
    for t in thrs:
        yhat = (prob >= t).astype(int)
        tn, fp, fn, tp = confusion_matrix(y, yhat, labels=[0,1]).ravel()
        precision = tp / (tp + fp + 1e-9)
        recall    = tp / (tp + fn + 1e-9)
        f1        = 2 * precision * recall / (precision + recall + 1e-9)
        row = dict(thr=float(t), f1=float(f1), precision=float(precision), recall=float(recall),
                   tp=int(tp), fp=int(fp), fn=int(fn), tn=int(tn))
        if best is None or row["f1"] > best["f1"]:
            best = row
    return best

# collect VAL probs/labels
base_prob_val, base_y_val = collect_probs_labels(base_best_model, val_dl, device)
pai_prob_val,  pai_y_val  = collect_probs_labels(pai_best_model,  val_dl, device)

best_base_val = best_f1_threshold(base_y_val, base_prob_val)
best_pai_val  = best_f1_threshold(pai_y_val,  pai_prob_val)

base_thr_val = best_base_val["thr"]
pai_thr_val  = best_pai_val["thr"]

print("VAL-chosen baseline thr:", best_base_val)
print("VAL-chosen dendrites thr:", best_pai_val)

VAL-chosen baseline thr: {'thr': 0.8686298727989197, 'f1': 0.27586206861831153, 'precision': 0.17391304347069944, 'recall': 0.6666666665555555, 'tp': 4, 'fp': 19, 'fn': 2, 'tn': 29}
VAL-chosen dendrites thr: {'thr': 0.47854703664779663, 'f1': 0.28571428538775506, 'precision': 0.9999999989999999, 'recall': 0.16666666663888888, 'tp': 1, 'fp': 0, 'fn': 5, 'tn': 48}


In [None]:
# collect TEST probs/labels
base_prob_test, base_y_test = collect_probs_labels(base_best_model, test_dl, device)
pai_prob_test,  pai_y_test  = collect_probs_labels(pai_best_model,  test_dl, device)

fig = plot_dashboard(
    base_prob_test, base_y_test, base_thr_val,
    pai_prob_test,  pai_y_test,  pai_thr_val
)
fig.show()

In [None]:
print("VAL sizes:", base_prob_val.shape, base_y_val.shape, "|", pai_prob_val.shape, pai_y_val.shape)
print("TEST sizes:", base_prob_test.shape, base_y_test.shape, "|", pai_prob_test.shape, pai_y_test.shape)
assert base_prob_val.shape == base_y_val.shape
assert pai_prob_val.shape  == pai_y_val.shape
assert base_prob_test.shape == base_y_test.shape
assert pai_prob_test.shape  == pai_y_test.shape

VAL sizes: (54,) (54,) | (54,) (54,)
TEST sizes: (72,) (72,) | (72,) (72,)


In [None]:
import textwrap
import numpy as np

def demo_example(i, ds_test, base_prob_test, pai_prob_test, base_thr_val, pai_thr_val, max_chars=450):
    ex = ds_test[i]  # adapt this if your dataset returns (x, y) instead of dict

    # ---- Adapt these fields to your dataset keys ----
    # Common patterns: ex["query"], ex["plan"], ex["org"], ex["dest"], ex["days"]
    query = ex.get("query", str(ex))
    label = int(ex.get("label", ex.get("y", -1)))

    base_p = float(base_prob_test[i])
    pai_p  = float(pai_prob_test[i])

    base_pred = int(base_p >= base_thr_val)
    pai_pred  = int(pai_p  >= pai_thr_val)

    print("="*100)
    print(f"Example #{i} | y={label}")
    print("-"*100)
    print("Query (truncated):")
    print(textwrap.fill(query[:max_chars], width=110))
    print("-"*100)
    print(f"Baseline   prob={base_p:.4f}  thr={base_thr_val:.3f}  pred={base_pred}")
    print(f"Dendrites  prob={pai_p:.4f}   thr={pai_thr_val:.3f}  pred={pai_pred}")
    print("="*100)

# usage:
# demo_example(0, test_dataset, base_prob_test, pai_prob_test, base_thr_val, pai_thr_val)

In [None]:
import pandas as pd
import numpy as np

def _safe_get(ex, key, default=None):
    # supports dict, HF Dataset row, etc.
    if ex is None:
        return default
    if hasattr(ex, "get"):
        return ex.get(key, default)
    try:
        return ex[key]
    except Exception:
        return default

def error_table(ds, probs, thr, name, y=None, k=5, query_key="query"):
    probs = np.asarray(probs).astype(float)

    # y can come from: explicit y, or ds["label"]/ds["y"]
    if y is None:
        if ds is None:
            raise ValueError("Pass y=... if ds is None")
        y = np.array([int(_safe_get(ds[i], "label", _safe_get(ds[i], "y"))) for i in range(len(ds))])
    else:
        y = np.asarray(y).astype(int)
        assert len(y) == len(probs), f"len(y)={len(y)} must match len(probs)={len(probs)}"

    pred = (probs >= thr).astype(int)

    rows = []
    N = len(probs) if ds is None else len(ds)
    for i in range(N):
        ex = None if ds is None else ds[i]
        q = _safe_get(ex, query_key, "")
        rows.append({
            "i": i,
            "y": int(y[i]),
            "prob": float(probs[i]),
            "pred": int(pred[i]),
            "type": ("TP" if pred[i]==1 and y[i]==1 else
                     "FP" if pred[i]==1 and y[i]==0 else
                     "FN" if pred[i]==0 and y[i]==1 else "TN"),
            "query": (q[:180] + "…") if isinstance(q, str) and len(q) > 180 else q
        })

    df = pd.DataFrame(rows)
    fp = df[df["type"]=="FP"].sort_values("prob", ascending=False).head(k)
    fn = df[df["type"]=="FN"].sort_values("prob", ascending=True).head(k)

    print(f"\n{name} — top {k} FP and FN (VAL-chosen thr={thr:.3f})")
    return pd.concat([fp.assign(bucket="FP"), fn.assign(bucket="FN")], axis=0)

In [None]:
tbl_base = error_table(None, base_prob_test, base_thr_val, "Baseline", y=base_y_test, k=5)
tbl_pai  = error_table(None,  pai_prob_test,  pai_thr_val,  "Dendrites", y=pai_y_test, k=5)

display(tbl_base)
display(tbl_pai)


Baseline — top 5 FP and FN (VAL-chosen thr=0.869)

Dendrites — top 5 FP and FN (VAL-chosen thr=0.479)


Unnamed: 0,i,y,prob,pred,type,query,bucket
45,45,0,0.924424,1,FP,,FP
46,46,0,0.920129,1,FP,,FP
36,36,0,0.892958,1,FP,,FP
37,37,0,0.892958,1,FP,,FP
48,48,0,0.89221,1,FP,,FP
7,7,1,0.797991,0,FN,,FN
6,6,1,0.812066,0,FN,,FN
2,2,1,0.831572,0,FN,,FN
0,0,1,0.850733,0,FN,,FN


Unnamed: 0,i,y,prob,pred,type,query,bucket
46,46,0,0.478707,1,FP,,FP
31,31,0,0.478686,1,FP,,FP
24,24,0,0.478621,1,FP,,FP
29,29,0,0.478587,1,FP,,FP
45,45,0,0.478582,1,FP,,FP
1,1,1,0.478118,0,FN,,FN
6,6,1,0.47822,0,FN,,FN
5,5,1,0.478273,0,FN,,FN
3,3,1,0.478281,0,FN,,FN
7,7,1,0.478356,0,FN,,FN


In [None]:
import numpy as np
import pandas as pd

# indices you showed in the screenshot (union)
idx = sorted(set([45,46,36,37,48,7,6,2,0,31,24,29,1,5,3]))

def get_query(ds=None, df=None, i=None, query_key="query"):
    if df is not None:
        # pandas df
        if query_key in df.columns:
            return str(df.loc[i, query_key])
        # fallback: first column that looks like query
        for c in df.columns:
            if "query" in c.lower():
                return str(df.loc[i, c])
        return ""
    if ds is not None:
        ex = ds[i]
        if hasattr(ex, "get"):
            return str(ex.get(query_key, ""))  # list[dict] / HF row-like
        try:
            return str(ex[query_key])
        except Exception:
            return ""
    return ""

def make_case_table(indices, *, ds=None, df=None,
                    base_probs=None, dend_probs=None,
                    y=None, base_thr=None, dend_thr=None,
                    query_key="query"):
    base_probs = np.asarray(base_probs).astype(float)
    dend_probs = np.asarray(dend_probs).astype(float)
    y = np.asarray(y).astype(int)

    rows = []
    for i in indices:
        q = get_query(ds=ds, df=df, i=i, query_key=query_key)
        bprob = float(base_probs[i]); dprob = float(dend_probs[i])
        bpred = int(bprob >= base_thr); dpred = int(dprob >= dend_thr)
        yi = int(y[i])

        def t(pred, y):
            return "TP" if pred==1 and y==1 else "FP" if pred==1 and y==0 else "FN" if pred==0 and y==1 else "TN"

        rows.append({
            "i": i,
            "y": yi,
            "baseline_prob": bprob,
            "baseline_pred": bpred,
            "baseline_type": t(bpred, yi),
            "dend_prob": dprob,
            "dend_pred": dpred,
            "dend_type": t(dpred, yi),
            "query": q
        })

    out = pd.DataFrame(rows).sort_values("i").reset_index(drop=True)
    return out

# ---- Pick ONE of these sources for queries ----
# Option A: you have a dataset object (HF Dataset / list of dicts)
# ds_test = ...

# Option B: you have a dataframe with a query column
# df_test = ...

cases = make_case_table(
    idx,
    ds=globals().get("ds_test", None),   # if you named it ds_test
    df=globals().get("df_test", None),   # if you named it df_test
    base_probs=base_prob_test,
    dend_probs=pai_prob_test,
    y=base_y_test,                       # labels (same for both)
    base_thr=base_thr_val,
    dend_thr=pai_thr_val,
    query_key="query"
)

display(cases)

Unnamed: 0,i,y,baseline_prob,baseline_pred,baseline_type,dend_prob,dend_pred,dend_type,query
0,0,1,0.850733,0,FN,0.478468,0,FN,Please plan a trip for me starting from Saraso...
1,1,1,0.868811,1,TP,0.478118,0,FN,Seeking assistance to develop a travel itinera...
2,2,1,0.831572,0,FN,0.478682,1,TP,Please create a travel plan that starts in Cha...
3,3,1,0.879389,1,TP,0.478281,0,FN,Please design a travel plan that departs from ...
4,5,1,0.880438,1,TP,0.478273,0,FN,Please help me plan a travel itinerary from Ph...
5,6,1,0.812066,0,FN,0.47822,0,FN,Please arrange a travel plan for one person fr...
6,7,1,0.797991,0,FN,0.478356,0,FN,Please help me build a travel plan that depart...
7,24,0,0.853872,0,TN,0.478621,1,FP,Could you help with a travel itinerary that be...
8,29,0,0.829239,0,TN,0.478587,1,FP,Please help me prepare a 3-day travel plan fro...
9,31,0,0.85828,0,TN,0.478686,1,FP,"Please organize a travel plan for one, leaving..."


In [None]:
print("Have ds_test?", "ds_test" in globals())
print("Have df_test?", "df_test" in globals())

Have ds_test? True
Have df_test? False


In [None]:
bpred = (np.asarray(base_prob_test) >= base_thr_val).astype(int)
dpred = (np.asarray(pai_prob_test)  >= pai_thr_val).astype(int)
y = np.asarray(base_y_test).astype(int)

disagree = np.where(bpred != dpred)[0]

# Sort disagreements by "confidence gap"
gap = np.abs(np.asarray(base_prob_test) - np.asarray(pai_prob_test))
top = disagree[np.argsort(-gap[disagree])][:15]

demo = make_case_table(
    top.tolist(),
    ds=globals().get("ds_test", None),
    df=globals().get("df_test", None),
    base_probs=base_prob_test,
    dend_probs=pai_prob_test,
    y=y,
    base_thr=base_thr_val,
    dend_thr=pai_thr_val,
)

display(demo)
print("Total disagreements:", len(disagree))

Unnamed: 0,i,y,baseline_prob,baseline_pred,baseline_type,dend_prob,dend_pred,dend_type,query
0,3,1,0.879389,1,TP,0.478281,0,FN,Please design a travel plan that departs from ...
1,4,1,0.922296,1,TP,0.478525,0,FN,Please create a travel itinerary for a solo tr...
2,5,1,0.880438,1,TP,0.478273,0,FN,Please help me plan a travel itinerary from Ph...
3,21,0,0.876423,1,FP,0.478179,0,TN,Please craft a travel plan starting from Dalla...
4,33,0,0.877251,1,FP,0.478206,0,TN,Kindly assist in plotting a travel plan starti...
5,35,0,0.877251,1,FP,0.478206,0,TN,Please construct a 3-day travel itinerary from...
6,36,0,0.892958,1,FP,0.478367,0,TN,Could you help create a travel plan that start...
7,37,0,0.892958,1,FP,0.478367,0,TN,Please help me create a travel plan departing ...
8,38,0,0.877251,1,FP,0.478206,0,TN,Can you assist me in creating a 3-day travel p...
9,39,0,0.88344,1,FP,0.478438,0,TN,Could you help organize a travel plan from Mol...


Total disagreements: 34


In [None]:
i = 46  # try 46, 45, 7, 6...
q = get_query(ds=globals().get("ds_test", None), df=globals().get("df_test", None), i=i)

print("QUERY:\n", q)
print("\nTruth y =", int(base_y_test[i]))
print("\nBaseline:  prob =", float(base_prob_test[i]), "pred =", int(base_prob_test[i] >= base_thr_val), "thr =", base_thr_val)
print("Dendrites:  prob =", float(pai_prob_test[i]),  "pred =", int(pai_prob_test[i]  >= pai_thr_val), "thr =", pai_thr_val)

QUERY:
 Could you create a travel plan departing from Gulfport to Charlotte for a 3-day trip from March 5th to 7th, 2022 for an individual? Our budget is now set at $1,200.

Truth y = 0

Baseline:  prob = 0.9201290607452393 pred = 1 thr = 0.8686298727989197
Dendrites:  prob = 0.478707492351532 pred = 1 thr = 0.47854703664779663


In [None]:
import numpy as np
import pandas as pd

# --- pick the dataset variable automatically ---
def pick_first_defined(names):
    for n in names:
        if n in globals():
            return globals()[n], n
    return None, None

ds_test, ds_name = pick_first_defined(["test_dataset", "test_ds", "ds_test", "test_data", "dataset_test"])
assert ds_test is not None, "Couldn't find your test dataset variable. Rename it to test_dataset or set ds_test manually."

# --- ensure thresholds exist (fallback to your printed values) ---
base_thr_val = globals().get("base_thr_val", 0.8686298727989197)
pai_thr_val  = globals().get("pai_thr_val",  0.47854703664779663)

# --- ensure probs exist ---
for v in ["base_prob_test", "pai_prob_test"]:
    assert v in globals(), f"Missing {v}. Run inference to get probability arrays first."

base_prob_test = np.asarray(base_prob_test).astype(float)
pai_prob_test  = np.asarray(pai_prob_test).astype(float)

print("Using dataset:", ds_name, " | N =", len(ds_test))
print("base_thr_val =", base_thr_val, " | pai_thr_val =", pai_thr_val)
print("base_prob_test shape:", base_prob_test.shape, " pai_prob_test shape:", pai_prob_test.shape)

Using dataset: test_ds  | N = 1000
base_thr_val = 0.8686298727989197  | pai_thr_val = 0.47854703664779663
base_prob_test shape: (72,)  pai_prob_test shape: (72,)


# Build disagreement table (baseline vs dendrites)

In [None]:
ex0 = ds_test[0]
print("type(ex0) =", type(ex0))

# Try common ways to introspect
if hasattr(ex0, "__dict__"):
    print("attrs:", list(ex0.__dict__.keys())[:50])

if isinstance(ex0, dict):
    print("dict keys:", list(ex0.keys())[:50])

# Try pretty print if possible
try:
    import pprint
    pprint.pprint(ex0)
except Exception as e:
    print("pprint failed:", e)

type(ex0) = <class 'dict'>
dict keys: ['org', 'dest', 'days', 'date', 'query', 'level', 'reference_information']
{'date': "['2022-03-22', '2022-03-23', '2022-03-24']",
 'days': 3,
 'dest': 'Chicago',
 'level': 'easy',
 'org': 'Sarasota',
 'query': 'Please plan a trip for me starting from Sarasota to Chicago for 3 '
          'days, from March 22nd to March 24th, 2022. The budget for this trip '
          'is set at $1,900.',
 'reference_information': "[{'Description': 'Attractions in Chicago', "
                          '\'Content\': "                                   '
                          'Name  Latitude  '
                          'Longitude                                               '
                          'Address          '
                          'Phone                                                                                                     '
                          'Website    City\\n                              '
                          'Navy P

In [None]:
def _get_any(ex, keys):
    """Try to fetch a value by keys from dict or object attrs."""
    # dict
    if isinstance(ex, dict):
        for k in keys:
            if k in ex:
                return ex[k]
    # object attrs
    for k in keys:
        if hasattr(ex, k):
            return getattr(ex, k)
    return None

def _dig_for_key(ex, target_keys=("label","y","target"), max_depth=4):
    """Recursively search nested dicts/lists/objects for target keys."""
    seen = set()

    def rec(x, depth):
        if depth > max_depth:
            return None
        xid = id(x)
        if xid in seen:
            return None
        seen.add(xid)

        # direct hit (dict)
        if isinstance(x, dict):
            for k in target_keys:
                if k in x:
                    return x[k]
            # recurse values
            for v in x.values():
                out = rec(v, depth+1)
                if out is not None:
                    return out

        # list/tuple
        elif isinstance(x, (list, tuple)):
            for v in x:
                out = rec(v, depth+1)
                if out is not None:
                    return out

        # object: check attrs + recurse into __dict__
        else:
            for k in target_keys:
                if hasattr(x, k):
                    return getattr(x, k)
            if hasattr(x, "__dict__"):
                out = rec(vars(x), depth+1)
                if out is not None:
                    return out

        return None

    return rec(ex, 0)

def get_label(ex):
    # fast paths
    v = _get_any(ex, ["label", "y", "target"])
    if v is None:
        # tuple/list conventional (x,y)
        if isinstance(ex, (tuple, list)) and len(ex) >= 2:
            v = ex[1]
        else:
            # deep search (nested)
            v = _dig_for_key(ex, target_keys=("label","y","target","gold","gt"))
    if v is None:
        # helpful error message
        msg = f"Couldn't find label. type={type(ex)}"
        if isinstance(ex, dict):
            msg += f" keys={list(ex.keys())[:30]}"
        elif hasattr(ex, "__dict__"):
            msg += f" attrs={list(ex.__dict__.keys())[:30]}"
        raise KeyError(msg)

    # normalize to int
    if isinstance(v, (np.ndarray, list, tuple)) and len(np.array(v).shape) == 0:
        v = np.array(v).item()
    return int(v)

def get_query(ex):
    v = _get_any(ex, ["query", "text", "prompt", "instruction"])
    if v is None:
        v = _dig_for_key(ex, target_keys=("query","text","prompt","instruction"))
    return "" if v is None else str(v)

In [None]:
import pandas as pd
import numpy as np

def conf_type(y, pred):
    if pred==1 and y==1: return "TP"
    if pred==1 and y==0: return "FP"
    if pred==0 and y==1: return "FN"
    return "TN"

def make_disagreement_df_from_arrays(ds, y_true, base_probs, dend_probs, base_thr, dend_thr):
    n = min(len(y_true), len(base_probs), len(dend_probs), len(ds))
    ds = ds[:n]
    y_true = np.asarray(y_true)[:n].astype(int)
    base_probs = np.asarray(base_probs)[:n]
    dend_probs = np.asarray(dend_probs)[:n]

    rows = []
    for i in range(n):
        ex = ds[i]
        q = ex.get("query", "") if isinstance(ex, dict) else ""
        bp = float(base_probs[i]); dp = float(dend_probs[i])
        bpred = int(bp >= base_thr)
        dpred = int(dp >= dend_thr)
        y = int(y_true[i])

        rows.append({
            "i": i,
            "y": y,
            "baseline_prob": bp,
            "baseline_pred": bpred,
            "baseline_type": conf_type(y, bpred),
            "dend_prob": dp,
            "dend_pred": dpred,
            "dend_type": conf_type(y, dpred),
            "query": (q[:140] + "…") if len(q) > 140 else q
        })

    df = pd.DataFrame(rows)
    disagree = df[df["baseline_pred"] != df["dend_pred"]].copy()
    return df, disagree

In [None]:
import numpy as np
import pandas as pd

def _get_n(ds):
    # HF Dataset / list-like
    try:
        return len(ds)
    except Exception:
        pass
    # dict-of-columns
    if isinstance(ds, dict):
        # pick a column to infer length
        for v in ds.values():
            try:
                return len(v)
            except Exception:
                continue
    raise TypeError("Can't determine length of ds.")

def _get_row(ds, i):
    # HF Dataset or list of dicts
    try:
        return ds[i]
    except Exception:
        pass

    # pandas DataFrame
    if isinstance(ds, pd.DataFrame):
        return ds.iloc[i].to_dict()

    # dict-of-columns -> reconstruct row dict
    if isinstance(ds, dict):
        row = {}
        for k, v in ds.items():
            try:
                row[k] = v[i]
            except Exception:
                # if column isn't indexable, skip
                continue
        return row

    raise TypeError(f"ds is not indexable: type={type(ds)}")

def conf_type(y, pred):
    if pred==1 and y==1: return "TP"
    if pred==1 and y==0: return "FP"
    if pred==0 and y==1: return "FN"
    return "TN"

def make_disagreement_df_from_arrays(ds, y_true, base_probs, dend_probs, base_thr, dend_thr):
    n = min(_get_n(ds), len(y_true), len(base_probs), len(dend_probs))

    y_true = np.asarray(y_true)[:n].astype(int)
    base_probs = np.asarray(base_probs)[:n]
    dend_probs = np.asarray(dend_probs)[:n]

    rows = []
    for i in range(n):
        ex = _get_row(ds, i)
        q = ex.get("query", "") if isinstance(ex, dict) else ""

        bp = float(base_probs[i]); dp = float(dend_probs[i])
        bpred = int(bp >= base_thr)
        dpred = int(dp >= dend_thr)
        y = int(y_true[i])

        rows.append({
            "i": i,
            "y": y,
            "baseline_prob": bp,
            "baseline_pred": bpred,
            "baseline_type": conf_type(y, bpred),
            "dend_prob": dp,
            "dend_pred": dpred,
            "dend_type": conf_type(y, dpred),
            "query": (q[:140] + "…") if isinstance(q, str) and len(q) > 140 else q
        })

    df = pd.DataFrame(rows)
    disagree = df[df["baseline_pred"] != df["dend_pred"]].copy()
    return df, disagree

In [None]:
df_all, disagree = make_disagreement_df_from_arrays(
    ds_test, base_y_test, base_prob_test, pai_prob_test, base_thr_val, pai_thr_val
)

print("Total disagreements:", len(disagree), "out of", len(df_all))
display(disagree.head(20))

Total disagreements: 34 out of 72


Unnamed: 0,i,y,baseline_prob,baseline_pred,baseline_type,dend_prob,dend_pred,dend_type,query
1,1,1,0.868811,1,TP,0.478118,0,FN,Seeking assistance to develop a travel itinera...
2,2,1,0.831572,0,FN,0.478682,1,TP,Please create a travel plan that starts in Cha...
3,3,1,0.879389,1,TP,0.478281,0,FN,Please design a travel plan that departs from ...
4,4,1,0.922296,1,TP,0.478525,0,FN,Please create a travel itinerary for a solo tr...
5,5,1,0.880438,1,TP,0.478273,0,FN,Please help me plan a travel itinerary from Ph...
16,16,0,0.873877,1,FP,0.478099,0,TN,Could you help me arrange a 3-day trip for one...
17,17,0,0.872792,1,FP,0.478153,0,TN,Please create a travel plan for me starting in...
18,18,0,0.873877,1,FP,0.478099,0,TN,Kindly assist in creating a travel plan for a ...
19,19,0,0.873877,1,FP,0.478099,0,TN,Please devise a travel plan that departs from ...
20,20,0,0.871624,1,FP,0.478155,0,TN,Please provide a travel plan departing from Sa...


In [None]:
import numpy as np
import pandas as pd

# ---------- 0) Grab y_true if you have it (optional) ----------
# If you *do* have labels for this evaluated subset, put it in y_true.
# This tries common variable names automatically.
y_true = None
for cand in ["base_y_test", "y_test", "pai_y_test", "labels_test", "y_true_test"]:
    if cand in globals():
        arr = np.asarray(globals()[cand])
        if arr.ndim == 1:
            y_true = arr.astype(int)
            print(f"Using y_true from `{cand}` with shape {y_true.shape}")
            break
if y_true is None:
    print("No y_true found in globals() — demo will run without ground-truth types.")

# ---------- 1) Robust dataset indexing (handles HF Dataset, list, dict-of-columns) ----------
def safe_get_ex(ds, i):
    if ds is None:
        return {}
    # HF Dataset / list-like
    try:
        ex = ds[i]
        if isinstance(ex, dict):
            return ex
    except Exception:
        pass

    # dict-of-columns (e.g., {"query":[...], "org":[...], ...})
    if isinstance(ds, dict):
        out = {}
        for k, v in ds.items():
            try:
                out[k] = v[i]
            except Exception:
                pass
        return out

    return {}

def get_query_from_ex(ex):
    if not isinstance(ex, dict):
        return ""
    for k in ["query", "text", "prompt", "input", "question"]:
        if k in ex and ex[k] is not None:
            return str(ex[k])
    return ""

def conf_type(y, pred):
    if y is None: return "NA"
    if pred==1 and y==1: return "TP"
    if pred==1 and y==0: return "FP"
    if pred==0 and y==1: return "FN"
    return "TN"

# ---------- 2) Build a unified table for the evaluated subset ----------
def build_demo_df(ds_subset, base_probs, dend_probs, base_thr, dend_thr, y_true=None, q_clip=170):
    base_probs = np.asarray(base_probs).astype(float)
    dend_probs = np.asarray(dend_probs).astype(float)
    n = len(base_probs)
    assert len(dend_probs) == n, "base_probs and dend_probs must have same length"
    if y_true is not None:
        y_true = np.asarray(y_true).astype(int)
        assert len(y_true) == n, f"y_true length {len(y_true)} != probs length {n}"

    rows = []
    for i in range(n):
        ex = safe_get_ex(ds_subset, i)
        q = get_query_from_ex(ex)
        bp = float(base_probs[i])
        dp = float(dend_probs[i])
        bpred = int(bp >= base_thr)
        dpred = int(dp >= dend_thr)
        y = int(y_true[i]) if y_true is not None else None

        rows.append({
            "i": i,
            "y": y if y is not None else -1,
            "query": (q[:q_clip] + "…") if len(q) > q_clip else q,
            "baseline_prob": bp,
            "baseline_pred": bpred,
            "baseline_type": conf_type(y, bpred),
            "dend_prob": dp,
            "dend_pred": dpred,
            "dend_type": conf_type(y, dpred),
            "pred_changed": (bpred != dpred),
            "abs_prob_gap": abs(bp - dp),
        })

    df = pd.DataFrame(rows)
    return df

# You should already have these in your notebook:
#   ds_test (or your 72-example subset), base_prob_test, pai_prob_test, base_thr_val, pai_thr_val
df = build_demo_df(ds_test, base_prob_test, pai_prob_test, base_thr_val, pai_thr_val, y_true=y_true)
display(df.head(5))

print("N =", len(df), "| disagreements =", int(df["pred_changed"].sum()))

# ---------- 3) Auto-pick 5 showcase cases ----------
def pick_showcase(df, k_each=2, final_k=5):
    # If we have y labels, pick meaningful buckets. Otherwise pick high-gap disagreements.
    has_y = (df["baseline_type"].iloc[0] != "NA")

    picks = []

    if has_y:
        # 1) Baseline FP -> Dend TN  (dend fixes)
        a = df[(df.y==0) & (df.baseline_pred==1) & (df.dend_pred==0)].sort_values("baseline_prob", ascending=False).head(k_each)
        picks.append(("Baseline FP → Dend TN (fixed)", a))

        # 2) Baseline TP -> Dend FN  (dend hurts)
        b = df[(df.y==1) & (df.baseline_pred==1) & (df.dend_pred==0)].sort_values("baseline_prob", ascending=False).head(k_each)
        picks.append(("Baseline TP → Dend FN (lost)", b))

        # 3) Both FP (hard negatives)
        c = df[(df.y==0) & (df.baseline_pred==1) & (df.dend_pred==1)].sort_values("baseline_prob", ascending=False).head(1)
        picks.append(("Both FP (hard negative)", c))

        # 4) Both FN (hard positives)
        d = df[(df.y==1) & (df.baseline_pred==0) & (df.dend_pred==0)].sort_values("baseline_prob", ascending=True).head(1)
        picks.append(("Both FN (hard positive)", d))

        out = pd.concat([x for _, x in picks if len(x)], axis=0).drop_duplicates(subset=["i"])
        # If we still have too many, keep the most "interesting" by prob gap
        out = out.sort_values("abs_prob_gap", ascending=False).head(final_k)
        return out, picks
    else:
        # no labels: pick disagreements with biggest gap + a couple stable high-confidence
        dis = df[df.pred_changed].sort_values("abs_prob_gap", ascending=False).head(final_k)
        if len(dis) < final_k:
            stable = df[~df.pred_changed].sort_values(["baseline_prob"], ascending=False).head(final_k - len(dis))
            dis = pd.concat([dis, stable], axis=0)
        return dis, [("No y_true: picked by abs prob gap / confidence", dis)]

show_df, bucket_info = pick_showcase(df, k_each=2, final_k=5)
display(show_df[["i","y","baseline_prob","baseline_pred","baseline_type","dend_prob","dend_pred","dend_type","query"]])

# ---------- 4) Pretty-print “Input → Output” demo (copy/paste into writeup) ----------
def print_demo_cases(df_cases, base_thr, dend_thr, header="INPUT → OUTPUT DEMO"):
    print("\n" + "="*90)
    print(header)
    print("="*90)
    for _, r in df_cases.iterrows():
        i = int(r["i"])
        y = None if int(r["y"]) == -1 else int(r["y"])
        q = r["query"]

        print(f"\nCase i={i}" + (f" | y={y}" if y is not None else " | y=NA"))
        print("-"*90)
        print("QUERY:")
        print(q)

        print("\nBaseline:")
        print(f"  prob={r['baseline_prob']:.6f}  pred={int(r['baseline_pred'])}  thr={base_thr:.6f}  type={r['baseline_type']}")
        print("Dendrites:")
        print(f"  prob={r['dend_prob']:.6f}  pred={int(r['dend_pred'])}  thr={dend_thr:.6f}  type={r['dend_type']}")

        if r["baseline_pred"] != r["dend_pred"]:
            who = "Dendrites flips the decision"
        else:
            who = "Same decision"
        print(f"\nResult: {who} | abs_prob_gap={r['abs_prob_gap']:.6f}")
    print("\n" + "="*90)

print_demo_cases(show_df, base_thr_val, pai_thr_val)

# ---------- 5) (Optional) Save as markdown text block for Devpost ----------
md_lines = []
md_lines.append("### Input → Output Demo (VAL-chosen thresholds, TEST eval)\n")
md_lines.append(f"- Baseline thr (VAL): `{base_thr_val:.6f}`")
md_lines.append(f"- Dendrites thr (VAL): `{pai_thr_val:.6f}`\n")
for _, r in show_df.iterrows():
    i = int(r["i"])
    y = None if int(r["y"]) == -1 else int(r["y"])
    md_lines.append(f"**Case {i}**" + (f" (y={y})" if y is not None else " (y=NA)"))
    md_lines.append(f"- Query: {r['query']}")
    md_lines.append(f"- Baseline: prob={r['baseline_prob']:.6f}, pred={int(r['baseline_pred'])}, type={r['baseline_type']}")
    md_lines.append(f"- Dendrites: prob={r['dend_prob']:.6f}, pred={int(r['dend_pred'])}, type={r['dend_type']}\n")

demo_md = "\n".join(md_lines)
print("\n\n--- COPY BELOW INTO YOUR WRITEUP ---\n")
print(demo_md)

Using y_true from `base_y_test` with shape (72,)


Unnamed: 0,i,y,query,baseline_prob,baseline_pred,baseline_type,dend_prob,dend_pred,dend_type,pred_changed,abs_prob_gap
0,0,1,Please plan a trip for me starting from Saraso...,0.850733,0,FN,0.478468,0,FN,False,0.372265
1,1,1,Seeking assistance to develop a travel itinera...,0.868811,1,TP,0.478118,0,FN,True,0.390693
2,2,1,Please create a travel plan that starts in Cha...,0.831572,0,FN,0.478682,1,TP,True,0.35289
3,3,1,Please design a travel plan that departs from ...,0.879389,1,TP,0.478281,0,FN,True,0.401108
4,4,1,Please create a travel itinerary for a solo tr...,0.922296,1,TP,0.478525,0,FN,True,0.443771


N = 72 | disagreements = 34


Unnamed: 0,i,y,baseline_prob,baseline_pred,baseline_type,dend_prob,dend_pred,dend_type,query
45,45,0,0.924424,1,FP,0.478582,1,FP,Could you create a 3-day travel plan for one p...
4,4,1,0.922296,1,TP,0.478525,0,FN,Please create a travel itinerary for a solo tr...
36,36,0,0.892958,1,FP,0.478367,0,TN,Could you help create a travel plan that start...
37,37,0,0.892958,1,FP,0.478367,0,TN,Please help me create a travel plan departing ...
5,5,1,0.880438,1,TP,0.478273,0,FN,Please help me plan a travel itinerary from Ph...



INPUT → OUTPUT DEMO

Case i=45 | y=0
------------------------------------------------------------------------------------------
QUERY:
Could you create a 3-day travel plan for one person, departing from Houston and heading to Punta Gorda from March 20th to March 22nd, 2022, with a budget of $1,700?

Baseline:
  prob=0.924424  pred=1  thr=0.868630  type=FP
Dendrites:
  prob=0.478582  pred=1  thr=0.478547  type=FP

Result: Same decision | abs_prob_gap=0.445842

Case i=4 | y=1
------------------------------------------------------------------------------------------
QUERY:
Please create a travel itinerary for a solo traveler departing from Jacksonville and heading to Los Angeles for a period of 3 days, specifically from March 25th to March …

Baseline:
  prob=0.922296  pred=1  thr=0.868630  type=TP
Dendrites:
  prob=0.478525  pred=0  thr=0.478547  type=FN

Result: Dendrites flips the decision | abs_prob_gap=0.443771

Case i=36 | y=0
-------------------------------------------------------

In [None]:
# 1) Transition summary: what flipped into what?
df["transition"] = df["baseline_type"].astype(str) + " → " + df["dend_type"].astype(str)

summary = (
    df.groupby("transition")
      .size()
      .reset_index(name="count")
      .sort_values("count", ascending=False)
)

print("Transition counts:")
display(summary)

# Also show just flips vs no flips
flip_summary = (
    df.assign(flipped=df["baseline_pred"] != df["dend_pred"])
      .groupby(["flipped", "transition"]).size()
      .reset_index(name="count")
      .sort_values(["flipped","count"], ascending=[False, False])
)
print("Flip-only breakdown:")
display(flip_summary)


# 2) Pick a better showcase: 2 fixed FPs + 2 hurt TPs + 1 hard case (both FP or both FN)
def pick_cases(df, transition, k=2, sort_by="abs_prob_gap", ascending=False):
    sub = df[df["transition"] == transition].copy()
    if len(sub) == 0:
        return sub
    return sub.sort_values(sort_by, ascending=ascending).head(k)

# Buckets (edit k’s if you want)
fixed_fp = pick_cases(df, "FP → TN", k=2, sort_by="baseline_prob", ascending=False)   # dend fixed baseline over-triggering
fixed_fn = pick_cases(df, "FN → TP", k=1, sort_by="dend_prob", ascending=False)       # dend rescued a miss (optional)
hurt_tp  = pick_cases(df, "TP → FN", k=2, sort_by="baseline_prob", ascending=False)   # dend lost true positives
hard_fp  = pick_cases(df, "FP → FP", k=1, sort_by="baseline_prob", ascending=False)   # both wrong (hard negative)
hard_fn  = pick_cases(df, "FN → FN", k=1, sort_by="baseline_prob", ascending=True)    # both wrong (hard positive)

# Choose ONE hard case (prefer FP→FP if it exists, else FN→FN)
hard = hard_fp if len(hard_fp) else hard_fn

show2 = pd.concat([fixed_fp, fixed_fn, hurt_tp, hard], axis=0).drop_duplicates(subset=["i"]).head(5)

display(show2[["i","y","baseline_prob","baseline_pred","baseline_type",
               "dend_prob","dend_pred","dend_type","transition","abs_prob_gap","query"]])

# 3) Print the nicer demo output for the new selection
print_demo_cases(show2, base_thr_val, pai_thr_val, header="INPUT → OUTPUT DEMO (Curated Buckets)")

# 4) Generate a writeup-friendly markdown block for the new selection
md_lines = []
md_lines.append("### Input → Output Demo (curated cases)\n")
md_lines.append(f"- Baseline thr (VAL): `{base_thr_val:.6f}`")
md_lines.append(f"- Dendrites thr (VAL): `{pai_thr_val:.6f}`\n")
md_lines.append("Buckets: **FP→TN (fixed over-triggering)**, **FN→TP (rescued)**, **TP→FN (lost)**, **hard case (both wrong)**.\n")

for _, r in show2.iterrows():
    i = int(r["i"])
    y = int(r["y"])
    md_lines.append(f"**Case {i}** (y={y}) — {r['transition']}")
    md_lines.append(f"- Query: {r['query']}")
    md_lines.append(f"- Baseline: prob={r['baseline_prob']:.6f}, pred={int(r['baseline_pred'])}, type={r['baseline_type']}")
    md_lines.append(f"- Dendrites: prob={r['dend_prob']:.6f}, pred={int(r['dend_pred'])}, type={r['dend_type']}\n")

demo_md2 = "\n".join(md_lines)
print("\n\n--- COPY BELOW INTO YOUR WRITEUP (CURATED) ---\n")
print(demo_md2)

Transition counts:


Unnamed: 0,transition,count
5,TN → TN,33
3,FP → TN,23
4,TN → FP,6
6,TP → FN,4
0,FN → FN,3
2,FP → FP,2
1,FN → TP,1


Flip-only breakdown:


Unnamed: 0,flipped,transition,count
4,True,FP → TN,23
5,True,TN → FP,6
6,True,TP → FN,4
3,True,FN → TP,1
2,False,TN → TN,33
0,False,FN → FN,3
1,False,FP → FP,2


Unnamed: 0,i,y,baseline_prob,baseline_pred,baseline_type,dend_prob,dend_pred,dend_type,transition,abs_prob_gap,query
37,37,0,0.892958,1,FP,0.478367,0,TN,FP → TN,0.414591,Please help me create a travel plan departing ...
36,36,0,0.892958,1,FP,0.478367,0,TN,FP → TN,0.414591,Could you help create a travel plan that start...
2,2,1,0.831572,0,FN,0.478682,1,TP,FN → TP,0.35289,Please create a travel plan that starts in Cha...
4,4,1,0.922296,1,TP,0.478525,0,FN,TP → FN,0.443771,Please create a travel itinerary for a solo tr...
5,5,1,0.880438,1,TP,0.478273,0,FN,TP → FN,0.402165,Please help me plan a travel itinerary from Ph...



INPUT → OUTPUT DEMO (Curated Buckets)

Case i=37 | y=0
------------------------------------------------------------------------------------------
QUERY:
Please help me create a travel plan departing from Atlanta and heading to Minneapolis for 3 days, from March 5th to March 7th, 2022, within a budget of $1,900.

Baseline:
  prob=0.892958  pred=1  thr=0.868630  type=FP
Dendrites:
  prob=0.478367  pred=0  thr=0.478547  type=TN

Result: Dendrites flips the decision | abs_prob_gap=0.414591

Case i=36 | y=0
------------------------------------------------------------------------------------------
QUERY:
Could you help create a travel plan that starts in Atlanta and ends in Bozeman, taking place over 3 days from March 25th to March 27th, 2022? The travel budget at hand is…

Baseline:
  prob=0.892958  pred=1  thr=0.868630  type=FP
Dendrites:
  prob=0.478367  pred=0  thr=0.478547  type=TN

Result: Dendrites flips the decision | abs_prob_gap=0.414591

Case i=2 | y=1
---------------------------

# Add calibration + Brier score

In [None]:
import numpy as np
from sklearn.metrics import brier_score_loss

def calibration_bins(y, p, n_bins=10):
    y = np.asarray(y); p = np.asarray(p)
    bins = np.linspace(0, 1, n_bins+1)
    idx = np.digitize(p, bins) - 1
    rows = []
    for b in range(n_bins):
        m = idx == b
        if m.sum() == 0:
            continue
        rows.append({
            "bin": b,
            "count": int(m.sum()),
            "p_mean": float(p[m].mean()),
            "y_rate": float(y[m].mean()),
            "gap(p-y)": float(p[m].mean() - y[m].mean())
        })
    return pd.DataFrame(rows)

print("Brier (lower=better)")
print("Baseline:", brier_score_loss(base_y_test, base_prob_test))
print("Dendrites:", brier_score_loss(base_y_test, pai_prob_test))

display(calibration_bins(base_y_test, base_prob_test, 10))
display(calibration_bins(base_y_test, pai_prob_test, 10))

Brier (lower=better)
Baseline: 0.6524088877377674
Dendrites: 0.23362196985091416


Unnamed: 0,bin,count,p_mean,y_rate,gap(p-y)
0,7,1,0.797991,1.0,-0.202009
1,8,68,0.852562,0.088235,0.764327
2,9,3,0.922283,0.333333,0.588949


Unnamed: 0,bin,count,p_mean,y_rate,gap(p-y)
0,4,72,0.478345,0.111111,0.367234


# Threshold sweep table (VAL + TEST)

In [None]:
from sklearn.metrics import precision_recall_fscore_support

def sweep_thresholds(y, p, thrs=np.linspace(0,1,101)):
    out = []
    for t in thrs:
        pred = (p >= t).astype(int)
        pr, rc, f1, _ = precision_recall_fscore_support(y, pred, average="binary", zero_division=0)
        out.append({"thr": float(t), "precision": pr, "recall": rc, "f1": f1})
    return pd.DataFrame(out)

base_sweep = sweep_thresholds(base_y_test, base_prob_test)
dend_sweep = sweep_thresholds(base_y_test, pai_prob_test)

display(base_sweep.sort_values("f1", ascending=False).head(10))
display(dend_sweep.sort_values("f1", ascending=False).head(10))

Unnamed: 0,thr,precision,recall,f1
88,0.88,0.181818,0.25,0.210526
1,0.01,0.111111,1.0,0.2
2,0.02,0.111111,1.0,0.2
3,0.03,0.111111,1.0,0.2
0,0.0,0.111111,1.0,0.2
5,0.05,0.111111,1.0,0.2
6,0.06,0.111111,1.0,0.2
7,0.07,0.111111,1.0,0.2
8,0.08,0.111111,1.0,0.2
9,0.09,0.111111,1.0,0.2


Unnamed: 0,thr,precision,recall,f1
0,0.0,0.111111,1.0,0.2
1,0.01,0.111111,1.0,0.2
2,0.02,0.111111,1.0,0.2
3,0.03,0.111111,1.0,0.2
4,0.04,0.111111,1.0,0.2
5,0.05,0.111111,1.0,0.2
6,0.06,0.111111,1.0,0.2
7,0.07,0.111111,1.0,0.2
8,0.08,0.111111,1.0,0.2
9,0.09,0.111111,1.0,0.2


# Bootstrap confidence intervals

In [None]:
from sklearn.metrics import average_precision_score

def bootstrap_ci(y, p, metric_fn, n=2000, seed=0):
    rng = np.random.default_rng(seed)
    y = np.asarray(y); p = np.asarray(p)
    vals = []
    for _ in range(n):
        idx = rng.integers(0, len(y), len(y))
        vals.append(metric_fn(y[idx], p[idx]))
    vals = np.sort(vals)
    return float(vals[int(0.025*n)]), float(vals[int(0.975*n)]), float(np.mean(vals))

ap_fn = lambda yy, pp: average_precision_score(yy, pp)

base_ci = bootstrap_ci(base_y_test, base_prob_test, ap_fn)
dend_ci = bootstrap_ci(base_y_test, pai_prob_test, ap_fn)

print("AP bootstrap 95% CI")
print("Baseline:", base_ci)
print("Dendrites:", dend_ci)

AP bootstrap 95% CI
Baseline: (0.06059035756962956, 0.4809523809523809, 0.22810234975282237)
Dendrites: (0.05963553085001597, 0.4530797644005191, 0.20664608167091894)


# Error slicing by metadata (easy/medium/hard or days)

In [None]:
from sklearn.metrics import average_precision_score

def slice_ap(ds, y, p, key):
    rows = []
    for i in range(len(y)):
        ex = ds[i]
        rows.append({"i": i, key: ex.get(key, None)})
    meta = pd.DataFrame(rows)
    meta["y"] = y
    meta["p"] = p
    out = []
    for g, sub in meta.groupby(key):
        if sub["y"].nunique() < 2:
            continue
        out.append({
            key: g,
            "n": len(sub),
            "pos_rate": float(sub["y"].mean()),
            "AP": float(average_precision_score(sub["y"], sub["p"]))
        })
    return pd.DataFrame(out).sort_values("AP", ascending=False)

display(slice_ap(ds_test, base_y_test, base_prob_test, "level"))
display(slice_ap(ds_test, base_y_test, pai_prob_test, "level"))

display(slice_ap(ds_test, base_y_test, base_prob_test, "days"))
display(slice_ap(ds_test, base_y_test, pai_prob_test, "days"))

Unnamed: 0,level,n,pos_rate,AP
0,easy,72,0.111111,0.186575


Unnamed: 0,level,n,pos_rate,AP
0,easy,72,0.111111,0.168212


Unnamed: 0,days,n,pos_rate,AP
0,3,72,0.111111,0.186575


Unnamed: 0,days,n,pos_rate,AP
0,3,72,0.111111,0.168212


In [None]:
def show_case(ds, i):
    ex = ds[i]
    q = ex.get("query","")
    ref = ex.get("reference_information","")
    print("QUERY:\n", q)
    print("\nref_information chars:", len(ref))

show_case(ds_test, 0)

QUERY:
 Please plan a trip for me starting from Sarasota to Chicago for 3 days, from March 22nd to March 24th, 2022. The budget for this trip is set at $1,900.

ref_information chars: 13378


In [None]:
!pip -q install wandb plotly
import os, math
import numpy as np
import pandas as pd
import wandb
import matplotlib.pyplot as plt

import plotly.express as px

In [None]:
wandb.login()

True

# Load sweep into a DataFrame

In [None]:
import wandb
api = wandb.Api()

viewer = api.viewer
print("Logged in as:", viewer.username)

# teams can be in different places depending on version
if hasattr(viewer, "teams"):
    ts = viewer.teams
    print("Teams I can access:")
    for t in ts:
        # t can be string or Team object
        print(" -", getattr(t, "name", t))
else:
    print("No teams attribute on viewer in this wandb version.")

Logged in as: npnallstar
Teams I can access:
 - npnallstar
 - vtpy


In [None]:
import wandb
api = wandb.Api()

entities = ["npnallstar", "vtpy"]

for ent in entities:
    try:
        projs = list(api.projects(ent))
        print(f"\nEntity: {ent} | {len(projs)} projects")
        for p in projs:
            print(" -", p.name)
    except Exception as e:
        print(f"\nEntity: {ent} | ERROR:", e)


Entity: npnallstar | 0 projects

Entity: vtpy | 1 projects
 - dendrites-hackathon


In [None]:
ENTITY = "vtpy"
PROJECT = "dendrites-hackathon"

proj = api.project(PROJECT, entity=ENTITY)
print("Project OK:", proj.name)

Project OK: dendrites-hackathon


In [None]:
import wandb
api = wandb.Api()

target = "protbert_baseline_splitA"
entities = ["npnallstar", "vtpy"]

found = []
for ent in entities:
    try:
        for p in api.projects(ent):
            path = f"{ent}/{p.name}"
            runs = api.runs(path, per_page=50)
            for r in runs:
                if r.name == target:
                    print("FOUND:", path, "| run id:", r.id, "| url:", r.url)
                    found.append((path, r.url))
                    raise StopIteration
    except StopIteration:
        break
    except Exception:
        pass

if not found:
    print("Not found in last 50 runs per project. Increase per_page or search by regex.")

FOUND: vtpy/dendrites-hackathon | run id: 130atbjh | url: https://wandb.ai/vtpy/dendrites-hackathon/runs/130atbjh


In [None]:
import wandb
api = wandb.Api()

ENTITY = "vtpy"
PROJECT = "dendrites-hackathon"

proj = api.project(PROJECT, entity=ENTITY)
print("Project OK:", f"{ENTITY}/{proj.name}")

sweeps = proj.sweeps()
print("Sweeps found:", len(sweeps))
for s in sweeps[:30]:
    print("id:", s.id, "| name:", s.name, "| state:", s.state, "| url:", s.url)

Project OK: vtpy/dendrites-hackathon
Sweeps found: 0


In [None]:
import pandas as pd, wandb
api = wandb.Api()

ENTITY="vtpy"
PROJECT="dendrites-hackathon"

runs = list(api.runs(f"{ENTITY}/{PROJECT}", per_page=2000))
print("Runs:", len(runs))

def runs_to_df(runs):
    rows = []
    for r in runs:
        summ = dict(r.summary)
        cfg  = {k:v for k,v in dict(r.config).items() if not k.startswith("_")}
        rows.append({
            "run_id": r.id,
            "run_name": r.name,
            "state": r.state,
            "url": r.url,
            **summ,
            **cfg
        })
    return pd.DataFrame(rows)

df = runs_to_df(runs)
print("df shape:", df.shape)
df.head()

Runs: 1
df shape: (1, 23)


Unnamed: 0,run_id,run_name,state,url,_runtime,_step,_timestamp,_wandb,epoch,params_total,...,train_loss,val_acc,val_f1,lr,kind,epochs,max_len,batch_size,grad_accum,model_name
0,130atbjh,protbert_baseline_splitA,crashed,https://wandb.ai/vtpy/dendrites-hackathon/runs...,5849,6500,1766014000.0,{'runtime': 5849},1,419933186,...,0.092017,0.96853,0,2e-05,baseline,2,128,4,8,Rostlab/prot_bert_bfd


In [None]:
def pick_first(cols, candidates):
    for c in candidates:
        if c in cols: return c
    return None

cols = set(df.columns)

x = pick_first(cols, ["params_trainable","final_param_count","param_count","n_params","params"])
y = pick_first(cols, ["test_f1","test_acc","final_max_test","test_score","val_f1","val_acc"])

mode = pick_first(cols, ["dendrite_mode","model_format","method","use_dendrites","dendrites"])

print("x:", x, "| y:", y, "| mode:", mode)

x: params_trainable | y: test_f1 | mode: None


In [None]:
sorted([c for c in df.columns if "test" in c.lower() or "acc" in c.lower() or "f1" in c.lower()])[:80]

['grad_accum', 'test_acc', 'test_f1', 'val_acc', 'val_f1']

In [None]:
import plotly.express as px

d = df.dropna(subset=[x, y]).copy()

fig = px.scatter(
    d,
    x=x, y=y,
    color=mode if mode else None,
    hover_data=["run_name","run_id","url"]
)
fig.update_xaxes(type="log")
fig.show()

In [None]:
baseline_target = d[y].quantile(0.75)
dz = d[d[y] >= baseline_target].copy()

fig2 = px.scatter(
    dz, x=x, y=y,
    color=mode if mode else None,
    hover_data=["run_name","run_id","url"]
)
fig2.update_xaxes(type="log")
fig2.show()

print("Zoom filter:", baseline_target, "| kept:", len(dz), "of", len(d))

Zoom filter: 0.0 | kept: 1 of 1


In [None]:
import pandas as pd
import plotly.express as px

metric = y
d2 = d.dropna(subset=[metric]).copy()

# numeric columns only
num_cols = [c for c in d2.columns if pd.api.types.is_numeric_dtype(d2[c])]
# keep a manageable number (edit candidates to match your run config keys)
hp_candidates = [
    "dropout","learning_rate","lr","weight_decay",
    "num_conv","num_linear","hidden_size",
    "noise_std","max_dendrites","switch_threshold",
]
dims = [c for c in [x] + hp_candidates if c in num_cols]
if metric not in dims: dims = [metric] + dims
else: dims = [metric] + [c for c in dims if c != metric]

topN = min(150, len(d2))
dtop = d2.sort_values(metric, ascending=False).head(topN)

figp = px.parallel_coordinates(
    dtop,
    dimensions=dims,
    color=metric
)
figp.show()

In [None]:
cfg = {
    "kind": "baseline",          # or "dendrites"
    "lr": 2e-5,
    "batch_size": 4,
    "max_len": 128,
    "epochs": 2,
    "grad_accum": 8,
    "model_name": "Rostlab/prot_bert_bfd",
    "split": "A",
}

In [None]:
[c for c in globals().keys() if "model" in c.lower()][:50]

['train_model',
 'baseline_model',
 'make_dendrites_model',
 'pai_model',
 'base_best_model',
 'pai_best_model',
 'model_size_mb_state_dict',
 'base_model']

In [None]:
# Baseline eval
model = base_best_model if "base_best_model" in globals() else baseline_model

# Dendrites eval
# (your naming suggests PAI = dendrites)
dend_model = pai_best_model if "pai_best_model" in globals() else pai_model

# eval collector (binary or multiclass)

In [None]:
import torch
import numpy as np

def eval_collect_preds(model, loader, device, threshold=0.5):
    model.eval()
    model.to(device)

    y_true_all, y_pred_all = [], []

    with torch.no_grad():
        for batch in loader:
            # --- move batch to device (works for dict batches)
            if isinstance(batch, dict):
                y_true = batch.get("labels", batch.get("label"))
                inputs = {k: v.to(device) for k, v in batch.items() if k not in ["labels","label"]}
                y_true = y_true.cpu().numpy()
                out = model(**inputs)
                logits = out.logits if hasattr(out, "logits") else out
            else:
                # tuple style (x, y)
                x, y_true = batch
                x = x.to(device)
                y_true = y_true.cpu().numpy()
                logits = model(x)

            logits = logits.detach().cpu()

            # --- convert logits -> predictions
            if logits.ndim == 1 or logits.shape[-1] == 1:
                # binary (logits)
                prob = torch.sigmoid(logits.view(-1)).numpy()
                y_pred = (prob >= threshold).astype(int)
            else:
                # multiclass
                y_pred = torch.argmax(logits, dim=-1).numpy()

            y_true_all.append(y_true)
            y_pred_all.append(y_pred)

    y_true_all = np.concatenate(y_true_all)
    y_pred_all = np.concatenate(y_pred_all)
    return y_true_all, y_pred_all

In [None]:
[x for x in globals().keys() if ("test" in x.lower() and ("loader" in x.lower() or "dl" in x.lower()))]

['test_dl']

In [None]:
import wandb
from sklearn.metrics import f1_score, accuracy_score
import numpy as np
import torch

# ---- start a run (do this before wandb.log)
run = wandb.init(project="dendrites-hackathon", entity="vtpy", reinit=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
threshold = 0.5  # or base_thr_val / pai_thr_val

# choose model
model = base_best_model
# model = pai_best_model

y_true, y_pred = eval_collect_preds(model, test_dl, device, threshold=threshold)

avg = "binary" if len(np.unique(y_true)) <= 2 else "macro"
test_f1  = float(f1_score(y_true, y_pred, average=avg, zero_division=0))
test_acc = float(accuracy_score(y_true, y_pred))

wandb.log({
    "test_f1": test_f1,
    "test_acc": test_acc,
    "threshold": float(threshold),
})

print("test_f1:", test_f1, "test_acc:", test_acc)

wandb.finish()



test_f1: 0.2 test_acc: 0.1111111111111111


0,1
test_acc,▁
test_f1,▁
threshold,▁

0,1
test_acc,0.11111
test_f1,0.2
threshold,0.5


In [None]:
import wandb
import numpy as np
import torch
from sklearn.metrics import f1_score, accuracy_score, average_precision_score, precision_recall_curve, confusion_matrix

ENTITY  = "vtpy"
PROJECT = "dendrites-hackathon"

def count_params(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable

def log_pr_curve(y_true, y_prob, title="pr_curve"):
    # small + reliable: log as a table (W&B can chart it)
    p, r, t = precision_recall_curve(y_true, y_prob)
    tbl = wandb.Table(columns=["precision","recall","threshold"])
    for i in range(len(t)):
        tbl.add_data(float(p[i]), float(r[i]), float(t[i]))
    wandb.log({title: tbl})

def run_eval_wandb(model, test_dl, threshold=0.5, kind="baseline", extra_cfg=None):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device).eval()

    # You said you already have eval_collect_preds; it should return y_true, y_pred, y_prob (if possible)
    # If yours returns only (y_true, y_pred), modify it to also return probs for AP + PR curve.
    out = eval_collect_preds(model, test_dl, device, threshold=threshold, return_probs=True)
    y_true, y_pred, y_prob = out

    y_true = np.array(y_true).astype(int)
    y_pred = np.array(y_pred).astype(int)
    y_prob = np.array(y_prob).astype(float)

    avg = "binary" if len(np.unique(y_true)) <= 2 else "macro"
    test_f1  = float(f1_score(y_true, y_pred, average=avg, zero_division=0))
    test_acc = float(accuracy_score(y_true, y_pred))
    test_ap  = float(average_precision_score(y_true, y_prob)) if len(np.unique(y_true))==2 else float("nan")

    params_total, params_trainable = count_params(model)
    cm = confusion_matrix(y_true, y_pred).tolist()

    cfg = dict(kind=kind, threshold=float(threshold))
    if extra_cfg: cfg.update(extra_cfg)

    with wandb.init(entity=ENTITY, project=PROJECT, config=cfg, reinit="finish_previous") as run:
        wandb.log({
            "test_f1": test_f1,
            "test_acc": test_acc,
            "test_ap": test_ap,
            "params_total": int(params_total),
            "params_trainable": int(params_trainable),
            "confusion_matrix": cm,
        })

        # Helpful for the report panels
        log_pr_curve(y_true, y_prob, title="test_pr_curve")

        # Optional: histogram of probabilities
        wandb.log({"test_prob_hist": wandb.Histogram(y_prob)})

        print("Logged:", {"test_f1": test_f1, "test_acc": test_acc, "test_ap": test_ap,
                         "params_trainable": params_trainable, "kind": kind})

# Create a Sweep

In [None]:
import wandb

sweep_config = {
  "method": "random",
  "metric": {"name": "test_f1", "goal": "maximize"},
  "parameters": {
    "kind": {"values": ["baseline", "dendrites"]},
    "lr": {"values": [1e-5, 2e-5, 5e-5]},
    "batch_size": {"values": [4, 8, 16]},
    "max_len": {"values": [128, 256]},
    # dendrite knobs (only used when kind="dendrites")
    "max_dendrites": {"values": [0, 4, 8, 16]},
    "dend_switch_threshold": {"values": [0.0, 0.1, 0.2]},
    "dend_init_mag": {"values": [0.5, 1.0, 2.0]},
  }
}

sweep_id = wandb.sweep(sweep=sweep_config, entity=ENTITY, project=PROJECT)
print("Created sweep:", sweep_id)

Create sweep with ID: ywuxq1gt
Sweep URL: https://wandb.ai/vtpy/dendrites-hackathon/sweeps/ywuxq1gt
Created sweep: ywuxq1gt


In [None]:
def sweep_run():
    with wandb.init(entity=ENTITY, project=PROJECT, reinit="finish_previous") as run:
        cfg = dict(wandb.config)

        # --- build/train based on cfg ---
        # You already have these in globals:
        # - train_model(...)
        # - base_model / baseline_model(...)
        # - make_dendrites_model(...) / pai_model(...)
        #
        # PSEUDOCODE: replace with your actual build/train calls
        if cfg["kind"] == "baseline":
            model = baseline_model(lr=cfg["lr"], max_len=cfg["max_len"])  # adapt to your signature
        else:
            model = make_dendrites_model(
                lr=cfg["lr"], max_len=cfg["max_len"],
                max_dendrites=cfg["max_dendrites"],
                dend_switch_threshold=cfg["dend_switch_threshold"],
                dend_init_mag=cfg["dend_init_mag"],
            )

        model = train_model(model, lr=cfg["lr"], batch_size=cfg["batch_size"], epochs=cfg.get("epochs", 2))

        # --- evaluate + log ---
        run_eval_wandb(
            model=model,
            test_dl=test_dl,
            threshold=0.5,
            kind=cfg["kind"],
            extra_cfg=cfg
        )

In [None]:
import inspect
import wandb
import torch
import numpy as np
from sklearn.metrics import f1_score, accuracy_score, average_precision_score

ENTITY  = "vtpy"
PROJECT = "dendrites-hackathon"

def _filter_kwargs(fn, kwargs: dict):
    """Return kwargs that are accepted by fn(**kwargs)."""
    sig = inspect.signature(fn)
    params = sig.parameters
    # if **kwargs present, pass everything
    if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()):
        return kwargs
    allowed = set(params.keys())
    return {k:v for k,v in kwargs.items() if k in allowed}

def count_params(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return int(total), int(trainable)

def evaluate_and_log(model, test_dl, threshold=0.5):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device).eval()

    # Your eval_collect_preds must exist. If it can't return probs, we still log F1/Acc.
    try:
        y_true, y_pred, y_prob = eval_collect_preds(model, test_dl, device, threshold=threshold, return_probs=True)
        y_prob = np.asarray(y_prob, dtype=float)
    except TypeError:
        y_true, y_pred = eval_collect_preds(model, test_dl, device, threshold=threshold)
        y_prob = None

    y_true = np.asarray(y_true, dtype=int)
    y_pred = np.asarray(y_pred, dtype=int)

    avg = "binary" if len(np.unique(y_true)) <= 2 else "macro"
    test_f1  = float(f1_score(y_true, y_pred, average=avg, zero_division=0))
    test_acc = float(accuracy_score(y_true, y_pred))

    logs = {"test_f1": test_f1, "test_acc": test_acc, "threshold": float(threshold)}

    if y_prob is not None and len(np.unique(y_true)) == 2:
        logs["test_ap"] = float(average_precision_score(y_true, y_prob))

    params_total, params_trainable = count_params(model)
    logs["params_total"] = params_total
    logs["params_trainable"] = params_trainable

    wandb.log(logs)
    return logs

In [None]:
def sweep_run():
    with wandb.init() as run:
        cfg = dict(wandb.config)

        # --- Build model (only pass args the function actually supports)
        if cfg["kind"] == "baseline":
            build_fn = baseline_model
        else:
            build_fn = make_dendrites_model

        build_kwargs = {
            "lr": cfg.get("lr"),
            "learning_rate": cfg.get("lr"),
            "max_len": cfg.get("max_len"),
            "seq_len": cfg.get("max_len"),
            "batch_size": cfg.get("batch_size"),
            "max_dendrites": cfg.get("max_dendrites"),
            "dend_switch_threshold": cfg.get("dend_switch_threshold"),
            "dend_init_mag": cfg.get("dend_init_mag"),
        }
        build_kwargs = {k:v for k,v in build_kwargs.items() if v is not None}
        build_kwargs = _filter_kwargs(build_fn, build_kwargs)

        model = build_fn(**build_kwargs)

        # --- Train (again: only pass supported args)
        train_kwargs = {
            "lr": cfg.get("lr"),
            "learning_rate": cfg.get("lr"),
            "batch_size": cfg.get("batch_size"),
            "max_len": cfg.get("max_len"),
            "epochs": cfg.get("epochs", 2),
            "grad_accum": cfg.get("grad_accum"),
        }
        train_kwargs = {k:v for k,v in train_kwargs.items() if v is not None}
        train_kwargs = _filter_kwargs(train_model, train_kwargs)

        model = train_model(model, **train_kwargs)

        # --- Evaluate + log
        evaluate_and_log(model, test_dl, threshold=cfg.get("threshold", 0.5))

# infer in_dim + safe builder selection

In [None]:
import inspect, types
import torch
import wandb
import numpy as np
from sklearn.metrics import f1_score, accuracy_score, average_precision_score

def _filter_kwargs(fn, kwargs: dict):
    sig = inspect.signature(fn)
    params = sig.parameters
    if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()):
        return kwargs
    allowed = set(params.keys())
    return {k:v for k,v in kwargs.items() if k in allowed}

def infer_in_dim(dl):
    """Try to infer feature dimension from first batch of a DataLoader."""
    batch = next(iter(dl))

    # common patterns: (x,y), (x,y,...) or dict batches
    if isinstance(batch, (tuple, list)):
        x = batch[0]
    elif isinstance(batch, dict):
        # try common keys
        for k in ["x", "features", "inputs", "emb", "embedding"]:
            if k in batch:
                x = batch[k]
                break
        else:
            # transformer-style dict: input_ids -> (B, L); that's not "in_dim" for an MLP
            # but if your MLP consumes pooled embeddings, you likely have "x"/"features" instead
            raise ValueError(f"Can't infer in_dim from dict batch keys={list(batch.keys())[:30]}")
    else:
        raise ValueError(f"Unknown batch type: {type(batch)}")

    if not torch.is_tensor(x):
        raise ValueError(f"Expected tensor inputs, got {type(x)}")

    # if MLP input is (B, D), in_dim = D
    if x.dim() >= 2:
        return int(x.shape[-1])
    # if (B,), not enough info
    raise ValueError(f"Input tensor has shape {tuple(x.shape)}; can't infer in_dim.")

def pick_baseline_builder():
    """
    We want a FUNCTION that returns a new model, not an nn.Module instance.
    If you accidentally overwrote baseline_model with a model instance, we handle it.
    """
    candidates = ["baseline_model", "base_model", "make_baseline_model", "build_baseline_model"]
    for name in candidates:
        obj = globals().get(name, None)
        if obj is None:
            continue
        if isinstance(obj, torch.nn.Module):
            # this is already a model, not a builder
            continue
        if callable(obj):
            return obj

    # fallback: if baseline_model exists but is an nn.Module instance, reuse it (not ideal for sweeps)
    obj = globals().get("baseline_model", None)
    if isinstance(obj, torch.nn.Module):
        return lambda **kwargs: obj

    raise ValueError("Couldn't find a baseline model builder. Expected something like baseline_model() or base_model().")

def count_params(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return int(total), int(trainable)

def evaluate_and_log(model, test_dl, threshold=0.5):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device).eval()

    # If your eval_collect_preds can return probs, great. Otherwise log F1/Acc only.
    try:
        y_true, y_pred, y_prob = eval_collect_preds(model, test_dl, device, threshold=threshold, return_probs=True)
        y_prob = np.asarray(y_prob, dtype=float)
    except TypeError:
        y_true, y_pred = eval_collect_preds(model, test_dl, device, threshold=threshold)
        y_prob = None

    y_true = np.asarray(y_true, dtype=int)
    y_pred = np.asarray(y_pred, dtype=int)

    avg = "binary" if len(np.unique(y_true)) <= 2 else "macro"
    test_f1  = float(f1_score(y_true, y_pred, average=avg, zero_division=0))
    test_acc = float(accuracy_score(y_true, y_pred))

    logs = {"test_f1": test_f1, "test_acc": test_acc, "threshold": float(threshold)}

    # AP only makes sense for binary + probabilities
    if y_prob is not None and len(np.unique(y_true)) == 2:
        logs["test_ap"] = float(average_precision_score(y_true, y_prob))

    params_total, params_trainable = count_params(model)
    logs["params_total"] = params_total
    logs["params_trainable"] = params_trainable

    wandb.log(logs)
    return logs

# fixed sweep_run() (adds in_dim, removes lr from model build)

In [None]:
def sweep_run():
    with wandb.init() as run:
        cfg = dict(wandb.config)

        # --- choose builders
        if cfg["kind"] == "baseline":
            build_fn = pick_baseline_builder()
        else:
            build_fn = make_dendrites_model  # you have this

        # --- infer in_dim if needed (use whatever loader exists)
        dl_for_dim = globals().get("train_dl", None) or globals().get("test_dl", None)
        if dl_for_dim is None:
            raise ValueError("Need train_dl or test_dl defined to infer in_dim.")

        inferred_in_dim = infer_in_dim(dl_for_dim)

        # --- MODEL BUILD KWARGS (NO lr here)
        build_kwargs = {
            "in_dim": inferred_in_dim,
            "max_len": cfg.get("max_len"),
            "seq_len": cfg.get("max_len"),
            "batch_size": cfg.get("batch_size"),
            "max_dendrites": cfg.get("max_dendrites"),
            "dend_switch_threshold": cfg.get("dend_switch_threshold"),
            "dend_init_mag": cfg.get("dend_init_mag"),
        }
        build_kwargs = {k:v for k,v in build_kwargs.items() if v is not None}
        build_kwargs = _filter_kwargs(build_fn, build_kwargs)

        model = build_fn(**build_kwargs)

        # --- TRAIN KWARGS (lr belongs here)
        train_kwargs = {
            "lr": cfg.get("lr"),
            "learning_rate": cfg.get("lr"),
            "batch_size": cfg.get("batch_size"),
            "max_len": cfg.get("max_len"),
            "epochs": cfg.get("epochs", 2),
            "grad_accum": cfg.get("grad_accum"),
        }
        train_kwargs = {k:v for k,v in train_kwargs.items() if v is not None}
        train_kwargs = _filter_kwargs(train_model, train_kwargs)

        model = train_model(model, **train_kwargs)

        # --- EVAL + LOG
        evaluate_and_log(model, test_dl, threshold=cfg.get("threshold", 0.5))

# Fix sweep_run for train_model(make_model_fn, ...)

In [None]:
import inspect
import wandb
import torch
import numpy as np

def _call_train_model(train_model, make_model_fn, train_kwargs):
    """
    Handles both patterns:
      1) train_model(make_model_fn, **kwargs)
      2) train_model(make_model_fn=..., **kwargs)
    and returns the trained/best model (first element if tuple).
    """
    try:
        out = train_model(make_model_fn, **train_kwargs)
    except TypeError:
        out = train_model(make_model_fn=make_model_fn, **train_kwargs)

    # train_model sometimes returns (best_model, history, metrics, ...)
    if isinstance(out, (tuple, list)):
        return out[0]
    return out

def sweep_run():
    with wandb.init() as run:
        cfg = dict(wandb.config)

        # ---- infer in_dim from a dataloader
        dl_for_dim = globals().get("train_dl", None) or globals().get("test_dl", None)
        if dl_for_dim is None:
            raise ValueError("Need train_dl or test_dl defined to infer in_dim.")
        in_dim = infer_in_dim(dl_for_dim)

        # ---- choose builder function (baseline vs dendrites)
        baseline_builder = pick_baseline_builder()

        def make_model_fn():
            if cfg.get("kind") == "baseline":
                # Build baseline model (pass in_dim if builder supports it)
                kwargs = {"in_dim": in_dim, "max_len": cfg.get("max_len")}
                kwargs = {k:v for k,v in kwargs.items() if v is not None}
                kwargs = _filter_kwargs(baseline_builder, kwargs)
                return baseline_builder(**kwargs)

            # Build dendrites model (pass only what it supports)
            dend_kwargs = {
                "in_dim": in_dim,
                "max_len": cfg.get("max_len"),
                "seq_len": cfg.get("max_len"),
                "max_dendrites": cfg.get("max_dendrites"),
                "dend_switch_threshold": cfg.get("dend_switch_threshold"),
                "dend_init_mag": cfg.get("dend_init_mag"),
            }
            dend_kwargs = {k:v for k,v in dend_kwargs.items() if v is not None}
            dend_kwargs = _filter_kwargs(make_dendrites_model, dend_kwargs)
            return make_dendrites_model(**dend_kwargs)

        # ---- training kwargs (lr belongs here, NOT in model build)
        train_kwargs = {
            "lr": cfg.get("lr"),
            "learning_rate": cfg.get("lr"),
            "batch_size": cfg.get("batch_size"),
            "max_len": cfg.get("max_len"),
            "epochs": cfg.get("epochs", 2),
            "grad_accum": cfg.get("grad_accum"),
        }
        train_kwargs = {k:v for k,v in train_kwargs.items() if v is not None}
        train_kwargs = _filter_kwargs(train_model, train_kwargs)

        # ---- train
        model = _call_train_model(train_model, make_model_fn, train_kwargs)

        # ---- eval + log
        evaluate_and_log(model, test_dl, threshold=cfg.get("threshold", 0.5))

# Run Agent Wandb

In [None]:
import inspect
print("train_model:", inspect.signature(train_model))
print("make_dendrites_model:", inspect.signature(make_dendrites_model))
print("baseline_builder:", pick_baseline_builder(), "sig:", inspect.signature(pick_baseline_builder()))

train_model: (run_name, make_model_fn, epochs=3, lr=0.001, weight_decay=0.0, resume=True)
make_dendrites_model: (in_dim: int)
baseline_builder: <function pick_baseline_builder.<locals>.<lambda> at 0x795ffdf276a0> sig: (**kwargs)


In [None]:
import wandb, torch, numpy as np
from sklearn.metrics import f1_score, accuracy_score

# --- 1) infer in_dim from train_dl/test_dl (works for (x,y) or {"x":...,"y":...})
def infer_in_dim(dl):
    batch = next(iter(dl))
    if isinstance(batch, dict):
        # try common keys
        for k in ["x", "features", "emb", "embedding", "inputs"]:
            if k in batch:
                x = batch[k]
                break
        else:
            # otherwise take first tensor-like value
            x = next(v for v in batch.values() if torch.is_tensor(v))
    else:
        # tuple/list: (x,y,...) or (x,y)
        x = batch[0]

    if torch.is_tensor(x):
        x = x.detach()
        # expect [B, D] or [D]
        return int(x.shape[-1])
    # numpy fallback
    x = np.asarray(x)
    return int(x.shape[-1])

def count_params(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable

# --- 2) build baseline model safely (your baseline_builder is **kwargs lambda)
def build_baseline_model(baseline_builder, in_dim):
    # try with in_dim, otherwise call with no args
    try:
        return baseline_builder(in_dim=in_dim)
    except TypeError:
        return baseline_builder()

# --- 3) evaluation (uses your test_dl variable)
@torch.no_grad()
def eval_collect_preds_generic(model, dl, device="cuda", threshold=0.5):
    model.eval()
    model.to(device)

    y_true, y_pred = [], []
    for batch in dl:
        if isinstance(batch, dict):
            # try keys
            x = batch.get("x", None)
            y = batch.get("y", None) if "y" in batch else batch.get("label", None)
            if x is None:
                # pick first tensor value as x if needed
                x = next(v for v in batch.values() if torch.is_tensor(v))
        else:
            x, y = batch[0], batch[1]

        x = x.to(device)
        y = y.to(device)

        logits = model(x)
        # binary: logits shape [B] or [B,1] or [B,2]
        if logits.ndim == 2 and logits.shape[-1] == 2:
            prob = torch.softmax(logits, dim=-1)[:, 1]
        else:
            prob = torch.sigmoid(logits.squeeze(-1))

        pred = (prob >= threshold).long()
        y_true.append(y.long().view(-1).cpu())
        y_pred.append(pred.view(-1).cpu())

    y_true = torch.cat(y_true).numpy()
    y_pred = torch.cat(y_pred).numpy()
    return y_true, y_pred

# --- 4) THE SWEEP RUNNER (matches: train_model(run_name, make_model_fn, ...))
def sweep_run():
    with wandb.init() as run:
        cfg = dict(wandb.config)

        # pick dataloaders you already have
        train_dl_local = globals().get("train_dl", None) or globals().get("train_loader", None)
        test_dl_local  = globals().get("test_dl", None)  or globals().get("test_loader", None)

        if train_dl_local is None or test_dl_local is None:
            raise ValueError("Need train_dl and test_dl defined (your notebook has test_dl already).")

        in_dim = infer_in_dim(train_dl_local)

        baseline_builder = pick_baseline_builder()

        # model factory required by your train_model signature
        def make_model_fn():
            kind = cfg.get("kind", "baseline")
            if kind == "dendrites":
                return make_dendrites_model(in_dim=in_dim)
            return build_baseline_model(baseline_builder, in_dim=in_dim)

        # train (IMPORTANT: lr/epochs/etc go here, not into model)
        model = train_model(
            run.name,
            make_model_fn,
            epochs=cfg.get("epochs", 3),
            lr=cfg.get("lr", 1e-3),
            weight_decay=cfg.get("weight_decay", 0.0),
            resume=cfg.get("resume", True),
        )

        # if your train_model returns (model, history, ...) handle it
        if isinstance(model, (tuple, list)):
            model = model[0]

        # eval + log
        device = "cuda" if torch.cuda.is_available() else "cpu"
        threshold = cfg.get("threshold", 0.5)

        # if you already have eval_collect_preds(), use it; else fallback
        if "eval_collect_preds" in globals():
            y_true, y_pred = eval_collect_preds(model, test_dl_local, device, threshold=threshold)
        else:
            y_true, y_pred = eval_collect_preds_generic(model, test_dl_local, device=device, threshold=threshold)

        avg = "binary" if len(np.unique(y_true)) <= 2 else "macro"
        test_f1  = float(f1_score(y_true, y_pred, average=avg, zero_division=0))
        test_acc = float(accuracy_score(y_true, y_pred))

        params_total, params_trainable = count_params(model)

        wandb.log({
            "test_f1": test_f1,
            "test_acc": test_acc,
            "params_total": params_total,
            "params_trainable": params_trainable,
            "threshold": threshold,
        })

        # optional: confusion matrix (nice in report)
        try:
            cm_plot = wandb.plot.confusion_matrix(
                probs=None,
                y_true=y_true,
                preds=y_pred,
                class_names=["0","1"] if len(np.unique(y_true))<=2 else None,
            )
            wandb.log({"confusion_matrix": cm_plot})
        except Exception:
            pass

In [None]:
sweep_config = {
  "method": "random",
  "metric": {"name": "test_f1", "goal": "maximize"},
  "parameters": {
    "kind": {"values": ["baseline", "dendrites"]},
    "lr": {"values": [1e-5, 2e-5, 5e-5, 1e-4, 3e-4, 1e-3]},
    "epochs": {"values": [1, 2, 3, 4]},
    "weight_decay": {"values": [0.0, 1e-4, 1e-3]},
    "threshold": {"values": [0.3, 0.4, 0.5, 0.6]},
    "resume": {"values": [True]},
  }
}

In [None]:
import wandb, torch, numpy as np
from sklearn.metrics import f1_score, accuracy_score

def infer_in_dim(dl):
    batch = next(iter(dl))
    if isinstance(batch, dict):
        x = batch.get("x", None)
        if x is None:
            x = next(v for v in batch.values() if torch.is_tensor(v))
    else:
        x = batch[0]
    return int(x.shape[-1])

@torch.no_grad()
def eval_collect_preds_generic(model, dl, device="cuda", threshold=0.5):
    model.eval().to(device)
    y_true, y_pred = [], []
    for batch in dl:
        if isinstance(batch, dict):
            x = batch.get("x", None)
            y = batch.get("y", None) if "y" in batch else batch.get("label", None)
            if x is None:
                x = next(v for v in batch.values() if torch.is_tensor(v))
        else:
            x, y = batch[0], batch[1]

        x = x.to(device)
        y = y.to(device)

        logits = model(x)

        if logits.ndim == 2 and logits.shape[-1] == 2:
            prob = torch.softmax(logits, dim=-1)[:, 1]
        else:
            prob = torch.sigmoid(logits.squeeze(-1))

        pred = (prob >= threshold).long()
        y_true.append(y.long().view(-1).cpu())
        y_pred.append(pred.view(-1).cpu())

    y_true = torch.cat(y_true).numpy()
    y_pred = torch.cat(y_pred).numpy()
    return y_true, y_pred

def count_params(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable

def sweep_run():
    with wandb.init() as run:
        cfg = dict(wandb.config)

        train_dl_local = globals().get("train_dl", None) or globals().get("train_loader", None)
        test_dl_local  = globals().get("test_dl", None)  or globals().get("test_loader", None)

        if train_dl_local is None or test_dl_local is None:
            raise ValueError("Need train_dl (or train_loader) and test_dl (or test_loader) defined.")

        in_dim = infer_in_dim(train_dl_local)

        baseline_builder = pick_baseline_builder()  # your notebook function

        def make_model_fn():
            if cfg.get("kind", "baseline") == "dendrites":
                return make_dendrites_model(in_dim=in_dim)

            # baseline_builder is (**kwargs). Try in_dim; fallback to no args.
            try:
                return baseline_builder(in_dim=in_dim)
            except TypeError:
                return baseline_builder()

        model = train_model(
            run.name,
            make_model_fn,
            epochs=cfg.get("epochs", 3),
            lr=cfg.get("lr", 1e-3),
            weight_decay=cfg.get("weight_decay", 0.0),
            resume=cfg.get("resume", True),
        )

        if isinstance(model, (tuple, list)):
            model = model[0]

        device = "cuda" if torch.cuda.is_available() else "cpu"
        threshold = cfg.get("threshold", 0.5)

        # use your eval_collect_preds if it exists, else fallback
        if "eval_collect_preds" in globals():
            y_true, y_pred = eval_collect_preds(model, test_dl_local, device, threshold=threshold)
        else:
            y_true, y_pred = eval_collect_preds_generic(model, test_dl_local, device=device, threshold=threshold)

        avg = "binary" if len(np.unique(y_true)) <= 2 else "macro"
        test_f1  = float(f1_score(y_true, y_pred, average=avg, zero_division=0))
        test_acc = float(accuracy_score(y_true, y_pred))

        params_total, params_trainable = count_params(model)

        wandb.log({
            "test_f1": test_f1,
            "test_acc": test_acc,
            "params_total": params_total,
            "params_trainable": params_trainable,
            "threshold": threshold,
        })

In [None]:
import numpy as np, inspect

print("best_f1_threshold sig:", inspect.signature(best_f1_threshold))

toy_y = np.array([0,1,0,1])
toy_p = np.array([0.1,0.9,0.2,0.8])
out = best_f1_threshold(toy_y, toy_p)

print("type(out):", type(out))
try:
    print("len(out):", len(out))
except Exception:
    print("no len(out)")
print("out:", out)

best_f1_threshold sig: (y, prob)
type(out): <class 'dict'>
len(out): 8
out: {'thr': 0.8, 'f1': 0.9999999989999999, 'precision': 0.9999999995, 'recall': 0.9999999995, 'tp': 2, 'fp': 0, 'fn': 0, 'tn': 2}


In [None]:
import numpy as np

_best_f1_threshold_orig = best_f1_threshold  # keep original

def best_f1_threshold(val_y, val_prob, *args, **kwargs):
    out = _best_f1_threshold_orig(val_y, val_prob, *args, **kwargs)

    # If it returns a dict
    if isinstance(out, dict):
        thr = out.get("thr", out.get("threshold", out.get("best_thr")))
        f1  = out.get("f1", out.get("best_f1"))
        p   = out.get("p", out.get("precision", out.get("best_p")))
        r   = out.get("r", out.get("recall", out.get("best_r")))
        return float(thr), float(f1), float(p), float(r)

    # If it returns a tuple/list/np array
    if isinstance(out, (tuple, list, np.ndarray)):
        out = list(out)
        if len(out) >= 4:
            return float(out[0]), float(out[1]), float(out[2]), float(out[3])

    # Fallback (will likely still error, but gives clearer behavior)
    return out

# before start wandb agent : Patch best_f1_threshold to match train_model() (dict → 4-tuple)

In [None]:
import numpy as np

_best_f1_threshold_orig = best_f1_threshold  # keep original

def best_f1_threshold(y, prob):
    out = _best_f1_threshold_orig(y, prob)
    # your function returns a dict with 8 keys
    if isinstance(out, dict):
        return (float(out["thr"]),
                float(out["f1"]),
                float(out["precision"]),
                float(out["recall"]))
    # fallback if it ever changes
    if isinstance(out, (tuple, list, np.ndarray)) and len(out) >= 4:
        return float(out[0]), float(out[1]), float(out[2]), float(out[3])
    raise TypeError(f"Unexpected best_f1_threshold output type: {type(out)} -> {out}")

# Fix baseline builder resolution (avoid the “build_baseline_model missing …” issue)

In [None]:
import inspect

def resolve_baseline_builder():
    # Preferred: if you have a function that returns the builder lambda
    if "pick_baseline_builder" in globals() and callable(globals()["pick_baseline_builder"]):
        fn = globals()["pick_baseline_builder"]()
        print("Using pick_baseline_builder() ->", fn, "sig:", inspect.signature(fn))
        return fn

    # Otherwise: use whatever is currently in baseline_builder *if* it looks like a builder
    cand = globals().get("baseline_builder", None)
    if cand is not None and callable(cand):
        sig = str(inspect.signature(cand))
        # your real builder earlier looked like (**kwargs)
        if "kwargs" in sig and "baseline_builder" not in sig:
            print("Using existing baseline_builder ->", cand, "sig:", sig)
            return cand

    raise RuntimeError(
        "Could not resolve a usable baseline builder. "
        "Make sure you still have pick_baseline_builder() or the original baseline_builder lambda."
    )

BASELINE_BUILDER_FN = resolve_baseline_builder()

Using pick_baseline_builder() -> <function build_baseline_model at 0x795fff16a7a0> sig: (baseline_builder, in_dim)


In [None]:
def infer_in_dim():
    for k in ["in_dim", "input_dim", "emb_dim", "hidden_size"]:
        if k in globals() and isinstance(globals()[k], int):
            print("Using", k, "=", globals()[k])
            return globals()[k]

    # If your model is already built somewhere and has a first Linear
    for name in ["base_best_model", "pai_best_model", "base_model", "baseline_model"]:
        m = globals().get(name, None)
        if m is None:
            continue
        try:
            import torch.nn as nn
            for mod in m.modules():
                if isinstance(mod, nn.Linear):
                    print("Inferred in_dim from", name, "first Linear.in_features =", mod.in_features)
                    return int(mod.in_features)
        except Exception:
            pass

    raise RuntimeError("Couldn't infer in_dim automatically. Define `in_dim = <int>` manually.")

in_dim = infer_in_dim()

Inferred in_dim from base_best_model first Linear.in_features = 20000


In [None]:
import wandb
import numpy as np
from sklearn.metrics import f1_score, accuracy_score

# pick correct test loader var
test_loader = test_dl if "test_dl" in globals() else test_loader

def count_trainable_params(model):
    import torch
    return int(sum(p.numel() for p in model.parameters() if p.requires_grad))

def eval_collect_preds_from_loader(model, loader, device, threshold=0.5):
    import torch
    model.eval()
    ys, preds = [], []
    with torch.no_grad():
        for batch in loader:
            # adapt this part if your batch format differs
            # common cases: batch = (x, y) or dict with "y"/"labels"
            if isinstance(batch, (tuple, list)):
                x, y = batch[0], batch[1]
            elif isinstance(batch, dict):
                y = batch.get("y", batch.get("label", batch.get("labels")))
                x = {k:v for k,v in batch.items() if k not in ["y","label","labels"]}
            else:
                raise TypeError(f"Unknown batch type: {type(batch)}")

            # move to device
            if isinstance(x, dict):
                x = {k: (v.to(device) if hasattr(v, "to") else v) for k,v in x.items()}
                logits = model(**x) if callable(getattr(model, "forward", None)) else model(x)
            else:
                x = x.to(device) if hasattr(x, "to") else x
                logits = model(x)

            # logits -> prob (binary)
            if hasattr(logits, "detach"):
                logits = logits.detach()
            if logits.ndim > 1 and logits.shape[-1] > 1:
                # multiclass: argmax
                pred = logits.argmax(dim=-1).cpu().numpy()
            else:
                # binary: sigmoid + threshold
                prob = torch.sigmoid(logits).view(-1).cpu().numpy()
                pred = (prob >= threshold).astype(int)

            y_np = y.detach().cpu().numpy() if hasattr(y, "detach") else np.array(y)
            ys.append(y_np.reshape(-1))
            preds.append(pred.reshape(-1))

    y_true = np.concatenate(ys)
    y_pred = np.concatenate(preds)
    return y_true, y_pred

def sweep_run():
    run = wandb.init()
    cfg = wandb.config

    device = "cuda" if ("torch" in globals() and torch.cuda.is_available()) else "cpu"

    def make_model_fn():
        if cfg.kind == "baseline":
            # builder expects **kwargs; give in_dim if accepted
            try:
                return BASELINE_BUILDER_FN(in_dim=in_dim)
            except TypeError:
                return BASELINE_BUILDER_FN()
        else:
            return make_dendrites_model(in_dim)

    # ---- train using your provided API ----
    model = train_model(
        run_name=run.name,
        make_model_fn=make_model_fn,
        epochs=int(cfg.epochs),
        lr=float(cfg.lr),
        weight_decay=float(cfg.weight_decay),
        resume=bool(cfg.resume),
    )

    # ---- evaluate + log ----
    y_true, y_pred = eval_collect_preds_from_loader(model, test_loader, device, threshold=float(cfg.threshold))
    avg = "binary" if len(np.unique(y_true)) <= 2 else "macro"
    test_f1  = float(f1_score(y_true, y_pred, average=avg, zero_division=0))
    test_acc = float(accuracy_score(y_true, y_pred))

    params_trainable = count_trainable_params(model)

    wandb.log({
        "test_f1": test_f1,
        "test_acc": test_acc,
        "params_trainable": params_trainable,
        "kind": cfg.kind,
        "threshold": float(cfg.threshold),
        "in_dim": int(in_dim),
    })

In [None]:
import inspect

# snapshot so globals can't change during iteration
items = list(globals().items())

def safe_sig(obj):
    try:
        return str(inspect.signature(obj))
    except Exception:
        return None

# show any objects that look like "baseline builder" candidates
cands = []
for name, obj in items:
    if callable(obj):
        sig = safe_sig(obj)
        if sig and ("baseline" in name.lower() or "builder" in name.lower()):
            cands.append((name, sig, obj))

print("Baseline/builder-ish callables:")
for name, sig, obj in cands[:50]:
    print(f"- {name:30s} sig={sig}")

Baseline/builder-ish callables:
- baseline_model                 sig=(*args, **kwargs)
- train_baseline_mlp             sig=(run_name='baseline_mlp', run_root='/content/runs_travelplanner', epochs=10, lr=0.001, weight_decay=0.0, grad_clip=1.0, log_every=50, resume=True)
- pick_baseline_builder          sig=()
- build_baseline_model           sig=(baseline_builder, in_dim)
- resolve_baseline_builder       sig=()
- BASELINE_BUILDER_FN            sig=(baseline_builder, in_dim)
- BASELINE_CORE_BUILDER          sig=(baseline_builder, in_dim)


In [None]:
import inspect

print("baseline_builder in globals()? ->", "baseline_builder" in globals())
if "baseline_builder" in globals():
    print("baseline_builder =", baseline_builder)
    print("sig =", inspect.signature(baseline_builder))

baseline_builder in globals()? -> False


In [None]:
import inspect

print("base_model in globals()? ->", "base_model" in globals())
print("build_baseline_model in globals()? ->", "build_baseline_model" in globals())

# inspect base_model signature (helps debugging)
if "base_model" in globals():
    try:
        print("base_model sig:", inspect.signature(base_model))
    except Exception as e:
        print("couldn't inspect base_model signature:", e)

def BASELINE_CORE_BUILDER(in_dim):
    # Preferred path: build_baseline_model(base_model, in_dim)
    if "build_baseline_model" in globals() and "base_model" in globals():
        return build_baseline_model(base_model, in_dim)

    # Fallback: maybe base_model itself builds the model
    if "base_model" in globals():
        try:
            return base_model(in_dim=in_dim)
        except TypeError:
            return base_model(in_dim)

    raise RuntimeError("Couldn't reconstruct a baseline builder. base_model/build_baseline_model missing.")

print("BASELINE_CORE_BUILDER sig:", inspect.signature(BASELINE_CORE_BUILDER))

base_model in globals()? -> True
build_baseline_model in globals()? -> True
base_model sig: (*args, **kwargs)
BASELINE_CORE_BUILDER sig: (in_dim)


In [None]:
def best_f1_threshold_unpack(y, prob):
    out = best_f1_threshold(y, prob)
    # out is dict like {'thr':..., 'f1':..., 'precision':..., 'recall':...}
    return out["thr"], out["f1"], out["precision"], out["recall"], out

# example usage:
# thr, best_f1, best_p, best_r, out = best_f1_threshold_unpack(val_y, val_prob)

In [None]:
import wandb, torch, numpy as np
from sklearn.metrics import f1_score, accuracy_score

device = "cuda" if torch.cuda.is_available() else "cpu"

def sweep_run():
    run = wandb.init()
    cfg = wandb.config

    # --- choose model
    in_dim = globals().get("in_dim", None)
    if in_dim is None:
        # if you have embeddings / tokenizer pipeline, in_dim should come from that
        # last resort: infer from one batch of your train_dl
        xb, yb = next(iter(train_dl))
        in_dim = xb.shape[-1]

    if cfg.kind == "baseline":
        make_model_fn = lambda: BASELINE_CORE_BUILDER(in_dim)
    else:
        make_model_fn = lambda: make_dendrites_model(in_dim)

    # --- train via your existing API
    model = train_model(
        run_name=wandb.run.name,
        make_model_fn=make_model_fn,
        epochs=int(cfg.epochs),
        lr=float(cfg.lr),
        weight_decay=float(cfg.weight_decay),
        resume=bool(cfg.resume),
    )

    # --- eval (use your test_dl)
    threshold = float(cfg.threshold)

    y_true, y_pred = eval_collect_preds(model, test_dl, device, threshold=threshold)
    avg = "binary" if len(np.unique(y_true)) <= 2 else "macro"

    test_f1  = float(f1_score(y_true, y_pred, average=avg, zero_division=0))
    test_acc = float(accuracy_score(y_true, y_pred))

    # params
    params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    params_total = sum(p.numel() for p in model.parameters())

    wandb.log({
        "test_f1": test_f1,
        "test_acc": test_acc,
        "threshold": threshold,
        "params_trainable": params_trainable,
        "params_total": params_total,
    })

    wandb.finish()

In [None]:
import inspect
print("base_model =", base_model)
print("base_model sig =", inspect.signature(base_model))
print("build_baseline_model sig =", inspect.signature(build_baseline_model))
print("baseline_builder exists?", "baseline_builder" in globals())

base_model = BaseMLP(
  (net): Sequential(
    (0): Linear(in_features=20000, out_features=512, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=512, out_features=256, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.2, inplace=False)
    (6): Linear(in_features=256, out_features=1, bias=True)
  )
)
base_model sig = (*args, **kwargs)
build_baseline_model sig = (baseline_builder, in_dim)
baseline_builder exists? False


In [None]:
import copy
import torch
import torch.nn as nn

def infer_in_dim_from_base_model(m):
    # assumes base_model.net[0] is Linear
    first = m.net[0]
    assert isinstance(first, nn.Linear), f"Expected first layer Linear, got {type(first)}"
    return first.in_features

IN_DIM = infer_in_dim_from_base_model(base_model)
print("IN_DIM =", IN_DIM)  # should print 20000

def clone_base_model_with_in_dim(in_dim: int):
    m = copy.deepcopy(base_model)
    first = m.net[0]
    if hasattr(first, "in_features") and first.in_features != in_dim:
        # patch first layer to new input dim (keep out_features/bias)
        new_first = nn.Linear(in_dim, first.out_features, bias=(first.bias is not None))
        m.net[0] = new_first
    return m

# This is the builder sweep/training expects
def BASELINE_CORE_BUILDER(in_dim: int):
    return clone_base_model_with_in_dim(in_dim)

IN_DIM = 20000


In [None]:
def best_f1_threshold_unpack(y, prob):
    out = best_f1_threshold(y, prob)   # returns dict
    return out["thr"], out["f1"], out["precision"], out["recall"], out

In [None]:
import wandb, numpy as np, torch
from sklearn.metrics import f1_score, accuracy_score

device = "cuda" if torch.cuda.is_available() else "cpu"

def sweep_run():
    run = wandb.init()
    cfg = wandb.config

    in_dim = IN_DIM

    # choose model builder (ONLY model args here)
    if cfg.kind == "baseline":
        make_model_fn = lambda: BASELINE_CORE_BUILDER(in_dim)
    else:
        make_model_fn = lambda: make_dendrites_model(in_dim)

    # train args (ONLY training args here)
    model = train_model(
        run_name=wandb.run.name,
        make_model_fn=make_model_fn,
        epochs=int(cfg.epochs),
        lr=float(cfg.lr),
        weight_decay=float(cfg.weight_decay),
        resume=bool(cfg.resume),
    )

    # eval
    threshold = float(cfg.threshold)
    y_true, y_pred = eval_collect_preds(model, test_dl, device, threshold=threshold)

    avg = "binary" if len(np.unique(y_true)) <= 2 else "macro"
    test_f1  = float(f1_score(y_true, y_pred, average=avg, zero_division=0))
    test_acc = float(accuracy_score(y_true, y_pred))

    params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    params_total = sum(p.numel() for p in model.parameters())

    wandb.log({
        "test_f1": test_f1,
        "test_acc": test_acc,
        "params_trainable": params_trainable,
        "params_total": params_total,
        "threshold": threshold,
        "kind": cfg.kind,
    })

    wandb.finish()

In [None]:
sweep_config = {
  "method": "random",
  "metric": {"name": "test_f1", "goal": "maximize"},
  "parameters": {
    "kind": {"values": ["baseline", "dendrites"]},
    "lr": {"values": [1e-5, 2e-5, 5e-5, 1e-4]},
    "weight_decay": {"values": [0.0, 1e-4, 1e-3]},
    "epochs": {"values": [1, 2, 3]},
    "threshold": {"values": [0.3, 0.5, 0.7]},
    "resume": {"values": [True]},
  }
}

sweep_id = wandb.sweep(sweep_config, project="dendrites-hackathon", entity="vtpy")
print("SWEEP:", sweep_id)
wandb.agent(sweep_id, function=sweep_run, count=30)

Create sweep with ID: 7mt1bror
Sweep URL: https://wandb.ai/vtpy/dendrites-hackathon/sweeps/7mt1bror
SWEEP: 7mt1bror


[34m[1mwandb[0m: Agent Starting Run: 5jokvd12 with config:
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	kind: baseline
[34m[1mwandb[0m: 	lr: 2e-05
[34m[1mwandb[0m: 	resume: True
[34m[1mwandb[0m: 	threshold: 0.3
[34m[1mwandb[0m: 	weight_decay: 0.0001


[polished-sweep-1] params total/trainable: 10372097 10372097
[polished-sweep-1] ep 0 step 0/5 loss 1.62413


Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/wandb/agents/pyagent.py", line 296, in _run_job
    self._function()
  File "/tmp/ipython-input-3416870044.py", line 19, in sweep_run
    model = train_model(
            ^^^^^^^^^^^^
  File "/tmp/ipython-input-209466060.py", line 68, in train_model
    thr, best_f1, best_p, best_r = best_f1_threshold(val_y, val_prob)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipython-input-3905526601.py", line 6, in best_f1_threshold
    out = _best_f1_threshold_orig(y, prob)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipython-input-300826175.py", line 6, in best_f1_threshold
    out = _best_f1_threshold_orig(val_y, val_prob, *args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipython-input-300826175.py", line 6, in best_f1_threshold
    out = _best_f1_threshold_orig(val_y, val_prob, *args, **kwargs)
          ^^^^^^^^^^^^^

[34m[1mwandb[0m: [32m[41mERROR[0m Run 5jokvd12 errored: maximum recursion depth exceeded
[34m[1mwandb[0m: Agent Starting Run: 6qeiu6c8 with config:
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	kind: baseline
[34m[1mwandb[0m: 	lr: 2e-05
[34m[1mwandb[0m: 	resume: True
[34m[1mwandb[0m: 	threshold: 0.7
[34m[1mwandb[0m: 	weight_decay: 0.001


[wise-sweep-2] params total/trainable: 10372097 10372097
[wise-sweep-2] ep 0 step 0/5 loss 1.65821


Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/wandb/agents/pyagent.py", line 296, in _run_job
    self._function()
  File "/tmp/ipython-input-3416870044.py", line 19, in sweep_run
    model = train_model(
            ^^^^^^^^^^^^
  File "/tmp/ipython-input-209466060.py", line 68, in train_model
    thr, best_f1, best_p, best_r = best_f1_threshold(val_y, val_prob)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipython-input-3905526601.py", line 6, in best_f1_threshold
    out = _best_f1_threshold_orig(y, prob)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipython-input-300826175.py", line 6, in best_f1_threshold
    out = _best_f1_threshold_orig(val_y, val_prob, *args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipython-input-300826175.py", line 6, in best_f1_threshold
    out = _best_f1_threshold_orig(val_y, val_prob, *args, **kwargs)
          ^^^^^^^^^^^^^

[34m[1mwandb[0m: [32m[41mERROR[0m Run 6qeiu6c8 errored: maximum recursion depth exceeded
[34m[1mwandb[0m: Agent Starting Run: rchj8i78 with config:
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	kind: baseline
[34m[1mwandb[0m: 	lr: 2e-05
[34m[1mwandb[0m: 	resume: True
[34m[1mwandb[0m: 	threshold: 0.7
[34m[1mwandb[0m: 	weight_decay: 0.0001


[youthful-sweep-3] params total/trainable: 10372097 10372097
[youthful-sweep-3] ep 0 step 0/5 loss 1.62294


Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/wandb/agents/pyagent.py", line 296, in _run_job
    self._function()
  File "/tmp/ipython-input-3416870044.py", line 19, in sweep_run
    model = train_model(
            ^^^^^^^^^^^^
  File "/tmp/ipython-input-209466060.py", line 68, in train_model
    thr, best_f1, best_p, best_r = best_f1_threshold(val_y, val_prob)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipython-input-3905526601.py", line 6, in best_f1_threshold
    out = _best_f1_threshold_orig(y, prob)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipython-input-300826175.py", line 6, in best_f1_threshold
    out = _best_f1_threshold_orig(val_y, val_prob, *args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipython-input-300826175.py", line 6, in best_f1_threshold
    out = _best_f1_threshold_orig(val_y, val_prob, *args, **kwargs)
          ^^^^^^^^^^^^^

[34m[1mwandb[0m: [32m[41mERROR[0m Run rchj8i78 errored: maximum recursion depth exceeded
[34m[1mwandb[0m: [32m[41mERROR[0m Detected 3 failed runs in the first 60 seconds, killing sweep.
[34m[1mwandb[0m: To disable this check set WANDB_AGENT_DISABLE_FLAPPING=true


In [None]:
import numpy as np

# If you previously created these, wipe them so they can't cause loops
for _name in ["_best_f1_threshold_orig", "_best_f1_threshold_orig2"]:
    if _name in globals():
        del globals()[_name]

def best_f1_threshold(y, prob, num_grid=200):
    """
    Returns (best_thr, best_f1, best_precision, best_recall)
    so train_model can do: thr, best_f1, best_p, best_r = best_f1_threshold(...)
    """
    y = np.asarray(y).astype(int)
    prob = np.asarray(prob).astype(float)

    # Guardrails
    if len(y) == 0:
        return 0.5, 0.0, 0.0, 0.0

    # Grid thresholds (include exact probs too for more precise choice)
    grid = np.linspace(0.0, 1.0, num_grid)
    grid = np.unique(np.concatenate([grid, np.clip(prob, 0, 1)]))

    best = (-1.0, 0.5, 0.0, 0.0)  # (f1, thr, p, r)

    for thr in grid:
        pred = (prob >= thr).astype(int)

        tp = int(((pred == 1) & (y == 1)).sum())
        fp = int(((pred == 1) & (y == 0)).sum())
        fn = int(((pred == 0) & (y == 1)).sum())

        precision = tp / (tp + fp + 1e-12)
        recall    = tp / (tp + fn + 1e-12)
        f1        = 2 * precision * recall / (precision + recall + 1e-12)

        if f1 > best[0]:
            best = (f1, float(thr), float(precision), float(recall))

    best_f1, best_thr, best_p, best_r = best
    return best_thr, best_f1, best_p, best_r

# quick sanity (should not recurse)
print(best_f1_threshold([0,1,0,1], [0.1,0.9,0.2,0.8]))

(0.20100502512562815, 0.9999999999989999, 0.9999999999995, 0.9999999999995)


In [None]:
import os
os.environ["WANDB_AGENT_DISABLE_FLAPPING"] = "true"

In [None]:
import numpy as np
import torch
from sklearn.metrics import average_precision_score, accuracy_score, precision_recall_fscore_support

def eval_report(model, dl, threshold=0.5, device=None):
    model.eval()
    if device is None:
        device = next(model.parameters()).device

    ys = []
    probs = []

    with torch.no_grad():
        for batch in dl:
            # --- adapt these two lines if your batch format differs ---
            x = batch[0].to(device) if isinstance(batch, (tuple, list)) else batch["x"].to(device)
            y = batch[1]            if isinstance(batch, (tuple, list)) else batch["y"]

            y = y.detach().cpu().numpy()
            ys.append(y)

            out = model(x)
            out = out.detach().cpu().numpy()

            # convert logits -> prob if needed
            # if your model already outputs probabilities, keep it
            if out.ndim > 1 and out.shape[-1] == 1:
                out = out.reshape(-1)
            prob = 1 / (1 + np.exp(-out))  # sigmoid
            probs.append(prob)

    yt = np.concatenate(ys).reshape(-1)
    prob = np.concatenate(probs).reshape(-1)

    # ---- CRITICAL PART: make y_true valid for PR-AUC/AP ----
    # If y_true isn't strictly {0,1}, binarize it.
    uniq = np.unique(yt)
    if not set(uniq.tolist()).issubset({0, 1}):
        # common case: floats in [0,1]
        if yt.min() >= 0 and yt.max() <= 1:
            yt = (yt >= 0.5).astype(int)
        else:
            # fallback: treat as continuous target, threshold at median
            yt = (yt >= np.median(yt)).astype(int)
    else:
        yt = yt.astype(int)

    pred = (prob >= threshold).astype(int)

    # metrics
    acc = float(accuracy_score(yt, pred))
    p, r, f1, _ = precision_recall_fscore_support(
        yt, pred, average="binary", zero_division=0
    )

    # AP only valid if both classes appear
    ap = float("nan")
    if len(np.unique(yt)) == 2:
        ap = float(average_precision_score(yt, prob))

    rep = {
        "acc": acc,
        "precision": float(p),
        "recall": float(r),
        "f1": float(f1),
        "ap": ap,
        "n": int(len(yt)),
        "pos_rate": float(yt.mean()),
    }
    return rep, yt, prob

In [None]:
val_rep, val_y, val_prob = eval_report(pai_best_model, val_dl, threshold=0.5)
print("val_rep:", val_rep)
print("unique y:", np.unique(val_y)[:10], "count:", len(np.unique(val_y)))

val_rep: {'acc': 0.8888888888888888, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'ap': 0.2871881205214538, 'n': 54, 'pos_rate': 0.1111111111111111}
unique y: [0 1] count: 2


In [None]:
import numpy as np
from sklearn.metrics import precision_recall_curve

def best_f1_threshold(y, prob):
    y = np.asarray(y).reshape(-1).astype(int)
    prob = np.asarray(prob).reshape(-1)

    p, r, thr = precision_recall_curve(y, prob)
    f1 = (2 * p * r) / (p + r + 1e-12)

    i = int(np.nanargmax(f1))
    best_thr = float(thr[i]) if i < len(thr) else 0.5
    return best_thr, float(f1[i]), float(p[i]), float(r[i])

In [None]:
val_rep, val_y, val_prob = eval_report(pai_best_model, val_dl, threshold=0.5)
print("val_rep:", val_rep)
print("unique y:", np.unique(val_y)[:10], "count:", len(np.unique(val_y)))

val_rep: {'acc': 0.8888888888888888, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'ap': 0.2871881205214538, 'n': 54, 'pos_rate': 0.1111111111111111}
unique y: [0 1] count: 2


In [None]:
import numpy as np
from sklearn.metrics import precision_recall_curve

def best_f1_threshold(y, prob):
    y = np.asarray(y).reshape(-1).astype(int)
    prob = np.asarray(prob).reshape(-1).astype(float)

    p, r, thr = precision_recall_curve(y, prob)
    f1 = (2 * p * r) / (p + r + 1e-12)

    i = int(np.nanargmax(f1))
    best_thr = float(thr[i]) if i < len(thr) else 0.5
    return best_thr, float(f1[i]), float(p[i]), float(r[i])

In [None]:
import copy, wandb, torch
import numpy as np

def fresh_copy(model):
    m = copy.deepcopy(model)
    # reinit if possible
    for mod in m.modules():
        rp = getattr(mod, "reset_parameters", None)
        if callable(rp):
            try:
                rp()
            except Exception:
                pass
    return m

IN_DIM = int(base_model.net[0].in_features)  # 20000 in your printout

def sweep_run():
    run = wandb.init()  # sweep provides project/entity automatically
    cfg = wandb.config

    kind = str(cfg.kind)
    run_name = f"{kind}_lr{cfg.lr}_wd{cfg.weight_decay}_ep{cfg.epochs}_thr{cfg.threshold}"

    def make_model_fn():
        if kind == "baseline":
            return fresh_copy(base_model)
        elif kind == "dendrites":
            return make_dendrites_model(IN_DIM)
        else:
            raise ValueError(f"Unknown kind={kind}")

    # ---- train (your trainer will .to(device) internally)
    model = train_model(
        run_name=run_name,
        make_model_fn=make_model_fn,
        epochs=int(cfg.epochs),
        lr=float(cfg.lr),
        weight_decay=float(cfg.weight_decay),
        resume=bool(cfg.resume),
    )

    # ---- eval (threshold is swept; later you can switch to val-chosen thr*)
    thr = float(cfg.threshold)
    val_rep, _, _  = eval_report(model, val_dl,  threshold=thr)
    test_rep, _, _ = eval_report(model, test_dl, threshold=thr)

    params_trainable = int(sum(p.numel() for p in model.parameters() if p.requires_grad))
    params_total     = int(sum(p.numel() for p in model.parameters()))

    wandb.log({
        "kind": kind,
        "params_trainable": params_trainable,
        "params_total": params_total,
        "threshold": thr,

        "val_acc": float(val_rep["acc"]),
        "val_f1":  float(val_rep["f1"]),
        "val_ap":  float(val_rep["ap"]),

        "test_acc": float(test_rep["acc"]),
        "test_f1":  float(test_rep["f1"]),
        "test_ap":  float(test_rep["ap"]),
    })

    wandb.finish()

In [None]:
sweep_cfg = {
  "method": "random",
  "metric": {"name": "test_f1", "goal": "maximize"},
  "parameters": {
    "kind": {"values": ["baseline", "dendrites"]},
    "lr": {"values": [1e-3, 3e-4, 1e-4, 5e-5, 2e-5]},
    "weight_decay": {"values": [0.0, 1e-4, 1e-3]},
    "epochs": {"values": [1, 2, 3, 4]},
    "threshold": {"values": [0.2, 0.3, 0.5, 0.7]},
    "resume": {"values": [True]},
  }
}

sweep_id = wandb.sweep(sweep_cfg, project="dendrites-hackathon", entity="vtpy")
print("SWEEP:", sweep_id)

Create sweep with ID: 8znxrbqq
Sweep URL: https://wandb.ai/vtpy/dendrites-hackathon/sweeps/8znxrbqq
SWEEP: 8znxrbqq


In [None]:
# --- Patch eval_report output keys to match train_model expectation ---
_eval_report_orig = eval_report  # keep original

def eval_report(*args, **kwargs):
    rep, prob, yt = _eval_report_orig(*args, **kwargs)

    # Your eval_report currently uses 'ap' for PR-AUC (average precision)
    if "pr_auc" not in rep and "ap" in rep:
        rep["pr_auc"] = rep["ap"]
    # Optional: keep both names present
    if "ap" not in rep and "pr_auc" in rep:
        rep["ap"] = rep["pr_auc"]

    return rep, prob, yt

In [None]:
import numpy as np
from sklearn.metrics import precision_recall_curve

def best_f1_threshold(y, prob):
    y = np.asarray(y).reshape(-1)
    prob = np.asarray(prob).reshape(-1)

    # force binary ints
    if y.dtype != np.int64 and y.dtype != np.int32:
        y = (y > 0.5).astype(int)
    else:
        y = y.astype(int)

    # If no positives OR all positives, PR curve is degenerate → avoid warnings/crashes
    pos = int(y.sum())
    if pos == 0 or pos == len(y):
        return 0.5, 0.0, 0.0, 0.0  # thr, f1, precision, recall

    p, r, thr = precision_recall_curve(y, prob)
    f1 = (2 * p * r) / (p + r + 1e-12)
    i = int(np.nanargmax(f1))
    best_thr = float(thr[i]) if i < len(thr) else 0.5
    return best_thr, float(f1[i]), float(p[i]), float(r[i])

In [None]:
import inspect

# keep original
_save_ckpt_orig = save_ckpt

def save_ckpt(*args, **kwargs):
    """
    Fix: if caller passes `extra` both positionally and by keyword,
    drop the keyword copy (or vice-versa) to avoid crash.
    """
    sig = inspect.signature(_save_ckpt_orig)
    params = list(sig.parameters.keys())

    if "extra" in params:
        extra_pos = params.index("extra")
        # If `extra` was already passed positionally, remove keyword version
        if "extra" in kwargs and len(args) > extra_pos:
            kwargs.pop("extra", None)

    return _save_ckpt_orig(*args, **kwargs)

print("**YAY patched save_ckpt. signature:", inspect.signature(save_ckpt))

**YAY patched save_ckpt. signature: (*args, **kwargs)


In [None]:
import os, re

# you already have _save_ckpt_orig from your previous patch
# if not, set it once:
# _save_ckpt_orig = save_ckpt

def _sanitize_ckpt_path(p: str) -> str:
    p = str(p).replace("\\", "/")

    # if someone made a folder named "last.pt/..."
    if ".pt/" in p:
        p = p.split(".pt/")[0] + ".pt"

    # remove newlines + illegal filename chars
    p = p.replace("\n", "_")
    p = re.sub(r"[^0-9A-Za-z._/\-]", "_", p)

    # make sure parent exists
    parent = os.path.dirname(p)
    if parent:
        os.makedirs(parent, exist_ok=True)

    return p

def save_ckpt(*args, **kwargs):
    # try to locate the checkpoint path argument
    # (works whether it's passed as kwarg or positional)
    path_idx = None
    path_val = kwargs.get("path", None)

    if path_val is None:
        for i, a in enumerate(args):
            if isinstance(a, (str, os.PathLike)) and (
                str(a).endswith((".pt", ".pth")) or ".pt/" in str(a) or "/content/" in str(a)
            ):
                path_idx = i
                path_val = str(a)
                break

    if path_val is not None:
        safe_path = _sanitize_ckpt_path(str(path_val))
        if "path" in kwargs:
            kwargs["path"] = safe_path
        elif path_idx is not None:
            args = list(args)
            args[path_idx] = safe_path
            args = tuple(args)

    # attempt save; if still fails, skip (metrics are already logged)
    try:
        return _save_ckpt_orig(*args, **kwargs)
    except Exception as e:
        print("⚠️ save_ckpt: skipping checkpoint save to keep sweep running.\n   ", repr(e))
        return None

print("*yay! save_ckpt patched to sanitize path / skip on failure")

*yay! save_ckpt patched to sanitize path / skip on failure


In [None]:
def unwrap_model(obj):
    # already a torch model?
    if hasattr(obj, "eval") and hasattr(obj, "parameters"):
        return obj

    # tuple/list: find first thing that looks like a model
    if isinstance(obj, (tuple, list)):
        for it in obj:
            m = unwrap_model(it)
            if m is not None:
                return m

    # dict: search values
    if isinstance(obj, dict):
        for it in obj.values():
            m = unwrap_model(it)
            if m is not None:
                return m

    return None

In [None]:
import wandb, numpy as np

def unwrap_model(obj):
    if hasattr(obj, "eval") and hasattr(obj, "parameters"):
        return obj
    if isinstance(obj, (tuple, list)):
        for it in obj:
            m = unwrap_model(it)
            if m is not None:
                return m
    if isinstance(obj, dict):
        for it in obj.values():
            m = unwrap_model(it)
            if m is not None:
                return m
    return None

def get_pr_auc(rep):
    return float(rep.get("pr_auc", rep.get("ap", 0.0)))

def sweep_run():
    # W&B gives you a config object here
    wandb.init()
    cfg = wandb.config

    # make a readable run name
    run_name = f"{cfg.kind}_lr{cfg.lr}_wd{cfg.weight_decay}_ep{cfg.epochs}_thr{cfg.threshold}"
    wandb.run.name = run_name

    # ---- pick builder based on cfg.kind ----
    # you already have these functions in globals:
    # - make_dendrites_model(in_dim)
    # - BaseMLP model instance exists as base_model (or build baseline some other way)
    # - train_model(run_name, make_model_fn, epochs, lr, weight_decay, resume)

    # infer in_dim from an existing model if possible (you showed 20000)
    in_dim = 20000

    def make_model_fn():
        if cfg.kind == "dendrites":
            return make_dendrites_model(in_dim=in_dim)
        else:
            # baseline: easiest is to rebuild same architecture; if you already have a builder, use it.
            # If you only have `base_model` instance, you MUST create a fresh one, not reuse the trained object.
            # Here’s a simple fallback: call the class constructor if you have one.
            # Replace this line with your baseline factory if needed:
            return BaseMLP(in_dim=in_dim)  # <-- change if your baseline ctor differs

    # ---- train (train_model may return tuple/dict/etc) ----
    trained = train_model(
        run_name=run_name,
        make_model_fn=make_model_fn,
        epochs=int(cfg.epochs),
        lr=float(cfg.lr),
        weight_decay=float(cfg.weight_decay),
        resume=bool(cfg.resume),
    )

    model = unwrap_model(trained)
    assert model is not None, f"Could not unwrap model from train_model output: {type(trained)}"

    # ---- eval + log metrics needed for sweep scatter ----
    thr = float(cfg.threshold)

    val_rep, _, _  = eval_report(model, val_dl,  threshold=thr)
    test_rep, _, _ = eval_report(model, test_dl, threshold=thr)

    params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)

    wandb.log({
        "params_trainable": params_trainable,
        "val_pr_auc": get_pr_auc(val_rep),
        "val_f1": float(val_rep.get("f1", 0.0)),
        "test_pr_auc": get_pr_auc(test_rep),
        "test_f1": float(test_rep.get("f1", 0.0)),
    })

    # also push into summary (helps reports)
    wandb.run.summary["params_trainable"] = params_trainable
    wandb.run.summary["test_pr_auc"] = get_pr_auc(test_rep)
    wandb.run.summary["test_f1"] = float(test_rep.get("f1", 0.0))

In [None]:
import copy
import torch.nn as nn

def fresh_clone_with_reset(model):
    """Deep-copy a model and reset all layers that support reset_parameters()."""
    m = copy.deepcopy(model)

    def _reset(mod):
        if hasattr(mod, "reset_parameters"):
            mod.reset_parameters()

    m.apply(_reset)
    return m

def infer_in_dim_from_base(model):
    # tries common patterns; you printed base_model.net[0] is Linear
    if hasattr(model, "net") and isinstance(model.net, nn.Sequential):
        for layer in model.net:
            if isinstance(layer, nn.Linear):
                return int(layer.in_features)
    # fallback
    for mod in model.modules():
        if isinstance(mod, nn.Linear):
            return int(mod.in_features)
    raise ValueError("Couldn't infer in_dim from base_model.")

In [None]:
in_dim = infer_in_dim_from_base(base_model)

def make_model_fn():
    if cfg.kind == "dendrites":
        # your signature is make_dendrites_model(in_dim: int) (positional!)
        return make_dendrites_model(in_dim)
    else:
        # baseline: fresh untrained copy of the known-good architecture
        return fresh_clone_with_reset(base_model)

In [None]:
import copy
import torch.nn as nn

def fresh_clone_with_reset(model):
    """Deep-copy a model and reset all layers that support reset_parameters()."""
    m = copy.deepcopy(model)

    def _reset(mod):
        if hasattr(mod, "reset_parameters"):
            mod.reset_parameters()

    m.apply(_reset)
    return m

def infer_in_dim_from_base(model):
    # tries common patterns; you printed base_model.net[0] is Linear
    if hasattr(model, "net") and isinstance(model.net, nn.Sequential):
        for layer in model.net:
            if isinstance(layer, nn.Linear):
                return int(layer.in_features)
    # fallback
    for mod in model.modules():
        if isinstance(mod, nn.Linear):
            return int(mod.in_features)
    raise ValueError("Couldn't infer in_dim from base_model.")

In [None]:
# =========================
# FULL UPDATED SWEEP CODE
# (Paste as ONE CELL in Colab)
# =========================

import os, copy, math, traceback
import numpy as np
import torch
import torch.nn as nn
import wandb
from sklearn.metrics import average_precision_score, f1_score, precision_score, recall_score, accuracy_score


# -------------------------
# 0) HARD REQUIREMENTS (must already exist in your notebook)
# -------------------------
# - base_model : a working baseline model instance (you printed it; in_dim=20000)
# - make_dendrites_model(in_dim: int) : builds dendrites model
# - train_model(run_name, make_model_fn, epochs=..., lr=..., weight_decay=..., resume=...) : trains
# - val_dl, test_dl : dataloaders
# - (optional) eval_report(model, dl, threshold=thr) : if you have it; we provide a safe fallback below

# If you DON'T have eval_report in your notebook, this cell defines one that works for binary classifiers
# assuming batch yields (x, y) and model(x) returns logits or probs.


# -------------------------
# 1) Utilities: clone/reset + in_dim infer + unwrap train_model output
# -------------------------
def fresh_clone_with_reset(model):
    """Deep-copy a model and reset all layers that support reset_parameters()."""
    m = copy.deepcopy(model)

    def _reset(mod):
        if hasattr(mod, "reset_parameters"):
            mod.reset_parameters()

    m.apply(_reset)
    return m


def infer_in_dim_from_base(model):
    """Infer input dimension from first Linear layer."""
    # common: model.net is Sequential
    if hasattr(model, "net") and isinstance(model.net, nn.Sequential):
        for layer in model.net:
            if isinstance(layer, nn.Linear):
                return int(layer.in_features)

    # fallback: scan modules
    for mod in model.modules():
        if isinstance(mod, nn.Linear):
            return int(mod.in_features)

    raise ValueError("Couldn't infer in_dim from base_model.")


def unwrap_model(trained):
    """
    train_model sometimes returns:
      - model
      - (model, stuff...)
      - dict with 'model'
    This unwraps to nn.Module or returns None.
    """
    if isinstance(trained, nn.Module):
        return trained
    if isinstance(trained, (tuple, list)) and len(trained) > 0:
        if isinstance(trained[0], nn.Module):
            return trained[0]
    if isinstance(trained, dict):
        for k in ["model", "best_model", "net"]:
            if k in trained and isinstance(trained[k], nn.Module):
                return trained[k]
    return None


# -------------------------
# 2) Metrics helpers
# -------------------------
def best_f1_threshold(y_true, y_prob):
    """
    Returns a dict: {'thr','f1','precision','recall','tp','fp','fn','tn'}
    Avoids recursion bugs from earlier patches.
    """
    y_true = np.asarray(y_true).astype(int)
    y_prob = np.asarray(y_prob).astype(float)

    # If all probs are same, just use 0.5
    uniq = np.unique(y_prob)
    thr_candidates = uniq
    if len(thr_candidates) > 200:
        # sample thresholds if too many unique values
        thr_candidates = np.quantile(y_prob, np.linspace(0, 1, 200))

    best = {"thr": 0.5, "f1": -1.0, "precision": 0.0, "recall": 0.0, "tp": 0, "fp": 0, "fn": 0, "tn": 0}

    for thr in thr_candidates:
        y_pred = (y_prob >= thr).astype(int)
        tp = int(((y_pred == 1) & (y_true == 1)).sum())
        fp = int(((y_pred == 1) & (y_true == 0)).sum())
        fn = int(((y_pred == 0) & (y_true == 1)).sum())
        tn = int(((y_pred == 0) & (y_true == 0)).sum())

        prec = tp / (tp + fp + 1e-9)
        rec  = tp / (tp + fn + 1e-9)
        f1   = 2 * prec * rec / (prec + rec + 1e-9)

        if f1 > best["f1"]:
            best = {"thr": float(thr), "f1": float(f1), "precision": float(prec), "recall": float(rec),
                    "tp": tp, "fp": fp, "fn": fn, "tn": tn}
    return best


@torch.no_grad()
def collect_probs_and_y(model, dl, device):
    model.eval()
    ys, probs = [], []
    for batch in dl:
        # supports (x,y) or dict batches
        if isinstance(batch, (tuple, list)) and len(batch) >= 2:
            x, y = batch[0], batch[1]
        elif isinstance(batch, dict):
            # try common keys
            x = batch.get("x", batch.get("input", batch.get("features", None)))
            y = batch.get("y", batch.get("label", None))
        else:
            raise ValueError(f"Unsupported batch type: {type(batch)}")

        x = x.to(device)
        y = y.detach().cpu().numpy().astype(int)

        out = model(x)
        # out can be shape [B], [B,1], logits or probs
        if isinstance(out, (tuple, list)):
            out = out[0]
        out = out.detach().float()

        if out.ndim == 2 and out.shape[1] == 1:
            out = out[:, 0]

        # If values look like logits, apply sigmoid; if already [0,1], sigmoid won’t hurt much
        prob = torch.sigmoid(out).detach().cpu().numpy().astype(float)

        ys.append(y)
        probs.append(prob)

    y_true = np.concatenate(ys, axis=0)
    y_prob = np.concatenate(probs, axis=0)
    return y_true, y_prob


def eval_report_safe(model, dl, device, threshold=None, choose_thr_on="val"):
    """
    Returns: (rep_dict, y_prob, y_true)
    rep_dict includes: acc, precision, recall, f1, ap(pr_auc), n, pos_rate, thr_used
    """
    y_true, y_prob = collect_probs_and_y(model, dl, device)

    # threshold selection
    if threshold is None and choose_thr_on == "val":
        best = best_f1_threshold(y_true, y_prob)
        thr = best["thr"]
    else:
        thr = 0.5 if threshold is None else float(threshold)

    y_pred = (y_prob >= thr).astype(int)

    # PR-AUC / AP
    # If a split has no positives, sklearn warns; catch and set AP=0.0
    try:
        ap = float(average_precision_score(y_true, y_prob))
    except Exception:
        ap = 0.0

    rep = {
        "acc": float(accuracy_score(y_true, y_pred)),
        "precision": float(precision_score(y_true, y_pred, zero_division=0)),
        "recall": float(recall_score(y_true, y_pred, zero_division=0)),
        "f1": float(f1_score(y_true, y_pred, zero_division=0)),
        "ap": ap,                 # <- this is PR-AUC / Average Precision
        "pr_auc": ap,             # <- alias (your print earlier expected pr_auc sometimes)
        "n": int(len(y_true)),
        "pos_rate": float(np.mean(y_true == 1)),
        "thr_used": float(thr),
    }
    return rep, y_prob, y_true


def count_params(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return int(total), int(trainable)


# -------------------------
# 3) Disable checkpoint saving safely (prevents sweep crashes)
# -------------------------
# If your notebook has save_ckpt, patch it to no-op.
if "save_ckpt" in globals():
    _save_ckpt_orig = save_ckpt
    def save_ckpt(*args, **kwargs):
        print("⚠️ save_ckpt: skipping checkpoint save to keep sweep running.")
        return None
    globals()["save_ckpt"] = save_ckpt


# -------------------------
# 4) Sweep config (edit as needed)
# -------------------------
sweep_config = {
    "method": "random",
    "metric": {"name": "test_pr_auc", "goal": "maximize"},
    "parameters": {
        "kind": {"values": ["baseline", "dendrites"]},
        "lr": {"values": [2e-5, 1e-4, 5e-4, 1e-3]},
        "weight_decay": {"values": [0.0, 1e-4, 1e-3]},
        "epochs": {"values": [1, 2, 3, 4]},
        # optional; if you want fixed threshold evaluation rather than val-chosen:
        # "threshold": {"values": [0.3, 0.5, 0.7]},
    },
}

# -------------------------
# 5) The sweep run function
# -------------------------
def sweep_run():
    cfg = wandb.config

    # IMPORTANT: always init inside sweep_run
    run_name = f"{cfg.kind}_lr{cfg.lr}_wd{cfg.weight_decay}_ep{cfg.epochs}"
    wandb.init(name=run_name)

    try:
        device = "cuda" if torch.cuda.is_available() else "cpu"
        in_dim = infer_in_dim_from_base(base_model)

        # model factory expected by train_model: make_model_fn() -> model
        def make_model_fn():
            if cfg.kind == "dendrites":
                # signature: make_dendrites_model(in_dim: int) positional
                return make_dendrites_model(in_dim)
            else:
                # baseline: fresh clone of the known-good architecture
                return fresh_clone_with_reset(base_model)

        # ---- train
        trained = train_model(
            run_name=run_name,
            make_model_fn=make_model_fn,
            epochs=int(cfg.epochs),
            lr=float(cfg.lr),
            weight_decay=float(cfg.weight_decay),
            resume=True,
        )

        model = unwrap_model(trained)
        assert model is not None, f"Could not unwrap model from train_model output: {type(trained)}"
        model = model.to(device)

        # ---- evaluate (VAL chooses best-F1 threshold, then report VAL + TEST)
        val_rep, _, _ = eval_report_safe(model, val_dl, device, threshold=None, choose_thr_on="val")
        thr = val_rep["thr_used"]

        test_rep, _, _ = eval_report_safe(model, test_dl, device, threshold=thr, choose_thr_on=None)

        # params
        total_params, trainable_params = count_params(model)

        # log
        wandb.log({
            "kind": cfg.kind,
            "lr": float(cfg.lr),
            "weight_decay": float(cfg.weight_decay),
            "epochs": int(cfg.epochs),

            "params_total": total_params,
            "params_trainable": trainable_params,

            "val_pr_auc": val_rep["pr_auc"],
            "val_f1": val_rep["f1"],
            "val_acc": val_rep["acc"],
            "val_precision": val_rep["precision"],
            "val_recall": val_rep["recall"],
            "val_thr": thr,

            "test_pr_auc": test_rep["pr_auc"],
            "test_f1": test_rep["f1"],
            "test_acc": test_rep["acc"],
            "test_precision": test_rep["precision"],
            "test_recall": test_rep["recall"],
            "test_thr": thr,
        })

        print(f"[{run_name}] DONE | val_pr_auc={val_rep['pr_auc']:.6f} val_f1={val_rep['f1']:.4f} thr={thr:.4f} | "
              f"test_pr_auc={test_rep['pr_auc']:.6f} test_f1={test_rep['f1']:.4f} | params_trainable={trainable_params}")

    except Exception as e:
        wandb.alert(title="Run crashed", text=str(e))
        print("❌ sweep_run crashed:\n", traceback.format_exc())
        raise
    finally:
        wandb.finish()


# -------------------------
# 6) Create + run the sweep
# -------------------------
# (Run this once to create the sweep)
sweep_id = wandb.sweep(sweep_config, project="dendrites-hackathon", entity="vtpy")
print("Created sweep:", sweep_id)
print("Sweep URL:", f"https://wandb.ai/vtpy/dendrites-hackathon/sweeps/{sweep_id}")

# (Then start agent)
# Example:
# wandb.agent(f"vtpy/dendrites-hackathon/{sweep_id}", function=sweep_run, count=30)

Create sweep with ID: 2xsntta0
Sweep URL: https://wandb.ai/vtpy/dendrites-hackathon/sweeps/2xsntta0
Created sweep: 2xsntta0
Sweep URL: https://wandb.ai/vtpy/dendrites-hackathon/sweeps/2xsntta0


In [None]:
def sweep_run():
    # IMPORTANT: init FIRST so wandb.config exists
    run = wandb.init(project="dendrites-hackathon", entity="vtpy")

    cfg = wandb.config  # now valid

    # now you can build a name from cfg and optionally rename the run
    run_name = f"{cfg.kind}_lr{cfg.lr}_wd{cfg.weight_decay}_ep{cfg.epochs}"
    try:
        run.name = run_name  # works in most cases
        run.save()
    except Exception:
        pass

    try:
        device = "cuda" if torch.cuda.is_available() else "cpu"
        in_dim = infer_in_dim_from_base(base_model)

        def make_model_fn():
            if cfg.kind == "dendrites":
                return make_dendrites_model(in_dim)   # positional
            else:
                return fresh_clone_with_reset(base_model)

        trained = train_model(
            run_name=run_name,
            make_model_fn=make_model_fn,
            epochs=int(cfg.epochs),
            lr=float(cfg.lr),
            weight_decay=float(cfg.weight_decay),
            resume=True,
        )

        model = unwrap_model(trained)
        assert model is not None, f"Could not unwrap model from train_model output: {type(trained)}"
        model = model.to(device)

        # VAL chooses threshold; TEST uses it
        val_rep, _, _ = eval_report_safe(model, val_dl, device, threshold=None, choose_thr_on="val")
        thr = val_rep["thr_used"]
        test_rep, _, _ = eval_report_safe(model, test_dl, device, threshold=thr, choose_thr_on=None)

        total_params, trainable_params = count_params(model)

        wandb.log({
            "kind": cfg.kind,
            "lr": float(cfg.lr),
            "weight_decay": float(cfg.weight_decay),
            "epochs": int(cfg.epochs),

            "params_total": total_params,
            "params_trainable": trainable_params,

            "val_pr_auc": val_rep["pr_auc"],
            "val_f1": val_rep["f1"],
            "val_acc": val_rep["acc"],
            "val_thr": thr,

            "test_pr_auc": test_rep["pr_auc"],
            "test_f1": test_rep["f1"],
            "test_acc": test_rep["acc"],
            "test_thr": thr,
        })

        print(f"[{run_name}] DONE | val_pr_auc={val_rep['pr_auc']:.6f} val_f1={val_rep['f1']:.4f} thr={thr:.4f} | "
              f"test_pr_auc={test_rep['pr_auc']:.6f} test_f1={test_rep['f1']:.4f} | params_trainable={trainable_params}")

    except Exception as e:
        wandb.alert(title="Run crashed", text=str(e))
        print("❌ sweep_run crashed:\n", traceback.format_exc())
        raise
    finally:
        wandb.finish()

In [None]:
import numpy as np
from sklearn.metrics import precision_recall_curve

def best_f1_threshold(y, prob):
    """
    Compatibility shim for train_model():
    returns (thr, best_f1, best_precision, best_recall)
    """
    y = np.asarray(y).astype(int)
    prob = np.asarray(prob).astype(float)

    # handle degenerate cases cleanly
    pos = int(y.sum())
    if pos == 0:
        return 0.5, 0.0, 0.0, 0.0
    if pos == len(y):
        return 0.5, 1.0, 1.0, 1.0

    prec, rec, thr = precision_recall_curve(y, prob)
    # prec/rec length = len(thr)+1
    prec_t = prec[:-1]
    rec_t  = rec[:-1]
    denom = (prec_t + rec_t)
    f1 = np.where(denom > 0, 2.0 * prec_t * rec_t / (denom + 1e-12), 0.0)

    if len(f1) == 0 or len(thr) == 0:
        return 0.5, 0.0, float(prec[-1]), float(rec[-1])

    idx = int(np.nanargmax(f1))
    return float(thr[idx]), float(f1[idx]), float(prec_t[idx]), float(rec_t[idx])

# quick sanity check (should print a 4-tuple)
print("best_f1_threshold sanity:", best_f1_threshold([0,1,0,1], [0.1,0.9,0.2,0.8]))

best_f1_threshold sanity: (0.8, 0.9999999999995, 1.0, 1.0)


In [None]:
import numpy as np

def _best_to_thr(best):
    # supports dict output: {"thr": ...}
    if isinstance(best, dict):
        return float(best["thr"])
    # supports tuple/list output: (thr, f1, p, r)
    if isinstance(best, (tuple, list)) and len(best) >= 1:
        return float(best[0])
    raise TypeError(f"Unexpected best_f1_threshold output type: {type(best)}")

def eval_report_safe(model, dl, device, threshold=None, choose_thr_on="val"):
    """
    Safe wrapper:
    - gets prob + y once
    - optionally chooses thr via best_f1_threshold
    - re-runs eval_report at that thr
    Returns: (rep, prob, y_true)
    """
    # 1) run once at any fixed thr just to fetch prob + y
    rep0, prob, yt = eval_report(model, dl, threshold=0.5)  # eval_report already exists in your notebook
    yt = np.asarray(yt).astype(int)
    prob = np.asarray(prob).astype(float)

    # 2) choose threshold
    thr = threshold
    if thr is None:
        best = best_f1_threshold(yt, prob)   # could be dict OR tuple depending on your patch
        thr = _best_to_thr(best)

    # 3) compute final report at chosen thr
    rep, prob2, yt2 = eval_report(model, dl, threshold=float(thr))

    # normalize naming: ensure rep has "pr_auc"
    if "pr_auc" not in rep and "ap" in rep:
        rep["pr_auc"] = rep["ap"]
    rep["thr_used"] = float(thr)

    return rep, prob2, yt2

print("Patched eval_report_safe OK")

Patched eval_report_safe OK


In [None]:
wandb.agent("vtpy/dendrites-hackathon/2xsntta0", function=sweep_run, count=30)

[34m[1mwandb[0m: Agent Starting Run: fz1f248f with config:
[34m[1mwandb[0m: 	epochs: 4
[34m[1mwandb[0m: 	kind: dendrites
[34m[1mwandb[0m: 	lr: 0.0005
[34m[1mwandb[0m: 	weight_decay: 0.001


[dendrites_lr0.0005_wd0.001_ep4] params total/trainable: 2568449 2568449
[dendrites_lr0.0005_wd0.001_ep4] ep 0 step 0/5 loss 2.98529
[dendrites_lr0.0005_wd0.001_ep4] ep 0 train_loss 2.87833 | val PR-AUC 0.126892 F1 0.200000 (thr=0.5000)
[dendrites_lr0.0005_wd0.001_ep4]             test PR-AUC 0.141664 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[dendrites_lr0.0005_wd0.001_ep4] *** new best saved: val_pr_auc=0.126892
[dendrites_lr0.0005_wd0.001_ep4] ep 1 step 0/5 loss 2.79545
[dendrites_lr0.0005_wd0.001_ep4] ep 1 train_loss 2.82381 | val PR-AUC 0.132400 F1 0.200000 (thr=0.5000)
[dendrites_lr0.0005_wd0.001_ep4]             test PR-AUC 0.131993 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[dendrites_lr0.0005_wd0.001_ep4] *** new best saved: val_pr_auc=0.132400
[dendrites_lr0.0005_wd0.001_ep4] ep 2 step 0/5 lo

0,1
epochs,▁
lr,▁
params_total,▁
params_trainable,▁
test_acc,▁
test_f1,▁
test_pr_auc,▁
test_thr,▁
val_acc,▁
val_f1,▁

0,1
epochs,4
kind,dendrites
lr,0.0005
params_total,2568449
params_trainable,2568449
test_acc,0.11111
test_f1,0.2
test_pr_auc,0.13373
test_thr,0.5
val_acc,0.11111


[34m[1mwandb[0m: Agent Starting Run: 0xqg1m6u with config:
[34m[1mwandb[0m: 	epochs: 4
[34m[1mwandb[0m: 	kind: dendrites
[34m[1mwandb[0m: 	lr: 2e-05
[34m[1mwandb[0m: 	weight_decay: 0


[dendrites_lr2e-05_wd0_ep4] params total/trainable: 2568449 2568449
[dendrites_lr2e-05_wd0_ep4] ep 0 step 0/5 loss 3.39261
[dendrites_lr2e-05_wd0_ep4] ep 0 train_loss 3.18762 | val PR-AUC 0.133890 F1 0.200000 (thr=0.5000)
[dendrites_lr2e-05_wd0_ep4]             test PR-AUC 0.304625 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[dendrites_lr2e-05_wd0_ep4] *** new best saved: val_pr_auc=0.133890
[dendrites_lr2e-05_wd0_ep4] ep 1 step 0/5 loss 2.64433
[dendrites_lr2e-05_wd0_ep4] ep 1 train_loss 2.93990 | val PR-AUC 0.147854 F1 0.200000 (thr=0.5000)
[dendrites_lr2e-05_wd0_ep4]             test PR-AUC 0.146308 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[dendrites_lr2e-05_wd0_ep4] *** new best saved: val_pr_auc=0.147854
[dendrites_lr2e-05_wd0_ep4] ep 2 step 0/5 loss 2.93809
[dendrites_lr2e-05_wd0_ep4] ep 2 train_

0,1
epochs,▁
lr,▁
params_total,▁
params_trainable,▁
test_acc,▁
test_f1,▁
test_pr_auc,▁
test_thr,▁
val_acc,▁
val_f1,▁

0,1
epochs,4
kind,dendrites
lr,2e-05
params_total,2568449
params_trainable,2568449
test_acc,0.11111
test_f1,0.2
test_pr_auc,0.15043
test_thr,0.5
val_acc,0.11111


[34m[1mwandb[0m: Agent Starting Run: 3weuyrpb with config:
[34m[1mwandb[0m: 	epochs: 4
[34m[1mwandb[0m: 	kind: dendrites
[34m[1mwandb[0m: 	lr: 2e-05
[34m[1mwandb[0m: 	weight_decay: 0.0001


[dendrites_lr2e-05_wd0.0001_ep4] params total/trainable: 2568449 2568449
[dendrites_lr2e-05_wd0.0001_ep4] ep 0 step 0/5 loss 2.96459
[dendrites_lr2e-05_wd0.0001_ep4] ep 0 train_loss 3.02561 | val PR-AUC 0.129814 F1 0.200000 (thr=0.5000)
[dendrites_lr2e-05_wd0.0001_ep4]             test PR-AUC 0.109317 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[dendrites_lr2e-05_wd0.0001_ep4] *** new best saved: val_pr_auc=0.129814
[dendrites_lr2e-05_wd0.0001_ep4] ep 1 step 0/5 loss 3.48983
[dendrites_lr2e-05_wd0.0001_ep4] ep 1 train_loss 3.16149 | val PR-AUC 0.273817 F1 0.200000 (thr=0.5000)
[dendrites_lr2e-05_wd0.0001_ep4]             test PR-AUC 0.103794 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[dendrites_lr2e-05_wd0.0001_ep4] *** new best saved: val_pr_auc=0.273817
[dendrites_lr2e-05_wd0.0001_ep4] ep 2 step 0/5 lo

0,1
epochs,▁
lr,▁
params_total,▁
params_trainable,▁
test_acc,▁
test_f1,▁
test_pr_auc,▁
test_thr,▁
val_acc,▁
val_f1,▁

0,1
epochs,4
kind,dendrites
lr,2e-05
params_total,2568449
params_trainable,2568449
test_acc,0.11111
test_f1,0.2
test_pr_auc,0.11659
test_thr,0.5
val_acc,0.11111


[34m[1mwandb[0m: Agent Starting Run: yu5ukp7l with config:
[34m[1mwandb[0m: 	epochs: 1
[34m[1mwandb[0m: 	kind: baseline
[34m[1mwandb[0m: 	lr: 0.0005
[34m[1mwandb[0m: 	weight_decay: 0


[baseline_lr0.0005_wd0_ep1] params total/trainable: 10372097 10372097
[baseline_lr0.0005_wd0_ep1] ep 0 step 0/5 loss 2.83551
[baseline_lr0.0005_wd0_ep1] ep 0 train_loss 3.18332 | val PR-AUC 0.138356 F1 0.200000 (thr=0.5000)
[baseline_lr0.0005_wd0_ep1]             test PR-AUC 0.123340 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[baseline_lr0.0005_wd0_ep1] *** new best saved: val_pr_auc=0.138356
[baseline_lr0.0005_wd0_ep1] DONE | val_pr_auc=0.138356 val_f1=0.2000 thr=0.5000 | test_pr_auc=0.123340 test_f1=0.2000 | params_trainable=10372097


0,1
epochs,▁
lr,▁
params_total,▁
params_trainable,▁
test_acc,▁
test_f1,▁
test_pr_auc,▁
test_thr,▁
val_acc,▁
val_f1,▁

0,1
epochs,1
kind,baseline
lr,0.0005
params_total,10372097
params_trainable,10372097
test_acc,0.11111
test_f1,0.2
test_pr_auc,0.12334
test_thr,0.5
val_acc,0.11111


[34m[1mwandb[0m: Agent Starting Run: twepaj7o with config:
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	kind: baseline
[34m[1mwandb[0m: 	lr: 2e-05
[34m[1mwandb[0m: 	weight_decay: 0.0001


[baseline_lr2e-05_wd0.0001_ep3] params total/trainable: 10372097 10372097
[baseline_lr2e-05_wd0.0001_ep3] ep 0 step 0/5 loss 3.53384
[baseline_lr2e-05_wd0.0001_ep3] ep 0 train_loss 3.24137 | val PR-AUC 0.146029 F1 0.000000 (thr=0.5000)
[baseline_lr2e-05_wd0.0001_ep3]             test PR-AUC 0.240580 F1 0.000000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[baseline_lr2e-05_wd0.0001_ep3] *** new best saved: val_pr_auc=0.146029
[baseline_lr2e-05_wd0.0001_ep3] ep 1 step 0/5 loss 3.44574
[baseline_lr2e-05_wd0.0001_ep3] ep 1 train_loss 3.19768 | val PR-AUC 0.145265 F1 0.000000 (thr=0.5000)
[baseline_lr2e-05_wd0.0001_ep3]             test PR-AUC 0.172521 F1 0.000000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[baseline_lr2e-05_wd0.0001_ep3] ep 2 step 0/5 loss 2.75028
[baseline_lr2e-05_wd0.0001_ep3] ep 2 train_loss 3.03057 | val PR-AUC 0.151949 F1 0.000000 (thr=0.5000)
[baseline_lr2e-05_wd0.0001_e

0,1
epochs,▁
lr,▁
params_total,▁
params_trainable,▁
test_acc,▁
test_f1,▁
test_pr_auc,▁
test_thr,▁
val_acc,▁
val_f1,▁

0,1
epochs,3
kind,baseline
lr,2e-05
params_total,10372097
params_trainable,10372097
test_acc,0.88889
test_f1,0.2
test_pr_auc,0.17019
test_thr,0.5
val_acc,0.88889


[34m[1mwandb[0m: Agent Starting Run: xach3azy with config:
[34m[1mwandb[0m: 	epochs: 2
[34m[1mwandb[0m: 	kind: dendrites
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	weight_decay: 0.0001


[dendrites_lr0.001_wd0.0001_ep2] params total/trainable: 2568449 2568449
[dendrites_lr0.001_wd0.0001_ep2] ep 0 step 0/5 loss 2.97245
[dendrites_lr0.001_wd0.0001_ep2] ep 0 train_loss 3.04785 | val PR-AUC 0.139179 F1 0.200000 (thr=0.5000)
[dendrites_lr0.001_wd0.0001_ep2]             test PR-AUC 0.116850 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[dendrites_lr0.001_wd0.0001_ep2] *** new best saved: val_pr_auc=0.139179
[dendrites_lr0.001_wd0.0001_ep2] ep 1 step 0/5 loss 2.92593
[dendrites_lr0.001_wd0.0001_ep2] ep 1 train_loss 2.88724 | val PR-AUC 0.138583 F1 0.200000 (thr=0.5000)
[dendrites_lr0.001_wd0.0001_ep2]             test PR-AUC 0.119438 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[dendrites_lr0.001_wd0.0001_ep2] DONE | val_pr_auc=0.138583 val_f1=0.2000 thr=0.5000 | test_pr_auc=0.119438 test_f1=0.2000 | params_trainable=2568449


0,1
epochs,▁
lr,▁
params_total,▁
params_trainable,▁
test_acc,▁
test_f1,▁
test_pr_auc,▁
test_thr,▁
val_acc,▁
val_f1,▁

0,1
epochs,2
kind,dendrites
lr,0.001
params_total,2568449
params_trainable,2568449
test_acc,0.11111
test_f1,0.2
test_pr_auc,0.11944
test_thr,0.5
val_acc,0.11111


[34m[1mwandb[0m: Agent Starting Run: ha4zwos1 with config:
[34m[1mwandb[0m: 	epochs: 2
[34m[1mwandb[0m: 	kind: dendrites
[34m[1mwandb[0m: 	lr: 2e-05
[34m[1mwandb[0m: 	weight_decay: 0.001


[dendrites_lr2e-05_wd0.001_ep2] params total/trainable: 2568449 2568449
[dendrites_lr2e-05_wd0.001_ep2] ep 0 step 0/5 loss 3.42146
[dendrites_lr2e-05_wd0.001_ep2] ep 0 train_loss 2.97097 | val PR-AUC 0.268488 F1 0.200000 (thr=0.5000)
[dendrites_lr2e-05_wd0.001_ep2]             test PR-AUC 0.130085 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[dendrites_lr2e-05_wd0.001_ep2] *** new best saved: val_pr_auc=0.268488
[dendrites_lr2e-05_wd0.001_ep2] ep 1 step 0/5 loss 2.85349
[dendrites_lr2e-05_wd0.001_ep2] ep 1 train_loss 2.83427 | val PR-AUC 0.259992 F1 0.200000 (thr=0.5000)
[dendrites_lr2e-05_wd0.001_ep2]             test PR-AUC 0.120485 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[dendrites_lr2e-05_wd0.001_ep2] DONE | val_pr_auc=0.259992 val_f1=0.2000 thr=0.5000 | test_pr_auc=0.120485 test_f1=0.2000 | params_trainable=2568449


0,1
epochs,▁
lr,▁
params_total,▁
params_trainable,▁
test_acc,▁
test_f1,▁
test_pr_auc,▁
test_thr,▁
val_acc,▁
val_f1,▁

0,1
epochs,2
kind,dendrites
lr,2e-05
params_total,2568449
params_trainable,2568449
test_acc,0.11111
test_f1,0.2
test_pr_auc,0.12049
test_thr,0.5
val_acc,0.11111


[34m[1mwandb[0m: Agent Starting Run: 1zkbvewx with config:
[34m[1mwandb[0m: 	epochs: 1
[34m[1mwandb[0m: 	kind: dendrites
[34m[1mwandb[0m: 	lr: 2e-05
[34m[1mwandb[0m: 	weight_decay: 0.001


[dendrites_lr2e-05_wd0.001_ep1] params total/trainable: 2568449 2568449
[dendrites_lr2e-05_wd0.001_ep1] ep 0 step 0/5 loss 2.59748
[dendrites_lr2e-05_wd0.001_ep1] ep 0 train_loss 2.87712 | val PR-AUC 0.140572 F1 0.200000 (thr=0.5000)
[dendrites_lr2e-05_wd0.001_ep1]             test PR-AUC 0.118587 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[dendrites_lr2e-05_wd0.001_ep1] *** new best saved: val_pr_auc=0.140572
[dendrites_lr2e-05_wd0.001_ep1] DONE | val_pr_auc=0.140572 val_f1=0.2000 thr=0.5000 | test_pr_auc=0.118587 test_f1=0.2000 | params_trainable=2568449


0,1
epochs,▁
lr,▁
params_total,▁
params_trainable,▁
test_acc,▁
test_f1,▁
test_pr_auc,▁
test_thr,▁
val_acc,▁
val_f1,▁

0,1
epochs,1
kind,dendrites
lr,2e-05
params_total,2568449
params_trainable,2568449
test_acc,0.11111
test_f1,0.2
test_pr_auc,0.11859
test_thr,0.5
val_acc,0.11111


[34m[1mwandb[0m: Agent Starting Run: hqsd4m0a with config:
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	kind: dendrites
[34m[1mwandb[0m: 	lr: 0.0005
[34m[1mwandb[0m: 	weight_decay: 0


[dendrites_lr0.0005_wd0_ep3] params total/trainable: 2568449 2568449
[dendrites_lr0.0005_wd0_ep3] ep 0 step 0/5 loss 2.62058
[dendrites_lr0.0005_wd0_ep3] ep 0 train_loss 3.27938 | val PR-AUC 0.138894 F1 0.222222 (thr=0.5000)
[dendrites_lr0.0005_wd0_ep3]             test PR-AUC 0.168093 F1 0.184211
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[dendrites_lr0.0005_wd0_ep3] *** new best saved: val_pr_auc=0.138894
[dendrites_lr0.0005_wd0_ep3] ep 1 step 0/5 loss 2.88856
[dendrites_lr0.0005_wd0_ep3] ep 1 train_loss 3.15950 | val PR-AUC 0.138894 F1 0.200000 (thr=0.5000)
[dendrites_lr0.0005_wd0_ep3]             test PR-AUC 0.160846 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[dendrites_lr0.0005_wd0_ep3] ep 2 step 0/5 loss 3.29340
[dendrites_lr0.0005_wd0_ep3] ep 2 train_loss 3.25154 | val PR-AUC 0.138894 F1 0.200000 (thr=0.5000)
[dendrites_lr0.0005_wd0_ep3]             test PR-AUC 0.13739

0,1
epochs,▁
lr,▁
params_total,▁
params_trainable,▁
test_acc,▁
test_f1,▁
test_pr_auc,▁
test_thr,▁
val_acc,▁
val_f1,▁

0,1
epochs,3
kind,dendrites
lr,0.0005
params_total,2568449
params_trainable,2568449
test_acc,0.11111
test_f1,0.2
test_pr_auc,0.13739
test_thr,0.5
val_acc,0.11111


[34m[1mwandb[0m: Agent Starting Run: hj7udzgo with config:
[34m[1mwandb[0m: 	epochs: 2
[34m[1mwandb[0m: 	kind: dendrites
[34m[1mwandb[0m: 	lr: 2e-05
[34m[1mwandb[0m: 	weight_decay: 0.0001


[dendrites_lr2e-05_wd0.0001_ep2] params total/trainable: 2568449 2568449
[dendrites_lr2e-05_wd0.0001_ep2] ep 0 step 0/5 loss 2.93285
[dendrites_lr2e-05_wd0.0001_ep2] ep 0 train_loss 3.20800 | val PR-AUC 0.158054 F1 0.200000 (thr=0.5000)
[dendrites_lr2e-05_wd0.0001_ep2]             test PR-AUC 0.146175 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[dendrites_lr2e-05_wd0.0001_ep2] *** new best saved: val_pr_auc=0.158054
[dendrites_lr2e-05_wd0.0001_ep2] ep 1 step 0/5 loss 2.92741
[dendrites_lr2e-05_wd0.0001_ep2] ep 1 train_loss 3.34073 | val PR-AUC 0.160577 F1 0.200000 (thr=0.5000)
[dendrites_lr2e-05_wd0.0001_ep2]             test PR-AUC 0.132836 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[dendrites_lr2e-05_wd0.0001_ep2] *** new best saved: val_pr_auc=0.160577
[dendrites_lr2e-05_wd0.0001_ep2] DONE | val_pr_au

0,1
epochs,▁
lr,▁
params_total,▁
params_trainable,▁
test_acc,▁
test_f1,▁
test_pr_auc,▁
test_thr,▁
val_acc,▁
val_f1,▁

0,1
epochs,2
kind,dendrites
lr,2e-05
params_total,2568449
params_trainable,2568449
test_acc,0.11111
test_f1,0.2
test_pr_auc,0.13284
test_thr,0.5
val_acc,0.11111


[34m[1mwandb[0m: Agent Starting Run: yzqlz566 with config:
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	kind: dendrites
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	weight_decay: 0


[dendrites_lr0.001_wd0_ep3] params total/trainable: 2568449 2568449
[dendrites_lr0.001_wd0_ep3] ep 0 step 0/5 loss 3.05200
[dendrites_lr0.001_wd0_ep3] ep 0 train_loss 3.14665 | val PR-AUC 0.138894 F1 0.200000 (thr=0.5000)
[dendrites_lr0.001_wd0_ep3]             test PR-AUC 0.124357 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[dendrites_lr0.001_wd0_ep3] *** new best saved: val_pr_auc=0.138894
[dendrites_lr0.001_wd0_ep3] ep 1 step 0/5 loss 2.42915
[dendrites_lr0.001_wd0_ep3] ep 1 train_loss 2.84518 | val PR-AUC 0.135216 F1 0.200000 (thr=0.5000)
[dendrites_lr0.001_wd0_ep3]             test PR-AUC 0.119056 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[dendrites_lr0.001_wd0_ep3] ep 2 step 0/5 loss 2.83367
[dendrites_lr0.001_wd0_ep3] ep 2 train_loss 2.80422 | val PR-AUC 0.132429 F1 0.200000 (thr=0.5000)
[dendrites_lr0.001_wd0_ep3]             test PR-AUC 0.120837 F1 0.2000

0,1
epochs,▁
lr,▁
params_total,▁
params_trainable,▁
test_acc,▁
test_f1,▁
test_pr_auc,▁
test_thr,▁
val_acc,▁
val_f1,▁

0,1
epochs,3
kind,dendrites
lr,0.001
params_total,2568449
params_trainable,2568449
test_acc,0.11111
test_f1,0.2
test_pr_auc,0.12084
test_thr,0.5
val_acc,0.11111


[34m[1mwandb[0m: Agent Starting Run: zo813f5w with config:
[34m[1mwandb[0m: 	epochs: 2
[34m[1mwandb[0m: 	kind: dendrites
[34m[1mwandb[0m: 	lr: 2e-05
[34m[1mwandb[0m: 	weight_decay: 0.001


[dendrites_lr2e-05_wd0.001_ep2] params total/trainable: 2568449 2568449
[dendrites_lr2e-05_wd0.001_ep2] ep 0 step 0/5 loss 3.35119
[dendrites_lr2e-05_wd0.001_ep2] ep 0 train_loss 3.32463 | val PR-AUC 0.145752 F1 0.000000 (thr=0.5000)
[dendrites_lr2e-05_wd0.001_ep2]             test PR-AUC 0.127104 F1 0.000000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[dendrites_lr2e-05_wd0.001_ep2] *** new best saved: val_pr_auc=0.145752
[dendrites_lr2e-05_wd0.001_ep2] ep 1 step 0/5 loss 3.18279
[dendrites_lr2e-05_wd0.001_ep2] ep 1 train_loss 3.33265 | val PR-AUC 0.264805 F1 0.000000 (thr=0.5000)
[dendrites_lr2e-05_wd0.001_ep2]             test PR-AUC 0.132249 F1 0.000000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[dendrites_lr2e-05_wd0.001_ep2] *** new best saved: val_pr_auc=0.264805
[dendrites_lr2e-05_wd0.001_ep2] DONE | val_pr_auc=0.264805

0,1
epochs,▁
lr,▁
params_total,▁
params_trainable,▁
test_acc,▁
test_f1,▁
test_pr_auc,▁
test_thr,▁
val_acc,▁
val_f1,▁

0,1
epochs,2
kind,dendrites
lr,2e-05
params_total,2568449
params_trainable,2568449
test_acc,0.88889
test_f1,0
test_pr_auc,0.13225
test_thr,0.5
val_acc,0.88889


[34m[1mwandb[0m: Agent Starting Run: 3vnq7l72 with config:
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	kind: dendrites
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	weight_decay: 0.0001


[dendrites_lr0.001_wd0.0001_ep3] params total/trainable: 2568449 2568449
[dendrites_lr0.001_wd0.0001_ep3] ep 0 step 0/5 loss 3.20503
[dendrites_lr0.001_wd0.0001_ep3] ep 0 train_loss 3.20343 | val PR-AUC 0.137900 F1 0.200000 (thr=0.5000)
[dendrites_lr0.001_wd0.0001_ep3]             test PR-AUC 0.149419 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[dendrites_lr0.001_wd0.0001_ep3] *** new best saved: val_pr_auc=0.137900
[dendrites_lr0.001_wd0.0001_ep3] ep 1 step 0/5 loss 2.46088
[dendrites_lr0.001_wd0.0001_ep3] ep 1 train_loss 2.87917 | val PR-AUC 0.138254 F1 0.200000 (thr=0.5000)
[dendrites_lr0.001_wd0.0001_ep3]             test PR-AUC 0.142630 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[dendrites_lr0.001_wd0.0001_ep3] *** new best saved: val_pr_auc=0.138254
[dendrites_lr0.001_wd0.0001_ep3] ep 2 step 0/5 lo

0,1
epochs,▁
lr,▁
params_total,▁
params_trainable,▁
test_acc,▁
test_f1,▁
test_pr_auc,▁
test_thr,▁
val_acc,▁
val_f1,▁

0,1
epochs,3
kind,dendrites
lr,0.001
params_total,2568449
params_trainable,2568449
test_acc,0.11111
test_f1,0.2
test_pr_auc,0.14719
test_thr,0.5
val_acc,0.11111


[34m[1mwandb[0m: Agent Starting Run: 8yx0d7wi with config:
[34m[1mwandb[0m: 	epochs: 4
[34m[1mwandb[0m: 	kind: dendrites
[34m[1mwandb[0m: 	lr: 0.0001
[34m[1mwandb[0m: 	weight_decay: 0.001


[dendrites_lr0.0001_wd0.001_ep4] params total/trainable: 2568449 2568449
[dendrites_lr0.0001_wd0.001_ep4] ep 0 step 0/5 loss 2.78730
[dendrites_lr0.0001_wd0.001_ep4] ep 0 train_loss 2.88179 | val PR-AUC 0.134810 F1 0.000000 (thr=0.5000)
[dendrites_lr0.0001_wd0.001_ep4]             test PR-AUC 0.122189 F1 0.000000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[dendrites_lr0.0001_wd0.001_ep4] *** new best saved: val_pr_auc=0.134810
[dendrites_lr0.0001_wd0.001_ep4] ep 1 step 0/5 loss 2.77109
[dendrites_lr0.0001_wd0.001_ep4] ep 1 train_loss 3.27606 | val PR-AUC 0.134113 F1 0.000000 (thr=0.5000)
[dendrites_lr0.0001_wd0.001_ep4]             test PR-AUC 0.146629 F1 0.000000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[dendrites_lr0.0001_wd0.001_ep4] ep 2 step 0/5 loss 3.15018
[dendrites_lr0.0001_wd0.001_ep4] ep 2 train_loss 3.36954 | val PR-AUC 0.157076 F1 0.117647 (thr=0.5000)
[dendrites_lr0.0001_

0,1
epochs,▁
lr,▁
params_total,▁
params_trainable,▁
test_acc,▁
test_f1,▁
test_pr_auc,▁
test_thr,▁
val_acc,▁
val_f1,▁

0,1
epochs,4
kind,dendrites
lr,0.0001
params_total,2568449
params_trainable,2568449
test_acc,0.20833
test_f1,0.19718
test_pr_auc,0.14885
test_thr,0.5
val_acc,0.18519


[34m[1mwandb[0m: Agent Starting Run: re2yro8t with config:
[34m[1mwandb[0m: 	epochs: 4
[34m[1mwandb[0m: 	kind: dendrites
[34m[1mwandb[0m: 	lr: 0.0005
[34m[1mwandb[0m: 	weight_decay: 0.001


[dendrites_lr0.0005_wd0.001_ep4] params total/trainable: 2568449 2568449
[dendrites_lr0.0005_wd0.001_ep4] ep 0 step 0/5 loss 2.76954
[dendrites_lr0.0005_wd0.001_ep4] ep 0 train_loss 2.77827 | val PR-AUC 0.140882 F1 0.200000 (thr=0.5000)
[dendrites_lr0.0005_wd0.001_ep4]             test PR-AUC 0.152906 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[dendrites_lr0.0005_wd0.001_ep4] *** new best saved: val_pr_auc=0.140882
[dendrites_lr0.0005_wd0.001_ep4] ep 1 step 0/5 loss 3.27267
[dendrites_lr0.0005_wd0.001_ep4] ep 1 train_loss 2.96741 | val PR-AUC 0.135591 F1 0.200000 (thr=0.5000)
[dendrites_lr0.0005_wd0.001_ep4]             test PR-AUC 0.152783 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[dendrites_lr0.0005_wd0.001_ep4] ep 2 step 0/5 loss 2.92841
[dendrites_lr0.0005_wd0.001_ep4] ep 2 train_loss 2.81991 | val PR-AUC 0.136874 F1 0.200000 (thr=0.5000)
[dendrites_lr0.0005_

0,1
epochs,▁
lr,▁
params_total,▁
params_trainable,▁
test_acc,▁
test_f1,▁
test_pr_auc,▁
test_thr,▁
val_acc,▁
val_f1,▁

0,1
epochs,4
kind,dendrites
lr,0.0005
params_total,2568449
params_trainable,2568449
test_acc,0.11111
test_f1,0.2
test_pr_auc,0.13426
test_thr,0.5
val_acc,0.11111


[34m[1mwandb[0m: Agent Starting Run: qr15z1zc with config:
[34m[1mwandb[0m: 	epochs: 2
[34m[1mwandb[0m: 	kind: dendrites
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	weight_decay: 0.001


[dendrites_lr0.001_wd0.001_ep2] params total/trainable: 2568449 2568449
[dendrites_lr0.001_wd0.001_ep2] ep 0 step 0/5 loss 2.77645
[dendrites_lr0.001_wd0.001_ep2] ep 0 train_loss 3.06744 | val PR-AUC 0.136617 F1 0.200000 (thr=0.5000)
[dendrites_lr0.001_wd0.001_ep2]             test PR-AUC 0.120306 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[dendrites_lr0.001_wd0.001_ep2] *** new best saved: val_pr_auc=0.136617
[dendrites_lr0.001_wd0.001_ep2] ep 1 step 0/5 loss 2.82077
[dendrites_lr0.001_wd0.001_ep2] ep 1 train_loss 2.72184 | val PR-AUC 0.141235 F1 0.200000 (thr=0.5000)
[dendrites_lr0.001_wd0.001_ep2]             test PR-AUC 0.115449 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[dendrites_lr0.001_wd0.001_ep2] *** new best saved: val_pr_auc=0.141235
[dendrites_lr0.001_wd0.001_ep2] DONE | val_pr_auc=0.141235

0,1
epochs,▁
lr,▁
params_total,▁
params_trainable,▁
test_acc,▁
test_f1,▁
test_pr_auc,▁
test_thr,▁
val_acc,▁
val_f1,▁

0,1
epochs,2
kind,dendrites
lr,0.001
params_total,2568449
params_trainable,2568449
test_acc,0.11111
test_f1,0.2
test_pr_auc,0.11545
test_thr,0.5
val_acc,0.11111


[34m[1mwandb[0m: Agent Starting Run: hbged74a with config:
[34m[1mwandb[0m: 	epochs: 1
[34m[1mwandb[0m: 	kind: baseline
[34m[1mwandb[0m: 	lr: 2e-05
[34m[1mwandb[0m: 	weight_decay: 0


[baseline_lr2e-05_wd0_ep1] params total/trainable: 10372097 10372097
[baseline_lr2e-05_wd0_ep1] ep 0 step 0/5 loss 2.90338
[baseline_lr2e-05_wd0_ep1] ep 0 train_loss 3.07600 | val PR-AUC 0.189611 F1 0.200000 (thr=0.5000)
[baseline_lr2e-05_wd0_ep1]             test PR-AUC 0.115256 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[baseline_lr2e-05_wd0_ep1] *** new best saved: val_pr_auc=0.189611
[baseline_lr2e-05_wd0_ep1] DONE | val_pr_auc=0.189611 val_f1=0.2000 thr=0.5000 | test_pr_auc=0.115256 test_f1=0.2000 | params_trainable=10372097


0,1
epochs,▁
lr,▁
params_total,▁
params_trainable,▁
test_acc,▁
test_f1,▁
test_pr_auc,▁
test_thr,▁
val_acc,▁
val_f1,▁

0,1
epochs,1
kind,baseline
lr,2e-05
params_total,10372097
params_trainable,10372097
test_acc,0.11111
test_f1,0.2
test_pr_auc,0.11526
test_thr,0.5
val_acc,0.11111


[34m[1mwandb[0m: Agent Starting Run: un6b91m9 with config:
[34m[1mwandb[0m: 	epochs: 2
[34m[1mwandb[0m: 	kind: dendrites
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	weight_decay: 0.0001


[dendrites_lr0.001_wd0.0001_ep2] params total/trainable: 2568449 2568449
[dendrites_lr0.001_wd0.0001_ep2] ep 0 step 0/5 loss 2.96654
[dendrites_lr0.001_wd0.0001_ep2] ep 0 train_loss 3.00129 | val PR-AUC 0.136358 F1 0.000000 (thr=0.5000)
[dendrites_lr0.001_wd0.0001_ep2]             test PR-AUC 0.151980 F1 0.000000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[dendrites_lr0.001_wd0.0001_ep2] *** new best saved: val_pr_auc=0.136358
[dendrites_lr0.001_wd0.0001_ep2] ep 1 step 0/5 loss 3.02216
[dendrites_lr0.001_wd0.0001_ep2] ep 1 train_loss 3.10712 | val PR-AUC 0.142261 F1 0.200000 (thr=0.5000)
[dendrites_lr0.001_wd0.0001_ep2]             test PR-AUC 0.175645 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[dendrites_lr0.001_wd0.0001_ep2] *** new best saved: val_pr_auc=0.142261
[dendrites_lr0.001_wd0.0001_ep2] DONE | val_pr_au

0,1
epochs,▁
lr,▁
params_total,▁
params_trainable,▁
test_acc,▁
test_f1,▁
test_pr_auc,▁
test_thr,▁
val_acc,▁
val_f1,▁

0,1
epochs,2
kind,dendrites
lr,0.001
params_total,2568449
params_trainable,2568449
test_acc,0.11111
test_f1,0.2
test_pr_auc,0.17564
test_thr,0.5
val_acc,0.11111


[34m[1mwandb[0m: Agent Starting Run: bakfgf59 with config:
[34m[1mwandb[0m: 	epochs: 2
[34m[1mwandb[0m: 	kind: baseline
[34m[1mwandb[0m: 	lr: 0.0001
[34m[1mwandb[0m: 	weight_decay: 0.001


[baseline_lr0.0001_wd0.001_ep2] params total/trainable: 10372097 10372097
[baseline_lr0.0001_wd0.001_ep2] ep 0 step 0/5 loss 3.62349
[baseline_lr0.0001_wd0.001_ep2] ep 0 train_loss 3.15482 | val PR-AUC 0.151431 F1 0.000000 (thr=0.5000)
[baseline_lr0.0001_wd0.001_ep2]             test PR-AUC 0.151361 F1 0.000000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[baseline_lr0.0001_wd0.001_ep2] *** new best saved: val_pr_auc=0.151431
[baseline_lr0.0001_wd0.001_ep2] ep 1 step 0/5 loss 3.51563
[baseline_lr0.0001_wd0.001_ep2] ep 1 train_loss 3.24216 | val PR-AUC 0.138931 F1 0.200000 (thr=0.5000)
[baseline_lr0.0001_wd0.001_ep2]             test PR-AUC 0.143308 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[baseline_lr0.0001_wd0.001_ep2] DONE | val_pr_auc=0.138931 val_f1=0.2000 thr=0.5000 | test_pr_auc=0.143308 test_f1=0.2000 | params_trainable=10372097


0,1
epochs,▁
lr,▁
params_total,▁
params_trainable,▁
test_acc,▁
test_f1,▁
test_pr_auc,▁
test_thr,▁
val_acc,▁
val_f1,▁

0,1
epochs,2
kind,baseline
lr,0.0001
params_total,10372097
params_trainable,10372097
test_acc,0.11111
test_f1,0.2
test_pr_auc,0.14331
test_thr,0.5
val_acc,0.11111


[34m[1mwandb[0m: Agent Starting Run: 7grwqnyj with config:
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	kind: baseline
[34m[1mwandb[0m: 	lr: 0.0001
[34m[1mwandb[0m: 	weight_decay: 0.001


[baseline_lr0.0001_wd0.001_ep3] params total/trainable: 10372097 10372097
[baseline_lr0.0001_wd0.001_ep3] ep 0 step 0/5 loss 3.53680
[baseline_lr0.0001_wd0.001_ep3] ep 0 train_loss 3.17501 | val PR-AUC 0.147165 F1 0.200000 (thr=0.5000)
[baseline_lr0.0001_wd0.001_ep3]             test PR-AUC 0.112156 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[baseline_lr0.0001_wd0.001_ep3] *** new best saved: val_pr_auc=0.147165
[baseline_lr0.0001_wd0.001_ep3] ep 1 step 0/5 loss 2.32606
[baseline_lr0.0001_wd0.001_ep3] ep 1 train_loss 3.10533 | val PR-AUC 0.151520 F1 0.200000 (thr=0.5000)
[baseline_lr0.0001_wd0.001_ep3]             test PR-AUC 0.166416 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[baseline_lr0.0001_wd0.001_ep3] *** new best saved: val_pr_auc=0.151520
[baseline_lr0.0001_wd0.001_ep3] ep 2 step 0/5 loss 2.600

0,1
epochs,▁
lr,▁
params_total,▁
params_trainable,▁
test_acc,▁
test_f1,▁
test_pr_auc,▁
test_thr,▁
val_acc,▁
val_f1,▁

0,1
epochs,3
kind,baseline
lr,0.0001
params_total,10372097
params_trainable,10372097
test_acc,0.11111
test_f1,0.2
test_pr_auc,0.1691
test_thr,0.5
val_acc,0.11111


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: 4sfgwjm7 with config:
[34m[1mwandb[0m: 	epochs: 4
[34m[1mwandb[0m: 	kind: baseline
[34m[1mwandb[0m: 	lr: 0.0005
[34m[1mwandb[0m: 	weight_decay: 0.001


[baseline_lr0.0005_wd0.001_ep4] params total/trainable: 10372097 10372097
[baseline_lr0.0005_wd0.001_ep4] ep 0 step 0/5 loss 2.95334
[baseline_lr0.0005_wd0.001_ep4] ep 0 train_loss 3.15111 | val PR-AUC 0.138386 F1 0.200000 (thr=0.5000)
[baseline_lr0.0005_wd0.001_ep4]             test PR-AUC 0.146337 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[baseline_lr0.0005_wd0.001_ep4] *** new best saved: val_pr_auc=0.138386
[baseline_lr0.0005_wd0.001_ep4] ep 1 step 0/5 loss 2.34211
[baseline_lr0.0005_wd0.001_ep4] ep 1 train_loss 2.53145 | val PR-AUC 0.138386 F1 0.200000 (thr=0.5000)
[baseline_lr0.0005_wd0.001_ep4]             test PR-AUC 0.153075 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[baseline_lr0.0005_wd0.001_ep4] ep 2 step 0/5 loss 2.92862
[baseline_lr0.0005_wd0.001_ep4] ep 2 train_loss 2.56843 | val PR-AUC 0.142261 F1 0.200000 (thr=0.5000)
[baseline_lr0.0005_wd0.001_e

0,1
epochs,▁
lr,▁
params_total,▁
params_trainable,▁
test_acc,▁
test_f1,▁
test_pr_auc,▁
test_thr,▁
val_acc,▁
val_f1,▁

0,1
epochs,4
kind,baseline
lr,0.0005
params_total,10372097
params_trainable,10372097
test_acc,0.11111
test_f1,0.2
test_pr_auc,0.14699
test_thr,0.5
val_acc,0.11111


[34m[1mwandb[0m: Agent Starting Run: 7m7vhznc with config:
[34m[1mwandb[0m: 	epochs: 4
[34m[1mwandb[0m: 	kind: baseline
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	weight_decay: 0.0001


[baseline_lr0.001_wd0.0001_ep4] params total/trainable: 10372097 10372097
[baseline_lr0.001_wd0.0001_ep4] ep 0 step 0/5 loss 2.98969
[baseline_lr0.001_wd0.0001_ep4] ep 0 train_loss 2.77152 | val PR-AUC 0.132224 F1 0.200000 (thr=0.5000)
[baseline_lr0.001_wd0.0001_ep4]             test PR-AUC 0.123866 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[baseline_lr0.001_wd0.0001_ep4] *** new best saved: val_pr_auc=0.132224
[baseline_lr0.001_wd0.0001_ep4] ep 1 step 0/5 loss 2.98479
[baseline_lr0.001_wd0.0001_ep4] ep 1 train_loss 2.30585 | val PR-AUC 0.137598 F1 0.200000 (thr=0.5000)
[baseline_lr0.001_wd0.0001_ep4]             test PR-AUC 0.148748 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[baseline_lr0.001_wd0.0001_ep4] *** new best saved: val_pr_auc=0.137598
[baseline_lr0.001_wd0.0001_ep4] ep 2 step 0/5 loss 1.885

0,1
epochs,▁
lr,▁
params_total,▁
params_trainable,▁
test_acc,▁
test_f1,▁
test_pr_auc,▁
test_thr,▁
val_acc,▁
val_f1,▁

0,1
epochs,4
kind,baseline
lr,0.001
params_total,10372097
params_trainable,10372097
test_acc,0.11111
test_f1,0.2
test_pr_auc,0.15022
test_thr,0.5
val_acc,0.11111


[34m[1mwandb[0m: Agent Starting Run: g39cnrgo with config:
[34m[1mwandb[0m: 	epochs: 4
[34m[1mwandb[0m: 	kind: baseline
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	weight_decay: 0.0001


[baseline_lr0.001_wd0.0001_ep4] params total/trainable: 10372097 10372097
[baseline_lr0.001_wd0.0001_ep4] ep 0 step 0/5 loss 3.45308
[baseline_lr0.001_wd0.0001_ep4] ep 0 train_loss 3.04676 | val PR-AUC 0.151520 F1 0.200000 (thr=0.5000)
[baseline_lr0.001_wd0.0001_ep4]             test PR-AUC 0.149365 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[baseline_lr0.001_wd0.0001_ep4] *** new best saved: val_pr_auc=0.151520
[baseline_lr0.001_wd0.0001_ep4] ep 1 step 0/5 loss 2.79860
[baseline_lr0.001_wd0.0001_ep4] ep 1 train_loss 2.64346 | val PR-AUC 0.142261 F1 0.200000 (thr=0.5000)
[baseline_lr0.001_wd0.0001_ep4]             test PR-AUC 0.151988 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[baseline_lr0.001_wd0.0001_ep4] ep 2 step 0/5 loss 2.38090
[baseline_lr0.001_wd0.0001_ep4] ep 2 train_loss 1.84647 | val PR-AUC 0.142261 F1 0.200000 (thr=0.5000)
[baseline_lr0.001_wd0.0001_e

0,1
epochs,▁
lr,▁
params_total,▁
params_trainable,▁
test_acc,▁
test_f1,▁
test_pr_auc,▁
test_thr,▁
val_acc,▁
val_f1,▁

0,1
epochs,4
kind,baseline
lr,0.001
params_total,10372097
params_trainable,10372097
test_acc,0.11111
test_f1,0.2
test_pr_auc,0.15067
test_thr,0.5
val_acc,0.11111


[34m[1mwandb[0m: Agent Starting Run: ub1i9k9l with config:
[34m[1mwandb[0m: 	epochs: 2
[34m[1mwandb[0m: 	kind: dendrites
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	weight_decay: 0.0001


[dendrites_lr0.001_wd0.0001_ep2] params total/trainable: 2568449 2568449
[dendrites_lr0.001_wd0.0001_ep2] ep 0 step 0/5 loss 2.65545
[dendrites_lr0.001_wd0.0001_ep2] ep 0 train_loss 2.90593 | val PR-AUC 0.134174 F1 0.200000 (thr=0.5000)
[dendrites_lr0.001_wd0.0001_ep2]             test PR-AUC 0.120219 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[dendrites_lr0.001_wd0.0001_ep2] *** new best saved: val_pr_auc=0.134174
[dendrites_lr0.001_wd0.0001_ep2] ep 1 step 0/5 loss 2.91313
[dendrites_lr0.001_wd0.0001_ep2] ep 1 train_loss 2.69795 | val PR-AUC 0.135204 F1 0.200000 (thr=0.5000)
[dendrites_lr0.001_wd0.0001_ep2]             test PR-AUC 0.117880 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[dendrites_lr0.001_wd0.0001_ep2] *** new best saved: val_pr_auc=0.135204
[dendrites_lr0.001_wd0.0001_ep2] DONE | val_pr_au

0,1
epochs,▁
lr,▁
params_total,▁
params_trainable,▁
test_acc,▁
test_f1,▁
test_pr_auc,▁
test_thr,▁
val_acc,▁
val_f1,▁

0,1
epochs,2
kind,dendrites
lr,0.001
params_total,2568449
params_trainable,2568449
test_acc,0.11111
test_f1,0.2
test_pr_auc,0.11788
test_thr,0.5
val_acc,0.11111


[34m[1mwandb[0m: Agent Starting Run: i50x2z3v with config:
[34m[1mwandb[0m: 	epochs: 2
[34m[1mwandb[0m: 	kind: dendrites
[34m[1mwandb[0m: 	lr: 0.0005
[34m[1mwandb[0m: 	weight_decay: 0


[dendrites_lr0.0005_wd0_ep2] params total/trainable: 2568449 2568449
[dendrites_lr0.0005_wd0_ep2] ep 0 step 0/5 loss 3.53084
[dendrites_lr0.0005_wd0_ep2] ep 0 train_loss 3.18071 | val PR-AUC 0.137481 F1 0.000000 (thr=0.5000)
[dendrites_lr0.0005_wd0_ep2]             test PR-AUC 0.121560 F1 0.000000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[dendrites_lr0.0005_wd0_ep2] *** new best saved: val_pr_auc=0.137481
[dendrites_lr0.0005_wd0_ep2] ep 1 step 0/5 loss 3.12317
[dendrites_lr0.0005_wd0_ep2] ep 1 train_loss 3.33956 | val PR-AUC 0.142261 F1 0.000000 (thr=0.5000)
[dendrites_lr0.0005_wd0_ep2]             test PR-AUC 0.118639 F1 0.000000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[dendrites_lr0.0005_wd0_ep2] *** new best saved: val_pr_auc=0.142261
[dendrites_lr0.0005_wd0_ep2] DONE | val_pr_auc=0.142261 val_f1=0.0000 thr=0.5000 | te

0,1
epochs,▁
lr,▁
params_total,▁
params_trainable,▁
test_acc,▁
test_f1,▁
test_pr_auc,▁
test_thr,▁
val_acc,▁
val_f1,▁

0,1
epochs,2
kind,dendrites
lr,0.0005
params_total,2568449
params_trainable,2568449
test_acc,0.88889
test_f1,0
test_pr_auc,0.11864
test_thr,0.5
val_acc,0.88889


[34m[1mwandb[0m: Agent Starting Run: r7qk3iai with config:
[34m[1mwandb[0m: 	epochs: 1
[34m[1mwandb[0m: 	kind: dendrites
[34m[1mwandb[0m: 	lr: 0.0001
[34m[1mwandb[0m: 	weight_decay: 0


[dendrites_lr0.0001_wd0_ep1] params total/trainable: 2568449 2568449
[dendrites_lr0.0001_wd0_ep1] ep 0 step 0/5 loss 3.09531
[dendrites_lr0.0001_wd0_ep1] ep 0 train_loss 2.88507 | val PR-AUC 0.153786 F1 0.200000 (thr=0.5000)
[dendrites_lr0.0001_wd0_ep1]             test PR-AUC 0.130879 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[dendrites_lr0.0001_wd0_ep1] *** new best saved: val_pr_auc=0.153786
[dendrites_lr0.0001_wd0_ep1] DONE | val_pr_auc=0.153786 val_f1=0.2000 thr=0.5000 | test_pr_auc=0.130879 test_f1=0.2000 | params_trainable=2568449


0,1
epochs,▁
lr,▁
params_total,▁
params_trainable,▁
test_acc,▁
test_f1,▁
test_pr_auc,▁
test_thr,▁
val_acc,▁
val_f1,▁

0,1
epochs,1
kind,dendrites
lr,0.0001
params_total,2568449
params_trainable,2568449
test_acc,0.11111
test_f1,0.2
test_pr_auc,0.13088
test_thr,0.5
val_acc,0.11111


[34m[1mwandb[0m: Agent Starting Run: gh6y6o67 with config:
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	kind: baseline
[34m[1mwandb[0m: 	lr: 0.0001
[34m[1mwandb[0m: 	weight_decay: 0.001


[baseline_lr0.0001_wd0.001_ep3] params total/trainable: 10372097 10372097
[baseline_lr0.0001_wd0.001_ep3] ep 0 step 0/5 loss 2.99889
[baseline_lr0.0001_wd0.001_ep3] ep 0 train_loss 3.10710 | val PR-AUC 0.143138 F1 0.193548 (thr=0.5000)
[baseline_lr0.0001_wd0.001_ep3]             test PR-AUC 0.152984 F1 0.166667
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[baseline_lr0.0001_wd0.001_ep3] *** new best saved: val_pr_auc=0.143138
[baseline_lr0.0001_wd0.001_ep3] ep 1 step 0/5 loss 3.35455
[baseline_lr0.0001_wd0.001_ep3] ep 1 train_loss 3.13740 | val PR-AUC 0.140285 F1 0.200000 (thr=0.5000)
[baseline_lr0.0001_wd0.001_ep3]             test PR-AUC 0.141364 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[baseline_lr0.0001_wd0.001_ep3] ep 2 step 0/5 loss 3.02126
[baseline_lr0.0001_wd0.001_ep3] ep 2 train_loss 3.05349 | val PR-AUC 0.138583 F1 0.200000 (thr=0.5000)
[baseline_lr0.0001_wd0.001_e

0,1
epochs,▁
lr,▁
params_total,▁
params_trainable,▁
test_acc,▁
test_f1,▁
test_pr_auc,▁
test_thr,▁
val_acc,▁
val_f1,▁

0,1
epochs,3
kind,baseline
lr,0.0001
params_total,10372097
params_trainable,10372097
test_acc,0.11111
test_f1,0.2
test_pr_auc,0.14033
test_thr,0.5
val_acc,0.11111


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: 2bbfdnc9 with config:
[34m[1mwandb[0m: 	epochs: 4
[34m[1mwandb[0m: 	kind: baseline
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	weight_decay: 0


[baseline_lr0.001_wd0_ep4] params total/trainable: 10372097 10372097
[baseline_lr0.001_wd0_ep4] ep 0 step 0/5 loss 2.87862
[baseline_lr0.001_wd0_ep4] ep 0 train_loss 3.00024 | val PR-AUC 0.135591 F1 0.200000 (thr=0.5000)
[baseline_lr0.001_wd0_ep4]             test PR-AUC 0.152122 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[baseline_lr0.001_wd0_ep4] *** new best saved: val_pr_auc=0.135591
[baseline_lr0.001_wd0_ep4] ep 1 step 0/5 loss 2.76188
[baseline_lr0.001_wd0_ep4] ep 1 train_loss 2.73388 | val PR-AUC 0.137598 F1 0.200000 (thr=0.5000)
[baseline_lr0.001_wd0_ep4]             test PR-AUC 0.147236 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[baseline_lr0.001_wd0_ep4] *** new best saved: val_pr_auc=0.137598
[baseline_lr0.001_wd0_ep4] ep 2 step 0/5 loss 2.09834
[baseline_lr0.001_wd0_ep4] ep 2 train_loss 1.92

0,1
epochs,▁
lr,▁
params_total,▁
params_trainable,▁
test_acc,▁
test_f1,▁
test_pr_auc,▁
test_thr,▁
val_acc,▁
val_f1,▁

0,1
epochs,4
kind,baseline
lr,0.001
params_total,10372097
params_trainable,10372097
test_acc,0.11111
test_f1,0.2
test_pr_auc,0.14164
test_thr,0.5
val_acc,0.11111


[34m[1mwandb[0m: Agent Starting Run: y2daek4v with config:
[34m[1mwandb[0m: 	epochs: 2
[34m[1mwandb[0m: 	kind: dendrites
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	weight_decay: 0


[dendrites_lr0.001_wd0_ep2] params total/trainable: 2568449 2568449
[dendrites_lr0.001_wd0_ep2] ep 0 step 0/5 loss 2.88076
[dendrites_lr0.001_wd0_ep2] ep 0 train_loss 2.97133 | val PR-AUC 0.149782 F1 0.200000 (thr=0.5000)
[dendrites_lr0.001_wd0_ep2]             test PR-AUC 0.145245 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[dendrites_lr0.001_wd0_ep2] *** new best saved: val_pr_auc=0.149782
[dendrites_lr0.001_wd0_ep2] ep 1 step 0/5 loss 2.75297
[dendrites_lr0.001_wd0_ep2] ep 1 train_loss 2.69678 | val PR-AUC 0.136617 F1 0.200000 (thr=0.5000)
[dendrites_lr0.001_wd0_ep2]             test PR-AUC 0.146263 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[dendrites_lr0.001_wd0_ep2] DONE | val_pr_auc=0.136617 val_f1=0.2000 thr=0.5000 | test_pr_auc=0.146263 test_f1=0.2000 | params_trainable=2568449


0,1
epochs,▁
lr,▁
params_total,▁
params_trainable,▁
test_acc,▁
test_f1,▁
test_pr_auc,▁
test_thr,▁
val_acc,▁
val_f1,▁

0,1
epochs,2
kind,dendrites
lr,0.001
params_total,2568449
params_trainable,2568449
test_acc,0.11111
test_f1,0.2
test_pr_auc,0.14626
test_thr,0.5
val_acc,0.11111


[34m[1mwandb[0m: Agent Starting Run: 1i06xvzu with config:
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	kind: baseline
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	weight_decay: 0.001


[baseline_lr0.001_wd0.001_ep3] params total/trainable: 10372097 10372097
[baseline_lr0.001_wd0.001_ep3] ep 0 step 0/5 loss 2.53715
[baseline_lr0.001_wd0.001_ep3] ep 0 train_loss 2.94580 | val PR-AUC 0.135216 F1 0.200000 (thr=0.5000)
[baseline_lr0.001_wd0.001_ep3]             test PR-AUC 0.122668 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[baseline_lr0.001_wd0.001_ep3] *** new best saved: val_pr_auc=0.135216
[baseline_lr0.001_wd0.001_ep3] ep 1 step 0/5 loss 2.89042
[baseline_lr0.001_wd0.001_ep3] ep 1 train_loss 2.55804 | val PR-AUC 0.135796 F1 0.200000 (thr=0.5000)
[baseline_lr0.001_wd0.001_ep3]             test PR-AUC 0.149387 F1 0.200000
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
⚠️ save_ckpt: skipping checkpoint save to keep sweep running.
[baseline_lr0.001_wd0.001_ep3] *** new best saved: val_pr_auc=0.135796
[baseline_lr0.001_wd0.001_ep3] ep 2 step 0/5 loss 2.05946
[baseli

0,1
epochs,▁
lr,▁
params_total,▁
params_trainable,▁
test_acc,▁
test_f1,▁
test_pr_auc,▁
test_thr,▁
val_acc,▁
val_f1,▁

0,1
epochs,3
kind,baseline
lr,0.001
params_total,10372097
params_trainable,10372097
test_acc,0.11111
test_f1,0.2
test_pr_auc,0.14346
test_thr,0.5
val_acc,0.11111


In [None]:
# =========================
# W&B Sweep + Metrics (FIXED)
# - Fixes: "call wandb.init() before wandb.config"
# - Fixes: best_f1_threshold unpack mismatch
# - Fixes: eval_report_safe expecting dict but got tuple
# =========================

import os
import math
import copy
import random
from dataclasses import dataclass
from typing import Dict, Any, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn

import wandb
from sklearn.metrics import (
    average_precision_score,
    precision_recall_curve,
)

# -------------------------
# Utils
# -------------------------
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def count_params(model: nn.Module) -> Tuple[int, int]:
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable

def _to_numpy(x):
    if isinstance(x, torch.Tensor):
        return x.detach().cpu().numpy()
    return np.asarray(x)

def _squeeze_prob(y_prob: np.ndarray) -> np.ndarray:
    y_prob = np.asarray(y_prob)
    # Accept (N,1) or (N,) for binary
    if y_prob.ndim == 2 and y_prob.shape[1] == 1:
        y_prob = y_prob[:, 0]
    return y_prob

def _squeeze_y(y_true: np.ndarray) -> np.ndarray:
    y_true = np.asarray(y_true)
    if y_true.ndim == 2 and y_true.shape[1] == 1:
        y_true = y_true[:, 0]
    return y_true

# -------------------------
# Threshold selection (binary)
# Returns EXACTLY 4 values to match: thr, best_f1, best_p, best_r
# -------------------------
def best_f1_threshold(y_true, y_prob) -> Tuple[float, float, float, float]:
    y_true = _squeeze_y(_to_numpy(y_true)).astype(int)
    y_prob = _squeeze_prob(_to_numpy(y_prob)).astype(float)

    # Guard
    if y_true.size == 0:
        return 0.5, 0.0, 0.0, 0.0

    precisions, recalls, thresholds = precision_recall_curve(y_true, y_prob)

    # precision_recall_curve gives thresholds of length (len(precisions)-1)
    # Use midpoints; common handling: compute f1 for each threshold index
    eps = 1e-12
    f1s = (2 * precisions * recalls) / (precisions + recalls + eps)

    # Align thresholds: f1s has same len as precisions/recalls
    # Thresholds misses first element; we map indices >=1 to thresholds[idx-1]
    best_idx = int(np.nanargmax(f1s))
    if best_idx == 0:
        best_thr = 0.5  # fallback (no threshold for idx 0)
    else:
        best_thr = float(thresholds[best_idx - 1])

    best_f1 = float(f1s[best_idx])
    best_p = float(precisions[best_idx])
    best_r = float(recalls[best_idx])
    return best_thr, best_f1, best_p, best_r

def bin_metrics_at_threshold(y_true, y_prob, thr: float) -> Dict[str, float]:
    y_true = _squeeze_y(_to_numpy(y_true)).astype(int)
    y_prob = _squeeze_prob(_to_numpy(y_prob)).astype(float)
    y_pred = (y_prob >= thr).astype(int)

    tp = float(((y_pred == 1) & (y_true == 1)).sum())
    fp = float(((y_pred == 1) & (y_true == 0)).sum())
    fn = float(((y_pred == 0) & (y_true == 1)).sum())
    tn = float(((y_pred == 0) & (y_true == 0)).sum())

    eps = 1e-12
    precision = tp / (tp + fp + eps)
    recall = tp / (tp + fn + eps)
    f1 = 2 * precision * recall / (precision + recall + eps)
    acc = (tp + tn) / (tp + tn + fp + fn + eps)

    return {
        "acc": float(acc),
        "precision": float(precision),
        "recall": float(recall),
        "f1": float(f1),
    }

def pr_auc(y_true, y_prob) -> float:
    y_true = _squeeze_y(_to_numpy(y_true)).astype(int)
    y_prob = _squeeze_prob(_to_numpy(y_prob)).astype(float)
    # If only one class present, AP is defined but can be 0.0; sklearn may warn
    try:
        return float(average_precision_score(y_true, y_prob))
    except Exception:
        return 0.0

# -------------------------
# Safe eval report
# -------------------------
@torch.no_grad()
def predict_probs(model, dl, device) -> Tuple[np.ndarray, np.ndarray]:
    model.eval()
    all_y = []
    all_prob = []

    for batch in dl:
        # ---- TODO: adapt to your dataloader batch structure ----
        # Example assumes batch = (x, y)
        x, y = batch
        x = x.to(device)
        y = y.to(device)

        logits = model(x)

        # ---- TODO: choose sigmoid/softmax depending on your setup ----
        # Binary: logits shape (N,) or (N,1)
        prob = torch.sigmoid(logits)

        all_y.append(y.detach().cpu())
        all_prob.append(prob.detach().cpu())

    y_true = torch.cat(all_y, dim=0).numpy()
    y_prob = torch.cat(all_prob, dim=0).numpy()
    return y_true, y_prob

def eval_report_safe(
    model,
    dl,
    device,
    threshold: Optional[float] = None,
    choose_thr_on: str = "val",
) -> Tuple[Dict[str, float], np.ndarray, np.ndarray]:
    y_true, y_prob = predict_probs(model, dl, device)

    # Choose threshold if not provided
    if threshold is None:
        thr, best_f1, best_p, best_r = best_f1_threshold(y_true, y_prob)
    else:
        thr = float(threshold)

    report = {
        "pr_auc": pr_auc(y_true, y_prob),
        "thr": float(thr),
    }
    report.update(bin_metrics_at_threshold(y_true, y_prob, thr))
    return report, y_prob, y_true

# -------------------------
# Training loop (generic)
# -------------------------
def train_one_epoch(model, dl, optimizer, device, criterion) -> float:
    model.train()
    running = 0.0
    n = 0

    for batch in dl:
        # ---- TODO: adapt to your batch structure ----
        x, y = batch
        x = x.to(device)
        y = y.to(device)

        optimizer.zero_grad(set_to_none=True)
        logits = model(x)

        # Ensure shapes match criterion expectations
        # For BCEWithLogitsLoss: logits and y should be float
        loss = criterion(logits.view_as(y).float(), y.float())
        loss.backward()
        optimizer.step()

        bs = y.shape[0]
        running += float(loss.item()) * bs
        n += bs

    return running / max(n, 1)

def save_ckpt(path: str, model: nn.Module, extra: Optional[Dict[str, Any]] = None):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    payload = {"model_state": model.state_dict()}
    if extra:
        payload.update(extra)
    torch.save(payload, path)

def train_model(
    *,
    kind: str,
    epochs: int,
    lr: float,
    weight_decay: float,
    model: nn.Module,
    train_dl,
    val_dl,
    test_dl,
    device,
    run_prefix: str = "",
    save_best_path: Optional[str] = None,
    seed: int = 42,
) -> Dict[str, Any]:
    set_seed(seed)
    model = model.to(device)

    params_total, params_trainable = count_params(model)
    print(f"[{run_prefix}] params total/trainable: {params_total} {params_trainable}")

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    # ---- TODO: if you're multi-class, replace with CrossEntropyLoss + softmax metrics
    criterion = nn.BCEWithLogitsLoss()

    best = {
        "val_pr_auc": -1.0,
        "val_f1": -1.0,
        "thr": 0.5,
        "state": None,
    }

    for ep in range(int(epochs)):
        train_loss = train_one_epoch(model, train_dl, optimizer, device, criterion)

        val_rep, _, _ = eval_report_safe(model, val_dl, device, threshold=None)
        test_rep, _, _ = eval_report_safe(model, test_dl, device, threshold=val_rep["thr"])

        # Console
        print(
            f"[{run_prefix}] ep {ep} train_loss {train_loss:.5f} | "
            f"val PR-AUC {val_rep['pr_auc']:.6f} F1 {val_rep['f1']:.6f} (thr={val_rep['thr']:.4f})"
        )
        print(
            f"[{run_prefix}]             test PR-AUC {test_rep['pr_auc']:.6f} F1 {test_rep['f1']:.6f}"
        )

        # W&B log
        wandb.log(
            {
                "epoch": ep,
                "train_loss": train_loss,
                "params_total": params_total,
                "params_trainable": params_trainable,
                "val_pr_auc": val_rep["pr_auc"],
                "val_f1": val_rep["f1"],
                "val_acc": val_rep["acc"],
                "val_thr": val_rep["thr"],
                "test_pr_auc": test_rep["pr_auc"],
                "test_f1": test_rep["f1"],
                "test_acc": test_rep["acc"],
                "test_thr": test_rep["thr"],
            },
            step=ep,
        )

        # Track best by val PR-AUC (change if you prefer F1)
        if val_rep["pr_auc"] > best["val_pr_auc"]:
            best["val_pr_auc"] = val_rep["pr_auc"]
            best["val_f1"] = val_rep["f1"]
            best["thr"] = val_rep["thr"]
            best["state"] = copy.deepcopy(model.state_dict())

            # Optional checkpoint
            if save_best_path:
                # If you previously had "skipping checkpoint save", just set save_best_path=None in sweeps.
                save_ckpt(save_best_path, model, extra={"epoch": ep, "val_rep": val_rep})
            print(f"[{run_prefix}] *** new best saved: val_pr_auc={best['val_pr_auc']:.6f}")

    # Restore best for final test eval
    if best["state"] is not None:
        model.load_state_dict(best["state"])

    final_val, _, _ = eval_report_safe(model, val_dl, device, threshold=best["thr"])
    final_test, _, _ = eval_report_safe(model, test_dl, device, threshold=best["thr"])

    return {
        "best_thr": best["thr"],
        "val_pr_auc": final_val["pr_auc"],
        "val_f1": final_val["f1"],
        "test_pr_auc": final_test["pr_auc"],
        "test_f1": final_test["f1"],
        "params_trainable": params_trainable,
    }

# -------------------------
# W&B Sweep entrypoint (FIXED)
# -------------------------
def sweep_run():
    """
    IMPORTANT:
    - wandb.init() MUST happen before reading wandb.config
    - run.name can be set after init
    """
    run = wandb.init()  # <-- FIX #1
    cfg = wandb.config  # safe now

    # Create a stable readable run name
    run_name = f"{cfg.kind}_lr{cfg.lr}_wd{cfg.weight_decay}_ep{cfg.epochs}"
    try:
        run.name = run_name
        run.save()
    except Exception:
        pass

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # ---- TODO: YOU MUST PROVIDE THESE ----
    # 1) model = build_model(cfg.kind)
    # 2) train_dl, val_dl, test_dl = get_dataloaders()
    #
    # If you already have them globally, just reference them here.
    global build_model, get_dataloaders  # remove if not using globals
    model = build_model(cfg.kind)
    train_dl, val_dl, test_dl = get_dataloaders()

    results = train_model(
        kind=str(cfg.kind),
        epochs=int(cfg.epochs),
        lr=float(cfg.lr),
        weight_decay=float(cfg.weight_decay),
        model=model,
        train_dl=train_dl,
        val_dl=val_dl,
        test_dl=test_dl,
        device=device,
        run_prefix=f"[{run_name}]",
        save_best_path=None,  # keep None for sweeps to avoid IO bottlenecks
        seed=42,
    )

    # Log final summary metrics
    wandb.log(
        {
            "final/val_pr_auc": results["val_pr_auc"],
            "final/val_f1": results["val_f1"],
            "final/test_pr_auc": results["test_pr_auc"],
            "final/test_f1": results["test_f1"],
            "final/best_thr": results["best_thr"],
            "final/params_trainable": results["params_trainable"],
        }
    )

    wandb.finish()


In [None]:
def build_model(kind: str) -> nn.Module:
    if kind == "base":
        return BaseModel(...)
    elif kind == "dendritic":
        return DendriticModel(...)
    else:
        raise ValueError(kind)

def get_dataloaders():
    return train_dl, val_dl, test_dl

In [None]:
[name for name in globals().keys() if "train" in name.lower() and ("load" in name.lower() or "dl" in name.lower())]

['train_dl']

In [None]:
train_loader = train_dl


val_loader  = val_dl
test_loader = test_dl

In [None]:
[name for name in globals().keys() if ("val" in name.lower() or "test" in name.lower()) and ("dl" in name.lower() or "load" in name.lower())]

['val_dl',
 'test_dl',
 'test_loader',
 'eval_collect_preds_from_loader',
 'val_loader']

In [None]:
import torch
import torch.nn as nn

# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# criterion (adjust if your task uses something else)
criterion = nn.CrossEntropyLoss()

# optimizer (pull lr/wd from sweep config if available)
lr = float(getattr(wandb.config, "lr", 1e-3))
wd = float(getattr(wandb.config, "weight_decay", 0.0))

optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)

# training loop

In [None]:
import numpy as np

epochs = int(getattr(wandb.config, "epochs", 3))  # or set manually: epochs = 4
best_val = -1.0
best_thr = 0.5
best_state = None

for ep in range(epochs):
    # --- train ---
    train_loss = train_one_epoch(model, train_dl, optimizer, device, criterion)

    # --- eval (val): choose best threshold on val ---
    val_rep, val_metrics, _ = eval_report_safe(
        model, val_dl, device,
        threshold=None,          # let it search threshold
        choose_thr_on="val"      # choose on validation set
    )

    # val_metrics may be dict-like OR you can parse from val_rep; try dict first
    val_pr = val_metrics.get("pr_auc", None) if isinstance(val_metrics, dict) else None
    val_f1 = val_metrics.get("f1", None) if isinstance(val_metrics, dict) else None
    thr    = val_metrics.get("thr", None) if isinstance(val_metrics, dict) else None

    # fallback: if your function returns metrics differently, keep thr as previous
    if thr is None:
        thr = best_thr

    # --- eval (test) using chosen threshold ---
    test_rep, test_metrics, _ = eval_report_safe(
        model, test_dl, device,
        threshold=thr,
        choose_thr_on=None
    )

    test_pr = test_metrics.get("pr_auc", None) if isinstance(test_metrics, dict) else None
    test_f1 = test_metrics.get("f1", None) if isinstance(test_metrics, dict) else None

    # --- log to W&B (safe) ---
    wandb.log({
        "epoch": ep,
        "train_loss": float(train_loss),
        "val_pr_auc": float(val_pr) if val_pr is not None else np.nan,
        "val_f1": float(val_f1) if val_f1 is not None else np.nan,
        "val_thr": float(thr) if thr is not None else np.nan,
        "test_pr_auc": float(test_pr) if test_pr is not None else np.nan,
        "test_f1": float(test_f1) if test_f1 is not None else np.nan,
    })

    # --- keep best by val PR-AUC ---
    if val_pr is not None and val_pr > best_val:
        best_val = float(val_pr)
        best_thr = float(thr)
        best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}

print(f"BEST val_pr_auc={best_val:.6f} at thr={best_thr:.4f}")

BEST val_pr_auc=-1.000000 at thr=0.5000


In [None]:
import inspect
print(inspect.signature(train_one_epoch))
print(inspect.getsource(train_one_epoch).splitlines()[0:40])

(model, dl, optimizer, device, criterion) -> float
['def train_one_epoch(model, dl, optimizer, device, criterion) -> float:', '    model.train()', '    running = 0.0', '    n = 0', '', '    for batch in dl:', '        # ---- TODO: adapt to your batch structure ----', '        x, y = batch', '        x = x.to(device)', '        y = y.to(device)', '', '        optimizer.zero_grad(set_to_none=True)', '        logits = model(x)', '', '        # Ensure shapes match criterion expectations', '        # For BCEWithLogitsLoss: logits and y should be float', '        loss = criterion(logits.view_as(y).float(), y.float())', '        loss.backward()', '        optimizer.step()', '', '        bs = y.shape[0]', '        running += float(loss.item()) * bs', '        n += bs', '', '    return running / max(n, 1)']


In [None]:
import numpy as np
import torch
from sklearn.metrics import average_precision_score, f1_score

def evaluate(model, dl, device, criterion=None, threshold=0.5):
    """
    Returns dict with:
      - loss (if criterion provided)
      - pr_auc  (Average Precision)
      - f1_at_0p5
    Uses eval_collect_preds_from_loader(model, dl, device) if available.
    """
    model.eval()
    with torch.no_grad():
        # Prefer your existing helper if present
        if "eval_collect_preds_from_loader" in globals():
            y_true, y_score = eval_collect_preds_from_loader(model, dl, device)
            y_true = np.asarray(y_true).reshape(-1)
            y_score = np.asarray(y_score).reshape(-1)
        else:
            # Fallback: assumes batch = (x, y)
            ys, scores = [], []
            for batch in dl:
                x, y = batch
                x = x.to(device)
                y = y.to(device)

                logits = model(x).view_as(y)
                prob = torch.sigmoid(logits).detach().cpu().numpy()

                ys.append(y.detach().cpu().numpy())
                scores.append(prob)

            y_true = np.concatenate(ys).reshape(-1)
            y_score = np.concatenate(scores).reshape(-1)

    # Handle edge case: PR-AUC undefined if only 1 class in y_true
    uniq = np.unique(y_true)
    if uniq.size < 2:
        pr_auc = float("nan")
    else:
        pr_auc = float(average_precision_score(y_true, y_score))

    y_pred = (y_score >= threshold).astype(int)
    f1 = float(f1_score(y_true, y_pred, zero_division=0))

    out = {"pr_auc": pr_auc, "f1_at_0p5": f1, "n": int(len(y_true))}
    return out

In [None]:
val_metrics = evaluate(model, val_dl, device, criterion)
print(val_metrics)

{'pr_auc': 0.1111111111111111, 'f1_at_0p5': 0.2, 'n': 54}


In [None]:
# quick label distribution check
ys = []
for batch in val_dl:
    y = batch[-1] if isinstance(batch, (tuple, list)) else batch["y"]
    ys.append(y.detach().cpu())
ys = torch.cat(ys).view(-1)

print("unique:", torch.unique(ys))
print("positive rate:", ys.float().mean().item())

unique: tensor([0, 1])
positive rate: 0.1111111119389534


In [None]:
import numpy as np
from sklearn.metrics import f1_score, average_precision_score

def best_threshold_from_val(model, dl, device, grid=np.linspace(0.05, 0.95, 19)):
    y_true, y_score = eval_collect_preds_from_loader(model, dl, device)
    y_true = np.asarray(y_true).reshape(-1)
    y_score = np.asarray(y_score).reshape(-1)

    pr_auc = average_precision_score(y_true, y_score)

    best = {"thr": 0.5, "f1": -1}
    for thr in grid:
        f1 = f1_score(y_true, (y_score >= thr).astype(int), zero_division=0)
        if f1 > best["f1"]:
            best = {"thr": float(thr), "f1": float(f1)}
    return {"pr_auc": float(pr_auc), **best}

best = best_threshold_from_val(model, val_dl, device)
print(best)

{'pr_auc': 0.1111111111111111, 'thr': 0.05, 'f1': 0.2}


In [None]:
import torch
pos = ys.sum().item()
neg = (len(ys) - pos)
pos_weight = torch.tensor([neg / max(pos, 1)], device=device)
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
print("pos_weight:", pos_weight.item())

pos_weight: 8.0


In [None]:
import numpy as np

y_true, y_score = eval_collect_preds_from_loader(model, val_dl, device)
y_true = np.asarray(y_true).reshape(-1)
y_score = np.asarray(y_score).reshape(-1)

print("score min/mean/max:", y_score.min(), y_score.mean(), y_score.max())
print("pos mean score:", y_score[y_true==1].mean() if (y_true==1).any() else None)
print("neg mean score:", y_score[y_true==0].mean() if (y_true==0).any() else None)

score min/mean/max: 1 1.0 1
pos mean score: 1.0
neg mean score: 1.0


In [None]:
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support

thr = best["thr"]
y_pred = (y_score >= thr).astype(int)

cm = confusion_matrix(y_true, y_pred)
p, r, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary", zero_division=0)
print("thr:", thr)
print("confusion matrix [[tn fp],[fn tp]]:\n", cm)
print("precision:", p, "recall:", r, "f1:", f1)

thr: 0.05
confusion matrix [[tn fp],[fn tp]]:
 [[ 0 48]
 [ 0  6]]
precision: 0.1111111111111111 recall: 1.0 f1: 0.2


In [None]:
import numpy as np
import torch
from torch import nn

def unpack_batch(batch, device):
    """
    Tries to support common batch formats:
    - (x, y)
    - {"x":..., "y":...} or {"inputs":..., "labels":...}
    - (x, y, *extras)
    """
    if isinstance(batch, (list, tuple)):
        x, y = batch[0], batch[1]
    elif isinstance(batch, dict):
        # adjust keys if your dataloader uses different ones
        x = batch.get("x", batch.get("inputs", batch.get("input_ids", None)))
        y = batch.get("y", batch.get("labels", batch.get("label", None)))
        if x is None or y is None:
            raise KeyError(f"Unknown batch dict keys: {batch.keys()}")
    else:
        raise TypeError(f"Unsupported batch type: {type(batch)}")

    x = x.to(device)
    y = y.to(device)
    return x, y

In [None]:
from sklearn.metrics import average_precision_score, precision_recall_curve

@torch.no_grad()
def eval_collect_logits_and_labels(model, dl, device):
    model.eval()
    all_logits = []
    all_y = []
    for batch in dl:
        x, y = unpack_batch(batch, device)
        logits = model(x)

        # make shapes consistent: [bs]
        logits = logits.squeeze(-1)
        y = y.squeeze(-1)

        all_logits.append(logits.detach().cpu())
        all_y.append(y.detach().cpu())
    logits = torch.cat(all_logits).float().numpy()
    y = torch.cat(all_y).float().numpy()
    return logits, y

def evaluate(model, dl, device):
    logits, y = eval_collect_logits_and_labels(model, dl, device)
    probs = 1 / (1 + np.exp(-logits))

    pr_auc = float(average_precision_score(y, probs))

    # f1 at 0.5
    pred = (probs >= 0.5).astype(np.int32)
    tp = int(((pred == 1) & (y == 1)).sum())
    fp = int(((pred == 1) & (y == 0)).sum())
    fn = int(((pred == 0) & (y == 1)).sum())
    precision = tp / (tp + fp + 1e-12)
    recall    = tp / (tp + fn + 1e-12)
    f1 = (2 * precision * recall) / (precision + recall + 1e-12)

    return {"pr_auc": pr_auc, "f1_at_0p5": float(f1), "n": int(len(y))}

def best_threshold_from_val(model, val_dl, device):
    logits, y = eval_collect_logits_and_labels(model, val_dl, device)
    probs = 1 / (1 + np.exp(-logits))

    prec, rec, thr = precision_recall_curve(y, probs)
    # thr has length (len(prec)-1). align:
    prec2, rec2 = prec[:-1], rec[:-1]
    f1 = 2 * prec2 * rec2 / (prec2 + rec2 + 1e-12)

    best_i = int(np.argmax(f1))
    best_thr = float(thr[best_i])
    pr_auc = float(average_precision_score(y, probs))
    return {"pr_auc": pr_auc, "thr": best_thr, "f1": float(f1[best_i])}

In [None]:
def train_one_epoch(model, dl, optimizer, device, criterion, grad_clip=None):
    model.train()
    running = 0.0
    n = 0

    for batch in dl:
        x, y = unpack_batch(batch, device)

        optimizer.zero_grad(set_to_none=True)

        logits = model(x).squeeze(-1)
        y = y.squeeze(-1)

        # BCEWithLogitsLoss expects float targets
        loss = criterion(logits.float(), y.float())
        loss.backward()

        if grad_clip is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

        optimizer.step()

        bs = y.shape[0]
        running += float(loss.item()) * bs
        n += bs

    return running / max(n, 1)

# ---- set these ----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

epochs = 10  # change this
lr = 1e-3
weight_decay = 1e-3

# if you computed pos_weight earlier, keep it; otherwise derive from train labels if you can
# pos_weight should be a tensor on device
try:
    _ = pos_weight
except NameError:
    # fallback: set it manually or compute it from your data
    pos_weight = torch.tensor(1.0)

pos_weight = torch.tensor(float(pos_weight)).to(device)

criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

In [None]:
import wandb
from datetime import datetime

def run_training(
    model,
    train_dl, val_dl, test_dl,
    epochs=10,
    lr=1e-3,
    weight_decay=1e-3,
    pos_weight=1.0,
    grad_clip=None,
    project="dendrites-hackathon",
    entity=None,
    run_name=None,
    extra_config=None
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    pos_weight = torch.tensor(float(pos_weight)).to(device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    cfg = {
        "epochs": epochs,
        "lr": lr,
        "weight_decay": weight_decay,
        "pos_weight": float(pos_weight.item()),
        "grad_clip": grad_clip,
    }
    if extra_config:
        cfg.update(extra_config)

    if run_name is None:
        run_name = f"run_{datetime.now().strftime('%m%d_%H%M%S')}"

    wandb.init(project=project, entity=entity, name=run_name, config=cfg)

    best = {"val_pr_auc": -1.0, "thr": 0.5, "epoch": -1}

    for ep in range(1, epochs + 1):
        train_loss = train_one_epoch(model, train_dl, optimizer, device, criterion, grad_clip=grad_clip)

        val_metrics = evaluate(model, val_dl, device)
        best_thr = best_threshold_from_val(model, val_dl, device)  # returns pr_auc, thr, f1

        # optional: evaluate test each epoch (usually do only at end; but ok for hackathon)
        test_metrics = evaluate(model, test_dl, device)

        wandb.log({
            "epoch": ep,
            "train/loss": train_loss,
            "val/pr_auc": val_metrics["pr_auc"],
            "val/f1@0.5": val_metrics["f1_at_0p5"],
            "val/best_thr": best_thr["thr"],
            "val/best_f1": best_thr["f1"],
            "test/pr_auc": test_metrics["pr_auc"],
            "test/f1@0.5": test_metrics["f1_at_0p5"],
        })

        # track best by val PR-AUC
        if val_metrics["pr_auc"] > best["val_pr_auc"]:
            best.update({"val_pr_auc": val_metrics["pr_auc"], "thr": best_thr["thr"], "epoch": ep})
            # save best checkpoint
            wandb.run.summary["best_val_pr_auc"] = best["val_pr_auc"]
            wandb.run.summary["best_epoch"] = best["epoch"]
            wandb.run.summary["best_thr"] = best["thr"]

    # Final test with best threshold
    logits, y = eval_collect_logits_and_labels(model, test_dl, device)
    probs = 1 / (1 + np.exp(-logits))
    thr = best["thr"]
    pred = (probs >= thr).astype(np.int32)
    tp = int(((pred == 1) & (y == 1)).sum())
    fp = int(((pred == 1) & (y == 0)).sum())
    fn = int(((pred == 0) & (y == 1)).sum())
    precision = tp / (tp + fp + 1e-12)
    recall    = tp / (tp + fn + 1e-12)
    f1 = (2 * precision * recall) / (precision + recall + 1e-12)

    wandb.run.summary["final_test_f1@best_thr"] = float(f1)
    wandb.run.summary["final_test_precision@best_thr"] = float(precision)
    wandb.run.summary["final_test_recall@best_thr"] = float(recall)

    wandb.finish()
    return best

# --- call it ---
best = run_training(
    model=model,
    train_dl=train_dl,
    val_dl=val_dl,
    test_dl=test_dl,
    epochs=epochs,
    lr=lr,
    weight_decay=weight_decay,
    pos_weight=float(pos_weight.item()),
    grad_clip=1.0,
    run_name="baseline_adamw",
    extra_config={"model_variant": "baseline"}
)

print("Best:", best)

0,1
epoch,▁▅█
train_loss,▇█▁
val_thr,▁▁▁
+4,...

0,1
epoch,2.0
test_f1,
test_pr_auc,
train_loss,117.7408
val_f1,
val_pr_auc,
val_thr,0.5


0,1
epoch,▁▂▃▃▄▅▆▆▇█
test/f1@0.5,▁▁▁▁▁▁▁▁▁▁
test/pr_auc,▄▁▆█▄▄▄▄▅▄
train/loss,█▄▆▃▄▂▂▁▂▁
val/best_f1,▁▁▇▇▇▇▇▇█▃
val/best_thr,▇█▆▄▄▅▆▅▃▁
val/f1@0.5,▁▁▁▁▁▁▁▁▁▁
val/pr_auc,▅▅▇▇▇▇▇█▁▂

0,1
best_epoch,8
best_thr,0.87788
best_val_pr_auc,0.15237
epoch,10
final_test_f1@best_thr,0
final_test_precision@best_thr,0
final_test_recall@best_thr,0
test/f1@0.5,0.2
test/pr_auc,0.24924
train/loss,1.43199


Best: {'val_pr_auc': 0.15236823154821877, 'thr': 0.8778806924819946, 'epoch': 8}


In [None]:
# 1) sweep config
sweep_config = {
    "method": "bayes",
    "metric": {"name": "val/pr_auc", "goal": "maximize"},
    "parameters": {
        "lr": {"distribution": "log_uniform_values", "min": 1e-4, "max": 3e-3},
        "weight_decay": {"distribution": "log_uniform_values", "min": 1e-6, "max": 1e-2},
        # add ONE dendrite knob here (example names; replace with your real ones)
        # "dendrite_k": {"values": [4, 8, 16]},
        # "dendrite_strength": {"distribution": "uniform", "min": 0.0, "max": 1.0},
    }
}

# 2) training function for sweep
def sweep_train():
    import wandb
    wandb.init()

    cfg = wandb.config

    # IMPORTANT: if your model depends on cfg (e.g., dendrites params),
    # you must REBUILD the model here using cfg.
    # model = build_model(dendrite_k=cfg.dendrite_k, ...)

    best = run_training(
        model=model,  # replace with newly built model if needed
        train_dl=train_dl,
        val_dl=val_dl,
        test_dl=test_dl,
        epochs=10,
        lr=float(cfg.lr),
        weight_decay=float(cfg.weight_decay),
        pos_weight=float(pos_weight.item()),
        grad_clip=1.0,
        run_name=None,
        extra_config=dict(cfg)
    )
    return best

# 3) create sweep once
sweep_id = wandb.sweep(sweep_config, project="dendrites-hackathon")

# 4) run agent (many runs)
wandb.agent(sweep_id, function=sweep_train, count=20)

Create sweep with ID: 46sgkuvo
Sweep URL: https://wandb.ai/vtpy/dendrites-hackathon/sweeps/46sgkuvo


[34m[1mwandb[0m: Agent Starting Run: tdif3p2t with config:
[34m[1mwandb[0m: 	lr: 0.0005784145230219258
[34m[1mwandb[0m: 	weight_decay: 0.001984463047868843


0,1
epoch,▁▂▃▃▄▅▆▆▇█
test/f1@0.5,▁▁▁▁▁▁▁▁▁▁
test/pr_auc,▁█▇▂▃▁▁▂▃▆
train/loss,██▅▅▅▄▄▃▂▁
val/best_f1,▁▂▄▃▁▁▁▂▁█
val/best_thr,▄█▇▆▁▇█▁▅▁
val/f1@0.5,▁▁▁▁▁▁▁▁▁▁
val/pr_auc,▁▁▂▃▅▅▅▅▅█

0,1
best_epoch,10
best_thr,0.75876
best_val_pr_auc,0.41344
epoch,10
final_test_f1@best_thr,0.31579
final_test_precision@best_thr,0.27273
final_test_recall@best_thr,0.375
test/f1@0.5,0.2
test/pr_auc,0.25857
train/loss,1.12508


[34m[1mwandb[0m: Agent Starting Run: hpmq62fe with config:
[34m[1mwandb[0m: 	lr: 0.001065783331350847
[34m[1mwandb[0m: 	weight_decay: 0.00022236512171524135


0,1
epoch,▁▂▃▃▄▅▆▆▇█
test/f1@0.5,███▇▁▁▁▁▃▃
test/pr_auc,██▅▁▁██▅██
train/loss,▇█▆▆▅▄▄▃▂▁
val/best_f1,▃█▁████▃▅▃
val/best_thr,█▅█▇▃▅▅▂▁▃
val/f1@0.5,▄▄▄▄▁▁▁▄██
val/pr_auc,▆█▄████▂▂▁

0,1
best_epoch,2
best_thr,0.69435
best_val_pr_auc,0.40934
epoch,10
final_test_f1@best_thr,0.42857
final_test_precision@best_thr,0.5
final_test_recall@best_thr,0.375
test/f1@0.5,0.15385
test/pr_auc,0.44634
train/loss,0.65756


[34m[1mwandb[0m: Agent Starting Run: 3ei24bcq with config:
[34m[1mwandb[0m: 	lr: 0.00011489581841961362
[34m[1mwandb[0m: 	weight_decay: 2.154883977035616e-05


0,1
epoch,▁▂▃▃▄▅▆▆▇█
test/f1@0.5,▁▂▂▂▅▅▄▄██
test/pr_auc,▁▄▆███████
train/loss,▆█▅▄▆▄▁▂▄▃
val/best_f1,████████▁█
val/best_thr,▅▄▃▃▂▂▂▂█▁
val/f1@0.5,██████▂▁▂█
val/pr_auc,▁▂▂▂████▇█

0,1
best_epoch,5
best_thr,0.55035
best_val_pr_auc,0.2975
epoch,10
final_test_f1@best_thr,0.35294
final_test_precision@best_thr,0.33333
final_test_recall@best_thr,0.375
test/f1@0.5,0.27273
test/pr_auc,0.44663
train/loss,0.52503


[34m[1mwandb[0m: Agent Starting Run: jeg08kfr with config:
[34m[1mwandb[0m: 	lr: 0.001970109790861881
[34m[1mwandb[0m: 	weight_decay: 2.3210441605336214e-06


0,1
epoch,▁▂▃▃▄▅▆▆▇█
test/f1@0.5,▁▆▂▃▇▂▇▃▅█
test/pr_auc,▂█▃▅▂▅▁▃▄▅
train/loss,██▅▄▃▃▂▃▁▁
val/best_f1,▂▄██▄▄▁██▄
val/best_thr,██▄▅▇▇▁▆▂▂
val/f1@0.5,▅▄▅█▆▇▆█▁▁
val/pr_auc,▁▇█▁▆▆▁▁▃▂

0,1
best_epoch,3
best_thr,0.39074
best_val_pr_auc,0.2975
epoch,10
final_test_f1@best_thr,0.46154
final_test_precision@best_thr,0.6
final_test_recall@best_thr,0.375
test/f1@0.5,0.46154
test/pr_auc,0.39453
train/loss,0.22592


[34m[1mwandb[0m: Agent Starting Run: zjy5t9dq with config:
[34m[1mwandb[0m: 	lr: 0.0014207002228785993
[34m[1mwandb[0m: 	weight_decay: 0.005695829754728529


0,1
epoch,▁▂▃▃▄▅▆▆▇█
test/f1@0.5,▄▁████▁▄▁▁
test/pr_auc,▆▅█▁▁▂▂▅▅▅
train/loss,█▄▃▁▁▇▁▄▇▃
val/best_f1,▁█▁████▁██
val/best_thr,▁▅▁▇██▆▁▆█
val/f1@0.5,▅█▅█▁▂█▅█▂
val/pr_auc,▂▇▁████▂██

0,1
best_epoch,4
best_thr,0.78951
best_val_pr_auc,0.27225
epoch,10
final_test_f1@best_thr,0.18182
final_test_precision@best_thr,0.33333
final_test_recall@best_thr,0.125
test/f1@0.5,0.33333
test/pr_auc,0.3006
train/loss,0.12901


[34m[1mwandb[0m: Agent Starting Run: 48tpv2qo with config:
[34m[1mwandb[0m: 	lr: 0.0011561327603170526
[34m[1mwandb[0m: 	weight_decay: 2.0064795106231516e-05


0,1
epoch,▁▂▃▃▄▅▆▆▇█
test/f1@0.5,▁▅▅▅████▇▄
test/pr_auc,▅▃▃▁▂▅██▅▆
train/loss,▆▇▄▄▄▂▆▁█▅
val/best_f1,▁▁▁▁▁▁▁▁▁▁
val/best_thr,▁███▇▇▇▇█▇
val/f1@0.5,█▁▁▁▂▂▃▂▁▂
val/pr_auc,█▇▇▇▁▇▇▇▇▇

0,1
best_epoch,1
best_thr,0.556
best_val_pr_auc,0.26626
epoch,10
final_test_f1@best_thr,0.18182
final_test_precision@best_thr,0.33333
final_test_recall@best_thr,0.125
test/f1@0.5,0.18182
test/pr_auc,0.34423
train/loss,0.10178


[34m[1mwandb[0m: Agent Starting Run: gekumtcm with config:
[34m[1mwandb[0m: 	lr: 0.00024319241133520736
[34m[1mwandb[0m: 	weight_decay: 7.78249465861457e-05


0,1
epoch,▁▂▃▃▄▅▆▆▇█
test/f1@0.5,█████▃▃█▁▁
test/pr_auc,▇▇▇▆▆▆▆▇█▁
train/loss,▅▄▄▄▁▂▃▁▅█
val/best_f1,▁▁▁▁▁▁▁▁▁▁
val/best_thr,▃▁▃▆▇██▇▇▇
val/f1@0.5,▁█▁▁▁▁▁▁▁▁
val/pr_auc,▁▁▁▁▁▁▁▁▁▁

0,1
best_epoch,1
best_thr,0.97494
best_val_pr_auc,0.26575
epoch,10
final_test_f1@best_thr,0
final_test_precision@best_thr,0
final_test_recall@best_thr,0
test/f1@0.5,0.33333
test/pr_auc,0.3156
train/loss,0.09016


[34m[1mwandb[0m: Agent Starting Run: ppuk8dyw with config:
[34m[1mwandb[0m: 	lr: 0.0015590494141355905
[34m[1mwandb[0m: 	weight_decay: 3.1207456097777996e-06


0,1
epoch,▁▂▃▃▄▅▆▆▇█
test/f1@0.5,▃▃▂█▃▂▆▄▂▁
test/pr_auc,▂█▄█▄▁▂▅▅▄
train/loss,▃▅▃█▆▃▆▁▁▅
val/best_f1,███████▁▁█
val/best_thr,█▇█████▁▁█
val/f1@0.5,▂█▁▂▂▁▁▁▁▂
val/pr_auc,███████▁▁█

0,1
best_epoch,3
best_thr,0.99683
best_val_pr_auc,0.26575
epoch,10
final_test_f1@best_thr,0.22222
final_test_precision@best_thr,1.0
final_test_recall@best_thr,0.125
test/f1@0.5,0.125
test/pr_auc,0.23132
train/loss,0.1075


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: mvju26qr with config:
[34m[1mwandb[0m: 	lr: 0.0010187014386312102
[34m[1mwandb[0m: 	weight_decay: 7.526298330899762e-05


0,1
epoch,▁▂▃▃▄▅▆▆▇█
test/f1@0.5,▅▅▇██▅▁▅▃▅
test/pr_auc,▁▂▁████▆▆█
train/loss,▂▃▄█▄▂▃▁▆▁
val/best_f1,█████▁▁███
val/best_thr,█████▁▁███
val/f1@0.5,▁▂▁▂▂▅█▁▁▁
val/pr_auc,█████▁▁███

0,1
best_epoch,3
best_thr,0.99782
best_val_pr_auc,0.26863
epoch,10
final_test_f1@best_thr,0.18182
final_test_precision@best_thr,0.33333
final_test_recall@best_thr,0.125
test/f1@0.5,0.26667
test/pr_auc,0.27915
train/loss,0.01673


[34m[1mwandb[0m: Agent Starting Run: 1ukkyivo with config:
[34m[1mwandb[0m: 	lr: 0.0023386832212159104
[34m[1mwandb[0m: 	weight_decay: 0.00014225252577319356


0,1
epoch,▁▂▃▃▄▅▆▆▇█
test/f1@0.5,▅▃▁▄▃▆▆▅▆█
test/pr_auc,▅▇▇█▃▃▂█▄▁
train/loss,▃▇▅▁▃▄█▂▂▁
val/best_f1,▁▁███████▁
val/best_thr,▁▁█▁▁████▁
val/f1@0.5,█▃▄▁▄▄▄▂▄▅
val/pr_auc,▁▁████▇██▁

0,1
best_epoch,4
best_thr,0.00432
best_val_pr_auc,0.2758
epoch,10
final_test_f1@best_thr,0.23729
final_test_precision@best_thr,0.13725
final_test_recall@best_thr,0.875
test/f1@0.5,0.33333
test/pr_auc,0.15659
train/loss,0.02836


[34m[1mwandb[0m: Agent Starting Run: 6rnmju46 with config:
[34m[1mwandb[0m: 	lr: 0.0009814815637683265
[34m[1mwandb[0m: 	weight_decay: 0.001912765611612768


0,1
epoch,▁▂▃▃▄▅▆▆▇█
test/f1@0.5,▄██▄█▃▃▁▁▃
test/pr_auc,▂▂▁▁▂▃██▂▂
train/loss,▂▅▃▂█▃▂▄▇▁
val/best_f1,██▁▃▁▁▁▃██
val/best_thr,▁▁▁▁▁▁▁▁██
val/f1@0.5,▁██▇███▇▇▇
val/pr_auc,█▂▂▁▂▂▂▂██

0,1
best_epoch,1
best_thr,0.00247
best_val_pr_auc,0.27798
epoch,10
final_test_f1@best_thr,0.25455
final_test_precision@best_thr,0.14894
final_test_recall@best_thr,0.875
test/f1@0.5,0.23529
test/pr_auc,0.17696
train/loss,0.01314


[34m[1mwandb[0m: Agent Starting Run: y8krklgw with config:
[34m[1mwandb[0m: 	lr: 0.0003044391275817301
[34m[1mwandb[0m: 	weight_decay: 1.158114530712288e-06


0,1
epoch,▁▂▃▃▄▅▆▆▇█
test/f1@0.5,▁▁██▁▁▁▁▁▁
test/pr_auc,▁▂▁▄▃▅█▆▅▄
train/loss,▅▃▂▂█▂▃▁▃▄
val/best_f1,▁▁▁▁▁▁▁▁▁▁
val/best_thr,▆▆██▇▄▁▄▆█
val/f1@0.5,▁▁▁▁▁▁▁▁▁▁
val/pr_auc,▁▂▁▄██████

0,1
best_epoch,6
best_thr,0.99975
best_val_pr_auc,0.28125
epoch,10
final_test_f1@best_thr,0.18182
final_test_precision@best_thr,0.33333
final_test_recall@best_thr,0.125
test/f1@0.5,0.23529
test/pr_auc,0.18707
train/loss,0.02498


[34m[1mwandb[0m: Agent Starting Run: gp84dgsz with config:
[34m[1mwandb[0m: 	lr: 0.0028425188630288862
[34m[1mwandb[0m: 	weight_decay: 0.0016236585812960055


0,1
epoch,▁▂▃▃▄▅▆▆▇█
test/f1@0.5,█▃▄▆▅█▄▇▇▁
test/pr_auc,▂▇█▇▃▂▇▄▁▁
train/loss,▁▆▆▁█▄▃▆▂▄
val/best_f1,▄▁▁▄▁▁▁▄▅█
val/best_thr,▁█▁▁▁▁▁▁▁▁
val/f1@0.5,█▂▁▂▅▂▄▄▄█
val/pr_auc,▁████▂▂▂▁▂

0,1
best_epoch,5
best_thr,0.00145
best_val_pr_auc,0.28697
epoch,10
final_test_f1@best_thr,0.2
final_test_precision@best_thr,0.1129
final_test_recall@best_thr,0.875
test/f1@0.5,0.11111
test/pr_auc,0.16911
train/loss,0.08838


[34m[1mwandb[0m: Agent Starting Run: lhmuf9y8 with config:
[34m[1mwandb[0m: 	lr: 0.001492499352822561
[34m[1mwandb[0m: 	weight_decay: 0.00024859233679721426


0,1
epoch,▁▂▃▃▄▅▆▆▇█
test/f1@0.5,█▅▁▆▆▁▁▁▁▅
test/pr_auc,▄▇█▆▁▂▂▁▄▄
train/loss,▅█▂▂▂▂▂▁▁▁
val/best_f1,▁▄▁▄▄▄▄▄██
val/best_thr,█▆▁▄▇▄▃▃▃▄
val/f1@0.5,▁▁▁▁▁▁▁██▁
val/pr_auc,▂███▂▂▁▁▂▂

0,1
best_epoch,2
best_thr,0.00557
best_val_pr_auc,0.28292
epoch,10
final_test_f1@best_thr,0.22222
final_test_precision@best_thr,0.13043
final_test_recall@best_thr,0.75
test/f1@0.5,0.2
test/pr_auc,0.17502
train/loss,0.01242


[34m[1mwandb[0m: Agent Starting Run: nrvhpglz with config:
[34m[1mwandb[0m: 	lr: 0.0015611762606506066
[34m[1mwandb[0m: 	weight_decay: 0.0008791269120259303


0,1
epoch,▁▂▃▃▄▅▆▆▇█
test/f1@0.5,█▆▆▅▁▃▅▆▅▁
test/pr_auc,▂▃▃▄█▁▂▁▂▂
train/loss,▂▂▄▁▄█▃▃▄▁
val/best_f1,▇█▂▃▁▅▃▅█▂
val/best_thr,▁▃▃▁▁▇█▄▃▂
val/f1@0.5,█▄▃▄▃▁▇▂▄▄
val/pr_auc,▂▂▂██▂▁▂▂▁

0,1
best_epoch,4
best_thr,0.00163
best_val_pr_auc,0.28585
epoch,10
final_test_f1@best_thr,0.21875
final_test_precision@best_thr,0.125
final_test_recall@best_thr,0.875
test/f1@0.5,0.125
test/pr_auc,0.17849
train/loss,0.0021


[34m[1mwandb[0m: Agent Starting Run: u2vd34cr with config:
[34m[1mwandb[0m: 	lr: 0.0003452101727019075
[34m[1mwandb[0m: 	weight_decay: 0.0024056991264576307


0,1
epoch,▁▂▃▃▄▅▆▆▇█
test/f1@0.5,▁▅▆▆███▆▆▆
test/pr_auc,▁▂▅▅▅▇▇▆▆█
train/loss,▁▂▄▁▁▁█▆▂▁
val/best_f1,▁▁▁▁▁▁▁▁▁▁
val/best_thr,▇█▆▄▂▂▁▂▁▁
val/f1@0.5,█▁▁███████
val/pr_auc,▄▁▃▁▄██▆▆▆

0,1
best_epoch,6
best_thr,0.00181
best_val_pr_auc,0.16135
epoch,10
final_test_f1@best_thr,0.2069
final_test_precision@best_thr,0.12
final_test_recall@best_thr,0.75
test/f1@0.5,0.28571
test/pr_auc,0.19422
train/loss,0.0003


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: 8kvnmpro with config:
[34m[1mwandb[0m: 	lr: 0.0025003334273363603
[34m[1mwandb[0m: 	weight_decay: 8.332118539569875e-05


0,1
epoch,▁▂▃▃▄▅▆▆▇█
test/f1@0.5,▁▆█▄▅▅▇█▆▆
test/pr_auc,█▁▅▆▄▃▅▇▄▄
train/loss,▄█▃▄▃▂▁▂▁▇
val/best_f1,▅▂▂▁▄▁▄█▄▂
val/best_thr,▁▃▃▂▆▄▅▃▃█
val/f1@0.5,▆█▁▆▁▄▃▂▁▁
val/pr_auc,▃▁▁▁▂▂▁▅█▆

0,1
best_epoch,9
best_thr,0.00629
best_val_pr_auc,0.19283
epoch,10
final_test_f1@best_thr,0.17143
final_test_precision@best_thr,0.09677
final_test_recall@best_thr,0.75
test/f1@0.5,0.1875
test/pr_auc,0.1754
train/loss,0.09123


[34m[1mwandb[0m: Agent Starting Run: 0c280l2j with config:
[34m[1mwandb[0m: 	lr: 0.00015288250892757412
[34m[1mwandb[0m: 	weight_decay: 3.1873598539607545e-06


0,1
epoch,▁▂▃▃▄▅▆▆▇█
test/f1@0.5,▁▂▇▇▇▇████
test/pr_auc,▁▁▂▃▄▆▆▆█▇
train/loss,▂█▇▁▁▁▂▃▄▁
val/best_f1,▂▂▁▃██████
val/best_thr,█▇▅▄▃▃▂▂▂▁
val/f1@0.5,▁▁▁▁▁▁▁▁▁▁
val/pr_auc,▂▃▁▅█▆▇▇▆▆

0,1
best_epoch,5
best_thr,0.01089
best_val_pr_auc,0.19851
epoch,10
final_test_f1@best_thr,0.27907
final_test_precision@best_thr,0.17143
final_test_recall@best_thr,0.75
test/f1@0.5,0.28571
test/pr_auc,0.21114
train/loss,0.00176


[34m[1mwandb[0m: Agent Starting Run: 4b183pr9 with config:
[34m[1mwandb[0m: 	lr: 0.0001940209391081429
[34m[1mwandb[0m: 	weight_decay: 4.500805820938433e-06


0,1
epoch,▁▂▃▃▄▅▆▆▇█
test/f1@0.5,▁▁▁▁▁▁▁▁▁▁
test/pr_auc,▂▇▇▇▇▇█▇▁▁
train/loss,▄▁▁▁▂▄▅█▁▇
val/best_f1,▁▁▁▁▁▁▁▁▁▁
val/best_thr,██▇▇▆▃▁▁▂▁
val/f1@0.5,▁▁▁▁▁▁▁▁▁▁
val/pr_auc,███▁▁▁▁▁▁▁

0,1
best_epoch,1
best_thr,0.00884
best_val_pr_auc,0.19373
epoch,10
final_test_f1@best_thr,0.32432
final_test_precision@best_thr,0.2069
final_test_recall@best_thr,0.75
test/f1@0.5,0.28571
test/pr_auc,0.21242
train/loss,0.00569


[34m[1mwandb[0m: Agent Starting Run: c56jbayo with config:
[34m[1mwandb[0m: 	lr: 0.00029718987971961714
[34m[1mwandb[0m: 	weight_decay: 4.052039089211855e-05


0,1
epoch,▁▂▃▃▄▅▆▆▇█
test/f1@0.5,▁█▁▁▁▁▁▁▁▁
test/pr_auc,▅▅▅█▃▃▃▄▁▁
train/loss,▃▃▃▁▁▁▄▂█▁
val/best_f1,▁▁▁▁▁▁▁▁▁▁
val/best_thr,█▅▅▅▆▅▄▃▁▁
val/f1@0.5,▁▁▁▁▁▁███▁
val/pr_auc,▁▁▁▁▁▁▁▁▁▁

0,1
best_epoch,1
best_thr,0.00429
best_val_pr_auc,0.17151
epoch,10
final_test_f1@best_thr,0.33333
final_test_precision@best_thr,0.21429
final_test_recall@best_thr,0.75
test/f1@0.5,0.28571
test/pr_auc,0.20766
train/loss,0.0004


In [None]:
import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import average_precision_score, f1_score, precision_score, recall_score
import wandb

# -----------------------
# Utils
# -----------------------
def unpack_batch(batch):
    # Most common: (x, y)
    if isinstance(batch, (tuple, list)) and len(batch) >= 2:
        return batch[0], batch[1]
    # If your dataloader yields dicts, adapt here:
    if isinstance(batch, dict):
        x = batch.get("x", batch.get("inputs"))
        y = batch.get("y", batch.get("labels"))
        return x, y
    raise ValueError(f"Unknown batch type: {type(batch)}")

@torch.no_grad()
def collect_probs_and_labels(model, dl, device):
    model.eval()
    probs_list, y_list = [], []
    for batch in dl:
        x, y = unpack_batch(batch)
        x = x.to(device)
        y = y.to(device)

        logits = model(x)
        logits = logits.squeeze()
        p = torch.sigmoid(logits).detach().float().cpu().numpy()
        y_np = y.detach().long().cpu().numpy()

        probs_list.append(p)
        y_list.append(y_np)

    probs = np.concatenate(probs_list) if probs_list else np.array([])
    y_true = np.concatenate(y_list) if y_list else np.array([])
    return probs, y_true

def safe_pr_auc(y_true, probs):
    # average_precision_score requires both classes present
    if len(np.unique(y_true)) < 2:
        return float("nan")
    return float(average_precision_score(y_true, probs))

def metrics_at_threshold(y_true, probs, thr):
    y_pred = (probs >= thr).astype(int)
    # zero_division=0 avoids NaNs when no predicted positives
    f1 = float(f1_score(y_true, y_pred, zero_division=0))
    prec = float(precision_score(y_true, y_pred, zero_division=0))
    rec = float(recall_score(y_true, y_pred, zero_division=0))
    return f1, prec, rec

def best_threshold_from_val(model, val_dl, device, step=0.01):
    probs, y_true = collect_probs_and_labels(model, val_dl, device)
    pr_auc = safe_pr_auc(y_true, probs)

    best = {"thr": 0.5, "f1": 0.0, "pr_auc": pr_auc}
    if probs.size == 0:
        return best

    for thr in np.arange(0.0, 1.0 + 1e-9, step):
        f1, _, _ = metrics_at_threshold(y_true, probs, thr)
        if f1 > best["f1"]:
            best.update({"thr": float(thr), "f1": float(f1)})
    return best

def train_one_epoch(model, dl, optimizer, device, criterion, clip_grad=1.0):
    model.train()
    running, n = 0.0, 0

    for batch in dl:
        x, y = unpack_batch(batch)
        x = x.to(device)
        y = y.to(device).float()

        optimizer.zero_grad(set_to_none=True)
        logits = model(x).squeeze()
        loss = criterion(logits, y)

        loss.backward()
        if clip_grad is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad)
        optimizer.step()

        bs = y.shape[0]
        running += float(loss.item()) * bs
        n += bs

    return running / max(n, 1)

@torch.no_grad()
def eval_metrics(model, dl, device, thr=0.5):
    probs, y_true = collect_probs_and_labels(model, dl, device)
    pr_auc = safe_pr_auc(y_true, probs)
    f1, prec, rec = metrics_at_threshold(y_true, probs, thr) if probs.size else (float("nan"),)*3
    return {"pr_auc": pr_auc, "f1": f1, "precision": prec, "recall": rec, "n": int(len(y_true))}

# -----------------------
# Main sweep-run function
# -----------------------
def run_one(config=None):
    # Always finish any previous run in notebooks
    try:
        wandb.finish()
    except Exception:
        pass

    with wandb.init(config=config, project="dendrites-hackathon", entity="vtpy", reinit=True):
        cfg = wandb.config

        # ---- You must define these in your notebook already ----
        # model = ...
        # device = ...
        # train_dl, val_dl, test_dl = ...
        # pos_weight = ... (optional)
        # kind = cfg.get("kind", "baseline") if you want
        #
        # If you already have model built elsewhere, just use it here.
        global model, device, train_dl, val_dl, test_dl

        # Optimizer
        optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

        # Loss (binary)
        # If you have pos_weight, use it. Otherwise remove it.
        if "pos_weight" in globals():
            criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(device))
        else:
            criterion = nn.BCEWithLogitsLoss()

        # Nice W&B summaries
        wandb.define_metric("val/pr_auc", summary="max")
        wandb.define_metric("test/pr_auc", summary="max")

        best_state = {"val_pr_auc": -1.0, "epoch": -1, "thr": 0.5}

        for ep in range(int(cfg.epochs)):
            train_loss = train_one_epoch(model, train_dl, optimizer, device, criterion, clip_grad=1.0)

            # threshold selection on val
            best_val = best_threshold_from_val(model, val_dl, device, step=0.01)
            val_at_05 = eval_metrics(model, val_dl, device, thr=0.5)
            test_at_05 = eval_metrics(model, test_dl, device, thr=0.5)
            test_at_best = eval_metrics(model, test_dl, device, thr=best_val["thr"])

            wandb.log({
                "epoch": ep + 1,
                "train/loss": train_loss,

                "val/pr_auc": best_val["pr_auc"],
                "val/best_thr": best_val["thr"],
                "val/best_f1": best_val["f1"],
                "val/f1@0.5": val_at_05["f1"],

                "test/pr_auc": test_at_05["pr_auc"],
                "test/f1@0.5": test_at_05["f1"],

                "final_test_f1@best_thr": test_at_best["f1"],
                "final_test_precision@best_thr": test_at_best["precision"],
                "final_test_recall@best_thr": test_at_best["recall"],
            })

            # Track best epoch by val PR-AUC
            val_pr_auc = best_val["pr_auc"]
            if np.isfinite(val_pr_auc) and val_pr_auc > best_state["val_pr_auc"]:
                best_state.update({"val_pr_auc": float(val_pr_auc), "epoch": ep + 1, "thr": float(best_val["thr"])})

        wandb.log({
            "best_epoch": best_state["epoch"],
            "best_thr": best_state["thr"],
            "best_val_pr_auc": best_state["val_pr_auc"],
        })

        print("Best:", best_state)

# BONUS 2

Take on a bigger challenge by integrating dendritic optimization into a new framework.  We are currently set up for Huggingface, PyTorch Lightning, PyTorch Geometric, and PyTorch Tabular.  If you get dendritic optimization working with other similar cases that’ll get you bonus points. (note if you don't want any specific bonus points and are actually looking for an easier project doing a project within these frameworks using what we’ve already built is a great option for that)
Note, these projects are also harder.  If all of your code is in one file adding Dendritic Optimization can be done in under an hour.  If you have to dive into the depths of a library with custom trainers etc., it is more challenging.
Options: MMDetection - Ultralytics YOLO - MONAI - SAIL Models - Torchtune - Ray
Other framework adjacent options with added complexity: NanoGPT, Meta Map Anything, Huggingface Distillation
Additions to your PR.  If you find any bugs or optimizations that you submit in your PR in addition to your project to the examples folder that’ll also get you bonus points.
Connect your project to a business need within your case study. Describe how this model optimization unlocks use cases, hardware options, or data limited training.  


## NanoGPT (framework-adjacent option)

In [None]:
!git clone https://github.com/karpathy/nanoGPT.git


Cloning into 'nanoGPT'...
remote: Enumerating objects: 689, done.[K
remote: Total 689 (delta 0), reused 0 (delta 0), pack-reused 689 (from 1)[K
Receiving objects: 100% (689/689), 975.24 KiB | 13.74 MiB/s, done.
Resolving deltas: 100% (382/382), done.


In [None]:
%cd nanoGPT

!pip -q install PerforatedAI || true
!pip -q install perforatedai || true

/content/nanoGPT
[31mERROR: Could not open requirements file: [Errno 2] No such file or directory: 'requirements.txt'[0m[31m
[0m

In [None]:
%%writefile dendritic_patch.py
import re
import inspect
from typing import Iterable, Optional, Dict, Any, Tuple

import torch
import torch.nn as nn


def _import_dendritic_linear():
    tried = []
    for mod, name in [
        ("perforatedai", "DendriticLinear"),
        ("perforatedai.dendrites", "DendriticLinear"),
        ("perforatedai.modules", "DendriticLinear"),
        ("perforated_ai", "DendriticLinear"),
        ("perforated_ai.dendrites", "DendriticLinear"),
    ]:
        try:
            m = __import__(mod, fromlist=[name])
            return getattr(m, name)
        except Exception as e:
            tried.append((mod, str(e)))
    raise ImportError(
        "Could not import DendriticLinear. Tried:\n" +
        "\n".join([f"- {m}: {err}" for m, err in tried])
    )


def _filter_kwargs_for_init(cls, kwargs: Dict[str, Any]) -> Dict[str, Any]:
    sig = inspect.signature(cls.__init__)
    allowed = set(sig.parameters.keys())
    allowed.discard("self")
    return {k: v for k, v in kwargs.items() if k in allowed}


def apply_dendrites_to_nanogpt(
    model: nn.Module,
    *,
    include_regex: Iterable[str],
    exclude_regex: Iterable[str] = (),
    dend_kwargs: Optional[Dict[str, Any]] = None,
    verbose: bool = True,
) -> Tuple[int, int]:
    DendriticLinear = _import_dendritic_linear()
    dend_kwargs = dend_kwargs or {}
    dend_kwargs = _filter_kwargs_for_init(DendriticLinear, dend_kwargs)

    inc = [re.compile(p) for p in include_regex]
    exc = [re.compile(p) for p in exclude_regex]

    def ok(name: str) -> bool:
        if not any(r.search(name) for r in inc):
            return False
        if any(r.search(name) for r in exc):
            return False
        return True

    name_to_parent = {}
    for parent_name, parent in model.named_modules():
        for child_name, child in parent.named_children():
            full = f"{parent_name}.{child_name}" if parent_name else child_name
            name_to_parent[full] = (parent, child_name, child)

    replaced, seen = 0, 0
    device = next(model.parameters()).device

    for full_name, (parent, child_name, child) in name_to_parent.items():
        if isinstance(child, nn.Linear):
            seen += 1
            if not ok(full_name):
                continue

            new_layer = DendriticLinear(
                child.in_features,
                child.out_features,
                bias=(child.bias is not None),
                **dend_kwargs,
            ).to(device)

            with torch.no_grad():
                new_layer.weight.copy_(child.weight)
                if child.bias is not None:
                    new_layer.bias.copy_(child.bias)

            setattr(parent, child_name, new_layer)
            replaced += 1
            if verbose:
                print(f"[dendrites] replaced: {full_name}")

    if verbose:
        print(f"[dendrites] replaced {replaced}/{seen} Linear layers")
    return replaced, seen

Writing dendritic_patch.py


In [None]:
import re
from pathlib import Path

p = Path("train.py")
txt = p.read_text()

# 1) Ensure argparse import exists
if "import argparse" not in txt:
    # put argparse near other imports; fall back to top if needed
    if "import os" in txt:
        txt = txt.replace("import os", "import os\nimport argparse", 1)
    else:
        txt = "import argparse\n" + txt

# 2) Inject args parsing after a stable marker
marker = "from contextlib import nullcontext"
args_block = """
parser = argparse.ArgumentParser()
parser.add_argument("--dendrites", action="store_true")
parser.add_argument("--dend_target", type=str, default="mlp", choices=["mlp", "mlp+attn"])
parser.add_argument("--dend_init_mag", type=float, default=0.1)
parser.add_argument("--dend_switch_threshold", type=float, default=0.5)
parser.add_argument("--max_dendrites", type=int, default=8)
parser.add_argument("--threshold", type=float, default=0.5)
args, _ = parser.parse_known_args()
"""

if marker in txt and "parser.add_argument(\"--dendrites\"" not in txt:
    txt = txt.replace(marker, marker + "\n" + args_block, 1)

# 3) Insert dendrites patch right after model = GPT(gptconf)
patch_block = r"""
from dendritic_patch import apply_dendrites_to_nanogpt

if args.dendrites:
    include = [
        r"^transformer\.h\.\d+\.mlp\.c_fc$",
        r"^transformer\.h\.\d+\.mlp\.c_proj$",
    ]
    if args.dend_target == "mlp+attn":
        include += [
            r"^transformer\.h\.\d+\.attn\.c_attn$",
            r"^transformer\.h\.\d+\.attn\.c_proj$",
        ]
    exclude = [r"lm_head", r"wte", r"wpe"]

    dend_kwargs = dict(
        dend_init_mag=args.dend_init_mag,
        dend_switch_threshold=args.dend_switch_threshold,
        max_dendrites=args.max_dendrites,
        threshold=args.threshold,
    )

    apply_dendrites_to_nanogpt(
        model,
        include_regex=include,
        exclude_regex=exclude,
        dend_kwargs=dend_kwargs,
        verbose=True,
    )
"""

if "from dendritic_patch import apply_dendrites_to_nanogpt" not in txt:
    m = re.search(r"model\s*=\s*GPT\(gptconf\)\s*\n", txt)
    if not m:
        raise RuntimeError("Couldn't find the line `model = GPT(gptconf)` in train.py. Paste the model creation section and I’ll adapt the inserter.")
    insert_at = m.end()
    txt = txt[:insert_at] + patch_block + "\n" + txt[insert_at:]

p.write_text(txt)
print("YAY Patched train.py (no re.sub template escapes).")

YAY Patched train.py (no re.sub template escapes).


In [None]:
import re
from pathlib import Path

p = Path("train.py")
txt = p.read_text()

# 1) remove our argparse block if present
txt = re.sub(r"\nparser = argparse\.ArgumentParser\(\)[\s\S]*?args, _ = parser\.parse_known_args\(\)\n", "\n", txt)

# 2) remove argparse import if we added it
txt = txt.replace("\nimport argparse\n", "\n")

# 3) insert dendrite default globals BEFORE configurator runs
defaults_block = """
# --- Dendritic Optimization (configurator-controlled) ---
dendrites = False              # set True via --dendrites=True
dend_target = "mlp"            # "mlp" or "mlp+attn"
dend_init_mag = 0.1
dend_switch_threshold = 0.5
max_dendrites = 8
threshold = 0.5
# --------------------------------------------------------
"""

if "Dendritic Optimization (configurator-controlled)" not in txt:
    m = re.search(r"exec\(open\('configurator\.py'\)\.read\(\)\)", txt)
    if not m:
        raise RuntimeError("Couldn't find configurator exec line in train.py")
    txt = txt[:m.start()] + defaults_block + "\n" + txt[m.start():]

# 4) replace args.* usage inside our dendrites patch with configurator globals
txt = txt.replace("if args.dendrites:", "if dendrites:")
txt = txt.replace("if args.dend_target == \"mlp+attn\":", "if dend_target == \"mlp+attn\":")
txt = txt.replace("dend_init_mag=args.dend_init_mag", "dend_init_mag=dend_init_mag")
txt = txt.replace("dend_switch_threshold=args.dend_switch_threshold", "dend_switch_threshold=dend_switch_threshold")
txt = txt.replace("max_dendrites=args.max_dendrites", "max_dendrites=max_dendrites")
txt = txt.replace("threshold=args.threshold", "threshold=threshold")

p.write_text(txt)
print("*Updated train.py to use nanoGPT configurator vars (no argparse flags).")

*Updated train.py to use nanoGPT configurator vars (no argparse flags).


In [None]:
!ls config

eval_gpt2_large.py   eval_gpt2_xl.py	      train_shakespeare_char.py
eval_gpt2_medium.py  finetune_shakespeare.py
eval_gpt2.py	     train_gpt2.py


In [None]:
!pip -q install --upgrade pip
!pip -q install "git+https://github.com/PerforatedAI/PerforatedAI.git"

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for perforatedai (pyproject.toml) ... [?25l[?25hdone


In [None]:
import perforatedai
from perforatedai import globals_perforatedai as GPA
from perforatedai import utils_perforatedai as UPA
print("perforatedai OK", perforatedai.__file__)

perforatedai OK /content/PerforatedAI/perforatedai/__init__.py


In [None]:
from pathlib import Path

Path("dendritic_patch.py").write_text(r'''
import torch
import torch.nn as nn

def apply_dendrites_to_nanogpt(
    model: nn.Module,
    dend_target: str = "mlp",   # "mlp" | "attn" | "all"
    max_dendrites: int = 8,
    out_dir: str = "out",
    switch_mode: str = "fixed", # "fixed" is easiest for short runs
    first_switch: int = 1,
    switch_every: int = 1,
):
    """
    Returns: (model, tracker)

    Uses PerforatedAI PAINeuronModuleTracker + PAIConfig.
    Converts selected nn.Linear layers (mlp/attn) inside nanoGPT blocks.
    Uses fixed switching so you see dendrite additions in short runs.
    """
    try:
        import perforatedai as pai
    except Exception as e:
        raise ImportError("perforatedai not installed. Run: !pip install perforatedai") from e

    names_to_convert = []
    for name, module in model.named_modules():
        if not isinstance(module, nn.Linear):
            continue

        is_mlp  = ".mlp."  in name
        is_attn = ".attn." in name

        if dend_target == "mlp" and is_mlp:
            names_to_convert.append(name)
        elif dend_target == "attn" and is_attn:
            names_to_convert.append(name)
        elif dend_target == "all" and (is_mlp or is_attn):
            names_to_convert.append(name)

    if len(names_to_convert) == 0:
        raise RuntimeError(
            f"No nn.Linear modules matched dend_target='{dend_target}'. "
            "Double-check nanoGPT module names."
        )

    config = pai.PAIConfig()
    config.set_max_dendrites(int(max_dendrites))
    config.set_module_names_to_convert(names_to_convert)

    if switch_mode == "fixed":
        config.set_switch_mode(config.DOING_FIXED_SWITCH)
        config.set_first_fixed_switch_num(int(first_switch))
        config.set_fixed_switch_num(int(switch_every))
    elif switch_mode == "history":
        config.set_switch_mode(config.DOING_HISTORY)
        config.set_n_epochs_to_switch(2)
    else:
        config.set_switch_mode(config.DOING_NO_SWITCH)

    tracker = pai.PAINeuronModuleTracker(
        doing_pai=True,
        save_name=out_dir,
        maximizing_score=True,
        config=config,
    )

    model = tracker.initialize(model)
    return model, tracker
''')

print(" dendritic_patch.py written")

 dendritic_patch.py written


In [None]:
!python -c "import dendritic_patch; print('import ok')"

import ok


In [None]:
import perforatedai as pai
print("perforatedai path:", pai.__file__)
print("Has PAINeuronModuleTracker:", hasattr(pai, "PAINeuronModuleTracker"))
print("Has initialize_pai:", hasattr(pai, "initialize_pai"))
print("Top-level names sample:", [x for x in dir(pai) if "PAI" in x or "Tracker" in x][:30])

perforatedai path: /content/PerforatedAI/perforatedai/__init__.py
Has PAINeuronModuleTracker: False
Has initialize_pai: False
Top-level names sample: []


In [None]:
import importlib, inspect

m = importlib.import_module("perforatedai.modules_perforatedai")
print("module file:", m.__file__)

funcs = [name for name,obj in inspect.getmembers(m, inspect.isfunction)]
clss  = [name for name,obj in inspect.getmembers(m, inspect.isclass)]

print("functions (sample):", funcs[:50])
print("classes (sample):", clss[:50])

# likely candidates
print("init-ish candidates:", [f for f in funcs if "init" in f.lower() or "initialize" in f.lower()][:50])
print("pai candidates:", [f for f in funcs if "pai" in f.lower()][:50])

module file: /content/PerforatedAI/perforatedai/modules_perforatedai.py
functions (sample): ['filter_backward', 'init_params', 'set_tracked_params', 'set_wrapped_params']
classes (sample): ['DendriteValueTracker', 'PAIDendriteModule', 'PAINeuronModule', 'TrackedNeuronModule', 'datetime']
init-ish candidates: ['init_params']
pai candidates: []


In [None]:
from pathlib import Path

patch = r"""
from __future__ import annotations
import torch
import torch.nn as nn
from perforatedai import modules_perforatedai as pm

def _should_wrap(name: str, dend_target: str) -> bool:
    n = name.lower()
    if dend_target == "mlp":
        return ".mlp." in n
    if dend_target == "attn":
        return ".attn." in n
    if dend_target == "all":
        return True
    # fallback: substring selector
    return dend_target.lower() in n

def _wrap_linear(linear: nn.Linear, name: str, activation_function_value: float):
    # Dendritic wrapper (most relevant)
    try:
        return pm.PAIDendriteModule(
            linear,
            activation_function_value=activation_function_value,
            name=name,
            output_dimensions=linear.out_features,
        )
    except TypeError:
        # If output_dimensions isn't accepted in your version:
        return pm.PAIDendriteModule(
            linear,
            activation_function_value=activation_function_value,
            name=name,
        )

def apply_dendrites_to_nanogpt(
    model: nn.Module,
    dend_target: str = "mlp",
    activation_function_value: float = 0.3,
    verbose: bool = True,
) -> nn.Module:
    replaced = 0

    # Replace via parent setattr
    for name, module in list(model.named_modules()):
        if not isinstance(module, nn.Linear):
            continue
        if not _should_wrap(name, dend_target):
            continue

        parts = name.split(".")
        parent = model
        for p in parts[:-1]:
            parent = getattr(parent, p)
        leaf = parts[-1]

        wrapped = _wrap_linear(module, name=name, activation_function_value=activation_function_value)
        setattr(parent, leaf, wrapped)
        replaced += 1

    # Find one wrapped module to use as neuron_main_module for init_params
    neuron_main = None
    for m in model.modules():
        if isinstance(m, pm.PAIDendriteModule) or isinstance(m, pm.PAINeuronModule):
            neuron_main = m
            break

    if neuron_main is None:
        raise RuntimeError("No dendritic modules were wrapped; check dend_target selection.")

    # Required init path for this library version
    try:
        pm.init_params(model, neuron_main)
    except Exception as e:
        if verbose:
            print("[PerforatedAI] init_params failed:", repr(e))
            raise

    # Optional registration helpers (often needed)
    try:
        pm.set_wrapped_params(model)
    except Exception as e:
        if verbose:
            print("[PerforatedAI] set_wrapped_params skipped:", repr(e))

    try:
        pm.set_tracked_params(model)
    except Exception as e:
        if verbose:
            print("[PerforatedAI] set_tracked_params skipped:", repr(e))

    if verbose:
        print(f"[PerforatedAI] Wrapped {replaced} Linear layers (target={dend_target}).")
        print(f"[PerforatedAI] neuron_main_module: {type(neuron_main).__name__}")

    return model
"""

Path("/content/nanoGPT/dendritic_patch.py").write_text(patch)
print("Wrote /content/nanoGPT/dendritic_patch.py")

Wrote /content/nanoGPT/dendritic_patch.py


In [None]:
from pathlib import Path

train_py = Path("/content/nanoGPT/train.py")
txt = train_py.read_text()

# 1) Add config globals BEFORE configurator exec (prevents AssertionError on unknown CLI keys)
marker = "exec(open('configurator.py').read())"
if marker in txt and "dendrites = False" not in txt:
    inject = (
        "# --- PerforatedAI Dendrites config ---\n"
        "dendrites = False\n"
        "dend_target = 'mlp'          # 'mlp' | 'attn' | 'all'\n"
        "dend_activation = 0.3        # activation_function_value for PAIDendriteModule\n"
        "# -----------------------------------\n\n"
    )
    txt = txt.replace(marker, inject + marker)

# 2) Apply dendrites after model is created (NanoGPT usually has `model = GPT(gptconf)`).
needle = "model = GPT(gptconf)"
if needle in txt and "apply_dendrites_to_nanogpt" not in txt:
    apply_block = (
        "\n# --- Apply PerforatedAI dendrites ---\n"
        "if dendrites:\n"
        "    from dendritic_patch import apply_dendrites_to_nanogpt\n"
        "    model = apply_dendrites_to_nanogpt(\n"
        "        model,\n"
        "        dend_target=dend_target,\n"
        "        activation_function_value=dend_activation,\n"
        "        verbose=True,\n"
        "    )\n"
        "# -----------------------------------\n"
    )
    txt = txt.replace(needle, needle + apply_block)

train_py.write_text(txt)
print("Patched /content/nanoGPT/train.py")

Patched /content/nanoGPT/train.py


In [None]:
from pathlib import Path
import re

train_py = Path("/content/nanoGPT/train.py")
txt = train_py.read_text()

# Insert dendrites config right above the configurator exec, even if formatting differs
pattern = r"(exec\(open\('configurator\.py'\)\.read\(\)\)\s*# overrides from command line or config file)"
m = re.search(pattern, txt)

assert m, "Couldn't find configurator exec line in train.py (pattern mismatch)."

block = (
    "# --- PerforatedAI Dendrites config ---\n"
    "dendrites = False\n"
    "dend_target = 'mlp'          # 'mlp' | 'attn' | 'all'\n"
    "dend_activation = 0.3        # activation_function_value for PAIDendriteModule\n"
    "# -----------------------------------\n\n"
)

# Only insert if dend_activation isn't already defined before exec
pre = txt[:m.start(1)]
if "dend_activation" not in pre:
    txt = txt[:m.start(1)] + block + txt[m.start(1):]
    train_py.write_text(txt)
    print(" Inserted dend config (including dend_activation) above configurator exec.")
else:
    print("dend_activation already defined above configurator exec; no change.")

dend_activation already defined above configurator exec; no change.


In [None]:
%cd /content/nanoGPT
!python train.py config/train_shakespeare_char.py \
  --device=cpu --compile=False \
  --max_iters=200 --eval_interval=100 --batch_size=8 \
  --n_layer=2 --n_head=2 --n_embd=128 \
  --dendrites=True --dend_target=mlp --dend_activation=0.3

/content/nanoGPT
Overriding config with config/train_shakespeare_char.py:
# train a miniature character-level shakespeare model
# good for debugging and playing on macbooks and such

out_dir = 'out-shakespeare-char'
eval_interval = 250 # keep frequent because we'll overfit
eval_iters = 200
log_interval = 10 # don't print too too often

# we expect to overfit on this small dataset, so only save when val improves
always_save_checkpoint = False

wandb_log = False # override via command line if you like
wandb_project = 'shakespeare-char'
wandb_run_name = 'mini-gpt'

dataset = 'shakespeare_char'
gradient_accumulation_steps = 1
batch_size = 64
block_size = 256 # context of up to 256 previous characters

# baby GPT model :)
n_layer = 6
n_head = 6
n_embd = 384
dropout = 0.2

learning_rate = 1e-3 # with baby networks can afford to go a bit higher
max_iters = 5000
lr_decay_iters = 5000 # make equal to max_iters usually
min_lr = 1e-4 # learning_rate / 10 usually
beta2 = 0.99 # make a bit bigger b

## Restore runtime

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import wandb
api = wandb.Api()


  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: Logging into https://api.wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: Find your API key here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mnpnallstar[0m ([33mvtpy[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
import os

root = "/content/drive/MyDrive"
hits = []
for dirpath, _, filenames in os.walk(root):
    for f in filenames:
        if f.endswith((".pt", ".pth", ".ckpt", ".bin", ".zip")):
            hits.append(os.path.join(dirpath, f))

print("Found:", len(hits))
for p in hits[:200]:
    print(p)

Found: 3
/content/drive/MyDrive/AI4Alz_backup/artifacts/mri_resnet18.pt
/content/drive/MyDrive/AI4Alz_backup/artifacts/mri_model.pt
/content/drive/MyDrive/AI4Alz_backup/artifacts/mri_mobilenet.pt


In [None]:
import wandb

api = wandb.Api()
entity = "vtpy"
project = "dendrites-hackathon"

runs = api.runs(f"{entity}/{project}")
runs = list(runs)

print("Runs found:", len(runs))
for r in runs[:30]:
    print(r.id, "|", r.name, "| state:", r.state)

Runs found: 81
130atbjh | protbert_baseline_splitA | state: crashed
c31lphdj | zany-music-2 | state: finished
uydsy4vb | wandering-sunset-3 | state: finished
34jrtywb | scarlet-cloud-4 | state: finished
ki8g6z4v | apricot-sweep-1 | state: failed
12pm6gsk | warm-sweep-2 | state: failed
gwa9ns1d | distinctive-sweep-3 | state: failed
0vkvol6i | lemon-sweep-4 | state: failed
byckq6nx | rosy-sweep-5 | state: failed
a6xc5zqn | chocolate-sweep-6 | state: failed
j183qq4k | visionary-sweep-7 | state: failed
z8c42rh9 | glamorous-sweep-8 | state: failed
g3d5usln | grateful-sweep-9 | state: crashed
t4qqofk0 | summer-sweep-1 | state: failed
r6kwkvku | legendary-sweep-2 | state: failed
cvyjmd5k | crimson-sweep-3 | state: crashed
wjccc33m | quiet-sweep-4 | state: failed
k6k9a5v1 | blooming-sweep-5 | state: failed
vffph0r7 | absurd-sweep-1 | state: failed
5jokvd12 | polished-sweep-1 | state: failed
6qeiu6c8 | wise-sweep-2 | state: failed
rchj8i78 | youthful-sweep-3 | state: failed
9c9imcp7 | dulcet-sw

In [None]:
target_ids = {"tdif3p2t","hpmq62fe","3ei24bcq","jeg08kfr","0c280l2j","4b183pr9","c56jbayo"}  # add more

for r in runs:
    if r.id in target_ids:
        print("\n==", r.id, r.name, "==")
        # config keys may differ; print common ones:
        for k in ["lr","weight_decay","batch_size","max_iters","epochs","n_layer","n_head","n_embd","dendrites","dend_target"]:
            if k in r.config:
                print(f"config.{k} =", r.config[k])

        # summary keys from your logs:
        for k in ["best_val_pr_auc","best_thr","best_epoch",
                  "test/pr_auc","test/f1@0.5",
                  "final_test_f1@best_thr","final_test_precision@best_thr","final_test_recall@best_thr"]:
            if k in r.summary:
                print(f"summary.{k} =", r.summary[k])


== tdif3p2t run_0102_221555 ==
config.lr = 0.0005784145230219258
config.weight_decay = 0.001984463047868843
config.epochs = 10
summary.best_val_pr_auc = 0.4134445155441796
summary.best_thr = 0.7587597370147705
summary.best_epoch = 10
summary.test/pr_auc = 0.2585728651429871
summary.test/f1@0.5 = 0.199999999999815
summary.final_test_f1@best_thr = 0.31578947368368976
summary.final_test_precision@best_thr = 0.27272727272724795
summary.final_test_recall@best_thr = 0.3749999999999531

== hpmq62fe run_0102_221627 ==
config.lr = 0.001065783331350847
config.weight_decay = 0.00022236512171524135
config.epochs = 10
summary.best_val_pr_auc = 0.40934331797235024
summary.best_thr = 0.6943539977073669
summary.best_epoch = 2
summary.test/pr_auc = 0.4463422983701668
summary.test/f1@0.5 = 0.15384615384581987
summary.final_test_f1@best_thr = 0.42857142857087754
summary.final_test_precision@best_thr = 0.4999999999999167
summary.final_test_recall@best_thr = 0.3749999999999531

== 3ei24bcq run_0102_221656

In [None]:
import wandb

api = wandb.Api()
entity = "vtpy"
project = "dendrites-hackathon"

runs = list(api.runs(f"{entity}/{project}"))
print("Runs found:", len(runs))

# ---- 1) Verify your target IDs exist in this project ----
target_ids = {"tdif3p2t","hpmq62fe","3ei24bcq","jeg08kfr","0c280l2j","4b183pr9","c56jbayo"}

id_to_run = {r.id: r for r in runs}
present = sorted(target_ids.intersection(id_to_run.keys()))
missing = sorted(target_ids.difference(id_to_run.keys()))
print("\nTarget IDs present:", present)
print("Target IDs missing:", missing)

# ---- 2) Print config + summary for ONLY those runs ----
config_keys = ["lr","weight_decay","batch_size","max_iters","epochs","n_layer","n_head","n_embd","dendrites","dend_target"]
summary_keys = [
    "best_val_pr_auc","best_thr","best_epoch",
    "test/pr_auc","test/f1@0.5",
    "final_test_f1@best_thr","final_test_precision@best_thr","final_test_recall@best_thr"
]

def safe_get(d, k):
    try:
        return d.get(k, None)
    except Exception:
        return None

for rid in present:
    r = id_to_run[rid]
    print("\n" + "="*70)
    print(f"== {r.id} | {r.name} | state: {r.state} ==")
    print("url:", r.url)

    # config
    for k in config_keys:
        v = safe_get(r.config, k)
        if v is not None:
            print(f"config.{k} = {v}")

    # summary
    for k in summary_keys:
        v = safe_get(r.summary, k)
        if v is not None:
            print(f"summary.{k} = {v}")

# ---- 3) Find best runs automatically (filter finished only) ----
finished = [r for r in runs if r.state == "finished"]
print("\nFinished runs:", len(finished))

def metric(r, key):
    v = safe_get(r.summary, key)
    return float(v) if v is not None else None

# choose which metric you care about:
# A) maximize PR-AUC on test
metric_key = "test/pr_auc"
scored = [(r, metric(r, metric_key)) for r in finished]
scored = [(r, s) for (r, s) in scored if s is not None]
scored.sort(key=lambda x: x[1], reverse=True)

print(f"\nTop 10 by {metric_key}:")
for r, s in scored[:10]:
    print(f"{s:.5f} | {r.id} | {r.name} | {r.url}")

# B) maximize final_test_f1@best_thr (often what your run summaries emphasize)
metric_key2 = "final_test_f1@best_thr"
scored2 = [(r, metric(r, metric_key2)) for r in finished]
scored2 = [(r, s) for (r, s) in scored2 if s is not None]
scored2.sort(key=lambda x: x[1], reverse=True)

print(f"\nTop 10 by {metric_key2}:")
for r, s in scored2[:10]:
    print(f"{s:.5f} | {r.id} | {r.name} | {r.url}")

Runs found: 81

Target IDs present: ['0c280l2j', '3ei24bcq', '4b183pr9', 'c56jbayo', 'hpmq62fe', 'jeg08kfr', 'tdif3p2t']
Target IDs missing: []

== 0c280l2j | run_0102_222348 | state: finished ==
url: https://wandb.ai/vtpy/dendrites-hackathon/runs/0c280l2j
config.lr = 0.00015288250892757412
config.weight_decay = 3.1873598539607545e-06
config.epochs = 10
summary.best_val_pr_auc = 0.19850516696211495
summary.best_thr = 0.01088752318173647
summary.best_epoch = 5
summary.test/pr_auc = 0.2111371187135194
summary.test/f1@0.5 = 0.2857142857137868
summary.final_test_f1@best_thr = 0.2790697674415446
summary.final_test_precision@best_thr = 0.17142857142856652
summary.final_test_recall@best_thr = 0.7499999999999063

== 3ei24bcq | run_0102_221656 | state: finished ==
url: https://wandb.ai/vtpy/dendrites-hackathon/runs/3ei24bcq
config.lr = 0.00011489581841961362
config.weight_decay = 2.154883977035616e-05
config.epochs = 10
summary.best_val_pr_auc = 0.297498199559858
summary.best_thr = 0.5503516793

In [None]:
%%bash
set -e
cd /content
rm -rf PerforatedAI

git clone https://github.com/PerforatedAI/PerforatedAI.git
cd PerforatedAI

python -m pip -q install -U pip setuptools wheel
python -m pip -q install -e .

python - <<'PY'
import perforatedai
from perforatedai import utils_perforatedai as UPA
from perforatedai import globals_perforatedai as GPA

print("perforatedai import OK:", perforatedai.__file__)
print("Has initialize_pai:", hasattr(UPA, "initialize_pai"))
print("Has GPA.pai_tracker:", hasattr(GPA, "pai_tracker"))
print("GPA.pai_tracker has add_validation_score:", hasattr(GPA.pai_tracker, "add_validation_score"))

# Optional: show tracker signature safely
from perforatedai.tracker_perforatedai import PAINeuronModuleTracker
import inspect
print("PAINeuronModuleTracker.__init__ signature:", inspect.signature(PAINeuronModuleTracker.__init__))
PY

Building dendrites without Perforated Backpropagation
perforatedai import OK: /content/PerforatedAI/perforatedai/__init__.py
Has initialize_pai: True
Has GPA.pai_tracker: True
GPA.pai_tracker has add_validation_score: False
PAINeuronModuleTracker.__init__ signature: (self, doing_pai, save_name, making_graphs=True, param_vals_setting=-1, values_per_train_epoch=-1, values_per_val_epoch=-1)


Cloning into 'PerforatedAI'...
Updating files:  85% (162/189)Updating files:  86% (163/189)Updating files:  87% (165/189)Updating files:  88% (167/189)Updating files:  89% (169/189)Updating files:  90% (171/189)Updating files:  91% (172/189)Updating files:  92% (174/189)Updating files:  93% (176/189)Updating files:  94% (178/189)Updating files:  95% (180/189)Updating files:  96% (182/189)Updating files:  97% (184/189)Updating files:  98% (186/189)Updating files:  99% (188/189)Updating files: 100% (189/189)Updating files: 100% (189/189), done.


In [None]:
from perforatedai.tracker_perforatedai import PAINeuronModuleTracker
t = PAINeuronModuleTracker(None, "sanity")   # positional args

In [None]:
%%bash
set -e
cd /content

# Clean old clones
rm -rf PerforatedAI

# Clone the official repo (required by judges)
git clone https://github.com/PerforatedAI/PerforatedAI.git

# Tooling
python -m pip -q install -U pip setuptools wheel

# Avoid conflicts if you previously pip-installed something similar
python -m pip -q uninstall -y perforatedai || true

# Editable install from the cloned repo
python -m pip -q install -e /content/PerforatedAI

Cloning into 'PerforatedAI'...
Updating files:  88% (168/189)Updating files:  89% (169/189)Updating files:  90% (171/189)Updating files:  91% (172/189)Updating files:  92% (174/189)Updating files:  93% (176/189)Updating files:  94% (178/189)Updating files:  95% (180/189)Updating files:  96% (182/189)Updating files:  97% (184/189)Updating files:  98% (186/189)Updating files:  99% (188/189)Updating files: 100% (189/189)Updating files: 100% (189/189), done.


In [None]:
import os, inspect
import torch
import torch.nn as nn

import perforatedai
from perforatedai import globals_perforatedai as GPA
from perforatedai import utils_perforatedai as UPA

print("perforatedai:", perforatedai.__file__)
print("Before initialize -> GPA.pai_tracker:", type(GPA.pai_tracker), "len:", len(GPA.pai_tracker) if isinstance(GPA.pai_tracker, list) else "n/a")

# Tiny dummy model (just to force tracker creation)
model = nn.Sequential(nn.Linear(10, 8), nn.ReLU(), nn.Linear(8, 2))

out = UPA.initialize_pai(model, save_name="PAI_SANITY")  # should create PAI_SANITY/ outputs
model = out[0] if isinstance(out, tuple) else out

print("After initialize  -> GPA.pai_tracker:", type(GPA.pai_tracker), "len:", len(GPA.pai_tracker) if isinstance(GPA.pai_tracker, list) else "n/a")

# Get tracker object
tracker = GPA.pai_tracker[0] if isinstance(GPA.pai_tracker, list) else GPA.pai_tracker
print("tracker type:", type(tracker))

# Check the exact val-score method name available in YOUR installed version
print("Has add_validation_score:", hasattr(tracker, "add_validation_score"))
print("Val/score-like methods:", [m for m in dir(tracker) if ("val" in m.lower() and "score" in m.lower())])

# Show where the PAI graph landed
print("PAI_SANITY exists:", os.path.exists("PAI_SANITY"))
if os.path.exists("PAI_SANITY"):
    print("PAI_SANITY files:", sorted(os.listdir("PAI_SANITY"))[:20])

perforatedai: /usr/local/lib/python3.12/dist-packages/perforatedai/__init__.cpython-312-x86_64-linux-gnu.so
Before initialize -> GPA.pai_tracker: <class 'list'> len: 0
Running a test of Dendrite Capacity.
After initialize  -> GPA.pai_tracker: <class 'perforatedai.tracker_perforatedai.PAINeuronModuleTracker'> len: n/a
tracker type: <class 'perforatedai.tracker_perforatedai.PAINeuronModuleTracker'>
Has add_validation_score: True
Val/score-like methods: ['add_validation_score', 'reset_vals_for_score_reset']
PAI_SANITY exists: False


# Build simple MLP model

In [None]:
import torch
import torch.nn as nn

class MLPBinary(nn.Module):
    def __init__(self, in_dim: int, hidden: int = 256, dropout: float = 0.2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, 1)  # logits
        )

    def forward(self, x):
        return self.net(x).squeeze(-1)  # (batch,)

In [None]:
in_dim = 768  # example for many text embedding models; change to your feature size
model = MLPBinary(in_dim=in_dim)

In [None]:
from perforatedai import globals_perforatedai as GPA
from perforatedai import utils_perforatedai as UPA

GPA.pc.set_testing_dendrite_capacity(False)
model = UPA.initialize_pai(model, save_name="PAI")

Running Dendrite Experiment


In [None]:
import torch
from torch.utils.data import TensorDataset, DataLoader

# ----------------------------
# 1) Get data: use existing X/y if present, else make synthetic
# ----------------------------
def _find_first_existing(names):
    for n in names:
        if n in globals():
            return globals()[n], n
    return None, None

X_obj, X_name = _find_first_existing(["X", "features", "inputs", "X_all"])
y_obj, y_name = _find_first_existing(["y", "Y", "labels", "target", "targets", "y_all"])

if X_obj is None or y_obj is None:
    print("⚠️ No (X, y) found in memory. Creating a synthetic dataset for a sanity-run...")
    # Synthetic binary classification with imbalance (≈10% positives)
    N, D = 6000, 128
    g = torch.Generator().manual_seed(42)
    X = torch.randn(N, D, generator=g)
    w = torch.randn(D, generator=g)
    logits = X @ w
    probs = torch.sigmoid((logits - logits.mean()) / (logits.std() + 1e-6))
    # push to ~10% positives
    thresh = torch.quantile(probs, 0.90)
    y = (probs > thresh).long()
else:
    print(f"✅ Found data in memory: X from '{X_name}', y from '{y_name}'")
    X = torch.as_tensor(X_obj)
    y = torch.as_tensor(y_obj)

# Ensure dtypes/shapes
X = X.float()
# If y is one-hot (N,C), convert to class index
if y.ndim > 1 and y.size(-1) > 1:
    y = torch.argmax(y, dim=-1)
# Make y shape (N,)
y = y.view(-1).long()

print("Data shapes:", "X =", tuple(X.shape), "| y =", tuple(y.shape), "| positives =", int((y==1).sum()), "/", len(y))

# ----------------------------
# 2) Split: 80/10/10
# ----------------------------
g = torch.Generator().manual_seed(123)
N = X.size(0)
perm = torch.randperm(N, generator=g)

n_train = int(0.8 * N)
n_val   = int(0.1 * N)
train_idx = perm[:n_train]
val_idx   = perm[n_train:n_train+n_val]
test_idx  = perm[n_train+n_val:]

X_train, y_train = X[train_idx], y[train_idx]
X_val,   y_val   = X[val_idx],   y[val_idx]
X_test,  y_test  = X[test_idx],  y[test_idx]

print("Split sizes:", len(train_idx), len(val_idx), len(test_idx))

# ----------------------------
# 3) DataLoaders
# ----------------------------
batch_size = 64
train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(TensorDataset(X_val, y_val),     batch_size=batch_size, shuffle=False)
test_loader  = DataLoader(TensorDataset(X_test, y_test),   batch_size=batch_size, shuffle=False)

# Sanity batch
xb, yb = next(iter(train_loader))
print("Batch:", xb.shape, yb.shape, "yb unique:", torch.unique(yb).tolist())

print("\nNow you have: X_train, y_train, X_val, y_val, X_test, y_test + train_loader/val_loader/test_loader.")
print("Next: re-run your 'Running Dendrite Experiment' cell.")

⚠️ No (X, y) found in memory. Creating a synthetic dataset for a sanity-run...
Data shapes: X = (6000, 128) | y = (6000,) | positives = 600 / 6000
Split sizes: 4800 600 600
Batch: torch.Size([64, 128]) torch.Size([64]) yb unique: [0, 1]

Now you have: X_train, y_train, X_val, y_val, X_test, y_test + train_loader/val_loader/test_loader.
Next: re-run your 'Running Dendrite Experiment' cell.


In [None]:
# =========================
# FIX: use tracker-managed optimizer so add_validation_score() can update LR
# =========================
import os, json
import torch
import torch.nn as nn
from sklearn.metrics import average_precision_score

# 0) Verify loaders exist
required = ["train_loader","val_loader","test_loader","X_train","y_train","X_val","y_val","X_test","y_test"]
missing = [k for k in required if k not in globals()]
if missing:
    raise NameError(f"Missing: {missing}\nRun your dataset-split cell first.")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

# 1) PerforatedAI imports
import perforatedai
from perforatedai import globals_perforatedai as GPA
from perforatedai import utils_perforatedai as UPA
print("perforatedai:", perforatedai.__file__)

# 2) (Optional) W&B
USE_WANDB = True
try:
    import wandb
    if USE_WANDB:
        wandb.login()
except Exception as e:
    print("wandb disabled:", repr(e))
    USE_WANDB = False

# 3) Minimal model
class MLPBinary(nn.Module):
    def __init__(self, in_dim=128, hidden=256, dropout=0.2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, 1),
        )
    def forward(self, x):
        return self.net(x).squeeze(-1)

@torch.no_grad()
def eval_pr_auc(model, loader):
    model.eval()
    probs, ys = [], []
    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.to(device)
        logits = model(xb)
        p = torch.sigmoid(logits).detach().cpu()
        probs.append(p)
        ys.append(yb.detach().cpu())
    p = torch.cat(probs).numpy()
    y = torch.cat(ys).numpy()
    return float(average_precision_score(y, p))

def train_one_epoch(model, loader, criterion, optimizer):
    model.train()
    tot, n = 0.0, 0
    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.to(device).float()
        optimizer.zero_grad(set_to_none=True)
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()
        tot += float(loss.item()) * xb.size(0)
        n += xb.size(0)
    return tot / max(1, n)

# 4) Configure PAI safely (only call setters that exist)
def safe_set(obj, name, *args):
    if hasattr(obj, name):
        getattr(obj, name)(*args)

def configure_pai(max_dendrites=8, switch_speed=5, verbose=False):
    pc = getattr(GPA, "pc", None)
    if pc is None:
        print("⚠️ GPA.pc missing")
        return
    safe_set(pc, "set_verbose", bool(verbose))
    safe_set(pc, "set_max_dendrites", int(max_dendrites))
    safe_set(pc, "set_n_epochs_to_switch", int(switch_speed))
    safe_set(pc, "set_improvement_threshold", [0.001, 0.0001, 0])
    safe_set(pc, "set_candidate_weight_initialization_multiplier", 0.1)
    safe_set(pc, "set_pai_forward_function", torch.relu)
    safe_set(pc, "set_modules_to_convert", [nn.Linear])
    safe_set(pc, "set_modules_to_track", [])
    safe_set(pc, "set_perforated_backpropagation", False)

# 5) Run experiment (CRITICAL: tracker.setup_optimizer)
def run_experiment(run_name="PAI_SANITY_synth_v2",
                   max_epochs=20,
                   lr=3e-4,
                   weight_decay=1e-4,
                   max_dendrites=8,
                   switch_speed=5):

    os.makedirs("PAI", exist_ok=True)

    # model
    model = MLPBinary(in_dim=X_train.shape[1], hidden=256, dropout=0.2).to(device)

    # configure + initialize
    configure_pai(max_dendrites=max_dendrites, switch_speed=switch_speed, verbose=False)
    model = UPA.initialize_pai(model, save_name="PAI").to(device)

    tracker = GPA.pai_tracker
    if not hasattr(tracker, "add_validation_score"):
        raise RuntimeError("Tracker missing add_validation_score — your install/import is wrong.")

    # tell tracker which optimizer/scheduler to use, then create optimizer via tracker
    # (this is what fixes your crash)
    if hasattr(tracker, "set_optimizer"):
        tracker.set_optimizer(torch.optim.AdamW)
    if hasattr(tracker, "set_scheduler"):
        tracker.set_scheduler(torch.optim.lr_scheduler.ReduceLROnPlateau)

    # IMPORTANT: use tracker.setup_optimizer so internal optimizer != None
    if hasattr(tracker, "setup_optimizer"):
        optimArgs = {"params": model.parameters(), "lr": lr, "weight_decay": weight_decay}
        schedArgs = {"mode": "max", "patience": 3}
        optimizer, _scheduler = tracker.setup_optimizer(model, optimArgs, schedArgs)
    else:
        optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    criterion = nn.BCEWithLogitsLoss()

    wb = None
    if USE_WANDB:
        wb = wandb.init(
            project="dendrites-hackathon",
            name=run_name,
            config=dict(lr=lr, weight_decay=weight_decay, max_epochs=max_epochs,
                        max_dendrites=max_dendrites, switch_speed=switch_speed),
            reinit=True,
        )

    best_val = -1.0
    best_test = -1.0
    best_params = None

    for epoch in range(1, max_epochs + 1):
        train_loss = train_one_epoch(model, train_loader, criterion, optimizer)
        val_pr = eval_pr_auc(model, val_loader)
        test_pr = eval_pr_auc(model, test_loader)
        params = int(UPA.count_params(model))

        if val_pr > best_val:
            best_val, best_test, best_params = val_pr, test_pr, params

        log = {
            "Epoch": epoch,
            "Epoch Train Loss": train_loss,
            "Epoch Val PR_AUC": val_pr,
            "Epoch Test PR_AUC": test_pr,
            "Epoch Param Count": params,
        }
        if hasattr(tracker, "member_vars") and isinstance(tracker.member_vars, dict) and "num_dendrites_added" in tracker.member_vars:
            log["Epoch Dendrite Count"] = tracker.member_vars["num_dendrites_added"]

        print(f"[{epoch:03d}] loss={train_loss:.4f} val_pr={val_pr:.4f} test_pr={test_pr:.4f} params={params}")
        if wb: wb.log(log)

        # REQUIRED call: this triggers dendrite adding + graph generation
        model, restructured, training_complete = tracker.add_validation_score(val_pr, model)
        model = model.to(device)

        # If architecture changed, MUST rebuild optimizer via tracker again
        if restructured:
            print("RESTRUCTURED: rebuilding optimizer via tracker.setup_optimizer(...)")
            if hasattr(tracker, "setup_optimizer"):
                optimArgs = {"params": model.parameters(), "lr": lr, "weight_decay": weight_decay}
                schedArgs = {"mode": "max", "patience": 3}
                optimizer, _scheduler = tracker.setup_optimizer(model, optimArgs, schedArgs)
            else:
                optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

        if training_complete:
            print("🏁 training_complete=True (tracker stopped).")
            break

    summary = {
        "Final Max Val PR_AUC": float(best_val),
        "Final Max Test PR_AUC": float(best_test),
        "Final Param Count": int(best_params) if best_params is not None else None,
    }
    if hasattr(tracker, "member_vars") and isinstance(tracker.member_vars, dict) and "num_dendrites_added" in tracker.member_vars:
        summary["Final Dendrite Count"] = tracker.member_vars["num_dendrites_added"]

    if wb:
        wb.log(summary)
        wb.finish()

    with open("PAI/summary.json", "w") as f:
        json.dump(summary, f, indent=2)

    print("\nDONE. Judge-required artifacts should now exist:")
    print(" - PAI/PAI.png")
    print(" - PAI/summary.json\n")
    print("Summary:", summary)

    return summary

# RUN
summary = run_experiment(
    run_name="PAI_SANITY_synth_v2",
    max_epochs=20,
    lr=3e-4,
    weight_decay=1e-4,
    max_dendrites=8,
    switch_speed=5,
)

# quick check
print("\nFiles in PAI/:")
!ls -lah PAI | sed -n '1,200p'



device: cpu
perforatedai: /usr/local/lib/python3.12/dist-packages/perforatedai/__init__.cpython-312-x86_64-linux-gnu.so
Running Dendrite Experiment
For PAI training it is recommended to not use weight decay in your optimizer
--Call--
> [0;32m/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py[0m(2656)[0;36mparameters[0;34m()[0m
[0;32m   2654 [0;31m                [0;32myield[0m [0mname[0m[0;34m,[0m [0mv[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   2655 [0;31m[0;34m[0m[0m
[0m[0;32m-> 2656 [0;31m    [0;32mdef[0m [0mparameters[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mrecurse[0m[0;34m:[0m [0mbool[0m [0;34m=[0m [0;32mTrue[0m[0;34m)[0m [0;34m->[0m [0mIterator[0m[0;34m[[0m[0mParameter[0m[0;34m][0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   2657 [0;31m        r"""Return an iterator over module parameters.
[0m[0;32m   2658 [0;31m[0;34m[0m[0m
[0m
ipdb> c


0,1
Epoch,▁
Epoch Dendrite Count,▁
Epoch Param Count,▁
Epoch Test PR_AUC,▁
Epoch Train Loss,▁
Epoch Val PR_AUC,▁

0,1
Epoch,1.0
Epoch Dendrite Count,0.0
Epoch Param Count,99073.0
Epoch Test PR_AUC,0.67153
Epoch Train Loss,0.34199
Epoch Val PR_AUC,0.56234


[001] loss=0.3518 val_pr=0.5103 test_pr=0.7622 params=99073
Adding validation score 0.51032641
Checking PAI switch with mode n, switch mode DOING_SWITCH_EVERY_TIME, epoch 0, last improved epoch 0, total epochs 0, n: 5, num_cycles: 0
Returning True - switching every time
Importing best Model for switch to PA...
Switching back to N...
Resetting committed to initial rate to False
Saving model before starting normal training to retain PBNodes regardless of next N Phase results
RESTRUCTURED: rebuilding optimizer via tracker.setup_optimizer(...)
[002] loss=0.2364 val_pr=0.7599 test_pr=0.9180 params=198659
Adding validation score 0.75989882
Checking PAI switch with mode n, switch mode DOING_SWITCH_EVERY_TIME, epoch 1, last improved epoch 1, total epochs 1, n: 5, num_cycles: 2
Returning True - switching every time
Importing best Model for switch to PA...
Switching back to N...
Resetting committed to initial rate to False
Saving model before starting normal training to retain PBNodes regardless

0,1
Epoch,▁▂▃▄▅▅▆▇█
Epoch Dendrite Count,▁▂▃▄▅▅▆▇█
Epoch Param Count,▁▂▃▄▄▅▆▇█
Epoch Test PR_AUC,▁▆▇██████
Epoch Train Loss,█▆▄▃▂▂▁▁▁
Epoch Val PR_AUC,▁▅▇▇█████
Final Dendrite Count,▁
Final Max Test PR_AUC,▁
Final Max Val PR_AUC,▁
Final Param Count,▁

0,1
Epoch,9.0
Epoch Dendrite Count,8.0
Epoch Param Count,928080.0
Epoch Test PR_AUC,0.97523
Epoch Train Loss,0.01012
Epoch Val PR_AUC,0.96094
Final Dendrite Count,8.0
Final Max Test PR_AUC,0.98169
Final Max Val PR_AUC,0.96141
Final Param Count,609315.0



DONE. Judge-required artifacts should now exist:
 - PAI/PAI.png
 - PAI/summary.json

Summary: {'Final Max Val PR_AUC': 0.9614118994446165, 'Final Max Test PR_AUC': 0.9816919586885289, 'Final Param Count': 609315, 'Final Dendrite Count': 8}

Files in PAI/:
total 64M
drwxr-xr-x 2 root root 4.0K Jan  5 18:51 .
drwxr-xr-x 1 root root 4.0K Jan  5 18:47 ..
-rw-r--r-- 1 root root 782K Jan  5 18:50 beforeSwitch_0.pt
-rw-r--r-- 1 root root 2.8M Jan  5 18:51 beforeSwitch_10.pt
-rw-r--r-- 1 root root 3.2M Jan  5 18:51 beforeSwitch_12.pt
-rw-r--r-- 1 root root 3.6M Jan  5 18:51 beforeSwitch_14.pt
-rw-r--r-- 1 root root 1.2M Jan  5 18:51 beforeSwitch_2.pt
-rw-r--r-- 1 root root 1.6M Jan  5 18:51 beforeSwitch_4.pt
-rw-r--r-- 1 root root 2.0M Jan  5 18:51 beforeSwitch_6.pt
-rw-r--r-- 1 root root 2.4M Jan  5 18:51 beforeSwitch_8.pt
-rw-r--r-- 1 root root 782K Jan  5 18:50 best_model_beforeSwitch_0.pt
-rw-r--r-- 1 root root 2.8M Jan  5 18:51 best_model_beforeSwitch_10.pt
-rw-r--r-- 1 root root 3.2M Ja

In [None]:
from perforatedai import globals_perforatedai as GPA
GPA.pc.set_weight_decay_accepted(True)

In [None]:
from perforatedai import globals_perforatedai as GPA
GPA.pc.set_weight_decay_accepted(True)
# define these BEFORE wandb.init
lr = 1e-3
max_dendrites = 8   # or whatever you're using

import wandb
wandb.login()

run = wandb.init(
    project="dendrites-hackathon",
    name="PAI_SANITY_synth_v3",
    config={"lr": lr, "max_dendrites": max_dendrites},
    reinit="finish_previous",
)

In [None]:
import os, json
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

from perforatedai import globals_perforatedai as GPA
from perforatedai import utils_perforatedai as UPA

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

device: cpu


In [None]:
# Expect: X_train, y_train, X_val, y_val, X_test, y_test are torch tensors
# Shapes: X: [N, D], y: [N] (0/1) or [N,1]

def make_loader(X, y, bs=64, shuffle=False):
    ds = TensorDataset(X, y)
    return DataLoader(ds, batch_size=bs, shuffle=shuffle, drop_last=False)

batch_size = 64
train_loader = make_loader(X_train, y_train, bs=batch_size, shuffle=True)
val_loader   = make_loader(X_val,   y_val,   bs=batch_size, shuffle=False)
test_loader  = make_loader(X_test,  y_test,  bs=batch_size, shuffle=False)

In [None]:
class MLPBinary(nn.Module):
    def __init__(self, d_in=128, h=256, dropout=0.2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_in, h),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(h, h),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(h, 1),
        )
    def forward(self, x):
        return self.net(x)

model = MLPBinary(d_in=X_train.shape[1]).to(device)

# configure PerforatedAI + initialize_pai (creates tracker + PAI graphs)

In [None]:
GPA.pc.set_weight_decay_accepted(True)  # silence warning (or use wd=0)

# Basic PAI knobs (tune later)
GPA.pc.set_max_dendrites(max_dendrites)          # from your wandb config
GPA.pc.set_n_epochs_to_switch(1)                # switch speed; 1 makes it restructure quickly
GPA.pc.set_verbose(False)

# IMPORTANT: this wraps modules and sets up GPA.pai_tracker
model = UPA.initialize_pai(model, save_name="PAI")  # outputs to ./PAI/*
tracker = GPA.pai_tracker
print("tracker:", type(tracker), "has add_validation_score:", hasattr(tracker, "add_validation_score"))

Running Dendrite Experiment
tracker: <class 'perforatedai.tracker_perforatedai.PAINeuronModuleTracker'> has add_validation_score: True


In [None]:
# Use Adam. (They recommend no weight decay for PAI; set wd=0.)
lr = run.config["lr"]
weight_decay = 0.0

tracker.set_optimizer(torch.optim.Adam)
tracker.set_scheduler(torch.optim.lr_scheduler.ReduceLROnPlateau)

optimArgs = {"params": model.parameters(), "lr": lr, "weight_decay": weight_decay}
schedArgs = {"mode": "max", "patience": 3}

optimizer, scheduler = tracker.setup_optimizer(model, optimArgs, schedArgs)

#  train/eval + REQUIRED add_validation_score loop (+ wandb logging)

In [None]:
import numpy as np

criterion = nn.BCEWithLogitsLoss()

def train_one_epoch(model, loader):
    model.train()
    total_loss = 0.0
    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.to(device).float().view(-1, 1)

        optimizer.zero_grad(set_to_none=True)
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * xb.size(0)
    return total_loss / len(loader.dataset)

@torch.no_grad()
def eval_pr_auc(model, loader):
    # Simple PR-AUC for binary labels (no sklearn needed)
    model.eval()
    probs = []
    ys = []
    for xb, yb in loader:
        xb = xb.to(device)
        logits = model(xb)
        p = torch.sigmoid(logits).detach().cpu().numpy().reshape(-1)
        probs.append(p)
        ys.append(yb.cpu().numpy().reshape(-1))
    probs = np.concatenate(probs)
    ys = np.concatenate(ys).astype(int)

    # PR-AUC (approx) using sorting
    order = np.argsort(-probs)
    ys = ys[order]
    tp = np.cumsum(ys == 1)
    fp = np.cumsum(ys == 0)
    prec = tp / np.maximum(tp + fp, 1)
    rec = tp / np.maximum(tp[-1], 1)
    # trapz over recall
    return float(np.trapz(prec, rec))

def param_count(m):
    return int(UPA.count_params(m))

max_epochs = 30
best_val = -1.0

for epoch in range(1, max_epochs + 1):
    train_loss = train_one_epoch(model, train_loader)
    val_score  = eval_pr_auc(model, val_loader)     # <-- metric you MAXIMIZE
    test_score = eval_pr_auc(model, test_loader)

    # Log per-epoch
    run.log({
        "Epoch": epoch,
        "Epoch Train Loss": train_loss,
        "Epoch Val PR_AUC": val_score,
        "Epoch Test PR_AUC": test_score,
        "Epoch Param Count": param_count(model),
        "Epoch Dendrite Count": tracker.member_vars.get("num_dendrites_added", 0),
    })

    # (optional) add extra scores like the example
    tracker.add_extra_score(test_score, "Test")

    print(f"[{epoch:03d}] loss={train_loss:.4f} val_pr={val_score:.4f} test_pr={test_score:.4f} params={param_count(model)}")

    # this is what actually adds dendrites + generates PAI/PAI.png
    model, restructured, training_complete = tracker.add_validation_score(val_score, model)
    model = model.to(device)

    # If PAI changed architecture, rebuild optimizer/scheduler
    if restructured:
        optimizer, scheduler = tracker.setup_optimizer(model, optimArgs, schedArgs)

    if training_complete:
        print(" training_complete=True (tracker stopped).")
        break

run.finish()
print("Done. Check PAI/PAI.png")

In [None]:
import json, os

def load_summary(folder):
    path = os.path.join(folder, "summary.json")
    with open(path, "r") as f:
        s = json.load(f)
    # adapt keys if your summary uses different names
    # (yours looked like: Final Max Val PR_AUC / Final Max Test PR_AUC / Final Param Count)
    return {
        "val":   float(s.get("Final Max Val PR_AUC",  s.get("Final Max Val", 0))),
        "test":  float(s.get("Final Max Test PR_AUC", s.get("Final Max Test", 0))),
        "params": int(s.get("Final Param Count", s.get("Final Params", 0))),
    }

# Example:
# orig_nodend = load_summary("PAI_ORIG_NODEND")
# comp_nodend = load_summary("PAI_COMP_NODEND")
# comp_dend   = load_summary("PAI_COMP_DEND")

In [None]:
def judge_ylim(scores_percent):
    ymax = 100.0
    min_score = min(scores_percent)
    dist = ymax - min_score
    ymin = ymax - 2.0 * dist
    ymin = max(0.0, ymin)  # don’t go below 0
    return ymin, ymax


In [None]:
import matplotlib.pyplot as plt

def plot_accuracy_improvement(orig_score, dend_score, outpath="accuracy_improvement.png", title="Accuracy Improvement", ylabel="Test Score (%)"):
    # convert 0..1 to percent if needed
    if orig_score <= 1.0 and dend_score <= 1.0:
        orig = orig_score * 100.0
        dend = dend_score * 100.0
    else:
        orig, dend = orig_score, dend_score

    ymin, ymax = judge_ylim([orig, dend])

    plt.figure(figsize=(8,4.5))
    plt.bar(["Original Model", "Original + Dendrites"], [orig, dend])
    plt.ylim(ymin, ymax)
    plt.ylabel(ylabel)
    plt.title(title)
    plt.grid(axis="y", alpha=0.3)
    plt.tight_layout()
    plt.savefig(outpath, dpi=200)
    plt.show()
    print("Saved:", outpath)

# Example usage (replace with your real runs):
#plot_accuracy_improvement(orig_nodend["test"], comp_dend["test"], outpath="accuracy_improvement.png", ylabel="Test PR-AUC (%)")

In [None]:
def plot_model_compression(orig, comp, comp_dend, outpath="model_compression.png", title="Model Compression", score_label="Test Score (%)"):
    # orig/comp/comp_dend are dicts with keys: test, params
    scores = [orig["test"], comp["test"], comp_dend["test"]]
    params = [orig["params"], comp["params"], comp_dend["params"]]

    # convert score to percent if needed
    scores_pct = [s*100.0 if s <= 1.0 else s for s in scores]
    ymin, ymax = judge_ylim(scores_pct)

    labels = ["Original Model", "Original Model\n(Compressed)", "Compressed\n+ Dendrites"]

    fig, ax1 = plt.subplots(figsize=(9,4.8))
    ax1.bar(labels, scores_pct)
    ax1.set_ylabel(score_label)
    ax1.set_ylim(ymin, ymax)
    ax1.set_title(title)
    ax1.grid(axis="y", alpha=0.3)

    ax2 = ax1.twinx()
    ax2.plot(labels, [p/1e6 for p in params], marker="o")
    ax2.set_ylabel("Parameters (M)")

    plt.tight_layout()
    plt.savefig(outpath, dpi=200)
    plt.show()
    print("Saved:", outpath)

# Example usage:
# plot_model_compression(orig_nodend, comp_nodend, comp_dend, outpath="model_compression.png", score_label="Test PR-AUC (%)")

In [None]:
import os, json
import matplotlib.pyplot as plt
from IPython.display import Image, display

def load_summary(folder):
    path = os.path.join(folder, "summary.json")
    if not os.path.exists(path):
        raise FileNotFoundError(f"Missing {path}. Did you set save_name='{folder}' when calling initialize_pai?")
    with open(path, "r") as f:
        s = json.load(f)

    # Your keys likely look like these (from earlier):
    val = float(s.get("Final Max Val PR_AUC", s.get("Final Max Val", 0)))
    test = float(s.get("Final Max Test PR_AUC", s.get("Final Max Test", 0)))
    params = int(s.get("Final Param Count", s.get("Final Params", 0)))

    return {"val": val, "test": test, "params": params}

def to_percent(x):
    return x * 100.0 if x <= 1.0 else float(x)

def judge_ylim(scores_percent):
    ymax = 100.0
    min_score = min(scores_percent)
    dist = ymax - min_score
    ymin = ymax - 2.0 * dist
    return max(0.0, ymin), ymax

In [None]:
import os
print([d for d in os.listdir(".") if os.path.isdir(d)])

['.config', 'PerforatedAI', 'PAI', 'wandb', 'drive', 'sample_data']


In [None]:
import os, json

def read_summary(folder="PAI"):
    path = os.path.join(folder, "summary.json")
    with open(path, "r") as f:
        s = json.load(f)

    # choose the metric you used
    test_key_candidates = [
        "Final Max Test PR_AUC",
        "Final Max Test Accuracy",
        "Final Max Test",
    ]
    test = None
    for k in test_key_candidates:
        if k in s:
            test = float(s[k])
            break
    if test is None:
        raise KeyError(f"No test metric found. Keys: {list(s.keys())}")

    params = int(s["Final Param Count"])
    dendrites = int(s.get("Final Dendrite Count", -1))
    return test, params, dendrites

test, params, dend = read_summary("PAI")
print("PAI run -> test:", test, "params:", params, "dendrites:", dend)
print("As percent:", test*100)

PAI run -> test: 0.9816919586885289 params: 609315 dendrites: 8
As percent: 98.16919586885288


## no dendrites baseline to compare

In [None]:
import torch
import torch.nn as nn

class MLPBinary(nn.Module):
    def __init__(self, in_dim=128, hidden=256, dropout=0.2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, 1),
        )

    def forward(self, x):
        return self.net(x).squeeze(-1)

def build_model_full():
    return MLPBinary(in_dim=128, hidden=256, dropout=0.2)

def build_model_comp():
    # "X% width" example: half width (256 -> 128)
    return MLPBinary(in_dim=128, hidden=128, dropout=0.2)

In [None]:
from perforatedai import globals_perforatedai as GPA
from perforatedai import utils_perforatedai as UPA
import torch

def run_one(save_name: str, max_dendrites: int, model_builder, *, lr=1e-3, weight_decay=0.0, max_epochs=20):
    GPA.pc.set_weight_decay_accepted(True)
    GPA.pc.set_max_dendrites(max_dendrites)

    # Make sure tracker uses Adam (and expects an optimizer to exist)
    GPA.pai_tracker.set_optimizer(torch.optim.Adam)

    model = model_builder()
    model = UPA.initialize_pai(model, save_name=save_name)

    # IMPORTANT: add_validation_score uses tracker.optimizer internally
    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    GPA.pai_tracker.optimizer = opt

    best_val, best_test, best_params = -1.0, -1.0, None

    for epoch in range(1, max_epochs + 1):
        train_loss = train_one_epoch(model, train_loader, opt)  # <-- you already have this
        val_pr = eval_pr_auc(model, val_loader)                  # <-- you already have this (maximize)
        test_pr = eval_pr_auc(model, test_loader)

        # Optional extra logs
        GPA.pai_tracker.add_extra_score(test_pr, "Test")

        # REQUIRED dendrite growth call
        model, restructured, training_complete = GPA.pai_tracker.add_validation_score(val_pr, model)

        if restructured:
            opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
            GPA.pai_tracker.optimizer = opt

        if val_pr > best_val:
            best_val = val_pr
            best_test = test_pr
            best_params = UPA.count_params(model)

        if training_complete:
            break

    return {"best_val": best_val, "best_test": best_test, "best_params": best_params}

In [None]:
from perforatedai import globals_perforatedai as GPA
from perforatedai import utils_perforatedai as UPA
import torch
import os, json

def run_one(save_name: str, max_dendrites: int, model_builder, *, lr=1e-3, weight_decay=0.0, max_epochs=20):
    GPA.pc.set_weight_decay_accepted(True)
    GPA.pc.set_max_dendrites(max_dendrites)

    # tracker expects an optimizer to exist internally
    GPA.pai_tracker.set_optimizer(torch.optim.Adam)

    model = model_builder()
    model = UPA.initialize_pai(model, save_name=save_name)

    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    GPA.pai_tracker.optimizer = opt  # IMPORTANT for add_validation_score LR updates

    best_val, best_test, best_params = -1.0, -1.0, None

    for epoch in range(1, max_epochs + 1):
        # match YOUR signature
        train_loss = train_one_epoch(model, train_loader)

        val_pr  = eval_pr_auc(model, val_loader)
        test_pr = eval_pr_auc(model, test_loader)

        # Optional: extra score logs
        GPA.pai_tracker.add_extra_score(test_pr, "Test")
        # QUIRED: grows dendrites + writes PAI/*.png/csv + summary.json at end
        model, restructured, training_complete = GPA.pai_tracker.add_validation_score(val_pr, model)

        if restructured:
            # rebuild optimizer for new parameters after restructuring
            opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
            GPA.pai_tracker.optimizer = opt

        if val_pr > best_val:
            best_val = val_pr
            best_test = test_pr
            best_params = UPA.count_params(model)

        if training_complete:
            break

    return {"best_val": best_val, "best_test": best_test, "best_params": best_params}

In [None]:
import torch
import numpy as np

def _flatten_binary_logits_and_targets(logits, yb):
    # logits: [B] or [B,1] or [B,*,1] -> [B]
    if isinstance(logits, (list, tuple)):
        logits = logits[0]
    logits = logits.squeeze(-1)          # [B,1] -> [B]
    logits = logits.view(-1)             # force [B]

    # targets: [B] or [B,1] -> [B] float
    yb = yb.squeeze(-1).view(-1).float()
    return logits, yb

# ✅ Replace your train_one_epoch with this version
def train_one_epoch(model, loader):
    model.train()
    total = 0.0
    n = 0
    for xb, yb in loader:
        optimizer.zero_grad(set_to_none=True)
        logits = model(xb)
        logits, yb = _flatten_binary_logits_and_targets(logits, yb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()
        total += float(loss.item()) * xb.size(0)
        n += xb.size(0)
    return total / max(n, 1)

# ✅ Replace/define eval_pr_auc like this (no sklearn needed)
def eval_pr_auc(model, loader):
    model.eval()
    all_p = []
    all_y = []
    with torch.no_grad():
        for xb, yb in loader:
            logits = model(xb)
            logits, yb = _flatten_binary_logits_and_targets(logits, yb)
            p = torch.sigmoid(logits)
            all_p.append(p.cpu())
            all_y.append(yb.cpu())
    p = torch.cat(all_p).numpy()
    y = torch.cat(all_y).numpy()

    # PR-AUC (Average Precision) implementation
    order = np.argsort(-p)
    y_sorted = y[order]
    tp = np.cumsum(y_sorted)
    fp = np.cumsum(1 - y_sorted)
    precision = tp / np.maximum(tp + fp, 1e-12)
    recall = tp / np.maximum(tp[-1], 1e-12)
    # area under PR curve (step-wise)
    return float(np.sum((recall[1:] - recall[:-1]) * precision[1:]))

In [None]:
import torch
from perforatedai import globals_perforatedai as GPA
from perforatedai import utils_perforatedai as UPA

# optional: silence the weight decay warning (they recommend wd=0 for PAI)
GPA.pc.set_weight_decay_accepted(True)

def run_one(save_name, max_dendrites, model_builder, lr=1e-3, weight_decay=0.0, max_epochs=10):
    global optimizer, criterion  # so your existing train_one_epoch() can use them

    # 1) Build your model
    model = model_builder().to(device)

    # 2) Tell PerforatedAI how many dendrites are allowed (MUST be set before initialize)
    GPA.pc.set_max_dendrites(int(max_dendrites))

    # 3) Initialize PAI (creates tracker + PAI/ outputs when training completes)
    model = UPA.initialize_pai(model, save_name=save_name)
    tracker = GPA.pai_tracker  # now this is a PAINeuronModuleTracker object

    # 4) IMPORTANT: create optimizer THROUGH the tracker so add_validation_score can manage LR
    tracker.set_optimizer(torch.optim.Adam)
    tracker.set_scheduler(torch.optim.lr_scheduler.ReduceLROnPlateau)

    optimArgs = {
        "params": model.parameters(),
        "lr": float(lr),
        "weight_decay": float(weight_decay),
        "betas": (0.9, 0.999),
    }
    # maximize val metric (PR-AUC), so mode='max'
    schedArgs = {"mode": "max", "patience": 3}
    optimizer, scheduler = tracker.setup_optimizer(model, optimArgs, schedArgs)

    # 5) Loss (binary) — keep yours if already defined; else define here
    if "criterion" not in globals() or criterion is None:
        criterion = torch.nn.BCEWithLogitsLoss()

    best_val = -1e9
    best_test = -1e9
    best_params = None
    best_dend = None

    for epoch in range(1, max_epochs + 1):
        train_loss = train_one_epoch(model, train_loader)   # uses global optimizer + criterion
        val_pr  = eval_pr_auc(model, val_loader)
        test_pr = eval_pr_auc(model, test_loader)

        # (optional) help tracker/graphs: extra scores
        try:
            tracker.add_extra_score(test_pr, "Test")
        except Exception:
            pass

        print(f"[{epoch:03d}] loss={train_loss:.4f} val_pr={val_pr:.6f} test_pr={test_pr:.6f}")

        #  REQUIRED: this is what actually grows dendrites + generates PAI/PAI.png
        model, restructured, training_complete = tracker.add_validation_score(val_pr, model)
        model = model.to(device)

        # If tracker changed the architecture, rebuild optimizer via tracker again
        if restructured:
            optimArgs["params"] = model.parameters()
            optimizer, scheduler = tracker.setup_optimizer(model, optimArgs, schedArgs)

        if val_pr > best_val:
            best_val = val_pr
            best_test = test_pr
            best_params = UPA.count_params(model)
            # dendrite count is stored on tracker
            best_dend = getattr(tracker, "member_vars", {}).get("num_dendrites_added", None)

        if training_complete:
            print("training_complete=True (tracker stopped).")
            break

    return {
        "best_val_pr": float(best_val),
        "best_test_pr": float(best_test),
        "params": int(best_params) if best_params is not None else None,
        "dendrites": best_dend,
        "save_name": save_name,
    }

In [None]:
full_nodend = run_one("FULL_NODEND", max_dendrites=0, model_builder=build_model_full, max_epochs=10)
full_dend   = run_one("FULL_DEND",   max_dendrites=8, model_builder=build_model_full, max_epochs=10)

comp_nodend = run_one("COMP_NODEND", max_dendrites=0, model_builder=build_model_comp, max_epochs=10)
comp_dend   = run_one("COMP_DEND",   max_dendrites=8, model_builder=build_model_comp, max_epochs=10)

full_nodend, full_dend, comp_nodend, comp_dend

Running Dendrite Experiment
[001] loss=0.2761 val_pr=0.854325 test_pr=0.884311
Adding validation score 0.85432494
Checking PAI switch with mode n, switch mode DOING_SWITCH_EVERY_TIME, epoch 0, last improved epoch 0, total epochs 0, n: 1, num_cycles: 0
Returning True - switching every time
Last Dendrites were good and this hit the max of 0
before load
after load
after graphs
after save
training_complete=True (tracker stopped).
Running Dendrite Experiment
[001] loss=0.2790 val_pr=0.848635 test_pr=0.869842
Adding validation score 0.84863454
Checking PAI switch with mode n, switch mode DOING_SWITCH_EVERY_TIME, epoch 0, last improved epoch 0, total epochs 0, n: 1, num_cycles: 0
Returning True - switching every time
Importing best Model for switch to PA...
Switching back to N...
Resetting committed to initial rate to False
Saving model before starting normal training to retain PBNodes regardless of next N Phase results
[002] loss=0.1184 val_pr=0.934814 test_pr=0.937256
Adding validation scor

({'best_val_pr': 0.8543249368667603,
  'best_test_pr': 0.8843109011650085,
  'params': 99073,
  'dendrites': 0,
  'save_name': 'FULL_NODEND'},
 {'best_val_pr': 0.9509258270263672,
  'best_test_pr': 0.9325947761535645,
  'params': 401935,
  'dendrites': 3,
  'save_name': 'FULL_DEND'},
 {'best_val_pr': 0.7831177115440369,
  'best_test_pr': 0.8512295484542847,
  'params': 33153,
  'dendrites': 0,
  'save_name': 'COMP_NODEND'},
 {'best_val_pr': 0.9445125460624695,
  'best_test_pr': 0.9319042563438416,
  'params': 316624,
  'dendrites': 8,
  'save_name': 'COMP_DEND'})

In [None]:
print("=== Model Compression table (top) ===")
print("Original Model,", full_nodend_test, ",", full_nodend_params)
print("Original Model X% Width,", comp_nodend_test, ",", comp_nodend_params)
print("X% Width + Dendrites,", comp_dend_test, ",", comp_dend_params)

print("\n=== Accuracy Improvement (bottom, percent) ===")
print("Original Model,", 100*full_nodend_test)
print("Original + Dendrites,", 100*full_dend_test)

=== Model Compression table (top) ===
Original Model, 0.8843109011650085 , 99073
Original Model X% Width, 0.8512295484542847 , 33153
X% Width + Dendrites, 0.9319042563438416 , 316624

=== Accuracy Improvement (bottom, percent) ===
Original Model, 88.43109011650085
Original + Dendrites, 93.25947761535645


In [None]:
import os
for f in ["PAI/PAI.png", "PAI/summary.json"]:
    print(f, "OKOK" if os.path.exists(f) else "MISSSING")

PAI/PAI.png OKOK
PAI/summary.json OKOK


# Improve PAI plot

In [None]:
# -----------------------
# Cell 2) Make an easy synthetic dataset that reaches 97–100% accuracy
# We reshape 128 features into a 1x16x8 "image" so we can use conv layers.
# -----------------------
def make_synth(n=12000, d=128, pos_frac=0.5, sep=2.5):
    n_pos = int(n * pos_frac)
    n_neg = n - n_pos
    # Two Gaussians separated by "sep" to make classification easy
    X_pos = np.random.randn(n_pos, d) + sep
    X_neg = np.random.randn(n_neg, d) - sep
    X = np.vstack([X_pos, X_neg]).astype(np.float32)
    y = np.concatenate([np.ones(n_pos), np.zeros(n_neg)]).astype(np.float32)
    # shuffle
    idx = np.random.permutation(n)
    return X[idx], y[idx]

X, y = make_synth(n=12000, d=128, pos_frac=0.5, sep=2.2)

# Train/Val/Test split
n = len(X)
n_train = int(0.8*n)
n_val   = int(0.1*n)
X_train, y_train = X[:n_train], y[:n_train]
X_val,   y_val   = X[n_train:n_train+n_val], y[n_train:n_train+n_val]
X_test,  y_test  = X[n_train+n_val:], y[n_train+n_val:]

# Reshape to N x 1 x 16 x 8
def to_img(x): return x.reshape(-1, 1, 16, 8)

X_train_t = torch.from_numpy(to_img(X_train))
y_train_t = torch.from_numpy(y_train).view(-1, 1)  # shape (N,1) for BCEWithLogits
X_val_t   = torch.from_numpy(to_img(X_val))
y_val_t   = torch.from_numpy(y_val).view(-1, 1)
X_test_t  = torch.from_numpy(to_img(X_test))
y_test_t  = torch.from_numpy(y_test).view(-1, 1)

batch_size = 128
train_loader = DataLoader(TensorDataset(X_train_t, y_train_t), batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(TensorDataset(X_val_t,   y_val_t),   batch_size=batch_size, shuffle=False)
test_loader  = DataLoader(TensorDataset(X_test_t,  y_test_t),  batch_size=batch_size, shuffle=False)

print("Shapes:", X_train_t.shape, y_train_t.shape)

Shapes: torch.Size([9600, 1, 16, 8]) torch.Size([9600, 1])


In [None]:
# -----------------------
# Cell 3) CNN model (conv1/conv2/fc1/fc2-like) to match their example plots
# width_mult lets you make a "compressed" variant if needed.
# -----------------------
class TinyCNNBinary(nn.Module):
    def __init__(self, width_mult=1.0):
        super().__init__()
        c1 = max(4, int(8 * width_mult))
        c2 = max(8, int(16 * width_mult))
        f1 = max(16, int(64 * width_mult))

        self.conv1 = nn.Conv2d(1, c1, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(c1, c2, kernel_size=3, padding=1)
        self.act = nn.ReLU()
        self.pool = nn.AdaptiveAvgPool2d((4, 2))  # keep it small/fast

        self.fc1 = nn.Linear(c2 * 4 * 2, f1)
        self.fc2 = nn.Linear(f1, 1)

        self.drop = nn.Dropout(0.1)

    def forward(self, x):
        x = self.act(self.conv1(x))
        x = self.act(self.conv2(x))
        x = self.pool(x)
        x = x.flatten(1)
        x = self.drop(self.act(self.fc1(x)))
        x = self.fc2(x)
        return x

def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print("Full params:", count_params(TinyCNNBinary(1.0)))
print("Comp params:", count_params(TinyCNNBinary(0.6)))

Full params: 9569
Comp params: 3186


In [None]:
# -----------------------
# Cell 4) Training + Accuracy(%) metric
# -----------------------
criterion = nn.BCEWithLogitsLoss()

@torch.no_grad()
def accuracy_percent(model, loader):
    model.eval()
    correct = 0
    total = 0
    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.to(device)
        logits = model(xb)
        preds = (torch.sigmoid(logits) > 0.5).float()
        correct += (preds == yb).sum().item()
        total += yb.numel()
    return 100.0 * correct / max(1, total)

def train_one_epoch(model, loader, optimizer):
    model.train()
    total_loss = 0.0
    total_n = 0
    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.to(device)
        optimizer.zero_grad(set_to_none=True)
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * xb.size(0)
        total_n += xb.size(0)
    return total_loss / max(1, total_n)

In [None]:
# -----------------------
# Cell 5) Helper: safely attach optimizer to tracker (prevents "param_groups" None error)
# and try tracker.setup_optimizer(...) in multiple calling styles.
# -----------------------
def attach_optimizer_to_tracker(tracker, opt):
    # common names
    for attr in ["optimizer", "optim", "_optimizer", "pai_optimizer", "opt"]:
        try:
            setattr(tracker, attr, opt)
        except Exception:
            pass
    # also try any attribute containing "optim"
    for attr in dir(tracker):
        if "optim" in attr.lower() and not attr.startswith("__"):
            try:
                setattr(tracker, attr, opt)
            except Exception:
                pass

def try_setup_optimizer(tracker, model, opt):
    if not hasattr(tracker, "setup_optimizer"):
        return
    # Try multiple conventions (PerforatedAI versions differ)
    for call in [
        lambda: tracker.setup_optimizer(opt),
        lambda: tracker.setup_optimizer(model, opt),
        lambda: tracker.setup_optimizer(model, opt, lr=opt.param_groups[0]["lr"]),
        lambda: tracker.setup_optimizer(model, lr=opt.param_groups[0]["lr"]),
        lambda: tracker.setup_optimizer(lr=opt.param_groups[0]["lr"]),
        lambda: tracker.setup_optimizer(),
    ]:
        try:
            call()
            return
        except Exception:
            continue

In [None]:
from perforatedai import globals_perforatedai as GPA

# Find which object holds module_names_to_convert / module_names_to_track
hits = []
for name in dir(GPA):
    if name.startswith("_"):
        continue
    obj = getattr(GPA, name)
    if hasattr(obj, "module_names_to_convert") or hasattr(obj, "module_names_to_track"):
        hits.append(name)

print("Candidates in GPA:", hits)

# Print the fields for each candidate
for name in hits:
    obj = getattr(GPA, name)
    print("\n---", name, "---")
    if hasattr(obj, "module_names_to_convert"):
        print("module_names_to_convert =", getattr(obj, "module_names_to_convert"))
    if hasattr(obj, "module_names_to_track"):
        print("module_names_to_track   =", getattr(obj, "module_names_to_track"))

Candidates in GPA: ['pc']

--- pc ---
module_names_to_convert = ['PAISequential']
module_names_to_track   = []


In [None]:
import torch.nn as nn
from perforatedai import globals_perforatedai as GPA

# Tell PerforatedAI which module TYPES to convert into dendrite-enabled modules
GPA.pc.set_modules_to_convert([nn.Conv2d, nn.Linear])

# Tell it which module TYPES to at least TRACK (avoid “not tracked” warnings)
GPA.pc.set_modules_to_track([nn.Conv2d, nn.Linear])

# Skip interactive debugger prompts next time
GPA.pc.set_unwrapped_modules_confirmed(True)

# If you use weight_decay, suppress the warning (or set weight_decay=0)
GPA.pc.set_weight_decay_accepted(True)

print("pc.convert =", GPA.pc.module_names_to_convert)
print("pc.track   =", GPA.pc.module_names_to_track)

pc.convert = ['PAISequential']
pc.track   = []


In [None]:
from perforatedai import utils_perforatedai as UPA

model = UPA.initialize_pai(model, save_name="PAI")

Running Dendrite Experiment


In [None]:
import torch

# IMPORTANT: do this after initialize_pai (so GPA.pai_tracker exists)
GPA.pai_tracker.set_optimizer(torch.optim.Adam)
GPA.pai_tracker.set_scheduler(torch.optim.lr_scheduler.ReduceLROnPlateau)

optimArgs = {
    "params": model.parameters(),
    "lr": 1e-3,
    "betas": (0.9, 0.999),
    # "weight_decay": 0.0,   # recommended 0 for PAI
}
schedArgs = {"mode": "max", "patience": 3}  # "max" because higher val score is better

optimizer, scheduler = GPA.pai_tracker.setup_optimizer(model, optimArgs, schedArgs)

In [None]:
from perforatedai import globals_perforatedai as GPA

# Convert these module TYPES (by name) into dendrite-enabled versions
GPA.pc.module_names_to_convert = ["Conv2d", "Linear", "PAISequential"]

# Track these types to avoid "not tracked" warnings
GPA.pc.module_names_to_track = ["Conv2d", "Linear"]

# Skip the interactive debugger prompts
GPA.pc.set_unwrapped_modules_confirmed(True)

# If you use weight_decay, silence warning (or just set weight_decay=0)
GPA.pc.set_weight_decay_accepted(True)

print("pc.convert =", GPA.pc.module_names_to_convert)
print("pc.track   =", GPA.pc.module_names_to_track)

pc.convert = ['Conv2d', 'Linear', 'PAISequential']
pc.track   = ['Conv2d', 'Linear']


In [None]:
from perforatedai import utils_perforatedai as UPA

model = UPA.initialize_pai(model, save_name="PAI")

Running Dendrite Experiment


In [None]:
import torch

GPA.pai_tracker.set_optimizer(torch.optim.Adam)
GPA.pai_tracker.set_scheduler(torch.optim.lr_scheduler.ReduceLROnPlateau)

optimArgs = {
    "params": model.parameters(),
    "lr": 1e-3,
    "betas": (0.9, 0.999),
    # "weight_decay": 0.0,   # recommended 0 for PAI
}
schedArgs = {"mode": "max", "patience": 3}  # maximize your metric

optimizer, scheduler = GPA.pai_tracker.setup_optimizer(model, optimArgs, schedArgs)

In [None]:
from perforatedai import utils_perforatedai as UPA

# IMPORTANT: do this AFTER setting GPA.pc.module_names_to_convert/track
model = UPA.initialize_pai(model, save_name="PAI")

Running Dendrite Experiment


In [None]:
import torch
from perforatedai import globals_perforatedai as GPA

GPA.pai_tracker.set_optimizer(torch.optim.Adam)
GPA.pai_tracker.set_scheduler(torch.optim.lr_scheduler.ReduceLROnPlateau)

optimArgs = {
    "params": model.parameters(),
    "lr": 1e-3,
    "betas": (0.9, 0.999),
    # "weight_decay": 0.0,  # recommended 0 for PAI
}
schedArgs = {"mode": "max", "patience": 3}   # maximize val metric

optimizer, scheduler = GPA.pai_tracker.setup_optimizer(model, optimArgs, schedArgs)

In [None]:
xb, yb = next(iter(train_loader))
print("xb:", xb.shape, xb.dtype)
print("yb:", yb.shape, yb.dtype, "unique:", torch.unique(yb)[:10])

xb: torch.Size([128, 1, 16, 8]) torch.float32
yb: torch.Size([128, 1]) torch.float32 unique: tensor([0., 1.])


In [None]:
import torch
import torch.nn as nn

class MLPBinary(nn.Module):
    def __init__(self, d_in):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_in, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 1),
        )

    def forward(self, x):
        if x.dim() > 2:
            x = x.view(x.size(0), -1)
        return self.net(x)

In [None]:
xb, _ = next(iter(train_loader))
d_in = xb[0].numel() if xb.dim() > 2 else xb.shape[1]
print("Using d_in =", d_in)

model = MLPBinary(d_in=d_in).to(device)

Using d_in = 128


In [None]:
from perforatedai import utils_perforatedai as UPA
model = UPA.initialize_pai(model, save_name="PAI")

Running Dendrite Experiment


In [None]:
import torch
from perforatedai import globals_perforatedai as GPA

GPA.pai_tracker.set_optimizer(torch.optim.Adam)
GPA.pai_tracker.set_scheduler(torch.optim.lr_scheduler.ReduceLROnPlateau)

optimArgs = {"params": model.parameters(), "lr": 1e-3, "betas": (0.9, 0.999)}
schedArgs = {"mode": "max", "patience": 3}

optimizer, scheduler = GPA.pai_tracker.setup_optimizer(model, optimArgs, schedArgs)

In [None]:
logits = model(xb)              # logits: (B, 1)
yb = yb.float().view(-1, 1)     # yb:     (B, 1)
loss = criterion(logits, yb)

In [None]:
import numpy as np
import torch

def eval_metric(model, loader):
    """
    Returns PR-AUC (higher is better). For binary classification.
    Assumes model outputs logits of shape [B] or [B,1].
    Assumes y is 0/1 of shape [B] or [B,1].
    """
    model.eval()
    all_probs = []
    all_y = []

    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(device)
            yb = yb.to(device).float().view(-1)  # [B]

            logits = model(xb)
            logits = logits.view(-1)             # [B]
            probs = torch.sigmoid(logits)

            all_probs.append(probs.detach().cpu())
            all_y.append(yb.detach().cpu())

    y_true = torch.cat(all_y).numpy()
    y_prob = torch.cat(all_probs).numpy()

    # PR-AUC
    from sklearn.metrics import average_precision_score
    return float(average_precision_score(y_true, y_prob))

In [None]:
import torch

def eval_metric(model, loader):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(device)
            yb = yb.to(device).float().view(-1)  # [B]

            logits = model(xb).view(-1)
            preds = (torch.sigmoid(logits) >= 0.5).float()

            correct += (preds == yb).sum().item()
            total += yb.numel()

    return float(correct / max(total, 1))

In [None]:
# --- Cell A: metric (PR-AUC) ---
import torch
from sklearn.metrics import average_precision_score

def eval_metric(model, loader):
    """
    Returns PR-AUC (higher is better) for binary classification.
    - model outputs logits [B] or [B,1]
    - y in {0,1} as [B] or [B,1]
    """
    model.eval()
    all_probs = []
    all_y = []

    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(device)
            yb = yb.to(device).float().view(-1)   # [B]

            logits = model(xb).view(-1)           # [B]
            probs = torch.sigmoid(logits)

            all_probs.append(probs.detach().cpu())
            all_y.append(yb.detach().cpu())

    y_true = torch.cat(all_y).numpy()
    y_prob = torch.cat(all_probs).numpy()
    return float(average_precision_score(y_true, y_prob))

In [None]:
import torch
import torch.nn as nn
from perforatedai import globals_perforatedai as GPA
from perforatedai import utils_perforatedai as UPA

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

class SmallCNNBinarySeq(nn.Module):
    def __init__(self, width_mult: float = 1.0):
        super().__init__()
        c1 = max(4, int(16 * width_mult))
        c2 = max(4, int(32 * width_mult))
        self.net = nn.Sequential(
            nn.Conv2d(1, c1, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(c1, c2, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(c2, 1),
        )

    def forward(self, x):
        return self.net(x)

tmp_model = SmallCNNBinarySeq().to(device)

names = sorted({type(m).__name__ for m in tmp_model.modules()})
print("Module type names in your model:")
print(names)

device: cpu
Module type names in your model:
['AdaptiveAvgPool2d', 'Conv2d', 'Flatten', 'Linear', 'MaxPool2d', 'ReLU', 'Sequential', 'SmallCNNBinarySeq']


In [None]:
# What does initialize_pai think it should convert/track?
print("pc.get_module_names_to_convert():", GPA.pc.get_module_names_to_convert())
print("pc.get_module_names_to_track():  ", GPA.pc.get_module_names_to_track())

# list module types in your model
types = sorted({type(m).__name__ for m in model.modules()})
print("Module type names in model:", types)

# sanity: show whether it's even returning a new object
tmp = SmallCNNBinarySeq().to(device)
tmp2 = UPA.initialize_pai(tmp, save_name="PAI_DEBUG")
print("initialize_pai returned same object?", tmp2 is tmp)
print("PAI modules after debug wrap:", len([m for m in tmp2.modules() if "PAI" in type(m).__name__]))

pc.get_module_names_to_convert(): []
pc.get_module_names_to_track():   []
Module type names in model: ['AdaptiveAvgPool2d', 'Conv2d', 'Flatten', 'Linear', 'MaxPool2d', 'ReLU', 'Sequential', 'SmallCNNBinarySeq', 'TrackedNeuronModule']
Running Dendrite Experiment
initialize_pai returned same object? True
PAI modules after debug wrap: 0


In [None]:
import inspect
print("pc fields:", [a for a in dir(GPA.pc) if "module_names" in a or "convert" in a or "track" in a])
print("pc.convert =", getattr(GPA.pc, "module_names_to_convert", None))
print("pc.track   =", getattr(GPA.pc, "module_names_to_track", None))
print("initialize_pai source?", UPA.initialize_pai, type(UPA.initialize_pai))

pc fields: ['_module_ids_to_convert', '_module_ids_to_track', '_module_names_to_convert', '_module_names_to_not_save', '_module_names_to_track', '_module_names_with_processing', '_modules_to_convert', '_modules_to_track', 'append_module_ids_to_convert', 'append_module_ids_to_track', 'append_module_names_to_convert', 'append_module_names_to_not_save', 'append_module_names_to_track', 'append_module_names_with_processing', 'append_modules_to_convert', 'append_modules_to_track', 'get_module_ids_to_convert', 'get_module_ids_to_track', 'get_module_names_to_convert', 'get_module_names_to_not_save', 'get_module_names_to_track', 'get_module_names_with_processing', 'get_modules_to_convert', 'get_modules_to_track', 'module_ids_to_convert', 'module_ids_to_track', 'module_names_to_convert', 'module_names_to_not_save', 'module_names_to_track', 'module_names_with_processing', 'modules_to_convert', 'modules_to_track', 'set_module_ids_to_convert', 'set_module_ids_to_track', 'set_module_names_to_convert

In [None]:
import os, shutil
import torch
import torch.nn as nn

from perforatedai import globals_perforatedai as GPA
from perforatedai import utils_perforatedai as UPA

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

# (optional) clean old artifacts
if os.path.exists("PAI"):
    shutil.rmtree("PAI")

# --- IMPORTANT: use SETTERS (and pass LISTS) ---
GPA.pc.set_module_names_to_convert(["Conv2d", "Linear", "Sequential", "PAISequential"])
GPA.pc.set_module_names_to_track(["Conv2d", "Linear"])

# Some builds also support setting by python types (safe to try)
try:
    GPA.pc.set_modules_to_convert([nn.Conv2d, nn.Linear, nn.Sequential])
    GPA.pc.set_modules_to_track([nn.Conv2d, nn.Linear])
except Exception as e:
    print("Skipping type-based set_modules_*:", e)

# Skip interactive confirmations
try:
    GPA.pc.set_unwrapped_modules_confirmed(True)
    GPA.pc.set_weight_decay_accepted(True)
except Exception as e:
    print("Skipping confirmations:", e)

print("convert =", GPA.pc.get_module_names_to_convert())
print("track   =", GPA.pc.get_module_names_to_track())

# --- Build a tiny CNN (same pattern as yours) ---
class SmallCNNBinarySeq(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(32, 1),
        )
    def forward(self, x):
        return self.net(x)

model = SmallCNNBinarySeq().to(device)

# --- WRAP (capture the returned model!) ---
model = UPA.initialize_pai(model, save_name="PAI")

# --- VERIFY wrapping happened ---
pai_modules = [n for n, m in model.named_modules() if "PAI" in type(m).__name__]
dend_params = [n for n, p in model.named_parameters() if "dendrite" in n.lower()]

print("PAI modules:", len(pai_modules), "examples:", pai_modules[:10])
print("Dendrite params:", len(dend_params), "examples:", dend_params[:10])

assert (len(pai_modules) > 0) or (len(dend_params) > 0), \
    "Still not wrapped — next step is to print model.named_modules() types and pc settings."

device: cpu
convert = ['Conv2d', 'Linear', 'Sequential', 'PAISequential']
track   = ['Conv2d', 'Linear']
Running Dendrite Experiment
PAI modules: 2 examples: ['net', 'net.dendrite_module']
Dendrite params: 6 examples: ['net.dendrite_module.parent_module.0.weight', 'net.dendrite_module.parent_module.0.bias', 'net.dendrite_module.parent_module.3.weight', 'net.dendrite_module.parent_module.3.bias', 'net.dendrite_module.parent_module.7.weight', 'net.dendrite_module.parent_module.7.bias']


In [None]:
import torch
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion = torch.nn.BCEWithLogitsLoss()

@torch.no_grad()
def eval_accuracy_percent(model, loader):
    model.eval()
    correct = 0
    total = 0
    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.to(device).float().view(-1, 1)   # [B,1]
        logits = model(xb)
        probs = torch.sigmoid(logits)
        preds = (probs >= 0.5).float()
        correct += (preds == yb).sum().item()
        total += yb.numel()
    acc = correct / max(total, 1)
    return acc * 100.0  # IMPORTANT: percent (0–100) so plots match judge example

def train_one_epoch(model, loader, optimizer):
    model.train()
    running = 0.0
    n = 0
    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.to(device).float().view(-1, 1)  # [B,1] to match logits

        optimizer.zero_grad(set_to_none=True)
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()

        bs = xb.size(0)
        running += loss.item() * bs
        n += bs
    return running / max(n, 1)

In [None]:
from perforatedai import globals_perforatedai as GPA

# Convert these module TYPES into dendrite-enabled versions
GPA.pc.set_module_names_to_convert(["Conv2d", "Linear", "Sequential", "PAISequential"])

# Track these module TYPES so their params are wrapped/tracked
GPA.pc.set_module_names_to_track(["Conv2d", "Linear"])

# Skip interactive confirmations
GPA.pc.set_unwrapped_modules_confirmed(True)
GPA.pc.set_weight_decay_accepted(True)

print("pc.convert =", GPA.pc.get_module_names_to_convert())
print("pc.track   =", GPA.pc.get_module_names_to_track())

pc.convert = ['Conv2d', 'Linear', 'Sequential', 'PAISequential']
pc.track   = ['Conv2d', 'Linear']


In [None]:
import torch
from perforatedai import utils_perforatedai as UPA

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

# ✅ Build ONCE
model = SmallCNNBinarySeq().to(device)   # <-- replace with your model constructor

# ✅ Wrap ONCE
model = UPA.initialize_pai(model, save_name="PAI")

tracker = GPA.pai_tracker
print("tracker:", type(tracker))

device: cpu
Running Dendrite Experiment
tracker: <class 'perforatedai.tracker_perforatedai.PAINeuronModuleTracker'>


In [None]:
import torch

# ✅ Fix the module mentioned in the error: ".net"
model.net.set_this_output_dimensions(torch.tensor([-1, 0, -1, -1]))

# If you get the same error for another wrapper (e.g., model.features),
# call set_this_output_dimensions on that module too.
print("set_this_output_dimensions applied to model.net")

set_this_output_dimensions applied to model.net


In [None]:
import numpy as np
from sklearn.metrics import average_precision_score
import torch.nn as nn

criterion = nn.BCEWithLogitsLoss()

def train_one_epoch(model, loader, optimizer):
    model.train()
    total = 0.0
    n = 0
    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.to(device).view(-1)  # [B]
        optimizer.zero_grad(set_to_none=True)
        logits = model(xb).view(-1) # [B]
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()
        total += float(loss.item()) * xb.size(0)
        n += xb.size(0)
    return total / max(n, 1)

@torch.no_grad()
def eval_metric(model, loader):
    model.eval()
    ys, ps = [], []
    for xb, yb in loader:
        xb = xb.to(device)
        logits = model(xb).view(-1)
        prob = torch.sigmoid(logits).detach().cpu().numpy()
        ps.append(prob)
        ys.append(yb.view(-1).cpu().numpy())
    y = np.concatenate(ys)
    p = np.concatenate(ps)
    return float(average_precision_score(y, p))  # higher is better

In [None]:
from perforatedai import globals_perforatedai as GPA
import torch

# (optional) shows all output-dimension problems
GPA.pc.set_debugging_output_dimensions(1)

# ✅ this is exactly what the error is asking for
# use a Python list (works best) — not a string, not a single int
model.net.set_this_output_dimensions([-1, 0, -1, -1])

print("model.net this_output_dimensions set")

model.net this_output_dimensions set


In [None]:
xb, yb = next(iter(train_loader))
xb = xb.to(device)

with torch.no_grad():
    out = model(xb)

print("forward ok, out shape:", out.shape)

forward ok, out shape: torch.Size([128, 1])


In [None]:
from perforatedai import globals_perforatedai as GPA
import torch

GPA.pc.set_debugging_output_dimensions(1)  # show all issues

xb, yb = next(iter(train_loader))
xb = xb.to(device)

with torch.no_grad():
    out = model(xb)

print("model(xb) out shape:", tuple(out.shape), "ndim:", out.ndim)

# Choose correct vector based on output rank
if out.ndim == 2:
    # (N, features)  <-- your case: (128, 1)
    vec = [-1, 0]
elif out.ndim == 4:
    # (N, C, H, W)
    vec = [-1, 0, -1, -1]
elif out.ndim == 3:
    # (N, C, L)  (rare)
    vec = [-1, 0, -1]
else:
    raise ValueError(f"Unsupported output ndim={out.ndim}, shape={tuple(out.shape)}")

# Apply to the module the error names: ".net"
model.net.set_this_output_dimensions(vec)
print(" set model.net.this_output_dimensions =", vec)

# Optional: turn off verbose checking once fixed
GPA.pc.set_debugging_output_dimensions(0)

model(xb) out shape: (128, 1) ndim: 2
 set model.net.this_output_dimensions = [-1, 0]


In [None]:
xb, yb = next(iter(train_loader))
xb = xb.to(device)

with torch.no_grad():
    out = model(xb)

print("forward ok, out shape:", out.shape)

forward ok, out shape: torch.Size([128, 1])


In [None]:
import torch
import numpy as np

def eval_pr_auc_binary(model, loader, device="cpu"):
    """
    For binary classification.
    Model outputs logits of shape (B,1) or (B,).
    Labels y are 0/1 (shape (B,1) or (B,)).
    Returns PR-AUC in [0,1].
    """
    from sklearn.metrics import average_precision_score

    model.eval()
    ys, ps = [], []
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(device)
            yb = yb.to(device)

            logits = model(xb)
            if logits.ndim == 2 and logits.shape[1] == 1:
                logits = logits.squeeze(1)  # (B,)
            if yb.ndim == 2 and yb.shape[1] == 1:
                yb = yb.squeeze(1)          # (B,)

            probs = torch.sigmoid(logits)

            ys.append(yb.detach().cpu().numpy())
            ps.append(probs.detach().cpu().numpy())

    y = np.concatenate(ys)
    p = np.concatenate(ps)
    return float(average_precision_score(y, p))

In [None]:
import torch

def eval_accuracy(model, loader, device="cpu"):
    """
    If model outputs:
      - binary logits (B,1) => threshold at 0
      - multi-class logits (B,C) => argmax
    Returns accuracy in [0,1].
    """
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(device)
            yb = yb.to(device)

            logits = model(xb)

            # Binary
            if logits.ndim == 2 and logits.shape[1] == 1:
                pred = (logits.squeeze(1) > 0).long()
                if yb.ndim == 2 and yb.shape[1] == 1:
                    yb = yb.squeeze(1).long()
                else:
                    yb = yb.long()

            # Multi-class
            else:
                pred = logits.argmax(dim=1)
                yb = yb.long().view(-1)

            correct += (pred == yb).sum().item()
            total += yb.numel()

    return correct / max(total, 1)

In [None]:
import time
import torch
from perforatedai import globals_perforatedai as GPA

# ---- choose metric ----
# metric_fn = lambda m, l: eval_pr_auc_binary(m, l, device=device)  # if you defined it
metric_fn = lambda m, l: eval_accuracy(m, l, device=device)         # if you defined it

GPA.pc.set_weight_decay_accepted(True)
GPA.pc.set_unwrapped_modules_confirmed(True)

max_epochs = 30
lr = 1e-3
weight_decay = 0.0

tracker = GPA.pai_tracker
print("tracker:", type(tracker))

tracker.set_optimizer(torch.optim.Adam)
tracker.set_scheduler(torch.optim.lr_scheduler.ReduceLROnPlateau)

opt_args = {
    "params": model.parameters(),
    "lr": lr,
    "betas": (0.9, 0.999),
    "weight_decay": weight_decay,
}
sched_args = {"mode": "max", "patience": 3}

try:
    optimizer, scheduler = tracker.setup_optimizer(model, opt_args, sched_args)
except TypeError:
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", patience=3)

best_val = -1e9
best_test_at_best_val = None

print("Running Dendrite Experiment")
for epoch in range(1, max_epochs + 1):
    t0 = time.time()

    # ✅ your signature: (model, loader, optimizer, device)
    train_loss = train_one_epoch(model, train_loader, optimizer, device)

    val_score  = metric_fn(model, val_loader)
    test_score = metric_fn(model, test_loader)

    # optional: show test curve on PAI plot
    try:
        tracker.add_extra_score(test_score, "Test")
    except Exception:
        pass

    # REQUIRED PAI call
    model, restructured, training_complete = tracker.add_validation_score(val_score, model)
    model = model.to(device)

    if restructured:
        opt_args["params"] = model.parameters()
        try:
            optimizer, scheduler = tracker.setup_optimizer(model, opt_args, sched_args)
        except TypeError:
            optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", patience=3)

    if val_score > best_val:
        best_val = val_score
        best_test_at_best_val = test_score

    dt = time.time() - t0
    print(f"[{epoch:03d}] loss={train_loss:.4f} val={val_score:.6f} test={test_score:.6f} time={dt:.2f}s")

    if training_complete:
        print("🏁 training_complete=True (tracker stopped).")
        break

print("\nDone.")
print("Best val:", best_val)
print("Test @ best val:", best_test_at_best_val)

tracker: <class 'perforatedai.tracker_perforatedai.PAINeuronModuleTracker'>
Running Dendrite Experiment
Adding validation score 1.00000000
Checking PAI switch with mode n, switch mode DOING_SWITCH_EVERY_TIME, epoch 0, last improved epoch 0, total epochs 0, n: 1, num_cycles: 0
Returning True - switching every time
Importing best Model for switch to PA...
Switching back to N...
Resetting committed to initial rate to False
Saving model before starting normal training to retain PBNodes regardless of next N Phase results
[001] loss=0.0705 val=1.000000 test=1.000000 time=3.06s
Adding validation score 1.00000000
Checking PAI switch with mode n, switch mode DOING_SWITCH_EVERY_TIME, epoch 1, last improved epoch 1, total epochs 1, n: 1, num_cycles: 2
Returning True - switching every time
Importing best Model for switch to PA...
Switching back to N...
Resetting committed to initial rate to False
Saving model before starting normal training to retain PBNodes regardless of next N Phase results
[002

In [None]:
runs = {
  "Original Model":        {"test": 0.8843109011650085, "params": 99073},
  "Original + Dendrites":  {"test": 0.9325947761535645, "params": 401935},
  "Original Model X% Width":{"test": 0.8512295484542847, "params": 33153},
  "X% Width + Dendrites":  {"test": 0.9319042563438416, "params": 316624},
}

for k,v in runs.items():
    print(f"{k:24s}  test={v['test']:.6f}  pct={v['test']*100:.4f}  params={v['params']}  params(M)={v['params']/1e6:.4f}")

Original Model            test=0.884311  pct=88.4311  params=99073  params(M)=0.0991
Original + Dendrites      test=0.932595  pct=93.2595  params=401935  params(M)=0.4019
Original Model X% Width   test=0.851230  pct=85.1230  params=33153  params(M)=0.0332
X% Width + Dendrites      test=0.931904  pct=93.1904  params=316624  params(M)=0.3166


In [None]:
import matplotlib.pyplot as plt

# ---- paste your 4-run results here ----
FULL_NODEND = {"score": 0.8843109011650085, "params": 99073}
FULL_DEND   = {"score": 0.9325947761535645, "params": 401935}
COMP_NODEND = {"score": 0.8512295484542847, "params": 33153}
COMP_DEND   = {"score": 0.9319042563438416, "params": 316624}

# =========================
# 1) Model Compression graph (bars = score, line = params)
# =========================
labels = ["Original", "X% Width", "X% Width + Dendrites"]
scores = [FULL_NODEND["score"]*100, COMP_NODEND["score"]*100, COMP_DEND["score"]*100]  # percent
params_m = [FULL_NODEND["params"]/1e6, COMP_NODEND["params"]/1e6, COMP_DEND["params"]/1e6]  # millions

fig, ax1 = plt.subplots(figsize=(7,4))
ax1.bar(labels, scores)
ax1.set_ylabel("Test Score (%)")
ax1.set_title("Model Compression")

ax2 = ax1.twinx()
ax2.plot(labels, params_m, marker="o")
ax2.set_ylabel("Parameters (M)")

plt.tight_layout()
plt.show()

# =========================
# 2) Accuracy Improvement graph
# =========================
labels2 = ["Original", "Original + Dendrites"]
scores2 = [FULL_NODEND["score"]*100, FULL_DEND["score"]*100]

plt.figure(figsize=(6,4))
plt.bar(labels2, scores2)
plt.ylabel("Test Score (%)")
plt.title("Accuracy Improvement")
plt.tight_layout()
plt.show()

In [None]:
import torch
import time

# -------- loss --------
criterion = torch.nn.BCEWithLogitsLoss()

def train_one_epoch(model, loader, optimizer, device):
    model.train()
    total_loss, n = 0.0, 0
    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.to(device).float().view(-1)   # shape [B]
        optimizer.zero_grad(set_to_none=True)
        logits = model(xb).view(-1)           # shape [B]
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * xb.size(0)
        n += xb.size(0)
    return total_loss / max(n, 1)

@torch.no_grad()
def eval_accuracy(model, loader, device):
    model.eval()
    correct, n = 0, 0
    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.to(device).long().view(-1)
        logits = model(xb).view(-1)
        preds = (torch.sigmoid(logits) >= 0.5).long()
        correct += (preds == yb).sum().item()
        n += xb.size(0)
    return correct / max(n, 1)   # 0..1

@torch.no_grad()
def eval_pr_auc(model, loader, device):
    # Requires sklearn in runtime
    from sklearn.metrics import average_precision_score
    model.eval()
    ys, ps = [], []
    for xb, yb in loader:
        xb = xb.to(device)
        logits = model(xb).view(-1)
        prob = torch.sigmoid(logits).detach().cpu().numpy()
        ys.append(yb.view(-1).cpu().numpy())
        ps.append(prob)
    import numpy as np
    y = np.concatenate(ys)
    p = np.concatenate(ps)
    return float(average_precision_score(y, p))  # 0..1

# -------- TRAIN LOOP (with PAI tracker calls) --------
max_epochs = 30
best_val = -1e9
best_test_at_best_val = None

for epoch in range(1, max_epochs + 1):
    t0 = time.time()

    train_loss = train_one_epoch(model, train_loader, optimizer, device)

    # pick ONE metric:
    val_score  = eval_pr_auc(model, val_loader, device)   # higher is better
    test_score = eval_pr_auc(model, test_loader, device)

    # optional: scale to match "97-100 style" plots
    val_for_tracker  = val_score * 100.0
    test_for_tracker = test_score * 100.0

    # log test as extra score (optional)
    GPA.pai_tracker.add_extra_score(test_for_tracker, "Test")

    # REQUIRED: add validation score (tracker uses it to decide switches/restructure)
    model, restructured, training_complete = GPA.pai_tracker.add_validation_score(val_for_tracker, model)

    # IMPORTANT: if restructured, rebuild optimizer on the NEW model
    if restructured:
        # easiest safe path: rebuild optimizer from scratch
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    dt = time.time() - t0
    print(f"[{epoch:03d}] loss={train_loss:.4f} val={val_score:.6f} test={test_score:.6f} time={dt:.2f}s")

    if val_score > best_val:
        best_val = val_score
        best_test_at_best_val = test_score

    if training_complete:
        print("🏁 training_complete=True (tracker stopped).")
        break

print("\nDone.")
print("Best val:", best_val)
print("Test @ best val:", best_test_at_best_val)

Adding validation score 100.00000000
Checking PAI switch with mode n, switch mode DOING_SWITCH_EVERY_TIME, epoch 8, last improved epoch 8, total epochs 8, n: 1, num_cycles: 16
Returning True - switching every time
Last Dendrites were good and this hit the max of 8
For improved results, try perforated backpropagation next time!
before load
after load
after graphs
after save
[001] loss=0.0000 val=1.000000 test=1.000000 time=11.72s
🏁 training_complete=True (tracker stopped).

Done.
Best val: 1.0
Test @ best val: 1.0


In [None]:
# should be different objects / different data
print("train_loader is val_loader?", train_loader is val_loader)
print("train ds == val ds?", getattr(train_loader, "dataset", None) is getattr(val_loader, "dataset", None))
print("sizes:", len(train_loader.dataset), len(val_loader.dataset), len(test_loader.dataset))

train_loader is val_loader? False
train ds == val ds? False
sizes: 9600 1200 1200


In [None]:
import torch

def count_pos(loader):
    pos, total = 0, 0
    for _, y in loader:
        y = y.view(-1)
        pos += (y > 0.5).sum().item()
        total += y.numel()
    return pos, total, pos/total

print("train pos:", count_pos(train_loader))
print("val pos:",   count_pos(val_loader))
print("test pos:",  count_pos(test_loader))

train pos: (4772, 9600, 0.4970833333333333)
val pos: (639, 1200, 0.5325)
test pos: (589, 1200, 0.49083333333333334)


In [None]:
import torch

def sanity_pred_check(model, loader, device):
    model.eval()
    xb, yb = next(iter(loader))
    xb, yb = xb.to(device), yb.to(device).view(-1)

    with torch.no_grad():
        logits = model(xb).view(-1)          # shape [B]
        probs  = torch.sigmoid(logits)       # shape [B]
        preds  = (probs >= 0.5).float()

    acc = (preds == yb).float().mean().item()
    print("logits:", logits.min().item(), logits.max().item())
    print("probs :", probs.min().item(), probs.max().item())
    print("preds unique:", torch.unique(preds, return_counts=True))
    print("labels unique:", torch.unique(yb, return_counts=True))
    print("batch acc:", acc)

sanity_pred_check(model, val_loader, device)

logits: -27.052581787109375 21.76656723022461
probs : 1.783253124158779e-12 1.0
preds unique: (tensor([0., 1.]), tensor([68, 60]))
labels unique: (tensor([0., 1.]), tensor([68, 60]))
batch acc: 1.0


In [None]:
import numpy as np
import torch

# Accuracy (0..1)
def eval_accuracy(model, loader, device):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(device)
            yb = yb.to(device).view(-1)
            logits = model(xb).view(-1)
            preds = (torch.sigmoid(logits) >= 0.5).float()
            correct += (preds == yb).sum().item()
            total += yb.numel()
    return correct / total

# PR-AUC (Average Precision) (0..1)
def eval_pr_auc(model, loader, device):
    try:
        from sklearn.metrics import average_precision_score
    except Exception as e:
        raise RuntimeError("Need sklearn for PR-AUC. In Colab: !pip -q install scikit-learn") from e

    model.eval()
    all_probs, all_y = [], []
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(device)
            yb = yb.to(device).view(-1).cpu().numpy()
            logits = model(xb).view(-1)
            probs = torch.sigmoid(logits).cpu().numpy()
            all_probs.append(probs)
            all_y.append(yb)

    all_probs = np.concatenate(all_probs)
    all_y = np.concatenate(all_y)
    return float(average_precision_score(all_y, all_probs))

In [None]:
import torch
from perforatedai import globals_perforatedai as GPA

# ---- MUST be after UPA.initialize_pai(...) so tracker exists ----
print("tracker:", type(GPA.pai_tracker))

# Tell tracker which optimizer/scheduler *classes* to use (NOT instances, NOT lists)
GPA.pai_tracker.set_optimizer(torch.optim.Adam)
GPA.pai_tracker.set_scheduler(torch.optim.lr_scheduler.ReduceLROnPlateau)

optimArgs = {
    "params": model.parameters(),
    "lr": 1e-3,
    "betas": (0.9, 0.999),
    "weight_decay": 0.0,   # recommended for PAI
}
schedArgs = {
    "mode": "max",         # maximizing val metric
    "patience": 3,
}

# IMPORTANT: setup_optimizer takes (model, optimArgs, schedArgs)
optimizer, scheduler = GPA.pai_tracker.setup_optimizer(model, optimArgs, schedArgs)

# sanity checks (these prevent your current crash)
assert hasattr(optimizer, "param_groups"), f"optimizer is wrong type: {type(optimizer)}"
assert not isinstance(optimizer, list), "optimizer is a list — must be a torch optimizer instance"
print("✅ optimizer ok:", type(optimizer))

tracker: <class 'perforatedai.tracker_perforatedai.PAINeuronModuleTracker'>
✅ optimizer ok: <class 'torch.optim.adam.Adam'>


In [None]:
# show all output-dim issues (optional)
GPA.pc.set_debugging_output_dimensions(1)

# For output shape (B, 1), set 2D output dims:
model.net.set_this_output_dimensions([0, -1])
print("✅ set this_output_dimensions for model.net to 2D")

✅ set this_output_dimensions for model.net to 2D


In [None]:
import numpy as np
import torch
from sklearn.metrics import average_precision_score

@torch.no_grad()
def eval_pr_auc(model, loader, device):
    model.eval()
    ys, ps = [], []
    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.to(device).view(-1).float()
        logits = model(xb).view(-1)
        probs = torch.sigmoid(logits)
        ys.append(yb.detach().cpu().numpy())
        ps.append(probs.detach().cpu().numpy())
    y = np.concatenate(ys)
    p = np.concatenate(ps)
    return float(average_precision_score(y, p))

In [None]:
@torch.no_grad()
def eval_acc(model, loader, device):
    model.eval()
    correct, total = 0, 0
    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.to(device).view(-1).float()
        logits = model(xb).view(-1)
        pred = (torch.sigmoid(logits) > 0.5).float()
        correct += (pred == yb).sum().item()
        total += yb.numel()
    return correct / total

In [None]:
import time
import torch
from perforatedai import globals_perforatedai as GPA

def train_one_epoch(model, loader, optimizer, device, criterion):
    model.train()
    total_loss, n = 0.0, 0
    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.to(device).float()

        optimizer.zero_grad(set_to_none=True)
        logits = model(xb)

        # match shapes for BCEWithLogitsLoss
        loss = criterion(logits.view(-1), yb.view(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        n += 1
    return total_loss / max(n, 1)

# pick ONE metric function:
eval_metric = eval_pr_auc   # or eval_acc

criterion = torch.nn.BCEWithLogitsLoss()

best_val = -1e9
best_test_at_best_val = None

max_epochs = 30
for epoch in range(1, max_epochs + 1):
    t0 = time.time()

    train_loss = train_one_epoch(model, train_loader, optimizer, device, criterion)
    val_score  = eval_metric(model, val_loader, device)
    test_score = eval_metric(model, test_loader, device)

    # Put Test on the PAI plot (optional but nice)
    GPA.pai_tracker.add_extra_score(test_score * 100.0, "Test")

    # IMPORTANT: tracker wants "higher is better"
    model, restructured, training_complete = GPA.pai_tracker.add_validation_score(val_score * 100.0, model)

    if restructured:
        # params changed → rebuild optimizer/scheduler through tracker
        optimArgs["params"] = model.parameters()
        optimizer, scheduler = GPA.pai_tracker.setup_optimizer(model, optimArgs, schedArgs)

    # Scheduler step (ReduceLROnPlateau expects the monitored metric)
    scheduler.step(val_score)

    if val_score > best_val:
        best_val = val_score
        best_test_at_best_val = test_score

    print(f"[{epoch:03d}] loss={train_loss:.4f} val={val_score:.6f} test={test_score:.6f} time={time.time()-t0:.2f}s")

    if training_complete:
        print("🏁 training_complete=True (tracker stopped).")
        break

print("\nDone.")
print("Best val:", best_val)
print("Test @ best val:", best_test_at_best_val)

Adding validation score 100.00000000
Checking PAI switch with mode n, switch mode DOING_SWITCH_EVERY_TIME, epoch 8, last improved epoch 8, total epochs 8, n: 1, num_cycles: 16
Returning True - switching every time
Last Dendrites were good and this hit the max of 8
For improved results, try perforated backpropagation next time!
before load
after load
after graphs
after save
[001] loss=0.0000 val=1.000000 test=1.000000 time=12.53s
🏁 training_complete=True (tracker stopped).

Done.
Best val: 1.0
Test @ best val: 1.0


In [None]:
import os, glob, json

print("cwd:", os.getcwd())

# 1) What summary-like files exist in PAI/ ?
if os.path.isdir("PAI"):
    summ_like = [f for f in os.listdir("PAI") if "summary" in f.lower()]
    print("PAI/ summary-like files:", summ_like)
    print("PAI/ has PAI.png?", "PAI.png" in os.listdir("PAI"))

# 2) Find summary.json anywhere in the workspace
hits = glob.glob("**/summary.json", recursive=True)
print("\nFound summary.json files:")
for h in hits:
    print(" -", h)

# 3) Find any PAI.png anywhere
png_hits = glob.glob("**/PAI.png", recursive=True)
print("\nFound PAI.png files:")
for h in png_hits:
    print(" -", h)

cwd: /content
PAI/ summary-like files: []
PAI/ has PAI.png? True

Found summary.json files:

Found PAI.png files:
 - PerforatedAI/Examples/hackathonProjects/mnist-example-submission/PAI.png
 - PAI/PAI.png


In [None]:
import os, glob

print("cwd:", os.getcwd())
print("\nFolders here:", [d for d in os.listdir(".") if os.path.isdir(d)])

# find any summary.json anywhere
hits = glob.glob("**/summary.json", recursive=True)
print("\nFound summary.json:")
for h in hits:
    print(" -", h)

# also list any folders that look like PAI outputs
pai_like = []
for d in [p for p in os.listdir(".") if os.path.isdir(p)]:
    files = os.listdir(d)
    if any(f.lower().endswith(".png") and "pai" in f.lower() for f in files) or any("scores" in f.lower() for f in files):
        pai_like.append(d)

print("\nFolders that look like PAI outputs:", pai_like[:30])

cwd: /content

Folders here: ['.config', 'PerforatedAI', 'PAI', 'FULL_NODEND', 'FULL_DEND', 'COMP_NODEND', 'wandb', 'drive', 'COMP_DEND', 'sample_data']

Found summary.json:

Folders that look like PAI outputs: ['PAI', 'FULL_NODEND', 'FULL_DEND', 'COMP_NODEND', 'COMP_DEND']


In [None]:
import os, json
import pandas as pd

def _read_best_scores(folder):
    # Find a Scores CSV in the folder (handles different naming patterns)
    candidates = [f for f in os.listdir(folder) if f.lower().endswith("scores.csv")]
    if not candidates:
        raise FileNotFoundError(f"No *Scores.csv found in {folder}. Files: {os.listdir(folder)[:20]}")

    # Prefer the plain one if exists
    preferred = None
    for f in candidates:
        if f.lower() == "paiscores.csv" or f.lower().endswith("scores.csv"):
            preferred = f
            break
    scores_path = os.path.join(folder, preferred or candidates[0])

    df = pd.read_csv(scores_path)

    # Columns in your screenshot: Epochs, Validation Scores, Validation Running Scores, Test
    # Some files repeat in blocks; we want rows where Validation Scores is present.
    if "Validation Scores" not in df.columns:
        raise ValueError(f"Scores CSV in {folder} missing 'Validation Scores'. Columns: {df.columns.tolist()}")

    val_rows = df[df["Validation Scores"].notna()].copy()
    if len(val_rows) == 0:
        raise ValueError(f"No validation rows found in {scores_path}")

    # Best val epoch
    best_idx = val_rows["Validation Scores"].astype(float).idxmax()
    best_epoch = int(df.loc[best_idx, "Epochs"])
    best_val = float(df.loc[best_idx, "Validation Scores"])

    # Test at that epoch: look up row(s) where Test is notna and Epochs matches
    test_rows = df[(df["Epochs"] == best_epoch) & (df["Test"].notna())]
    if len(test_rows) == 0:
        # fallback: max test anywhere
        test_val = float(df[df["Test"].notna()]["Test"].astype(float).max())
    else:
        test_val = float(test_rows["Test"].astype(float).iloc[0])

    return best_epoch, best_val, test_val, scores_path

def _read_params_at_epoch(folder, epoch):
    candidates = [f for f in os.listdir(folder) if "param_counts" in f.lower() and f.lower().endswith(".csv")]
    if not candidates:
        # fallback: infer from model checkpoints is hard; raise a helpful message
        raise FileNotFoundError(f"No *param_counts.csv found in {folder}. Files: {os.listdir(folder)[:20]}")
    path = os.path.join(folder, candidates[0])
    df = pd.read_csv(path)

    # common columns: Epochs, Param Count (or similar)
    epoch_col = "Epochs" if "Epochs" in df.columns else df.columns[0]
    # find a numeric column for params
    num_cols = [c for c in df.columns if c != epoch_col]
    if not num_cols:
        raise ValueError(f"Param CSV in {folder} has no param column. Columns: {df.columns.tolist()}")
    param_col = num_cols[0]

    # match epoch
    row = df[df[epoch_col] == epoch]
    if len(row) == 0:
        # fallback: last row
        params = int(df[param_col].iloc[-1])
    else:
        params = int(row[param_col].iloc[0])
    return params, path

# Choose which run is your "main PAI" submission.
# Usually judges want the "COMP_DEND" run (compressed + dendrites). Change if needed.
SUBMISSION_RUN = "COMP_DEND"

best_epoch, best_val, best_test, scores_path = _read_best_scores(SUBMISSION_RUN)
params, params_path = _read_params_at_epoch(SUBMISSION_RUN, best_epoch)

summary = {
    "Final Max Val PR_AUC": best_val,
    "Final Max Test PR_AUC": best_test,
    "Final Param Count": params,
    # If you want dendrite count, you can just hardcode from your printed results.
    # Your COMP_DEND run showed dendrites=8 earlier.
    "Final Dendrite Count": 8
}

os.makedirs("PAI", exist_ok=True)
with open("PAI/summary.json", "w") as f:
    json.dump(summary, f, indent=2)

print(" Wrote PAI/summary.json from:", SUBMISSION_RUN)
print("   Scores source:", scores_path)
print("   Params source:", params_path)
print("   Summary:", summary)

 Wrote PAI/summary.json from: COMP_DEND
   Scores source: COMP_DEND/COMP_DEND_beforeSwitch_12Scores.csv
   Params source: COMP_DEND/COMP_DENDparam_counts.csv
   Summary: {'Final Max Val PR_AUC': 0.9305008053779602, 'Final Max Test PR_AUC': 0.9532941579818726, 'Final Param Count': 100744, 'Final Dendrite Count': 8}


In [None]:
import pandas as pd

df = pd.read_csv("PAI/PAIScores.csv")
print(df.head(12))
print("\nNon-null ranges:")
for col in ["Validation Scores", "Validation Running Scores", "Test"]:
    if col in df.columns:
        vals = df[col].dropna().astype(float)
        if len(vals):
            print(col, "min=", vals.min(), "max=", vals.max())

    Epochs  Validation Scores  Validation Running Scores  Test
0        0                1.0                        NaN   NaN
1        1                1.0                        NaN   NaN
2        2                1.0                        NaN   NaN
3        3                1.0                        NaN   NaN
4        4                1.0                        NaN   NaN
5        5                1.0                        NaN   NaN
6        6                1.0                        NaN   NaN
7        7                1.0                        NaN   NaN
8        8                1.0                        NaN   NaN
9        9              100.0                        NaN   NaN
10       0                NaN                        1.0   NaN
11       1                NaN                        1.0   NaN

Non-null ranges:
Validation Scores min= 1.0 max= 100.0
Validation Running Scores min= 1.0 max= 100.0
Test min= 1.0 max= 100.0


In [None]:
# after you compute val_pr and test_pr as 0..1
val_score  = float(val_pr)  * 100.0
test_score = float(test_pr) * 100.0

# log consistently
print(f"[{epoch:03d}] loss={train_loss:.4f} val={val_score:.3f} test={test_score:.3f}")

GPA.pai_tracker.add_extra_score(test_score, "Test")
model, restructured, training_complete = GPA.pai_tracker.add_validation_score(val_score, model)

[001] loss=0.0000 val=100.000 test=100.000
Adding validation score 100.00000000
Checking PAI switch with mode n, switch mode DOING_SWITCH_EVERY_TIME, epoch 8, last improved epoch 8, total epochs 8, n: 1, num_cycles: 16
Returning True - switching every time
Last Dendrites were good and this hit the max of 8
For improved results, try perforated backpropagation next time!
before load
after load
after graphs
after save


In [None]:
import os, shutil, time
from perforatedai import globals_perforatedai as GPA

# 1) pick a NEW save_name every run (avoid CSV mixing)
SAVE_NAME = f"PAI_{int(time.time())}"

# 2) wipe the folder if it exists (clean slate)
if os.path.exists(SAVE_NAME):
    shutil.rmtree(SAVE_NAME)

# 3) silence warnings
GPA.pc.set_weight_decay_accepted(True)
GPA.pc.set_unwrapped_modules_confirmed(True)

# 4) make sure Conv/Linear can be converted + tracked
GPA.pc.module_names_to_convert = ["Conv2d", "Linear", "Sequential", "PAISequential"]
GPA.pc.module_names_to_track   = ["Conv2d", "Linear"]

print("SAVE_NAME =", SAVE_NAME)
print("convert =", GPA.pc.module_names_to_convert)
print("track   =", GPA.pc.module_names_to_track)

SAVE_NAME = PAI_1767678775
convert = ['Conv2d', 'Linear', 'Sequential', 'PAISequential']
track   = ['Conv2d', 'Linear']


In [None]:
import torch
from perforatedai import utils_perforatedai as UPA
from perforatedai import globals_perforatedai as GPA

device = torch.device("cpu")

# ---- build your model ONCE ----
# (use your real constructor here)
model = SmallCNNBinarySeq().to(device)

# ---- wrap ----
model = UPA.initialize_pai(model, save_name=SAVE_NAME)

# ---- fix the_output_dimensions error (your model outputs [B,1]) ----
# Use exactly what the error asked for:
model.net.set_this_output_dimensions([-1, 0, -1, -1])

# ---- setup optimizer/scheduler THROUGH the tracker (correct API: pass dicts) ----
GPA.pai_tracker.set_optimizer(torch.optim.Adam)
GPA.pai_tracker.set_scheduler(torch.optim.lr_scheduler.ReduceLROnPlateau)

optArgs = {
    "params": model.parameters(),
    "lr": 1e-3,
    "betas": (0.9, 0.999),
    "weight_decay": 0.0,
}
schedArgs = {"mode": "max", "patience": 3}

optimizer, scheduler = GPA.pai_tracker.setup_optimizer(model, optArgs, schedArgs)

print("wrapped ok. tracker =", type(GPA.pai_tracker))

Running Dendrite Experiment
wrapped ok. tracker = <class 'perforatedai.tracker_perforatedai.PAINeuronModuleTracker'>


In [None]:
from perforatedai import globals_perforatedai as GPA
from perforatedai import utils_perforatedai as UPA
import torch

GPA.pc.set_debugging_output_dimensions(0)  # do NOT leave this at 1

# build model ONCE
model = SmallCNNBinarySeq().to(device)

# wrap ONCE
model = UPA.initialize_pai(model, save_name=SAVE_NAME)

# ✅ fix the exact module mentioned by the error: ".net"
model.net.set_this_output_dimensions([-1, 0, -1, -1])

# sanity: make sure forward works on the wrapped model
xb, yb = next(iter(train_loader))
with torch.no_grad():
    out = model(xb.to(device))
print("forward ok, out shape:", out.shape)

Running Dendrite Experiment
forward ok, out shape: torch.Size([128, 1])


In [None]:
print(type(model))
print(type(model.net))
print([type(m).__name__ for m in model.modules()])


<class '__main__.SmallCNNBinarySeq'>
<class 'perforatedai.modules_perforatedai.PAINeuronModule'>
['SmallCNNBinarySeq', 'PAINeuronModule', 'Sequential', 'Conv2d', 'ReLU', 'MaxPool2d', 'Conv2d', 'ReLU', 'AdaptiveAvgPool2d', 'Flatten', 'Linear', 'ParameterList', 'ParameterList', 'PAIDendriteModule', 'ModuleList', 'Sequential', 'Conv2d', 'ReLU', 'MaxPool2d', 'Conv2d', 'ReLU', 'AdaptiveAvgPool2d', 'Flatten', 'Linear', 'ParameterList', 'ParameterList', 'ModuleList', 'DendriteValueTracker']


In [None]:
import torch
from perforatedai import globals_perforatedai as GPA

GPA.pai_tracker.set_optimizer(torch.optim.Adam)
GPA.pai_tracker.set_scheduler(torch.optim.lr_scheduler.ReduceLROnPlateau)

optimArgs = {
    "params": model.parameters(),
    "lr": 1e-3,
    "betas": (0.9, 0.999),
    # "weight_decay": 0.0,  # recommended 0 for PAI
}
schedArgs = {"mode": "max", "patience": 3}

optimizer, scheduler = GPA.pai_tracker.setup_optimizer(model, optimArgs, schedArgs)

In [None]:
if restructured:
    optimArgs["params"] = model.parameters()
    optimizer, scheduler = GPA.pai_tracker.setup_optimizer(model, optimArgs, schedArgs)

In [None]:
import time, torch

max_epochs = 30
best_val = -1e9
best_test_at_best_val = None

for epoch in range(1, max_epochs + 1):
    t0 = time.time()

    # ---- TRAIN ----
    train_loss = train_one_epoch(model, train_loader, optimizer, device)

    # ---- EVAL (must return 0..1) ----
    val_pr  = eval_pr_auc(model, val_loader, device)     # 0..1
    test_pr = eval_pr_auc(model, test_loader, device)    # 0..1

    # ---- SCALE to 0..100 for PAI plots ----
    val_score  = float(val_pr)  * 100.0
    test_score = float(test_pr) * 100.0

    # ---- TRACKER LOGGING ----
    GPA.pai_tracker.add_extra_score(test_score, "Test")
    model, restructured, training_complete = GPA.pai_tracker.add_validation_score(val_score, model)

    # if model changed, rebuild optimizer/scheduler
    if restructured:
        optimArgs["params"] = model.parameters()
        optimizer, scheduler = GPA.pai_tracker.setup_optimizer(model, optimArgs, schedArgs)

    # ---- best tracking ----
    if val_score > best_val:
        best_val = val_score
        best_test_at_best_val = test_score

    print(f"[{epoch:03d}] loss={train_loss:.4f} val={val_score:.3f} test={test_score:.3f} time={time.time()-t0:.2f}s")

    if training_complete:
        print("🏁 training_complete=True (tracker stopped).")
        break

print("\nDone.")
print("Best val:", best_val)
print("Test @ best val:", best_test_at_best_val)

Adding validation score 100.00000000
Checking PAI switch with mode n, switch mode DOING_SWITCH_EVERY_TIME, epoch 0, last improved epoch 0, total epochs 0, n: 1, num_cycles: 0
Returning True - switching every time
Importing best Model for switch to PA...
Switching back to N...
Resetting committed to initial rate to False
Saving model before starting normal training to retain PBNodes regardless of next N Phase results
[001] loss=0.0762 val=100.000 test=100.000 time=4.30s
Adding validation score 100.00000000
Checking PAI switch with mode n, switch mode DOING_SWITCH_EVERY_TIME, epoch 1, last improved epoch 1, total epochs 1, n: 1, num_cycles: 2
Returning True - switching every time
Importing best Model for switch to PA...
Switching back to N...
Resetting committed to initial rate to False
Saving model before starting normal training to retain PBNodes regardless of next N Phase results
[002] loss=0.0001 val=100.000 test=100.000 time=4.07s
Adding validation score 100.00000000
Checking PAI sw

In [None]:
import numpy as np
import torch
from sklearn.metrics import average_precision_score

@torch.no_grad()
def eval_pr_auc(model, loader, device):
    model.eval()
    all_p, all_y = [], []
    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.view(-1).cpu().numpy()
        logits = model(xb).view(-1).detach().cpu().numpy()
        probs = 1.0 / (1.0 + np.exp(-logits))  # sigmoid
        all_p.append(probs)
        all_y.append(yb)
    all_p = np.concatenate(all_p)
    all_y = np.concatenate(all_y)
    return float(average_precision_score(all_y, all_p))  # 0..1

In [None]:
val_pr  = eval_pr_auc(model, val_loader, device)   # 0..1
test_pr = eval_pr_auc(model, test_loader, device)  # 0..1

val_score  = val_pr * 100.0
test_score = test_pr * 100.0

In [None]:
import os
print(os.listdir("FULL_DEND")[:10])
print(os.path.exists("FULL_DEND/summary.json"), os.path.exists("FULL_DEND/FULL_DEND.png"))

['FULL_DEND_beforeSwitch_8Times.csv', 'best_model_beforeSwitch_10.pt', 'best_model.pt', 'FULL_DEND_beforeSwitch_0Times.csv', 'FULL_DEND_beforeSwitch_10learning_rate.csv', 'FULL_DEND_beforeSwitch_12switch_epochs.csv', 'FULL_DEND_beforeSwitch_0Best PBScores.csv', 'FULL_DENDparam_counts.csv', 'FULL_DEND_beforeSwitch_10.png', 'FULL_DENDTimes.csv']
False True


In [None]:
import pandas as pd
import numpy as np

scores = pd.read_csv("FULL_DEND/FULL_DENDScores.csv")
params = pd.read_csv("FULL_DEND/FULL_DENDbefore_finalparam_counts.csv")

# --- coerce to numeric safely ---
for c in ["Epochs", "Validation Scores", "Validation Running Scores", "Test"]:
    if c in scores.columns:
        scores[c] = pd.to_numeric(scores[c], errors="coerce")

print("Non-null counts:")
print(scores[["Epochs","Validation Scores","Validation Running Scores","Test"]].notna().sum())

# pick which val column actually has values
val_col = "Validation Scores" if scores["Validation Scores"].notna().any() else "Validation Running Scores"
assert scores[val_col].notna().any(), "Both Validation Scores and Validation Running Scores are empty/NaN."

# best epoch = argmax over val_col (ignore NaNs)
best_row = scores.loc[scores[val_col].idxmax()]
best_epoch = int(best_row["Epochs"])
best_val = float(best_row[val_col])

# test at best epoch (if missing, fall back to max available test)
test_at_best = scores.loc[scores["Epochs"] == best_epoch, "Test"]
test_at_best = float(test_at_best.dropna().iloc[0]) if test_at_best.dropna().any() else float(scores["Test"].dropna().max())

# params file is by Switch Number, so just take FINAL param count (last row)
params["Param Count"] = pd.to_numeric(params["Param Count"], errors="coerce")
final_params = int(params["Param Count"].dropna().iloc[-1])

print("\n===== RESULTS =====")
print("best_epoch:", best_epoch)
print("best_val:", best_val)
print("test @ best epoch:", test_at_best)
print("final params:", final_params)

Non-null counts:
Epochs                       27
Validation Scores             9
Validation Running Scores     9
Test                          9
dtype: int64

===== RESULTS =====
best_epoch: 2
best_val: 0.9509258270263672
test @ best epoch: 0.9325947761535645
final params: 928080


In [None]:
import os, glob
import pandas as pd
import numpy as np

def _pick_scores_file(folder):
    csvs = glob.glob(os.path.join(folder, "*.csv"))
    if not csvs:
        raise FileNotFoundError(f"No CSV files found in {folder}")

    # 1) BEST: has Validation + Test columns (the file we need)
    for p in sorted(csvs):
        try:
            df0 = pd.read_csv(p, nrows=3)
        except Exception:
            continue
        cols = set(df0.columns)
        if ("Validation Scores" in cols or "Validation Running Scores" in cols) and ("Test" in cols) and ("Epochs" in cols):
            return p

    # 2) Next best: contains Scores.csv but NOT "Best PBScores"
    for p in sorted(csvs):
        base = os.path.basename(p)
        if "Scores" in base and "Best PBScores" not in base:
            return p

    # 3) Otherwise: show what we found (so you can see what's missing)
    raise ValueError(
        f"{folder}: Couldn't find a Scores csv with Validation/Test.\n"
        f"CSV files: {[os.path.basename(x) for x in csvs]}"
    )

def _read_scores(folder):
    path = _pick_scores_file(folder)
    df = pd.read_csv(path)

    # numeric coercion
    for c in ["Epochs", "Validation Scores", "Validation Running Scores", "Test"]:
        if c in df.columns:
            df[c] = pd.to_numeric(df[c], errors="coerce")

    # choose val column that actually has values
    if "Validation Scores" in df.columns and df["Validation Scores"].notna().any():
        val_col = "Validation Scores"
    elif "Validation Running Scores" in df.columns and df["Validation Running Scores"].notna().any():
        val_col = "Validation Running Scores"
    else:
        raise ValueError(f"{folder}: picked {os.path.basename(path)} but it has no numeric validation column.")

    if "Test" not in df.columns:
        raise ValueError(f"{folder}: picked {os.path.basename(path)} but it has no Test column.")

    return df, val_col, path

In [None]:
import os
print("FULL_NODEND files:")
for f in sorted(os.listdir("FULL_NODEND")):
    if f.endswith(".csv") or f.endswith(".json") or f.endswith(".png"):
        print(" -", f)

FULL_NODEND files:
 - FULL_NODEND.png
 - FULL_NODENDBest PBScores.csv
 - FULL_NODENDScores.csv
 - FULL_NODENDTimes.csv
 - FULL_NODENDbefore_final.png
 - FULL_NODENDbefore_finalBest PBScores.csv
 - FULL_NODENDbefore_finalScores.csv
 - FULL_NODENDbefore_finalTimes.csv
 - FULL_NODENDbefore_finalbest_test_scores.csv
 - FULL_NODENDbefore_finallearning_rate.csv
 - FULL_NODENDbefore_finalparam_counts.csv
 - FULL_NODENDbefore_finalswitch_epochs.csv
 - FULL_NODENDbest_test_scores.csv
 - FULL_NODENDlearning_rate.csv
 - FULL_NODENDparam_counts.csv
 - FULL_NODENDswitch_epochs.csv


In [None]:
import os, glob, json
import pandas as pd
import numpy as np

def pick_scores_csv(folder: str) -> str:
    # ignore Best PBScores
    candidates = [p for p in glob.glob(os.path.join(folder, "*.csv"))
                  if p.endswith("Scores.csv") and ("Best PBScores" not in os.path.basename(p))]
    if not candidates:
        raise FileNotFoundError(f"{folder}: no *Scores.csv found (non-BestPBScores).")

    # prefer the one that has actual numeric validation
    def score(p):
        try:
            df = pd.read_csv(p)
        except Exception:
            return -1
        cols = set(df.columns)
        if "Epochs" not in cols:
            return -1
        val_ok = False
        if "Validation Scores" in cols and pd.to_numeric(df["Validation Scores"], errors="coerce").notna().any():
            val_ok = True
        if "Validation Running Scores" in cols and pd.to_numeric(df["Validation Running Scores"], errors="coerce").notna().any():
            val_ok = True
        test_ok = ("Test" in cols) and pd.to_numeric(df["Test"], errors="coerce").notna().any()
        return int(val_ok) + int(test_ok)

    candidates = sorted(candidates, key=score, reverse=True)
    return candidates[0]

def pick_params_csv(folder: str) -> str:
    # Prefer before_finalparam_counts.csv if present (often more complete)
    pref = os.path.join(folder, f"{os.path.basename(folder)}before_finalparam_counts.csv")
    if os.path.exists(pref):
        return pref
    # else fallback to *param_counts.csv
    candidates = glob.glob(os.path.join(folder, "*param_counts.csv"))
    if not candidates:
        raise FileNotFoundError(f"{folder}: no *param_counts.csv found.")
    # pick the shortest name last (usually final), but any is fine
    return sorted(candidates, key=lambda p: len(os.path.basename(p)))[0]

def extract_run(folder: str):
    scores_path = pick_scores_csv(folder)
    df = pd.read_csv(scores_path)

    # coerce numeric
    for c in ["Epochs", "Validation Scores", "Validation Running Scores", "Test"]:
        if c in df.columns:
            df[c] = pd.to_numeric(df[c], errors="coerce")

    # choose best validation column available
    if "Validation Scores" in df.columns and df["Validation Scores"].notna().any():
        val_col = "Validation Scores"
    elif "Validation Running Scores" in df.columns and df["Validation Running Scores"].notna().any():
        val_col = "Validation Running Scores"
    else:
        raise ValueError(f"{folder}: {os.path.basename(scores_path)} has no numeric validation column.")

    # IMPORTANT: validation and test are in separate blocks → do NOT dropna on both
    val_rows = df[df[val_col].notna()][["Epochs", val_col]].copy()
    if val_rows.empty:
        raise ValueError(f"{folder}: no validation rows found in {os.path.basename(scores_path)}")

    best_idx = val_rows[val_col].idxmax()
    best_epoch = int(val_rows.loc[best_idx, "Epochs"])
    best_val = float(val_rows.loc[best_idx, val_col])

    test_at_best = np.nan
    if "Test" in df.columns and df["Test"].notna().any():
        test_rows = df[df["Test"].notna()][["Epochs", "Test"]].copy()
        hit = test_rows[test_rows["Epochs"] == best_epoch]
        if not hit.empty:
            test_at_best = float(hit.iloc[0]["Test"])
        else:
            # fallback: take max test if epoch alignment is missing
            test_at_best = float(test_rows["Test"].max())

    # params
    params_path = pick_params_csv(folder)
    p = pd.read_csv(params_path)
    # find param column
    param_col = "Param Count" if "Param Count" in p.columns else p.columns[-1]
    p[param_col] = pd.to_numeric(p[param_col], errors="coerce")
    final_params = int(p[param_col].dropna().iloc[-1])

    return {
        "folder": folder,
        "scores_csv": scores_path,
        "params_csv": params_path,
        "best_epoch": best_epoch,
        "best_val": best_val,
        "test_at_best": float(test_at_best) if not np.isnan(test_at_best) else None,
        "final_params": final_params,
    }

RUNS = {
    "FULL_NODEND": "FULL_NODEND",
    "FULL_DEND": "FULL_DEND",
    "COMP_NODEND": "COMP_NODEND",
    "COMP_DEND": "COMP_DEND",
}

results = {k: extract_run(v) for k, v in RUNS.items()}
results

{'FULL_NODEND': {'folder': 'FULL_NODEND',
  'scores_csv': 'FULL_NODEND/FULL_NODENDbefore_finalScores.csv',
  'params_csv': 'FULL_NODEND/FULL_NODENDbefore_finalparam_counts.csv',
  'best_epoch': 0,
  'best_val': 0.8543249368667603,
  'test_at_best': 0.8843109011650085,
  'final_params': 99073},
 'FULL_DEND': {'folder': 'FULL_DEND',
  'scores_csv': 'FULL_DEND/FULL_DEND_beforeSwitch_14Scores.csv',
  'params_csv': 'FULL_DEND/FULL_DENDbefore_finalparam_counts.csv',
  'best_epoch': 2,
  'best_val': 0.9509258270263672,
  'test_at_best': 0.9325947761535645,
  'final_params': 928080},
 'COMP_NODEND': {'folder': 'COMP_NODEND',
  'scores_csv': 'COMP_NODEND/COMP_NODENDScores.csv',
  'params_csv': 'COMP_NODEND/COMP_NODENDbefore_finalparam_counts.csv',
  'best_epoch': 0,
  'best_val': 0.7831177115440369,
  'test_at_best': 0.8512295484542847,
  'final_params': 33153},
 'COMP_DEND': {'folder': 'COMP_DEND',
  'scores_csv': 'COMP_DEND/COMP_DEND_beforeSwitch_12Scores.csv',
  'params_csv': 'COMP_DEND/COMP

In [None]:
def pct(x):
    return None if x is None else 100.0 * float(x)

print("\n=== Paste these into the judge template ===")
for k, r in results.items():
    print(f"\n{k}")
    print("  Best Val (raw):", r["best_val"])
    print("  Best Val (%):  ", pct(r["best_val"]))
    print("  Test@Best (raw):", r["test_at_best"])
    print("  Test@Best (%):  ", pct(r["test_at_best"]))
    print("  Final Params:", r["final_params"])


=== Paste these into the judge template ===

FULL_NODEND
  Best Val (raw): 0.8543249368667603
  Best Val (%):   85.43249368667603
  Test@Best (raw): 0.8843109011650085
  Test@Best (%):   88.43109011650085
  Final Params: 99073

FULL_DEND
  Best Val (raw): 0.9509258270263672
  Best Val (%):   95.09258270263672
  Test@Best (raw): 0.9325947761535645
  Test@Best (%):   93.25947761535645
  Final Params: 928080

COMP_NODEND
  Best Val (raw): 0.7831177115440369
  Best Val (%):   78.31177115440369
  Test@Best (raw): 0.8512295484542847
  Test@Best (%):   85.12295484542847
  Final Params: 33153

COMP_DEND
  Best Val (raw): 0.9305008053779602
  Best Val (%):   93.05008053779602
  Test@Best (raw): 0.9532941579818726
  Test@Best (%):   95.32941579818726
  Final Params: 316624


In [None]:
BASELINE = {
    "test": 0.8843109011650085,   # FULL_NODEND test@best
    "params": 99073               # FULL_NODEND final params
}
print("Baseline:", BASELINE)

Baseline: {'test': 0.8843109011650085, 'params': 99073}


In [None]:
import time, os
import torch
import numpy as np
from sklearn.metrics import average_precision_score
from perforatedai import utils_perforatedai as UPA

device = torch.device("cpu")
print("device:", device)

@torch.no_grad()
def eval_pr_auc(model, loader, device):
    model.eval()
    all_p, all_y = [], []
    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.view(-1).cpu().numpy()
        logits = model(xb).view(-1).detach().cpu().numpy()
        probs = 1.0 / (1.0 + np.exp(-logits))
        all_p.append(probs)
        all_y.append(yb)
    all_p = np.concatenate(all_p)
    all_y = np.concatenate(all_y)
    return float(average_precision_score(all_y, all_p))  # 0..1

def setup_tracker_optimizer(model, lr):
    # IMPORTANT: call these AFTER initialize_pai (tracker exists then)
    GPA.pai_tracker.set_optimizer(torch.optim.Adam)
    GPA.pai_tracker.set_scheduler(torch.optim.lr_scheduler.ReduceLROnPlateau)

    opt_args = {
        "params": model.parameters(),
        "lr": float(lr),
        "betas": (0.9, 0.999),
        # "weight_decay": 0.0,  # recommended off
    }
    sched_args = {
        "mode": "max",     # maximize val score
        "patience": 3
    }

    optimizer, scheduler = GPA.pai_tracker.setup_optimizer(model, opt_args, sched_args)
    return optimizer, scheduler

device: cpu


In [None]:
import os, time
import numpy as np
import torch
import wandb
from sklearn.metrics import average_precision_score
from perforatedai import globals_perforatedai as GPA
from perforatedai import utils_perforatedai as UPA

device = torch.device("cpu")
print("device:", device)

# --- PAI config (LISTS!) ---
GPA.pc.set_weight_decay_accepted(True)
GPA.pc.set_unwrapped_modules_confirmed(True)
GPA.pc.set_module_names_to_convert(["Conv2d", "Linear", "Sequential", "PAISequential"])
GPA.pc.set_module_names_to_track(["Conv2d", "Linear"])

# --- Baseline (FULL_NODEND from your extraction) ---
BASELINE = {"test": 0.8843109011650085, "params": 99073}

@torch.no_grad()
def eval_pr_auc(model, loader, device):
    model.eval()
    all_p, all_y = [], []
    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.view(-1).cpu().numpy()
        logits = model(xb).view(-1).detach().cpu().numpy()
        probs = 1.0 / (1.0 + np.exp(-logits))
        all_p.append(probs)
        all_y.append(yb)
    all_p = np.concatenate(all_p)
    all_y = np.concatenate(all_y)
    return float(average_precision_score(all_y, all_p))  # 0..1

def setup_tracker_optimizer(model, lr):
    GPA.pai_tracker.set_optimizer(torch.optim.Adam)
    GPA.pai_tracker.set_scheduler(torch.optim.lr_scheduler.ReduceLROnPlateau)

    opt_args = {
        "params": model.parameters(),
        "lr": float(lr),
        "betas": (0.9, 0.999),
    }
    sched_args = {"mode": "max", "patience": 3}
    optimizer, scheduler = GPA.pai_tracker.setup_optimizer(model, opt_args, sched_args)
    return optimizer, scheduler

# --------------------------
# YOU MUST already have these in memory:
# - train_one_epoch(model, train_loader, optimizer, device)
# - train_loader, val_loader, test_loader
# - build_model_comp(width_mult=...)  OR some way to build COMP at width_mult
# --------------------------

def run_one_comp(save_name, width_mult, max_dendrites, lr, max_epochs=12):
    # ---- build COMP model (EDIT THIS LINE TO MATCH YOUR CODE) ----
    # Example if you have: build_model_comp(width_mult=0.35)
    model = build_model_comp(width_mult=width_mult).to(device)

    # ---- wrap once ----
    model = UPA.initialize_pai(model, save_name=save_name)

    # ---- fix output dims for your PAINeuronModule wrapper ----
    # you verified forward output is (B,1)
    model.net.set_this_output_dimensions([-1, 0, -1, -1])

    # (optional) set max dendrites if your build uses this config var
    try:
        GPA.pc.set_max_dendrites(int(max_dendrites))
    except Exception:
        pass

    optimizer, scheduler = setup_tracker_optimizer(model, lr)

    best_val, best_test_at_best = -1.0, -1.0
    t0 = time.time()

    for epoch in range(1, int(max_epochs) + 1):
        model.train()
        train_loss = train_one_epoch(model, train_loader, optimizer, device)

        val_pr  = eval_pr_auc(model, val_loader, device)
        test_pr = eval_pr_auc(model, test_loader, device)

        # log to tracker in % (97-100 style)
        val_score  = val_pr * 100.0
        test_score = test_pr * 100.0

        GPA.pai_tracker.add_extra_score(test_score, "Test")
        model, restructured, training_complete = GPA.pai_tracker.add_validation_score(val_score, model)

        if restructured:
            optimizer, scheduler = setup_tracker_optimizer(model, lr)

        if val_pr > best_val:
            best_val = val_pr
            best_test_at_best = test_pr

        print(f"[{epoch:03d}] loss={train_loss:.4f} val_pr={val_pr:.6f} test_pr={test_pr:.6f} time={time.time()-t0:.2f}s")

        if training_complete:
            break

    params = sum(p.numel() for p in model.parameters())
    return {"best_val_pr": best_val, "best_test_pr": best_test_at_best, "params": int(params), "out_dir": save_name}

device: cpu


In [None]:
sweep_config = {
    "method": "grid",
    "metric": {"name": "test_pr", "goal": "maximize"},
    "parameters": {
        "width_mult": {"values": [0.25, 0.35, 0.50]},
        "max_dendrites": {"values": [1, 2, 3]},
        "lr": {"values": [1e-3, 3e-4]},
        "max_epochs": {"value": 12},
    }
}

def sweep_train():
    run = wandb.init()  # IMPORTANT: do not pass project here for sweeps
    cfg = wandb.config

    save_name = f"SWEEP_COMP_w{cfg.width_mult}_d{cfg.max_dendrites}_lr{cfg.lr}"
    out = run_one_comp(
        save_name=save_name,
        width_mult=float(cfg.width_mult),
        max_dendrites=int(cfg.max_dendrites),
        lr=float(cfg.lr),
        max_epochs=int(cfg.max_epochs),
    )

    test_pr = float(out["best_test_pr"])
    params  = int(out["params"])
    compression = 1.0 - (params / BASELINE["params"])
    improvement = test_pr - BASELINE["test"]

    wandb.log({
        "test_pr": test_pr,
        "params": params,
        "compression_vs_full": compression,
        "improvement_vs_full": improvement,
        "width_mult": float(cfg.width_mult),
        "max_dendrites": int(cfg.max_dendrites),
        "lr": float(cfg.lr),
    })

    png_path = os.path.join(out["out_dir"], f"{out['out_dir']}.png")
    if os.path.exists(png_path):
        wandb.log({"pai_plot": wandb.Image(png_path)})

    run.finish()

In [None]:
import inspect
print(inspect.signature(build_model_comp))
print(inspect.getsource(build_model_comp)[:800])  # optional peek

()
def build_model_comp():
    # "X% width" example: half width (256 -> 128)
    return MLPBinary(in_dim=128, hidden=128, dropout=0.2)



In [None]:
import inspect

def build_comp_any(width_mult: float):
    sig = inspect.signature(build_model_comp)
    params = sig.parameters

    # try common keyword names
    for k in ["width_mult", "width", "mult", "multiplier", "w", "ratio"]:
        if k in params:
            return build_model_comp(**{k: float(width_mult)})

    # if it takes exactly 1 positional arg (besides self)
    if len(params) == 1:
        return build_model_comp(float(width_mult))

    # fallback: no args
    return build_model_comp()

In [None]:
import os, time, json
import torch
import torch.nn as nn
import wandb
from perforatedai import globals_perforatedai as GPA

device = torch.device("cpu")  # or your device

def run_one_comp(save_name: str, width_mult: float, max_dendrites: int, lr: float, max_epochs: int = 12, weight_decay: float = 0.0):
    """
    Runs ONE experiment (COMP model) and writes outputs into folder `save_name/`.
    Logs metrics to W&B if a run is active.
    """
    os.makedirs(save_name, exist_ok=True)

    # ---- IMPORTANT: configure conversions/tracking (your working settings) ----
    # Use lists (NOT strings) when appending/setting.
    GPA.pc.set_module_names_to_convert(["Conv2d", "Linear", "Sequential", "PAISequential"])
    GPA.pc.set_module_names_to_track(["Conv2d", "Linear"])
    GPA.pc.set_weight_decay_accepted(True)
    GPA.pc.set_unwrapped_modules_confirmed(True)

    # ---- build COMP model (robust to arg name differences) ----
    model = build_comp_any(width_mult).to(device)

    # ---- wrap with PAI ----
    # NOTE: depending on your notebook, initialize_pai might be:
    #   GPA.initialize_pai(...)  OR  UPA.initialize_pai(...)
    # Use the one you used successfully earlier.
    model = GPA.initialize_pai(model, save_name=save_name, max_dendrites=max_dendrites)

    # ---- optimizer (keep it simple + compatible) ----
    optimizer = torch.optim.Adam(model.parameters(), lr=float(lr), weight_decay=float(weight_decay))

    # ---- TRAIN LOOP ----
    best_val = -1e9
    best_test_at_best_val = None

    for epoch in range(1, max_epochs + 1):
        t0 = time.time()

        # MUST match your train fn signature:
        # you said yours needs (model, loader, optimizer, device)
        train_loss = train_one_epoch(model, train_loader, optimizer, device)

        # MUST return "higher is better" (PR-AUC, accuracy, etc.)
        val_score  = eval_metric(model, val_loader, device)
        test_score = eval_metric(model, test_loader, device)

        # log to wandb if active
        if wandb.run is not None:
            wandb.log({
                "epoch": epoch,
                "train_loss": float(train_loss),
                "val_score": float(val_score),
                "test_score": float(test_score),
                "width_mult": float(width_mult),
                "max_dendrites": int(max_dendrites),
                "lr": float(lr),
            })

        # ---- tell tracker the scores (use % if you want 97–100 scale) ----
        # Your earlier successful run used 100-scale. Keep it consistent:
        GPA.pai_tracker.add_extra_score(float(test_score) * 100.0, "Test")
        model, restructured, training_complete = GPA.pai_tracker.add_validation_score(float(val_score) * 100.0, model)

        # IMPORTANT: rebuild optimizer after restructure (Dendrites changed params)
        if restructured:
            optimizer = torch.optim.Adam(model.parameters(), lr=float(lr), weight_decay=float(weight_decay))

        # track best val
        if val_score > best_val:
            best_val = float(val_score)
            best_test_at_best_val = float(test_score)

        print(f"[{epoch:03d}] loss={train_loss:.4f} val={val_score:.6f} test={test_score:.6f} time={time.time()-t0:.2f}s")

        if training_complete:
            break

    # ---- return summary for sweep ----
    out = {
        "save_name": save_name,
        "width_mult": float(width_mult),
        "max_dendrites": int(max_dendrites),
        "lr": float(lr),
        "best_val": float(best_val),
        "best_test_at_best_val": float(best_test_at_best_val) if best_test_at_best_val is not None else None,
    }
    return out

In [None]:
@torch.no_grad()
def eval_metric(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.to(device)

        logits = model(xb)
        probs = torch.sigmoid(logits).view(-1)
        preds = (probs >= 0.5).long()

        y = yb.view(-1).long()
        correct += (preds == y).sum().item()
        total += y.numel()
    return correct / max(total, 1)

In [None]:
def sweep_train():
    cfg = wandb.config

    out = run_one_comp(
        save_name=f"SWEEP_COMP_w{cfg.width_mult}_d{cfg.max_dendrites}_lr{cfg.lr}",
        width_mult=float(cfg.width_mult),
        max_dendrites=int(cfg.max_dendrites),
        lr=float(cfg.lr),
        max_epochs=int(cfg.max_epochs),
        weight_decay=0.0,
    )

    # tell wandb what to optimize
    wandb.log({
        "best_val": out["best_val"],
        "best_test_at_best_val": out["best_test_at_best_val"],
    })