# Exploratory Notebook for architecture setup

We will be using GPT2 small for quick testing. Currently, I'm following the guide from [here](https://colab.research.google.com/github/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb#scrollTo=sNSfL80Uv611)

In [3]:
%load_ext autoreload
%autoreload 2
# The above lines are used to automatically reload modules before executing code

from fsrl import SAEAdapter
import torch

In [13]:
# MPS support maybe?
device = "cuda" if torch.cuda.is_available() else "cpu"

release = "gpt2-small-res-jb"
sae_id = "blocks.7.hook_resid_pre"

sae, cfg_dict, sparsity = SAEAdapter.from_pretrained(release, sae_id, device=device)

We should be mindful of the loading warning. For now, I will ignore it

In [6]:
print(sae)

SAEAdapter(
  (activation_fn): ReLU()
  (hook_sae_input): HookPoint()
  (hook_sae_acts_pre): HookPoint()
  (hook_sae_acts_post): HookPoint()
  (hook_sae_output): HookPoint()
  (hook_sae_recons): HookPoint()
  (hook_sae_error): HookPoint()
  (adapter_activation): ReLU()
  (adapter_layers): ModuleList(
    (0): Linear(in_features=768, out_features=24576, bias=True)
  )
  (hook_sae_adapter_pre): HookPoint()
  (hook_sae_adapter_post): HookPoint()
  (hook_sae_fusion): HookPoint()
)


Hookpoints are used for caching activations in the model. These can be accessed through run_with_cache which stores the activations in a dictionary using the hookpoint names as keys.

In [7]:
for k, v in sae.cfg.__dict__.items():
    print(f"{k}: {v}")

architecture: standard
d_in: 768
d_sae: 24576
activation_fn_str: relu
apply_b_dec_to_input: True
finetuning_scaling_factor: False
context_size: 128
model_name: gpt2-small
hook_name: blocks.7.hook_resid_pre
hook_layer: 7
hook_head_index: None
prepend_bos: True
dataset_path: Skylion007/openwebtext
dataset_trust_remote_code: True
normalize_activations: none
dtype: torch.float32
device: cuda
sae_lens_training_version: None
activation_fn_kwargs: {}
neuronpedia_id: gpt2-small/7-res-jb
model_from_pretrained_kwargs: {'center_writing_weights': True}
seqpos_slice: (None,)
release: gpt2-small-res-jb
sae_id: blocks.7.hook_resid_pre


There are some important variables to note:

- D_in = Width of the residual stream (e.g. 768 for GPT2 small)
- D_sae = Number of features in the SAE (e.g. 24576 for GPT2 small)
- hook_name = Name of the hookpoint in transformer_lens on which the SAE was trained on.

The neuronpedia_id might also be worth keeping in mind if we want to get labels for the features

In [8]:
from transformer_lens import HookedTransformer
from datasets import load_dataset

model = HookedTransformer.from_pretrained("gpt2-small", device=device)

# Small subset of the Pile
dataset = load_dataset(
    path="NeelNanda/pile-10k",
    split="train",
    streaming=False,
)

Loaded pretrained model gpt2-small into HookedTransformer


In [9]:
from transformer_lens.utils import tokenize_and_concatenate

# This function turns all dataset examples into tokens, combines then and then splits them into chunks based on the max length
# It's useful in a pre-training / unsupervised context, where we have a continuous stream of unlabelled text data 
# The huggingface tokenizer is unaware of the concatenation step which takes immediately after and thus raises a harmless warning
token_dataset = tokenize_and_concatenate(
    dataset=dataset,
    tokenizer=model.tokenizer,
    streaming=False,
    max_length=sae.cfg.context_size,
    add_bos_token=sae.cfg.prepend_bos,
)

In [10]:
def calculate_l0_norm(feature_acts: torch.Tensor) -> float:
    """
    Calculates the average L0 norm for a batch of feature activations.
    Ignores the BOS token position.
    """
    # Ensure the tensor is on the CPU and detached for calculation
    feature_acts = feature_acts.detach().cpu()
    
    # Exclude the BOS token [:, 1:] and count non-zero features
    l0_per_token = (feature_acts[:, 1:] > 0).float().sum(dim=-1)
    
    # Return the average L0 norm across all tokens in the batch
    return l0_per_token.mean().item()

BATCH_SIZE = 32

## Testing SAE Adapter

In [11]:
with torch.no_grad():
    
    # Get a batch of tokens from your dataset
    batch_tokens = token_dataset[:BATCH_SIZE]["tokens"]
    print(f"Batch tokens shape: {batch_tokens.shape}")

    # Get the LLM's activations at the SAE's hook point
    _, llm_cache = model.run_with_cache(batch_tokens, prepend_bos=True)
    sae_input_activations = llm_cache[sae.cfg.hook_name]
    
    # Use run_with_cache on the SAE itself to get its internal activations
    sae_reconstruction, sae_cache = sae.run_with_cache(sae_input_activations)
    
    # Check the keys in the SAE cache
    print(sae_cache.keys())

    # Clean up memory
    del llm_cache
    
    # a) L0 norm of the original, unsteered SAE features
    base_sae_features = sae_cache["hook_sae_acts_post"]
    l0_base = calculate_l0_norm(base_sae_features)
    print(f"Average L0 (Base SAE): {l0_base:.2f}")

    # b) L0 norm of the final, steered/fused features
    fused_features = sae_cache["hook_sae_fusion"]
    l0_fused = calculate_l0_norm(fused_features)
    print(f"Average L0 (Fused/Steered): {l0_fused:.2f}")
    
    # Clean up more memory
    del sae_cache

Batch tokens shape: torch.Size([32, 128])
dict_keys(['hook_sae_input', 'hook_sae_acts_pre', 'hook_sae_acts_post', 'hook_sae_error', 'hook_sae_adapter_pre', 'hook_sae_adapter_post', 'hook_sae_fusion', 'hook_sae_recons', 'hook_sae_output'])
Average L0 (Base SAE): 61.42
Average L0 (Fused/Steered): 12372.82


This seems to be working. The l0 norm of the fused vector is very high since the adapter is randomly initialized with small values.

## Testing saving and loading
Should probably put these in the test folder at some point

In [14]:
# Testing save
save_path = "../models/test"
sae.save_adapter(save_path)

Adapter saved to ../models/test


In [15]:
# Testing load
sae = SAEAdapter.load_from_pretrained_adapter(save_path, device=device)

Adapter loaded from ../models/test


## Testing out the HookedModel for Feature Steering

In [16]:
from fsrl import HookedModel
hooked_model = HookedModel(model, sae)

# Will display the SAE's weights and biases
print(hooked_model.get_trainable_parameters())

[Parameter containing:
tensor([[ 1.8859e-05,  1.6555e-04,  3.9675e-05,  ...,  1.4030e-05,
         -7.6774e-05, -4.5187e-05],
        [ 1.1209e-04,  3.7414e-06,  1.9793e-04,  ..., -4.4484e-05,
         -4.6172e-05, -2.3969e-06],
        [-1.7648e-05,  1.6273e-04,  1.9007e-04,  ..., -9.0681e-05,
         -1.4844e-04,  1.7187e-04],
        ...,
        [-1.8724e-04, -3.5663e-05, -1.0638e-04,  ...,  1.9664e-05,
          2.3716e-05,  1.1651e-04],
        [-1.3992e-05,  2.6648e-05, -4.1684e-05,  ...,  2.8051e-05,
          9.1322e-05, -2.3333e-05],
        [-5.1574e-05,  1.4032e-04,  8.9403e-05,  ...,  2.0003e-05,
         -7.6993e-05, -5.2337e-06]], device='cuda:0', requires_grad=True), Parameter containing:
tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0', requires_grad=True)]


In [17]:
# Test a forward pass using example tokens
example_tokens = batch_tokens[:1]  # Take the first example from the batch

# Input
print(model.tokenizer.decode(example_tokens[0], skip_special_tokens=True), end="\n-----------------------------\n\n")

output = hooked_model(example_tokens)

# Decode the output using the tokenizer, take argmax to get the most likely token
output = output.logits.argmax(dim=-1)
decoded_output = model.tokenizer.decode(output[0], skip_special_tokens=True)
print(f"Decoded output: {decoded_output}")

It is done, and submitted. You can play “Survival of the Tastiest” on Android, and on the web. Playing on the web works, but you have to simulate multi-touch for table moving and that can be a bit confusing.

There’s a lot I’d like to talk about. I’ll go through every topic, insted of making the typical what went right/wrong list.

Concept

Working over the theme was probably one of the hardest tasks I had to face.

Originally, I had an idea of what kind of
-----------------------------

Decoded output: ,'s a. it it,
 can submit it�️iv of the Fundes Heroes� on the. iOS you iOS iPhone.
 on Android web is just too it can to be the-thread gestures it-. table's be a pain of.

The are�t a way of want�ve like to say about,
'm�d talk over the single I anda, the a table Android� wrong,wrong,,

Iclusionss
The with the past of a the of the most things I ever to do. I
I, I wanted a idea for how I of table


In [18]:
clean_output = model(example_tokens)
clean_output = clean_output.argmax(dim=-1)
clean_decoded_output = model.tokenizer.decode(clean_output[0], skip_special_tokens=True)
print(f"Clean decoded output: {clean_decoded_output}")

Clean decoded output: 
's a. it it,
 can see it�️iv of the Femptes Survivors� on your. iOS you iOS iPhone.
 on Android web is just too it can to be the-player gestures it-. other's be a pain tricky.

The are�t a way of want�ve like to say about.
'm�d talk into the single in anda in the a table table� wrong,wrong,,

Iclusionss
The with the past of a the of the hardest things I had to do. I
I, I wanted a idea for how I of game


I guess it works? Both outputs should be identical with a newly initialized model.