# Evaluating your SAE

## Set Up

In [1]:
import os
import sys
import torch
import wandb
import json
import plotly.express as px
from transformer_lens import utils
from datasets import load_dataset
from typing import  Dict
from pathlib import Path

from functools import partial

sys.path.append("..")

from sae_training.utils import LMSparseAutoencoderSessionloader
from sae_analysis.visualizer import data_fns, html_fns
from sae_analysis.visualizer.data_fns import get_feature_data, FeatureData

if torch.backends.mps.is_available():
    device = "mps" 
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

torch.set_grad_enabled(False)

def imshow(x, **kwargs):
    x_numpy = utils.to_numpy(x)
    px.imshow(x_numpy, **kwargs).show()
    

# Load your Autoencoder



In [3]:

# Load model from Huggingface
# run = wandb.init()
# artifact = run.use_artifact('jbloom/mats_sae_training_gpt2_small/sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_6144:v2', type='model')
# artifact_dir = artifact.download()

# Load in Model
path = "../artifacts/sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_6144:v2/1200001024_sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_6144.pt"
model, sparse_autoencoder, activations_loader = LMSparseAutoencoderSessionloader.load_session_from_pretrained(
    path
)

VBox(children=(Label(value='0.077 MB of 0.077 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))



VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011167762966619598, max=1.0…

[34m[1mwandb[0m:   1 of 1 files downloaded.  


Loaded pretrained model gpt2-small into HookedTransformer


Token indices sequence length is longer than the specified maximum sequence length for this model (1218 > 1024). Running this sequence through the model will result in indexing errors


## Test the Autoencoder

### L0 Test and Reconstruction Test

In [4]:
with torch.no_grad():
    batch_tokens = activations_loader.get_batch_tokens()
    _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)
    sae_out, feature_acts, loss, mse_loss, l1_loss = sparse_autoencoder(
        cache[sparse_autoencoder.cfg.hook_point]
    )
    del cache
    
    
    l0 = (feature_acts > 0).float().sum(-1).detach()
    print("average l0", l0.mean().item())
    px.histogram(l0.flatten().cpu().numpy()).show()


315.20526123046875


Orig 3.646359920501709
reconstr 3.8246564865112305
Zero 11.784465789794922


In [None]:
def reconstr_hook(mlp_out, hook, new_mlp_out):
    return new_mlp_out

def zero_abl_hook(mlp_out, hook):
    return torch.zeros_like(mlp_out)

print("Orig", model(batch_tokens, return_type="loss").item())
print(
    "reconstr",
    model.run_with_hooks(
        batch_tokens,
        fwd_hooks=[
            (
                utils.get_act_name("resid_pre", 10),
                partial(reconstr_hook, new_mlp_out=sae_out),
            )
        ],
        return_type="loss",
    ).item(),
)
print(
    "Zero",
    model.run_with_hooks(
        batch_tokens,
        return_type="loss",
        fwd_hooks=[(utils.get_act_name("resid_pre", 10), zero_abl_hook)],
    ).item(),
)

## Specific Capability Test

Validating model performance on specific tasks when using the reconstructed activation is quite important when studying specific tasks.

In [5]:
example_prompt = "When John and Mary went to the shops, John gave the bag to"
example_answer = " Mary"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

logits, cache = model.run_with_cache(example_prompt, prepend_bos=True)
tokens = model.to_tokens(example_prompt)
sae_out, feature_acts, loss, mse_loss, l1_loss = sparse_autoencoder(
    cache[sparse_autoencoder.cfg.hook_point]
)

def reconstr_hook(mlp_out, hook, new_mlp_out):
    return new_mlp_out


def zero_abl_hook(mlp_out, hook):
    return torch.zeros_like(mlp_out)

print("Orig", model(tokens, return_type="loss").item())
print(
    "reconstr",
    model.run_with_hooks(
        tokens,
        fwd_hooks=[
            (
                utils.get_act_name("resid_pre", 10),
                partial(reconstr_hook, new_mlp_out=sae_out),
            )
        ],
        return_type="loss",
    ).item(),
)
print(
    "Zero",
    model.run_with_hooks(
        tokens,
        return_type="loss",
        fwd_hooks=[(utils.get_act_name("resid_pre", 10), zero_abl_hook)],
    ).item(),
)


with model.hooks(
    fwd_hooks=[
        (
            utils.get_act_name("resid_pre", 10),
            partial(reconstr_hook, new_mlp_out=sae_out),
        )
    ]
):
    utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'When', ' John', ' and', ' Mary', ' went', ' to', ' the', ' shops', ',', ' John', ' gave', ' the', ' bag', ' to']
Tokenized answer: [' Mary']


Top 0th token. Logit: 18.19 Prob: 69.93% Token: | Mary|
Top 1th token. Logit: 15.82 Prob:  6.49% Token: | them|
Top 2th token. Logit: 15.48 Prob:  4.66% Token: | the|
Top 3th token. Logit: 14.93 Prob:  2.66% Token: | his|
Top 4th token. Logit: 14.86 Prob:  2.49% Token: | John|
Top 5th token. Logit: 14.12 Prob:  1.19% Token: | her|
Top 6th token. Logit: 13.99 Prob:  1.04% Token: | their|
Top 7th token. Logit: 13.70 Prob:  0.78% Token: | a|
Top 8th token. Logit: 13.53 Prob:  0.66% Token: | him|
Top 9th token. Logit: 13.39 Prob:  0.57% Token: | Mrs|


Orig 3.9790918827056885
reconstr 3.999037027359009
Zero 11.021940231323242
Tokenized prompt: ['<|endoftext|>', 'When', ' John', ' and', ' Mary', ' went', ' to', ' the', ' shops', ',', ' John', ' gave', ' the', ' bag', ' to']
Tokenized answer: [' Mary']


Top 0th token. Logit: 17.44 Prob: 63.61% Token: | Mary|
Top 1th token. Logit: 14.90 Prob:  5.04% Token: | John|
Top 2th token. Logit: 14.86 Prob:  4.81% Token: | the|
Top 3th token. Logit: 14.81 Prob:  4.61% Token: | them|
Top 4th token. Logit: 14.42 Prob:  3.11% Token: | his|
Top 5th token. Logit: 14.05 Prob:  2.16% Token: | her|
Top 6th token. Logit: 13.92 Prob:  1.89% Token: | their|
Top 7th token. Logit: 13.03 Prob:  0.78% Token: | him|
Top 8th token. Logit: 12.97 Prob:  0.73% Token: | a|
Top 9th token. Logit: 12.44 Prob:  0.43% Token: | Mrs|


# Generating Feature Interfaces

In [6]:
vals, inds = torch.topk(feature_acts[0,-1].detach().cpu(),10)
px.bar(x=[str(i) for i in inds], y=vals).show()

In [7]:
vocab_dict = model.tokenizer.vocab
vocab_dict = {v: k.replace("Ġ", " ").replace("\n", "\\n") for k, v in vocab_dict.items()}

vocab_dict_filepath = Path(os.getcwd()) / "vocab_dict.json"
if not vocab_dict_filepath.exists():
    with open(vocab_dict_filepath, "w") as f:
        json.dump(vocab_dict, f)
        

os.environ["TOKENIZERS_PARALLELISM"] = "false"
data = load_dataset("NeelNanda/c4-code-20k", split="train") # currently use this dataset to avoid deal with tokenization while streaming
tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=128)
tokenized_data = tokenized_data.shuffle(42)
all_tokens = tokenized_data["tokens"]


# Currently, don't think much more time can be squeezed out of it. Maybe the best saving would be to
# make the entire sequence indexing parallelized, but that's possibly not worth it right now.

max_batch_size = 512
total_batch_size = 4096*5
feature_idx = list(inds.flatten().cpu().numpy())
# max_batch_size = 512
# total_batch_size = 16384
# feature_idx = list(range(1000))

tokens = all_tokens[:total_batch_size]

feature_data: Dict[int, FeatureData] = get_feature_data(
    encoder=sparse_autoencoder,
    # encoder_B=sparse_autoencoder,
    model=model,
    hook_point=sparse_autoencoder.cfg.hook_point,
    hook_point_layer=sparse_autoencoder.cfg.hook_point_layer,
    tokens=tokens,
    feature_idx=feature_idx,
    max_batch_size=max_batch_size,
    left_hand_k = 3,
    buffer = (5, 5),
    n_groups = 10,
    first_group_size = 20,
    other_groups_size = 5,
    verbose = True,
)


for test_idx in list(inds.flatten().cpu().numpy()):
    html_str = feature_data[test_idx].get_all_html()
    with open(f"data_{test_idx:04}.html", "w") as f:
        f.write(html_str)

Storing model activations: 100%|██████████| 40/40 [01:48<00:00,  2.70s/it]
                                                                      

Estimated time for all 24576 features = 275 minutes



This will produce a number of html files which each contain a dashboard showing feature activation on the sample data. It currently doesn't process that much data so it isn't that useful. 