# Tool-call Monitor: Diff-in-Means and Logistic Probe

Train simple probes on tool-call activations captured by `pipeline.py` under `data/activations/{good,bad}/activations.jsonl`.

We use 8 train + 2 eval activations per class by default, compare a diff-in-means direction vs. a logistic classifier, and render eval tool calls with color-coded scores.

In [30]:
import json
import os
from pathlib import Path
import random
from typing import List, Dict

import torch
from IPython.display import display, HTML

random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x1318d20d0>

In [31]:
DATA_ROOT = Path("data/activations")
GOOD_PATH = DATA_ROOT / "good" / "activations.jsonl"
BAD_PATH = DATA_ROOT / "bad" / "activations.jsonl"

def load_jsonl(path: Path) -> List[Dict]:
    if not path.exists():
        raise FileNotFoundError(f"Missing activations file: {path}")
    rows = []
    with path.open() as f:
        for line in f:
            if line.strip():
                rows.append(json.loads(line))
    return rows

good_rows = load_jsonl(GOOD_PATH)
bad_rows = load_jsonl(BAD_PATH)
print(f"Loaded {len(good_rows)} good activations, {len(bad_rows)} bad activations")


Loaded 14 good activations, 18 bad activations


In [32]:
def select_examples(rows: List[Dict], n_train: int, n_eval: int) -> Dict[str, List]:
    if len(rows) < n_train + n_eval:
        raise ValueError(f"Need at least {n_train + n_eval} rows, got {len(rows)}")
    random.shuffle(rows)
    train = rows[:n_train]
    eval_ = rows[n_train:n_train + n_eval]
    return {"train": train, "eval": eval_}

N_TRAIN_PER_CLASS = 7
N_EVAL_PER_CLASS = 3

splits = {
    "good": select_examples(good_rows, N_TRAIN_PER_CLASS, N_EVAL_PER_CLASS),
    "bad": select_examples(bad_rows, N_TRAIN_PER_CLASS, N_EVAL_PER_CLASS),
}


In [33]:
def rows_to_tensor(rows: List[Dict]) -> torch.Tensor:
    return torch.tensor([r["activation"] for r in rows], dtype=torch.float32)

train_x = torch.cat([
    rows_to_tensor(splits["good"]["train"]),
    rows_to_tensor(splits["bad"]["train"]),
], dim=0)
train_y = torch.cat([
    torch.ones(len(splits["good"]["train"])),
    torch.zeros(len(splits["bad"]["train"])),
])

eval_x = torch.cat([
    rows_to_tensor(splits["good"]["eval"]),
    rows_to_tensor(splits["bad"]["eval"]),
], dim=0)
eval_y = torch.cat([
    torch.ones(len(splits["good"]["eval"])),
    torch.zeros(len(splits["bad"]["eval"])),
])

eval_rows = splits['good']['eval'] + splits['bad']['eval']
print("Train shape:", train_x.shape, "Eval shape:", eval_x.shape)


Train shape: torch.Size([14, 3584]) Eval shape: torch.Size([6, 3584])


## Diff-in-means probe


In [34]:
good_mean = train_x[:len(splits["good"]["train"]), :].mean(dim=0)
bad_mean = train_x[len(splits["good"]["train"]):, :].mean(dim=0)
diff_vec = good_mean - bad_mean

def score_diff(x: torch.Tensor) -> torch.Tensor:
    return torch.matmul(x, diff_vec)

with torch.no_grad():
    diff_scores = score_diff(eval_x)
    diff_probs = torch.sigmoid(diff_scores)
    preds = (diff_scores > 0).float()
    acc = (preds == eval_y).float().mean().item()
    print(f"Diff-in-means accuracy: {acc*100:.1f}% ({int(acc*len(eval_y))}/{len(eval_y)})")


Diff-in-means accuracy: 66.7% (4/6)


## Logistic classifier


In [35]:
logit = torch.nn.Linear(train_x.shape[1], 1)
optimizer = torch.optim.Adam(logit.parameters(), lr=1e-2)
loss_fn = torch.nn.BCEWithLogitsLoss()

EPOCHS = 200
for epoch in range(EPOCHS):
    optimizer.zero_grad()
    logits = logit(train_x).squeeze(-1)
    loss = loss_fn(logits, train_y)
    loss.backward()
    optimizer.step()

with torch.no_grad():
    eval_logits = logit(eval_x).squeeze(-1)
    logistic_probs = torch.sigmoid(eval_logits)
    eval_preds = (logistic_probs > 0.5).float()
    eval_acc = (eval_preds == eval_y).float().mean().item()
    print(f"Logistic accuracy: {eval_acc*100:.1f}% ({int(eval_acc*len(eval_y))}/{len(eval_y)})")


Logistic accuracy: 50.0% (3/6)


In [36]:
def color_for_prob(p: float) -> str:
    p = max(0.0, min(1.0, float(p)))
    r = int(255 * (1 - p))
    g = int(255 * p)
    return f'#{r:02x}{g:02x}00'

def render_examples(eval_rows, eval_y, diff_probs, logistic_probs):
    items = []
    for i, row in enumerate(eval_rows):
        prob = float(logistic_probs[i].item())
        diff_prob = float(diff_probs[i].item())
        color = color_for_prob(prob)
        label = 'good' if float(eval_y[i].item()) == 1.0 else 'bad'
        tool_text = row.get('tool_call', '[tool call missing]')
        prompt = row.get('prompt', '[prompt missing]')
        generation = row.get('generation', '[generation missing]')
        highlighted = generation.replace(tool_text, f"<code style=\"background:{color}; padding:4px; border-radius:4px;\">{tool_text}</code>")
        html = (
            f"<div style='margin-bottom:16px; padding:8px; border:1px solid #ddd; border-radius:6px;'>"
            f"<div style='font-weight:bold; margin-bottom:4px;'>User prompt:</div><div style='margin-left:8px;'>{prompt}</div>"
            f"<div style='font-weight:bold; margin:8px 0 4px;'>Model response:</div><div style='margin-left:8px;'>{highlighted}</div>"
            f"<div style='font-size:12px; margin-top:6px;'>logit_prob_good={prob:.2f}, diff_prob_good={diff_prob:.2f}, true={label}</div>"
            f"</div>"
        )
        items.append(html)
    return HTML(''.join(items))

with torch.no_grad():
    display(render_examples(eval_rows, eval_y, diff_probs, logistic_probs))
