# Select Discriminative Features per Role

Filter the `[n_roles, sae_dim]` aggregated matrix with 3 conditions:
1. **Prefers a role** — one role's activation is meaningfully higher than the rest
2. **Fires often enough** — appears in >= X% of that role's responses
3. **Not flat** — activation varies across roles (high CV)

In [1]:
import numpy as np
from pathlib import Path
from collections import defaultdict
from tqdm.auto import tqdm

AGGREGATED_DIR = Path("../outputs/aggregated/general/gemma-3-27b-it_layer_40_width_65k_l0_medium")
FEATURES_DIR = Path("../outputs/features/general/gemma-3-27b-it_layer_40_width_65k_l0_medium")
STRATEGY = "mean"
NEURONPEDIA_ID = "gemma-3-27b-it/40-gemmascope-2-res-65k"

# Tune these to get ~10-20 features per role
MIN_PREFERENCE_RATIO = 5.0  # C1: top role activates >= Nx the average of other roles
MIN_FIRING_RATE = 0.5      # C2: fires in >= X% of preferred role's responses
MIN_CV = 0.5                # C3: coefficient of variation across roles

data = np.load(AGGREGATED_DIR / STRATEGY / "per_role.npz", allow_pickle=True)
features = data["features"]       # [n_roles, sae_dim]
role_names = data["role_names"]   # [n_roles]
n_roles, sae_dim = features.shape
print(f"{n_roles} roles, {sae_dim} SAE features, strategy={STRATEGY}")

7 roles, 65536 SAE features, strategy=mean


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Condition 3: not flat (high CV across roles)
feat_mean = features.mean(axis=0)
feat_std = features.std(axis=0)
cv = np.divide(feat_std, feat_mean, where=feat_mean > 0, out=np.zeros(sae_dim))
pass_c3 = cv >= MIN_CV

# Condition 1: one role clearly dominates
preferred_role_idx = features.argmax(axis=0)  # [sae_dim]
top_val = features.max(axis=0)
mean_rest = (features.sum(axis=0) - top_val) / (n_roles - 1)
pref_ratio = np.divide(top_val, mean_rest, where=mean_rest > 0, out=np.full(sae_dim, np.inf))
pass_c1 = pref_ratio >= MIN_PREFERENCE_RATIO

candidate_idx = np.where(pass_c1 & pass_c3 & (feat_mean > 0))[0]
print(f"C1 (pref ratio >= {MIN_PREFERENCE_RATIO}): {pass_c1.sum()}")
print(f"C3 (CV >= {MIN_CV}): {pass_c3.sum()}")
print(f"Candidates (C1 & C3): {len(candidate_idx)}")

C1 (pref ratio >= 5.0): 20749
C3 (CV >= 0.5): 42378
Candidates (C1 & C3): 17512


In [3]:
# Condition 2: fires in >= X% of preferred role's responses
# Group candidates by preferred role to load each file once
role_to_feats = defaultdict(list)
for fi in candidate_idx:
    role_to_feats[preferred_role_idx[fi]].append(fi)

firing_rate = {}
feat_key = f"{STRATEGY}_features"

for ri, feat_list in tqdm(role_to_feats.items(), desc="Computing firing rates"):
    with np.load(FEATURES_DIR / f"{role_names[ri]}.npz") as npz:
        resp = npz[feat_key]  # [n_responses, sae_dim]
        rates = (resp[:, feat_list] > 0).mean(axis=0)
        for fi, rate in zip(feat_list, rates):
            firing_rate[fi] = rate

selected = np.array([fi for fi in candidate_idx if firing_rate[fi] >= MIN_FIRING_RATE])
print(f"C2 (firing rate >= {MIN_FIRING_RATE}): {len(selected)} of {len(candidate_idx)} candidates passed")
print(f"Total selected features: {len(selected)}")

Computing firing rates:   0%|          | 0/7 [00:00<?, ?it/s]

Computing firing rates: 100%|██████████| 7/7 [00:03<00:00,  1.93it/s]

C2 (firing rate >= 0.5): 875 of 17512 candidates passed
Total selected features: 875





In [4]:
# Group selected features by preferred role, sorted by activation
role_features = defaultdict(list)
for fi in selected:
    role_features[role_names[preferred_role_idx[fi]]].append(fi)

for role in role_features:
    ri = np.where(role_names == role)[0][0]
    role_features[role].sort(key=lambda fi, ri=ri: features[ri, fi], reverse=True)

counts = [len(v) for v in role_features.values()]
print(f"Roles with features: {len(role_features)} / {n_roles}")
if counts:
    print(f"Features per role: median={np.median(counts):.0f}, mean={np.mean(counts):.1f}, min={min(counts)}, max={max(counts)}")

Roles with features: 7 / 7
Features per role: median=105, mean=125.0, min=27, max=278


In [25]:
def show_role(role, k=20):
    if role not in role_features:
        print(f"{role}: no features passed filtering")
        return
    ri = np.where(role_names == role)[0][0]
    feats = role_features[role][:k]
    print(f"\n=== {role} ({len(role_features[role])} features) ===")
    for fi in feats:
        url = f"https://neuronpedia.org/{NEURONPEDIA_ID}/{fi}"
        print(f"  {fi:>6d}  act={features[ri, fi]:.4f}  ratio={pref_ratio[fi]:.1f}x  cv={cv[fi]:.2f}  fire={firing_rate[fi]:.0%}  {url}")

for role in ["biologist", "assistant", "doctor", "mathematician", "historian"]:
    if role in role_names:
        show_role(role)


=== biologist (123 features) ===
    2074  act=659.8511  ratio=46.3x  cv=2.12  fire=100%  https://neuronpedia.org/gemma-3-27b-it/40-gemmascope-2-res-65k/2074
    3280  act=260.5252  ratio=146.3x  cv=2.34  fire=100%  https://neuronpedia.org/gemma-3-27b-it/40-gemmascope-2-res-65k/3280
    6986  act=164.0516  ratio=118.9x  cv=2.31  fire=93%  https://neuronpedia.org/gemma-3-27b-it/40-gemmascope-2-res-65k/6986
   30793  act=150.0612  ratio=24.4x  cv=1.89  fire=79%  https://neuronpedia.org/gemma-3-27b-it/40-gemmascope-2-res-65k/30793
   18366  act=124.2868  ratio=9.0x  cv=1.31  fire=91%  https://neuronpedia.org/gemma-3-27b-it/40-gemmascope-2-res-65k/18366
    1763  act=121.8418  ratio=6.2x  cv=1.26  fire=85%  https://neuronpedia.org/gemma-3-27b-it/40-gemmascope-2-res-65k/1763
    1779  act=120.3632  ratio=17.2x  cv=1.71  fire=77%  https://neuronpedia.org/gemma-3-27b-it/40-gemmascope-2-res-65k/1779
    3107  act=105.1357  ratio=34.3x  cv=2.03  fire=92%  https://neuronpedia.org/gemma-3-27b-it

## Feature Map (Decoder Weight Space)

In [6]:
import pandas as pd
import plotly.express as px
from sae_lens import SAE
from umap import UMAP
from numpy.linalg import norm

sae = SAE.from_pretrained(release="gemma-scope-2-27b-it-res", sae_id="layer_40_width_65k_l0_medium", device="cuda")
W_dec = sae.W_dec.detach().cpu().numpy()  # [d_sae, d_model]
del sae

W_selected = W_dec[selected]  # [n_selected, d_model]
n_neighbors = min(15, len(selected) - 1)
embedding = UMAP(n_components=2, metric="cosine", n_neighbors=n_neighbors, min_dist=0.1, random_state=42).fit_transform(W_selected)

feat_df = pd.DataFrame({
    "x": embedding[:, 0],
    "y": embedding[:, 1],
    "feature_idx": selected,
    "preferred_role": [role_names[preferred_role_idx[fi]] for fi in selected],
    "activation": [features[preferred_role_idx[fi], fi] for fi in selected],
    "pref_ratio": [pref_ratio[fi] for fi in selected],
    "firing_rate": [firing_rate[fi] for fi in selected],
    "neuronpedia": [f"https://neuronpedia.org/{NEURONPEDIA_ID}/{fi}" for fi in selected],
})
print(f"Projected {len(selected)} features from {W_dec.shape[1]}D to 2D")

  warn(


Projected 875 features from 5376D to 2D


In [17]:
import requests
from concurrent.futures import ThreadPoolExecutor

def fetch_description(feat_idx):
    url = f"https://neuronpedia.org/api/feature/{NEURONPEDIA_ID}/{feat_idx}"
    try:
        data = requests.get(url, timeout=10).json()
        if data.get("explanations"):
            return data["explanations"][0].get("description", "")
    except Exception:
        pass
    return ""

with ThreadPoolExecutor(max_workers=10) as pool:
    descriptions = list(tqdm(pool.map(fetch_description, selected), total=len(selected), desc="Fetching descriptions"))

feat_df["description"] = descriptions
print(f"Fetched {sum(1 for d in descriptions if d):,} / {len(selected):,} descriptions")

Fetching descriptions: 100%|██████████| 875/875 [02:19<00:00,  6.28it/s]

Fetched 874 / 875 descriptions





In [24]:
HOVER_COLS = ["description", "preferred_role", "activation", "pref_ratio", "firing_rate"]
HOVER_TEMPLATE = (
    "<b>Feature %{hovertext}</b><br>"
    "%{customdata[0]}<br>"
    "Role: %{customdata[1]}<br>"
    "Act: %{customdata[2]:.4f} · Ratio: %{customdata[3]:.1f}x · Fire: %{customdata[4]:.0%}"
    "<extra></extra>"
)

def style_feature_map(fig, showlegend=False):
    fig.update_traces(
        hovertemplate=HOVER_TEMPLATE,
        marker=dict(size=5, opacity=0.75, line=dict(width=0.5, color="white")),
    )
    fig.update_layout(
        template="simple_white",
        width=1100, height=800,
        showlegend=showlegend,
        xaxis=dict(showticklabels=False, title="", showgrid=False, zeroline=False, showline=False),
        yaxis=dict(showticklabels=False, title="", showgrid=False, zeroline=False, showline=False),
    )
    return fig

fig = px.scatter(
    feat_df, x="x", y="y",
    color="preferred_role",
    hover_name="feature_idx",
    hover_data=HOVER_COLS,
    title=f"Feature Decoder Map — {len(selected)} discriminative features",
)
style_feature_map(fig).show()