In [None]:
# Imports and Setup
%load_ext autoreload
%autoreload 2

from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
import json
from sentence_transformers import SentenceTransformer
from transformer_lens import HookedTransformer
import utils.activation_collection as activation_collection
from eval_config import EvalConfig
import utils.dataset_utils as dataset_utils
import pandas as pd
from tqdm import tqdm
import gc
import torch
from sae_lens import SAE
from sae_lens.sae import TopK
import utils.formatting_utils as formatting_utils

In [None]:
# Configuration
config = EvalConfig()
device = "cuda"
llm_batch_size = activation_collection.LLM_NAME_TO_BATCH_SIZE[config.model_name]
llm_dtype = activation_collection.LLM_NAME_TO_DTYPE[config.model_name]

In [None]:
# Load concepts
with open('concepts.json') as f:
    concepts = json.load(f)

# Print first 5 adjectives
print(concepts['adjectives'][:5])

In [None]:
# Initialize SentenceTransformer model
model = SentenceTransformer("nomic-ai/nomic-embed-text-v1", trust_remote_code=True)

In [None]:
# Generate embeddings
queries = concepts['adjectives']
query_embeddings = model.encode(queries)
print(f'shape of query_embeddings: {query_embeddings.shape}')

In [None]:
# Calculate similarities
print(f"Similarity function: {model.similarity_fn_name}")
similarities = model.similarity(query_embeddings, query_embeddings)
print(similarities)

In [None]:
# Initialize HookedTransformer
model = HookedTransformer.from_pretrained_no_processing(
    config.model_name, device=device, dtype=llm_dtype
)

In [None]:
# Tokenize data and collect activations
tokenized = dataset_utils.tokenize_data(concepts, model.tokenizer, config.context_length, device=device)
all_llm_acts_BLD = activation_collection.get_all_llm_activations(
    tokenized, model, llm_batch_size, config.hook_name
)
llm_acts_BD = activation_collection.create_meaned_model_activations(all_llm_acts_BLD)

In [None]:
# SAE setup
selected_saes_dict = {'sae_bench_pythia70m_sweep_topk_ctx128_0730':
    ['pythia70m_sweep_topk_ctx128_0730/resid_post_layer_4/trainer_10',
    'pythia70m_sweep_topk_ctx128_0730/resid_post_layer_4/trainer_12']}

sae_release = 'sae_bench_pythia70m_sweep_topk_ctx128_0730'

sae_map_df = pd.DataFrame.from_records(
    {k: v.__dict__ for k, v in get_pretrained_saes_directory().items()}
).T

sae_id_to_name_map = sae_map_df.saes_map[sae_release]
sae_name_to_id_map = {v: k for k, v in sae_id_to_name_map.items()}

sae_name = selected_saes_dict[sae_release][0]
sae_id = sae_name_to_id_map[sae_name]

In [None]:
# Load and prepare SAE
gc.collect()
torch.cuda.empty_cache()

sae, cfg_dict, sparsity = SAE.from_pretrained(
    release=sae_release,
    sae_id=sae_id,
    device=device,
)
sae = sae.to(device=device)

if "topk" in sae_name and not isinstance(sae.activation_fn, TopK):
    sae = formatting_utils.fix_topk_saes(sae, sae_release, sae_name, data_dir="../../")
    assert isinstance(sae.activation_fn, TopK)

In [None]:
# Get SAE activations
all_sae_train_acts_BF = activation_collection.get_sae_meaned_activations(
    all_llm_acts_BLD, sae, config.sae_batch_size, llm_dtype
)

# Print available keys in all_sae_train_acts_BF
print(all_sae_train_acts_BF.keys())