In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import einops
from transformers import AutoModelForCausalLM, AutoTokenizer
from dataclasses import dataclass
import os
import random

import mypkg.whitebox_infra.attribution as attribution
import mypkg.whitebox_infra.dictionaries.batch_topk_sae as batch_topk_sae
import mypkg.whitebox_infra.data_utils as data_utils
import mypkg.whitebox_infra.model_utils as model_utils
import mypkg.whitebox_infra.interp_utils as interp_utils
import mypkg.pipeline.setup.dataset as dataset_setup
import mypkg.pipeline.infra.hiring_bias_prompts as hiring_bias_prompts
from mypkg.eval_config import EvalConfig
import mypkg.pipeline.infra.model_inference as model_inference

  from .autonotebook import tqdm as notebook_tqdm


INFO 04-09 01:43:41 [__init__.py:239] Automatically detected platform cuda.


2025-04-09 01:43:42,223	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_name = "mistralai/Ministral-8B-Instruct-2410"
model_name = "mistralai/Mistral-Small-24B-Instruct-2501"
# model_name = "google/gemma-2-9b-it"
# model_name = "google/gemma-2-27b-it"

args = hiring_bias_prompts.HiringBiasArgs(
    political_orientation=True,
    employment_gap=False,
    pregnancy=False,
    race=False,
    gender=False,
)


args = hiring_bias_prompts.HiringBiasArgs(
    political_orientation=False,
    employment_gap=False,
    pregnancy=False,
    race=True,
    gender=False,
)

dtype = torch.bfloat16
model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=dtype, device_map=device
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

gradient_checkpointing = False

if model_name == "google/gemma-2-27b-it":
    gradient_checkpointing = True
    batch_size = 1
elif model_name == "mistralai/Mistral-Small-24B-Instruct-2501":
    batch_size = 1
else:
    batch_size = 3

if gradient_checkpointing:
    model.config.use_cache = False
    model.gradient_checkpointing_enable()


chosen_layer_percentage = [25]
# chosen_layer_percentage = [50]

system_prompt = "yes_no.txt"

chosen_layers = []
for layer_percent in chosen_layer_percentage:
    chosen_layers.append(model_utils.MODEL_CONFIGS[model_name]["layer_mappings"][layer_percent]["layer"])

eval_config = EvalConfig(
        model_name=model_name,
        political_orientation=True,
        pregnancy=False,
        employment_gap=False,
        anthropic_dataset=False,
        downsample=50,
        # downsample=10,
        gpu_inference=True,
        anti_bias_statement_file="v1.txt",
        # anti_bias_statement_file="v3.txt",
        # anti_bias_statement_file="v17.txt",
        job_description_file="short_meta_job_description.txt",
        system_prompt_filename=system_prompt,
    )


Loading checkpoint shards: 100%|██████████| 10/10 [00:07<00:00,  1.38it/s]


In [4]:
# sae_repo = "adamkarvonen/ministral_saes"
# sae_path = f"mistralai_Ministral-8B-Instruct-2410_batch_top_k/resid_post_layer_{chosen_layers[0]}/trainer_1/ae.pt"

# sae = batch_topk_sae.load_dictionary_learning_batch_topk_sae(
#     repo_id=sae_repo,
#     filename=sae_path,
#     model_name=model_name,
#     device=device,
#     dtype=dtype,
#     layer=chosen_layers[0],
#     local_dir="downloaded_saes",
# )
trainer_id = 2

if "gemma" in model_name:
    trainer_id = 0
    
sae = model_utils.load_model_sae(model_name, device, dtype, chosen_layer_percentage[0], trainer_id=trainer_id)

submodules = [model_utils.get_submodule(model, chosen_layers[0])]

Original keys in state_dict: dict_keys(['b_dec', 'k', 'threshold', 'decoder.weight', 'encoder.weight', 'encoder.bias'])
Renamed keys in state_dict: dict_keys(['b_dec', 'k', 'threshold', 'W_dec', 'W_enc', 'b_enc'])


In [5]:
df = dataset_setup.load_raw_dataset()

industry = "INFORMATION-TECHNOLOGY"
downsample = eval_config.downsample
random_seed = eval_config.random_seed

random.seed(random_seed)
torch.manual_seed(random_seed)

df = dataset_setup.filter_by_industry(df, industry)

df = dataset_setup.balanced_downsample(df, downsample, random_seed)



prompts = hiring_bias_prompts.create_all_prompts_hiring_bias(df, args, eval_config)

Downsampled to 50 unique resumes
Total samples after maintaining demographic variations: 200


In [6]:
train_texts, train_labels = hiring_bias_prompts.process_hiring_bias_resumes_prompts(prompts, args)

train_texts = model_utils.add_chat_template(train_texts, model_name)

dataloader = data_utils.create_simple_dataloader(
    train_texts, train_labels, model_name, device, batch_size=batch_size, max_length=2500
)

No pad token found, setting eos token as pad token
Filtered out 4 samples that exceeded max_length (2500)
Original dataset size: 200, new size: 196


In [7]:
# model_answers = model_inference.run_single_forward_pass_transformers(
#     prompts, model_name, batch_size=batch_size * 6, model=model
# )

# bias_scores = hiring_bias_prompts.evaluate_bias(
#     model_answers,
#     system_prompt
# )
# print(bias_scores)

In [8]:
diff_acts_F = attribution.get_activations(model, sae, dataloader, submodules, chosen_layers, device)

  0%|          | 0/196 [00:00<?, ?it/s]

100%|██████████| 196/196 [00:08<00:00, 23.61it/s]


In [9]:
acts_top_k_ids = diff_acts_F.abs().topk(20).indices
print(acts_top_k_ids)

acts_top_k_vals = diff_acts_F[acts_top_k_ids]
print(acts_top_k_vals)

output_filename = os.path.join(
    "diff_acts",
    f"{eval_config.anti_bias_statement_file.replace('.txt', '')}_trainer_{trainer_id}_model_{model_name.replace('/', '_')}_layer_{chosen_layer_percentage[0]}_attrib_data.pt",
)
print(output_filename)

os.makedirs("diff_acts", exist_ok=True)

diff_acts = {"diff_acts_F": diff_acts_F}
diff_acts["config"] = {
        "model_name": model_name,
        "layer": chosen_layers[0],
        "bias_categories": "N\A",
        "anti_bias_statement_file": eval_config.anti_bias_statement_file,
        "downsample": eval_config.downsample,
        "random_seed": eval_config.random_seed,
        "batch_size": batch_size,
        "chosen_layer_percentage": chosen_layer_percentage,
        "trainer_id": trainer_id,
    }

torch.save(diff_acts, output_filename)

# tensor([23759, 42925, 33394, 45780, 10085, 23574, 30460, 61472,   521, 59020,
#           775,  4261, 41205, 29588, 44488,   983, 55773, 44225, 42612, 30063],
#        device='cuda:0')
# tensor([-2.2821e-05, -1.8463e-05, -1.8400e-05,  1.6533e-05, -1.6091e-05,
#          1.5563e-05,  1.5402e-05,  1.4945e-05, -1.4770e-05, -1.4396e-05,
#         -1.4326e-05, -1.4277e-05, -1.3993e-05,  1.3824e-05, -1.3621e-05,
#         -1.2731e-05,  1.2631e-05,  1.2598e-05, -1.2597e-05, -1.2533e-05],
#        device='cuda:0')
# diff_acts/v1_trainer_3_model_mistralai_Ministral-8B-Instruct-2410_layer_50_attrib_data.pt

# tensor([24895, 37973, 29596, 35104, 19677, 64690, 23575, 14460, 54626, 60262,
#         36264, 10894,  2577,  6381, 25218, 17486, 50206,  1279,  9861, 26352],
#        device='cuda:0')
# tensor([-0.0399, -0.0257,  0.0186, -0.0184,  0.0138,  0.0096,  0.0089, -0.0089,
#         -0.0082, -0.0079, -0.0077,  0.0074,  0.0067, -0.0066, -0.0062, -0.0060,
#          0.0056,  0.0050,  0.0046,  0.0046], device='cuda:0')
# diff_acts/v17_trainer_2_model_mistralai_Mistral-Small-24B-Instruct-2501_layer_50_attrib_data.pt

tensor([61599, 50902, 33592, 47196,  6638, 41619, 39383, 61307, 17161, 57264,
        53451, 29812, 46652, 49757, 46425, 44249,   574, 43212, 24257, 40698],
       device='cuda:0')
tensor([-0.0014,  0.0014,  0.0013, -0.0011,  0.0008, -0.0008,  0.0008,  0.0007,
        -0.0006,  0.0006,  0.0006, -0.0006, -0.0005, -0.0005, -0.0005,  0.0004,
         0.0004,  0.0004,  0.0004,  0.0004], device='cuda:0')
diff_acts/v1_trainer_2_model_mistralai_Mistral-Small-24B-Instruct-2501_layer_25_attrib_data.pt


In [10]:
# Build the custom loss function
yes_vs_no_loss_fn = attribution.make_yes_no_loss_fn(
    tokenizer,
    yes_candidates=["yes", " yes", "Yes", " Yes", "YES", " YES"],
    no_candidates=["no", " no", "No", " No", "NO", " NO"],
    device=device,
)

effects_F, error_effect, predicted_tokens = attribution.get_effects(
    model,
    tokenizer,
    sae,
    dataloader,
    yes_vs_no_loss_fn,
    submodules,
    chosen_layers,
    device,
    verbose=True,
)

# Print peak memory usage
if torch.cuda.is_available():
    peak_memory = torch.cuda.max_memory_allocated() / 1024**2  # Convert to MB
    print(f"Peak CUDA memory usage: {peak_memory:.2f} MB")

# Peak CUDA memory usage: 33432.72 MB

  0%|          | 0/196 [00:00<?, ?it/s]

100%|██████████| 196/196 [00:51<00:00,  3.78it/s]


Label 1:
Total samples: 98
Yes rate: 70.4%
No rate: 29.6%
Invalid rate: 0.0%

Label 0:
Total samples: 98
Yes rate: 73.5%
No rate: 26.5%
Invalid rate: 0.0%
Peak CUDA memory usage: 67290.98 MB





In [11]:
top_k_ids = effects_F.abs().topk(20).indices
print(top_k_ids)

top_k_vals = effects_F[top_k_ids]
print(top_k_vals)

print(error_effect)

# tensor([ 4393, 15242,  9049,  3959, 11802, 14960,   428,  9920,  2715,  3509,
#          9444, 12979,  9319,  8910, 12243,  7781, 11637, 10283,  4204,  2557],
#        device='cuda:0')
# tensor([0.0074, 0.0058, 0.0033, 0.0016, 0.0016, 0.0013, 0.0011, 0.0009, 0.0008,
#         0.0008, 0.0008, 0.0008, 0.0007, 0.0007, 0.0007, 0.0007, 0.0006, 0.0006,
#         0.0006, 0.0005], device='cuda:0')

# tensor([ 4393, 15242,  9049, 13855, 11802,  3959,  2039, 14960,  1683,  3509,
#           428,  4794,  3645,  9920,  5911,  1160,  1656, 16078,  9319,   394],
#        device='cuda:0')
# tensor([ 0.0305,  0.0242,  0.0133, -0.0078,  0.0064,  0.0063, -0.0058,  0.0052,
#         -0.0051,  0.0050,  0.0045, -0.0043, -0.0042,  0.0040, -0.0039, -0.0034,
#         -0.0033, -0.0033,  0.0033, -0.0031], device='cuda:0')
# tensor(-0.0181, device='cuda:0')

# tensor([ 4393, 15242,  9049, 13855, 11802,  3645,  2039,  3959,  1683,  3509,
#         14960,  9920,  4794,  1656,  5911, 16078,  9319,  1160,  5286,   394],
#        device='cuda:0')
# tensor([ 0.0276,  0.0227,  0.0114, -0.0077,  0.0062, -0.0058, -0.0056,  0.0055,
#         -0.0049,  0.0047,  0.0046,  0.0038, -0.0036, -0.0035, -0.0034, -0.0034,
#          0.0033, -0.0032, -0.0030, -0.0029], device='cuda:0')
# tensor(-0.0127, device='cuda:0')

# tensor([15658, 54626, 17662, 37447, 26352, 13920, 47911,   413,   204, 39717,
#         52330, 24031, 62241, 53133, 14038,  4682, 13032, 61127, 46104,  9405],
#        device='cuda:0')
# tensor([ 0.3070,  0.1780,  0.1460,  0.1291,  0.1048, -0.0993, -0.0823,  0.0779,
#          0.0748,  0.0666,  0.0586,  0.0581, -0.0442, -0.0442,  0.0392,  0.0377,
#         -0.0370,  0.0332, -0.0313, -0.0293], device='cuda:0')
# tensor(-0.0644, device='cuda:0')

# tensor([33592,  3662, 47196, 52209,  9593, 18947, 64744, 50582, 45795,   574,
#         57031, 61625, 33006, 53980, 46136,  4975, 12936, 32874,  6954, 38436],
#        device='cuda:0')
# tensor([ 0.1563, -0.1505, -0.1254,  0.1122,  0.1059,  0.1016, -0.0913, -0.0874,
#         -0.0841, -0.0828, -0.0805,  0.0714,  0.0673,  0.0635, -0.0592, -0.0583,
#          0.0563,  0.0545, -0.0545, -0.0544], device='cuda:0')
# tensor(0.0565, device='cuda:0')

# downsample 10
# tensor([63073,  3662, 45795, 33006, 58683, 50582, 45304, 42524, 33592, 27866,
#         21772, 12936, 25266, 57031, 15390, 60822, 64744, 44535, 60211,   169],
#        device='cuda:0')
# tensor([-0.5895, -0.2480, -0.2339,  0.2267,  0.2038, -0.1890, -0.1808,  0.1536,
#          0.1423, -0.1383, -0.1340,  0.1250, -0.1238, -0.1221,  0.1199,  0.1138,
#         -0.1091, -0.1051,  0.1040,  0.0992], device='cuda:0')
# tensor(0.6601, device='cuda:0')

tensor([33592,  3662, 47196, 52209, 18947,  9593, 64744, 50582, 57031,   574,
        45795, 61625, 33006, 53980,  4975, 46136, 12936, 32874, 56610, 46957],
       device='cuda:0')
tensor([ 0.1536, -0.1528, -0.1309,  0.1148,  0.1071,  0.1047, -0.0910, -0.0882,
        -0.0862, -0.0824, -0.0781,  0.0728,  0.0669,  0.0657, -0.0587, -0.0573,
         0.0556,  0.0549, -0.0548,  0.0543], device='cuda:0')
tensor(0.0866, device='cuda:0')


In [12]:
set1 = set(acts_top_k_ids.cpu().tolist())
set2 = set(top_k_ids.cpu().tolist())

print(set1)
print(set2)

print(set1.intersection(set2))

{17161, 41619, 61599, 57264, 33592, 46652, 574, 24257, 53451, 43212, 50902, 39383, 44249, 46425, 47196, 49757, 6638, 29812, 40698, 61307}
{18947, 12936, 50582, 56610, 33592, 61625, 46136, 574, 57031, 3662, 53980, 47196, 45795, 64744, 32874, 46957, 33006, 4975, 52209, 9593}
{33592, 47196, 574}


In [13]:
print(sae.W_dec.shape[0])

65536


In [14]:
acts_dir = "max_acts"
acts_filename = f"acts_{model_name}_layer_{chosen_layers[0]}_trainer_{trainer_id}_layer_percent_{chosen_layer_percentage[0]}.pt".replace("/", "_")
if not os.path.exists(acts_filename):
    from huggingface_hub import hf_hub_download
    path_to_config = hf_hub_download(
        repo_id="adamkarvonen/sae_max_acts",
        filename=acts_filename,
        force_download=False,
        local_dir=acts_dir,
        repo_type="dataset",
    )

    acts_path = os.path.join(acts_dir, acts_filename)
    acts_data = torch.load(acts_path)
    
    # max_tokens, max_acts = interp_utils.get_interp_prompts(
    #     model,
    #     submodules[0],
    #     sae,
    #     torch.tensor(list(range(sae.W_dec.shape[0]))),
    #     context_length=128,
    #     tokenizer=tokenizer,
    #     batch_size=batch_size * 32,
    #     num_tokens=30_000_000,
    # )
    # acts_data = {
    #     "max_tokens": max_tokens,
    #     "max_acts": max_acts,
    # }
    # torch.save(acts_data, acts_filename)
else:
    acts_path = os.path.join(acts_dir, acts_filename)
    acts_data = torch.load(acts_path)
max_tokens = acts_data["max_tokens"].cpu()
max_acts = acts_data["max_acts"].cpu()


In [15]:
from circuitsvis.activations import text_neuron_activations
import gc
from IPython.display import clear_output, display


def _list_decode(x):
    if len(x.shape) == 0:
        return tokenizer.decode(x, skip_special_tokens=False)
    else:
        return [_list_decode(y) for y in x]


def display_html_activations(
    selected_tokens_FKL: list[str],
    selected_activations_FKL: list[torch.Tensor],
    num_display: int = 10,
    k: int = 5,
):

    all_html_activations = []

    for i in range(num_display):

        selected_activations_KL11 = [
            selected_activations_FKL[i, k, :, None, None] for k in range(k)
        ]
        selected_tokens_KL = selected_tokens_FKL[i]
        selected_token_strs_KL = _list_decode(selected_tokens_KL)

        for k in range(len(selected_token_strs_KL)):
            if (
                "<s>" in selected_token_strs_KL[k][0]
                or "<bos>" in selected_token_strs_KL[k][0]
            ):
                selected_token_strs_KL[k][0] = "BOS>"

        # selected_token_strs_KL = tokenizer.batch_decode(selected_token_KL, skip_special_tokens=False)
        # for k in range(len(selected_token_strs_KL)):
        #     string = selected_token_strs_KL[k]
        #     print(string[:10])
        # print("".join(string))

        html_activations = text_neuron_activations(
            selected_token_strs_KL, selected_activations_KL11
        )

        all_html_activations.append(html_activations)
    
    return all_html_activations

clear_output(wait=True)
gc.collect()
html_activations = display_html_activations(max_tokens[top_k_ids.cpu()], max_acts[top_k_ids.cpu()], num_display=2)
for i, html_activation in enumerate(html_activations):
    feature_idx = top_k_ids[i]
    feature_val = top_k_vals[i]
    print(f"Feature {i}, feature idx: {feature_idx}, value: {feature_val}")
    # display(html_activation)

Feature 0, feature idx: 33592, value: 0.15362504124641418
Feature 1, feature idx: 3662, value: -0.15279383957386017


In [28]:
from copy import deepcopy

random.seed(random_seed)
torch.manual_seed(random_seed)

inspection_eval_config = deepcopy(eval_config)
inspection_eval_config.downsample = 3

inspect_df = dataset_setup.balanced_downsample(df, inspection_eval_config.downsample, random_seed)

inspect_prompts = hiring_bias_prompts.create_all_prompts_hiring_bias(inspect_df, args, inspection_eval_config)

inspect_texts, inspect_labels = hiring_bias_prompts.process_hiring_bias_resumes_prompts(inspect_prompts, args)

inspect_texts = model_utils.add_chat_template(inspect_texts, model_name)

inspect_tokens_FBL, inspect_activations_FBL = interp_utils.get_interp_prompts_user_inputs(
    model,
    submodules[0],
    sae,
    top_k_ids,
    inspect_texts,
    tokenizer=tokenizer,
    batch_size=batch_size * 6,
    k=len(inspect_texts),
    sort_by_activation=False,
)

print(inspect_tokens_FBL.shape)
print(inspect_activations_FBL.shape)
print(inspect_labels)

labels_tensor = torch.tensor(inspect_labels)

pos_mask_K = labels_tensor == 1
neg_mask_K = labels_tensor == 0

print(pos_mask_K)
print(neg_mask_K)

print(inspect_activations_FBL.shape)

pos_acts_FKL = inspect_activations_FBL[:, pos_mask_K, :]
neg_acts_FKL = inspect_activations_FBL[:, neg_mask_K, :]

mean_pos_acts_F = pos_acts_FKL.mean(dim=(1,2))
mean_neg_acts_F = neg_acts_FKL.mean(dim=(1,2))

torch.set_printoptions(precision=4, sci_mode=False)

print(mean_pos_acts_F)
print(mean_neg_acts_F)

ratios = mean_pos_acts_F / mean_neg_acts_F

print(ratios)




Downsampled to 3 unique resumes
Total samples after maintaining demographic variations: 12


Processing batches:   0%|          | 0/2 [00:00<?, ?it/s]

Processing batches: 100%|██████████| 2/2 [00:00<00:00,  2.64it/s]

torch.Size([20, 12, 2357])
torch.Size([20, 12, 2357])
[0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
tensor([False,  True, False,  True, False,  True, False,  True, False,  True,
        False,  True])
tensor([ True, False,  True, False,  True, False,  True, False,  True, False,
         True, False])
torch.Size([20, 12, 2357])
tensor([0.0022, 0.0047, 0.3281, 0.0115, 0.0154, 0.0039, 0.0019, 0.0026, 0.0330,
        0.0334, 0.0786, 0.0038, 0.0183, 0.0034, 0.0017, 0.0151, 0.0038, 0.0029,
        0.0040, 0.0195], dtype=torch.bfloat16)
tensor([    0.0003,     0.0047,     0.3301,     0.0117,     0.0154,     0.0040,
            0.0018,     0.0025,     0.0327,     0.0330,     0.0786,     0.0038,
            0.0183,     0.0035,     0.0014,     0.0151,     0.0034,     0.0032,
            0.0040,     0.0195], dtype=torch.bfloat16)
tensor([6.8125, 0.9883, 0.9922, 0.9805, 0.9961, 0.9922, 1.0312, 1.0234, 1.0078,
        1.0156, 1.0000, 0.9883, 1.0000, 0.9883, 1.2266, 1.0078, 1.1172, 0.9023,
        0.9922, 1.




In [None]:
# from copy import deepcopy

# inspection_eval_config = deepcopy(eval_config)
# inspection_eval_config.downsample = 3

# inspect_df = dataset_setup.balanced_downsample(df, inspection_eval_config.downsample, random_seed)

# inspect_prompts = hiring_bias_prompts.create_all_prompts_hiring_bias(inspect_df, args, inspection_eval_config)

# inspect_texts, inspect_labels = hiring_bias_prompts.process_hiring_bias_resumes_prompts(inspect_prompts, args)

# inspect_texts = model_utils.add_chat_template(inspect_texts, model_name)

# inspect_tokens_FBL, inspect_activations_FBL = interp_utils.get_interp_prompts_user_inputs(
#     model,
#     submodules[0],
#     sae,
#     top_k_ids,
#     inspect_texts,
#     tokenizer=tokenizer,
#     batch_size=batch_size * 6,
#     k=len(inspect_texts),
# )

# num_display = 2

# html_activations = display_html_activations(inspect_tokens_FBL, inspect_activations_FBL, num_display=num_display, k=len(inspect_texts))

# for i, html_activation in enumerate(html_activations):
#     feature_idx = top_k_ids[i]
#     feature_val = top_k_vals[i]
#     print(f"Feature {i}, feature idx: {feature_idx}, value: {feature_val}")
#     display(html_activation)