In [5]:
#code block #1 — CONFIG
import os
import glob
import random
import itertools
import numpy as np
import pandas as pd

import torch
import torch.nn.functional as F

from transformer_lens import HookedTransformer
from sae_lens import SAE

# ----------------- PATHS: EDIT THESE -----------------
BASE_PRISM_DIR = r"C:\Users\thors\Documents\GitHub\prism"

# PRISM descriptions for GPT-2 small SAE (same as before)
DESCRIPTION_DIR = os.path.join(
    BASE_PRISM_DIR,
    r"descriptions\gemini-1-5-pro\gpt2-small-sae"
)

# PRISM metrics CSV (polysemanticity + quality)
METRICS_CSV = os.path.join(
    BASE_PRISM_DIR,
    r"results\meta-evaluation_cosine-similarity_target-gpt2-small-sae_textgen-gemini-1-5-pro_mean_evalgen-gemini-1-5-pro_cosmopedia_1000.csv"
)

# Your prompt CSV (generated externally)
PROMPTS_CSV = r"C:\Users\thors\Documents\GitHub\llm-interpretability\gpt-2_layer0_prompts.csv"   # must include: unit,text,sample_type (and ideally layer)

# Output dir
OUTPUT_DIR = os.path.join(BASE_PRISM_DIR, "runtime_collision_results")
os.makedirs(OUTPUT_DIR, exist_ok=True)

# ----------------- MODELS -----------------
EVAL_MODEL_NAME = "gpt2-small"
SAE_RELEASE = "gpt2-small-resid-post-v5-32k"

# ----------------- EXPERIMENT -----------------
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

# If your CSV doesn’t have a layer column, set it here (you said layer 0)
DEFAULT_LAYER_IF_MISSING = 0

# Optional: save raw results (can be big)
SAVE_RAW_RESULTS = False


Using device: cpu


In [6]:
#code block #2 — LOAD PRISM (LAYER 0 ONLY) + MERGE METRICS
# Load metrics once
metrics_df = pd.read_csv(METRICS_CSV)

# Load descriptions (restrict to layer 0 to keep it light)
desc_files_layer0 = glob.glob(os.path.join(DESCRIPTION_DIR, "gpt2-small-sae_layer-0_unit-*.csv"))
if len(desc_files_layer0) == 0:
    raise FileNotFoundError("No layer-0 description CSVs found. Check DESCRIPTION_DIR.")

desc_dfs = [pd.read_csv(p) for p in desc_files_layer0]
descriptions_df = pd.concat(desc_dfs, ignore_index=True)

# Merge
merged_df = descriptions_df.merge(metrics_df, on=["layer", "unit"], how="left")

# For joining later
metrics_cols = ["layer", "unit", "cosine_similarity", "cosine_similarity_random", "max_auc", "max_mad"]
metrics_unique = merged_df[metrics_cols].drop_duplicates(subset=["layer", "unit"]).reset_index(drop=True)

print("Loaded PRISM layer 0 descriptions:", len(descriptions_df))
print("Merged rows:", len(merged_df))


Loaded PRISM layer 0 descriptions: 95
Merged rows: 95


In [7]:
#code block #3 — LOAD EVAL MODEL + SAE LOADER
eval_model = HookedTransformer.from_pretrained(
    EVAL_MODEL_NAME,
    device=device,
    dtype=torch.float32,
)
eval_model.eval()
print("Loaded eval model:", EVAL_MODEL_NAME)

# Max PRISM unit per layer (sanity check)
max_unit_by_layer = merged_df.groupby("layer")["unit"].max().to_dict()

sae_cache = {}

def hook_name_for_layer(layer: int) -> str:
    return f"blocks.{layer}.hook_resid_post"

def get_sae_for_layer(layer: int) -> SAE:
    if layer in sae_cache:
        return sae_cache[layer]

    sae_id = f"blocks.{layer}.hook_resid_post"
    sae, cfg, sparsity = SAE.from_pretrained_with_cfg_and_sparsity(
        SAE_RELEASE,
        sae_id,
        device=device,
    )
    sae.eval()

    # simple check
    if layer in max_unit_by_layer:
        if max_unit_by_layer[layer] >= sae.cfg.d_sae:
            raise ValueError(
                f"Layer {layer}: PRISM max unit={max_unit_by_layer[layer]} but SAE d_sae={sae.cfg.d_sae}. "
                "Mismatch between PRISM and SAE release."
            )

    sae_cache[layer] = sae
    return sae


Loaded pretrained model gpt2-small into HookedTransformer
Loaded eval model: gpt2-small


In [8]:
#code block #4 — LOAD YOUR PROMPTS CSV + NORMALIZE
texts_df = pd.read_csv(PROMPTS_CSV)

# Normalize column names
if "single/mixed" in texts_df.columns:
    texts_df = texts_df.rename(columns={"single/mixed": "sample_type"})

required = {"unit", "text", "sample_type"}
missing = required - set(texts_df.columns)
if missing:
    raise ValueError(f"Missing required columns: {missing}. Found columns: {list(texts_df.columns)}")

# Add layer if missing
if "layer" not in texts_df.columns:
    texts_df["layer"] = DEFAULT_LAYER_IF_MISSING

# Make sample_type robust
texts_df["sample_type"] = (
    texts_df["sample_type"]
    .astype(str)
    .str.lower()
    .str.strip()
    .replace({"0": "single", "1": "mixed", "false": "single", "true": "mixed"})
)

texts_df = texts_df[texts_df["sample_type"].isin(["single", "mixed"])].reset_index(drop=True)

# Basic cleanup
texts_df["text"] = texts_df["text"].astype(str)

# Sanity summary
print("Loaded prompts:", len(texts_df))
print(texts_df["sample_type"].value_counts(dropna=False))

# Ensure we can compute collision penalty (need both single and mixed per unit)
counts = texts_df.groupby(["layer", "unit"])["sample_type"].nunique()
bad = counts[counts < 2]
if len(bad):
    print("Warning: some (layer, unit) missing single or mixed; collision penalty will be NaN for them.")
    print(bad.head(10))


Loaded prompts: 114
sample_type
single    95
mixed     19
Name: count, dtype: int64


In [9]:
#code block #5 — LOSS FUNCTIONS (BASELINE + SAE ABLATION)
from functools import partial

def _per_sample_loss_from_logits(logits: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
    """
    logits: [B, S, V]
    tokens: [B, S]
    returns: [B] mean NLL per token (next-token)
    """
    shift_logits = logits[:, :-1, :]
    shift_labels = tokens[:, 1:]

    logp = torch.log_softmax(shift_logits, dim=-1)
    nll = -logp.gather(dim=-1, index=shift_labels.unsqueeze(-1)).squeeze(-1)  # [B, S-1]

    # GPT-2 usually no padding here; token_count = S-1
    token_counts = torch.full((tokens.shape[0],), fill_value=shift_labels.shape[1], device=tokens.device, dtype=torch.float32)
    return nll.sum(dim=1) / token_counts

@torch.no_grad()
def compute_losses_baseline(texts: list[str]) -> np.ndarray:
    if not texts:
        return np.array([], dtype=np.float32)
    tokens = eval_model.to_tokens(texts).to(device)
    logits = eval_model(tokens, return_type="logits")
    losses = _per_sample_loss_from_logits(logits, tokens)
    return losses.cpu().numpy()

def sae_intervention_hook(
    acts: torch.Tensor,
    hook,
    sae: SAE,
    feature_idx: int,
):
    """
    Ablate one SAE feature at resid_post by setting it to 0.
    acts: [B, S, d_model]
    """
    B, S, D = acts.shape
    flat = acts.reshape(-1, D)
    feats = sae.encode(flat)           # [B*S, d_sae]
    feats[:, feature_idx] = 0.0
    recon = sae.decode(feats)          # [B*S, d_model]
    return recon.reshape(B, S, D)

@torch.no_grad()
def compute_losses_ablate_feature(texts: list[str], layer: int, unit: int) -> np.ndarray:
    if not texts:
        return np.array([], dtype=np.float32)

    sae = get_sae_for_layer(layer)
    hook_name = hook_name_for_layer(layer)

    tokens = eval_model.to_tokens(texts).to(device)

    hook_fn = partial(sae_intervention_hook, sae=sae, feature_idx=int(unit))
    logits = eval_model.run_with_hooks(
        tokens,
        return_type="logits",
        fwd_hooks=[(hook_name, hook_fn)],
    )

    losses = _per_sample_loss_from_logits(logits, tokens)
    return losses.cpu().numpy()


In [None]:
#code block #6 — RUN COLLISION EXPERIMENT ON YOUR CSV (ABLATION PENALTY + BASELINE LOSS DIAGNOSTICS)

results_rows = []

for (layer, unit), g in texts_df.groupby(["layer", "unit"]):
    layer = int(layer)
    unit = int(unit)

    singles = g[g["sample_type"] == "single"]["text"].tolist()
    mixed   = g[g["sample_type"] == "mixed"]["text"].tolist()

    # Baseline + ablate on singles
    base_s = compute_losses_baseline(singles)
    ablt_s = compute_losses_ablate_feature(singles, layer=layer, unit=unit)

    # Baseline + ablate on mixed
    base_m = compute_losses_baseline(mixed)
    ablt_m = compute_losses_ablate_feature(mixed, layer=layer, unit=unit)

    # Store per-sample results (minimal)
    for i, (b, a) in enumerate(zip(base_s, ablt_s)):
        results_rows.append({
            "layer": layer,
            "unit": unit,
            "sample_type": "single",
            "sample_idx": i,
            "baseline": float(b),
            "ablate": float(a),
        })
    for i, (b, a) in enumerate(zip(base_m, ablt_m)):
        results_rows.append({
            "layer": layer,
            "unit": unit,
            "sample_type": "mixed",
            "sample_idx": i,
            "baseline": float(b),
            "ablate": float(a),
        })

results_df = pd.DataFrame(results_rows)
results_df["delta"] = results_df["ablate"] - results_df["baseline"]

# -------- Feature-level summary with baseline-loss diagnostics --------

feat_summary = results_df.groupby(["layer", "unit", "sample_type"]).agg(
    base_loss_mean=("baseline", "mean"),
    base_loss_std=("baseline", "std"),
    delta_mean=("delta", "mean"),
    delta_std=("delta", "std"),
    n=("delta", "size"),
).reset_index()

single_feat = feat_summary[feat_summary["sample_type"] == "single"].rename(columns={
    "base_loss_mean": "base_single_mean",
    "base_loss_std": "base_single_std",
    "delta_mean": "delta_single_mean",
    "delta_std": "delta_single_std",
    "n": "n_single",
}).drop(columns=["sample_type"])

mixed_feat = feat_summary[feat_summary["sample_type"] == "mixed"].rename(columns={
    "base_loss_mean": "base_mixed_mean",
    "base_loss_std": "base_mixed_std",
    "delta_mean": "delta_mixed_mean",
    "delta_std": "delta_mixed_std",
    "n": "n_mixed",
}).drop(columns=["sample_type"])

feat = single_feat.merge(mixed_feat, on=["layer", "unit"], how="outer")

# Collision penalty (absolute)
feat["collision_penalty"] = feat["delta_mixed_mean"] - feat["delta_single_mean"]

# Baseline gap (mixed harder than single?)
feat["baseline_gap"] = feat["base_mixed_mean"] - feat["base_single_mean"]

# Relative deltas (robustness diagnostics)
feat["delta_single_rel"] = feat["delta_single_mean"] / feat["base_single_mean"]
feat["delta_mixed_rel"]  = feat["delta_mixed_mean"]  / feat["base_mixed_mean"]
feat["collision_penalty_rel"] = feat["delta_mixed_rel"] - feat["delta_single_rel"]

# Join PRISM metrics (layer 0 present; others will be NaN)
feat_with_metrics = feat.merge(metrics_unique, on=["layer", "unit"], how="left")

print("Computed feature-level collision metrics for", len(feat_with_metrics), "features")
print(feat_with_metrics[
    ["layer","unit",
     "collision_penalty","collision_penalty_rel",
     "delta_single_mean","delta_mixed_mean",
     "base_single_mean","base_mixed_mean","baseline_gap",
     "cosine_similarity","max_auc","max_mad"]
].head(10))

# Optional saves
if SAVE_RAW_RESULTS:
    raw_path = os.path.join(OUTPUT_DIR, "runtime_collision_from_csv_raw.csv")
    results_df.to_csv(raw_path, index=False)
    print("Saved raw:", raw_path)

feat_path = os.path.join(OUTPUT_DIR, "runtime_collision_from_csv_feature_level.csv")
feat_with_metrics.to_csv(feat_path, index=False)
print("Saved feature-level:", feat_path)

Computed feature-level collision metrics for 19 features
   layer   unit  collision_penalty  collision_penalty_rel  delta_single_mean  \
0      0    259           0.535411               0.113635          -0.546685   
1      0   2002           0.112652               0.026823          -0.070695   
2      0   2236           1.246048               0.227148          -1.244619   
3      0   2332           0.207311               0.047592          -0.224504   
4      0   3358           0.460175               0.095806          -0.416933   
5      0   5369           0.662599               0.132548          -0.647283   
6      0   8966           0.989261               0.180271          -0.988548   
7      0   9661           0.894785               0.166979          -0.856334   
8      0  10233           0.234287               0.052397          -0.212474   
9      0  10917           0.362007               0.080967          -0.359046   

   delta_mixed_mean  base_single_mean  base_mixed_mean  baseli

In [11]:
#code block #7 — OPTIONAL QC BLOCK (PROMPT QUALITY + FEATURE ACTIVATION)
# Toggle this on/off
RUN_QC = True

if RUN_QC:
    import re
    from collections import Counter
    from sklearn.feature_extraction.text import TfidfVectorizer
    from sklearn.metrics.pairwise import cosine_similarity

    def simple_tokenize(s: str):
        return re.findall(r"\w+|[^\w\s]", str(s).lower())

    def repetition_metrics(text: str, ngram_n: int = 3):
        toks = simple_tokenize(text)
        n = len(toks)
        if n == 0:
            return {"n_tokens": 0, "unique_token_ratio": 0.0, "max_ngram_count": 0}

        unique_ratio = len(set(toks)) / n
        if n < ngram_n:
            max_ng = 0
        else:
            ngrams = [tuple(toks[i:i+ngram_n]) for i in range(n-ngram_n+1)]
            counts = Counter(ngrams)
            max_ng = max(counts.values()) if counts else 0

        return {"n_tokens": n, "unique_token_ratio": float(unique_ratio), "max_ngram_count": int(max_ng)}

    def get_descriptions_for_feature(layer: int, unit: int) -> list[str]:
        g = merged_df[(merged_df["layer"] == layer) & (merged_df["unit"] == unit)]
        return g["description"].dropna().astype(str).tolist()

    def tfidf_coverage_scores(text: str, descriptions: list[str]):
        if not descriptions:
            return {"cov_min": np.nan, "cov_mean": np.nan, "cov_max": np.nan}

        corpus = descriptions + [str(text)]
        vec = TfidfVectorizer(ngram_range=(1, 2), max_features=5000).fit_transform(corpus)
        desc_vecs = vec[:-1]
        text_vec = vec[-1]
        sims = cosine_similarity(desc_vecs, text_vec).reshape(-1)

        return {"cov_min": float(np.min(sims)), "cov_mean": float(np.mean(sims)), "cov_max": float(np.max(sims))}

    @torch.no_grad()
    def activation_stats_for_feature(texts: list[str], layer: int, unit: int):
        """
        Returns per-text stats:
          act_mean = mean over positions of SAE feature activation
          act_max  = max over positions of SAE feature activation
        """
        sae = get_sae_for_layer(layer)
        hook_name = hook_name_for_layer(layer)

        tokens = eval_model.to_tokens(texts).to(device)
        _, cache = eval_model.run_with_cache(tokens, names_filter=[hook_name])
        acts = cache[hook_name]  # [B, S, d_model]

        B, S, D = acts.shape
        flat = acts.reshape(-1, D)
        feats = sae.encode(flat)                 # [B*S, d_sae]
        u = feats[:, int(unit)].reshape(B, S)    # [B, S]

        act_mean = u.mean(dim=1).cpu().numpy()
        act_max  = u.max(dim=1).values.cpu().numpy()
        return act_mean, act_max

    # ---- thresholds (start permissive; tighten later) ----
    MIN_UNIQUE_TOKEN_RATIO = 0.30
    MAX_TRIGRAM_REPEAT = 3
    MIN_COV_MIN = 0.05
    MIN_ACT_RATIO_VS_SINGLE_MEDIAN = 0.80  # mixed should be at least 80% of typical single activation

    qc_parts = []

    for (layer, unit), g in texts_df.groupby(["layer", "unit"]):
        layer = int(layer)
        unit = int(unit)
        g = g.copy()

        descs = get_descriptions_for_feature(layer, unit)

        # text-only metrics
        rep = g["text"].apply(repetition_metrics).apply(pd.Series)
        cov = g["text"].apply(lambda t: tfidf_coverage_scores(t, descs)).apply(pd.Series)
        g = pd.concat([g.reset_index(drop=True), rep.reset_index(drop=True), cov.reset_index(drop=True)], axis=1)

        # activation metrics (batched per feature)
        act_mean, act_max = activation_stats_for_feature(g["text"].tolist(), layer=layer, unit=unit)
        g["act_mean"] = act_mean
        g["act_max"] = act_max

        # single baseline activation for this feature
        singles = g[g["sample_type"] == "single"]
        single_median = float(np.median(singles["act_mean"])) if len(singles) else np.nan

        # flags
        g["ok_repetition"] = (g["unique_token_ratio"] >= MIN_UNIQUE_TOKEN_RATIO) & (g["max_ngram_count"] <= MAX_TRIGRAM_REPEAT)
        g["ok_coverage"] = (g["cov_min"].fillna(-1.0) >= MIN_COV_MIN)

        # activation flag:
        # - singles: just require > 0 (or keep all; your choice)
        # - mixed: require not “dead” compared to singles
        g["ok_activation"] = True
        if not np.isnan(single_median) and single_median > 0:
            is_mixed = (g["sample_type"] == "mixed")
            g.loc[is_mixed, "ok_activation"] = g.loc[is_mixed, "act_mean"] >= (MIN_ACT_RATIO_VS_SINGLE_MEDIAN * single_median)

        # final keep decision
        g["keep"] = g["ok_repetition"] & g["ok_coverage"] & g["ok_activation"]

        qc_parts.append(g)

    qc_df = pd.concat(qc_parts, ignore_index=True)

    print("QC keep rate overall:", qc_df["keep"].mean())
    print("QC keep rate by sample_type:")
    print(qc_df.groupby("sample_type")["keep"].mean())

    # Save QC table (useful for later debugging/analysis)
    qc_path = os.path.join(OUTPUT_DIR, "prompt_qc_table.csv")
    qc_df.to_csv(qc_path, index=False)
    print("Saved QC table:", qc_path)


QC keep rate overall: 0.0
QC keep rate by sample_type:
sample_type
mixed     0.0
single    0.0
Name: keep, dtype: float64
Saved QC table: C:\Users\thors\Documents\GitHub\prism\runtime_collision_results\prompt_qc_table.csv
