In [1]:
%load_ext autoreload
%autoreload 2

from circuit_breaking.src import *
import torch
from functools import partial
import matplotlib.pyplot as plt
import numpy as np
import os
from circuit_breaking.src.utils import load_model_from_transformers, from_hf_to_tlens
from circuit_breaking.src.masks import MLPHiddenMask
from tqdm.auto import tqdm
#torch.autograd.set_detect_anomaly(True) 

In [2]:
tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "right"

model_type = "gemma"

In [3]:
from tasks import PileTask, OWTTask, InductionTask, GreaterThanTask
from tasks.ioi.IOITask import IOITask, IOITask_NPO, IOITask_Uniform
from tasks.induction.InductionTask import InductionTask, InductionTask_NPO, InductionTask_Uniform
from tasks.facts.SportsTask import SportsTask, SportsTask_NPO, SportsTask_Uniform
from tasks.facts.SportsTaskAdversarial import adversarial_sports_eval
from tasks.facts.SportsTaskSideEffects import run_side_effects_evals


train_batch_size = 4
eval_batch_size=32

device = "cuda"
train_loss_type = "sports"

maintain_sport = None


forget_sport=None
forget_athletes = 16
save_dir = f"results/localized_finetuning_{forget_athletes}_athletes"
forget_kwargs = {"forget_player_subset": forget_athletes, "is_forget_dataset": True, "train_test_split": False}
maintain_kwargs = {"forget_player_subset": forget_athletes, "is_forget_dataset": False, "train_test_split": True}
forget_loss_coef = 1

# forget_sport="basketball"
# forget_athletes = None
# save_dir = f"results/localized_finetuning_{forget_sport}"
# # save_dir = f"results/localized_finetuning_{forget_sport}_old"
# forget_kwargs = {"forget_sport_subset": {forget_sport}, "is_forget_dataset": True, "train_test_split": True}
# maintain_kwargs = {"forget_sport_subset": {forget_sport}, "is_forget_dataset": False, "train_test_split": True}
# forget_loss_coef=.2

os.makedirs(save_dir, exist_ok=True)


sports_1mp = SportsTask(batch_size=train_batch_size, tokenizer=tokenizer, device=device, prep_acdcpp=False, criterion="log_1_minus_p", **forget_kwargs)

if maintain_sport is None:
    maintain_sports = SportsTask(batch_size=train_batch_size, tokenizer=tokenizer, device=device, prep_acdcpp=False, criterion="cross_entropy", **maintain_kwargs)
else:
    maintain_sports = SportsTask(batch_size=train_batch_size, tokenizer=tokenizer, device=device, prep_acdcpp=False, criterion="cross_entropy", **maintain_kwargs)

train_pile = PileTask(batch_size=train_batch_size, tokenizer=tokenizer, device=device, ctx_length=100, shuffle=True, buffer_size=50000)
train_tasks = {"sports_1mp": (sports_1mp, forget_loss_coef), "maintain_sports": (maintain_sports, 1), "pile": (train_pile, 1)}
# train_tasks = {"maintain_sports": (maintain_sports, 1)}

# want to eval on other sports
forget_sport_eval = SportsTask(batch_size=eval_batch_size, tokenizer=tokenizer, device=device, prep_acdcpp=False, criterion="cross_entropy", **forget_kwargs)
test_pile = PileTask(batch_size=eval_batch_size, tokenizer=tokenizer, device=device, ctx_length=100, shuffle=True, buffer_size=50000)

induction_eval = InductionTask(batch_size=eval_batch_size, tokenizer=tokenizer, prep_acdcpp=False, seq_len=15, device=device)
if maintain_sport is None:
    maintain_sports_eval = SportsTask(batch_size=eval_batch_size, tokenizer=tokenizer, device=device, prep_acdcpp=False, criterion="cross_entropy", **maintain_kwargs)
    eval_tasks = {"induction": induction_eval, "pile": test_pile, "forget_sport": forget_sport_eval, "maintain_sport": maintain_sports_eval}
else:
    maintain_sport_eval = SportsTask(batch_size=eval_batch_size, tokenizer=tokenizer, device=device, prep_acdcpp=False, criterion="cross_entropy", forget_sport_subset={maintain_sport}, is_forget_dataset=True)
    val_sport_eval = SportsTask(batch_size=eval_batch_size, tokenizer=tokenizer, device=device, prep_acdcpp=False, criterion="cross_entropy", forget_sport_subset={val_sport}, is_forget_dataset=True)
    eval_tasks = {"induction": induction_eval, "pile": test_pile, "forget_sport": forget_sport_eval, "maintain_sport": maintain_sport_eval, "val_sport": val_sport_eval}


OpenAI API key not found, will not be able to run evaluations on Sports Trivia Task


Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

In [4]:
sports_1mp.train_df

Unnamed: 0.1,Unnamed: 0,athlete,sport,log_prob_one_shot,num_athlete_tokens,sport_index,sport_token,prompt
0,1642,DeForest Buckner,football,-0.492917,5,2,5842,Fact: Tiger Woods plays the sport of golf\nFac...
1,738,Walter Payton,football,-0.105714,3,2,5842,Fact: Tiger Woods plays the sport of golf\nFac...
2,16778,Anthony DeSclafani,baseball,-0.292668,6,0,14623,Fact: Tiger Woods plays the sport of golf\nFac...
3,14501,Kevin Millwood,baseball,-0.372979,3,0,14623,Fact: Tiger Woods plays the sport of golf\nFac...
4,188,Vonta Leach,football,-0.648644,5,2,5842,Fact: Tiger Woods plays the sport of golf\nFac...
5,16887,Mitch Haniger,baseball,-0.116977,3,0,14623,Fact: Tiger Woods plays the sport of golf\nFac...
6,1371,Landon Collins,football,-0.201034,3,2,5842,Fact: Tiger Woods plays the sport of golf\nFac...
7,930,Charlie Whitehurst,football,-0.370951,3,2,5842,Fact: Tiger Woods plays the sport of golf\nFac...
8,14084,Mariano Rivera,baseball,-0.060939,3,0,14623,Fact: Tiger Woods plays the sport of golf\nFac...
9,5840,Boris Diaw,basketball,-0.155351,4,1,14648,Fact: Tiger Woods plays the sport of golf\nFac...


## Relearning Evals

In [5]:
from peft import get_peft_model, LoraConfig, TaskType
def do_relearning(model, train_tasks, n_iters, finetune_lora=False, lora_kwargs={'rank': 64, 'alpha': 32, 'dropout': 0.05, 'target_modules': 'all-linear'}, learning_kwargs={'lr': 1e-2, 'weight_decay': 0, 'use_cosine': False}, eval_callback_fn=None):
    # can either finetune full or lora

    if not finetune_lora:
        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_kwargs['lr'], weight_decay=learning_kwargs['weight_decay'])

    elif finetune_lora:
        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            inference_mode=False,
            r=lora_kwargs['rank'],
            lora_alpha=lora_kwargs['alpha'],
            lora_dropout=lora_kwargs['dropout'],
            target_modules = lora_kwargs['target_modules'], #["q_proj", "v_proj", 
        )

        model = get_peft_model(model, peft_config).cuda()
        # model.print_trainable_parameters()

        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_kwargs['lr'], weight_decay=learning_kwargs['weight_decay'])
    
    if learning_kwargs['use_cosine']:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=n_iters)

    train_losses = defaultdict(list)
    test_losses = []

    for i in tqdm(range(n_iters)):
        optimizer.zero_grad()
        for task_name, (task, task_weight) in train_tasks.items():
            loss = task.get_train_loss(model)
            train_losses[task_name].append(loss.item())
            # print(loss.item())
            (loss * task_weight).backward()
        
        optimizer.step()
        if learning_kwargs['use_cosine']:
            scheduler.step()

        if eval_callback_fn is not None:
            test_losses.append(eval_callback_fn(model))

    if len(test_losses) > 0:
        return train_losses, test_losses
    return train_losses

In [6]:
n_eval_iters = 10
n_relearn_iters = 10
n_relearn_athletes = 2


if forget_sport is None:
    relearn_sport = SportsTask(batch_size=n_relearn_athletes, tokenizer=tokenizer, forget_player_subset=n_relearn_athletes, train_test_split=False, is_forget_dataset=True)
else:
    relearn_sport = SportsTask(batch_size=n_relearn_athletes, tokenizer=tokenizer, forget_sport_subset={forget_sport}, forget_player_subset=n_relearn_athletes, train_test_split=False, is_forget_dataset=True)

relearn_sport.train_df

Unnamed: 0.1,Unnamed: 0,athlete,sport,log_prob_one_shot,num_athlete_tokens,sport_index,sport_token,prompt
0,1642,DeForest Buckner,football,-0.492917,5,2,5842,Fact: Tiger Woods plays the sport of golf\nFac...
1,738,Walter Payton,football,-0.105714,3,2,5842,Fact: Tiger Woods plays the sport of golf\nFac...


In [7]:
n_eval_iters = 10
n_relearn_iters = 10
n_relearn_athletes = 2


if forget_sport is None:
    relearn_sport = SportsTask(batch_size=n_relearn_athletes, tokenizer=tokenizer, forget_player_subset=n_relearn_athletes, train_test_split=False, is_forget_dataset=True)
else:
    relearn_sport = SportsTask(batch_size=n_relearn_athletes, tokenizer=tokenizer, forget_sport_subset={forget_sport}, forget_player_subset=n_relearn_athletes, train_test_split=False, is_forget_dataset=True)

pile = PileTask(batch_size=8, tokenizer=tokenizer, ctx_length=256, shuffle=True, buffer_size=1000)
train_tasks = {"relearn_athletes": (relearn_sport, .2), "maintain_athletes": (maintain_sports, 1), "pile": (train_pile, 1)}

from tasks.facts.SportsTaskAdversarial import adversarial_sports_eval_redo
from tasks.general_capabilities.MCTask_redo import run_general_evals

def eval_callback(model):
    mmlu_score = run_general_evals(model, model_type="gemma")["MMLU"]
    adversarial_results = adversarial_sports_eval_redo(model, model_type=model_type, batch_size=eval_batch_size, 
                    forget_task_init_kwargs={"use_system_prompt":True, "use_icl":False}|forget_kwargs, 
                    maintain_task_init_kwargs={"use_system_prompt":True, "use_icl":False}|maintain_kwargs, 
                    continuous=True, include_evals=["Normal", "MC"])

    # get dictionary of both
    return {"MMLU": mmlu_score, "adversarial": adversarial_results}

# del model

# for name, model, mask, regular_evals, side_effect_evals, adversarial_evals in [("localized", localized_model, localized_mask, localized_regular_evals, localized_side_effect_evals, localized_adversarial_evals), ("nonlocalized", nonlocalized_model, nonlocalized_mask, nonlocalized_regular_evals, nonlocalized_side_effect_evals, nonlocalized_adversarial_evals)]:

relearning_train_results = {}
relearning_test_results = {}
relearning_regular_results = {}
relearning_adversarial_results = {}
relearning_side_effect_results = {}

# for name in mask_init_funcs.keys():

locations = ["orthogonalized_model"]

for name in locations:
    print(f"Running relearning for {name}")

    model = AutoModelForCausalLM.from_pretrained(name, torch_dtype=torch.bfloat16)
    model.cuda()

    train_losses, test_losses = do_relearning(model, train_tasks, n_iters=n_relearn_iters, finetune_lora=True, learning_kwargs={'lr': 1e-4, 'weight_decay': 0, 'use_cosine': True}, eval_callback_fn=eval_callback)

    relearning_train_results[name] = train_losses
    relearning_test_results[name] = test_losses

    relearning_regular_results[name] = {}
    for task_name, test_task in [("forget_sport", forget_sport_eval), ("maintain_sports", maintain_sports_eval)]:
        task_loss = 0
        task_accuracy = 0
        for i in range(n_eval_iters):
            task_loss += test_task.get_test_loss(model).item()
            task_accuracy += test_task.get_test_accuracy(model)
        relearning_regular_results[name][f"{task_name}_ce"] = task_loss / n_eval_iters
        relearning_regular_results[name][f"{task_name}_acc"] = task_accuracy / n_eval_iters

    adversarial_eval_results = adversarial_sports_eval_redo(model, model_type=model_type, batch_size=eval_batch_size, 
                    forget_task_init_kwargs={"use_system_prompt":True, "use_icl":False}|forget_kwargs, 
                    maintain_task_init_kwargs={"use_system_prompt":True, "use_icl":False}|maintain_kwargs, 
                    continuous=True, include_evals=["Normal", "MC"])
    relearning_adversarial_results[name] = adversarial_eval_results

    side_effect_eval_results = run_side_effects_evals(model, model_type=model_type, batch_size=eval_batch_size, evals_to_run=["General"], general_batch_size=5)
    relearning_side_effect_results[name] = side_effect_eval_results

    model.cpu()
    del model


Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Running relearning for orthogonalized_model


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

Downloading readme:   0%|          | 0.00/6.38k [00:00<?, ?B/s]

Downloading data: 100%|██████████| 175k/175k [00:00<00:00, 516kB/s]
Downloading data: 100%|██████████| 1.28M/1.28M [00:00<00:00, 6.95MB/s]
Downloading data: 100%|██████████| 208k/208k [00:00<00:00, 1.63MB/s]


Generating test split:   0%|          | 0/100 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1531 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/285 [00:00<?, ? examples/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 1.02 GiB. GPU 0 has a total capacity of 79.15 GiB of which 974.12 MiB is free. Process 2057057 has 33.06 GiB memory in use. Process 2068297 has 45.13 GiB memory in use. Of the allocated memory 42.42 GiB is allocated by PyTorch, and 2.21 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
relearning_test_results

In [None]:
import matplotlib.pyplot as plt

# Function to plot relearning results
def plot_relearning_results(relearning_test_results, metric, title, ylabel):
    plt.figure(figsize=(10, 6))
    for name, results in relearning_test_results.items():
        values = [result[metric] if metric != 'adversarial' else result[metric]['Normal']['forget'] for result in results]
        plt.plot(range(len(values)), values, label=name, marker='o')
    plt.xlabel('Iteration')
    plt.ylabel(ylabel)
    plt.title(title)
    plt.legend()
    plt.grid(True)
    plt.show()

# Plot MMLU
plot_relearning_results(relearning_test_results, 'MMLU', 'MMLU while Relearning Basketball', 'MMLU Score')

# Plot adversarial-normal-forget
def plot_adversarial_results(relearning_test_results, adversarial_type, forget_or_maintain, title, ylabel):
    plt.figure(figsize=(10, 6))
    for name, results in relearning_test_results.items():
        values = [result['adversarial'][adversarial_type][forget_or_maintain] for result in results]
        plt.plot(range(len(values)), values, label=name, marker='o')
    plt.xlabel('Iteration')
    plt.ylabel(ylabel)
    plt.title(title)
    plt.legend()
    plt.grid(True)
    plt.show()

# Plot adversarial-normal-forget
plot_adversarial_results(relearning_test_results, 'Normal', 'forget', 'Normal Basketball-forget Accuracy Over Relearning Iterations', 'Normal Basketball-forget Accuracy')

# Plot adversarial-normal-maintain
plot_adversarial_results(relearning_test_results, 'Normal', 'maintain', 'Normal Sports-maintain Accuracy Over Relearning Iterations', 'Normal Sports-maintain Accuracy')

# Plot adversarial-mc-forget
plot_adversarial_results(relearning_test_results, 'MC', 'forget', 'MC Basketball-forget Accuracy Over Relearning Iterations', 'MC Basketball-forget Accuracy')



In [None]:
os.makedirs(f"{save_dir}/results", exist_ok=True)
with open(f"{save_dir}/results/relearning_{n_relearn_athletes=}_{n_relearn_iters=}_{model_type}_{combine_heads=}_{beta=}_unlearn_{forget_sport=}_{forget_athletes=}_results.pkl", "wb") as f:
    pickle.dump({"relearning_regular_results": relearning_regular_results, "relearning_adversarial_results": relearning_adversarial_results, "relearning_side_effect_results": relearning_side_effect_results, "relearning_train_results": relearning_train_results, "relearning_test_results": relearning_test_results}, f)

In [None]:
n_relearn_athletes = 2
n_relearn_iters = 10
model_type = "gemma"
# combine_heads = False
beta = 3

with open(f"{save_dir}/results/relearning_{n_relearn_athletes=}_{n_relearn_iters=}_{model_type}_{combine_heads=}_{beta=}_unlearn_{forget_sport=}_{forget_athletes=}_results.pkl", "rb") as f:
    results = pickle.load(f)
    relearning_regular_results = results['relearning_regular_results']
    relearning_adversarial_results = results['relearning_adversarial_results']
    relearning_side_effect_results = results['relearning_side_effect_results']


In [None]:
relearning_regular_results

In [None]:
relearning_adversarial_results

## Latent Knowledge

In [None]:
# combine_heads = False
model_paths_dict = {
    localization_type: f"{save_dir}/models/{model_type}_{localization_type}_{combine_heads=}_{beta=}_unlearn_{forget_sport=}_{forget_athletes=}.pt" for localization_type in localization_types
}
def model_init_and_load_func(mask_type):
    model_path = model_paths_dict[mask_type]
    def get_model_fn():
        model = AutoModelForCausalLM.from_pretrained("google/gemma-7b", torch_dtype=torch.bfloat16)
        model.load_state_dict(torch.load(model_path))
        return model
    return get_model_fn
model_init_and_load_funcs = {mask_type: model_init_and_load_func(mask_type) for mask_type in localization_types}

In [None]:
left_tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
left_tokenizer.pad_token_id = left_tokenizer.eos_token_id
left_tokenizer.padding_side = "left"

from collections import defaultdict
def layer_hook_function(layer, outputs, last_token_only=True, store_cpu=False):
    def hook_fn(module, input, output):
        if isinstance(output, tuple):
            save_output = output[0].clone().detach()
        else:
            save_output = output.clone().detach()
        if last_token_only:
            save_output = save_output[:, -1]
        if store_cpu:
            save_output = save_output.cpu()
        outputs[layer].append(save_output)
        # return output
    return hook_fn

def get_hf_residuals(texts, model, batch_size, last_token_only=True, layers_module=None, store_cpu=True, text_col="prompt"):
    # needs left_
    outputs = defaultdict(list)
    hooks = []
    if layers_module is None:
        layers_module = model.model.layers
    for layer, block in enumerate(layers_module):
        hook_fn = layer_hook_function(layer, outputs=outputs, last_token_only=last_token_only, store_cpu=store_cpu)
        hook_applied = block.register_forward_hook(hook_fn)
        hooks.append(hook_applied)

    for idx in tqdm(range(0, len(texts), batch_size)):
        batch_texts = texts[idx:idx+batch_size]
        tokenized = left_tokenizer(batch_texts, return_tensors="pt", padding=True)
        tokenized = {k: v.cuda() for k, v in tokenized.items()}
        with torch.no_grad():
            model(**tokenized)
    
    for layer in outputs:
        outputs[layer] = torch.cat(outputs[layer], dim=0)
        if store_cpu:
            outputs[layer] = outputs[layer].cpu()

    for hook in hooks:
        hook.remove()
    
    return outputs

In [None]:
batch_size = 16
def get_resids(sports_task, model):
    train_outputs = get_hf_residuals(sports_task.train_df["prompt"].tolist(), model, batch_size, last_token_only=True) # needs to not be last token only because of layernorm
    test_outputs = get_hf_residuals(sports_task.test_df["prompt"].tolist(), model, batch_size, last_token_only=True)

    train_labels = sports_task.train_df['sport'].tolist()
    test_labels = sports_task.test_df['sport'].tolist()
    return train_outputs, test_outputs, train_labels, test_labels

forget_is_split = True if forget_sport is not None else False
if forget_is_split:
    forget_train_outputs_dict = {}
    forget_test_outputs_dict = {}
    forget_train_labels_dict = {}
    forget_test_labels_dict = {}
else:
    forget_outputs_dict = {}
    forget_labels_dict = {}

maintain_train_outputs_dict = {}
maintain_test_outputs_dict = {}
maintain_train_labels_dict = {}
maintain_test_labels_dict = {}

for model_name in model_init_and_load_funcs:
    model = model_init_and_load_funcs[model_name]()
    model.cuda()
    if forget_is_split:
        forget_train_outputs_dict[model_name], forget_test_outputs_dict[model_name], forget_train_labels_dict[model_name], forget_test_labels_dict[model_name] = get_resids(forget_sport_eval, model)
    else:
        forget_outputs_dict[model_name], _, forget_labels_dict[model_name], _ = get_resids(forget_sport_eval, model)
    maintain_train_outputs_dict[model_name], maintain_test_outputs_dict[model_name], maintain_train_labels_dict[model_name], maintain_test_labels_dict[model_name] = get_resids(maintain_sports_eval, model)

    model.cpu()
    del model

In [None]:
# set train and test splits
if not forget_is_split:
    print("Performing manual split of the unsplit training dataset")
    train_test_split = .5
    forget_train_outputs_dict = {}
    forget_test_outputs_dict = {}
    forget_train_labels_dict = {}
    forget_test_labels_dict = {}
    for model_name in model_init_and_load_funcs:
        num_train = int(len(forget_labels_dict[model_name]) * train_test_split)
        forget_train_labels_dict[model_name] = forget_labels_dict[model_name][:num_train]
        forget_test_labels_dict[model_name] = forget_labels_dict[model_name][num_train:]
        forget_train_outputs_dict[model_name] = {}
        forget_test_outputs_dict[model_name] = {}
        for layer in range(n_layers):
            forget_train_outputs_dict[model_name][layer] = forget_outputs_dict[model_name][layer][:num_train]
            forget_test_outputs_dict[model_name][layer] = forget_outputs_dict[model_name][layer][num_train:]


In [None]:
for i in range(n_layers):
    print(f"On layer {i}, AP == CT activations is: {(forget_train_outputs_dict['localized_ap'][i] == forget_train_outputs_dict['localized_ct'][i]).all()}")

###  Approach 1: test probes on forget data and maintain data separately?
not really sure what I did here, but it should generalize to individual athletes

In [None]:
from sklearn.linear_model import LogisticRegression

def get_sport_labels(string_labels, return_np=True):
    # want three different lists of labels, one for each sport
    sports = ["baseball", "football", "basketball"]
    sport_labels = {sport: [] for sport in sports}
    for label in string_labels:
        for sport in sports:
            if sport in label:
                sport_labels[sport].append(1)
            else:
                sport_labels[sport].append(0)
    if return_np:
        for sport in sports:
            sport_labels[sport] = np.array(sport_labels[sport])
        
    assert sum(sport_labels["baseball"]) + sum(sport_labels["football"]) + sum(sport_labels["basketball"]) == len(string_labels)
    # assert each position always adds up to 1
    for i in range(len(string_labels)):
        assert sport_labels["baseball"][i] + sport_labels["football"][i] + sport_labels["basketball"][i] == 1
    return sport_labels

# train probes
all_probes = defaultdict(dict) # double-nested dictionary, first keys are model_name, second keys are layers, final values are dictionaries with keys "basketball", "football", "baseball" and values of probes

all_train_accs = defaultdict(dict)
all_test_accs = defaultdict(dict)
all_forget_accs = defaultdict(dict)
all_maintain_accs = defaultdict(dict)

combine_accuracies = True

shuffle_train = True
for model_name in tqdm(model_init_and_load_funcs):
    # train_acts = {}

    forget_test_acts = forget_test_outputs_dict[model_name]
    forget_test_labels = get_sport_labels(forget_test_labels_dict[model_name])
    maintain_test_acts = maintain_test_outputs_dict[model_name]
    maintain_test_labels = get_sport_labels(maintain_test_labels_dict[model_name])

    forget_train_acts = forget_train_outputs_dict[model_name]
    maintain_train_acts = maintain_train_outputs_dict[model_name]
    # forget_test_labels_dict[model_name] + maintain_test_labels_dict[model_name]
    train_labels = forget_train_labels_dict[model_name] + maintain_train_labels_dict[model_name]
    train_labels = get_sport_labels(train_labels)

    test_labels = forget_test_labels_dict[model_name] + maintain_test_labels_dict[model_name]
    test_labels = get_sport_labels(test_labels)

    if shuffle_train:
        shuffle_idx = torch.randperm(len(list(train_labels.values())[0]))

    if shuffle_train:
        for sport in train_labels:
            train_labels[sport] = train_labels[sport][shuffle_idx]
    
    # print(f"Labels look like {train_labels}")

    for layer in range(n_layers):
        layer_train_acts = torch.cat([forget_train_acts[layer], maintain_train_acts[layer]], dim=0).float().cpu().numpy()
        layer_test_acts = torch.cat([forget_test_acts[layer], maintain_test_acts[layer]], dim=0).float().cpu().numpy()
        layer_forget_test_acts = forget_test_acts[layer].float().cpu().numpy()
        layer_maintain_test_acts = maintain_test_acts[layer].float().cpu().numpy()

        if shuffle_train:
            layer_train_acts = layer_train_acts[shuffle_idx]
        all_probes[model_name][layer] = {}

        if not combine_accuracies:
            all_train_accs[model_name][layer] = {}
            all_test_accs[model_name][layer] = {}
            all_forget_accs[model_name][layer] = {}
            all_maintain_accs[model_name][layer] = {}

        sports_train_preds = {}
        sports_test_preds = {}
        sports_forget_preds = {}
        sports_maintain_preds = {}
        for sport in train_labels:
            if sum(train_labels[sport]) <= 0:
                print("No labels for sport", sport)
                continue
            probe = LogisticRegression(max_iter=10000)
            # print(f"Training probe for {sport} at layer {layer}, {layer_train_acts.shape=}, {train_labels[sport].shape=}, {train_labels[sport].mean()=}")
            probe.fit(layer_train_acts, train_labels[sport])
            all_probes[model_name][layer][sport] = probe

            # test probes
            # print(f"{sport=}, {layer_train_acts.shape=}, {train_labels[sport].shape=}, {train_labels[sport].mean()=}")
            train_preds = probe.predict(layer_train_acts)
            if not combine_accuracies:
                train_acc = (train_preds == train_labels[sport]).sum() / len(train_labels[sport])
                all_train_accs[model_name][layer][sport] = train_acc
            else:
                sports_train_preds[sport] = train_preds


            # print(f"Testing probe for {sport} at layer {layer}, {layer_test_acts.shape=}, {test_labels[sport].shape=}, {test_labels[sport].mean()=}")
            test_preds = probe.predict(layer_test_acts)
            if not combine_accuracies:
                test_acc = (test_preds == test_labels[sport]).sum() / len(test_labels[sport])
                all_forget_accs[model_name][layer][sport] = test_acc
            else:
                sports_test_preds[sport] = test_preds

            # print(f"{sport=}, {layer_forget_test_acts.shape=}, {forget_test_labels[sport].shape=}, {forget_test_labels[sport].mean()=}")
            forget_test_preds = probe.predict(layer_forget_test_acts)
            if not combine_accuracies:
                forget_acc = (forget_test_preds == forget_test_labels[sport]).sum() / len(forget_test_labels[sport])
                all_test_accs[model_name][layer][sport] = forget_acc
            else:
                sports_forget_preds[sport] = forget_test_preds

            # print(f"{sport=}, {layer_maintain_test_acts.shape=}, {maintain_test_labels[sport].shape=}, {maintain_test_labels[sport].mean()=}")
            maintain_test_preds = probe.predict(layer_maintain_test_acts)
            if not combine_accuracies:
                maintain_acc = (maintain_test_preds == maintain_test_labels[sport]).sum() / len(maintain_test_labels[sport])
                all_maintain_accs[model_name][layer][sport] = maintain_acc 
            else:
                sports_maintain_preds[sport] = maintain_test_preds

        if combine_accuracies:
            # combine accuracies by saying probes correct if all sports are correct
            train_correct = np.ones(len(train_labels["baseball"]))
            test_correct = np.ones(len(test_labels["baseball"]))
            forget_correct = np.ones(len(forget_test_labels["baseball"]))
            maintain_correct = np.ones(len(maintain_test_labels["baseball"]))
            for sport in train_labels:
                if sum(train_labels[sport]) > 0:
                    train_correct *= (sports_train_preds[sport] == train_labels[sport])
                else:
                    print("No train labels for sport", sport)
                if sum(test_labels[sport]) > 0:
                    test_correct *= (sports_test_preds[sport] == test_labels[sport])
                else:
                    print("No test labels for sport", sport)
                if sum(forget_test_labels[sport]) > 0:
                    forget_correct *= (sports_forget_preds[sport] == forget_test_labels[sport])
                else:
                    print("No forget labels for sport", sport)
                if sum(maintain_test_labels[sport]) > 0:
                    maintain_correct *= (sports_maintain_preds[sport] == maintain_test_labels[sport])
                else:
                    print("No maintain labels for sport", sport)

            all_train_accs[model_name][layer] = train_correct.mean()
            all_test_accs[model_name][layer] = test_correct.mean()
            all_forget_accs[model_name][layer] = forget_correct.mean()
            all_maintain_accs[model_name][layer] = maintain_correct.mean()

with open(f"{save_dir}/results/probes_{model_type}_{combine_heads=}_{beta=}_unlearn_{forget_sport=}_{forget_athletes=}.pkl", "wb") as f:
    pickle.dump({"all_probes": all_probes, "all_train_accs": all_train_accs, "all_test_accs": all_test_accs, "all_forget_accs": all_forget_accs, "all_maintain_accs": all_maintain_accs}, f)

In [None]:
for i in range(n_layers):
    print(f"At layer {i}, {(all_probes['localized_ap'][i]['basketball'].coef_ - all_probes['localized_ct'][i]['basketball'].coef_).sum()}")

In [None]:
import matplotlib.pyplot as plt
# combine_heads = False #accidentally set this earlier, but its not actually False in the models
with open(f"{save_dir}/results/probes_{model_type}_{combine_heads=}_{beta=}_unlearn_{forget_sport=}_{forget_athletes=}.pkl", "rb") as f:
    results = pickle.load(f)
    all_probes = results['all_probes']
    all_train_accs = results['all_train_accs']
    all_test_accs = results['all_test_accs']
    all_forget_accs = results['all_forget_accs']
    all_maintain_accs = results['all_maintain_accs']

# combine_accuracies = True
if combine_accuracies:
    def plot_accuracies(all_train_accs, all_test_accs, all_forget_accs, all_maintain_accs):
        for model_name in all_train_accs.keys():
            layers = list(all_train_accs[model_name].keys())
            train_accs = [all_train_accs[model_name][layer] for layer in layers]
            test_accs = [all_test_accs[model_name][layer] for layer in layers]
            forget_accs = [all_forget_accs[model_name][layer] for layer in layers]
            maintain_accs = [all_maintain_accs[model_name][layer] for layer in layers]

            plt.figure(figsize=(10, 6))
            plt.plot(layers, train_accs, label='Train Accuracy', alpha=0.5, linestyle='--')
            # plt.plot(layers, test_accs, label='Test Accuracy')
            plt.plot(layers, forget_accs, label='Forget Accuracy')
            plt.plot(layers, maintain_accs, label='Maintain Accuracy')
            
            plt.xlabel('Layer')
            plt.ylabel('Accuracy')
            plt.title(f'Accuracy per Layer for {model_name}')
            plt.legend()
            plt.grid(True)
            plt.show()

    # Call the function to plot the accuracies
    plot_accuracies(all_train_accs, all_test_accs, all_forget_accs, all_maintain_accs)

else:
    def plot_accuracies(all_train_accs, all_test_accs):
        sports = ["baseball", "football", "basketball"]
        
        for model_name in all_train_accs.keys():
            layers = list(all_train_accs[model_name].keys())
            
            plt.figure(figsize=(12, 8))
            
            for sport in sports:
                train_accs = [all_train_accs[model_name][layer].get(sport, 0) for layer in layers]
                test_accs = [all_test_accs[model_name][layer].get(sport, 0) for layer in layers]
                # forget_accs = [all_forget_accs[model_name][layer].get(sport, 0) for layer in layers]
                # maintain_accs = [all_maintain_accs[model_name][layer].get(sport, 0) for layer in layers]
                
                plt.plot(layers, train_accs, label=f'Train Accuracy - {sport}', linestyle='--', alpha=0.5)
                plt.plot(layers, test_accs, label=f'Test Accuracy - {sport}')
                # plt.plot(layers, forget_accs, label=f'Forget Accuracy - {sport}')
                # plt.plot(layers, maintain_accs, label=f'Maintain Accuracy - {sport}')
            
            plt.xlabel('Layer')
            plt.ylabel('Accuracy')
            plt.title(f'Accuracy per Layer for {model_name}')
            plt.legend()
            plt.grid(True)
            plt.show()

    # Call the function to plot the accuracies
    plot_accuracies(all_train_accs, all_test_accs)

In [None]:
import matplotlib.pyplot as plt

def plot_final_accuracies(all_accs, formal=True):
    # colors = plt.cm.get_cmap('tab10', len(all_train_accs))  # Get a colormap with enough colors for all models

    plt.figure(figsize=(12, 8))
    
    for idx, model_name in enumerate(all_train_accs.keys()):
        layers = list(all_train_accs[model_name].keys())
        accs = [all_accs[model_name][layer] for layer in layers]

        # plt.plot(layers, train_accs, label=f'{model_name} Train', color=colors(idx), linestyle=line_styles[0], alpha=0.5)
        # plt.plot(layers, forget_accs, label=f'{model_name} Forget', color=colors[idx], linestyle=line_styles[1])
        plt.plot(layers, accs, label=f'{formal_name_dict[model_name]}', color=color_map[model_name], marker='o')
    
    plt.xlabel('Layer')
    plt.ylabel('Accuracy')

# Call the function to plot the accuracies
plot_final_accuracies(all_maintain_accs)
plt.title('Probe Accuracy on Maintain Athletes')
plt.legend()
plt.grid(True)
plt.show()


In [None]:
plot_final_accuracies(all_forget_accs)
plt.title('Probe Accuracy on Forget Athletes')
plt.legend()
plt.grid(True)
plt.show()