# Demo Notebook

Steps:
1. Download SAE with SAE Lens.
2. Create a dataset consistent with that SAE. 
3. Fold the SAE decoder norm weights so that feature activations are "correct".
4. Estimate the activation normalization constant if needed, and fold it into the SAE weights.
5. Run the SAE generator for the features you want.

# Set Up

In [1]:
import json
import torch
from sae_lens import SAE
from transformers import AutoTokenizer, AutoModel
from transformer_lens import HookedTransformer
from sae_dashboard.sae_vis_data import SaeVisConfig
from sae_dashboard.sae_vis_runner import SaeVisRunner

## Step 1. Download / Initialize SAE

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
model = HookedTransformer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct", device=device, dtype="bfloat16")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
sae, cfg_dict, sparsity = SAE.from_eleuther(
    release="huypn16/sae-llama-3.2-1B-32x",  # see other options in sae_lens/pretrained_saes.yaml
    sae_id="layers.8",  # won't always be a hook point
    device=device,
)
# fold w_dec norm so feature activations are accurate
sae.fold_W_dec_norm()


Device: cuda
Loading model config for meta-llama/Llama-3.2-1B-Instruct
Loaded model config for {'d_model': 2048, 'd_head': 64, 'n_heads': 32, 'd_mlp': 8192, 'n_layers': 16, 'n_ctx': 2048, 'eps': 1e-05, 'd_vocab': 128256, 'act_fn': 'silu', 'n_key_value_heads': 8, 'normalization_type': 'RMS', 'positional_embedding_type': 'rotary', 'rotary_adjacent_pairs': False, 'rotary_dim': 64, 'final_rms': True, 'gated_mlp': True, 'original_architecture': 'LlamaForCausalLM', 'tokenizer_name': 'meta-llama/Llama-3.2-1B-Instruct'}




Loaded pretrained model meta-llama/Llama-3.2-1B-Instruct into HookedTransformer


layers.8/cfg.json:   0%|          | 0.00/563 [00:00<?, ?B/s]

-----Loading from eleuther-----
/datadrive5/.cache/hub/models--huypn16--sae-llama-3.2-1B-32x/snapshots/c1d312b5f3e8b693867ef3089a83c6aae051ffc3/layers.8/cfg.json
{'architecture': 'topk', 'hook_name': 'blocks.8.hook_resid_post', 'hook_layer': 8, 'layer': 8, 'k': 32, 'activation_fn_str': 'relu', 'd_sae': 65536, 'd_in': 2048, 'multi_topk': False, 'device': 'cuda', 'apply_b_dec_to_input': False, 'finetuning_scaling_factor': False, 'context_size': 1024, 'hook_head_index': None, 'prepend_bos': True, 'normalize_activations': 'none', 'dtype': 'float32', 'sae_lens_training_version': 'eleuther', 'neuronpedia_id': None, 'activation_fn_kwargs': {}, 'model_from_pretrained_kwargs': {}}
{'architecture': 'topk', 'hook_name': 'blocks.8.hook_resid_post', 'hook_layer': 8, 'layer': 8, 'k': 32, 'activation_fn_str': 'relu', 'd_sae': 65536, 'd_in': 2048, 'multi_topk': False, 'device': 'cuda', 'apply_b_dec_to_input': False, 'finetuning_scaling_factor': False, 'context_size': 1024, 'hook_head_index': None, 'pr

# 2. Get token dataset

In [3]:
from sae_lens import ActivationsStore
# activations_store = ActivationsStore.from_sae(
#     model=model,
#     sae=sae,
#     streaming=True,
#     dataset="open-web-math/open-web-math",
#     token_columns=["text"],
#     store_batch_size_prompts=16,
#     n_batches_in_buffer=16,
#     device=device,
# )
activations_store = ActivationsStore.from_sae(
    model=model,
    sae=sae,
    streaming=True,
    dataset="lighteval/MATH",
    token_columns=["problem", "solution"],
    store_batch_size_prompts=16,
    n_batches_in_buffer=16,
    device=device,
)

Token columns: ['problem', 'solution']




In [4]:
from tqdm import tqdm
import os
def get_tokens(
    activations_store: ActivationsStore,
    n_prompts: int,
):
    all_tokens_list = []
    pbar = tqdm(range(n_prompts))
    for _ in pbar:
        batch_tokens = activations_store.get_batch_tokens()
        batch_tokens = batch_tokens[torch.randperm(batch_tokens.shape[0])][
            : batch_tokens.shape[0]
        ]
        all_tokens_list.append(batch_tokens)

    all_tokens = torch.cat(all_tokens_list, dim=0)
    all_tokens = all_tokens[torch.randperm(all_tokens.shape[0])]
    return all_tokens

# 1000 prompts is plenty for a demo.
if os.path.exists("llama_lighteval.pt"):
    token_dataset = torch.load("llama_lighteval.pt")
else:
    token_dataset = get_tokens(activations_store, n_prompts=1024)
    torch.save(token_dataset, "llama_lighteval.pt")

print(tokenizer.decode(token_dataset[0]))
print(token_dataset.shape) # [store_batch_size_prompts * n_prompts, 1024]

<|begin_of_text|>),SE);
label("G",(4,3),SW);
label("6",(3,0),S);
label("1",(0.5,3),N);
label("2",(5,3),N);
label("3",(6,1.5),E);
[/asy]
solution: We first find the length of line segment $FG$. Since $DC$ has length $6$ and $DF$ and $GC$ have lengths $1$ and $2$ respectively, $FG$ must have length $3$. Next, we notice that $DC$ and $AB$ are parallel so $\angle EFG \cong \angle EAB$ because they are corresponding angles. Similarly, $\angle EGF \cong \angle EBA$. Now that we have two pairs of congruent angles, we know that $\triangle FEG \sim \triangle AEB$ by Angle-Angle Similarity.

Because the two triangles are similar, we have that the ratio of the altitudes of $\triangle FEG$ to $\triangle AEB$ equals the ratio of the bases. $FG:AB=3:6=1:2$, so the the ratio of the altitude of $\triangle FEG$ to that of $\triangle AEB$ is also $1:2$. Thus, the height of the rectangle $ABCD$ must be half of the altitude of $\triangle AEB$. Since the height of rectangle $ABCD$ is $3$, the altitude of $

  token_dataset = torch.load("llama_lighteval.pt")


# Step 3 Evaluate the SAE

In [5]:
from sae_lens import run_evals
from sae_lens.evals import get_eval_everything_config

eval_metrics = run_evals(
    sae=sae,
    activation_store=activations_store,
    model=model,
    eval_config=get_eval_everything_config(
        batch_size_prompts=8,
        n_eval_reconstruction_batches=10,
        n_eval_sparsity_variance_batches=3,
    )
)
print(json.dumps(eval_metrics, indent=4))
# CE Loss score should be high for residual stream SAEs
print(eval_metrics["metrics/ce_loss_score"])
# ce loss without SAE should be fairly low < 3.5 suggesting the Model is being run correctly
print(eval_metrics["metrics/ce_loss_without_sae"])
# ce loss with SAE shouldn't be massively higher
print(eval_metrics["metrics/ce_loss_with_sae"])

standard replacement hook:  blocks.8.hook_resid_post
standard replacement hook:  blocks.8.hook_resid_post
standard replacement hook:  blocks.8.hook_resid_post
standard replacement hook:  blocks.8.hook_resid_post
standard replacement hook:  blocks.8.hook_resid_post
standard replacement hook:  blocks.8.hook_resid_post
standard replacement hook:  blocks.8.hook_resid_post
standard replacement hook:  blocks.8.hook_resid_post
standard replacement hook:  blocks.8.hook_resid_post
standard replacement hook:  blocks.8.hook_resid_post
{
    "metrics/kl_div_with_sae": 0.78125,
    "metrics/kl_div_with_ablation": 10.4375,
    "metrics/ce_loss_with_sae": 2.25,
    "metrics/ce_loss_without_sae": 1.5703125,
    "metrics/ce_loss_with_ablation": 11.75,
    "metrics/kl_div_score": 0.9251497005988024,
    "metrics/ce_loss_score": 0.9332310053722179,
    "metrics/l2_norm_in": 9.25,
    "metrics/l2_norm_out": 8.761981964111328,
    "metrics/l2_ratio": 0.9328801035881042,
    "metrics/l0": 32.0,
    "metrics

# 4. Generate Feature Dashboards

In [6]:
import gc

gc.collect()
torch.cuda.empty_cache()

In [7]:
from pathlib import Path

test_feature_idx_llama = list([2705, 9766, 18472, 22648, 24905, 25939, 27169, 27353, 27368, 32379])

feature_vis_config_llama = SaeVisConfig(
    hook_point=sae.cfg.hook_name,
    features=test_feature_idx_llama,
    minibatch_size_features=10,
    minibatch_size_tokens=16,  # this is number of prompts at a time.
    verbose=True,
    device="cuda",
    cache_dir=Path(
        "llama.layers.8_bs=128_nrows=16384"
    ),  # this will enable us to skip running the model for subsequent features.
    dtype="bfloat16",
)

data = SaeVisRunner(feature_vis_config_llama).run(
    encoder=sae,  # type: ignore
    model=model,
    tokens=token_dataset[:4096],
)

n_token_batches: 256
len(feature_batches): 1
len(tokens): 4096
cfg.minibatch_size_tokens: 16


Forward passes to cache data for vis:   0%|          | 0/256 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/10 [00:00<?, ?it/s]

get_model_acts_time 0.08502626419067383
encode_time 0.004608154296875


  return torch.load(


postprocessing_time 0.337327241897583
get_model_acts_time 0.0870199203491211
encode_time 0.0022530555725097656
postprocessing_time 0.33298230171203613
get_model_acts_time 0.0524289608001709
encode_time 0.002359151840209961
postprocessing_time 0.31594014167785645
get_model_acts_time 0.053992509841918945
encode_time 0.0021483898162841797
postprocessing_time 0.3226654529571533
get_model_acts_time 0.05787801742553711
encode_time 0.010416746139526367
postprocessing_time 0.3277270793914795
get_model_acts_time 0.05559968948364258
encode_time 0.0020656585693359375
postprocessing_time 0.32517290115356445
get_model_acts_time 0.06272339820861816
encode_time 0.002248525619506836
postprocessing_time 0.32788968086242676
get_model_acts_time 0.05564689636230469
encode_time 0.002056598663330078
postprocessing_time 0.33024096488952637
get_model_acts_time 0.0579228401184082
encode_time 0.002100706100463867
postprocessing_time 0.31865477561950684
get_model_acts_time 0.05360674858093262
encode_time 0.00206

KeyboardInterrupt: 

In [11]:
from sae_dashboard.data_writing_fns import save_feature_centric_vis

filename = f"llama.layers.8.thresholdfire_toks=16384x1024.html"
save_feature_centric_vis(sae_vis_data=data, filename=filename)

Saving feature-centric vis:   0%|          | 0/10 [00:00<?, ?it/s]