### First try for linear probes: 

- Global probe (z_H) is_solved accuracy: 1.0000
- Local probe (z_L) per_cell_correct accuracy: 0.9998
- Saved trained probes to results/probes


-- Something could be going wrong, or it's strong evidence that z_h encodes global features and z_l encodes local ones

### Second try for linear probes (switching z_l and z_h): 

python scripts/train_linear_probes.py --probes_dir results/probes --use_global_z z_L --use_local_z z_H

- Global probe (z_L) is_solved accuracy: 1.0000
- Local probe (z_H) per_cell_correct accuracy: 0.9660

python scripts/train_linear_probes.py --probes_dir results/probes
Global probe (z_H) is_solved accuracy: 1.0000
Local probe (z_L) per_cell_correct accuracy: 0.9128
Saved trained probes to results/probes
(.venv) ubuntu@leo-vm2:~/HRM$ python scripts/train_linear_probes.py --probes_dir results/probes --use_global_z z_L --use_local_z z_H
Global probe (z_L) is_solved accuracy: 1.0000
Local probe (z_H) per_cell_correct accuracy: 0.8991
Saved trained probes to results/probes

### Additional Probe Metrics and Why They Matter

We'll evaluate more than raw accuracy to understand how well hidden-state probes capture signal and how reliable their scores are across steps and representations:
- Precision/Recall/F1: balance performance under class imbalance and error costs.
- Confusion Matrix: error types (false positives vs false negatives).
- Calibration (Reliability) Curve: do probe probabilities reflect true likelihoods?
- Threshold Sweep: sensitivity to decision boundary; find optimal operating point.
- Per-Step Curves: how information emerges across ACT steps.
- z_H vs z_L Comparison: which level encodes global vs local signals more linearly.
- Per-Puzzle Aggregation: consistency across puzzles (variance, outliers).

In [1]:
# Utilities to load probe datasets and trained probes, then compute metrics
import os
import json
import torch
import numpy as np

# Paths
probes_dir = os.path.join("results", "probes")
global_path = os.path.join(probes_dir, "probe_global.pt")
local_path = os.path.join(probes_dir, "probe_local.pt")
index_path = os.path.join(probes_dir, "probe_index.json")

# Load probe datasets
global_samples = torch.load(global_path)
local_samples = torch.load(local_path)
with open(index_path, "r") as f:
    probe_index = json.load(f)

def sigmoid(x):
    return 1 / (1 + torch.exp(-x))

def metrics_binary(logits: torch.Tensor, y: torch.Tensor, threshold: float = 0.5):
    probs = sigmoid(logits)
    preds = (probs >= threshold).to(torch.int64)
    y = y.to(torch.int64)
    tp = int(((preds == 1) & (y == 1)).sum().item())
    tn = int(((preds == 0) & (y == 0)).sum().item())
    fp = int(((preds == 1) & (y == 0)).sum().item())
    fn = int(((preds == 0) & (y == 1)).sum().item())
    acc = (tp + tn) / max(1, tp + tn + fp + fn)
    prec = tp / max(1, tp + fp)
    rec = tp / max(1, tp + fn)
    f1 = 2 * prec * rec / max(1e-8, prec + rec)
    return {
        "acc": acc,
        "precision": prec,
        "recall": rec,
        "f1": f1,
        "tp": tp,
        "tn": tn,
        "fp": fp,
        "fn": fn,
        "probs": probs.detach(),
    }

def calibration_bins(probs: torch.Tensor, y: torch.Tensor, n_bins: int = 10):
    # Returns per-bin avg prob and accuracy
    probs = probs.detach().cpu().numpy()
    y = y.detach().cpu().numpy()
    bins = np.linspace(0.0, 1.0, n_bins + 1)
    bin_ids = np.digitize(probs, bins) - 1
    out = []
    for b in range(n_bins):
        m = bin_ids == b
        if m.sum() == 0:
            out.append((bins[b:b+2].mean(), np.nan, 0))
        else:
            out.append((probs[m].mean(), y[m].mean(), int(m.sum())))
    return out  # list of (avg_prob, avg_acc, count)

def load_trained_probe(path: str, in_dim: int):
    # A simple Linear(D,1); we only need weights to compute logits: x @ W^T + b
    sd = torch.load(path)
    W = sd.get('linear.weight')
    b = sd.get('linear.bias')
    if W is None:
        # For safety, try alternative keys
        for k in sd:
            if k.endswith('weight') and sd[k].shape == (1, in_dim):
                W = sd[k]
            if k.endswith('bias') and sd[k].shape == (1,):
                b = sd[k]
    assert W is not None and b is not None, "Invalid probe state dict"
    return W.detach(), b.detach()

def build_global_dataset(samples, use_z: str = 'z_H'):
    X, y, steps = [], [], []
    for row in samples:
        z = row[use_z]  # [B, D]
        labels = row.get('labels', {})
        is_solved = labels.get('is_solved')
        if z is None or is_solved is None:
            continue
        # Aggregate batch to single vector and label for the entry
        z_vec = z.mean(dim=0)  # [D]
        if is_solved.ndim > 0:
            y_scalar = (is_solved.float().mean() > 0.5).to(torch.int64)
        else:
            y_scalar = is_solved.to(torch.int64)
        X.append(z_vec.float())
        y.append(y_scalar)
        steps.append(int(row.get('step', 0)))
    X = torch.stack(X)
    y = torch.stack(y)
    steps = torch.tensor(steps)
    return X, y, steps

def build_local_dataset(samples, use_z: str = 'z_L'):
    X, y, steps = [], [], []
    for row in samples:
        z = row[use_z]  # [B, T, D] or [T, D]
        labels = row.get('labels', {})
        pc = labels.get('per_cell_correct')  # [B, T] or [T]
        if z is None or pc is None:
            continue
        if z.ndim == 3:
            B, T, D = z.shape
            X.append(z.reshape(B * T, D).float())
            y.append(pc.reshape(-1).to(torch.int64))
        else:
            T, D = z.shape
            X.append(z.reshape(T, D).float())
            y.append(pc.reshape(-1).to(torch.int64))
        steps.append(int(row.get('step', 0)))
    X = torch.cat(X, dim=0)
    y = torch.cat(y, dim=0)
    steps = torch.tensor(steps)
    return X, y, steps

# Build datasets for z_H and z_L
Xg_H, yg, steps_g = build_global_dataset(global_samples, use_z='z_H')
Xg_L, _ygL, _steps_gL = build_global_dataset(global_samples, use_z='z_L')
Xl_H, yl_H, steps_lH = build_local_dataset(local_samples, use_z='z_H')
Xl_L, yl_L, steps_lL = build_local_dataset(local_samples, use_z='z_L')

# Load trained probes if present
gp_H_path = os.path.join(probes_dir, 'global_probe_z_H.pt')
gp_L_path = os.path.join(probes_dir, 'global_probe_z_L.pt')
lp_H_path = os.path.join(probes_dir, 'local_probe_z_H.pt')
lp_L_path = os.path.join(probes_dir, 'local_probe_z_L.pt')

def probe_logits(X: torch.Tensor, W: torch.Tensor, b: torch.Tensor):
    return (X @ W.t()) + b  # [N,1]

results = {}
# Global z_H
if os.path.exists(gp_H_path):
    W, b = load_trained_probe(gp_H_path, Xg_H.shape[1])
    logits = probe_logits(Xg_H, W, b).view(-1)
    results['global_z_H'] = metrics_binary(logits, yg)

# Global z_L
if os.path.exists(gp_L_path):
    W, b = load_trained_probe(gp_L_path, Xg_L.shape[1])
    logits = probe_logits(Xg_L, W, b).view(-1)
    # Use yg for comparison (same entries aggregated); approximate
    results['global_z_L'] = metrics_binary(logits, yg)

# Local z_H
if os.path.exists(lp_H_path):
    W, b = load_trained_probe(lp_H_path, Xl_H.shape[1])
    logits = probe_logits(Xl_H, W, b).view(-1)
    results['local_z_H'] = metrics_binary(logits, yl_H)

# Local z_L
if os.path.exists(lp_L_path):
    W, b = load_trained_probe(lp_L_path, Xl_L.shape[1])
    logits = probe_logits(Xl_L, W, b).view(-1)
    results['local_z_L'] = metrics_binary(logits, yl_L)

# Print summary metrics
for k, m in results.items():
    print(f"{k}: acc={m['acc']:.4f} precision={m['precision']:.4f} recall={m['recall']:.4f} f1={m['f1']:.4f} TP={m['tp']} TN={m['tn']} FP={m['fp']} FN={m['fn']}")

# Calibration summary (bin means)
for k, m in results.items():
    bins = calibration_bins(m['probs'], (m['probs'] >= 0.5).to(torch.int64))
    # Note: using predicted positives for quick visualization of reliability; for true calibration use y
    print(f"\nCalibration (rough, by predicted positive) for {k}:")
    for avg_p, avg_acc, count in bins:
        print(f"  bin_avg_p={avg_p:.3f} acc={avg_acc if not np.isnan(avg_acc) else 'nan'} count={count}")

# Threshold sweep to see operating point sensitivity

def threshold_sweep(logits: torch.Tensor, y: torch.Tensor, steps=11):
    ts = np.linspace(0.0, 1.0, steps)
    out = []
    for t in ts:
        m = metrics_binary(logits, y, threshold=t)
        out.append((t, m['acc'], m['precision'], m['recall'], m['f1']))
    return out

for k, m in list(results.items()):
    # Recover logits approximately via logit(p)
    probs = m['probs']
    logits = torch.log(probs / (1 - probs + 1e-8))
    print(f"\nThreshold sweep for {k} (t, acc, prec, rec, f1):")
    for row in threshold_sweep(logits, (probs >= 0.5).to(torch.int64)):
        print("  ", tuple(round(v, 4) if isinstance(v, float) else v for v in row))


FileNotFoundError: [Errno 2] No such file or directory: 'results/probes/probe_global.pt'

### Interpreting These Metrics

- Precision/Recall/F1: Accuracy alone can hide class imbalance. Precision penalizes false positives (claiming solved/correct when not), recall penalizes false negatives (missing solved/correct). F1 balances both.
- Confusion Matrix: Helps diagnose whether the probe tends to over/under-predict; useful for adjusting thresholds in downstream use.
- Calibration Curve: If the probe outputs are well-calibrated, a 0.8 probability should correspond to ~80% correctness. Poor calibration suggests scores are not reliable as probabilities.
- Threshold Sweep: Shows robustness of decisions to the chosen cutoff; helpful when the application has asymmetric costs or requires high precision/recall.
- Per-Step Dynamics: In ACT models, signal should strengthen over steps; plotting accuracy vs step reveals when representations become linearly separable.
- z_H vs z_L: Validates representational roles â€” global status tends to be in z_H, local correctness in z_L. If this flips, it informs architecture/training changes.
- Per-Puzzle Aggregation: High variance across puzzles suggests brittleness; consistent performance indicates generalizable encoding.

In [None]:
# Optional: Per-step accuracy curves (requires mapping entries to steps)
# For global probes, we approximated aggregation per entry; here, we compute per-step slices using the step metadata.

def per_step_accuracy_global(samples, use_z: str, probe_path: str):
    X, y, steps = build_global_dataset(samples, use_z=use_z)
    if not os.path.exists(probe_path):
        print(f"Missing probe {probe_path}")
        return []
    W, b = load_trained_probe(probe_path, X.shape[1])
    logits = probe_logits(X, W, b).view(-1)
    probs = sigmoid(logits)
    preds = (probs >= 0.5).to(torch.int64)
    out = []
    for s in sorted(set(steps.tolist())):
        m = steps == s
        if m.sum() == 0:
            continue
        acc = (preds[m] == y[m]).float().mean().item()
        out.append((s, acc))
    return out

gH_steps = per_step_accuracy_global(global_samples, 'z_H', gp_H_path)
gL_steps = per_step_accuracy_global(global_samples, 'z_L', gp_L_path)
print("Per-step global accuracy (z_H):", gH_steps)
print("Per-step global accuracy (z_L):", gL_steps)
