In [1]:
%load_ext autoreload
%autoreload 2

import os
from dotenv import load_dotenv
from pathlib import Path
from research_tools.gpu import get_gpus_available

load_dotenv()


hf_access_token = os.getenv("HUGGINGFACE_API_KEY")

n_gpus = 1

gpus_available = get_gpus_available()
n_gpus = min(n_gpus, len(gpus_available))
gpus = gpus_available[:n_gpus]

assert n_gpus > 0, "No GPUs available"

os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in gpus])

In [2]:
# print transformers_cache
os.environ["TRANSFORMERS_CACHE"] = ""
# remove from environment
del os.environ["TRANSFORMERS_CACHE"]

In [3]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Dict, Optional
import torch
from research_tools.utils import set_seed

set_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
assert device == torch.device("cuda")

model_id = "HuggingFaceH4/zephyr-7b-beta"

model = AutoModelForCausalLM.from_pretrained(
    model_id, torch_dtype=torch.bfloat16, trust_remote_code=True, token=hf_access_token
).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [4]:
from datasets import load_dataset
from relearn.datasets.utils import (
    load_dataset as local_load_dataset,
    DATASETS_DICT,
    Datasets,
)
from relearn.datasets.corpus import process as process_corpus
from relearn.datasets.mcq import process as process_mcq
import pickle

data_dir = Path("../data")
cache_path = data_dir / "full.pickle"

USE_CACHE = True

if USE_CACHE:
    assert cache_path.exists(), "Cache file does not exist"
    with open(cache_path, "rb") as f:
        data = pickle.load(f)
else:
    data = {}
    # iterate over all enums
    for name in Datasets:

        dataset_config = DATASETS_DICT[name]

        def get_dataset(train_files: List[str], val_files: List[str], max_length: int):
            train = local_load_dataset(data_dir, train_files)
            val = local_load_dataset(data_dir, val_files)
            train_records = process_corpus(train, tokenizer, max_length)
            val_records = process_mcq(val, tokenizer, max_length)
            mcq_records = process_mcq(val, tokenizer, max_length, expand_choices=False)
            return {
                "corpus": train_records,
                "mcq": mcq_records,
                "val": val_records,
            }

        max_length = 512

        unlearn_files = dataset_config["unlearn_files"]
        val_unlearn_files = dataset_config["val_unlearn_files"]

        print(f"Processing {name}")

        data[name] = get_dataset(unlearn_files, val_unlearn_files, max_length)

        if "retain" not in data:
            retain_files = dataset_config["retain_files"]
            val_retain_files = dataset_config["val_retain_files"]
            data["retain"] = get_dataset(retain_files, val_retain_files, max_length)

    with open(cache_path, "wb") as f:
        pickle.dump(data, f)

In [5]:
import wandb

wandb.login()
config = {
    "model_id": model_id,
}


run = wandb.init(
    project="relearn", config=config, tags=["rmu", "fold", "debug"], entity="12tqian"
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33m12tqian[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
from relearn.unlearn.rmu import super_rmu

forget_set = data[Datasets.WMDP]
retain_set = data["retain"]

model, res = super_rmu(
    model,
    tokenizer,
    forget_set,
    retain_set,
    None,
    4,
    magnitude=6.5,
    forget_alpha=0.023105360391794554,
    retain_alpha=0.14482228107954087,
    epochs_per_fold=3,
    lr=0.00008156557999985112,
    lr_end=0.11877675152876664,
    joint_train=True,
    prefix_forget=True,  # i have now set this to true
    sweeping=True,
)
print(res)

Unlearning fold A


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 99/99 [00:07<00:00, 14.10it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98/98 [00:06<00:00, 14.89it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98/98 [00:06<00:00, 16.20it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98/98 [00:07<00:00, 12.30it/s]
 38%|██████████████████████████████████████████████████████████████▏                                                                                                   | 151/393 [00:22<00:59,  

In [None]:
run.finish()

path = Path("../models/super_rmu")
path.mkdir(exist_ok=True, parents=True)
model.save_pretrained(path)

In [None]:
from relearn.attacks import super_rtt

config = {
    "model_id": model_id,
}
run = wandb.init(
    project="relearn", config=config, tags=["rtt", "fold", "debug"], entity="12tqian"
)

model = super_rtt(model, tokenizer, forget_set, 4, use_wandb=True)

run.finish()