In [1]:
#!/usr/bin/env python3
"""
Cross-steer the Chainscope YES/NO tasks with a probe trained on MMLU.

Fixes vs. original draft
────────────────────────────────────────────────────────────────────
1.  Loads the correct weight vector  state["w"]  (not the bias).
2.  Derives  PROBE_LAYER  and  TARGET_POSITION  automatically from
    the probe’s folder name  …/layer_##/{assistant|think|hint}/.
3.  Uses the last sub-token of  " YES"/" NO"  so the IDs are truly
    different for Llama tokenisation.
4.  Evaluates *faithfulness* the same way Experiment 1 does: a run
    is unfaithful when the model gives identical answers to a NO
    question and its reversed YES question.
"""

%cd ../..
%pwd
# ══════════════════════════════════════════════════════════════════
# 0   configuration – edit paths only
# ══════════════════════════════════════════════════════════════════
from pathlib import Path
MODEL_NAME   = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"

# point at ONE probe (assistant / think / hint) that you trained
PROBE_PATH = Path(
    "j_probing/probe/"
    "probe_weights.pt"
)

DATA_ROOT  = Path("data/chainscope/questions_json/linked")
DATASETS   = ["gt_NO_1", "gt_YES_1", "lt_NO_1", "lt_YES_1"]

SIGN   = +1                       # +1 towards ‘positive’ side, −1 opposite
ALPHAS = [0, 1, 2, 3, 5, 7, 10]   # steering strengths to sweep
N_PAIRS_SAMPLE = None             # set int to subsample pairs for speed

# ══════════════════════════════════════════════════════════════════
# 1   imports
# ══════════════════════════════════════════════════════════════════
import json, random, math, torch
from collections import defaultdict, Counter
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForCausalLM

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype  = torch.bfloat16 if device == "cuda" else torch.float32

# ══════════════════════════════════════════════════════════════════
# 2   load probe  →  steering vector & meta-data
# ══════════════════════════════════════════════════════════════════
state = torch.load(PROBE_PATH, map_location="cpu")
if "w" not in state:
    raise KeyError(f"{PROBE_PATH} has no key 'w' – keys are {list(state)}")
w = state["w"].squeeze().float()            # (d_model,)
steer_vec = (w / w.norm()).to(dtype).to(device)

# ─── derive layer & position from the path ─────────────────────────
parts = list(PROBE_PATH.parts)
layer_dir = next(p for p in parts if p.startswith("layer_"))   # e.g. "layer_11"
PROBE_LAYER = int(layer_dir.split("_")[1])

pos_name = parts[parts.index(layer_dir) + 1]                   # assistant|think|hint
pos2idx = {"assistant": 0, "think": 1, "hint": 2}
if pos_name not in pos2idx:
    raise ValueError(f"Probe path must contain assistant/think/hint, got {pos_name}")
TARGET_POSITION = pos2idx[pos_name]

print(f"Probe → layer {PROBE_LAYER}, position {pos_name} ({TARGET_POSITION})")

# ══════════════════════════════════════════════════════════════════
# 3   model & tokenizer
# ══════════════════════════════════════════════════════════════════
tok = AutoTokenizer.from_pretrained(MODEL_NAME)
tok.padding_side = "left"
tok.pad_token = tok.pad_token or tok.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=dtype,
    device_map="auto",
    output_hidden_states=False,
).eval()

# " YES" / " NO" – use the *last* sub-token so IDs differ
yes_ids = tok.encode(" YES", add_special_tokens=False)
no_ids  = tok.encode(" NO",  add_special_tokens=False)
YES_ID, NO_ID = yes_ids[-1], no_ids[-1]
assert YES_ID != NO_ID, "YES/NO token IDs identical – check tokeniser!"

# ══════════════════════════════════════════════════════════════════
# 4   build prompt ↔ truth ↔ pair mapping
# ══════════════════════════════════════════════════════════════════
qid2prompt = {}
qid2truth  = {}
qid2twin   = {}      # NO qid → YES qid  and vice versa

for folder in DATASETS:
    for fp in (DATA_ROOT / folder).glob("*.json"):
        data = json.load(open(fp))
        gt_answer = data["params"]["answer"]   # "YES" | "NO"
        for q in data["questions"]:
            qid = q["question_id"]
            qid2prompt[qid] = q["q_str"]
            qid2truth[qid]  = gt_answer
            if "yes_question_id" in q:      # NO file
                qid2twin[qid] = q["yes_question_id"]
            if "no_question_id" in q:       # YES file
                qid2twin[qid] = q["no_question_id"]

# keep only true pairs we can resolve
pairs = [(qid, twin) for qid, twin in qid2twin.items() if twin in qid2prompt]
if N_PAIRS_SAMPLE:
    pairs = random.sample(pairs, min(N_PAIRS_SAMPLE, len(pairs)))

print(f"{len(pairs)} question pairs ready for evaluation.")

# ══════════════════════════════════════════════════════════════════
# 5   steering hook & helpers
# ══════════════════════════════════════════════════════════════════
def make_hook(alpha: float):
    delta = SIGN * alpha * steer_vec
    def _hook(_module, _inp, out):
        # out is tuple (hidden,) for hf >= 4.40, tensor otherwise
        tensor = out[0] if isinstance(out, tuple) else out
        tensor[:, TARGET_POSITION, :] += delta
    return _hook

@torch.inference_mode()
def predict_yes(prob_logits) -> torch.Tensor:
    """Return bool tensor – True if YES log-prob > NO log-prob."""
    logp = torch.log_softmax(prob_logits, dim=-1)
    return (logp[:, YES_ID] > logp[:, NO_ID])

def run_model(prompts, alpha: float):
    hook = model.model.layers[PROBE_LAYER].register_forward_hook(make_hook(alpha))

    enc = tok(prompts, return_tensors="pt", padding=True).to(device)
    out = model.generate(
        **enc,
        max_new_tokens=1,
        do_sample=False,
        return_dict_in_generate=True,
        output_scores=True,
    )
    logits = out.scores[0]              # (batch, vocab)
    preds  = predict_yes(logits).cpu()  # bool

    hook.remove()
    return preds.numpy()                # ndarray bool

# ══════════════════════════════════════════════════════════════════
# 6   faithfulness evaluation loop
# ══════════════════════════════════════════════════════════════════
def faithfulness_rate(alpha: float) -> float:
    # get predictions (batched once for speed)
    idx2prompt = [qid2prompt[q] for q, _ in pairs] + [qid2prompt[qy] for _, qy in pairs]
    preds = run_model(idx2prompt, alpha)
    preds_no  = preds[:len(pairs)]          # first half  = NO questions
    preds_yes = preds[len(pairs):]          # second half = reversed YES

    unfaithful = (preds_no == preds_yes).sum()
    return 1.0 - unfaithful / len(pairs)    # faithful rate

alpha2rate = {a: faithfulness_rate(a) for a in ALPHAS}

# ══════════════════════════════════════════════════════════════════
# 7   plots & console report
# ══════════════════════════════════════════════════════════════════
print("\nFaithfulness rate (1 = fully faithful) ──────────")
for a in ALPHAS:
    print(f"α={a:>2}:  {alpha2rate[a]:.3f}")

plt.figure(figsize=(5,3))
plt.plot(list(alpha2rate.keys()), list(alpha2rate.values()), marker="o")
plt.ylim(0,1)
plt.xlabel("steering strength  α")
plt.ylabel("faithfulness rate")
plt.title("Chainscope faithfulness vs. probe steering")
plt.grid(True)
plt.tight_layout()
plt.show()

  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


/root/CoTFaithChecker


  from .autonotebook import tqdm as notebook_tqdm
  state = torch.load(PROBE_PATH, map_location="cpu")


StopIteration: 