### Load models, datasets, SAEs

In [2]:
device = "cuda:3"

In [3]:
import os
import random
import sys

import pandas as pd
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformer_lens import utils as tl_utils

from src import *

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f2163b0f650>

In [4]:

model_name = "meta-llama/Meta-Llama-3-8B"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
).cuda()

tokenizer = AutoTokenizer.from_pretrained(model_name)
model.tokenizer = tokenizer


sae = Sae.load_from_hub("EleutherAI/sae-llama-3-8b-32x-v2", layer=24).cuda()
sae_layer = 24

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:

# Load Dataset

dataset = load_dataset(
    "togethercomputer/RedPajama-Data-1T-Sample",
    split="train",
    # TODO: Maybe set this to False by default? But RPJ requires it.
    trust_remote_code=True,
)

### Load and visualize features

In [6]:

# Cache all SAE activations if theyre not already saved

cache = FeatureCache(
    model=model,
    tokenizer=tokenizer,
    encoder=sae,
    encoder_layer=sae_layer,
)
cache.load_from_disk(
    folder_name="cached_sae_acts/",
)


ex = cache.load_example(10)
print(ex)
print(len(ex.str_tokens))


<|begin_of_text|> of 60% or letter grade ‘B’.
Only students who have the NYSC discharge, exemption or exclusion certificate can be allowed to clear into the different courses.
Uniben Postgraduate School Transcripts
Candidates should request for their transcript from their former institution by clicking on the link www.etx-ng.com/unibenpg to forward their academic transcript(s) to the “Secretary, School of Postgraduate Studies”, to reach him not later than 8 weeks from the date of this publication. Applications of candidates whose transcripts are not forwarded to the school shall not be processed.
ETX- NG TRANSCRIPT SERVICE:
ETX-NG
128


  latent_indices=torch.from_numpy(latent_indices),


In [7]:
common_features = cache.get_common_features(1000)
print(common_features)
common_features[0].get_quantiles(5, 10)

[<src.feature_caching.Feature object at 0x7f1fadce8310>, <src.feature_caching.Feature object at 0x7f1fabd187d0>, <src.feature_caching.Feature object at 0x7f1faddfd050>, <src.feature_caching.Feature object at 0x7f1fade54210>, <src.feature_caching.Feature object at 0x7f1fadaa92d0>, <src.feature_caching.Feature object at 0x7f1fac065590>, <src.feature_caching.Feature object at 0x7f1fac661350>, <src.feature_caching.Feature object at 0x7f1fabeaed50>, <src.feature_caching.Feature object at 0x7f1fada358d0>, <src.feature_caching.Feature object at 0x7f1fad912350>, <src.feature_caching.Feature object at 0x7f1fac192f50>, <src.feature_caching.Feature object at 0x7f1fac1ba690>, <src.feature_caching.Feature object at 0x7f1fade3f2d0>, <src.feature_caching.Feature object at 0x7f1fac02c0d0>, <src.feature_caching.Feature object at 0x7f1fac1d8c90>, <src.feature_caching.Feature object at 0x7f1fac2aa990>, <src.feature_caching.Feature object at 0x7f1fac4b3e90>, <src.feature_caching.Feature object at 0x7f1fab

{}

In [8]:
common_features[-1].get_num_nonzero()

863

In [9]:
common_features[203].get_num_nonzero()

48

In [10]:
import time
from circuitsvis.tokens import colored_tokens
from circuitsvis.activations import text_neuron_activations
import numpy as np

feature_idx = common_features[852].feature_id
n_top = 50

max_activation_examples = cache.features[feature_idx].get_max_activating(n_top)
tokens, acts = [], []
for idx in range(n_top):
    tokens_idx, acts_idx = max_activation_examples[idx].get_tokens_feature_lists(feature_idx)
    tokens.append(tokens_idx)
    acts.append(np.array([ [[x,],] for x in acts_idx]))
text_neuron_activations(tokens, acts)

In [12]:
from circuitsvis.tokens import colored_tokens

feature_idx = 124
feature_acts = cache.features[feature_idx].get_quantiles(5, 10)
print(list(feature_acts.keys())[-2])
for example in feature_acts[list(feature_acts.keys())[-4]]:
    # print(str(example).replace("\n", ""))
    tokens, acts = example.get_tokens_feature_lists(feature_idx)
    print(max(acts))
    colored_tokens(tokens, acts)


(0.22226563096046448, 0.24360352754592896)
0.1929931640625
0.1966552734375
0.19580078125
0.2010498046875
0.201171875
0.1986083984375
0.198974609375
0.1912841796875
0.1910400390625
0.19970703125


In [13]:
colored_tokens(tokens, acts)

In [14]:
cache.load_example(58704).get_feature_activation(0)

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.], dtype=torch.float16)

In [15]:
acts

[0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.19970703125,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0]