In [None]:
%load_ext autoreload
%autoreload 2

In [10]:
import os

# Must set before importing torch
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [11]:
import torch
import einops
from transformers import AutoModelForCausalLM, AutoTokenizer
from dataclasses import dataclass
import os
import random
from copy import deepcopy
from typing import Callable, Optional
from tqdm import tqdm
import pickle
import gc

import mypkg.whitebox_infra.attribution as attribution
import mypkg.whitebox_infra.dictionaries.batch_topk_sae as batch_topk_sae
import mypkg.whitebox_infra.data_utils as data_utils
import mypkg.whitebox_infra.model_utils as model_utils
import mypkg.whitebox_infra.interp_utils as interp_utils
import mypkg.pipeline.setup.dataset as dataset_setup
import mypkg.pipeline.infra.hiring_bias_prompts as hiring_bias_prompts
from mypkg.eval_config import EvalConfig
import mypkg.pipeline.infra.model_inference as model_inference
import mypkg.whitebox_infra.intervention_hooks as intervention_hooks

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_name = "mistralai/Ministral-8B-Instruct-2410"
# model_name = "mistralai/Mistral-Small-24B-Instruct-2501"
model_name = "google/gemma-2-2b-it"
# model_name = "google/gemma-2-9b-it"
# model_name = "google/gemma-2-27b-it"

bias_type = "gender"
bias_type = "race"
# bias_type = "political_orientation"

anti_bias_statement_file_idx = 3
anti_bias_statement_file = "v1.txt"
anti_bias_statement_file = f"v{anti_bias_statement_file_idx}.txt"
# anti_bias_statement_file = "v17.txt"

args = hiring_bias_prompts.HiringBiasArgs(
    political_orientation=bias_type == "political_orientation",
    employment_gap=bias_type == "employment_gap",
    pregnancy=bias_type == "pregnancy",
    race=bias_type == "race",
    gender=bias_type == "gender",
    misc=bias_type == "misc",
)


dtype = torch.bfloat16
model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=dtype, device_map=device
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

gradient_checkpointing = False

if model_name == "google/gemma-2-27b-it":
    gradient_checkpointing = True
    batch_size = 1
elif model_name == "mistralai/Mistral-Small-24B-Instruct-2501" or model_name == "google/gemma-2-2b-it":
    batch_size = 1
else:
    batch_size = 3

if gradient_checkpointing:
    model.config.use_cache = False
    model.gradient_checkpointing_enable()


chosen_layer_percentage = [25]
# chosen_layer_percentage = [50]

# chosen_layer_percentage = [75]

system_prompt = "yes_no.txt"
# system_prompt = "yes_no_qualifications.txt"

use_activation_loss_fn = True
use_activation_loss_fn = False

chosen_layers = []
for layer_percent in chosen_layer_percentage:
    chosen_layers.append(model_utils.MODEL_CONFIGS[model_name]["layer_mappings"][layer_percent]["layer"])

eval_config = EvalConfig(
        model_name=model_name,
        political_orientation=True,
        pregnancy=False,
        employment_gap=False,
        anthropic_dataset=False,
        # downsample=150,
        # downsample=5,
        downsample=20,
        gpu_inference=True,
        anti_bias_statement_file=anti_bias_statement_file,
        job_description_file="short_meta_job_description.txt",
        system_prompt_filename=system_prompt,
    )


In [5]:
# sae_repo = "adamkarvonen/ministral_saes"
# sae_path = f"mistralai_Ministral-8B-Instruct-2410_batch_top_k/resid_post_layer_{chosen_layers[0]}/trainer_1/ae.pt"

# sae = batch_topk_sae.load_dictionary_learning_batch_topk_sae(
#     repo_id=sae_repo,
#     filename=sae_path,
#     model_name=model_name,
#     device=device,
#     dtype=dtype,
#     layer=chosen_layers[0],
#     local_dir="downloaded_saes",
# )
trainer_id = model_utils.MODEL_CONFIGS[model_name]["trainer_id"]
    
sae = model_utils.load_model_sae(model_name, device, dtype, chosen_layer_percentage[0], trainer_id=trainer_id)

submodules = [model_utils.get_submodule(model, chosen_layers[0])]

In [None]:
df = dataset_setup.load_raw_dataset()

industry = "INFORMATION-TECHNOLOGY"
downsample = eval_config.downsample
random_seed = eval_config.random_seed

random.seed(random_seed)
torch.manual_seed(random_seed)

df = dataset_setup.filter_by_industry(df, industry)

df = dataset_setup.balanced_downsample(df, downsample, random_seed)



prompts = hiring_bias_prompts.create_all_prompts_hiring_bias(df, args, eval_config)

In [None]:
train_texts, train_labels, train_resume_prompt_results = hiring_bias_prompts.process_hiring_bias_resumes_prompts(prompts, model_name, args)

for i in range(len(train_texts)):
    train_resume_prompt_results[i].prompt = train_texts[i]

if system_prompt == "yes_no_qualifications.txt":
    for i in range(len(train_texts)):
        # train_texts[i] += "The most important qualifications of the candidate to consider are:"
        # train_texts[i] += "The most important factors about the candidate to consider are:"
        train_texts[i] += "The most important qualities of the candidate to consider are:"

dataloader = data_utils.create_simple_dataloader(
    train_texts, train_labels, prompts, model_name, device, batch_size=batch_size, max_length=2500
)

In [None]:
model_answers = model_inference.run_single_forward_pass_transformers(
    prompts, model_name, batch_size=batch_size * 2, model=model
)

bias_scores = hiring_bias_prompts.evaluate_bias(
    model_answers,
    system_prompt
)
print(bias_scores)

bias_probs = hiring_bias_prompts.evaluate_bias_probs(
    model_answers,
    system_prompt
)
print(bias_probs)

In [None]:

ablation_features = intervention_hooks.lookup_sae_features(model_name, trainer_id, chosen_layer_percentage[0], anti_bias_statement_file_idx, bias_type)

print(ablation_features)

In [None]:
ablation_features = torch.tensor([4356, 31477])
model_answers = model_inference.run_single_forward_pass_transformers(
    prompts, model_name, batch_size=batch_size * 2, model=model, ablation_features=ablation_features, ablation_type="clamping"
)

bias_scores = hiring_bias_prompts.evaluate_bias(
    model_answers,
    system_prompt
)

print(bias_scores)

bias_probs = hiring_bias_prompts.evaluate_bias_probs(
    model_answers,
    system_prompt
)
print(bias_probs)

In [None]:
ablation_features = torch.tensor([4356, 31477])
model_answers = model_inference.run_single_forward_pass_transformers(
    prompts, model_name, batch_size=batch_size * 2, model=model, ablation_features=ablation_features, ablation_type="adaptive_clamping", scale=2.0
)

bias_scores = hiring_bias_prompts.evaluate_bias(
    model_answers,
    system_prompt
)

print(bias_scores)

bias_probs = hiring_bias_prompts.evaluate_bias_probs(
    model_answers,
    system_prompt
)
print(bias_probs)

In [None]:
ablation_features = torch.tensor([4356, 31477])
model_answers = model_inference.run_single_forward_pass_transformers(
    prompts, model_name, batch_size=batch_size * 2, model=model, ablation_features=ablation_features, ablation_type="targeted"
)

bias_scores = hiring_bias_prompts.evaluate_bias(
    model_answers,
    system_prompt
)

print(bias_scores)

bias_probs = hiring_bias_prompts.evaluate_bias_probs(
    model_answers,
    system_prompt
)
print(bias_probs)

In [12]:
# raise ValueError("Stop here")

# ablation_features = torch.tensor([4356])
# model_answers = model_inference.run_single_forward_pass_transformers(
#     prompts, model_name, batch_size=batch_size * 2, model=model, ablation_features=ablation_features
# )

# bias_scores = hiring_bias_prompts.evaluate_bias(
#     model_answers,
#     system_prompt
# )

# print(bias_scores)

# bias_probs = hiring_bias_prompts.evaluate_bias_probs(
#     model_answers,
#     system_prompt
# )
# print(bias_probs)