# Visualizing and Understanding SAE Features

In [1]:
%load_ext autoreload
%autoreload 2

from fsrl.utils import SAEfeatureAnalyzer
from fsrl import SAEAdapter, HookedModel
import os
from dotenv import load_dotenv
import torch
from transformer_lens import HookedTransformer
import json

load_dotenv()

False

Instantiating GPT2 with its associated SAE:

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

release = "gpt2-small-res-jb"
sae_id = "blocks.7.hook_resid_pre"

adapter_kwargs = {
    "use_lora_adapter": True,
    "lora_rank": 64,
    "lora_alpha": 32,
    "fusion_mode": "additive",
}

sae, cfg_dict, sparsity = SAEAdapter.from_pretrained(release, sae_id, device=device, **adapter_kwargs)
model = HookedTransformer.from_pretrained("gpt2-small", device=device)
sae_model = HookedModel(model, sae)

This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


Loaded pretrained model gpt2-small into HookedTransformer


The `SAEfeatureAnalyzer` is designed to be an easy to use class that takes in hooked model with an SAE and allows for inspection of the SAE. This includes:
* Retrieving feature explanations.
* Visualizing the logit distribution which characterizes how much each token's logit increases when feature $i$ activates.
* Retrieve SAE feature dashboards.

In [3]:
os.environ["NEURONPEDIA_API_KEY"] = "sk-np-0KyQNyOITxPjUbju5ru0JR7CXIhg1H2gV6VVYHqWrOc0"
sae_analyzer = SAEfeatureAnalyzer(sae_model)

Fetching all explanations for gpt2-small/7-res-jb...
Successfully loaded 24570 feature explanations.


In [4]:
print(sae_analyzer.feature_info[0]['description'])

mathematical equations involving variables and functions


In [5]:
sae_analyzer.get_feature_page(0)

## Logit Weight Distribution

Let $W_{\text{dec}} \in \mathbb{R}^{d_{\text{sae}} \times d_{\text{model}}}
$ be the decoder weight matrix of the SAE where $d_{\text{sae}}$ is the dimensionality of the SAEs latent space and the $d_{\text{model}}$ the dimensionality of the models embeddings. Let $W_U \in \mathbb{R}^{d_{\text{model}} \times |V|}$ denote the unembedding matrix from the Transformer, mapping the residuals to logits and where $|V|$ is the vocab size.

We are interested in the following quantity:
$$
W_{\text{dec}} W_U \in \mathbb{R}^{d_{\text{sae}} \times |V|}
$$
Each row is the logit weight distribution for a feature -- it says: ''If this feature activates +1, how much does it increase the logits of each token in the vocabulary?''

Given $w^{(i)} = W_{\text{dec}}[i] \cdot W_U \in \mathbb{R}^{|V|}$:

1. **Mean**:

$$
\mu_i = \frac{1}{|V|} \sum_{t=1}^{|V|} w^{(i)}_t
$$

— Average logit shift across all tokens.

2. **Standard Deviation**:

$$
\sigma_i = \sqrt{ \frac{1}{|V|} \sum_{t=1}^{|V|} (w^{(i)}_t - \mu_i)^2 }
$$

— Measures how token-selective the feature is.

3. **Skewness**:

$$
\text{skew}_i = \frac{1}{|V|} \sum_{t=1}^{|V|} \left( \frac{w^{(i)}_t - \mu_i}{\sigma_i} \right)^3
$$

— Indicates whether the feature boosts a few tokens disproportionately. Skewness quantifies the asymmetry of a distribution. A distribution with a long tail to the right has positive skew, while one with a long tail to the left has negative skew.

4. **Kurtosis**:

$$
\text{kurt}_i = \frac{1}{|V|} \sum_{t=1}^{|V|} \left( \frac{w^{(i)}_t - \mu_i}{\sigma_i} \right)^4
$$

— Measures how "peaked" the distribution is (e.g., sharp preference for a few tokens). Kurtosis measures the heaviness of a distribution’s tails. Values greater than 3 indicate heavier tails compared to a normal distribution.


In [6]:
# Looking at the 3rd/4th moments
sae_analyzer.plot_logit_distr_skewness()

In [7]:
sae_analyzer.plot_logit_distr_kurtosis()

In [8]:
# Looking at their joint distribution
sae_analyzer.plot_logit_distr_skewness_vs_kurtosis()

Note that this plot may differ per layer. In this case we are looking at the distribution assuming layer 7. The plot helps identify feature types:
  * Local Context: High kurtosis, low std; promote specific tokens like brackets or quotes.

  * Partition: High std, right skew, low kurtosis; often bimodal, linked to spacing/capitalization.
* Prediction: High skew, moderate std; promote token groups like digits or verbs.

* Suppression: Negative skew, low kurtosis; likely suppress specific tokens, overlapping with partition features.

See https://www.alignmentforum.org/posts/qykrYY6rXXM7EEs8Q/understanding-sae-features-with-the-logit-lens for more details.