# Sparse Autoencoders for model steering
Authored by Mikkel Godsk Jørgensen (mgojo@dtu.dk)

In this notebook, we will explore using sparse autoencoders (abbr.: *SAEs*) for model steering in a similar fashion to [Templeton et al., 2024](https://transformer-circuits.pub/2024/scaling-monosemanticity/). We will be using the Gemma-2-2b model by Google [(Riviere et al., 2024)](https://arxiv.org/pdf/2408.00118), and the Gemma-Scope suite by DeepMind [(Lieberum et al., 2024)](https://arxiv.org/pdf/2408.05147).

A sparse autoencoder is a shallow autoencoder with a wide intermediate layer subject to a sparsity constraint/incentive. The Gemma-Scope architecture is as follows:
$$
\begin{split}
    \textbf{Encoder:}\quad\quad\boldsymbol{f}(\boldsymbol{x})&=\sigma(\boldsymbol{W}_{\rm enc}\boldsymbol{x}+\boldsymbol{b}_{\rm enc})\\
    \textbf{Decoder:}\quad\quad\boldsymbol{g}(\boldsymbol{f})&=\boldsymbol{W}_{\rm dec}\boldsymbol{f}+\boldsymbol{b}_{\rm dec}.
\end{split}
$$
Here the input $\boldsymbol{x}\in\mathbb{R}^n$ is an internal representation from the model subject to investigation/steering (i.e. Gemma 2), $\boldsymbol{f}(\boldsymbol{x})\in\mathbb{R}^M$ with $M\gg n$ is a vector of so-called *feature activations*, and $\hat{\boldsymbol{x}}=\boldsymbol{g}(\boldsymbol{f}(\boldsymbol{x}))\in\mathbb{R}^n$ is the reconstruction of the internal representation.
$\sigma:\mathbb{R}^M\rightarrow\mathbb{R}^M$ is the $\textrm{JumpReLU}_{\boldsymbol{\theta}}$ activation function parametrized by $\boldsymbol{\theta}$.

The model is trained in a fashion to minimize $\mathcal{L}=||\boldsymbol{x}-\hat{\boldsymbol{x}}||_2^2+\lambda ||\boldsymbol{f}(\boldsymbol{x})||_0.$ Here the second term incentivizes sparsity on the feature activations.
For those wanting a more detailed technical explanation, refer to e.g. [Rajamanoharan et al., 2024](https://arxiv.org/pdf/2407.14435).


Now the question remains: Why would we do this?<br>
To answer this question, we might start by considering the vast knowledge that e.g. large language models seem to possess, in spite of having a relatively moderate size of their hidden dimension (for e.g. Llama-3-8b, the representations are of only 4096 dimensions). Although debated, it has even been hypothesized that the knowledge of many deep learning models is represented linearly (Concept Activation Vectors, [Kim et al., 2018](https://arxiv.org/pdf/1711.11279)). To make sense of these counterintuitive ideas, it has been proposed that LLMs utilize something called *superposition* where different concepts need not be orthogonal (see e.g. [Elhage et al., 2022](https://transformer-circuits.pub/2022/toy_model/index.html) if you are curious), which, in turn, implies that single neurons don't respond to single concepts but are instead *polysemantic*.

## Getting started
**Access:**
To get access to download Gemma 2 2b, you must sign up at HuggingFace, apply for access to the model [here](https://huggingface.co/google/gemma-2-2b-it), and create an access token by following these [instructions](https://huggingface.co/docs/hub/security-tokens).

**Compute requirements:**
It is recommended to upload this notebook to Google Colab and selecting a GPU instance (the free-tier T4 is sufficient).

In [None]:
try:
    import google.colab
    !pip install transformers bitsandbytes sae_lens==6.6.0 datasets --upgrade
except ModuleNotFoundError:
    pass

**Important:** Please restart the session before running the next cell!

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from sae_lens.saes.sae import SAE

access_token = "hf_..."

model_name = "google/gemma-2-2b-it"
device = "cuda:0"   # "cpu"

tokenizer = AutoTokenizer.from_pretrained(model_name, token=access_token,)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    load_in_4bit=True,
    device_map=device,
    token=access_token,
)

In [None]:
import requests
import io
import json
response = requests.get("https://github.com/LenkaTetkova/Latent-space-navigation/raw/refs/heads/main/data/sae_feature_labels.json")
feature_label_dict = json.load(io.BytesIO(response.content))

sae = SAE.from_pretrained(
    release="gemma-scope-2b-pt-res",
    sae_id=f"layer_21/width_65k/average_l0_20",
    device=device,
)

## Model steering
To steer the model with sparse autoencoders, we edit the activations during inference and take the approach of [Templeton et al., 2024](https://transformer-circuits.pub/2024/scaling-monosemanticity/). Here we let $\boldsymbol{x}\in\mathbb{R}^n$ be a hidden representation from the model.

To intervene on the hidden representation, we start by computing the feature activations $\boldsymbol{f}(\boldsymbol{x})$ using the SAE. We then compute the reconstruction error term
$$
\boldsymbol{e}=\boldsymbol{x}-\hat{\boldsymbol{x}}.
$$

To steer using feature #i and a strength $\alpha$, compute
$$
\boldsymbol{f}_{\rm int}=\boldsymbol{f}(\boldsymbol{x})\odot (1-\boldsymbol{m})+\alpha\boldsymbol{m}
$$
where $\boldsymbol{m}$ has the elements $$m_j=\begin{cases}1&\textrm{ if }j=i\\0&\textrm{ otherwise}\end{cases}.$$ Here $\odot$ denotes the elementwise product.

We now define the edited "reconstruction" as $\hat{\boldsymbol{x}}_{\rm int}=\boldsymbol{g}(\boldsymbol{f}_{\rm int})+\boldsymbol{e}$ which we now pass through the rest of the model. In essence, what we have just done is to force a specific activation pattern into the models inference, which will affect the rest of the computations as we go further downstream!

In [None]:
from typing import Callable, List
from functools import partial


def clamp_intervention(
        latents:torch.Tensor,   # [batch_size, seq_len, vocab_size]
        feature_ixs: List[int], 
        clamp_value: float, 
    ) -> torch.Tensor:          # -> [batch_size, seq_len, vocab_size]
        mask = torch.zeros((latents.shape[-1],), dtype=latents.dtype, device=latents.device)[None,None,:]
        mask[...,feature_ixs] = 1.
        return (latents * (1.-mask)) + clamp_value*mask


class SAEIntervention:       # To be used in the intervention forward as the `intervention_fun`.
    def __init__(self, sae:SAE, intervention:Callable[[torch.Tensor], torch.Tensor] = lambda x:x,):
        self.sae = sae
        self.intervention = intervention

    def __call__(
            self, 
            acts:torch.Tensor   # [batch_size, seq_len, hidden_dim]
        ) -> torch.Tensor:      # -> [batch_size, seq_len, hidden_dim]
        error = acts - self.sae.forward(acts).to_dense()
        latents = self.sae.encode(acts).to_dense()
        new_latents = self.intervention(latents)
        acts_intervention = self.sae.decode(new_latents)
        acts_hat = error + acts_intervention
        return acts_hat.to(acts.dtype)


class CAVIntervention:
    def __init__(self, cav:torch.Tensor, scale:float):
        self.cav = cav.flatten()
        self.scale = scale

    def __call__(
            self, 
            acts:torch.Tensor   # [batch_size, seq_len, hidden_dim]
        ) -> torch.Tensor:      # -> [batch_size, seq_len, hidden_dim]
        return (acts + self.scale * self.cav[None,None,:]).to(acts.dtype)


class InterventionForwardHook:
    def __init__(self, intervention_fun: Callable[[torch.Tensor], torch.Tensor]):
        self.intervention_fun = intervention_fun

    def __call__(self, module, args, outputs: torch.Tensor):
        return (self.intervention_fun(outputs[0]),)
    

def clear_all_hooks(model):
    for m in model.modules():
        m._forward_hooks.clear()


def generate(input_prompt, intervention_layer, intervention):
    # Tokenize prompt
    input_ids = tokenizer.apply_chat_template(
        [
            {"role": "user", "content": input_prompt}
        ],
        return_tensors="pt"
    ).to(model.device)

    # Add intervention to model
    handle = model.model.layers[intervention_layer].register_forward_hook(
        InterventionForwardHook(
            intervention,
        )
    )

    # Inference, clean up, decode
    outputs = model.generate(input_ids, max_new_tokens=64, do_sample=True)
    handle.remove()     # Clean up forward hook
    outputs = outputs[0][input_ids.shape[1]:]   # Remove prompt
    return tokenizer.decode(outputs)     # Ensure the LLM is able to speak

## Experiment time!
NB: Keep in mind the the LLM is on the smaller side thereby much less capable than e.g. ChatGPT.

### Sterring with Sparse Autoencoders
Since SAEs are trained unsupervised, their features don't come with a ground truth label attached - we must find those ourselves. There have some different proposed pipelines using LLMs to guess what a feature responded to via the maximally activating samples. You can explore this approach [here](https://www.neuronpedia.org/gemma-scope) but note that their choice of SAE for layer 21 has a different sparsity parameter from ours rendering their labels incompatible with this notebook.


The labels provided in this notebook stem from our own auto-labelling pipeline. Keep in mind that mismatches between features and labels may occur.


Try for yourself to see if you can steer the model in a way you find favorable by adjusting the feature and strength. You will likely find that setting the strength to a high value will break the model, while a lower value may make it talk about your chosen topic without you soliciting it in the prompt!

In [None]:
from ipywidgets import widgets, interactive


intervention_layer = int(sae.cfg.metadata['hook_name'].split('.')[1])
widget_feature_select = widgets.Select(
    options=list(feature_label_dict.keys()),
    value='baseball',
    description='Feature:',
    disabled=False
)
widget_clamp_select = widgets.FloatSlider(
    value=525.0,
    min=0,
    max=1000.0,
    step=25.0,
    description='Clamp value:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
)
widget_text_input = widgets.Textarea(
    value='Write a haiku.',
    placeholder='Type in a prompt',
    description='Prompt:',
    disabled=False
)
w = interactive(
    lambda **kwargs: print(
        generate(
            input_prompt=kwargs["input_prompt"],
            intervention_layer=intervention_layer,
            intervention=SAEIntervention(
                sae=sae,
                intervention=partial(
                    clamp_intervention,
                    feature_ixs=[feature_label_dict[kwargs["feature_topic"]]],
                    clamp_value=kwargs["clamp_value"],
                )
            )
        ).strip()
    ),
    {'manual': True, 'manual_name': "Generate"},
    feature_topic=widget_feature_select,
    clamp_value=widget_clamp_select,
    input_prompt=widget_text_input
)
ui = widgets.HBox(w.children[:-1])
out = w.children[-1]
display(widgets.VBox([ui, out]))

### Steering via CAVs
Another approach to steering is to use Concept Activation Vectors (abbr.: *CAVs*, see e.g. [Kim et al., 2018](https://arxiv.org/pdf/1711.11279)). This approach exhibits some similarity to the approach with the sparse autoencoders, but here the vectors come with a ground truth attached.
To obtain CAVs, I have trained a suite on ordinary linear SVMs on the activations of the model. To aggregate the activations over a piece of text, I average the LLM representation over all tokens, which seems to work reasonably well.

To intervene, I do the following:
$$
\boldsymbol{x}_{\rm int}=\boldsymbol{x}+\alpha\boldsymbol{v},
$$
where $\boldsymbol{v}$ is a CAV.

Further down, you have the option to train your own CAVs on the 20 Newsgroup dataset if you desire to do so.

In [None]:
# Download CAVs from our Github repository and load them
response = requests.get("https://github.com/LenkaTetkova/Latent-space-navigation/raw/refs/heads/main/data/cavs_layer_21.pt")
f = torch.load(io.BytesIO(response.content))
cavs = f["cavs"]            # Shape: [n_labels, hidden_dim]
cav_labels = f["labels"]    # List of length n_labels

In [None]:
widget_topic_select_cav = widgets.Select(
    options=cav_labels,
    description='CAV:',
    disabled=False
)
widget_select_cav_strength = widgets.FloatSlider(
    value=525.0,
    min=0,
    max=1000.0,
    step=25.0,
    description='Strength:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
)
widget_text_input_cav = widgets.Textarea(
    value='Write a haiku.',
    placeholder='Type in a prompt',
    description='Prompt:',
    disabled=False
)
w_cav = interactive(
    lambda **kwargs: print(
        generate(
            input_prompt=kwargs["input_prompt"],
            intervention_layer=intervention_layer,
            intervention=CAVIntervention(
                cavs[cav_labels.index(kwargs["topic"])],
                kwargs["strength_value"],
            )
        ).strip()
    ),
    {'manual': True, 'manual_name': "Generate"},
    topic=widget_topic_select_cav,
    strength_value=widget_select_cav_strength,
    input_prompt=widget_text_input_cav,
)
ui_cav = widgets.HBox(w_cav.children[:-1])
out_cav = w_cav.children[-1]
display(widgets.VBox([ui_cav, out_cav]))

### Optional: Training your own CAVs
This code demonstrates how you might go about training your own CAVs. The approach is similar to how I obtained the shared ones, but are optained via a more easily accessible dataset. 
The training will take some time and the quality may vary. To import your trained CAVs into the above widget, simply rerun the widget cell to load them in.


Note: In the training code, I remove 64 prefix tokens (+ one BOS token) since the semantics in this dataset seem to often come a bit late in the text.

In [None]:
from tqdm.notebook import tqdm
from datasets import load_dataset
from sklearn.svm import LinearSVC


################################
# Load dataset: 20 news groups #
################################
ds = load_dataset("SetFit/20_newsgroups", trust_remote_code=False)
dl = torch.utils.data.DataLoader(
    ds["train"], 
    batch_size=32, 
    shuffle=False,  # Deterministic ordering...
)
model_base = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-2b",
    load_in_4bit=True,
    device_map=device,
    token=access_token,
)   # The base model seems to give better results.


#############################
# Collect model activations #
#############################
@torch.no_grad()    # To avoid CUDA OOM
def get_mean_activation(batch, max_seq_len=128, aggregate_from_token=64):
    # Tokenize
    tokenizer.padding_side = "right"    # So we can easily remove BOS and other prefix tokens
    tokens = tokenizer(batch, return_tensors="pt", padding=True)
    tokens = {k:v[:,:max_seq_len+1,...].to(device) for k,v in tokens.items()}

    # Compute model activations and remove BOS
    activations = model_base(**tokens, output_hidden_states=True).hidden_states[intervention_layer+1][:,aggregate_from_token+1:,:]
    mask = tokens["attention_mask"][:,aggregate_from_token+1:]

    # Return average activation across sequence
    return torch.einsum('ijk,ij->ik', activations.float(), mask.float()) / mask.sum(dim=1, keepdim=True)


activations = []
labels = []
for batch in tqdm(dl, desc="Collecting activations from language model"):
    activations.append(
        get_mean_activation(batch["text"]).cpu()
    )
    labels.append(batch["label"])

activations = torch.concat(activations, dim=0)
labels = torch.concat(labels, dim=0)
remove_mask = activations.isnan().any(dim=1)    # Remove nans...
activations = activations[~remove_mask]
labels = labels[~remove_mask]

#############
# Train CAV #
#############
acts_np = activations.numpy()
labels_np = labels.numpy()
svc = LinearSVC(fit_intercept=False).fit(acts_np, labels_np)
cavs = svc.coef_
cavs = torch.from_numpy(cavs).to(model.device)
cavs /= cavs.norm(dim=1, keepdim=True)
cav_labels = list(
    map(    # 3: Remove the int-label
        lambda x: x[1],
        sorted(     # 2: Sort them in a list
            list(
                set(    # 1: Get unique int-label and str-label pairs
                    map(
                        lambda x: (x['label'],x['label_text']),
                        ds['train']
                    )
                )
            )
        )
    )
)
assert len(cav_labels) == len(set(cav_labels))  # We've got a problem if there are any duplicates!


# Purge model from memory
import gc
del model_base
gc.collect()
torch.cuda.empty_cache()