In [1]:
%load_ext autoreload
%autoreload 2

import os
from dotenv import load_dotenv
from pathlib import Path
from research_tools.gpu import get_gpus_available

load_dotenv()


hf_access_token = os.getenv("HUGGINGFACE_API_KEY")

n_gpus = 1

gpus_available = get_gpus_available()
n_gpus = min(n_gpus, len(gpus_available))
gpus = gpus_available[:n_gpus]

assert n_gpus > 0, "No GPUs available"

os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in gpus])

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Dict, Optional
import torch
from research_tools.utils import set_seed

set_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
assert device == torch.device("cuda")

model_id = "HuggingFaceH4/zephyr-7b-beta"

model = AutoModelForCausalLM.from_pretrained(
    model_id, torch_dtype=torch.bfloat16, trust_remote_code=True, token=hf_access_token
).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [3]:
from relearn.datasets.utils import (
    Datasets,
)
from relearn.datasets.corpus import process as process_corpus
from relearn.datasets.mcq import process as process_mcq
import pickle

data_dir = Path("../data")
cache_path = data_dir / "full.pickle"

assert cache_path.exists(), "Cache file does not exist"
with open(cache_path, "rb") as f:
    data = pickle.load(f)

In [None]:
from relearn.unlearn.rmu import train_rmu
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Dict, List, Optional
import itertools
import wandb
from relearn.evaluate import run_eval
import numpy as np
from relearn.datasets.folds import get_folds_shuffled, fold_name

import torch
from torch.optim import AdamW
from relearn.unlearn.rmu.utils import get_params

from torch.optim.lr_scheduler import ExponentialLR

config = {
    "model_id": model_id,
    "k_folds": 4,
    "lr": 1e-5,
    "lr_decay": 1,
    "prefix_forget": True,
    "sweeping": True,
    "use_wandb": True,
    "joint_train": True,
    "magnitude": 6.5,
    "forget_alpha": 0.39422,
    "retain_alpha": 32,
    "actual_retain_alpha": 1.0,
    "batch_size": 4,
    "epsilon": 1e-6,
    "activation_layer": 7,
    "train_layers": [5, 6, 7],
    "param_names": ["down_proj"],
    "epochs_per_fold": 12,
}
forget_records_dict = data[Datasets.WMDP]
retain_records_dict = data["retain"]



k_folds = config["k_folds"]
lr = config["lr"]
lr_decay = config["lr_decay"]
prefix_forget = config["prefix_forget"]
sweeping = config["sweeping"]
use_wandb = config["use_wandb"]
joint_train = config["joint_train"]
magnitude = config["magnitude"]
forget_alpha = config["forget_alpha"]
retain_alpha = config["retain_alpha"]
actual_retain_alpha = config["actual_retain_alpha"]
batch_size = config["batch_size"]
epsilon = config["epsilon"]

# for now i do prefix forget false

activation_layer = config["activation_layer"]
train_layers = config["train_layers"]
param_names  = config["param_names"]
epochs_per_fold = config["epochs_per_fold"]


cur_i = 0

if cur_i > 0:
    load_path = Path(f"../models/four/{cur_i - 1}")
    model = AutoModelForCausalLM.from_pretrained(
        load_path, torch_dtype=torch.bfloat16,
    ).to(device)

    optimizer_path = load_path / "optimizer.pt"
    optimizer = torch.load(optimizer_path)

    scheduler_path = load_path / "scheduler.pt"
    scheduler = torch.load(scheduler_path)

    control_vecs_path = load_path / "control_vecs.pt"
    control_vecs = torch.load(control_vecs_path)

else:
    optimizer = AdamW(
    get_params(model, train_layers, param_names),
        lr=lr,
    )

    scheduler = ExponentialLR(optimizer, gamma=lr_decay)
    control_vecs = {}


run = wandb.init(
    project="relearn", config=config, tags=["rmu", "fold", "debug"], entity="12tqian"
)





folds = get_folds_shuffled(forget_records_dict, k_folds)
print(len(folds))


def get_data(fold_inds: List[int]):
    if joint_train:
        return {fold_name(i): folds[i]["corpus"] for i in fold_inds}
    else:
        return {
            fold_name(-1): list(
                itertools.chain(
                    *[f[i]["corpus"] for i, f in enumerate(folds) if i in fold_inds]
                )
            )
        }

eval_records_dict = {fold_name(i): folds[i]["val"] for i in range(k_folds)}
eval_records_dict["retain"] = retain_records_dict["val"]

base_epoch = 0



# forget alpha for one fold, two fold
for i in [cur_i]:
    print(f"Unlearning fold {fold_name(i)}")

    if prefix_forget:
        forget_fold_inds = list(range(i + 1))
    else:
        forget_fold_inds = [i]

    retain_fold_inds = list(range(i + 1, k_folds))

    cur_lr = lr

    forget_dict = get_data(forget_fold_inds)
    retain_dict = get_data(retain_fold_inds)
    retain_dict["retain"] = retain_records_dict["corpus"]

    shared_forget_alpha = forget_alpha / (1 + epsilon)
    shared_retain_alpha = retain_alpha / (k_folds - 1 + epsilon)

    # weird retain coef cause i finetuned for 2/2
    model, control_vecs_next, eval_dict = train_rmu(
        model,
        forget_dict,
        retain_dict,
        eval_records_dict,
        magnitude=magnitude,
        # if current, then do shared full forget alpha
        # all others are split k - 1
        # for the retains, you do the thing
        # this is so we can share optimizer state? cope
        forget_alphas={
            k: shared_forget_alpha if idx == i else shared_retain_alpha
            for idx, k in enumerate(forget_dict.keys()):
        },

        # forget_alphas={
        #     k: (
        #         forget_alpha / (i + 1 + epsilon)
        #         if prefix_forget
        #         else forget_alpha / (1 + epsilon)
        #     )
        #     for k in forget_dict.keys()
        # },
        retain_alphas={
            **{k: shared_retain_alpha 
            for idx, k in enumerate(retain_dict.keys())},
            **{"retain": actual_retain_alpha + (retain_alpha if i == k_folds - 1 else 0)},

        },
        # retain_alphas={
        #     **{
        #         k: retain_alpha / (k_folds - i - 1 + epsilon)
        #         for k in retain_dict.keys()
        #     },
        #     **{
        #         # unsure if below is necessary?
        #         "retain": actual_retain_alpha
        #         + (retain_alpha if i == k_folds - 1 else 0)
        #     },
        # },
        lr=cur_lr,
        tokenizer=tokenizer,
        use_wandb=use_wandb,
        eval_at_start=True,
        n_epochs=epochs_per_fold,
        batch_size=batch_size,
        max_batches=None,
        base_epoch=base_epoch,
        return_control_vecs=True,
        control_vecs_init=control_vecs,
        print_evals=True,
        monitor_name=f"{fold_name(i)}/acc",
        activation_layer=activation_layer,
        train_layers=train_layers,
        param_names=param_names,
        optimizer=optimizer,
        scheduler=scheduler,
    )

    control_vecs.update(control_vecs_next)
    base_epoch += epochs_per_fold
if sweeping:
    full_eval = {
        "forget": forget_records_dict["val"],
        "retain": retain_records_dict["val"],
    }
    res = run_eval(model, tokenizer, full_eval, -1)


    if use_wandb:
        wandb.log(res)


run.finish()


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33m12tqian[0m. Use [1m`wandb login --relogin`[0m to force relogin


4
Unlearning fold A


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 99/99 [00:06<00:00, 14.33it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98/98 [00:06<00:00, 14.89it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98/98 [00:06<00:00, 16.21it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98/98 [00:06<00:00, 14.60it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 393/393 [00:42<00:00,  

{'epoch': -1, 'A/acc': 0.5939086294416244, 'B/acc': 0.5357142857142857, 'C/acc': 0.5153061224489796, 'D/acc': 0.5510204081632653, 'retain/acc': 0.5923566878980892}


  forget_loss = torch.nn.functional.mse_loss(
100%|███████████████████████████████████████████| 147/147 [02:45<00:00,  1.13s/it, A/forget_loss=0.0179, B/retain_loss=1.54e-5, C/retain_loss=1.75e-5, D/retain_loss=1.61e-5, retain/retain_loss=4.12e-7]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 99/99 [00:06<00:00, 15.40it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98/98 [00:06<00:00, 14.77it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98/98 [00:06<00:00, 16.10it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████

{'epoch': 0, 'A/acc': 0.5989847715736041, 'B/acc': 0.5408163265306123, 'C/acc': 0.5153061224489796, 'D/acc': 0.5459183673469388, 'retain/acc': 0.5910828025477707}


 35%|███████████████▌                            | 52/147 [00:58<01:47,  1.13s/it, A/forget_loss=0.0179, B/retain_loss=1.85e-5, C/retain_loss=1.79e-5, D/retain_loss=2.09e-5, retain/retain_loss=6.48e-7]

In [None]:
path = Path(f"../models/four/{cur_i}")
path.mkdir(parents=True, exist_ok=True)
model.save_pretrained(path)

import json

# save config for later
config_path = path / "my_config.json"
with open(config_path, "w") as f:
    json.dump(config, f)


# save control vecs, map of dict of tensors
control_vecs_path = path / "control_vecs.pt"
torch.save(control_vecs, control_vecs_path)

# save optimizer state
optimizer_path = path / "optimizer.pt"
torch.save(optimizer.state_dict(), optimizer_path)

# save scheduler state
scheduler_path = path / "scheduler.pt"
torch.save(scheduler.state_dict(), scheduler_path)