# An investigation into use of SAEs as steering vectors

## Setup

### If in Colab, install deps. Otherwise, setup autoreload.

In [20]:
try:
    import google.colab

    IN_COLAB = True
    %pip install sae-lens transformer-lens

except ImportError:
    # Local
    IN_COLAB = False

    import IPython

    ipython = IPython.get_ipython()
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Imports

In [27]:
import collections
import json
import os
import requests
from pprint import pprint
import pathlib

import huggingface_hub
from sae_lens import SAE
from tqdm import tqdm, trange
import torch as t
from transformer_lens import HookedTransformer

### Other settings

In [22]:
t.set_grad_enabled(False)

if t.backends.mps.is_available():
    DEVICE = "mps"
else:
    DEVICE = "cuda" if t.cuda.is_available() else "cpu"

print(f"\nDevice: {DEVICE}")


Device: cuda


### HuggingFace Login

In [23]:
huggingface_hub.notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

## Load GPT-2

In [24]:
gpt2_small = HookedTransformer.from_pretrained(
    "gpt2-small",
    device=DEVICE,
)

pprint(gpt2_small.cfg)



Loaded pretrained model gpt2-small into HookedTransformer
HookedTransformerConfig:
{'act_fn': 'gelu_new',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_scale': 8.0,
 'attn_scores_soft_cap': -1.0,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 64,
 'd_mlp': 3072,
 'd_model': 768,
 'd_vocab': 50257,
 'd_vocab_out': 50257,
 'decoder_start_token_id': None,
 'default_prepend_bos': True,
 'device': 'cuda',
 'dtype': torch.float32,
 'eps': 1e-05,
 'experts_per_token': None,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': 0.02886751345948129,
 'load_in_4bit': False,
 'model_name': 'gpt2',
 'n_ctx': 1024,
 'n_devices': 1,
 'n_heads': 12,
 'n_key_value_heads': None,
 'n_layers': 12,
 'n_params': 84934656,
 'normalization_type': 'LNPre',
 'num_experts': None,
 'original_architecture': 'GPT2LMHeadModel',
 'output_logits_soft_cap':

## Import GPT-2 J-B SAEs

In [25]:
saes = [
    SAE.from_pretrained(
        release="gpt2-small-res-jb",
        sae_id=f"blocks.{layer}.hook_resid_pre",
        device=DEVICE,
    )[0]
    for layer in trange(gpt2_small.cfg.n_layers)
]

pprint(saes[0].cfg)

100%|██████████| 12/12 [00:13<00:00,  1.11s/it]

SAEConfig(architecture='standard',
          d_in=768,
          d_sae=24576,
          activation_fn_str='relu',
          apply_b_dec_to_input=True,
          finetuning_scaling_factor=False,
          context_size=128,
          model_name='gpt2-small',
          hook_name='blocks.0.hook_resid_pre',
          hook_layer=0,
          hook_head_index=None,
          prepend_bos=True,
          dataset_path='Skylion007/openwebtext',
          dataset_trust_remote_code=True,
          normalize_activations='none',
          dtype='torch.float32',
          device='cuda',
          sae_lens_training_version=None,
          activation_fn_kwargs={},
          neuronpedia_id='gpt2-small/0-res-jb')





## Export SAE feature explanations for later search

In [26]:
def load_explanations_from_saes(saes, save_path):
    try:
        with open(save_path, "r") as f:
            explanations = json.load(f)
    except FileNotFoundError:
        url = "https://www.neuronpedia.org/api/explanation/export"

        explanations = []

        for i, sae in enumerate(tqdm(saes)):
            model, sae_id = sae.cfg.neuronpedia_id.split("/")

            querystring = {"modelId": model, "saeId": sae_id}
            headers = {"X-Api-Key": os.getenv("NEURONPEDIA_TOKEN")}

            response = requests.get(url, headers=headers, params=querystring)
            explanations += response.json()

            with open(save_path, "w") as f:
                json.dump(explanations, f, indent=2)

    return explanations, save_path


explanations_fpath = "gpt2-small_res-jb_explanations.json"
explanations, _ = load_explanations_from_saes(saes, explanations_fpath)

## Find all features whose explanations contain keywords

In [37]:
def get_explanations_with_keywords_by_layer(explanations, keywords):
    explanations_filtered = collections.defaultdict(list)

    for explanation in explanations:
        if any(keyword in explanation["description"].upper() for keyword in keywords):
            layer = int(explanation["layer"].split("-")[0])
            explanations_filtered[layer].append(explanation)

    return explanations_filtered


keywords = ["AUSTRALIA"]

explanations_filtered = get_explanations_with_keywords_by_layer(explanations, keywords)

explanation_count = 0
for layer in range(len(saes)):
    explanation_count_in_layer = len(explanations_filtered[layer])
    explanation_count += explanation_count_in_layer
    print(f"Number of relevant features in layer {layer}: {explanation_count_in_layer}")
print(f"Total relevant features: {explanation_count}")


Number of relevant features in layer 0: 3
Number of relevant features in layer 1: 12
Number of relevant features in layer 2: 7
Number of relevant features in layer 3: 9
Number of relevant features in layer 4: 8
Number of relevant features in layer 5: 6
Number of relevant features in layer 6: 5
Number of relevant features in layer 7: 4
Number of relevant features in layer 8: 5
Number of relevant features in layer 9: 4
Number of relevant features in layer 10: 3
Number of relevant features in layer 11: 2
Total relevant features: 68


## Find SAE feature indices that correlate with intended steering direction 

In [102]:
def run_model_and_get_filtered_activations(
    model, prompt, explanations_filtered, activation_threshold, quiet=True
):
    _, cache = model.run_with_cache(prompt, prepend_bos=True)

    if not quiet:
        tokens = model.to_tokens(prompt)
        print(f"Tokens: {tokens}")
        print(f"Token strings: {model.to_str_tokens(tokens)}")

    saes_out = {}
    acts_filtered = {}

    for layer, sae in enumerate(tqdm(saes)):
        if explanations_filtered[layer] == []:
            continue

        explanations_filtered_idx = t.tensor(
            [int(explanation["index"]) for explanation in explanations_filtered[layer]],
            device=DEVICE,
        )

        feature_acts = sae.encode(
            cache[sae.cfg.hook_name]
        )  # shape (batch, sequence, features)

        saes_out[layer] = sae.decode(sv_feature_acts)

        feature_acts_vals_sorted, idx = feature_acts[
            :, :, explanations_filtered_idx
        ].sort(descending=True, dim=-1)

        feature_acts_idx_sorted = explanations_filtered_idx[idx]

        mask = feature_acts_vals_sorted >= activation_threshold

        acts_filtered[layer] = {
            "val": feature_acts_vals_sorted[mask].tolist(),
            "idx": feature_acts_idx_sorted[mask].tolist(),
        }

    return (
        cache,
        acts_filtered,
        saes_out,
    )


def print_relevant_features(feature_act_vals_sorted, feature_act_idx_sorted, model):
    for layer in range(len(saes)):
        if explanations_filtered[layer] == []:
            continue

        print(f"\nSorted activations for layer {layer}")
        pprint(feature_act_vals_sorted[layer])
        pprint(feature_act_idx_sorted[layer])


sv_prompt = "Sydney Opera House"
activation_threshold = 1

cache, acts_filtered, saes_out = run_model_and_get_filtered_activations(
    gpt2_small, sv_prompt, explanations_filtered, activation_threshold, quiet=False
)

# print_relevant_features(act_vals, act_idx, gpt2_small)

print("\nFiltered activations:")
pprint(acts_filtered)

Tokens: tensor([[50256,    50,  5173,  1681, 26049,  2097]], device='cuda:0')
Token strings: ['<|endoftext|>', 'S', 'yd', 'ney', ' Opera', ' House']


100%|██████████| 12/12 [00:00<00:00, 1332.44it/s]


Filtered activations:
{0: {'idx': [], 'val': []},
 1: {'idx': [14665], 'val': [5.82246732711792]},
 2: {'idx': [20864], 'val': [7.139648914337158]},
 3: {'idx': [19448], 'val': [2.3950774669647217]},
 4: {'idx': [19972], 'val': [4.3374762535095215]},
 5: {'idx': [], 'val': []},
 6: {'idx': [], 'val': []},
 7: {'idx': [], 'val': []},
 8: {'idx': [], 'val': []},
 9: {'idx': [3036], 'val': [3.887145757675171]},
 10: {'idx': [10541, 24077], 'val': [6.6507368087768555, 1.4402518272399902]},
 11: {'idx': [], 'val': []}}





In [96]:
def hooked_generate(prompt_batch, fwd_hooks=[], max_new_tokens=20, seed=None, **kwargs):
    if seed is not None:
        t.manual_seed(seed)

    with model.hooks(fwd_hooks=fwd_hooks):
        tokenized = model.to_tokens(prompt_batch)
        result = model.generate(
            stop_at_eos=False,  # avoids a bug on MPS
            input=tokenized,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            **kwargs,
        )
    return result


def run_generate_single_layer(
    prompt,
    layer,
    model,
    coeff,
    saes,
    saes_out,
    sampling_kwargs,
    n=3,
    steering_on=True,
    max_new_tokens=20,
):
    steering_vector = saes[steering_layer].W_dec[20864]
    sae_out = saes_out[steering_layer]

    def steering_hook(resid_pre, hook):
        if resid_pre.shape[1] == 1:
            return

        position = sae_out.shape[1]
        if steering_on:
            # using our steering vector and applying the coefficient
            resid_pre[:, : position - 1, :] += coeff * steering_vector

    model.reset_hooks()
    editing_hooks = [(f"blocks.{layer}.hook_resid_post", steering_hook)]
    res = hooked_generate(
        [prompt] * n,
        editing_hooks,
        max_new_tokens,
        seed=None,
        **sampling_kwargs,
    )

    # Print results, removing the ugly beginning of sequence token
    # res_str = model.to_string(res[:, 1:])
    # print(("\n" + "-" * 80 + "\n").join(res_str))

    return res

In [97]:
steering_layer = 2

example_prompt = "The White House is located in a country called"
coeff = 100

steering_kwargs = {"model": gpt2_small, "saes": saes, "saes_out": saes_out}
sampling_kwargs = {"temperature": 1.0, "top_p": 0.1, "freq_penalty": 1.0}

res_on = run_generate(
    example_prompt,
    layer=steering_layer,
    steering_on=True,
    coeff=coeff,
    sampling_kwargs=sampling_kwargs,
    **steering_kwargs,
)
res_off = run_generate(
    example_prompt,
    layer=steering_layer,
    steering_on=False,
    coeff=coeff,
    sampling_kwargs=sampling_kwargs,
    **steering_kwargs,
)

pprint(gpt2_small.to_string(res_on))
pprint(gpt2_small.to_string(res_off))

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

['<|endoftext|>The White House is located in a country called Australia, and '
 'the president has been visiting the country for more than a year.\n'
 '\n'
 'But he',
 '<|endoftext|>The White House is located in a country called the United '
 "States of America, and it's not just about being a place where people can "
 'come and',
 '<|endoftext|>The White House is located in a country called Australia, and '
 'the United States has been a major player in the world of sports since it '
 'was founded']
['<|endoftext|>The White House is located in a country called the United '
 "States of America, and it's not just that. It's also that the United States",
 '<|endoftext|>The White House is located in a country called the United '
 "States of America, and it's not just the president who has been on the "
 'receiving end',
 '<|endoftext|>The White House is located in a country called the United '
 "States of America, and it's not just that. It's also that the United States"]


## Switch to Gemma 2

In [98]:
gemma_2_2b = HookedTransformer.from_pretrained(
    "gemma-2-2b",
    device=DEVICE,
)

pprint(gemma_2_2b.cfg)



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

KeyboardInterrupt: 

# Get Gemma 2 SAEs

In [None]:
width = 16

saes = []
for i in trange(gemma_2_2b.cfg.n_layers):
    print(f"Downloading canonical SAE for layer {i} and width {width}k")
    saes.append(
        SAE.from_pretrained(
            release="gemma-scope-2b-pt-res-canonical",
            sae_id=f"layer_{i}/width_{width}k/canonical",
            device=DEVICE,
        )[0]
    )

pprint(saes[0].cfg)

In [None]:
for sae in tqdm(saes):
    fpath = pathlib.Path(
        f"./gemma-scope-2b-pt-res-canonical/layer_{sae.cfg.hook_layer}__width_{str(sae.cfg.d_sae)[:-3]}k"
    )
    fpath.mkdir(parents=True, exist_ok=True)
    sae.save_model(fpath)

## Find all features whose explanations contain keywords

In [None]:
EXPLANATIONS_GEMMA_FPATH = "gemma-scope-2b-pt-res-canonical-w16k_explanations.json"

try:
    with open(EXPLANATIONS_GEMMA_FPATH, "r") as f:
        explanations = json.load(f)
except FileNotFoundError:
    url = "https://www.neuronpedia.org/api/explanation/export"

    explanations = []

    for i in trange(len(saes)):
        sae = saes[i]
        model, sae_id = sae.cfg.neuronpedia_id.split("/")

        querystring = {"modelId": model, "saeId": sae_id}

        headers = {"X-Api-Key": os.getenv("NEURONPEDIA_TOKEN")}

        response = requests.get(url, headers=headers, params=querystring)

        explanations += response.json()

    with open(EXPLANATIONS_GEMMA_FPATH, "w") as f:
        json.dump(explanations, f, indent=2)

In [None]:
KEYWORDS = ["POTTER"]

explanations_filtered = [[] for i in range(len(saes))]
explanation_count = 0

for explanation in explanations:
    if any(keyword in explanation["description"].upper() for keyword in KEYWORDS):
        layer = int(explanation["layer"].split("-")[0])
        explanations_filtered[layer].append(explanation)
        explanation_count += 1

for i in range(len(saes)):
    print(f"Number of relevant features in layer {i}: {len(explanations_filtered[i])}")

print(f"Total relevant features: {explanation_count}")


In [None]:
sv_prompt = "Albus Dumbledore"
sv_logits, cache = gemma_2_2b.run_with_cache(sv_prompt, prepend_bos=True)
tokens = gemma_2_2b.to_tokens(sv_prompt)
str_tokens = gemma_2_2b.to_str_tokens(tokens)
print(f"Tokens: {tokens}")
print(f"Token strings: {str_tokens}")

k = 6
act_threshold_relative = 0.005

saes_out = []
sv_feature_acts_vals_sorted_all_layers = []
sv_feature_acts_idx_sorted_all_layers = []

for i, sae in enumerate(saes):
    explanations_filtered_idx = t.tensor(
        [int(explanation["index"]) for explanation in explanations_filtered[i]],
        device=DEVICE,
    )
    if explanations_filtered_idx.numel() == 0:
        continue

    sv_feature_acts = sae.encode(cache[sae.cfg.hook_name])
    saes_out.append(sae.decode(sv_feature_acts))

    act_max = sv_feature_acts.max()
    act_threshold = act_threshold_relative * act_max

    sv_feature_acts_filtered = sv_feature_acts[:, :, explanations_filtered_idx]

    if (
        sv_feature_acts_filtered.sum() * sv_feature_acts_filtered.numel()
        < act_threshold
    ):
        continue

    sv_feature_acts_vals_sorted, sv_feature_acts_idx_sorted = (
        sv_feature_acts_filtered.sort(descending=True, dim=-1)
    )

    print(f"\nSorted activations for layer {i}")
    print(f"\tMax activation for layer {i}: {act_max}")
    for token_idx, str_token in enumerate(str_tokens):
        vals_for_token = sv_feature_acts_vals_sorted[0, token_idx]
        idx_for_token = explanations_filtered_idx[sv_feature_acts_idx_sorted][
            0, token_idx
        ]
        mask = vals_for_token >= act_threshold
        if not mask.any():
            continue
        print(f"\tToken {token_idx}: '{str_token}'")
        print(f"\t\tRelevant feature activations: {vals_for_token[mask].tolist()}")
        print(f"\t\tRelevant feature indices: {idx_for_token[mask].tolist()}")

# from sae_lens.analysis.neuronpedia_integration import get_neuronpedia_quick_list

# get_neuronpedia_quick_list(
#     sae=sae, features=sv_feature_acts_idx[:, :, :].flatten().tolist()
# )

In [None]:
STEERING_LAYER = 8
steering_vector = saes[STEERING_LAYER].W_dec[3012]

example_prompt = "My favourite protagonist in any fantasy novel series is named"
coeff = 300
sampling_kwargs = dict(temperature=1.0, top_p=0.1, freq_penalty=1.0)

model = gemma_2_2b

sae_out = saes_out[STEERING_LAYER]

In [None]:
def steering_hook(resid_pre, hook):
    if resid_pre.shape[1] == 1:
        return

    position = sae_out.shape[1]
    if steering_on:
        # using our steering vector and applying the coefficient
        resid_pre[:, : position - 1, :] += coeff * steering_vector


def hooked_generate(prompt_batch, fwd_hooks=[], seed=None, **kwargs):
    if seed is not None:
        t.manual_seed(seed)

    with model.hooks(fwd_hooks=fwd_hooks):
        tokenized = model.to_tokens(prompt_batch)
        result = model.generate(
            stop_at_eos=False,  # avoids a bug on MPS
            input=tokenized,
            max_new_tokens=50,
            do_sample=True,
            **kwargs,
        )
    return result


In [None]:
def run_generate(example_prompt):
    model.reset_hooks()
    editing_hooks = [(f"blocks.{STEERING_LAYER}.hook_resid_post", steering_hook)]
    res = hooked_generate(
        [example_prompt] * 3, editing_hooks, seed=None, **sampling_kwargs
    )

    # Print results, removing the ugly beginning of sequence token
    res_str = model.to_string(res[:, 1:])
    print(("\n" + "-" * 80 + "\n").join(res_str))

In [None]:
steering_on = True
run_generate(example_prompt)

In [None]:
steering_on = False
run_generate(example_prompt)