In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import einops
from transformers import AutoModelForCausalLM, AutoTokenizer
from dataclasses import dataclass
import os
import random

import mypkg.whitebox_infra.attribution as attribution
import mypkg.whitebox_infra.dictionaries.batch_topk_sae as batch_topk_sae
import mypkg.whitebox_infra.data_utils as data_utils
import mypkg.whitebox_infra.model_utils as model_utils
import mypkg.whitebox_infra.interp_utils as interp_utils
import mypkg.pipeline.setup.dataset as dataset_setup
import mypkg.pipeline.infra.hiring_bias_prompts as hiring_bias_prompts
from mypkg.eval_config import EvalConfig

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# model_name = "mistralai/Ministral-8B-Instruct-2410"
model_name = "mistralai/Mistral-Small-24B-Instruct-2501"
# model_name = "google/gemma-2-9b-it"
# model_name = "google/gemma-2-27b-it"
dtype = torch.bfloat16
model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=dtype, device_map="cuda:0"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

gradient_checkpointing = False

if model_name == "google/gemma-2-27b-it":
    gradient_checkpointing = True
    batch_size = 1
elif model_name == "mistralai/Mistral-Small-24B-Instruct-2501":
    batch_size = 2
else:
    batch_size = 4

if gradient_checkpointing:
    model.config.use_cache = False
    model.gradient_checkpointing_enable()


chosen_layer_percentage = [50]

chosen_layers = []
for layer_percent in chosen_layer_percentage:
    chosen_layers.append(model_utils.MODEL_CONFIGS[model_name]["layer_mappings"][layer_percent]["layer"])

eval_config = EvalConfig(
        model_name=model_name,
        political_orientation=True,
        pregnancy=False,
        employment_gap=False,
        anthropic_dataset=False,
        downsample=100,
        gpu_inference=True,
        anti_bias_statement_file="v1.txt",
        job_description_file="short_meta_job_description.txt",
        system_prompt_filename="yes_no.txt",
    )


In [None]:
# sae_repo = "adamkarvonen/ministral_saes"
# sae_path = f"mistralai_Ministral-8B-Instruct-2410_batch_top_k/resid_post_layer_{chosen_layers[0]}/trainer_1/ae.pt"

# sae = batch_topk_sae.load_dictionary_learning_batch_topk_sae(
#     repo_id=sae_repo,
#     filename=sae_path,
#     model_name=model_name,
#     device=device,
#     dtype=dtype,
#     layer=chosen_layers[0],
#     local_dir="downloaded_saes",
# )
trainer_id = 1
sae = model_utils.load_model_sae(model_name, device, dtype, chosen_layer_percentage[0], trainer_id=trainer_id)

submodules = [model_utils.get_submodule(model, chosen_layers[0])]

In [None]:
df = dataset_setup.load_raw_dataset()

industry = "INFORMATION-TECHNOLOGY"
downsample = eval_config.downsample
random_seed = eval_config.random_seed

random.seed(random_seed)
torch.manual_seed(random_seed)

df = dataset_setup.filter_by_industry(df, industry)

df = dataset_setup.balanced_downsample(df, downsample, random_seed)


args = hiring_bias_prompts.HiringBiasArgs(
    political_orientation=True,
    employment_gap=False,
    pregnancy=False,
    race=False,
    gender=False,
)


args = hiring_bias_prompts.HiringBiasArgs(
    political_orientation=False,
    employment_gap=False,
    pregnancy=False,
    race=True,
    gender=False,
)



prompts = hiring_bias_prompts.create_all_prompts_hiring_bias(df, args, eval_config)

In [None]:
train_texts, train_labels = hiring_bias_prompts.process_hiring_bias_resumes_prompts(prompts, args)

train_texts = model_utils.add_chat_template(train_texts, model_name)

dataloader = data_utils.create_simple_dataloader(
    train_texts, train_labels, model_name, device, batch_size=batch_size
)

In [None]:
# Build the custom loss function
yes_vs_no_loss_fn = attribution.make_yes_no_loss_fn(
    tokenizer,
    yes_candidates=["yes", " yes", "Yes", " Yes", "YES", " YES"],
    no_candidates=["no", " no", "No", " No", "NO", " NO"],
    device=device,
)

effects_F, error_effect = attribution.get_effects(
    model,
    sae,
    dataloader,
    yes_vs_no_loss_fn,
    submodules,
    chosen_layers,
    device,
)

# Print peak memory usage
if torch.cuda.is_available():
    peak_memory = torch.cuda.max_memory_allocated() / 1024**2  # Convert to MB
    print(f"Peak CUDA memory usage: {peak_memory:.2f} MB")

# Peak CUDA memory usage: 33432.72 MB

In [None]:
top_k_ids = effects_F.abs().topk(20).indices
print(top_k_ids)

top_k_vals = effects_F[top_k_ids]
print(top_k_vals)

print(error_effect)

# tensor([ 4393, 15242,  9049,  3959, 11802, 14960,   428,  9920,  2715,  3509,
#          9444, 12979,  9319,  8910, 12243,  7781, 11637, 10283,  4204,  2557],
#        device='cuda:0')
# tensor([0.0074, 0.0058, 0.0033, 0.0016, 0.0016, 0.0013, 0.0011, 0.0009, 0.0008,
#         0.0008, 0.0008, 0.0008, 0.0007, 0.0007, 0.0007, 0.0007, 0.0006, 0.0006,
#         0.0006, 0.0005], device='cuda:0')

# tensor([ 4393, 15242,  9049, 13855, 11802,  3959,  2039, 14960,  1683,  3509,
#           428,  4794,  3645,  9920,  5911,  1160,  1656, 16078,  9319,   394],
#        device='cuda:0')
# tensor([ 0.0305,  0.0242,  0.0133, -0.0078,  0.0064,  0.0063, -0.0058,  0.0052,
#         -0.0051,  0.0050,  0.0045, -0.0043, -0.0042,  0.0040, -0.0039, -0.0034,
#         -0.0033, -0.0033,  0.0033, -0.0031], device='cuda:0')
# tensor(-0.0181, device='cuda:0')

# tensor([ 4393, 15242,  9049, 13855, 11802,  3645,  2039,  3959,  1683,  3509,
#         14960,  9920,  4794,  1656,  5911, 16078,  9319,  1160,  5286,   394],
#        device='cuda:0')
# tensor([ 0.0276,  0.0227,  0.0114, -0.0077,  0.0062, -0.0058, -0.0056,  0.0055,
#         -0.0049,  0.0047,  0.0046,  0.0038, -0.0036, -0.0035, -0.0034, -0.0034,
#          0.0033, -0.0032, -0.0030, -0.0029], device='cuda:0')
# tensor(-0.0127, device='cuda:0')

In [None]:
print(sae.W_dec.shape[0])

In [None]:
acts_dir = "max_acts"
acts_filename = f"acts_{model_name}_layer_{chosen_layers[0]}_trainer_{trainer_id}_layer_percent_{chosen_layer_percentage[0]}.pt".replace("/", "_")
if not os.path.exists(acts_filename):
    from huggingface_hub import hf_hub_download
    path_to_config = hf_hub_download(
        repo_id="adamkarvonen/sae_max_acts",
        filename=acts_filename,
        force_download=False,
        local_dir=acts_dir,
        repo_type="dataset",
    )

    acts_path = os.path.join(acts_dir, acts_filename)
    acts_data = torch.load(acts_path)
    
    # max_tokens, max_acts = interp_utils.get_interp_prompts(
    #     model,
    #     submodules[0],
    #     sae,
    #     torch.tensor(list(range(sae.W_dec.shape[0]))),
    #     context_length=128,
    #     tokenizer=tokenizer,
    #     batch_size=batch_size * 32,
    #     num_tokens=30_000_000,
    # )
    # acts_data = {
    #     "max_tokens": max_tokens,
    #     "max_acts": max_acts,
    # }
    # torch.save(acts_data, acts_filename)
else:
    acts_path = os.path.join(acts_dir, acts_filename)
    acts_data = torch.load(acts_path)
max_tokens = acts_data["max_tokens"].cpu()
max_acts = acts_data["max_acts"].cpu()


In [None]:
from circuitsvis.activations import text_neuron_activations
import gc
from IPython.display import clear_output, display

def _list_decode(x):
    if len(x.shape) == 0:
        return tokenizer.decode(x, skip_special_tokens=False)
    else:
        return [_list_decode(y) for y in x]
    

clear_output(wait=True)
gc.collect()

for i in range(20):
    feature_idx = top_k_ids[i]
    feature_val = top_k_vals[i]
    print(f"Feature {i}, value: {feature_val}")
    selected_token_KL = max_tokens[feature_idx]

    selected_activations_KL11 = [max_acts[feature_idx, k, :, None, None] for k in range(5)]
    selected_token_strs_KL = _list_decode(selected_token_KL)

    for k in range(len(selected_token_strs_KL)):
        if "<s>" in selected_token_strs_KL[k][0] or "<bos>" in selected_token_strs_KL[k][0]:
            selected_token_strs_KL[k][0] = "BOS>"

    # selected_token_strs_KL = tokenizer.batch_decode(selected_token_KL, skip_special_tokens=False)
    # for k in range(len(selected_token_strs_KL)):
    #     string = selected_token_strs_KL[k]
    #     print(string[:10])
        # print("".join(string))

    html_activations = text_neuron_activations(selected_token_strs_KL, selected_activations_KL11)
    display(html_activations)