# Demo Notebook

Steps:
1. Download SAE with SAE Lens.
2. Create a dataset consistent with that SAE. 
3. Fold the SAE decoder norm weights so that feature activations are "correct".
4. Estimate the activation normalization constant if needed, and fold it into the SAE weights.
5. Run the SAE generator for the features you want.

# Set Up

In [None]:
%load_ext autoreload
%autoreload 2
import json
import torch
from sae_lens import SAE
from transformers import AutoTokenizer
from transformer_lens import HookedTransformer
from sae_dashboard.sae_vis_data import SaeVisConfig
from sae_dashboard.sae_vis_runner import SaeVisRunner
from sae_dashboard.data_writing_fns import save_prompt_centric_vis

In [None]:
# Imports for displaying vis in Colab / notebook
import webbrowser
import http.server
import socketserver
import threading
PORT = 8000
def display_vis_inline(filename: str, height: int = 850):
    '''
    Displays the HTML files in Colab. Uses global `PORT` variable defined in prev cell, so that each
    vis has a unique port without having to define a port within the function.
    '''
    webbrowser.open(filename);

## Step 1. Download / Initialize SAE

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
model = HookedTransformer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct", dtype="float16", device=device)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
sae, cfg_dict, sparsity = SAE.from_eleuther(
    release="huypn16/sae-qwen-2.5-1.5B-OWM-16x",  # see other options in sae_lens/pretrained_saes.yaml
    sae_id="layers.14",  # won't always be a hook point
    device=device,
)
# fold w_dec norm so feature activations are accurate
sae.fold_W_dec_norm()


# 2. Get token dataset

In [None]:
from sae_lens import ActivationsStore
activations_store = ActivationsStore.from_sae(
    model=model,
    sae=sae,
    streaming=True,
    dataset="open-web-math/open-web-math",
    token_columns=["text"],
    store_batch_size_prompts=16,
    n_batches_in_buffer=16,
    device=device,
)

In [None]:
from tqdm import tqdm
import os
def get_tokens(
    activations_store: ActivationsStore,
    n_prompts: int,
):
    all_tokens_list = []
    pbar = tqdm(range(n_prompts))
    for _ in pbar:
        batch_tokens = activations_store.get_batch_tokens()
        batch_tokens = batch_tokens[torch.randperm(batch_tokens.shape[0])][
            : batch_tokens.shape[0]
        ]
        all_tokens_list.append(batch_tokens)

    all_tokens = torch.cat(all_tokens_list, dim=0)
    all_tokens = all_tokens[torch.randperm(all_tokens.shape[0])]
    return all_tokens

# 1000 prompts is plenty for a demo.
if os.path.exists("qwen_owm.pt"):
    token_dataset = torch.load("qwen_owm.pt")
else:
    token_dataset = get_tokens(activations_store, n_prompts=1024)
    torch.save(token_dataset, "qwen_owm.pt")

print(tokenizer.decode(token_dataset[0][:32]))
print(token_dataset.shape) # [store_batch_size_prompts * n_prompts, 1024]

# Step 3 Evaluate the SAE

In [None]:
from sae_lens import run_evals
from sae_lens.evals import get_eval_everything_config

eval_metrics = run_evals(
    sae=sae,
    activation_store=activations_store,
    model=model,
    eval_config=get_eval_everything_config(
        batch_size_prompts=8,
        n_eval_reconstruction_batches=10,
        n_eval_sparsity_variance_batches=3,
    )
)
print(json.dumps(eval_metrics, indent=4))
# CE Loss score should be high for residual stream SAEs
print(eval_metrics["metrics/ce_loss_score"])
# ce loss without SAE should be fairly low < 3.5 suggesting the Model is being run correctly
print(eval_metrics["metrics/ce_loss_without_sae"])
# ce loss with SAE shouldn't be massively higher
print(eval_metrics["metrics/ce_loss_with_sae"])

# 4. Generate Feature Dashboards

In [6]:
import gc

gc.collect()
torch.cuda.empty_cache()

In [None]:
from pathlib import Path
import random
# random features
test_feature_idx_qwen = random.sample(range(sae.cfg.d_sae), 100)
# test_feature_idx_llama = [21719]
# test_feature_idx_llama = random.sample(range(1024), 2)
# test_feature_idx_qwen = [1, 2, 3]
# test_feature_idx_qwen =  [101, 345, 4087, 4297, 4410, 4411, 4444, 4782, 4783, 4877, 4954, 6460, 6551, 6878, 7384, 7410, 8303, 9321, 9775, 10327, 10738, 11302, 11594, 13107, 14068, 14344, 15023, 15311, 16451, 17808, 17975, 18038, 18312, 18758, 18923, 20166, 21021, 21719, 21792, 21850, 22720, 23602, 24219]
# test_feature_idx_qwen = [200, 416, 694, 787, 837, 1189, 1262, 1265, 1418, 1536, 1879, 1908, 1941, 1948, 1993, 2073, 2363, 2365, 2388, 2489, 2536, 2701, 2722, 2981, 3390, 3499, 3779, 4274, 4627, 4812, 5980, 6000, 6133, 6454, 6650, 6809, 7115, 7503, 7597, 7836, 8042, 8454, 8694, 9655, 10000, 10046, 10372, 10466, 10509, 10605, 10800, 11501, 11551, 13075, 13357, 13371, 13724, 13793, 14715, 14841, 14986, 14996, 15303, 16754, 16852, 16882, 16887, 17110, 17186, 17313, 17394, 17430, 17438, 17996, 18090, 18256, 18451, 18540, 18613, 18992, 19281, 19298, 19664, 19985, 21002, 21558, 21874, 21967, 21968, 22083, 22679, 22765, 22928, 23296, 23317, 23423, 23545, 23609, 23682]
# test_feature_idx_qwen = range(sae.cfg.d_sae)

feature_vis_config_llama = SaeVisConfig(
    hook_point=sae.cfg.hook_name,
    features=test_feature_idx_qwen,
    minibatch_size_features=100,
    minibatch_size_tokens=256,  # this is number of prompts at a time.
    verbose=True,
    device="cuda",
    cache_dir=Path(
        "qwen25.layers.14_bs=256_nrows=16384"
    ),  # this will enable us to skip running the model for subsequent features.
    dtype="bfloat16",
)

data = SaeVisRunner(feature_vis_config_llama).run(
    encoder=sae,  # type: ignore
    model=model,
    tokens=token_dataset[:8192],
)

In [3]:

filename = f"qwen_owm.layers.14.math_toks=16384x1024.html"
# prompt = "Solving the following mathematical problem. Problem: Calculate the following expression: (12 + 1000 * 2 - 1 ) * 412 - 2. Step 1: First, we minus 2 and 1"
prompt = """Problem:  Let $(a_1,b_1),$ $(a_2,b_2),$ $\dots,$ $(a_n,b_n)$ be all the ordered pairs $(a,b)$ of complex numbers with $a^2+b^2\\neq 0,$
\[a+\\frac{10b}{a^2+b^2}=5, \quad \\text{and} \quad b+\\frac{10a}{a^2+b^2}=4.\]Find $a_1 + b_1 + a_2 + b_2 + \dots + a_n + b_n.$ 

#### Step 1: If $a = 0,$ then $\\frac{10}{b} = 5,$ so $b = 2,$ which does not satisfy the second equation.

#### Step 2: If $b = 0,$ then $\\frac{10}{a} = 4,$ so $a = \\frac{5}{2},$ which does not satisfy the first equation.

#### Step 3: So, we can assume that both $a$ and $b$ are nonzero.

#### Step 4: Then $\\frac{5 - a}{b} = \\frac{4 - b}{a} = \\frac{10}{a^2 + b^2}.$

#### Step 5: \[\\frac{5b - ab}{b^2} = \\frac{4a - ab}{a^2} = \\frac{10}{a^2 + b^2},\]so
\[\\frac{4a + 5b - 2ab}{a^2 + b^2} = \\frac{10}{a^2 + b^2},\]so $4a + 5b - 2ab = 10.$

#### Step 6: Then $2ab - 4a - 5b + 10 = 0,$ which factors as $(2a - 5)(b - 2) = 0.$  Hence, $a = \\frac{5}{2}$ or $b = 2.$

#### Step 7: If $a = \\frac{5}{2},$ then \[\\frac{5/2}{b} = \\frac{10}{\\frac{25}{4} + b^2}.\]. This simplifies to $4b^2 - 16b + 25 = 0.$  By the quadratic formula,
\[b = 2 \pm \\frac{3i}{2}.\]"""

save_prompt_centric_vis(prompt=prompt, sae_vis_data=data, filename=filename)

NameError: name 'save_prompt_centric_vis' is not defined

In [None]:
from sae_dashboard.data_writing_fns import save_feature_centric_vis

filename = f"qwen_owm.layers.14.math_toks=16384x1024.html"
save_feature_centric_vis(sae_vis_data=data, filename=filename)