In [1]:
# Imports and Setup
%load_ext autoreload
%autoreload 2

from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
import json
from transformer_lens import HookedTransformer
import sae_bench_utils.activation_collection as activation_collection
from eval_config import EvalConfig
import sae_bench_utils.dataset_utils as dataset_utils
import pandas as pd
from tqdm import tqdm, trange
import gc
import torch
from sae_lens import SAE
from sae_lens.sae import TopK
import sae_bench_utils.formatting_utils as formatting_utils

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Configuration
config = EvalConfig()
device = "cuda"
llm_batch_size = activation_collection.LLM_NAME_TO_BATCH_SIZE[config.model_name]
llm_dtype = activation_collection.LLM_NAME_TO_DTYPE[config.model_name]

In [3]:
# Initialize HookedTransformer
model = HookedTransformer.from_pretrained_no_processing(
    config.model_name, device=device, dtype=llm_dtype
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model pythia-70m-deduped into HookedTransformer


In [4]:
with open("known_1000.json", "r") as f:
    known_1000 = json.load(f)

# prompts = known_1000["prompt"]
prompts = [f"{s['prompt']} {s['attribute']}" for s in known_1000]

prompts[:5]
data = dict(prompts=prompts)

In [9]:
tokenized_data = dataset_utils.tokenize_data(data, model.tokenizer, max_length=64, device=device)

Tokenizing data:   0%|          | 0/1 [00:00<?, ?it/s]

Tokenizing data: 100%|██████████| 1/1 [00:00<00:00, 21.98it/s]


In [29]:
llm_activations_BLD = activation_collection.get_all_llm_activations(
    tokenized_data,
    model,
    batch_size=512,
    hook_name = config.hook_name,
)

Collecting activations for class prompts: 100%|██████████| 3/3 [00:00<00:00, 20.38it/s]


In [36]:
final_acts_BD = activation_collection.filter_final_token_activations(llm_activations_BLD, tokenized_data)
acts = final_acts_BD['prompts']

## Feed through SAE

In [37]:
r1 = acts[0]
r2 = acts[1]

In [38]:
# SAE setup
selected_saes_dict = {'sae_bench_pythia70m_sweep_topk_ctx128_0730':
    ['pythia70m_sweep_topk_ctx128_0730/resid_post_layer_4/trainer_10',
    'pythia70m_sweep_topk_ctx128_0730/resid_post_layer_4/trainer_12']}

sae_release = 'sae_bench_pythia70m_sweep_topk_ctx128_0730'

sae_map_df = pd.DataFrame.from_records(
    {k: v.__dict__ for k, v in get_pretrained_saes_directory().items()}
).T

sae_id_to_name_map = sae_map_df.saes_map[sae_release]
sae_name_to_id_map = {v: k for k, v in sae_id_to_name_map.items()}

sae_name = selected_saes_dict[sae_release][0]
sae_id = sae_name_to_id_map[sae_name]

In [42]:
# Load and prepare SAE
def load_sae(sae_release, sae_id, device):
    gc.collect()
    torch.cuda.empty_cache()

    sae, cfg_dict, sparsity = SAE.from_pretrained(
        release=sae_release,
        sae_id=sae_id,
        device=device,
    )
    sae = sae.to(device=device)

    if "topk" in sae_name and not isinstance(sae.activation_fn, TopK):
        sae = formatting_utils.fix_topk_saes(sae, sae_release, sae_name, data_dir="../../")
        assert isinstance(sae.activation_fn, TopK)

    return sae

In [43]:
sae = load_sae(sae_release, sae_id, device)

f1 = sae.encode(r1).nonzero().squeeze()
f2 = sae.encode(r2).nonzero().squeeze()
f12 = sae.encode(r1+r2).nonzero().squeeze()

In [44]:
pre_enc_set = set(f1.tolist() + f2.tolist())
post_enc_set = set(f12.tolist())

intersection = pre_enc_set.intersection(post_enc_set)
difference = pre_enc_set.symmetric_difference(post_enc_set)

In [45]:
len(pre_enc_set), len(post_enc_set), len(intersection), len(difference)

(152, 80, 56, 120)

In [46]:
def covering_metric(f1, f2, f12):
    pre_enc_set = set(f1.tolist() + f2.tolist())
    post_enc_set = set(f12.tolist())
    intersection = pre_enc_set.intersection(post_enc_set)
    return len(intersection) / len(pre_enc_set)

In [49]:
covering_metric(f1, f2, f12)

0.3684210526315789