In [1]:
%%capture
!uv pip install --upgrade sae-lens transformer-lens sae-dashboard

In [2]:
from torch import Tensor
from transformer_lens import utils
from functools import partial
from tqdm import tqdm
from jaxtyping import Int, Float

import torch, pathlib, pandas as pd
import huggingface_hub as hf_hub, safetensors as st

# device setup
if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

Device: cuda


In [3]:
from transformer_lens import HookedTransformer
from sae_lens import SAE, HookedSAETransformer
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes

---
## metadata for SAE exploration

In [4]:
# TODO: Make this nicer.
df = pd.DataFrame.from_records(
    {k: v.__dict__ for k, v in get_pretrained_saes_directory().items()}
).T
df.drop(
    columns=["expected_var_explained","expected_l0",
             "config_overrides","conversion_func"],
    inplace=True,
)

# currently only layer 6 works pretty well
layer = 20
MODEL = "google/gemma-2-9b-it"
SAE_ID="gemma-scope-9b-it-res"

In [5]:
df

Unnamed: 0,release,repo_id,model,saes_map,neuronpedia_id
gemma-2b-it-res-jb,gemma-2b-it-res-jb,jbloom/Gemma-2b-IT-Residual-Stream-SAEs,gemma-2b-it,{'blocks.12.hook_resid_post': 'gemma_2b_it_blo...,{'blocks.12.hook_resid_post': 'gemma-2b-it/12-...
gemma-2b-res-jb,gemma-2b-res-jb,jbloom/Gemma-2b-Residual-Stream-SAEs,gemma-2b,{'blocks.0.hook_resid_post': 'gemma_2b_blocks....,{'blocks.0.hook_resid_post': 'gemma-2b/0-res-j...
gemma-scope-27b-pt-res,gemma-scope-27b-pt-res,google/gemma-scope-27b-pt-res,gemma-2-27b,{'layer_10/width_131k/average_l0_106': 'layer_...,"{'layer_10/width_131k/average_l0_106': None, '..."
gemma-scope-27b-pt-res-canonical,gemma-scope-27b-pt-res-canonical,google/gemma-scope-27b-pt-res,gemma-2-27b,{'layer_10/width_131k/canonical': 'layer_10/wi...,{'layer_10/width_131k/canonical': 'gemma-2-27b...
gemma-scope-2b-pt-att,gemma-scope-2b-pt-att,google/gemma-scope-2b-pt-att,gemma-2-2b,{'layer_0/width_16k/average_l0_104': 'layer_0/...,"{'layer_0/width_16k/average_l0_104': None, 'la..."
gemma-scope-2b-pt-att-canonical,gemma-scope-2b-pt-att-canonical,google/gemma-scope-2b-pt-att,gemma-2-2b,{'layer_0/width_16k/canonical': 'layer_0/width...,{'layer_0/width_16k/canonical': 'gemma-2-2b/0-...
gemma-scope-2b-pt-mlp,gemma-scope-2b-pt-mlp,google/gemma-scope-2b-pt-mlp,gemma-2-2b,{'layer_0/width_16k/average_l0_119': 'layer_0/...,"{'layer_0/width_16k/average_l0_119': None, 'la..."
gemma-scope-2b-pt-mlp-canonical,gemma-scope-2b-pt-mlp-canonical,google/gemma-scope-2b-pt-mlp,gemma-2-2b,{'layer_0/width_16k/canonical': 'layer_0/width...,{'layer_0/width_16k/canonical': 'gemma-2-2b/0-...
gemma-scope-2b-pt-res,gemma-scope-2b-pt-res,google/gemma-scope-2b-pt-res,gemma-2-2b,{'embedding/width_4k/average_l0_6': 'embedding...,"{'embedding/width_4k/average_l0_6': None, 'emb..."
gemma-scope-2b-pt-res-canonical,gemma-scope-2b-pt-res-canonical,google/gemma-scope-2b-pt-res,gemma-2-2b,{'layer_0/width_16k/canonical': 'layer_0/width...,{'layer_0/width_16k/canonical': 'gemma-2-2b/0-...


In [6]:
print(f"SAEs in the {SAE_ID}")
for k, v in df.loc[df.release==SAE_ID,"saes_map"].values[0].items():print(f"SAE id: {k} for hook point: {v}")

SAEs in the gemma-scope-9b-it-res
SAE id: layer_20/width_131k/average_l0_13 for hook point: layer_20/width_131k/average_l0_13
SAE id: layer_20/width_131k/average_l0_153 for hook point: layer_20/width_131k/average_l0_153
SAE id: layer_20/width_131k/average_l0_24 for hook point: layer_20/width_131k/average_l0_24
SAE id: layer_20/width_131k/average_l0_43 for hook point: layer_20/width_131k/average_l0_43
SAE id: layer_20/width_131k/average_l0_81 for hook point: layer_20/width_131k/average_l0_81
SAE id: layer_20/width_16k/average_l0_14 for hook point: layer_20/width_16k/average_l0_14
SAE id: layer_20/width_16k/average_l0_189 for hook point: layer_20/width_16k/average_l0_189
SAE id: layer_20/width_16k/average_l0_25 for hook point: layer_20/width_16k/average_l0_25
SAE id: layer_20/width_16k/average_l0_47 for hook point: layer_20/width_16k/average_l0_47
SAE id: layer_20/width_16k/average_l0_91 for hook point: layer_20/width_16k/average_l0_91
SAE id: layer_31/width_131k/average_l0_109 for hook 

---
## models

We will be using [GemmaScope](https://huggingface.co/google/gemma-scope-9b-pt-res/tree/main) and [google/gemma-2-9b-it](https://huggingface.co/google/gemma-2-9b-it) for exploration

In [7]:
# NOTE: Make sure to only run this once
model = HookedSAETransformer.from_pretrained(MODEL, device=device)



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2-9b-it into HookedTransformer


In [8]:
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release=SAE_ID, device=device,
    sae_id=f"layer_{layer}/width_131k/average_l0_81", # test with L0_81
)

In [9]:
# get hook point
hook_point = sae.cfg.hook_name
sae.cfg.__dict__

{'architecture': 'jumprelu',
 'd_in': 3584,
 'd_sae': 131072,
 'activation_fn_str': 'relu',
 'apply_b_dec_to_input': False,
 'finetuning_scaling_factor': False,
 'context_size': 1024,
 'model_name': 'gemma-2-9b-it',
 'hook_name': 'blocks.20.hook_resid_post',
 'hook_layer': 20,
 'hook_head_index': None,
 'prepend_bos': True,
 'dataset_path': 'monology/pile-uncopyrighted',
 'dataset_trust_remote_code': True,
 'normalize_activations': None,
 'dtype': 'float32',
 'device': 'cuda',
 'sae_lens_training_version': None,
 'activation_fn_kwargs': {},
 'neuronpedia_id': None,
 'model_from_pretrained_kwargs': {},
 'seqpos_slice': (None,)}

In [10]:
from datasets import load_dataset
from transformer_lens.utils import tokenize_and_concatenate

dataset = load_dataset(path="NeelNanda/pile-10k", split="train", streaming=False)

token_dataset = tokenize_and_concatenate(
    dataset=dataset, tokenizer=model.tokenizer,
    streaming=True,
    max_length=sae.cfg.context_size,
    add_bos_token=sae.cfg.prepend_bos,
)

In [11]:
sv_prompt="""The chiffonier stood a few feet from the foot of the bed. He had emptied the drawers into cartons that morning, which were in the living room."""

sv_logits, cache = model.run_with_cache(sv_prompt, prepend_bos=True)
tokens = model.to_tokens(sv_prompt)
print(tokens)

# get the feature activations from our SAE
sv_feature_acts = sae.encode(cache[hook_point])

# get sae_out
sae_out = sae.decode(sv_feature_acts)

# print out the top activations, focus on the indices
print(torch.topk(sv_feature_acts, 3))
top_indices = torch.topk(sv_feature_acts, 3).indices.tolist()

tensor([[     2,    651,  92276,   1129,  12671,    476,   2619,   5368,    774,
            573,   4701,    576,    573,   3609, 235265,   1315,   1093, 112935,
            573,  64145,   1280, 137007,    674,   5764, 235269,    948,   1049,
            575,    573,   5027,   2965, 235265]], device='cuda:0')
torch.return_types.topk(
values=tensor([[[3156.4102,  774.5602,  577.5037],
         [1156.9087,  405.6550,  163.4940],
         [  35.3094,   31.4556,   30.7801],
         [  33.7741,   30.0334,   29.1854],
         [  44.0230,   37.2889,   32.9577],
         [  37.9170,   32.5813,   28.3535],
         [  70.3743,   39.1346,   29.2105],
         [  42.5608,   31.9773,   30.9554],
         [  44.6310,   35.2725,   33.9535],
         [  36.2088,   23.6369,   18.6098],
         [  47.2001,   35.4091,   30.4748],
         [  40.7974,   39.5112,   35.0373],
         [  47.3596,   30.3908,   26.2560],
         [  68.0637,   63.9519,   50.8768],
         [  53.7122,   39.3916,   30.7550

In [12]:
sae.cfg.neuronpedia_id = MODEL

In [25]:
from IPython.display import IFrame, HTML

# get a random feature from the SAE
feature_idx = torch.randint(0, sae.cfg.d_sae, (1,)).item()

html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"


def get_dashboard_html(sae_release="gemma-2-9b-it", sae_id=f"{layer}-gemmascope-res-131k", feature_idx=0):
    print((result := html_template.format(sae_release, sae_id, feature_idx)))
    return result


html = get_dashboard_html(feature_idx=43499)
IFrame(html, width=1200, height=600)

https://neuronpedia.org/gemma-2-9b-it/20-gemmascope-res-131k/43499?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300


In [26]:
def find_max_activation(model, sae, activation_store, feature_idx, num_batches=100):
    """
    Find the maximum activation for a given feature index. This is useful for
    calibrating the right amount of the feature to add.
    """
    max_activation = 0.0

    pbar = tqdm(range(num_batches))
    for _ in pbar:
        tokens = activation_store.get_batch_tokens()

        _, cache = model.run_with_cache(
            tokens,
            stop_at_layer=sae.cfg.hook_layer + 1,
            names_filter=[sae.cfg.hook_name],
        )
        sae_in = cache[sae.cfg.hook_name]
        feature_acts = sae.encode(sae_in).squeeze()

        feature_acts = feature_acts.flatten(0, 1)
        batch_max_activation = feature_acts[:, feature_idx].max().item()
        max_activation = max(max_activation, batch_max_activation)

        pbar.set_description(f"Max activation: {max_activation:.4f}")

    return max_activation


def steering(activations, hook, steering_strength=1.0, steering_vector=None, max_act=1.0):
    return activations + max_act * steering_strength * steering_vector


def generate_with_steering(
    model,
    sae,
    prompt,
    steering_feature,
    max_act,
    steering_strength=1.0,
    max_new_tokens=95,
):
    input_ids = model.to_tokens(prompt, prepend_bos=sae.cfg.prepend_bos)

    steering_vector = sae.W_dec[steering_feature].to(model.cfg.device)

    steering_hook = partial(
        steering,
        steering_vector=steering_vector,
        steering_strength=steering_strength,
        max_act=max_act,
    )

    # standard transformerlens syntax for a hook context for generation
    with model.hooks(fwd_hooks=[(sae.cfg.hook_name, steering_hook)]):
        output = model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            temperature=0.7,
            top_p=0.9,
            stop_at_eos=False if device == "mps" else True,
            prepend_bos=sae.cfg.prepend_bos,
        )

    full_text = model.tokenizer.decode(output[0])
    
    # Decode the prompt and generated text separately
    prompt_len = len(input_ids[0])
    prompt_tokens = output[0][:prompt_len]
    generated_tokens = output[0][prompt_len:]
    
    prompt_text = model.tokenizer.decode(prompt_tokens)
    generated_text = model.tokenizer.decode(generated_tokens)
    
    # Create HTML with different colors
    html_output = f"""
    <div style="font-family: monospace;">
        <span style="color: black;">{prompt_text}</span>
        <span style="color: blue; font-weight: bold;">{generated_text}</span>
    </div>
    """
    
    # Display the colored output
    display(HTML(html_output))
    return full_text

In [28]:
# Choose a feature to steer
steering_feature = steering_feature = 43499

# Find the maximum activation for this feature
# NOTE: we could also get the max activation from Neuronpedia (https://www.neuronpedia.org/api-doc#tag/lookup/GET/api/feature/{modelId}/{layer}/{index})
max_act = 50.66
print(f"Maximum activation for feature {steering_feature}: {max_act:.4f}")


# Generate text without steering for comparison
normal_text = model.generate(
    sv_prompt,
    max_new_tokens=95,
    stop_at_eos=False if device == "mps" else True,
    prepend_bos=sae.cfg.prepend_bos,
)

print("\nNormal text (without steering):")
print(normal_text)

# Generate text with steering
steered_text = generate_with_steering(
    model, sae, sv_prompt, steering_feature, max_act, steering_strength=2.0
)
print("Steered text:")
print(steered_text)

Maximum activation for feature 43499: 50.6600


  0%|          | 0/95 [00:00<?, ?it/s]


Normal text (without steering):
The chiffonier stood a few feet from the foot of the bed. He had emptied the drawers into cartons that morning, which were in the living room.  He stood for a moment, staring at the empty space where the bristly monster of wood and fabric had housed all the ghosts of his past.

He closed his eyes, trying to recall the feeling of opening the top drawer. The smell of dust and cedarwood always hit him first, followed by the whisper of paper and the faint scent of lavender, a vapor trail from the hand-painted brooch his grandmother had sewn against a faded silk dress. His grandmother. His


  0%|          | 0/95 [00:00<?, ?it/s]

Steered text:
<bos>The chiffonier stood a few feet from the foot of the bed. He had emptied the drawers into cartons that morning, which were in the living room. Now, all that remained was the chiffonier itself, a solid piece of mahogany with intricate carvings and brass handles.

The man stood in the middle of the room, looking at it with a mixture of fondness and regret. It had been his grandfather's, a piece of furniture that had seen generations come and go.

"It's too big," he sighed.

The room was small, the walls close. The chiffonier, with its imposing presence,
