# Exploratory Notebook for architecture setup

We will be using GPT2 small for quick testing. Currently, I'm following the guide from [here](https://colab.research.google.com/github/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb#scrollTo=sNSfL80Uv611)

In [1]:
%load_ext autoreload
%autoreload 2
# The above lines are used to automatically reload modules before executing code

from fsrl import SAEAdapter
import torch

In [2]:
# MPS support maybe?
device = "cuda" if torch.cuda.is_available() else "cpu"

release = "gpt2-small-res-jb"
sae_id = "blocks.7.hook_resid_pre"

sae, cfg_dict, sparsity = SAEAdapter.from_pretrained(release, sae_id, device=device)

This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


We should be mindful of the loading warning. For now, I will ignore it

In [3]:
print(sae)

SAEAdapter(
  (activation_fn): ReLU()
  (hook_sae_input): HookPoint()
  (hook_sae_acts_pre): HookPoint()
  (hook_sae_acts_post): HookPoint()
  (hook_sae_output): HookPoint()
  (hook_sae_recons): HookPoint()
  (hook_sae_error): HookPoint()
  (adapter): Linear(in_features=768, out_features=24576, bias=True)
  (hook_sae_adapter): HookPoint()
)


Hookpoints are used for caching activations in the model. These can be accessed through run_with_cache which stores the activations in a dictionary using the hookpoint names as keys.

In [4]:
for k, v in sae.cfg.__dict__.items():
    print(f"{k}: {v}")

architecture: standard
d_in: 768
d_sae: 24576
activation_fn_str: relu
apply_b_dec_to_input: True
finetuning_scaling_factor: False
context_size: 128
model_name: gpt2-small
hook_name: blocks.7.hook_resid_pre
hook_layer: 7
hook_head_index: None
prepend_bos: True
dataset_path: Skylion007/openwebtext
dataset_trust_remote_code: True
normalize_activations: none
dtype: torch.float32
device: cuda
sae_lens_training_version: None
activation_fn_kwargs: {}
neuronpedia_id: gpt2-small/7-res-jb
model_from_pretrained_kwargs: {'center_writing_weights': True}
seqpos_slice: (None,)


There are some important variables to note:

- D_in = Width of the residual stream (e.g. 768 for GPT2 small)
- D_sae = Number of features in the SAE (e.g. 24576 for GPT2 small)
- hook_name = Name of the hookpoint in transformer_lens on which the SAE was trained on.

The neuronpedia_id might also be worth keeping in mind if we want to get labels for the features

In [5]:
from transformer_lens import HookedTransformer
from datasets import load_dataset

model = HookedTransformer.from_pretrained("gpt2-small", device=device)

# Small subset of the Pile
dataset = load_dataset(
    path="NeelNanda/pile-10k",
    split="train",
    streaming=False,
)

Loaded pretrained model gpt2-small into HookedTransformer


In [6]:
from transformer_lens.utils import tokenize_and_concatenate

# This function turns all dataset examples into tokens, combines then and then splits them into chunks based on the max length
# It's useful in a pre-training / unsupervised context, where we have a continuous stream of unlabelled text data 
# The huggingface tokenizer is unaware of the concatenation step which takes immediately after and thus raises a harmless warning
token_dataset = tokenize_and_concatenate(
    dataset=dataset,
    tokenizer=model.tokenizer,
    streaming=False,
    max_length=sae.cfg.context_size,
    add_bos_token=sae.cfg.prepend_bos,
)

In [7]:
batch_size = 32

with torch.no_grad(): # During training, we need gradients
    
    # activation store can give us tokens.
    batch_tokens = token_dataset[:batch_size]["tokens"]
    
    print(batch_tokens.shape)  # Should be (32, context_size)
    
    _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)

    # Use the SAE
    feature_acts = sae.encode(cache[sae.cfg.hook_name])
    sae_out = sae.decode(feature_acts)

    # save some room
    del cache

    # ignore the bos token, get the number of features that activated in each token, averaged accross batch and position
    l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()
    print("average l0 (base SAE)", l0.mean().item())

torch.Size([32, 128])
average l0 (base SAE) 61.4202766418457
