In [2]:
# Imports and Setup
%load_ext autoreload
%autoreload 2

from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
import json
from sentence_transformers import SentenceTransformer
from transformer_lens import HookedTransformer
import sae_bench_utils.activation_collection as activation_collection
from eval_config import EvalConfig
import sae_bench_utils.dataset_utils as dataset_utils
import pandas as pd
from tqdm import tqdm, trange
import gc
import torch
from sae_lens import SAE
from sae_lens.sae import TopK
import sae_bench_utils.formatting_utils as formatting_utils

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Configuration
config = EvalConfig()
device = "cuda"
llm_batch_size = activation_collection.LLM_NAME_TO_BATCH_SIZE[config.model_name]
llm_dtype = activation_collection.LLM_NAME_TO_DTYPE[config.model_name]

In [4]:
# Initialize HookedTransformer
model = HookedTransformer.from_pretrained_no_processing(
    config.model_name, device=device, dtype=llm_dtype
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model pythia-70m-deduped into HookedTransformer


In [5]:
s1 = "Miles Davis played the"
s2 = "Paris is the capital of"



In [18]:
def get_activations(sentences: list[str], hook_name: str):
    out = []
    for s in sentences:
        acts_BLD = None
        def activation_hook(resid_BLD: torch.Tensor, hook):
            nonlocal acts_BLD
            acts_BLD = resid_BLD

        model.run_with_hooks(
            s, return_type=None, fwd_hooks=[(hook_name, activation_hook)]
        )
        out.append(acts_BLD[0])
    return out

acts = get_activations([s1, s2], config.hook_name)

In [19]:
acts[0].shape, acts[1].shape

(torch.Size([6, 512]), torch.Size([6, 512]))

In [20]:
r1 = acts[0][-1]
r2 = acts[1][-1]

## Feed through SAE

In [22]:
# SAE setup
selected_saes_dict = {'sae_bench_pythia70m_sweep_topk_ctx128_0730':
    ['pythia70m_sweep_topk_ctx128_0730/resid_post_layer_4/trainer_10',
    'pythia70m_sweep_topk_ctx128_0730/resid_post_layer_4/trainer_12']}

sae_release = 'sae_bench_pythia70m_sweep_topk_ctx128_0730'

sae_map_df = pd.DataFrame.from_records(
    {k: v.__dict__ for k, v in get_pretrained_saes_directory().items()}
).T

sae_id_to_name_map = sae_map_df.saes_map[sae_release]
sae_name_to_id_map = {v: k for k, v in sae_id_to_name_map.items()}

sae_name = selected_saes_dict[sae_release][0]
sae_id = sae_name_to_id_map[sae_name]

In [23]:
# Load and prepare SAE
gc.collect()
torch.cuda.empty_cache()

sae, cfg_dict, sparsity = SAE.from_pretrained(
    release=sae_release,
    sae_id=sae_id,
    device=device,
)
sae = sae.to(device=device)

if "topk" in sae_name and not isinstance(sae.activation_fn, TopK):
    sae = formatting_utils.fix_topk_saes(sae, sae_release, sae_name, data_dir="../../")
    assert isinstance(sae.activation_fn, TopK)

In [24]:
f1 = sae.encode(r1).nonzero().squeeze()
f2 = sae.encode(r2).nonzero().squeeze()
f12 = sae.encode(r1+r2).nonzero().squeeze()

In [26]:
pre_enc_set = set(f1.tolist() + f2.tolist())
post_enc_set = set(f12.tolist())

intersection = pre_enc_set.intersection(post_enc_set)
difference = pre_enc_set.symmetric_difference(post_enc_set)

In [27]:
len(pre_enc_set), len(post_enc_set), len(intersection), len(difference)

(148, 80, 57, 114)