In [2]:
import gc
import itertools
import math
import os
import random
import sys
from collections import Counter, defaultdict
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Any, Callable, Literal, TypeAlias

import circuitsvis as cv
import einops
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import requests
import torch as t
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from IPython.display import HTML, IFrame, clear_output, display
from jaxtyping import Float, Int
from openai import OpenAI
from rich import print as rprint
from rich.table import Table
from sae_lens import (
    SAE,
    ActivationsStore,
    HookedSAETransformer,
    LanguageModelSAERunnerConfig,
    SAEConfig,
    SAETrainingRunner,
    upload_saes_to_huggingface,
)
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
from sae_vis import SaeVisConfig, SaeVisData, SaeVisLayoutConfig
from tabulate import tabulate
from torch import Tensor, nn
from torch.distributions.categorical import Categorical
from torch.nn import functional as F
from tqdm.auto import tqdm
from transformer_lens import ActivationCache, HookedTransformer
from transformer_lens.hook_points import HookPoint
from transformer_lens.utils import get_act_name, test_prompt, to_numpy

device = t.device(
    "mps"
    if t.backends.mps.is_available()
    else "cuda" if t.cuda.is_available() else "cpu"
)

# from plotly_utils import imshow, line

MAIN = __name__ == "__main__"

In [6]:
gpt2 = HookedSAETransformer.from_pretrained("gpt2-small", device=device)

gpt2_saes = {
    layer: SAE.from_pretrained(
        release="gpt2-small-res-jb",
        sae_id=f"blocks.{layer}.hook_resid_pre",
        device=str(device),
    )[0]
    for layer in tqdm(range(gpt2.cfg.n_layers))
}

Loaded pretrained model gpt2-small into HookedTransformer


This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)
100%|██████████| 12/12 [00:12<00:00,  1.05s/it]


In [3]:
class SparseTensor:
    """
    Handles 2D tensor data (assumed to be non-negative) in 2 different formats:
        dense:  The full tensor, which contains zeros. Shape is (n1, ..., nk).
        sparse: A tuple of nonzero values (shape (n_nonzero,)), nonzero indices (shape (n_nonzero, k)), and the shape of
                the dense tensor.
    """

    sparse: tuple[Tensor, Tensor, tuple[int, ...]]
    dense: Tensor

    def __init__(self, sparse: tuple[Tensor, Tensor, tuple[int, ...]], dense: Tensor):
        self.sparse = sparse
        self.dense = dense

    @classmethod
    def from_dense(cls, dense: Tensor) -> "SparseTensor":
        sparse = (dense[dense > 0], t.argwhere(dense > 0), tuple(dense.shape))
        return cls(sparse, dense)

    @classmethod
    def from_sparse(
        cls, sparse: tuple[Tensor, Tensor, tuple[int, ...]]
    ) -> "SparseTensor":
        nonzero_values, nonzero_indices, shape = sparse
        dense = t.zeros(shape, dtype=nonzero_values.dtype, device=nonzero_values.device)
        dense[nonzero_indices.unbind(-1)] = nonzero_values
        return cls(sparse, dense)

    @property
    def values(self) -> Tensor:
        return self.sparse[0].squeeze()

    @property
    def indices(self) -> Tensor:
        return self.sparse[1].squeeze()

    @property
    def shape(self) -> tuple[int, ...]:
        return self.sparse[2]

In [4]:
### LATENT2LATENT ATTRIBUTION ###
def latent_acts_to_later_latent_acts(
    latent_acts_nonzero: Float[Tensor, "nonzero_acts"],
    latent_acts_nonzero_inds: Int[Tensor, "nonzero_acts n_indices"],
    latent_acts_shape: tuple[int, ...],
    sae_from: SAE,
    sae_to: SAE,
    model: HookedSAETransformer,
    error,
) -> tuple[Tensor, tuple[Tensor]]:
    """
    Given some latent activations for a residual stream SAE earlier in the model, computes the latent activations of a
    later SAE. It does this by mapping the latent activations through the path SAE decoder -> intermediate model layers
    -> later SAE encoder.

    This function must input & output sparse information (i.e. nonzero values and their indices) rather than dense
    tensors, because latent activations are sparse but jacrev() doesn't support gradients on real sparse tensors.
    """
    latent_acts_from = SparseTensor.from_sparse(
        (latent_acts_nonzero, latent_acts_nonzero_inds, latent_acts_shape)
    ).dense
    resid_stream_from = sae_from.decode(latent_acts_from)

    resid_stream_next = model.forward(
        resid_stream_from + error,
        start_at_layer=sae_from.cfg.hook_layer,
        stop_at_layer=sae_to.cfg.hook_layer,
    )

    latent_acts_next = sae_to.encode(resid_stream_next)

    latent_acts_next = SparseTensor.from_dense(latent_acts_next)

    return latent_acts_next.sparse[0]


def latent_to_latent_gradients(
    tokens: Float[Tensor, "batch seq"],
    sae_from: SAE,
    sae_to: SAE,
    model: HookedSAETransformer,
) -> tuple[Tensor, SparseTensor, SparseTensor, SparseTensor]:
    """
    Computes the gradients between all active pairs of latents belonging to two SAEs.

    Returns:
        latent_latent_gradients:    The gradients between all active pairs of latents
        latent_acts_prev:           The latent activations of the first SAE
        latent_acts_next:           The latent activations of the second SAE
        latent_acts_next_recon:     The reconstructed latent activations of the second SAE (i.e.
                                    based on the first SAE's reconstructions)
    """
    acts_prev_name = f"{sae_from.cfg.hook_name}.hook_sae_acts_post"
    acts_next_name = f"{sae_to.cfg.hook_name}.hook_sae_acts_post"
    sae_from_error_name = f"{sae_from.cfg.hook_name}.hook_sae_error"
    sae_from.use_error_term = True

    with t.no_grad():
        # Get the true activations for both SAEs
        _, cache = model.run_with_cache_with_saes(
            tokens,
            names_filter=[acts_prev_name, acts_next_name, sae_from_error_name],
            stop_at_layer=sae_to.cfg.hook_layer + 1,
            saes=[sae_from, sae_to],
            remove_batch_dim=False,
        )

    latent_acts_prev = SparseTensor.from_dense(cache[acts_prev_name])
    latent_acts_next = SparseTensor.from_dense(cache[acts_next_name])
    sae_from_error = cache[sae_from_error_name]

    latent_acts_to_later_latent_acts_and_gradients = t.func.jacrev(
        latent_acts_to_later_latent_acts
    )

    latent_latent_gradients = latent_acts_to_later_latent_acts_and_gradients(
        *latent_acts_prev.sparse, sae_from, sae_to, model, sae_from_error
    )

    latent_latent_attributions = einops.einsum(
        latent_latent_gradients,
        latent_acts_prev.sparse[0],
        "next_nonzero from_nonzero, from_nonzero -> next_nonzero from_nonzero",
    )

    # Set SAE state back to default
    sae_from.use_error_term = False

    return (
        latent_latent_gradients,
        latent_acts_prev,
        latent_acts_next,
        latent_latent_attributions,
    )

In [7]:
### Activation data sae trained on
gpt2_act_store = ActivationsStore.from_sae(
    model=gpt2,
    sae=gpt2_saes[0],
    streaming=True,
    store_batch_size_prompts=1000,
    n_batches_in_buffer=32,
    device=str(device),
)

Downloading builder script: 100%|██████████| 2.73k/2.73k [00:00<00:00, 228kB/s]
Downloading readme: 100%|██████████| 7.35k/7.35k [00:00<00:00, 305kB/s]


In [8]:
batch = gpt2_act_store.get_batch_tokens()[:, :8] # truncate to 8 tokens due to cost constraints

Token indices sequence length is longer than the specified maximum sequence length for this model (1217 > 1024). Running this sequence through the model will result in indexing errors


In [9]:
### Compute average latent to latent non zero attribution rate
threshold = 1e-8
layer_from = 0
layer_to = 3
non_zero_rate = 0
for example in tqdm(batch):
    (
        latent_latent_gradients,
        latent_acts_prev,
        latent_acts_next,
        latent_latent_attributions,
    ) = latent_to_latent_gradients(example, gpt2_saes[layer_from], gpt2_saes[layer_to], gpt2)
    total_pairs = gpt2_saes[layer_from].cfg.d_sae ** 2
    non_zero_pairs = (latent_latent_attributions.abs() > threshold).sum()
    l0_percentage = non_zero_pairs / total_pairs
    non_zero_rate += l0_percentage
non_zero_rate / batch.shape[0]

100%|██████████| 1000/1000 [00:33<00:00, 30.06it/s]


tensor(1.7893e-05, device='cuda:0')

In [13]:
### TOKEN2LATENT GRADIENTS ###
def tokens_to_latent_acts(
    token_scales: Float[Tensor, "batch seq"],
    tokens: Int[Tensor, "batch seq"],
    sae: SAE,
    model: HookedSAETransformer,
) -> tuple[Tensor, tuple[Tensor]]:
    """
    Given scale factors for model's embeddings (i.e. scale factors applied after we compute the sum
    of positional and token embeddings), returns the SAE's latents.

    Returns:
        latent_acts_sparse: The SAE's latents in sparse form (i.e. the tensor of values)
        latent_acts_dense:  The SAE's latents in dense tensor, in a length-1 tuple
    """
    resid_after_emb = model(tokens, stop_at_layer=0)
    resid_after_emb_scaled = einops.einsum(
        resid_after_emb, token_scales, "... s d, ... s -> ... s d"
    )
    resid_before_sae = model(
        resid_after_emb_scaled, start_at_layer=0, stop_at_layer=sae.cfg.hook_layer
    )
    sae_latents = SparseTensor.from_dense(sae.encode(resid_before_sae))
    return sae_latents.sparse[0], (sae_latents.dense,)


def token_to_latent_gradients(
    tokens: Float[Tensor, "batch seq"],
    sae: SAE,
    model: HookedSAETransformer,
) -> tuple[Tensor, SparseTensor]:
    """
    Computes the gradients between an SAE's latents and all input tokens.

    Returns:
        token_latent_grads: The gradients between input tokens and SAE latents
        latent_acts:        The SAE's latent activations
    """
    batch, seq = tokens.shape
    token_scales = t.ones(batch, seq, device=tokens.device)
    token_latent_grads, (latent_acts,) = t.func.jacrev(
        tokens_to_latent_acts, has_aux=True
    )(token_scales, tokens, sae, model)
    token_latent_grads = einops.rearrange(
        token_latent_grads, "d_sae_nonzero batch seq -> batch seq d_sae_nonzero"
    )
    latent_acts = SparseTensor.from_dense(latent_acts)
    return (token_latent_grads, latent_acts)

In [16]:
### Compute average token to latent non zero attribution rate
sae_layer = 3
threshold = 1e-8
for example in tqdm(batch):
    (token_latent_grads, latent_acts) = token_to_latent_gradients(example.unsqueeze(0), gpt2_saes[sae_layer], gpt2)
    total_pairs = batch.shape[-1] * gpt2_saes[3].cfg.d_sae
    non_zero_pairs = (token_latent_grads.abs() >= threshold).sum()
    l0_percentage = non_zero_pairs / total_pairs
    non_zero_rate += l0_percentage
non_zero_rate / batch.shape[0]

100%|██████████| 1000/1000 [00:22<00:00, 45.13it/s]


tensor(0.0054, device='cuda:0')

In [17]:
### LATENT2LOGIT ATTRIBUTIONS ###
def latent_acts_to_logits(
    latent_acts_nonzero: Float[Tensor, "nonzero_acts"],
    latent_acts_nonzero_inds: Int[Tensor, "nonzero_acts n_indices"],
    latent_acts_shape: tuple[int, ...],
    sae: SAE,
    model: HookedSAETransformer,
    sae_error_term,
    token_ids: list[int] | None = None,
) -> tuple[Tensor, tuple[Tensor]]:
    """
    Computes the logits as a downstream function of the SAE's reconstructed residual stream. If we
    supply `token_ids`, it means we only compute & return the logits for those specified tokens.
    """
    latent_acts = SparseTensor.from_sparse(
        (latent_acts_nonzero, latent_acts_nonzero_inds, latent_acts_shape)
    )
    res_stream = (
        sae.decode(latent_acts.dense) + sae_error_term
    )  # can also do without sae error term but why? They use error term in Marks.
    logits = model(res_stream, start_at_layer=sae.cfg.hook_layer)[0, -1]  # [d_vocab]
    return logits[token_ids]


def latent_to_logit_gradients(
    tokens: Float[Tensor, "batch seq"],
    sae: SAE,
    model: HookedSAETransformer,
    k: int | None = None,
) -> tuple[Tensor, Tensor, Tensor, list[int] | None, SparseTensor]:
    """
    Computes the gradients between active latents and some top-k set of logits (we
    use k to avoid having to compute the gradients for all tokens).

    Returns:
        latent_logit_gradients:  The gradients between the SAE's active latents & downstream logits
        logits:                  The model's true logits
        logits_recon:            The model's reconstructed logits (i.e. based on SAE reconstruction)
        token_ids:               The tokens we computed the gradients for
        latent_acts:             The SAE's latent activations
    """
    assert tokens.shape[0] == 1, "Only supports batch size 1 for now"

    acts_hook_name = f"{sae.cfg.hook_name}.hook_sae_acts_post"
    sae_error_name = f"{sae.cfg.hook_name}.hook_sae_error"
    sae.use_error_term = True
    with t.no_grad():
        logits, cache = model.run_with_cache_with_saes(
            tokens,
            names_filter=[acts_hook_name, sae_error_name],
            saes=[sae],
            remove_batch_dim=False,
        )

    logits = logits[0, -1]
    _, token_ids = logits.topk(k=k)
    token_ids = token_ids.tolist()

    latent_acts = cache[acts_hook_name]
    latent_acts = SparseTensor.from_dense(latent_acts)
    sae_error_term = cache[sae_error_name]

    latent_logit_gradients = t.func.jacrev(latent_acts_to_logits)(
        *latent_acts.sparse,
        sae,
        model,
        sae_error_term,
        token_ids,
    )

    latent_logit_attributions = latent_logit_gradients * latent_acts.sparse[0]
    sae.use_error_term = False

    return (
        latent_logit_gradients,
        logits,
        token_ids,
        latent_acts,
        latent_logit_attributions,
    )

In [19]:
### LATENT2LOGIT L0
k = 25
layer = 9
threshold = 1e-8
non_zero_rate = 0
for example in tqdm(batch):
    (
        latent_logit_grads,
        logits,
        token_ids,
        latent_acts,
        latent_logit_attributions,
    ) = latent_to_logit_gradients(example.unsqueeze(0), sae=gpt2_saes[layer], model=gpt2, k=k)
    total_pairs = 25 * gpt2_saes[layer].cfg.d_sae
    non_zero_pairs = (latent_logit_attributions.abs() >= threshold).sum()
    l0_percentage = non_zero_pairs / total_pairs
    non_zero_rate += l0_percentage
non_zero_rate / batch.shape[0]

100%|██████████| 1000/1000 [00:41<00:00, 23.96it/s]


tensor(0.0154, device='cuda:0')