# Multiclass (0–3) Substation Classification with Hetero GNN
We predict **Classes_4** excluding label `4` (i.e., keep only classes **0–3**). 
This notebook uses a robust stratified split with per-class minimums, **class-balanced focal loss**, 
and a **GATv2** hetero model with edge attributes.

**Summary from previous runs (this repo):**
- Labeled per-class counts: `{2: 45, 3: 77, 1: 13, 0: 59}` (class imbalance).
- Best validation macro-F1: **0.5447**; test macro-F1: **0.3554** (on a matching split).

In [1]:
# A1 — Setup & Paths
import os, json, random
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch import nn

from collections import Counter
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, classification_report, confusion_matrix
from torch_geometric.nn import HeteroConv, GATv2Conv

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)

DATA_PATH = "IncidentDataFinal.csv"
GRAPH_IN  = "Hetero_Final_NW_graph_fixed_kara.pt"      # unlabeled graph you already built
GRAPH_OUT = "Hetero_graph_kara_classes4_labeled.pt"    # will save labeled (0..3 only)


Device: cuda


In [2]:
# A2 — Load incidents (CSV) and base graph
_cols = pd.read_csv(DATA_PATH, nrows=0).columns
_date_cols = [c for c in ['Job OFF Time','Job ON Time'] if c in _cols]
inc = pd.read_csv(DATA_PATH, parse_dates=_date_cols)

g = torch.load(GRAPH_IN)
print("Incidents:", inc.shape)
print(g)


Incidents: (264458, 31)
HeteroData(
  substation={
    x=[194, 18],
    node_ids=[194],
    substation_id=[194],
  },
  (substation, spatial, substation)={
    edge_index=[2, 7738],
    edge_attr=[7738, 8],
  },
  (substation, temporal, substation)={
    edge_index=[2, 19134],
    edge_attr=[19134, 2],
  },
  (substation, causal, substation)={
    edge_index=[2, 7192],
    edge_attr=[7192, 4],
  }
)


  g = torch.load(GRAPH_IN)


In [3]:
# A3 — Attach Classes_4 (drop 4), majority per substation name, save labeled graph
def attach_classes4_ignore_4(graph, incident_df, target_col="Classes_4", save_path=None):
    g = graph
    sub = g['substation']

    df = incident_df.copy()
    if target_col not in df.columns:
        raise ValueError(f"Missing '{target_col}' in incidents.")
    df['Job Substation'] = df['Job Substation'].astype(str).str.strip().str.upper()
    df[target_col] = pd.to_numeric(df[target_col], errors='coerce')

    # keep only labels 0..3
    df = df[df[target_col].isin([0,1,2,3])]

    labels_by_name = (
        df.groupby('Job Substation')[target_col]
          .apply(lambda s: s.dropna().astype(int).mode().iloc[0] if s.dropna().size else np.nan)
    )

    node_names = [str(n).strip().upper() for n in getattr(sub, 'node_ids', [])]
    y_list, mask_list = [], []
    for name in node_names:
        v = labels_by_name.get(name, np.nan)
        if pd.isna(v): y_list.append(-1); mask_list.append(False)
        else:          y_list.append(int(v)); mask_list.append(True)

    sub.y = torch.tensor(y_list, dtype=torch.long)
    sub.train_mask = torch.tensor(mask_list, dtype=torch.bool)

    labeled = int(sub.train_mask.sum().item())
    print(f"Labeled (0..3 only): {labeled}/{len(node_names)} nodes.")
    if save_path: torch.save(g, save_path); print("Saved ->", save_path)
    return g

g = attach_classes4_ignore_4(g, inc, "Classes_4", save_path=GRAPH_OUT)


Labeled (0..3 only): 194/194 nodes.
Saved -> Hetero_graph_kara_classes4_labeled.pt


In [4]:
# A4 — Build tensors, normalize features/edges, make spatial undirected
g = torch.load(GRAPH_OUT).to(device)

# Edge dicts + per-relation z-score normalization
edge_index_dict = {rel: g[rel].edge_index.to(device) for rel in g.edge_types}
edge_attr_dict, edge_dim_dict = {}, {}
for rel in g.edge_types:
    ea = getattr(g[rel], 'edge_attr', None)
    if ea is not None and ea.numel() > 0:
        m, s = ea.mean(0, keepdim=True), ea.std(0, keepdim=True)
        s[s==0] = 1.0
        edge_attr_dict[rel] = ((ea - m) / s).to(device)
        edge_dim_dict[rel]  = ea.size(1)
    else:
        edge_attr_dict[rel] = None
        edge_dim_dict[rel]  = 0

# Node features z-score
x_raw = g['substation'].x.to(device)
x_mean, x_std = x_raw.mean(0, keepdim=True), x_raw.std(0, keepdim=True)
x_std[x_std==0] = 1.0
x_dict = {'substation': (x_raw - x_mean) / x_std}

# Labels
y = g['substation'].y.to(device)
num_nodes = x_dict['substation'].size(0)

# Make spatial undirected (duplicate reverse edges). Keep temporal/causal directed.
if ('substation','spatial','substation') in g.edge_types:
    rel = ('substation','spatial','substation')
    ei = edge_index_dict[rel]; ea = edge_attr_dict[rel]
    rev_ei = torch.stack([ei[1], ei[0]], dim=0)
    edge_index_dict[rel] = torch.cat([ei, rev_ei], dim=1)
    if ea is not None:
        edge_attr_dict[rel] = torch.cat([ea, ea], dim=0)
    print("Spatial edges (undirected):", edge_index_dict[rel].size(1))


Spatial edges (undirected): 15476


  g = torch.load(GRAPH_OUT).to(device)


In [53]:
# A5 — Robust stratified split with per-class minimums
import math

def split_counts(mask, name):
    yy = y[mask].cpu().numpy()
    from collections import Counter
    c  = dict(Counter(yy))
    print(f"{name:5} size={int(mask.sum().item()):3d} | class counts:", c)

def _check_feasibility(y_all, labeled_idx, train_frac, val_frac, test_frac, min_per_class):
    from collections import Counter
    counts = Counter(y_all[labeled_idx])
    p_rest = 1.0 - train_frac
    p_val  = val_frac / (val_frac + test_frac)
    p_test = test_frac / (val_frac + test_frac)
    need_val  = {c: math.ceil((min_per_class)/(p_rest*p_val)  + 1e-9) for c in counts}
    need_test = {c: math.ceil((min_per_class)/(p_rest*p_test) + 1e-9) for c in counts}
    infeasible = {c: (counts[c], max(need_val[c], need_test[c])) for c in counts if counts[c] < max(need_val[c], need_test[c])}
    return counts, infeasible

def stratified_split_with_min(y_all, labeled_idx, train_frac=0.8, val_frac=0.1, test_frac=0.1,
                              min_per_class=2, base_seed=42, max_tries=500):
    assert abs(train_frac + val_frac + test_frac - 1.0) < 1e-6
    counts, infeasible = _check_feasibility(y_all, labeled_idx, train_frac, val_frac, test_frac, min_per_class)
    if infeasible:
        msg_lines = ["Split infeasible with current fractions/min_per_class.",
                     f"fractions=(train={train_frac:.2f}, val={val_frac:.2f}, test={test_frac:.2f}), min_per_class={min_per_class}",
                     "Per-class labeled counts:"]
        msg_lines += [f"  class {c}: n={counts[c]} (needs ≥{need} to ensure val/test ≥{min_per_class})"
                      for c, (_, need) in infeasible.items()]
        raise RuntimeError("\n".join(msg_lines))

    y_lab = y_all[labeled_idx]
    for t in range(max_tries):
        seed = base_seed + t
        tr_idx, rest_idx = train_test_split(
            labeled_idx, test_size=(1 - train_frac), stratify=y_lab, random_state=seed
        )
        pos_map = {nid:i for i, nid in enumerate(labeled_idx)}
        y_rest  = np.array([y_lab[pos_map[nid]] for nid in rest_idx])

        val_idx, test_idx = train_test_split(
            rest_idx, test_size=test_frac/(val_frac+test_frac),
            stratify=y_rest, random_state=seed
        )
        from collections import Counter
        ok_val  = all(v >= min_per_class for v in Counter(y_all[val_idx]).values())
        ok_test = all(v >= min_per_class for v in Counter(y_all[test_idx]).values())
        if ok_val and ok_test:
            return tr_idx, val_idx, test_idx, seed

    raise RuntimeError("Could not find a split meeting per-class minimums after reseeding.")

# Build labeled index + labels (0..3 only)
labeled_idx  = torch.where(y >= 0)[0].cpu().numpy()
y_np         = y.cpu().numpy()

# Diagnostics
counts, infeasible = _check_feasibility(y_np, labeled_idx, 0.8, 0.1, 0.1, min_per_class=1)
print("Labeled per-class counts:", dict(counts))

train_idx_l, val_idx_l, test_idx_l, used_seed = stratified_split_with_min(
    y_np, labeled_idx, train_frac=0.68, val_frac=0.16, test_frac=0.16, min_per_class=2, base_seed=SEED

)

train_mask = torch.zeros(num_nodes, dtype=torch.bool, device=device); train_mask[torch.tensor(train_idx_l, device=device)] = True
val_mask   = torch.zeros(num_nodes, dtype=torch.bool, device=device); val_mask[torch.tensor(val_idx_l, device=device)]   = True
test_mask  = torch.zeros(num_nodes, dtype=torch.bool, device=device); test_mask[torch.tensor(test_idx_l, device=device)]  = True

print("Used split seed:", used_seed)
split_counts(train_mask, "train")
split_counts(val_mask,   "val")
split_counts(test_mask,  "test")


Labeled per-class counts: {2: 45, 3: 77, 1: 13, 0: 59}
Used split seed: 42
train size=131 | class counts: {2: 30, 3: 52, 1: 9, 0: 40}
val   size= 31 | class counts: {2: 8, 3: 12, 0: 9, 1: 2}
test  size= 32 | class counts: {2: 7, 3: 13, 1: 2, 0: 10}


In [54]:
# B1 — Class-balanced alpha (Cui et al.) + focal CE loss
def class_balanced_alpha(labels, num_classes=4, beta=0.9999):
    counts = np.bincount(labels, minlength=num_classes).astype(np.float32)
    effective_num = 1.0 - np.power(beta, counts)
    weights = (1.0 - beta) / np.maximum(effective_num, 1e-8)
    weights = weights / weights.sum() * num_classes
    return torch.tensor(weights, dtype=torch.float32, device=device), counts

ALPHA_CB, cls_counts = class_balanced_alpha(y_np[labeled_idx], num_classes=4, beta=0.9999)
print("Class-balanced alpha:", ALPHA_CB.detach().cpu().numpy(), "| counts:", cls_counts)

def focal_ce_loss(logits, targets, alpha=None, gamma=2.0):
    logp = F.log_softmax(logits, dim=1)
    p    = torch.exp(logp)
    ce   = F.nll_loss(logp, targets, reduction='none')
    pt   = p[torch.arange(p.size(0), device=logits.device), targets]
    loss = ((1 - pt) ** gamma) * ce
    if alpha is not None:
        loss = alpha[targets] * loss
    return loss.mean()


Class-balanced alpha: [0.5259618  2.381557   0.6891103  0.40337116] | counts: [59. 13. 45. 77.]


In [55]:
# C1 — Hetero GATv2 (edge_dim), larger hidden, BN, more dropout
class HeteroGATv2Edge(nn.Module):
    def __init__(self, hidden=128, out_channels=4, metadata=None, heads=4, dropout=0.35):
        super().__init__()
        self.dropout = dropout
        conv1_dict, conv2_dict = {}, {}
        for rel in metadata[1]:
            edim = edge_dim_dict.get(rel, 0)
            if edim > 0:
                conv1 = GATv2Conv((-1,-1), hidden, heads=heads, concat=False,
                                  edge_dim=edim, add_self_loops=False, dropout=dropout)
                conv2 = GATv2Conv((-1,-1), hidden, heads=heads, concat=False,
                                  edge_dim=edim, add_self_loops=False, dropout=dropout)
            else:
                conv1 = GATv2Conv((-1,-1), hidden, heads=heads, concat=False,
                                  add_self_loops=False, dropout=dropout)
                conv2 = GATv2Conv((-1,-1), hidden, heads=heads, concat=False,
                                  add_self_loops=False, dropout=dropout)
            conv1_dict[rel] = conv1; conv2_dict[rel] = conv2

        self.conv1 = HeteroConv(conv1_dict, aggr='mean')
        self.conv2 = HeteroConv(conv2_dict, aggr='mean')
        self.bn1   = nn.BatchNorm1d(hidden)
        self.bn2   = nn.BatchNorm1d(hidden)
        self.lin   = nn.Linear(hidden, out_channels)

    def forward(self, x_dict, edge_index_dict, edge_attr_dict):
        x_dict = self.conv1(x_dict, edge_index_dict, edge_attr_dict=edge_attr_dict)
        x = self.bn1(x_dict['substation'])
        x = torch.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x_dict = {'substation': x}
        x_dict = self.conv2(x_dict, edge_index_dict, edge_attr_dict=edge_attr_dict)
        x = self.bn2(x_dict['substation'])
        x = torch.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        return self.lin(x)

model = HeteroGATv2Edge(metadata=g.metadata()).to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.0015, weight_decay=2e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode='max', patience=12, factor=0.5)


In [61]:
# === Auto Hyperparam Tuning: Random -> Local Refinement -> Refit (macro-F1) ===
# Assumes you already have: model, device, x_dict, edge_index_dict, edge_attr_dict, y, train_mask, val_mask, test_mask

import math, random, numpy as np, torch
import torch.nn.functional as F
from copy import deepcopy
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split

# ---------------------------
# SEARCH SETTINGS
# ---------------------------
SEARCH_TRIALS_STAGE1 = 24        # initial random search
SEARCH_TRIALS_STAGE2 = 12        # local refinement around best
BASE_SEED            = 2024
MAX_EPOCHS           = 600
ES_PATIENCE          = 25
ES_MIN_DELTA         = 1e-4
MIN_LR               = 1e-6
PRINT_EVERY          = 10

DO_LOCAL_REFINEMENT  = True
DO_REFIT_TRAINVAL    = True
REFIT_SHADOW_VAL_FR  = 0.10      # carve this from train+val for early stopping
DO_ENSEMBLE_TOPK     = False     # optional: average logits of top-k single-model runs
ENSEMBLE_TOPK        = 3
ENSEMBLE_REPEATS     = 2         # per top config

# If you can rebuild the model (to tune dropout/hidden dim/layers), set this to a factory:
# Example:
# def MODEL_FACTORY(cfg=None):
#     return MyModel(hidden_dim=cfg.get("hidden_dim",128),
#                    dropout=cfg.get("dropout",0.3),
#                    num_layers=cfg.get("num_layers",3)).to(device)
MODEL_FACTORY = None

# ---------------------------
# UTILITIES
# ---------------------------
def set_seeds(seed):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def compute_alpha_cb(y_labels, beta, device):
    with torch.no_grad():
        labeled = y_labels[y_labels >= 0].cpu()
        counts  = torch.bincount(labeled)
        eff = 1.0 - (beta ** counts.float())
        eff = eff.clamp_min(1e-12)
        alpha = (1.0 - beta) / eff
        alpha = (alpha / alpha.sum() * len(alpha)).to(device)
    return alpha

def focal_ce_loss(logits, targets, alpha=None, gamma=2.0, label_smoothing=0.0):
    ce = F.cross_entropy(
        logits, targets, weight=alpha, reduction='none', label_smoothing=label_smoothing
    )
    pt = torch.softmax(logits, dim=1).gather(1, targets.unsqueeze(1)).squeeze(1)
    loss = ((1.0 - pt).clamp_min(1e-6) ** gamma) * ce
    return loss.mean()

@torch.no_grad()
def eval_f1_mask(model, mask):
    model.eval()
    logits = model(x_dict, edge_index_dict, edge_attr_dict)
    pred = logits[mask].argmax(dim=1).cpu().numpy()
    true = y[mask].cpu().numpy()
    return f1_score(true, pred, average='macro')

def build_or_reset_model(init_state, cfg=None):
    if MODEL_FACTORY is not None:
        return MODEL_FACTORY(cfg)
    m = deepcopy(model).to(device)
    m.load_state_dict(deepcopy(init_state))
    return m

def train_one_trial(init_state, config, train_mask_, val_mask_):
    m = build_or_reset_model(init_state, config)
    opt = torch.optim.AdamW(m.parameters(), lr=config["lr"], weight_decay=config["wd"])
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        opt, mode="max", factor=config["plateau_factor"], patience=config["plateau_patience"],
        threshold=ES_MIN_DELTA, cooldown=0, min_lr=MIN_LR
    )
    alpha_cb = compute_alpha_cb(y, config["beta"], device)

    best_f1, best_state, patience, best_epoch = -1.0, None, 0, 0
    for epoch in range(1, MAX_EPOCHS + 1):
        m.train(True)
        logits = m(x_dict, edge_index_dict, edge_attr_dict)
        loss = focal_ce_loss(
            logits[train_mask_], y[train_mask_],
            alpha=alpha_cb, gamma=config["gamma"], label_smoothing=config["label_smoothing"]
        )
        opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(m.parameters(), config["clip_norm"])
        opt.step()

        val_f1 = eval_f1_mask(m, val_mask_)
        scheduler.step(val_f1)

        if val_f1 > best_f1 + ES_MIN_DELTA:
            best_f1 = val_f1; best_epoch = epoch; patience = 0
            best_state = {k: v.detach().cpu().clone() for k, v in m.state_dict().items()}
        else:
            patience += 1

        if epoch % PRINT_EVERY == 0 or epoch == 1:
            print(f"  Epoch {epoch:03d} | loss {loss.item():.4f} | val F1 {val_f1:.4f} | lr {opt.param_groups[0]['lr']:.2e}")

        if patience >= ES_PATIENCE:
            break

    # load best and compute test F1
    m.load_state_dict({k: v.to(device) for k, v in best_state.items()})
    test_f1 = eval_f1_mask(m, test_mask)
    return best_f1, test_f1, best_epoch, best_state

def loguniform(a, b):
    return 10 ** np.random.uniform(np.log10(a), np.log10(b))

# ---------------------------
# SEARCH SPACES
# ---------------------------
def sample_config_global():
    cfg = {
        "lr":               float(loguniform(5e-5, 5e-3)),
        "wd":               float(loguniform(1e-6, 1e-3)),
        "gamma":            float(np.random.uniform(1.0, 4.0)),
        "beta":             float(np.random.uniform(0.990, 0.9995)),
        "label_smoothing":  float(np.random.choice([0.0, 0.05, 0.10])),
        "clip_norm":        float(np.random.choice([0.5, 1.0, 2.0])),
        "plateau_patience": int(np.random.choice([8, 12, 16])),
        "plateau_factor":   float(np.random.choice([0.5, 0.7])),
    }
    # Optional architecture knobs (only used if MODEL_FACTORY is provided)
    if MODEL_FACTORY is not None:
        cfg.update({
            "hidden_dim": int(np.random.choice([64, 128, 256])),
            "dropout":    float(np.random.choice([0.1, 0.2, 0.3, 0.5])),
            "num_layers": int(np.random.choice([2, 3, 4])),
        })
    return cfg

def sample_config_local(best_cfg):
    # narrow around the best (multiplicative noise for positive params)
    def jitter_mult(val, low=0.6, high=1.6): return float(val * np.random.uniform(low, high))
    cfg = dict(best_cfg)
    cfg["lr"]               = float(np.clip(jitter_mult(best_cfg["lr"]), 5e-5, 5e-3))
    cfg["wd"]               = float(np.clip(jitter_mult(best_cfg["wd"]), 1e-6, 1e-3))
    cfg["gamma"]            = float(np.clip(np.random.normal(best_cfg["gamma"], 0.4), 1.0, 4.0))
    cfg["beta"]             = float(np.clip(np.random.normal(best_cfg["beta"], 3e-4), 0.990, 0.9995))
    cfg["label_smoothing"]  = float(np.random.choice([best_cfg["label_smoothing"], 0.0, 0.05, 0.1]))
    cfg["clip_norm"]        = float(np.random.choice([best_cfg["clip_norm"], 0.5, 1.0, 2.0]))
    cfg["plateau_patience"] = int(np.random.choice([8, 12, 16]))
    cfg["plateau_factor"]   = float(np.random.choice([0.5, 0.7]))
    if MODEL_FACTORY is not None:
        cfg["hidden_dim"] = int(np.random.choice([best_cfg.get("hidden_dim",128), 64, 128, 256]))
        cfg["dropout"]    = float(np.clip(np.random.normal(best_cfg.get("dropout",0.3), 0.1), 0.05, 0.6))
        cfg["num_layers"] = int(np.random.choice([best_cfg.get("num_layers",3), 2, 3, 4]))
    return cfg

# ---------------------------
# RUN STAGE 1: RANDOM SEARCH
# ---------------------------
init_state = deepcopy(model.state_dict())
set_seeds(BASE_SEED)
results = []
best_global = {"val_f1": -1.0}

print(f"Stage 1 — RANDOM search ({SEARCH_TRIALS_STAGE1} trials)")
for t in range(SEARCH_TRIALS_STAGE1):
    trial_seed = BASE_SEED + t
    set_seeds(trial_seed)
    cfg = sample_config_global()
    print(f"\n=== Trial {t+1}/{SEARCH_TRIALS_STAGE1} | seed={trial_seed} | cfg={cfg} ===")
    val_f1, test_f1, best_epoch, best_state = train_one_trial(init_state, cfg, train_mask, val_mask)
    row = {"trial": t+1, "seed": trial_seed, "config": cfg, "val_f1": float(val_f1),
           "test_f1": float(test_f1), "best_epoch": int(best_epoch), "state": best_state}
    results.append(row)
    if val_f1 > best_global["val_f1"] + 1e-9:
        best_global = row

print("\nTop-5 after Stage 1 (by val F1):")
for r in sorted(results, key=lambda z: (-z["val_f1"], -z["test_f1"]))[:5]:
    c = r["config"]
    print(f"  valF1={r['val_f1']:.4f} | testF1={r['test_f1']:.4f} | epoch={r['best_epoch']:03d} | "
          f"lr={c['lr']:.2e}, wd={c['wd']:.1e}, γ={c['gamma']:.2f}, β={c['beta']:.4f}, "
          f"ls={c['label_smoothing']}, clip={c['clip_norm']}")

# ---------------------------
# RUN STAGE 2: LOCAL REFINEMENT
# ---------------------------
if DO_LOCAL_REFINEMENT:
    print(f"\nStage 2 — LOCAL refinement around best (± jitter) with {SEARCH_TRIALS_STAGE2} trials")
    base_cfg = best_global["config"]
    for t in range(SEARCH_TRIALS_STAGE2):
        trial_seed = BASE_SEED + 10_000 + t
        set_seeds(trial_seed)
        cfg = sample_config_local(base_cfg)
        print(f"\n=== Local Trial {t+1}/{SEARCH_TRIALS_STAGE2} | seed={trial_seed} | cfg={cfg} ===")
        val_f1, test_f1, best_epoch, best_state = train_one_trial(init_state, cfg, train_mask, val_mask)
        row = {"trial": f"L{t+1}", "seed": trial_seed, "config": cfg, "val_f1": float(val_f1),
               "test_f1": float(test_f1), "best_epoch": int(best_epoch), "state": best_state}
        results.append(row)
        if val_f1 > best_global["val_f1"] + 1e-9:
            best_global = row

# ---------------------------
# LEADERBOARD
# ---------------------------
results_sorted = sorted(results, key=lambda r: (-r["val_f1"], -r["test_f1"]))
print("\n=== Leaderboard (top 10 by val F1) ===")
for row in results_sorted[:10]:
    cfg = row["config"]
    print(f"{str(row['trial']).rjust(3)} | valF1={row['val_f1']:.4f} | testF1={row['test_f1']:.4f} "
          f"| epoch={row['best_epoch']:03d} | lr={cfg['lr']:.2e}, wd={cfg['wd']:.1e}, "
          f"γ={cfg['gamma']:.2f}, β={cfg['beta']:.4f}, ls={cfg['label_smoothing']}, "
          f"clip={cfg['clip_norm']}, pat={cfg.get('plateau_patience',12)}, fact={cfg.get('plateau_factor',0.5)}")

print("\n=== Best (by val F1) ===")
print(best_global["config"])
print(f"Best val F1: {best_global['val_f1']:.4f} @ epoch {best_global['best_epoch']} | "
      f"Test F1: {best_global['test_f1']:.4f} | seed={best_global['seed']}")

# ---------------------------
# OPTIONAL: ENSEMBLE top-k
# ---------------------------
def logits_all_for_state(state):
    m = build_or_reset_model(init_state, best_global["config"])
    m.load_state_dict({k: v.to(device) for k, v in state.items()})
    m.eval()
    with torch.no_grad():
        return m(x_dict, edge_index_dict, edge_attr_dict)

if DO_ENSEMBLE_TOPK:
    topk = results_sorted[:ENSEMBLE_TOPK]
    ensemble_logits = None
    for i, row in enumerate(topk, 1):
        # average several repeats per row (retrain from same config+seed to smooth randomness)
        for rep in range(ENSEMBLE_REPEATS):
            print(f"Ensembling: model {i}/{len(topk)}, repeat {rep+1}/{ENSEMBLE_REPEATS}")
            # retrain quickly to get a fresh checkpoint
            _val, _test, _ep, _state = train_one_trial(init_state, row["config"], train_mask, val_mask)
            log = logits_all_for_state(_state)
            ensemble_logits = log if ensemble_logits is None else (ensemble_logits + log)
    ensemble_logits /= float(len(topk) * ENSEMBLE_REPEATS)

    pred_ens = ensemble_logits[test_mask].argmax(1).cpu().numpy()
    true_t   = y[test_mask].cpu().numpy()
    test_f1_ens = f1_score(true_t, pred_ens, average='macro')
    print(f"\nEnsemble Test F1 (top-{ENSEMBLE_TOPK} x{ENSEMBLE_REPEATS}): {test_f1_ens:.4f}")

# ---------------------------
# REFIT on TRAIN+VAL (with small shadow-val)
# ---------------------------
if DO_REFIT_TRAINVAL:
    print("\nRefit — train on (train ∪ val) with a small shadow validation for early stop")
    best_cfg = best_global["config"]
    # indices of train+val
    tv_idx = torch.where((train_mask | val_mask).cpu())[0].cpu().numpy()
    y_tv   = y[tv_idx].cpu().numpy()

    # stratified shadow split
    tv_train_idx, tv_val_idx = train_test_split(tv_idx, test_size=REFIT_SHADOW_VAL_FR,
                                                stratify=y_tv, random_state=BASE_SEED+123)
    tv_train_mask = torch.zeros_like(train_mask); tv_train_mask[tv_train_idx] = True
    tv_val_mask   = torch.zeros_like(val_mask);   tv_val_mask[tv_val_idx]     = True

    # train one more time on train+val (shadow-val for early stop)
    set_seeds(BASE_SEED+55)
    val_f1_refit, test_f1_refit, ep_refit, state_refit = train_one_trial(init_state, best_cfg, tv_train_mask, tv_val_mask)
    print(f"Refit: shadow-val best F1={val_f1_refit:.4f} @ epoch {ep_refit}")
    # final test
    m_final = build_or_reset_model(init_state, best_cfg)
    m_final.load_state_dict({k: v.to(device) for k, v in state_refit.items()})
    final_test_f1 = eval_f1_mask(m_final, test_mask)
    print(f"Final Test F1 after refit: {final_test_f1:.4f}")

# Load best single-model weights by default for downstream use
model.load_state_dict({k: v.to(device) for k, v in best_global["state"].items()})
model.eval();


Stage 1 — RANDOM search (24 trials)

=== Trial 1/24 | seed=2024 | cfg={'lr': 0.0007498925553792222, 'wd': 0.00012511985806183071, 'gamma': 1.564455880115518, 'beta': 0.9904161813555952, 'label_smoothing': 0.05, 'clip_norm': 1.0, 'plateau_patience': 8, 'plateau_factor': 0.7} ===
  Epoch 001 | loss 0.0260 | val F1 0.5257 | lr 7.50e-04
  Epoch 010 | loss 0.0266 | val F1 0.5000 | lr 5.25e-04
  Epoch 020 | loss 0.0115 | val F1 0.5101 | lr 5.25e-04
  Epoch 030 | loss 0.0089 | val F1 0.5101 | lr 3.67e-04
  Epoch 040 | loss 0.0068 | val F1 0.5311 | lr 2.57e-04

=== Trial 2/24 | seed=2025 | cfg={'lr': 9.331389802327092e-05, 'wd': 0.0004608452422319137, 'gamma': 3.7978169196595073, 'beta': 0.9942328975584521, 'label_smoothing': 0.0, 'clip_norm': 0.5, 'plateau_patience': 16, 'plateau_factor': 0.5} ===
  Epoch 001 | loss 0.0047 | val F1 0.5259 | lr 9.33e-05
  Epoch 010 | loss 0.0121 | val F1 0.5259 | lr 9.33e-05
  Epoch 020 | loss 0.0050 | val F1 0.4804 | lr 4.67e-05

=== Trial 3/24 | seed=2026 | 

In [62]:
# === Finalize best config, refit on train+val (with shadow-val), evaluate on test, save artifacts ===
import os, json, numpy as np, torch
from sklearn.metrics import f1_score, classification_report, confusion_matrix
from sklearn.model_selection import train_test_split

BEST_CFG = {
    "lr": 0.004678186464332436,
    "wd": 3.14280901334886e-06,
    "gamma": 3.342487225772438,
    "beta": 0.9905963698507055,
    "label_smoothing": 0.1,
    "clip_norm": 2.0,
    "plateau_patience": 8,
    "plateau_factor": 0.7,
}

SEED = 2040           # winner's seed
SHADOW_FR = 0.10      # 10% of train+val for early stop
MAX_EPOCHS = 800
ES_PATIENCE = 30
ES_MIN_DELTA = 1e-4
MIN_LR = 1e-6

# Build shadow split from train ∪ val
tv_idx = torch.where((train_mask | val_mask).cpu())[0].cpu().numpy()
y_tv   = y[tv_idx].cpu().numpy()
tv_train_idx, tv_val_idx = train_test_split(
    tv_idx, test_size=SHADOW_FR, stratify=y_tv, random_state=SEED
)
tv_train_mask = torch.zeros_like(train_mask); tv_train_mask[tv_train_idx] = True
tv_val_mask   = torch.zeros_like(val_mask);   tv_val_mask[tv_val_idx]     = True

# Train loop (refit with best config)
m = build_or_reset_model(init_state, BEST_CFG)
opt = torch.optim.AdamW(m.parameters(), lr=BEST_CFG["lr"], weight_decay=BEST_CFG["wd"])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    opt, mode="max", factor=BEST_CFG["plateau_factor"],
    patience=BEST_CFG["plateau_patience"], threshold=ES_MIN_DELTA, min_lr=MIN_LR
)
alpha_cb = compute_alpha_cb(y, BEST_CFG["beta"], device)

best_f1, best_state, patience, best_epoch = -1.0, None, 0, 0
for epoch in range(1, MAX_EPOCHS + 1):
    m.train(True)
    logits = m(x_dict, edge_index_dict, edge_attr_dict)
    loss = focal_ce_loss(
        logits[tv_train_mask], y[tv_train_mask],
        alpha=alpha_cb, gamma=BEST_CFG["gamma"], label_smoothing=BEST_CFG["label_smoothing"]
    )
    opt.zero_grad(); loss.backward()
    torch.nn.utils.clip_grad_norm_(m.parameters(), BEST_CFG["clip_norm"])
    opt.step()

    val_f1 = eval_f1_mask(m, tv_val_mask)
    scheduler.step(val_f1)

    if val_f1 > best_f1 + ES_MIN_DELTA:
        best_f1, best_epoch, patience = val_f1, epoch, 0
        best_state = {k: v.detach().cpu().clone() for k, v in m.state_dict().items()}
    else:
        patience += 1
    if patience >= ES_PATIENCE:
        break

# Evaluate on test with best state
m.load_state_dict({k: v.to(device) for k, v in best_state.items()})
m.eval()
with torch.no_grad():
    logits = m(x_dict, edge_index_dict, edge_attr_dict)
pred = logits[test_mask].argmax(1).cpu().numpy()
true = y[test_mask].cpu().numpy()

macro_f1 = f1_score(true, pred, average='macro')
report   = classification_report(true, pred, digits=4)
cm       = confusion_matrix(true, pred)
cm_norm  = confusion_matrix(true, pred, normalize='true')

print(f"Refit best shadow-val F1 {best_f1:.4f} @ epoch {best_epoch}")
print(f"Test macro-F1: {macro_f1:.4f}\n")
print(report)
print("Confusion matrix (counts):\n", cm)
print("\nConfusion matrix (row-normalized):\n", np.round(cm_norm, 3))

# Save artifacts (weights, config, confusion matrix)
os.makedirs("artifacts", exist_ok=True)
torch.save(best_state, "artifacts/model_best_state.pt")
with open("artifacts/best_config.json", "w") as f:
    json.dump(BEST_CFG, f, indent=2)
np.savetxt("artifacts/confusion_matrix.csv", cm, fmt="%d", delimiter=",")
np.savetxt("artifacts/confusion_matrix_normalized.csv", cm_norm, fmt="%.6f", delimiter=",")


Refit best shadow-val F1 1.0000 @ epoch 2
Test macro-F1: 0.5809

              precision    recall  f1-score   support

           0     1.0000    0.7000    0.8235        10
           1     0.0000    0.0000    0.0000         2
           2     0.5000    1.0000    0.6667         7
           3     0.9091    0.7692    0.8333        13

    accuracy                         0.7500        32
   macro avg     0.6023    0.6173    0.5809        32
weighted avg     0.7912    0.7500    0.7417        32

Confusion matrix (counts):
 [[ 7  0  2  1]
 [ 0  0  2  0]
 [ 0  0  7  0]
 [ 0  0  3 10]]

Confusion matrix (row-normalized):
 [[0.7   0.    0.2   0.1  ]
 [0.    0.    1.    0.   ]
 [0.    0.    1.    0.   ]
 [0.    0.    0.231 0.769]]


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
