In [None]:
# Uncomment if on colab

# import os
# os.chdir('/content')
# !rm -rf LLM-Audting-Experiments
# !git clone https://github.com/FabienRoger/LLM-Audting-Experiments.git
# os.chdir('/content/LLM-Audting-Experiments')
# !pwd
# !pip install -r requirements.txt

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
from activation_utils import get_activations, get_all_activations, get_res_layers, run_and_modify, get_mlp_layers

from activation_ds import ActivationsDataset
from data_loading import PromptDataset
from inlp import inlp
from linear import get_linear_cut, get_multi_lin_cut
from logit_lense import print_logit_lense
from utils import make_projections, orthonormalize
from tqdm import tqdm
from metrics import get_avg_delta, get_perf_degradations
from itertools import product

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = GPT2LMHeadModel.from_pretrained("gpt2-xl").to(device)
print(model.device)

In [None]:
ds_names = ["men_v_women", "naive_black_v_white", "bash_v_powershell"]
prompt_dss = {n:PromptDataset("gpt2", n) for n in ds_names}

stud_layers_posibilities = {
  "mlp10-17":get_mlp_layers(model, list(range(10,17))),
  "res10-17":get_res_layers(model, list(range(10,17))),
  "mlp8-12":get_mlp_layers(model, list(range(8,12))),
  "res8-12":get_res_layers(model, list(range(8,12))),
}
dirs_fns = {
    "torch_inlp":lambda act_ds: inlp(act_ds, 8, max_iters=200, use_torch=True, loading_bar=False),
    "svm_inlp":lambda act_ds: inlp(act_ds, 8, max_iters=10_000, use_torch=False, loading_bar=False),
    "mlp":lambda act_ds: get_multi_lin_cut(act_ds, 8, epochs=500, progress_bar=False)[0],
}
get_modif_fns = {
    "default": lambda dirs, std_layer_name: {},
    "proj1": lambda dirs, std_layer_name: dict(
        [(layer, make_projections(dirs[:1], is_rest=std_layer_name.startswith("res")))
        for layer in stud_layers_posibilities[std_layer_name]]
    ),
    "proj4": lambda dirs, std_layer_name: dict(
        [(layer, make_projections(dirs[:4], is_rest=std_layer_name.startswith("res")))
        for layer in stud_layers_posibilities[std_layer_name]]
    ),
    "proj8": lambda dirs, std_layer_name: dict(
        [(layer, make_projections(dirs[:8], is_rest=std_layer_name.startswith("res")))
        for layer in stud_layers_posibilities[std_layer_name]]
    ),
}

def to_run_fn(modif):
  return lambda x, modif=modif: torch.softmax(run_and_modify(x, model, modif).logits[0].detach(), -1)


to_eval_product = list(product(stud_layers_posibilities.keys(), dirs_fns.keys()))

print(len(to_eval_product))

In [None]:
for stud_layers_n, dirs_fn_n in to_eval_product:
  for prompt_ds_n in prompt_dss.keys():
    prompt_ds = prompt_dss[prompt_ds_n]
    stud_layers = stud_layers_posibilities[stud_layers_n]
    activations = get_all_activations(prompt_ds, model, stud_layers)
    act_ds = ActivationsDataset.from_data(activations, stud_layers, device)
    dirs = dirs_fns[dirs_fn_n](act_ds)
    dirs = [d.to(device) for d in dirs]
    for get_modif_fn_n, get_modif_fn in get_modif_fns.items():
      print(stud_layers_n,dirs_fn_n, get_modif_fn_n, prompt_ds_n)
      run_fn = to_run_fn(get_modif_fn(dirs, stud_layers_n))
      run_fn_default = to_run_fn(get_modif_fns["default"](dirs, stud_layers_n))
      print(
        get_avg_delta(prompt_ds, run_fn,loading_bar=False),
        get_perf_degradations(prompt_ds, run_fn, run_fn_default,loading_bar=False),
      )