In [None]:
# ──────────────────────────────────────────────────────────────────────────────
#  XY⇄YX  faithfulness  ·  APPLY  a previously-trained linear probe
# ──────────────────────────────────────────────────────────────────────────────
%cd ../..
%pwd
import os, re, json, time, random, logging
from pathlib import Path
from collections import defaultdict, Counter

import joblib
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import (
    accuracy_score, f1_score, confusion_matrix, classification_report,
    roc_curve, auc,
)

# ╭───────────────────────── CONFIGURABLE PARAMETERS ─────────────────────────╮
PROBE_PATH          = Path("linear_probes/realyesno_None4k/linear_probe_layer5.joblib")
ACTIVATIONS_DIR     = Path("h_hidden_space/outputs/f1_hint_xyyx/xyyx_deterministic/gt_lt_completions_1")
ANSWERS_DIRS        = [
    Path("e_confirm_xy_yx/outputs/matched_vals_gt"),
    Path("e_confirm_xy_yx/outputs/matched_vals_lt"),
]
QUESTION_JSON_ROOT  = Path("data/chainscope/questions_json/linked")

INFERENCE_BATCH_SIZE = 32        # must match the batch-size used when capturing hiddens
MAX_SAMPLES          = None      # e.g. 10_000 to subsample for speed, or None = use all
SEED                 = 0
# ╰────────────────────────────────────────────────────────────────────────────╯

random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

# ─── logging setup ───────────────────────────────────────────────────────────
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s  %(levelname)s │ %(message)s",
    datefmt="%H:%M:%S",
)
log = logging.getLogger("apply_probe")

t0 = time.time()
log.info("loading linear probe  →  %s", PROBE_PATH)
probe = joblib.load(PROBE_PATH)

m = re.search(r"layer(\d+)", PROBE_PATH.name)        # ← fixed: plain assignment
LAYER = int(m.group(1)) if m else None
if LAYER is None:
    raise RuntimeError("Could not infer layer-number from probe filename – "
                       "rename as “…layerXX.joblib” or set LAYER manually.")
log.info("   → expecting activations from transformer layer  %d", LAYER)

# ╭────────────────────────────────────────────────────────────────────────────╮
# │ 1.  answers-metadata lookup                                               │
# ╰────────────────────────────────────────────────────────────────────────────╯
answers_map = {}
for dir_ in ANSWERS_DIRS:
    for fp in dir_.glob("*.json"):
        with open(fp) as f:
            raw = json.load(f)
        questions = raw["questions"] if isinstance(raw, dict) else raw
        for q in questions:
            same_flag = q["same"][0] if isinstance(q["same"], list) else q["same"]
            q_no  = q["question_id"]
            q_yes = q["question_yes_id"]
            answers_map[q_no]  = {"expected": "NO",  "actual": q["a_answers"][0],
                                  "same": same_flag}
            answers_map[q_yes] = {"expected": "YES", "actual": q["b_answers"][0],
                                  "same": same_flag}

log.info("answers_map built with %d entries", len(answers_map))

# ╭────────────────────────────────────────────────────────────────────────────╮
# │ 2.  utilities to resolve hidden-file rows → question-ids                  │
# ╰────────────────────────────────────────────────────────────────────────────╯
_STEM2PATH = {}

def _index_datasets_once():
    if _STEM2PATH:               # already done
        return
    for p in QUESTION_JSON_ROOT.rglob("*.json"):
        _STEM2PATH[p.stem] = p
    if not _STEM2PATH:
        raise RuntimeError(f"No *.json found under {QUESTION_JSON_ROOT}")

def _dataset_questions_for(stem: str):
    _index_datasets_once()
    fp = _STEM2PATH.get(stem)
    if fp is None:
        raise FileNotFoundError(f"Cannot find dataset JSON for stem {stem}")
    with open(fp) as f:
        raw = json.load(f)
    return raw["questions"] if isinstance(raw, dict) else raw

_BATCH_RE = re.compile(r"_batch(\d+)_hidden\.pt$")

def question_ids_for_hidden_file(hid_path: Path, batch_size: int, actual_len: int):
    m = _BATCH_RE.search(hid_path.name)
    if m is None:
        raise ValueError(f"Bad hidden filename: {hid_path.name}")
    batch_idx = int(m.group(1))
    stem      = hid_path.name[:m.start()]         # strip “…_batchX_hidden.pt”
    q_list    = _dataset_questions_for(stem)
    start     = batch_idx * batch_size
    return [q_list[start + j]["question_id"] for j in range(actual_len)]

# ╭────────────────────────────────────────────────────────────────────────────╮
# │ 3.  collect hidden-vectors & labels                                       │
# ╰────────────────────────────────────────────────────────────────────────────╯
X_vecs, y_labels = [], []

hidden_files = sorted(ACTIVATIONS_DIR.rglob("*_hidden.pt"))
log.info("found %d hidden-state files in %s", len(hidden_files), ACTIVATIONS_DIR)

for hid_fp in hidden_files:
    batch_hidden = torch.load(hid_fp)            # list[n_layers]  each  (B,H)
    vecs_L       = batch_hidden[LAYER]           # (B,H) tensor for probed layer
    batch_len    = vecs_L.size(0)

    q_ids = question_ids_for_hidden_file(
        hid_fp, INFERENCE_BATCH_SIZE, batch_len
    )

    for row_idx, qid in enumerate(q_ids):
        meta = answers_map.get(qid)
        if meta is None:
            continue        # shouldn’t happen unless dirs mismatch

        correct = (meta["actual"] == meta["expected"])
        if not correct and not meta["same"]:
            continue        # we *skip* ordinary wrong answers

        label = 1 if correct else 0      # 1 = faithful-and-correct, 0 = same-answer error
        X_vecs.append(vecs_L[row_idx].float().numpy())
        y_labels.append(label)

log.info("collected %d labelled activation-vectors", len(X_vecs))
if not X_vecs:
    raise RuntimeError("Nothing to evaluate on – check the filters & paths.")

# optional subsample for speed
if MAX_SAMPLES and len(X_vecs) > MAX_SAMPLES:
    log.info("sub-sampling to %d items (random, reproducible)", MAX_SAMPLES)
    idx = np.random.RandomState(SEED).choice(len(X_vecs), MAX_SAMPLES, replace=False)
    X_vecs = [X_vecs[i] for i in idx]
    y_labels = [y_labels[i] for i in idx]

X = np.stack(X_vecs)
y = np.array(y_labels)

# ╭────────────────────────────────────────────────────────────────────────────╮
# │ 4.  run the probe                                                         │
# ╰────────────────────────────────────────────────────────────────────────────╯
log.info("running probe on %s vectors (d=%d)", len(X), X.shape[1])

if hasattr(probe, "predict_proba"):
    prob_pos = probe.predict_proba(X)[:, 1]
else:
    prob_pos = probe.decision_function(X)
pred = (prob_pos >= 0.5).astype(int)

acc = accuracy_score(y, pred)
f1  = f1_score(y, pred)
log.info("RESULTS  ·  acc %.3f   f1 %.3f", acc, f1)
log.info("\n" + classification_report(y, pred, digits=3))

# ╭────────────────────────────────────────────────────────────────────────────╮
# │ 5.  quick-look figures                                                    │
# ╰────────────────────────────────────────────────────────────────────────────╯
figdir = PROBE_PATH.parent
figdir.mkdir(parents=True, exist_ok=True)

# Confusion-matrix
cm = confusion_matrix(y, pred)
plt.figure(figsize=(3,3))
plt.imshow(cm, cmap="Blues")
for (i,j),v in np.ndenumerate(cm):
    plt.text(j, i, f"{v:,}", ha="center", va="center", fontsize=9)
plt.xticks([0,1], ["error","correct"]); plt.yticks([0,1], ["error","correct"])
plt.title("confusion matrix")
plt.tight_layout(); cm_path = figdir / "confusion_matrix_apply.png"; plt.savefig(cm_path); plt.close()
log.info("• confusion-matrix written → %s", cm_path)

# ROC
fpr, tpr, _ = roc_curve(y, prob_pos)
roc_auc = auc(fpr, tpr)
plt.figure(figsize=(4,4))
plt.plot(fpr, tpr, lw=2); plt.plot([0,1],[0,1],"--")
plt.title(f"ROC AUC {roc_auc:.3f}")
plt.xlabel("false-positive-rate"); plt.ylabel("true-positive-rate")
plt.tight_layout(); roc_path = figdir / "roc_apply.png"; plt.savefig(roc_path); plt.close()
log.info("• ROC curve written        → %s", roc_path)

# probability histogram
plt.figure(figsize=(5,3))
plt.hist(prob_pos[y==0], bins=40, alpha=0.6, label="label 0  (same-answer error)")
plt.hist(prob_pos[y==1], bins=40, alpha=0.6, label="label 1  (correct)")
plt.xlabel("probe p(correct)"); plt.ylabel("count"); plt.legend()
plt.tight_layout(); hist_path = figdir / "prob_hist_apply.png"; plt.savefig(hist_path); plt.close()
log.info("• probability histogram    → %s", hist_path)

log.info("done in %.1f s", time.time() - t0)


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]
02:34:05  INFO │ 🔌  loading linear probe  →  linear_probes/realyesno_None4k/linear_probe_layer5.joblib
02:34:05  INFO │    → expecting activations from transformer layer  5
02:34:05  INFO │ answers_map built with 100 entries
02:34:05  INFO │ found 592 hidden-state files in h_hidden_space/outputs/f1_hint_xyyx/xyyx_deterministic/gt_lt_completions_1


  batch_hidden = torch.load(hid_fp)            # list[n_layers]  each  (B,H)


/root/CoTFaithChecker


02:34:07  INFO │ collected 9620 labelled activation-vectors
02:34:07  INFO │ running probe on 9620 vectors (d=4096)
02:34:07  INFO │ RESULTS  ·  acc 0.770   f1 0.870
02:34:07  INFO │ 
              precision    recall  f1-score   support

           0      0.800     0.002     0.004      2220
           1      0.770     1.000     0.870      7400

    accuracy                          0.770      9620
   macro avg      0.785     0.501     0.437      9620
weighted avg      0.777     0.770     0.670      9620

02:34:07  INFO │ • confusion-matrix written → linear_probes/realyesno_None4k/confusion_matrix_apply.png
02:34:07  INFO │ • ROC curve written        → linear_probes/realyesno_None4k/roc_apply.png
02:34:08  INFO │ • probability histogram    → linear_probes/realyesno_None4k/prob_hist_apply.png
02:34:08  INFO │ done in 2.2 s
