In [1]:
import gc
import itertools
import math
import os
import random
import sys
from collections import Counter, defaultdict
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Any, Callable, Literal, TypeAlias

import circuitsvis as cv
import einops
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import requests
import torch as t
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from IPython.display import HTML, IFrame, clear_output, display
from jaxtyping import Float, Int
from openai import OpenAI
from rich import print as rprint
from rich.table import Table
from sae_lens import (
    SAE,
    ActivationsStore,
    HookedSAETransformer,
    LanguageModelSAERunnerConfig,
    SAEConfig,
    SAETrainingRunner,
    upload_saes_to_huggingface,
)
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
from sae_vis import SaeVisConfig, SaeVisData, SaeVisLayoutConfig
from tabulate import tabulate
from torch import Tensor, nn
from torch.distributions.categorical import Categorical
from torch.nn import functional as F
from tqdm.auto import tqdm
from transformer_lens import ActivationCache, HookedTransformer
from transformer_lens.hook_points import HookPoint
from transformer_lens.utils import get_act_name, test_prompt, to_numpy

device = t.device("mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu")

#from plotly_utils import imshow, line

MAIN = __name__ == "__main__"

In [2]:
gpt2 = HookedSAETransformer.from_pretrained("gpt2-small", device=device)

gpt2_saes = {
    layer: SAE.from_pretrained(
        release="gpt2-small-res-jb",
        sae_id=f"blocks.{layer}.hook_resid_pre",
        device=str(device),
    )[0]
    for layer in tqdm(range(gpt2.cfg.n_layers))
}

Loaded pretrained model gpt2-small into HookedTransformer


  0%|          | 0/12 [00:00<?, ?it/s]

This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


In [9]:
gpt2_saes[4].cfg

SAEConfig(architecture='standard', d_in=768, d_sae=24576, activation_fn_str='relu', apply_b_dec_to_input=True, finetuning_scaling_factor=False, context_size=128, model_name='gpt2-small', hook_name='blocks.4.hook_resid_pre', hook_layer=4, hook_head_index=None, prepend_bos=True, dataset_path='Skylion007/openwebtext', dataset_trust_remote_code=True, normalize_activations='none', dtype='torch.float32', device='cpu', sae_lens_training_version=None, activation_fn_kwargs={}, neuronpedia_id='gpt2-small/4-res-jb', model_from_pretrained_kwargs={'center_writing_weights': True}, seqpos_slice=(None,))

In [3]:
def display_dashboard(
    sae_release="gpt2-small-res-jb",
    sae_id="blocks.7.hook_resid_pre",
    latent_idx=0,
    width=800,
    height=600,
):
    release = get_pretrained_saes_directory()[sae_release]
    neuronpedia_id = release.neuronpedia_id[sae_id]

    url = f"https://neuronpedia.org/{neuronpedia_id}/{latent_idx}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

    print(url)
    display(IFrame(url, width=width, height=height))

In [4]:
class SparseTensor:
    """
    Handles 2D tensor data (assumed to be non-negative) in 2 different formats:
        dense:  The full tensor, which contains zeros. Shape is (n1, ..., nk).
        sparse: A tuple of nonzero values (shape (n_nonzero,)), nonzero indices (shape (n_nonzero, k)), and the shape of
                the dense tensor.
    """

    sparse: tuple[Tensor, Tensor, tuple[int, ...]]
    dense: Tensor

    def __init__(self, sparse: tuple[Tensor, Tensor, tuple[int, ...]], dense: Tensor):
        self.sparse = sparse
        self.dense = dense

    @classmethod
    def from_dense(cls, dense: Tensor) -> "SparseTensor":
        sparse = (dense[dense > 0], t.argwhere(dense > 0), tuple(dense.shape))
        return cls(sparse, dense)

    @classmethod
    def from_sparse(
        cls, sparse: tuple[Tensor, Tensor, tuple[int, ...]]
    ) -> "SparseTensor":
        nonzero_values, nonzero_indices, shape = sparse
        dense = t.zeros(shape, dtype=nonzero_values.dtype, device=nonzero_values.device)
        dense[nonzero_indices.unbind(-1)] = nonzero_values
        return cls(sparse, dense)

    @property
    def values(self) -> Tensor:
        return self.sparse[0].squeeze()

    @property
    def indices(self) -> Tensor:
        return self.sparse[1].squeeze()

    @property
    def shape(self) -> tuple[int, ...]:
        return self.sparse[2]

In [5]:
### LATENT2LATENT ATTRIBUTION ###
def latent_acts_to_later_latent_acts(
    latent_acts_nonzero: Float[Tensor, "nonzero_acts"],
    latent_acts_nonzero_inds: Int[Tensor, "nonzero_acts n_indices"],
    latent_acts_shape: tuple[int, ...],
    sae_from: SAE,
    sae_to: SAE,
    model: HookedSAETransformer,
    error
) -> tuple[Tensor, tuple[Tensor]]:
    """
    Given some latent activations for a residual stream SAE earlier in the model, computes the latent activations of a
    later SAE. It does this by mapping the latent activations through the path SAE decoder -> intermediate model layers
    -> later SAE encoder.

    This function must input & output sparse information (i.e. nonzero values and their indices) rather than dense
    tensors, because latent activations are sparse but jacrev() doesn't support gradients on real sparse tensors.
    """
    latent_acts_from = SparseTensor.from_sparse((latent_acts_nonzero, latent_acts_nonzero_inds, latent_acts_shape)).dense
    resid_stream_from = sae_from.decode(latent_acts_from)

    resid_stream_next = model.forward(
        resid_stream_from + error,
        start_at_layer=sae_from.cfg.hook_layer,
        stop_at_layer=sae_to.cfg.hook_layer,
    )

    latent_acts_next = sae_to.encode(resid_stream_next)

    latent_acts_next = SparseTensor.from_dense(latent_acts_next)

    return latent_acts_next.sparse[0]

def latent_to_latent_gradients(
    tokens: Float[Tensor, "batch seq"],
    sae_from: SAE,
    sae_to: SAE,
    model: HookedSAETransformer,
) -> tuple[Tensor, SparseTensor, SparseTensor, SparseTensor]:
    """
    Computes the gradients between all active pairs of latents belonging to two SAEs.

    Returns:
        latent_latent_gradients:    The gradients between all active pairs of latents
        latent_acts_prev:           The latent activations of the first SAE
        latent_acts_next:           The latent activations of the second SAE
        latent_acts_next_recon:     The reconstructed latent activations of the second SAE (i.e.
                                    based on the first SAE's reconstructions)
    """
    acts_prev_name = f"{sae_from.cfg.hook_name}.hook_sae_acts_post"
    acts_next_name = f"{sae_to.cfg.hook_name}.hook_sae_acts_post"
    sae_from_error_name = f"{sae_from.cfg.hook_name}.hook_sae_error"
    sae_from.use_error_term = True

    with t.no_grad():
        # Get the true activations for both SAEs
        _, cache = model.run_with_cache_with_saes(
            tokens,
            names_filter=[acts_prev_name, acts_next_name, sae_from_error_name],
            stop_at_layer=sae_to.cfg.hook_layer + 1,
            saes=[sae_from, sae_to],
            remove_batch_dim=False,
        )

    latent_acts_prev = SparseTensor.from_dense(cache[acts_prev_name])
    latent_acts_next = SparseTensor.from_dense(cache[acts_next_name])
    sae_from_error = cache[sae_from_error_name]

    latent_acts_to_later_latent_acts_and_gradients = t.func.jacrev(latent_acts_to_later_latent_acts)

    latent_latent_gradients = latent_acts_to_later_latent_acts_and_gradients(
            *latent_acts_prev.sparse, sae_from, sae_to, model, sae_from_error
        )

    latent_latent_attributions = einops.einsum(latent_latent_gradients, latent_acts_prev.sparse[0], 'next_nonzero from_nonzero, from_nonzero -> next_nonzero from_nonzero')


    # Set SAE state back to default
    sae_from.use_error_term = False

    return (
        latent_latent_gradients,
        latent_acts_prev,
        latent_acts_next,
        latent_latent_attributions,
    )

In [6]:
### LATENT2LATENT PLOT ###
prompt = "The Eiffel tower is in Paris"
tokens = gpt2.to_tokens(prompt)
str_toks = gpt2.to_str_tokens(prompt)
layer_from = 0
layer_to = 3

# Get latent-to-latent gradients
t.cuda.empty_cache()
t.set_grad_enabled(True)
(
    latent_latent_gradients,
    latent_acts_prev,
    latent_acts_next,
    latent_latent_attributions,
) = latent_to_latent_gradients(tokens, gpt2_saes[layer_from], gpt2_saes[layer_to], gpt2)
t.set_grad_enabled(False) 

px.imshow(
    to_numpy(latent_latent_attributions.T),
    color_continuous_midpoint=0.0,
    color_continuous_scale="RdBu",
    x=[
        f"F{layer_to}.{latent}, {str_toks[seq]!r} ({seq})"
        for (_, seq, latent) in latent_acts_next.indices
    ],
    y=[
        f"F{layer_from}.{latent}, {str_toks[seq]!r} ({seq})"
        for (_, seq, latent) in latent_acts_prev.indices
    ],
    labels={"x": f"To layer {layer_to}", "y": f"From layer {layer_from}"},
    title=f'Attributions between SAE latents in layer {layer_from} and SAE latents in layer {layer_to}<br><sup>   Prompt: "{"".join(str_toks)}"</sup>',
    width=1600,
    height=1000,
).show()

In [13]:
## Count L0 percentage
total_pairs = gpt2_saes[3].cfg.d_sae ** 2
non_zero_pairs = latent_latent_attributions.numel()
l0_percentage = non_zero_pairs / total_pairs

In [14]:
l0_percentage

2.996789084540473e-05

Observations:
 -- The majority of large attribution pairs are between latents in the residual stream of the same token
 -- What is more interesting to look at is large attributions pairs between latents in residual streams of different tokens
 -- Plot is upper triangular cuz thats how autoregressive transformers work

In [73]:
### Get off staircase latent pairs ###
off_staircase_indices = [] # tuples of the form ([latent_prev_token_idx, latent_prev_id], [latent_next_token_idx, latent_next_id])
for i in range(latent_latent_attributions.shape[0]):
    for j in range(latent_latent_attributions.shape[1]):
        if latent_latent_attributions[i, j].item() > 2:
            latent_next_idx = latent_acts_next.indices[i].tolist()
            latent_prev_idx = latent_acts_prev.indices[j].tolist()
            latent_next_token_idx = latent_next_idx[1]
            latent_prev_token_idx = latent_prev_idx[1]
            if latent_next_token_idx != latent_prev_token_idx:
                off_staircase_indices.append((latent_prev_idx[1:], latent_next_idx[1:]))

In [79]:
str_toks

['<|endoftext|>', 'The', ' E', 'iff', 'el', ' tower', ' is', ' in', ' Paris']

In [75]:
off_staircase_indices

[([2, 16911], [3, 15266]),
 ([3, 2124], [4, 2562]),
 ([3, 15033], [4, 8922]),
 ([2, 16911], [4, 15266]),
 ([3, 2124], [4, 18588]),
 ([7, 11786], [8, 20822])]

In [None]:
# E - iff bigram
display_dashboard(sae_id="blocks.0.hook_resid_pre", latent_idx=16911)
display_dashboard(sae_id="blocks.3.hook_resid_pre", latent_idx=15266)

https://neuronpedia.org/gpt2-small/0-res-jb/16911?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300


https://neuronpedia.org/gpt2-small/3-res-jb/15266?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300


In [None]:
# Iff - el bigram
display_dashboard(sae_id="blocks.0.hook_resid_pre", latent_idx=2124)
display_dashboard(sae_id="blocks.3.hook_resid_pre", latent_idx=2562)

https://neuronpedia.org/gpt2-small/0-res-jb/2124?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300


https://neuronpedia.org/gpt2-small/3-res-jb/2562?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300


In [None]:
# iff - el bigram
display_dashboard(sae_id="blocks.0.hook_resid_pre", latent_idx=15033)
display_dashboard(sae_id="blocks.3.hook_resid_pre", latent_idx=8922)

https://neuronpedia.org/gpt2-small/0-res-jb/15033?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300


https://neuronpedia.org/gpt2-small/3-res-jb/8922?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300


In [83]:
display_dashboard(sae_id="blocks.0.hook_resid_pre", latent_idx=2124)
display_dashboard(sae_id="blocks.3.hook_resid_pre", latent_idx=18588)

https://neuronpedia.org/gpt2-small/0-res-jb/2124?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300


https://neuronpedia.org/gpt2-small/3-res-jb/18588?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300


In [84]:
display_dashboard(sae_id="blocks.0.hook_resid_pre", latent_idx=11786)
display_dashboard(sae_id="blocks.3.hook_resid_pre", latent_idx=20822)

https://neuronpedia.org/gpt2-small/0-res-jb/11786?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300


https://neuronpedia.org/gpt2-small/3-res-jb/20822?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300


On the sparsity of latent-to-latent attributions, which we interpret as forming bigram pairs:
    -- Even for the off-staircase latents, the residual stream locations are almost always adjacent.
        -- The single example of a skip bigram we find in fact also has a normal bigram
    -- Comparison to baseline: sometimes the bigram pair can be clearly explained by looking at the autointerp descriptions, sometimes you would
       never guess from autointerp that a pair of latents would have high attribution
    -- Latent 2124 is part of two separate E - iff bigram circuits

In [15]:
### TOKEN2LATENT GRADIENTS ###
def tokens_to_latent_acts(
    token_scales: Float[Tensor, "batch seq"],
    tokens: Int[Tensor, "batch seq"],
    sae: SAE,
    model: HookedSAETransformer,
) -> tuple[Tensor, tuple[Tensor]]:
    """
    Given scale factors for model's embeddings (i.e. scale factors applied after we compute the sum
    of positional and token embeddings), returns the SAE's latents.

    Returns:
        latent_acts_sparse: The SAE's latents in sparse form (i.e. the tensor of values)
        latent_acts_dense:  The SAE's latents in dense tensor, in a length-1 tuple
    """
    resid_after_emb = model(tokens, stop_at_layer = 0)
    resid_after_emb_scaled = einops.einsum(resid_after_emb, token_scales, '... s d, ... s -> ... s d')
    resid_before_sae = model(resid_after_emb_scaled, start_at_layer = 0, stop_at_layer = sae.cfg.hook_layer)
    sae_latents = SparseTensor.from_dense(sae.encode(resid_before_sae))
    return sae_latents.sparse[0], (sae_latents.dense,)

def token_to_latent_gradients(
    tokens: Float[Tensor, "batch seq"],
    sae: SAE,
    model: HookedSAETransformer,
) -> tuple[Tensor, SparseTensor]:
    """
    Computes the gradients between an SAE's latents and all input tokens.

    Returns:
        token_latent_grads: The gradients between input tokens and SAE latents
        latent_acts:        The SAE's latent activations
    """
    batch, seq = tokens.shape
    token_scales = t.ones(batch, seq, device = tokens.device)
    token_latent_grads, (latent_acts,) = t.func.jacrev(tokens_to_latent_acts, has_aux=True)(
        token_scales, tokens, sae, model
    )
    token_latent_grads = einops.rearrange(token_latent_grads, 'd_sae_nonzero batch seq -> batch seq d_sae_nonzero')
    latent_acts = SparseTensor.from_dense(latent_acts)
    return (token_latent_grads, latent_acts)

In [16]:
### TOKEN2LATENT PLOT ###
prompt = "The Eiffel tower is in Paris"
tokens = gpt2.to_tokens(prompt)
str_toks = gpt2.to_str_tokens(prompt)
sae_layer = 3

t.cuda.empty_cache()
t.set_grad_enabled(True)
token_latent_grads, latent_acts = token_to_latent_gradients(
    tokens, sae=gpt2_saes[sae_layer], model=gpt2
)
t.set_grad_enabled(False)

px.imshow(
    to_numpy(token_latent_grads[0]),
    color_continuous_midpoint=0.0,
    color_continuous_scale="RdBu",
    x=[
        f"F{sae_layer}.{latent:05}, {str_toks[seq]!r} ({seq})"
        for (_, seq, latent) in latent_acts.indices
    ],
    y=[f"{str_toks[i]!r} ({i})" for i in range(len(str_toks))],
    labels={"x": f"To layer {sae_layer}", "y": "From tokens"},
    title=f'Gradients between input tokens and SAE latents in layer {sae_layer}<br><sup>   Prompt: "{"".join(str_toks)}"</sup>',
    width=1900,
    height=450,
)

In [23]:
total_pairs = len(str_toks) * gpt2_saes[3].cfg.d_sae
non_zero_pairs = token_latent_grads.numel()

In [26]:
l0_percentage = non_zero_pairs / total_pairs
l0_percentage

0.007364908854166667

In [None]:
### Get off staircase token-latent pairs ###
off_staircase_indices = [] # tuples of the form ([latent_prev_token_idx, latent_prev_id], [latent_next_token_idx, latent_next_id])
for i in range(latent_latent_attributions.shape[0]):
    for j in range(latent_latent_attributions.shape[1]):
        if latent_latent_attributions[i, j].item() > 2:
            latent_next_idx = latent_acts_next.indices[i].tolist()
            latent_prev_idx = latent_acts_prev.indices[j].tolist()
            latent_next_token_idx = latent_next_idx[1]
            latent_prev_token_idx = latent_prev_idx[1]
            if latent_next_token_idx != latent_prev_token_idx:
                off_staircase_indices.append((latent_prev_idx[1:], latent_next_idx[1:]))

In [27]:
### LATENT2LOGIT ATTRIBUTIONS ###
def latent_acts_to_logits(
    latent_acts_nonzero: Float[Tensor, "nonzero_acts"],
    latent_acts_nonzero_inds: Int[Tensor, "nonzero_acts n_indices"],
    latent_acts_shape: tuple[int, ...],
    sae: SAE,
    model: HookedSAETransformer,
    sae_error_term,
    token_ids: list[int] | None = None,
) -> tuple[Tensor, tuple[Tensor]]:
    """
    Computes the logits as a downstream function of the SAE's reconstructed residual stream. If we
    supply `token_ids`, it means we only compute & return the logits for those specified tokens.
    """
    latent_acts = SparseTensor.from_sparse((latent_acts_nonzero, latent_acts_nonzero_inds, latent_acts_shape))
    res_stream = sae.decode(latent_acts.dense) + sae_error_term # can also do without sae error term but why? They use error term in Marks.
    logits = model(res_stream, start_at_layer = sae.cfg.hook_layer)[0, -1] # [d_vocab] 
    return logits[token_ids]

def latent_to_logit_gradients(
    tokens: Float[Tensor, "batch seq"],
    sae: SAE,
    model: HookedSAETransformer,
    k: int | None = None,
) -> tuple[Tensor, Tensor, Tensor, list[int] | None, SparseTensor]:
    """
    Computes the gradients between active latents and some top-k set of logits (we
    use k to avoid having to compute the gradients for all tokens).

    Returns:
        latent_logit_gradients:  The gradients between the SAE's active latents & downstream logits
        logits:                  The model's true logits
        logits_recon:            The model's reconstructed logits (i.e. based on SAE reconstruction)
        token_ids:               The tokens we computed the gradients for
        latent_acts:             The SAE's latent activations
    """
    assert tokens.shape[0] == 1, "Only supports batch size 1 for now"

    acts_hook_name = f"{sae.cfg.hook_name}.hook_sae_acts_post"
    sae_error_name = f"{sae.cfg.hook_name}.hook_sae_error"
    sae.use_error_term = True
    with t.no_grad():
        logits, cache = model.run_with_cache_with_saes(
            tokens,
            names_filter=[acts_hook_name, sae_error_name],
            saes=[sae],
            remove_batch_dim=False,
        )

    logits = logits[0, -1]
    _, token_ids = logits.topk(k = k)
    token_ids = token_ids.tolist()

    latent_acts = cache[acts_hook_name]
    latent_acts = SparseTensor.from_dense(latent_acts)
    sae_error_term = cache[sae_error_name]

    latent_logit_gradients = t.func.jacrev(latent_acts_to_logits)(
        *latent_acts.sparse,
        sae,
        model,
        sae_error_term,
        token_ids,
    )

    latent_logit_attributions = latent_logit_gradients * latent_acts.sparse[0]
    sae.use_error_term = False

    return (
        latent_logit_gradients,
        logits,
        token_ids,
        latent_acts,
        latent_logit_attributions,
    )

In [28]:
### LATENT2LOGIT GRADIENT PLOT ###
prompt = "The Eiffel tower is in the city of"
answer = " Paris"

tokens = gpt2.to_tokens(prompt, prepend_bos=True)
str_toks = gpt2.to_str_tokens(prompt, prepend_bos=True)
k = 25
layer = 9

(
    latent_logit_grads,
    logits,
    token_ids,
    latent_acts,
    latent_logit_attributions,
) = latent_to_logit_gradients(tokens, sae=gpt2_saes[layer], model=gpt2, k=k)

# We can, for instance, sort the attributions by their value for the first logit in token_ids
sorted_indices = latent_logit_grads[0].argsort(descending=True)

# Reorder the attributions
latent_logit_grads = latent_logit_grads[:, sorted_indices]

px.imshow(
    to_numpy(latent_logit_grads),
    color_continuous_midpoint=0.0,
    color_continuous_scale="RdBu",
    x=[
        f"{str_toks[seq]!r} ({seq}), latent {latent:05}"
        for (_, seq, latent) in latent_acts.indices[sorted_indices]
    ],
    y=[f"{tok!r} ({gpt2.to_single_str_token(tok)})" for tok in token_ids],
    labels={"x": f"Features in layer {layer}", "y": "Logits"},
    title=(
        f"Gradients between SAE latents in layer {layer} "
        f"and final logits (top {k} logits)\nPrompt: {' '.join(str_toks)}"
    ),
    width=1900,
    height=800,
    aspect="auto",
).show()

In [12]:
### LATENT2LOGIT ATTRIBUTIONS PLOT ###
prompt = "The Eiffel tower is in the city of"
answer = " Paris"

tokens = gpt2.to_tokens(prompt, prepend_bos=True)
str_toks = gpt2.to_str_tokens(prompt, prepend_bos=True)
k = 25
layer = 9

(
    latent_logit_grads,
    logits,
    token_ids,
    latent_acts,
    latent_logit_attributions,
) = latent_to_logit_gradients(tokens, sae=gpt2_saes[layer], model=gpt2, k=k)

# We can, for instance, sort the attributions by their value for the first logit in token_ids
sorted_indices = latent_logit_attributions[0].argsort(descending=True)

# Reorder the attributions
latent_logit_attributions = latent_logit_attributions[:, sorted_indices]

px.imshow(
    to_numpy(latent_logit_attributions),
    color_continuous_midpoint=0.0,
    color_continuous_scale="RdBu",
    x=[
        f"{str_toks[seq]!r} ({seq}), latent {latent:05}"
        for (_, seq, latent) in latent_acts.indices[sorted_indices]
    ],
    y=[f"{tok!r} ({gpt2.to_single_str_token(tok)})" for tok in token_ids],
    labels={"x": f"Features in layer {layer}", "y": "Logits"},
    title=(
        f"**Attributions** (gradient × activation) between SAE latents in layer {layer} "
        f"and final logits (top {k} logits)\nPrompt: {' '.join(str_toks)}"
    ),
    width=1900,
    height=800,
    aspect="auto",
).show()

In [30]:
total_pairs = 25 * gpt2_saes[8].cfg.d_sae
non_zero_pairs = latent_logit_attributions.numel()

In [31]:
non_zero_pairs / total_pairs

0.021321614583333332

In [15]:
### Compute L1_norm of latents and L1_norm of latent attributions ###
(
    latent_logit_grads,
    logits,
    token_ids,
    latent_acts,
    latent_logit_attributions,
) = latent_to_logit_gradients(tokens, sae=gpt2_saes[layer], model=gpt2, k=k)

t.sum(t.abs(latent_logit_attributions))
t.sum(t.abs(latent_acts.sparse[0]))

tensor(4671.0801)

In [64]:
### Fact: Michael Jordan plays the sport of
def latent_acts_to_sport(
    latent_acts_nonzero: Float[Tensor, "nonzero_acts"],
    latent_acts_nonzero_inds: Int[Tensor, "nonzero_acts n_indices"],
    latent_acts_shape: tuple[int, ...],
    sae: SAE,
    model: HookedSAETransformer,
    sae_error_term,
) -> tuple[Tensor, tuple[Tensor]]:
    """
    Computes the logit difference between " basketball" and " golf"
    """
    latent_acts = SparseTensor.from_sparse(
        (latent_acts_nonzero, latent_acts_nonzero_inds, latent_acts_shape)
    )
    res_stream = (
        sae.decode(latent_acts.dense) + sae_error_term
    )  # can also do without sae error term but why? They use error term in Marks.
    logits = model(res_stream, start_at_layer=sae.cfg.hook_layer)[0, -1]  # [d_vocab]
    difference = logits[model.to_tokens(" tennis", prepend_bos=False).item()] - logits[model.to_tokens(" golf", prepend_bos = False).item()]
    return difference


def latent_to_sport_gradients(
    tokens: Float[Tensor, "batch seq"],
    sae: SAE,
    model: HookedSAETransformer,
    k: int | None = None,
) -> tuple[Tensor, Tensor, Tensor, list[int] | None, SparseTensor]:
    """
    Computes the gradients between active latents and some top-k set of logits (we
    use k to avoid having to compute the gradients for all tokens).

    Returns:
        latent_logit_gradients:  The gradients between the SAE's active latents & downstream logits
        logits:                  The model's true logits
        logits_recon:            The model's reconstructed logits (i.e. based on SAE reconstruction)
        token_ids:               The tokens we computed the gradients for
        latent_acts:             The SAE's latent activations
    """
    assert tokens.shape[0] == 1, "Only supports batch size 1 for now"

    acts_hook_name = f"{sae.cfg.hook_name}.hook_sae_acts_post"
    sae_error_name = f"{sae.cfg.hook_name}.hook_sae_error"
    sae.use_error_term = True
    with t.no_grad():
        logits, cache = model.run_with_cache_with_saes(
            tokens,
            names_filter=[acts_hook_name, sae_error_name],
            saes=[sae],
            remove_batch_dim=False,
        )

    latent_acts = cache[acts_hook_name]
    latent_acts = SparseTensor.from_dense(latent_acts)
    sae_error_term = cache[sae_error_name]

    latent_logit_gradients = t.func.jacrev(latent_acts_to_sport)(
        *latent_acts.sparse,
        sae,
        model,
        sae_error_term,
    )

    latent_logit_attributions = latent_logit_gradients * latent_acts.sparse[0]
    sae.use_error_term = False

    return (
        latent_logit_gradients,
        latent_acts,
        latent_logit_attributions,
    )

In [65]:
prompt = "Fact: Serena Williams plays the sport of"
answer = " basketball"

tokens = gpt2.to_tokens(prompt, prepend_bos=True)
str_toks = gpt2.to_str_tokens(prompt, prepend_bos=True)
layer = 9

(
    latent_logit_grads,
    latent_acts,
    latent_logit_attributions,
) = latent_to_sport_gradients(tokens, sae=gpt2_saes[layer], model=gpt2)

In [66]:
# Sort by absolute value:
sorted_indices = t.abs(latent_logit_attributions).argsort(descending=True)
latent_logit_attributions_sorted = latent_logit_attributions[sorted_indices]

px.imshow(
    to_numpy(latent_logit_attributions_sorted.reshape(1, -1)),
    color_continuous_midpoint=0.0,
    color_continuous_scale="RdBu",
    x=[
        f"{str_toks[seq]!r} ({seq}), latent {latent:05}"
        for (_, seq, latent) in latent_acts.indices[sorted_indices]
    ],
    labels={"x": f"Features in layer {layer}", "y": "Basketball vs Golf logit diff"},
    title=(
        f"Attributions between SAE latents in layer {layer} and logit difference between 'basketball' and 'golf'\n"
        f"Prompt: {prompt}"
    ),
    width=1900,
    height=400,
).show()

In [63]:
display_dashboard(sae_id="blocks.9.hook_resid_pre", latent_idx=21685)

https://neuronpedia.org/gpt2-small/9-res-jb/21685?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300


In [30]:
display_dashboard(sae_id="blocks.9.hook_resid_pre", latent_idx=23219)

https://neuronpedia.org/gpt2-small/9-res-jb/23219?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300


In [31]:
display_dashboard(sae_id="blocks.9.hook_resid_pre", latent_idx=2792)

https://neuronpedia.org/gpt2-small/9-res-jb/2792?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300
