# Steering an LLM using SAELens

This notebook shows how to steer an LLM by making it emphasize a specific feature when responding to a prompt. This notebook is inspired by the [SAELens](https://jbloomaus.github.io/SAELens/) tutorials and by the [Neuronpedia steering feature](https://www.neuronpedia.org/gemma-2-9b-it/steer), and conceptually by Anthropic's famous Golden Gate Bridge example.

## Prerequisites

We ran this notebook on a `g6.12xlarge` EC2 instance using the [Deep Learning AMI](https://aws.amazon.com/ai/machine-learning/amis/).

## Identify a specific feature

Using [Neuronpedia search](https://www.neuronpedia.org/search-explanations), let's find a feature of interest for the Gemma 2B model. We'll use the pre-built SAE identified as `gemmascope-res-16k`, and drill into layer 20 of that SAE. 

For this example, I identified feature `4832` as relevant to talking about racing sailboats.

In [1]:
from IPython.display import IFrame, display

html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

def get_dashboard_html(sae_release = "gemma-2-2b", sae_id="20-gemmascope-res-16k", feature_idx=0):
    return html_template.format(sae_release, sae_id, feature_idx)

def display_dashboard(sae_release = "gemma-2-2b", sae_id="20-gemmascope-res-16k", feature_idx=0):
    html = get_dashboard_html(sae_release = sae_release, sae_id = sae_id, feature_idx=feature_idx)
    print(html)
    display(IFrame(html, width=1200, height=600))



In [3]:
latent_idx = 4832

display_dashboard(feature_idx=latent_idx)

https://neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/4832?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300


## Steering

Now that we have a feature identified, let's try steering a new output.

In [8]:
from jaxtyping import Float, Int
from torch import Tensor, nn
from transformer_lens.hook_points import HookPoint
from sae_lens import (
    SAE,
    ActivationsStore,
    HookedSAETransformer,
    LanguageModelSAERunnerConfig,
    SAEConfig
)

def steering_hook(
    activations: Float[Tensor, "batch pos d_in"],
    hook: HookPoint,
    sae: SAE,
    latent_idx: int,
    steering_coefficient: float,
) -> Tensor:
    """
    Steers the model by returning a modified activations tensor, with some multiple of the steering vector added to all
    sequence positions.
    """
    return activations + steering_coefficient * sae.W_dec[latent_idx]

In [9]:
GENERATE_KWARGS = dict(temperature=0.5, freq_penalty=2.0, verbose=False)


def generate_with_steering(
    model: HookedSAETransformer,
    sae: SAE,
    prompt: str,
    latent_idx: int,
    steering_coefficient: float = 1.0,
    max_new_tokens: int = 50,
):
    """
    Generates text with steering. A multiple of the steering vector (the decoder weight for this latent) is added to
    the last sequence position before every forward pass.
    """
    _steering_hook = partial(
        steering_hook,
        sae=sae,
        latent_idx=latent_idx,
        steering_coefficient=steering_coefficient,
    )

    with model.hooks(fwd_hooks=[(sae.cfg.hook_name, _steering_hook)]):
        output = model.generate(prompt, max_new_tokens=max_new_tokens, **GENERATE_KWARGS)

    return output

In [10]:
import torch as t
device = t.device("mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu")
gemma_2_2b = HookedSAETransformer.from_pretrained("gemma-2-2b", device=device)



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Loaded pretrained model gemma-2-2b into HookedTransformer


In [11]:
prompt = "Should I travel by plane or by"

no_steering_output = gemma_2_2b.generate(prompt, max_new_tokens=50, **GENERATE_KWARGS)    

In [12]:
gemmascope_sae_release = "gemma-scope-2b-pt-res-canonical"
gemmascope_sae_id = "layer_20/width_16k/canonical"

gemma_2_2b_sae = SAE.from_pretrained(gemmascope_sae_release, gemmascope_sae_id, device=str(device))[0]

### Output

You'll see in the output below that the steered output includes more references to sailing and racing.

In [15]:
from rich.table import Table
from rich import print as rprint
from tqdm.auto import tqdm
from functools import partial

table = Table(show_header=False, show_lines=True, title="Steering Output")
table.add_row("Normal", no_steering_output)
for i in tqdm(range(3), "Generating steered examples..."):
    table.add_row(
        f"Steered #{i}",
        generate_with_steering(
            gemma_2_2b,
            gemma_2_2b_sae,
            prompt,
            latent_idx,
            steering_coefficient=240.0,  # roughly 1.5-2x the latent's max activation
        ).replace("\n", "↵"),
    )
rprint(table)

Generating steered examples...:   0%|          | 0/3 [00:00<?, ?it/s]