In [13]:
try:
    # for google colab users
    import google.colab  # type: ignore
    from google.colab import output

    COLAB = True
    %pip install sae-lens transformer-lens
except:
    # for local setup
    COLAB = False
    from IPython import get_ipython  # type: ignore

    ipython = get_ipython()
    assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Imports for displaying vis in Colab / notebook
import webbrowser
import http.server
import socketserver
import threading

PORT = 8000

# general imports
import os
import torch as t
from tqdm import tqdm
import plotly.express as px

torch.set_grad_enabled(False);

In [14]:
def display_vis_inline(filename: str, height: int = 850):
    """
    Displays the HTML files in Colab. Uses global `PORT` variable defined in prev cell, so that each
    vis has a unique port without having to define a port within the function.
    """
    if not (COLAB):
        webbrowser.open(filename)
    else:
        global PORT

        def serve(directory):
            os.chdir(directory)

            # Create a handler for serving files
            handler = http.server.SimpleHTTPRequestHandler

            # Create a socket server with the handler
            with socketserver.TCPServer(("", PORT), handler) as httpd:
                print(f"Serving files from {directory} on port {PORT}")
                httpd.serve_forever()

        thread = threading.Thread(target=serve, args=("/content",))
        thread.start()

        output.serve_kernel_port_as_iframe(
            PORT, path=f"/{filename}", height=height, cache_in_notebook=True
        )

        PORT += 1

In [18]:
# package import
from torch import Tensor
from transformer_lens import utils
from functools import partial
from jaxtyping import Int, Float

# device setup
if torch.backends.mps.is_available():
    DEVICE = "mps"
else:
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {DEVICE}")

Device: cuda


In [35]:
from transformer_lens import HookedTransformer
from sae_lens import SAE
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes

# Choose a layer you want to focus on
# For this tutorial, we're going to use layer 2
layer = 5

# get model
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    device=DEVICE,
)

# get the SAE for this layer
sae, cfg_dict, _ = SAE.from_pretrained(
    release="gpt2-small-res-jb", sae_id=f"blocks.{layer}.hook_resid_pre", device=DEVICE
)

print(sae.cfg)

# get hook point
hook_point = sae.cfg.hook_name
print(hook_point)

Loaded pretrained model gpt2-small into HookedTransformer
SAEConfig(architecture='standard', d_in=768, d_sae=24576, activation_fn_str='relu', apply_b_dec_to_input=True, finetuning_scaling_factor=False, context_size=128, model_name='gpt2-small', hook_name='blocks.5.hook_resid_pre', hook_layer=5, hook_head_index=None, prepend_bos=True, dataset_path='Skylion007/openwebtext', dataset_trust_remote_code=True, normalize_activations='none', dtype='torch.float32', device='cuda', sae_lens_training_version=None, activation_fn_kwargs={}, neuronpedia_id='gpt2-small/5-res-jb')
blocks.5.hook_resid_pre


In [41]:
sv_prompt = " The Golden Gate Bridge"
sv_logits, cache = model.run_with_cache(sv_prompt, prepend_bos=True)
tokens = model.to_tokens(sv_prompt)
print(f"Tokens: {tokens}")
print(f"Token strings: {model.to_str_tokens(tokens)}")

# get the feature activations from our SAE
sv_feature_acts = sae.encode(cache[hook_point])

# get sae_out
sae_out = sae.decode(sv_feature_acts)

k = 3
sv_feature_acts_topk_vals, sv_feature_acts_topk_indices = t.topk(sv_feature_acts, k)

# print out the top activations, focus on the indices
print(f"\nTop {k} feature activations:\n{sv_feature_acts_topk_vals}")
print(f"\nTop {k} feature indices:\n{sv_feature_acts_topk_indices}")

Tokens: tensor([[50256,   383,  8407, 12816, 10290]], device='cuda:0')
Token strings: ['<|endoftext|>', ' The', ' Golden', ' Gate', ' Bridge']

Top 3 feature activations:
tensor([[[639.8358, 519.3585, 463.5883],
         [ 18.4421,  18.1516,   7.2985],
         [ 53.0257,   7.6162,   5.8503],
         [ 19.2280,  17.0333,   7.8110],
         [ 38.3003,   6.2576,   5.5530]]], device='cuda:0')

Top 3 feature indices:
tensor([[[ 4242,  1334, 12411],
         [19946, 12384,  8814],
         [12308,  3820,  5686],
         [ 7544,  6053,  1546],
         [17274,  6053, 20939]]], device='cuda:0')


In [44]:
from sae_lens.analysis.neuronpedia_integration import get_neuronpedia_quick_list

get_neuronpedia_quick_list(sae=sae, features=sv_feature_acts_topk_indices.tolist())

'https://neuronpedia.org/quick-list/?name=temporary_list&features=%5B%7B%22modelId%22%3A%20%22gpt2-small%22%2C%20%22layer%22%3A%20%225-res-jb%22%2C%20%22index%22%3A%20%22%5B%5B4242%2C%201334%2C%2012411%5D%2C%20%5B19946%2C%2012384%2C%208814%5D%2C%20%5B12308%2C%203820%2C%205686%5D%2C%20%5B7544%2C%206053%2C%201546%5D%2C%20%5B17274%2C%206053%2C%2020939%5D%5D%22%7D%5D'