In [1]:
%load_ext autoreload
%autoreload 2

import os
from dotenv import load_dotenv
from pathlib import Path
from research_tools.gpu import get_gpus_available

load_dotenv()


hf_access_token = os.getenv("HUGGINGFACE_API_KEY")

n_gpus = 1

gpus_available = get_gpus_available()
n_gpus = min(n_gpus, len(gpus_available))
gpus = gpus_available[:n_gpus]

assert n_gpus > 0, "No GPUs available"

os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in gpus])

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Dict, Optional
import torch
from research_tools.utils import set_seed

set_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
assert device == torch.device("cuda")

model_id = "HuggingFaceH4/zephyr-7b-beta"

model = AutoModelForCausalLM.from_pretrained(
    model_id, torch_dtype=torch.bfloat16, trust_remote_code=True, token=hf_access_token
).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

Could not cache non-existence of file. Will ignore error and continue. Error: [Errno 13] Permission denied: '/mnt/align4_drive/data/huggingface/hub/models--HuggingFaceH4--zephyr-7b-beta/.no_exist/892b3d7a7b1cf10c7a701c60881cd93df615734c/chat_template.jinja'


In [3]:
from datasets import load_dataset
from relearn.datasets.utils import (
    load_dataset as local_load_dataset,
    DATASETS_DICT,
    Datasets,
)
from relearn.datasets.corpus import process as process_corpus
from relearn.datasets.mcq import process as process_mcq

dataset_config = DATASETS_DICT[Datasets.WMDP]

# retain_dataset = load_dataset("wikitext", "wikitext-2-raw-v1")

data_dir = Path("../data")


def get_dataset(train_files: List[str], val_files: List[str], max_length: int):
    train = local_load_dataset(data_dir, train_files)
    val = local_load_dataset(data_dir, val_files)
    train_records = process_corpus(train, tokenizer, max_length)
    val_records = process_mcq(val, tokenizer, max_length)
    mcq_records = process_mcq(val, tokenizer, max_length, expand_choices=False)
    return {
        "corpus": train_records,
        "mcq": mcq_records,
        "val": val_records,
    }

n_val_files = 4
max_length = 512

unlearn_files = dataset_config["unlearn_files"]
val_unlearn_files = dataset_config["val_unlearn_files"]
retain_files = dataset_config["retain_files"]
val_retain_files = dataset_config["val_retain_files"]

store = {
    "A": get_dataset(unlearn_files[:n_val_files], val_unlearn_files[:n_val_files], max_length),
    "B": get_dataset(unlearn_files[n_val_files:], val_unlearn_files[n_val_files:], max_length),
    "retain": get_dataset(retain_files, val_retain_files, max_length),
}

Filter:   0%|          | 0/1884 [00:00<?, ? examples/s]

Map:   0%|          | 0/1884 [00:00<?, ? examples/s]

Filter:   0%|          | 0/471 [00:00<?, ? examples/s]

Map:   0%|          | 0/471 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

In [4]:
eval_dict = {
    k: v["val"] for k, v in store.items()
}

In [5]:
import wandb

wandb.login()

config = {
    "model_id": model_id,
    "magnitude": 6.5,
    "lr": 1e-5,
    "n_epochs": 12,
    "forget_alphas": {"A": 0.39422},
    "retain_alphas": {"B": 13.51609, "retain": 1},
    "datasets_config": dataset_config,
}


run = wandb.init(
    project="relearn", config=config, tags=["rmu", "debug"], entity="12tqian"
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33m12tqian[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [5]:
from pathlib import Path

path = Path("../models/random_bd") / "init"
# os.makedirs(path, exist_ok=True)
# model.save_pretrained(path)

# torch.save(model.state_dict(), path)
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16).to(
    device
)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [6]:
from relearn.evaluate import run_eval

res = run_eval(
    model,
    tokenizer,
    eval_dict,
    0,
    use_wandb=False
)
for k, v in res.items():
    print(f"{k}: {v}")

100%|██████████| 314/314 [00:20<00:00, 14.98it/s]
100%|██████████| 79/79 [00:05<00:00, 14.84it/s]
100%|██████████| 393/393 [00:42<00:00,  9.25it/s]

epoch: 0
A/acc: 0.5398089171974523
B/acc: 0.47770700636942676
retain/acc: 0.5796178343949044



