# Setup & loading stuff

In [1]:
from datasets import load_dataset
from sae_lens import SAE, HookedSAETransformer

import torch

from sae_vis.data_config_classes import SaeVisConfig
from sae_vis.data_storing_fns import SaeVisData

torch.set_grad_enabled(False)
# if torch.cuda.is_available():
#     device = torch.device("cuda")
# elif torch.backends.mps.is_available():
#     device = torch.device("mps")
# else:
#     device = torch.device("cpu")
device = torch.device("mps")

print(f"Using device: {device}")

%env TOKENIZERS_PARALLELISM=true

Using device: mps
env: TOKENIZERS_PARALLELISM=true


In [2]:
model = HookedSAETransformer.from_pretrained_no_processing(
    "pythia-14m", device=device
).to(torch.float32)
sae = SAE.load_from_disk(
    "/Users/sidbaskaran/Desktop/research/SAELens/checkpoints/hfvusbuw/final_10000384"
).to(device=device, dtype=torch.float32)

batch_size = 4096
batch_size_for_computing_alive_feats = 512
seq_len = 64

original_dataset = load_dataset(
    sae.cfg.dataset_path, split="train", streaming=True, trust_remote_code=True
)

original_dataset = original_dataset.map(lambda x: model.tokenizer(x["text"]))
original_dataset = original_dataset.filter(lambda x: len(x["input_ids"]) >= seq_len)


attn_tokens_as_list = [
    x["input_ids"][: seq_len - 1] for (_, x) in zip(range(batch_size), original_dataset)
]
tokens = torch.tensor(attn_tokens_as_list, device=device)
bos_token = torch.tensor(
    [model.tokenizer.bos_token_id for _ in range(batch_size)], device=device
)  # type: ignore
tokens = torch.cat([bos_token.unsqueeze(1), tokens], dim=1)
print(f"Tokens loaded for attn-only model: {tokens.shape=}")

_, cache = model.run_with_cache_with_saes(
    tokens[:batch_size_for_computing_alive_feats],
    saes=[sae],
    names_filter=(post_acts_hook := f"{sae.cfg.hook_name}.hook_sae_acts_post"),
    stop_at_layer=sae.cfg.hook_layer + 1,
)
acts = cache[post_acts_hook]
alive_feats = (acts.flatten(0, 1) > 1e-8).any(dim=0).nonzero().squeeze().tolist()
print(f"Alive features: {len(alive_feats)}/{sae.cfg.d_sae}\n")


Loaded pretrained model pythia-14m into HookedTransformer
Changing model dtype to torch.float32


This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


Tokens loaded for attn-only model: tokens.shape=torch.Size([4096, 64])
Alive features: 2726/8192



In [3]:
sae.cfg.dataset_path

'togethercomputer/RedPajama-Data-1T-Sample'

In [4]:
sae.cfg.hook_name

'blocks.0.hook_mlp_out'

In [5]:
sae_vis_data = SaeVisData.create(
    sae,
    model=model,
    tokens=tokens,
    cfg=SaeVisConfig(features=alive_feats[:32]),
)
sae_vis_data.save_feature_centric_vis(filename="demo_feature_vis_pythia70m_topk.html")

In [6]:
sae_vis_data.save_prompt_centric_vis(
    prompt="write fibonacci sequence in python",
    filename="demo_prompt_vis_pythia70m_topk.html",
)