In [1]:
# CELL 1: CONFIG & IMPORTS

import os
import random
import itertools
import numpy as np
import pandas as pd

import torch
import torch.nn.functional as F

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformer_lens import HookedTransformer
from sae_lens import SAE

# ----------------- PATHS: EDIT THESE -----------------

BASE_PRISM_DIR = r"C:\Users\thors\Documents\GitHub\prism"

# PRISM descriptions for GPT-2 small SAE
DESCRIPTION_DIR = os.path.join(
    BASE_PRISM_DIR,
    r"descriptions\gemini-1-5-pro\gpt2-small-sae"
)

# PRISM polysemanticity + COSY metrics for GPT-2 small SAE
METRICS_CSV = os.path.join(
    BASE_PRISM_DIR,
    r"results\meta-evaluation_cosine-similarity_target-gpt2-small-sae_textgen-gemini-1-5-pro_mean_evalgen-gemini-1-5-pro_cosmopedia_1000.csv"
)

# Where to save experiment results (if/when you want)
OUTPUT_DIR = os.path.join(BASE_PRISM_DIR, "runtime_collision_results")
os.makedirs(OUTPUT_DIR, exist_ok=True)

# ----------------- MODELS -----------------

# Evaluation model (HookedTransformer) – GPT-2 small
EVAL_MODEL_NAME = "gpt2-small"  # transformer_lens name

# Text generator for constructing A/B/AB prompts.
GEN_MODEL_NAME = "HuggingFaceTB/SmolLM2-360M-Instruct"  # HF name; can swap later

# SAE release used by PRISM for GPT-2 small: v5, width 32k, resid_post
SAE_RELEASE = "gpt2-small-resid-post-v5-32k"

# ----------------- EXPERIMENT KNOBS -----------------

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

# How many different SAE features to analyse in this run
N_FEATURES_TO_TEST = 4

# How many descriptions (concepts) per feature to use at most
MAX_DESCRIPTIONS_PER_FEATURE = 3  # use all if small; cap if large

# How many samples per single concept (A, B, C, ...)
N_SAMPLES_PER_CONCEPT = 3

# How many samples per pair of concepts (AB, AC, ...)
N_SAMPLES_PER_PAIR = 3

# Max generation length for synthetic prompts
MAX_NEW_TOKENS = 128


  from .autonotebook import tqdm as notebook_tqdm


Using device: cpu


In [2]:
# CELL 2: LOAD PRISM DESCRIPTIONS + METRICS FOR GPT-2 SMALL SAE

import glob

# 1) Load polysemanticity + description quality metrics
metrics_df = pd.read_csv(METRICS_CSV)
print("Loaded metrics for", len(metrics_df), "feature-description rows")

# 2) Load all description CSVs for gpt2-small-sae
desc_files = glob.glob(os.path.join(DESCRIPTION_DIR, "gpt2-small-sae_layer-*_unit-*.csv"))
print(f"Found {len(desc_files)} description files")

desc_dfs = [pd.read_csv(path) for path in desc_files]
descriptions_df = pd.concat(desc_dfs, ignore_index=True)
print("Loaded", len(descriptions_df), "description rows")

# 3) Merge descriptions with metrics on (layer, unit)
merged_df = descriptions_df.merge(metrics_df, on=["layer", "unit"], how="left")


Loaded metrics for 59 feature-description rows
Found 59 description files
Loaded 295 description rows


In [3]:
# CELL 3: SELECT FEATURES TO TEST AND PREP CONCEPT LISTS

# Group by (layer, unit)
grouped = merged_df.groupby(["layer", "unit"])

# Get candidate features with at least 2 descriptions
candidate_features = [
    (layer, unit)
    for (layer, unit), g in grouped
    if len(g) >= 2
]

print(f"Found {len(candidate_features)} features with >= 2 descriptions.")

# Take the first N_FEATURES_TO_TEST for this run
features_to_test = candidate_features[:N_FEATURES_TO_TEST]
print(f"Will test these features (layer, unit): {features_to_test}")

# Build a dict: (layer, unit) -> description rows + metrics (subsampled if needed)
feature_concepts = {}
for (layer, unit) in features_to_test:
    g = grouped.get_group((layer, unit)).reset_index(drop=True)
    if len(g) > MAX_DESCRIPTIONS_PER_FEATURE:
        g = g.sample(n=MAX_DESCRIPTIONS_PER_FEATURE, random_state=SEED).reset_index(drop=True)
    feature_concepts[(layer, unit)] = g


Found 59 features with >= 2 descriptions.
Will test these features (layer, unit): [(0, 259), (0, 2002), (0, 2236), (0, 2332)]


In [4]:
# CELL 4: LOAD GPT-2 SMALL + SAE PER LAYER, MATCHING PRISM

# Load eval model once
eval_model = HookedTransformer.from_pretrained(
    EVAL_MODEL_NAME,
    device=device,
    dtype=torch.float32,
)
eval_model.eval()
print("Loaded eval model:", EVAL_MODEL_NAME)

# Pre-compute, from merged_df, the max unit index per layer that PRISM uses
max_unit_by_layer = merged_df.groupby("layer")["unit"].max().to_dict()

sae_cache = {}

def get_sae_for_layer(layer: int) -> SAE:
    if layer in sae_cache:
        return sae_cache[layer]
    sae_id = f"blocks.{layer}.hook_resid_post"
    print(f"Loading SAE for layer {layer}: release={SAE_RELEASE}, sae_id={sae_id}")

    sae, cfg, sparsity = SAE.from_pretrained_with_cfg_and_sparsity(
        SAE_RELEASE,
        sae_id,
        device=device,
    )
    sae.eval()

    # Simple sanity check: SAE feature width must exceed max PRISM unit for this layer
    if layer in max_unit_by_layer:
        max_unit = max_unit_by_layer[layer]
        n_features = sae.cfg.d_sae  # number of SAE features

        if max_unit >= n_features:
            raise ValueError(
                f"PRISM units go up to {max_unit} in layer {layer}, "
                f"but SAE '{SAE_RELEASE}/{sae_id}' only has {n_features} features. "
                "Likely a mismatch between SAE release and PRISM descriptions."
            )

    sae_cache[layer] = sae
    return sae


def hook_name_for_layer(layer: int) -> str:
    """Return the TransformerLens hook name for this layer's resid_post."""
    return f"blocks.{layer}.hook_resid_post"


`torch_dtype` is deprecated! Use `dtype` instead!


Loaded pretrained model gpt2-small into HookedTransformer
Loaded eval model: gpt2-small


In [5]:
# CELL 5: LOAD TEXT GENERATOR MODEL + HELPER

gen_tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL_NAME)
gen_model = AutoModelForCausalLM.from_pretrained(GEN_MODEL_NAME).to(device)
gen_model.eval()

if gen_tokenizer.pad_token_id is None:
    gen_tokenizer.pad_token_id = gen_tokenizer.eos_token_id

def generate_samples(prompt: str, n_samples: int, max_new_tokens: int = MAX_NEW_TOKENS):
    samples = []
    inputs = gen_tokenizer(prompt, return_tensors="pt").to(device)
    for _ in range(n_samples):
        with torch.no_grad():
            out = gen_model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=0.8,
                top_p=0.95,
                pad_token_id=gen_tokenizer.pad_token_id,
            )
        full_text = gen_tokenizer.decode(out[0], skip_special_tokens=True)
        if full_text.startswith(prompt):
            continuation = full_text[len(prompt):].strip()
        else:
            continuation = full_text
        samples.append(continuation)
    return samples


To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


In [6]:
# CELL 6: LOSS COMPUTATION HELPERS (BASELINE)

def compute_losses_baseline(texts):
    """Baseline per-sample loss with no intervention."""
    if not texts:
        return np.array([]), float("nan")

    tokens = eval_model.to_tokens(texts).to(device)
    input_ids = tokens
    labels = tokens.clone()

    with torch.no_grad():
        logits = eval_model(input_ids, return_type="logits")

    shift_logits = logits[:, :-1, :]
    shift_labels = labels[:, 1:]

    logp = torch.log_softmax(shift_logits, dim=-1)
    nll = -logp.gather(dim=-1, index=shift_labels.unsqueeze(-1)).squeeze(-1)

    token_counts = (shift_labels != eval_model.tokenizer.pad_token_id).float().sum(dim=1)
    # For GPT-2 there is no pad in normal use, so this is effectively just seq_len.

    per_sample_loss = nll.sum(dim=1) / token_counts

    return per_sample_loss.cpu().numpy(), float(per_sample_loss.mean().item())


In [7]:
# CELL 7: SAE INTERVENTION HOOK (PER-LAYER) + LOSS WITH INTERVENTION

from functools import partial

def sae_intervention_hook(
    acts: torch.Tensor,
    hook,
    sae: SAE,
    feature_indices,
    mode: str = "ablate",
    clamp_values=None,
):
    """
    acts: [batch, seq, d_model] activation at blocks.{layer}.hook_resid_post
    mode:
      - "ablate": set selected features to 0
      - "clamp": set selected features to clamp_values (same length as feature_indices)
    """
    assert mode in ("ablate", "clamp")

    bsz, seq_len, d_model = acts.shape
    acts_flat = acts.reshape(-1, d_model)  

    with torch.no_grad():
        feats = sae.encode(acts_flat) 

        if not isinstance(feature_indices, (list, tuple, np.ndarray)):
            feature_indices_list = [feature_indices]
        else:
            feature_indices_list = list(feature_indices)

        if mode == "ablate":
            feats[:, feature_indices_list] = 0.0
        elif mode == "clamp":
            assert clamp_values is not None
            assert len(clamp_values) == len(feature_indices_list)
            for idx, val in zip(feature_indices_list, clamp_values):
                feats[:, idx] = val

        recon_flat = sae.decode(feats) 

    recon = recon_flat.reshape(bsz, seq_len, d_model)
    return recon


def compute_losses_with_intervention(
    texts,
    sae: SAE,
    hook_name: str,
    feature_indices,
    mode: str = "ablate",
    clamp_values=None,
):
    """Compute per-sample loss with SAE intervention at a specific layer."""
    if not texts:
        return np.array([]), float("nan")

    tokens = eval_model.to_tokens(texts).to(device)
    input_ids = tokens
    labels = tokens.clone()

    hook_fn = partial(
        sae_intervention_hook,
        sae=sae,
        feature_indices=feature_indices,
        mode=mode,
        clamp_values=clamp_values,
    )
    fwd_hooks = [(hook_name, hook_fn)]

    with torch.no_grad():
        logits = eval_model.run_with_hooks(
            input_ids,
            return_type="logits",
            fwd_hooks=fwd_hooks,
        )

    shift_logits = logits[:, :-1, :]
    shift_labels = labels[:, 1:]

    logp = torch.log_softmax(shift_logits, dim=-1)
    nll = -logp.gather(dim=-1, index=shift_labels.unsqueeze(-1)).squeeze(-1)

    token_counts = (shift_labels != eval_model.tokenizer.pad_token_id).float().sum(dim=1)
    per_sample_loss = nll.sum(dim=1) / token_counts

    return per_sample_loss.cpu().numpy(), float(per_sample_loss.mean().item())


In [8]:
# CELL 8: PROMPT TEMPLATES + GENERATION

def build_single_concept_prompt(desc: str) -> str:
    return (
        "Write a short paragraph (4–6 sentences) that strongly involves the following theme:\n"
        f"\"{desc}\"\n"
        "Focus ONLY on this theme.\n"
        "Avoid mentioning or alluding to unrelated topics.\n\n"
    )

def build_pair_concept_prompt(desc1: str, desc2: str) -> str:
    return (
        "Write a short paragraph (6–8 sentences) that strongly involves BOTH of the following themes:\n"
        f"1. \"{desc1}\"\n"
        f"2. \"{desc2}\"\n\n"
        "Make sure both themes appear multiple times and interact in a coherent way.\n"
        "Do not write a list; write a natural paragraph.\n\n"
    )

all_feature_entries = []

for (layer, unit), df_feat in feature_concepts.items():
    descriptions = df_feat["description"].tolist()
    n_desc = len(descriptions)
    print(f"\n=== Generating data for feature (layer={layer}, unit={unit}), n_desc={n_desc} ===")

    single_texts = {}  # concept index -> list of texts
    pair_texts = {}    # (i, j) -> list of texts

    # Single-concept texts
    for i, desc in enumerate(descriptions):
        prompt = build_single_concept_prompt(desc)
        texts_i = generate_samples(prompt, n_samples=N_SAMPLES_PER_CONCEPT)
        single_texts[i] = texts_i

    # Pair-concept texts (all unordered pairs)
    for i, j in itertools.combinations(range(n_desc), 2):
        desc_i, desc_j = descriptions[i], descriptions[j]
        prompt_pair = build_pair_concept_prompt(desc_i, desc_j)
        texts_ij = generate_samples(prompt_pair, n_samples=N_SAMPLES_PER_PAIR)
        pair_texts[(i, j)] = texts_ij

    all_feature_entries.append(
        {
            "layer": layer,
            "unit": unit,
            "df_feat": df_feat,
            "descriptions": descriptions,
            "single_texts": single_texts,
            "pair_texts": pair_texts,
        }
    )

print("\nFinished generating prompts for all selected features.")



=== Generating data for feature (layer=0, unit=259), n_desc=3 ===

=== Generating data for feature (layer=0, unit=2002), n_desc=3 ===

=== Generating data for feature (layer=0, unit=2236), n_desc=3 ===

=== Generating data for feature (layer=0, unit=2332), n_desc=3 ===

Finished generating prompts for all selected features.


In [9]:
all_feature_entries

[{'layer': 0,
  'unit': 259,
  'df_feat':    layer  unit                                        description  \
  0      0   259                   Diseases and negative attributes   
  1      0   259     Diseases, medical conditions, or health issues   
  2      0   259  Medical studies of the effects of various fact...   
  
     mean_activation                                         highlights  \
  0         3.352165  ['Text #1: cancer (4.638434410095215)', 'Text ...   
  1         6.872105  ['Text #1:  cancer (6.19657564163208)', 'Text ...   
  2         6.837861  ['Text #1:  cancer (7.169066429138184)', 'Text...   
  
     cosine_similarity  cosine_similarity_random  max_auc  max_mad  
  0           0.467154                  0.333642      0.5      0.0  
  1           0.467154                  0.333642      0.5      0.0  
  2           0.467154                  0.333642      0.5      0.0  ,
  'descriptions': ['Diseases and negative attributes',
   'Diseases, medical conditions, or h

In [10]:
# CELL 9: RUN BASELINE + ABLATION LOSSES FOR EACH FEATURE

results_rows = []

for feat_entry in all_feature_entries:
    layer = feat_entry["layer"]
    unit = feat_entry["unit"]
    descriptions = feat_entry["descriptions"]
    single_texts = feat_entry["single_texts"]
    pair_texts = feat_entry["pair_texts"]

    sae = get_sae_for_layer(layer)
    hook_name = hook_name_for_layer(layer)

    feature_idx = unit  # PRISM's 'unit' is the SAE feature index for this layer

    print(f"\n=== Evaluating feature (layer={layer}, unit={unit}) ===")

    # SINGLE-CONCEPT PROMPTS
    for concept_id, texts in single_texts.items():
        concept_label = f"C{concept_id}"

        # Baseline
        losses_base, _ = compute_losses_baseline(texts)

        # Ablated
        losses_abl, _ = compute_losses_with_intervention(
            texts,
            sae=sae,
            hook_name=hook_name,
            feature_indices=[feature_idx],
            mode="ablate",
            clamp_values=None,
        )

        for i, (text, lb, la) in enumerate(zip(texts, losses_base, losses_abl)):
            results_rows.append(
                {
                    "layer": layer,
                    "unit": unit,
                    "concept_set": concept_label,
                    "sample_type": "single",
                    "sample_idx": i,
                    "mode": "baseline",
                    "loss": float(lb),
                    "text": text,
                }
            )
            results_rows.append(
                {
                    "layer": layer,
                    "unit": unit,
                    "concept_set": concept_label,
                    "sample_type": "single",
                    "sample_idx": i,
                    "mode": "ablate",
                    "loss": float(la),
                    "text": text,
                }
            )

    # PAIR-CONCEPT PROMPTS
    for (i, j), texts in pair_texts.items():
        concept_label = f"C{i}+C{j}"

        losses_base, _ = compute_losses_baseline(texts)
        losses_abl, _ = compute_losses_with_intervention(
            texts,
            sae=sae,
            hook_name=hook_name,
            feature_indices=[feature_idx],
            mode="ablate",
            clamp_values=None,
        )

        for k, (text, lb, la) in enumerate(zip(texts, losses_base, losses_abl)):
            results_rows.append(
                {
                    "layer": layer,
                    "unit": unit,
                    "concept_set": concept_label,
                    "sample_type": "pair",
                    "sample_idx": k,
                    "mode": "baseline",
                    "loss": float(lb),
                    "text": text,
                }
            )
            results_rows.append(
                {
                    "layer": layer,
                    "unit": unit,
                    "concept_set": concept_label,
                    "sample_type": "pair",
                    "sample_idx": k,
                    "mode": "ablate",
                    "loss": float(la),
                    "text": text,
                }
            )

print("\nFinished running losses for all features and prompts.")


Loading SAE for layer 0: release=gpt2-small-resid-post-v5-32k, sae_id=blocks.0.hook_resid_post

=== Evaluating feature (layer=0, unit=259) ===

=== Evaluating feature (layer=0, unit=2002) ===

=== Evaluating feature (layer=0, unit=2236) ===

=== Evaluating feature (layer=0, unit=2332) ===

Finished running losses for all features and prompts.


In [11]:
# CELL 10: ASSEMBLE RESULTS AND (OPTIONALLY) SAVE RAW RESULTS

results_df = pd.DataFrame(results_rows)
print("Built results_df with", len(results_df), "rows")

# Optional: save raw per-sample results (commented out by default to avoid huge files)
# out_path = os.path.join(OUTPUT_DIR, "runtime_collision_gpt2_small_sae_raw.csv")
# results_df.to_csv(out_path, index=False)
# print(f"Saved raw results to: {out_path}")


Built results_df with 144 rows


In [12]:
# CELL 11: COLLISION METRICS (PER-SAMPLE -> PER-FEATURE) + JOIN WITH PRISM

# 1) PER-SAMPLE DELTA LOSS (ABLATE - BASELINE)

pivot_cols = ["layer", "unit", "concept_set", "sample_type", "sample_idx"]
pivot_df = results_df.pivot_table(
    index=pivot_cols,
    columns="mode",
    values="loss"
).reset_index()

pivot_df["delta"] = pivot_df["ablate"] - pivot_df["baseline"]

# 2) PER (FEATURE, CONCEPT_SET, SAMPLE_TYPE) SUMMARY

group_cols = ["layer", "unit", "sample_type", "concept_set"]

agg_df = pivot_df.groupby(group_cols).agg(
    mean_baseline=("baseline", "mean"),
    mean_ablate=("ablate", "mean"),
    mean_delta=("delta", "mean"),
    n_samples=("delta", "size"),
).reset_index()

# 3) PER-FEATURE SUMMARY (SINGLE vs PAIR) + COLLISION PENALTY

singles = agg_df[agg_df["sample_type"] == "single"]
pairs   = agg_df[agg_df["sample_type"] == "pair"]

single_feat = singles.groupby(["layer", "unit"]).agg(
    delta_single_mean=("mean_delta", "mean"),
    delta_single_std=("mean_delta", "std"),
    n_single_sets=("mean_delta", "size"),
).reset_index()

pair_feat = pairs.groupby(["layer", "unit"]).agg(
    delta_pair_mean=("mean_delta", "mean"),
    delta_pair_std=("mean_delta", "std"),
    n_pair_sets=("mean_delta", "size"),
).reset_index()

feat_summary = single_feat.merge(
    pair_feat,
    on=["layer", "unit"],
    how="outer",
    suffixes=("_single", "_pair")
)

feat_summary["collision_penalty"] = (
    feat_summary["delta_pair_mean"] - feat_summary["delta_single_mean"]
)

# 4) JOIN WITH PRISM METRICS

metrics_cols = [
    "layer",
    "unit",
    "cosine_similarity",
    "cosine_similarity_random",
    "max_auc",
    "max_mad",
]
metrics_unique = merged_df[metrics_cols].drop_duplicates(subset=["layer", "unit"])

feat_with_metrics = feat_summary.merge(
    metrics_unique,
    on=["layer", "unit"],
    how="left"
)

print("Built feat_with_metrics with", len(feat_with_metrics), "features")

# Optional: save feature-level metrics (collision_penalty + PRISM metrics)
# feat_out_path = os.path.join(OUTPUT_DIR, "runtime_collision_gpt2_small_sae_features.csv")
# feat_with_metrics.to_csv(feat_out_path, index=False)
# print(f"Saved feature-level metrics to: {feat_out_path}")


Built feat_with_metrics with 4 features


In [13]:
feat_summary

Unnamed: 0,layer,unit,delta_single_mean,delta_single_std,n_single_sets,delta_pair_mean,delta_pair_std,n_pair_sets,collision_penalty
0,0,259,-3.570649,4.004694,3,-1.984439,3.30375,3,1.586209
1,0,2002,-6.210857,9.03833,3,-22.715263,4.778617,3,-16.504406
2,0,2236,-0.047184,0.0608,3,-18.213164,31.447434,3,-18.16598
3,0,2332,-0.163259,0.304153,3,-0.906015,1.540595,3,-0.742756
