# An investigation into use of SAEs as steering vectors

## Imports / Setup

In [6]:
try:
    # Google Colab
    # import google.colab  # type: ignore
    from google.colab import output

    COLAB = True
    %pip install sae-lens transformer-lens

except ImportError:
    # Local
    COLAB = False
    import IPython  # type: ignore

    ipython = IPython.get_ipython()
    assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Imports for displaying vis in Colab / notebook
import webbrowser
import http.server
import requests
import socketserver
import threading

PORT = 8000

# General imports
import os
import torch as t
from tqdm import tqdm, trange
import plotly.express as px
from pprint import pprint
import pathlib

t.set_grad_enabled(False)

from huggingface_hub import notebook_login

# package import
from functools import partial
from jaxtyping import Int, Float

from transformer_lens import HookedTransformer
from sae_lens import SAE
import json

# Device setup
if t.backends.mps.is_available():
    DEVICE = "mps"
else:
    DEVICE = "cuda" if t.cuda.is_available() else "cpu"

print(f"\nDevice: {DEVICE}")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload

Device: cuda


## HuggingFace Login

In [2]:
import huggingface_hub

huggingface_hub.notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [3]:
def display_vis_inline(filename: str, height: int = 850):
    """
    Displays the HTML files in Colab. Uses global `PORT` variable defined in prev cell, so that each
    vis has a unique port without having to define a port within the function.
    """
    if not (COLAB):
        webbrowser.open(filename)
    else:
        global PORT

        def serve(directory):
            os.chdir(directory)

            # Create a handler for serving files
            handler = http.server.SimpleHTTPRequestHandler

            # Create a socket server with the handler
            with socketserver.TCPServer(("", PORT), handler) as httpd:
                print(f"Serving files from {directory} on port {PORT}")
                httpd.serve_forever()

        thread = threading.Thread(target=serve, args=("/content",))
        thread.start()

        output.serve_kernel_port_as_iframe(
            PORT, path=f"/{filename}", height=height, cache_in_notebook=True
        )

        PORT += 1

In [None]:
# # get gpt2_small
# gpt2_small = HookedTransformer.from_pretrained(
#     "gemma-2b",
#     device=DEVICE,
# )


## Load GPT-2

In [None]:
gpt2_small = HookedTransformer.from_pretrained(
    "gpt2-small",
    device=DEVICE,
)

pprint(gpt2_small.cfg)

## Import GPT-2 SAEs

In [None]:
# width = 16
# f"layer_{layer}/width_{width}k/canonical"

# get the SAE for this layer
saes = [
    SAE.from_pretrained(
        release="gpt2-small-res-jb",
        sae_id=f"blocks.{layer}.hook_resid_pre",
        device=DEVICE,
    )[0]
    for layer in range(12)
]

pprint(saes[0].cfg)

## Export SAE feature explanations for later search

In [None]:
EXPLANATIONS_FPATH = "gpt2-small_res-jb_explanations.json"

try:
    with open(EXPLANATIONS_FPATH, "r") as f:
        explanations = json.load(f)
except FileNotFoundError:
    url = "https://www.neuronpedia.org/api/explanation/export"

    explanations = []

    for i in trange(len(saes)):
        sae = saes[i]
        model, sae_id = sae.cfg.neuronpedia_id.split("/")

        querystring = {"modelId": model, "saeId": sae_id}

        headers = {"X-Api-Key": os.getenv("NEURONPEDIA_TOKEN")}

        response = requests.get(url, headers=headers, params=querystring)

        explanations += response.json()

        with open(EXPLANATIONS_FPATH, "w") as f:
            json.dump(explanations, f, indent=2)

## Find all features whose explanations contain keywords

In [None]:
KEYWORDS = ["AUSTRALIA"]

explanations_filtered = [[] for i in range(12)]
explanation_count = 0

for explanation in explanations:
    if any(keyword in explanation["description"].upper() for keyword in KEYWORDS):
        layer = int(explanation["layer"].split("-")[0])
        explanations_filtered[layer].append(explanation)
        explanation_count += 1

for i in range(12):
    print(f"Number of relevant features in layer {i}: {len(explanations_filtered[i])}")

print(f"Total relevant features: {explanation_count}")


## Find SAE feature indices that correlate with intended steering direction 

In [None]:
sv_prompt = "Australia"
sv_logits, cache = gpt2_small.run_with_cache(sv_prompt, prepend_bos=True)
tokens = gpt2_small.to_tokens(sv_prompt)
str_tokens = gpt2_small.to_str_tokens(tokens)
print(f"Tokens: {tokens}")
print(f"Token strings: {str_tokens}")

k = 6
act_threshold_relative = 0.005

saes_out = []
sv_feature_acts_vals_sorted_all_layers = []
sv_feature_acts_idx_sorted_all_layers = []

for i, sae in enumerate(saes):
    explanations_filtered_idx = t.tensor(
        [int(explanation["index"]) for explanation in explanations_filtered[i]],
        device=DEVICE,
    )

    sv_feature_acts = sae.encode(cache[sae.cfg.hook_name])
    saes_out.append(sae.decode(sv_feature_acts))

    act_max = sv_feature_acts.max()
    act_threshold = act_threshold_relative * act_max

    sv_feature_acts_filtered = sv_feature_acts[:, :, explanations_filtered_idx]

    if (
        sv_feature_acts_filtered.sum() * sv_feature_acts_filtered.numel()
        < act_threshold
    ):
        continue

    sv_feature_acts_vals_sorted, sv_feature_acts_idx_sorted = (
        sv_feature_acts_filtered.sort(descending=True, dim=-1)
    )

    print(f"\nSorted activations for layer {i}")
    print(f"\tMax activation for layer {i}: {act_max}")
    for token_idx, str_token in enumerate(str_tokens):
        vals_for_token = sv_feature_acts_vals_sorted[0, token_idx]
        idx_for_token = explanations_filtered_idx[sv_feature_acts_idx_sorted][
            0, token_idx
        ]
        mask = vals_for_token >= act_threshold
        if not mask.any():
            continue
        print(f"\tToken {token_idx}: '{str_token}'")
        print(f"\t\tRelevant feature activations: {vals_for_token[mask].tolist()}")
        print(f"\t\tRelevant feature indices: {idx_for_token[mask].tolist()}")

# from sae_lens.analysis.neuronpedia_integration import get_neuronpedia_quick_list

# get_neuronpedia_quick_list(
#     sae=sae, features=sv_feature_acts_idx[:, :, :].flatten().tolist()
# )

In [None]:
STEERING_LAYER = 0
steering_vector = sae.W_dec[15642]

example_prompt = "In which country would I find the White House?"
coeff = 3000
sampling_kwargs = dict(temperature=1.0, top_p=0.1, freq_penalty=1.0)

model = gpt2_small

sae_out = saes_out[STEERING_LAYER]

In [None]:
def steering_hook(resid_pre, hook):
    if resid_pre.shape[1] == 1:
        return

    position = sae_out.shape[1]
    if steering_on:
        # using our steering vector and applying the coefficient
        resid_pre[:, : position - 1, :] += coeff * steering_vector


def hooked_generate(prompt_batch, fwd_hooks=[], seed=None, **kwargs):
    if seed is not None:
        t.manual_seed(seed)

    with model.hooks(fwd_hooks=fwd_hooks):
        tokenized = model.to_tokens(prompt_batch)
        result = model.generate(
            stop_at_eos=False,  # avoids a bug on MPS
            input=tokenized,
            max_new_tokens=50,
            do_sample=True,
            **kwargs,
        )
    return result


In [None]:
def run_generate(example_prompt):
    model.reset_hooks()
    editing_hooks = [(f"blocks.{layer}.hook_resid_post", steering_hook)]
    res = hooked_generate(
        [example_prompt] * 3, editing_hooks, seed=None, **sampling_kwargs
    )

    # Print results, removing the ugly beginning of sequence token
    res_str = model.to_string(res[:, 1:])
    print(("\n\n" + "-" * 80 + "\n\n").join(res_str))

In [None]:
steering_on = True
run_generate(example_prompt)

In [None]:
steering_on = False
run_generate(example_prompt)

## Switch to Gemma 2

In [7]:
gemma_2_2b = HookedTransformer.from_pretrained(
    "gemma-2-2b",
    device=DEVICE,
)

pprint(gemma_2_2b.cfg)



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

  return t.to(


Loaded pretrained model gemma-2-2b into HookedTransformer
HookedTransformerConfig:
{'act_fn': 'gelu_pytorch_tanh',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_scale': 16.0,
 'attn_scores_soft_cap': 50.0,
 'attn_types': ['global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'gl

# Get Gemma 2 SAEs

In [8]:
width = 16

saes = []
for i in trange(gemma_2_2b.cfg.n_layers):
    print(f"Downloading canonical SAE for layer {i} and width {width}k")
    saes.append(
        SAE.from_pretrained(
            release="gemma-scope-2b-pt-res-canonical",
            sae_id=f"layer_{i}/width_{width}k/canonical",
            device=DEVICE,
        )[0]
    )

pprint(saes[0].cfg)

  0%|          | 0/26 [00:00<?, ?it/s]

Downloading canonical SAE for layer 0 and width 16k


  4%|▍         | 1/26 [00:12<05:18, 12.72s/it]

Downloading canonical SAE for layer 1 and width 16k


  8%|▊         | 2/26 [00:13<02:14,  5.61s/it]

Downloading canonical SAE for layer 2 and width 16k


params.npz:  45%|####5     | 136M/302M [00:00<?, ?B/s]

 12%|█▏        | 3/26 [00:44<06:33, 17.12s/it]

Downloading canonical SAE for layer 3 and width 16k


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

 15%|█▌        | 4/26 [01:23<09:26, 25.76s/it]

Downloading canonical SAE for layer 4 and width 16k


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

 19%|█▉        | 5/26 [02:03<10:50, 30.96s/it]

Downloading canonical SAE for layer 5 and width 16k


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

 23%|██▎       | 6/26 [02:42<11:18, 33.90s/it]

Downloading canonical SAE for layer 6 and width 16k


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

 27%|██▋       | 7/26 [03:17<10:45, 33.95s/it]

Downloading canonical SAE for layer 7 and width 16k


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

 31%|███       | 8/26 [03:50<10:10, 33.90s/it]

Downloading canonical SAE for layer 8 and width 16k


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

 35%|███▍      | 9/26 [04:29<10:00, 35.33s/it]

Downloading canonical SAE for layer 9 and width 16k


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

 38%|███▊      | 10/26 [05:02<09:12, 34.55s/it]

Downloading canonical SAE for layer 10 and width 16k


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

 42%|████▏     | 11/26 [05:37<08:42, 34.84s/it]

Downloading canonical SAE for layer 11 and width 16k


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

 46%|████▌     | 12/26 [06:11<08:02, 34.48s/it]

Downloading canonical SAE for layer 12 and width 16k


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

 50%|█████     | 13/26 [06:44<07:23, 34.09s/it]

Downloading canonical SAE for layer 13 and width 16k


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

 54%|█████▍    | 14/26 [07:17<06:46, 33.90s/it]

Downloading canonical SAE for layer 14 and width 16k


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

 58%|█████▊    | 15/26 [07:51<06:10, 33.70s/it]

Downloading canonical SAE for layer 15 and width 16k


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

 62%|██████▏   | 16/26 [08:25<05:39, 33.95s/it]

Downloading canonical SAE for layer 16 and width 16k


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

 65%|██████▌   | 17/26 [09:07<05:25, 36.21s/it]

Downloading canonical SAE for layer 17 and width 16k


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

 69%|██████▉   | 18/26 [09:43<04:50, 36.27s/it]

Downloading canonical SAE for layer 18 and width 16k


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

 73%|███████▎  | 19/26 [10:18<04:11, 35.97s/it]

Downloading canonical SAE for layer 19 and width 16k


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

 77%|███████▋  | 20/26 [10:59<03:43, 37.30s/it]

Downloading canonical SAE for layer 20 and width 16k


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

 81%|████████  | 21/26 [12:01<03:44, 44.81s/it]

Downloading canonical SAE for layer 21 and width 16k


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

 85%|████████▍ | 22/26 [12:54<03:09, 47.31s/it]

Downloading canonical SAE for layer 22 and width 16k


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

 88%|████████▊ | 23/26 [14:03<02:41, 53.77s/it]

Downloading canonical SAE for layer 23 and width 16k


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

 92%|█████████▏| 24/26 [14:59<01:48, 54.41s/it]

Downloading canonical SAE for layer 24 and width 16k


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

 96%|█████████▌| 25/26 [15:42<00:51, 51.16s/it]

Downloading canonical SAE for layer 25 and width 16k


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

100%|██████████| 26/26 [16:23<00:00, 37.84s/it]

SAEConfig(architecture='jumprelu',
          d_in=2304,
          d_sae=16384,
          activation_fn_str='relu',
          apply_b_dec_to_input=False,
          finetuning_scaling_factor=False,
          context_size=1024,
          model_name='gemma-2-2b',
          hook_name='blocks.0.hook_resid_post',
          hook_layer=0,
          hook_head_index=None,
          prepend_bos=True,
          dataset_path='monology/pile-uncopyrighted',
          dataset_trust_remote_code=True,
          normalize_activations=None,
          dtype='float32',
          device='cuda',
          sae_lens_training_version=None,
          activation_fn_kwargs={},
          neuronpedia_id='gemma-2-2b/0-gemmascope-res-16k')





In [9]:
for sae in saes:
    fpath = Pathlib"./"
    sae.save_model()

['T_destination',
 'W_dec',
 'W_enc',
 '__annotations__',
 '__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_apply',
 '_backward_hooks',
 '_backward_pre_hooks',
 '_buffers',
 '_call_impl',
 '_compiled_call_impl',
 '_enable_hook',
 '_enable_hook_with_name',
 '_enable_hooks_for_points',
 '_forward_hooks',
 '_forward_hooks_always_called',
 '_forward_hooks_with_kwargs',
 '_forward_pre_hooks',
 '_forward_pre_hooks_with_kwargs',
 '_get_backward_hooks',
 '_get_backward_pre_hooks',
 '_get_name',
 '_is_full_backward_hook',
 '_load_from_state_dict',
 '_load_state_dict_post_hooks',
 '_load_state_dict_pre_hooks',
 '_may