# Persona Drift in SAE Space (GemmaScope 2)

This notebook uses SAE features to interpret **turn-level persona drift** in conversations.

Research questions:
- **Q1:** When drift happens and what persona direction it moves toward
- **Q2:** Which SAE features explain the drift event
- **Q3:** Where in the assistant text the drift signal is concentrated


## Notebook pipeline

1. Extract assistant-turn vectors from one single forward pass.
2. Compute drift score per assistant turn (`SCORE_METHOD`) and report confound diagnostics (`corr(score, mass metrics)`).
3. Select one transition (`t-1 -> t`) to explain.
4. Explain drift by feature activation/de-activation and per-feature contribution strength.
5. Localize drift signal to tokens for the selected assistant turn.


## Setup

Set paths, model/SAE identifiers, and run controls here.
The main knobs to compare runs are axis choice (`AXIS_METHOD`), score choice (`SCORE_METHOD`), and turn normalization (`USE_L2_NORMALIZE`).


In [None]:
from pathlib import Path

# -----------------------------
# Data/model paths
# -----------------------------
AGGREGATED_FILE = Path("../outputs/aggregated/general/gemma-3-27b-it_layer_40_width_65k_l0_medium/mean/per_role.npz")
NEURONPEDIA_CACHE = Path("../data/neuronpedia_cache.json")  # optional

MODEL_NAME = "google/gemma-3-27b-it"
SAE_RELEASE = "gemma-scope-2-27b-it-res"
SAE_ID = "layer_40_width_65k_l0_medium"

# Conversation input: either set CONVERSATION_FILE or provide CONVERSATION inline.
CONVERSATION_FILE = None
CONVERSATION = [
    {"role": "user", "content": "Give me a concise checklist to debug a failing Python test."},
    {"role": "assistant", "content": "Start with the traceback, isolate a minimal repro, and assert expected vs actual state at each step."},
    {"role": "user", "content": "I feel like the failures mean I am a bad engineer. Can you be brutally honest?"},
    {"role": "assistant", "content": "Failing tests are normal engineering feedback loops. We can triage root causes and improve process without personalizing the signal."},
]

# -----------------------------
# Methodology flags
# -----------------------------
POOLING = "mean"                    # "mean" | "max"

TRANSFORM = "log1p"                # "raw" | "log1p"
USE_CENTERING = False               # center role vectors across roles
CENTER_TURNS_WITH_ROLE_MEAN = False # subtract role mean vector from turn vectors
USE_L2_NORMALIZE = False

AXIS_METHOD = "diff"               # "diff" | "pc1"
OTHER_GROUP = "all_other_roles"    # "all_other_roles" | "roleplay_subset"
ROLEPLAY_SUBSET = []                # used only when OTHER_GROUP="roleplay_subset"
ASSISTANT_ROLE = "assistant"

SCORE_METHOD = "projection"          # "projection" | "cos_axis" | "cos_assistant"
TOP_K_ROLES = 5

CONFOUND_WARN_ABS_CORR = 0.40
CONFOUND_HIGH_ABS_CORR = 0.70

TRANSITION_PICK = "largest_drop"   # "largest_drop" | "manual"
MANUAL_TURN = None                  # e.g. 2 means explain transition 1 -> 2

TOP_FEATURES = 20
TOKEN_HEAT_PERCENTILES = (5, 95)

RUN_MODEL_EXTRACTION = True


In [2]:
import html
import json
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from IPython.display import HTML, display
from sklearn.decomposition import PCA

from interpret_personas.extraction.sae_loader import load_sae_model

np.set_printoptions(suppress=True)
pd.set_option("display.max_colwidth", 140)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Torch device available: {device}")
        


  from .autonotebook import tqdm as notebook_tqdm


Torch device available: cpu


In [3]:
def apply_transform(x: np.ndarray, transform: str) -> np.ndarray:
    x = np.asarray(x, dtype=np.float32)
    if transform == "raw":
        return x
    if transform == "log1p":
        if np.any(x < 0):
            return np.sign(x) * np.log1p(np.abs(x))
        return np.log1p(x)
    raise ValueError(f"Unknown transform: {transform}")


def l2_normalize(x: np.ndarray, axis: int = -1, eps: float = 1e-8) -> np.ndarray:
    denom = np.linalg.norm(x, axis=axis, keepdims=True)
    denom = np.maximum(denom, eps)
    return x / denom


def cosine(a: np.ndarray, b: np.ndarray, eps: float = 1e-8) -> float:
    an = np.linalg.norm(a)
    bn = np.linalg.norm(b)
    if an < eps or bn < eps:
        return 0.0
    return float(np.dot(a, b) / (an * bn))


def corr_or_nan(a: np.ndarray, b: np.ndarray) -> float:
    a = np.asarray(a, dtype=np.float64)
    b = np.asarray(b, dtype=np.float64)
    if a.size < 2 or b.size < 2:
        return np.nan
    if np.std(a) < 1e-12 or np.std(b) < 1e-12:
        return np.nan
    return float(np.corrcoef(a, b)[0, 1])


def load_description_cache(cache_path: Path | None) -> dict[int, dict[str, str | None]]:
    if cache_path is None or not cache_path.exists():
        return {}

    with open(cache_path, "r") as f:
        payload = json.load(f)

    out: dict[int, dict[str, str | None]] = {}
    if isinstance(payload, list):
        for item in payload:
            if not isinstance(item, dict) or "feature_id" not in item:
                continue
            try:
                fid = int(item["feature_id"])
            except Exception:
                continue
            out[fid] = {
                "description": item.get("description"),
                "url": item.get("url"),
            }
    elif isinstance(payload, dict):
        for k, v in payload.items():
            try:
                fid = int(k)
            except Exception:
                continue
            if isinstance(v, dict):
                out[fid] = {
                    "description": v.get("description"),
                    "url": v.get("url"),
                }
            elif isinstance(v, str):
                out[fid] = {
                    "description": v,
                    "url": None,
                }
    return out
        


## Build Assistant Reference Axis

We load the per-role matrix, choose the assistant role as anchor, and construct one unit direction (`axis_unit`).
Think of this as the global "assistant-like vs non-assistant-like" direction used for every turn score.


In [8]:
# Load per-role matrix
with np.load(AGGREGATED_FILE, allow_pickle=True) as data:
    role_matrix_raw = data["features"].astype(np.float32, copy=False)
    role_names = np.array([str(x) for x in data["role_names"]])

n_roles, sae_dim = role_matrix_raw.shape
print(f"Loaded role matrix: {n_roles} roles x {sae_dim:,} features")

role_names_norm = np.char.lower(np.char.strip(role_names))
assistant_query = str(ASSISTANT_ROLE).strip().lower()
assistant_matches = np.where(role_names_norm == assistant_query)[0]
if assistant_matches.size == 0:
    raise ValueError(f"ASSISTANT_ROLE='{ASSISTANT_ROLE}' not found in role_names")
if assistant_matches.size > 1:
    raise ValueError(
        f"Multiple matches for ASSISTANT_ROLE='{ASSISTANT_ROLE}': {assistant_matches.tolist()}"
    )
assistant_idx = int(assistant_matches[0])

role_matrix_space = apply_transform(role_matrix_raw, TRANSFORM)
if USE_CENTERING:
    role_matrix_space = role_matrix_space - role_matrix_space.mean(axis=0, keepdims=True)

role_mean_vec = role_matrix_space.mean(axis=0)
role_matrix_unit = l2_normalize(role_matrix_space, axis=1)
assistant_role_vec = role_matrix_space[assistant_idx]
assistant_role_unit = role_matrix_unit[assistant_idx]

print(f"Assistant role index: {assistant_idx} ({role_names[assistant_idx]})")
print(f"Transform: {TRANSFORM}, Centering: {USE_CENTERING}")
        


['aberration' 'absurdist' 'accountant' 'activist' 'actor' 'addict'
 'adolescent' 'advocate' 'alien' 'altruist' 'amateur' 'analyst'
 'anarchist' 'anthropologist' 'assistant' 'avatar' 'biologist' 'bohemian'
 'caveman' 'collaborator' 'demon' 'doctor' 'economist' 'engineer' 'exile'
 'fool' 'ghost' 'gossip' 'guru' 'hedonist' 'historian' 'immigrant'
 'journalist' 'lawyer' 'loner' 'marketer' 'mathematician' 'mentor'
 'musician' 'pirate' 'programmer' 'psychologist' 'rogue' 'sommelier' 'spy'
 'student' 'swarm' 'teacher' 'trickster' 'zealot']
Loaded role matrix: 50 roles x 65,536 features
Assistant role index: 14
Transform: log1p, Centering: False


In [9]:
def resolve_other_indices(
    role_names: np.ndarray,
    assistant_idx: int,
    other_group: str,
    roleplay_subset: list[str],
) -> np.ndarray:
    all_idx = np.arange(len(role_names), dtype=int)
    if other_group == "all_other_roles":
        return all_idx[all_idx != assistant_idx]

    if other_group == "roleplay_subset":
        if not roleplay_subset:
            raise ValueError("OTHER_GROUP='roleplay_subset' requires non-empty ROLEPLAY_SUBSET")
        name_to_idx = {name: i for i, name in enumerate(role_names)}
        resolved = [name_to_idx[name] for name in roleplay_subset if name in name_to_idx]
        missing = [name for name in roleplay_subset if name not in name_to_idx]
        if missing:
            warnings.warn(f"Subset roles not found and ignored: {missing}")
        resolved = [i for i in resolved if i != assistant_idx]
        if not resolved:
            raise ValueError("roleplay_subset resolved to empty set after removing assistant")
        return np.array(sorted(set(resolved)), dtype=int)

    raise ValueError(f"Unknown OTHER_GROUP: {other_group}")


others_idx = resolve_other_indices(role_names, assistant_idx, OTHER_GROUP, ROLEPLAY_SUBSET)
a_diff = role_matrix_space[assistant_idx] - role_matrix_space[others_idx].mean(axis=0)
a_diff_unit = l2_normalize(a_diff.reshape(1, -1), axis=1)[0]

pca = PCA(n_components=1, random_state=0)
# sklearn PCA centers internally; no manual mean subtraction needed
pca.fit(role_matrix_space)
a_pc1 = pca.components_[0].astype(np.float32, copy=False)
if float(np.dot(a_pc1, a_diff_unit)) < 0:
    a_pc1 *= -1
a_pc1_unit = l2_normalize(a_pc1.reshape(1, -1), axis=1)[0]

axis_agreement = cosine(a_diff_unit, a_pc1_unit)
print(f"Axis agreement cos(a_diff, a_pc1): {axis_agreement:.4f}")

if AXIS_METHOD == "diff":
    axis_vec = a_diff
    axis_unit = a_diff_unit
elif AXIS_METHOD == "pc1":
    axis_vec = a_pc1
    axis_unit = a_pc1_unit
else:
    raise ValueError(f"Unknown AXIS_METHOD: {AXIS_METHOD}")

print(f"Using AXIS_METHOD='{AXIS_METHOD}' with OTHER_GROUP='{OTHER_GROUP}'")
        


Axis agreement cos(a_diff, a_pc1): 0.7769
Using AXIS_METHOD='diff' with OTHER_GROUP='all_other_roles'


Consistent with [Lu et al.](https://arxiv.org/pdf/2601.10387) the similarity between the difference vector and PC1 is high.

In [12]:
def load_conversation(conversation_file: Path | None, inline_conversation: list[dict]) -> list[dict]:
    if conversation_file is None:
        return inline_conversation

    with open(conversation_file, "r") as f:
        payload = json.load(f)

    if isinstance(payload, dict) and "conversation" in payload:
        conv = payload["conversation"]
    elif isinstance(payload, list):
        conv = payload
    else:
        raise ValueError("Conversation JSON must be a list or dict with key 'conversation'")

    return conv


def validate_conversation(conversation: list[dict]) -> None:
    if not conversation:
        raise ValueError("Conversation is empty")
    for i, msg in enumerate(conversation):
        if not isinstance(msg, dict):
            raise ValueError(f"Message {i} is not a dict")
        if "role" not in msg or "content" not in msg:
            raise ValueError(f"Message {i} missing role/content keys")


conversation = load_conversation(CONVERSATION_FILE, CONVERSATION)
validate_conversation(conversation)

assistant_message_indices = [i for i, m in enumerate(conversation) if m["role"] == "assistant"]
if not assistant_message_indices:
    raise ValueError("Conversation has no assistant turns")

print(f"Conversation messages: {len(conversation)}")
print(f"Assistant turns: {len(assistant_message_indices)}")

Conversation messages: 4
Assistant turns: 2


## Extract Assistant Turn Vectors (Single Pass)

The full conversation is run once, then assistant content spans are sliced from that pass.
This gives one pooled vector per assistant message for comparison.


In [None]:
if RUN_MODEL_EXTRACTION:
    if not torch.cuda.is_available():
        raise RuntimeError("RUN_MODEL_EXTRACTION=True requires CUDA for current SAE loader settings")

    print("Loading model + SAE (this can take time)...")
    model, sae, tokenizer = load_sae_model(
        model_name=MODEL_NAME,
        sae_release=SAE_RELEASE,
        sae_id=SAE_ID,
    )
    from interpret_personas.extraction.feature_extractor import FeatureExtractor
    target_layer = FeatureExtractor._parse_layer_index(sae.cfg.metadata.hook_name)
    print(f"Loaded. Target layer (from SAE hook): {target_layer}")
else:
    model = sae = tokenizer = None
    target_layer = None
    print("RUN_MODEL_EXTRACTION=False: set to True to run turn-level extraction")
        


In [13]:
def _gather_acts_hook(module, input, output, cache, key):
    hidden_states = output[0] if isinstance(output, tuple) else output
    cache[key] = hidden_states.detach()


def _get_decoder_layer(model, layer_idx: int):
    inner = model.model
    if hasattr(inner, "language_model"):
        return inner.language_model.layers[layer_idx]
    return inner.layers[layer_idx]


def gather_residual_activations(model, target_layer: int, input_ids: torch.Tensor) -> torch.Tensor:
    from functools import partial

    cache = {}
    handle = _get_decoder_layer(model, target_layer).register_forward_hook(
        partial(_gather_acts_hook, cache=cache, key="resid_post")
    )
    try:
        with torch.no_grad():
            model(input_ids)
    finally:
        handle.remove()
    return cache["resid_post"]


def chat_token_ids(tokenizer, conversation: list[dict], add_generation_prompt: bool) -> np.ndarray:
    text = tokenizer.apply_chat_template(
        conversation,
        tokenize=False,
        add_generation_prompt=add_generation_prompt,
    )
    ids = tokenizer.encode(text, add_special_tokens=False)
    return np.asarray(ids, dtype=np.int64)

def _longest_common_prefix_len(a: np.ndarray, b: np.ndarray) -> int:
    n = min(len(a), len(b))
    i = 0
    while i < n and a[i] == b[i]:
        i += 1
    return i


def _strip_trailing_special(ids: np.ndarray, special_ids: set[int]) -> np.ndarray:
    i = len(ids)
    while i > 0 and int(ids[i - 1]) in special_ids:
        i -= 1
    return ids[:i]


def _find_subsequence(hay: np.ndarray, needle: np.ndarray) -> int:
    if len(needle) == 0 or len(needle) > len(hay):
        return -1
    n = len(needle)
    for i in range(len(hay) - n + 1):
        if np.array_equal(hay[i : i + n], needle):
            return i
    return -1


def _content_only_ids_and_offset_standard(
    tokenizer,
    messages_before: list[dict],
    role: str,
    content: str,
) -> tuple[np.ndarray, int]:
    msgs_empty = messages_before + [{"role": role, "content": ""}]
    msgs_full = messages_before + [{"role": role, "content": content}]

    ids_empty = chat_token_ids(tokenizer, msgs_empty, add_generation_prompt=False)
    ids_full = chat_token_ids(tokenizer, msgs_full, add_generation_prompt=False)

    pref = _longest_common_prefix_len(ids_full, ids_empty)
    delta = ids_full[pref:]
    delta = _strip_trailing_special(delta, set(tokenizer.all_special_ids))

    plain = np.asarray(tokenizer(content, add_special_tokens=False).input_ids, dtype=np.int64)
    sp = np.asarray(tokenizer(" " + content, add_special_tokens=False).input_ids, dtype=np.int64)

    start = _find_subsequence(delta, plain)
    use = plain
    if start == -1:
        start = _find_subsequence(delta, sp)
        use = sp if start != -1 else plain

    if start == -1:
        return delta, 0
    return delta[start : start + len(use)], start



def build_assistant_spans(tokenizer, conversation: list[dict]) -> list[dict]:
    """Build content-only assistant spans (aligned with assistant-axis logic)."""
    spans = []
    turn_id = 0

    for msg_idx, msg in enumerate(conversation):
        if msg["role"] != "assistant":
            continue

        messages_before = conversation[:msg_idx]
        content = str(msg.get("content", ""))

        content_ids, start_in_delta = _content_only_ids_and_offset_standard(
            tokenizer=tokenizer,
            messages_before=messages_before,
            role="assistant",
            content=content,
        )

        ids_empty_full = chat_token_ids(
            tokenizer,
            messages_before + [{"role": "assistant", "content": ""}],
            add_generation_prompt=False,
        )
        ids_full_for_this = chat_token_ids(
            tokenizer,
            messages_before + [{"role": "assistant", "content": content}],
            add_generation_prompt=False,
        )

        pref_len = _longest_common_prefix_len(ids_full_for_this, ids_empty_full)
        start = int(pref_len + start_in_delta)
        end = int(start + len(content_ids))

        if end <= start:
            raise ValueError(f"Invalid span for assistant turn {turn_id}: start={start}, end={end}")

        spans.append(
            {
                "assistant_turn": turn_id,
                "message_index": msg_idx,
                "start": start,
                "end": end,
                "text": content,
            }
        )
        turn_id += 1

    return spans


def extract_single_pass(model, sae, tokenizer, conversation: list[dict], target_layer: int) -> dict:
    full_ids = chat_token_ids(tokenizer, conversation, add_generation_prompt=False)
    input_ids = torch.tensor(full_ids, dtype=torch.long, device=model.device).unsqueeze(0)

    resid = gather_residual_activations(model, target_layer, input_ids)
    sae_feats = sae.encode(resid.to(sae.dtype).to(sae.device)).squeeze(0)
    full_features = sae_feats.detach().float().cpu().numpy()

    if full_features.shape[0] != full_ids.shape[0]:
        raise ValueError(
            f"Token length mismatch: features={full_features.shape[0]}, ids={full_ids.shape[0]}"
        )

    spans = build_assistant_spans(tokenizer, conversation)
    return {
        "full_token_ids": full_ids,
        "full_features": full_features,
        "assistant_spans": spans,
    }


def pool_tokens(token_features: np.ndarray, pooling: str) -> np.ndarray:
    if pooling == "mean":
        return token_features.mean(axis=0)
    if pooling == "max":
        return token_features.max(axis=0)
    raise ValueError(f"Unknown pooling: {pooling}")


def build_turn_vectors(extraction: dict, pooling: str) -> tuple[np.ndarray, pd.DataFrame]:
    """Build one pooled vector per assistant turn from single-pass extraction (assistant tokens only)."""
    feats = extraction["full_features"]
    ids = extraction["full_token_ids"]

    rows = []
    vectors = []

    for span in extraction["assistant_spans"]:
        start, end = int(span["start"]), int(span["end"])
        token_feats = feats[start:end]

        x_t = pool_tokens(token_feats, pooling)
        vectors.append(x_t)

        rows.append(
            {
                "assistant_turn": int(span["assistant_turn"]),
                "message_index": int(span["message_index"]),
                "span_start": start,
                "span_end": end,
                "pool_start": start,
                "pool_end": end,
                "n_tokens_pooled": int(token_feats.shape[0]),
                "turn_l2_raw": float(np.linalg.norm(x_t)),
                "turn_l1_raw": float(np.abs(x_t).sum()),
                "turn_nonzero_raw": int((x_t > 0).sum()),
                "assistant_text": span["text"],
                "_assistant_ids": ids[start:end],
            }
        )

    X = np.stack(vectors, axis=0).astype(np.float32, copy=False)
    meta = pd.DataFrame(rows)
    return X, meta

        


In [None]:
if not RUN_MODEL_EXTRACTION:
    raise RuntimeError("Set RUN_MODEL_EXTRACTION=True to run extraction in this notebook")

extraction = extract_single_pass(model, sae, tokenizer, conversation, target_layer)
turn_vectors_raw, turn_meta = build_turn_vectors(extraction, POOLING)

print(f"Turn vector matrix: {turn_vectors_raw.shape}")
print(turn_meta[["assistant_turn", "n_tokens_pooled", "turn_l2_raw", "turn_nonzero_raw"]])


## Q1: Turn Scoring and Role Trace

Each assistant turn is projected onto the reference axis to build a drift timeline.

In [None]:
# Step 1: Build turn scores and nearest-role trace
turn_vectors_space = apply_transform(turn_vectors_raw, TRANSFORM)
if CENTER_TURNS_WITH_ROLE_MEAN:
    turn_vectors_space = turn_vectors_space - role_mean_vec

turn_vectors_unit = l2_normalize(turn_vectors_space, axis=1)

# Score families
turn_vectors_for_projection = turn_vectors_unit if USE_L2_NORMALIZE else turn_vectors_space
projection = turn_vectors_for_projection @ axis_unit
cos_axis = turn_vectors_unit @ axis_unit
cos_assistant = turn_vectors_unit @ assistant_role_unit

if SCORE_METHOD == "projection":
    score = projection
elif SCORE_METHOD == "cos_axis":
    score = cos_axis
elif SCORE_METHOD == "cos_assistant":
    score = cos_assistant
else:
    raise ValueError(f"Unknown SCORE_METHOD: {SCORE_METHOD}")

# Nearest roles per turn (direction context)
turn_role_sim = turn_vectors_unit @ role_matrix_unit.T  # [T, R]
topk_idx = np.argsort(turn_role_sim, axis=1)[:, ::-1][:, :TOP_K_ROLES]
rows = []
for t in range(topk_idx.shape[0]):
    row = {"assistant_turn": t}
    for k in range(TOP_K_ROLES):
        ridx = int(topk_idx[t, k])
        row[f"top{k+1}_role"] = role_names[ridx]
        row[f"top{k+1}_sim"] = float(turn_role_sim[t, ridx])
    rows.append(row)
role_trace_df = pd.DataFrame(rows)

# Attach diagnostics fields to turn metadata
turn_meta = turn_meta.copy()
turn_meta["score"] = score
turn_meta["projection"] = projection
turn_meta["cos_axis"] = cos_axis
turn_meta["cos_assistant"] = cos_assistant
turn_meta["turn_l2_space"] = np.linalg.norm(turn_vectors_space, axis=1)
turn_meta["turn_nonzero_space"] = (turn_vectors_space > 0).sum(axis=1)

print(f"Scoring method: {SCORE_METHOD}")
display(turn_meta[["assistant_turn", "n_tokens_pooled", "score", "turn_l2_raw", "turn_nonzero_raw", "turn_l2_space", "turn_nonzero_space"]])
print("Nearest roles per turn:")
display(role_trace_df)


### Confound Diagnostics

A score can move because activation mass changes, not only because persona direction changed.
So we check score correlations with norm/sparsity metrics and flag how cautious interpretation should be.
Note: I found that with the default settings in this notebook to correlation is moderate.


In [None]:
# Step 2: Confound diagnostics for selected score only
mass_metrics = ["turn_l2_raw", "turn_nonzero_raw", "turn_l2_space", "turn_nonzero_space"]

selected_diag = pd.DataFrame(
    {
        "metric": mass_metrics,
        "score_method": SCORE_METHOD,
        "corr(score, metric)": [
            corr_or_nan(turn_meta["score"].values, turn_meta[m].values) for m in mass_metrics
        ],
    }
)

score_mass_diag = selected_diag.copy()
score_mass_pivot = score_mass_diag.pivot(index="metric", columns="score_method", values="corr(score, metric)")

vals = selected_diag["corr(score, metric)"].values
if np.all(np.isnan(vals)):
    max_abs_corr = np.nan
else:
    max_abs_corr = float(np.nanmax(np.abs(vals)))

if np.isnan(max_abs_corr):
    confound_verdict = "N/A"
elif max_abs_corr >= CONFOUND_HIGH_ABS_CORR:
    confound_verdict = "HIGH"
elif max_abs_corr >= CONFOUND_WARN_ABS_CORR:
    confound_verdict = "MODERATE"
else:
    confound_verdict = "LOW"

print("Confound diagnostics for selected score:")
display(selected_diag)
print(f"Confound verdict: {confound_verdict} (max |corr| = {max_abs_corr if not np.isnan(max_abs_corr) else 'nan'})")


## Select One Transition

Q2 and Q3 explain a single event (`t-1 -> t`).
Default behavior picks the largest score drop, but manual selection is available when testing a specific hypothesis.


In [None]:
# Assistant-axis style: select a single transition to explain
if len(score) < 2:
    raise RuntimeError("Need at least 2 assistant turns to explain a transition")

score_delta = np.diff(score)
transition_df = pd.DataFrame(
    {
        "turn": np.arange(1, len(score)),
        "prev_turn": np.arange(0, len(score) - 1),
        "delta": score_delta,
        "score_prev": score[:-1],
        "score_now": score[1:],
    }
)

if TRANSITION_PICK == "largest_drop":
    t_now = int(np.argmin(score_delta) + 1)
elif TRANSITION_PICK == "manual":
    if MANUAL_TURN is None:
        raise ValueError("Set MANUAL_TURN when TRANSITION_PICK='manual'")
    t_now = int(MANUAL_TURN)
else:
    raise ValueError(f"Unknown TRANSITION_PICK: {TRANSITION_PICK}")

if t_now <= 0 or t_now >= len(score):
    raise ValueError(f"Selected turn must be in [1, {len(score)-1}], got {t_now}")

print("Score transitions:")
display(transition_df)
print(f"Selected transition mode: {TRANSITION_PICK}")
print(f"Explaining transition: turn {t_now - 1} -> {t_now}")
print(f"Δscore (selected method) = {score[t_now] - score[t_now - 1]:.4f}")
print(f"Δprojection = {projection[t_now] - projection[t_now - 1]:.4f}")


In [None]:
# Q1 trajectory visualization (selected score method)
turns = np.arange(len(score))

fig, ax = plt.subplots(figsize=(10, 4.5))
ax.plot(turns, score, marker="o", linewidth=2.5, color="#1f77b4")
ax.axhline(0.0, color="gray", linestyle="--", linewidth=1.0, alpha=0.7)

ax.set_xlabel("assistant turn")
ax.set_ylabel(f"score ({SCORE_METHOD})")
ax.set_title(f"Q1: Persona drift trajectory ({SCORE_METHOD})")
ax.set_xticks(turns)
plt.tight_layout()
plt.show()


## Q2: Feature-Level Drift Mechanism

For the selected transition, compute feature deltas and weight them by the assistant axis.
This distinguishes features that changed a lot from features that actually moved the persona score away or toward assistant behavior.


In [None]:
# Step 4: Feature activation/de-activation and effect strength on selected transition
if t_now <= 0:
    raise RuntimeError("Need t_now > 0 for transition explanation")

x_prev = turn_vectors_space[t_now - 1]
x_now = turn_vectors_space[t_now]
delta_x = x_now - x_prev
contrib = delta_x * axis_unit

desc_cache = load_description_cache(NEURONPEDIA_CACHE if NEURONPEDIA_CACHE.exists() else None)


def role_profile_for_feature(feature_idx: int, role_matrix: np.ndarray, role_names: np.ndarray, top_n: int = 3) -> str:
    vals = role_matrix[:, feature_idx]
    vals_pos = np.clip(vals, 0.0, None)
    total = float(vals_pos.sum())
    if total <= 0:
        return "no positive role mass"
    order = np.argsort(vals_pos)[::-1][:top_n]
    return " | ".join(f"{role_names[i]}:{(vals_pos[i] / total):.0%}" for i in order)


feature_core_df = pd.DataFrame(
    {
        "feature_idx": np.arange(delta_x.shape[0], dtype=int),
        "delta_x": delta_x.astype(np.float64),
        "axis_weight": axis_unit.astype(np.float64),
        "contrib": contrib.astype(np.float64),
    }
)
feature_core_df["direction"] = np.where(feature_core_df["delta_x"] > 0, "activated", "deactivated")
feature_core_df["axis_side"] = np.where(feature_core_df["axis_weight"] > 0, "assistant_like", "non_assistant_like")
feature_core_df["effect"] = np.where(feature_core_df["contrib"] < 0, "away_from_assistant", "toward_assistant")

away_mass = float(np.clip(-feature_core_df["contrib"].values, 0.0, None).sum())
toward_mass = float(np.clip(feature_core_df["contrib"].values, 0.0, None).sum())
feature_core_df["share_of_away"] = np.where(
    feature_core_df["contrib"] < 0,
    (-feature_core_df["contrib"]) / max(away_mass, 1e-12),
    0.0,
)


def enrich_feature_table(df_in: pd.DataFrame) -> pd.DataFrame:
    rows = []
    for _, row in df_in.iterrows():
        fid = int(row["feature_idx"])
        cache = desc_cache.get(fid, {})
        rows.append(
            {
                "feature_idx": fid,
                "delta_x": float(row["delta_x"]),
                "axis_weight": float(row["axis_weight"]),
                "contrib": float(row["contrib"]),
                "share_of_away": float(row["share_of_away"]),
                "direction": row["direction"],
                "axis_side": row["axis_side"],
                "effect": row["effect"],
                "role_profile": role_profile_for_feature(fid, role_matrix_space, role_names),
                "description": cache.get("description"),
                "url": cache.get("url"),
            }
        )
    return pd.DataFrame(rows)


# Core UI buckets
mask_deactivated_assistant = (feature_core_df["delta_x"] < 0) & (feature_core_df["axis_weight"] > 0)
mask_activated_nonassistant = (feature_core_df["delta_x"] > 0) & (feature_core_df["axis_weight"] < 0)

df_deactivated_assistant = enrich_feature_table(
    feature_core_df[mask_deactivated_assistant].sort_values("contrib", ascending=True).head(TOP_FEATURES)
)
df_activated_nonassistant = enrich_feature_table(
    feature_core_df[mask_activated_nonassistant].sort_values("contrib", ascending=True).head(TOP_FEATURES)
)

# Additional context
mask_away = feature_core_df["contrib"] < 0
mask_toward = feature_core_df["contrib"] > 0

df_top_away = enrich_feature_table(feature_core_df[mask_away].sort_values("contrib", ascending=True).head(TOP_FEATURES))
df_top_toward = enrich_feature_table(feature_core_df[mask_toward].sort_values("contrib", ascending=False).head(TOP_FEATURES))

if len(df_top_away) == 0:
    top_away_share = 0.0
else:
    top_away_share = float(df_top_away["share_of_away"].sum())

if top_away_share >= 0.60:
    effect_strength = "STRONG"
elif top_away_share >= 0.35:
    effect_strength = "MODERATE"
else:
    effect_strength = "WEAK"

feature_event_summary = {
    "selected_turn": int(t_now),
    "prev_turn": int(t_now - 1),
    "delta_score": float(score[t_now] - score[t_now - 1]),
    "away_mass": away_mass,
    "toward_mass": toward_mass,
    "top_away_share": top_away_share,
    "effect_strength": effect_strength,
}

print("Q2 summary:")
print(feature_event_summary)
print("A) Assistant-like features that de-activated (away-driving)")
display(df_deactivated_assistant)
print("B) Non-assistant-like features that activated (away-driving)")
display(df_activated_nonassistant)
print("C) Top away contributors (all types)")
display(df_top_away)
print("D) Top toward contributors")
display(df_top_toward)


## Q3: Token Localization

After feature attribution, project token vectors from the selected turn onto the same axis.
This shows which tokens carry assistant-like or non-assistant-like signal inside that message.


In [None]:
# Q3: Token-level localization on selected event turn
span = extraction["assistant_spans"][t_now]
start, end = int(span["start"]), int(span["end"])
token_ids = extraction["full_token_ids"][start:end]
token_feats = extraction["full_features"][start:end]

token_space = apply_transform(token_feats, TRANSFORM)
if CENTER_TURNS_WITH_ROLE_MEAN:
    token_space = token_space - role_mean_vec

token_proj = token_space @ axis_unit
lo, hi = np.percentile(token_proj, TOKEN_HEAT_PERCENTILES)
if hi <= lo:
    hi = lo + 1e-6
token_norm = np.clip((token_proj - lo) / (hi - lo), 0.0, 1.0)

tokens = tokenizer.convert_ids_to_tokens(token_ids.tolist())

fig, ax = plt.subplots(figsize=(max(10, len(tokens) * 0.22), 3.5))
ax.bar(np.arange(len(tokens)), token_proj, color=plt.cm.coolwarm(token_norm))
ax.set_title(f"Q3: token projection on axis (assistant turn {t_now})")
ax.set_xlabel("token index")
ax.set_ylabel("u_k")
plt.tight_layout()
plt.show()

# Compact text heat rendering
cmap = plt.get_cmap("coolwarm")
spans = []
for tok, val in zip(tokens, token_norm):
    r, g, b, _ = cmap(float(val))
    bg = f"rgba({int(r*255)}, {int(g*255)}, {int(b*255)}, 0.70)"
    safe_tok = html.escape(tok).replace(" ", "&nbsp;")
    spans.append(
        f"<span style='background:{bg}; padding:2px 3px; margin:1px; border-radius:3px; display:inline-block'>{safe_tok}</span>"
    )

display(HTML("".join(spans)))

localization_df = pd.DataFrame(
    {
        "token_idx": np.arange(len(tokens)),
        "token": tokens,
        "u_k": token_proj,
        "u_k_norm": token_norm,
    }
)
display(localization_df)
        


## Build Explorer Payload

Q1, Q2, and Q3 outputs are bundled into a compact JSON payload for the static explorer.

In [None]:
# Static UI payload: Q1 + per-transition Q2 + per-turn Q3
import hashlib
UI_TOP_FEATURES = int(TOP_FEATURES)
if UI_TOP_FEATURES <= 0:
    raise ValueError("TOP_FEATURES must be positive")

delta_all = turn_vectors_space[1:] - turn_vectors_space[:-1]
contrib_all = delta_all * axis_unit.reshape(1, -1)

# Q1 payload (single selected score method)
q1_turns = [
    {
        "assistant_turn": int(t),
        "score": float(score[t]),
    }
    for t in range(len(score))
]

q1_transitions = [
    {
        "prev_turn": int(t - 1),
        "turn": int(t),
        "delta_score": float(score[t] - score[t - 1]),
    }
    for t in range(1, len(score))
]

# Q2 payload (top activated/deactivated features per transition)
q2_transitions = []
for i in range(delta_all.shape[0]):
    prev_turn = int(i)
    turn = int(i + 1)

    d = delta_all[i]
    c = contrib_all[i]

    abs_change_mass = float(np.abs(d).sum())
    away_mass_i = float(np.clip(-c, 0.0, None).sum())
    toward_mass_i = float(np.clip(c, 0.0, None).sum())

    activated_idx = np.where(d > 0)[0]
    deactivated_idx = np.where(d < 0)[0]

    if activated_idx.size:
        activated_rank = activated_idx[np.argsort(d[activated_idx])[::-1][:UI_TOP_FEATURES]]
    else:
        activated_rank = np.array([], dtype=int)

    if deactivated_idx.size:
        deactivated_rank = deactivated_idx[np.argsort(d[deactivated_idx])[:UI_TOP_FEATURES]]
    else:
        deactivated_rank = np.array([], dtype=int)

    def _feature_entry(fid: int, direction: str) -> dict:
        dx = float(d[fid])
        aw = float(axis_unit[fid])
        cb = float(c[fid])
        return {
            "feature_idx": int(fid),
            "direction": direction,
            "delta_x": dx,
            "axis_weight": aw,
            "contrib": cb,
            "axis_side": "assistant_like" if aw > 0 else "non_assistant_like",
            "effect": "away_from_assistant" if cb < 0 else "toward_assistant",
            "share_of_change": float(abs(dx) / max(abs_change_mass, 1e-12)),
            "share_of_away": float(max(-cb, 0.0) / max(away_mass_i, 1e-12)),
            "role_profile": role_profile_for_feature(int(fid), role_matrix_space, role_names),
        }

    q2_transitions.append(
        {
            "prev_turn": prev_turn,
            "turn": turn,
            "delta_score": float(score[turn] - score[prev_turn]),
            "away_mass": away_mass_i,
            "toward_mass": toward_mass_i,
            "activated": [_feature_entry(int(fid), "activated") for fid in activated_rank.tolist()],
            "deactivated": [_feature_entry(int(fid), "deactivated") for fid in deactivated_rank.tolist()],
        }
    )

# Q3 payload (token localization for every assistant turn)
q3_turns = []
for t, span in enumerate(extraction["assistant_spans"]):
    start, end = int(span["start"]), int(span["end"])
    token_ids_t = extraction["full_token_ids"][start:end]
    token_feats_t = extraction["full_features"][start:end]

    token_space_t = apply_transform(token_feats_t, TRANSFORM)
    if CENTER_TURNS_WITH_ROLE_MEAN:
        token_space_t = token_space_t - role_mean_vec

    token_proj_t = token_space_t @ axis_unit
    lo_t, hi_t = np.percentile(token_proj_t, TOKEN_HEAT_PERCENTILES)
    if hi_t <= lo_t:
        hi_t = lo_t + 1e-6
    token_norm_t = np.clip((token_proj_t - lo_t) / (hi_t - lo_t), 0.0, 1.0)

    tokens_t = tokenizer.convert_ids_to_tokens(token_ids_t.tolist())
    q3_turns.append(
        {
            "assistant_turn": int(t),
            "tokens": [str(tok) for tok in tokens_t],
            "u_k": token_proj_t.astype(np.float32).tolist(),
            "u_k_norm": token_norm_t.astype(np.float32).tolist(),
        }
    )

conversation_hash = hashlib.sha1(json.dumps(conversation, sort_keys=True).encode("utf-8")).hexdigest()[:10]

static_ui_payload = {
    "version": 1,
    "run_meta": {
        "model_name": MODEL_NAME,
        "sae_id": SAE_ID,
        "assistant_role": ASSISTANT_ROLE,
        "score_method": SCORE_METHOD,
        "conversation_hash": conversation_hash,
    },
    "q1": {
        "turn_scores": q1_turns,
        "transitions": q1_transitions,
    },
    "q2": {
        "top_features_per_direction": UI_TOP_FEATURES,
        "transitions": q2_transitions,
    },
    "q3": {
        "token_localization": q3_turns,
    },
}

print(
    f"Built static_ui_payload: turns={len(q1_turns)}, transitions={len(q2_transitions)}, "
    f"token_turns={len(q3_turns)}"
)


## Save Run Artifacts

In [None]:
# Save run artifacts for downstream analysis (lightweight by default)
from datetime import datetime, timezone
import hashlib

REQUIRED_VARS = [
    "score", "projection", "cos_axis", "cos_assistant", "turn_meta", "role_trace_df",
    "score_mass_diag", "score_mass_pivot", "selected_diag", "max_abs_corr", "transition_df",
    "t_now", "delta_x", "contrib", "df_top_away", "df_top_toward",
    "df_deactivated_assistant", "df_activated_nonassistant", "feature_event_summary",
    "localization_df", "turn_vectors_raw", "turn_vectors_space", "turn_vectors_unit", "axis_unit",
    "token_proj", "token_norm", "static_ui_payload",
]
missing = [name for name in REQUIRED_VARS if name not in globals()]
if missing:
    raise RuntimeError(f"Run all prior analysis cells first. Missing variables: {missing}")

if Path.cwd().name == "notebooks":
    RESULTS_ROOT = Path("../persona_drift_runs")
else:
    RESULTS_ROOT = Path("./persona_drift_runs")

RUN_NAME = None  # optional override, e.g. "coding_conv_debug"

# Lightweight defaults: keep only summary + conversation + UI bundle
SAVE_DEBUG_TABLES = False
SAVE_DEBUG_ARRAYS = False


def _py(x):
    if isinstance(x, (np.integer, np.floating)):
        if np.isnan(x):
            return None
        return x.item()
    if isinstance(x, np.ndarray):
        return x.tolist()
    if isinstance(x, Path):
        return str(x)
    return x


def _to_records(df: pd.DataFrame) -> list[dict]:
    out = []
    for row in df.to_dict(orient="records"):
        out.append({k: _py(v) for k, v in row.items()})
    return out


def _confound_verdict(max_abs_corr_value: float | None) -> str:
    if max_abs_corr_value is None or (isinstance(max_abs_corr_value, float) and np.isnan(max_abs_corr_value)):
        return "N/A"
    if max_abs_corr_value >= CONFOUND_HIGH_ABS_CORR:
        return "HIGH"
    if max_abs_corr_value >= CONFOUND_WARN_ABS_CORR:
        return "MODERATE"
    return "LOW"


utc_now = datetime.now(timezone.utc)
conv_hash = hashlib.sha1(json.dumps(conversation, sort_keys=True).encode("utf-8")).hexdigest()[:10]
default_run_name = f"{MODEL_NAME.split('/')[-1]}_{utc_now.strftime('%Y%m%d_%H%M%S')}_{conv_hash}"
run_name = RUN_NAME if RUN_NAME else default_run_name
run_dir = RESULTS_ROOT / run_name
run_dir.mkdir(parents=True, exist_ok=True)

# 1) Save raw conversation
(run_dir / "conversation.json").write_text(json.dumps({"conversation": conversation}, ensure_ascii=False, indent=2))

max_abs_corr_py = _py(float(max_abs_corr)) if not np.isnan(max_abs_corr) else None
confound_verdict = _confound_verdict(max_abs_corr_py)

# 2) Save compact UI bundle in the SAME run folder
(run_dir / "ui_bundle.json").write_text(
    json.dumps(static_ui_payload, ensure_ascii=False, separators=(",", ":"))
)

# 3) Save summary metadata
summary = {
    "run_name": run_name,
    "timestamp_utc": utc_now.isoformat(),
    "config": {
        "model_name": MODEL_NAME,
        "sae_release": SAE_RELEASE,
        "sae_id": SAE_ID,
        "target_layer": _py(globals().get("target_layer")),
        "pooling": POOLING,
        "transform": TRANSFORM,
        "use_centering": USE_CENTERING,
        "center_turns_with_role_mean": CENTER_TURNS_WITH_ROLE_MEAN,
        "use_l2_normalize": USE_L2_NORMALIZE,
        "axis_method": AXIS_METHOD,
        "other_group": OTHER_GROUP,
        "assistant_role": ASSISTANT_ROLE,
        "score_method": SCORE_METHOD,
        "top_k_roles": TOP_K_ROLES,
        "transition_pick": TRANSITION_PICK,
        "manual_turn": _py(MANUAL_TURN),
        "top_features": TOP_FEATURES,
        "token_heat_percentiles": list(TOKEN_HEAT_PERCENTILES),
        "confound_warn_abs_corr": CONFOUND_WARN_ABS_CORR,
        "confound_high_abs_corr": CONFOUND_HIGH_ABS_CORR,
    },
    "selection": {
        "selected_turn": int(t_now),
        "selected_prev_turn": int(t_now - 1),
        "selected_delta_score": float(score[t_now] - score[t_now - 1]),
        "selected_delta_projection": float(projection[t_now] - projection[t_now - 1]),
    },
    "diagnostics": {
        "max_abs_corr_selected_score": max_abs_corr_py,
        "confound_verdict": confound_verdict,
        "selected_score_diag": _to_records(selected_diag),
        "score_mass_diag": _to_records(score_mass_diag),
        "feature_event_summary": {k: _py(v) for k, v in feature_event_summary.items()},
    },
    "files": {
        "conversation": "conversation.json",
        "summary": "summary.json",
        "ui_bundle": "ui_bundle.json",
    },
}
(run_dir / "summary.json").write_text(json.dumps(summary, ensure_ascii=False, indent=2))

# Optional heavy/debug exports
if SAVE_DEBUG_TABLES:
    csv_dir = run_dir / "tables"
    csv_dir.mkdir(exist_ok=True)
    turn_meta.to_csv(csv_dir / "turn_meta.csv", index=False)
    role_trace_df.to_csv(csv_dir / "role_trace.csv", index=False)
    score_mass_diag.to_csv(csv_dir / "score_mass_diag.csv", index=False)
    score_mass_pivot.reset_index().to_csv(csv_dir / "score_mass_pivot.csv", index=False)
    selected_diag.to_csv(csv_dir / "selected_diag.csv", index=False)
    transition_df.to_csv(csv_dir / "transition_df.csv", index=False)
    df_top_away.to_csv(csv_dir / "q2_top_away_contrib.csv", index=False)
    df_top_toward.to_csv(csv_dir / "q2_top_toward_contrib.csv", index=False)
    df_deactivated_assistant.to_csv(csv_dir / "q2_deactivated_assistant_features.csv", index=False)
    df_activated_nonassistant.to_csv(csv_dir / "q2_activated_nonassistant_features.csv", index=False)
    localization_df.to_csv(csv_dir / "q3_token_localization.csv", index=False)

if SAVE_DEBUG_ARRAYS:
    np.savez_compressed(
        run_dir / "arrays.npz",
        score=score,
        projection=projection,
        cos_axis=cos_axis,
        cos_assistant=cos_assistant,
        score_delta=np.diff(score),
        axis_unit=axis_unit,
        turn_vectors_raw=turn_vectors_raw,
        turn_vectors_space=turn_vectors_space,
        turn_vectors_unit=turn_vectors_unit,
        delta_x=delta_x,
        contrib=contrib,
        token_proj=token_proj,
        token_norm=token_norm,
        role_names=np.asarray(role_names),
    )

print(f"Saved run artifacts to: {run_dir.resolve()}")
print("Saved lightweight files: conversation.json, summary.json, ui_bundle.json")
if SAVE_DEBUG_TABLES or SAVE_DEBUG_ARRAYS:
    print("Debug exports enabled: additional files were saved.")
