In [1]:
from __future__ import annotations
"""
TODO(Adriano) after getting some plots for the blobs post-SAE in GPT2, it's important to check whether these
SAEs are actually any good. Unfortunately, I have really bad FVUs, MSEs, etc... It's also unclear if a error
norm of 30 is normal. People seem not to be reporting this very well and it's deeply annoying.
"""
import os
from pathlib import Path
from typing import Dict, List
import torch
import torch.nn as nn
from datasets import load_dataset
from sae_lens import SAE, HookedSAETransformer
from torch.utils.data import DataLoader
import dotenv
from transformers import AutoTokenizer

from transformer_lens.utils import tokenize_and_concatenate

dotenv.load_dotenv()

class SAEExtractor:
    """
    Wraps a HookedSAETransformer and a bank of pre-trained SAEs to combine
    (1) loading the SAEs from the repository using SAELens, (2) Running them in
    a hooked transformer, getting activations, etc...

    This is meant ONLY for gpt2 and JBloom's SAEs.
    """

    def __init__(
        self,
        *,
        device: str | torch.device | None = None,
    ) -> None:
        self.model_name = "gpt2"
        self.sae_release = "gpt2-small-res-jb"
        self.device = "cuda" # NOTE: you should use CUDA_VISIBLE_DEVICES to select the GPU
        self.model = HookedSAETransformer.from_pretrained(self.model_name, device=self.device)
        self.tokenizer = self.model.tokenizer
        # self.saes: Dict[str, SAE] = {}
        self._load_saes()

    def _load_saes(self) -> None:
        # TODO(Adriano) make this more modular...
        self.block2sae = []
        self.cfg_dics = []
        self.sparsities = []
        # print(self.model.cfg.n_layers)
        for layer in range(self.model.cfg.n_layers):
            sae, cfg_dict, sparsity = SAE.from_pretrained(
                release = self.sae_release, # see other options in sae_lens/pretrained_saes.yaml
                sae_id = f"blocks.{layer}.hook_resid_pre", # won't always be a hook point
                device = self.device
            )
            self.block2sae.append(sae)
            self.cfg_dics.append(cfg_dict)
            self.sparsities.append(sparsity)

    def cache_activations(self, tokens: torch.Tensor, hooks: List[str] | None = None) -> Dict[str, torch.Tensor]:
        hooks = hooks or list(self.saes)
        _, cache = self.model.run_with_cache_with_saes(tokens, act_names=hooks)
        out: Dict[str, torch.Tensor] = {}
        for hook in hooks:
            out_key = f"{hook}.hook_sae_acts_post"
            out[out_key] = cache[out_key].cpu()
        return out

print("="*50 + " [Loading Dataset] " + "="*50) # DEBUG
# dataset = load_dataset("openwebtext", split="train", trust_remote_code=True)
dataset = load_dataset("stas/openwebtext-10k", split="train", trust_remote_code=True) # Smaller version
tokenizer = AutoTokenizer.from_pretrained("gpt2")
token_dataset = tokenize_and_concatenate(
    dataset=dataset,  # type: ignore
    tokenizer=tokenizer,  # type: ignore
    streaming=True,
    max_length=1024, #sae.cfg.context_size,
    add_bos_token=True, #sae.cfg.prepend_bos,
)
# print(token_dataset) # Sanity
# print(token_dataset[0]['tokens']) # Sanity

print("="*50 + " [Loading Model] " + "="*50) # DEBUG
extractor = SAEExtractor()

  from .autonotebook import tqdm as notebook_tqdm




huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Loaded pretrained model gpt2 into HookedTransformer


This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


In [3]:
import tqdm
import torch
import gc
import matplotlib.pyplot as plt
gc.collect()
torch.cuda.empty_cache()

token_dataset_short = token_dataset[:300]['tokens']
dataset_length = token_dataset_short.shape[0]
sequence_length = token_dataset_short.shape[1]
sae_ins = torch.zeros((len(extractor.block2sae), dataset_length, sequence_length, extractor.model.cfg.d_model), device="cpu")
sae_outs = torch.zeros((len(extractor.block2sae), dataset_length, sequence_length, extractor.model.cfg.d_model), device="cpu")
print(sae_ins.shape)
print(sae_outs.shape)
for i in range(len(extractor.block2sae)):
    extractor.block2sae[i].eval()

batch_size = 30
while True:
    try:
        pbar = tqdm.trange(0, len(token_dataset_short), batch_size, desc=f"Batch Size = {batch_size}")
        with torch.no_grad():
            for i in pbar:
                j = min(i + batch_size, len(token_dataset_short))
                # activation store can give us tokens.
                batch_tokens = token_dataset_short[i:j]
                _, cache = extractor.model.run_with_cache(batch_tokens, prepend_bos=True)

                # print(cache.keys())
                # Use the SAE
                # print(len(extractor.block2sae))
                # print(f"hook_name={extractor.block2sae[8].cfg.hook_name}") # Nope
                sae_in = torch.stack([cache[extractor.block2sae[i].cfg.hook_name].detach() for i in range(len(extractor.block2sae))])
                del cache
                gc.collect()
                torch.cuda.empty_cache()
                # feature_acts = [extractor.block2sae[i].encode(sae_in[i])
                sae_out = torch.stack([extractor.block2sae[i](sae_in[i]).detach().cpu() for i in range(len(extractor.block2sae))])

                assert sae_in.shape == sae_out.shape
                assert sae_in.shape[0] == len(extractor.block2sae)
                assert sae_in.shape[1] == j - i
                assert sae_in.shape[2] == sequence_length
                assert sae_in.shape[3] == extractor.model.cfg.d_model, f"sae_in.shape={sae_in.shape}, need [2] = {extractor.model.cfg.d_model}" # fmt: skip
                assert sae_in.ndim == 4
                sae_ins[:, i:j, :] = sae_in.cpu()
                sae_outs[:, i:j, :] = sae_out.cpu()
        break # OK
    except torch.OutOfMemoryError as e:
        print(type(e), e)
        if batch_size == 1:
            raise e
        batch_size = min(1, batch_size // 2)

torch.Size([12, 300, 1024, 768])
torch.Size([12, 300, 1024, 768])


Batch Size = 30: 100%|██████████| 10/10 [00:32<00:00,  3.25s/it]


In [7]:
#         # ignore the bos token, get the number of features that activated in each token, averaged accross batch and position
#         # l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()
#         for i in range(len(extractor.block2sae)):
#             err_norm = (sae_out[i] - sae_in[i]).norm(dim=-1).detach()
#             fvu = ((sae_out[i] - sae_in[i]).pow(2).sum(dim=-1) / sae_in[i].pow(2).sum(dim=-1))
#             mse = (sae_out[i] - sae_in[i]).pow(2).mean(dim=-1).detach() # "variance explained"
#             # TODO(Adriano) error norm seems REALLY BAD
#             print(f"average error norm: {err_norm.mean().item()} +/- {err_norm.std().item()}; min={err_norm.min().item()}, max={err_norm.max().item()}")
#             print(f"average mse: {mse.mean().item()} +/- {mse.std().item()}; min={mse.min().item()}, max={mse.max().item()}")
#             print(f"average fvu: {fvu.mean().item()} +/- {fvu.std().item()}; min={fvu.min().item()}, max={fvu.max().item()}")
#             plt.hist(err_norm.flatten().cpu().log10().numpy(), bins=100) # around 100-1000 error => maybe around 0.7-7 error? => around 1.2 for this shit
#             plt.show()
#             gc.collect()
#             torch.cuda.empty_cache()
    

In [12]:
import einops
print("Flattening")

flatten = "layer batch seq dim -> layer (batch seq) dim"
sae_in_flat = einops.rearrange(sae_in, flatten).cpu()
sae_out_flat = einops.rearrange(sae_out, flatten).cpu()

print("Computing explained variances")
# https://github.com/jbloomAus/SAELens/blob/be0e55f69d360a0100027de1cf3f1a1606cf5552/sae_lens/evals.py#L512
sq_errs = (sae_in_flat - sae_out_flat).pow(2).sum(dim=-1) # norms of errors
var = (sae_in_flat - sae_in_flat.mean(dim=0)).pow(2).sum(dim=-1) # variance of 1st dim then sum those
explained_variances = 1 - sq_errs / var
print("\n".join(map(lambda x: str(x.shape), explained_variances)))
# Across batch and sequence unique per layer
print("="*100)
mean_evs = explained_variances.mean(dim=-1)
std_evs = explained_variances.std(dim=-1)
max_evs = explained_variances.max(dim=-1).values
min_evs = explained_variances.min(dim=-1).values
print(mean_evs.shape)
print(std_evs.shape)
print(max_evs.shape)
print(min_evs.shape)
print("="*100)
for mean, std, _max, _min in zip(
    mean_evs, std_evs, max_evs, min_evs
):
    print(f"Explained variance: {mean.item():.4f} +/- {std.item():.4f}; min={_min.item():.4f}, max={_max.item():.4f}")

# # Plot the explained variances for each layer
# plt.figure(figsize=(10, 6))
# plt.errorbar(
#     range(len(mean_evs)), 
#     mean_evs, 
#     yerr=std_evs, 
#     fmt='o-', 
#     capsize=5, 
#     label='Mean Explained Variance with Std Dev'
# )
# plt.fill_between(
#     range(len(mean_evs)),
#     min_evs,
#     max_evs,
#     alpha=0.2,
#     label='Min-Max Range'
# )
# plt.xlabel('Layer')
# plt.ylabel('Explained Variance')
# plt.title('Explained Variance by Layer')
# plt.grid(True, linestyle='--', alpha=0.7)
# plt.legend()
# plt.tight_layout()
# plt.show()

# # Print summary statistics
# print(f"\nOverall mean explained variance: {sum(mean_evs)/len(mean_evs):.4f}")
# print(f"Best layer: {torch.argmax(torch.Tensor(mean_evs)).item()} with {max(mean_evs):.4f}")
# print(f"Worst layer: {torch.argmin(torch.Tensor(mean_evs)).item()} with {min(mean_evs):.4f}")

Flattening
Computing explained variances
torch.Size([30720])
torch.Size([30720])
torch.Size([30720])
torch.Size([30720])
torch.Size([30720])
torch.Size([30720])
torch.Size([30720])
torch.Size([30720])
torch.Size([30720])
torch.Size([30720])
torch.Size([30720])
torch.Size([30720])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
Explained variance: 0.9251 +/- 0.1782; min=-0.7076, max=1.0000
Explained variance: -18.3155 +/- 56.2406; min=-1238.9437, max=1.0000
Explained variance: -5626.2036 +/- 18073.5391; min=-263304.5938, max=1.0000
Explained variance: -7083.8779 +/- 21524.2480; min=-304812.6250, max=0.9997
Explained variance: -1295.5112 +/- 5120.8442; min=-90918.4766, max=0.9999
Explained variance: -307.4859 +/- 1080.4110; min=-28341.6680, max=1.0000
Explained variance: -160.1618 +/- 345.7065; min=-6152.4971, max=0.9999
Explained variance: -138.6607 +/- 262.8350; min=-6145.1611, max=0.9999
Explained variance: -51.0636 +/- 80.9067; min=-1168.5859, max=0.9999
Explained

In [8]:
# from transformer_lens import utils
# from functools import partial
# model = extractor.model
# sae = extractor.block2sae[8]

# # next we want to do a reconstruction test.
# def reconstr_hook(activation, hook, sae_out):
#     return sae_out


# def zero_abl_hook(activation, hook):
#     return torch.zeros_like(activation)


# print("Orig", model(batch_tokens, return_type="loss").item())
# print(
#     "reconstr",
#     model.run_with_hooks(
#         batch_tokens,
#         fwd_hooks=[
#             (
#                 sae.cfg.hook_name,
#                 partial(reconstr_hook, sae_out=sae_out[8]),
#             )
#         ],
#         return_type="loss",
#     ).item(),
# )
# print(
#     "Zero",
#     model.run_with_hooks(
#         batch_tokens,
#         return_type="loss",
#         fwd_hooks=[(sae.cfg.hook_name, zero_abl_hook)],
#     ).item(),
# )

In [9]:
# example_prompt = "When John and Mary went to the shops, John gave the bag to"
# example_answer = " Mary"
# utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

# logits, cache = model.run_with_cache(example_prompt, prepend_bos=True)
# tokens = model.to_tokens(example_prompt)
# sae_out = sae(cache[sae.cfg.hook_name])


# def reconstr_hook(activations, hook, sae_out):
#     return sae_out


# def zero_abl_hook(mlp_out, hook):
#     return torch.zeros_like(mlp_out)


# hook_name = sae.cfg.hook_name

# print("Orig", model(tokens, return_type="loss").item())
# print(
#     "reconstr",
#     model.run_with_hooks(
#         tokens,
#         fwd_hooks=[
#             (
#                 hook_name,
#                 partial(reconstr_hook, sae_out=sae_out),
#             )
#         ],
#         return_type="loss",
#     ).item(),
# )
# print(
#     "Zero",
#     model.run_with_hooks(
#         tokens,
#         return_type="loss",
#         fwd_hooks=[(hook_name, zero_abl_hook)],
#     ).item(),
# )


# with model.hooks(
#     fwd_hooks=[
#         (
#             hook_name,
#             partial(reconstr_hook, sae_out=sae_out),
#         )
#     ]
# ):
#     utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

In [None]:
# from sae_dashboard.sae_vis_data import SaeVisConfig
# from sae_dashboard.sae_vis_runner import SaeVisRunner
# device = "cuda"
# test_feature_idx_gpt = list(range(10)) + [14057]

# feature_vis_config_gpt = SaeVisConfig(
#     hook_point=hook_name,
#     features=test_feature_idx_gpt,
#     minibatch_size_features=64,
#     minibatch_size_tokens=256,
#     verbose=True,
#     device=device,
# )

# visualization_data_gpt = SaeVisRunner(
#     feature_vis_config_gpt
# ).run(
#     encoder=sae,  # type: ignore
#     model=model,
#     tokens=token_dataset[:10000]["tokens"],  # type: ignore
# )
# # SaeVisData.create(
# #     encoder=sae,
# #     model=model, # type: ignore
# #     tokens=token_dataset[:10000]["tokens"],  # type: ignore
# #     cfg=feature_vis_config_gpt,
# # )

In [None]:
# from sae_dashboard.data_writing_fns import save_feature_centric_vis

# filename = f"demo_feature_dashboards.html"
# save_feature_centric_vis(sae_vis_data=visualization_data_gpt, filename=filename)

In [None]:
# from sae_lens.analysis.neuronpedia_integration import get_neuronpedia_quick_list

# # this function should open
# neuronpedia_quick_list = get_neuronpedia_quick_list(sae, test_feature_idx_gpt)

# print(neuronpedia_quick_list)