In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import einops
from transformers import AutoModelForCausalLM, AutoTokenizer
from dataclasses import dataclass
import os
import random
from copy import deepcopy
from typing import Callable, Optional
from tqdm import tqdm
import pickle

import mypkg.whitebox_infra.attribution as attribution
import mypkg.whitebox_infra.dictionaries.batch_topk_sae as batch_topk_sae
import mypkg.whitebox_infra.data_utils as data_utils
import mypkg.whitebox_infra.model_utils as model_utils
import mypkg.whitebox_infra.interp_utils as interp_utils
import mypkg.pipeline.setup.dataset as dataset_setup
import mypkg.pipeline.infra.hiring_bias_prompts as hiring_bias_prompts
from mypkg.eval_config import EvalConfig
import mypkg.pipeline.infra.model_inference as model_inference

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_name = "mistralai/Ministral-8B-Instruct-2410"
model_name = "mistralai/Mistral-Small-24B-Instruct-2501"
model_name = "google/gemma-2-2b-it"
# model_name = "google/gemma-2-27b-it"

bias_type = "gender"
bias_type = "race"
# bias_type = "political_orientation"

anti_bias_statement_file = "v1.txt"
anti_bias_statement_file = "v3.txt"
# anti_bias_statement_file = "v17.txt"

args = hiring_bias_prompts.HiringBiasArgs(
    political_orientation=bias_type == "political_orientation",
    employment_gap=bias_type == "employment_gap",
    pregnancy=bias_type == "pregnancy",
    race=bias_type == "race",
    gender=bias_type == "gender",
    misc=bias_type == "misc",
)


dtype = torch.bfloat16
model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=dtype, device_map=device
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

gradient_checkpointing = False

if model_name == "google/gemma-2-27b-it":
    gradient_checkpointing = True
    batch_size = 1
elif model_name == "mistralai/Mistral-Small-24B-Instruct-2501" or model_name == "google/gemma-2-2b-it":
    batch_size = 1
else:
    batch_size = 3

if gradient_checkpointing:
    model.config.use_cache = False
    model.gradient_checkpointing_enable()


chosen_layer_percentage = [25]
# chosen_layer_percentage = [50]

system_prompt = "yes_no.txt"

chosen_layers = []
for layer_percent in chosen_layer_percentage:
    chosen_layers.append(model_utils.MODEL_CONFIGS[model_name]["layer_mappings"][layer_percent]["layer"])

eval_config = EvalConfig(
        model_name=model_name,
        political_orientation=True,
        pregnancy=False,
        employment_gap=False,
        anthropic_dataset=False,
        downsample=50,
        # downsample=10,
        gpu_inference=True,
        anti_bias_statement_file=anti_bias_statement_file,
        job_description_file="short_meta_job_description.txt",
        system_prompt_filename=system_prompt,
    )


Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.68it/s]


In [5]:
# sae_repo = "adamkarvonen/ministral_saes"
# sae_path = f"mistralai_Ministral-8B-Instruct-2410_batch_top_k/resid_post_layer_{chosen_layers[0]}/trainer_1/ae.pt"

# sae = batch_topk_sae.load_dictionary_learning_batch_topk_sae(
#     repo_id=sae_repo,
#     filename=sae_path,
#     model_name=model_name,
#     device=device,
#     dtype=dtype,
#     layer=chosen_layers[0],
#     local_dir="downloaded_saes",
# )
trainer_id = 2

if "gemma" in model_name:
    trainer_id = 0
    
sae = model_utils.load_model_sae(model_name, device, dtype, chosen_layer_percentage[0], trainer_id=trainer_id)

submodules = [model_utils.get_submodule(model, chosen_layers[0])]

In [6]:
df = dataset_setup.load_raw_dataset()

industry = "INFORMATION-TECHNOLOGY"
downsample = eval_config.downsample
random_seed = eval_config.random_seed

random.seed(random_seed)
torch.manual_seed(random_seed)

df = dataset_setup.filter_by_industry(df, industry)

df = dataset_setup.balanced_downsample(df, downsample, random_seed)



prompts = hiring_bias_prompts.create_all_prompts_hiring_bias(df, args, eval_config)

Downsampled to 50 unique resumes
Total samples after maintaining demographic variations: 200


In [None]:
train_texts, train_labels, train_prompt_details = hiring_bias_prompts.process_hiring_bias_resumes_prompts(prompts, args)

train_texts = model_utils.add_chat_template(train_texts, model_name)

dataloader = data_utils.create_simple_dataloader(
    train_texts, train_labels, model_name, device, batch_size=batch_size, max_length=2500
)

Filtered out 4 samples that exceeded max_length (2500)
Original dataset size: 200, new size: 196


In [8]:
# model_answers = model_inference.run_single_forward_pass_transformers(
#     prompts, model_name, batch_size=batch_size * 6, model=model
# )

# bias_scores = hiring_bias_prompts.evaluate_bias(
#     model_answers,
#     system_prompt
# )
# print(bias_scores)

In [9]:
diff_acts_F = attribution.get_activations(model, sae, dataloader, submodules, chosen_layers, device)

  0%|          | 0/196 [00:00<?, ?it/s]

100%|██████████| 196/196 [00:12<00:00, 15.14it/s]


In [10]:
acts_top_k_ids = diff_acts_F.abs().topk(20).indices
print(acts_top_k_ids)

acts_top_k_vals = diff_acts_F[acts_top_k_ids]
print(acts_top_k_vals)

output_filename = os.path.join(
    "diff_acts",
    f"{eval_config.anti_bias_statement_file.replace('.txt', '')}_trainer_{trainer_id}_model_{model_name.replace('/', '_')}_layer_{chosen_layer_percentage[0]}_attrib_data.pt",
)
print(output_filename)

os.makedirs("diff_acts", exist_ok=True)

diff_acts = {"diff_acts_F": diff_acts_F}
diff_acts["config"] = {
        "model_name": model_name,
        "layer": chosen_layers[0],
        "bias_categories": "N\A",
        "anti_bias_statement_file": eval_config.anti_bias_statement_file,
        "downsample": eval_config.downsample,
        "random_seed": eval_config.random_seed,
        "batch_size": batch_size,
        "chosen_layer_percentage": chosen_layer_percentage,
        "trainer_id": trainer_id,
    }

torch.save(diff_acts, output_filename)

# tensor([23759, 42925, 33394, 45780, 10085, 23574, 30460, 61472,   521, 59020,
#           775,  4261, 41205, 29588, 44488,   983, 55773, 44225, 42612, 30063],
#        device='cuda:0')
# tensor([-2.2821e-05, -1.8463e-05, -1.8400e-05,  1.6533e-05, -1.6091e-05,
#          1.5563e-05,  1.5402e-05,  1.4945e-05, -1.4770e-05, -1.4396e-05,
#         -1.4326e-05, -1.4277e-05, -1.3993e-05,  1.3824e-05, -1.3621e-05,
#         -1.2731e-05,  1.2631e-05,  1.2598e-05, -1.2597e-05, -1.2533e-05],
#        device='cuda:0')
# diff_acts/v1_trainer_3_model_mistralai_Ministral-8B-Instruct-2410_layer_50_attrib_data.pt

# tensor([24895, 37973, 29596, 35104, 19677, 64690, 23575, 14460, 54626, 60262,
#         36264, 10894,  2577,  6381, 25218, 17486, 50206,  1279,  9861, 26352],
#        device='cuda:0')
# tensor([-0.0399, -0.0257,  0.0186, -0.0184,  0.0138,  0.0096,  0.0089, -0.0089,
#         -0.0082, -0.0079, -0.0077,  0.0074,  0.0067, -0.0066, -0.0062, -0.0060,
#          0.0056,  0.0050,  0.0046,  0.0046], device='cuda:0')
# diff_acts/v17_trainer_2_model_mistralai_Mistral-Small-24B-Instruct-2501_layer_50_attrib_data.pt

tensor([ 3513,  2841,   300,  4559, 15884,   461,  6517,  5023,   384,  5464,
         7549,  4255, 14280, 16038, 13324, 13715, 11782,  7574,  3844, 14850],
       device='cuda:0')
tensor([ 0.0163,  0.0116,  0.0108, -0.0083, -0.0082, -0.0075,  0.0067,  0.0065,
        -0.0064,  0.0064,  0.0062,  0.0056,  0.0056,  0.0055,  0.0054, -0.0052,
        -0.0049, -0.0048,  0.0048, -0.0047], device='cuda:0')
diff_acts/v3_trainer_0_model_google_gemma-2-2b-it_layer_25_attrib_data.pt


In [11]:
# Build the custom loss function
yes_vs_no_loss_fn = attribution.make_yes_no_loss_fn(
    tokenizer,
    yes_candidates=["yes", " yes", "Yes", " Yes", "YES", " YES"],
    no_candidates=["no", " no", "No", " No", "NO", " NO"],
    device=device,
)

effects_F, error_effect, predicted_tokens = attribution.get_effects(
    model,
    tokenizer,
    sae,
    dataloader,
    yes_vs_no_loss_fn,
    submodules,
    chosen_layers,
    device,
    verbose=True,
)

# Print peak memory usage
if torch.cuda.is_available():
    peak_memory = torch.cuda.max_memory_allocated() / 1024**2  # Convert to MB
    print(f"Peak CUDA memory usage: {peak_memory:.2f} MB")

# Peak CUDA memory usage: 33432.72 MB

  0%|          | 0/196 [00:00<?, ?it/s]

100%|██████████| 196/196 [01:29<00:00,  2.20it/s]


Label 1:
Total samples: 98
Yes rate: 67.3%
No rate: 32.7%
Invalid rate: 0.0%

Label 0:
Total samples: 98
Yes rate: 66.3%
No rate: 33.7%
Invalid rate: 0.0%
Peak CUDA memory usage: 15365.44 MB





In [12]:
top_k_ids = effects_F.abs().topk(20).indices
print(top_k_ids)

top_k_vals = effects_F[top_k_ids]
print(top_k_vals)

print(error_effect)

# tensor([ 4393, 15242,  9049,  3959, 11802, 14960,   428,  9920,  2715,  3509,
#          9444, 12979,  9319,  8910, 12243,  7781, 11637, 10283,  4204,  2557],
#        device='cuda:0')
# tensor([0.0074, 0.0058, 0.0033, 0.0016, 0.0016, 0.0013, 0.0011, 0.0009, 0.0008,
#         0.0008, 0.0008, 0.0008, 0.0007, 0.0007, 0.0007, 0.0007, 0.0006, 0.0006,
#         0.0006, 0.0005], device='cuda:0')

# tensor([ 4393, 15242,  9049, 13855, 11802,  3959,  2039, 14960,  1683,  3509,
#           428,  4794,  3645,  9920,  5911,  1160,  1656, 16078,  9319,   394],
#        device='cuda:0')
# tensor([ 0.0305,  0.0242,  0.0133, -0.0078,  0.0064,  0.0063, -0.0058,  0.0052,
#         -0.0051,  0.0050,  0.0045, -0.0043, -0.0042,  0.0040, -0.0039, -0.0034,
#         -0.0033, -0.0033,  0.0033, -0.0031], device='cuda:0')
# tensor(-0.0181, device='cuda:0')

# tensor([ 4393, 15242,  9049, 13855, 11802,  3645,  2039,  3959,  1683,  3509,
#         14960,  9920,  4794,  1656,  5911, 16078,  9319,  1160,  5286,   394],
#        device='cuda:0')
# tensor([ 0.0276,  0.0227,  0.0114, -0.0077,  0.0062, -0.0058, -0.0056,  0.0055,
#         -0.0049,  0.0047,  0.0046,  0.0038, -0.0036, -0.0035, -0.0034, -0.0034,
#          0.0033, -0.0032, -0.0030, -0.0029], device='cuda:0')
# tensor(-0.0127, device='cuda:0')

# tensor([15658, 54626, 17662, 37447, 26352, 13920, 47911,   413,   204, 39717,
#         52330, 24031, 62241, 53133, 14038,  4682, 13032, 61127, 46104,  9405],
#        device='cuda:0')
# tensor([ 0.3070,  0.1780,  0.1460,  0.1291,  0.1048, -0.0993, -0.0823,  0.0779,
#          0.0748,  0.0666,  0.0586,  0.0581, -0.0442, -0.0442,  0.0392,  0.0377,
#         -0.0370,  0.0332, -0.0313, -0.0293], device='cuda:0')
# tensor(-0.0644, device='cuda:0')

# tensor([33592,  3662, 47196, 52209,  9593, 18947, 64744, 50582, 45795,   574,
#         57031, 61625, 33006, 53980, 46136,  4975, 12936, 32874,  6954, 38436],
#        device='cuda:0')
# tensor([ 0.1563, -0.1505, -0.1254,  0.1122,  0.1059,  0.1016, -0.0913, -0.0874,
#         -0.0841, -0.0828, -0.0805,  0.0714,  0.0673,  0.0635, -0.0592, -0.0583,
#          0.0563,  0.0545, -0.0545, -0.0544], device='cuda:0')
# tensor(0.0565, device='cuda:0')

# downsample 10
# tensor([63073,  3662, 45795, 33006, 58683, 50582, 45304, 42524, 33592, 27866,
#         21772, 12936, 25266, 57031, 15390, 60822, 64744, 44535, 60211,   169],
#        device='cuda:0')
# tensor([-0.5895, -0.2480, -0.2339,  0.2267,  0.2038, -0.1890, -0.1808,  0.1536,
#          0.1423, -0.1383, -0.1340,  0.1250, -0.1238, -0.1221,  0.1199,  0.1138,
#         -0.1091, -0.1051,  0.1040,  0.0992], device='cuda:0')
# tensor(0.6601, device='cuda:0')

tensor([ 2870,  4048, 13789,  3513, 15263,   949, 15216,  2908,  6482,  4235,
        14392, 14217, 10273,  3235,  2841,  6652, 10847, 12553,  3587,  5859],
       device='cuda:0')
tensor([ 3.6280, -1.8414,  1.4296,  0.6981,  0.6451, -0.5354, -0.5200, -0.5059,
        -0.5012, -0.4925, -0.4562, -0.4433,  0.4365, -0.4199,  0.4192,  0.4149,
        -0.3746, -0.3451, -0.3224, -0.3199], device='cuda:0')
tensor(0.8407, device='cuda:0')


In [13]:
set1 = set(acts_top_k_ids.cpu().tolist())
set2 = set(top_k_ids.cpu().tolist())

print(set1)
print(set2)

print(set1.intersection(set2))

{384, 14850, 3844, 11782, 15884, 13324, 13715, 7574, 2841, 4255, 5023, 16038, 300, 3513, 14280, 461, 4559, 5464, 6517, 7549}
{3587, 14217, 12553, 4235, 2841, 15263, 10273, 3235, 949, 2870, 14392, 3513, 4048, 6482, 2908, 13789, 10847, 5859, 15216, 6652}
{2841, 3513}


In [14]:
print(sae.W_dec.shape[0])

16384


In [15]:
acts_dir = "max_acts"
acts_filename = f"acts_{model_name}_layer_{chosen_layers[0]}_trainer_{trainer_id}_layer_percent_{chosen_layer_percentage[0]}.pt".replace("/", "_")
print(acts_filename)
acts_path = os.path.join(acts_dir, acts_filename)
if not os.path.exists(acts_path):
    from huggingface_hub import hf_hub_download
    path_to_config = hf_hub_download(
        repo_id="adamkarvonen/sae_max_acts",
        filename=acts_filename,
        force_download=False,
        local_dir=acts_dir,
        repo_type="dataset",
    )

    
    acts_data = torch.load(acts_path)
    
    # max_tokens, max_acts = interp_utils.get_interp_prompts(
    #     model,
    #     submodules[0],
    #     sae,
    #     torch.tensor(list(range(sae.W_dec.shape[0]))),
    #     context_length=128,
    #     tokenizer=tokenizer,
    #     batch_size=batch_size * 32,
    #     num_tokens=30_000_000,
    # )
    # acts_data = {
    #     "max_tokens": max_tokens,
    #     "max_acts": max_acts,
    # }
    # torch.save(acts_data, acts_filename)
else:
    acts_data = torch.load(acts_path)
max_tokens = acts_data["max_tokens"].cpu()
max_acts = acts_data["max_acts"].cpu()


acts_google_gemma-2-2b-it_layer_5_trainer_0_layer_percent_25.pt


In [16]:
from circuitsvis.activations import text_neuron_activations
import gc
from IPython.display import clear_output, display


# def _list_decode(x):
#     if len(x.shape) == 0:
#         return tokenizer.decode(x, skip_special_tokens=False)
#     else:
#         return [_list_decode(y) for y in x]

def _list_decode(x: torch.Tensor):
    assert len(x.shape) == 1 or len(x.shape) == 2
    # Convert to list of lists, even if x is 1D
    if len(x.shape) == 1:
        x = x.unsqueeze(0)  # Make it 2D for consistent handling

    # Convert tensor to list of list of ints
    token_ids = x.tolist()
    
    # Convert token ids to token strings
    return [tokenizer.batch_decode(seq, skip_special_tokens=False) for seq in token_ids]


def create_html_activations(
    selected_tokens_FKL: list[str],
    selected_activations_FKL: list[torch.Tensor],
    num_display: int = 10,
    k: int = 5,
):

    all_html_activations = []

    for i in range(num_display):

        selected_activations_KL11 = [
            selected_activations_FKL[i, k, :, None, None] for k in range(k)
        ]
        selected_tokens_KL = selected_tokens_FKL[i]
        selected_token_strs_KL = _list_decode(selected_tokens_KL)

        # print(selected_token_strs_KL)

        for k in range(len(selected_token_strs_KL)):
            if (
                "<s>" in selected_token_strs_KL[k][0]
                or "<bos>" in selected_token_strs_KL[k][0]
            ):
                selected_token_strs_KL[k][0] = "BOS>"

        # selected_token_strs_KL = tokenizer.batch_decode(selected_token_KL, skip_special_tokens=False)
        # for k in range(len(selected_token_strs_KL)):
        #     string = selected_token_strs_KL[k]
        #     print(string[:10])
        # print("".join(string))

        html_activations = text_neuron_activations(
            selected_token_strs_KL, selected_activations_KL11
        )

        all_html_activations.append(html_activations)
    
    return all_html_activations

clear_output(wait=True)
gc.collect()
html_activations = create_html_activations(max_tokens[top_k_ids.cpu()], max_acts[top_k_ids.cpu()], num_display=20)

with open("autointerp_html_activations.pkl", "wb") as f:
    pickle.dump(html_activations, f)

for i, html_activation in enumerate(html_activations):
    feature_idx = top_k_ids[i]
    feature_val = top_k_vals[i]
    print(f"Feature {i}, feature idx: {feature_idx}, value: {feature_val}")



Feature 0, feature idx: 2870, value: 3.62801194190979
Feature 1, feature idx: 4048, value: -1.8413792848587036
Feature 2, feature idx: 13789, value: 1.429613471031189
Feature 3, feature idx: 3513, value: 0.6981467008590698
Feature 4, feature idx: 15263, value: 0.6450632214546204
Feature 5, feature idx: 949, value: -0.5353862643241882
Feature 6, feature idx: 15216, value: -0.5199776887893677
Feature 7, feature idx: 2908, value: -0.5058706402778625
Feature 8, feature idx: 6482, value: -0.5011627078056335
Feature 9, feature idx: 4235, value: -0.4924779534339905
Feature 10, feature idx: 14392, value: -0.45618849992752075
Feature 11, feature idx: 14217, value: -0.44330069422721863
Feature 12, feature idx: 10273, value: 0.43647849559783936
Feature 13, feature idx: 3235, value: -0.4198872148990631
Feature 14, feature idx: 2841, value: 0.4192456603050232
Feature 15, feature idx: 6652, value: 0.4148939251899719
Feature 16, feature idx: 10847, value: -0.3745908737182617
Feature 17, feature idx: 

In [17]:
# from copy import deepcopy

# random.seed(random_seed)
# torch.manual_seed(random_seed)

# inspection_eval_config = deepcopy(eval_config)
# inspection_eval_config.downsample = 3

# inspect_df = dataset_setup.balanced_downsample(df, inspection_eval_config.downsample, random_seed)

# inspect_prompts = hiring_bias_prompts.create_all_prompts_hiring_bias(inspect_df, args, inspection_eval_config)

# inspect_texts, inspect_labels = hiring_bias_prompts.process_hiring_bias_resumes_prompts(inspect_prompts, args)

# inspect_texts = model_utils.add_chat_template(inspect_texts, model_name)

# inspect_tokens_FBL, inspect_activations_FBL = interp_utils.get_interp_prompts_user_inputs(
#     model,
#     submodules[0],
#     sae,
#     top_k_ids,
#     inspect_texts,
#     tokenizer=tokenizer,
#     batch_size=batch_size * 6,
#     k=len(inspect_texts),
#     sort_by_activation=False,
# )

# print(inspect_tokens_FBL.shape)
# print(inspect_activations_FBL.shape)
# print(inspect_labels)

# labels_tensor = torch.tensor(inspect_labels)

# pos_mask_K = labels_tensor == 1
# neg_mask_K = labels_tensor == 0

# print(pos_mask_K)
# print(neg_mask_K)

# print(inspect_activations_FBL.shape)

# pos_acts_FKL = inspect_activations_FBL[:, pos_mask_K, :]
# neg_acts_FKL = inspect_activations_FBL[:, neg_mask_K, :]

# mean_pos_acts_F = pos_acts_FKL.mean(dim=(1,2))
# mean_neg_acts_F = neg_acts_FKL.mean(dim=(1,2))

# torch.set_printoptions(precision=4, sci_mode=False)

# print(mean_pos_acts_F)
# print(mean_neg_acts_F)

# ratios = mean_pos_acts_F / mean_neg_acts_F

# print(ratios)




In [18]:
def get_attrib_acts(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    sae: batch_topk_sae.BatchTopKSAE,
    top_k_ids: torch.Tensor,
    batch_size: int,
    device: torch.device,
    chosen_layers: list[int],
    submodules: list[torch.nn.Module],
    yes_vs_no_loss_fn: Callable,
    downsample: int
) -> list[dict]:
    random.seed(random_seed)
    torch.manual_seed(random_seed)

    inspection_eval_config = deepcopy(eval_config)
    inspection_eval_config.downsample = downsample
    inspection_eval_config.anti_bias_statement_file = anti_bias_statement_file
    # inspection_eval_config.anti_bias_statement_file = "v1.txt"
    # inspection_eval_config.anti_bias_statement_file = "v17.txt"
    # inspection_eval_config.anti_bias_statement_file = "v6.txt"

    inspect_df = dataset_setup.balanced_downsample(
        df, inspection_eval_config.downsample, random_seed
    )

    inspect_prompts = hiring_bias_prompts.create_all_prompts_hiring_bias(
        inspect_df, args, inspection_eval_config
    )

    inspect_texts, inspect_labels = (
        hiring_bias_prompts.process_hiring_bias_resumes_prompts(inspect_prompts, args)
    )

    inspect_texts = model_utils.add_chat_template(inspect_texts, model_name)

    inspect_dataloader = data_utils.create_simple_dataloader(
        inspect_texts,
        inspect_labels,
        model_name,
        device=device,
        batch_size=batch_size,
        shuffle=False,
    )

    all_batch_results = []

    for batch in tqdm(inspect_dataloader):
        input_ids, attention_mask, labels, idx_batch = batch

        model_inputs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
        }
        batch_results = attribution.compute_attributions(
            model,
            tokenizer,
            sae,
            model_inputs,
            labels,
            chosen_layers,
            submodules,
            yes_vs_no_loss_fn,
        )
        batch_results = batch_results[chosen_layers[0]]

        batch_results["encoded_acts_BLF"] = batch_results["encoded_acts_BLF"][
            :, :, top_k_ids
        ]
        batch_results["effects_BLF"] = batch_results["effects_BLF"][:, :, top_k_ids]
        batch_results["grad_x_dot_decoder_BLF"] = batch_results[
            "grad_x_dot_decoder_BLF"
        ][:, :, top_k_ids]
        all_batch_results.append(batch_results)
    return all_batch_results

all_attrib_batch_results = get_attrib_acts(model, tokenizer, sae, top_k_ids, batch_size, device, chosen_layers, submodules, yes_vs_no_loss_fn, downsample=10)

Downsampled to 10 unique resumes
Total samples after maintaining demographic variations: 40


100%|██████████| 40/40 [00:18<00:00,  2.18it/s]


In [19]:
# attrib_html_activations = []


# def reshape_attrib_acts(
#     all_batch_results: dict, gradients: bool = False
# ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
#     max_len = 0
#     for batch_results in all_batch_results:
#         model_inputs = batch_results["model_inputs"]
#         input_tokens_BL = model_inputs["input_ids"].cpu()
#         max_len = max(max_len, input_tokens_BL.shape[1])

#     padded_tokens_list = []
#     padded_acts_list = []
#     labels_list = []

#     print(max_len)

#     for batch_results in all_batch_results:
#         # print(batch_results.keys())
#         # print(batch_results)
#         model_inputs = batch_results["model_inputs"]
#         predicted_tokens = batch_results["predicted_tokens"]
#         input_tokens_BL = model_inputs["input_ids"].cpu()
#         sae_acts_BLF = batch_results["encoded_acts_BLF"]
#         effects_BLF = batch_results["effects_BLF"]
#         grad_x_dot_decoder_BLF = batch_results["grad_x_dot_decoder_BLF"]
#         labels_B = batch_results["labels"]
#         labels_list.append(labels_B)

#         # print(sae_acts_BLF.shape)
#         # print(effects_BLF.shape)
#         # print(grad_x_dot_decoder_BLF.shape)
#         # print(predicted_tokens)
#         # print(input_tokens_BL)

#         B, current_L, F = sae_acts_BLF.shape

#         input_tokens_FBL = einops.repeat(input_tokens_BL, "B L -> F B L", F=F)
#         # print(input_tokens_FBL.shape)

#         if gradients:
#             chosen_acts_BLF = grad_x_dot_decoder_BLF.clone()
#         else:
#             chosen_acts_BLF = sae_acts_BLF.clone()
#         chosen_acts_FBL = einops.rearrange(chosen_acts_BLF, "B L F -> F B L")
#         chosen_acts_FBL = chosen_acts_FBL.to(dtype=torch.float32)
#         # print(chosen_acts_FBL.shape)

#         pad_len = max_len - current_L

#         # Pad the last dimension (L)
#         # The padding tuple is (pad_left, pad_right) for the last dimension
#         padded_input_tokens_FBL = torch.nn.functional.pad(
#             input_tokens_FBL,
#             (0, pad_len),
#             mode="constant",
#             value=tokenizer.pad_token_id,
#         )
#         padded_chosen_acts_FBL = torch.nn.functional.pad(
#             chosen_acts_FBL, (0, pad_len), mode="constant", value=0
#         )
#         # --- Padding End ---

#         padded_tokens_list.append(padded_input_tokens_FBL)
#         padded_acts_list.append(padded_chosen_acts_FBL)

#     all_tokens_FBL = torch.cat(padded_tokens_list, dim=1)
#     all_acts_FBL = torch.cat(padded_acts_list, dim=1)
#     labels_B = torch.cat(labels_list, dim=0)

#     print(all_tokens_FBL.shape)
#     print(all_acts_FBL.shape)
#     print(labels_B.shape)

#     return all_tokens_FBL, all_acts_FBL, labels_B


# def analyze_acts(
#     acts_FBL: torch.Tensor,
#     tokens_FBL: torch.Tensor,
#     labels_B: torch.Tensor,
#     top_k_ids: torch.Tensor,
#     top_k_vals: torch.Tensor,
#     gradients: bool = False
# ):
#     print(acts_FBL.shape)
#     print(tokens_FBL.shape)
#     print(labels_B)

#     pos_mask_K = labels_B == 1
#     neg_mask_K = labels_B == 0

#     print(pos_mask_K)
#     print(neg_mask_K)

#     print(acts_FBL.shape)

#     pos_acts_FKL = acts_FBL[:, pos_mask_K, :]
#     neg_acts_FKL = acts_FBL[:, neg_mask_K, :]

#     if gradients:
#         pos_acts_FKL = pos_acts_FKL[:, :, -10:]
#         neg_acts_FKL = neg_acts_FKL[:, :, -10:]
#         neg_acts_FKL = neg_acts_FKL * -1

#     mean_pos_acts_F = pos_acts_FKL.mean(dim=(1, 2))
#     mean_neg_acts_F = neg_acts_FKL.mean(dim=(1, 2))

#     torch.set_printoptions(precision=6, sci_mode=False)

#     print(mean_pos_acts_F)
#     print(mean_neg_acts_F)

#     ratios = mean_pos_acts_F / mean_neg_acts_F

#     print(ratios)

#     for i in range(top_k_ids.shape[0]):
#         print()
#         print(f"Feature {i}, feature idx: {top_k_ids[i]}, value: {top_k_vals[i]:.5f}")
#         print(f"Pos acts: {mean_pos_acts_F[i]:.5f}, Neg acts: {mean_neg_acts_F[i]:.5f}, Ratio: {ratios[i]:.5f}")



# all_attrib_tokens_FBL, all_attrib_acts_FBL, all_attrib_labels_B = reshape_attrib_acts(
#     all_attrib_batch_results, gradients=True
# )

# analyze_acts(
#     all_attrib_acts_FBL,
#     all_attrib_tokens_FBL,
#     all_attrib_labels_B,
#     top_k_ids,
#     top_k_vals,
#     gradients=True
# )

# # html_activations = create_html_activations(all_attrib_tokens_FBL, all_attrib_acts_FBL, num_display=5, k=all_attrib_acts_FBL.shape[1])

# # with open("prompts_html_activations.pkl", "wb") as f:
# #     pickle.dump(html_activations, f)


In [20]:
# ... existing code ...


def reshape_attrib_acts(
    all_batch_results: dict, gradients: bool = False
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    # ... existing code in reshape_attrib_acts ...
    # No changes needed inside this function itself
    max_len = 0
    for batch_results in all_batch_results:
        model_inputs = batch_results["model_inputs"]
        input_tokens_BL = model_inputs["input_ids"].cpu()
        max_len = max(max_len, input_tokens_BL.shape[1])

    padded_tokens_list = []
    padded_acts_list = []
    labels_list = []

    # print(max_len)

    for batch_results in all_batch_results:
        # print(batch_results.keys())
        # print(batch_results)
        model_inputs = batch_results["model_inputs"]
        predicted_tokens = batch_results["predicted_tokens"]
        input_tokens_BL = model_inputs["input_ids"].cpu()
        sae_acts_BLF = batch_results["encoded_acts_BLF"]
        effects_BLF = batch_results["effects_BLF"]
        grad_x_dot_decoder_BLF = batch_results["grad_x_dot_decoder_BLF"]
        labels_B = batch_results["labels"]
        labels_list.append(labels_B)

        # print(sae_acts_BLF.shape)
        # print(effects_BLF.shape)
        # print(grad_x_dot_decoder_BLF.shape)
        # print(predicted_tokens)
        # print(input_tokens_BL)

        B, current_L, F = sae_acts_BLF.shape

        input_tokens_FBL = einops.repeat(input_tokens_BL, "B L -> F B L", F=F)
        # print(input_tokens_FBL.shape)

        if gradients:
            chosen_acts_BLF = grad_x_dot_decoder_BLF.clone()
        else:
            chosen_acts_BLF = sae_acts_BLF.clone()
        chosen_acts_FBL = einops.rearrange(chosen_acts_BLF, "B L F -> F B L")
        chosen_acts_FBL = chosen_acts_FBL.to(dtype=torch.float32)
        # print(chosen_acts_FBL.shape)

        pad_len = max_len - current_L

        # Pad the last dimension (L)
        # The padding tuple is (pad_left, pad_right) for the last dimension
        padded_input_tokens_FBL = torch.nn.functional.pad(
            input_tokens_FBL,
            (0, pad_len),
            mode="constant",
            value=tokenizer.pad_token_id,
        )
        padded_chosen_acts_FBL = torch.nn.functional.pad(
            chosen_acts_FBL, (0, pad_len), mode="constant", value=0
        )
        # --- Padding End ---

        padded_tokens_list.append(padded_input_tokens_FBL)
        padded_acts_list.append(padded_chosen_acts_FBL)

    all_tokens_FBL = torch.cat(padded_tokens_list, dim=1)
    all_acts_FBL = torch.cat(padded_acts_list, dim=1)
    labels_B = torch.cat(labels_list, dim=0)

    # print(all_tokens_FBL.shape)
    # print(all_acts_FBL.shape)
    # print(labels_B.shape)

    return all_tokens_FBL, all_acts_FBL, labels_B


def analyze_acts(
    acts_FBL: torch.Tensor,
    acts_grad_FBL: torch.Tensor,  # Added gradient activations
    tokens_FBL: torch.Tensor,
    labels_B: torch.Tensor,
    top_k_ids: torch.Tensor,
    top_k_vals: torch.Tensor,
    output_filename: str,
):
    print(f"Acts shape: {acts_FBL.shape}")
    print(f"Grad Acts shape: {acts_grad_FBL.shape}")
    print(f"Tokens shape: {tokens_FBL.shape}")
    print(f"Labels: {labels_B}")

    pos_mask_K = labels_B == 1
    neg_mask_K = labels_B == 0

    # print(f"Pos mask: {pos_mask_K}")
    # print(f"Neg mask: {neg_mask_K}")

    # --- Regular Activations ---
    pos_acts_FKL = acts_FBL[:, pos_mask_K, :]
    neg_acts_FKL = acts_FBL[:, neg_mask_K, :]
    mean_pos_acts_F = pos_acts_FKL.mean(dim=(1, 2))
    mean_neg_acts_F = neg_acts_FKL.mean(dim=(1, 2))
    ratios = mean_pos_acts_F / mean_neg_acts_F
    ratios = torch.nan_to_num(ratios, nan=0.0)  # Handle potential division by zero

    # --- Gradient Activations ---
    pos_acts_grad_FKL = acts_grad_FBL[:, pos_mask_K, :]
    neg_acts_grad_FKL = acts_grad_FBL[:, neg_mask_K, :]
    # Take last 10 tokens for gradients and invert negative examples as before
    # pos_acts_grad_FKL = pos_acts_grad_FKL[:, :, -40:]
    # neg_acts_grad_FKL = neg_acts_grad_FKL[:, :, -40:]

    pos_acts_grad_FKL = pos_acts_grad_FKL[:, :, -10:]
    neg_acts_grad_FKL = neg_acts_grad_FKL[:, :, -10:]
    neg_acts_grad_FKL = neg_acts_grad_FKL * -1
    mean_pos_acts_grad_F = pos_acts_grad_FKL.mean(dim=(1, 2))
    mean_neg_acts_grad_F = neg_acts_grad_FKL.mean(dim=(1, 2))
    ratios_grad = mean_pos_acts_grad_F / mean_neg_acts_grad_F
    ratios_grad = torch.nan_to_num(
        ratios_grad, nan=0.0
    )  # Handle potential division by zero

    torch.set_printoptions(precision=5, sci_mode=False)

    print("\n--- Analysis per Feature ---")
    header = f"{'Feature':<7} {'Idx':<5} {'Value':<10} | {'Acts Pos':<10} {'Acts Neg':<10} {'Acts Ratio':<10} | {'Grad Pos':<10} {'Grad Neg':<10} {'Grad Ratio':<10}"
    print(header)
    print("-" * len(header))

    for i in range(top_k_ids.shape[0]):
        print(
            f"Feature {i:<2}, "
            f"idx: {top_k_ids[i]:<5}, "
            f"value: {top_k_vals[i]:<10.5f} | "
            f"acts pos: {mean_pos_acts_F[i]:<10.5f}, "
            f"acts neg: {mean_neg_acts_F[i]:<10.5f}, "
            f"acts ratio: {ratios[i]:<10.5f} | "
            # f"grad pos: {mean_pos_acts_grad_F[i]:<10.5f}, "
            # f"grad neg: {mean_neg_acts_grad_F[i]:<10.5f}, "
            # f"grad ratio: {ratios_grad[i]:<10.5f}"
        )

    data = {
        "top_k_ids": top_k_ids,
        "top_k_vals": top_k_vals,
        "mean_pos_acts_F": mean_pos_acts_F,
        "mean_neg_acts_F": mean_neg_acts_F,
        "ratios": ratios,
        "mean_pos_acts_grad_F": mean_pos_acts_grad_F,
        "mean_neg_acts_grad_F": mean_neg_acts_grad_F,
        "ratios_grad": ratios_grad,
    }

    with open(output_filename, "wb") as f:
        pickle.dump(data, f)

output_dir = "bias_data_notebook"
output_filename = f"{bias_type}_{model_name}_layer_{chosen_layers[0]}_downsample_{downsample}_trainer_id_{trainer_id}.pkl".replace("/", "_")
output_filename = os.path.join(output_dir, output_filename)

os.makedirs(output_dir, exist_ok=True)

# --- Call reshape_attrib_acts twice ---
print("Reshaping standard activations...")
all_attrib_tokens_FBL, all_attrib_acts_FBL, all_attrib_labels_B = reshape_attrib_acts(
    all_attrib_batch_results, gradients=False
)
print("\nReshaping gradient activations...")
_, all_attrib_acts_grad_FBL, _ = reshape_attrib_acts(  # Tokens and labels are the same
    all_attrib_batch_results, gradients=True
)
print("\nRunning analysis...")
analyze_acts(
    all_attrib_acts_FBL,
    all_attrib_acts_grad_FBL,  # Pass gradient acts
    all_attrib_tokens_FBL,
    all_attrib_labels_B,
    top_k_ids,
    top_k_vals,
    output_filename,
)

html_activations = create_html_activations(
    all_attrib_tokens_FBL,
    # all_attrib_acts_grad_FBL,
    all_attrib_acts_FBL,
    num_display=12,
    k=all_attrib_acts_FBL.shape[1],
)

with open("prompts_html_activations.pkl", "wb") as f:
    pickle.dump(html_activations, f)

Reshaping standard activations...

Reshaping gradient activations...

Running analysis...
Acts shape: torch.Size([20, 40, 2408])
Grad Acts shape: torch.Size([20, 40, 2408])
Tokens shape: torch.Size([20, 40, 2408])
Labels: tensor([1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
        0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0], device='cuda:0')

--- Analysis per Feature ---
Feature Idx   Value      | Acts Pos   Acts Neg   Acts Ratio | Grad Pos   Grad Neg   Grad Ratio
----------------------------------------------------------------------------------------------
Feature 0 , idx: 2870 , value: 3.62801    | acts pos: 0.24207   , acts neg: 0.24221   , acts ratio: 0.99945    | 
Feature 1 , idx: 4048 , value: -1.84138   | acts pos: 0.13651   , acts neg: 0.13655   , acts ratio: 0.99968    | 
Feature 2 , idx: 13789, value: 1.42961    | acts pos: 5.41644   , acts neg: 5.41484   , acts ratio: 1.00030    | 
Feature 3 , idx: 3513 , value: 0.69815    | acts pos: 0.02527 

In [21]:
# attrib_html_activations = []


# def reshape_attrib_acts(
#     all_batch_results: dict, gradients: bool = False
# ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
#     max_len = 0
#     for batch_results in all_batch_results:
#         model_inputs = batch_results["model_inputs"]
#         input_tokens_BL = model_inputs["input_ids"].cpu()
#         max_len = max(max_len, input_tokens_BL.shape[1])

#     padded_tokens_list = []
#     padded_acts_list = []
#     labels_list = []

#     print(max_len)

#     for batch_results in all_batch_results:
#         # print(batch_results.keys())
#         # print(batch_results)
#         model_inputs = batch_results["model_inputs"]
#         predicted_tokens = batch_results["predicted_tokens"]
#         input_tokens_BL = model_inputs["input_ids"].cpu()
#         sae_acts_BLF = batch_results["encoded_acts_BLF"]
#         effects_BLF = batch_results["effects_BLF"]
#         grad_x_dot_decoder_BLF = batch_results["grad_x_dot_decoder_BLF"]
#         labels_B = batch_results["labels"]
#         labels_list.append(labels_B)

#         # print(sae_acts_BLF.shape)
#         # print(effects_BLF.shape)
#         # print(grad_x_dot_decoder_BLF.shape)
#         # print(predicted_tokens)
#         # print(input_tokens_BL)

#         B, current_L, F = sae_acts_BLF.shape

#         input_tokens_FBL = einops.repeat(input_tokens_BL, "B L -> F B L", F=F)
#         # print(input_tokens_FBL.shape)

#         if gradients:
#             chosen_acts_BLF = grad_x_dot_decoder_BLF.clone()
#         else:
#             chosen_acts_BLF = sae_acts_BLF.clone()
#         chosen_acts_FBL = einops.rearrange(chosen_acts_BLF, "B L F -> F B L")
#         chosen_acts_FBL = chosen_acts_FBL.to(dtype=torch.float32)
#         # print(chosen_acts_FBL.shape)

#         pad_len = max_len - current_L

#         # Pad the last dimension (L)
#         # The padding tuple is (pad_left, pad_right) for the last dimension
#         padded_input_tokens_FBL = torch.nn.functional.pad(
#             input_tokens_FBL,
#             (0, pad_len),
#             mode="constant",
#             value=tokenizer.pad_token_id,
#         )
#         padded_chosen_acts_FBL = torch.nn.functional.pad(
#             chosen_acts_FBL, (0, pad_len), mode="constant", value=0
#         )
#         # --- Padding End ---

#         padded_tokens_list.append(padded_input_tokens_FBL)
#         padded_acts_list.append(padded_chosen_acts_FBL)

#     all_tokens_FBL = torch.cat(padded_tokens_list, dim=1)
#     all_acts_FBL = torch.cat(padded_acts_list, dim=1)
#     labels_B = torch.cat(labels_list, dim=0)

#     print(all_tokens_FBL.shape)
#     print(all_acts_FBL.shape)
#     print(labels_B.shape)

#     return all_tokens_FBL, all_acts_FBL, labels_B


# def analyze_acts(
#     acts_FBL: torch.Tensor,
#     tokens_FBL: torch.Tensor,
#     labels_B: torch.Tensor,
#     top_k_ids: torch.Tensor,
#     top_k_vals: torch.Tensor,
#     gradients: bool = False
# ):
#     print(acts_FBL.shape)
#     print(tokens_FBL.shape)
#     print(labels_B)

#     pos_mask_K = labels_B == 1
#     neg_mask_K = labels_B == 0

#     print(pos_mask_K)
#     print(neg_mask_K)

#     print(acts_FBL.shape)

#     pos_acts_FKL = acts_FBL[:, pos_mask_K, :]
#     neg_acts_FKL = acts_FBL[:, neg_mask_K, :]

#     if gradients:
#         pos_acts_FKL = pos_acts_FKL[:, :, -50:]
#         neg_acts_FKL = neg_acts_FKL[:, :, -50:]
#         neg_acts_FKL = neg_acts_FKL * -1

#     mean_pos_acts_F = pos_acts_FKL.mean(dim=(1, 2))
#     mean_neg_acts_F = neg_acts_FKL.mean(dim=(1, 2))

#     torch.set_printoptions(precision=6, sci_mode=False)

#     print(mean_pos_acts_F)
#     print(mean_neg_acts_F)

#     ratios = mean_pos_acts_F / mean_neg_acts_F

#     print(ratios)

#     for i in range(top_k_ids.shape[0]):
#         print()
#         print(f"Feature {i}, feature idx: {top_k_ids[i]}, value: {top_k_vals[i]:.5f}")
#         print(f"Pos acts: {mean_pos_acts_F[i]:.5f}, Neg acts: {mean_neg_acts_F[i]:.5f}, Ratio: {ratios[i]:.5f}")



# all_attrib_tokens_FBL, all_attrib_acts_FBL, all_attrib_labels_B = reshape_attrib_acts(
#     all_attrib_batch_results, gradients=True
# )

# analyze_acts(
#     all_attrib_acts_FBL,
#     all_attrib_tokens_FBL,
#     all_attrib_labels_B,
#     top_k_ids,
#     top_k_vals,
#     gradients=True
# )

# html_activations = create_html_activations(all_attrib_tokens_FBL, all_attrib_acts_FBL, num_display=5, k=all_attrib_acts_FBL.shape[1])

# with open("prompts_html_activations.pkl", "wb") as f:
#     pickle.dump(html_activations, f)


In [22]:
# clear_output(wait=True)
# gc.collect()
# html_activations = create_html_activations(all_attrib_tokens_FBL, all_attrib_acts_FBL, num_display=2, k=all_attrib_acts_FBL.shape[1])
# for i, html_activation in enumerate(html_activations):
#     feature_idx = top_k_ids[i]
#     feature_val = top_k_vals[i]
#     print(f"Feature {i}, feature idx: {feature_idx}, value: {feature_val}")
#     display(html_activation)

In [23]:
# from copy import deepcopy

# inspection_eval_config = deepcopy(eval_config)
# inspection_eval_config.downsample = 3

# inspect_df = dataset_setup.balanced_downsample(df, inspection_eval_config.downsample, random_seed)

# inspect_prompts = hiring_bias_prompts.create_all_prompts_hiring_bias(inspect_df, args, inspection_eval_config)

# inspect_texts, inspect_labels = hiring_bias_prompts.process_hiring_bias_resumes_prompts(inspect_prompts, args)

# inspect_texts = model_utils.add_chat_template(inspect_texts, model_name)

# inspect_tokens_FBL, inspect_activations_FBL = interp_utils.get_interp_prompts_user_inputs(
#     model,
#     submodules[0],
#     sae,
#     top_k_ids,
#     inspect_texts,
#     tokenizer=tokenizer,
#     batch_size=batch_size * 6,
#     k=len(inspect_texts),
# )

# num_display = 2

# html_activations = display_html_activations(inspect_tokens_FBL, inspect_activations_FBL, num_display=num_display, k=len(inspect_texts))

# for i, html_activation in enumerate(html_activations):
#     feature_idx = top_k_ids[i]
#     feature_val = top_k_vals[i]
#     print(f"Feature {i}, feature idx: {feature_idx}, value: {feature_val}")
#     display(html_activation)