In [1]:
import os
import random
import sys

import pandas as pd
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformer_lens import utils as tl_utils

from src import *

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f5319dd83a0>

### Load the Model, SAE, & Dataset
Loading all the things

In [2]:
model_name = "meta-llama/Meta-Llama-3-8B"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
).cuda()
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.tokenizer = tokenizer

sae = Sae.load_from_hub("EleutherAI/sae-llama-3-8b-32x", layer=12).cuda()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
dataset = load_dataset(
    "togethercomputer/RedPajama-Data-1T-Sample",
    split="train",
    # TODO: Maybe set this to False by default? But RPJ requires it.
    trust_remote_code=True,
)

In [4]:
dataset = dataset.shuffle()
dataset[0]

{'text': "Home Arcade Games Hyper Shot iOS iOS games\nHyper Shot – Highly-Addictive iOS Arcade Game\nPosted By: Sara Carikj - 9:22 AM\nThe iPhone has been around for a decade now, but Apple's iOS remains fertile ground for a wide variety of great mobile games across every genre. Armed with strong graphics and responsive touch screens, the iPhone and iPad are solid gaming machines. So it's no surprise that developers continue to produce top-notch games for Apple's mobile devices. Whether you've just gotten an iOS device that you're looking to load up with games or you're a long-time iOS owner who wants to try something new, we've found a game to help you get started.\nWhat is Hyper Shot?\nThe name of the game that we’re talking about is Hyper Shot. It is an endless arcade game that can be enjoyed by users of any age! Compatible with both iPhones and iPads, this unique game is designed with smooth touch screen controls, upbeat graphics, beautiful colors and fun sounds. The best thing abo

In [5]:
SEQ_LEN = 128
seed = 42
n_examples = 100000

# Tokenize the data (using a utils function) and shuffle it
dataset = dataset.shuffle(seed)
tokenized_data = tl_utils.tokenize_and_concatenate(dataset.select(range(n_examples)), tokenizer, max_length=SEQ_LEN)
tokenized_data = tokenized_data.shuffle(seed)
all_tokens = tokenized_data["tokens"]

print(all_tokens.shape)

Map (num_proc=10):   0%|          | 0/100000 [00:00<?, ? examples/s]

torch.Size([972579, 128])


### Visualizations

In [6]:
sae = sae.to(torch.bfloat16) # Makes compute faster
sae_vis_config = SaeVisConfig(
    features = range(64),
    minibatch_size_tokens = 128,
    verbose = True,
)


if os.path.exists("sae_feats.json"):
    sae_vis_data = SaeVisData().load_json(
        filename="sae_feats.json",
        cfg=sae_vis_config,
        model=model,
        encoder=sae,
        encoder_B=None
    )
else:

    sae_vis_data = create_sae_vis_data(
        encoder=sae,
        encoder_layer=12,
        model=model,
        tokens=all_tokens,
        cfg=sae_vis_config
    )

    sae_vis_data.save_json("sae_feats.json")

Forward passes to cache data for vis:   0%|          | 0/7599 [00:00<?, ?it/s]

124490112


In [7]:
sae.W_dec.shape

torch.Size([131072, 4096])

In [8]:
dir(sae_vis_data.feature_data_dict.get(0))

['__annotations__',
 '__class__',
 '__dataclass_fields__',
 '__dataclass_params__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__match_args__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_get_html_data_feature_centric',
 '_get_html_data_prompt_centric',
 'acts_histogram_data',
 'feature_tables_data',
 'get_component_from_config',
 'logits_histogram_data',
 'logits_table_data',
 'prompt_data',
 'sequence_data']

In [9]:
sae_vis_data.feature_data_dict.get(0).sequence_data

SequenceMultiGroupData(seq_group_data=[SequenceGroupData(title='TOP ACTIVATIONS<br>MAX = 0.154', seq_data=[SequenceData(token_ids=[117054, 41551, 4073, 4718, 6834, 3279, 435, 6341, 198, 3923, 1550], feat_acts=[0.0, 0.0, 0.0, 0.0, 0.0, 0.1543, 0.0, 0.0, 0.0, 0.0, 0.0], loss_contribution=[-0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0006885528564453125, -0.0, -0.0, -0.0, -0.0], token_logits=[0.008333413861691952, -0.013808189891278744, -0.022604014724493027, 0.0017230864614248276, -0.004952811636030674, 0.009015126153826714, 0.008546494878828526, -0.008553413674235344, 0.021244024857878685, -0.0035560820251703262, -0.006625906564295292], top_token_ids=[[], [], [], [], [], [], [53403, 13649, 74326, 25225, 65043], [], [], [], []], top_logits=[[], [], [], [], [], [], [0.0047, 0.004, 0.0037, 0.0037, 0.0037], [], [], [], []], bottom_token_ids=[[], [], [], [], [], [], [23920, 54366, 87404, 24681, 88625], [], [], [], []], bottom_logits=[[], [], [], [], [], [], [-0.0046, -0.0044, -0.0041, -0.004, -0.0

In [10]:
sae_vis_data.feature_data_dict[0].feature_tables_data.neuron_alignment_indices

[]

In [13]:


# Save as HTML file & display vis
filename = "_feature_vis_demo.html"
sae_vis_data.save_feature_centric_vis(filename, feature_idx=8)
sae_vis_data.save_json("sae_feats.json")

Saving feature-centric vis:   0%|          | 0/64 [00:00<?, ?it/s]

In [15]:
torch.zeros(5).cuda().tolist()

[0.0, 0.0, 0.0, 0.0, 0.0]

In [19]:
tokenizer.decode(123412452)

''

In [41]:
from circuitsvis.tokens import colored_tokens

colored_tokens(["My", "123", "\n ", "1", "My", "123", "\n ", "1"], [0, 0.1, 0, 0.5])
