# Setup & loading stuff

In [1]:
import os
import sys

import gc
from pathlib import Path

from datasets import load_dataset
from IPython import get_ipython
from sae_lens import SAE, HookedSAETransformer

import torch
from huggingface_hub import hf_hub_download

from sae_vis.data_config_classes import SaeVisConfig, SaeVisLayoutConfig
from sae_vis.data_storing_fns import SaeVisData
from sae_vis.model_fns import (
    load_demo_model_saes_and_data,
    load_othello_vocab,
)

torch.set_grad_enabled(False)
assert torch.cuda.is_available()
device = torch.device("cuda")

In [2]:
from IPython.display import IFrame, display
import os

def display_vis_inline(filename: str, height: int = 850):
    """
    Displays the HTML file inline in a Jupyter notebook.
    """
    # If the file is not in the current directory, adjust the path as needed
    display(IFrame(src=filename, width='100%', height=height))

In [3]:
# Setup for basic model (examples 1-3)

SEQ_LEN = 128
DATASET_PATH = "NeelNanda/c4-code-20k"
MODEL_NAME = "gelu-1l"
HOOK_NAME = "blocks.0.mlp.hook_post"

# For this, it's just a 1L model from Neel's library
sae, sae_B, model, all_tokens = load_demo_model_saes_and_data(SEQ_LEN, str(device))

Loaded pretrained model gelu-1l into HookedTransformer
torch.Size([215402, 128])


In [4]:
# Othello setup (example 4)

hf_repo_id = "callummcdougall/arena-demos-othellogpt"
sae_id = "blocks.5.mlp.hook_post-v1"
model_name = "othello-gpt"

othellogpt: HookedSAETransformer = HookedSAETransformer.from_pretrained(model_name)
othellogpt_sae = SAE.from_pretrained(release=hf_repo_id, sae_id=sae_id, device=str(device))[0]

def hf_othello_load(filename: str):
    path = hf_hub_download(repo_id=hf_repo_id, filename=filename)
    return torch.load(path, weights_only=True, map_location=device)

othello_tokens = hf_othello_load("tokens.pt")[:5000]
othello_target_logits = hf_othello_load("target_logits.pt")[:5000]
othello_linear_probes = hf_othello_load("linear_probes.pt")
print(f"{othello_tokens.shape=}")

# Get live features
_, cache = othellogpt.run_with_cache_with_saes(
    othello_tokens[:1000],
    saes=[othellogpt_sae],
    names_filter=(post_acts_hook := f"{othellogpt_sae.cfg.hook_name}.hook_sae_acts_post"),
)
acts = cache[post_acts_hook]
othello_alive_feats = (acts[:, 5:-5].flatten(0, 1) > 1e-8).any(dim=0).nonzero().squeeze().tolist()
print(f"Alive features: {len(othello_alive_feats)}/{othellogpt_sae.cfg.d_sae}")

del cache
torch.cuda.empty_cache()
gc.collect()

Loaded pretrained model othello-gpt into HookedTransformer


This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


othello_tokens.shape=torch.Size([5000, 59])
Alive features: 6784/8192


24

In [7]:
gc.collect()
torch.cuda.empty_cache()

In [8]:
# Attention model setup (example 5)

attn_model: HookedSAETransformer = HookedSAETransformer.from_pretrained("attn-only-2l-demo")
hf_repo_id = "callummcdougall/arena-demos-attn2l"
sae_id = "blocks.0.attn.hook_z-v2"
attn_sae = SAE.from_pretrained(release=hf_repo_id, sae_id=sae_id, device=str(device))[0]

original_dataset = load_dataset(attn_sae.cfg.dataset_path, split="train", streaming=True, trust_remote_code=True)
batch_size = 4096
seq_len = 256
seq_list = [x["input_ids"][: seq_len - 1] for (_, x) in zip(range(batch_size), original_dataset)]
tokens = torch.tensor(seq_list, device=device)
assert attn_model.tokenizer is not None
bos_token = torch.tensor([attn_model.tokenizer.bos_token_id for _ in range(batch_size)], device=device)
tokens = torch.cat([bos_token.unsqueeze(1), tokens], dim=1)
assert tokens.shape == (batch_size, seq_len)

# Get live features
_, cache = attn_model.run_with_cache_with_saes(
    tokens[:512],
    saes=[attn_sae],
    names_filter=(post_acts_hook := f"{attn_sae.cfg.hook_name}.hook_sae_acts_post"),
    stop_at_layer=attn_sae.cfg.hook_layer + 1,
)
acts = cache[post_acts_hook]
attn_alive_feats = (acts.flatten(0, 1) > 1e-8).any(dim=0).nonzero().squeeze().tolist()
print(f"Alive features: {len(attn_alive_feats)}/{attn_sae.cfg.d_sae}")

del cache
torch.cuda.empty_cache()
gc.collect()



Loaded pretrained model attn-only-2l-demo into HookedTransformer


This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


Resolving data files:   0%|          | 0/73 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/73 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 4.00 GiB. GPU 0 has a total capacity of 22.07 GiB of which 1.37 GiB is free. Including non-PyTorch memory, this process has 20.69 GiB memory in use. Of the allocated memory 19.87 GiB is allocated by PyTorch, and 568.01 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

# Demos

In [5]:
# [1/5] Basic demo, 1L model, default settings

sae_vis_data = SaeVisData.create(
    sae=sae,
    sae_B=sae_B,
    model=model,
    tokens=all_tokens[:8192],
    cfg=SaeVisConfig(features=range(128)),
    verbose=True,
)

filename = "demo_feature_vis.html"
sae_vis_data.save_feature_centric_vis(filename, feature=8)
display_vis_inline(filename)

Forward passes to cache data for vis:   0%|          | 0/128 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/128 [00:00<?, ?it/s]

In [6]:
# [2/5] Custom layout demo

from sae_vis.data_config_classes import (
    ActsHistogramConfig,
    Column,
    FeatureTablesConfig,
    SeqMultiGroupConfig,
)

layout = SaeVisLayoutConfig(
    columns=[
        Column(
            SeqMultiGroupConfig(buffer=None, n_quantiles=0, top_acts_group_size=30),
            width=1000,
        ),
        Column(ActsHistogramConfig(), FeatureTablesConfig(n_rows=5), width=500),
    ],
    height=1000,
)
layout.help()

sae_vis_data_custom = SaeVisData.create(
    sae=sae,
    sae_B=sae_B,
    model=model,
    tokens=all_tokens[:4096, :48],  # 4096
    cfg=SaeVisConfig(
        features=range(256),  # 256
        feature_centric_layout=layout,
    ),
    verbose=True,
)

filename = "demo_feature_vis_custom.html"
sae_vis_data_custom.save_feature_centric_vis(filename, feature=8)
display_vis_inline(filename)

Forward passes to cache data for vis:   0%|          | 0/64 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/256 [00:00<?, ?it/s]

In [7]:
# [3/5] Prompt-centric vis

prompt = "'first_name': ('django.db.models.fields"
seq_pos = model.tokenizer.tokenize(prompt).index("Ġ('")
metric = "act_quantile"

filename = "demo_prompt_vis.html"
sae_vis_data.save_prompt_centric_vis(filename, prompt=prompt, seq_pos=seq_pos, metric=metric)
display_vis_inline(filename)

In [None]:
# [4/5] OthelloGPT

# This one is a bit more complicated because I've included linear probes in the vis! They tell you
# the extent to which any given SAE latent reads from / writes to some particular probe direction.

sae_vis_data = SaeVisData.create(
    sae=othellogpt_sae,
    model=othellogpt,  # type: ignore
    linear_probes=[
        ("input", "theirs vs mine", othello_linear_probes["theirs vs mine"]),
        ("output", "theirs vs mine", othello_linear_probes["theirs vs mine"]),
        ("input", "empty", othello_linear_probes["empty"]),
        ("output", "empty", othello_linear_probes["empty"]),
    ],
    tokens=othello_tokens,
    target_logits=othello_target_logits,
    cfg=SaeVisConfig(
        features=othello_alive_feats[:16],
        seqpos_slice=(5, -5),
        feature_centric_layout=SaeVisLayoutConfig.default_othello_layout(),
    ),
    vocab_dict=load_othello_vocab(),
    verbose=True,
    clear_memory_between_batches=True,
)

filename = "demo_othello_vis.html"
sae_vis_data.save_feature_centric_vis(filename, verbose=True)
display_vis_inline(filename)

In [None]:
# [5/5] Attention SAE

sae_vis_data = SaeVisData.create(
    sae=attn_sae,
    model=attn_model,
    tokens=tokens,
    cfg=SaeVisConfig(features=attn_alive_feats[:32]),
    verbose=True,
    clear_memory_between_batches=True,
)

filename = "demo_feature_vis_attn2l.html"
sae_vis_data.save_feature_centric_vis(filename)
display_vis_inline(filename)