# Demo Notebook

Steps:
1. Download SAE with SAE Lens.
2. Create a dataset consistent with that SAE. 
3. Fold the SAE decoder norm weights so that feature activations are "correct".
4. Estimate the activation normalization constant if needed, and fold it into the SAE weights.
5. Run the SAE generator for the features you want.

# Set Up

In [1]:
%load_ext autoreload
%autoreload 2
import os
import json
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="1"
import torch
from sae_lens import SAE
from transformers import AutoTokenizer
from transformer_lens import HookedTransformer
from sae_dashboard.sae_vis_data import SaeVisConfig
from sae_dashboard.sae_vis_runner import SaeVisRunner
from sae_dashboard.data_writing_fns import save_prompt_centric_vis


In [2]:
!nvidia-smi

Fri Oct 18 17:49:52 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 555.42.06              Driver Version: 555.42.06      CUDA Version: 12.5     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100 80GB PCIe          Off |   00000001:00:00.0 Off |                    0 |
| N/A   38C    P0             74W /  300W |   69475MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A100 80GB PCIe          Off |   00

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
sae, cfg_dict, sparsity = SAE.from_eleuther(
    release="huypn16/sae-qwen-2.5-1.5B-OMS-16x",  # see other options in sae_lens/pretrained_saes.yaml
    sae_id="layers.14",  # won't always be a hook point
    device=device,
) # type: ignore

model = HookedTransformer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct", device=device, dtype="bfloat16")
# fold w_dec norm so feature activations are accurate
sae.fold_W_dec_norm()


Device: cuda


layers.14/cfg.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

-----Loading from eleuther-----
/datadrive5/.cache/hub/models--huypn16--sae-qwen-2.5-1.5B-OMS-16x/snapshots/af2cd1bd1a7cfc848469a0a7c70d2baa23e0eb56/layers.14/cfg.json
{'model_name': 'Qwen/Qwen2.5-1.5B-Instruct', 'architecture': 'topk', 'hook_name': 'blocks.14.hook_resid_post', 'hook_layer': 14, 'layer': 14, 'k': 64, 'activation_fn_str': 'relu', 'd_sae': 24576, 'd_in': 1536, 'multi_topk': False, 'device': 'cuda', 'apply_b_dec_to_input': False, 'finetuning_scaling_factor': False, 'context_size': 1024, 'hook_head_index': None, 'prepend_bos': True, 'normalize_activations': 'none', 'dtype': 'float32', 'sae_lens_training_version': 'eleuther', 'neuronpedia_id': None, 'activation_fn_kwargs': {}, 'model_from_pretrained_kwargs': {}}
{'model_name': 'Qwen/Qwen2.5-1.5B-Instruct', 'architecture': 'topk', 'hook_name': 'blocks.14.hook_resid_post', 'hook_layer': 14, 'layer': 14, 'k': 64, 'activation_fn_str': 'relu', 'd_sae': 24576, 'd_in': 1536, 'multi_topk': False, 'device': 'cuda', 'apply_b_dec_to_i



Loaded pretrained model Qwen/Qwen2.5-1.5B-Instruct into HookedTransformer


In [4]:
from sae_lens import ActivationsStore
activations_store = ActivationsStore.from_sae(
    model=model,
    sae=sae,
    streaming=True,
    dataset="lighteval/MATH",
    token_columns=["problem", "solution"],
    store_batch_size_prompts=16,
    n_batches_in_buffer=16,
    device=device,
)

Path, databilder:  lighteval/MATH
Token columns: ['problem', 'solution']




In [5]:
from tqdm import tqdm
def get_tokens(
    activations_store: ActivationsStore,
    n_prompts: int,
):
    all_tokens_list = []
    pbar = tqdm(range(n_prompts))
    for _ in pbar:
        batch_tokens = activations_store.get_batch_tokens()
        batch_tokens = batch_tokens[torch.randperm(batch_tokens.shape[0])][
            : batch_tokens.shape[0]
        ]
        all_tokens_list.append(batch_tokens)

    all_tokens = torch.cat(all_tokens_list, dim=0)
    all_tokens = all_tokens[torch.randperm(all_tokens.shape[0])]
    return all_tokens

# 1000 prompts is plenty for a demo.
if os.path.exists("token_dataset_lighteval.pt"):
    token_dataset = torch.load("token_dataset_lighteval.pt")
else:
    token_dataset = get_tokens(activations_store, n_prompts=2048)
    torch.save(token_dataset, "token_dataset_lighteval.pt")

print(tokenizer.decode(token_dataset[0]))
print(token_dataset.shape) # [store_batch_size_prompts * n_prompts, 1024]

<|im_end|> [asy]

pair X,Y,Z;

Z = (0,0);

Y = (sqrt(51),0);

X = (0,7);

draw(X--Y--Z--X);

draw(rightanglemark(Y,Z,X,15));

label("$X$",X,NE);

label("$Y$",Y,SE);

label("$Z$",Z,SW);

label("$10$",(X+Y)/2,NE);

label("$\sqrt{51}$",(Z+Y)/2,S);

[/asy]

Because this is a right triangle, $\tan X = \frac{YZ}{XZ}$.

Using the Pythagorean Theorem, we find $XZ = \sqrt{XY^2 - YZ^2} = \sqrt{100-51} = 7$.

So $\tan X = \boxed{\frac{\sqrt{51}}{7}}$.<|im_end|>problem: The points $B(1, 1)$, $I(2, 4)$ and $G(5, 1)$ are plotted in the standard rectangular coordinate system to form triangle $BIG$. Triangle $BIG$ is translated five units to the left and two units upward to triangle $B'I'G'$, in such a way that $B'$ is the image of $B$, $I'$ is the image of $I$, and $G'$ is the image of $G$. What is the midpoint of segment $B'G'$? Express your answer as an ordered pair.
solution: Since triangle $B^\prime I^\prime G^\prime$ is translated from triangle $BIG,$ the midpoint of $B^\prime G ^\prime $ is the

  token_dataset = torch.load("token_dataset_lighteval.pt")


# Step 3 Evaluate the SAE

In [6]:
from sae_lens import run_evals
from sae_lens.evals import get_eval_everything_config

eval_metrics = run_evals(
    sae=sae,
    activation_store=activations_store,
    model=model,
    eval_config=get_eval_everything_config(
        batch_size_prompts=8,
        n_eval_reconstruction_batches=10,
        n_eval_sparsity_variance_batches=10,
    )
)
print(json.dumps(eval_metrics, indent=4))
# CE Loss score should be high for residual stream SAEs
print(eval_metrics["metrics/ce_loss_score"])
# ce loss without SAE should be fairly low < 3.5 suggesting the Model is being run correctly
print(eval_metrics["metrics/ce_loss_without_sae"])
# ce loss with SAE shouldn't be massively higher
print(eval_metrics["metrics/ce_loss_with_sae"])

standard replacement hook:  blocks.14.hook_resid_post
standard replacement hook:  blocks.14.hook_resid_post
standard replacement hook:  blocks.14.hook_resid_post
standard replacement hook:  blocks.14.hook_resid_post
standard replacement hook:  blocks.14.hook_resid_post
standard replacement hook:  blocks.14.hook_resid_post
standard replacement hook:  blocks.14.hook_resid_post
standard replacement hook:  blocks.14.hook_resid_post
standard replacement hook:  blocks.14.hook_resid_post
standard replacement hook:  blocks.14.hook_resid_post
{
    "metrics/kl_div_with_sae": 0.310546875,
    "metrics/kl_div_with_ablation": 16.5,
    "metrics/ce_loss_with_sae": 1.109375,
    "metrics/ce_loss_without_sae": 0.79296875,
    "metrics/ce_loss_with_ablation": 17.25,
    "metrics/kl_div_score": 0.9811789772727273,
    "metrics/ce_loss_score": 0.9807737953952054,
    "metrics/l2_norm_in": 62.75,
    "metrics/l2_norm_out": 80.74810028076172,
    "metrics/l2_ratio": 1.3440078496932983,
    "metrics/l0": 6

# 4. Generate Feature Dashboards

In [7]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
from pathlib import Path
import random
# random features
test_feature_idx_qwen = random.sample(range(sae.cfg.d_sae), 4)
# test_feature_idx_qwen = [21719]
# test_feature_idx_qwen = random.sample(range(1024), 2)
# test_feature_idx_qwen = [1, 2, 3]
# test_feature_idx_qwen =  [101, 345, 4087, 4297, 4410, 4411, 4444, 4782, 4783, 4877, 4954, 6460, 6551, 6878, 7384, 7410, 8303, 9321, 9775, 10327, 10738, 11302, 11594, 13107, 14068, 14344, 15023, 15311, 16451, 17808, 17975, 18038, 18312, 18758, 18923, 20166, 21021, 21719, 21792, 21850, 22720, 23602, 24219]
# test_feature_idx_qwen = range(sae.cfg.d_sae)

feature_vis_config_llama = SaeVisConfig(
    hook_point=sae.cfg.hook_name,
    features=test_feature_idx_qwen,
    minibatch_size_features=4,
    minibatch_size_tokens=256,  # this is number of prompts at a time.
    verbose=True,
    device="cuda",
    cache_dir=Path(
        "math.qwen25.layers.14_bs=256_nrows=8192" # dataset.model.layer.bs.nrows
    ),  # this will enable us to skip running the model for subsequent features.
    dtype="bfloat16",
)

data = SaeVisRunner(feature_vis_config_llama).run(
    encoder=sae,  # type: ignore
    model=model,
    tokens=token_dataset[:8192],
)

In [None]:
# filename_4 = f"prompt.qwen_math.layers.14.math_toks=16384x1024_complexsent_step4.html"
# filename_5 = f"prompt.qwen_math.layers.14.math_toks=16384x1024_complexsent_step5.html"

# prompt = "Solving the following mathematical problem. Problem: Calculate the following expression: (12 + 1000 * 2 - 1 ) * 412 - 2. Step 1: First, we minus 2 and 1"

# prompt = """Problem:  Let $(a_1,b_1),$ $(a_2,b_2),$ $\dots,$ $(a_n,b_n)$ be all the ordered pairs $(a,b)$ of complex numbers with $a^2+b^2\\neq 0,$
# \[a+\\frac{10b}{a^2+b^2}=5, \quad \\text{and} \quad b+\\frac{10a}{a^2+b^2}=4.\]Find $a_1 + b_1 + a_2 + b_2 + \dots + a_n + b_n.$ 

# #### Step 1: If $a = 0,$ then $\\frac{10}{b} = 5,$ so $b = 2,$ which does not satisfy the second equation.

# #### Step 2: If $b = 0,$ then $\\frac{10}{a} = 4,$ so $a = \\frac{5}{2},$ which does not satisfy the first equation.

# #### Step 3: So, we can assume that both $a$ and $b$ are nonzero.

# #### Step 4: Then $\\frac{5 - a}{b} = \\frac{4 - b}{a} = \\frac{10}{a^2 + b^2}.$

# #### Step 5: \[\\frac{5b - ab}{b^2} = \\frac{4a - ab}{a^2} = \\frac{10}{a^2 + b^2},\]so
# \[\\frac{4a + 5b - 2ab}{a^2 + b^2} = \\frac{10}{a^2 + b^2},\]so $4a + 5b - 2ab = 10.$

# #### Step 6: Then $2ab - 4a - 5b + 10 = 0,$ which factors as $(2a - 5)(b - 2) = 0.$  Hence, $a = \\frac{5}{2}$ or $b = 2.$"""

prompt = """Problem:  Find all the ordered pairs $(a,b)$ of complex numbers with $a^2+b^2\\neq 0,
$\[a+\\frac{10b}{a^2+b^2}=5, \\text{and} b+\\frac{10a}{a^2+b^2}=4.\]$"""

suffix = "#### Step 1: If $a = 0,$ then $\\frac{10}{b} = 5,$ so $b = 2$."

# ------------------------------------- Step 1 -------------------------------------
# prompt = """Problem:  Let $(a_1,b_1),$ $(a_2,b_2),$ $\dots,$ $(a_n,b_n)$ be all the ordered pairs $(a,b)$ of complex numbers with $a^2+b^2\\neq 0,$
# \[a+\\frac{10b}{a^2+b^2}=5, \quad \\text{and} \quad b+\\frac{10a}{a^2+b^2}=4.\]Find $a_1 + b_1 + a_2 + b_2 + \dots + a_n + b_n.$ """

# suffix = "#### Step 1: If $a = 0,$ then $\\frac{10}{b} = 5,$ so $b = 2,$ which does not satisfy the second equation."

# ------------------------------------- Step 2 -------------------------------------
# prompt = """Problem:  Let $(a_1,b_1),$ $(a_2,b_2),$ $\dots,$ $(a_n,b_n)$ be all the ordered pairs $(a,b)$ of complex numbers with $a^2+b^2\\neq 0,$
# \[a+\\frac{10b}{a^2+b^2}=5, \quad \\text{and} \quad b+\\frac{10a}{a^2+b^2}=4.\]Find $a_1 + b_1 + a_2 + b_2 + \dots + a_n + b_n.$ 

# #### Step 1: If $a = 0,$ then $\\frac{10}{b} = 5,$ so $b = 2,$ which does not satisfy the second equation."""

# suffix = "#### Step 2: If $b = 0,$ then $\\frac{10}{a} = 4,$ so $a = \\frac{5}{2},$ which does not satisfy the first equation."


# ------------------------------------- Step 3 -------------------------------------
# prompt = """Problem:  Let $(a_1,b_1),$ $(a_2,b_2),$ $\dots,$ $(a_n,b_n)$ be all the ordered pairs $(a,b)$ of complex numbers with $a^2+b^2\\neq 0,$
# \[a+\\frac{10b}{a^2+b^2}=5, \quad \\text{and} \quad b+\\frac{10a}{a^2+b^2}=4.\]Find $a_1 + b_1 + a_2 + b_2 + \dots + a_n + b_n.$ 

# #### Step 1: If $a = 0,$ then $\\frac{10}{b} = 5,$ so $b = 2,$ which does not satisfy the second equation.

# #### Step 2: If $b = 0,$ then $\\frac{10}{a} = 4,$ so $a = \\frac{5}{2},$ which does not satisfy the first equation."""

# suffix = "#### Step 3: So, we can assume that both $a$ and $b$ are nonzero."

# ------------------------------------- Step 4 -------------------------------------
# prompt_4 = """Problem:  Let $(a_1,b_1),$ $(a_2,b_2),$ $\dots,$ $(a_n,b_n)$ be all the ordered pairs $(a,b)$ of complex numbers with $a^2+b^2\\neq 0,$
# \[a+\\frac{10b}{a^2+b^2}=5, \quad \\text{and} \quad b+\\frac{10a}{a^2+b^2}=4.\]Find $a_1 + b_1 + a_2 + b_2 + \dots + a_n + b_n.$ 

# #### Step 1: If $a = 0,$ then $\\frac{10}{b} = 5,$ so $b = 2,$ which does not satisfy the second equation.

# #### Step 2: If $b = 0,$ then $\\frac{10}{a} = 4,$ so $a = \\frac{5}{2},$ which does not satisfy the first equation.

# #### Step 3: So, we can assume that both $a$ and $b$ are nonzero."""

# suffix_4 =  "#### Step 4: Then $\\frac{5 - a}{b} = \\frac{4 - b}{a} = \\frac{10}{a^2 + b^2}.$"
# ------------------------------------- Step 5 -------------------------------------
# prompt_5 = """Problem:  Let $(a_1,b_1),$ $(a_2,b_2),$ $\dots,$ $(a_n,b_n)$ be all the ordered pairs $(a,b)$ of complex numbers with $a^2+b^2\\neq 0,$
# \[a+\\frac{10b}{a^2+b^2}=5, \quad \\text{and} \quad b+\\frac{10a}{a^2+b^2}=4.\]Find $a_1 + b_1 + a_2 + b_2 + \dots + a_n + b_n.$ 

# #### Step 1: If $a = 0,$ then $\\frac{10}{b} = 5,$ so $b = 2,$ which does not satisfy the second equation.

# #### Step 2: If $b = 0,$ then $\\frac{10}{a} = 4,$ so $a = \\frac{5}{2},$ which does not satisfy the first equation.

# #### Step 3: So, we can assume that both $a$ and $b$ are nonzero.

# #### Step 4: Then $\\frac{5 - a}{b} = \\frac{4 - b}{a} = \\frac{10}{a^2 + b^2}.$"""

# suffix_5 = """#### Step 5: \[\\frac{5b - ab}{b^2} = \\frac{4a - ab}{a^2} = \\frac{10}{a^2 + b^2},\]so
# # \[\\frac{4a + 5b - 2ab}{a^2 + b^2} = \\frac{10}{a^2 + b^2},\]so $4a + 5b - 2ab = 10.$"""
fn = f"prompt.qwen_math.layers.14.math_toks=8192x1024_complexsent_step0.html"
save_prompt_centric_vis(prompt=prompt+suffix, sae_vis_data=data, filename=fn, num_top_features=20)
# save_prompt_centric_vis(prompt=prompt_4+suffix_4, sae_vis_data=data, filename=filename_4, num_top_features=20)
# save_prompt_centric_vis(prompt=prompt_5+suffix_5, sae_vis_data=data, filename=filename_5, num_top_features=20)

In [None]:
from sae_dashboard.data_writing_fns import save_feature_centric_vis

filename = f"feature.qwen25.layers.14.complexsent.4096f.toks=16384x1024.html"
save_feature_centric_vis(sae_vis_data=data, filename=filename)

In [22]:
# Imports for displaying vis in Colab / notebook
import webbrowser
import http.server
import socketserver
import threading
PORT = 8000

def display_vis_inline(filename: str, height: int = 850):
    '''
    Displays the HTML files in Colab. Uses global `PORT` variable defined in prev cell, so that each
    vis has a unique port without having to define a port within the function.
    '''
    webbrowser.open(filename);
display_vis_inline(filename_4)

In [None]:
%%html
from IPython import get_ipython # type: ignore
ipython = get_ipython(); assert ipython is not None
ipython.run_line_magic("load_ext", "autoreload")
ipython.run_line_magic("autoreload", "2")
from IPython.display import IFrame

IFrame(src="/datadrive5/huypn16-backup/ReinforceLLM/packages/SAEDashboard/prompt.qwen_math.layers.14.math_toks=16384x1024_complexsent_step1.html", width=700, height=600)
