# Attribution Demo for Neuronpedia



Format:
- Pick a prompt. 
- Look at next token probs
- Pick a positive and a negative token. 
- Pick an SAE.
- Do attribution to features. 


In [3]:
from dataclasses import dataclass
from functools import partial
from typing import Any, Literal, NamedTuple, Callable

import torch
from sae_lens import SAE, HookedSAETransformer
from transformer_lens import HookedTransformer
from transformer_lens.hook_points import HookPoint


class SaeReconstructionCache(NamedTuple):
    sae_in: torch.Tensor
    feature_acts: torch.Tensor
    sae_out: torch.Tensor
    sae_error: torch.Tensor


def track_grad(tensor: torch.Tensor) -> None:
    """wrapper around requires_grad and retain_grad"""
    tensor.requires_grad_(True)
    tensor.retain_grad()


@dataclass
class ApplySaesAndRunOutput:
    model_output: torch.Tensor
    model_activations: dict[str, torch.Tensor]
    sae_activations: dict[str, SaeReconstructionCache]

    def zero_grad(self) -> None:
        """Helper to zero grad all tensors in this object."""
        self.model_output.grad = None
        for act in self.model_activations.values():
            act.grad = None
        for cache in self.sae_activations.values():
            cache.sae_in.grad = None
            cache.feature_acts.grad = None
            cache.sae_out.grad = None
            cache.sae_error.grad = None


def apply_saes_and_run(
    model: HookedTransformer,
    saes: dict[str, SAE],
    input: Any,
    include_error_term: bool = True,
    track_model_hooks: list[str] | None = None,
    return_type: Literal["logits", "loss"] = "logits",
    track_grads: bool = False,
) -> ApplySaesAndRunOutput:
    """
    Apply the SAEs to the model at the specific hook points, and run the model.
    By default, this will include a SAE error term which guarantees that the SAE
    will not affect model output. This function is designed to work correctly with
    backprop as well, so it can be used for gradient-based feature attribution.

    Args:
        model: the model to run
        saes: the SAEs to apply
        input: the input to the model
        include_error_term: whether to include the SAE error term to ensure the SAE doesn't affect model output. Default True
        track_model_hooks: a list of hook points to record the activations and gradients. Default None
        return_type: this is passed to the model.run_with_hooks function. Default "logits"
        track_grads: whether to track gradients. Default False
    """

    fwd_hooks = []
    bwd_hooks = []

    sae_activations: dict[str, SaeReconstructionCache] = {}
    model_activations: dict[str, torch.Tensor] = {}

    # this hook just track the SAE input, output, features, and error. If `track_grads=True`, it also ensures
    # that requires_grad is set to True and retain_grad is called for intermediate values.
    def reconstruction_hook(
        sae_in: torch.Tensor, hook: HookPoint, hook_point: str
    ):  # noqa: ARG001
        sae = saes[hook_point]
        feature_acts = sae.encode(sae_in)
        sae_out = sae.decode(feature_acts)
        sae_error = (sae_in - sae_out).detach().clone()
        if track_grads:
            track_grad(sae_error)
            track_grad(sae_out)
            track_grad(feature_acts)
            track_grad(sae_in)
        sae_activations[hook_point] = SaeReconstructionCache(
            sae_in=sae_in,
            feature_acts=feature_acts,
            sae_out=sae_out,
            sae_error=sae_error,
        )

        if include_error_term:
            return sae_out + sae_error
        return sae_out

    def sae_bwd_hook(output_grads: torch.Tensor, hook: HookPoint):  # noqa: ARG001
        # this just passes the output grads to the input, so the SAE gets the same grads despite the error term hackery
        return (output_grads,)

    # this hook just records model activations, and ensures that intermediate activations have gradient tracking turned on if needed
    def tracking_hook(
        hook_input: torch.Tensor, hook: HookPoint, hook_point: str
    ):  # noqa: ARG001
        model_activations[hook_point] = hook_input
        if track_grads:
            track_grad(hook_input)
        return hook_input

    for hook_point in saes.keys():
        fwd_hooks.append(
            (hook_point, partial(reconstruction_hook, hook_point=hook_point))
        )
        bwd_hooks.append((hook_point, sae_bwd_hook))
    for hook_point in track_model_hooks or []:
        fwd_hooks.append((hook_point, partial(tracking_hook, hook_point=hook_point)))

    # now, just run the model while applying the hooks
    with model.hooks(fwd_hooks=fwd_hooks, bwd_hooks=bwd_hooks):
        model_output = model(input, return_type=return_type)

    return ApplySaesAndRunOutput(
        model_output=model_output,
        model_activations=model_activations,
        sae_activations=sae_activations,
    )
    
from dataclasses import dataclass
from transformer_lens.hook_points import HookPoint
from dataclasses import dataclass
from functools import partial
from typing import Any, Literal, NamedTuple

import torch
from sae_lens import SAE
from transformer_lens import HookedTransformer
from transformer_lens.hook_points import HookPoint

EPS = 1e-8

torch.set_grad_enabled(True)


@dataclass
class AttributionGrads:
    metric: torch.Tensor
    model_output: torch.Tensor
    model_activations: dict[str, torch.Tensor]
    sae_activations: dict[str, SaeReconstructionCache]


@dataclass
class Attribution:
    model_attributions: dict[str, torch.Tensor]
    model_activations: dict[str, torch.Tensor]
    model_grads: dict[str, torch.Tensor]
    sae_feature_attributions: dict[str, torch.Tensor]
    sae_feature_activations: dict[str, torch.Tensor]
    sae_feature_grads: dict[str, torch.Tensor]
    sae_errors_attribution_proportion: dict[str, float]


def calculate_attribution_grads(
    model: HookedSAETransformer,
    prompt: str,
    metric_fn: Callable[[torch.Tensor], torch.Tensor],
    track_hook_points: list[str] | None = None,
    include_saes: dict[str, SAE] | None = None,
    return_logits: bool = True,
    include_error_term: bool = True,
) -> AttributionGrads:
    """
    Wrapper around apply_saes_and_run that calculates gradients wrt to the metric_fn.
    Tracks grads for both SAE feature and model neurons, and returns them in a structured format.
    """
    output = apply_saes_and_run(
        model,
        saes=include_saes or {},
        input=prompt,
        return_type="logits" if return_logits else "loss",
        track_model_hooks=track_hook_points,
        include_error_term=include_error_term,
        track_grads=True,
    )
    metric = metric_fn(output.model_output)
    output.zero_grad()
    metric.backward()
    return AttributionGrads(
        metric=metric,
        model_output=output.model_output,
        model_activations=output.model_activations,
        sae_activations=output.sae_activations,
    )


def calculate_feature_attribution(
    model: HookedSAETransformer,
    input: Any,
    metric_fn: Callable[[torch.Tensor], torch.Tensor],
    track_hook_points: list[str] | None = None,
    include_saes: dict[str, SAE] | None = None,
    return_logits: bool = True,
    include_error_term: bool = True,
) -> Attribution:
    """
    Calculate feature attribution for SAE features and model neurons following
    the procedure in https://transformer-circuits.pub/2024/march-update/index.html#feature-heads.
    This include the SAE error term by default, so inserting the SAE into the calculation is
    guaranteed to not affect the model output. This can be disabled by setting `include_error_term=False`.

    Args:
        model: The model to calculate feature attribution for.
        input: The input to the model.
        metric_fn: A function that takes the model output and returns a scalar metric.
        track_hook_points: A list of model hook points to track activations for, if desired
        include_saes: A dictionary of SAEs to include in the calculation. The key is the hook point to apply the SAE to.
        return_logits: Whether to return the model logits or loss. This is passed to TLens, so should match whatever the metric_fn expects (probably logits)
        include_error_term: Whether to include the SAE error term in the calculation. This is recommended, as it ensures that the SAE will not affecting the model output.
    """
    # first, calculate gradients wrt to the metric_fn.
    # these will be multiplied with the activation values to get the attributions
    outputs_with_grads = calculate_attribution_grads(
        model,
        input,
        metric_fn,
        track_hook_points,
        include_saes=include_saes,
        return_logits=return_logits,
        include_error_term=include_error_term,
    )
    model_attributions = {}
    model_activations = {}
    model_grads = {}
    sae_feature_attributions = {}
    sae_feature_activations = {}
    sae_feature_grads = {}
    sae_error_proportions = {}
    # this code is long, but all it's doing is multiplying the grads by the activations
    # and recording grads, acts, and attributions in dictionaries to return to the user
    with torch.no_grad():
        for name, act in outputs_with_grads.model_activations.items():
            assert act.grad is not None
            raw_activation = act.detach().clone()
            model_attributions[name] = (act.grad * raw_activation).detach().clone()
            model_activations[name] = raw_activation
            model_grads[name] = act.grad.detach().clone()
        for name, act in outputs_with_grads.sae_activations.items():
            assert act.feature_acts.grad is not None
            assert act.sae_out.grad is not None
            raw_activation = act.feature_acts.detach().clone()
            sae_feature_attributions[name] = (
                (act.feature_acts.grad * raw_activation).detach().clone()
            )
            sae_feature_activations[name] = raw_activation
            sae_feature_grads[name] = act.feature_acts.grad.detach().clone()
            if include_error_term:
                assert act.sae_error.grad is not None
                error_grad_norm = act.sae_error.grad.norm().item()
            else:
                error_grad_norm = 0
            sae_out_norm = act.sae_out.grad.norm().item()
            sae_error_proportions[name] = error_grad_norm / (
                sae_out_norm + error_grad_norm + EPS
            )
        return Attribution(
            model_attributions=model_attributions,
            model_activations=model_activations,
            model_grads=model_grads,
            sae_feature_attributions=sae_feature_attributions,
            sae_feature_activations=sae_feature_activations,
            sae_feature_grads=sae_feature_grads,
            sae_errors_attribution_proportion=sae_error_proportions,
        )


In [7]:

device = "mps"

model = HookedSAETransformer.from_pretrained("gpt2-small")

sae, cfg_dict, sparsity = SAE.from_pretrained(
        release = "gpt2-small-res-jb",
        sae_id = "blocks.10.hook_resid_pre",
        device = device
)

In [14]:


prompt = " Tiger Woods plays the sport of"
pos_token = model.tokenizer.encode(" golf")[0]
neg_token = model.tokenizer.encode(" tennis")[0]


Tokenized prompt: ['<|endoftext|>', ' Tiger', ' Woods', ' plays', ' the', ' sport', ' of']
Tokenized answer: [' golf']


Top 0th token. Logit: 16.41 Prob: 59.12% Token: | golf|
Top 1th token. Logit: 13.69 Prob:  3.90% Token: | tennis|
Top 2th token. Logit: 13.49 Prob:  3.21% Token: | Tiger|
Top 3th token. Logit: 13.18 Prob:  2.35% Token: | baseball|
Top 4th token. Logit: 12.80 Prob:  1.60% Token: | basketball|
Top 5th token. Logit: 12.59 Prob:  1.29% Token: | football|
Top 6th token. Logit: 12.31 Prob:  0.98% Token: | the|
Top 7th token. Logit: 11.76 Prob:  0.56% Token: | cricket|
Top 8th token. Logit: 11.74 Prob:  0.56% Token: | gol|
Top 9th token. Logit: 11.60 Prob:  0.48% Token: | Golf|


In [62]:
import pandas as pd 
from typing import Optional

def test_prompt_pandas(
    prompt: str,
    model: HookedTransformer,  # Can't give type hint due to circular imports
    prepend_bos: Optional[bool] = True,
    top_k: int = 10,
) -> pd.DataFrame:
    """Test if the Model Can Give the Correct Answer to a Prompt.
    Returns results as a pandas DataFrame containing token predictions and their metrics.

    Args: [previous args remain the same]

    Returns:
        pd.DataFrame: Contains columns for:
            - answer_token: The expected next token
            - rank: Position of the answer token in sorted predictions
            - logit: Raw logit value for the answer token  
            - probability: Probability assigned to the answer token
            - is_answer: Boolean indicating if this is the actual answer token
    """
        
    prompt_tokens = model.to_tokens(prompt, prepend_bos=prepend_bos)

    logits = model(prompt_tokens)

    if logits.shape[0] == 1:
        logits = logits.squeeze(0)
    else:
        logits = logits

    probs = logits.softmax(dim=-1)
    
    # get top k predictions
    top_k_probs, top_k_inds = probs[-1].topk(top_k)
    
    top_k_probs_str = [model.to_string(v) for v in top_k_inds]
    
    # make a table with the tokem, rank, logit, probability, and is_answer
    df = pd.DataFrame({
        'token': top_k_probs_str,
        'rank': torch.arange(len(top_k_probs_str)),
        'logit': logits[-1, top_k_inds].detach().cpu().numpy(),
        'probability': top_k_probs.detach().cpu().numpy(),
        'is_answer': [t == " golf" for t in top_k_probs_str]
    })
    
    return df

from transformer_lens.utils import test_prompt
test_prompt(prompt, "golf", model)
test_prompt_pandas(prompt, model)



Tokenized prompt: ['<|endoftext|>', ' Tiger', ' Woods', ' plays', ' the', ' sport', ' of']
Tokenized answer: [' golf']


Top 0th token. Logit: 16.41 Prob: 59.12% Token: | golf|
Top 1th token. Logit: 13.69 Prob:  3.90% Token: | tennis|
Top 2th token. Logit: 13.49 Prob:  3.21% Token: | Tiger|
Top 3th token. Logit: 13.18 Prob:  2.35% Token: | baseball|
Top 4th token. Logit: 12.80 Prob:  1.60% Token: | basketball|
Top 5th token. Logit: 12.59 Prob:  1.29% Token: | football|
Top 6th token. Logit: 12.31 Prob:  0.98% Token: | the|
Top 7th token. Logit: 11.76 Prob:  0.56% Token: | cricket|
Top 8th token. Logit: 11.74 Prob:  0.56% Token: | gol|
Top 9th token. Logit: 11.60 Prob:  0.48% Token: | Golf|


Unnamed: 0,token,rank,logit,probability,is_answer
0,golf,0,16.408287,0.591173,True
1,tennis,1,13.689954,0.039008,False
2,Tiger,2,13.493741,0.032059,False
3,baseball,3,13.182872,0.023493,False
4,basketball,4,12.801028,0.016036,False
5,football,5,12.586832,0.012944,False
6,the,6,12.305334,0.009768,False
7,cricket,7,11.756577,0.005643,False
8,gol,8,11.743678,0.005571,False
9,Golf,9,11.604863,0.004849,False


In [15]:


def metric_fn(
    logits: torch.tensor,
    pos_token: torch.tensor = pos_token,
    neg_token: torch.Tensor = neg_token,
) -> torch.Tensor:
    return logits[0, -1, pos_token] - logits[0, -1, neg_token]


attribution_output = calculate_feature_attribution(
    input=prompt,
    model=model,
    metric_fn=metric_fn,
    include_saes={sae.cfg.hook_name: sae},
    include_error_term=True,
    return_logits=True,
)

feature_attribution_df = attribution_output.sae_feature_attributions[sae.cfg.hook_name]
feature_attribution_df

tensor([[[0., 0., 0.,  ..., -0., -0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [-0., -0., 0.,  ..., -0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., -0., -0., 0.],
         [0., 0., 0.,  ..., -0., -0., 0.],
         [0., 0., -0.,  ..., -0., -0., 0.]]], device='mps:0')

In [18]:
import plotly.express as px

tokens = model.to_str_tokens(prompt)
unique_tokens = [f"{i}/{t}" for i, t in enumerate(tokens)]

px.bar(
    x=unique_tokens,
    y=attribution_output.sae_feature_attributions[sae.cfg.hook_name][0]
    .sum(-1)
    .detach()
    .cpu()
    .numpy(),
)

In [21]:
import pandas as pd

df_long_nonzero.sort_values("attribution", ascending=False)

Unnamed: 0,position,feature,attribution
176,2,14535,0.388522
11,2,1059,0.344248
275,2,23219,0.124019
120,2,10499,0.117193
158,2,13075,0.087566
...,...,...,...
281,6,23581,-0.091620
0,1,332,-0.174772
47,6,3646,-0.195749
108,6,8880,-0.226573


In [24]:
from IPython.display import IFrame

# get a random feature from the SAE
feature_idx = torch.randint(0, sae.cfg.d_sae, (1,)).item()

html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"


def get_dashboard_html(sae_release="gpt2-small", sae_id="7-res-jb", feature_idx=0):
    return html_template.format(sae_release, sae_id, feature_idx)


html = get_dashboard_html(
    sae_release="gpt2-small", sae_id="7-res-jb", feature_idx=feature_idx
)
IFrame(html, width=1200, height=600)

In [25]:
for i, v in (
    df_long_nonzero.query("position==2")
    .groupby("feature")
    .attribution.sum()
    .sort_values(ascending=False)
    .head(5)
    .items()
):
    print(f"Feature {i} had a total attribution of {v:.2f}")
    html = get_dashboard_html(
        sae_release="gpt2-small",
        sae_id=f"{sae.cfg.hook_layer}-res-jb",
        feature_idx=int(i),
    )
    display(IFrame(html, width=1200, height=300))

Feature 14535 had a total attribution of 0.39


Feature 1059 had a total attribution of 0.34


Feature 23219 had a total attribution of 0.12


Feature 10499 had a total attribution of 0.12


Feature 13075 had a total attribution of 0.09
